diff --git a/.clang-format b/.clang-format index 11a44d587c..63ebecbce1 100644 --- a/.clang-format +++ b/.clang-format @@ -19,3 +19,6 @@ BreakBeforeTernaryOperators: false IndentWrappedFunctionNames: true ContinuationIndentWidth: 4 ObjCSpaceBeforeProtocolList: true +--- +Language: Cpp +IncludeBlocks: Regroup diff --git a/.gitignore b/.gitignore index c43e108f51..a86b405fab 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ *_proto_cpp.xml *~ .*.sw? +.cache .cipd .clangd .classpath diff --git a/.gn b/.gn index a2e2a90425..d6f84df080 100644 --- a/.gn +++ b/.gn @@ -11,37 +11,20 @@ import("//build/dotfile_settings.gni") # The location of the build configuration file. buildconfig = "//build/config/BUILDCONFIG.gn" +# The python interpreter to use by default. On Windows, this will look +# for python3.exe and python3.bat. +script_executable = "python3" + # The secondary source root is a parallel directory tree where # GN build files are placed when they can not be placed directly # in the source tree, e.g. for third party source trees. secondary_source = "//build/secondary/" -# These are the targets to check headers for by default. The files in targets -# matching these patterns (see "gn help label_pattern" for format) will have +# These are the targets to skip header checking by default. The files in targets +# matching these patterns (see "gn help label_pattern" for format) will not have # their includes checked for proper dependencies when you run either # "gn check" or "gn gen --check". -check_targets = [ - "//api/*", - "//audio/*", - "//backup/*", - "//call/*", - "//common_audio/*", - "//common_video/*", - "//examples/*", - "//logging/*", - "//media/*", - "//modules/*", - "//p2p/*", - "//pc/*", - "//rtc_base/*", - "//rtc_tools/*", - "//sdk/*", - "//stats/*", - "//system_wrappers/*", - "//test/*", - "//video/*", - "//third_party/libyuv/*", -] +no_check_targets = [ "//third_party/icu/*" ] # These are the list of GN files that run exec_script. This whitelist exists # to force additional review for new uses of exec_script, which is strongly @@ -61,7 +44,7 @@ default_args = { mac_sdk_min = "10.12" - ios_deployment_target = "10.0" + ios_deployment_target = "12.0" # The SDK API level, in contrast, is set by build/android/AndroidManifest.xml. android32_ndk_api_level = 16 diff --git a/.vpython b/.vpython index 92c9c51346..df838dccf8 100644 --- a/.vpython +++ b/.vpython @@ -52,7 +52,7 @@ wheel: < wheel: < name: "infra/python/wheels/six-py2_py3" - version: "version:1.10.0" + version: "version:1.15.0" > wheel: < name: "infra/python/wheels/pbr-py2_py3" diff --git a/AUTHORS b/AUTHORS index 74b1faef35..b4d4100c6a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -11,6 +11,7 @@ # Please keep the list sorted. # BEGIN individuals section. +Aaron Clauson Adam Fedor Akshay Shah Alexander Brauckmann @@ -21,6 +22,7 @@ Andrey Efremov Andrew Johnson Anil Kumar Ben Strong +Berthold Herrmann Bob Withers Bridger Maxwell Christophe Dumez @@ -30,10 +32,12 @@ Colin Plumb Cyril Lashkevich CZ Theng Danail Kirov +Dave Cowart David Porter Dax Booysen Dennis Angelo Dharmesh Chauhan +Di Wu Dirk-Jan C. Binnema Dmitry Lizin Eike Rathke @@ -49,10 +53,12 @@ James H. Brown Jan Grulich Jan Kalab Jens Nielsen +Jesús Leganés-Combarro Jiawei Ou Jie Mao Jiwon Kim Jose Antonio Olivera Ortega +Keiichi Enomoto Kiran Thind Korniltsev Anatoly Lennart Grahl diff --git a/BUILD.gn b/BUILD.gn index f8707dae8f..bc51df7c07 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -14,6 +14,7 @@ import("//build/config/linux/pkg_config.gni") import("//build/config/sanitizers/sanitizers.gni") +import("//third_party/google_benchmark/buildconfig.gni") import("webrtc.gni") if (rtc_enable_protobuf) { import("//third_party/protobuf/proto_library.gni") @@ -54,6 +55,7 @@ if (!build_with_chromium) { "modules/remote_bitrate_estimator:rtp_to_text", "modules/rtp_rtcp:test_packet_masks_metrics", "modules/video_capture:video_capture_internal_impl", + "net/dcsctp:dcsctp_unittests", "pc:peerconnection_unittests", "pc:rtc_pc_unittests", "rtc_tools:rtp_generator", @@ -136,6 +138,10 @@ config("common_inherited_config") { defines += [ "WEBRTC_ENABLE_AVX2" ] } + if (rtc_enable_win_wgc) { + defines += [ "RTC_ENABLE_WIN_WGC" ] + } + # Some tests need to declare their own trace event handlers. If this define is # not set, the first time TRACE_EVENT_* is called it will store the return # value for the current handler in an static variable, so that subsequent @@ -261,7 +267,7 @@ config("common_config") { } if (rtc_enable_sctp) { - defines += [ "HAVE_SCTP" ] + defines += [ "WEBRTC_HAVE_SCTP" ] } if (rtc_enable_external_auth) { @@ -348,6 +354,13 @@ config("common_config") { # recognize. cflags += [ "-Wunused-lambda-capture" ] } + + if (use_xcode_clang) { + # This may be removed if the clang version in xcode > 12.4 includes the + # fix https://reviews.llvm.org/D73007. + # https://bugs.llvm.org/show_bug.cgi?id=44556 + cflags += [ "-Wno-range-loop-analysis" ] + } } if (is_win && !is_clang) { @@ -422,10 +435,6 @@ config("common_config") { config("common_objc") { frameworks = [ "Foundation.framework" ] - - if (rtc_use_metal_rendering) { - defines = [ "RTC_SUPPORTS_METAL" ] - } } if (!build_with_chromium) { @@ -507,6 +516,10 @@ if (!build_with_chromium) { rtc_executable("webrtc_lib_link_test") { testonly = true + # This target is used for checking to link, so do not check dependencies + # on gn check. + check_includes = false # no-presubmit-check TODO(bugs.webrtc.org/12785) + sources = [ "webrtc_lib_link_test.cc" ] deps = [ # NOTE: Don't add deps here. If this test fails to link, it means you @@ -526,7 +539,7 @@ if (use_libfuzzer || use_afl) { } } -if (rtc_include_tests) { +if (rtc_include_tests && !build_with_chromium) { rtc_test("rtc_unittests") { testonly = true @@ -539,6 +552,7 @@ if (rtc_include_tests) { "api/transport:stun_unittest", "api/video/test:rtc_api_video_unittests", "api/video_codecs/test:video_codecs_api_unittests", + "api/voip:compile_all_headers", "call:fake_network_pipe_unittests", "p2p:libstunprober_unittests", "p2p:rtc_p2p_unittests", @@ -553,7 +567,7 @@ if (rtc_include_tests) { "rtc_base:untyped_function_unittest", "rtc_base:weak_ptr_unittests", "rtc_base/experiments:experiments_unittests", - "rtc_base/synchronization:sequence_checker_unittests", + "rtc_base/system:file_wrapper_unittests", "rtc_base/task_utils:pending_task_safety_flag_unittests", "rtc_base/task_utils:to_queued_task_unittests", "sdk:sdk_tests", @@ -583,12 +597,14 @@ if (rtc_include_tests) { } } - rtc_test("benchmarks") { - testonly = true - deps = [ - "rtc_base/synchronization:mutex_benchmark", - "test:benchmark_main", - ] + if (enable_google_benchmarks) { + rtc_test("benchmarks") { + testonly = true + deps = [ + "rtc_base/synchronization:mutex_benchmark", + "test:benchmark_main", + ] + } } # This runs tests that must run in real time and therefore can take some @@ -698,6 +714,7 @@ if (rtc_include_tests) { rtc_test("voip_unittests") { testonly = true deps = [ + "api/voip:compile_all_headers", "api/voip:voip_engine_factory_unittests", "audio/voip/test:audio_channel_unittests", "audio/voip/test:audio_egress_unittests", diff --git a/DEPS b/DEPS index 2b218db045..c24608a98a 100644 --- a/DEPS +++ b/DEPS @@ -1,43 +1,49 @@ # This file contains dependencies for WebRTC. gclient_gn_args_file = 'src/build/config/gclient_args.gni' +gclient_gn_args = [ + 'generate_location_tags', +] vars = { # By default, we should check out everything needed to run on the main # chromium waterfalls. More info at: crbug.com/570091. 'checkout_configuration': 'default', 'checkout_instrumented_libraries': 'checkout_linux and checkout_configuration == "default"', - 'chromium_revision': '42ab9dc8c82834db1834e688311ddd31ba475296', + 'chromium_revision': '6d8828f6a6eea769a05fa1c0b7acf10aca631d4a', + + # Keep the Chromium default of generating location tags. + 'generate_location_tags': True, } deps = { # TODO(kjellander): Move this to be Android-only once the libevent dependency # in base/third_party/libevent is solved. 'src/base': - 'https://chromium.googlesource.com/chromium/src/base@a361323fd59a160a15e5ce050ab53612cfab4956', + 'https://chromium.googlesource.com/chromium/src/base@e1acc6a30942360d4789d6c245cf7933e7e9bbec', 'src/build': - 'https://chromium.googlesource.com/chromium/src/build@d64e5999e338496041478aaed1759a32f2d2ff21', + 'https://chromium.googlesource.com/chromium/src/build@826926008327af276adbaafcfa92b525eb5bf326', 'src/buildtools': - 'https://chromium.googlesource.com/chromium/src/buildtools@235cfe435ca5a9826569ee4ef603e226216bd768', + 'https://chromium.googlesource.com/chromium/src/buildtools@2500c1d8f3a20a66a7cbafe3f69079a2edb742dd', # Gradle 6.6.1. Used for testing Android Studio project generation for WebRTC. 'src/examples/androidtests/third_party/gradle': { 'url': 'https://chromium.googlesource.com/external/github.com/gradle/gradle.git@f2d1fb54a951d8b11d25748e4711bec8d128d7e3', 'condition': 'checkout_android', }, 'src/ios': { - 'url': 'https://chromium.googlesource.com/chromium/src/ios@e13cd2916385804bc8549651a4fb09d8a708ef0d', + 'url': 'https://chromium.googlesource.com/chromium/src/ios@695a3541172406518e45c377048956a3e5270d7c', 'condition': 'checkout_ios', }, 'src/testing': - 'https://chromium.googlesource.com/chromium/src/testing@049cd24b1fb14d921ffda917ffa066ca89d4f403', + 'https://chromium.googlesource.com/chromium/src/testing@d749d1b98b475ea15face1c9d2311ed6b8e4b91f', 'src/third_party': - 'https://chromium.googlesource.com/chromium/src/third_party@e35e2377dbd560ad34a31663464fcac0ba7a0feb', + 'https://chromium.googlesource.com/chromium/src/third_party@c1d40d8b399db4c5ebab5e5022a002dca5b3dbb2', 'src/buildtools/linux64': { 'packages': [ { 'package': 'gn/gn/linux-amd64', - 'version': 'git_revision:595e3be7c8381d4eeefce62a63ec12bae9ce5140', + 'version': 'git_revision:24e2f7df92641de0351a96096fb2c490b2436bb8', } ], 'dep_type': 'cipd', @@ -47,7 +53,7 @@ deps = { 'packages': [ { 'package': 'gn/gn/mac-${{arch}}', - 'version': 'git_revision:595e3be7c8381d4eeefce62a63ec12bae9ce5140', + 'version': 'git_revision:24e2f7df92641de0351a96096fb2c490b2436bb8', } ], 'dep_type': 'cipd', @@ -57,7 +63,7 @@ deps = { 'packages': [ { 'package': 'gn/gn/windows-amd64', - 'version': 'git_revision:595e3be7c8381d4eeefce62a63ec12bae9ce5140', + 'version': 'git_revision:24e2f7df92641de0351a96096fb2c490b2436bb8', } ], 'dep_type': 'cipd', @@ -65,13 +71,13 @@ deps = { }, 'src/buildtools/clang_format/script': - 'https://chromium.googlesource.com/chromium/llvm-project/cfe/tools/clang-format.git@96636aa0e9f047f17447f2d45a094d0b59ed7917', + 'https://chromium.googlesource.com/external/github.com/llvm/llvm-project/clang/tools/clang-format.git@99803d74e35962f63a775f29477882afd4d57d94', 'src/buildtools/third_party/libc++/trunk': - 'https://chromium.googlesource.com/external/github.com/llvm/llvm-project/libcxx.git@d9040c75cfea5928c804ab7c235fed06a63f743a', + 'https://chromium.googlesource.com/external/github.com/llvm/llvm-project/libcxx.git@79a2e924d96e2fc1e4b937c42efd08898fa472d7', 'src/buildtools/third_party/libc++abi/trunk': - 'https://chromium.googlesource.com/external/github.com/llvm/llvm-project/libcxxabi.git@196ba1aaa8ac285d94f4ea8d9836390a45360533', + 'https://chromium.googlesource.com/external/github.com/llvm/llvm-project/libcxxabi.git@cb34896ebd62f93f708ff9aad26159cf11dde6f4', 'src/buildtools/third_party/libunwind/trunk': - 'https://chromium.googlesource.com/external/github.com/llvm/llvm-project/libunwind.git@d999d54f4bca789543a2eb6c995af2d9b5a1f3ed', + 'https://chromium.googlesource.com/external/github.com/llvm/llvm-project/libunwind.git@e7ac0f84fc2f2f8bd2ad151a7348e7120d77648a', 'src/tools/clang/dsymutil': { 'packages': [ @@ -118,20 +124,22 @@ deps = { }, 'src/third_party/boringssl/src': - 'https://boringssl.googlesource.com/boringssl.git@bac5544e9832c65c95283e95062263c79a9a6733', + 'https://boringssl.googlesource.com/boringssl.git@a10017c548b0805eb98e7847c37370dbd37cd8d6', 'src/third_party/breakpad/breakpad': - 'https://chromium.googlesource.com/breakpad/breakpad.git@e3d485f73f5836fdd6fb287ab96973c4f63175e1', + 'https://chromium.googlesource.com/breakpad/breakpad.git@b95c4868b10f69e642666742233aede1eb653012', 'src/third_party/catapult': - 'https://chromium.googlesource.com/catapult.git@178c01be65a4a5458769084606cbc6504dde10f7', + 'https://chromium.googlesource.com/catapult.git@3345f09ed65020a999e108ea37d30b49c87e14ed', 'src/third_party/ced/src': { 'url': 'https://chromium.googlesource.com/external/github.com/google/compact_enc_det.git@ba412eaaacd3186085babcd901679a48863c7dd5', }, 'src/third_party/colorama/src': 'https://chromium.googlesource.com/external/colorama.git@799604a1041e9b3bc5d2789ecbd7e8db2e18e6b8', + 'src/third_party/crc32c/src': + 'https://chromium.googlesource.com/external/github.com/google/crc32c.git@fa5ade41ee480003d9c5af6f43567ba22e4e17e6', 'src/third_party/depot_tools': - 'https://chromium.googlesource.com/chromium/tools/depot_tools.git@c603339365fb945069c39147b1110d94e1814c28', + 'https://chromium.googlesource.com/chromium/tools/depot_tools.git@a806594b95a39141fdbf1f359087a44ffb2deaaf', 'src/third_party/ffmpeg': - 'https://chromium.googlesource.com/chromium/third_party/ffmpeg.git@841aa72c9e153ae5f952e31e4b6406870555922d', + 'https://chromium.googlesource.com/chromium/third_party/ffmpeg.git@05c195662f0527913811827ba253cb93758ea4c0', 'src/third_party/findbugs': { 'url': 'https://chromium.googlesource.com/chromium/deps/findbugs.git@4275d9ac8610db6b1bc9a5e887f97e41b33fac67', 'condition': 'checkout_android', @@ -142,11 +150,11 @@ deps = { 'condition': 'checkout_linux', }, 'src/third_party/freetype/src': - 'https://chromium.googlesource.com/chromium/src/third_party/freetype2.git@03ceda9701cd8c08ea5b4ee0c2d558a98fc4ed7d', + 'https://chromium.googlesource.com/chromium/src/third_party/freetype2.git@d3dc2da9b27af5b90575d62989389cc65fe7977c', 'src/third_party/harfbuzz-ng/src': - 'https://chromium.googlesource.com/external/github.com/harfbuzz/harfbuzz.git@53806e5b83cee0e275eac038d0780f95ac56588c', + 'https://chromium.googlesource.com/external/github.com/harfbuzz/harfbuzz.git@cc9bb294919e846ef8a0731b5e9f304f95ef3bb8', 'src/third_party/google_benchmark/src': { - 'url': 'https://chromium.googlesource.com/external/github.com/google/benchmark.git@ffe1342eb2faa7d2e7c35b4db2ccf99fab81ec20', + 'url': 'https://chromium.googlesource.com/external/github.com/google/benchmark.git@e991355c02b93fe17713efe04cbc2e278e00fdbd', }, # WebRTC-only dependency (not present in Chromium). 'src/third_party/gtest-parallel': @@ -162,21 +170,27 @@ deps = { 'dep_type': 'cipd', }, 'src/third_party/googletest/src': - 'https://chromium.googlesource.com/external/github.com/google/googletest.git@1b0cdaae57c046c87fb99cb4f69c312a7e794adb', + 'https://chromium.googlesource.com/external/github.com/google/googletest.git@4ec4cd23f486bf70efcc5d2caa40f24368f752e3', 'src/third_party/icu': { - 'url': 'https://chromium.googlesource.com/chromium/deps/icu.git@899e18383fd732b47e6978db2b960a1b2a80179b', + 'url': 'https://chromium.googlesource.com/chromium/deps/icu.git@b9dfc58bf9b02ea0365509244aca13841322feb0', }, 'src/third_party/jdk': { 'packages': [ { 'package': 'chromium/third_party/jdk', - 'version': 'PfRSnxe8Od6WU4zBXomq-zsgcJgWmm3z4gMQNB-r2QcC', + 'version': 'JhpgSvTpgVUkoKe56yQmYaR1jXNcY8NqlltA0mKIO4EC', }, + ], + 'condition': 'host_os == "linux" and checkout_android', + 'dep_type': 'cipd', + }, + 'src/third_party/jdk/extras': { + 'packages': [ { 'package': 'chromium/third_party/jdk/extras', - 'version': 'fkhuOQ3r-zKtWEdKplpo6k0vKkjl-LY_rJTmtzFCQN4C', + 'version': '-7m_pvgICYN60yQI3qmTj_8iKjtnT4NXicT0G_jJPqsC', }, - ], + ], 'condition': 'host_os == "linux" and checkout_android', 'dep_type': 'cipd', }, @@ -190,23 +204,23 @@ deps = { 'src/third_party/libFuzzer/src': 'https://chromium.googlesource.com/chromium/llvm-project/compiler-rt/lib/fuzzer.git@debe7d2d1982e540fbd6bd78604bf001753f9e74', 'src/third_party/libjpeg_turbo': - 'https://chromium.googlesource.com/chromium/deps/libjpeg_turbo.git@518d81558c797486e125e37cb529d65b560a6ea0', + 'https://chromium.googlesource.com/chromium/deps/libjpeg_turbo.git@e9e400e0af31baf72d235655850bc00e55b6c145', 'src/third_party/libsrtp': - 'https://chromium.googlesource.com/chromium/deps/libsrtp.git@7990ca64c616b150a9cb4714601c4a3b0c84fe91', + 'https://chromium.googlesource.com/chromium/deps/libsrtp.git@5b7c744eb8310250ccc534f3f86a2015b3887a0a', 'src/third_party/libaom/source/libaom': - 'https://aomedia.googlesource.com/aom.git@43927e4611e7c3062a67ebaca38a625faa9a39d6', + 'https://aomedia.googlesource.com/aom.git@aba245dde334bd51a20940eb009fa46b6ffd4511', 'src/third_party/libunwindstack': { - 'url': 'https://chromium.googlesource.com/chromium/src/third_party/libunwindstack.git@11659d420a71e7323b379ea8781f07c6f384bc7e', + 'url': 'https://chromium.googlesource.com/chromium/src/third_party/libunwindstack.git@b34a0059a648f179ef05da2c0927f564bdaea2b3', 'condition': 'checkout_android', }, 'src/third_party/perfetto': - 'https://android.googlesource.com/platform/external/perfetto.git@94ca9a9578a7eeb3df93a820955458fb9aff28fd', + 'https://android.googlesource.com/platform/external/perfetto.git@aecbd80f576686b67e29bdfae8c9c03bb9ce1996', 'src/third_party/libvpx/source/libvpx': - 'https://chromium.googlesource.com/webm/libvpx.git@b5d77a48d740e211a130c8e45d9353ef8c154a47', + 'https://chromium.googlesource.com/webm/libvpx.git@eebc5cd487a89c51ba148f6d6ac45779970f72d7', 'src/third_party/libyuv': - 'https://chromium.googlesource.com/libyuv/libyuv.git@93b1b332cd60b56ab90aea14182755e379c28a80', + 'https://chromium.googlesource.com/libyuv/libyuv.git@49ebc996aa8c4bdf89c1b5ea461eb677234c61cc', 'src/third_party/lss': { - 'url': 'https://chromium.googlesource.com/linux-syscall-support.git@29f7c7e018f4ce706a709f0b0afbf8bacf869480', + 'url': 'https://chromium.googlesource.com/linux-syscall-support.git@92a65a8f5d705d1928874420c8d0d15bde8c89e5', 'condition': 'checkout_android or checkout_linux', }, 'src/third_party/mockito/src': { @@ -216,7 +230,7 @@ deps = { # Used by boringssl. 'src/third_party/nasm': { - 'url': 'https://chromium.googlesource.com/chromium/deps/nasm.git@19f3fad68da99277b2882939d3b2fa4c4b8d51d9' + 'url': 'https://chromium.googlesource.com/chromium/deps/nasm.git@e9be5fd6d723a435ca2da162f9e0ffcb688747c1' }, 'src/third_party/openh264/src': @@ -225,7 +239,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/r8', - 'version': 'DR3nwJggFDcmTDz7P8fJQCtRLO1nxDt26czkOqhtZJ8C', + 'version': 'Nu_mvQJe34CotIXadFlA3w732CJ9EvQGuVs4udcZedAC', }, ], 'condition': 'checkout_android', @@ -250,16 +264,16 @@ deps = { 'condition': 'checkout_android', }, 'src/third_party/usrsctp/usrsctplib': - 'https://chromium.googlesource.com/external/github.com/sctplab/usrsctp@a3c3ef666b7a5e4c93ebae5a7462add6f86f5cf2', + 'https://chromium.googlesource.com/external/github.com/sctplab/usrsctp@1ade45cbadfd19298d2c47dc538962d4425ad2dd', # Dependency used by libjpeg-turbo. 'src/third_party/yasm/binaries': { 'url': 'https://chromium.googlesource.com/chromium/deps/yasm/binaries.git@52f9b3f4b0aa06da24ef8b123058bb61ee468881', 'condition': 'checkout_win', }, 'src/tools': - 'https://chromium.googlesource.com/chromium/src/tools@03a8864bc66bd6bbc0014ab62551b4465251729e', + 'https://chromium.googlesource.com/chromium/src/tools@1a00526b21d46b8b86f13add37003fd33885f32b', 'src/tools/swarming_client': - 'https://chromium.googlesource.com/infra/luci/client-py.git@1a072711d4388c62e02480fabc26c68c24494be9', + 'https://chromium.googlesource.com/infra/luci/client-py.git@a32a1607f6093d338f756c7e7c7b4333b0c50c9c', 'src/third_party/accessibility_test_framework': { 'packages': [ @@ -350,10 +364,21 @@ deps = { }, 'src/third_party/android_ndk': { - 'url': 'https://chromium.googlesource.com/android_ndk.git@27c0a8d090c666a50e40fceb4ee5b40b1a2d3f87', + 'url': 'https://chromium.googlesource.com/android_ndk.git@401019bf85744311b26c88ced255cd53401af8b7', 'condition': 'checkout_android', }, + 'src/third_party/androidx': { + 'packages': [ + { + 'package': 'chromium/third_party/androidx', + 'version': '-umIXLPTAdxRy2iaK4QFSeOf4t7PAKglJP7ggvWhfRwC', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + 'src/third_party/android_sdk/public': { 'packages': [ { @@ -371,950 +396,200 @@ deps = { { 'package': 'chromium/third_party/android_sdk/public/patcher', 'version': 'I6FNMhrXlpB-E1lOhMlvld7xt9lBVNOO83KIluXDyA0C', - }, - { - 'package': 'chromium/third_party/android_sdk/public/platform-tools', - 'version': '8tF0AOj7Dwlv4j7_nfkhxWB0jzrvWWYjEIpirt8FIWYC', - }, - { - 'package': 'chromium/third_party/android_sdk/public/platforms/android-30', - 'version': 'YMUu9EHNZ__2Xcxl-KsaSf-dI5TMt_P62IseUVsxktMC', - }, - { - 'package': 'chromium/third_party/android_sdk/public/sources/android-29', - 'version': '4gxhM8E62bvZpQs7Q3d0DinQaW0RLCIefhXrQBFkNy8C', - }, - { - 'package': 'chromium/third_party/android_sdk/public/cmdline-tools', - 'version': 'V__2Ycej-H2-6AcXX5A3gi7sIk74SuN44PBm2uC_N1sC', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/icu4j': { - 'packages': [ - { - 'package': 'chromium/third_party/icu4j', - 'version': 'e87e5bed2b4935913ee26a3ebd0b723ee2344354', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/objenesis': { - 'packages': [ - { - 'package': 'chromium/third_party/objenesis', - 'version': 'tknDblENYi8IaJYyD6tUahUyHYZlzJ_Y74_QZSz4DpIC', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/robolectric': { - 'packages': [ - { - 'package': 'chromium/third_party/robolectric', - 'version': 'iC6RDM5EH3GEAzR-1shW_Mg0FeeNE5shq1okkFfuuNQC', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/sqlite4java': { - 'packages': [ - { - 'package': 'chromium/third_party/sqlite4java', - 'version': 'LofjKH9dgXIAJhRYCPQlMFywSwxYimrfDeBmaHc-Z5EC', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/turbine': { - 'packages': [ - { - 'package': 'chromium/third_party/turbine', - 'version': '_iPtB_ThhxlMOt2TsYqVppwriEEn0mp-NUNRwDwYLUAC', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/turbine/src': { - 'url': 'https://chromium.googlesource.com/external/github.com/google/turbine.git' + '@' + '3c31e67ae25b5e43713fd868e3a9b535ff6298af', - 'condition': 'checkout_android', - }, - - 'src/third_party/xstream': { - 'packages': [ - { - 'package': 'chromium/third_party/xstream', - 'version': '4278b1b78b86ab7a1a29e64d5aec9a47a9aab0fe', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/tools/luci-go': { - 'packages': [ - { - 'package': 'infra/tools/luci/isolate/${{platform}}', - 'version': 'git_revision:77944aa535e42e29faadf6cfa81aee252807d468', - }, - { - 'package': 'infra/tools/luci/isolated/${{platform}}', - 'version': 'git_revision:77944aa535e42e29faadf6cfa81aee252807d468', - }, - { - 'package': 'infra/tools/luci/swarming/${{platform}}', - 'version': 'git_revision:77944aa535e42e29faadf6cfa81aee252807d468', - }, - ], - 'dep_type': 'cipd', - }, - - # Everything coming after this is automatically updated by the auto-roller. - # === ANDROID_DEPS Generated Code Start === - # Generated by //third_party/android_deps/fetch_all.py - 'src/third_party/android_deps/libs/android_arch_core_common': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/android_arch_core_common', - 'version': 'version:1.1.1-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/android_arch_core_runtime': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/android_arch_core_runtime', - 'version': 'version:1.1.1-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/android_arch_lifecycle_common': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/android_arch_lifecycle_common', - 'version': 'version:1.1.1-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/android_arch_lifecycle_common_java8': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/android_arch_lifecycle_common_java8', - 'version': 'version:1.1.1-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/android_arch_lifecycle_livedata': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/android_arch_lifecycle_livedata', - 'version': 'version:1.1.1-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/android_arch_lifecycle_livedata_core': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/android_arch_lifecycle_livedata_core', - 'version': 'version:1.1.1-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/android_arch_lifecycle_runtime': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/android_arch_lifecycle_runtime', - 'version': 'version:1.1.1-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/android_arch_lifecycle_viewmodel': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/android_arch_lifecycle_viewmodel', - 'version': 'version:1.1.1-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_activity_activity': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_activity_activity', - 'version': 'version:1.1.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_annotation_annotation': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_annotation_annotation', - 'version': 'version:1.2.0-alpha01-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_annotation_annotation_experimental': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_annotation_annotation_experimental', - 'version': 'version:1.1.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_appcompat_appcompat': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_appcompat_appcompat', - 'version': 'version:1.2.0-beta01-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_appcompat_appcompat_resources': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_appcompat_appcompat_resources', - 'version': 'version:1.2.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_arch_core_core_common': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_arch_core_core_common', - 'version': 'version:2.2.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_arch_core_core_runtime': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_arch_core_core_runtime', - 'version': 'version:2.1.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_asynclayoutinflater_asynclayoutinflater': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_asynclayoutinflater_asynclayoutinflater', - 'version': 'version:1.1.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_cardview_cardview': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_cardview_cardview', - 'version': 'version:1.1.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_collection_collection': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_collection_collection', - 'version': 'version:1.2.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_concurrent_concurrent_futures': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_concurrent_concurrent_futures', - 'version': 'version:1.2.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_coordinatorlayout_coordinatorlayout': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_coordinatorlayout_coordinatorlayout', - 'version': 'version:1.2.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_core_core': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_core_core', - 'version': 'version:1.5.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_core_core_animation': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_core_core_animation', - 'version': 'version:1.0.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_cursoradapter_cursoradapter': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_cursoradapter_cursoradapter', - 'version': 'version:1.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_customview_customview': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_customview_customview', - 'version': 'version:1.2.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_documentfile_documentfile': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_documentfile_documentfile', - 'version': 'version:1.1.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_drawerlayout_drawerlayout': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_drawerlayout_drawerlayout', - 'version': 'version:1.2.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_exifinterface_exifinterface': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_exifinterface_exifinterface', - 'version': 'version:1.4.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_fragment_fragment': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_fragment_fragment', - 'version': 'version:1.2.5-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_gridlayout_gridlayout': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_gridlayout_gridlayout', - 'version': 'version:1.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_interpolator_interpolator': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_interpolator_interpolator', - 'version': 'version:1.1.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_leanback_leanback': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_leanback_leanback', - 'version': 'version:1.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_leanback_leanback_preference': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_leanback_leanback_preference', - 'version': 'version:1.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_legacy_legacy_preference_v14': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_legacy_legacy_preference_v14', - 'version': 'version:1.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_legacy_legacy_support_core_ui': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_legacy_legacy_support_core_ui', - 'version': 'version:1.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_legacy_legacy_support_core_utils': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_legacy_legacy_support_core_utils', - 'version': 'version:1.1.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_legacy_legacy_support_v4': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_legacy_legacy_support_v4', - 'version': 'version:1.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_lifecycle_lifecycle_common': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_lifecycle_lifecycle_common', - 'version': 'version:2.2.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_lifecycle_lifecycle_common_java8': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_lifecycle_lifecycle_common_java8', - 'version': 'version:2.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_lifecycle_lifecycle_livedata': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_lifecycle_lifecycle_livedata', - 'version': 'version:2.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_lifecycle_lifecycle_livedata_core': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_lifecycle_lifecycle_livedata_core', - 'version': 'version:2.2.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_lifecycle_lifecycle_runtime': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_lifecycle_lifecycle_runtime', - 'version': 'version:2.2.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_lifecycle_lifecycle_viewmodel': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_lifecycle_lifecycle_viewmodel', - 'version': 'version:2.2.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_lifecycle_lifecycle_viewmodel_savedstate': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_lifecycle_lifecycle_viewmodel_savedstate', - 'version': 'version:2.2.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_loader_loader': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_loader_loader', - 'version': 'version:1.2.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_localbroadcastmanager_localbroadcastmanager': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_localbroadcastmanager_localbroadcastmanager', - 'version': 'version:1.1.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_media_media': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_media_media', - 'version': 'version:1.3.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_mediarouter_mediarouter': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_mediarouter_mediarouter', - 'version': 'version:1.3.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_multidex_multidex': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_multidex_multidex', - 'version': 'version:2.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_palette_palette': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_palette_palette', - 'version': 'version:1.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_preference_preference': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_preference_preference', - 'version': 'version:1.1.1-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_print_print': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_print_print', - 'version': 'version:1.1.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_recyclerview_recyclerview': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_recyclerview_recyclerview', - 'version': 'version:1.2.0-alpha06-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_savedstate_savedstate': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_savedstate_savedstate', - 'version': 'version:1.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_slice_slice_builders': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_slice_slice_builders', - 'version': 'version:1.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_slice_slice_core': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_slice_slice_core', - 'version': 'version:1.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_slidingpanelayout_slidingpanelayout': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_slidingpanelayout_slidingpanelayout', - 'version': 'version:1.0.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_swiperefreshlayout_swiperefreshlayout': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_swiperefreshlayout_swiperefreshlayout', - 'version': 'version:1.2.0-SNAPSHOT-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_test_core': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_test_core', - 'version': 'version:1.2.0-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_test_espresso_espresso_contrib': { - 'packages': [ + }, { - 'package': 'chromium/third_party/android_deps/libs/androidx_test_espresso_espresso_contrib', - 'version': 'version:3.2.0-cr0', + 'package': 'chromium/third_party/android_sdk/public/platform-tools', + 'version': '8tF0AOj7Dwlv4j7_nfkhxWB0jzrvWWYjEIpirt8FIWYC', }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_test_espresso_espresso_core': { - 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_test_espresso_espresso_core', - 'version': 'version:3.2.0-cr0', + 'package': 'chromium/third_party/android_sdk/public/platforms/android-30', + 'version': 'YMUu9EHNZ__2Xcxl-KsaSf-dI5TMt_P62IseUVsxktMC', + }, + { + 'package': 'chromium/third_party/android_sdk/public/sources/android-29', + 'version': '4gxhM8E62bvZpQs7Q3d0DinQaW0RLCIefhXrQBFkNy8C', }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/androidx_test_espresso_espresso_idling_resource': { - 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_test_espresso_espresso_idling_resource', - 'version': 'version:3.2.0-cr0', + 'package': 'chromium/third_party/android_sdk/public/cmdline-tools', + 'version': 'V__2Ycej-H2-6AcXX5A3gi7sIk74SuN44PBm2uC_N1sC', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_test_espresso_espresso_intents': { + 'src/third_party/icu4j': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_test_espresso_espresso_intents', - 'version': 'version:3.2.0-cr0', + 'package': 'chromium/third_party/icu4j', + 'version': 'e87e5bed2b4935913ee26a3ebd0b723ee2344354', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_test_espresso_espresso_web': { + 'src/third_party/objenesis': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_test_espresso_espresso_web', - 'version': 'version:3.2.0-cr0', + 'package': 'chromium/third_party/objenesis', + 'version': 'tknDblENYi8IaJYyD6tUahUyHYZlzJ_Y74_QZSz4DpIC', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_test_ext_junit': { + 'src/third_party/robolectric': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_test_ext_junit', - 'version': 'version:1.1.1-cr0', + 'package': 'chromium/third_party/robolectric', + 'version': 'iC6RDM5EH3GEAzR-1shW_Mg0FeeNE5shq1okkFfuuNQC', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_test_monitor': { + 'src/third_party/sqlite4java': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_test_monitor', - 'version': 'version:1.2.0-cr0', + 'package': 'chromium/third_party/sqlite4java', + 'version': 'LofjKH9dgXIAJhRYCPQlMFywSwxYimrfDeBmaHc-Z5EC', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_test_rules': { + 'src/third_party/turbine': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_test_rules', - 'version': 'version:1.2.0-cr0', + 'package': 'chromium/third_party/turbine', + 'version': 'Om6yIEXgJxuqghErK29h9RcMH6VaymMbxwScwXmcN6EC', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_test_runner': { + 'src/tools/luci-go': { 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/androidx_test_runner', - 'version': 'version:1.2.0-cr0', - }, + { + 'package': 'infra/tools/luci/isolate/${{platform}}', + 'version': 'git_revision:2ac8bd9cbc20824bb04a39b0f1b77178ace930b3', + }, + { + 'package': 'infra/tools/luci/isolated/${{platform}}', + 'version': 'git_revision:2ac8bd9cbc20824bb04a39b0f1b77178ace930b3', + }, + { + 'package': 'infra/tools/luci/swarming/${{platform}}', + 'version': 'git_revision:2ac8bd9cbc20824bb04a39b0f1b77178ace930b3', + }, ], - 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_test_uiautomator_uiautomator': { + # TODO(crbug.com/1184780) Move this back to ANDROID_DEPS Generated Code + # section once org_robolectric_shadows_multidex is updated to a new version + # that does not need jetify. + 'src/third_party/android_deps/libs/org_robolectric_shadows_multidex': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_test_uiautomator_uiautomator', - 'version': 'version:2.2.0-cr0', + 'package': 'chromium/third_party/android_deps/libs/org_robolectric_shadows_multidex', + 'version': 'version:4.3.1-cr1', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_transition_transition': { + # Everything coming after this is automatically updated by the auto-roller. + # === ANDROID_DEPS Generated Code Start === + # Generated by //third_party/android_deps/fetch_all.py + 'src/third_party/android_deps/libs/android_arch_core_common': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_transition_transition', - 'version': 'version:1.4.0-SNAPSHOT-cr0', + 'package': 'chromium/third_party/android_deps/libs/android_arch_core_common', + 'version': 'version:2@1.1.1.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_tvprovider_tvprovider': { + 'src/third_party/android_deps/libs/android_arch_core_runtime': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_tvprovider_tvprovider', - 'version': 'version:1.1.0-SNAPSHOT-cr0', + 'package': 'chromium/third_party/android_deps/libs/android_arch_core_runtime', + 'version': 'version:2@1.1.1.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_vectordrawable_vectordrawable': { + 'src/third_party/android_deps/libs/android_arch_lifecycle_common': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_vectordrawable_vectordrawable', - 'version': 'version:1.2.0-SNAPSHOT-cr0', + 'package': 'chromium/third_party/android_deps/libs/android_arch_lifecycle_common', + 'version': 'version:2@1.1.1.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_vectordrawable_vectordrawable_animated': { + 'src/third_party/android_deps/libs/android_arch_lifecycle_common_java8': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_vectordrawable_vectordrawable_animated', - 'version': 'version:1.2.0-SNAPSHOT-cr0', + 'package': 'chromium/third_party/android_deps/libs/android_arch_lifecycle_common_java8', + 'version': 'version:2@1.1.1.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_versionedparcelable_versionedparcelable': { + 'src/third_party/android_deps/libs/android_arch_lifecycle_livedata': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_versionedparcelable_versionedparcelable', - 'version': 'version:1.1.0-cr0', + 'package': 'chromium/third_party/android_deps/libs/android_arch_lifecycle_livedata', + 'version': 'version:2@1.1.1.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_viewpager2_viewpager2': { + 'src/third_party/android_deps/libs/android_arch_lifecycle_livedata_core': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_viewpager2_viewpager2', - 'version': 'version:1.1.0-SNAPSHOT-cr0', + 'package': 'chromium/third_party/android_deps/libs/android_arch_lifecycle_livedata_core', + 'version': 'version:2@1.1.1.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_viewpager_viewpager': { + 'src/third_party/android_deps/libs/android_arch_lifecycle_runtime': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_viewpager_viewpager', - 'version': 'version:1.1.0-SNAPSHOT-cr0', + 'package': 'chromium/third_party/android_deps/libs/android_arch_lifecycle_runtime', + 'version': 'version:2@1.1.1.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/androidx_webkit_webkit': { + 'src/third_party/android_deps/libs/android_arch_lifecycle_viewmodel': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/androidx_webkit_webkit', - 'version': 'version:1.3.0-rc01-cr0', + 'package': 'chromium/third_party/android_deps/libs/android_arch_lifecycle_viewmodel', + 'version': 'version:2@1.1.1.cr0', }, ], 'condition': 'checkout_android', @@ -1325,7 +600,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/backport_util_concurrent_backport_util_concurrent', - 'version': 'version:3.1-cr0', + 'version': 'version:2@3.1.cr0', }, ], 'condition': 'checkout_android', @@ -1336,7 +611,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/classworlds_classworlds', - 'version': 'version:1.1-alpha-2-cr0', + 'version': 'version:2@1.1-alpha-2.cr0', }, ], 'condition': 'checkout_android', @@ -1347,7 +622,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_animated_vector_drawable', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1358,7 +633,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_appcompat_v7', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1369,7 +644,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_asynclayoutinflater', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1380,7 +655,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_cardview_v7', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1391,7 +666,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_collections', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1402,7 +677,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_coordinatorlayout', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1413,7 +688,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_cursoradapter', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1424,7 +699,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_customview', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1435,7 +710,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_design', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1446,7 +721,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_documentfile', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1457,7 +732,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_drawerlayout', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1468,7 +743,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_interpolator', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1479,7 +754,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_loader', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1490,7 +765,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_localbroadcastmanager', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1501,7 +776,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_multidex', - 'version': 'version:1.0.0-cr0', + 'version': 'version:2@1.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1512,7 +787,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_print', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1523,7 +798,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_recyclerview_v7', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1534,7 +809,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_slidingpanelayout', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1545,7 +820,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_support_annotations', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1556,7 +831,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_support_compat', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1567,7 +842,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_support_core_ui', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1578,7 +853,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_support_core_utils', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1589,7 +864,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_support_fragment', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1600,7 +875,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_support_media_compat', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1611,7 +886,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_support_v4', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1622,7 +897,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_support_vector_drawable', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1633,7 +908,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_swiperefreshlayout', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1644,7 +919,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_transition', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1655,7 +930,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_versionedparcelable', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1666,51 +941,62 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_android_support_viewpager', - 'version': 'version:28.0.0-cr0', + 'version': 'version:2@28.0.0.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/com_android_tools_build_jetifier_jetifier_core': { + 'src/third_party/android_deps/libs/com_android_tools_common': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/com_android_tools_build_jetifier_jetifier_core', - 'version': 'version:1.0.0-beta08-cr0', + 'package': 'chromium/third_party/android_deps/libs/com_android_tools_common', + 'version': 'version:2@30.0.0-alpha10.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/com_android_tools_build_jetifier_jetifier_processor': { + 'src/third_party/android_deps/libs/com_android_tools_desugar_jdk_libs': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/com_android_tools_build_jetifier_jetifier_processor', - 'version': 'version:1.0.0-beta08-cr0', + 'package': 'chromium/third_party/android_deps/libs/com_android_tools_desugar_jdk_libs', + 'version': 'version:2@1.1.1.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/com_android_tools_desugar_jdk_libs': { + 'src/third_party/android_deps/libs/com_android_tools_desugar_jdk_libs_configuration': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/com_android_tools_desugar_jdk_libs', - 'version': 'version:1.0.10-cr0', + 'package': 'chromium/third_party/android_deps/libs/com_android_tools_desugar_jdk_libs_configuration', + 'version': 'version:2@1.1.1.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/com_android_tools_desugar_jdk_libs_configuration': { + 'src/third_party/android_deps/libs/com_android_tools_layoutlib_layoutlib_api': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/com_android_tools_desugar_jdk_libs_configuration', - 'version': 'version:1.0.10-cr0', + 'package': 'chromium/third_party/android_deps/libs/com_android_tools_layoutlib_layoutlib_api', + 'version': 'version:2@30.0.0-alpha10.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_android_tools_sdk_common': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_android_tools_sdk_common', + 'version': 'version:2@30.0.0-alpha10.cr0', }, ], 'condition': 'checkout_android', @@ -1721,7 +1007,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_github_ben_manes_caffeine_caffeine', - 'version': 'version:2.8.0-cr0', + 'version': 'version:2@2.8.8.cr0', }, ], 'condition': 'checkout_android', @@ -1732,7 +1018,18 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_github_kevinstern_software_and_algorithms', - 'version': 'version:1.0-cr0', + 'version': 'version:2@1.0.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_google_android_datatransport_transport_api': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_google_android_datatransport_transport_api', + 'version': 'version:2@2.2.1.cr0', }, ], 'condition': 'checkout_android', @@ -1743,7 +1040,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_auth', - 'version': 'version:17.0.0-cr0', + 'version': 'version:2@17.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1754,7 +1051,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_auth_api_phone', - 'version': 'version:17.5.0-cr0', + 'version': 'version:2@17.5.0.cr0', }, ], 'condition': 'checkout_android', @@ -1765,7 +1062,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_auth_base', - 'version': 'version:17.0.0-cr0', + 'version': 'version:2@17.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1776,7 +1073,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_base', - 'version': 'version:17.1.0-cr0', + 'version': 'version:2@17.5.0.cr0', }, ], 'condition': 'checkout_android', @@ -1787,7 +1084,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_basement', - 'version': 'version:17.1.0-cr0', + 'version': 'version:2@17.5.0.cr0', }, ], 'condition': 'checkout_android', @@ -1798,7 +1095,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_cast', - 'version': 'version:17.0.0-cr0', + 'version': 'version:2@17.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1809,7 +1106,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_cast_framework', - 'version': 'version:17.0.0-cr0', + 'version': 'version:2@17.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1820,7 +1117,18 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_clearcut', - 'version': 'version:17.0.0-cr0', + 'version': 'version:2@17.0.0.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_google_android_gms_play_services_cloud_messaging': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_cloud_messaging', + 'version': 'version:2@16.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1831,7 +1139,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_fido', - 'version': 'version:18.1.0-cr0', + 'version': 'version:2@19.0.0-beta.cr0', }, ], 'condition': 'checkout_android', @@ -1842,7 +1150,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_flags', - 'version': 'version:17.0.0-cr0', + 'version': 'version:2@17.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1853,7 +1161,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_gcm', - 'version': 'version:17.0.0-cr0', + 'version': 'version:2@17.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1864,7 +1172,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_iid', - 'version': 'version:17.0.0-cr0', + 'version': 'version:2@17.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1875,7 +1183,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_instantapps', - 'version': 'version:17.0.0-cr0', + 'version': 'version:2@17.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1886,7 +1194,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_location', - 'version': 'version:17.0.0-cr0', + 'version': 'version:2@17.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1897,7 +1205,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_phenotype', - 'version': 'version:17.0.0-cr0', + 'version': 'version:2@17.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1908,7 +1216,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_places_placereport', - 'version': 'version:17.0.0-cr0', + 'version': 'version:2@17.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1919,7 +1227,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_stats', - 'version': 'version:17.0.0-cr0', + 'version': 'version:2@17.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1930,7 +1238,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_tasks', - 'version': 'version:17.0.0-cr0', + 'version': 'version:2@17.2.0.cr0', }, ], 'condition': 'checkout_android', @@ -1941,7 +1249,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_vision', - 'version': 'version:18.0.0-cr0', + 'version': 'version:2@18.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1952,7 +1260,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_gms_play_services_vision_common', - 'version': 'version:18.0.0-cr0', + 'version': 'version:2@18.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -1963,7 +1271,18 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_android_material_material', - 'version': 'version:1.2.0-alpha06-cr0', + 'version': 'version:2@1.4.0-rc01.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_google_android_play_core': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_google_android_play_core', + 'version': 'version:2@1.10.0.cr0', }, ], 'condition': 'checkout_android', @@ -1974,7 +1293,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_auto_auto_common', - 'version': 'version:0.10-cr0', + 'version': 'version:2@0.10.cr0', }, ], 'condition': 'checkout_android', @@ -1985,7 +1304,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_auto_service_auto_service', - 'version': 'version:1.0-rc6-cr0', + 'version': 'version:2@1.0-rc6.cr0', }, ], 'condition': 'checkout_android', @@ -1996,7 +1315,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_auto_service_auto_service_annotations', - 'version': 'version:1.0-rc6-cr0', + 'version': 'version:2@1.0-rc6.cr0', }, ], 'condition': 'checkout_android', @@ -2007,18 +1326,18 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_auto_value_auto_value_annotations', - 'version': 'version:1.7-cr0', + 'version': 'version:2@1.7.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/com_google_code_findbugs_jFormatString': { + 'src/third_party/android_deps/libs/com_google_code_findbugs_jformatstring': { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_code_findbugs_jformatstring', - 'version': 'version:3.0.0-cr0', + 'version': 'version:2@3.0.0.cr0', }, ], 'condition': 'checkout_android', @@ -2029,7 +1348,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_code_findbugs_jsr305', - 'version': 'version:3.0.2-cr0', + 'version': 'version:2@3.0.2.cr0', }, ], 'condition': 'checkout_android', @@ -2040,7 +1359,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_code_gson_gson', - 'version': 'version:2.8.0-cr0', + 'version': 'version:2@2.8.0.cr0', }, ], 'condition': 'checkout_android', @@ -2051,7 +1370,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_dagger_dagger', - 'version': 'version:2.30-cr0', + 'version': 'version:2@2.30.cr0', }, ], 'condition': 'checkout_android', @@ -2062,7 +1381,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_dagger_dagger_compiler', - 'version': 'version:2.30-cr0', + 'version': 'version:2@2.30.cr0', }, ], 'condition': 'checkout_android', @@ -2073,7 +1392,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_dagger_dagger_producers', - 'version': 'version:2.30-cr0', + 'version': 'version:2@2.30.cr0', }, ], 'condition': 'checkout_android', @@ -2084,7 +1403,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_dagger_dagger_spi', - 'version': 'version:2.30-cr0', + 'version': 'version:2@2.30.cr0', }, ], 'condition': 'checkout_android', @@ -2095,7 +1414,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_errorprone_error_prone_annotation', - 'version': 'version:2.4.0-cr0', + 'version': 'version:2@2.7.1.cr0', }, ], 'condition': 'checkout_android', @@ -2106,7 +1425,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_errorprone_error_prone_annotations', - 'version': 'version:2.4.0-cr0', + 'version': 'version:2@2.7.1.cr0', }, ], 'condition': 'checkout_android', @@ -2117,7 +1436,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_errorprone_error_prone_check_api', - 'version': 'version:2.4.0-cr0', + 'version': 'version:2@2.7.1.cr0', }, ], 'condition': 'checkout_android', @@ -2128,7 +1447,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_errorprone_error_prone_core', - 'version': 'version:2.4.0-cr0', + 'version': 'version:2@2.7.1.cr0', }, ], 'condition': 'checkout_android', @@ -2139,7 +1458,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_errorprone_error_prone_type_annotations', - 'version': 'version:2.4.0-cr0', + 'version': 'version:2@2.7.1.cr0', }, ], 'condition': 'checkout_android', @@ -2150,7 +1469,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_errorprone_javac', - 'version': 'version:9+181-r4173-1-cr0', + 'version': 'version:2@9+181-r4173-1.cr0', }, ], 'condition': 'checkout_android', @@ -2161,7 +1480,128 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_errorprone_javac_shaded', - 'version': 'version:9-dev-r4023-3-cr0', + 'version': 'version:2@9-dev-r4023-3.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_google_firebase_firebase_annotations': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_google_firebase_firebase_annotations', + 'version': 'version:2@16.0.0.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_google_firebase_firebase_common': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_google_firebase_firebase_common', + 'version': 'version:2@19.5.0.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_google_firebase_firebase_components': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_google_firebase_firebase_components', + 'version': 'version:2@16.1.0.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_google_firebase_firebase_encoders': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_google_firebase_firebase_encoders', + 'version': 'version:2@16.1.0.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_google_firebase_firebase_encoders_json': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_google_firebase_firebase_encoders_json', + 'version': 'version:2@17.1.0.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_google_firebase_firebase_iid': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_google_firebase_firebase_iid', + 'version': 'version:2@21.0.1.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_google_firebase_firebase_iid_interop': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_google_firebase_firebase_iid_interop', + 'version': 'version:2@17.0.0.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_google_firebase_firebase_installations': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_google_firebase_firebase_installations', + 'version': 'version:2@16.3.5.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_google_firebase_firebase_installations_interop': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_google_firebase_firebase_installations_interop', + 'version': 'version:2@16.0.1.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_google_firebase_firebase_measurement_connector': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_google_firebase_firebase_measurement_connector', + 'version': 'version:2@18.0.0.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/com_google_firebase_firebase_messaging': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/com_google_firebase_firebase_messaging', + 'version': 'version:2@21.0.1.cr0', }, ], 'condition': 'checkout_android', @@ -2172,7 +1612,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_googlejavaformat_google_java_format', - 'version': 'version:1.5-cr0', + 'version': 'version:2@1.5.cr0', }, ], 'condition': 'checkout_android', @@ -2183,7 +1623,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_guava_failureaccess', - 'version': 'version:1.0.1-cr0', + 'version': 'version:2@1.0.1.cr0', }, ], 'condition': 'checkout_android', @@ -2194,7 +1634,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_guava_guava', - 'version': 'version:30.1-jre-cr0', + 'version': 'version:2@30.1-jre.cr0', }, ], 'condition': 'checkout_android', @@ -2205,7 +1645,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_guava_guava_android', - 'version': 'version:30.1-android-cr0', + 'version': 'version:2@30.1-android.cr0', }, ], 'condition': 'checkout_android', @@ -2216,7 +1656,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_guava_listenablefuture', - 'version': 'version:1.0-cr0', + 'version': 'version:2@1.0.cr0', }, ], 'condition': 'checkout_android', @@ -2227,7 +1667,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_j2objc_j2objc_annotations', - 'version': 'version:1.3-cr0', + 'version': 'version:2@1.3.cr0', }, ], 'condition': 'checkout_android', @@ -2238,7 +1678,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_protobuf_protobuf_java', - 'version': 'version:3.4.0-cr0', + 'version': 'version:2@3.4.0.cr0', }, ], 'condition': 'checkout_android', @@ -2249,7 +1689,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_google_protobuf_protobuf_javalite', - 'version': 'version:3.13.0-cr0', + 'version': 'version:2@3.13.0.cr0', }, ], 'condition': 'checkout_android', @@ -2260,7 +1700,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_googlecode_java_diff_utils_diffutils', - 'version': 'version:1.3.0-cr0', + 'version': 'version:2@1.3.0.cr0', }, ], 'condition': 'checkout_android', @@ -2271,7 +1711,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_squareup_javapoet', - 'version': 'version:1.13.0-cr0', + 'version': 'version:2@1.13.0.cr0', }, ], 'condition': 'checkout_android', @@ -2282,18 +1722,18 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/com_squareup_javawriter', - 'version': 'version:2.1.1-cr0', + 'version': 'version:2@2.1.1.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/commons_cli_commons_cli': { + 'src/third_party/android_deps/libs/io_github_java_diff_utils_java_diff_utils': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/commons_cli_commons_cli', - 'version': 'version:1.3.1-cr0', + 'package': 'chromium/third_party/android_deps/libs/io_github_java_diff_utils_java_diff_utils', + 'version': 'version:2@4.0.cr0', }, ], 'condition': 'checkout_android', @@ -2304,7 +1744,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/javax_annotation_javax_annotation_api', - 'version': 'version:1.3.2-cr0', + 'version': 'version:2@1.3.2.cr0', }, ], 'condition': 'checkout_android', @@ -2315,7 +1755,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/javax_annotation_jsr250_api', - 'version': 'version:1.0-cr0', + 'version': 'version:2@1.0.cr0', }, ], 'condition': 'checkout_android', @@ -2326,7 +1766,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/javax_inject_javax_inject', - 'version': 'version:1-cr0', + 'version': 'version:2@1.cr0', }, ], 'condition': 'checkout_android', @@ -2337,18 +1777,18 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/nekohtml_nekohtml', - 'version': 'version:1.9.6.2-cr0', + 'version': 'version:2@1.9.6.2.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/nekohtml_xercesMinimal': { + 'src/third_party/android_deps/libs/nekohtml_xercesminimal': { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/nekohtml_xercesminimal', - 'version': 'version:1.9.6.2-cr0', + 'version': 'version:2@1.9.6.2.cr0', }, ], 'condition': 'checkout_android', @@ -2359,7 +1799,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/net_ltgt_gradle_incap_incap', - 'version': 'version:0.2-cr0', + 'version': 'version:2@0.2.cr0', }, ], 'condition': 'checkout_android', @@ -2370,7 +1810,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/net_sf_kxml_kxml2', - 'version': 'version:2.3.0-cr0', + 'version': 'version:2@2.3.0.cr0', }, ], 'condition': 'checkout_android', @@ -2381,7 +1821,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_ant_ant', - 'version': 'version:1.8.0-cr0', + 'version': 'version:2@1.8.0.cr0', }, ], 'condition': 'checkout_android', @@ -2392,7 +1832,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_ant_ant_launcher', - 'version': 'version:1.8.0-cr0', + 'version': 'version:2@1.8.0.cr0', }, ], 'condition': 'checkout_android', @@ -2403,7 +1843,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_maven_maven_ant_tasks', - 'version': 'version:2.1.3-cr0', + 'version': 'version:2@2.1.3.cr0', }, ], 'condition': 'checkout_android', @@ -2414,7 +1854,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_maven_maven_artifact', - 'version': 'version:2.2.1-cr0', + 'version': 'version:2@2.2.1.cr0', }, ], 'condition': 'checkout_android', @@ -2425,7 +1865,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_maven_maven_artifact_manager', - 'version': 'version:2.2.1-cr0', + 'version': 'version:2@2.2.1.cr0', }, ], 'condition': 'checkout_android', @@ -2436,7 +1876,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_maven_maven_error_diagnostics', - 'version': 'version:2.2.1-cr0', + 'version': 'version:2@2.2.1.cr0', }, ], 'condition': 'checkout_android', @@ -2447,7 +1887,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_maven_maven_model', - 'version': 'version:2.2.1-cr0', + 'version': 'version:2@2.2.1.cr0', }, ], 'condition': 'checkout_android', @@ -2458,7 +1898,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_maven_maven_plugin_registry', - 'version': 'version:2.2.1-cr0', + 'version': 'version:2@2.2.1.cr0', }, ], 'condition': 'checkout_android', @@ -2469,7 +1909,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_maven_maven_profile', - 'version': 'version:2.2.1-cr0', + 'version': 'version:2@2.2.1.cr0', }, ], 'condition': 'checkout_android', @@ -2480,7 +1920,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_maven_maven_project', - 'version': 'version:2.2.1-cr0', + 'version': 'version:2@2.2.1.cr0', }, ], 'condition': 'checkout_android', @@ -2491,7 +1931,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_maven_maven_repository_metadata', - 'version': 'version:2.2.1-cr0', + 'version': 'version:2@2.2.1.cr0', }, ], 'condition': 'checkout_android', @@ -2502,7 +1942,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_maven_maven_settings', - 'version': 'version:2.2.1-cr0', + 'version': 'version:2@2.2.1.cr0', }, ], 'condition': 'checkout_android', @@ -2513,7 +1953,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_maven_wagon_wagon_file', - 'version': 'version:1.0-beta-6-cr0', + 'version': 'version:2@1.0-beta-6.cr0', }, ], 'condition': 'checkout_android', @@ -2524,7 +1964,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_maven_wagon_wagon_http_lightweight', - 'version': 'version:1.0-beta-6-cr0', + 'version': 'version:2@1.0-beta-6.cr0', }, ], 'condition': 'checkout_android', @@ -2535,7 +1975,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_maven_wagon_wagon_http_shared', - 'version': 'version:1.0-beta-6-cr0', + 'version': 'version:2@1.0-beta-6.cr0', }, ], 'condition': 'checkout_android', @@ -2546,7 +1986,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_apache_maven_wagon_wagon_provider_api', - 'version': 'version:1.0-beta-6-cr0', + 'version': 'version:2@1.0-beta-6.cr0', }, ], 'condition': 'checkout_android', @@ -2557,7 +1997,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_ccil_cowan_tagsoup_tagsoup', - 'version': 'version:1.2.1-cr0', + 'version': 'version:2@1.2.1.cr0', }, ], 'condition': 'checkout_android', @@ -2568,7 +2008,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_checkerframework_checker_compat_qual', - 'version': 'version:2.5.5-cr0', + 'version': 'version:2@2.5.5.cr0', }, ], 'condition': 'checkout_android', @@ -2579,7 +2019,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_checkerframework_checker_qual', - 'version': 'version:3.5.0-cr0', + 'version': 'version:2@3.8.0.cr0', }, ], 'condition': 'checkout_android', @@ -2590,7 +2030,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_checkerframework_dataflow_shaded', - 'version': 'version:3.1.2-cr0', + 'version': 'version:2@3.11.0.cr0', }, ], 'condition': 'checkout_android', @@ -2601,7 +2041,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_codehaus_mojo_animal_sniffer_annotations', - 'version': 'version:1.17-cr0', + 'version': 'version:2@1.17.cr0', }, ], 'condition': 'checkout_android', @@ -2612,7 +2052,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_codehaus_plexus_plexus_container_default', - 'version': 'version:1.0-alpha-9-stable-1-cr0', + 'version': 'version:2@1.0-alpha-9-stable-1.cr0', }, ], 'condition': 'checkout_android', @@ -2623,7 +2063,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_codehaus_plexus_plexus_interpolation', - 'version': 'version:1.11-cr0', + 'version': 'version:2@1.11.cr0', }, ], 'condition': 'checkout_android', @@ -2634,18 +2074,18 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_codehaus_plexus_plexus_utils', - 'version': 'version:1.5.15-cr0', + 'version': 'version:2@1.5.15.cr0', }, ], 'condition': 'checkout_android', 'dep_type': 'cipd', }, - 'src/third_party/android_deps/libs/org_jdom_jdom2': { + 'src/third_party/android_deps/libs/org_eclipse_jgit_org_eclipse_jgit': { 'packages': [ { - 'package': 'chromium/third_party/android_deps/libs/org_jdom_jdom2', - 'version': 'version:2.0.6-cr0', + 'package': 'chromium/third_party/android_deps/libs/org_eclipse_jgit_org_eclipse_jgit', + 'version': 'version:2@4.4.1.201607150455-r.cr0', }, ], 'condition': 'checkout_android', @@ -2656,7 +2096,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_jetbrains_annotations', - 'version': 'version:13.0-cr0', + 'version': 'version:2@13.0.cr0', }, ], 'condition': 'checkout_android', @@ -2667,7 +2107,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_jetbrains_kotlin_kotlin_stdlib', - 'version': 'version:1.3.72-cr0', + 'version': 'version:2@1.5.10.cr0', }, ], 'condition': 'checkout_android', @@ -2678,7 +2118,51 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_jetbrains_kotlin_kotlin_stdlib_common', - 'version': 'version:1.3.72-cr0', + 'version': 'version:2@1.5.10.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/org_jetbrains_kotlin_kotlin_stdlib_jdk7': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/org_jetbrains_kotlin_kotlin_stdlib_jdk7', + 'version': 'version:2@1.5.0.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/org_jetbrains_kotlin_kotlin_stdlib_jdk8': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/org_jetbrains_kotlin_kotlin_stdlib_jdk8', + 'version': 'version:2@1.5.0.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/org_jetbrains_kotlinx_kotlinx_coroutines_android': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/org_jetbrains_kotlinx_kotlinx_coroutines_android', + 'version': 'version:2@1.5.0.cr0', + }, + ], + 'condition': 'checkout_android', + 'dep_type': 'cipd', + }, + + 'src/third_party/android_deps/libs/org_jetbrains_kotlinx_kotlinx_coroutines_core_jvm': { + 'packages': [ + { + 'package': 'chromium/third_party/android_deps/libs/org_jetbrains_kotlinx_kotlinx_coroutines_core_jvm', + 'version': 'version:2@1.5.0.cr0', }, ], 'condition': 'checkout_android', @@ -2689,7 +2173,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_jetbrains_kotlinx_kotlinx_metadata_jvm', - 'version': 'version:0.1.0-cr0', + 'version': 'version:2@0.1.0.cr0', }, ], 'condition': 'checkout_android', @@ -2700,7 +2184,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_ow2_asm_asm', - 'version': 'version:7.0-cr0', + 'version': 'version:2@7.0.cr0', }, ], 'condition': 'checkout_android', @@ -2711,7 +2195,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_ow2_asm_asm_analysis', - 'version': 'version:7.0-cr0', + 'version': 'version:2@7.0.cr0', }, ], 'condition': 'checkout_android', @@ -2722,7 +2206,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_ow2_asm_asm_commons', - 'version': 'version:7.0-cr0', + 'version': 'version:2@7.0.cr0', }, ], 'condition': 'checkout_android', @@ -2733,7 +2217,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_ow2_asm_asm_tree', - 'version': 'version:7.0-cr0', + 'version': 'version:2@7.0.cr0', }, ], 'condition': 'checkout_android', @@ -2744,7 +2228,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_ow2_asm_asm_util', - 'version': 'version:7.0-cr0', + 'version': 'version:2@7.0.cr0', }, ], 'condition': 'checkout_android', @@ -2755,7 +2239,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_pcollections_pcollections', - 'version': 'version:2.1.2-cr0', + 'version': 'version:2@2.1.2.cr0', }, ], 'condition': 'checkout_android', @@ -2766,7 +2250,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_robolectric_annotations', - 'version': 'version:4.3.1-cr0', + 'version': 'version:2@4.3.1.cr0', }, ], 'condition': 'checkout_android', @@ -2777,7 +2261,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_robolectric_junit', - 'version': 'version:4.3.1-cr0', + 'version': 'version:2@4.3.1.cr0', }, ], 'condition': 'checkout_android', @@ -2788,7 +2272,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_robolectric_pluginapi', - 'version': 'version:4.3.1-cr0', + 'version': 'version:2@4.3.1.cr0', }, ], 'condition': 'checkout_android', @@ -2799,7 +2283,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_robolectric_plugins_maven_dependency_resolver', - 'version': 'version:4.3.1-cr0', + 'version': 'version:2@4.3.1.cr0', }, ], 'condition': 'checkout_android', @@ -2810,7 +2294,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_robolectric_resources', - 'version': 'version:4.3.1-cr0', + 'version': 'version:2@4.3.1.cr0', }, ], 'condition': 'checkout_android', @@ -2821,7 +2305,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_robolectric_robolectric', - 'version': 'version:4.3.1-cr0', + 'version': 'version:2@4.3.1.cr0', }, ], 'condition': 'checkout_android', @@ -2832,7 +2316,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_robolectric_sandbox', - 'version': 'version:4.3.1-cr0', + 'version': 'version:2@4.3.1.cr0', }, ], 'condition': 'checkout_android', @@ -2843,7 +2327,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_robolectric_shadowapi', - 'version': 'version:4.3.1-cr0', + 'version': 'version:2@4.3.1.cr0', }, ], 'condition': 'checkout_android', @@ -2854,18 +2338,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_robolectric_shadows_framework', - 'version': 'version:4.3.1-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/org_robolectric_shadows_multidex': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/org_robolectric_shadows_multidex', - 'version': 'version:4.3.1-cr1', + 'version': 'version:2@4.3.1.cr0', }, ], 'condition': 'checkout_android', @@ -2876,7 +2349,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_robolectric_shadows_playservices', - 'version': 'version:4.3.1-cr0', + 'version': 'version:2@4.3.1.cr0', }, ], 'condition': 'checkout_android', @@ -2887,7 +2360,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_robolectric_utils', - 'version': 'version:4.3.1-cr0', + 'version': 'version:2@4.3.1.cr0', }, ], 'condition': 'checkout_android', @@ -2898,18 +2371,7 @@ deps = { 'packages': [ { 'package': 'chromium/third_party/android_deps/libs/org_robolectric_utils_reflector', - 'version': 'version:4.3.1-cr0', - }, - ], - 'condition': 'checkout_android', - 'dep_type': 'cipd', - }, - - 'src/third_party/android_deps/libs/org_threeten_threeten_extra': { - 'packages': [ - { - 'package': 'chromium/third_party/android_deps/libs/org_threeten_threeten_extra', - 'version': 'version:1.5.0-cr0', + 'version': 'version:2@4.3.1.cr0', }, ], 'condition': 'checkout_android', @@ -3145,6 +2607,16 @@ hooks = [ '--bucket', 'chromium-webrtc-resources', 'src/resources'], }, + { + 'name': 'Generate component metadata for tests', + 'pattern': '.', + 'action': [ + 'vpython', + 'src/testing/generate_location_tags.py', + '--out', + 'src/testing/location_tags.json', + ], + }, # Download and initialize "vpython" VirtualEnv environment packages. { 'name': 'vpython_common', @@ -3185,6 +2657,7 @@ include_rules = [ "+absl/base/const_init.h", "+absl/base/macros.h", "+absl/container/inlined_vector.h", + "+absl/functional/bind_front.h", "+absl/memory/memory.h", "+absl/meta/type_traits.h", "+absl/strings/ascii.h", diff --git a/DIR_METADATA b/DIR_METADATA new file mode 100644 index 0000000000..a002d0947f --- /dev/null +++ b/DIR_METADATA @@ -0,0 +1,3 @@ +monorail { + project: "webrtc" +} diff --git a/OWNERS b/OWNERS index cdd8ffc0ad..587c130ed7 100644 --- a/OWNERS +++ b/OWNERS @@ -1,4 +1,5 @@ henrika@webrtc.org +hta@webrtc.org juberti@webrtc.org mflodman@webrtc.org stefan@webrtc.org @@ -7,13 +8,10 @@ per-file .gitignore=* per-file .gn=mbonadei@webrtc.org per-file *.gn=mbonadei@webrtc.org per-file *.gni=mbonadei@webrtc.org +per-file .vpython=mbonadei@webrtc.org per-file AUTHORS=* per-file DEPS=* -per-file pylintrc=phoglund@webrtc.org +per-file pylintrc=mbonadei@webrtc.org per-file WATCHLISTS=* -per-file abseil-in-webrtc.md=danilchap@webrtc.org -per-file abseil-in-webrtc.md=mbonadei@webrtc.org -per-file style-guide.md=danilchap@webrtc.org per-file native-api.md=mbonadei@webrtc.org - -# COMPONENT: Internals>WebRTC +per-file *.lua=titovartem@webrtc.org diff --git a/PRESUBMIT.py b/PRESUBMIT.py index 12f87d7ff1..21875f61af 100755 --- a/PRESUBMIT.py +++ b/PRESUBMIT.py @@ -21,7 +21,7 @@ 'examples/objc', 'media/base/stream_params.h', 'media/base/video_common.h', - 'media/sctp/sctp_transport.cc', + 'media/sctp/usrsctp_transport.cc', 'modules/audio_coding', 'modules/audio_device', 'modules/audio_processing', @@ -146,9 +146,9 @@ def VerifyNativeApiHeadersListIsValid(input_api, output_api): if non_existing_paths: return [ output_api.PresubmitError( - 'Directories to native API headers have changed which has made the ' - 'list in PRESUBMIT.py outdated.\nPlease update it to the current ' - 'location of our native APIs.', non_existing_paths) + 'Directories to native API headers have changed which has made ' + 'the list in PRESUBMIT.py outdated.\nPlease update it to the ' + 'current location of our native APIs.', non_existing_paths) ] return [] @@ -157,7 +157,7 @@ def VerifyNativeApiHeadersListIsValid(input_api, output_api): You seem to be changing native API header files. Please make sure that you: 1. Make compatible changes that don't break existing clients. Usually this is done by keeping the existing method signatures unchanged. - 2. Mark the old stuff as deprecated (see RTC_DEPRECATED macro). + 2. Mark the old stuff as deprecated (use the ABSL_DEPRECATED macro). 3. Create a timeline and plan for when the deprecated stuff will be removed. (The amount of time we give users to change their code should be informed by how much work it is for them. If they just @@ -212,10 +212,10 @@ def CheckNoIOStreamInHeaders(input_api, output_api, source_file_filter): if len(files): return [ output_api.PresubmitError( - 'Do not #include in header files, since it inserts static ' - + - 'initialization into every file including the header. Instead, ' - + '#include . See http://crbug.com/94794', files) + 'Do not #include in header files, since it inserts ' + 'static initialization into every file including the header. ' + 'Instead, #include . See http://crbug.com/94794', + files) ] return [] @@ -237,15 +237,15 @@ def CheckNoPragmaOnce(input_api, output_api, source_file_filter): return [ output_api.PresubmitError( 'Do not use #pragma once in header files.\n' - 'See http://www.chromium.org/developers/coding-style#TOC-File-headers', + 'See http://www.chromium.org/developers/coding-style' + '#TOC-File-headers', files) ] return [] - -def CheckNoFRIEND_TEST( +def CheckNoFRIEND_TEST(# pylint: disable=invalid-name input_api, - output_api, # pylint: disable=invalid-name + output_api, source_file_filter): """Make sure that gtest's FRIEND_TEST() macro is not used, the FRIEND_TEST_ALL_PREFIXES() macro from testsupport/gtest_prod_util.h should be @@ -263,9 +263,9 @@ def CheckNoFRIEND_TEST( return [] return [ output_api.PresubmitPromptWarning( - 'WebRTC\'s code should not use ' - 'gtest\'s FRIEND_TEST() macro. Include testsupport/gtest_prod_util.h and ' - 'use FRIEND_TEST_ALL_PREFIXES() instead.\n' + '\n'.join(problems)) + 'WebRTC\'s code should not use gtest\'s FRIEND_TEST() macro. ' + 'Include testsupport/gtest_prod_util.h and use ' + 'FRIEND_TEST_ALL_PREFIXES() instead.\n' + '\n'.join(problems)) ] @@ -346,9 +346,9 @@ def CheckNoSourcesAbove(input_api, gn_files, output_api): if violating_gn_files: return [ output_api.PresubmitError( - 'Referencing source files above the directory of the GN file is not ' - 'allowed. Please introduce new GN targets in the proper location ' - 'instead.\n' + 'Referencing source files above the directory of the GN file ' + 'is not allowed. Please introduce new GN targets in the proper ' + 'location instead.\n' 'Invalid source entries:\n' '%s\n' 'Violating GN files:' % '\n'.join(violating_source_entries), @@ -407,9 +407,9 @@ def _MoreThanOneSourceUsed(*sources_lists): gn_file_content = input_api.ReadFile(gn_file) for target_match in TARGET_RE.finditer(gn_file_content): # list_of_sources is a list of tuples of the form - # (c_files, cc_files, objc_files) that keeps track of all the sources - # defined in a target. A GN target can have more that on definition of - # sources (since it supports if/else statements). + # (c_files, cc_files, objc_files) that keeps track of all the + # sources defined in a target. A GN target can have more that + # on definition of sources (since it supports if/else statements). # E.g.: # rtc_static_library("foo") { # if (is_win) { @@ -454,7 +454,8 @@ def _MoreThanOneSourceUsed(*sources_lists): return [ output_api.PresubmitError( 'GN targets cannot mix .c, .cc and .m (or .mm) source files.\n' - 'Please create a separate target for each collection of sources.\n' + 'Please create a separate target for each collection of ' + 'sources.\n' 'Mixed sources: \n' '%s\n' 'Violating GN files:\n%s\n' % @@ -476,8 +477,8 @@ def CheckNoPackageBoundaryViolations(input_api, gn_files, output_api): if errors: return [ output_api.PresubmitError( - 'There are package boundary violations in the following GN files:', - long_text='\n\n'.join(str(err) for err in errors)) + 'There are package boundary violations in the following GN ' + 'files:', long_text='\n\n'.join(str(err) for err in errors)) ] return [] @@ -491,7 +492,7 @@ def CheckNoWarningSuppressionFlagsAreAdded(gn_files, input_api, output_api, error_formatter=_ReportFileAndLine): - """Make sure that warning suppression flags are not added wihtout a reason.""" + """Ensure warning suppression flags are not added wihtout a reason.""" msg = ('Usage of //build/config/clang:extra_warnings is discouraged ' 'in WebRTC.\n' 'If you are not adding this code (e.g. you are just moving ' @@ -674,7 +675,8 @@ def CheckGnGen(input_api, output_api): if errors: return [ output_api.PresubmitPromptWarning( - 'Some #includes do not match the build dependency graph. Please run:\n' + 'Some #includes do not match the build dependency graph. ' + 'Please run:\n' ' gn gen --check ', long_text='\n\n'.join(errors)) ] @@ -729,18 +731,20 @@ def CheckUnwantedDependencies(input_api, output_api, source_file_filter): if error_descriptions: results.append( output_api.PresubmitError( - 'You added one or more #includes that violate checkdeps rules.\n' - 'Check that the DEPS files in these locations contain valid rules.\n' - 'See https://cs.chromium.org/chromium/src/buildtools/checkdeps/ for ' - 'more details about checkdeps.', error_descriptions)) + 'You added one or more #includes that violate checkdeps rules.' + '\nCheck that the DEPS files in these locations contain valid ' + 'rules.\nSee ' + 'https://cs.chromium.org/chromium/src/buildtools/checkdeps/ ' + 'for more details about checkdeps.', error_descriptions)) if warning_descriptions: results.append( output_api.PresubmitPromptOrNotify( - 'You added one or more #includes of files that are temporarily\n' - 'allowed but being removed. Can you avoid introducing the\n' - '#include? See relevant DEPS file(s) for details and contacts.\n' - 'See https://cs.chromium.org/chromium/src/buildtools/checkdeps/ for ' - 'more details about checkdeps.', warning_descriptions)) + 'You added one or more #includes of files that are temporarily' + '\nallowed but being removed. Can you avoid introducing the\n' + '#include? See relevant DEPS file(s) for details and contacts.' + '\nSee ' + 'https://cs.chromium.org/chromium/src/buildtools/checkdeps/ ' + 'for more details about checkdeps.', warning_descriptions)) return results @@ -787,9 +791,10 @@ def CheckChangeHasBugField(input_api, output_api): else: return [ output_api.PresubmitError( - 'The "Bug: [bug number]" footer is mandatory. Please create a bug and ' - 'reference it using either of:\n' - ' * https://bugs.webrtc.org - reference it using Bug: webrtc:XXXX\n' + 'The "Bug: [bug number]" footer is mandatory. Please create a ' + 'bug and reference it using either of:\n' + ' * https://bugs.webrtc.org - reference it using Bug: ' + 'webrtc:XXXX\n' ' * https://crbug.com - reference it using Bug: chromium:XXXXXX' ) ] @@ -911,10 +916,19 @@ def CommonChecks(input_api, output_api): results.extend( input_api.canned_checks.CheckLicense(input_api, output_api, _LicenseHeader(input_api))) + + # TODO(bugs.webrtc.org/12114): Delete this filter and run pylint on + # all python files. This is a temporary solution. + python_file_filter = lambda f: (f.LocalPath().endswith('.py') and + source_file_filter(f)) + python_changed_files = [f.LocalPath() for f in input_api.AffectedFiles( + file_filter=python_file_filter)] + results.extend( input_api.canned_checks.RunPylint( input_api, output_api, + files_to_check=python_changed_files, files_to_skip=( r'^base[\\\/].*\.py$', r'^build[\\\/].*\.py$', @@ -932,12 +946,13 @@ def CommonChecks(input_api, output_api): pylintrc='pylintrc')) # TODO(nisse): talk/ is no more, so make below checks simpler? - # WebRTC can't use the presubmit_canned_checks.PanProjectChecks function since - # we need to have different license checks in talk/ and webrtc/ directories. + # WebRTC can't use the presubmit_canned_checks.PanProjectChecks function + # since we need to have different license checks + # in talk/ and webrtc/directories. # Instead, hand-picked checks are included below. - # .m and .mm files are ObjC files. For simplicity we will consider .h files in - # ObjC subdirectories ObjC headers. + # .m and .mm files are ObjC files. For simplicity we will consider + # .h files in ObjC subdirectories ObjC headers. objc_filter_list = (r'.+\.m$', r'.+\.mm$', r'.+objc\/.+\.h$') # Skip long-lines check for DEPS and GN files. build_file_filter_list = (r'.+\.gn$', r'.+\.gni$', 'DEPS') @@ -1163,9 +1178,9 @@ def CheckAbslMemoryInclude(input_api, output_api, source_file_filter): if len(files): return [ output_api.PresubmitError( - 'Please include "absl/memory/memory.h" header for absl::WrapUnique.\n' - 'This header may or may not be included transitively depending on the ' - 'C++ standard version.', files) + 'Please include "absl/memory/memory.h" header for ' + 'absl::WrapUnique.\nThis header may or may not be included ' + 'transitively depending on the C++ standard version.', files) ] return [] @@ -1319,10 +1334,10 @@ def _CalculateAddedDeps(os_path, old_contents, new_contents): def CheckAddedDepsHaveTargetApprovals(input_api, output_api): """When a dependency prefixed with + is added to a DEPS file, we - want to make sure that the change is reviewed by an OWNER of the - target file or directory, to avoid layering violations from being - introduced. This check verifies that this happens. - """ + want to make sure that the change is reviewed by an OWNER of the + target file or directory, to avoid layering violations from being + introduced. This check verifies that this happens. + """ virtual_depended_on_files = set() file_filter = lambda f: not input_api.re.match( @@ -1343,13 +1358,15 @@ def CheckAddedDepsHaveTargetApprovals(input_api, output_api): if input_api.tbr: return [ output_api.PresubmitNotifyResult( - '--tbr was specified, skipping OWNERS check for DEPS additions' + '--tbr was specified, skipping OWNERS check for DEPS ' + 'additions' ) ] if input_api.dry_run: return [ output_api.PresubmitNotifyResult( - 'This is a dry run, skipping OWNERS check for DEPS additions' + 'This is a dry run, skipping OWNERS check for DEPS ' + 'additions' ) ] if not input_api.change.issue: @@ -1362,20 +1379,19 @@ def CheckAddedDepsHaveTargetApprovals(input_api, output_api): else: output = output_api.PresubmitNotifyResult - owners_db = input_api.owners_db owner_email, reviewers = ( input_api.canned_checks.GetCodereviewOwnerAndReviewers( input_api, - owners_db.email_regexp, + None, approval_needed=input_api.is_committing)) owner_email = owner_email or input_api.change.author_email - reviewers_plus_owner = set(reviewers) - if owner_email: - reviewers_plus_owner.add(owner_email) - missing_files = owners_db.files_not_covered_by(virtual_depended_on_files, - reviewers_plus_owner) + approval_status = input_api.owners_client.GetFilesApprovalStatus( + virtual_depended_on_files, reviewers.union([owner_email]), []) + missing_files = [ + f for f in virtual_depended_on_files + if approval_status[f] != input_api.owners_client.APPROVED] # We strip the /DEPS part that was added by # _FilesToCheckForIncomingDeps to fake a path to a file in a @@ -1394,11 +1410,12 @@ def StripDeps(path): if unapproved_dependencies: output_list = [ output( - 'You need LGTM from owners of depends-on paths in DEPS that were ' - 'modified in this CL:\n %s' % + 'You need LGTM from owners of depends-on paths in DEPS that ' + ' were modified in this CL:\n %s' % '\n '.join(sorted(unapproved_dependencies))) ] - suggested_owners = owners_db.reviewers_for(missing_files, owner_email) + suggested_owners = input_api.owners_client.SuggestOwners( + missing_files, exclude=[owner_email]) output_list.append( output('Suggested missing target path OWNERS:\n %s' % '\n '.join(suggested_owners or []))) diff --git a/README.md b/README.md index 4ffa4bae06..1ae9ea9cb7 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,8 @@ native API header files. * Master source code repo: https://webrtc.googlesource.com/src * Samples and reference apps: https://github.com/webrtc * Mailing list: http://groups.google.com/group/discuss-webrtc - * Continuous build: http://build.chromium.org/p/client.webrtc - * [Coding style guide](style-guide.md) + * Continuous build: https://ci.chromium.org/p/webrtc/g/ci/console + * [Coding style guide](g3doc/style-guide.md) * [Code of conduct](CODE_OF_CONDUCT.md) * [Reporting bugs](docs/bug-reporting.md) diff --git a/api/BUILD.gn b/api/BUILD.gn index f02c5fd434..c775a1a871 100644 --- a/api/BUILD.gn +++ b/api/BUILD.gn @@ -29,7 +29,10 @@ rtc_source_set("call_api") { rtc_source_set("callfactory_api") { visibility = [ "*" ] sources = [ "call/call_factory_interface.h" ] - deps = [ "../rtc_base/system:rtc_export" ] + deps = [ + "../call:rtp_interfaces", + "../rtc_base/system:rtc_export", + ] } if (!build_with_chromium) { @@ -52,6 +55,7 @@ if (!build_with_chromium) { "../pc:peerconnection", "../rtc_base", "../rtc_base:rtc_base_approved", + "../rtc_base:threading", "audio:audio_mixer_api", "audio_codecs:audio_codecs_api", "task_queue:default_task_queue_factory", @@ -89,6 +93,7 @@ rtc_library("rtp_packet_info") { ":scoped_refptr", "../rtc_base:rtc_base_approved", "../rtc_base/system:rtc_export", + "units:timestamp", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } @@ -133,14 +138,8 @@ rtc_library("libjingle_peerconnection_api") { "jsep_ice_candidate.cc", "jsep_ice_candidate.h", "jsep_session_description.h", - "media_stream_proxy.h", - "media_stream_track_proxy.h", - "peer_connection_factory_proxy.h", "peer_connection_interface.cc", "peer_connection_interface.h", - "peer_connection_proxy.h", - "proxy.cc", - "proxy.h", "rtp_receiver_interface.cc", "rtp_receiver_interface.h", "rtp_sender_interface.cc", @@ -155,10 +154,11 @@ rtc_library("libjingle_peerconnection_api") { "stats_types.h", "turn_customizer.h", "uma_metrics.h", - "video_track_source_proxy.h", + "video_track_source_proxy_factory.h", ] deps = [ ":array_view", + ":async_dns_resolver", ":audio_options_api", ":callfactory_api", ":fec_controller_api", @@ -174,6 +174,9 @@ rtc_library("libjingle_peerconnection_api") { ":rtp_parameters", ":rtp_transceiver_direction", ":scoped_refptr", + ":sequence_checker", + "../call:rtp_interfaces", + "../rtc_base:network_constants", "adaptation:resource_adaptation_api", "audio:audio_mixer_api", "audio_codecs:audio_codecs_api", @@ -192,6 +195,7 @@ rtc_library("libjingle_peerconnection_api") { "units:data_rate", "units:timestamp", "video:encoded_image", + "video:video_bitrate_allocator_factory", "video:video_frame", "video:video_rtp_headers", @@ -203,8 +207,10 @@ rtc_library("libjingle_peerconnection_api") { "../modules/audio_processing:audio_processing_statistics", "../rtc_base", "../rtc_base:checks", - "../rtc_base:deprecation", + "../rtc_base:ip_address", "../rtc_base:rtc_base_approved", + "../rtc_base:socket_address", + "../rtc_base:threading", "../rtc_base/system:rtc_export", ] absl_deps = [ @@ -248,7 +254,18 @@ rtc_source_set("packet_socket_factory") { "packet_socket_factory.h", ] deps = [ + ":async_dns_resolver", + "../rtc_base:async_resolver_interface", "../rtc_base:rtc_base", + "../rtc_base:socket_address", + "../rtc_base/system:rtc_export", + ] +} + +rtc_source_set("async_dns_resolver") { + sources = [ "async_dns_resolver.h" ] + deps = [ + "../rtc_base:socket_address", "../rtc_base/system:rtc_export", ] } @@ -383,6 +400,7 @@ rtc_source_set("peer_connection_quality_test_fixture_api") { ":video_quality_analyzer_api", "../media:rtc_media_base", "../rtc_base:rtc_base", + "../rtc_base:threading", "rtc_event_log", "task_queue", "transport:network_control", @@ -427,22 +445,6 @@ rtc_library("test_dependency_factory") { } if (rtc_include_tests) { - rtc_library("create_video_quality_test_fixture_api") { - visibility = [ "*" ] - testonly = true - sources = [ - "test/create_video_quality_test_fixture.cc", - "test/create_video_quality_test_fixture.h", - ] - deps = [ - ":fec_controller_api", - ":network_state_predictor_api", - ":scoped_refptr", - ":video_quality_test_fixture_api", - "../video:video_quality_test", - ] - } - # TODO(srte): Move to network_emulation sub directory. rtc_library("create_network_emulation_manager") { visibility = [ "*" ] @@ -457,21 +459,39 @@ if (rtc_include_tests) { ] } - rtc_library("create_peerconnection_quality_test_fixture") { - visibility = [ "*" ] - testonly = true - sources = [ - "test/create_peerconnection_quality_test_fixture.cc", - "test/create_peerconnection_quality_test_fixture.h", - ] + if (!build_with_chromium) { + rtc_library("create_video_quality_test_fixture_api") { + visibility = [ "*" ] + testonly = true + sources = [ + "test/create_video_quality_test_fixture.cc", + "test/create_video_quality_test_fixture.h", + ] + deps = [ + ":fec_controller_api", + ":network_state_predictor_api", + ":scoped_refptr", + ":video_quality_test_fixture_api", + "../video:video_quality_test", + ] + } - deps = [ - ":audio_quality_analyzer_api", - ":peer_connection_quality_test_fixture_api", - ":time_controller", - ":video_quality_analyzer_api", - "../test/pc/e2e:peerconnection_quality_test", - ] + rtc_library("create_peerconnection_quality_test_fixture") { + visibility = [ "*" ] + testonly = true + sources = [ + "test/create_peerconnection_quality_test_fixture.cc", + "test/create_peerconnection_quality_test_fixture.h", + ] + + deps = [ + ":audio_quality_analyzer_api", + ":peer_connection_quality_test_fixture_api", + ":time_controller", + ":video_quality_analyzer_api", + "../test/pc/e2e:peerconnection_quality_test", + ] + } } } @@ -541,6 +561,7 @@ rtc_source_set("rtc_stats_api") { deps = [ ":scoped_refptr", + "../api:refcountedbase", "../rtc_base:checks", "../rtc_base:rtc_base_approved", "../rtc_base/system:rtc_export", @@ -568,6 +589,10 @@ rtc_library("transport_api") { "call/transport.cc", "call/transport.h", ] + deps = [ + ":refcountedbase", + ":scoped_refptr", + ] } rtc_source_set("bitrate_allocation") { @@ -600,6 +625,8 @@ rtc_source_set("network_emulation_manager_api") { ":time_controller", "../call:simulated_network", "../rtc_base", + "../rtc_base:network_constants", + "../rtc_base:threading", "test/network_emulation", "units:data_rate", "units:data_size", @@ -617,6 +644,7 @@ rtc_source_set("time_controller") { deps = [ "../modules/utility", "../rtc_base", + "../rtc_base:threading", "../rtc_base/synchronization:yield_policy", "../system_wrappers", "task_queue", @@ -655,7 +683,10 @@ rtc_source_set("array_view") { rtc_source_set("refcountedbase") { visibility = [ "*" ] sources = [ "ref_counted_base.h" ] - deps = [ "../rtc_base:rtc_base_approved" ] + deps = [ + "../rtc_base:macromagic", + "../rtc_base:refcount", + ] } rtc_library("ice_transport_factory") { @@ -670,6 +701,7 @@ rtc_library("ice_transport_factory") { ":scoped_refptr", "../p2p:rtc_p2p", "../rtc_base", + "../rtc_base:threading", "../rtc_base/system:rtc_export", "rtc_event_log:rtc_event_log", ] @@ -689,8 +721,18 @@ rtc_source_set("function_view") { deps = [ "../rtc_base:checks" ] } +rtc_source_set("sequence_checker") { + visibility = [ "*" ] + sources = [ "sequence_checker.h" ] + deps = [ + "../rtc_base:checks", + "../rtc_base:macromagic", + "../rtc_base/synchronization:sequence_checker_internal", + ] +} + if (rtc_include_tests) { - if (rtc_enable_protobuf) { + if (rtc_enable_protobuf && !build_with_chromium) { rtc_library("audioproc_f_api") { visibility = [ "*" ] testonly = true @@ -911,6 +953,15 @@ if (rtc_include_tests) { ] } + rtc_source_set("mock_async_dns_resolver") { + testonly = true + sources = [ "test/mock_async_dns_resolver.h" ] + deps = [ + ":async_dns_resolver", + "../test:test_support", + ] + } + rtc_source_set("mock_rtp") { visibility = [ "*" ] testonly = true @@ -1006,6 +1057,7 @@ if (rtc_include_tests) { ":time_controller", "../call", "../call:call_interfaces", + "../call:rtp_interfaces", "../test/time_controller", ] } @@ -1022,6 +1074,7 @@ if (rtc_include_tests) { "rtp_packet_infos_unittest.cc", "rtp_parameters_unittest.cc", "scoped_refptr_unittest.cc", + "sequence_checker_unittest.cc", "test/create_time_controller_unittest.cc", ] @@ -1035,11 +1088,13 @@ if (rtc_include_tests) { ":rtp_packet_info", ":rtp_parameters", ":scoped_refptr", + ":sequence_checker", ":time_controller", "../rtc_base:checks", "../rtc_base:gunit_helpers", "../rtc_base:rtc_base_approved", "../rtc_base:rtc_task_queue", + "../rtc_base:task_queue_for_test", "../rtc_base/task_utils:repeating_task", "../test:fileutils", "../test:test_support", @@ -1047,6 +1102,7 @@ if (rtc_include_tests) { "units:time_delta", "units:timestamp", "units:units_unittests", + "video:rtp_video_frame_assembler_unittests", "video:video_unittests", ] } @@ -1060,6 +1116,7 @@ if (rtc_include_tests) { ":dummy_peer_connection", ":fake_frame_decryptor", ":fake_frame_encryptor", + ":mock_async_dns_resolver", ":mock_audio_mixer", ":mock_data_channel", ":mock_frame_decryptor", diff --git a/api/DEPS b/api/DEPS index 4b93438c3e..cdd17e9909 100644 --- a/api/DEPS +++ b/api/DEPS @@ -11,10 +11,12 @@ include_rules = [ "-common_video", "-data", "-examples", + "-g3doc", "-ios", "-infra", "-logging", "-media", + "-net", "-modules", "-out", "-p2p", @@ -40,12 +42,16 @@ include_rules = [ specific_include_rules = { # Some internal headers are allowed even in API headers: + + "call_factory_interface\.h": [ + "+call/rtp_transport_controller_send_factory_interface.h", + ], + ".*\.h": [ "+rtc_base/checks.h", "+rtc_base/system/rtc_export.h", "+rtc_base/system/rtc_export_template.h", "+rtc_base/units/unit_base.h", - "+rtc_base/deprecation.h", ], "array_view\.h": [ @@ -63,6 +69,10 @@ specific_include_rules = { "+rtc_base/async_resolver_interface.h", ], + "async_dns_resolver\.h": [ + "+rtc_base/socket_address.h", + ], + "candidate\.h": [ "+rtc_base/network_constants.h", "+rtc_base/socket_address.h", @@ -120,20 +130,22 @@ specific_include_rules = { "+rtc_base/async_packet_socket.h", ], - "peer_connection_factory_proxy\.h": [ - "+rtc_base/bind.h", - ], - "peer_connection_interface\.h": [ + "+call/rtp_transport_controller_send_factory_interface.h", "+media/base/media_config.h", "+media/base/media_engine.h", + "+p2p/base/port.h", "+p2p/base/port_allocator.h", + "+rtc_base/network.h", + "+rtc_base/network_constants.h", "+rtc_base/network_monitor_factory.h", + "+rtc_base/ref_count.h", "+rtc_base/rtc_certificate.h", "+rtc_base/rtc_certificate_generator.h", "+rtc_base/socket_address.h", "+rtc_base/ssl_certificate.h", "+rtc_base/ssl_stream_adapter.h", + "+rtc_base/thread.h", ], "proxy\.h": [ @@ -182,7 +194,6 @@ specific_include_rules = { "stats_types\.h": [ "+rtc_base/constructor_magic.h", "+rtc_base/ref_count.h", - "+rtc_base/string_encode.h", "+rtc_base/thread_checker.h", ], @@ -281,6 +292,11 @@ specific_include_rules = { "+rtc_base/ref_count.h", ], + "sequence_checker\.h": [ + "+rtc_base/synchronization/sequence_checker_internal.h", + "+rtc_base/thread_annotations.h", + ], + # .cc files in api/ should not be restricted in what they can #include, # so we re-add all the top-level directories here. (That's because .h # files leak their #includes to whoever's #including them, but .cc files diff --git a/api/OWNERS b/api/OWNERS index e18667970b..6ffb2588aa 100644 --- a/api/OWNERS +++ b/api/OWNERS @@ -11,15 +11,4 @@ per-file peer_connection*=hbos@webrtc.org per-file DEPS=mbonadei@webrtc.org -# Please keep this list in sync with Chromium's //base/metrics/OWNERS and -# send a CL when you notice any difference. -# Even if people in the list below cannot formally grant +1 on WebRTC, it -# is good to get their LGTM before sending the CL to one of the folder OWNERS. -per-file uma_metrics.h=asvitkine@chromium.org -per-file uma_metrics.h=bcwhite@chromium.org -per-file uma_metrics.h=caitlinfischer@google.com -per-file uma_metrics.h=holte@chromium.org -per-file uma_metrics.h=isherman@chromium.org -per-file uma_metrics.h=jwd@chromium.org -per-file uma_metrics.h=mpearson@chromium.org -per-file uma_metrics.h=rkaplow@chromium.org +per-file uma_metrics.h=kron@webrtc.org diff --git a/api/README.md b/api/README.md index 4cc799362d..7c1a27f512 100644 --- a/api/README.md +++ b/api/README.md @@ -1,6 +1,6 @@ # How to write code in the `api/` directory -Mostly, just follow the regular [style guide](../style-guide.md), but: +Mostly, just follow the regular [style guide](../g3doc/style-guide.md), but: * Note that `api/` code is not exempt from the “`.h` and `.cc` files come in pairs” rule, so if you declare something in `api/path/to/foo.h`, it should be @@ -17,7 +17,7 @@ it from a `.cc` file, so that users of our API headers won’t transitively For headers in `api/` that need to refer to non-public types, forward declarations are often a lesser evil than including non-public header files. The -usual [rules](../style-guide.md#forward-declarations) still apply, though. +usual [rules](../g3doc/style-guide.md#forward-declarations) still apply, though. `.cc` files in `api/` should preferably be kept reasonably small. If a substantial implementation is needed, consider putting it with our non-public diff --git a/api/async_dns_resolver.h b/api/async_dns_resolver.h new file mode 100644 index 0000000000..eabb41c11f --- /dev/null +++ b/api/async_dns_resolver.h @@ -0,0 +1,86 @@ +/* + * Copyright 2021 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef API_ASYNC_DNS_RESOLVER_H_ +#define API_ASYNC_DNS_RESOLVER_H_ + +#include +#include + +#include "rtc_base/socket_address.h" +#include "rtc_base/system/rtc_export.h" + +namespace webrtc { + +// This interface defines the methods to resolve a hostname asynchronously. +// The AsyncDnsResolverInterface class encapsulates a single name query. +// +// Usage: +// std::unique_ptr resolver = +// factory->Create(address-to-be-resolved, [r = resolver.get()]() { +// if (r->result.GetResolvedAddress(AF_INET, &addr) { +// // success +// } else { +// // failure +// error = r->result().GetError(); +// } +// // Release resolver. +// resolver_list.erase(std::remove_if(resolver_list.begin(), +// resolver_list.end(), +// [](refptr) { refptr.get() == r; }); +// }); +// resolver_list.push_back(std::move(resolver)); + +class AsyncDnsResolverResult { + public: + virtual ~AsyncDnsResolverResult() = default; + // Returns true iff the address from |Start| was successfully resolved. + // If the address was successfully resolved, sets |addr| to a copy of the + // address from |Start| with the IP address set to the top most resolved + // address of |family| (|addr| will have both hostname and the resolved ip). + virtual bool GetResolvedAddress(int family, + rtc::SocketAddress* addr) const = 0; + // Returns error from resolver. + virtual int GetError() const = 0; +}; + +class RTC_EXPORT AsyncDnsResolverInterface { + public: + virtual ~AsyncDnsResolverInterface() = default; + + // Start address resolution of the hostname in |addr|. + virtual void Start(const rtc::SocketAddress& addr, + std::function callback) = 0; + virtual const AsyncDnsResolverResult& result() const = 0; +}; + +// An abstract factory for creating AsyncDnsResolverInterfaces. This allows +// client applications to provide WebRTC with their own mechanism for +// performing DNS resolution. +class AsyncDnsResolverFactoryInterface { + public: + virtual ~AsyncDnsResolverFactoryInterface() = default; + + // Creates an AsyncDnsResolver and starts resolving the name. The callback + // will be called when resolution is finished. + // The callback will be called on the thread that the caller runs on. + virtual std::unique_ptr CreateAndResolve( + const rtc::SocketAddress& addr, + std::function callback) = 0; + // Creates an AsyncDnsResolver and does not start it. + // For backwards compatibility, will be deprecated and removed. + // One has to do a separate Start() call on the + // resolver to start name resolution. + virtual std::unique_ptr Create() = 0; +}; + +} // namespace webrtc + +#endif // API_ASYNC_DNS_RESOLVER_H_ diff --git a/api/audio/echo_canceller3_config.cc b/api/audio/echo_canceller3_config.cc index 5f1923e90f..b38d6b5b7e 100644 --- a/api/audio/echo_canceller3_config.cc +++ b/api/audio/echo_canceller3_config.cc @@ -153,7 +153,7 @@ bool EchoCanceller3Config::Validate(EchoCanceller3Config* config) { res = res & Limit(&c->filter.config_change_duration_blocks, 0, 100000); res = res & Limit(&c->filter.initial_state_seconds, 0.f, 100.f); - res = res & Limit(&c->filter.coarse_reset_hangover_blocks, 0, 2500); + res = res & Limit(&c->filter.coarse_reset_hangover_blocks, 0, 250000); res = res & Limit(&c->erle.min, 1.f, 100000.f); res = res & Limit(&c->erle.max_l, 1.f, 100000.f); @@ -229,6 +229,12 @@ bool EchoCanceller3Config::Validate(EchoCanceller3Config* config) { res = res & Limit(&c->suppressor.nearend_tuning.max_dec_factor_lf, 0.f, 100.f); + res = res & Limit(&c->suppressor.last_permanent_lf_smoothing_band, 0, 64); + res = res & Limit(&c->suppressor.last_lf_smoothing_band, 0, 64); + res = res & Limit(&c->suppressor.last_lf_band, 0, 63); + res = res & + Limit(&c->suppressor.first_hf_band, c->suppressor.last_lf_band + 1, 64); + res = res & Limit(&c->suppressor.dominant_nearend_detection.enr_threshold, 0.f, 1000000.f); res = res & Limit(&c->suppressor.dominant_nearend_detection.snr_threshold, diff --git a/api/audio/echo_canceller3_config.h b/api/audio/echo_canceller3_config.h index 55281af93b..087e8da439 100644 --- a/api/audio/echo_canceller3_config.h +++ b/api/audio/echo_canceller3_config.h @@ -43,6 +43,7 @@ struct RTC_EXPORT EchoCanceller3Config { size_t hysteresis_limit_blocks = 1; size_t fixed_capture_delay_samples = 0; float delay_estimate_smoothing = 0.7f; + float delay_estimate_smoothing_delay_found = 0.7f; float delay_candidate_detection_threshold = 0.2f; struct DelaySelectionThresholds { int initial; @@ -90,6 +91,7 @@ struct RTC_EXPORT EchoCanceller3Config { bool conservative_initial_phase = false; bool enable_coarse_filter_output_usage = true; bool use_linear_filter = true; + bool high_pass_filter_echo_reference = false; bool export_linear_aec_output = false; } filter; @@ -108,6 +110,7 @@ struct RTC_EXPORT EchoCanceller3Config { float default_len = 0.83f; bool echo_can_saturate = true; bool bounded_erl = false; + bool erle_onset_compensation_in_dominant_nearend = false; } ep_strength; struct EchoAudibility { @@ -191,6 +194,12 @@ struct RTC_EXPORT EchoCanceller3Config { 2.0f, 0.25f); + bool lf_smoothing_during_initial_phase = true; + int last_permanent_lf_smoothing_band = 0; + int last_lf_smoothing_band = 5; + int last_lf_band = 5; + int first_hf_band = 8; + struct DominantNearendDetection { float enr_threshold = .25f; float enr_exit_threshold = 10.f; diff --git a/api/audio/echo_canceller3_config_json.cc b/api/audio/echo_canceller3_config_json.cc index 9d10da9949..263599c538 100644 --- a/api/audio/echo_canceller3_config_json.cc +++ b/api/audio/echo_canceller3_config_json.cc @@ -11,6 +11,7 @@ #include +#include #include #include @@ -156,9 +157,14 @@ void Aec3ConfigFromJsonString(absl::string_view json_string, *parsing_successful = true; Json::Value root; - bool success = Json::Reader().parse(std::string(json_string), root); + Json::CharReaderBuilder builder; + std::string error_message; + std::unique_ptr reader(builder.newCharReader()); + bool success = + reader->parse(json_string.data(), json_string.data() + json_string.size(), + &root, &error_message); if (!success) { - RTC_LOG(LS_ERROR) << "Incorrect JSON format: " << json_string; + RTC_LOG(LS_ERROR) << "Incorrect JSON format: " << error_message; *parsing_successful = false; return; } @@ -191,6 +197,8 @@ void Aec3ConfigFromJsonString(absl::string_view json_string, &cfg.delay.fixed_capture_delay_samples); ReadParam(section, "delay_estimate_smoothing", &cfg.delay.delay_estimate_smoothing); + ReadParam(section, "delay_estimate_smoothing_delay_found", + &cfg.delay.delay_estimate_smoothing_delay_found); ReadParam(section, "delay_candidate_detection_threshold", &cfg.delay.delay_candidate_detection_threshold); @@ -230,6 +238,8 @@ void Aec3ConfigFromJsonString(absl::string_view json_string, ReadParam(section, "enable_coarse_filter_output_usage", &cfg.filter.enable_coarse_filter_output_usage); ReadParam(section, "use_linear_filter", &cfg.filter.use_linear_filter); + ReadParam(section, "high_pass_filter_echo_reference", + &cfg.filter.high_pass_filter_echo_reference); ReadParam(section, "export_linear_aec_output", &cfg.filter.export_linear_aec_output); } @@ -251,6 +261,8 @@ void Aec3ConfigFromJsonString(absl::string_view json_string, ReadParam(section, "default_len", &cfg.ep_strength.default_len); ReadParam(section, "echo_can_saturate", &cfg.ep_strength.echo_can_saturate); ReadParam(section, "bounded_erl", &cfg.ep_strength.bounded_erl); + ReadParam(section, "erle_onset_compensation_in_dominant_nearend", + &cfg.ep_strength.erle_onset_compensation_in_dominant_nearend); } if (rtc::GetValueFromJsonObject(aec3_root, "echo_audibility", §ion)) { @@ -335,6 +347,15 @@ void Aec3ConfigFromJsonString(absl::string_view json_string, &cfg.suppressor.nearend_tuning.max_dec_factor_lf); } + ReadParam(section, "lf_smoothing_during_initial_phase", + &cfg.suppressor.lf_smoothing_during_initial_phase); + ReadParam(section, "last_permanent_lf_smoothing_band", + &cfg.suppressor.last_permanent_lf_smoothing_band); + ReadParam(section, "last_lf_smoothing_band", + &cfg.suppressor.last_lf_smoothing_band); + ReadParam(section, "last_lf_band", &cfg.suppressor.last_lf_band); + ReadParam(section, "first_hf_band", &cfg.suppressor.first_hf_band); + if (rtc::GetValueFromJsonObject(section, "dominant_nearend_detection", &subsection)) { ReadParam(subsection, "enr_threshold", @@ -421,6 +442,8 @@ std::string Aec3ConfigToJsonString(const EchoCanceller3Config& config) { << config.delay.fixed_capture_delay_samples << ","; ost << "\"delay_estimate_smoothing\": " << config.delay.delay_estimate_smoothing << ","; + ost << "\"delay_estimate_smoothing_delay_found\": " + << config.delay.delay_estimate_smoothing_delay_found << ","; ost << "\"delay_candidate_detection_threshold\": " << config.delay.delay_candidate_detection_threshold << ","; @@ -513,6 +536,9 @@ std::string Aec3ConfigToJsonString(const EchoCanceller3Config& config) { << ","; ost << "\"use_linear_filter\": " << (config.filter.use_linear_filter ? "true" : "false") << ","; + ost << "\"high_pass_filter_echo_reference\": " + << (config.filter.high_pass_filter_echo_reference ? "true" : "false") + << ","; ost << "\"export_linear_aec_output\": " << (config.filter.export_linear_aec_output ? "true" : "false"); @@ -537,8 +563,11 @@ std::string Aec3ConfigToJsonString(const EchoCanceller3Config& config) { ost << "\"echo_can_saturate\": " << (config.ep_strength.echo_can_saturate ? "true" : "false") << ","; ost << "\"bounded_erl\": " - << (config.ep_strength.bounded_erl ? "true" : "false"); - + << (config.ep_strength.bounded_erl ? "true" : "false") << ","; + ost << "\"erle_onset_compensation_in_dominant_nearend\": " + << (config.ep_strength.erle_onset_compensation_in_dominant_nearend + ? "true" + : "false"); ost << "},"; ost << "\"echo_audibility\": {"; @@ -637,6 +666,16 @@ std::string Aec3ConfigToJsonString(const EchoCanceller3Config& config) { ost << "\"max_dec_factor_lf\": " << config.suppressor.nearend_tuning.max_dec_factor_lf; ost << "},"; + ost << "\"lf_smoothing_during_initial_phase\": " + << (config.suppressor.lf_smoothing_during_initial_phase ? "true" + : "false") + << ","; + ost << "\"last_permanent_lf_smoothing_band\": " + << config.suppressor.last_permanent_lf_smoothing_band << ","; + ost << "\"last_lf_smoothing_band\": " + << config.suppressor.last_lf_smoothing_band << ","; + ost << "\"last_lf_band\": " << config.suppressor.last_lf_band << ","; + ost << "\"first_hf_band\": " << config.suppressor.first_hf_band << ","; ost << "\"dominant_nearend_detection\": {"; ost << "\"enr_threshold\": " << config.suppressor.dominant_nearend_detection.enr_threshold << ","; diff --git a/api/audio/echo_control.h b/api/audio/echo_control.h index 8d567bf2b8..74fbc27b12 100644 --- a/api/audio/echo_control.h +++ b/api/audio/echo_control.h @@ -48,6 +48,13 @@ class EchoControl { // Provides an optional external estimate of the audio buffer delay. virtual void SetAudioBufferDelay(int delay_ms) = 0; + // Specifies whether the capture output will be used. The purpose of this is + // to allow the echo controller to deactivate some of the processing when the + // resulting output is anyway not used, for instance when the endpoint is + // muted. + // TODO(b/177830919): Make pure virtual. + virtual void SetCaptureOutputUsage(bool capture_output_used) {} + // Returns wheter the signal is altered. virtual bool ActiveProcessing() const = 0; diff --git a/api/audio/echo_detector_creator.cc b/api/audio/echo_detector_creator.cc index 4c3d9e61fe..04215b0deb 100644 --- a/api/audio/echo_detector_creator.cc +++ b/api/audio/echo_detector_creator.cc @@ -15,7 +15,7 @@ namespace webrtc { rtc::scoped_refptr CreateEchoDetector() { - return new rtc::RefCountedObject(); + return rtc::make_ref_counted(); } } // namespace webrtc diff --git a/api/audio/test/echo_canceller3_config_json_unittest.cc b/api/audio/test/echo_canceller3_config_json_unittest.cc index 4a952fe910..d6edd07d2e 100644 --- a/api/audio/test/echo_canceller3_config_json_unittest.cc +++ b/api/audio/test/echo_canceller3_config_json_unittest.cc @@ -21,6 +21,8 @@ TEST(EchoCanceller3JsonHelpers, ToStringAndParseJson) { cfg.delay.log_warning_on_delay_changes = true; cfg.filter.refined.error_floor = 2.f; cfg.filter.coarse_initial.length_blocks = 3u; + cfg.filter.high_pass_filter_echo_reference = + !cfg.filter.high_pass_filter_echo_reference; cfg.comfort_noise.noise_floor_dbfs = 100.f; cfg.echo_model.model_reverb_in_nonlinear_mode = false; cfg.suppressor.normal_tuning.mask_hf.enr_suppress = .5f; @@ -47,6 +49,8 @@ TEST(EchoCanceller3JsonHelpers, ToStringAndParseJson) { cfg_transformed.filter.coarse_initial.length_blocks); EXPECT_EQ(cfg.filter.refined.error_floor, cfg_transformed.filter.refined.error_floor); + EXPECT_EQ(cfg.filter.high_pass_filter_echo_reference, + cfg_transformed.filter.high_pass_filter_echo_reference); EXPECT_EQ(cfg.comfort_noise.noise_floor_dbfs, cfg_transformed.comfort_noise.noise_floor_dbfs); EXPECT_EQ(cfg.echo_model.model_reverb_in_nonlinear_mode, diff --git a/api/audio_codecs/BUILD.gn b/api/audio_codecs/BUILD.gn index b6292de570..5926f5ec2e 100644 --- a/api/audio_codecs/BUILD.gn +++ b/api/audio_codecs/BUILD.gn @@ -33,7 +33,6 @@ rtc_library("audio_codecs_api") { "..:bitrate_allocation", "..:scoped_refptr", "../../rtc_base:checks", - "../../rtc_base:deprecation", "../../rtc_base:rtc_base_approved", "../../rtc_base:sanitizer", "../../rtc_base/system:rtc_export", diff --git a/api/audio_codecs/audio_decoder.cc b/api/audio_codecs/audio_decoder.cc index 97cda27a03..4b18b4ab52 100644 --- a/api/audio_codecs/audio_decoder.cc +++ b/api/audio_codecs/audio_decoder.cc @@ -162,7 +162,7 @@ AudioDecoder::SpeechType AudioDecoder::ConvertSpeechType(int16_t type) { case 2: return kComfortNoise; default: - assert(false); + RTC_NOTREACHED(); return kSpeech; } } diff --git a/api/audio_codecs/audio_decoder.h b/api/audio_codecs/audio_decoder.h index 557ffe2759..ce235946da 100644 --- a/api/audio_codecs/audio_decoder.h +++ b/api/audio_codecs/audio_decoder.h @@ -136,7 +136,7 @@ class AudioDecoder { // with the decoded audio on either side of the concealment. // Note: The default implementation of GeneratePlc will be deleted soon. All // implementations must provide their own, which can be a simple as a no-op. - // TODO(bugs.webrtc.org/9676): Remove default impementation. + // TODO(bugs.webrtc.org/9676): Remove default implementation. virtual void GeneratePlc(size_t requested_samples_per_channel, rtc::BufferT* concealment_audio); diff --git a/api/audio_codecs/audio_decoder_factory_template.h b/api/audio_codecs/audio_decoder_factory_template.h index e628cb62dc..388668d4c6 100644 --- a/api/audio_codecs/audio_decoder_factory_template.h +++ b/api/audio_codecs/audio_decoder_factory_template.h @@ -123,9 +123,8 @@ rtc::scoped_refptr CreateAudioDecoderFactory() { static_assert(sizeof...(Ts) >= 1, "Caller must give at least one template parameter"); - return rtc::scoped_refptr( - new rtc::RefCountedObject< - audio_decoder_factory_template_impl::AudioDecoderFactoryT>()); + return rtc::make_ref_counted< + audio_decoder_factory_template_impl::AudioDecoderFactoryT>(); } } // namespace webrtc diff --git a/api/audio_codecs/audio_encoder.h b/api/audio_codecs/audio_encoder.h index fd2d948863..92e42cf107 100644 --- a/api/audio_codecs/audio_encoder.h +++ b/api/audio_codecs/audio_encoder.h @@ -16,12 +16,12 @@ #include #include +#include "absl/base/attributes.h" #include "absl/types/optional.h" #include "api/array_view.h" #include "api/call/bitrate_allocation.h" #include "api/units/time_delta.h" #include "rtc_base/buffer.h" -#include "rtc_base/deprecation.h" namespace webrtc { @@ -182,12 +182,11 @@ class AudioEncoder { // implementation does nothing. virtual void SetMaxPlaybackRate(int frequency_hz); - // This is to be deprecated. Please use |OnReceivedTargetAudioBitrate| - // instead. // Tells the encoder what average bitrate we'd like it to produce. The // encoder is free to adjust or disregard the given bitrate (the default // implementation does the latter). - RTC_DEPRECATED virtual void SetTargetBitrate(int target_bps); + ABSL_DEPRECATED("Use OnReceivedTargetAudioBitrate instead") + virtual void SetTargetBitrate(int target_bps); // Causes this encoder to let go of any other encoders it contains, and // returns a pointer to an array where they are stored (which is required to @@ -210,7 +209,8 @@ class AudioEncoder { virtual void OnReceivedUplinkPacketLossFraction( float uplink_packet_loss_fraction); - RTC_DEPRECATED virtual void OnReceivedUplinkRecoverablePacketLossFraction( + ABSL_DEPRECATED("") + virtual void OnReceivedUplinkRecoverablePacketLossFraction( float uplink_recoverable_packet_loss_fraction); // Provides target audio bitrate to this encoder to allow it to adapt. diff --git a/api/audio_codecs/audio_encoder_factory_template.h b/api/audio_codecs/audio_encoder_factory_template.h index 74cb053425..cdc7defd25 100644 --- a/api/audio_codecs/audio_encoder_factory_template.h +++ b/api/audio_codecs/audio_encoder_factory_template.h @@ -142,9 +142,8 @@ rtc::scoped_refptr CreateAudioEncoderFactory() { static_assert(sizeof...(Ts) >= 1, "Caller must give at least one template parameter"); - return rtc::scoped_refptr( - new rtc::RefCountedObject< - audio_encoder_factory_template_impl::AudioEncoderFactoryT>()); + return rtc::make_ref_counted< + audio_encoder_factory_template_impl::AudioEncoderFactoryT>(); } } // namespace webrtc diff --git a/api/audio_codecs/opus/audio_encoder_multi_channel_opus_config.cc b/api/audio_codecs/opus/audio_encoder_multi_channel_opus_config.cc index f01caf11b6..0052c429b2 100644 --- a/api/audio_codecs/opus/audio_encoder_multi_channel_opus_config.cc +++ b/api/audio_codecs/opus/audio_encoder_multi_channel_opus_config.cc @@ -38,7 +38,7 @@ operator=(const AudioEncoderMultiChannelOpusConfig&) = default; bool AudioEncoderMultiChannelOpusConfig::IsOk() const { if (frame_size_ms <= 0 || frame_size_ms % 10 != 0) return false; - if (num_channels < 0 || num_channels >= 255) { + if (num_channels >= 255) { return false; } if (bitrate_bps < kMinBitrateBps || bitrate_bps > kMaxBitrateBps) @@ -47,7 +47,7 @@ bool AudioEncoderMultiChannelOpusConfig::IsOk() const { return false; // Check the lengths: - if (num_channels < 0 || num_streams < 0 || coupled_streams < 0) { + if (num_streams < 0 || coupled_streams < 0) { return false; } if (num_streams < coupled_streams) { diff --git a/api/audio_codecs/opus/audio_encoder_opus_config.cc b/api/audio_codecs/opus/audio_encoder_opus_config.cc index 2f36d0261e..0e6f55ee65 100644 --- a/api/audio_codecs/opus/audio_encoder_opus_config.cc +++ b/api/audio_codecs/opus/audio_encoder_opus_config.cc @@ -61,7 +61,7 @@ bool AudioEncoderOpusConfig::IsOk() const { // well; we can add support for them when needed.) return false; } - if (num_channels < 0 || num_channels >= 255) { + if (num_channels >= 255) { return false; } if (!bitrate_bps) diff --git a/api/audio_codecs/test/audio_decoder_factory_template_unittest.cc b/api/audio_codecs/test/audio_decoder_factory_template_unittest.cc index 0e2e8c229f..464ecfd487 100644 --- a/api/audio_codecs/test/audio_decoder_factory_template_unittest.cc +++ b/api/audio_codecs/test/audio_decoder_factory_template_unittest.cc @@ -78,7 +78,7 @@ struct AudioDecoderFakeApi { TEST(AudioDecoderFactoryTemplateTest, NoDecoderTypes) { rtc::scoped_refptr factory( - new rtc::RefCountedObject< + rtc::make_ref_counted< audio_decoder_factory_template_impl::AudioDecoderFactoryT<>>()); EXPECT_THAT(factory->GetSupportedDecoders(), ::testing::IsEmpty()); EXPECT_FALSE(factory->IsSupportedDecoder({"foo", 8000, 1})); diff --git a/api/audio_codecs/test/audio_encoder_factory_template_unittest.cc b/api/audio_codecs/test/audio_encoder_factory_template_unittest.cc index 95ea85576d..110f9930bd 100644 --- a/api/audio_codecs/test/audio_encoder_factory_template_unittest.cc +++ b/api/audio_codecs/test/audio_encoder_factory_template_unittest.cc @@ -78,7 +78,7 @@ struct AudioEncoderFakeApi { TEST(AudioEncoderFactoryTemplateTest, NoEncoderTypes) { rtc::scoped_refptr factory( - new rtc::RefCountedObject< + rtc::make_ref_counted< audio_encoder_factory_template_impl::AudioEncoderFactoryT<>>()); EXPECT_THAT(factory->GetSupportedEncoders(), ::testing::IsEmpty()); EXPECT_EQ(absl::nullopt, factory->QueryAudioEncoder({"foo", 8000, 1})); diff --git a/api/call/transport.h b/api/call/transport.h index 2a2a87a5f6..8bff28825d 100644 --- a/api/call/transport.h +++ b/api/call/transport.h @@ -14,7 +14,8 @@ #include #include -#include +#include "api/ref_counted_base.h" +#include "api/scoped_refptr.h" namespace webrtc { @@ -30,7 +31,7 @@ struct PacketOptions { int packet_id = -1; // Additional data bound to the RTP packet for use in application code, // outside of WebRTC. - std::vector application_data; + rtc::scoped_refptr additional_data; // Whether this is a retransmission of an earlier packet. bool is_retransmit = false; bool included_in_feedback = false; diff --git a/api/candidate.cc b/api/candidate.cc index c857f89c3c..d5fe3a0672 100644 --- a/api/candidate.cc +++ b/api/candidate.cc @@ -12,6 +12,7 @@ #include "rtc_base/helpers.h" #include "rtc_base/ip_address.h" +#include "rtc_base/logging.h" #include "rtc_base/strings/string_builder.h" namespace cricket { @@ -129,9 +130,21 @@ Candidate Candidate::ToSanitizedCopy(bool use_hostname_address, bool filter_related_address) const { Candidate copy(*this); if (use_hostname_address) { - rtc::SocketAddress hostname_only_addr(address().hostname(), - address().port()); - copy.set_address(hostname_only_addr); + rtc::IPAddress ip; + if (address().hostname().empty()) { + // IP needs to be redacted, but no hostname available. + rtc::SocketAddress redacted_addr("redacted-ip.invalid", address().port()); + copy.set_address(redacted_addr); + } else if (IPFromString(address().hostname(), &ip)) { + // The hostname is an IP literal, and needs to be redacted too. + rtc::SocketAddress redacted_addr("redacted-literal.invalid", + address().port()); + copy.set_address(redacted_addr); + } else { + rtc::SocketAddress hostname_only_addr(address().hostname(), + address().port()); + copy.set_address(hostname_only_addr); + } } if (filter_related_address) { copy.set_related_address( diff --git a/api/data_channel_interface.h b/api/data_channel_interface.h index 5b2b1263ab..56bb6c98fb 100644 --- a/api/data_channel_interface.h +++ b/api/data_channel_interface.h @@ -44,11 +44,13 @@ struct DataChannelInit { // // Cannot be set along with |maxRetransmits|. // This is called |maxPacketLifeTime| in the WebRTC JS API. + // Negative values are ignored, and positive values are clamped to [0-65535] absl::optional maxRetransmitTime; // The max number of retransmissions. // // Cannot be set along with |maxRetransmitTime|. + // Negative values are ignored, and positive values are clamped to [0-65535] absl::optional maxRetransmits; // This is set by the application and opaque to the WebRTC implementation. diff --git a/api/g3doc/index.md b/api/g3doc/index.md new file mode 100644 index 0000000000..49637d191a --- /dev/null +++ b/api/g3doc/index.md @@ -0,0 +1,51 @@ + + + +# The WebRTC API + +The public API of the WebRTC library consists of the api/ directory and +its subdirectories. No other files should be depended on by webrtc users. + +Before starting to code against the API, it is important to understand +some basic concepts, such as: + +* Memory management, including webrtc's reference counted objects +* [Thread management](threading_design.md) + +## Using WebRTC through the PeerConnection class + +The +[PeerConnectionInterface](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/peer_connection_interface.h?q=webrtc::PeerConnectionInterface) +class is the recommended way to use the WebRTC library. + +It is closely modeled after the Javascript API documented in the [WebRTC +specification](https://w3c.github.io/webrtc-pc/). + +PeerConnections are created using the [PeerConnectionFactoryInterface](https://source.chromium.org/search?q=webrtc::PeerConnectionFactoryInterface). + +There are two levels of customization available: + +* Pass a PeerConnectionFactoryDependencies object to the function that creates + a PeerConnectionFactory. This object defines factories for a lot of internal + objects inside the PeerConnection, so that users can override them. + All PeerConnections using this interface will have the same options. +* Pass a PeerConnectionInterface::RTCConfiguration object to the + CreatePeerConnectionOrError() function on the + PeerConnectionFactoryInterface. These customizations will apply only to a + single PeerConnection. + +Most functions on the PeerConnection interface are asynchronous, and take a +callback that is executed when the function is finished. The callbacks are +mostly called on the thread that is passed as the "signaling thread" field of +the PeerConnectionFactoryDependencies, or the thread that called +PeerConnectionFactory::CreatePeerConnectionOrError() if no thread is given. + +See each class' module documentation for details. + +## Using WebRTC components without the PeerConnection class + +This needs to be done carefully, and in consultation with the WebRTC team. There +are non-obvious dependencies between many of the components. + + + diff --git a/api/DESIGN.md b/api/g3doc/threading_design.md similarity index 93% rename from api/DESIGN.md rename to api/g3doc/threading_design.md index 0a2f36eb2b..20c3539b22 100644 --- a/api/DESIGN.md +++ b/api/g3doc/threading_design.md @@ -1,4 +1,6 @@ -# Design considerations + + +# API Threading Design considerations The header files in this directory form the API to the WebRTC library that is intended for client applications' use. @@ -30,12 +32,12 @@ the two calls. sequential execution - other names for such constructs are task runners and sequenced task queues. -# Client threads and callbacks +## Client threads and callbacks At the moment, the API does not give any guarantee on which thread* the callbacks and events are called on. So it's best to write all callback and event handlers like this (pseudocode): -
+```
 void ObserverClass::Handler(event) {
   if (!called_on_client_thread()) {
     dispatch_to_client_thread(bind(handler(event)));
@@ -43,11 +45,11 @@ void ObserverClass::Handler(event) {
   }
   // Process event, we're now on the right thread
 }
-
+``` In the future, the implementation may change to always call the callbacks and event handlers on the client thread. -# Implementation considerations +## Implementation considerations The C++ classes that are part of the public API are also used to derive classes that form part of the implementation. diff --git a/api/ice_transport_factory.cc b/api/ice_transport_factory.cc index c32d7d2e11..26ef88bf1c 100644 --- a/api/ice_transport_factory.cc +++ b/api/ice_transport_factory.cc @@ -14,6 +14,7 @@ #include #include "p2p/base/ice_transport_internal.h" +#include "p2p/base/p2p_constants.h" #include "p2p/base/p2p_transport_channel.h" #include "p2p/base/port_allocator.h" #include "rtc_base/thread.h" @@ -41,7 +42,7 @@ class IceTransportWithTransportChannel : public IceTransportInterface { } private: - const rtc::ThreadChecker thread_checker_{}; + const SequenceChecker thread_checker_{}; const std::unique_ptr internal_ RTC_GUARDED_BY(thread_checker_); }; @@ -57,10 +58,18 @@ rtc::scoped_refptr CreateIceTransport( rtc::scoped_refptr CreateIceTransport( IceTransportInit init) { - return new rtc::RefCountedObject( - std::make_unique( - "", 0, init.port_allocator(), init.async_resolver_factory(), - init.event_log())); + if (init.async_resolver_factory()) { + // Backwards compatibility mode + return rtc::make_ref_counted( + std::make_unique( + "", cricket::ICE_CANDIDATE_COMPONENT_RTP, init.port_allocator(), + init.async_resolver_factory(), init.event_log())); + } else { + return rtc::make_ref_counted( + cricket::P2PTransportChannel::Create( + "", cricket::ICE_CANDIDATE_COMPONENT_RTP, init.port_allocator(), + init.async_dns_resolver_factory(), init.event_log())); + } } } // namespace webrtc diff --git a/api/ice_transport_interface.h b/api/ice_transport_interface.h index d2f1edc012..a3b364c87a 100644 --- a/api/ice_transport_interface.h +++ b/api/ice_transport_interface.h @@ -13,6 +13,7 @@ #include +#include "api/async_dns_resolver.h" #include "api/async_resolver_factory.h" #include "api/rtc_error.h" #include "api/rtc_event_log/rtc_event_log.h" @@ -52,11 +53,21 @@ struct IceTransportInit final { port_allocator_ = port_allocator; } + AsyncDnsResolverFactoryInterface* async_dns_resolver_factory() { + return async_dns_resolver_factory_; + } + void set_async_dns_resolver_factory( + AsyncDnsResolverFactoryInterface* async_dns_resolver_factory) { + RTC_DCHECK(!async_resolver_factory_); + async_dns_resolver_factory_ = async_dns_resolver_factory; + } AsyncResolverFactory* async_resolver_factory() { return async_resolver_factory_; } + ABSL_DEPRECATED("bugs.webrtc.org/12598") void set_async_resolver_factory( AsyncResolverFactory* async_resolver_factory) { + RTC_DCHECK(!async_dns_resolver_factory_); async_resolver_factory_ = async_resolver_factory; } @@ -65,8 +76,11 @@ struct IceTransportInit final { private: cricket::PortAllocator* port_allocator_ = nullptr; + AsyncDnsResolverFactoryInterface* async_dns_resolver_factory_ = nullptr; + // For backwards compatibility. Only one resolver factory can be set. AsyncResolverFactory* async_resolver_factory_ = nullptr; RtcEventLog* event_log_ = nullptr; + // TODO(https://crbug.com/webrtc/12657): Redesign to have const members. }; // TODO(qingsi): The factory interface is defined in this file instead of its diff --git a/api/jsep.h b/api/jsep.h index dcf821369e..b56cf1d15b 100644 --- a/api/jsep.h +++ b/api/jsep.h @@ -28,7 +28,6 @@ #include "absl/types/optional.h" #include "api/rtc_error.h" -#include "rtc_base/deprecation.h" #include "rtc_base/ref_count.h" #include "rtc_base/system/rtc_export.h" diff --git a/api/jsep_session_description.h b/api/jsep_session_description.h index e13d85e71c..70ac9398a6 100644 --- a/api/jsep_session_description.h +++ b/api/jsep_session_description.h @@ -23,7 +23,6 @@ #include "api/jsep.h" #include "api/jsep_ice_candidate.h" #include "rtc_base/constructor_magic.h" -#include "rtc_base/deprecation.h" namespace cricket { class SessionDescription; diff --git a/api/neteq/neteq.h b/api/neteq/neteq.h index 9781377ca8..ea7079e369 100644 --- a/api/neteq/neteq.h +++ b/api/neteq/neteq.h @@ -214,11 +214,15 @@ class NetEq { // |data_| in |audio_frame| is not written, but should be interpreted as being // all zeros. For testing purposes, an override can be supplied in the // |action_override| argument, which will cause NetEq to take this action - // next, instead of the action it would normally choose. + // next, instead of the action it would normally choose. An optional output + // argument for fetching the current sample rate can be provided, which + // will return the same value as last_output_sample_rate_hz() but will avoid + // additional synchronization. // Returns kOK on success, or kFail in case of an error. virtual int GetAudio( AudioFrame* audio_frame, bool* muted, + int* current_sample_rate_hz = nullptr, absl::optional action_override = absl::nullopt) = 0; // Replaces the current set of decoders with the given one. diff --git a/api/peer_connection_interface.cc b/api/peer_connection_interface.cc index e1d94dd8c7..230731c42d 100644 --- a/api/peer_connection_interface.cc +++ b/api/peer_connection_interface.cc @@ -10,8 +10,7 @@ #include "api/peer_connection_interface.h" -#include "api/dtls_transport_interface.h" -#include "api/sctp_transport_interface.h" +#include namespace webrtc { @@ -77,14 +76,27 @@ PeerConnectionFactoryInterface::CreatePeerConnection( std::unique_ptr allocator, std::unique_ptr cert_generator, PeerConnectionObserver* observer) { - return nullptr; + PeerConnectionDependencies dependencies(observer); + dependencies.allocator = std::move(allocator); + dependencies.cert_generator = std::move(cert_generator); + auto result = + CreatePeerConnectionOrError(configuration, std::move(dependencies)); + if (!result.ok()) { + return nullptr; + } + return result.MoveValue(); } rtc::scoped_refptr PeerConnectionFactoryInterface::CreatePeerConnection( const PeerConnectionInterface::RTCConfiguration& configuration, PeerConnectionDependencies dependencies) { - return nullptr; + auto result = + CreatePeerConnectionOrError(configuration, std::move(dependencies)); + if (!result.ok()) { + return nullptr; + } + return result.MoveValue(); } RTCErrorOr> diff --git a/api/peer_connection_interface.h b/api/peer_connection_interface.h index 92d965b328..5499b7d87c 100644 --- a/api/peer_connection_interface.h +++ b/api/peer_connection_interface.h @@ -67,19 +67,25 @@ #ifndef API_PEER_CONNECTION_INTERFACE_H_ #define API_PEER_CONNECTION_INTERFACE_H_ +#include #include +#include #include #include #include +#include "absl/base/attributes.h" +#include "absl/types/optional.h" #include "api/adaptation/resource.h" +#include "api/async_dns_resolver.h" #include "api/async_resolver_factory.h" #include "api/audio/audio_mixer.h" #include "api/audio_codecs/audio_decoder_factory.h" #include "api/audio_codecs/audio_encoder_factory.h" #include "api/audio_options.h" #include "api/call/call_factory_interface.h" +#include "api/candidate.h" #include "api/crypto/crypto_options.h" #include "api/data_channel_interface.h" #include "api/dtls_transport_interface.h" @@ -87,15 +93,18 @@ #include "api/ice_transport_interface.h" #include "api/jsep.h" #include "api/media_stream_interface.h" +#include "api/media_types.h" #include "api/neteq/neteq_factory.h" #include "api/network_state_predictor.h" #include "api/packet_socket_factory.h" #include "api/rtc_error.h" #include "api/rtc_event_log/rtc_event_log_factory_interface.h" #include "api/rtc_event_log_output.h" +#include "api/rtp_parameters.h" #include "api/rtp_receiver_interface.h" #include "api/rtp_sender_interface.h" #include "api/rtp_transceiver_interface.h" +#include "api/scoped_refptr.h" #include "api/sctp_transport_interface.h" #include "api/set_local_description_observer_interface.h" #include "api/set_remote_description_observer_interface.h" @@ -108,19 +117,26 @@ #include "api/transport/sctp_transport_factory_interface.h" #include "api/transport/webrtc_key_value_config.h" #include "api/turn_customizer.h" +#include "api/video/video_bitrate_allocator_factory.h" +#include "call/rtp_transport_controller_send_factory_interface.h" #include "media/base/media_config.h" #include "media/base/media_engine.h" // TODO(bugs.webrtc.org/7447): We plan to provide a way to let applications // inject a PacketSocketFactory and/or NetworkManager, and not expose -// PortAllocator in the PeerConnection api. +// PortAllocator in the PeerConnection api. This will let us remove nogncheck. +#include "p2p/base/port.h" // nogncheck #include "p2p/base/port_allocator.h" // nogncheck +#include "rtc_base/network.h" +#include "rtc_base/network_constants.h" #include "rtc_base/network_monitor_factory.h" +#include "rtc_base/ref_count.h" #include "rtc_base/rtc_certificate.h" #include "rtc_base/rtc_certificate_generator.h" #include "rtc_base/socket_address.h" #include "rtc_base/ssl_certificate.h" #include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/system/rtc_export.h" +#include "rtc_base/thread.h" namespace rtc { class Thread; @@ -403,12 +419,6 @@ class RTC_EXPORT PeerConnectionInterface : public rtc::RefCountInterface { // from consideration for gathering ICE candidates. bool disable_link_local_networks = false; - // If set to true, use RTP data channels instead of SCTP. - // TODO(deadbeef): Remove this. We no longer commit to supporting RTP data - // channels, though some applications are still working on moving off of - // them. - bool enable_rtp_data_channel = false; - // Minimum bitrate at which screencast video tracks will be encoded at. // This means adding padding bits up to this bitrate, which can help // when switching from a static scene to one with motion. @@ -621,12 +631,8 @@ class RTC_EXPORT PeerConnectionInterface : public rtc::RefCountInterface { absl::optional crypto_options; // Configure if we should include the SDP attribute extmap-allow-mixed in - // our offer. Although we currently do support this, it's not included in - // our offer by default due to a previous bug that caused the SDP parser to - // abort parsing if this attribute was present. This is fixed in Chrome 71. - // TODO(webrtc:9985): Change default to true once sufficient time has - // passed. - bool offer_extmap_allow_mixed = false; + // our offer on session level. + bool offer_extmap_allow_mixed = true; // TURN logging identifier. // This identifier is added to a TURN allocation @@ -643,6 +649,10 @@ class RTC_EXPORT PeerConnectionInterface : public rtc::RefCountInterface { // The delay before doing a usage histogram report for long-lived // PeerConnections. Used for testing only. absl::optional report_usage_pattern_delay_ms; + + // The ping interval (ms) when the connection is stable and writable. This + // parameter overrides the default value in the ICE implementation if set. + absl::optional stable_writable_connection_ping_interval_ms; // // Don't forget to update operator== if adding something. // @@ -909,9 +919,24 @@ class RTC_EXPORT PeerConnectionInterface : public rtc::RefCountInterface { // Also, calling CreateDataChannel is the only way to get a data "m=" section // in SDP, so it should be done before CreateOffer is called, if the // application plans to use data channels. + virtual RTCErrorOr> + CreateDataChannelOrError(const std::string& label, + const DataChannelInit* config) { + return RTCError(RTCErrorType::INTERNAL_ERROR, "dummy function called"); + } + // TODO(crbug.com/788659): Remove "virtual" below and default implementation + // above once mock in Chrome is fixed. + ABSL_DEPRECATED("Use CreateDataChannelOrError") virtual rtc::scoped_refptr CreateDataChannel( const std::string& label, - const DataChannelInit* config) = 0; + const DataChannelInit* config) { + auto result = CreateDataChannelOrError(label, config); + if (!result.ok()) { + return nullptr; + } else { + return result.MoveValue(); + } + } // NOTE: For the following 6 methods, it's only safe to dereference the // SessionDescriptionInterface on signaling_thread() (for example, calling @@ -1060,7 +1085,10 @@ class RTC_EXPORT PeerConnectionInterface : public rtc::RefCountInterface { // Removes a group of remote candidates from the ICE agent. Needed mainly for // continual gathering, to avoid an ever-growing list of candidates as - // networks come and go. + // networks come and go. Note that the candidates' transport_name must be set + // to the MID of the m= section that generated the candidate. + // TODO(bugs.webrtc.org/8395): Use IceCandidateInterface instead of + // cricket::Candidate, which would avoid the transport_name oddity. virtual bool RemoveIceCandidates( const std::vector& candidates) = 0; @@ -1078,13 +1106,11 @@ class RTC_EXPORT PeerConnectionInterface : public rtc::RefCountInterface { // playout of the underlying audio device but starts a task which will poll // for audio data every 10ms to ensure that audio processing happens and the // audio statistics are updated. - // TODO(henrika): deprecate and remove this. virtual void SetAudioPlayout(bool playout) {} // Enable/disable recording of transmitted audio streams. Enabled by default. // Note that even if recording is enabled, streams will only be recorded if // the appropriate SDP is also applied. - // TODO(henrika): deprecate and remove this. virtual void SetAudioRecording(bool recording) {} // Looks up the DtlsTransport associated with a MID value. @@ -1323,6 +1349,10 @@ struct RTC_EXPORT PeerConnectionDependencies final { // packet_socket_factory, not both. std::unique_ptr allocator; std::unique_ptr packet_socket_factory; + // Factory for creating resolvers that look up hostnames in DNS + std::unique_ptr + async_dns_resolver_factory; + // Deprecated - use async_dns_resolver_factory std::unique_ptr async_resolver_factory; std::unique_ptr ice_transport_factory; std::unique_ptr cert_generator; @@ -1369,6 +1399,8 @@ struct RTC_EXPORT PeerConnectionFactoryDependencies final { std::unique_ptr neteq_factory; std::unique_ptr sctp_factory; std::unique_ptr trials; + std::unique_ptr + transport_controller_send_factory; }; // PeerConnectionFactoryInterface is the factory interface used for creating @@ -1396,10 +1428,6 @@ class RTC_EXPORT PeerConnectionFactoryInterface // testing/debugging. bool disable_encryption = false; - // Deprecated. The only effect of setting this to true is that - // CreateDataChannel will fail, which is not that useful. - bool disable_sctp_data_channels = false; - // If set to true, any platform-supported network monitoring capability // won't be used, and instead networks will only be updated via polling. // @@ -1434,6 +1462,7 @@ class RTC_EXPORT PeerConnectionFactoryInterface PeerConnectionDependencies dependencies); // Deprecated creator - does not return an error code on error. // TODO(bugs.webrtc.org:12238): Deprecate and remove. + ABSL_DEPRECATED("Use CreatePeerConnectionOrError") virtual rtc::scoped_refptr CreatePeerConnection( const PeerConnectionInterface::RTCConfiguration& configuration, PeerConnectionDependencies dependencies); @@ -1447,6 +1476,7 @@ class RTC_EXPORT PeerConnectionFactoryInterface // responsibility of the caller to delete it. It can be safely deleted after // Close has been called on the returned PeerConnection, which ensures no // more observer callbacks will be invoked. + ABSL_DEPRECATED("Use CreatePeerConnectionOrError") virtual rtc::scoped_refptr CreatePeerConnection( const PeerConnectionInterface::RTCConfiguration& configuration, std::unique_ptr allocator, diff --git a/api/ref_counted_base.h b/api/ref_counted_base.h index a1761db851..931cb20762 100644 --- a/api/ref_counted_base.h +++ b/api/ref_counted_base.h @@ -10,8 +10,9 @@ #ifndef API_REF_COUNTED_BASE_H_ #define API_REF_COUNTED_BASE_H_ +#include + #include "rtc_base/constructor_magic.h" -#include "rtc_base/ref_count.h" #include "rtc_base/ref_counter.h" namespace rtc { @@ -30,6 +31,10 @@ class RefCountedBase { } protected: + // Provided for internal webrtc subclasses for corner cases where it's + // necessary to know whether or not a reference is exclusively held. + bool HasOneRef() const { return ref_count_.HasOneRef(); } + virtual ~RefCountedBase() = default; private: @@ -38,6 +43,55 @@ class RefCountedBase { RTC_DISALLOW_COPY_AND_ASSIGN(RefCountedBase); }; +// Template based version of `RefCountedBase` for simple implementations that do +// not need (or want) destruction via virtual destructor or the overhead of a +// vtable. +// +// To use: +// struct MyInt : public rtc::RefCountedNonVirtual { +// int foo_ = 0; +// }; +// +// rtc::scoped_refptr my_int(new MyInt()); +// +// sizeof(MyInt) on a 32 bit system would then be 8, int + refcount and no +// vtable generated. +template +class RefCountedNonVirtual { + public: + RefCountedNonVirtual() = default; + + void AddRef() const { ref_count_.IncRef(); } + RefCountReleaseStatus Release() const { + // If you run into this assert, T has virtual methods. There are two + // options: + // 1) The class doesn't actually need virtual methods, the type is complete + // so the virtual attribute(s) can be removed. + // 2) The virtual methods are a part of the design of the class. In this + // case you can consider using `RefCountedBase` instead or alternatively + // use `rtc::RefCountedObject`. + static_assert(!std::is_polymorphic::value, + "T has virtual methods. RefCountedBase is a better fit."); + const auto status = ref_count_.DecRef(); + if (status == RefCountReleaseStatus::kDroppedLastRef) { + delete static_cast(this); + } + return status; + } + + protected: + // Provided for internal webrtc subclasses for corner cases where it's + // necessary to know whether or not a reference is exclusively held. + bool HasOneRef() const { return ref_count_.HasOneRef(); } + + ~RefCountedNonVirtual() = default; + + private: + mutable webrtc::webrtc_impl::RefCounter ref_count_{0}; + + RTC_DISALLOW_COPY_AND_ASSIGN(RefCountedNonVirtual); +}; + } // namespace rtc #endif // API_REF_COUNTED_BASE_H_ diff --git a/api/rtc_event_log/rtc_event.cc b/api/rtc_event_log/rtc_event.cc index 81e6a4e6da..631188b915 100644 --- a/api/rtc_event_log/rtc_event.cc +++ b/api/rtc_event_log/rtc_event.cc @@ -14,6 +14,6 @@ namespace webrtc { -RtcEvent::RtcEvent() : timestamp_us_(rtc::TimeMicros()) {} +RtcEvent::RtcEvent() : timestamp_us_(rtc::TimeMillis() * 1000) {} } // namespace webrtc diff --git a/api/rtp_headers.h b/api/rtp_headers.h index b9a97c885d..cf3d909499 100644 --- a/api/rtp_headers.h +++ b/api/rtp_headers.h @@ -144,13 +144,12 @@ struct RTPHeaderExtension { VideoPlayoutDelay playout_delay; // For identification of a stream when ssrc is not signaled. See - // https://tools.ietf.org/html/draft-ietf-avtext-rid-09 - // TODO(danilchap): Update url from draft to release version. + // https://tools.ietf.org/html/rfc8852 std::string stream_id; std::string repaired_stream_id; // For identifying the media section used to interpret this RTP packet. See - // https://tools.ietf.org/html/draft-ietf-mmusic-sdp-bundle-negotiation-38 + // https://tools.ietf.org/html/rfc8843 std::string mid; absl::optional color_space; diff --git a/api/rtp_packet_info.cc b/api/rtp_packet_info.cc index a9ebd9df48..db818f7657 100644 --- a/api/rtp_packet_info.cc +++ b/api/rtp_packet_info.cc @@ -16,7 +16,7 @@ namespace webrtc { RtpPacketInfo::RtpPacketInfo() - : ssrc_(0), rtp_timestamp_(0), receive_time_ms_(-1) {} + : ssrc_(0), rtp_timestamp_(0), receive_time_(Timestamp::MinusInfinity()) {} RtpPacketInfo::RtpPacketInfo( uint32_t ssrc, @@ -24,19 +24,19 @@ RtpPacketInfo::RtpPacketInfo( uint32_t rtp_timestamp, absl::optional audio_level, absl::optional absolute_capture_time, - int64_t receive_time_ms) + Timestamp receive_time) : ssrc_(ssrc), csrcs_(std::move(csrcs)), rtp_timestamp_(rtp_timestamp), audio_level_(audio_level), absolute_capture_time_(absolute_capture_time), - receive_time_ms_(receive_time_ms) {} + receive_time_(receive_time) {} RtpPacketInfo::RtpPacketInfo(const RTPHeader& rtp_header, - int64_t receive_time_ms) + Timestamp receive_time) : ssrc_(rtp_header.ssrc), rtp_timestamp_(rtp_header.timestamp), - receive_time_ms_(receive_time_ms) { + receive_time_(receive_time) { const auto& extension = rtp_header.extension; const auto csrcs_count = std::min(rtp_header.numCSRCs, kRtpCsrcSize); @@ -49,12 +49,31 @@ RtpPacketInfo::RtpPacketInfo(const RTPHeader& rtp_header, absolute_capture_time_ = extension.absolute_capture_time; } +RtpPacketInfo::RtpPacketInfo( + uint32_t ssrc, + std::vector csrcs, + uint32_t rtp_timestamp, + absl::optional audio_level, + absl::optional absolute_capture_time, + int64_t receive_time_ms) + : RtpPacketInfo(ssrc, + csrcs, + rtp_timestamp, + audio_level, + absolute_capture_time, + Timestamp::Millis(receive_time_ms)) {} +RtpPacketInfo::RtpPacketInfo(const RTPHeader& rtp_header, + int64_t receive_time_ms) + : RtpPacketInfo(rtp_header, Timestamp::Millis(receive_time_ms)) {} + bool operator==(const RtpPacketInfo& lhs, const RtpPacketInfo& rhs) { return (lhs.ssrc() == rhs.ssrc()) && (lhs.csrcs() == rhs.csrcs()) && (lhs.rtp_timestamp() == rhs.rtp_timestamp()) && (lhs.audio_level() == rhs.audio_level()) && (lhs.absolute_capture_time() == rhs.absolute_capture_time()) && - (lhs.receive_time_ms() == rhs.receive_time_ms()); + (lhs.receive_time() == rhs.receive_time() && + (lhs.local_capture_clock_offset() == + rhs.local_capture_clock_offset())); } } // namespace webrtc diff --git a/api/rtp_packet_info.h b/api/rtp_packet_info.h index 639ba32770..605620d638 100644 --- a/api/rtp_packet_info.h +++ b/api/rtp_packet_info.h @@ -17,6 +17,7 @@ #include "absl/types/optional.h" #include "api/rtp_headers.h" +#include "api/units/timestamp.h" #include "rtc_base/system/rtc_export.h" namespace webrtc { @@ -35,8 +36,18 @@ class RTC_EXPORT RtpPacketInfo { uint32_t rtp_timestamp, absl::optional audio_level, absl::optional absolute_capture_time, - int64_t receive_time_ms); + Timestamp receive_time); + + RtpPacketInfo(const RTPHeader& rtp_header, Timestamp receive_time); + // TODO(bugs.webrtc.org/12722): Deprecated, remove once downstream projects + // are updated. + RtpPacketInfo(uint32_t ssrc, + std::vector csrcs, + uint32_t rtp_timestamp, + absl::optional audio_level, + absl::optional absolute_capture_time, + int64_t receive_time_ms); RtpPacketInfo(const RTPHeader& rtp_header, int64_t receive_time_ms); RtpPacketInfo(const RtpPacketInfo& other) = default; @@ -64,8 +75,19 @@ class RTC_EXPORT RtpPacketInfo { absolute_capture_time_ = value; } - int64_t receive_time_ms() const { return receive_time_ms_; } - void set_receive_time_ms(int64_t value) { receive_time_ms_ = value; } + const absl::optional& local_capture_clock_offset() const { + return local_capture_clock_offset_; + } + + void set_local_capture_clock_offset(const absl::optional& value) { + local_capture_clock_offset_ = value; + } + + Timestamp receive_time() const { return receive_time_; } + void set_receive_time(Timestamp value) { receive_time_ = value; } + // TODO(bugs.webrtc.org/12722): Deprecated, remove once downstream projects + // are updated. + int64_t receive_time_ms() const { return receive_time_.ms(); } private: // Fields from the RTP header: @@ -80,10 +102,19 @@ class RTC_EXPORT RtpPacketInfo { // Fields from the Absolute Capture Time header extension: // http://www.webrtc.org/experiments/rtp-hdrext/abs-capture-time + // To not be confused with |local_capture_clock_offset_|, the + // |estimated_capture_clock_offset| in |absolute_capture_time_| should + // represent the clock offset between a remote sender and the capturer, and + // thus equals to the corresponding values in the received RTP packets, + // subjected to possible interpolations. absl::optional absolute_capture_time_; + // Clock offset against capturer's clock. Should be derived from the estimated + // capture clock offset defined in the Absolute Capture Time header extension. + absl::optional local_capture_clock_offset_; + // Local |webrtc::Clock|-based timestamp of when the packet was received. - int64_t receive_time_ms_; + Timestamp receive_time_; }; bool operator==(const RtpPacketInfo& lhs, const RtpPacketInfo& rhs); diff --git a/api/rtp_packet_info_unittest.cc b/api/rtp_packet_info_unittest.cc index fe79f6df3c..601d34f49e 100644 --- a/api/rtp_packet_info_unittest.cc +++ b/api/rtp_packet_info_unittest.cc @@ -37,7 +37,7 @@ TEST(RtpPacketInfoTest, Ssrc) { rhs = RtpPacketInfo(); EXPECT_NE(rhs.ssrc(), value); - rhs = RtpPacketInfo(value, {}, {}, {}, {}, {}); + rhs = RtpPacketInfo(value, {}, {}, {}, {}, Timestamp::Millis(0)); EXPECT_EQ(rhs.ssrc(), value); } @@ -64,7 +64,7 @@ TEST(RtpPacketInfoTest, Csrcs) { rhs = RtpPacketInfo(); EXPECT_NE(rhs.csrcs(), value); - rhs = RtpPacketInfo({}, value, {}, {}, {}, {}); + rhs = RtpPacketInfo({}, value, {}, {}, {}, Timestamp::Millis(0)); EXPECT_EQ(rhs.csrcs(), value); } @@ -91,7 +91,7 @@ TEST(RtpPacketInfoTest, RtpTimestamp) { rhs = RtpPacketInfo(); EXPECT_NE(rhs.rtp_timestamp(), value); - rhs = RtpPacketInfo({}, {}, value, {}, {}, {}); + rhs = RtpPacketInfo({}, {}, value, {}, {}, Timestamp::Millis(0)); EXPECT_EQ(rhs.rtp_timestamp(), value); } @@ -118,7 +118,7 @@ TEST(RtpPacketInfoTest, AudioLevel) { rhs = RtpPacketInfo(); EXPECT_NE(rhs.audio_level(), value); - rhs = RtpPacketInfo({}, {}, {}, value, {}, {}); + rhs = RtpPacketInfo({}, {}, {}, value, {}, Timestamp::Millis(0)); EXPECT_EQ(rhs.audio_level(), value); } @@ -145,12 +145,41 @@ TEST(RtpPacketInfoTest, AbsoluteCaptureTime) { rhs = RtpPacketInfo(); EXPECT_NE(rhs.absolute_capture_time(), value); - rhs = RtpPacketInfo({}, {}, {}, {}, value, {}); + rhs = RtpPacketInfo({}, {}, {}, {}, value, Timestamp::Millis(0)); EXPECT_EQ(rhs.absolute_capture_time(), value); } +TEST(RtpPacketInfoTest, LocalCaptureClockOffset) { + RtpPacketInfo lhs; + RtpPacketInfo rhs; + + EXPECT_TRUE(lhs == rhs); + EXPECT_FALSE(lhs != rhs); + + const absl::optional value = 10; + rhs.set_local_capture_clock_offset(value); + EXPECT_EQ(rhs.local_capture_clock_offset(), value); + + EXPECT_FALSE(lhs == rhs); + EXPECT_TRUE(lhs != rhs); + + lhs = rhs; + + EXPECT_TRUE(lhs == rhs); + EXPECT_FALSE(lhs != rhs); + + // Default local capture clock offset is null. + rhs = RtpPacketInfo(); + EXPECT_EQ(rhs.local_capture_clock_offset(), absl::nullopt); + + // Default local capture clock offset is null. + rhs = RtpPacketInfo({}, {}, {}, {}, AbsoluteCaptureTime{12, 34}, + Timestamp::Millis(0)); + EXPECT_EQ(rhs.local_capture_clock_offset(), absl::nullopt); +} + TEST(RtpPacketInfoTest, ReceiveTimeMs) { - const int64_t value = 8868963877546349045LL; + const Timestamp timestamp = Timestamp::Micros(8868963877546349045LL); RtpPacketInfo lhs; RtpPacketInfo rhs; @@ -158,8 +187,8 @@ TEST(RtpPacketInfoTest, ReceiveTimeMs) { EXPECT_TRUE(lhs == rhs); EXPECT_FALSE(lhs != rhs); - rhs.set_receive_time_ms(value); - EXPECT_EQ(rhs.receive_time_ms(), value); + rhs.set_receive_time(timestamp); + EXPECT_EQ(rhs.receive_time(), timestamp); EXPECT_FALSE(lhs == rhs); EXPECT_TRUE(lhs != rhs); @@ -170,10 +199,10 @@ TEST(RtpPacketInfoTest, ReceiveTimeMs) { EXPECT_FALSE(lhs != rhs); rhs = RtpPacketInfo(); - EXPECT_NE(rhs.receive_time_ms(), value); + EXPECT_NE(rhs.receive_time(), timestamp); - rhs = RtpPacketInfo({}, {}, {}, {}, {}, value); - EXPECT_EQ(rhs.receive_time_ms(), value); + rhs = RtpPacketInfo({}, {}, {}, {}, {}, timestamp); + EXPECT_EQ(rhs.receive_time(), timestamp); } } // namespace webrtc diff --git a/api/rtp_packet_infos_unittest.cc b/api/rtp_packet_infos_unittest.cc index ce502ac378..e83358fc17 100644 --- a/api/rtp_packet_infos_unittest.cc +++ b/api/rtp_packet_infos_unittest.cc @@ -27,9 +27,12 @@ RtpPacketInfos::vector_type ToVector(Iterator begin, Iterator end) { } // namespace TEST(RtpPacketInfosTest, BasicFunctionality) { - RtpPacketInfo p0(123, {1, 2}, 89, 5, AbsoluteCaptureTime{45, 78}, 7); - RtpPacketInfo p1(456, {3, 4}, 89, 4, AbsoluteCaptureTime{13, 21}, 1); - RtpPacketInfo p2(789, {5, 6}, 88, 1, AbsoluteCaptureTime{99, 78}, 7); + RtpPacketInfo p0(123, {1, 2}, 89, 5, AbsoluteCaptureTime{45, 78}, + Timestamp::Millis(7)); + RtpPacketInfo p1(456, {3, 4}, 89, 4, AbsoluteCaptureTime{13, 21}, + Timestamp::Millis(1)); + RtpPacketInfo p2(789, {5, 6}, 88, 1, AbsoluteCaptureTime{99, 78}, + Timestamp::Millis(7)); RtpPacketInfos x({p0, p1, p2}); @@ -52,9 +55,12 @@ TEST(RtpPacketInfosTest, BasicFunctionality) { } TEST(RtpPacketInfosTest, CopyShareData) { - RtpPacketInfo p0(123, {1, 2}, 89, 5, AbsoluteCaptureTime{45, 78}, 7); - RtpPacketInfo p1(456, {3, 4}, 89, 4, AbsoluteCaptureTime{13, 21}, 1); - RtpPacketInfo p2(789, {5, 6}, 88, 1, AbsoluteCaptureTime{99, 78}, 7); + RtpPacketInfo p0(123, {1, 2}, 89, 5, AbsoluteCaptureTime{45, 78}, + Timestamp::Millis(7)); + RtpPacketInfo p1(456, {3, 4}, 89, 4, AbsoluteCaptureTime{13, 21}, + Timestamp::Millis(1)); + RtpPacketInfo p2(789, {5, 6}, 88, 1, AbsoluteCaptureTime{99, 78}, + Timestamp::Millis(7)); RtpPacketInfos lhs({p0, p1, p2}); RtpPacketInfos rhs = lhs; diff --git a/api/rtp_parameters.cc b/api/rtp_parameters.cc index 92f99e9bb8..5ce6780753 100644 --- a/api/rtp_parameters.cc +++ b/api/rtp_parameters.cc @@ -130,6 +130,7 @@ constexpr char RtpExtension::kColorSpaceUri[]; constexpr char RtpExtension::kMidUri[]; constexpr char RtpExtension::kRidUri[]; constexpr char RtpExtension::kRepairedRidUri[]; +constexpr char RtpExtension::kVideoFrameTrackingIdUri[]; constexpr int RtpExtension::kMinId; constexpr int RtpExtension::kMaxId; @@ -164,67 +165,126 @@ bool RtpExtension::IsSupportedForVideo(absl::string_view uri) { uri == webrtc::RtpExtension::kColorSpaceUri || uri == webrtc::RtpExtension::kRidUri || uri == webrtc::RtpExtension::kRepairedRidUri || - uri == webrtc::RtpExtension::kVideoLayersAllocationUri; + uri == webrtc::RtpExtension::kVideoLayersAllocationUri || + uri == webrtc::RtpExtension::kVideoFrameTrackingIdUri; } bool RtpExtension::IsEncryptionSupported(absl::string_view uri) { - return uri == webrtc::RtpExtension::kAudioLevelUri || - uri == webrtc::RtpExtension::kTimestampOffsetUri || -#if !defined(ENABLE_EXTERNAL_AUTH) - // TODO(jbauch): Figure out a way to always allow "kAbsSendTimeUri" - // here and filter out later if external auth is really used in - // srtpfilter. External auth is used by Chromium and replaces the - // extension header value of "kAbsSendTimeUri", so it must not be - // encrypted (which can't be done by Chromium). - uri == webrtc::RtpExtension::kAbsSendTimeUri || + return +#if defined(ENABLE_EXTERNAL_AUTH) + // TODO(jbauch): Figure out a way to always allow "kAbsSendTimeUri" + // here and filter out later if external auth is really used in + // srtpfilter. External auth is used by Chromium and replaces the + // extension header value of "kAbsSendTimeUri", so it must not be + // encrypted (which can't be done by Chromium). + uri != webrtc::RtpExtension::kAbsSendTimeUri && #endif - uri == webrtc::RtpExtension::kAbsoluteCaptureTimeUri || - uri == webrtc::RtpExtension::kVideoRotationUri || - uri == webrtc::RtpExtension::kTransportSequenceNumberUri || - uri == webrtc::RtpExtension::kTransportSequenceNumberV2Uri || - uri == webrtc::RtpExtension::kPlayoutDelayUri || - uri == webrtc::RtpExtension::kVideoContentTypeUri || - uri == webrtc::RtpExtension::kMidUri || - uri == webrtc::RtpExtension::kRidUri || - uri == webrtc::RtpExtension::kRepairedRidUri || - uri == webrtc::RtpExtension::kVideoLayersAllocationUri; + uri != webrtc::RtpExtension::kEncryptHeaderExtensionsUri; } -const RtpExtension* RtpExtension::FindHeaderExtensionByUri( +// Returns whether a header extension with the given URI exists. +// Note: This does not differentiate between encrypted and non-encrypted +// extensions, so use with care! +static bool HeaderExtensionWithUriExists( const std::vector& extensions, absl::string_view uri) { for (const auto& extension : extensions) { if (extension.uri == uri) { + return true; + } + } + return false; +} + +const RtpExtension* RtpExtension::FindHeaderExtensionByUri( + const std::vector& extensions, + absl::string_view uri, + Filter filter) { + const webrtc::RtpExtension* fallback_extension = nullptr; + for (const auto& extension : extensions) { + if (extension.uri != uri) { + continue; + } + + switch (filter) { + case kDiscardEncryptedExtension: + // We only accept an unencrypted extension. + if (!extension.encrypt) { + return &extension; + } + break; + + case kPreferEncryptedExtension: + // We prefer an encrypted extension but we can fall back to an + // unencrypted extension. + if (extension.encrypt) { + return &extension; + } else { + fallback_extension = &extension; + } + break; + + case kRequireEncryptedExtension: + // We only accept an encrypted extension. + if (extension.encrypt) { + return &extension; + } + break; + } + } + + // Returning fallback extension (if any) + return fallback_extension; +} + +const RtpExtension* RtpExtension::FindHeaderExtensionByUri( + const std::vector& extensions, + absl::string_view uri) { + return FindHeaderExtensionByUri(extensions, uri, kPreferEncryptedExtension); +} + +const RtpExtension* RtpExtension::FindHeaderExtensionByUriAndEncryption( + const std::vector& extensions, + absl::string_view uri, + bool encrypt) { + for (const auto& extension : extensions) { + if (extension.uri == uri && extension.encrypt == encrypt) { return &extension; } } return nullptr; } -std::vector RtpExtension::FilterDuplicateNonEncrypted( - const std::vector& extensions) { +const std::vector RtpExtension::DeduplicateHeaderExtensions( + const std::vector& extensions, + Filter filter) { std::vector filtered; - for (auto extension = extensions.begin(); extension != extensions.end(); - ++extension) { - if (extension->encrypt) { - filtered.push_back(*extension); - continue; - } - // Only add non-encrypted extension if no encrypted with the same URI - // is also present... - if (std::any_of(extension + 1, extensions.end(), - [&](const RtpExtension& check) { - return extension->uri == check.uri; - })) { - continue; + // If we do not discard encrypted extensions, add them first + if (filter != kDiscardEncryptedExtension) { + for (const auto& extension : extensions) { + if (!extension.encrypt) { + continue; + } + if (!HeaderExtensionWithUriExists(filtered, extension.uri)) { + filtered.push_back(extension); + } } + } - // ...and has not been added before. - if (!FindHeaderExtensionByUri(filtered, extension->uri)) { - filtered.push_back(*extension); + // If we do not require encrypted extensions, add missing, non-encrypted + // extensions. + if (filter != kRequireEncryptedExtension) { + for (const auto& extension : extensions) { + if (extension.encrypt) { + continue; + } + if (!HeaderExtensionWithUriExists(filtered, extension.uri)) { + filtered.push_back(extension); + } } } + return filtered; } } // namespace webrtc diff --git a/api/rtp_parameters.h b/api/rtp_parameters.h index df0e7a93b1..a098bad6b0 100644 --- a/api/rtp_parameters.h +++ b/api/rtp_parameters.h @@ -246,6 +246,18 @@ struct RTC_EXPORT RtpHeaderExtensionCapability { // RTP header extension, see RFC8285. struct RTC_EXPORT RtpExtension { + enum Filter { + // Encrypted extensions will be ignored and only non-encrypted extensions + // will be considered. + kDiscardEncryptedExtension, + // Encrypted extensions will be preferred but will fall back to + // non-encrypted extensions if necessary. + kPreferEncryptedExtension, + // Encrypted extensions will be required, so any non-encrypted extensions + // will be discarded. + kRequireEncryptedExtension, + }; + RtpExtension(); RtpExtension(absl::string_view uri, int id); RtpExtension(absl::string_view uri, int id, bool encrypt); @@ -260,17 +272,28 @@ struct RTC_EXPORT RtpExtension { // Return "true" if the given RTP header extension URI may be encrypted. static bool IsEncryptionSupported(absl::string_view uri); - // Returns the named header extension if found among all extensions, - // nullptr otherwise. + // Returns the header extension with the given URI or nullptr if not found. + static const RtpExtension* FindHeaderExtensionByUri( + const std::vector& extensions, + absl::string_view uri, + Filter filter); + ABSL_DEPRECATED( + "Use RtpExtension::FindHeaderExtensionByUri with filter argument") static const RtpExtension* FindHeaderExtensionByUri( const std::vector& extensions, absl::string_view uri); - // Return a list of RTP header extensions with the non-encrypted extensions - // removed if both the encrypted and non-encrypted extension is present for - // the same URI. - static std::vector FilterDuplicateNonEncrypted( - const std::vector& extensions); + // Returns the header extension with the given URI and encrypt parameter, + // if found, otherwise nullptr. + static const RtpExtension* FindHeaderExtensionByUriAndEncryption( + const std::vector& extensions, + absl::string_view uri, + bool encrypt); + + // Returns a list of extensions where any extension URI is unique. + static const std::vector DeduplicateHeaderExtensions( + const std::vector& extensions, + Filter filter); // Encryption of Header Extensions, see RFC 6904 for details: // https://tools.ietf.org/html/rfc6904 @@ -353,6 +376,15 @@ struct RTC_EXPORT RtpExtension { static constexpr char kRepairedRidUri[] = "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id"; + // Header extension to propagate webrtc::VideoFrame id field + static constexpr char kVideoFrameTrackingIdUri[] = + "http://www.webrtc.org/experiments/rtp-hdrext/video-frame-tracking-id"; + + // Header extension for Mixer-to-Client audio levels per CSRC as defined in + // https://tools.ietf.org/html/rfc6465 + static constexpr char kCsrcAudioLevelsUri[] = + "urn:ietf:params:rtp-hdrext:csrc-audio-level"; + // Inclusive min and max IDs for two-byte header extensions and one-byte // header extensions, per RFC8285 Section 4.2-4.3. static constexpr int kMinId = 1; diff --git a/api/rtp_parameters_unittest.cc b/api/rtp_parameters_unittest.cc index 5928cbda63..51ad426748 100644 --- a/api/rtp_parameters_unittest.cc +++ b/api/rtp_parameters_unittest.cc @@ -23,28 +23,249 @@ static const RtpExtension kExtension1(kExtensionUri1, 1); static const RtpExtension kExtension1Encrypted(kExtensionUri1, 10, true); static const RtpExtension kExtension2(kExtensionUri2, 2); -TEST(RtpExtensionTest, FilterDuplicateNonEncrypted) { +TEST(RtpExtensionTest, DeduplicateHeaderExtensions) { std::vector extensions; std::vector filtered; + extensions.clear(); + extensions.push_back(kExtension1); + extensions.push_back(kExtension1Encrypted); + filtered = RtpExtension::DeduplicateHeaderExtensions( + extensions, RtpExtension::Filter::kDiscardEncryptedExtension); + EXPECT_EQ(1u, filtered.size()); + EXPECT_EQ(std::vector{kExtension1}, filtered); + + extensions.clear(); + extensions.push_back(kExtension1); + extensions.push_back(kExtension1Encrypted); + filtered = RtpExtension::DeduplicateHeaderExtensions( + extensions, RtpExtension::Filter::kPreferEncryptedExtension); + EXPECT_EQ(1u, filtered.size()); + EXPECT_EQ(std::vector{kExtension1Encrypted}, filtered); + + extensions.clear(); extensions.push_back(kExtension1); extensions.push_back(kExtension1Encrypted); - filtered = RtpExtension::FilterDuplicateNonEncrypted(extensions); + filtered = RtpExtension::DeduplicateHeaderExtensions( + extensions, RtpExtension::Filter::kRequireEncryptedExtension); + EXPECT_EQ(1u, filtered.size()); + EXPECT_EQ(std::vector{kExtension1Encrypted}, filtered); + + extensions.clear(); + extensions.push_back(kExtension1Encrypted); + extensions.push_back(kExtension1); + filtered = RtpExtension::DeduplicateHeaderExtensions( + extensions, RtpExtension::Filter::kDiscardEncryptedExtension); + EXPECT_EQ(1u, filtered.size()); + EXPECT_EQ(std::vector{kExtension1}, filtered); + + extensions.clear(); + extensions.push_back(kExtension1Encrypted); + extensions.push_back(kExtension1); + filtered = RtpExtension::DeduplicateHeaderExtensions( + extensions, RtpExtension::Filter::kPreferEncryptedExtension); EXPECT_EQ(1u, filtered.size()); EXPECT_EQ(std::vector{kExtension1Encrypted}, filtered); extensions.clear(); extensions.push_back(kExtension1Encrypted); extensions.push_back(kExtension1); - filtered = RtpExtension::FilterDuplicateNonEncrypted(extensions); + filtered = RtpExtension::DeduplicateHeaderExtensions( + extensions, RtpExtension::Filter::kRequireEncryptedExtension); EXPECT_EQ(1u, filtered.size()); EXPECT_EQ(std::vector{kExtension1Encrypted}, filtered); extensions.clear(); extensions.push_back(kExtension1); extensions.push_back(kExtension2); - filtered = RtpExtension::FilterDuplicateNonEncrypted(extensions); + filtered = RtpExtension::DeduplicateHeaderExtensions( + extensions, RtpExtension::Filter::kDiscardEncryptedExtension); + EXPECT_EQ(2u, filtered.size()); + EXPECT_EQ(extensions, filtered); + filtered = RtpExtension::DeduplicateHeaderExtensions( + extensions, RtpExtension::Filter::kPreferEncryptedExtension); EXPECT_EQ(2u, filtered.size()); EXPECT_EQ(extensions, filtered); + filtered = RtpExtension::DeduplicateHeaderExtensions( + extensions, RtpExtension::Filter::kRequireEncryptedExtension); + EXPECT_EQ(0u, filtered.size()); + + extensions.clear(); + extensions.push_back(kExtension1); + extensions.push_back(kExtension2); + extensions.push_back(kExtension1Encrypted); + filtered = RtpExtension::DeduplicateHeaderExtensions( + extensions, RtpExtension::Filter::kDiscardEncryptedExtension); + EXPECT_EQ(2u, filtered.size()); + EXPECT_EQ((std::vector{kExtension1, kExtension2}), filtered); + filtered = RtpExtension::DeduplicateHeaderExtensions( + extensions, RtpExtension::Filter::kPreferEncryptedExtension); + EXPECT_EQ(2u, filtered.size()); + EXPECT_EQ((std::vector{kExtension1Encrypted, kExtension2}), + filtered); + filtered = RtpExtension::DeduplicateHeaderExtensions( + extensions, RtpExtension::Filter::kRequireEncryptedExtension); + EXPECT_EQ(1u, filtered.size()); + EXPECT_EQ((std::vector{kExtension1Encrypted}), filtered); +} + +TEST(RtpExtensionTest, FindHeaderExtensionByUriAndEncryption) { + std::vector extensions; + + extensions.clear(); + EXPECT_EQ(nullptr, RtpExtension::FindHeaderExtensionByUriAndEncryption( + extensions, kExtensionUri1, false)); + + extensions.clear(); + extensions.push_back(kExtension1); + EXPECT_EQ(kExtension1, *RtpExtension::FindHeaderExtensionByUriAndEncryption( + extensions, kExtensionUri1, false)); + EXPECT_EQ(nullptr, RtpExtension::FindHeaderExtensionByUriAndEncryption( + extensions, kExtensionUri1, true)); + EXPECT_EQ(nullptr, RtpExtension::FindHeaderExtensionByUriAndEncryption( + extensions, kExtensionUri2, false)); + + extensions.clear(); + extensions.push_back(kExtension1); + extensions.push_back(kExtension2); + extensions.push_back(kExtension1Encrypted); + EXPECT_EQ(kExtension1, *RtpExtension::FindHeaderExtensionByUriAndEncryption( + extensions, kExtensionUri1, false)); + EXPECT_EQ(kExtension2, *RtpExtension::FindHeaderExtensionByUriAndEncryption( + extensions, kExtensionUri2, false)); + EXPECT_EQ(kExtension1Encrypted, + *RtpExtension::FindHeaderExtensionByUriAndEncryption( + extensions, kExtensionUri1, true)); + EXPECT_EQ(nullptr, RtpExtension::FindHeaderExtensionByUriAndEncryption( + extensions, kExtensionUri2, true)); +} + +TEST(RtpExtensionTest, FindHeaderExtensionByUri) { + std::vector extensions; + + extensions.clear(); + EXPECT_EQ(nullptr, RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kDiscardEncryptedExtension)); + EXPECT_EQ(nullptr, RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kPreferEncryptedExtension)); + EXPECT_EQ(nullptr, RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kRequireEncryptedExtension)); + + extensions.clear(); + extensions.push_back(kExtension1); + EXPECT_EQ(kExtension1, *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kDiscardEncryptedExtension)); + EXPECT_EQ(kExtension1, *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kPreferEncryptedExtension)); + EXPECT_EQ(nullptr, RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kRequireEncryptedExtension)); + EXPECT_EQ(nullptr, RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri2, + RtpExtension::Filter::kDiscardEncryptedExtension)); + EXPECT_EQ(nullptr, RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri2, + RtpExtension::Filter::kPreferEncryptedExtension)); + EXPECT_EQ(nullptr, RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri2, + RtpExtension::Filter::kRequireEncryptedExtension)); + + extensions.clear(); + extensions.push_back(kExtension1); + extensions.push_back(kExtension1Encrypted); + EXPECT_EQ(kExtension1, *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kDiscardEncryptedExtension)); + + extensions.clear(); + extensions.push_back(kExtension1); + extensions.push_back(kExtension1Encrypted); + EXPECT_EQ(kExtension1Encrypted, + *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kPreferEncryptedExtension)); + + extensions.clear(); + extensions.push_back(kExtension1); + extensions.push_back(kExtension1Encrypted); + EXPECT_EQ(kExtension1Encrypted, + *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kRequireEncryptedExtension)); + + extensions.clear(); + extensions.push_back(kExtension1Encrypted); + extensions.push_back(kExtension1); + EXPECT_EQ(kExtension1, *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kDiscardEncryptedExtension)); + + extensions.clear(); + extensions.push_back(kExtension1Encrypted); + extensions.push_back(kExtension1); + EXPECT_EQ(kExtension1Encrypted, + *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kPreferEncryptedExtension)); + + extensions.clear(); + extensions.push_back(kExtension1Encrypted); + extensions.push_back(kExtension1); + EXPECT_EQ(kExtension1Encrypted, + *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kRequireEncryptedExtension)); + + extensions.clear(); + extensions.push_back(kExtension1); + extensions.push_back(kExtension2); + EXPECT_EQ(kExtension1, *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kDiscardEncryptedExtension)); + EXPECT_EQ(kExtension1, *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kPreferEncryptedExtension)); + EXPECT_EQ(nullptr, RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kRequireEncryptedExtension)); + EXPECT_EQ(kExtension2, *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri2, + RtpExtension::Filter::kDiscardEncryptedExtension)); + EXPECT_EQ(kExtension2, *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri2, + RtpExtension::Filter::kPreferEncryptedExtension)); + EXPECT_EQ(nullptr, RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri2, + RtpExtension::Filter::kRequireEncryptedExtension)); + + extensions.clear(); + extensions.push_back(kExtension1); + extensions.push_back(kExtension2); + extensions.push_back(kExtension1Encrypted); + EXPECT_EQ(kExtension1, *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kDiscardEncryptedExtension)); + EXPECT_EQ(kExtension1Encrypted, + *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kPreferEncryptedExtension)); + EXPECT_EQ(kExtension1Encrypted, + *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri1, + RtpExtension::Filter::kRequireEncryptedExtension)); + EXPECT_EQ(kExtension2, *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri2, + RtpExtension::Filter::kDiscardEncryptedExtension)); + EXPECT_EQ(kExtension2, *RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri2, + RtpExtension::Filter::kPreferEncryptedExtension)); + EXPECT_EQ(nullptr, RtpExtension::FindHeaderExtensionByUri( + extensions, kExtensionUri2, + RtpExtension::Filter::kRequireEncryptedExtension)); } } // namespace webrtc diff --git a/api/rtp_receiver_interface.h b/api/rtp_receiver_interface.h index 786ea3aceb..327c9f2fee 100644 --- a/api/rtp_receiver_interface.h +++ b/api/rtp_receiver_interface.h @@ -22,11 +22,9 @@ #include "api/frame_transformer_interface.h" #include "api/media_stream_interface.h" #include "api/media_types.h" -#include "api/proxy.h" #include "api/rtp_parameters.h" #include "api/scoped_refptr.h" #include "api/transport/rtp/rtp_source.h" -#include "rtc_base/deprecation.h" #include "rtc_base/ref_count.h" #include "rtc_base/system/rtc_export.h" @@ -101,11 +99,13 @@ class RTC_EXPORT RtpReceiverInterface : public rtc::RefCountInterface { // before it is sent across the network. This will decrypt the entire frame // using the user provided decryption mechanism regardless of whether SRTP is // enabled or not. + // TODO(bugs.webrtc.org/12772): Remove. virtual void SetFrameDecryptor( rtc::scoped_refptr frame_decryptor); // Returns a pointer to the frame decryptor set previously by the // user. This can be used to update the state of the object. + // TODO(bugs.webrtc.org/12772): Remove. virtual rtc::scoped_refptr GetFrameDecryptor() const; // Sets a frame transformer between the depacketizer and the decoder to enable @@ -118,32 +118,6 @@ class RTC_EXPORT RtpReceiverInterface : public rtc::RefCountInterface { ~RtpReceiverInterface() override = default; }; -// Define proxy for RtpReceiverInterface. -// TODO(deadbeef): Move this to .cc file and out of api/. What threads methods -// are called on is an implementation detail. -BEGIN_SIGNALING_PROXY_MAP(RtpReceiver) -PROXY_SIGNALING_THREAD_DESTRUCTOR() -PROXY_CONSTMETHOD0(rtc::scoped_refptr, track) -PROXY_CONSTMETHOD0(rtc::scoped_refptr, dtls_transport) -PROXY_CONSTMETHOD0(std::vector, stream_ids) -PROXY_CONSTMETHOD0(std::vector>, - streams) -BYPASS_PROXY_CONSTMETHOD0(cricket::MediaType, media_type) -BYPASS_PROXY_CONSTMETHOD0(std::string, id) -PROXY_CONSTMETHOD0(RtpParameters, GetParameters) -PROXY_METHOD1(void, SetObserver, RtpReceiverObserverInterface*) -PROXY_METHOD1(void, SetJitterBufferMinimumDelay, absl::optional) -PROXY_CONSTMETHOD0(std::vector, GetSources) -PROXY_METHOD1(void, - SetFrameDecryptor, - rtc::scoped_refptr) -PROXY_CONSTMETHOD0(rtc::scoped_refptr, - GetFrameDecryptor) -PROXY_METHOD1(void, - SetDepacketizerToDecoderFrameTransformer, - rtc::scoped_refptr) -END_PROXY_MAP() - } // namespace webrtc #endif // API_RTP_RECEIVER_INTERFACE_H_ diff --git a/api/rtp_sender_interface.h b/api/rtp_sender_interface.h index a33b80042e..9ffad68644 100644 --- a/api/rtp_sender_interface.h +++ b/api/rtp_sender_interface.h @@ -23,7 +23,6 @@ #include "api/frame_transformer_interface.h" #include "api/media_stream_interface.h" #include "api/media_types.h" -#include "api/proxy.h" #include "api/rtc_error.h" #include "api/rtp_parameters.h" #include "api/scoped_refptr.h" @@ -101,33 +100,6 @@ class RTC_EXPORT RtpSenderInterface : public rtc::RefCountInterface { ~RtpSenderInterface() override = default; }; -// Define proxy for RtpSenderInterface. -// TODO(deadbeef): Move this to .cc file and out of api/. What threads methods -// are called on is an implementation detail. -BEGIN_SIGNALING_PROXY_MAP(RtpSender) -PROXY_SIGNALING_THREAD_DESTRUCTOR() -PROXY_METHOD1(bool, SetTrack, MediaStreamTrackInterface*) -PROXY_CONSTMETHOD0(rtc::scoped_refptr, track) -PROXY_CONSTMETHOD0(rtc::scoped_refptr, dtls_transport) -PROXY_CONSTMETHOD0(uint32_t, ssrc) -BYPASS_PROXY_CONSTMETHOD0(cricket::MediaType, media_type) -BYPASS_PROXY_CONSTMETHOD0(std::string, id) -PROXY_CONSTMETHOD0(std::vector, stream_ids) -PROXY_CONSTMETHOD0(std::vector, init_send_encodings) -PROXY_CONSTMETHOD0(RtpParameters, GetParameters) -PROXY_METHOD1(RTCError, SetParameters, const RtpParameters&) -PROXY_CONSTMETHOD0(rtc::scoped_refptr, GetDtmfSender) -PROXY_METHOD1(void, - SetFrameEncryptor, - rtc::scoped_refptr) -PROXY_CONSTMETHOD0(rtc::scoped_refptr, - GetFrameEncryptor) -PROXY_METHOD1(void, SetStreams, const std::vector&) -PROXY_METHOD1(void, - SetEncoderToPacketizerFrameTransformer, - rtc::scoped_refptr) -END_PROXY_MAP() - } // namespace webrtc #endif // API_RTP_SENDER_INTERFACE_H_ diff --git a/api/rtp_transceiver_interface.h b/api/rtp_transceiver_interface.h index 9b46846597..4799c4b153 100644 --- a/api/rtp_transceiver_interface.h +++ b/api/rtp_transceiver_interface.h @@ -14,6 +14,7 @@ #include #include +#include "absl/base/attributes.h" #include "absl/types/optional.h" #include "api/array_view.h" #include "api/media_types.h" @@ -111,8 +112,8 @@ class RTC_EXPORT RtpTransceiverInterface : public rtc::RefCountInterface { // https://w3c.github.io/webrtc-pc/#dom-rtcrtptransceiver-direction // TODO(hta): Deprecate SetDirection without error and rename // SetDirectionWithError to SetDirection, remove default implementations. - RTC_DEPRECATED virtual void SetDirection( - RtpTransceiverDirection new_direction); + ABSL_DEPRECATED("Use SetDirectionWithError instead") + virtual void SetDirection(RtpTransceiverDirection new_direction); virtual RTCError SetDirectionWithError(RtpTransceiverDirection new_direction); // The current_direction attribute indicates the current direction negotiated @@ -140,7 +141,7 @@ class RTC_EXPORT RtpTransceiverInterface : public rtc::RefCountInterface { // This is an internal function, and is exposed for historical reasons. // https://w3c.github.io/webrtc-pc/#dfn-stop-the-rtcrtptransceiver virtual void StopInternal(); - RTC_DEPRECATED virtual void Stop(); + ABSL_DEPRECATED("Use StopStandard instead") virtual void Stop(); // The SetCodecPreferences method overrides the default codec preferences used // by WebRTC for this transceiver. diff --git a/api/scoped_refptr.h b/api/scoped_refptr.h index fa4e83dbaf..4e3f0ebfc8 100644 --- a/api/scoped_refptr.h +++ b/api/scoped_refptr.h @@ -104,6 +104,7 @@ class scoped_refptr { T* get() const { return ptr_; } operator T*() const { return ptr_; } + T& operator*() const { return *ptr_; } T* operator->() const { return ptr_; } // Returns the (possibly null) raw pointer, and makes the scoped_refptr hold a diff --git a/api/sctp_transport_interface.h b/api/sctp_transport_interface.h index 6af0bfce34..7080889fcf 100644 --- a/api/sctp_transport_interface.h +++ b/api/sctp_transport_interface.h @@ -35,6 +35,8 @@ enum class SctpTransportState { // http://w3c.github.io/webrtc-pc/#rtcsctptransport-interface class RTC_EXPORT SctpTransportInformation { public: + SctpTransportInformation() = default; + SctpTransportInformation(const SctpTransportInformation&) = default; explicit SctpTransportInformation(SctpTransportState state); SctpTransportInformation( SctpTransportState state, diff --git a/api/sequence_checker.h b/api/sequence_checker.h new file mode 100644 index 0000000000..5db7b9e4df --- /dev/null +++ b/api/sequence_checker.h @@ -0,0 +1,116 @@ +/* + * Copyright 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef API_SEQUENCE_CHECKER_H_ +#define API_SEQUENCE_CHECKER_H_ + +#include "rtc_base/checks.h" +#include "rtc_base/synchronization/sequence_checker_internal.h" +#include "rtc_base/thread_annotations.h" + +namespace webrtc { + +// SequenceChecker is a helper class used to help verify that some methods +// of a class are called on the same task queue or thread. A +// SequenceChecker is bound to a a task queue if the object is +// created on a task queue, or a thread otherwise. +// +// +// Example: +// class MyClass { +// public: +// void Foo() { +// RTC_DCHECK_RUN_ON(&sequence_checker_); +// ... (do stuff) ... +// } +// +// private: +// SequenceChecker sequence_checker_; +// } +// +// In Release mode, IsCurrent will always return true. +class RTC_LOCKABLE SequenceChecker +#if RTC_DCHECK_IS_ON + : public webrtc_sequence_checker_internal::SequenceCheckerImpl { + using Impl = webrtc_sequence_checker_internal::SequenceCheckerImpl; +#else + : public webrtc_sequence_checker_internal::SequenceCheckerDoNothing { + using Impl = webrtc_sequence_checker_internal::SequenceCheckerDoNothing; +#endif + public: + // Returns true if sequence checker is attached to the current sequence. + bool IsCurrent() const { return Impl::IsCurrent(); } + // Detaches checker from sequence to which it is attached. Next attempt + // to do a check with this checker will result in attaching this checker + // to the sequence on which check was performed. + void Detach() { Impl::Detach(); } +}; + +} // namespace webrtc + +// RTC_RUN_ON/RTC_GUARDED_BY/RTC_DCHECK_RUN_ON macros allows to annotate +// variables are accessed from same thread/task queue. +// Using tools designed to check mutexes, it checks at compile time everywhere +// variable is access, there is a run-time dcheck thread/task queue is correct. +// +// class SequenceCheckerExample { +// public: +// int CalledFromPacer() RTC_RUN_ON(pacer_sequence_checker_) { +// return var2_; +// } +// +// void CallMeFromPacer() { +// RTC_DCHECK_RUN_ON(&pacer_sequence_checker_) +// << "Should be called from pacer"; +// CalledFromPacer(); +// } +// +// private: +// int pacer_var_ RTC_GUARDED_BY(pacer_sequence_checker_); +// SequenceChecker pacer_sequence_checker_; +// }; +// +// class TaskQueueExample { +// public: +// class Encoder { +// public: +// rtc::TaskQueueBase& Queue() { return encoder_queue_; } +// void Encode() { +// RTC_DCHECK_RUN_ON(&encoder_queue_); +// DoSomething(var_); +// } +// +// private: +// rtc::TaskQueueBase& encoder_queue_; +// Frame var_ RTC_GUARDED_BY(encoder_queue_); +// }; +// +// void Encode() { +// // Will fail at runtime when DCHECK is enabled: +// // encoder_->Encode(); +// // Will work: +// rtc::scoped_refptr encoder = encoder_; +// encoder_->Queue().PostTask([encoder] { encoder->Encode(); }); +// } +// +// private: +// rtc::scoped_refptr encoder_; +// } + +// Document if a function expected to be called from same thread/task queue. +#define RTC_RUN_ON(x) \ + RTC_THREAD_ANNOTATION_ATTRIBUTE__(exclusive_locks_required(x)) + +#define RTC_DCHECK_RUN_ON(x) \ + webrtc::webrtc_sequence_checker_internal::SequenceCheckerScope \ + seq_check_scope(x); \ + RTC_DCHECK((x)->IsCurrent()) \ + << webrtc::webrtc_sequence_checker_internal::ExpectationToString(x) + +#endif // API_SEQUENCE_CHECKER_H_ diff --git a/rtc_base/synchronization/sequence_checker_unittest.cc b/api/sequence_checker_unittest.cc similarity index 90% rename from rtc_base/synchronization/sequence_checker_unittest.cc rename to api/sequence_checker_unittest.cc index 6fcb522c54..21a0894a8e 100644 --- a/rtc_base/synchronization/sequence_checker_unittest.cc +++ b/api/sequence_checker_unittest.cc @@ -8,7 +8,7 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include "rtc_base/synchronization/sequence_checker.h" +#include "api/sequence_checker.h" #include #include @@ -17,7 +17,6 @@ #include "rtc_base/event.h" #include "rtc_base/platform_thread.h" #include "rtc_base/task_queue_for_test.h" -#include "rtc_base/thread_checker.h" #include "test/gtest.h" namespace webrtc { @@ -41,21 +40,14 @@ class CompileTimeTestForGuardedBy { }; void RunOnDifferentThread(rtc::FunctionView run) { - struct Object { - static void Run(void* obj) { - auto* me = static_cast(obj); - me->run(); - me->thread_has_run_event.Set(); - } - - rtc::FunctionView run; - rtc::Event thread_has_run_event; - } object{run}; - - rtc::PlatformThread thread(&Object::Run, &object, "thread"); - thread.Start(); - EXPECT_TRUE(object.thread_has_run_event.Wait(1000)); - thread.Stop(); + rtc::Event thread_has_run_event; + rtc::PlatformThread::SpawnJoinable( + [&] { + run(); + thread_has_run_event.Set(); + }, + "thread"); + EXPECT_TRUE(thread_has_run_event.Wait(1000)); } } // namespace diff --git a/api/stats/rtc_stats.h b/api/stats/rtc_stats.h index 5de5b7fbb0..9290e803fa 100644 --- a/api/stats/rtc_stats.h +++ b/api/stats/rtc_stats.h @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -237,6 +238,9 @@ class RTCStatsMemberInterface { kSequenceUint64, // std::vector kSequenceDouble, // std::vector kSequenceString, // std::vector + + kMapStringUint64, // std::map + kMapStringDouble, // std::map }; virtual ~RTCStatsMemberInterface() {} @@ -363,6 +367,13 @@ class RTCStatsMember : public RTCStatsMemberInterface { T value_; }; +namespace rtc_stats_internal { + +typedef std::map MapStringUint64; +typedef std::map MapStringDouble; + +} // namespace rtc_stats_internal + #define WEBRTC_DECLARE_RTCSTATSMEMBER(T) \ template <> \ RTC_EXPORT RTCStatsMemberInterface::Type RTCStatsMember::StaticType(); \ @@ -391,6 +402,8 @@ WEBRTC_DECLARE_RTCSTATSMEMBER(std::vector); WEBRTC_DECLARE_RTCSTATSMEMBER(std::vector); WEBRTC_DECLARE_RTCSTATSMEMBER(std::vector); WEBRTC_DECLARE_RTCSTATSMEMBER(std::vector); +WEBRTC_DECLARE_RTCSTATSMEMBER(rtc_stats_internal::MapStringUint64); +WEBRTC_DECLARE_RTCSTATSMEMBER(rtc_stats_internal::MapStringDouble); // Using inheritance just so that it's obvious from the member's declaration // whether it's standardized or not. @@ -455,6 +468,10 @@ extern template class RTC_EXPORT_TEMPLATE_DECLARE(RTC_EXPORT) RTCNonStandardStatsMember>; extern template class RTC_EXPORT_TEMPLATE_DECLARE(RTC_EXPORT) RTCNonStandardStatsMember>; +extern template class RTC_EXPORT_TEMPLATE_DECLARE(RTC_EXPORT) + RTCNonStandardStatsMember>; +extern template class RTC_EXPORT_TEMPLATE_DECLARE(RTC_EXPORT) + RTCNonStandardStatsMember>; } // namespace webrtc diff --git a/api/stats/rtc_stats_collector_callback.h b/api/stats/rtc_stats_collector_callback.h index c3e08245ea..506cc63e6f 100644 --- a/api/stats/rtc_stats_collector_callback.h +++ b/api/stats/rtc_stats_collector_callback.h @@ -17,7 +17,7 @@ namespace webrtc { -class RTCStatsCollectorCallback : public virtual rtc::RefCountInterface { +class RTCStatsCollectorCallback : public rtc::RefCountInterface { public: ~RTCStatsCollectorCallback() override = default; diff --git a/api/stats/rtc_stats_report.h b/api/stats/rtc_stats_report.h index 94bd813b07..0fe5ce91f9 100644 --- a/api/stats/rtc_stats_report.h +++ b/api/stats/rtc_stats_report.h @@ -19,9 +19,11 @@ #include #include +#include "api/ref_counted_base.h" #include "api/scoped_refptr.h" #include "api/stats/rtc_stats.h" -#include "rtc_base/ref_count.h" +// TODO(tommi): Remove this include after fixing iwyu issue in chromium. +// See: third_party/blink/renderer/platform/peerconnection/rtc_stats.cc #include "rtc_base/ref_counted_object.h" #include "rtc_base/system/rtc_export.h" @@ -29,7 +31,8 @@ namespace webrtc { // A collection of stats. // This is accessible as a map from |RTCStats::id| to |RTCStats|. -class RTC_EXPORT RTCStatsReport : public rtc::RefCountInterface { +class RTC_EXPORT RTCStatsReport final + : public rtc::RefCountedNonVirtual { public: typedef std::map> StatsMap; @@ -107,11 +110,11 @@ class RTC_EXPORT RTCStatsReport : public rtc::RefCountInterface { // listing all of its stats objects. std::string ToJson() const; - friend class rtc::RefCountedObject; + protected: + friend class rtc::RefCountedNonVirtual; + ~RTCStatsReport() = default; private: - ~RTCStatsReport() override; - int64_t timestamp_us_; StatsMap stats_; }; diff --git a/api/stats/rtcstats_objects.h b/api/stats/rtcstats_objects.h index ee3d70727f..2030380918 100644 --- a/api/stats/rtcstats_objects.h +++ b/api/stats/rtcstats_objects.h @@ -13,6 +13,7 @@ #include +#include #include #include #include @@ -161,6 +162,7 @@ class RTC_EXPORT RTCIceCandidatePairStats final : public RTCStats { // TODO(hbos): Support enum types? // "RTCStatsMember"? RTCStatsMember state; + // Obsolete: priority RTCStatsMember priority; RTCStatsMember nominated; // TODO(hbos): Collect this the way the spec describes it. We have a value for @@ -208,9 +210,11 @@ class RTC_EXPORT RTCIceCandidateStats : public RTCStats { ~RTCIceCandidateStats() override; RTCStatsMember transport_id; + // Obsolete: is_remote RTCStatsMember is_remote; RTCStatsMember network_type; RTCStatsMember ip; + RTCStatsMember address; RTCStatsMember port; RTCStatsMember protocol; RTCStatsMember relay_protocol; @@ -219,9 +223,6 @@ class RTC_EXPORT RTCIceCandidateStats : public RTCStats { RTCStatsMember priority; // TODO(hbos): Not collected by |RTCStatsCollector|. crbug.com/632723 RTCStatsMember url; - // TODO(hbos): |deleted = true| case is not supported by |RTCStatsCollector|. - // crbug.com/632723 - RTCStatsMember deleted; // = false protected: RTCIceCandidateStats(const std::string& id, @@ -374,34 +375,64 @@ class RTC_EXPORT RTCRTPStreamStats : public RTCStats { ~RTCRTPStreamStats() override; RTCStatsMember ssrc; - // TODO(hbos): Remote case not supported by |RTCStatsCollector|. - // crbug.com/657855, 657856 - RTCStatsMember is_remote; // = false - RTCStatsMember media_type; // renamed to kind. RTCStatsMember kind; + // Obsolete: track_id RTCStatsMember track_id; RTCStatsMember transport_id; RTCStatsMember codec_id; - // FIR and PLI counts are only defined for |media_type == "video"|. - RTCStatsMember fir_count; - RTCStatsMember pli_count; - // TODO(hbos): NACK count should be collected by |RTCStatsCollector| for both - // audio and video but is only defined in the "video" case. crbug.com/657856 - RTCStatsMember nack_count; - // TODO(hbos): Not collected by |RTCStatsCollector|. crbug.com/657854 - // SLI count is only defined for |media_type == "video"|. - RTCStatsMember sli_count; - RTCStatsMember qp_sum; + + // Obsolete + RTCStatsMember media_type; // renamed to kind. protected: RTCRTPStreamStats(const std::string& id, int64_t timestamp_us); RTCRTPStreamStats(std::string&& id, int64_t timestamp_us); }; +// https://www.w3.org/TR/webrtc-stats/#receivedrtpstats-dict* +class RTC_EXPORT RTCReceivedRtpStreamStats : public RTCRTPStreamStats { + public: + WEBRTC_RTCSTATS_DECL(); + + RTCReceivedRtpStreamStats(const RTCReceivedRtpStreamStats& other); + ~RTCReceivedRtpStreamStats() override; + + // TODO(hbos) The following fields need to be added and migrated + // both from RTCInboundRtpStreamStats and RTCRemoteInboundRtpStreamStats: + // packetsReceived, packetsDiscarded, packetsRepaired, burstPacketsLost, + // burstPacketDiscarded, burstLossCount, burstDiscardCount, burstLossRate, + // burstDiscardRate, gapLossRate, gapDiscardRate, framesDropped, + // partialFramesLost, fullFramesLost + // crbug.com/webrtc/12532 + RTCStatsMember jitter; + RTCStatsMember packets_lost; // Signed per RFC 3550 + + protected: + RTCReceivedRtpStreamStats(const std::string&& id, int64_t timestamp_us); + RTCReceivedRtpStreamStats(std::string&& id, int64_t timestamp_us); +}; + +// https://www.w3.org/TR/webrtc-stats/#sentrtpstats-dict* +class RTC_EXPORT RTCSentRtpStreamStats : public RTCRTPStreamStats { + public: + WEBRTC_RTCSTATS_DECL(); + + RTCSentRtpStreamStats(const RTCSentRtpStreamStats& other); + ~RTCSentRtpStreamStats() override; + + RTCStatsMember packets_sent; + RTCStatsMember bytes_sent; + + protected: + RTCSentRtpStreamStats(const std::string&& id, int64_t timestamp_us); + RTCSentRtpStreamStats(std::string&& id, int64_t timestamp_us); +}; + // https://w3c.github.io/webrtc-stats/#inboundrtpstats-dict* // TODO(hbos): Support the remote case |is_remote = true|. // https://bugs.webrtc.org/7065 -class RTC_EXPORT RTCInboundRTPStreamStats final : public RTCRTPStreamStats { +class RTC_EXPORT RTCInboundRTPStreamStats final + : public RTCReceivedRtpStreamStats { public: WEBRTC_RTCSTATS_DECL(); @@ -410,16 +441,13 @@ class RTC_EXPORT RTCInboundRTPStreamStats final : public RTCRTPStreamStats { RTCInboundRTPStreamStats(const RTCInboundRTPStreamStats& other); ~RTCInboundRTPStreamStats() override; + RTCStatsMember remote_id; RTCStatsMember packets_received; RTCStatsMember fec_packets_received; RTCStatsMember fec_packets_discarded; RTCStatsMember bytes_received; RTCStatsMember header_bytes_received; - RTCStatsMember packets_lost; // Signed per RFC 3550 RTCStatsMember last_packet_received_timestamp; - // TODO(hbos): Collect and populate this value for both "audio" and "video", - // currently not collected for "video". https://bugs.webrtc.org/7065 - RTCStatsMember jitter; RTCStatsMember jitter_buffer_delay; RTCStatsMember jitter_buffer_emitted_count; RTCStatsMember total_samples_received; @@ -471,6 +499,11 @@ class RTC_EXPORT RTCInboundRTPStreamStats final : public RTCRTPStreamStats { // TODO(hbos): This is only implemented for video; implement it for audio as // well. RTCStatsMember decoder_implementation; + // FIR and PLI counts are only defined for |media_type == "video"|. + RTCStatsMember fir_count; + RTCStatsMember pli_count; + RTCStatsMember nack_count; + RTCStatsMember qp_sum; }; // https://w3c.github.io/webrtc-stats/#outboundrtpstats-dict* @@ -508,10 +541,8 @@ class RTC_EXPORT RTCOutboundRTPStreamStats final : public RTCRTPStreamStats { // implement it for audio as well. RTCStatsMember total_packet_send_delay; // Enum type RTCQualityLimitationReason - // TODO(https://crbug.com/webrtc/10686): Also expose - // qualityLimitationDurations. Requires RTCStatsMember support for - // "record", see https://crbug.com/webrtc/10685. RTCStatsMember quality_limitation_reason; + RTCStatsMember> quality_limitation_durations; // https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-qualitylimitationresolutionchanges RTCStatsMember quality_limitation_resolution_changes; // https://henbos.github.io/webrtc-provisional-stats/#dom-rtcoutboundrtpstreamstats-contenttype @@ -519,18 +550,16 @@ class RTC_EXPORT RTCOutboundRTPStreamStats final : public RTCRTPStreamStats { // TODO(hbos): This is only implemented for video; implement it for audio as // well. RTCStatsMember encoder_implementation; + // FIR and PLI counts are only defined for |media_type == "video"|. + RTCStatsMember fir_count; + RTCStatsMember pli_count; + RTCStatsMember nack_count; + RTCStatsMember qp_sum; }; -// TODO(https://crbug.com/webrtc/10671): Refactor the stats dictionaries to have -// the same hierarchy as in the spec; implement RTCReceivedRtpStreamStats. -// Several metrics are shared between "outbound-rtp", "remote-inbound-rtp", -// "inbound-rtp" and "remote-outbound-rtp". In the spec there is a hierarchy of -// dictionaries that minimizes defining the same metrics in multiple places. -// From JavaScript this hierarchy is not observable and the spec's hierarchy is -// purely editorial. In C++ non-final classes in the hierarchy could be used to -// refer to different stats objects within the hierarchy. // https://w3c.github.io/webrtc-stats/#remoteinboundrtpstats-dict* -class RTC_EXPORT RTCRemoteInboundRtpStreamStats final : public RTCStats { +class RTC_EXPORT RTCRemoteInboundRtpStreamStats final + : public RTCReceivedRtpStreamStats { public: WEBRTC_RTCSTATS_DECL(); @@ -539,17 +568,6 @@ class RTC_EXPORT RTCRemoteInboundRtpStreamStats final : public RTCStats { RTCRemoteInboundRtpStreamStats(const RTCRemoteInboundRtpStreamStats& other); ~RTCRemoteInboundRtpStreamStats() override; - // In the spec RTCRemoteInboundRtpStreamStats inherits from RTCRtpStreamStats - // and RTCReceivedRtpStreamStats. The members here are listed based on where - // they are defined in the spec. - // RTCRtpStreamStats - RTCStatsMember ssrc; - RTCStatsMember kind; - RTCStatsMember transport_id; - RTCStatsMember codec_id; - // RTCReceivedRtpStreamStats - RTCStatsMember packets_lost; - RTCStatsMember jitter; // TODO(hbos): The following RTCReceivedRtpStreamStats metrics should also be // implemented: packetsReceived, packetsDiscarded, packetsRepaired, // burstPacketsLost, burstPacketsDiscarded, burstLossCount, burstDiscardCount, @@ -557,8 +575,25 @@ class RTC_EXPORT RTCRemoteInboundRtpStreamStats final : public RTCStats { // RTCRemoteInboundRtpStreamStats RTCStatsMember local_id; RTCStatsMember round_trip_time; - // TODO(hbos): The following RTCRemoteInboundRtpStreamStats metric should also - // be implemented: fractionLost. + RTCStatsMember fraction_lost; + RTCStatsMember total_round_trip_time; + RTCStatsMember round_trip_time_measurements; +}; + +// https://w3c.github.io/webrtc-stats/#remoteoutboundrtpstats-dict* +class RTC_EXPORT RTCRemoteOutboundRtpStreamStats final + : public RTCSentRtpStreamStats { + public: + WEBRTC_RTCSTATS_DECL(); + + RTCRemoteOutboundRtpStreamStats(const std::string& id, int64_t timestamp_us); + RTCRemoteOutboundRtpStreamStats(std::string&& id, int64_t timestamp_us); + RTCRemoteOutboundRtpStreamStats(const RTCRemoteOutboundRtpStreamStats& other); + ~RTCRemoteOutboundRtpStreamStats() override; + + RTCStatsMember local_id; + RTCStatsMember remote_timestamp; + RTCStatsMember reports_sent; }; // https://w3c.github.io/webrtc-stats/#dom-rtcmediasourcestats @@ -590,6 +625,8 @@ class RTC_EXPORT RTCAudioSourceStats final : public RTCMediaSourceStats { RTCStatsMember audio_level; RTCStatsMember total_audio_energy; RTCStatsMember total_samples_duration; + RTCStatsMember echo_return_loss; + RTCStatsMember echo_return_loss_enhancement; }; // https://w3c.github.io/webrtc-stats/#dom-rtcvideosourcestats @@ -604,7 +641,6 @@ class RTC_EXPORT RTCVideoSourceStats final : public RTCMediaSourceStats { RTCStatsMember width; RTCStatsMember height; - // TODO(hbos): Implement this metric. RTCStatsMember frames; RTCStatsMember frames_per_second; }; diff --git a/api/stats_types.cc b/api/stats_types.cc index 7dcbd134a1..6fdc7e85a5 100644 --- a/api/stats_types.cc +++ b/api/stats_types.cc @@ -15,6 +15,7 @@ #include "absl/algorithm/container.h" #include "rtc_base/checks.h" #include "rtc_base/ref_counted_object.h" +#include "rtc_base/string_encode.h" // TODO(tommi): Could we have a static map of value name -> expected type // and use this to RTC_DCHECK on correct usage (somewhat strongly typed values)? diff --git a/api/stats_types.h b/api/stats_types.h index c1922a8a22..d032462da6 100644 --- a/api/stats_types.h +++ b/api/stats_types.h @@ -21,11 +21,10 @@ #include #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/ref_count.h" -#include "rtc_base/string_encode.h" #include "rtc_base/system/rtc_export.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -344,7 +343,7 @@ class RTC_EXPORT StatsReport { const StatsValueName name; private: - rtc::ThreadChecker thread_checker_; + webrtc::SequenceChecker thread_checker_; mutable int ref_count_ RTC_GUARDED_BY(thread_checker_) = 0; const Type type_; @@ -447,7 +446,7 @@ class StatsCollection { private: Container list_; - rtc::ThreadChecker thread_checker_; + webrtc::SequenceChecker thread_checker_; }; } // namespace webrtc diff --git a/api/task_queue/task_queue_base.h b/api/task_queue/task_queue_base.h index 90b1efd31e..88419edd8f 100644 --- a/api/task_queue/task_queue_base.h +++ b/api/task_queue/task_queue_base.h @@ -27,12 +27,14 @@ class RTC_LOCKABLE RTC_EXPORT TaskQueueBase { // Starts destruction of the task queue. // On return ensures no task are running and no new tasks are able to start // on the task queue. - // Responsible for deallocation. Deallocation may happen syncrhoniously during + // Responsible for deallocation. Deallocation may happen synchronously during // Delete or asynchronously after Delete returns. // Code not running on the TaskQueue should not make any assumption when // TaskQueue is deallocated and thus should not call any methods after Delete. // Code running on the TaskQueue should not call Delete, but can assume // TaskQueue still exists and may call other methods, e.g. PostTask. + // Should be called on the same task queue or thread that this task queue + // was created on. virtual void Delete() = 0; // Schedules a task to execute. Tasks are executed in FIFO order. @@ -43,17 +45,20 @@ class RTC_LOCKABLE RTC_EXPORT TaskQueueBase { // TaskQueue or it may happen asynchronously after TaskQueue is deleted. // This may vary from one implementation to the next so assumptions about // lifetimes of pending tasks should not be made. + // May be called on any thread or task queue, including this task queue. virtual void PostTask(std::unique_ptr task) = 0; // Schedules a task to execute a specified number of milliseconds from when // the call is made. The precision should be considered as "best effort" // and in some cases, such as on Windows when all high precision timers have // been used up, can be off by as much as 15 millseconds. + // May be called on any thread or task queue, including this task queue. virtual void PostDelayedTask(std::unique_ptr task, uint32_t milliseconds) = 0; // Returns the task queue that is running the current thread. // Returns nullptr if this thread is not associated with any task queue. + // May be called on any thread or task queue, including this task queue. static TaskQueueBase* Current(); bool IsCurrent() const { return Current() == this; } diff --git a/api/test/DEPS b/api/test/DEPS index d97ac49df6..329076830c 100644 --- a/api/test/DEPS +++ b/api/test/DEPS @@ -8,9 +8,6 @@ specific_include_rules = { "dummy_peer_connection\.h": [ "+rtc_base/ref_counted_object.h", ], - "fake_constraints\.h": [ - "+rtc_base/string_encode.h", - ], "neteq_factory_with_codecs\.h": [ "+system_wrappers/include/clock.h", ], @@ -35,7 +32,4 @@ specific_include_rules = { "create_frame_generator\.h": [ "+system_wrappers/include/clock.h", ], - "videocodec_test_fixture\.h": [ - "+media/base/h264_profile_level_id.h" - ], } diff --git a/api/test/compile_all_headers.cc b/api/test/compile_all_headers.cc index 6f06742995..5ecdcc1eb8 100644 --- a/api/test/compile_all_headers.cc +++ b/api/test/compile_all_headers.cc @@ -30,6 +30,7 @@ #include "api/test/dummy_peer_connection.h" #include "api/test/fake_frame_decryptor.h" #include "api/test/fake_frame_encryptor.h" +#include "api/test/mock_async_dns_resolver.h" #include "api/test/mock_audio_mixer.h" #include "api/test/mock_data_channel.h" #include "api/test/mock_frame_decryptor.h" diff --git a/api/test/create_time_controller.cc b/api/test/create_time_controller.cc index a2c0cb713f..f7faeaab42 100644 --- a/api/test/create_time_controller.cc +++ b/api/test/create_time_controller.cc @@ -13,6 +13,8 @@ #include #include "call/call.h" +#include "call/rtp_transport_config.h" +#include "call/rtp_transport_controller_send_factory_interface.h" #include "test/time_controller/external_time_controller.h" #include "test/time_controller/simulated_time_controller.h" @@ -40,8 +42,13 @@ std::unique_ptr CreateTimeControllerBasedCallFactory( time_controller_->CreateProcessThread("CallModules"), [this]() { module_thread_ = nullptr; }); } + + RtpTransportConfig transportConfig = config.ExtractTransportConfig(); + return Call::Create(config, time_controller_->GetClock(), module_thread_, - time_controller_->CreateProcessThread("Pacer")); + config.rtp_transport_controller_send_factory->Create( + transportConfig, time_controller_->GetClock(), + time_controller_->CreateProcessThread("Pacer"))); } private: diff --git a/api/test/dummy_peer_connection.h b/api/test/dummy_peer_connection.h index 4d17aeddd0..80ae20c3c7 100644 --- a/api/test/dummy_peer_connection.h +++ b/api/test/dummy_peer_connection.h @@ -114,10 +114,10 @@ class DummyPeerConnection : public PeerConnectionInterface { } void ClearStatsCache() override {} - rtc::scoped_refptr CreateDataChannel( + RTCErrorOr> CreateDataChannelOrError( const std::string& label, const DataChannelInit* config) override { - return nullptr; + return RTCError(RTCErrorType::INTERNAL_ERROR, "Dummy function called"); } const SessionDescriptionInterface* local_description() const override { diff --git a/api/test/mock_async_dns_resolver.h b/api/test/mock_async_dns_resolver.h new file mode 100644 index 0000000000..e863cac6e6 --- /dev/null +++ b/api/test/mock_async_dns_resolver.h @@ -0,0 +1,54 @@ +/* + * Copyright 2021 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef API_TEST_MOCK_ASYNC_DNS_RESOLVER_H_ +#define API_TEST_MOCK_ASYNC_DNS_RESOLVER_H_ + +#include +#include + +#include "api/async_dns_resolver.h" +#include "test/gmock.h" + +namespace webrtc { + +class MockAsyncDnsResolverResult : public AsyncDnsResolverResult { + public: + MOCK_METHOD(bool, + GetResolvedAddress, + (int, rtc::SocketAddress*), + (const override)); + MOCK_METHOD(int, GetError, (), (const override)); +}; + +class MockAsyncDnsResolver : public AsyncDnsResolverInterface { + public: + MOCK_METHOD(void, + Start, + (const rtc::SocketAddress&, std::function), + (override)); + MOCK_METHOD(AsyncDnsResolverResult&, result, (), (const override)); +}; + +class MockAsyncDnsResolverFactory : public AsyncDnsResolverFactoryInterface { + public: + MOCK_METHOD(std::unique_ptr, + CreateAndResolve, + (const rtc::SocketAddress&, std::function), + (override)); + MOCK_METHOD(std::unique_ptr, + Create, + (), + (override)); +}; + +} // namespace webrtc + +#endif // API_TEST_MOCK_ASYNC_DNS_RESOLVER_H_ diff --git a/api/test/mock_peerconnectioninterface.h b/api/test/mock_peerconnectioninterface.h index be34df0b32..b5d94238c8 100644 --- a/api/test/mock_peerconnectioninterface.h +++ b/api/test/mock_peerconnectioninterface.h @@ -100,8 +100,8 @@ class MockPeerConnectionInterface GetSctpTransport, (), (const override)); - MOCK_METHOD(rtc::scoped_refptr, - CreateDataChannel, + MOCK_METHOD(RTCErrorOr>, + CreateDataChannelOrError, (const std::string&, const DataChannelInit*), (override)); MOCK_METHOD(const SessionDescriptionInterface*, diff --git a/api/test/network_emulation/BUILD.gn b/api/test/network_emulation/BUILD.gn index fb7bedc003..a8044d7230 100644 --- a/api/test/network_emulation/BUILD.gn +++ b/api/test/network_emulation/BUILD.gn @@ -12,6 +12,7 @@ rtc_library("network_emulation") { visibility = [ "*" ] sources = [ + "cross_traffic.h", "network_emulation_interfaces.cc", "network_emulation_interfaces.h", ] @@ -20,11 +21,32 @@ rtc_library("network_emulation") { "../..:array_view", "../../../rtc_base", "../../../rtc_base:checks", + "../../../rtc_base:ip_address", "../../../rtc_base:rtc_base_approved", + "../../../rtc_base:socket_address", "../../numerics", + "../../task_queue", "../../units:data_rate", "../../units:data_size", + "../../units:time_delta", "../../units:timestamp", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } + +rtc_library("create_cross_traffic") { + visibility = [ "*" ] + testonly = true + + sources = [ + "create_cross_traffic.cc", + "create_cross_traffic.h", + ] + + deps = [ + ":network_emulation", + "../..:network_emulation_manager_api", + "../../../rtc_base/task_utils:repeating_task", + "../../../test/network:emulated_network", + ] +} diff --git a/api/test/network_emulation/create_cross_traffic.cc b/api/test/network_emulation/create_cross_traffic.cc new file mode 100644 index 0000000000..36a535cec6 --- /dev/null +++ b/api/test/network_emulation/create_cross_traffic.cc @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "api/test/network_emulation/create_cross_traffic.h" + +#include + +#include "rtc_base/task_utils/repeating_task.h" +#include "test/network/cross_traffic.h" + +namespace webrtc { + +std::unique_ptr CreateRandomWalkCrossTraffic( + CrossTrafficRoute* traffic_route, + RandomWalkConfig config) { + return std::make_unique(config, traffic_route); +} + +std::unique_ptr CreatePulsedPeaksCrossTraffic( + CrossTrafficRoute* traffic_route, + PulsedPeaksConfig config) { + return std::make_unique(config, traffic_route); +} + +std::unique_ptr CreateFakeTcpCrossTraffic( + EmulatedRoute* send_route, + EmulatedRoute* ret_route, + FakeTcpConfig config) { + return std::make_unique(config, send_route, + ret_route); +} + +} // namespace webrtc diff --git a/api/test/network_emulation/create_cross_traffic.h b/api/test/network_emulation/create_cross_traffic.h new file mode 100644 index 0000000000..42fc855392 --- /dev/null +++ b/api/test/network_emulation/create_cross_traffic.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef API_TEST_NETWORK_EMULATION_CREATE_CROSS_TRAFFIC_H_ +#define API_TEST_NETWORK_EMULATION_CREATE_CROSS_TRAFFIC_H_ + +#include + +#include "api/test/network_emulation/cross_traffic.h" +#include "api/test/network_emulation_manager.h" + +namespace webrtc { + +// This API is still in development and can be changed without prior notice. + +std::unique_ptr CreateRandomWalkCrossTraffic( + CrossTrafficRoute* traffic_route, + RandomWalkConfig config); + +std::unique_ptr CreatePulsedPeaksCrossTraffic( + CrossTrafficRoute* traffic_route, + PulsedPeaksConfig config); + +std::unique_ptr CreateFakeTcpCrossTraffic( + EmulatedRoute* send_route, + EmulatedRoute* ret_route, + FakeTcpConfig config); + +} // namespace webrtc + +#endif // API_TEST_NETWORK_EMULATION_CREATE_CROSS_TRAFFIC_H_ diff --git a/api/test/network_emulation/cross_traffic.h b/api/test/network_emulation/cross_traffic.h new file mode 100644 index 0000000000..85343e44d2 --- /dev/null +++ b/api/test/network_emulation/cross_traffic.h @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef API_TEST_NETWORK_EMULATION_CROSS_TRAFFIC_H_ +#define API_TEST_NETWORK_EMULATION_CROSS_TRAFFIC_H_ + +#include "api/task_queue/task_queue_base.h" +#include "api/test/network_emulation/network_emulation_interfaces.h" +#include "api/units/data_rate.h" +#include "api/units/data_size.h" +#include "api/units/time_delta.h" +#include "api/units/timestamp.h" + +namespace webrtc { + +// This API is still in development and can be changed without prior notice. + +// Represents the endpoint for cross traffic that is going through the network. +// It can be used to emulate unexpected network load. +class CrossTrafficRoute { + public: + virtual ~CrossTrafficRoute() = default; + + // Triggers sending of dummy packets with size |packet_size| bytes. + virtual void TriggerPacketBurst(size_t num_packets, size_t packet_size) = 0; + // Sends a packet over the nodes. The content of the packet is unspecified; + // only the size metter for the emulation purposes. + virtual void SendPacket(size_t packet_size) = 0; + // Sends a packet over the nodes and runs |action| when it has been delivered. + virtual void NetworkDelayedAction(size_t packet_size, + std::function action) = 0; +}; + +// Describes a way of generating cross traffic on some route. Used by +// NetworkEmulationManager to produce cross traffic during some period of time. +class CrossTrafficGenerator { + public: + virtual ~CrossTrafficGenerator() = default; + + // Time between Process calls. + virtual TimeDelta GetProcessInterval() const = 0; + + // Called periodically by NetworkEmulationManager. Generates traffic on the + // route. + virtual void Process(Timestamp at_time) = 0; +}; + +// Config of a cross traffic generator. Generated traffic rises and falls +// randomly. +struct RandomWalkConfig { + int random_seed = 1; + DataRate peak_rate = DataRate::KilobitsPerSec(100); + DataSize min_packet_size = DataSize::Bytes(200); + TimeDelta min_packet_interval = TimeDelta::Millis(1); + TimeDelta update_interval = TimeDelta::Millis(200); + double variance = 0.6; + double bias = -0.1; +}; + +// Config of a cross traffic generator. Generated traffic has form of periodic +// peaks alternating with periods of silence. +struct PulsedPeaksConfig { + DataRate peak_rate = DataRate::KilobitsPerSec(100); + DataSize min_packet_size = DataSize::Bytes(200); + TimeDelta min_packet_interval = TimeDelta::Millis(1); + TimeDelta send_duration = TimeDelta::Millis(100); + TimeDelta hold_duration = TimeDelta::Millis(2000); +}; + +struct FakeTcpConfig { + DataSize packet_size = DataSize::Bytes(1200); + DataSize send_limit = DataSize::PlusInfinity(); + TimeDelta process_interval = TimeDelta::Millis(200); + TimeDelta packet_timeout = TimeDelta::Seconds(1); +}; + +} // namespace webrtc + +#endif // API_TEST_NETWORK_EMULATION_CROSS_TRAFFIC_H_ diff --git a/api/test/network_emulation/network_emulation_interfaces.h b/api/test/network_emulation/network_emulation_interfaces.h index 36fb996549..c8e6ed053e 100644 --- a/api/test/network_emulation/network_emulation_interfaces.h +++ b/api/test/network_emulation/network_emulation_interfaces.h @@ -222,10 +222,23 @@ class EmulatedEndpoint : public EmulatedNetworkReceiverInterface { // |desired_port| != 0 and is free or will be the one, selected by endpoint) // or absl::nullopt if desired_port in used. Also fails if there are no more // free ports to bind to. + // + // The Bind- and Unbind-methods must not be called from within a bound + // receiver's OnPacketReceived method. virtual absl::optional BindReceiver( uint16_t desired_port, EmulatedNetworkReceiverInterface* receiver) = 0; + // Unbinds receiver from the specified port. Do nothing if no receiver was + // bound before. After this method returns, no more packets can be delivered + // to the receiver, and it is safe to destroy it. virtual void UnbindReceiver(uint16_t port) = 0; + // Binds receiver that will accept all packets which arrived on any port + // for which there are no bound receiver. + virtual void BindDefaultReceiver( + EmulatedNetworkReceiverInterface* receiver) = 0; + // Unbinds default receiver. Do nothing if no default receiver was bound + // before. + virtual void UnbindDefaultReceiver() = 0; virtual rtc::IPAddress GetPeerLocalAddress() const = 0; private: diff --git a/api/test/network_emulation_manager.h b/api/test/network_emulation_manager.h index 80efb0e7d8..ec51b290e0 100644 --- a/api/test/network_emulation_manager.h +++ b/api/test/network_emulation_manager.h @@ -17,6 +17,7 @@ #include #include "api/array_view.h" +#include "api/test/network_emulation/cross_traffic.h" #include "api/test/network_emulation/network_emulation_interfaces.h" #include "api/test/simulated_network.h" #include "api/test/time_controller.h" @@ -55,6 +56,8 @@ struct EmulatedEndpointConfig { kDebug }; + // If specified will be used to name endpoint for logging purposes. + absl::optional name = absl::nullopt; IpAddressFamily generated_ip_family = IpAddressFamily::kIpv4; // If specified will be used as IP address for endpoint node. Must be unique // among all created nodes. @@ -65,6 +68,14 @@ struct EmulatedEndpointConfig { // Network type which will be used to represent endpoint to WebRTC. rtc::AdapterType type = rtc::AdapterType::ADAPTER_TYPE_UNKNOWN; StatsGatheringMode stats_gathering_mode = StatsGatheringMode::kDefault; + // Allow endpoint to send packets specifying source IP address different to + // the current endpoint IP address. If false endpoint will crash if attempt + // to send such packet will be done. + bool allow_send_packet_with_different_source_ip = false; + // Allow endpoint to receive packet with destination IP address different to + // the current endpoint IP address. If false endpoint will crash if such + // packet will arrive. + bool allow_receive_packets_with_different_dest_ip = false; }; struct EmulatedTURNServerConfig { @@ -164,6 +175,8 @@ class NetworkEmulationManager { virtual ~NetworkEmulationManager() = default; virtual TimeController* time_controller() = 0; + // Returns a mode in which underlying time controller operates. + virtual TimeMode time_mode() const = 0; // Creates an emulated network node, which represents single network in // the emulated network layer. Uses default implementation on network behavior @@ -221,9 +234,39 @@ class NetworkEmulationManager { virtual EmulatedRoute* CreateRoute( const std::vector& via_nodes) = 0; + // Creates a default route between endpoints going through specified network + // nodes. Default route is used for packet when there is no known route for + // packet's destination IP. + // + // This route is single direction only and describe how traffic that was + // sent by network interface |from| have to be delivered in case if routing + // was unspecified. Return object can be used to remove created route. The + // route must contains at least one network node inside it. + // + // Assume that E{0-9} are endpoints and N{0-9} are network nodes, then + // creation of the route have to follow these rules: + // 1. A route consists of a source endpoint, an ordered list of one or + // more network nodes, and a destination endpoint. + // 2. If (E1, ..., E2) is a route, then E1 != E2. + // In other words, the source and the destination may not be the same. + // 3. Given two simultaneously existing routes (E1, ..., E2) and + // (E3, ..., E4), either E1 != E3 or E2 != E4. + // In other words, there may be at most one route from any given source + // endpoint to any given destination endpoint. + // 4. Given two simultaneously existing routes (E1, ..., N1, ..., E2) + // and (E3, ..., N2, ..., E4), either N1 != N2 or E2 != E4. + // In other words, a network node may not belong to two routes that lead + // to the same destination endpoint. + // 5. Any node N can belong to only one default route. + virtual EmulatedRoute* CreateDefaultRoute( + EmulatedEndpoint* from, + const std::vector& via_nodes, + EmulatedEndpoint* to) = 0; + // Removes route previously created by CreateRoute(...). // Caller mustn't call this function with route, that have been already - // removed earlier. + // removed earlier. Removing a route that is currently in use will lead to + // packets being dropped. virtual void ClearRoute(EmulatedRoute* route) = 0; // Creates a simulated TCP connection using |send_route| for traffic and @@ -233,6 +276,20 @@ class NetworkEmulationManager { virtual TcpMessageRoute* CreateTcpRoute(EmulatedRoute* send_route, EmulatedRoute* ret_route) = 0; + // Creates a route over the given |via_nodes|. Returns an object that can be + // used to emulate network load with cross traffic over the created route. + virtual CrossTrafficRoute* CreateCrossTrafficRoute( + const std::vector& via_nodes) = 0; + + // Starts generating cross traffic using given |generator|. Takes ownership + // over the generator. + virtual CrossTrafficGenerator* StartCrossTraffic( + std::unique_ptr generator) = 0; + + // Stops generating cross traffic that was started using given |generator|. + // The |generator| shouldn't be used after and the reference may be invalid. + virtual void StopCrossTraffic(CrossTrafficGenerator* generator) = 0; + // Creates EmulatedNetworkManagerInterface which can be used then to inject // network emulation layer into PeerConnection. |endpoints| - are available // network interfaces for PeerConnection. If endpoint is enabled, it will be @@ -246,7 +303,7 @@ class NetworkEmulationManager { // |stats_callback|. Callback will be executed on network emulation // internal task queue. virtual void GetStats( - rtc::ArrayView endpoints, + rtc::ArrayView endpoints, std::function)> stats_callback) = 0; diff --git a/api/test/peerconnection_quality_test_fixture.h b/api/test/peerconnection_quality_test_fixture.h index f370478956..8717e8f73d 100644 --- a/api/test/peerconnection_quality_test_fixture.h +++ b/api/test/peerconnection_quality_test_fixture.h @@ -220,11 +220,19 @@ class PeerConnectionE2EQualityTestFixture { // was captured during the test for this video stream on sender side. // It is useful when generator is used as input. absl::optional input_dump_file_name; + // Used only if |input_dump_file_name| is set. Specifies the module for the + // video frames to be dumped. Modulo equals X means every Xth frame will be + // written to the dump file. The value must be greater than 0. + int input_dump_sampling_modulo = 1; // If specified this file will be used as output on the receiver side for // this stream. If multiple streams will be produced by input stream, // output files will be appended with indexes. The produced files contains // what was rendered for this video stream on receiver side. absl::optional output_dump_file_name; + // Used only if |output_dump_file_name| is set. Specifies the module for the + // video frames to be dumped. Modulo equals X means every Xth frame will be + // written to the dump file. The value must be greater than 0. + int output_dump_sampling_modulo = 1; // If true will display input and output video on the user's screen. bool show_on_screen = false; // If specified, determines a sync group to which this video stream belongs. diff --git a/api/test/simulated_network.h b/api/test/simulated_network.h index 3fba61f74d..fcac51f4ea 100644 --- a/api/test/simulated_network.h +++ b/api/test/simulated_network.h @@ -46,8 +46,7 @@ struct PacketDeliveryInfo { // for built-in network behavior that will be used by WebRTC if no custom // NetworkBehaviorInterface is provided. struct BuiltInNetworkBehaviorConfig { - BuiltInNetworkBehaviorConfig() {} - // Queue length in number of packets. + // Queue length in number of packets. size_t queue_length_packets = 0; // Delay in addition to capacity induced delay. int queue_delay_ms = 0; diff --git a/api/test/videocodec_test_fixture.h b/api/test/videocodec_test_fixture.h index 379d46d096..e0f804fe46 100644 --- a/api/test/videocodec_test_fixture.h +++ b/api/test/videocodec_test_fixture.h @@ -59,7 +59,7 @@ class VideoCodecTestFixture { class EncodedFrameChecker { public: virtual ~EncodedFrameChecker() = default; - virtual void CheckEncodedFrame(webrtc::VideoCodecType codec, + virtual void CheckEncodedFrame(VideoCodecType codec, const EncodedImage& encoded_frame) const = 0; }; @@ -123,16 +123,16 @@ class VideoCodecTestFixture { bool encode_in_real_time = false; // Codec settings to use. - webrtc::VideoCodec codec_settings; + VideoCodec codec_settings; // Name of the codec being tested. std::string codec_name; // H.264 specific settings. struct H264CodecSettings { - H264::Profile profile = H264::kProfileConstrainedBaseline; + H264Profile profile = H264Profile::kProfileConstrainedBaseline; H264PacketizationMode packetization_mode = - webrtc::H264PacketizationMode::NonInterleaved; + H264PacketizationMode::NonInterleaved; } h264_codec_settings; // Custom checker that will be called for each frame. diff --git a/api/test/videocodec_test_stats.cc b/api/test/videocodec_test_stats.cc index b2f88a4661..b973dc2d12 100644 --- a/api/test/videocodec_test_stats.cc +++ b/api/test/videocodec_test_stats.cc @@ -24,71 +24,91 @@ VideoCodecTestStats::FrameStatistics::FrameStatistics(size_t frame_number, std::string VideoCodecTestStats::FrameStatistics::ToString() const { rtc::StringBuilder ss; - ss << "frame_number " << frame_number; - ss << " decoded_width " << decoded_width; - ss << " decoded_height " << decoded_height; - ss << " spatial_idx " << spatial_idx; - ss << " temporal_idx " << temporal_idx; - ss << " inter_layer_predicted " << inter_layer_predicted; - ss << " non_ref_for_inter_layer_pred " << non_ref_for_inter_layer_pred; - ss << " frame_type " << static_cast(frame_type); - ss << " length_bytes " << length_bytes; - ss << " qp " << qp; - ss << " psnr " << psnr; - ss << " psnr_y " << psnr_y; - ss << " psnr_u " << psnr_u; - ss << " psnr_v " << psnr_v; - ss << " ssim " << ssim; - ss << " encode_time_us " << encode_time_us; - ss << " decode_time_us " << decode_time_us; - ss << " rtp_timestamp " << rtp_timestamp; - ss << " target_bitrate_kbps " << target_bitrate_kbps; - ss << " target_framerate_fps " << target_framerate_fps; + for (const auto& entry : ToMap()) { + if (ss.size() > 0) { + ss << " "; + } + ss << entry.first << " " << entry.second; + } return ss.Release(); } +std::map VideoCodecTestStats::FrameStatistics::ToMap() + const { + std::map map; + map["frame_number"] = std::to_string(frame_number); + map["decoded_width"] = std::to_string(decoded_width); + map["decoded_height"] = std::to_string(decoded_height); + map["spatial_idx"] = std::to_string(spatial_idx); + map["temporal_idx"] = std::to_string(temporal_idx); + map["inter_layer_predicted"] = std::to_string(inter_layer_predicted); + map["non_ref_for_inter_layer_pred"] = + std::to_string(non_ref_for_inter_layer_pred); + map["frame_type"] = std::to_string(static_cast(frame_type)); + map["length_bytes"] = std::to_string(length_bytes); + map["qp"] = std::to_string(qp); + map["psnr"] = std::to_string(psnr); + map["psnr_y"] = std::to_string(psnr_y); + map["psnr_u"] = std::to_string(psnr_u); + map["psnr_v"] = std::to_string(psnr_v); + map["ssim"] = std::to_string(ssim); + map["encode_time_us"] = std::to_string(encode_time_us); + map["decode_time_us"] = std::to_string(decode_time_us); + map["rtp_timestamp"] = std::to_string(rtp_timestamp); + map["target_bitrate_kbps"] = std::to_string(target_bitrate_kbps); + map["target_framerate_fps"] = std::to_string(target_framerate_fps); + return map; +} + std::string VideoCodecTestStats::VideoStatistics::ToString( std::string prefix) const { rtc::StringBuilder ss; - ss << prefix << "target_bitrate_kbps: " << target_bitrate_kbps; - ss << "\n" << prefix << "input_framerate_fps: " << input_framerate_fps; - ss << "\n" << prefix << "spatial_idx: " << spatial_idx; - ss << "\n" << prefix << "temporal_idx: " << temporal_idx; - ss << "\n" << prefix << "width: " << width; - ss << "\n" << prefix << "height: " << height; - ss << "\n" << prefix << "length_bytes: " << length_bytes; - ss << "\n" << prefix << "bitrate_kbps: " << bitrate_kbps; - ss << "\n" << prefix << "framerate_fps: " << framerate_fps; - ss << "\n" << prefix << "enc_speed_fps: " << enc_speed_fps; - ss << "\n" << prefix << "dec_speed_fps: " << dec_speed_fps; - ss << "\n" << prefix << "avg_delay_sec: " << avg_delay_sec; - ss << "\n" - << prefix << "max_key_frame_delay_sec: " << max_key_frame_delay_sec; - ss << "\n" - << prefix << "max_delta_frame_delay_sec: " << max_delta_frame_delay_sec; - ss << "\n" - << prefix << "time_to_reach_target_bitrate_sec: " - << time_to_reach_target_bitrate_sec; - ss << "\n" - << prefix << "avg_key_frame_size_bytes: " << avg_key_frame_size_bytes; - ss << "\n" - << prefix << "avg_delta_frame_size_bytes: " << avg_delta_frame_size_bytes; - ss << "\n" << prefix << "avg_qp: " << avg_qp; - ss << "\n" << prefix << "avg_psnr: " << avg_psnr; - ss << "\n" << prefix << "min_psnr: " << min_psnr; - ss << "\n" << prefix << "avg_ssim: " << avg_ssim; - ss << "\n" << prefix << "min_ssim: " << min_ssim; - ss << "\n" << prefix << "num_input_frames: " << num_input_frames; - ss << "\n" << prefix << "num_encoded_frames: " << num_encoded_frames; - ss << "\n" << prefix << "num_decoded_frames: " << num_decoded_frames; - ss << "\n" - << prefix - << "num_dropped_frames: " << num_input_frames - num_encoded_frames; - ss << "\n" << prefix << "num_key_frames: " << num_key_frames; - ss << "\n" << prefix << "num_spatial_resizes: " << num_spatial_resizes; - ss << "\n" << prefix << "max_nalu_size_bytes: " << max_nalu_size_bytes; + for (const auto& entry : ToMap()) { + if (ss.size() > 0) { + ss << "\n"; + } + ss << prefix << entry.first << ": " << entry.second; + } return ss.Release(); } +std::map VideoCodecTestStats::VideoStatistics::ToMap() + const { + std::map map; + map["target_bitrate_kbps"] = std::to_string(target_bitrate_kbps); + map["input_framerate_fps"] = std::to_string(input_framerate_fps); + map["spatial_idx"] = std::to_string(spatial_idx); + map["temporal_idx"] = std::to_string(temporal_idx); + map["width"] = std::to_string(width); + map["height"] = std::to_string(height); + map["length_bytes"] = std::to_string(length_bytes); + map["bitrate_kbps"] = std::to_string(bitrate_kbps); + map["framerate_fps"] = std::to_string(framerate_fps); + map["enc_speed_fps"] = std::to_string(enc_speed_fps); + map["dec_speed_fps"] = std::to_string(dec_speed_fps); + map["avg_delay_sec"] = std::to_string(avg_delay_sec); + map["max_key_frame_delay_sec"] = std::to_string(max_key_frame_delay_sec); + map["max_delta_frame_delay_sec"] = std::to_string(max_delta_frame_delay_sec); + map["time_to_reach_target_bitrate_sec"] = + std::to_string(time_to_reach_target_bitrate_sec); + map["avg_key_frame_size_bytes"] = std::to_string(avg_key_frame_size_bytes); + map["avg_delta_frame_size_bytes"] = + std::to_string(avg_delta_frame_size_bytes); + map["avg_qp"] = std::to_string(avg_qp); + map["avg_psnr"] = std::to_string(avg_psnr); + map["min_psnr"] = std::to_string(min_psnr); + map["avg_ssim"] = std::to_string(avg_ssim); + map["min_ssim"] = std::to_string(min_ssim); + map["num_input_frames"] = std::to_string(num_input_frames); + map["num_encoded_frames"] = std::to_string(num_encoded_frames); + map["num_decoded_frames"] = std::to_string(num_decoded_frames); + map["num_dropped_frames"] = + std::to_string(num_input_frames - num_encoded_frames); + map["num_key_frames"] = std::to_string(num_key_frames); + map["num_spatial_resizes"] = std::to_string(num_spatial_resizes); + map["max_nalu_size_bytes"] = std::to_string(max_nalu_size_bytes); + return map; +} + } // namespace test } // namespace webrtc diff --git a/api/test/videocodec_test_stats.h b/api/test/videocodec_test_stats.h index df1aed73aa..02a18a71d9 100644 --- a/api/test/videocodec_test_stats.h +++ b/api/test/videocodec_test_stats.h @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -33,6 +34,9 @@ class VideoCodecTestStats { std::string ToString() const; + // Returns name -> value text map of frame statistics. + std::map ToMap() const; + size_t frame_number = 0; size_t rtp_timestamp = 0; @@ -78,6 +82,9 @@ class VideoCodecTestStats { struct VideoStatistics { std::string ToString(std::string prefix) const; + // Returns name -> value text map of video statistics. + std::map ToMap() const; + size_t target_bitrate_kbps = 0; float input_framerate_fps = 0.0f; diff --git a/api/transport/BUILD.gn b/api/transport/BUILD.gn index 7bcda8b4a7..30955273b0 100644 --- a/api/transport/BUILD.gn +++ b/api/transport/BUILD.gn @@ -33,7 +33,6 @@ rtc_library("network_control") { deps = [ ":webrtc_key_value_config", - "../../rtc_base:deprecation", "../rtc_event_log", "../units:data_rate", "../units:data_size", @@ -89,8 +88,8 @@ rtc_library("goog_cc") { ":webrtc_key_value_config", "..:network_state_predictor_api", "../../modules/congestion_controller/goog_cc", - "../../rtc_base:deprecation", ] + absl_deps = [ "//third_party/abseil-cpp/absl/base:core_headers" ] } rtc_source_set("sctp_transport_factory_interface") { @@ -108,8 +107,10 @@ rtc_source_set("stun_types") { deps = [ "../../api:array_view", "../../rtc_base:checks", + "../../rtc_base:ip_address", "../../rtc_base:rtc_base", "../../rtc_base:rtc_base_approved", + "../../rtc_base:socket_address", ] absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] } @@ -147,6 +148,7 @@ if (rtc_include_tests) { ":stun_types", "../../rtc_base", "../../rtc_base:rtc_base_approved", + "../../rtc_base:socket_address", "../../test:test_support", "//testing/gtest", ] diff --git a/api/transport/data_channel_transport_interface.h b/api/transport/data_channel_transport_interface.h index 7b8c653c39..2b2f5d2e6d 100644 --- a/api/transport/data_channel_transport_interface.h +++ b/api/transport/data_channel_transport_interface.h @@ -47,15 +47,15 @@ struct SendDataParams { // If set, the maximum number of times this message may be // retransmitted by the transport before it is dropped. // Setting this value to zero disables retransmission. - // Must be non-negative. |max_rtx_count| and |max_rtx_ms| may not be set - // simultaneously. + // Valid values are in the range [0-UINT16_MAX]. + // |max_rtx_count| and |max_rtx_ms| may not be set simultaneously. absl::optional max_rtx_count; // If set, the maximum number of milliseconds for which the transport // may retransmit this message before it is dropped. // Setting this value to zero disables retransmission. - // Must be non-negative. |max_rtx_count| and |max_rtx_ms| may not be set - // simultaneously. + // Valid values are in the range [0-UINT16_MAX]. + // |max_rtx_count| and |max_rtx_ms| may not be set simultaneously. absl::optional max_rtx_ms; }; @@ -88,7 +88,7 @@ class DataChannelSink { // Callback issued when the data channel becomes unusable (closed). // TODO(https://crbug.com/webrtc/10360): Make pure virtual when all // consumers updated. - virtual void OnTransportClosed() {} + virtual void OnTransportClosed(RTCError error) {} }; // Transport for data channels. diff --git a/api/transport/goog_cc_factory.h b/api/transport/goog_cc_factory.h index b14d6dcd78..e12755d745 100644 --- a/api/transport/goog_cc_factory.h +++ b/api/transport/goog_cc_factory.h @@ -12,9 +12,9 @@ #define API_TRANSPORT_GOOG_CC_FACTORY_H_ #include +#include "absl/base/attributes.h" #include "api/network_state_predictor.h" #include "api/transport/network_control.h" -#include "rtc_base/deprecation.h" namespace webrtc { class RtcEventLog; @@ -31,8 +31,8 @@ class GoogCcNetworkControllerFactory : public NetworkControllerFactoryInterface { public: GoogCcNetworkControllerFactory() = default; - explicit RTC_DEPRECATED GoogCcNetworkControllerFactory( - RtcEventLog* event_log); + ABSL_DEPRECATED("") + explicit GoogCcNetworkControllerFactory(RtcEventLog* event_log); explicit GoogCcNetworkControllerFactory( NetworkStatePredictorFactoryInterface* network_state_predictor_factory); @@ -49,7 +49,8 @@ class GoogCcNetworkControllerFactory // Deprecated, use GoogCcFactoryConfig to enable feedback only mode instead. // Factory to create packet feedback only GoogCC, this can be used for // connections providing packet receive time feedback but no other reports. -class RTC_DEPRECATED GoogCcFeedbackNetworkControllerFactory +class ABSL_DEPRECATED("use GoogCcFactoryConfig instead") + GoogCcFeedbackNetworkControllerFactory : public GoogCcNetworkControllerFactory { public: explicit GoogCcFeedbackNetworkControllerFactory(RtcEventLog* event_log); diff --git a/api/transport/network_types.cc b/api/transport/network_types.cc index 88b67b3a47..7451940151 100644 --- a/api/transport/network_types.cc +++ b/api/transport/network_types.cc @@ -48,7 +48,7 @@ std::vector TransportPacketsFeedback::ReceivedWithSendInfo() const { std::vector res; for (const PacketResult& fb : packet_feedbacks) { - if (fb.receive_time.IsFinite()) { + if (fb.IsReceived()) { res.push_back(fb); } } @@ -58,7 +58,7 @@ std::vector TransportPacketsFeedback::ReceivedWithSendInfo() std::vector TransportPacketsFeedback::LostWithSendInfo() const { std::vector res; for (const PacketResult& fb : packet_feedbacks) { - if (fb.receive_time.IsPlusInfinity()) { + if (!fb.IsReceived()) { res.push_back(fb); } } @@ -74,7 +74,7 @@ std::vector TransportPacketsFeedback::SortedByReceiveTime() const { std::vector res; for (const PacketResult& fb : packet_feedbacks) { - if (fb.receive_time.IsFinite()) { + if (fb.IsReceived()) { res.push_back(fb); } } diff --git a/api/transport/network_types.h b/api/transport/network_types.h index 10fc0beedf..4e96b0f12e 100644 --- a/api/transport/network_types.h +++ b/api/transport/network_types.h @@ -19,7 +19,6 @@ #include "api/units/data_size.h" #include "api/units/time_delta.h" #include "api/units/timestamp.h" -#include "rtc_base/deprecation.h" namespace webrtc { @@ -159,6 +158,8 @@ struct PacketResult { PacketResult(const PacketResult&); ~PacketResult(); + inline bool IsReceived() const { return !receive_time.IsPlusInfinity(); } + SentPacket sent_packet; Timestamp receive_time = Timestamp::PlusInfinity(); }; diff --git a/api/transport/stun.cc b/api/transport/stun.cc index e1bf03be62..1b5bf0c409 100644 --- a/api/transport/stun.cc +++ b/api/transport/stun.cc @@ -246,6 +246,31 @@ const StunUInt16ListAttribute* StunMessage::GetUnknownAttributes() const { GetAttribute(STUN_ATTR_UNKNOWN_ATTRIBUTES)); } +StunMessage::IntegrityStatus StunMessage::ValidateMessageIntegrity( + const std::string& password) { + password_ = password; + if (GetByteString(STUN_ATTR_MESSAGE_INTEGRITY)) { + if (ValidateMessageIntegrityOfType( + STUN_ATTR_MESSAGE_INTEGRITY, kStunMessageIntegritySize, + buffer_.c_str(), buffer_.size(), password)) { + integrity_ = IntegrityStatus::kIntegrityOk; + } else { + integrity_ = IntegrityStatus::kIntegrityBad; + } + } else if (GetByteString(STUN_ATTR_GOOG_MESSAGE_INTEGRITY_32)) { + if (ValidateMessageIntegrityOfType( + STUN_ATTR_GOOG_MESSAGE_INTEGRITY_32, kStunMessageIntegrity32Size, + buffer_.c_str(), buffer_.size(), password)) { + integrity_ = IntegrityStatus::kIntegrityOk; + } else { + integrity_ = IntegrityStatus::kIntegrityBad; + } + } else { + integrity_ = IntegrityStatus::kNoIntegrity; + } + return integrity_; +} + bool StunMessage::ValidateMessageIntegrity(const char* data, size_t size, const std::string& password) { @@ -353,11 +378,6 @@ bool StunMessage::AddMessageIntegrity(const std::string& password) { password.size()); } -bool StunMessage::AddMessageIntegrity(const char* key, size_t keylen) { - return AddMessageIntegrityOfType(STUN_ATTR_MESSAGE_INTEGRITY, - kStunMessageIntegritySize, key, keylen); -} - bool StunMessage::AddMessageIntegrity32(absl::string_view password) { return AddMessageIntegrityOfType(STUN_ATTR_GOOG_MESSAGE_INTEGRITY_32, kStunMessageIntegrity32Size, password.data(), @@ -395,6 +415,8 @@ bool StunMessage::AddMessageIntegrityOfType(int attr_type, // Insert correct HMAC into the attribute. msg_integrity_attr->CopyBytes(hmac, attr_size); + password_.assign(key, keylen); + integrity_ = IntegrityStatus::kIntegrityOk; return true; } @@ -473,6 +495,9 @@ bool StunMessage::AddFingerprint() { } bool StunMessage::Read(ByteBufferReader* buf) { + // Keep a copy of the buffer data around for later verification. + buffer_.assign(buf->Data(), buf->Length()); + if (!buf->ReadUInt16(&type_)) { return false; } diff --git a/api/transport/stun.h b/api/transport/stun.h index 8893b2a1ff..682a17a945 100644 --- a/api/transport/stun.h +++ b/api/transport/stun.h @@ -16,6 +16,7 @@ #include #include + #include #include #include @@ -149,15 +150,24 @@ class StunMessage { StunMessage(); virtual ~StunMessage(); + // The verification status of the message. This is checked on parsing, + // or set by AddMessageIntegrity. + enum class IntegrityStatus { + kNotSet, + kNoIntegrity, // Message-integrity attribute missing + kIntegrityOk, // Message-integrity checked OK + kIntegrityBad, // Message-integrity verification failed + }; + int type() const { return type_; } size_t length() const { return length_; } const std::string& transaction_id() const { return transaction_id_; } uint32_t reduced_transaction_id() const { return reduced_transaction_id_; } // Returns true if the message confirms to RFC3489 rather than - // RFC5389. The main difference between two version of the STUN + // RFC5389. The main difference between the two versions of the STUN // protocol is the presence of the magic cookie and different length - // of transaction ID. For outgoing packets version of the protocol + // of transaction ID. For outgoing packets the version of the protocol // is determined by the lengths of the transaction ID. bool IsLegacy() const; @@ -191,19 +201,27 @@ class StunMessage { // Remote all attributes and releases them. void ClearAttributes(); - // Validates that a raw STUN message has a correct MESSAGE-INTEGRITY value. - // This can't currently be done on a StunMessage, since it is affected by - // padding data (which we discard when reading a StunMessage). - static bool ValidateMessageIntegrity(const char* data, - size_t size, - const std::string& password); - static bool ValidateMessageIntegrity32(const char* data, - size_t size, - const std::string& password); + // Validates that a STUN message has a correct MESSAGE-INTEGRITY value. + // This uses the buffered raw-format message stored by Read(). + IntegrityStatus ValidateMessageIntegrity(const std::string& password); + + // Returns the current integrity status of the message. + IntegrityStatus integrity() const { return integrity_; } + + // Shortcut for checking if integrity is verified. + bool IntegrityOk() const { + return integrity_ == IntegrityStatus::kIntegrityOk; + } + + // Returns the password attribute used to set or check the integrity. + // Can only be called after adding or checking the integrity. + std::string password() const { + RTC_DCHECK(integrity_ != IntegrityStatus::kNotSet); + return password_; + } // Adds a MESSAGE-INTEGRITY attribute that is valid for the current message. bool AddMessageIntegrity(const std::string& password); - bool AddMessageIntegrity(const char* key, size_t keylen); // Adds a STUN_ATTR_GOOG_MESSAGE_INTEGRITY_32 attribute that is valid for the // current message. @@ -244,6 +262,30 @@ class StunMessage { bool EqualAttributes(const StunMessage* other, std::function attribute_type_mask) const; + // Expose raw-buffer ValidateMessageIntegrity function for testing. + static bool ValidateMessageIntegrityForTesting(const char* data, + size_t size, + const std::string& password) { + return ValidateMessageIntegrity(data, size, password); + } + // Expose raw-buffer ValidateMessageIntegrity function for testing. + static bool ValidateMessageIntegrity32ForTesting( + const char* data, + size_t size, + const std::string& password) { + return ValidateMessageIntegrity32(data, size, password); + } + // Validates that a STUN message in byte buffer form + // has a correct MESSAGE-INTEGRITY value. + // These functions are not recommended and will be deprecated; use + // ValidateMessageIntegrity(password) on the parsed form instead. + static bool ValidateMessageIntegrity(const char* data, + size_t size, + const std::string& password); + static bool ValidateMessageIntegrity32(const char* data, + size_t size, + const std::string& password); + protected: // Verifies that the given attribute is allowed for this message. virtual StunAttributeValueType GetAttributeValueType(int type) const; @@ -269,6 +311,10 @@ class StunMessage { std::string transaction_id_; uint32_t reduced_transaction_id_; uint32_t stun_magic_cookie_; + // The original buffer for messages created by Read(). + std::string buffer_; + IntegrityStatus integrity_ = IntegrityStatus::kNotSet; + std::string password_; }; // Base class for all STUN/TURN attributes. diff --git a/api/transport/stun_unittest.cc b/api/transport/stun_unittest.cc index bf2717e007..bf791f257d 100644 --- a/api/transport/stun_unittest.cc +++ b/api/transport/stun_unittest.cc @@ -1196,24 +1196,24 @@ TEST_F(StunTest, FailToReadRtcpPacket) { // Check our STUN message validation code against the RFC5769 test messages. TEST_F(StunTest, ValidateMessageIntegrity) { // Try the messages from RFC 5769. - EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + EXPECT_TRUE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kRfc5769SampleRequest), sizeof(kRfc5769SampleRequest), kRfc5769SampleMsgPassword)); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kRfc5769SampleRequest), sizeof(kRfc5769SampleRequest), "InvalidPassword")); - EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + EXPECT_TRUE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kRfc5769SampleResponse), sizeof(kRfc5769SampleResponse), kRfc5769SampleMsgPassword)); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kRfc5769SampleResponse), sizeof(kRfc5769SampleResponse), "InvalidPassword")); - EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + EXPECT_TRUE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kRfc5769SampleResponseIPv6), sizeof(kRfc5769SampleResponseIPv6), kRfc5769SampleMsgPassword)); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kRfc5769SampleResponseIPv6), sizeof(kRfc5769SampleResponseIPv6), "InvalidPassword")); @@ -1222,40 +1222,40 @@ TEST_F(StunTest, ValidateMessageIntegrity) { ComputeStunCredentialHash(kRfc5769SampleMsgWithAuthUsername, kRfc5769SampleMsgWithAuthRealm, kRfc5769SampleMsgWithAuthPassword, &key); - EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + EXPECT_TRUE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kRfc5769SampleRequestLongTermAuth), sizeof(kRfc5769SampleRequestLongTermAuth), key)); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kRfc5769SampleRequestLongTermAuth), sizeof(kRfc5769SampleRequestLongTermAuth), "InvalidPassword")); // Try some edge cases. - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kStunMessageWithZeroLength), sizeof(kStunMessageWithZeroLength), kRfc5769SampleMsgPassword)); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kStunMessageWithExcessLength), sizeof(kStunMessageWithExcessLength), kRfc5769SampleMsgPassword)); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kStunMessageWithSmallLength), sizeof(kStunMessageWithSmallLength), kRfc5769SampleMsgPassword)); // Again, but with the lengths matching what is claimed in the headers. - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kStunMessageWithZeroLength), kStunHeaderSize + rtc::GetBE16(&kStunMessageWithZeroLength[2]), kRfc5769SampleMsgPassword)); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kStunMessageWithExcessLength), kStunHeaderSize + rtc::GetBE16(&kStunMessageWithExcessLength[2]), kRfc5769SampleMsgPassword)); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kStunMessageWithSmallLength), kStunHeaderSize + rtc::GetBE16(&kStunMessageWithSmallLength[2]), kRfc5769SampleMsgPassword)); // Check that a too-short HMAC doesn't cause buffer overflow. - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(kStunMessageWithBadHmacAtEnd), sizeof(kStunMessageWithBadHmacAtEnd), kRfc5769SampleMsgPassword)); @@ -1268,8 +1268,8 @@ TEST_F(StunTest, ValidateMessageIntegrity) { if (i > 0) buf[i - 1] ^= 0x01; EXPECT_EQ(i >= sizeof(buf) - 8, - StunMessage::ValidateMessageIntegrity(buf, sizeof(buf), - kRfc5769SampleMsgPassword)); + StunMessage::ValidateMessageIntegrityForTesting( + buf, sizeof(buf), kRfc5769SampleMsgPassword)); } } @@ -1291,7 +1291,7 @@ TEST_F(StunTest, AddMessageIntegrity) { rtc::ByteBufferWriter buf1; EXPECT_TRUE(msg.Write(&buf1)); - EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + EXPECT_TRUE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(buf1.Data()), buf1.Length(), kRfc5769SampleMsgPassword)); @@ -1309,7 +1309,7 @@ TEST_F(StunTest, AddMessageIntegrity) { rtc::ByteBufferWriter buf3; EXPECT_TRUE(msg2.Write(&buf3)); - EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + EXPECT_TRUE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(buf3.Data()), buf3.Length(), kRfc5769SampleMsgPassword)); } @@ -1317,40 +1317,40 @@ TEST_F(StunTest, AddMessageIntegrity) { // Check our STUN message validation code against the RFC5769 test messages. TEST_F(StunTest, ValidateMessageIntegrity32) { // Try the messages from RFC 5769. - EXPECT_TRUE(StunMessage::ValidateMessageIntegrity32( + EXPECT_TRUE(StunMessage::ValidateMessageIntegrity32ForTesting( reinterpret_cast(kSampleRequestMI32), sizeof(kSampleRequestMI32), kRfc5769SampleMsgPassword)); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32ForTesting( reinterpret_cast(kSampleRequestMI32), sizeof(kSampleRequestMI32), "InvalidPassword")); // Try some edge cases. - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32ForTesting( reinterpret_cast(kStunMessageWithZeroLength), sizeof(kStunMessageWithZeroLength), kRfc5769SampleMsgPassword)); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32ForTesting( reinterpret_cast(kStunMessageWithExcessLength), sizeof(kStunMessageWithExcessLength), kRfc5769SampleMsgPassword)); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32ForTesting( reinterpret_cast(kStunMessageWithSmallLength), sizeof(kStunMessageWithSmallLength), kRfc5769SampleMsgPassword)); // Again, but with the lengths matching what is claimed in the headers. - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32ForTesting( reinterpret_cast(kStunMessageWithZeroLength), kStunHeaderSize + rtc::GetBE16(&kStunMessageWithZeroLength[2]), kRfc5769SampleMsgPassword)); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32ForTesting( reinterpret_cast(kStunMessageWithExcessLength), kStunHeaderSize + rtc::GetBE16(&kStunMessageWithExcessLength[2]), kRfc5769SampleMsgPassword)); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32ForTesting( reinterpret_cast(kStunMessageWithSmallLength), kStunHeaderSize + rtc::GetBE16(&kStunMessageWithSmallLength[2]), kRfc5769SampleMsgPassword)); // Check that a too-short HMAC doesn't cause buffer overflow. - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32ForTesting( reinterpret_cast(kStunMessageWithBadHmacAtEnd), sizeof(kStunMessageWithBadHmacAtEnd), kRfc5769SampleMsgPassword)); @@ -1363,7 +1363,7 @@ TEST_F(StunTest, ValidateMessageIntegrity32) { if (i > 0) buf[i - 1] ^= 0x01; EXPECT_EQ(i >= sizeof(buf) - 8, - StunMessage::ValidateMessageIntegrity32( + StunMessage::ValidateMessageIntegrity32ForTesting( buf, sizeof(buf), kRfc5769SampleMsgPassword)); } } @@ -1384,7 +1384,7 @@ TEST_F(StunTest, AddMessageIntegrity32) { rtc::ByteBufferWriter buf1; EXPECT_TRUE(msg.Write(&buf1)); - EXPECT_TRUE(StunMessage::ValidateMessageIntegrity32( + EXPECT_TRUE(StunMessage::ValidateMessageIntegrity32ForTesting( reinterpret_cast(buf1.Data()), buf1.Length(), kRfc5769SampleMsgPassword)); @@ -1402,7 +1402,7 @@ TEST_F(StunTest, AddMessageIntegrity32) { rtc::ByteBufferWriter buf3; EXPECT_TRUE(msg2.Write(&buf3)); - EXPECT_TRUE(StunMessage::ValidateMessageIntegrity32( + EXPECT_TRUE(StunMessage::ValidateMessageIntegrity32ForTesting( reinterpret_cast(buf3.Data()), buf3.Length(), kRfc5769SampleMsgPassword)); } @@ -1420,14 +1420,14 @@ TEST_F(StunTest, AddMessageIntegrity32AndMessageIntegrity) { rtc::ByteBufferWriter buf1; EXPECT_TRUE(msg.Write(&buf1)); - EXPECT_TRUE(StunMessage::ValidateMessageIntegrity32( + EXPECT_TRUE(StunMessage::ValidateMessageIntegrity32ForTesting( reinterpret_cast(buf1.Data()), buf1.Length(), "password1")); - EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + EXPECT_TRUE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(buf1.Data()), buf1.Length(), "password2")); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity32ForTesting( reinterpret_cast(buf1.Data()), buf1.Length(), "password2")); - EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + EXPECT_FALSE(StunMessage::ValidateMessageIntegrityForTesting( reinterpret_cast(buf1.Data()), buf1.Length(), "password1")); } diff --git a/api/uma_metrics.h b/api/uma_metrics.h index 30543b68b1..a975b82aeb 100644 --- a/api/uma_metrics.h +++ b/api/uma_metrics.h @@ -167,6 +167,52 @@ enum SimulcastApiVersion { kSimulcastApiVersionMax }; +// Metrics for reporting usage of BUNDLE. +// These values are persisted to logs. Entries should not be renumbered and +// numeric values should never be reused. +enum BundleUsage { + // There are no m-lines in the SDP, only a session description. + kBundleUsageEmpty = 0, + // Only a data channel is negotiated but BUNDLE is not negotiated. + kBundleUsageNoBundleDatachannelOnly = 1, + // BUNDLE is not negotiated and there is at most one m-line per media type, + kBundleUsageNoBundleSimple = 2, + // BUNDLE is not negotiated and there are multiple m-lines per media type, + kBundleUsageNoBundleComplex = 3, + // Only a data channel is negotiated and BUNDLE is negotiated. + kBundleUsageBundleDatachannelOnly = 4, + // BUNDLE is negotiated but there is at most one m-line per media type, + kBundleUsageBundleSimple = 5, + // BUNDLE is negotiated and there are multiple m-lines per media type, + kBundleUsageBundleComplex = 6, + // Legacy plan-b metrics. + kBundleUsageNoBundlePlanB = 7, + kBundleUsageBundlePlanB = 8, + kBundleUsageMax +}; + +// Metrics for reporting configured BUNDLE policy, mapping directly to +// https://w3c.github.io/webrtc-pc/#rtcbundlepolicy-enum +// These values are persisted to logs. Entries should not be renumbered and +// numeric values should never be reused. +enum BundlePolicyUsage { + kBundlePolicyUsageBalanced = 0, + kBundlePolicyUsageMaxBundle = 1, + kBundlePolicyUsageMaxCompat = 2, + kBundlePolicyUsageMax +}; + +// Metrics for provisional answers as described in +// https://datatracker.ietf.org/doc/html/rfc8829#section-4.1.10.1 +// These values are persisted to logs. Entries should not be renumbered and +// numeric values should never be reused. +enum ProvisionalAnswerUsage { + kProvisionalAnswerNotUsed = 0, + kProvisionalAnswerLocal = 1, + kProvisionalAnswerRemote = 2, + kProvisionalAnswerMax +}; + // When adding new metrics please consider using the style described in // https://chromium.googlesource.com/chromium/src.git/+/HEAD/tools/metrics/histograms/README.md#usage // instead of the legacy enums used above. diff --git a/api/video/BUILD.gn b/api/video/BUILD.gn index d50a334635..ec90bc137e 100644 --- a/api/video/BUILD.gn +++ b/api/video/BUILD.gn @@ -43,6 +43,8 @@ rtc_library("video_frame") { sources = [ "i420_buffer.cc", "i420_buffer.h", + "nv12_buffer.cc", + "nv12_buffer.h", "video_codec_type.h", "video_frame.cc", "video_frame.h", @@ -90,23 +92,6 @@ rtc_library("video_frame_i010") { ] } -rtc_library("video_frame_nv12") { - visibility = [ "*" ] - sources = [ - "nv12_buffer.cc", - "nv12_buffer.h", - ] - deps = [ - ":video_frame", - "..:scoped_refptr", - "../../rtc_base", - "../../rtc_base:checks", - "../../rtc_base/memory:aligned_malloc", - "../../rtc_base/system:rtc_export", - "//third_party/libyuv", - ] -} - rtc_source_set("recordable_encoded_frame") { visibility = [ "*" ] sources = [ "recordable_encoded_frame.h" ] @@ -142,7 +127,6 @@ rtc_library("encoded_image") { "..:rtp_packet_info", "..:scoped_refptr", "../../rtc_base:checks", - "../../rtc_base:deprecation", "../../rtc_base:rtc_base_approved", "../../rtc_base/system:rtc_export", ] @@ -159,6 +143,41 @@ rtc_library("encoded_frame") { deps = [ "../../modules/video_coding:encoded_frame" ] } +rtc_library("rtp_video_frame_assembler") { + visibility = [ "*" ] + sources = [ + "rtp_video_frame_assembler.cc", + "rtp_video_frame_assembler.h", + ] + + deps = [ + ":encoded_frame", + "../../modules/rtp_rtcp:rtp_rtcp", + "../../modules/rtp_rtcp:rtp_rtcp_format", + "../../modules/video_coding:video_coding", + "../../rtc_base:logging", + ] + + absl_deps = [ + "//third_party/abseil-cpp/absl/container:inlined_vector", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("rtp_video_frame_assembler_unittests") { + testonly = true + sources = [ "rtp_video_frame_assembler_unittests.cc" ] + + deps = [ + ":rtp_video_frame_assembler", + "..:array_view", + "../../modules/rtp_rtcp:rtp_packetizer_av1_test_helper", + "../../modules/rtp_rtcp:rtp_rtcp", + "../../modules/rtp_rtcp:rtp_rtcp_format", + "../../test:test_support", + ] +} + rtc_source_set("video_codec_constants") { visibility = [ "*" ] sources = [ "video_codec_constants.h" ] diff --git a/api/video/DEPS b/api/video/DEPS index 1cb8ad83cb..cf6770dce0 100644 --- a/api/video/DEPS +++ b/api/video/DEPS @@ -40,4 +40,8 @@ specific_include_rules = { "video_stream_encoder_create.cc": [ "+video/video_stream_encoder.h", ], + + "rtp_video_frame_assembler.h": [ + "+modules/rtp_rtcp/source/rtp_packet_received.h", + ], } diff --git a/api/video/encoded_frame.cc b/api/video/encoded_frame.cc index 26a794ec02..42d6b06b84 100644 --- a/api/video/encoded_frame.cc +++ b/api/video/encoded_frame.cc @@ -11,11 +11,9 @@ #include "api/video/encoded_frame.h" namespace webrtc { -namespace video_coding { bool EncodedFrame::delayed_by_retransmission() const { return 0; } -} // namespace video_coding } // namespace webrtc diff --git a/api/video/encoded_frame.h b/api/video/encoded_frame.h index 6a2b1f82e5..5f046327fa 100644 --- a/api/video/encoded_frame.h +++ b/api/video/encoded_frame.h @@ -17,37 +17,6 @@ #include "modules/video_coding/encoded_frame.h" namespace webrtc { -namespace video_coding { - -// NOTE: This class is still under development and may change without notice. -struct VideoLayerFrameId { - // TODO(philipel): The default ctor is currently used internaly, but have a - // look if we can remove it. - VideoLayerFrameId() : picture_id(-1), spatial_layer(0) {} - VideoLayerFrameId(int64_t picture_id, uint8_t spatial_layer) - : picture_id(picture_id), spatial_layer(spatial_layer) {} - - bool operator==(const VideoLayerFrameId& rhs) const { - return picture_id == rhs.picture_id && spatial_layer == rhs.spatial_layer; - } - - bool operator!=(const VideoLayerFrameId& rhs) const { - return !(*this == rhs); - } - - bool operator<(const VideoLayerFrameId& rhs) const { - if (picture_id == rhs.picture_id) - return spatial_layer < rhs.spatial_layer; - return picture_id < rhs.picture_id; - } - - bool operator<=(const VideoLayerFrameId& rhs) const { return !(rhs < *this); } - bool operator>(const VideoLayerFrameId& rhs) const { return rhs < *this; } - bool operator>=(const VideoLayerFrameId& rhs) const { return rhs <= *this; } - - int64_t picture_id; - uint8_t spatial_layer; -}; // TODO(philipel): Remove webrtc::VCMEncodedFrame inheritance. // TODO(philipel): Move transport specific info out of EncodedFrame. @@ -73,7 +42,8 @@ class EncodedFrame : public webrtc::VCMEncodedFrame { bool is_keyframe() const { return num_references == 0; } - VideoLayerFrameId id; + void SetId(int64_t id) { id_ = id; } + int64_t Id() const { return id_; } // TODO(philipel): Add simple modify/access functions to prevent adding too // many |references|. @@ -82,9 +52,13 @@ class EncodedFrame : public webrtc::VCMEncodedFrame { // Is this subframe the last one in the superframe (In RTP stream that would // mean that the last packet has a marker bit set). bool is_last_spatial_layer = true; + + private: + // The ID of the frame is determined from RTP level information. The IDs are + // used to describe order and dependencies between frames. + int64_t id_ = -1; }; -} // namespace video_coding } // namespace webrtc #endif // API_VIDEO_ENCODED_FRAME_H_ diff --git a/api/video/encoded_image.cc b/api/video/encoded_image.cc index 1c73bdabe6..fc77b9415b 100644 --- a/api/video/encoded_image.cc +++ b/api/video/encoded_image.cc @@ -32,13 +32,13 @@ EncodedImageBuffer::~EncodedImageBuffer() { // static rtc::scoped_refptr EncodedImageBuffer::Create(size_t size) { - return new rtc::RefCountedObject(size); + return rtc::make_ref_counted(size); } // static rtc::scoped_refptr EncodedImageBuffer::Create( const uint8_t* data, size_t size) { - return new rtc::RefCountedObject(data, size); + return rtc::make_ref_counted(data, size); } const uint8_t* EncodedImageBuffer::data() const { @@ -66,21 +66,11 @@ EncodedImage::EncodedImage() = default; EncodedImage::EncodedImage(EncodedImage&&) = default; EncodedImage::EncodedImage(const EncodedImage&) = default; -EncodedImage::EncodedImage(uint8_t* buffer, size_t size, size_t capacity) - : size_(size), buffer_(buffer), capacity_(capacity) {} - EncodedImage::~EncodedImage() = default; EncodedImage& EncodedImage::operator=(EncodedImage&&) = default; EncodedImage& EncodedImage::operator=(const EncodedImage&) = default; -void EncodedImage::Retain() { - if (buffer_) { - encoded_data_ = EncodedImageBuffer::Create(buffer_, size_); - buffer_ = nullptr; - } -} - void EncodedImage::SetEncodeTime(int64_t encode_start_ms, int64_t encode_finish_ms) { timing_.encode_start_ms = encode_start_ms; diff --git a/api/video/encoded_image.h b/api/video/encoded_image.h index 650766ab64..dae4e3a60a 100644 --- a/api/video/encoded_image.h +++ b/api/video/encoded_image.h @@ -26,7 +26,6 @@ #include "api/video/video_rotation.h" #include "api/video/video_timing.h" #include "rtc_base/checks.h" -#include "rtc_base/deprecation.h" #include "rtc_base/ref_count.h" #include "rtc_base/system/rtc_export.h" @@ -73,12 +72,10 @@ class RTC_EXPORT EncodedImage { EncodedImage(); EncodedImage(EncodedImage&&); EncodedImage(const EncodedImage&); - RTC_DEPRECATED EncodedImage(uint8_t* buffer, size_t length, size_t capacity); ~EncodedImage(); EncodedImage& operator=(EncodedImage&&); - // Discouraged: potentially expensive. EncodedImage& operator=(const EncodedImage&); // TODO(nisse): Change style to timestamp(), set_timestamp(), for consistency @@ -112,6 +109,15 @@ class RTC_EXPORT EncodedImage { color_space_ = color_space; } + // These methods along with the private member video_frame_tracking_id_ are + // meant for media quality testing purpose only. + absl::optional VideoFrameTrackingId() const { + return video_frame_tracking_id_; + } + void SetVideoFrameTrackingId(absl::optional tracking_id) { + video_frame_tracking_id_ = tracking_id; + } + const RtpPacketInfos& PacketInfos() const { return packet_infos_; } void SetPacketInfos(RtpPacketInfos packet_infos) { packet_infos_ = std::move(packet_infos); @@ -128,34 +134,26 @@ class RTC_EXPORT EncodedImage { RTC_DCHECK_LE(new_size, new_size == 0 ? 0 : capacity()); size_ = new_size; } + void SetEncodedData( rtc::scoped_refptr encoded_data) { encoded_data_ = encoded_data; size_ = encoded_data->size(); - buffer_ = nullptr; } void ClearEncodedData() { encoded_data_ = nullptr; size_ = 0; - buffer_ = nullptr; - capacity_ = 0; } rtc::scoped_refptr GetEncodedData() const { - RTC_DCHECK(buffer_ == nullptr); return encoded_data_; } const uint8_t* data() const { - return buffer_ ? buffer_ - : (encoded_data_ ? encoded_data_->data() : nullptr); + return encoded_data_ ? encoded_data_->data() : nullptr; } - // Hack to workaround lack of ownership of the encoded data. If we don't - // already own the underlying data, make an owned copy. - void Retain(); - uint32_t _encodedWidth = 0; uint32_t _encodedHeight = 0; // NTP time of the capture time in local timebase in milliseconds. @@ -185,22 +183,17 @@ class RTC_EXPORT EncodedImage { } timing_; private: - size_t capacity() const { - return buffer_ ? capacity_ : (encoded_data_ ? encoded_data_->size() : 0); - } + size_t capacity() const { return encoded_data_ ? encoded_data_->size() : 0; } - // TODO(bugs.webrtc.org/9378): We're transitioning to always owning the - // encoded data. rtc::scoped_refptr encoded_data_; size_t size_ = 0; // Size of encoded frame data. - // Non-null when used with an un-owned buffer. - uint8_t* buffer_ = nullptr; - // Allocated size of _buffer; relevant only if it's non-null. - size_t capacity_ = 0; uint32_t timestamp_rtp_ = 0; absl::optional spatial_index_; std::map spatial_layer_frame_size_bytes_; absl::optional color_space_; + // This field is meant for media quality testing purpose only. When enabled it + // carries the webrtc::VideoFrame id field from the sender to the receiver. + absl::optional video_frame_tracking_id_; // Information about packets used to assemble this video frame. This is needed // by |SourceTracker| when the frame is delivered to the RTCRtpReceiver's // MediaStreamTrack, in order to implement getContributingSources(). See: diff --git a/api/video/i010_buffer.cc b/api/video/i010_buffer.cc index 7286676ded..74d37d1b57 100644 --- a/api/video/i010_buffer.cc +++ b/api/video/i010_buffer.cc @@ -56,8 +56,8 @@ I010Buffer::~I010Buffer() {} // static rtc::scoped_refptr I010Buffer::Create(int width, int height) { - return new rtc::RefCountedObject( - width, height, width, (width + 1) / 2, (width + 1) / 2); + return rtc::make_ref_counted(width, height, width, + (width + 1) / 2, (width + 1) / 2); } // static diff --git a/api/video/i420_buffer.cc b/api/video/i420_buffer.cc index 2a52217ce3..8783a4a313 100644 --- a/api/video/i420_buffer.cc +++ b/api/video/i420_buffer.cc @@ -60,7 +60,7 @@ I420Buffer::~I420Buffer() {} // static rtc::scoped_refptr I420Buffer::Create(int width, int height) { - return new rtc::RefCountedObject(width, height); + return rtc::make_ref_counted(width, height); } // static @@ -69,8 +69,8 @@ rtc::scoped_refptr I420Buffer::Create(int width, int stride_y, int stride_u, int stride_v) { - return new rtc::RefCountedObject(width, height, stride_y, - stride_u, stride_v); + return rtc::make_ref_counted(width, height, stride_y, stride_u, + stride_v); } // static diff --git a/api/video/nv12_buffer.cc b/api/video/nv12_buffer.cc index cfa85ac52e..37d688b88b 100644 --- a/api/video/nv12_buffer.cc +++ b/api/video/nv12_buffer.cc @@ -49,7 +49,7 @@ NV12Buffer::~NV12Buffer() = default; // static rtc::scoped_refptr NV12Buffer::Create(int width, int height) { - return new rtc::RefCountedObject(width, height); + return rtc::make_ref_counted(width, height); } // static @@ -57,8 +57,7 @@ rtc::scoped_refptr NV12Buffer::Create(int width, int height, int stride_y, int stride_uv) { - return new rtc::RefCountedObject(width, height, stride_y, - stride_uv); + return rtc::make_ref_counted(width, height, stride_y, stride_uv); } // static @@ -145,11 +144,10 @@ void NV12Buffer::CropAndScaleFrom(const NV12BufferInterface& src, const uint8_t* uv_plane = src.DataUV() + src.StrideUV() * uv_offset_y + uv_offset_x * 2; - // kFilterBox is unsupported in libyuv, so using kFilterBilinear instead. int res = libyuv::NV12Scale(y_plane, src.StrideY(), uv_plane, src.StrideUV(), crop_width, crop_height, MutableDataY(), StrideY(), MutableDataUV(), StrideUV(), width(), - height(), libyuv::kFilterBilinear); + height(), libyuv::kFilterBox); RTC_DCHECK_EQ(res, 0); } diff --git a/api/video/recordable_encoded_frame.h b/api/video/recordable_encoded_frame.h index db59964f26..b4ad83a344 100644 --- a/api/video/recordable_encoded_frame.h +++ b/api/video/recordable_encoded_frame.h @@ -26,8 +26,10 @@ class RecordableEncodedFrame { public: // Encoded resolution in pixels struct EncodedResolution { - unsigned width; - unsigned height; + bool empty() const { return width == 0 && height == 0; } + + unsigned width = 0; + unsigned height = 0; }; virtual ~RecordableEncodedFrame() = default; diff --git a/api/video/rtp_video_frame_assembler.cc b/api/video/rtp_video_frame_assembler.cc new file mode 100644 index 0000000000..8f3d04c30b --- /dev/null +++ b/api/video/rtp_video_frame_assembler.cc @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "api/video/rtp_video_frame_assembler.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/types/optional.h" +#include "modules/rtp_rtcp/source/rtp_dependency_descriptor_extension.h" +#include "modules/rtp_rtcp/source/rtp_generic_frame_descriptor_extension.h" +#include "modules/rtp_rtcp/source/rtp_packet_received.h" +#include "modules/rtp_rtcp/source/video_rtp_depacketizer_av1.h" +#include "modules/rtp_rtcp/source/video_rtp_depacketizer_generic.h" +#include "modules/rtp_rtcp/source/video_rtp_depacketizer_h264.h" +#include "modules/rtp_rtcp/source/video_rtp_depacketizer_raw.h" +#include "modules/rtp_rtcp/source/video_rtp_depacketizer_vp8.h" +#include "modules/rtp_rtcp/source/video_rtp_depacketizer_vp9.h" +#include "modules/video_coding/frame_object.h" +#include "modules/video_coding/packet_buffer.h" +#include "modules/video_coding/rtp_frame_reference_finder.h" +#include "rtc_base/logging.h" + +namespace webrtc { +namespace { +std::unique_ptr CreateDepacketizer( + RtpVideoFrameAssembler::PayloadFormat payload_format) { + switch (payload_format) { + case RtpVideoFrameAssembler::kRaw: + return std::make_unique(); + case RtpVideoFrameAssembler::kH264: + return std::make_unique(); + case RtpVideoFrameAssembler::kVp8: + return std::make_unique(); + case RtpVideoFrameAssembler::kVp9: + return std::make_unique(); + case RtpVideoFrameAssembler::kAv1: + return std::make_unique(); + case RtpVideoFrameAssembler::kGeneric: + return std::make_unique(); + } + RTC_NOTREACHED(); + return nullptr; +} +} // namespace + +class RtpVideoFrameAssembler::Impl { + public: + explicit Impl(std::unique_ptr depacketizer); + ~Impl() = default; + + FrameVector InsertPacket(const RtpPacketReceived& packet); + + private: + using RtpFrameVector = + absl::InlinedVector, 3>; + + RtpFrameVector AssembleFrames( + video_coding::PacketBuffer::InsertResult insert_result); + FrameVector FindReferences(RtpFrameVector frames); + FrameVector UpdateWithPadding(uint16_t seq_num); + bool ParseDependenciesDescriptorExtension(const RtpPacketReceived& rtp_packet, + RTPVideoHeader& video_header); + bool ParseGenericDescriptorExtension(const RtpPacketReceived& rtp_packet, + RTPVideoHeader& video_header); + void ClearOldData(uint16_t incoming_seq_num); + + std::unique_ptr video_structure_; + SeqNumUnwrapper frame_id_unwrapper_; + absl::optional video_structure_frame_id_; + std::unique_ptr depacketizer_; + video_coding::PacketBuffer packet_buffer_; + RtpFrameReferenceFinder reference_finder_; +}; + +RtpVideoFrameAssembler::Impl::Impl( + std::unique_ptr depacketizer) + : depacketizer_(std::move(depacketizer)), + packet_buffer_(/*start_buffer_size=*/2048, /*max_buffer_size=*/2048) {} + +RtpVideoFrameAssembler::FrameVector RtpVideoFrameAssembler::Impl::InsertPacket( + const RtpPacketReceived& rtp_packet) { + absl::optional parsed_payload = + depacketizer_->Parse(rtp_packet.PayloadBuffer()); + + if (parsed_payload == absl::nullopt) { + return {}; + } + + if (parsed_payload->video_payload.size() == 0) { + ClearOldData(rtp_packet.SequenceNumber()); + return UpdateWithPadding(rtp_packet.SequenceNumber()); + } + + if (rtp_packet.HasExtension()) { + if (!ParseDependenciesDescriptorExtension(rtp_packet, + parsed_payload->video_header)) { + return {}; + } + } else if (rtp_packet.HasExtension()) { + if (!ParseGenericDescriptorExtension(rtp_packet, + parsed_payload->video_header)) { + return {}; + } + } + + parsed_payload->video_header.is_last_packet_in_frame |= rtp_packet.Marker(); + + auto packet = std::make_unique( + rtp_packet, parsed_payload->video_header); + packet->video_payload = std::move(parsed_payload->video_payload); + + ClearOldData(rtp_packet.SequenceNumber()); + return FindReferences( + AssembleFrames(packet_buffer_.InsertPacket(std::move(packet)))); +} + +void RtpVideoFrameAssembler::Impl::ClearOldData(uint16_t incoming_seq_num) { + constexpr uint16_t kOldSeqNumThreshold = 2000; + uint16_t old_seq_num = incoming_seq_num - kOldSeqNumThreshold; + packet_buffer_.ClearTo(old_seq_num); + reference_finder_.ClearTo(old_seq_num); +} + +RtpVideoFrameAssembler::Impl::RtpFrameVector +RtpVideoFrameAssembler::Impl::AssembleFrames( + video_coding::PacketBuffer::InsertResult insert_result) { + video_coding::PacketBuffer::Packet* first_packet = nullptr; + std::vector> payloads; + RtpFrameVector result; + + for (auto& packet : insert_result.packets) { + if (packet->is_first_packet_in_frame()) { + first_packet = packet.get(); + payloads.clear(); + } + payloads.emplace_back(packet->video_payload); + + if (packet->is_last_packet_in_frame()) { + rtc::scoped_refptr bitstream = + depacketizer_->AssembleFrame(payloads); + + if (!bitstream) { + continue; + } + + const video_coding::PacketBuffer::Packet& last_packet = *packet; + result.push_back(std::make_unique( + first_packet->seq_num, // + last_packet.seq_num, // + last_packet.marker_bit, // + /*times_nacked=*/0, // + /*first_packet_received_time=*/0, // + /*last_packet_received_time=*/0, // + first_packet->timestamp, // + /*ntp_time_ms=*/0, // + /*timing=*/VideoSendTiming(), // + first_packet->payload_type, // + first_packet->codec(), // + last_packet.video_header.rotation, // + last_packet.video_header.content_type, // + first_packet->video_header, // + last_packet.video_header.color_space, // + /*packet_infos=*/RtpPacketInfos(), // + std::move(bitstream))); + } + } + + return result; +} + +RtpVideoFrameAssembler::FrameVector +RtpVideoFrameAssembler::Impl::FindReferences(RtpFrameVector frames) { + FrameVector res; + for (auto& frame : frames) { + auto complete_frames = reference_finder_.ManageFrame(std::move(frame)); + for (std::unique_ptr& complete_frame : complete_frames) { + res.push_back(std::move(complete_frame)); + } + } + return res; +} + +RtpVideoFrameAssembler::FrameVector +RtpVideoFrameAssembler::Impl::UpdateWithPadding(uint16_t seq_num) { + auto res = + FindReferences(AssembleFrames(packet_buffer_.InsertPadding(seq_num))); + auto ref_finder_update = reference_finder_.PaddingReceived(seq_num); + + res.insert(res.end(), std::make_move_iterator(ref_finder_update.begin()), + std::make_move_iterator(ref_finder_update.end())); + + return res; +} + +bool RtpVideoFrameAssembler::Impl::ParseDependenciesDescriptorExtension( + const RtpPacketReceived& rtp_packet, + RTPVideoHeader& video_header) { + webrtc::DependencyDescriptor dependency_descriptor; + + if (!rtp_packet.GetExtension( + video_structure_.get(), &dependency_descriptor)) { + // Descriptor is either malformed, or the template referenced is not in + // the `video_structure_` currently being held. + // TODO(bugs.webrtc.org/10342): Improve packet reordering behavior. + RTC_LOG(LS_WARNING) << "ssrc: " << rtp_packet.Ssrc() + << " Failed to parse dependency descriptor."; + return false; + } + + if (dependency_descriptor.attached_structure != nullptr && + !dependency_descriptor.first_packet_in_frame) { + RTC_LOG(LS_WARNING) << "ssrc: " << rtp_packet.Ssrc() + << "Invalid dependency descriptor: structure " + "attached to non first packet of a frame."; + return false; + } + + video_header.is_first_packet_in_frame = + dependency_descriptor.first_packet_in_frame; + video_header.is_last_packet_in_frame = + dependency_descriptor.last_packet_in_frame; + + int64_t frame_id = + frame_id_unwrapper_.Unwrap(dependency_descriptor.frame_number); + auto& generic_descriptor_info = video_header.generic.emplace(); + generic_descriptor_info.frame_id = frame_id; + generic_descriptor_info.spatial_index = + dependency_descriptor.frame_dependencies.spatial_id; + generic_descriptor_info.temporal_index = + dependency_descriptor.frame_dependencies.temporal_id; + + for (int fdiff : dependency_descriptor.frame_dependencies.frame_diffs) { + generic_descriptor_info.dependencies.push_back(frame_id - fdiff); + } + for (int cdiff : dependency_descriptor.frame_dependencies.chain_diffs) { + generic_descriptor_info.chain_diffs.push_back(frame_id - cdiff); + } + generic_descriptor_info.decode_target_indications = + dependency_descriptor.frame_dependencies.decode_target_indications; + if (dependency_descriptor.resolution) { + video_header.width = dependency_descriptor.resolution->Width(); + video_header.height = dependency_descriptor.resolution->Height(); + } + if (dependency_descriptor.active_decode_targets_bitmask.has_value()) { + generic_descriptor_info.active_decode_targets = + *dependency_descriptor.active_decode_targets_bitmask; + } + + // FrameDependencyStructure is sent in the dependency descriptor of the first + // packet of a key frame and is required to parse all subsequent packets until + // the next key frame. + if (dependency_descriptor.attached_structure) { + RTC_DCHECK(dependency_descriptor.first_packet_in_frame); + if (video_structure_frame_id_ > frame_id) { + RTC_LOG(LS_WARNING) + << "Arrived key frame with id " << frame_id << " and structure id " + << dependency_descriptor.attached_structure->structure_id + << " is older than the latest received key frame with id " + << *video_structure_frame_id_ << " and structure id " + << video_structure_->structure_id; + return false; + } + video_structure_ = std::move(dependency_descriptor.attached_structure); + video_structure_frame_id_ = frame_id; + video_header.frame_type = VideoFrameType::kVideoFrameKey; + } else { + video_header.frame_type = VideoFrameType::kVideoFrameDelta; + } + return true; +} + +bool RtpVideoFrameAssembler::Impl::ParseGenericDescriptorExtension( + const RtpPacketReceived& rtp_packet, + RTPVideoHeader& video_header) { + RtpGenericFrameDescriptor generic_frame_descriptor; + if (!rtp_packet.GetExtension( + &generic_frame_descriptor)) { + return false; + } + + video_header.is_first_packet_in_frame = + generic_frame_descriptor.FirstPacketInSubFrame(); + video_header.is_last_packet_in_frame = + generic_frame_descriptor.LastPacketInSubFrame(); + + if (generic_frame_descriptor.FirstPacketInSubFrame()) { + video_header.frame_type = + generic_frame_descriptor.FrameDependenciesDiffs().empty() + ? VideoFrameType::kVideoFrameKey + : VideoFrameType::kVideoFrameDelta; + + auto& generic_descriptor_info = video_header.generic.emplace(); + int64_t frame_id = + frame_id_unwrapper_.Unwrap(generic_frame_descriptor.FrameId()); + generic_descriptor_info.frame_id = frame_id; + generic_descriptor_info.spatial_index = + generic_frame_descriptor.SpatialLayer(); + generic_descriptor_info.temporal_index = + generic_frame_descriptor.TemporalLayer(); + for (uint16_t fdiff : generic_frame_descriptor.FrameDependenciesDiffs()) { + generic_descriptor_info.dependencies.push_back(frame_id - fdiff); + } + } + video_header.width = generic_frame_descriptor.Width(); + video_header.height = generic_frame_descriptor.Height(); + return true; +} + +RtpVideoFrameAssembler::RtpVideoFrameAssembler(PayloadFormat payload_format) + : impl_(std::make_unique(CreateDepacketizer(payload_format))) {} + +RtpVideoFrameAssembler::~RtpVideoFrameAssembler() = default; + +RtpVideoFrameAssembler::FrameVector RtpVideoFrameAssembler::InsertPacket( + const RtpPacketReceived& packet) { + return impl_->InsertPacket(packet); +} + +} // namespace webrtc diff --git a/api/video/rtp_video_frame_assembler.h b/api/video/rtp_video_frame_assembler.h new file mode 100644 index 0000000000..353942bdc8 --- /dev/null +++ b/api/video/rtp_video_frame_assembler.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef API_VIDEO_RTP_VIDEO_FRAME_ASSEMBLER_H_ +#define API_VIDEO_RTP_VIDEO_FRAME_ASSEMBLER_H_ + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "api/video/encoded_frame.h" +#include "modules/rtp_rtcp/source/rtp_packet_received.h" + +namespace webrtc { +// The RtpVideoFrameAssembler takes RtpPacketReceived and assembles them into +// complete frames. A frame is considered complete when all packets of the frame +// has been received, the bitstream data has successfully extracted, an ID has +// been assigned, and all dependencies are known. Frame IDs are strictly +// monotonic in decode order, dependencies are expressed as frame IDs. +class RtpVideoFrameAssembler { + public: + // FrameVector is just a vector-like type of std::unique_ptr. + // The vector type may change without notice. + using FrameVector = absl::InlinedVector, 3>; + enum PayloadFormat { kRaw, kH264, kVp8, kVp9, kAv1, kGeneric }; + + explicit RtpVideoFrameAssembler(PayloadFormat payload_format); + RtpVideoFrameAssembler(const RtpVideoFrameAssembler& other) = delete; + RtpVideoFrameAssembler& operator=(const RtpVideoFrameAssembler& other) = + delete; + ~RtpVideoFrameAssembler(); + + // Typically when a packet is inserted zero or one frame is completed. In the + // case of RTP packets being inserted out of order then sometime multiple + // frames could be completed from a single packet, hence the 'FrameVector' + // return type. + FrameVector InsertPacket(const RtpPacketReceived& packet); + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace webrtc + +#endif // API_VIDEO_RTP_VIDEO_FRAME_ASSEMBLER_H_ diff --git a/api/video/rtp_video_frame_assembler_unittests.cc b/api/video/rtp_video_frame_assembler_unittests.cc new file mode 100644 index 0000000000..916a83cd73 --- /dev/null +++ b/api/video/rtp_video_frame_assembler_unittests.cc @@ -0,0 +1,495 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include + +#include "api/array_view.h" +#include "api/video/rtp_video_frame_assembler.h" +#include "modules/rtp_rtcp/source/rtp_dependency_descriptor_extension.h" +#include "modules/rtp_rtcp/source/rtp_format.h" +#include "modules/rtp_rtcp/source/rtp_generic_frame_descriptor_extension.h" +#include "modules/rtp_rtcp/source/rtp_packet_to_send.h" +#include "modules/rtp_rtcp/source/rtp_packetizer_av1_test_helper.h" +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +using ::testing::ElementsAreArray; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::Matches; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; +using ::testing::UnorderedElementsAreArray; +using PayloadFormat = RtpVideoFrameAssembler::PayloadFormat; + +class PacketBuilder { + public: + explicit PacketBuilder(PayloadFormat format) + : format_(format), packet_to_send_(&extension_manager_) {} + + PacketBuilder& WithSeqNum(uint16_t seq_num) { + seq_num_ = seq_num; + return *this; + } + + PacketBuilder& WithPayload(rtc::ArrayView payload) { + payload_.assign(payload.begin(), payload.end()); + return *this; + } + + PacketBuilder& WithVideoHeader(const RTPVideoHeader& video_header) { + video_header_ = video_header; + return *this; + } + + template + PacketBuilder& WithExtension(int id, const Args&... args) { + extension_manager_.Register(id); + packet_to_send_.IdentifyExtensions(extension_manager_); + packet_to_send_.SetExtension(std::forward(args)...); + return *this; + } + + RtpPacketReceived Build() { + auto packetizer = + RtpPacketizer::Create(GetVideoCodecType(), payload_, {}, video_header_); + packetizer->NextPacket(&packet_to_send_); + packet_to_send_.SetSequenceNumber(seq_num_); + + RtpPacketReceived received(&extension_manager_); + received.Parse(packet_to_send_.Buffer()); + return received; + } + + private: + absl::optional GetVideoCodecType() { + switch (format_) { + case PayloadFormat::kRaw: { + return absl::nullopt; + } + case PayloadFormat::kH264: { + return kVideoCodecH264; + } + case PayloadFormat::kVp8: { + return kVideoCodecVP8; + } + case PayloadFormat::kVp9: { + return kVideoCodecVP9; + } + case PayloadFormat::kAv1: { + return kVideoCodecAV1; + } + case PayloadFormat::kGeneric: { + return kVideoCodecGeneric; + } + } + RTC_NOTREACHED(); + return absl::nullopt; + } + + const RtpVideoFrameAssembler::PayloadFormat format_; + uint16_t seq_num_ = 0; + std::vector payload_; + RTPVideoHeader video_header_; + RtpPacketReceived::ExtensionManager extension_manager_; + RtpPacketToSend packet_to_send_; +}; + +void AppendFrames(RtpVideoFrameAssembler::FrameVector from, + RtpVideoFrameAssembler::FrameVector& to) { + to.insert(to.end(), std::make_move_iterator(from.begin()), + std::make_move_iterator(from.end())); +} + +rtc::ArrayView References(const std::unique_ptr& frame) { + return rtc::MakeArrayView(frame->references, frame->num_references); +} + +rtc::ArrayView Payload(const std::unique_ptr& frame) { + return rtc::ArrayView(*frame->GetEncodedData()); +} + +TEST(RtpVideoFrameAssembler, Vp8Packetization) { + RtpVideoFrameAssembler assembler(RtpVideoFrameAssembler::kVp8); + + // When sending VP8 over RTP parts of the payload is actually inspected at the + // RTP level. It just so happen that the initial 'V' sets the keyframe bit + // (0x01) to the correct value. + uint8_t kKeyframePayload[] = "Vp8Keyframe"; + ASSERT_EQ(kKeyframePayload[0] & 0x01, 0); + + uint8_t kDeltaframePayload[] = "SomeFrame"; + ASSERT_EQ(kDeltaframePayload[0] & 0x01, 1); + + RtpVideoFrameAssembler::FrameVector frames; + + RTPVideoHeader video_header; + auto& vp8_header = + video_header.video_type_header.emplace(); + + vp8_header.pictureId = 10; + vp8_header.tl0PicIdx = 0; + AppendFrames(assembler.InsertPacket(PacketBuilder(PayloadFormat::kVp8) + .WithPayload(kKeyframePayload) + .WithVideoHeader(video_header) + .Build()), + frames); + + vp8_header.pictureId = 11; + vp8_header.tl0PicIdx = 1; + AppendFrames(assembler.InsertPacket(PacketBuilder(PayloadFormat::kVp8) + .WithPayload(kDeltaframePayload) + .WithVideoHeader(video_header) + .Build()), + frames); + + ASSERT_THAT(frames, SizeIs(2)); + + EXPECT_THAT(frames[0]->Id(), Eq(10)); + EXPECT_THAT(References(frames[0]), IsEmpty()); + EXPECT_THAT(Payload(frames[0]), ElementsAreArray(kKeyframePayload)); + + EXPECT_THAT(frames[1]->Id(), Eq(11)); + EXPECT_THAT(References(frames[1]), UnorderedElementsAre(10)); + EXPECT_THAT(Payload(frames[1]), ElementsAreArray(kDeltaframePayload)); +} + +TEST(RtpVideoFrameAssembler, Vp9Packetization) { + RtpVideoFrameAssembler assembler(RtpVideoFrameAssembler::kVp9); + RtpVideoFrameAssembler::FrameVector frames; + + uint8_t kPayload[] = "SomePayload"; + + RTPVideoHeader video_header; + auto& vp9_header = + video_header.video_type_header.emplace(); + vp9_header.InitRTPVideoHeaderVP9(); + + vp9_header.picture_id = 10; + vp9_header.tl0_pic_idx = 0; + AppendFrames(assembler.InsertPacket(PacketBuilder(PayloadFormat::kVp9) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .Build()), + frames); + + vp9_header.picture_id = 11; + vp9_header.tl0_pic_idx = 1; + vp9_header.inter_pic_predicted = true; + AppendFrames(assembler.InsertPacket(PacketBuilder(PayloadFormat::kVp9) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .Build()), + frames); + + ASSERT_THAT(frames, SizeIs(2)); + + EXPECT_THAT(frames[0]->Id(), Eq(10)); + EXPECT_THAT(Payload(frames[0]), ElementsAreArray(kPayload)); + EXPECT_THAT(References(frames[0]), IsEmpty()); + + EXPECT_THAT(frames[1]->Id(), Eq(11)); + EXPECT_THAT(Payload(frames[1]), ElementsAreArray(kPayload)); + EXPECT_THAT(References(frames[1]), UnorderedElementsAre(10)); +} + +TEST(RtpVideoFrameAssembler, Av1Packetization) { + RtpVideoFrameAssembler assembler(RtpVideoFrameAssembler::kAv1); + RtpVideoFrameAssembler::FrameVector frames; + + auto kKeyframePayload = + BuildAv1Frame({Av1Obu(kAv1ObuTypeSequenceHeader).WithPayload({1, 2, 3}), + Av1Obu(kAv1ObuTypeFrame).WithPayload({4, 5, 6})}); + + auto kDeltaframePayload = + BuildAv1Frame({Av1Obu(kAv1ObuTypeFrame).WithPayload({7, 8, 9})}); + + RTPVideoHeader video_header; + + video_header.frame_type = VideoFrameType::kVideoFrameKey; + AppendFrames(assembler.InsertPacket(PacketBuilder(PayloadFormat::kAv1) + .WithPayload(kKeyframePayload) + .WithVideoHeader(video_header) + .WithSeqNum(20) + .Build()), + frames); + + AppendFrames(assembler.InsertPacket(PacketBuilder(PayloadFormat::kAv1) + .WithPayload(kDeltaframePayload) + .WithSeqNum(21) + .Build()), + frames); + + ASSERT_THAT(frames, SizeIs(2)); + + EXPECT_THAT(frames[0]->Id(), Eq(20)); + EXPECT_THAT(Payload(frames[0]), ElementsAreArray(kKeyframePayload)); + EXPECT_THAT(References(frames[0]), IsEmpty()); + + EXPECT_THAT(frames[1]->Id(), Eq(21)); + EXPECT_THAT(Payload(frames[1]), ElementsAreArray(kDeltaframePayload)); + EXPECT_THAT(References(frames[1]), UnorderedElementsAre(20)); +} + +TEST(RtpVideoFrameAssembler, RawPacketizationDependencyDescriptorExtension) { + RtpVideoFrameAssembler assembler(RtpVideoFrameAssembler::kRaw); + RtpVideoFrameAssembler::FrameVector frames; + uint8_t kPayload[] = "SomePayload"; + + FrameDependencyStructure dependency_structure; + dependency_structure.num_decode_targets = 1; + dependency_structure.num_chains = 1; + dependency_structure.decode_target_protected_by_chain.push_back(0); + dependency_structure.templates.push_back( + FrameDependencyTemplate().S(0).T(0).Dtis("S").ChainDiffs({0})); + dependency_structure.templates.push_back( + FrameDependencyTemplate().S(0).T(0).Dtis("S").ChainDiffs({10}).FrameDiffs( + {10})); + + DependencyDescriptor dependency_descriptor; + + dependency_descriptor.frame_number = 10; + dependency_descriptor.frame_dependencies = dependency_structure.templates[0]; + dependency_descriptor.attached_structure = + std::make_unique(dependency_structure); + AppendFrames(assembler.InsertPacket( + PacketBuilder(PayloadFormat::kRaw) + .WithPayload(kPayload) + .WithExtension( + 1, dependency_structure, dependency_descriptor) + .Build()), + frames); + + dependency_descriptor.frame_number = 20; + dependency_descriptor.frame_dependencies = dependency_structure.templates[1]; + dependency_descriptor.attached_structure.reset(); + AppendFrames(assembler.InsertPacket( + PacketBuilder(PayloadFormat::kRaw) + .WithPayload(kPayload) + .WithExtension( + 1, dependency_structure, dependency_descriptor) + .Build()), + frames); + + ASSERT_THAT(frames, SizeIs(2)); + + EXPECT_THAT(frames[0]->Id(), Eq(10)); + EXPECT_THAT(Payload(frames[0]), ElementsAreArray(kPayload)); + EXPECT_THAT(References(frames[0]), IsEmpty()); + + EXPECT_THAT(frames[1]->Id(), Eq(20)); + EXPECT_THAT(Payload(frames[1]), ElementsAreArray(kPayload)); + EXPECT_THAT(References(frames[1]), UnorderedElementsAre(10)); +} + +TEST(RtpVideoFrameAssembler, RawPacketizationGenericDescriptor00Extension) { + RtpVideoFrameAssembler assembler(RtpVideoFrameAssembler::kRaw); + RtpVideoFrameAssembler::FrameVector frames; + uint8_t kPayload[] = "SomePayload"; + + RtpGenericFrameDescriptor generic; + + generic.SetFirstPacketInSubFrame(true); + generic.SetLastPacketInSubFrame(true); + generic.SetFrameId(100); + AppendFrames( + assembler.InsertPacket( + PacketBuilder(PayloadFormat::kRaw) + .WithPayload(kPayload) + .WithExtension(1, generic) + .Build()), + frames); + + generic.SetFrameId(102); + generic.AddFrameDependencyDiff(2); + AppendFrames( + assembler.InsertPacket( + PacketBuilder(PayloadFormat::kRaw) + .WithPayload(kPayload) + .WithExtension(1, generic) + .Build()), + frames); + + ASSERT_THAT(frames, SizeIs(2)); + + EXPECT_THAT(frames[0]->Id(), Eq(100)); + EXPECT_THAT(Payload(frames[0]), ElementsAreArray(kPayload)); + EXPECT_THAT(References(frames[0]), IsEmpty()); + + EXPECT_THAT(frames[1]->Id(), Eq(102)); + EXPECT_THAT(Payload(frames[1]), ElementsAreArray(kPayload)); + EXPECT_THAT(References(frames[1]), UnorderedElementsAre(100)); +} + +TEST(RtpVideoFrameAssembler, RawPacketizationGenericPayloadDescriptor) { + RtpVideoFrameAssembler assembler(RtpVideoFrameAssembler::kGeneric); + RtpVideoFrameAssembler::FrameVector frames; + uint8_t kPayload[] = "SomePayload"; + + RTPVideoHeader video_header; + + video_header.frame_type = VideoFrameType::kVideoFrameKey; + AppendFrames(assembler.InsertPacket(PacketBuilder(PayloadFormat::kGeneric) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .WithSeqNum(123) + .Build()), + frames); + + video_header.frame_type = VideoFrameType::kVideoFrameDelta; + AppendFrames(assembler.InsertPacket(PacketBuilder(PayloadFormat::kGeneric) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .WithSeqNum(124) + .Build()), + frames); + + ASSERT_THAT(frames, SizeIs(2)); + + EXPECT_THAT(frames[0]->Id(), Eq(123)); + EXPECT_THAT(Payload(frames[0]), ElementsAreArray(kPayload)); + EXPECT_THAT(References(frames[0]), IsEmpty()); + + EXPECT_THAT(frames[1]->Id(), Eq(124)); + EXPECT_THAT(Payload(frames[1]), ElementsAreArray(kPayload)); + EXPECT_THAT(References(frames[1]), UnorderedElementsAre(123)); +} + +TEST(RtpVideoFrameAssembler, Padding) { + RtpVideoFrameAssembler assembler(RtpVideoFrameAssembler::kGeneric); + RtpVideoFrameAssembler::FrameVector frames; + uint8_t kPayload[] = "SomePayload"; + + RTPVideoHeader video_header; + + video_header.frame_type = VideoFrameType::kVideoFrameKey; + AppendFrames(assembler.InsertPacket(PacketBuilder(PayloadFormat::kGeneric) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .WithSeqNum(123) + .Build()), + frames); + + video_header.frame_type = VideoFrameType::kVideoFrameDelta; + AppendFrames(assembler.InsertPacket(PacketBuilder(PayloadFormat::kGeneric) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .WithSeqNum(125) + .Build()), + frames); + + ASSERT_THAT(frames, SizeIs(1)); + + EXPECT_THAT(frames[0]->Id(), Eq(123)); + EXPECT_THAT(Payload(frames[0]), ElementsAreArray(kPayload)); + EXPECT_THAT(References(frames[0]), IsEmpty()); + + // Padding packets have no bitstream data. An easy way to generate one is to + // build a normal packet and then simply remove the bitstream portion of the + // payload. + RtpPacketReceived padding_packet = PacketBuilder(PayloadFormat::kGeneric) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .WithSeqNum(124) + .Build(); + // The payload descriptor is one byte, keep it. + padding_packet.SetPayloadSize(1); + + AppendFrames(assembler.InsertPacket(padding_packet), frames); + + ASSERT_THAT(frames, SizeIs(2)); + + EXPECT_THAT(frames[1]->Id(), Eq(125)); + EXPECT_THAT(Payload(frames[1]), ElementsAreArray(kPayload)); + EXPECT_THAT(References(frames[1]), UnorderedElementsAre(123)); +} + +TEST(RtpVideoFrameAssembler, ClearOldPackets) { + RtpVideoFrameAssembler assembler(RtpVideoFrameAssembler::kGeneric); + + // If we don't have a payload the packet will be counted as a padding packet. + uint8_t kPayload[] = "DontCare"; + + RTPVideoHeader video_header; + video_header.frame_type = VideoFrameType::kVideoFrameKey; + EXPECT_THAT(assembler.InsertPacket(PacketBuilder(PayloadFormat::kGeneric) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .WithSeqNum(0) + .Build()), + SizeIs(1)); + + EXPECT_THAT(assembler.InsertPacket(PacketBuilder(PayloadFormat::kGeneric) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .WithSeqNum(2000) + .Build()), + SizeIs(1)); + + EXPECT_THAT(assembler.InsertPacket(PacketBuilder(PayloadFormat::kGeneric) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .WithSeqNum(0) + .Build()), + SizeIs(0)); + + EXPECT_THAT(assembler.InsertPacket(PacketBuilder(PayloadFormat::kGeneric) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .WithSeqNum(1) + .Build()), + SizeIs(1)); +} + +TEST(RtpVideoFrameAssembler, ClearOldPacketsWithPadding) { + RtpVideoFrameAssembler assembler(RtpVideoFrameAssembler::kGeneric); + uint8_t kPayload[] = "DontCare"; + + RTPVideoHeader video_header; + video_header.frame_type = VideoFrameType::kVideoFrameKey; + EXPECT_THAT(assembler.InsertPacket(PacketBuilder(PayloadFormat::kGeneric) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .WithSeqNum(0) + .Build()), + SizeIs(1)); + + // Padding packets have no bitstream data. An easy way to generate one is to + // build a normal packet and then simply remove the bitstream portion of the + // payload. + RtpPacketReceived padding_packet = PacketBuilder(PayloadFormat::kGeneric) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .WithSeqNum(2000) + .Build(); + // The payload descriptor is one byte, keep it. + padding_packet.SetPayloadSize(1); + EXPECT_THAT(assembler.InsertPacket(padding_packet), SizeIs(0)); + + EXPECT_THAT(assembler.InsertPacket(PacketBuilder(PayloadFormat::kGeneric) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .WithSeqNum(0) + .Build()), + SizeIs(0)); + + EXPECT_THAT(assembler.InsertPacket(PacketBuilder(PayloadFormat::kGeneric) + .WithPayload(kPayload) + .WithVideoHeader(video_header) + .WithSeqNum(1) + .Build()), + SizeIs(1)); +} + +} // namespace +} // namespace webrtc diff --git a/api/video/test/BUILD.gn b/api/video/test/BUILD.gn index 72f50494bb..1573e7848f 100644 --- a/api/video/test/BUILD.gn +++ b/api/video/test/BUILD.gn @@ -20,7 +20,6 @@ rtc_library("rtc_api_video_unittests") { "..:video_adaptation", "..:video_bitrate_allocation", "..:video_frame", - "..:video_frame_nv12", "..:video_rtp_headers", "../../../test:frame_utils", "../../../test:test_support", diff --git a/api/video/video_frame.h b/api/video/video_frame.h index e62aae8e5d..e073fd5e42 100644 --- a/api/video/video_frame.h +++ b/api/video/video_frame.h @@ -134,11 +134,11 @@ class RTC_EXPORT VideoFrame { // Get frame size in pixels. uint32_t size() const; - // Get frame ID. Returns 0 if ID is not set. Not guarantee to be transferred - // from the sender to the receiver, but preserved on single side. The id + // Get frame ID. Returns 0 if ID is not set. Not guaranteed to be transferred + // from the sender to the receiver, but preserved on the sender side. The id // should be propagated between all frame modifications during its lifetime // from capturing to sending as encoded image. It is intended to be unique - // over a time window of a few minutes for peer connection, to which + // over a time window of a few minutes for the peer connection to which the // corresponding video stream belongs to. uint16_t id() const { return id_; } void set_id(uint16_t id) { id_ = id; } diff --git a/api/video/video_frame_buffer.cc b/api/video/video_frame_buffer.cc index 64f339448b..7085010325 100644 --- a/api/video/video_frame_buffer.cc +++ b/api/video/video_frame_buffer.cc @@ -11,6 +11,7 @@ #include "api/video/video_frame_buffer.h" #include "api/video/i420_buffer.h" +#include "api/video/nv12_buffer.h" #include "rtc_base/checks.h" namespace webrtc { @@ -139,4 +140,18 @@ int NV12BufferInterface::ChromaWidth() const { int NV12BufferInterface::ChromaHeight() const { return (height() + 1) / 2; } + +rtc::scoped_refptr NV12BufferInterface::CropAndScale( + int offset_x, + int offset_y, + int crop_width, + int crop_height, + int scaled_width, + int scaled_height) { + rtc::scoped_refptr result = + NV12Buffer::Create(scaled_width, scaled_height); + result->CropAndScaleFrom(*this, offset_x, offset_y, crop_width, crop_height); + return result; +} + } // namespace webrtc diff --git a/api/video/video_frame_buffer.h b/api/video/video_frame_buffer.h index 67b8797325..62adc204f6 100644 --- a/api/video/video_frame_buffer.h +++ b/api/video/video_frame_buffer.h @@ -242,6 +242,13 @@ class RTC_EXPORT NV12BufferInterface : public BiplanarYuv8Buffer { int ChromaWidth() const final; int ChromaHeight() const final; + rtc::scoped_refptr CropAndScale(int offset_x, + int offset_y, + int crop_width, + int crop_height, + int scaled_width, + int scaled_height) override; + protected: ~NV12BufferInterface() override {} }; diff --git a/api/video/video_source_interface.h b/api/video/video_source_interface.h index b03d7c5483..8b5823fc27 100644 --- a/api/video/video_source_interface.h +++ b/api/video/video_source_interface.h @@ -12,6 +12,7 @@ #define API_VIDEO_VIDEO_SOURCE_INTERFACE_H_ #include +#include #include "absl/types/optional.h" #include "api/video/video_sink_interface.h" @@ -22,6 +23,15 @@ namespace rtc { // VideoSinkWants is used for notifying the source of properties a video frame // should have when it is delivered to a certain sink. struct RTC_EXPORT VideoSinkWants { + struct FrameSize { + FrameSize(int width, int height) : width(width), height(height) {} + FrameSize(const FrameSize&) = default; + ~FrameSize() = default; + + int width; + int height; + }; + VideoSinkWants(); VideoSinkWants(const VideoSinkWants&); ~VideoSinkWants(); @@ -49,8 +59,34 @@ struct RTC_EXPORT VideoSinkWants { // Note that this field is unrelated to any horizontal or vertical stride // requirements the encoder has on the incoming video frame buffers. int resolution_alignment = 1; + + // The resolutions that sink is configured to consume. If the sink is an + // encoder this is what the encoder is configured to encode. In singlecast we + // only encode one resolution, but in simulcast and SVC this can mean multiple + // resolutions per frame. + // + // The sink is always configured to consume a subset of the + // webrtc::VideoFrame's resolution. In the case of encoding, we usually encode + // at webrtc::VideoFrame's resolution but this may not always be the case due + // to scaleResolutionDownBy or turning off simulcast or SVC layers. + // + // For example, we may capture at 720p and due to adaptation (e.g. applying + // |max_pixel_count| constraints) create webrtc::VideoFrames of size 480p, but + // if we do scaleResolutionDownBy:2 then the only resolution we end up + // encoding is 240p. In this case we still need to provide webrtc::VideoFrames + // of size 480p but we can optimize internal buffers for 240p, avoiding + // downsampling to 480p if possible. + // + // Note that the |resolutions| can change while frames are in flight and + // should only be used as a hint when constructing the webrtc::VideoFrame. + std::vector resolutions; }; +inline bool operator==(const VideoSinkWants::FrameSize& a, + const VideoSinkWants::FrameSize& b) { + return a.width == b.width && a.height == b.height; +} + template class VideoSourceInterface { public: diff --git a/api/video/video_stream_decoder.h b/api/video/video_stream_decoder.h index 4bf8b985c4..8d71dd300c 100644 --- a/api/video/video_stream_decoder.h +++ b/api/video/video_stream_decoder.h @@ -38,9 +38,7 @@ class VideoStreamDecoderInterface { // Called when the VideoStreamDecoder enters a non-decodable state. virtual void OnNonDecodableState() = 0; - // Called with the last continuous frame. - virtual void OnContinuousUntil( - const video_coding::VideoLayerFrameId& key) = 0; + virtual void OnContinuousUntil(int64_t frame_id) {} virtual void OnDecodedFrame(VideoFrame frame, const FrameInfo& frame_info) = 0; @@ -48,7 +46,7 @@ class VideoStreamDecoderInterface { virtual ~VideoStreamDecoderInterface() = default; - virtual void OnFrame(std::unique_ptr frame) = 0; + virtual void OnFrame(std::unique_ptr frame) = 0; virtual void SetMinPlayoutDelay(TimeDelta min_delay) = 0; virtual void SetMaxPlayoutDelay(TimeDelta max_delay) = 0; diff --git a/api/video/video_stream_decoder_create_unittest.cc b/api/video/video_stream_decoder_create_unittest.cc index 93edb4b8a2..849a054a04 100644 --- a/api/video/video_stream_decoder_create_unittest.cc +++ b/api/video/video_stream_decoder_create_unittest.cc @@ -21,7 +21,6 @@ class NullCallbacks : public VideoStreamDecoderInterface::Callbacks { public: ~NullCallbacks() override = default; void OnNonDecodableState() override {} - void OnContinuousUntil(const video_coding::VideoLayerFrameId& key) override {} void OnDecodedFrame(VideoFrame frame, const VideoStreamDecoderInterface::Callbacks::FrameInfo& frame_info) override {} diff --git a/api/video/video_timing.h b/api/video/video_timing.h index fbd92254a0..80320daa83 100644 --- a/api/video/video_timing.h +++ b/api/video/video_timing.h @@ -41,7 +41,7 @@ struct VideoSendTiming { uint16_t pacer_exit_delta_ms; uint16_t network_timestamp_delta_ms; uint16_t network2_timestamp_delta_ms; - uint8_t flags; + uint8_t flags = TimingFrameFlags::kInvalid; }; // Used to report precise timings of a 'timing frames'. Contains all important diff --git a/api/video_codecs/BUILD.gn b/api/video_codecs/BUILD.gn index a99027641e..83d67fcac4 100644 --- a/api/video_codecs/BUILD.gn +++ b/api/video_codecs/BUILD.gn @@ -15,6 +15,8 @@ if (is_android) { rtc_library("video_codecs_api") { visibility = [ "*" ] sources = [ + "h264_profile_level_id.cc", + "h264_profile_level_id.h", "sdp_video_format.cc", "sdp_video_format.h", "spatial_layer.cc", @@ -23,7 +25,6 @@ rtc_library("video_codecs_api") { "video_codec.h", "video_decoder.cc", "video_decoder.h", - "video_decoder_factory.cc", "video_decoder_factory.h", "video_encoder.cc", "video_encoder.h", @@ -35,11 +36,14 @@ rtc_library("video_codecs_api") { "vp8_frame_config.h", "vp8_temporal_layers.cc", "vp8_temporal_layers.h", + "vp9_profile.cc", + "vp9_profile.h", ] deps = [ "..:fec_controller_api", "..:scoped_refptr", + "../../api:array_view", "../../modules/video_coding:codec_globals_headers", "../../rtc_base:checks", "../../rtc_base:rtc_base_approved", @@ -138,7 +142,6 @@ rtc_library("rtc_software_fallback_wrappers") { ":video_codecs_api", "..:fec_controller_api", "../../api/video:video_frame", - "../../media:rtc_h264_profile_id", "../../media:rtc_media_base", "../../modules/video_coding:video_codec_interface", "../../modules/video_coding:video_coding_utility", diff --git a/api/video_codecs/builtin_video_encoder_factory.cc b/api/video_codecs/builtin_video_encoder_factory.cc index 2f722a4a5c..9463a9cdf2 100644 --- a/api/video_codecs/builtin_video_encoder_factory.cc +++ b/api/video_codecs/builtin_video_encoder_factory.cc @@ -26,18 +26,6 @@ namespace webrtc { namespace { -bool IsFormatSupported(const std::vector& supported_formats, - const SdpVideoFormat& format) { - for (const SdpVideoFormat& supported_format : supported_formats) { - if (cricket::IsSameCodec(format.name, format.parameters, - supported_format.name, - supported_format.parameters)) { - return true; - } - } - return false; -} - // This class wraps the internal factory and adds simulcast. class BuiltinVideoEncoderFactory : public VideoEncoderFactory { public: @@ -47,8 +35,8 @@ class BuiltinVideoEncoderFactory : public VideoEncoderFactory { VideoEncoderFactory::CodecInfo QueryVideoEncoder( const SdpVideoFormat& format) const override { // Format must be one of the internal formats. - RTC_DCHECK(IsFormatSupported( - internal_encoder_factory_->GetSupportedFormats(), format)); + RTC_DCHECK( + format.IsCodecInList(internal_encoder_factory_->GetSupportedFormats())); VideoEncoderFactory::CodecInfo info; return info; } @@ -57,8 +45,8 @@ class BuiltinVideoEncoderFactory : public VideoEncoderFactory { const SdpVideoFormat& format) override { // Try creating internal encoder. std::unique_ptr internal_encoder; - if (IsFormatSupported(internal_encoder_factory_->GetSupportedFormats(), - format)) { + if (format.IsCodecInList( + internal_encoder_factory_->GetSupportedFormats())) { internal_encoder = std::make_unique( internal_encoder_factory_.get(), format); } diff --git a/api/video_codecs/h264_profile_level_id.cc b/api/video_codecs/h264_profile_level_id.cc new file mode 100644 index 0000000000..fa47758189 --- /dev/null +++ b/api/video_codecs/h264_profile_level_id.cc @@ -0,0 +1,252 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "api/video_codecs/h264_profile_level_id.h" + +#include +#include +#include + +#include "rtc_base/arraysize.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +namespace { + +const char kProfileLevelId[] = "profile-level-id"; + +// For level_idc=11 and profile_idc=0x42, 0x4D, or 0x58, the constraint set3 +// flag specifies if level 1b or level 1.1 is used. +const uint8_t kConstraintSet3Flag = 0x10; + +// Convert a string of 8 characters into a byte where the positions containing +// character c will have their bit set. For example, c = 'x', str = "x1xx0000" +// will return 0b10110000. constexpr is used so that the pattern table in +// kProfilePatterns is statically initialized. +constexpr uint8_t ByteMaskString(char c, const char (&str)[9]) { + return (str[0] == c) << 7 | (str[1] == c) << 6 | (str[2] == c) << 5 | + (str[3] == c) << 4 | (str[4] == c) << 3 | (str[5] == c) << 2 | + (str[6] == c) << 1 | (str[7] == c) << 0; +} + +// Class for matching bit patterns such as "x1xx0000" where 'x' is allowed to be +// either 0 or 1. +class BitPattern { + public: + explicit constexpr BitPattern(const char (&str)[9]) + : mask_(~ByteMaskString('x', str)), + masked_value_(ByteMaskString('1', str)) {} + + bool IsMatch(uint8_t value) const { return masked_value_ == (value & mask_); } + + private: + const uint8_t mask_; + const uint8_t masked_value_; +}; + +// Table for converting between profile_idc/profile_iop to H264Profile. +struct ProfilePattern { + const uint8_t profile_idc; + const BitPattern profile_iop; + const H264Profile profile; +}; + +// This is from https://tools.ietf.org/html/rfc6184#section-8.1. +constexpr ProfilePattern kProfilePatterns[] = { + {0x42, BitPattern("x1xx0000"), H264Profile::kProfileConstrainedBaseline}, + {0x4D, BitPattern("1xxx0000"), H264Profile::kProfileConstrainedBaseline}, + {0x58, BitPattern("11xx0000"), H264Profile::kProfileConstrainedBaseline}, + {0x42, BitPattern("x0xx0000"), H264Profile::kProfileBaseline}, + {0x58, BitPattern("10xx0000"), H264Profile::kProfileBaseline}, + {0x4D, BitPattern("0x0x0000"), H264Profile::kProfileMain}, + {0x64, BitPattern("00000000"), H264Profile::kProfileHigh}, + {0x64, BitPattern("00001100"), H264Profile::kProfileConstrainedHigh}}; + +struct LevelConstraint { + const int max_macroblocks_per_second; + const int max_macroblock_frame_size; + const H264Level level; +}; + +// This is from ITU-T H.264 (02/2016) Table A-1 – Level limits. +static constexpr LevelConstraint kLevelConstraints[] = { + {1485, 99, H264Level::kLevel1}, + {1485, 99, H264Level::kLevel1_b}, + {3000, 396, H264Level::kLevel1_1}, + {6000, 396, H264Level::kLevel1_2}, + {11880, 396, H264Level::kLevel1_3}, + {11880, 396, H264Level::kLevel2}, + {19800, 792, H264Level::kLevel2_1}, + {20250, 1620, H264Level::kLevel2_2}, + {40500, 1620, H264Level::kLevel3}, + {108000, 3600, H264Level::kLevel3_1}, + {216000, 5120, H264Level::kLevel3_2}, + {245760, 8192, H264Level::kLevel4}, + {245760, 8192, H264Level::kLevel4_1}, + {522240, 8704, H264Level::kLevel4_2}, + {589824, 22080, H264Level::kLevel5}, + {983040, 36864, H264Level::kLevel5_1}, + {2073600, 36864, H264Level::kLevel5_2}, +}; + +} // anonymous namespace + +absl::optional ParseH264ProfileLevelId(const char* str) { + // The string should consist of 3 bytes in hexadecimal format. + if (strlen(str) != 6u) + return absl::nullopt; + const uint32_t profile_level_id_numeric = strtol(str, nullptr, 16); + if (profile_level_id_numeric == 0) + return absl::nullopt; + + // Separate into three bytes. + const uint8_t level_idc = + static_cast(profile_level_id_numeric & 0xFF); + const uint8_t profile_iop = + static_cast((profile_level_id_numeric >> 8) & 0xFF); + const uint8_t profile_idc = + static_cast((profile_level_id_numeric >> 16) & 0xFF); + + // Parse level based on level_idc and constraint set 3 flag. + H264Level level_casted = static_cast(level_idc); + H264Level level; + + switch (level_casted) { + case H264Level::kLevel1_1: + level = (profile_iop & kConstraintSet3Flag) != 0 ? H264Level::kLevel1_b + : H264Level::kLevel1_1; + break; + case H264Level::kLevel1: + case H264Level::kLevel1_2: + case H264Level::kLevel1_3: + case H264Level::kLevel2: + case H264Level::kLevel2_1: + case H264Level::kLevel2_2: + case H264Level::kLevel3: + case H264Level::kLevel3_1: + case H264Level::kLevel3_2: + case H264Level::kLevel4: + case H264Level::kLevel4_1: + case H264Level::kLevel4_2: + case H264Level::kLevel5: + case H264Level::kLevel5_1: + case H264Level::kLevel5_2: + level = level_casted; + break; + default: + // Unrecognized level_idc. + return absl::nullopt; + } + + // Parse profile_idc/profile_iop into a Profile enum. + for (const ProfilePattern& pattern : kProfilePatterns) { + if (profile_idc == pattern.profile_idc && + pattern.profile_iop.IsMatch(profile_iop)) { + return H264ProfileLevelId(pattern.profile, level); + } + } + + // Unrecognized profile_idc/profile_iop combination. + return absl::nullopt; +} + +absl::optional H264SupportedLevel(int max_frame_pixel_count, + float max_fps) { + static const int kPixelsPerMacroblock = 16 * 16; + + for (int i = arraysize(kLevelConstraints) - 1; i >= 0; --i) { + const LevelConstraint& level_constraint = kLevelConstraints[i]; + if (level_constraint.max_macroblock_frame_size * kPixelsPerMacroblock <= + max_frame_pixel_count && + level_constraint.max_macroblocks_per_second <= + max_fps * level_constraint.max_macroblock_frame_size) { + return level_constraint.level; + } + } + + // No level supported. + return absl::nullopt; +} + +absl::optional ParseSdpForH264ProfileLevelId( + const SdpVideoFormat::Parameters& params) { + // TODO(magjed): The default should really be kProfileBaseline and kLevel1 + // according to the spec: https://tools.ietf.org/html/rfc6184#section-8.1. In + // order to not break backwards compatibility with older versions of WebRTC + // where external codecs don't have any parameters, use + // kProfileConstrainedBaseline kLevel3_1 instead. This workaround will only be + // done in an interim period to allow external clients to update their code. + // http://crbug/webrtc/6337. + static const H264ProfileLevelId kDefaultProfileLevelId( + H264Profile::kProfileConstrainedBaseline, H264Level::kLevel3_1); + + const auto profile_level_id_it = params.find(kProfileLevelId); + return (profile_level_id_it == params.end()) + ? kDefaultProfileLevelId + : ParseH264ProfileLevelId(profile_level_id_it->second.c_str()); +} + +absl::optional H264ProfileLevelIdToString( + const H264ProfileLevelId& profile_level_id) { + // Handle special case level == 1b. + if (profile_level_id.level == H264Level::kLevel1_b) { + switch (profile_level_id.profile) { + case H264Profile::kProfileConstrainedBaseline: + return {"42f00b"}; + case H264Profile::kProfileBaseline: + return {"42100b"}; + case H264Profile::kProfileMain: + return {"4d100b"}; + // Level 1b is not allowed for other profiles. + default: + return absl::nullopt; + } + } + + const char* profile_idc_iop_string; + switch (profile_level_id.profile) { + case H264Profile::kProfileConstrainedBaseline: + profile_idc_iop_string = "42e0"; + break; + case H264Profile::kProfileBaseline: + profile_idc_iop_string = "4200"; + break; + case H264Profile::kProfileMain: + profile_idc_iop_string = "4d00"; + break; + case H264Profile::kProfileConstrainedHigh: + profile_idc_iop_string = "640c"; + break; + case H264Profile::kProfileHigh: + profile_idc_iop_string = "6400"; + break; + // Unrecognized profile. + default: + return absl::nullopt; + } + + char str[7]; + snprintf(str, 7u, "%s%02x", profile_idc_iop_string, profile_level_id.level); + return {str}; +} + +bool H264IsSameProfile(const SdpVideoFormat::Parameters& params1, + const SdpVideoFormat::Parameters& params2) { + const absl::optional profile_level_id = + ParseSdpForH264ProfileLevelId(params1); + const absl::optional other_profile_level_id = + ParseSdpForH264ProfileLevelId(params2); + // Compare H264 profiles, but not levels. + return profile_level_id && other_profile_level_id && + profile_level_id->profile == other_profile_level_id->profile; +} + +} // namespace webrtc diff --git a/api/video_codecs/h264_profile_level_id.h b/api/video_codecs/h264_profile_level_id.h new file mode 100644 index 0000000000..51d025cd7b --- /dev/null +++ b/api/video_codecs/h264_profile_level_id.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef API_VIDEO_CODECS_H264_PROFILE_LEVEL_ID_H_ +#define API_VIDEO_CODECS_H264_PROFILE_LEVEL_ID_H_ + +#include + +#include "absl/types/optional.h" +#include "api/video_codecs/sdp_video_format.h" +#include "rtc_base/system/rtc_export.h" + +namespace webrtc { + +enum class H264Profile { + kProfileConstrainedBaseline, + kProfileBaseline, + kProfileMain, + kProfileConstrainedHigh, + kProfileHigh, +}; + +// All values are equal to ten times the level number, except level 1b which is +// special. +enum class H264Level { + kLevel1_b = 0, + kLevel1 = 10, + kLevel1_1 = 11, + kLevel1_2 = 12, + kLevel1_3 = 13, + kLevel2 = 20, + kLevel2_1 = 21, + kLevel2_2 = 22, + kLevel3 = 30, + kLevel3_1 = 31, + kLevel3_2 = 32, + kLevel4 = 40, + kLevel4_1 = 41, + kLevel4_2 = 42, + kLevel5 = 50, + kLevel5_1 = 51, + kLevel5_2 = 52 +}; + +struct H264ProfileLevelId { + constexpr H264ProfileLevelId(H264Profile profile, H264Level level) + : profile(profile), level(level) {} + H264Profile profile; + H264Level level; +}; + +// Parse profile level id that is represented as a string of 3 hex bytes. +// Nothing will be returned if the string is not a recognized H264 +// profile level id. +absl::optional ParseH264ProfileLevelId(const char* str); + +// Parse profile level id that is represented as a string of 3 hex bytes +// contained in an SDP key-value map. A default profile level id will be +// returned if the profile-level-id key is missing. Nothing will be returned if +// the key is present but the string is invalid. +RTC_EXPORT absl::optional ParseSdpForH264ProfileLevelId( + const SdpVideoFormat::Parameters& params); + +// Given that a decoder supports up to a given frame size (in pixels) at up to a +// given number of frames per second, return the highest H.264 level where it +// can guarantee that it will be able to support all valid encoded streams that +// are within that level. +RTC_EXPORT absl::optional H264SupportedLevel( + int max_frame_pixel_count, + float max_fps); + +// Returns canonical string representation as three hex bytes of the profile +// level id, or returns nothing for invalid profile level ids. +RTC_EXPORT absl::optional H264ProfileLevelIdToString( + const H264ProfileLevelId& profile_level_id); + +// Returns true if the parameters have the same H264 profile (Baseline, High, +// etc). +RTC_EXPORT bool H264IsSameProfile(const SdpVideoFormat::Parameters& params1, + const SdpVideoFormat::Parameters& params2); + +} // namespace webrtc + +#endif // API_VIDEO_CODECS_H264_PROFILE_LEVEL_ID_H_ diff --git a/api/video_codecs/sdp_video_format.cc b/api/video_codecs/sdp_video_format.cc index f8901492ee..689c337ced 100644 --- a/api/video_codecs/sdp_video_format.cc +++ b/api/video_codecs/sdp_video_format.cc @@ -10,10 +10,57 @@ #include "api/video_codecs/sdp_video_format.h" +#include "absl/strings/match.h" +#include "api/video_codecs/h264_profile_level_id.h" +#include "api/video_codecs/video_codec.h" +#include "api/video_codecs/vp9_profile.h" +#include "rtc_base/checks.h" #include "rtc_base/strings/string_builder.h" namespace webrtc { +namespace { + +std::string H264GetPacketizationModeOrDefault( + const SdpVideoFormat::Parameters& params) { + constexpr char kH264FmtpPacketizationMode[] = "packetization-mode"; + const auto it = params.find(kH264FmtpPacketizationMode); + if (it != params.end()) { + return it->second; + } + // If packetization-mode is not present, default to "0". + // https://tools.ietf.org/html/rfc6184#section-6.2 + return "0"; +} + +bool H264IsSamePacketizationMode(const SdpVideoFormat::Parameters& left, + const SdpVideoFormat::Parameters& right) { + return H264GetPacketizationModeOrDefault(left) == + H264GetPacketizationModeOrDefault(right); +} + +// Some (video) codecs are actually families of codecs and rely on parameters +// to distinguish different incompatible family members. +bool IsSameCodecSpecific(const SdpVideoFormat& format1, + const SdpVideoFormat& format2) { + // The assumption when calling this function is that the two formats have the + // same name. + RTC_DCHECK(absl::EqualsIgnoreCase(format1.name, format2.name)); + + VideoCodecType codec_type = PayloadStringToCodecType(format1.name); + switch (codec_type) { + case kVideoCodecH264: + return H264IsSameProfile(format1.parameters, format2.parameters) && + H264IsSamePacketizationMode(format1.parameters, + format2.parameters); + case kVideoCodecVP9: + return VP9IsSameProfile(format1.parameters, format2.parameters); + default: + return true; + } +} +} // namespace + SdpVideoFormat::SdpVideoFormat(const std::string& name) : name(name) {} SdpVideoFormat::SdpVideoFormat(const std::string& name, @@ -37,6 +84,23 @@ std::string SdpVideoFormat::ToString() const { return builder.str(); } +bool SdpVideoFormat::IsSameCodec(const SdpVideoFormat& other) const { + // Two codecs are considered the same if the name matches (case insensitive) + // and certain codec-specific parameters match. + return absl::EqualsIgnoreCase(name, other.name) && + IsSameCodecSpecific(*this, other); +} + +bool SdpVideoFormat::IsCodecInList( + rtc::ArrayView formats) const { + for (const auto& format : formats) { + if (IsSameCodec(format)) { + return true; + } + } + return false; +} + bool operator==(const SdpVideoFormat& a, const SdpVideoFormat& b) { return a.name == b.name && a.parameters == b.parameters; } diff --git a/api/video_codecs/sdp_video_format.h b/api/video_codecs/sdp_video_format.h index 97bb75489d..a1e23f4f9c 100644 --- a/api/video_codecs/sdp_video_format.h +++ b/api/video_codecs/sdp_video_format.h @@ -14,6 +14,7 @@ #include #include +#include "api/array_view.h" #include "rtc_base/system/rtc_export.h" namespace webrtc { @@ -32,6 +33,13 @@ struct RTC_EXPORT SdpVideoFormat { ~SdpVideoFormat(); + // Returns true if the SdpVideoFormats have the same names as well as codec + // specific parameters. Please note that two SdpVideoFormats can represent the + // same codec even though not all parameters are the same. + bool IsSameCodec(const SdpVideoFormat& other) const; + bool IsCodecInList( + rtc::ArrayView formats) const; + std::string ToString() const; friend RTC_EXPORT bool operator==(const SdpVideoFormat& a, diff --git a/api/video_codecs/test/BUILD.gn b/api/video_codecs/test/BUILD.gn index cb810fcb8b..c082dbc562 100644 --- a/api/video_codecs/test/BUILD.gn +++ b/api/video_codecs/test/BUILD.gn @@ -13,6 +13,8 @@ if (rtc_include_tests) { testonly = true sources = [ "builtin_video_encoder_factory_unittest.cc", + "h264_profile_level_id_unittest.cc", + "sdp_video_format_unittest.cc", "video_decoder_software_fallback_wrapper_unittest.cc", "video_encoder_software_fallback_wrapper_unittest.cc", ] diff --git a/api/video_codecs/test/h264_profile_level_id_unittest.cc b/api/video_codecs/test/h264_profile_level_id_unittest.cc new file mode 100644 index 0000000000..47098d2682 --- /dev/null +++ b/api/video_codecs/test/h264_profile_level_id_unittest.cc @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "api/video_codecs/h264_profile_level_id.h" + +#include +#include + +#include "absl/types/optional.h" +#include "test/gtest.h" + +namespace webrtc { + +TEST(H264ProfileLevelId, TestParsingInvalid) { + // Malformed strings. + EXPECT_FALSE(ParseH264ProfileLevelId("")); + EXPECT_FALSE(ParseH264ProfileLevelId(" 42e01f")); + EXPECT_FALSE(ParseH264ProfileLevelId("4242e01f")); + EXPECT_FALSE(ParseH264ProfileLevelId("e01f")); + EXPECT_FALSE(ParseH264ProfileLevelId("gggggg")); + + // Invalid level. + EXPECT_FALSE(ParseH264ProfileLevelId("42e000")); + EXPECT_FALSE(ParseH264ProfileLevelId("42e00f")); + EXPECT_FALSE(ParseH264ProfileLevelId("42e0ff")); + + // Invalid profile. + EXPECT_FALSE(ParseH264ProfileLevelId("42e11f")); + EXPECT_FALSE(ParseH264ProfileLevelId("58601f")); + EXPECT_FALSE(ParseH264ProfileLevelId("64e01f")); +} + +TEST(H264ProfileLevelId, TestParsingLevel) { + EXPECT_EQ(H264Level::kLevel3_1, ParseH264ProfileLevelId("42e01f")->level); + EXPECT_EQ(H264Level::kLevel1_1, ParseH264ProfileLevelId("42e00b")->level); + EXPECT_EQ(H264Level::kLevel1_b, ParseH264ProfileLevelId("42f00b")->level); + EXPECT_EQ(H264Level::kLevel4_2, ParseH264ProfileLevelId("42C02A")->level); + EXPECT_EQ(H264Level::kLevel5_2, ParseH264ProfileLevelId("640c34")->level); +} + +TEST(H264ProfileLevelId, TestParsingConstrainedBaseline) { + EXPECT_EQ(H264Profile::kProfileConstrainedBaseline, + ParseH264ProfileLevelId("42e01f")->profile); + EXPECT_EQ(H264Profile::kProfileConstrainedBaseline, + ParseH264ProfileLevelId("42C02A")->profile); + EXPECT_EQ(H264Profile::kProfileConstrainedBaseline, + ParseH264ProfileLevelId("4de01f")->profile); + EXPECT_EQ(H264Profile::kProfileConstrainedBaseline, + ParseH264ProfileLevelId("58f01f")->profile); +} + +TEST(H264ProfileLevelId, TestParsingBaseline) { + EXPECT_EQ(H264Profile::kProfileBaseline, + ParseH264ProfileLevelId("42a01f")->profile); + EXPECT_EQ(H264Profile::kProfileBaseline, + ParseH264ProfileLevelId("58A01F")->profile); +} + +TEST(H264ProfileLevelId, TestParsingMain) { + EXPECT_EQ(H264Profile::kProfileMain, + ParseH264ProfileLevelId("4D401f")->profile); +} + +TEST(H264ProfileLevelId, TestParsingHigh) { + EXPECT_EQ(H264Profile::kProfileHigh, + ParseH264ProfileLevelId("64001f")->profile); +} + +TEST(H264ProfileLevelId, TestParsingConstrainedHigh) { + EXPECT_EQ(H264Profile::kProfileConstrainedHigh, + ParseH264ProfileLevelId("640c1f")->profile); +} + +TEST(H264ProfileLevelId, TestSupportedLevel) { + EXPECT_EQ(H264Level::kLevel2_1, *H264SupportedLevel(640 * 480, 25)); + EXPECT_EQ(H264Level::kLevel3_1, *H264SupportedLevel(1280 * 720, 30)); + EXPECT_EQ(H264Level::kLevel4_2, *H264SupportedLevel(1920 * 1280, 60)); +} + +// Test supported level below level 1 requirements. +TEST(H264ProfileLevelId, TestSupportedLevelInvalid) { + EXPECT_FALSE(H264SupportedLevel(0, 0)); + // All levels support fps > 5. + EXPECT_FALSE(H264SupportedLevel(1280 * 720, 5)); + // All levels support frame sizes > 183 * 137. + EXPECT_FALSE(H264SupportedLevel(183 * 137, 30)); +} + +TEST(H264ProfileLevelId, TestToString) { + EXPECT_EQ("42e01f", *H264ProfileLevelIdToString(H264ProfileLevelId( + H264Profile::kProfileConstrainedBaseline, + H264Level::kLevel3_1))); + EXPECT_EQ("42000a", *H264ProfileLevelIdToString(H264ProfileLevelId( + H264Profile::kProfileBaseline, H264Level::kLevel1))); + EXPECT_EQ("4d001f", H264ProfileLevelIdToString(H264ProfileLevelId( + H264Profile::kProfileMain, H264Level::kLevel3_1))); + EXPECT_EQ("640c2a", + *H264ProfileLevelIdToString(H264ProfileLevelId( + H264Profile::kProfileConstrainedHigh, H264Level::kLevel4_2))); + EXPECT_EQ("64002a", *H264ProfileLevelIdToString(H264ProfileLevelId( + H264Profile::kProfileHigh, H264Level::kLevel4_2))); +} + +TEST(H264ProfileLevelId, TestToStringLevel1b) { + EXPECT_EQ("42f00b", *H264ProfileLevelIdToString(H264ProfileLevelId( + H264Profile::kProfileConstrainedBaseline, + H264Level::kLevel1_b))); + EXPECT_EQ("42100b", + *H264ProfileLevelIdToString(H264ProfileLevelId( + H264Profile::kProfileBaseline, H264Level::kLevel1_b))); + EXPECT_EQ("4d100b", *H264ProfileLevelIdToString(H264ProfileLevelId( + H264Profile::kProfileMain, H264Level::kLevel1_b))); +} + +TEST(H264ProfileLevelId, TestToStringRoundTrip) { + EXPECT_EQ("42e01f", + *H264ProfileLevelIdToString(*ParseH264ProfileLevelId("42e01f"))); + EXPECT_EQ("42e01f", + *H264ProfileLevelIdToString(*ParseH264ProfileLevelId("42E01F"))); + EXPECT_EQ("4d100b", + *H264ProfileLevelIdToString(*ParseH264ProfileLevelId("4d100b"))); + EXPECT_EQ("4d100b", + *H264ProfileLevelIdToString(*ParseH264ProfileLevelId("4D100B"))); + EXPECT_EQ("640c2a", + *H264ProfileLevelIdToString(*ParseH264ProfileLevelId("640c2a"))); + EXPECT_EQ("640c2a", + *H264ProfileLevelIdToString(*ParseH264ProfileLevelId("640C2A"))); +} + +TEST(H264ProfileLevelId, TestToStringInvalid) { + EXPECT_FALSE(H264ProfileLevelIdToString( + H264ProfileLevelId(H264Profile::kProfileHigh, H264Level::kLevel1_b))); + EXPECT_FALSE(H264ProfileLevelIdToString(H264ProfileLevelId( + H264Profile::kProfileConstrainedHigh, H264Level::kLevel1_b))); + EXPECT_FALSE(H264ProfileLevelIdToString( + H264ProfileLevelId(static_cast(255), H264Level::kLevel3_1))); +} + +TEST(H264ProfileLevelId, TestParseSdpProfileLevelIdEmpty) { + const absl::optional profile_level_id = + ParseSdpForH264ProfileLevelId(SdpVideoFormat::Parameters()); + EXPECT_TRUE(profile_level_id); + EXPECT_EQ(H264Profile::kProfileConstrainedBaseline, + profile_level_id->profile); + EXPECT_EQ(H264Level::kLevel3_1, profile_level_id->level); +} + +TEST(H264ProfileLevelId, TestParseSdpProfileLevelIdConstrainedHigh) { + SdpVideoFormat::Parameters params; + params["profile-level-id"] = "640c2a"; + const absl::optional profile_level_id = + ParseSdpForH264ProfileLevelId(params); + EXPECT_TRUE(profile_level_id); + EXPECT_EQ(H264Profile::kProfileConstrainedHigh, profile_level_id->profile); + EXPECT_EQ(H264Level::kLevel4_2, profile_level_id->level); +} + +TEST(H264ProfileLevelId, TestParseSdpProfileLevelIdInvalid) { + SdpVideoFormat::Parameters params; + params["profile-level-id"] = "foobar"; + EXPECT_FALSE(ParseSdpForH264ProfileLevelId(params)); +} + +} // namespace webrtc diff --git a/api/video_codecs/test/sdp_video_format_unittest.cc b/api/video_codecs/test/sdp_video_format_unittest.cc new file mode 100644 index 0000000000..d55816690e --- /dev/null +++ b/api/video_codecs/test/sdp_video_format_unittest.cc @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "api/video_codecs/sdp_video_format.h" + +#include + +#include "test/gtest.h" + +namespace webrtc { + +typedef SdpVideoFormat Sdp; +typedef SdpVideoFormat::Parameters Params; + +TEST(SdpVideoFormatTest, SameCodecNameNoParameters) { + EXPECT_TRUE(Sdp("H264").IsSameCodec(Sdp("h264"))); + EXPECT_TRUE(Sdp("VP8").IsSameCodec(Sdp("vp8"))); + EXPECT_TRUE(Sdp("Vp9").IsSameCodec(Sdp("vp9"))); + EXPECT_TRUE(Sdp("AV1").IsSameCodec(Sdp("Av1"))); +} +TEST(SdpVideoFormatTest, DifferentCodecNameNoParameters) { + EXPECT_FALSE(Sdp("H264").IsSameCodec(Sdp("VP8"))); + EXPECT_FALSE(Sdp("VP8").IsSameCodec(Sdp("VP9"))); + EXPECT_FALSE(Sdp("AV1").IsSameCodec(Sdp(""))); +} +TEST(SdpVideoFormatTest, SameCodecNameSameParameters) { + EXPECT_TRUE(Sdp("VP9").IsSameCodec(Sdp("VP9", Params{{"profile-id", "0"}}))); + EXPECT_TRUE(Sdp("VP9", Params{{"profile-id", "0"}}) + .IsSameCodec(Sdp("VP9", Params{{"profile-id", "0"}}))); + EXPECT_TRUE(Sdp("VP9", Params{{"profile-id", "2"}}) + .IsSameCodec(Sdp("VP9", Params{{"profile-id", "2"}}))); + EXPECT_TRUE( + Sdp("H264", Params{{"profile-level-id", "42e01f"}}) + .IsSameCodec(Sdp("H264", Params{{"profile-level-id", "42e01f"}}))); + EXPECT_TRUE( + Sdp("H264", Params{{"profile-level-id", "640c34"}}) + .IsSameCodec(Sdp("H264", Params{{"profile-level-id", "640c34"}}))); +} + +TEST(SdpVideoFormatTest, SameCodecNameDifferentParameters) { + EXPECT_FALSE(Sdp("VP9").IsSameCodec(Sdp("VP9", Params{{"profile-id", "2"}}))); + EXPECT_FALSE(Sdp("VP9", Params{{"profile-id", "0"}}) + .IsSameCodec(Sdp("VP9", Params{{"profile-id", "1"}}))); + EXPECT_FALSE(Sdp("VP9", Params{{"profile-id", "2"}}) + .IsSameCodec(Sdp("VP9", Params{{"profile-id", "0"}}))); + EXPECT_FALSE( + Sdp("H264", Params{{"profile-level-id", "42e01f"}}) + .IsSameCodec(Sdp("H264", Params{{"profile-level-id", "640c34"}}))); + EXPECT_FALSE( + Sdp("H264", Params{{"profile-level-id", "640c34"}}) + .IsSameCodec(Sdp("H264", Params{{"profile-level-id", "42f00b"}}))); +} + +TEST(SdpVideoFormatTest, DifferentCodecNameSameParameters) { + EXPECT_FALSE(Sdp("VP9", Params{{"profile-id", "0"}}) + .IsSameCodec(Sdp("H264", Params{{"profile-id", "0"}}))); + EXPECT_FALSE(Sdp("VP9", Params{{"profile-id", "2"}}) + .IsSameCodec(Sdp("VP8", Params{{"profile-id", "2"}}))); + EXPECT_FALSE( + Sdp("H264", Params{{"profile-level-id", "42e01f"}}) + .IsSameCodec(Sdp("VP9", Params{{"profile-level-id", "42e01f"}}))); + EXPECT_FALSE( + Sdp("H264", Params{{"profile-level-id", "640c34"}}) + .IsSameCodec(Sdp("VP8", Params{{"profile-level-id", "640c34"}}))); +} + +} // namespace webrtc diff --git a/api/video_codecs/test/video_encoder_software_fallback_wrapper_unittest.cc b/api/video_codecs/test/video_encoder_software_fallback_wrapper_unittest.cc index 5c5a25cc89..2d8b002f2d 100644 --- a/api/video_codecs/test/video_encoder_software_fallback_wrapper_unittest.cc +++ b/api/video_codecs/test/video_encoder_software_fallback_wrapper_unittest.cc @@ -613,13 +613,13 @@ TEST_F(ForcedFallbackTestEnabled, FallbackIsEndedForNonValidSettings) { EncodeFrameAndVerifyLastName("libvpx"); // Re-initialize encoder with invalid setting, expect no fallback. - codec_.VP8()->numberOfTemporalLayers = 2; + codec_.numberOfSimulcastStreams = 2; InitEncode(kWidth, kHeight); EXPECT_EQ(1, fake_encoder_->init_encode_count_); EncodeFrameAndVerifyLastName("fake-encoder"); // Re-initialize encoder with valid setting. - codec_.VP8()->numberOfTemporalLayers = 1; + codec_.numberOfSimulcastStreams = 1; InitEncode(kWidth, kHeight); EXPECT_EQ(1, fake_encoder_->init_encode_count_); EncodeFrameAndVerifyLastName("libvpx"); diff --git a/api/video_codecs/video_decoder.cc b/api/video_codecs/video_decoder.cc index fee3ec6d42..04673e6c31 100644 --- a/api/video_codecs/video_decoder.cc +++ b/api/video_codecs/video_decoder.cc @@ -32,10 +32,6 @@ VideoDecoder::DecoderInfo VideoDecoder::GetDecoderInfo() const { return info; } -bool VideoDecoder::PrefersLateDecoding() const { - return true; -} - const char* VideoDecoder::ImplementationName() const { return "unknown"; } diff --git a/api/video_codecs/video_decoder.h b/api/video_codecs/video_decoder.h index a6af3f22e9..04052de08b 100644 --- a/api/video_codecs/video_decoder.h +++ b/api/video_codecs/video_decoder.h @@ -70,12 +70,6 @@ class RTC_EXPORT VideoDecoder { virtual DecoderInfo GetDecoderInfo() const; - // Deprecated, use GetDecoderInfo().prefers_late_decoding instead. - // Returns true if the decoder prefer to decode frames late. - // That is, it can not decode infinite number of frames before the decoded - // frame is consumed. - // TODO(bugs.webrtc.org/12271): Remove when downstream has been updated. - virtual bool PrefersLateDecoding() const; // Deprecated, use GetDecoderInfo().implementation_name instead. virtual const char* ImplementationName() const; }; diff --git a/api/video_codecs/video_decoder_factory.h b/api/video_codecs/video_decoder_factory.h index e4d83c2465..0b6ea4f9f2 100644 --- a/api/video_codecs/video_decoder_factory.h +++ b/api/video_codecs/video_decoder_factory.h @@ -15,31 +15,51 @@ #include #include +#include "absl/types/optional.h" +#include "api/video_codecs/sdp_video_format.h" #include "rtc_base/system/rtc_export.h" namespace webrtc { class VideoDecoder; -struct SdpVideoFormat; // A factory that creates VideoDecoders. // NOTE: This class is still under development and may change without notice. class RTC_EXPORT VideoDecoderFactory { public: + struct CodecSupport { + bool is_supported = false; + bool is_power_efficient = false; + }; + // Returns a list of supported video formats in order of preference, to use // for signaling etc. virtual std::vector GetSupportedFormats() const = 0; + // Query whether the specifed format is supported or not and if it will be + // power efficient, which is currently interpreted as if there is support for + // hardware acceleration. + // See https://w3c.github.io/webrtc-svc/#scalabilitymodes* for a specification + // of valid values for |scalability_mode|. + // NOTE: QueryCodecSupport is currently an experimental feature that is + // subject to change without notice. + virtual CodecSupport QueryCodecSupport( + const SdpVideoFormat& format, + absl::optional scalability_mode) const { + // Default implementation, query for supported formats and check if the + // specified format is supported. Returns false if scalability_mode is + // specified. + CodecSupport codec_support; + if (!scalability_mode) { + codec_support.is_supported = format.IsCodecInList(GetSupportedFormats()); + } + return codec_support; + } + // Creates a VideoDecoder for the specified format. virtual std::unique_ptr CreateVideoDecoder( const SdpVideoFormat& format) = 0; - // Note: Do not call or override this method! This method is a legacy - // workaround and is scheduled for removal without notice. - virtual std::unique_ptr LegacyCreateVideoDecoder( - const SdpVideoFormat& format, - const std::string& receive_stream_id); - virtual ~VideoDecoderFactory() {} }; diff --git a/api/video_codecs/video_encoder.cc b/api/video_codecs/video_encoder.cc index 486200bc82..a7e9d7487c 100644 --- a/api/video_codecs/video_encoder.cc +++ b/api/video_codecs/video_encoder.cc @@ -135,8 +135,17 @@ std::string VideoEncoder::EncoderInfo::ToString() const { << ", is_hardware_accelerated = " << is_hardware_accelerated << ", has_internal_source = " << has_internal_source << ", fps_allocation = ["; + size_t num_spatial_layer_with_fps_allocation = 0; + for (size_t i = 0; i < kMaxSpatialLayers; ++i) { + if (!fps_allocation[i].empty()) { + num_spatial_layer_with_fps_allocation = i + 1; + } + } bool first = true; - for (size_t i = 0; i < fps_allocation->size(); ++i) { + for (size_t i = 0; i < num_spatial_layer_with_fps_allocation; ++i) { + if (fps_allocation[i].empty()) { + break; + } if (!first) { oss << ", "; } diff --git a/api/video_codecs/video_encoder.h b/api/video_codecs/video_encoder.h index a030362ab7..caf069718b 100644 --- a/api/video_codecs/video_encoder.h +++ b/api/video_codecs/video_encoder.h @@ -364,7 +364,7 @@ class RTC_EXPORT VideoEncoder { // TODO(bugs.webrtc.org/10720): After updating downstream projects and posting // an announcement to discuss-webrtc, remove the three-parameters variant // and make the two-parameters variant pure-virtual. - /* RTC_DEPRECATED */ virtual int32_t InitEncode( + /* ABSL_DEPRECATED("bugs.webrtc.org/10720") */ virtual int32_t InitEncode( const VideoCodec* codec_settings, int32_t number_of_cores, size_t max_payload_size); diff --git a/api/video_codecs/video_encoder_config.cc b/api/video_codecs/video_encoder_config.cc index 5956d60365..0321da24da 100644 --- a/api/video_codecs/video_encoder_config.cc +++ b/api/video_codecs/video_encoder_config.cc @@ -57,7 +57,8 @@ VideoEncoderConfig::VideoEncoderConfig() max_bitrate_bps(0), bitrate_priority(1.0), number_of_streams(0), - legacy_conference_mode(false) {} + legacy_conference_mode(false), + is_quality_scaling_allowed(false) {} VideoEncoderConfig::VideoEncoderConfig(VideoEncoderConfig&&) = default; diff --git a/api/video_codecs/video_encoder_config.h b/api/video_codecs/video_encoder_config.h index 1a061f52f7..59163743a2 100644 --- a/api/video_codecs/video_encoder_config.h +++ b/api/video_codecs/video_encoder_config.h @@ -181,6 +181,9 @@ class VideoEncoderConfig { // Legacy Google conference mode flag for simulcast screenshare bool legacy_conference_mode; + // Indicates whether quality scaling can be used or not. + bool is_quality_scaling_allowed; + private: // Access to the copy constructor is private to force use of the Copy() // method for those exceptional cases where we do use it. diff --git a/api/video_codecs/video_encoder_factory.h b/api/video_codecs/video_encoder_factory.h index 22430eb19d..c2d66cfa86 100644 --- a/api/video_codecs/video_encoder_factory.h +++ b/api/video_codecs/video_encoder_factory.h @@ -12,6 +12,7 @@ #define API_VIDEO_CODECS_VIDEO_ENCODER_FACTORY_H_ #include +#include #include #include "absl/types/optional.h" @@ -36,6 +37,11 @@ class VideoEncoderFactory { bool has_internal_source = false; }; + struct CodecSupport { + bool is_supported = false; + bool is_power_efficient = false; + }; + // An injectable class that is continuously updated with encoding conditions // and selects the best encoder given those conditions. class EncoderSelectorInterface { @@ -78,6 +84,26 @@ class VideoEncoderFactory { return CodecInfo(); } + // Query whether the specifed format is supported or not and if it will be + // power efficient, which is currently interpreted as if there is support for + // hardware acceleration. + // See https://w3c.github.io/webrtc-svc/#scalabilitymodes* for a specification + // of valid values for |scalability_mode|. + // NOTE: QueryCodecSupport is currently an experimental feature that is + // subject to change without notice. + virtual CodecSupport QueryCodecSupport( + const SdpVideoFormat& format, + absl::optional scalability_mode) const { + // Default implementation, query for supported formats and check if the + // specified format is supported. Returns false if scalability_mode is + // specified. + CodecSupport codec_support; + if (!scalability_mode) { + codec_support.is_supported = format.IsCodecInList(GetSupportedFormats()); + } + return codec_support; + } + // Creates a VideoEncoder for the specified format. virtual std::unique_ptr CreateVideoEncoder( const SdpVideoFormat& format) = 0; diff --git a/api/video_codecs/video_encoder_software_fallback_wrapper.cc b/api/video_codecs/video_encoder_software_fallback_wrapper.cc index 94a18171a1..bcce9dcd93 100644 --- a/api/video_codecs/video_encoder_software_fallback_wrapper.cc +++ b/api/video_codecs/video_encoder_software_fallback_wrapper.cc @@ -25,6 +25,7 @@ #include "api/video/video_frame.h" #include "api/video_codecs/video_codec.h" #include "api/video_codecs/video_encoder.h" +#include "media/base/video_common.h" #include "modules/video_coding/include/video_error_codes.h" #include "modules/video_coding/utility/simulcast_utility.h" #include "rtc_base/checks.h" @@ -50,7 +51,6 @@ struct ForcedFallbackParams { return enable_resolution_based_switch && codec.codecType == kVideoCodecVP8 && codec.numberOfSimulcastStreams <= 1 && - codec.VP8().numberOfTemporalLayers == 1 && codec.width * codec.height <= max_pixels; } @@ -418,6 +418,13 @@ VideoEncoder::EncoderInfo VideoEncoderSoftwareFallbackWrapper::GetEncoderInfo() EncoderInfo info = IsFallbackActive() ? fallback_encoder_info : default_encoder_info; + info.requested_resolution_alignment = cricket::LeastCommonMultiple( + fallback_encoder_info.requested_resolution_alignment, + default_encoder_info.requested_resolution_alignment); + info.apply_alignment_to_all_simulcast_layers = + fallback_encoder_info.apply_alignment_to_all_simulcast_layers || + default_encoder_info.apply_alignment_to_all_simulcast_layers; + if (fallback_params_.has_value()) { const auto settings = (encoder_state_ == EncoderState::kForcedFallback) ? fallback_encoder_info.scaling_settings diff --git a/media/base/vp9_profile.cc b/api/video_codecs/vp9_profile.cc similarity index 90% rename from media/base/vp9_profile.cc rename to api/video_codecs/vp9_profile.cc index abf2502fc8..5e2bd53a86 100644 --- a/media/base/vp9_profile.cc +++ b/api/video_codecs/vp9_profile.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source @@ -8,7 +8,7 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include "media/base/vp9_profile.h" +#include "api/video_codecs/vp9_profile.h" #include #include @@ -47,7 +47,6 @@ absl::optional StringToVP9Profile(const std::string& str) { default: return absl::nullopt; } - return absl::nullopt; } absl::optional ParseSdpForVP9Profile( @@ -59,7 +58,7 @@ absl::optional ParseSdpForVP9Profile( return StringToVP9Profile(profile_str); } -bool IsSameVP9Profile(const SdpVideoFormat::Parameters& params1, +bool VP9IsSameProfile(const SdpVideoFormat::Parameters& params1, const SdpVideoFormat::Parameters& params2) { const absl::optional profile = ParseSdpForVP9Profile(params1); const absl::optional other_profile = diff --git a/api/video_codecs/vp9_profile.h b/api/video_codecs/vp9_profile.h new file mode 100644 index 0000000000..e632df437b --- /dev/null +++ b/api/video_codecs/vp9_profile.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef API_VIDEO_CODECS_VP9_PROFILE_H_ +#define API_VIDEO_CODECS_VP9_PROFILE_H_ + +#include + +#include "absl/types/optional.h" +#include "api/video_codecs/sdp_video_format.h" +#include "rtc_base/system/rtc_export.h" + +namespace webrtc { + +// Profile information for VP9 video. +extern RTC_EXPORT const char kVP9FmtpProfileId[]; + +enum class VP9Profile { + kProfile0, + kProfile1, + kProfile2, +}; + +// Helper functions to convert VP9Profile to std::string. Returns "0" by +// default. +RTC_EXPORT std::string VP9ProfileToString(VP9Profile profile); + +// Helper functions to convert std::string to VP9Profile. Returns null if given +// an invalid profile string. +absl::optional StringToVP9Profile(const std::string& str); + +// Parse profile that is represented as a string of single digit contained in an +// SDP key-value map. A default profile(kProfile0) will be returned if the +// profile key is missing. Nothing will be returned if the key is present but +// the string is invalid. +RTC_EXPORT absl::optional ParseSdpForVP9Profile( + const SdpVideoFormat::Parameters& params); + +// Returns true if the parameters have the same VP9 profile, or neither contains +// VP9 profile. +bool VP9IsSameProfile(const SdpVideoFormat::Parameters& params1, + const SdpVideoFormat::Parameters& params2); + +} // namespace webrtc + +#endif // API_VIDEO_CODECS_VP9_PROFILE_H_ diff --git a/api/video_track_source_proxy_factory.h b/api/video_track_source_proxy_factory.h new file mode 100644 index 0000000000..974720d50b --- /dev/null +++ b/api/video_track_source_proxy_factory.h @@ -0,0 +1,28 @@ +/* + * Copyright 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef API_VIDEO_TRACK_SOURCE_PROXY_FACTORY_H_ +#define API_VIDEO_TRACK_SOURCE_PROXY_FACTORY_H_ + +#include "api/media_stream_interface.h" + +namespace webrtc { + +// Creates a proxy source for |source| which makes sure the real +// VideoTrackSourceInterface implementation is destroyed on the signaling thread +// and marshals calls to |worker_thread| and |signaling_thread|. +rtc::scoped_refptr RTC_EXPORT +CreateVideoTrackSourceProxy(rtc::Thread* signaling_thread, + rtc::Thread* worker_thread, + VideoTrackSourceInterface* source); + +} // namespace webrtc + +#endif // API_VIDEO_TRACK_SOURCE_PROXY_FACTORY_H_ diff --git a/api/voip/BUILD.gn b/api/voip/BUILD.gn index 4db59fd98c..714490a526 100644 --- a/api/voip/BUILD.gn +++ b/api/voip/BUILD.gn @@ -21,11 +21,13 @@ rtc_source_set("voip_api") { ] deps = [ "..:array_view", - "../../rtc_base/system:unused", "../audio_codecs:audio_codecs_api", "../neteq:neteq_api", ] - absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/base:core_headers", + "//third_party/abseil-cpp/absl/types:optional", + ] } rtc_library("voip_engine_factory") { @@ -47,9 +49,21 @@ rtc_library("voip_engine_factory") { } if (rtc_include_tests) { + rtc_source_set("mock_voip_engine") { + testonly = true + visibility = [ "*" ] + sources = [ "test/mock_voip_engine.h" ] + deps = [ + ":voip_api", + "..:array_view", + "../../test:test_support", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + } + rtc_library("voip_engine_factory_unittests") { testonly = true - sources = [ "voip_engine_factory_unittest.cc" ] + sources = [ "test/voip_engine_factory_unittest.cc" ] deps = [ ":voip_engine_factory", "../../modules/audio_device:mock_audio_device", @@ -59,4 +73,13 @@ if (rtc_include_tests) { "../task_queue:default_task_queue_factory", ] } + + rtc_library("compile_all_headers") { + testonly = true + sources = [ "test/compile_all_headers.cc" ] + deps = [ + ":mock_voip_engine", + "../../test:test_support", + ] + } } diff --git a/api/voip/DEPS b/api/voip/DEPS index 837b9a673e..3845dffab0 100644 --- a/api/voip/DEPS +++ b/api/voip/DEPS @@ -3,10 +3,6 @@ specific_include_rules = { "+third_party/absl/types/optional.h", ], - "voip_base.h": [ - "+rtc_base/system/unused.h", - ], - "voip_engine_factory.h": [ "+modules/audio_device/include/audio_device.h", "+modules/audio_processing/include/audio_processing.h", diff --git a/api/proxy.cc b/api/voip/test/compile_all_headers.cc similarity index 58% rename from api/proxy.cc rename to api/voip/test/compile_all_headers.cc index 67318e7dab..73a0f0d1c4 100644 --- a/api/proxy.cc +++ b/api/voip/test/compile_all_headers.cc @@ -1,5 +1,5 @@ /* - * Copyright 2017 The WebRTC project authors. All Rights Reserved. + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source @@ -8,5 +8,7 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include "api/proxy.h" +// This file verifies that all include files in this directory can be +// compiled without errors or other required includes. +#include "api/voip/test/mock_voip_engine.h" diff --git a/api/voip/test/mock_voip_engine.h b/api/voip/test/mock_voip_engine.h new file mode 100644 index 0000000000..74b880d652 --- /dev/null +++ b/api/voip/test/mock_voip_engine.h @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef API_VOIP_TEST_MOCK_VOIP_ENGINE_H_ +#define API_VOIP_TEST_MOCK_VOIP_ENGINE_H_ + +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/voip/voip_base.h" +#include "api/voip/voip_codec.h" +#include "api/voip/voip_dtmf.h" +#include "api/voip/voip_engine.h" +#include "api/voip/voip_network.h" +#include "api/voip/voip_statistics.h" +#include "api/voip/voip_volume_control.h" +#include "test/gmock.h" + +namespace webrtc { + +class MockVoipBase : public VoipBase { + public: + MOCK_METHOD(ChannelId, + CreateChannel, + (Transport*, absl::optional), + (override)); + MOCK_METHOD(VoipResult, ReleaseChannel, (ChannelId), (override)); + MOCK_METHOD(VoipResult, StartSend, (ChannelId), (override)); + MOCK_METHOD(VoipResult, StopSend, (ChannelId), (override)); + MOCK_METHOD(VoipResult, StartPlayout, (ChannelId), (override)); + MOCK_METHOD(VoipResult, StopPlayout, (ChannelId), (override)); +}; + +class MockVoipCodec : public VoipCodec { + public: + MOCK_METHOD(VoipResult, + SetSendCodec, + (ChannelId, int, const SdpAudioFormat&), + (override)); + MOCK_METHOD(VoipResult, + SetReceiveCodecs, + (ChannelId, (const std::map&)), + (override)); +}; + +class MockVoipDtmf : public VoipDtmf { + public: + MOCK_METHOD(VoipResult, + RegisterTelephoneEventType, + (ChannelId, int, int), + (override)); + MOCK_METHOD(VoipResult, + SendDtmfEvent, + (ChannelId, DtmfEvent, int), + (override)); +}; + +class MockVoipNetwork : public VoipNetwork { + public: + MOCK_METHOD(VoipResult, + ReceivedRTPPacket, + (ChannelId channel_id, rtc::ArrayView rtp_packet), + (override)); + MOCK_METHOD(VoipResult, + ReceivedRTCPPacket, + (ChannelId channel_id, rtc::ArrayView rtcp_packet), + (override)); +}; + +class MockVoipStatistics : public VoipStatistics { + public: + MOCK_METHOD(VoipResult, + GetIngressStatistics, + (ChannelId, IngressStatistics&), + (override)); + MOCK_METHOD(VoipResult, + GetChannelStatistics, + (ChannelId channel_id, ChannelStatistics&), + (override)); +}; + +class MockVoipVolumeControl : public VoipVolumeControl { + public: + MOCK_METHOD(VoipResult, SetInputMuted, (ChannelId, bool), (override)); + + MOCK_METHOD(VoipResult, + GetInputVolumeInfo, + (ChannelId, VolumeInfo&), + (override)); + MOCK_METHOD(VoipResult, + GetOutputVolumeInfo, + (ChannelId, VolumeInfo&), + (override)); +}; + +class MockVoipEngine : public VoipEngine { + public: + VoipBase& Base() override { return base_; } + VoipNetwork& Network() override { return network_; } + VoipCodec& Codec() override { return codec_; } + VoipDtmf& Dtmf() override { return dtmf_; } + VoipStatistics& Statistics() override { return statistics_; } + VoipVolumeControl& VolumeControl() override { return volume_; } + + // Direct access to underlying members are required for testing. + MockVoipBase base_; + MockVoipNetwork network_; + MockVoipCodec codec_; + MockVoipDtmf dtmf_; + MockVoipStatistics statistics_; + MockVoipVolumeControl volume_; +}; + +} // namespace webrtc + +#endif // API_VOIP_TEST_MOCK_VOIP_ENGINE_H_ diff --git a/api/voip/voip_engine_factory_unittest.cc b/api/voip/test/voip_engine_factory_unittest.cc similarity index 80% rename from api/voip/voip_engine_factory_unittest.cc rename to api/voip/test/voip_engine_factory_unittest.cc index 84b474f3b8..f967a0ba8f 100644 --- a/api/voip/voip_engine_factory_unittest.cc +++ b/api/voip/test/voip_engine_factory_unittest.cc @@ -24,11 +24,11 @@ namespace { // Create voip engine with mock modules as normal use case. TEST(VoipEngineFactoryTest, CreateEngineWithMockModules) { VoipEngineConfig config; - config.encoder_factory = new rtc::RefCountedObject(); - config.decoder_factory = new rtc::RefCountedObject(); + config.encoder_factory = rtc::make_ref_counted(); + config.decoder_factory = rtc::make_ref_counted(); config.task_queue_factory = CreateDefaultTaskQueueFactory(); config.audio_processing = - new rtc::RefCountedObject>(); + rtc::make_ref_counted>(); config.audio_device_module = test::MockAudioDeviceModule::CreateNice(); auto voip_engine = CreateVoipEngine(std::move(config)); @@ -38,8 +38,8 @@ TEST(VoipEngineFactoryTest, CreateEngineWithMockModules) { // Create voip engine without setting audio processing as optional component. TEST(VoipEngineFactoryTest, UseNoAudioProcessing) { VoipEngineConfig config; - config.encoder_factory = new rtc::RefCountedObject(); - config.decoder_factory = new rtc::RefCountedObject(); + config.encoder_factory = rtc::make_ref_counted(); + config.decoder_factory = rtc::make_ref_counted(); config.task_queue_factory = CreateDefaultTaskQueueFactory(); config.audio_device_module = test::MockAudioDeviceModule::CreateNice(); diff --git a/api/voip/voip_base.h b/api/voip/voip_base.h index 6a411f8d88..d469ea4bd4 100644 --- a/api/voip/voip_base.h +++ b/api/voip/voip_base.h @@ -11,8 +11,8 @@ #ifndef API_VOIP_VOIP_BASE_H_ #define API_VOIP_VOIP_BASE_H_ +#include "absl/base/attributes.h" #include "absl/types/optional.h" -#include "rtc_base/system/unused.h" namespace webrtc { @@ -36,7 +36,7 @@ class Transport; enum class ChannelId : int {}; -enum class RTC_WARN_UNUSED_RESULT VoipResult { +enum class ABSL_MUST_USE_RESULT VoipResult { // kOk indicates the function was successfully invoked with no error. kOk, // kInvalidArgument indicates the caller specified an invalid argument, such diff --git a/audio/BUILD.gn b/audio/BUILD.gn index 6901e33673..200f9f4038 100644 --- a/audio/BUILD.gn +++ b/audio/BUILD.gn @@ -47,6 +47,7 @@ rtc_library("audio") { "../api:rtp_headers", "../api:rtp_parameters", "../api:scoped_refptr", + "../api:sequence_checker", "../api:transport_api", "../api/audio:aec3_factory", "../api/audio:audio_frame_api", @@ -90,10 +91,11 @@ rtc_library("audio") { "../rtc_base:rtc_base_approved", "../rtc_base:rtc_task_queue", "../rtc_base:safe_minmax", + "../rtc_base:threading", "../rtc_base/experiments:field_trial_parser", "../rtc_base/synchronization:mutex", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/system:no_unique_address", + "../rtc_base/task_utils:pending_task_safety_flag", "../rtc_base/task_utils:to_queued_task", "../system_wrappers", "../system_wrappers:field_trial", @@ -138,6 +140,7 @@ if (rtc_include_tests) { "mock_voe_channel_proxy.h", "remix_resample_unittest.cc", "test/audio_stats_test.cc", + "test/nack_test.cc", ] deps = [ ":audio", @@ -150,6 +153,7 @@ if (rtc_include_tests) { "../api/audio_codecs:audio_codecs_api", "../api/audio_codecs/opus:audio_decoder_opus", "../api/audio_codecs/opus:audio_encoder_opus", + "../api/crypto:frame_decryptor_interface", "../api/rtc_event_log", "../api/task_queue:default_task_queue_factory", "../api/units:time_delta", @@ -191,7 +195,7 @@ if (rtc_include_tests) { ] } - if (rtc_enable_protobuf) { + if (rtc_enable_protobuf && !build_with_chromium) { rtc_test("low_bandwidth_audio_test") { testonly = true @@ -219,8 +223,8 @@ if (rtc_include_tests) { "../test:test_support", "../test/pc/e2e:network_quality_metrics_reporter", "//testing/gtest", - "//third_party/abseil-cpp/absl/flags:flag", ] + absl_deps = [ "//third_party/abseil-cpp/absl/flags:flag" ] if (is_android) { deps += [ "//testing/android/native_test:native_test_native_code" ] } @@ -278,30 +282,32 @@ if (rtc_include_tests) { } } - rtc_library("audio_perf_tests") { - testonly = true + if (!build_with_chromium) { + rtc_library("audio_perf_tests") { + testonly = true - sources = [ - "test/audio_bwe_integration_test.cc", - "test/audio_bwe_integration_test.h", - ] - deps = [ - "../api:simulated_network_api", - "../api/task_queue", - "../call:fake_network", - "../call:simulated_network", - "../common_audio", - "../rtc_base:rtc_base_approved", - "../rtc_base:task_queue_for_test", - "../system_wrappers", - "../test:field_trial", - "../test:fileutils", - "../test:test_common", - "../test:test_main", - "../test:test_support", - "//testing/gtest", - ] + sources = [ + "test/audio_bwe_integration_test.cc", + "test/audio_bwe_integration_test.h", + ] + deps = [ + "../api:simulated_network_api", + "../api/task_queue", + "../call:fake_network", + "../call:simulated_network", + "../common_audio", + "../rtc_base:rtc_base_approved", + "../rtc_base:task_queue_for_test", + "../system_wrappers", + "../test:field_trial", + "../test:fileutils", + "../test:test_common", + "../test:test_main", + "../test:test_support", + "//testing/gtest", + ] - data = [ "//resources/voice_engine/audio_dtx16.wav" ] + data = [ "//resources/voice_engine/audio_dtx16.wav" ] + } } } diff --git a/audio/audio_receive_stream.cc b/audio/audio_receive_stream.cc index 54c8a02976..f243fa67db 100644 --- a/audio/audio_receive_stream.cc +++ b/audio/audio_receive_stream.cc @@ -18,12 +18,14 @@ #include "api/audio_codecs/audio_format.h" #include "api/call/audio_sink.h" #include "api/rtp_parameters.h" +#include "api/sequence_checker.h" #include "audio/audio_send_stream.h" #include "audio/audio_state.h" #include "audio/channel_receive.h" #include "audio/conversion.h" #include "call/rtp_config.h" #include "call/rtp_stream_receiver_controller_interface.h" +#include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/strings/string_builder.h" @@ -68,7 +70,6 @@ namespace { std::unique_ptr CreateChannelReceive( Clock* clock, webrtc::AudioState* audio_state, - ProcessThread* module_process_thread, NetEqFactory* neteq_factory, const webrtc::AudioReceiveStream::Config& config, RtcEventLog* event_log) { @@ -76,67 +77,70 @@ std::unique_ptr CreateChannelReceive( internal::AudioState* internal_audio_state = static_cast(audio_state); return voe::CreateChannelReceive( - clock, module_process_thread, neteq_factory, - internal_audio_state->audio_device_module(), config.rtcp_send_transport, - event_log, config.rtp.local_ssrc, config.rtp.remote_ssrc, - config.jitter_buffer_max_packets, config.jitter_buffer_fast_accelerate, - config.jitter_buffer_min_delay_ms, + clock, neteq_factory, internal_audio_state->audio_device_module(), + config.rtcp_send_transport, event_log, config.rtp.local_ssrc, + config.rtp.remote_ssrc, config.jitter_buffer_max_packets, + config.jitter_buffer_fast_accelerate, config.jitter_buffer_min_delay_ms, config.jitter_buffer_enable_rtx_handling, config.decoder_factory, - config.codec_pair_id, config.frame_decryptor, config.crypto_options, - std::move(config.frame_transformer)); + config.codec_pair_id, std::move(config.frame_decryptor), + config.crypto_options, std::move(config.frame_transformer)); } } // namespace AudioReceiveStream::AudioReceiveStream( Clock* clock, - RtpStreamReceiverControllerInterface* receiver_controller, PacketRouter* packet_router, - ProcessThread* module_process_thread, NetEqFactory* neteq_factory, const webrtc::AudioReceiveStream::Config& config, const rtc::scoped_refptr& audio_state, webrtc::RtcEventLog* event_log) : AudioReceiveStream(clock, - receiver_controller, packet_router, config, audio_state, event_log, CreateChannelReceive(clock, audio_state.get(), - module_process_thread, neteq_factory, config, event_log)) {} AudioReceiveStream::AudioReceiveStream( Clock* clock, - RtpStreamReceiverControllerInterface* receiver_controller, PacketRouter* packet_router, const webrtc::AudioReceiveStream::Config& config, const rtc::scoped_refptr& audio_state, webrtc::RtcEventLog* event_log, std::unique_ptr channel_receive) - : audio_state_(audio_state), - channel_receive_(std::move(channel_receive)), - source_tracker_(clock) { + : config_(config), + audio_state_(audio_state), + source_tracker_(clock), + channel_receive_(std::move(channel_receive)) { RTC_LOG(LS_INFO) << "AudioReceiveStream: " << config.rtp.remote_ssrc; RTC_DCHECK(config.decoder_factory); RTC_DCHECK(config.rtcp_send_transport); RTC_DCHECK(audio_state_); RTC_DCHECK(channel_receive_); - module_process_thread_checker_.Detach(); + packet_sequence_checker_.Detach(); - RTC_DCHECK(receiver_controller); RTC_DCHECK(packet_router); // Configure bandwidth estimation. channel_receive_->RegisterReceiverCongestionControlObjects(packet_router); - // Register with transport. - rtp_stream_receiver_ = receiver_controller->CreateReceiver( - config.rtp.remote_ssrc, channel_receive_.get()); - ConfigureStream(this, config, true); + // When output is muted, ChannelReceive will directly notify the source + // tracker of "delivered" frames, so RtpReceiver information will continue to + // be updated. + channel_receive_->SetSourceTracker(&source_tracker_); + + // Complete configuration. + // TODO(solenberg): Config NACK history window (which is a packet count), + // using the actual packet size for the configured codec. + channel_receive_->SetNACKStatus(config.rtp.nack.rtp_history_ms != 0, + config.rtp.nack.rtp_history_ms / 20); + channel_receive_->SetReceiveCodecs(config.decoder_map); + // `frame_transformer` and `frame_decryptor` have been given to + // `channel_receive_` already. } AudioReceiveStream::~AudioReceiveStream() { @@ -147,10 +151,43 @@ AudioReceiveStream::~AudioReceiveStream() { channel_receive_->ResetReceiverCongestionControlObjects(); } -void AudioReceiveStream::Reconfigure( +void AudioReceiveStream::RegisterWithTransport( + RtpStreamReceiverControllerInterface* receiver_controller) { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + RTC_DCHECK(!rtp_stream_receiver_); + rtp_stream_receiver_ = receiver_controller->CreateReceiver( + config_.rtp.remote_ssrc, channel_receive_.get()); +} + +void AudioReceiveStream::UnregisterFromTransport() { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + rtp_stream_receiver_.reset(); +} + +void AudioReceiveStream::ReconfigureForTesting( const webrtc::AudioReceiveStream::Config& config) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - ConfigureStream(this, config, false); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + + // SSRC can't be changed mid-stream. + RTC_DCHECK_EQ(config_.rtp.remote_ssrc, config.rtp.remote_ssrc); + RTC_DCHECK_EQ(config_.rtp.local_ssrc, config.rtp.local_ssrc); + + // Configuration parameters which cannot be changed. + RTC_DCHECK_EQ(config_.rtcp_send_transport, config.rtcp_send_transport); + // Decoder factory cannot be changed because it is configured at + // voe::Channel construction time. + RTC_DCHECK_EQ(config_.decoder_factory, config.decoder_factory); + + // TODO(solenberg): Config NACK history window (which is a packet count), + // using the actual packet size for the configured codec. + RTC_DCHECK_EQ(config_.rtp.nack.rtp_history_ms, config.rtp.nack.rtp_history_ms) + << "Use SetUseTransportCcAndNackHistory"; + + RTC_DCHECK(config_.decoder_map == config.decoder_map) << "Use SetDecoderMap"; + RTC_DCHECK_EQ(config_.frame_transformer, config.frame_transformer) + << "Use SetDepacketizerToDecoderFrameTransformer"; + + config_ = config; } void AudioReceiveStream::Start() { @@ -173,6 +210,54 @@ void AudioReceiveStream::Stop() { audio_state()->RemoveReceivingStream(this); } +bool AudioReceiveStream::IsRunning() const { + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + return playing_; +} + +void AudioReceiveStream::SetDepacketizerToDecoderFrameTransformer( + rtc::scoped_refptr frame_transformer) { + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + channel_receive_->SetDepacketizerToDecoderFrameTransformer( + std::move(frame_transformer)); +} + +void AudioReceiveStream::SetDecoderMap( + std::map decoder_map) { + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + config_.decoder_map = std::move(decoder_map); + channel_receive_->SetReceiveCodecs(config_.decoder_map); +} + +void AudioReceiveStream::SetUseTransportCcAndNackHistory(bool use_transport_cc, + int history_ms) { + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + RTC_DCHECK_GE(history_ms, 0); + config_.rtp.transport_cc = use_transport_cc; + if (config_.rtp.nack.rtp_history_ms != history_ms) { + config_.rtp.nack.rtp_history_ms = history_ms; + // TODO(solenberg): Config NACK history window (which is a packet count), + // using the actual packet size for the configured codec. + channel_receive_->SetNACKStatus(history_ms != 0, history_ms / 20); + } +} + +void AudioReceiveStream::SetFrameDecryptor( + rtc::scoped_refptr frame_decryptor) { + // TODO(bugs.webrtc.org/11993): This is called via WebRtcAudioReceiveStream, + // expect to be called on the network thread. + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + channel_receive_->SetFrameDecryptor(std::move(frame_decryptor)); +} + +void AudioReceiveStream::SetRtpExtensions( + std::vector extensions) { + // TODO(bugs.webrtc.org/11993): This is called via WebRtcAudioReceiveStream, + // expect to be called on the network thread. + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + config_.rtp.extensions = std::move(extensions); +} + webrtc::AudioReceiveStream::Stats AudioReceiveStream::GetStats( bool get_and_clear_legacy_stats) const { RTC_DCHECK_RUN_ON(&worker_thread_checker_); @@ -193,6 +278,7 @@ webrtc::AudioReceiveStream::Stats AudioReceiveStream::GetStats( call_stats.header_and_padding_bytes_rcvd; stats.packets_rcvd = call_stats.packetsReceived; stats.packets_lost = call_stats.cumulativeLost; + stats.nacks_sent = call_stats.nacks_sent; stats.capture_start_ntp_time_ms = call_stats.capture_start_ntp_time_ms_; stats.last_packet_received_timestamp_ms = call_stats.last_packet_received_timestamp_ms; @@ -253,6 +339,14 @@ webrtc::AudioReceiveStream::Stats AudioReceiveStream::GetStats( stats.decoding_plc_cng = ds.decoded_plc_cng; stats.decoding_muted_output = ds.decoded_muted_output; + stats.last_sender_report_timestamp_ms = + call_stats.last_sender_report_timestamp_ms; + stats.last_sender_report_remote_timestamp_ms = + call_stats.last_sender_report_remote_timestamp_ms; + stats.sender_reports_packets_sent = call_stats.sender_reports_packets_sent; + stats.sender_reports_bytes_sent = call_stats.sender_reports_bytes_sent; + stats.sender_reports_reports_count = call_stats.sender_reports_reports_count; + return stats; } @@ -306,14 +400,10 @@ uint32_t AudioReceiveStream::id() const { } absl::optional AudioReceiveStream::GetInfo() const { - RTC_DCHECK_RUN_ON(&module_process_thread_checker_); - absl::optional info = channel_receive_->GetSyncInfo(); - - if (!info) - return absl::nullopt; - - info->current_delay_ms = channel_receive_->GetDelayEstimate(); - return info; + // TODO(bugs.webrtc.org/11993): This is called via RtpStreamsSynchronizer, + // expect to be called on the network thread. + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + return channel_receive_->GetSyncInfo(); } bool AudioReceiveStream::GetPlayoutRtpTimestamp(uint32_t* rtp_timestamp, @@ -331,12 +421,14 @@ void AudioReceiveStream::SetEstimatedPlayoutNtpTimestampMs( } bool AudioReceiveStream::SetMinimumPlayoutDelay(int delay_ms) { - RTC_DCHECK_RUN_ON(&module_process_thread_checker_); + // TODO(bugs.webrtc.org/11993): This is called via RtpStreamsSynchronizer, + // expect to be called on the network thread. + RTC_DCHECK_RUN_ON(&worker_thread_checker_); return channel_receive_->SetMinimumPlayoutDelay(delay_ms); } void AudioReceiveStream::AssociateSendStream(AudioSendStream* send_stream) { - RTC_DCHECK_RUN_ON(&worker_thread_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); channel_receive_->SetAssociatedSendChannel( send_stream ? send_stream->GetChannel() : nullptr); associated_send_stream_ = send_stream; @@ -350,6 +442,24 @@ void AudioReceiveStream::DeliverRtcp(const uint8_t* packet, size_t length) { channel_receive_->ReceivedRTCPPacket(packet, length); } +void AudioReceiveStream::SetSyncGroup(const std::string& sync_group) { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + config_.sync_group = sync_group; +} + +void AudioReceiveStream::SetLocalSsrc(uint32_t local_ssrc) { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + // TODO(tommi): Consider storing local_ssrc in one place. + config_.rtp.local_ssrc = local_ssrc; + channel_receive_->OnLocalSsrcChange(local_ssrc); +} + +uint32_t AudioReceiveStream::local_ssrc() const { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + RTC_DCHECK_EQ(config_.rtp.local_ssrc, channel_receive_->GetLocalSsrc()); + return config_.rtp.local_ssrc; +} + const webrtc::AudioReceiveStream::Config& AudioReceiveStream::config() const { RTC_DCHECK_RUN_ON(&worker_thread_checker_); return config_; @@ -357,7 +467,7 @@ const webrtc::AudioReceiveStream::Config& AudioReceiveStream::config() const { const AudioSendStream* AudioReceiveStream::GetAssociatedSendStreamForTesting() const { - RTC_DCHECK_RUN_ON(&worker_thread_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); return associated_send_stream_; } @@ -366,50 +476,5 @@ internal::AudioState* AudioReceiveStream::audio_state() const { RTC_DCHECK(audio_state); return audio_state; } - -void AudioReceiveStream::ConfigureStream(AudioReceiveStream* stream, - const Config& new_config, - bool first_time) { - RTC_LOG(LS_INFO) << "AudioReceiveStream::ConfigureStream: " - << new_config.ToString(); - RTC_DCHECK(stream); - const auto& channel_receive = stream->channel_receive_; - const auto& old_config = stream->config_; - - // Configuration parameters which cannot be changed. - RTC_DCHECK(first_time || - old_config.rtp.remote_ssrc == new_config.rtp.remote_ssrc); - RTC_DCHECK(first_time || - old_config.rtcp_send_transport == new_config.rtcp_send_transport); - // Decoder factory cannot be changed because it is configured at - // voe::Channel construction time. - RTC_DCHECK(first_time || - old_config.decoder_factory == new_config.decoder_factory); - - if (!first_time) { - // SSRC can't be changed mid-stream. - RTC_DCHECK_EQ(old_config.rtp.local_ssrc, new_config.rtp.local_ssrc); - RTC_DCHECK_EQ(old_config.rtp.remote_ssrc, new_config.rtp.remote_ssrc); - } - - // TODO(solenberg): Config NACK history window (which is a packet count), - // using the actual packet size for the configured codec. - if (first_time || old_config.rtp.nack.rtp_history_ms != - new_config.rtp.nack.rtp_history_ms) { - channel_receive->SetNACKStatus(new_config.rtp.nack.rtp_history_ms != 0, - new_config.rtp.nack.rtp_history_ms / 20); - } - if (first_time || old_config.decoder_map != new_config.decoder_map) { - channel_receive->SetReceiveCodecs(new_config.decoder_map); - } - - if (first_time || - old_config.frame_transformer != new_config.frame_transformer) { - channel_receive->SetDepacketizerToDecoderFrameTransformer( - new_config.frame_transformer); - } - - stream->config_ = new_config; -} } // namespace internal } // namespace webrtc diff --git a/audio/audio_receive_stream.h b/audio/audio_receive_stream.h index 32f8b60d58..61ebc2719f 100644 --- a/audio/audio_receive_stream.h +++ b/audio/audio_receive_stream.h @@ -11,17 +11,20 @@ #ifndef AUDIO_AUDIO_RECEIVE_STREAM_H_ #define AUDIO_AUDIO_RECEIVE_STREAM_H_ +#include #include +#include #include #include "api/audio/audio_mixer.h" #include "api/neteq/neteq_factory.h" #include "api/rtp_headers.h" +#include "api/sequence_checker.h" #include "audio/audio_state.h" #include "call/audio_receive_stream.h" #include "call/syncable.h" #include "modules/rtp_rtcp/source/source_tracker.h" -#include "rtc_base/thread_checker.h" +#include "rtc_base/system/no_unique_address.h" #include "system_wrappers/include/clock.h" namespace webrtc { @@ -44,9 +47,7 @@ class AudioReceiveStream final : public webrtc::AudioReceiveStream, public Syncable { public: AudioReceiveStream(Clock* clock, - RtpStreamReceiverControllerInterface* receiver_controller, PacketRouter* packet_router, - ProcessThread* module_process_thread, NetEqFactory* neteq_factory, const webrtc::AudioReceiveStream::Config& config, const rtc::scoped_refptr& audio_state, @@ -54,7 +55,6 @@ class AudioReceiveStream final : public webrtc::AudioReceiveStream, // For unit tests, which need to supply a mock channel receive. AudioReceiveStream( Clock* clock, - RtpStreamReceiverControllerInterface* receiver_controller, PacketRouter* packet_router, const webrtc::AudioReceiveStream::Config& config, const rtc::scoped_refptr& audio_state, @@ -65,12 +65,37 @@ class AudioReceiveStream final : public webrtc::AudioReceiveStream, AudioReceiveStream(const AudioReceiveStream&) = delete; AudioReceiveStream& operator=(const AudioReceiveStream&) = delete; + // Destruction happens on the worker thread. Prior to destruction the caller + // must ensure that a registration with the transport has been cleared. See + // `RegisterWithTransport` for details. + // TODO(tommi): As a further improvement to this, performing the full + // destruction on the network thread could be made the default. ~AudioReceiveStream() override; + // Called on the network thread to register/unregister with the network + // transport. + void RegisterWithTransport( + RtpStreamReceiverControllerInterface* receiver_controller); + // If registration has previously been done (via `RegisterWithTransport`) then + // `UnregisterFromTransport` must be called prior to destruction, on the + // network thread. + void UnregisterFromTransport(); + // webrtc::AudioReceiveStream implementation. - void Reconfigure(const webrtc::AudioReceiveStream::Config& config) override; void Start() override; void Stop() override; + const RtpConfig& rtp_config() const override { return config_.rtp; } + bool IsRunning() const override; + void SetDepacketizerToDecoderFrameTransformer( + rtc::scoped_refptr frame_transformer) + override; + void SetDecoderMap(std::map decoder_map) override; + void SetUseTransportCcAndNackHistory(bool use_transport_cc, + int history_ms) override; + void SetFrameDecryptor(rtc::scoped_refptr + frame_decryptor) override; + void SetRtpExtensions(std::vector extensions) override; + webrtc::AudioReceiveStream::Stats GetStats( bool get_and_clear_legacy_stats) const override; void SetSink(AudioSinkInterface* sink) override; @@ -96,27 +121,48 @@ class AudioReceiveStream final : public webrtc::AudioReceiveStream, void AssociateSendStream(AudioSendStream* send_stream); void DeliverRtcp(const uint8_t* packet, size_t length); + + void SetSyncGroup(const std::string& sync_group); + + void SetLocalSsrc(uint32_t local_ssrc); + + uint32_t local_ssrc() const; + + uint32_t remote_ssrc() const { + // The remote_ssrc member variable of config_ will never change and can be + // considered const. + return config_.rtp.remote_ssrc; + } + const webrtc::AudioReceiveStream::Config& config() const; const AudioSendStream* GetAssociatedSendStreamForTesting() const; - private: - static void ConfigureStream(AudioReceiveStream* stream, - const Config& new_config, - bool first_time); + // TODO(tommi): Remove this method. + void ReconfigureForTesting(const webrtc::AudioReceiveStream::Config& config); + private: AudioState* audio_state() const; - rtc::ThreadChecker worker_thread_checker_; - rtc::ThreadChecker module_process_thread_checker_; + RTC_NO_UNIQUE_ADDRESS SequenceChecker worker_thread_checker_; + // TODO(bugs.webrtc.org/11993): This checker conceptually represents + // operations that belong to the network thread. The Call class is currently + // moving towards handling network packets on the network thread and while + // that work is ongoing, this checker may in practice represent the worker + // thread, but still serves as a mechanism of grouping together concepts + // that belong to the network thread. Once the packets are fully delivered + // on the network thread, this comment will be deleted. + RTC_NO_UNIQUE_ADDRESS SequenceChecker packet_sequence_checker_; webrtc::AudioReceiveStream::Config config_; rtc::scoped_refptr audio_state_; - const std::unique_ptr channel_receive_; SourceTracker source_tracker_; - AudioSendStream* associated_send_stream_ = nullptr; + const std::unique_ptr channel_receive_; + AudioSendStream* associated_send_stream_ + RTC_GUARDED_BY(packet_sequence_checker_) = nullptr; bool playing_ RTC_GUARDED_BY(worker_thread_checker_) = false; - std::unique_ptr rtp_stream_receiver_; + std::unique_ptr rtp_stream_receiver_ + RTC_GUARDED_BY(packet_sequence_checker_); }; } // namespace internal } // namespace webrtc diff --git a/audio/audio_receive_stream_unittest.cc b/audio/audio_receive_stream_unittest.cc index fcd691ea80..fb5f1cb876 100644 --- a/audio/audio_receive_stream_unittest.cc +++ b/audio/audio_receive_stream_unittest.cc @@ -74,7 +74,7 @@ const AudioDecodingCallStats kAudioDecodeStats = MakeAudioDecodeStatsForTest(); struct ConfigHelper { explicit ConfigHelper(bool use_null_audio_processing) - : ConfigHelper(new rtc::RefCountedObject(), + : ConfigHelper(rtc::make_ref_counted(), use_null_audio_processing) {} ConfigHelper(rtc::scoped_refptr audio_mixer, @@ -87,9 +87,9 @@ struct ConfigHelper { config.audio_processing = use_null_audio_processing ? nullptr - : new rtc::RefCountedObject>(); + : rtc::make_ref_counted>(); config.audio_device_module = - new rtc::RefCountedObject>(); + rtc::make_ref_counted>(); audio_state_ = AudioState::Create(config); channel_receive_ = new ::testing::StrictMock(); @@ -104,8 +104,7 @@ struct ConfigHelper { .WillRepeatedly(Invoke([](const std::map& codecs) { EXPECT_THAT(codecs, ::testing::IsEmpty()); })); - EXPECT_CALL(*channel_receive_, SetDepacketizerToDecoderFrameTransformer(_)) - .Times(1); + EXPECT_CALL(*channel_receive_, SetSourceTracker(_)); stream_config_.rtp.local_ssrc = kLocalSsrc; stream_config_.rtp.remote_ssrc = kRemoteSsrc; @@ -116,15 +115,16 @@ struct ConfigHelper { RtpExtension::kTransportSequenceNumberUri, kTransportSequenceNumberId)); stream_config_.rtcp_send_transport = &rtcp_send_transport_; stream_config_.decoder_factory = - new rtc::RefCountedObject; + rtc::make_ref_counted(); } std::unique_ptr CreateAudioReceiveStream() { - return std::unique_ptr( - new internal::AudioReceiveStream( - Clock::GetRealTimeClock(), &rtp_stream_receiver_controller_, - &packet_router_, stream_config_, audio_state_, &event_log_, - std::unique_ptr(channel_receive_))); + auto ret = std::make_unique( + Clock::GetRealTimeClock(), &packet_router_, stream_config_, + audio_state_, &event_log_, + std::unique_ptr(channel_receive_)); + ret->RegisterWithTransport(&rtp_stream_receiver_controller_); + return ret; } AudioReceiveStream::Config& config() { return stream_config_; } @@ -198,6 +198,7 @@ TEST(AudioReceiveStreamTest, ConstructDestruct) { for (bool use_null_audio_processing : {false, true}) { ConfigHelper helper(use_null_audio_processing); auto recv_stream = helper.CreateAudioReceiveStream(); + recv_stream->UnregisterFromTransport(); } } @@ -211,6 +212,7 @@ TEST(AudioReceiveStreamTest, ReceiveRtcpPacket) { ReceivedRTCPPacket(&rtcp_packet[0], rtcp_packet.size())) .WillOnce(Return()); recv_stream->DeliverRtcp(&rtcp_packet[0], rtcp_packet.size()); + recv_stream->UnregisterFromTransport(); } } @@ -275,6 +277,7 @@ TEST(AudioReceiveStreamTest, GetStats) { EXPECT_EQ(kCallStats.capture_start_ntp_time_ms_, stats.capture_start_ntp_time_ms); EXPECT_EQ(kPlayoutNtpTimestampMs, stats.estimated_playout_ntp_timestamp_ms); + recv_stream->UnregisterFromTransport(); } } @@ -285,6 +288,7 @@ TEST(AudioReceiveStreamTest, SetGain) { EXPECT_CALL(*helper.channel_receive(), SetChannelOutputVolumeScaling(FloatEq(0.765f))); recv_stream->SetGain(0.765f); + recv_stream->UnregisterFromTransport(); } } @@ -316,14 +320,9 @@ TEST(AudioReceiveStreamTest, StreamsShouldBeAddedToMixerOnceOnStart) { // Stop stream before it is being destructed. recv_stream2->Stop(); - } -} -TEST(AudioReceiveStreamTest, ReconfigureWithSameConfig) { - for (bool use_null_audio_processing : {false, true}) { - ConfigHelper helper(use_null_audio_processing); - auto recv_stream = helper.CreateAudioReceiveStream(); - recv_stream->Reconfigure(helper.config()); + recv_stream1->UnregisterFromTransport(); + recv_stream2->UnregisterFromTransport(); } } @@ -333,20 +332,32 @@ TEST(AudioReceiveStreamTest, ReconfigureWithUpdatedConfig) { auto recv_stream = helper.CreateAudioReceiveStream(); auto new_config = helper.config(); - new_config.rtp.nack.rtp_history_ms = 300 + 20; + new_config.rtp.extensions.clear(); new_config.rtp.extensions.push_back( RtpExtension(RtpExtension::kAudioLevelUri, kAudioLevelId + 1)); new_config.rtp.extensions.push_back( RtpExtension(RtpExtension::kTransportSequenceNumberUri, kTransportSequenceNumberId + 1)); - new_config.decoder_map.emplace(1, SdpAudioFormat("foo", 8000, 1)); MockChannelReceive& channel_receive = *helper.channel_receive(); - EXPECT_CALL(channel_receive, SetNACKStatus(true, 15 + 1)).Times(1); + + // TODO(tommi, nisse): This applies new extensions to the internal config, + // but there's nothing that actually verifies that the changes take effect. + // In fact Call manages the extensions separately in Call::ReceiveRtpConfig + // and changing this config value (there seem to be a few copies), doesn't + // affect that logic. + recv_stream->ReconfigureForTesting(new_config); + + new_config.decoder_map.emplace(1, SdpAudioFormat("foo", 8000, 1)); EXPECT_CALL(channel_receive, SetReceiveCodecs(new_config.decoder_map)); + recv_stream->SetDecoderMap(new_config.decoder_map); + + EXPECT_CALL(channel_receive, SetNACKStatus(true, 15 + 1)).Times(1); + recv_stream->SetUseTransportCcAndNackHistory(new_config.rtp.transport_cc, + 300 + 20); - recv_stream->Reconfigure(new_config); + recv_stream->UnregisterFromTransport(); } } @@ -357,17 +368,23 @@ TEST(AudioReceiveStreamTest, ReconfigureWithFrameDecryptor) { auto new_config_0 = helper.config(); rtc::scoped_refptr mock_frame_decryptor_0( - new rtc::RefCountedObject()); + rtc::make_ref_counted()); new_config_0.frame_decryptor = mock_frame_decryptor_0; - recv_stream->Reconfigure(new_config_0); + // TODO(tommi): While this changes the internal config value, it doesn't + // actually change what frame_decryptor is used. WebRtcAudioReceiveStream + // recreates the whole instance in order to change this value. + // So, it's not clear if changing this post initialization needs to be + // supported. + recv_stream->ReconfigureForTesting(new_config_0); auto new_config_1 = helper.config(); rtc::scoped_refptr mock_frame_decryptor_1( - new rtc::RefCountedObject()); + rtc::make_ref_counted()); new_config_1.frame_decryptor = mock_frame_decryptor_1; new_config_1.crypto_options.sframe.require_frame_encryption = true; - recv_stream->Reconfigure(new_config_1); + recv_stream->ReconfigureForTesting(new_config_1); + recv_stream->UnregisterFromTransport(); } } diff --git a/audio/audio_send_stream.cc b/audio/audio_send_stream.cc index 4e21b1f31d..62dd53d337 100644 --- a/audio/audio_send_stream.cc +++ b/audio/audio_send_stream.cc @@ -102,7 +102,6 @@ AudioSendStream::AudioSendStream( const webrtc::AudioSendStream::Config& config, const rtc::scoped_refptr& audio_state, TaskQueueFactory* task_queue_factory, - ProcessThread* module_process_thread, RtpTransportControllerSendInterface* rtp_transport, BitrateAllocatorInterface* bitrate_allocator, RtcEventLog* event_log, @@ -119,7 +118,6 @@ AudioSendStream::AudioSendStream( voe::CreateChannelSend( clock, task_queue_factory, - module_process_thread, config.send_transport, rtcp_rtt_stats, event_log, @@ -142,7 +140,7 @@ AudioSendStream::AudioSendStream( const absl::optional& suspended_rtp_state, std::unique_ptr channel_send) : clock_(clock), - worker_queue_(rtp_transport->GetWorkerQueue()), + rtp_transport_queue_(rtp_transport->GetWorkerQueue()), allocate_audio_without_feedback_( field_trial::IsEnabled("WebRTC-Audio-ABWENoTWCC")), enable_audio_alr_probing_( @@ -160,7 +158,7 @@ AudioSendStream::AudioSendStream( rtp_rtcp_module_(channel_send_->GetRtpRtcp()), suspended_rtp_state_(suspended_rtp_state) { RTC_LOG(LS_INFO) << "AudioSendStream: " << config.rtp.ssrc; - RTC_DCHECK(worker_queue_); + RTC_DCHECK(rtp_transport_queue_); RTC_DCHECK(audio_state_); RTC_DCHECK(channel_send_); RTC_DCHECK(bitrate_allocator_); @@ -168,31 +166,32 @@ AudioSendStream::AudioSendStream( RTC_DCHECK(rtp_rtcp_module_); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); ConfigureStream(config, true); - + UpdateCachedTargetAudioBitrateConstraints(); pacer_thread_checker_.Detach(); } AudioSendStream::~AudioSendStream() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_LOG(LS_INFO) << "~AudioSendStream: " << config_.rtp.ssrc; RTC_DCHECK(!sending_); channel_send_->ResetSenderCongestionControlObjects(); // Blocking call to synchronize state with worker queue to ensure that there // are no pending tasks left that keeps references to audio. rtc::Event thread_sync_event; - worker_queue_->PostTask([&] { thread_sync_event.Set(); }); + rtp_transport_queue_->PostTask([&] { thread_sync_event.Set(); }); thread_sync_event.Wait(rtc::Event::kForever); } const webrtc::AudioSendStream::Config& AudioSendStream::GetConfig() const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); return config_; } void AudioSendStream::Reconfigure( const webrtc::AudioSendStream::Config& new_config) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); ConfigureStream(new_config, false); } @@ -351,20 +350,22 @@ void AudioSendStream::ConfigureStream( } channel_send_->CallEncoder([this](AudioEncoder* encoder) { + RTC_DCHECK_RUN_ON(&worker_thread_checker_); if (!encoder) { return; } - worker_queue_->PostTask( - [this, length_range = encoder->GetFrameLengthRange()] { - RTC_DCHECK_RUN_ON(worker_queue_); - frame_length_range_ = length_range; - }); + frame_length_range_ = encoder->GetFrameLengthRange(); + UpdateCachedTargetAudioBitrateConstraints(); }); if (sending_) { ReconfigureBitrateObserver(new_config); } + config_ = new_config; + if (!first_time) { + UpdateCachedTargetAudioBitrateConstraints(); + } } void AudioSendStream::Start() { @@ -379,13 +380,7 @@ void AudioSendStream::Start() { if (send_side_bwe_with_overhead_) rtp_transport_->IncludeOverheadInPacedSender(); rtp_rtcp_module_->SetAsPartOfAllocation(true); - rtc::Event thread_sync_event; - worker_queue_->PostTask([&] { - RTC_DCHECK_RUN_ON(worker_queue_); - ConfigureBitrateObserver(); - thread_sync_event.Set(); - }); - thread_sync_event.Wait(rtc::Event::kForever); + ConfigureBitrateObserver(); } else { rtp_rtcp_module_->SetAsPartOfAllocation(false); } @@ -396,7 +391,7 @@ void AudioSendStream::Start() { } void AudioSendStream::Stop() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); if (!sending_) { return; } @@ -431,14 +426,14 @@ bool AudioSendStream::SendTelephoneEvent(int payload_type, int payload_frequency, int event, int duration_ms) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); channel_send_->SetSendTelephoneEventPayloadType(payload_type, payload_frequency); return channel_send_->SendTelephoneEventOutband(event, duration_ms); } void AudioSendStream::SetMuted(bool muted) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); channel_send_->SetInputMute(muted); } @@ -448,7 +443,7 @@ webrtc::AudioSendStream::Stats AudioSendStream::GetStats() const { webrtc::AudioSendStream::Stats AudioSendStream::GetStats( bool has_remote_tracks) const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); webrtc::AudioSendStream::Stats stats; stats.local_ssrc = config_.rtp.ssrc; stats.target_bitrate_bps = channel_send_->GetBitrate(); @@ -503,29 +498,35 @@ webrtc::AudioSendStream::Stats AudioSendStream::GetStats( stats.report_block_datas = std::move(call_stats.report_block_datas); + stats.nacks_rcvd = call_stats.nacks_rcvd; + return stats; } void AudioSendStream::DeliverRtcp(const uint8_t* packet, size_t length) { RTC_DCHECK_RUN_ON(&worker_thread_checker_); channel_send_->ReceivedRTCPPacket(packet, length); - worker_queue_->PostTask([&]() { + + { // Poll if overhead has changed, which it can do if ack triggers us to stop // sending mid/rid. MutexLock lock(&overhead_per_packet_lock_); UpdateOverheadForEncoder(); - }); + } + UpdateCachedTargetAudioBitrateConstraints(); } uint32_t AudioSendStream::OnBitrateUpdated(BitrateAllocationUpdate update) { - RTC_DCHECK_RUN_ON(worker_queue_); + RTC_DCHECK_RUN_ON(rtp_transport_queue_); // Pick a target bitrate between the constraints. Overrules the allocator if // it 1) allocated a bitrate of zero to disable the stream or 2) allocated a // higher than max to allow for e.g. extra FEC. - auto constraints = GetMinMaxBitrateConstraints(); - update.target_bitrate.Clamp(constraints.min, constraints.max); - update.stable_target_bitrate.Clamp(constraints.min, constraints.max); + RTC_DCHECK(cached_constraints_.has_value()); + update.target_bitrate.Clamp(cached_constraints_->min, + cached_constraints_->max); + update.stable_target_bitrate.Clamp(cached_constraints_->min, + cached_constraints_->max); channel_send_->OnBitrateAllocation(update); @@ -536,13 +537,17 @@ uint32_t AudioSendStream::OnBitrateUpdated(BitrateAllocationUpdate update) { void AudioSendStream::SetTransportOverhead( int transport_overhead_per_packet_bytes) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - MutexLock lock(&overhead_per_packet_lock_); - transport_overhead_per_packet_bytes_ = transport_overhead_per_packet_bytes; - UpdateOverheadForEncoder(); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + { + MutexLock lock(&overhead_per_packet_lock_); + transport_overhead_per_packet_bytes_ = transport_overhead_per_packet_bytes; + UpdateOverheadForEncoder(); + } + UpdateCachedTargetAudioBitrateConstraints(); } void AudioSendStream::UpdateOverheadForEncoder() { + RTC_DCHECK_RUN_ON(&worker_thread_checker_); size_t overhead_per_packet_bytes = GetPerPacketOverheadBytes(); if (overhead_per_packet_ == overhead_per_packet_bytes) { return; @@ -552,19 +557,11 @@ void AudioSendStream::UpdateOverheadForEncoder() { channel_send_->CallEncoder([&](AudioEncoder* encoder) { encoder->OnReceivedOverhead(overhead_per_packet_bytes); }); - auto update_task = [this, overhead_per_packet_bytes] { - RTC_DCHECK_RUN_ON(worker_queue_); - if (total_packet_overhead_bytes_ != overhead_per_packet_bytes) { - total_packet_overhead_bytes_ = overhead_per_packet_bytes; - if (registered_with_allocator_) { - ConfigureBitrateObserver(); - } + if (total_packet_overhead_bytes_ != overhead_per_packet_bytes) { + total_packet_overhead_bytes_ = overhead_per_packet_bytes; + if (registered_with_allocator_) { + ConfigureBitrateObserver(); } - }; - if (worker_queue_->IsCurrent()) { - update_task(); - } else { - worker_queue_->PostTask(update_task); } } @@ -602,7 +599,6 @@ const internal::AudioState* AudioSendStream::audio_state() const { void AudioSendStream::StoreEncoderProperties(int sample_rate_hz, size_t num_channels) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); encoder_sample_rate_hz_ = sample_rate_hz; encoder_num_channels_ = num_channels; if (sending_) { @@ -800,7 +796,6 @@ void AudioSendStream::ReconfigureCNG(const Config& new_config) { void AudioSendStream::ReconfigureBitrateObserver( const webrtc::AudioSendStream::Config& new_config) { - RTC_DCHECK_RUN_ON(&worker_thread_checker_); // Since the Config's default is for both of these to be -1, this test will // allow us to configure the bitrate observer if the new config has bitrate // limits set, but would only have us call RemoveBitrateObserver if we were @@ -819,20 +814,13 @@ void AudioSendStream::ReconfigureBitrateObserver( rtp_transport_->AccountForAudioPacketsInPacedSender(true); if (send_side_bwe_with_overhead_) rtp_transport_->IncludeOverheadInPacedSender(); - rtc::Event thread_sync_event; - worker_queue_->PostTask([&] { - RTC_DCHECK_RUN_ON(worker_queue_); - // We may get a callback immediately as the observer is registered, so - // make - // sure the bitrate limits in config_ are up-to-date. - config_.min_bitrate_bps = new_config.min_bitrate_bps; - config_.max_bitrate_bps = new_config.max_bitrate_bps; - - config_.bitrate_priority = new_config.bitrate_priority; - ConfigureBitrateObserver(); - thread_sync_event.Set(); - }); - thread_sync_event.Wait(rtc::Event::kForever); + // We may get a callback immediately as the observer is registered, so + // make sure the bitrate limits in config_ are up-to-date. + config_.min_bitrate_bps = new_config.min_bitrate_bps; + config_.max_bitrate_bps = new_config.max_bitrate_bps; + + config_.bitrate_priority = new_config.bitrate_priority; + ConfigureBitrateObserver(); rtp_rtcp_module_->SetAsPartOfAllocation(true); } else { rtp_transport_->AccountForAudioPacketsInPacedSender(false); @@ -845,6 +833,7 @@ void AudioSendStream::ConfigureBitrateObserver() { // This either updates the current observer or adds a new observer. // TODO(srte): Add overhead compensation here. auto constraints = GetMinMaxBitrateConstraints(); + RTC_DCHECK(constraints.has_value()); DataRate priority_bitrate = allocation_settings_.priority_bitrate; if (send_side_bwe_with_overhead_) { @@ -866,30 +855,41 @@ void AudioSendStream::ConfigureBitrateObserver() { if (allocation_settings_.priority_bitrate_raw) priority_bitrate = *allocation_settings_.priority_bitrate_raw; - bitrate_allocator_->AddObserver( - this, - MediaStreamAllocationConfig{ - constraints.min.bps(), constraints.max.bps(), 0, - priority_bitrate.bps(), true, - allocation_settings_.bitrate_priority.value_or( - config_.bitrate_priority)}); + rtp_transport_queue_->PostTask([this, constraints, priority_bitrate, + config_bitrate_priority = + config_.bitrate_priority] { + RTC_DCHECK_RUN_ON(rtp_transport_queue_); + bitrate_allocator_->AddObserver( + this, + MediaStreamAllocationConfig{ + constraints->min.bps(), constraints->max.bps(), + 0, priority_bitrate.bps(), true, + allocation_settings_.bitrate_priority.value_or( + config_bitrate_priority)}); + }); registered_with_allocator_ = true; } void AudioSendStream::RemoveBitrateObserver() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + registered_with_allocator_ = false; rtc::Event thread_sync_event; - worker_queue_->PostTask([this, &thread_sync_event] { - RTC_DCHECK_RUN_ON(worker_queue_); - registered_with_allocator_ = false; + rtp_transport_queue_->PostTask([this, &thread_sync_event] { + RTC_DCHECK_RUN_ON(rtp_transport_queue_); bitrate_allocator_->RemoveObserver(this); thread_sync_event.Set(); }); thread_sync_event.Wait(rtc::Event::kForever); } -AudioSendStream::TargetAudioBitrateConstraints +absl::optional AudioSendStream::GetMinMaxBitrateConstraints() const { + if (config_.min_bitrate_bps < 0 || config_.max_bitrate_bps < 0) { + RTC_LOG(LS_WARNING) << "Config is invalid: min_bitrate_bps=" + << config_.min_bitrate_bps + << "; max_bitrate_bps=" << config_.max_bitrate_bps + << "; both expected greater or equal to 0"; + return absl::nullopt; + } TargetAudioBitrateConstraints constraints{ DataRate::BitsPerSec(config_.min_bitrate_bps), DataRate::BitsPerSec(config_.max_bitrate_bps)}; @@ -902,7 +902,11 @@ AudioSendStream::GetMinMaxBitrateConstraints() const { RTC_DCHECK_GE(constraints.min, DataRate::Zero()); RTC_DCHECK_GE(constraints.max, DataRate::Zero()); - RTC_DCHECK_GE(constraints.max, constraints.min); + if (constraints.max < constraints.min) { + RTC_LOG(LS_WARNING) << "TargetAudioBitrateConstraints::max is less than " + << "TargetAudioBitrateConstraints::min"; + return absl::nullopt; + } if (send_side_bwe_with_overhead_) { if (use_legacy_overhead_calculation_) { // OverheadPerPacket = Ipv4(20B) + UDP(8B) + SRTP(10B) + RTP(12) @@ -913,7 +917,10 @@ AudioSendStream::GetMinMaxBitrateConstraints() const { constraints.min += kMinOverhead; constraints.max += kMinOverhead; } else { - RTC_DCHECK(frame_length_range_); + if (!frame_length_range_.has_value()) { + RTC_LOG(LS_WARNING) << "frame_length_range_ is not set"; + return absl::nullopt; + } const DataSize kOverheadPerPacket = DataSize::Bytes(total_packet_overhead_bytes_); constraints.min += kOverheadPerPacket / frame_length_range_->second; @@ -927,5 +934,18 @@ void AudioSendStream::RegisterCngPayloadType(int payload_type, int clockrate_hz) { channel_send_->RegisterCngPayloadType(payload_type, clockrate_hz); } + +void AudioSendStream::UpdateCachedTargetAudioBitrateConstraints() { + absl::optional + new_constraints = GetMinMaxBitrateConstraints(); + if (!new_constraints.has_value()) { + return; + } + rtp_transport_queue_->PostTask([this, new_constraints]() { + RTC_DCHECK_RUN_ON(rtp_transport_queue_); + cached_constraints_ = new_constraints; + }); +} + } // namespace internal } // namespace webrtc diff --git a/audio/audio_send_stream.h b/audio/audio_send_stream.h index 1e6982e41f..e0b15dc0c9 100644 --- a/audio/audio_send_stream.h +++ b/audio/audio_send_stream.h @@ -15,6 +15,7 @@ #include #include +#include "api/sequence_checker.h" #include "audio/audio_level.h" #include "audio/channel_send.h" #include "call/audio_send_stream.h" @@ -25,7 +26,6 @@ #include "rtc_base/race_checker.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/task_queue.h" -#include "rtc_base/thread_checker.h" namespace webrtc { class RtcEventLog; @@ -58,7 +58,6 @@ class AudioSendStream final : public webrtc::AudioSendStream, const webrtc::AudioSendStream::Config& config, const rtc::scoped_refptr& audio_state, TaskQueueFactory* task_queue_factory, - ProcessThread* module_process_thread, RtpTransportControllerSendInterface* rtp_transport, BitrateAllocatorInterface* bitrate_allocator, RtcEventLog* event_log, @@ -121,22 +120,29 @@ class AudioSendStream final : public webrtc::AudioSendStream, internal::AudioState* audio_state(); const internal::AudioState* audio_state() const; - void StoreEncoderProperties(int sample_rate_hz, size_t num_channels); - - void ConfigureStream(const Config& new_config, bool first_time); - bool SetupSendCodec(const Config& new_config); - bool ReconfigureSendCodec(const Config& new_config); - void ReconfigureANA(const Config& new_config); - void ReconfigureCNG(const Config& new_config); - void ReconfigureBitrateObserver(const Config& new_config); - - void ConfigureBitrateObserver() RTC_RUN_ON(worker_queue_); - void RemoveBitrateObserver(); + void StoreEncoderProperties(int sample_rate_hz, size_t num_channels) + RTC_RUN_ON(worker_thread_checker_); + + void ConfigureStream(const Config& new_config, bool first_time) + RTC_RUN_ON(worker_thread_checker_); + bool SetupSendCodec(const Config& new_config) + RTC_RUN_ON(worker_thread_checker_); + bool ReconfigureSendCodec(const Config& new_config) + RTC_RUN_ON(worker_thread_checker_); + void ReconfigureANA(const Config& new_config) + RTC_RUN_ON(worker_thread_checker_); + void ReconfigureCNG(const Config& new_config) + RTC_RUN_ON(worker_thread_checker_); + void ReconfigureBitrateObserver(const Config& new_config) + RTC_RUN_ON(worker_thread_checker_); + + void ConfigureBitrateObserver() RTC_RUN_ON(worker_thread_checker_); + void RemoveBitrateObserver() RTC_RUN_ON(worker_thread_checker_); // Returns bitrate constraints, maybe including overhead when enabled by // field trial. - TargetAudioBitrateConstraints GetMinMaxBitrateConstraints() const - RTC_RUN_ON(worker_queue_); + absl::optional GetMinMaxBitrateConstraints() + const RTC_RUN_ON(worker_thread_checker_); // Sets per-packet overhead on encoded (for ANA) based on current known values // of transport and packetization overheads. @@ -147,13 +153,18 @@ class AudioSendStream final : public webrtc::AudioSendStream, size_t GetPerPacketOverheadBytes() const RTC_EXCLUSIVE_LOCKS_REQUIRED(overhead_per_packet_lock_); - void RegisterCngPayloadType(int payload_type, int clockrate_hz); + void RegisterCngPayloadType(int payload_type, int clockrate_hz) + RTC_RUN_ON(worker_thread_checker_); + + void UpdateCachedTargetAudioBitrateConstraints() + RTC_RUN_ON(worker_thread_checker_); + Clock* clock_; - rtc::ThreadChecker worker_thread_checker_; - rtc::ThreadChecker pacer_thread_checker_; + SequenceChecker worker_thread_checker_; + SequenceChecker pacer_thread_checker_; rtc::RaceChecker audio_capture_race_checker_; - rtc::TaskQueue* worker_queue_; + rtc::TaskQueue* rtp_transport_queue_; const bool allocate_audio_without_feedback_; const bool force_no_audio_feedback_ = allocate_audio_without_feedback_; @@ -161,22 +172,26 @@ class AudioSendStream final : public webrtc::AudioSendStream, const bool send_side_bwe_with_overhead_; const AudioAllocationConfig allocation_settings_; - webrtc::AudioSendStream::Config config_; + webrtc::AudioSendStream::Config config_ + RTC_GUARDED_BY(worker_thread_checker_); rtc::scoped_refptr audio_state_; const std::unique_ptr channel_send_; RtcEventLog* const event_log_; const bool use_legacy_overhead_calculation_; - int encoder_sample_rate_hz_ = 0; - size_t encoder_num_channels_ = 0; - bool sending_ = false; + int encoder_sample_rate_hz_ RTC_GUARDED_BY(worker_thread_checker_) = 0; + size_t encoder_num_channels_ RTC_GUARDED_BY(worker_thread_checker_) = 0; + bool sending_ RTC_GUARDED_BY(worker_thread_checker_) = false; mutable Mutex audio_level_lock_; // Keeps track of audio level, total audio energy and total samples duration. // https://w3c.github.io/webrtc-stats/#dom-rtcaudiohandlerstats-totalaudioenergy webrtc::voe::AudioLevel audio_level_ RTC_GUARDED_BY(audio_level_lock_); BitrateAllocatorInterface* const bitrate_allocator_ - RTC_GUARDED_BY(worker_queue_); + RTC_GUARDED_BY(rtp_transport_queue_); + // Constrains cached to be accessed from |rtp_transport_queue_|. + absl::optional + cached_constraints_ RTC_GUARDED_BY(rtp_transport_queue_) = absl::nullopt; RtpTransportControllerSendInterface* const rtp_transport_; RtpRtcpInterface* const rtp_rtcp_module_; @@ -205,10 +220,12 @@ class AudioSendStream final : public webrtc::AudioSendStream, size_t transport_overhead_per_packet_bytes_ RTC_GUARDED_BY(overhead_per_packet_lock_) = 0; - bool registered_with_allocator_ RTC_GUARDED_BY(worker_queue_) = false; - size_t total_packet_overhead_bytes_ RTC_GUARDED_BY(worker_queue_) = 0; + bool registered_with_allocator_ RTC_GUARDED_BY(worker_thread_checker_) = + false; + size_t total_packet_overhead_bytes_ RTC_GUARDED_BY(worker_thread_checker_) = + 0; absl::optional> frame_length_range_ - RTC_GUARDED_BY(worker_queue_); + RTC_GUARDED_BY(worker_thread_checker_); }; } // namespace internal } // namespace webrtc diff --git a/audio/audio_send_stream_unittest.cc b/audio/audio_send_stream_unittest.cc index f76a8fa255..357e08040c 100644 --- a/audio/audio_send_stream_unittest.cc +++ b/audio/audio_send_stream_unittest.cc @@ -121,7 +121,7 @@ std::unique_ptr SetupAudioEncoderMock( rtc::scoped_refptr SetupEncoderFactoryMock() { rtc::scoped_refptr factory = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); ON_CALL(*factory.get(), GetSupportedEncoders()) .WillByDefault(Return(std::vector( std::begin(kCodecSpecs), std::end(kCodecSpecs)))); @@ -154,7 +154,7 @@ struct ConfigHelper { audio_processing_( use_null_audio_processing ? nullptr - : new rtc::RefCountedObject>()), + : rtc::make_ref_counted>()), bitrate_allocator_(&limit_observer_), worker_queue_(task_queue_factory_->CreateTaskQueue( "ConfigHelper_worker_queue", @@ -165,8 +165,7 @@ struct ConfigHelper { AudioState::Config config; config.audio_mixer = AudioMixerImpl::Create(); config.audio_processing = audio_processing_; - config.audio_device_module = - new rtc::RefCountedObject(); + config.audio_device_module = rtc::make_ref_counted(); audio_state_ = AudioState::Create(config); SetupDefaultChannelSend(audio_bwe_enabled); @@ -923,7 +922,7 @@ TEST(AudioSendStreamTest, ReconfigureWithFrameEncryptor) { auto new_config = helper.config(); rtc::scoped_refptr mock_frame_encryptor_0( - new rtc::RefCountedObject()); + rtc::make_ref_counted()); new_config.frame_encryptor = mock_frame_encryptor_0; EXPECT_CALL(*helper.channel_send(), SetFrameEncryptor(Ne(nullptr))) .Times(1); @@ -936,7 +935,7 @@ TEST(AudioSendStreamTest, ReconfigureWithFrameEncryptor) { // Updating frame encryptor to a new object should force a call to the // proxy. rtc::scoped_refptr mock_frame_encryptor_1( - new rtc::RefCountedObject()); + rtc::make_ref_counted()); new_config.frame_encryptor = mock_frame_encryptor_1; new_config.crypto_options.sframe.require_frame_encryption = true; EXPECT_CALL(*helper.channel_send(), SetFrameEncryptor(Ne(nullptr))) diff --git a/audio/audio_state.cc b/audio/audio_state.cc index 566bae1311..0e60f0372b 100644 --- a/audio/audio_state.cc +++ b/audio/audio_state.cc @@ -187,6 +187,6 @@ void AudioState::UpdateNullAudioPollerState() { rtc::scoped_refptr AudioState::Create( const AudioState::Config& config) { - return new rtc::RefCountedObject(config); + return rtc::make_ref_counted(config); } } // namespace webrtc diff --git a/audio/audio_state.h b/audio/audio_state.h index 5e766428d9..89c748dc4e 100644 --- a/audio/audio_state.h +++ b/audio/audio_state.h @@ -15,11 +15,11 @@ #include #include +#include "api/sequence_checker.h" #include "audio/audio_transport_impl.h" #include "audio/null_audio_poller.h" #include "call/audio_state.h" #include "rtc_base/ref_count.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -65,8 +65,8 @@ class AudioState : public webrtc::AudioState { void UpdateAudioTransportWithSendingStreams(); void UpdateNullAudioPollerState(); - rtc::ThreadChecker thread_checker_; - rtc::ThreadChecker process_thread_checker_; + SequenceChecker thread_checker_; + SequenceChecker process_thread_checker_; const webrtc::AudioState::Config config_; bool recording_enabled_ = true; bool playout_enabled_ = true; diff --git a/audio/audio_state_unittest.cc b/audio/audio_state_unittest.cc index 02fc04e6dc..5f07a7b339 100644 --- a/audio/audio_state_unittest.cc +++ b/audio/audio_state_unittest.cc @@ -90,7 +90,7 @@ struct FakeAsyncAudioProcessingHelper { FakeTaskQueueFactory task_queue_factory_; rtc::scoped_refptr CreateFactory() { - return new rtc::RefCountedObject( + return rtc::make_ref_counted( audio_frame_processor_, task_queue_factory_); } }; @@ -107,10 +107,9 @@ struct ConfigHelper { audio_state_config.audio_processing = params.use_null_audio_processing ? nullptr - : new rtc::RefCountedObject< - testing::NiceMock>(); + : rtc::make_ref_counted>(); audio_state_config.audio_device_module = - new rtc::RefCountedObject>(); + rtc::make_ref_counted>(); if (params.use_async_audio_processing) { audio_state_config.async_audio_processing_factory = async_audio_processing_helper_.CreateFactory(); @@ -183,7 +182,7 @@ TEST_P(AudioStateTest, Create) { TEST_P(AudioStateTest, ConstructDestruct) { ConfigHelper helper(GetParam()); rtc::scoped_refptr audio_state( - new rtc::RefCountedObject(helper.config())); + rtc::make_ref_counted(helper.config())); } TEST_P(AudioStateTest, RecordedAudioArrivesAtSingleStream) { @@ -196,7 +195,7 @@ TEST_P(AudioStateTest, RecordedAudioArrivesAtSingleStream) { } rtc::scoped_refptr audio_state( - new rtc::RefCountedObject(helper.config())); + rtc::make_ref_counted(helper.config())); MockAudioSendStream stream; audio_state->AddSendingStream(&stream, 8000, 2); @@ -245,7 +244,7 @@ TEST_P(AudioStateTest, RecordedAudioArrivesAtMultipleStreams) { } rtc::scoped_refptr audio_state( - new rtc::RefCountedObject(helper.config())); + rtc::make_ref_counted(helper.config())); MockAudioSendStream stream_1; MockAudioSendStream stream_2; @@ -308,7 +307,7 @@ TEST_P(AudioStateTest, EnableChannelSwap) { } rtc::scoped_refptr audio_state( - new rtc::RefCountedObject(helper.config())); + rtc::make_ref_counted(helper.config())); audio_state->SetStereoChannelSwapping(true); diff --git a/audio/channel_receive.cc b/audio/channel_receive.cc index 2788dacf78..57269cd193 100644 --- a/audio/channel_receive.cc +++ b/audio/channel_receive.cc @@ -10,8 +10,6 @@ #include "audio/channel_receive.h" -#include - #include #include #include @@ -22,6 +20,8 @@ #include "api/crypto/frame_decryptor_interface.h" #include "api/frame_transformer_interface.h" #include "api/rtc_event_log/rtc_event_log.h" +#include "api/sequence_checker.h" +#include "api/task_queue/task_queue_base.h" #include "audio/audio_level.h" #include "audio/channel_receive_frame_transformer_delegate.h" #include "audio/channel_send.h" @@ -33,7 +33,8 @@ #include "modules/pacing/packet_router.h" #include "modules/rtp_rtcp/include/receive_statistics.h" #include "modules/rtp_rtcp/include/remote_ntp_time_estimator.h" -#include "modules/rtp_rtcp/source/absolute_capture_time_receiver.h" +#include "modules/rtp_rtcp/source/absolute_capture_time_interpolator.h" +#include "modules/rtp_rtcp/source/capture_clock_offset_updater.h" #include "modules/rtp_rtcp/source/rtp_header_extensions.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "modules/rtp_rtcp/source/rtp_rtcp_config.h" @@ -46,7 +47,9 @@ #include "rtc_base/numerics/safe_minmax.h" #include "rtc_base/race_checker.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/thread_checker.h" +#include "rtc_base/system/no_unique_address.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/time_utils.h" #include "system_wrappers/include/metrics.h" @@ -78,12 +81,12 @@ AudioCodingModule::Config AcmConfig( return acm_config; } -class ChannelReceive : public ChannelReceiveInterface { +class ChannelReceive : public ChannelReceiveInterface, + public RtcpPacketTypeCounterObserver { public: // Used for receive streams. ChannelReceive( Clock* clock, - ProcessThread* module_process_thread, NetEqFactory* neteq_factory, AudioDeviceModule* audio_device_module, Transport* rtcp_send_transport, @@ -162,6 +165,8 @@ class ChannelReceive : public ChannelReceiveInterface { int PreferredSampleRate() const override; + void SetSourceTracker(SourceTracker* source_tracker) override; + // Associate to a send channel. // Used for obtaining RTT for a receive-only channel. void SetAssociatedSendChannel(const ChannelSendInterface* channel) override; @@ -172,44 +177,55 @@ class ChannelReceive : public ChannelReceiveInterface { rtc::scoped_refptr frame_transformer) override; + void SetFrameDecryptor(rtc::scoped_refptr + frame_decryptor) override; + + void OnLocalSsrcChange(uint32_t local_ssrc) override; + uint32_t GetLocalSsrc() const override; + + void RtcpPacketTypesCounterUpdated( + uint32_t ssrc, + const RtcpPacketTypeCounter& packet_counter) override; + private: void ReceivePacket(const uint8_t* packet, size_t packet_length, - const RTPHeader& header); + const RTPHeader& header) + RTC_RUN_ON(worker_thread_checker_); int ResendPackets(const uint16_t* sequence_numbers, int length); - void UpdatePlayoutTimestamp(bool rtcp, int64_t now_ms); + void UpdatePlayoutTimestamp(bool rtcp, int64_t now_ms) + RTC_RUN_ON(worker_thread_checker_); int GetRtpTimestampRateHz() const; int64_t GetRTT() const; void OnReceivedPayloadData(rtc::ArrayView payload, - const RTPHeader& rtpHeader); + const RTPHeader& rtpHeader) + RTC_RUN_ON(worker_thread_checker_); void InitFrameTransformerDelegate( - rtc::scoped_refptr frame_transformer); - - bool Playing() const { - MutexLock lock(&playing_lock_); - return playing_; - } + rtc::scoped_refptr frame_transformer) + RTC_RUN_ON(worker_thread_checker_); // Thread checkers document and lock usage of some methods to specific threads // we know about. The goal is to eventually split up voe::ChannelReceive into // parts with single-threaded semantics, and thereby reduce the need for // locks. - rtc::ThreadChecker worker_thread_checker_; - rtc::ThreadChecker module_process_thread_checker_; + RTC_NO_UNIQUE_ADDRESS SequenceChecker worker_thread_checker_; + RTC_NO_UNIQUE_ADDRESS SequenceChecker network_thread_checker_; + + TaskQueueBase* const worker_thread_; + ScopedTaskSafety worker_safety_; + // Methods accessed from audio and video threads are checked for sequential- // only access. We don't necessarily own and control these threads, so thread // checkers cannot be used. E.g. Chromium may transfer "ownership" from one // audio thread to another, but access is still sequential. rtc::RaceChecker audio_thread_race_checker_; - rtc::RaceChecker video_capture_thread_race_checker_; Mutex callback_mutex_; Mutex volume_settings_mutex_; - mutable Mutex playing_lock_; - bool playing_ RTC_GUARDED_BY(&playing_lock_) = false; + bool playing_ RTC_GUARDED_BY(worker_thread_checker_) = false; RtcEventLog* const event_log_; @@ -219,34 +235,34 @@ class ChannelReceive : public ChannelReceiveInterface { std::unique_ptr rtp_receive_statistics_; std::unique_ptr rtp_rtcp_; const uint32_t remote_ssrc_; + SourceTracker* source_tracker_ = nullptr; // Info for GetSyncInfo is updated on network or worker thread, and queried on // the worker thread. - mutable Mutex sync_info_lock_; absl::optional last_received_rtp_timestamp_ - RTC_GUARDED_BY(&sync_info_lock_); + RTC_GUARDED_BY(&worker_thread_checker_); absl::optional last_received_rtp_system_time_ms_ - RTC_GUARDED_BY(&sync_info_lock_); + RTC_GUARDED_BY(&worker_thread_checker_); // The AcmReceiver is thread safe, using its own lock. acm2::AcmReceiver acm_receiver_; AudioSinkInterface* audio_sink_ = nullptr; AudioLevel _outputAudioLevel; + Clock* const clock_; RemoteNtpTimeEstimator ntp_estimator_ RTC_GUARDED_BY(ts_stats_lock_); // Timestamp of the audio pulled from NetEq. absl::optional jitter_buffer_playout_timestamp_; - mutable Mutex video_sync_lock_; - uint32_t playout_timestamp_rtp_ RTC_GUARDED_BY(video_sync_lock_); + uint32_t playout_timestamp_rtp_ RTC_GUARDED_BY(worker_thread_checker_); absl::optional playout_timestamp_rtp_time_ms_ - RTC_GUARDED_BY(video_sync_lock_); - uint32_t playout_delay_ms_ RTC_GUARDED_BY(video_sync_lock_); + RTC_GUARDED_BY(worker_thread_checker_); + uint32_t playout_delay_ms_ RTC_GUARDED_BY(worker_thread_checker_); absl::optional playout_timestamp_ntp_ - RTC_GUARDED_BY(video_sync_lock_); + RTC_GUARDED_BY(worker_thread_checker_); absl::optional playout_timestamp_ntp_time_ms_ - RTC_GUARDED_BY(video_sync_lock_); + RTC_GUARDED_BY(worker_thread_checker_); mutable Mutex ts_stats_lock_; @@ -257,36 +273,64 @@ class ChannelReceive : public ChannelReceiveInterface { // frame. int64_t capture_start_ntp_time_ms_ RTC_GUARDED_BY(ts_stats_lock_); - // uses - ProcessThread* _moduleProcessThreadPtr; AudioDeviceModule* _audioDeviceModulePtr; float _outputGain RTC_GUARDED_BY(volume_settings_mutex_); - // An associated send channel. - mutable Mutex assoc_send_channel_lock_; const ChannelSendInterface* associated_send_channel_ - RTC_GUARDED_BY(assoc_send_channel_lock_); + RTC_GUARDED_BY(network_thread_checker_); PacketRouter* packet_router_ = nullptr; - rtc::ThreadChecker construction_thread_; + SequenceChecker construction_thread_; // E2EE Audio Frame Decryption - rtc::scoped_refptr frame_decryptor_; + rtc::scoped_refptr frame_decryptor_ + RTC_GUARDED_BY(worker_thread_checker_); webrtc::CryptoOptions crypto_options_; - webrtc::AbsoluteCaptureTimeReceiver absolute_capture_time_receiver_; + webrtc::AbsoluteCaptureTimeInterpolator absolute_capture_time_interpolator_ + RTC_GUARDED_BY(worker_thread_checker_); + + webrtc::CaptureClockOffsetUpdater capture_clock_offset_updater_; rtc::scoped_refptr frame_transformer_delegate_; + + // Counter that's used to control the frequency of reporting histograms + // from the `GetAudioFrameWithInfo` callback. + int audio_frame_interval_count_ RTC_GUARDED_BY(audio_thread_race_checker_) = + 0; + // Controls how many callbacks we let pass by before reporting callback stats. + // A value of 100 means 100 callbacks, each one of which represents 10ms worth + // of data, so the stats reporting frequency will be 1Hz (modulo failures). + constexpr static int kHistogramReportingInterval = 100; + + mutable Mutex rtcp_counter_mutex_; + RtcpPacketTypeCounter rtcp_packet_type_counter_ + RTC_GUARDED_BY(rtcp_counter_mutex_); }; void ChannelReceive::OnReceivedPayloadData( rtc::ArrayView payload, const RTPHeader& rtpHeader) { - if (!Playing()) { + if (!playing_) { // Avoid inserting into NetEQ when we are not playing. Count the // packet as discarded. + + // If we have a source_tracker_, tell it that the frame has been + // "delivered". Normally, this happens in AudioReceiveStream when audio + // frames are pulled out, but when playout is muted, nothing is pulling + // frames. The downside of this approach is that frames delivered this way + // won't be delayed for playout, and therefore will be unsynchronized with + // (a) audio delay when playing and (b) any audio/video synchronization. But + // the alternative is that muting playout also stops the SourceTracker from + // updating RtpSource information. + if (source_tracker_) { + RtpPacketInfos::vector_type packet_vector = { + RtpPacketInfo(rtpHeader, clock_->CurrentTime())}; + source_tracker_->OnFrameDelivered(RtpPacketInfos(packet_vector)); + } + return; } @@ -312,18 +356,20 @@ void ChannelReceive::InitFrameTransformerDelegate( rtc::scoped_refptr frame_transformer) { RTC_DCHECK(frame_transformer); RTC_DCHECK(!frame_transformer_delegate_); + RTC_DCHECK(worker_thread_->IsCurrent()); // Pass a callback to ChannelReceive::OnReceivedPayloadData, to be called by // the delegate to receive transformed audio. ChannelReceiveFrameTransformerDelegate::ReceiveFrameCallback receive_audio_callback = [this](rtc::ArrayView packet, const RTPHeader& header) { + RTC_DCHECK_RUN_ON(&worker_thread_checker_); OnReceivedPayloadData(packet, header); }; frame_transformer_delegate_ = - new rtc::RefCountedObject( + rtc::make_ref_counted( std::move(receive_audio_callback), std::move(frame_transformer), - rtc::Thread::Current()); + worker_thread_); frame_transformer_delegate_->Init(); } @@ -418,17 +464,37 @@ AudioMixer::Source::AudioFrameInfo ChannelReceive::GetAudioFrameWithInfo( } } - { - RTC_HISTOGRAM_COUNTS_1000("WebRTC.Audio.TargetJitterBufferDelayMs", - acm_receiver_.TargetDelayMs()); - const int jitter_buffer_delay = acm_receiver_.FilteredCurrentDelayMs(); - MutexLock lock(&video_sync_lock_); - RTC_HISTOGRAM_COUNTS_1000("WebRTC.Audio.ReceiverDelayEstimateMs", - jitter_buffer_delay + playout_delay_ms_); - RTC_HISTOGRAM_COUNTS_1000("WebRTC.Audio.ReceiverJitterBufferDelayMs", - jitter_buffer_delay); - RTC_HISTOGRAM_COUNTS_1000("WebRTC.Audio.ReceiverDeviceDelayMs", - playout_delay_ms_); + // Fill in local capture clock offset in |audio_frame->packet_infos_|. + RtpPacketInfos::vector_type packet_infos; + for (auto& packet_info : audio_frame->packet_infos_) { + absl::optional local_capture_clock_offset; + if (packet_info.absolute_capture_time().has_value()) { + local_capture_clock_offset = + capture_clock_offset_updater_.AdjustEstimatedCaptureClockOffset( + packet_info.absolute_capture_time() + ->estimated_capture_clock_offset); + } + RtpPacketInfo new_packet_info(packet_info); + new_packet_info.set_local_capture_clock_offset(local_capture_clock_offset); + packet_infos.push_back(std::move(new_packet_info)); + } + audio_frame->packet_infos_ = RtpPacketInfos(packet_infos); + + ++audio_frame_interval_count_; + if (audio_frame_interval_count_ >= kHistogramReportingInterval) { + audio_frame_interval_count_ = 0; + worker_thread_->PostTask(ToQueuedTask(worker_safety_, [this]() { + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + RTC_HISTOGRAM_COUNTS_1000("WebRTC.Audio.TargetJitterBufferDelayMs", + acm_receiver_.TargetDelayMs()); + const int jitter_buffer_delay = acm_receiver_.FilteredCurrentDelayMs(); + RTC_HISTOGRAM_COUNTS_1000("WebRTC.Audio.ReceiverDelayEstimateMs", + jitter_buffer_delay + playout_delay_ms_); + RTC_HISTOGRAM_COUNTS_1000("WebRTC.Audio.ReceiverJitterBufferDelayMs", + jitter_buffer_delay); + RTC_HISTOGRAM_COUNTS_1000("WebRTC.Audio.ReceiverDeviceDelayMs", + playout_delay_ms_); + })); } return muted ? AudioMixer::Source::AudioFrameInfo::kMuted @@ -442,9 +508,12 @@ int ChannelReceive::PreferredSampleRate() const { acm_receiver_.last_output_sample_rate_hz()); } +void ChannelReceive::SetSourceTracker(SourceTracker* source_tracker) { + source_tracker_ = source_tracker; +} + ChannelReceive::ChannelReceive( Clock* clock, - ProcessThread* module_process_thread, NetEqFactory* neteq_factory, AudioDeviceModule* audio_device_module, Transport* rtcp_send_transport, @@ -460,7 +529,8 @@ ChannelReceive::ChannelReceive( rtc::scoped_refptr frame_decryptor, const webrtc::CryptoOptions& crypto_options, rtc::scoped_refptr frame_transformer) - : event_log_(rtc_event_log), + : worker_thread_(TaskQueueBase::Current()), + event_log_(rtc_event_log), rtp_receive_statistics_(ReceiveStatistics::Create(clock)), remote_ssrc_(remote_ssrc), acm_receiver_(AcmConfig(neteq_factory, @@ -469,25 +539,23 @@ ChannelReceive::ChannelReceive( jitter_buffer_max_packets, jitter_buffer_fast_playout)), _outputAudioLevel(), + clock_(clock), ntp_estimator_(clock), playout_timestamp_rtp_(0), playout_delay_ms_(0), rtp_ts_wraparound_handler_(new rtc::TimestampWrapAroundHandler()), capture_start_rtp_time_stamp_(-1), capture_start_ntp_time_ms_(-1), - _moduleProcessThreadPtr(module_process_thread), _audioDeviceModulePtr(audio_device_module), _outputGain(1.0f), associated_send_channel_(nullptr), frame_decryptor_(frame_decryptor), crypto_options_(crypto_options), - absolute_capture_time_receiver_(clock) { - // TODO(nisse): Use _moduleProcessThreadPtr instead? - module_process_thread_checker_.Detach(); - - RTC_DCHECK(module_process_thread); + absolute_capture_time_interpolator_(clock) { RTC_DCHECK(audio_device_module); + network_thread_checker_.Detach(); + acm_receiver_.ResetInitialDelay(); acm_receiver_.SetMinimumDelay(0); acm_receiver_.SetMaximumDelay(0); @@ -504,6 +572,7 @@ ChannelReceive::ChannelReceive( configuration.receive_statistics = rtp_receive_statistics_.get(); configuration.event_log = event_log_; configuration.local_media_ssrc = local_ssrc; + configuration.rtcp_packet_type_counter_observer = this; if (frame_transformer) InitFrameTransformerDelegate(std::move(frame_transformer)); @@ -512,53 +581,46 @@ ChannelReceive::ChannelReceive( rtp_rtcp_->SetSendingMediaStatus(false); rtp_rtcp_->SetRemoteSSRC(remote_ssrc_); - _moduleProcessThreadPtr->RegisterModule(rtp_rtcp_.get(), RTC_FROM_HERE); - // Ensure that RTCP is enabled for the created channel. rtp_rtcp_->SetRTCPStatus(RtcpMode::kCompound); } ChannelReceive::~ChannelReceive() { - RTC_DCHECK(construction_thread_.IsCurrent()); + RTC_DCHECK_RUN_ON(&construction_thread_); // Resets the delegate's callback to ChannelReceive::OnReceivedPayloadData. if (frame_transformer_delegate_) frame_transformer_delegate_->Reset(); StopPlayout(); - - if (_moduleProcessThreadPtr) - _moduleProcessThreadPtr->DeRegisterModule(rtp_rtcp_.get()); } void ChannelReceive::SetSink(AudioSinkInterface* sink) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); MutexLock lock(&callback_mutex_); audio_sink_ = sink; } void ChannelReceive::StartPlayout() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - MutexLock lock(&playing_lock_); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); playing_ = true; } void ChannelReceive::StopPlayout() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - MutexLock lock(&playing_lock_); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); playing_ = false; _outputAudioLevel.ResetLevelFullRange(); } absl::optional> ChannelReceive::GetReceiveCodec() const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); return acm_receiver_.LastDecoder(); } void ChannelReceive::SetReceiveCodecs( const std::map& codecs) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); for (const auto& kv : codecs) { RTC_DCHECK_GE(kv.second.clockrate_hz, 1000); payload_type_frequencies_[kv.first] = kv.second.clockrate_hz; @@ -566,15 +628,15 @@ void ChannelReceive::SetReceiveCodecs( acm_receiver_.SetCodecs(codecs); } -// May be called on either worker thread or network thread. void ChannelReceive::OnRtpPacket(const RtpPacketReceived& packet) { + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + // TODO(bugs.webrtc.org/11993): Expect to be called exclusively on the + // network thread. Once that's done, the same applies to + // UpdatePlayoutTimestamp and int64_t now_ms = rtc::TimeMillis(); - { - MutexLock lock(&sync_info_lock_); - last_received_rtp_timestamp_ = packet.Timestamp(); - last_received_rtp_system_time_ms_ = now_ms; - } + last_received_rtp_timestamp_ = packet.Timestamp(); + last_received_rtp_system_time_ms_ = now_ms; // Store playout timestamp for the received RTP packet UpdatePlayoutTimestamp(false, now_ms); @@ -593,9 +655,9 @@ void ChannelReceive::OnRtpPacket(const RtpPacketReceived& packet) { // Interpolates absolute capture timestamp RTP header extension. header.extension.absolute_capture_time = - absolute_capture_time_receiver_.OnReceivePacket( - AbsoluteCaptureTimeReceiver::GetSource(header.ssrc, - header.arrOfCSRCs), + absolute_capture_time_interpolator_.OnReceivePacket( + AbsoluteCaptureTimeInterpolator::GetSource(header.ssrc, + header.arrOfCSRCs), header.timestamp, rtc::saturated_cast(packet_copy.payload_type_frequency()), header.extension.absolute_capture_time); @@ -607,7 +669,7 @@ void ChannelReceive::ReceivePacket(const uint8_t* packet, size_t packet_length, const RTPHeader& header) { const uint8_t* payload = packet + header.headerLength; - assert(packet_length >= header.headerLength); + RTC_DCHECK_GE(packet_length, header.headerLength); size_t payload_length = packet_length - header.headerLength; size_t payload_data_length = payload_length - header.paddingLength; @@ -654,8 +716,11 @@ void ChannelReceive::ReceivePacket(const uint8_t* packet, } } -// May be called on either worker thread or network thread. void ChannelReceive::ReceivedRTCPPacket(const uint8_t* data, size_t length) { + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + // TODO(bugs.webrtc.org/11993): Expect to be called exclusively on the + // network thread. + // Store playout timestamp for the received RTCP packet UpdatePlayoutTimestamp(true, rtc::TimeMillis()); @@ -671,8 +736,10 @@ void ChannelReceive::ReceivedRTCPPacket(const uint8_t* data, size_t length) { uint32_t ntp_secs = 0; uint32_t ntp_frac = 0; uint32_t rtp_timestamp = 0; - if (0 != - rtp_rtcp_->RemoteNTP(&ntp_secs, &ntp_frac, NULL, NULL, &rtp_timestamp)) { + if (rtp_rtcp_->RemoteNTP(&ntp_secs, &ntp_frac, + /*rtcp_arrival_time_secs=*/nullptr, + /*rtcp_arrival_time_frac=*/nullptr, + &rtp_timestamp) != 0) { // Waiting for RTCP. return; } @@ -680,33 +747,39 @@ void ChannelReceive::ReceivedRTCPPacket(const uint8_t* data, size_t length) { { MutexLock lock(&ts_stats_lock_); ntp_estimator_.UpdateRtcpTimestamp(rtt, ntp_secs, ntp_frac, rtp_timestamp); + absl::optional remote_to_local_clock_offset_ms = + ntp_estimator_.EstimateRemoteToLocalClockOffsetMs(); + if (remote_to_local_clock_offset_ms.has_value()) { + capture_clock_offset_updater_.SetRemoteToLocalClockOffset( + Int64MsToQ32x32(*remote_to_local_clock_offset_ms)); + } } } int ChannelReceive::GetSpeechOutputLevelFullRange() const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); return _outputAudioLevel.LevelFullRange(); } double ChannelReceive::GetTotalOutputEnergy() const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); return _outputAudioLevel.TotalEnergy(); } double ChannelReceive::GetTotalOutputDuration() const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); return _outputAudioLevel.TotalDuration(); } void ChannelReceive::SetChannelOutputVolumeScaling(float scaling) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); MutexLock lock(&volume_settings_mutex_); _outputGain = scaling; } void ChannelReceive::RegisterReceiverCongestionControlObjects( PacketRouter* packet_router) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_DCHECK(packet_router); RTC_DCHECK(!packet_router_); constexpr bool remb_candidate = false; @@ -715,19 +788,18 @@ void ChannelReceive::RegisterReceiverCongestionControlObjects( } void ChannelReceive::ResetReceiverCongestionControlObjects() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_DCHECK(packet_router_); packet_router_->RemoveReceiveRtpModule(rtp_rtcp_.get()); packet_router_ = nullptr; } CallReceiveStatistics ChannelReceive::GetRTCPStatistics() const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - // --- RtcpStatistics + RTC_DCHECK_RUN_ON(&worker_thread_checker_); CallReceiveStatistics stats; - // The jitter statistics is updated for each received RTP packet and is - // based on received packets. + // The jitter statistics is updated for each received RTP packet and is based + // on received packets. RtpReceiveStats rtp_stats; StreamStatistician* statistician = rtp_receive_statistics_->GetStatistician(remote_ssrc_); @@ -738,10 +810,9 @@ CallReceiveStatistics ChannelReceive::GetRTCPStatistics() const { stats.cumulativeLost = rtp_stats.packets_lost; stats.jitterSamples = rtp_stats.jitter; - // --- RTT stats.rttMs = GetRTT(); - // --- Data counters + // Data counters. if (statistician) { stats.payload_bytes_rcvd = rtp_stats.packet_counter.payload_bytes; @@ -758,16 +829,38 @@ CallReceiveStatistics ChannelReceive::GetRTCPStatistics() const { stats.last_packet_received_timestamp_ms = absl::nullopt; } - // --- Timestamps + { + MutexLock lock(&rtcp_counter_mutex_); + stats.nacks_sent = rtcp_packet_type_counter_.nack_packets; + } + + // Timestamps. { MutexLock lock(&ts_stats_lock_); stats.capture_start_ntp_time_ms_ = capture_start_ntp_time_ms_; } + + absl::optional rtcp_sr_stats = + rtp_rtcp_->GetSenderReportStats(); + if (rtcp_sr_stats.has_value()) { + // Number of seconds since 1900 January 1 00:00 GMT (see + // https://tools.ietf.org/html/rfc868). + constexpr int64_t kNtpJan1970Millisecs = + 2208988800 * rtc::kNumMillisecsPerSec; + stats.last_sender_report_timestamp_ms = + rtcp_sr_stats->last_arrival_timestamp.ToMs() - kNtpJan1970Millisecs; + stats.last_sender_report_remote_timestamp_ms = + rtcp_sr_stats->last_remote_timestamp.ToMs() - kNtpJan1970Millisecs; + stats.sender_reports_packets_sent = rtcp_sr_stats->packets_sent; + stats.sender_reports_bytes_sent = rtcp_sr_stats->bytes_sent; + stats.sender_reports_reports_count = rtcp_sr_stats->reports_count; + } + return stats; } void ChannelReceive::SetNACKStatus(bool enable, int max_packets) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); // None of these functions can fail. if (enable) { rtp_receive_statistics_->SetMaxReorderingThreshold(max_packets); @@ -785,47 +878,80 @@ int ChannelReceive::ResendPackets(const uint16_t* sequence_numbers, return rtp_rtcp_->SendNACK(sequence_numbers, length); } +void ChannelReceive::RtcpPacketTypesCounterUpdated( + uint32_t ssrc, + const RtcpPacketTypeCounter& packet_counter) { + if (ssrc != remote_ssrc_) { + return; + } + MutexLock lock(&rtcp_counter_mutex_); + rtcp_packet_type_counter_ = packet_counter; +} + void ChannelReceive::SetAssociatedSendChannel( const ChannelSendInterface* channel) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - MutexLock lock(&assoc_send_channel_lock_); + RTC_DCHECK_RUN_ON(&network_thread_checker_); associated_send_channel_ = channel; } void ChannelReceive::SetDepacketizerToDecoderFrameTransformer( rtc::scoped_refptr frame_transformer) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); // Depending on when the channel is created, the transformer might be set // twice. Don't replace the delegate if it was already initialized. - if (!frame_transformer || frame_transformer_delegate_) + if (!frame_transformer || frame_transformer_delegate_) { + RTC_NOTREACHED() << "Not setting the transformer?"; return; + } + InitFrameTransformerDelegate(std::move(frame_transformer)); } +void ChannelReceive::SetFrameDecryptor( + rtc::scoped_refptr frame_decryptor) { + // TODO(bugs.webrtc.org/11993): Expect to be called on the network thread. + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + frame_decryptor_ = std::move(frame_decryptor); +} + +void ChannelReceive::OnLocalSsrcChange(uint32_t local_ssrc) { + // TODO(bugs.webrtc.org/11993): Expect to be called on the network thread. + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + rtp_rtcp_->SetLocalSsrc(local_ssrc); +} + +uint32_t ChannelReceive::GetLocalSsrc() const { + // TODO(bugs.webrtc.org/11993): Expect to be called on the network thread. + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + return rtp_rtcp_->local_media_ssrc(); +} + NetworkStatistics ChannelReceive::GetNetworkStatistics( bool get_and_clear_legacy_stats) const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); NetworkStatistics stats; acm_receiver_.GetNetworkStatistics(&stats, get_and_clear_legacy_stats); return stats; } AudioDecodingCallStats ChannelReceive::GetDecodingCallStatistics() const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); AudioDecodingCallStats stats; acm_receiver_.GetDecodingCallStatistics(&stats); return stats; } uint32_t ChannelReceive::GetDelayEstimate() const { - RTC_DCHECK(worker_thread_checker_.IsCurrent() || - module_process_thread_checker_.IsCurrent()); - MutexLock lock(&video_sync_lock_); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + // Return the current jitter buffer delay + playout delay. return acm_receiver_.FilteredCurrentDelayMs() + playout_delay_ms_; } bool ChannelReceive::SetMinimumPlayoutDelay(int delay_ms) { - RTC_DCHECK(module_process_thread_checker_.IsCurrent()); + // TODO(bugs.webrtc.org/11993): This should run on the network thread. + // We get here via RtpStreamsSynchronizer. Once that's done, many (all?) of + // these locks aren't needed. + RTC_DCHECK_RUN_ON(&worker_thread_checker_); // Limit to range accepted by both VoE and ACM, so we're at least getting as // close as possible, instead of failing. delay_ms = rtc::SafeClamp(delay_ms, kVoiceEngineMinMinPlayoutDelayMs, @@ -840,29 +966,24 @@ bool ChannelReceive::SetMinimumPlayoutDelay(int delay_ms) { bool ChannelReceive::GetPlayoutRtpTimestamp(uint32_t* rtp_timestamp, int64_t* time_ms) const { - RTC_DCHECK_RUNS_SERIALIZED(&video_capture_thread_race_checker_); - { - MutexLock lock(&video_sync_lock_); - if (!playout_timestamp_rtp_time_ms_) - return false; - *rtp_timestamp = playout_timestamp_rtp_; - *time_ms = playout_timestamp_rtp_time_ms_.value(); - return true; - } + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + if (!playout_timestamp_rtp_time_ms_) + return false; + *rtp_timestamp = playout_timestamp_rtp_; + *time_ms = playout_timestamp_rtp_time_ms_.value(); + return true; } void ChannelReceive::SetEstimatedPlayoutNtpTimestampMs(int64_t ntp_timestamp_ms, int64_t time_ms) { - RTC_DCHECK_RUNS_SERIALIZED(&video_capture_thread_race_checker_); - MutexLock lock(&video_sync_lock_); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); playout_timestamp_ntp_ = ntp_timestamp_ms; playout_timestamp_ntp_time_ms_ = time_ms; } absl::optional ChannelReceive::GetCurrentEstimatedPlayoutNtpTimestampMs(int64_t now_ms) const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - MutexLock lock(&video_sync_lock_); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); if (!playout_timestamp_ntp_ || !playout_timestamp_ntp_time_ms_) return absl::nullopt; @@ -879,25 +1000,36 @@ int ChannelReceive::GetBaseMinimumPlayoutDelayMs() const { } absl::optional ChannelReceive::GetSyncInfo() const { - RTC_DCHECK(module_process_thread_checker_.IsCurrent()); + // TODO(bugs.webrtc.org/11993): This should run on the network thread. + // We get here via RtpStreamsSynchronizer. Once that's done, many of + // these locks aren't needed. + RTC_DCHECK_RUN_ON(&worker_thread_checker_); Syncable::Info info; if (rtp_rtcp_->RemoteNTP(&info.capture_time_ntp_secs, - &info.capture_time_ntp_frac, nullptr, nullptr, + &info.capture_time_ntp_frac, + /*rtcp_arrival_time_secs=*/nullptr, + /*rtcp_arrival_time_frac=*/nullptr, &info.capture_time_source_clock) != 0) { return absl::nullopt; } - { - MutexLock lock(&sync_info_lock_); - if (!last_received_rtp_timestamp_ || !last_received_rtp_system_time_ms_) { - return absl::nullopt; - } - info.latest_received_capture_timestamp = *last_received_rtp_timestamp_; - info.latest_receive_time_ms = *last_received_rtp_system_time_ms_; + + if (!last_received_rtp_timestamp_ || !last_received_rtp_system_time_ms_) { + return absl::nullopt; } + info.latest_received_capture_timestamp = *last_received_rtp_timestamp_; + info.latest_receive_time_ms = *last_received_rtp_system_time_ms_; + + int jitter_buffer_delay = acm_receiver_.FilteredCurrentDelayMs(); + info.current_delay_ms = jitter_buffer_delay + playout_delay_ms_; + return info; } +// RTC_RUN_ON(worker_thread_checker_) void ChannelReceive::UpdatePlayoutTimestamp(bool rtcp, int64_t now_ms) { + // TODO(bugs.webrtc.org/11993): Expect to be called exclusively on the + // network thread. Once that's done, we won't need video_sync_lock_. + jitter_buffer_playout_timestamp_ = acm_receiver_.GetPlayoutTimestamp(); if (!jitter_buffer_playout_timestamp_) { @@ -920,14 +1052,11 @@ void ChannelReceive::UpdatePlayoutTimestamp(bool rtcp, int64_t now_ms) { // Remove the playout delay. playout_timestamp -= (delay_ms * (GetRtpTimestampRateHz() / 1000)); - { - MutexLock lock(&video_sync_lock_); - if (!rtcp && playout_timestamp != playout_timestamp_rtp_) { - playout_timestamp_rtp_ = playout_timestamp; - playout_timestamp_rtp_time_ms_ = now_ms; - } - playout_delay_ms_ = delay_ms; + if (!rtcp && playout_timestamp != playout_timestamp_rtp_) { + playout_timestamp_rtp_ = playout_timestamp; + playout_timestamp_rtp_time_ms_ = now_ms; } + playout_delay_ms_ = delay_ms; } int ChannelReceive::GetRtpTimestampRateHz() const { @@ -945,37 +1074,32 @@ int ChannelReceive::GetRtpTimestampRateHz() const { } int64_t ChannelReceive::GetRTT() const { - std::vector report_blocks; - rtp_rtcp_->RemoteRTCPStat(&report_blocks); + RTC_DCHECK_RUN_ON(&network_thread_checker_); + std::vector report_blocks = + rtp_rtcp_->GetLatestReportBlockData(); - // TODO(nisse): Could we check the return value from the ->RTT() call below, - // instead of checking if we have any report blocks? if (report_blocks.empty()) { - MutexLock lock(&assoc_send_channel_lock_); - // Tries to get RTT from an associated channel. + // Try fall back on an RTT from an associated channel. if (!associated_send_channel_) { return 0; } return associated_send_channel_->GetRTT(); } - int64_t rtt = 0; - int64_t avg_rtt = 0; - int64_t max_rtt = 0; - int64_t min_rtt = 0; // TODO(nisse): This method computes RTT based on sender reports, even though // a receive stream is not supposed to do that. - if (rtp_rtcp_->RTT(remote_ssrc_, &rtt, &avg_rtt, &min_rtt, &max_rtt) != 0) { - return 0; + for (const ReportBlockData& data : report_blocks) { + if (data.report_block().sender_ssrc == remote_ssrc_) { + return data.last_rtt_ms(); + } } - return rtt; + return 0; } } // namespace std::unique_ptr CreateChannelReceive( Clock* clock, - ProcessThread* module_process_thread, NetEqFactory* neteq_factory, AudioDeviceModule* audio_device_module, Transport* rtcp_send_transport, @@ -992,12 +1116,11 @@ std::unique_ptr CreateChannelReceive( const webrtc::CryptoOptions& crypto_options, rtc::scoped_refptr frame_transformer) { return std::make_unique( - clock, module_process_thread, neteq_factory, audio_device_module, - rtcp_send_transport, rtc_event_log, local_ssrc, remote_ssrc, - jitter_buffer_max_packets, jitter_buffer_fast_playout, - jitter_buffer_min_delay_ms, jitter_buffer_enable_rtx_handling, - decoder_factory, codec_pair_id, frame_decryptor, crypto_options, - std::move(frame_transformer)); + clock, neteq_factory, audio_device_module, rtcp_send_transport, + rtc_event_log, local_ssrc, remote_ssrc, jitter_buffer_max_packets, + jitter_buffer_fast_playout, jitter_buffer_min_delay_ms, + jitter_buffer_enable_rtx_handling, decoder_factory, codec_pair_id, + std::move(frame_decryptor), crypto_options, std::move(frame_transformer)); } } // namespace voe diff --git a/audio/channel_receive.h b/audio/channel_receive.h index eef2db425c..deec49feaf 100644 --- a/audio/channel_receive.h +++ b/audio/channel_receive.h @@ -28,6 +28,7 @@ #include "call/rtp_packet_sink_interface.h" #include "call/syncable.h" #include "modules/audio_coding/include/audio_coding_module_typedefs.h" +#include "modules/rtp_rtcp/source/source_tracker.h" #include "system_wrappers/include/clock.h" // TODO(solenberg, nisse): This file contains a few NOLINT marks, to silence @@ -43,7 +44,6 @@ namespace webrtc { class AudioDeviceModule; class FrameDecryptorInterface; class PacketRouter; -class ProcessThread; class RateLimiter; class ReceiveStatistics; class RtcEventLog; @@ -57,13 +57,23 @@ struct CallReceiveStatistics { int64_t payload_bytes_rcvd = 0; int64_t header_and_padding_bytes_rcvd = 0; int packetsReceived; - // The capture ntp time (in local timebase) of the first played out audio + uint32_t nacks_sent = 0; + // The capture NTP time (in local timebase) of the first played out audio // frame. int64_t capture_start_ntp_time_ms_; // The timestamp at which the last packet was received, i.e. the time of the // local clock when it was received - not the RTP timestamp of that packet. // https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-lastpacketreceivedtimestamp absl::optional last_packet_received_timestamp_ms; + // Remote outbound stats derived by the received RTCP sender reports. + // Note that the timestamps below correspond to the time elapsed since the + // Unix epoch. + // https://w3c.github.io/webrtc-stats/#remoteoutboundrtpstats-dict* + absl::optional last_sender_report_timestamp_ms; + absl::optional last_sender_report_remote_timestamp_ms; + uint32_t sender_reports_packets_sent = 0; + uint64_t sender_reports_bytes_sent = 0; + uint64_t sender_reports_reports_count = 0; }; namespace voe { @@ -135,6 +145,10 @@ class ChannelReceiveInterface : public RtpPacketSinkInterface { virtual int PreferredSampleRate() const = 0; + // Sets the source tracker to notify about "delivered" packets when output is + // muted. + virtual void SetSourceTracker(SourceTracker* source_tracker) = 0; + // Associate to a send channel. // Used for obtaining RTT for a receive-only channel. virtual void SetAssociatedSendChannel( @@ -145,11 +159,16 @@ class ChannelReceiveInterface : public RtpPacketSinkInterface { virtual void SetDepacketizerToDecoderFrameTransformer( rtc::scoped_refptr frame_transformer) = 0; + + virtual void SetFrameDecryptor( + rtc::scoped_refptr frame_decryptor) = 0; + + virtual void OnLocalSsrcChange(uint32_t local_ssrc) = 0; + virtual uint32_t GetLocalSsrc() const = 0; }; std::unique_ptr CreateChannelReceive( Clock* clock, - ProcessThread* module_process_thread, NetEqFactory* neteq_factory, AudioDeviceModule* audio_device_module, Transport* rtcp_send_transport, diff --git a/audio/channel_receive_frame_transformer_delegate.cc b/audio/channel_receive_frame_transformer_delegate.cc index 261afbb100..7e617df780 100644 --- a/audio/channel_receive_frame_transformer_delegate.cc +++ b/audio/channel_receive_frame_transformer_delegate.cc @@ -47,7 +47,7 @@ class TransformableAudioFrame : public TransformableAudioFrameInterface { ChannelReceiveFrameTransformerDelegate::ChannelReceiveFrameTransformerDelegate( ReceiveFrameCallback receive_frame_callback, rtc::scoped_refptr frame_transformer, - rtc::Thread* channel_receive_thread) + TaskQueueBase* channel_receive_thread) : receive_frame_callback_(receive_frame_callback), frame_transformer_(std::move(frame_transformer)), channel_receive_thread_(channel_receive_thread) {} diff --git a/audio/channel_receive_frame_transformer_delegate.h b/audio/channel_receive_frame_transformer_delegate.h index 3227c55914..f59834d24e 100644 --- a/audio/channel_receive_frame_transformer_delegate.h +++ b/audio/channel_receive_frame_transformer_delegate.h @@ -14,7 +14,7 @@ #include #include "api/frame_transformer_interface.h" -#include "rtc_base/synchronization/sequence_checker.h" +#include "api/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_queue.h" #include "rtc_base/thread.h" @@ -32,7 +32,7 @@ class ChannelReceiveFrameTransformerDelegate : public TransformedFrameCallback { ChannelReceiveFrameTransformerDelegate( ReceiveFrameCallback receive_frame_callback, rtc::scoped_refptr frame_transformer, - rtc::Thread* channel_receive_thread); + TaskQueueBase* channel_receive_thread); // Registers |this| as callback for |frame_transformer_|, to get the // transformed frames. @@ -67,7 +67,7 @@ class ChannelReceiveFrameTransformerDelegate : public TransformedFrameCallback { RTC_GUARDED_BY(sequence_checker_); rtc::scoped_refptr frame_transformer_ RTC_GUARDED_BY(sequence_checker_); - rtc::Thread* channel_receive_thread_; + TaskQueueBase* const channel_receive_thread_; }; } // namespace webrtc diff --git a/audio/channel_receive_frame_transformer_delegate_unittest.cc b/audio/channel_receive_frame_transformer_delegate_unittest.cc index e7f5a454b8..01aac45b24 100644 --- a/audio/channel_receive_frame_transformer_delegate_unittest.cc +++ b/audio/channel_receive_frame_transformer_delegate_unittest.cc @@ -41,9 +41,9 @@ class MockChannelReceive { TEST(ChannelReceiveFrameTransformerDelegateTest, RegisterTransformedFrameCallbackOnInit) { rtc::scoped_refptr mock_frame_transformer = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); rtc::scoped_refptr delegate = - new rtc::RefCountedObject( + rtc::make_ref_counted( ChannelReceiveFrameTransformerDelegate::ReceiveFrameCallback(), mock_frame_transformer, nullptr); EXPECT_CALL(*mock_frame_transformer, RegisterTransformedFrameCallback); @@ -55,9 +55,9 @@ TEST(ChannelReceiveFrameTransformerDelegateTest, TEST(ChannelReceiveFrameTransformerDelegateTest, UnregisterTransformedFrameCallbackOnReset) { rtc::scoped_refptr mock_frame_transformer = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); rtc::scoped_refptr delegate = - new rtc::RefCountedObject( + rtc::make_ref_counted( ChannelReceiveFrameTransformerDelegate::ReceiveFrameCallback(), mock_frame_transformer, nullptr); EXPECT_CALL(*mock_frame_transformer, UnregisterTransformedFrameCallback); @@ -69,10 +69,10 @@ TEST(ChannelReceiveFrameTransformerDelegateTest, TEST(ChannelReceiveFrameTransformerDelegateTest, TransformRunsChannelReceiveCallback) { rtc::scoped_refptr mock_frame_transformer = - new rtc::RefCountedObject>(); + rtc::make_ref_counted>(); MockChannelReceive mock_channel; rtc::scoped_refptr delegate = - new rtc::RefCountedObject( + rtc::make_ref_counted( mock_channel.callback(), mock_frame_transformer, rtc::Thread::Current()); rtc::scoped_refptr callback; @@ -100,10 +100,10 @@ TEST(ChannelReceiveFrameTransformerDelegateTest, TEST(ChannelReceiveFrameTransformerDelegateTest, OnTransformedDoesNotRunChannelReceiveCallbackAfterReset) { rtc::scoped_refptr mock_frame_transformer = - new rtc::RefCountedObject>(); + rtc::make_ref_counted>(); MockChannelReceive mock_channel; rtc::scoped_refptr delegate = - new rtc::RefCountedObject( + rtc::make_ref_counted( mock_channel.callback(), mock_frame_transformer, rtc::Thread::Current()); diff --git a/audio/channel_send.cc b/audio/channel_send.cc index d331f0129b..06e9238ce8 100644 --- a/audio/channel_send.cc +++ b/audio/channel_send.cc @@ -21,6 +21,7 @@ #include "api/call/transport.h" #include "api/crypto/frame_encryptor_interface.h" #include "api/rtc_event_log/rtc_event_log.h" +#include "api/sequence_checker.h" #include "audio/channel_send_frame_transformer_delegate.h" #include "audio/utility/audio_frame_operations.h" #include "call/rtp_transport_controller_send_interface.h" @@ -41,7 +42,6 @@ #include "rtc_base/rate_limiter.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/task_queue.h" -#include "rtc_base/thread_checker.h" #include "rtc_base/time_utils.h" #include "system_wrappers/include/clock.h" #include "system_wrappers/include/field_trial.h" @@ -60,8 +60,9 @@ class TransportSequenceNumberProxy; class VoERtcpObserver; class ChannelSend : public ChannelSendInterface, - public AudioPacketizationCallback { // receive encoded - // packets from the ACM + public AudioPacketizationCallback, // receive encoded + // packets from the ACM + public RtcpPacketTypeCounterObserver { public: // TODO(nisse): Make OnUplinkPacketLossRate public, and delete friend // declaration. @@ -69,7 +70,6 @@ class ChannelSend : public ChannelSendInterface, ChannelSend(Clock* clock, TaskQueueFactory* task_queue_factory, - ProcessThread* module_process_thread, Transport* rtp_transport, RtcpRttStats* rtcp_rtt_stats, RtcEventLog* rtc_event_log, @@ -151,6 +151,11 @@ class ChannelSend : public ChannelSendInterface, rtc::scoped_refptr frame_transformer) override; + // RtcpPacketTypeCounterObserver. + void RtcpPacketTypesCounterUpdated( + uint32_t ssrc, + const RtcpPacketTypeCounter& packet_counter) override; + private: // From AudioPacketizationCallback in the ACM int32_t SendData(AudioFrameType frameType, @@ -179,8 +184,7 @@ class ChannelSend : public ChannelSendInterface, // specific threads we know about. The goal is to eventually split up // voe::Channel into parts with single-threaded semantics, and thereby reduce // the need for locks. - rtc::ThreadChecker worker_thread_checker_; - rtc::ThreadChecker module_process_thread_checker_; + SequenceChecker worker_thread_checker_; // Methods accessed from audio and video threads are checked for sequential- // only access. We don't necessarily own and control these threads, so thread // checkers cannot be used. E.g. Chromium may transfer "ownership" from one @@ -189,6 +193,7 @@ class ChannelSend : public ChannelSendInterface, mutable Mutex volume_settings_mutex_; + const uint32_t ssrc_; bool sending_ RTC_GUARDED_BY(&worker_thread_checker_) = false; RtcEventLog* const event_log_; @@ -200,7 +205,6 @@ class ChannelSend : public ChannelSendInterface, uint32_t _timeStamp RTC_GUARDED_BY(encoder_queue_); // uses - ProcessThread* const _moduleProcessThreadPtr; RmsLevel rms_level_ RTC_GUARDED_BY(encoder_queue_); bool input_mute_ RTC_GUARDED_BY(volume_settings_mutex_); bool previous_frame_muted_ RTC_GUARDED_BY(encoder_queue_); @@ -218,8 +222,7 @@ class ChannelSend : public ChannelSendInterface, const std::unique_ptr rtp_packet_pacer_proxy_; const std::unique_ptr retransmission_rate_limiter_; - rtc::ThreadChecker construction_thread_; - + SequenceChecker construction_thread_; bool encoder_queue_is_active_ RTC_GUARDED_BY(encoder_queue_) = false; @@ -243,6 +246,10 @@ class ChannelSend : public ChannelSendInterface, rtc::TaskQueue encoder_queue_; const bool fixing_timestamp_stall_; + + mutable Mutex rtcp_counter_mutex_; + RtcpPacketTypeCounter rtcp_packet_type_counter_ + RTC_GUARDED_BY(rtcp_counter_mutex_); }; const int kTelephoneEventAttenuationdB = 10; @@ -264,7 +271,7 @@ class RtpPacketSenderProxy : public RtpPacketSender { } private: - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; Mutex mutex_; RtpPacketSender* rtp_packet_pacer_ RTC_GUARDED_BY(&mutex_); }; @@ -446,7 +453,6 @@ int32_t ChannelSend::SendRtpAudio(AudioFrameType frameType, ChannelSend::ChannelSend( Clock* clock, TaskQueueFactory* task_queue_factory, - ProcessThread* module_process_thread, Transport* rtp_transport, RtcpRttStats* rtcp_rtt_stats, RtcEventLog* rtc_event_log, @@ -457,10 +463,10 @@ ChannelSend::ChannelSend( uint32_t ssrc, rtc::scoped_refptr frame_transformer, TransportFeedbackObserver* feedback_observer) - : event_log_(rtc_event_log), + : ssrc_(ssrc), + event_log_(rtc_event_log), _timeStamp(0), // This is just an offset, RTP module will add it's own // random offset - _moduleProcessThreadPtr(module_process_thread), input_mute_(false), previous_frame_muted_(false), _includeAudioLevelIndication(false), @@ -476,9 +482,6 @@ ChannelSend::ChannelSend( TaskQueueFactory::Priority::NORMAL)), fixing_timestamp_stall_( !field_trial::IsDisabled("WebRTC-Audio-FixTimestampStall")) { - RTC_DCHECK(module_process_thread); - module_process_thread_checker_.Detach(); - audio_coding_.reset(AudioCodingModule::Create(AudioCodingModule::Config())); RtpRtcpInterface::Configuration configuration; @@ -496,6 +499,7 @@ ChannelSend::ChannelSend( retransmission_rate_limiter_.get(); configuration.extmap_allow_mixed = extmap_allow_mixed; configuration.rtcp_report_interval_ms = rtcp_report_interval_ms; + configuration.rtcp_packet_type_counter_observer = this; configuration.local_media_ssrc = ssrc; @@ -505,8 +509,6 @@ ChannelSend::ChannelSend( rtp_sender_audio_ = std::make_unique(configuration.clock, rtp_rtcp_->RtpSender()); - _moduleProcessThreadPtr->RegisterModule(rtp_rtcp_.get(), RTC_FROM_HERE); - // Ensure that RTCP is enabled by default for the created channel. rtp_rtcp_->SetRTCPStatus(RtcpMode::kCompound); @@ -526,9 +528,6 @@ ChannelSend::~ChannelSend() { StopSend(); int error = audio_coding_->RegisterTransportCallback(NULL); RTC_DCHECK_EQ(0, error); - - if (_moduleProcessThreadPtr) - _moduleProcessThreadPtr->DeRegisterModule(rtp_rtcp_.get()); } void ChannelSend::StartSend() { @@ -750,25 +749,20 @@ std::vector ChannelSend::GetRemoteRTCPReportBlocks() const { // Get the report blocks from the latest received RTCP Sender or Receiver // Report. Each element in the vector contains the sender's SSRC and a // report block according to RFC 3550. - std::vector rtcp_report_blocks; - - int ret = rtp_rtcp_->RemoteRTCPStat(&rtcp_report_blocks); - RTC_DCHECK_EQ(0, ret); - std::vector report_blocks; - - std::vector::const_iterator it = rtcp_report_blocks.begin(); - for (; it != rtcp_report_blocks.end(); ++it) { + for (const ReportBlockData& data : rtp_rtcp_->GetLatestReportBlockData()) { ReportBlock report_block; - report_block.sender_SSRC = it->sender_ssrc; - report_block.source_SSRC = it->source_ssrc; - report_block.fraction_lost = it->fraction_lost; - report_block.cumulative_num_packets_lost = it->packets_lost; + report_block.sender_SSRC = data.report_block().sender_ssrc; + report_block.source_SSRC = data.report_block().source_ssrc; + report_block.fraction_lost = data.report_block().fraction_lost; + report_block.cumulative_num_packets_lost = data.report_block().packets_lost; report_block.extended_highest_sequence_number = - it->extended_highest_sequence_number; - report_block.interarrival_jitter = it->jitter; - report_block.last_SR_timestamp = it->last_sender_report_timestamp; - report_block.delay_since_last_SR = it->delay_since_last_sender_report; + data.report_block().extended_highest_sequence_number; + report_block.interarrival_jitter = data.report_block().jitter; + report_block.last_SR_timestamp = + data.report_block().last_sender_report_timestamp; + report_block.delay_since_last_SR = + data.report_block().delay_since_last_sender_report; report_blocks.push_back(report_block); } return report_blocks; @@ -796,9 +790,24 @@ CallSendStatistics ChannelSend::GetRTCPStatistics() const { stats.retransmitted_packets_sent = rtp_stats.retransmitted.packets; stats.report_block_datas = rtp_rtcp_->GetLatestReportBlockData(); + { + MutexLock lock(&rtcp_counter_mutex_); + stats.nacks_rcvd = rtcp_packet_type_counter_.nack_packets; + } + return stats; } +void ChannelSend::RtcpPacketTypesCounterUpdated( + uint32_t ssrc, + const RtcpPacketTypeCounter& packet_counter) { + if (ssrc != ssrc_) { + return; + } + MutexLock lock(&rtcp_counter_mutex_); + rtcp_packet_type_counter_ = packet_counter; +} + void ChannelSend::ProcessAndEncodeAudio( std::unique_ptr audio_frame) { RTC_DCHECK_RUNS_SERIALIZED(&audio_thread_race_checker_); @@ -864,29 +873,19 @@ ANAStats ChannelSend::GetANAStatistics() const { } RtpRtcpInterface* ChannelSend::GetRtpRtcp() const { - RTC_DCHECK(module_process_thread_checker_.IsCurrent()); return rtp_rtcp_.get(); } int64_t ChannelSend::GetRTT() const { - std::vector report_blocks; - rtp_rtcp_->RemoteRTCPStat(&report_blocks); - + std::vector report_blocks = + rtp_rtcp_->GetLatestReportBlockData(); if (report_blocks.empty()) { return 0; } - int64_t rtt = 0; - int64_t avg_rtt = 0; - int64_t max_rtt = 0; - int64_t min_rtt = 0; // We don't know in advance the remote ssrc used by the other end's receiver - // reports, so use the SSRC of the first report block for calculating the RTT. - if (rtp_rtcp_->RTT(report_blocks[0].sender_ssrc, &rtt, &avg_rtt, &min_rtt, - &max_rtt) != 0) { - return 0; - } - return rtt; + // reports, so use the first report block for the RTT. + return report_blocks.front().last_rtt_ms(); } void ChannelSend::SetFrameEncryptor( @@ -934,7 +933,7 @@ void ChannelSend::InitFrameTransformerDelegate( absolute_capture_timestamp_ms); }; frame_transformer_delegate_ = - new rtc::RefCountedObject( + rtc::make_ref_counted( std::move(send_audio_callback), std::move(frame_transformer), &encoder_queue_); frame_transformer_delegate_->Init(); @@ -945,7 +944,6 @@ void ChannelSend::InitFrameTransformerDelegate( std::unique_ptr CreateChannelSend( Clock* clock, TaskQueueFactory* task_queue_factory, - ProcessThread* module_process_thread, Transport* rtp_transport, RtcpRttStats* rtcp_rtt_stats, RtcEventLog* rtc_event_log, @@ -957,10 +955,10 @@ std::unique_ptr CreateChannelSend( rtc::scoped_refptr frame_transformer, TransportFeedbackObserver* feedback_observer) { return std::make_unique( - clock, task_queue_factory, module_process_thread, rtp_transport, - rtcp_rtt_stats, rtc_event_log, frame_encryptor, crypto_options, - extmap_allow_mixed, rtcp_report_interval_ms, ssrc, - std::move(frame_transformer), feedback_observer); + clock, task_queue_factory, rtp_transport, rtcp_rtt_stats, rtc_event_log, + frame_encryptor, crypto_options, extmap_allow_mixed, + rtcp_report_interval_ms, ssrc, std::move(frame_transformer), + feedback_observer); } } // namespace voe diff --git a/audio/channel_send.h b/audio/channel_send.h index 2e23ef5d2d..67391af956 100644 --- a/audio/channel_send.h +++ b/audio/channel_send.h @@ -28,7 +28,6 @@ namespace webrtc { class FrameEncryptorInterface; -class ProcessThread; class RtcEventLog; class RtpTransportControllerSendInterface; @@ -46,6 +45,7 @@ struct CallSendStatistics { // ReportBlockData represents the latest Report Block that was received for // that pair. std::vector report_block_datas; + uint32_t nacks_rcvd; }; // See section 6.4.2 in http://www.ietf.org/rfc/rfc3550.txt for details. @@ -126,7 +126,6 @@ class ChannelSendInterface { std::unique_ptr CreateChannelSend( Clock* clock, TaskQueueFactory* task_queue_factory, - ProcessThread* module_process_thread, Transport* rtp_transport, RtcpRttStats* rtcp_rtt_stats, RtcEventLog* rtc_event_log, diff --git a/audio/channel_send_frame_transformer_delegate.h b/audio/channel_send_frame_transformer_delegate.h index 531d1bc110..9b7eb33b5c 100644 --- a/audio/channel_send_frame_transformer_delegate.h +++ b/audio/channel_send_frame_transformer_delegate.h @@ -14,10 +14,10 @@ #include #include "api/frame_transformer_interface.h" +#include "api/sequence_checker.h" #include "modules/audio_coding/include/audio_coding_module_typedefs.h" #include "rtc_base/buffer.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/task_queue.h" namespace webrtc { diff --git a/audio/channel_send_frame_transformer_delegate_unittest.cc b/audio/channel_send_frame_transformer_delegate_unittest.cc index e2f3647c0a..2ec78f8922 100644 --- a/audio/channel_send_frame_transformer_delegate_unittest.cc +++ b/audio/channel_send_frame_transformer_delegate_unittest.cc @@ -53,9 +53,9 @@ class MockChannelSend { TEST(ChannelSendFrameTransformerDelegateTest, RegisterTransformedFrameCallbackOnInit) { rtc::scoped_refptr mock_frame_transformer = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); rtc::scoped_refptr delegate = - new rtc::RefCountedObject( + rtc::make_ref_counted( ChannelSendFrameTransformerDelegate::SendFrameCallback(), mock_frame_transformer, nullptr); EXPECT_CALL(*mock_frame_transformer, RegisterTransformedFrameCallback); @@ -67,9 +67,9 @@ TEST(ChannelSendFrameTransformerDelegateTest, TEST(ChannelSendFrameTransformerDelegateTest, UnregisterTransformedFrameCallbackOnReset) { rtc::scoped_refptr mock_frame_transformer = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); rtc::scoped_refptr delegate = - new rtc::RefCountedObject( + rtc::make_ref_counted( ChannelSendFrameTransformerDelegate::SendFrameCallback(), mock_frame_transformer, nullptr); EXPECT_CALL(*mock_frame_transformer, UnregisterTransformedFrameCallback); @@ -82,10 +82,10 @@ TEST(ChannelSendFrameTransformerDelegateTest, TransformRunsChannelSendCallback) { TaskQueueForTest channel_queue("channel_queue"); rtc::scoped_refptr mock_frame_transformer = - new rtc::RefCountedObject>(); + rtc::make_ref_counted>(); MockChannelSend mock_channel; rtc::scoped_refptr delegate = - new rtc::RefCountedObject( + rtc::make_ref_counted( mock_channel.callback(), mock_frame_transformer, &channel_queue); rtc::scoped_refptr callback; EXPECT_CALL(*mock_frame_transformer, RegisterTransformedFrameCallback) @@ -112,10 +112,10 @@ TEST(ChannelSendFrameTransformerDelegateTest, OnTransformedDoesNotRunChannelSendCallbackAfterReset) { TaskQueueForTest channel_queue("channel_queue"); rtc::scoped_refptr mock_frame_transformer = - new rtc::RefCountedObject>(); + rtc::make_ref_counted>(); MockChannelSend mock_channel; rtc::scoped_refptr delegate = - new rtc::RefCountedObject( + rtc::make_ref_counted( mock_channel.callback(), mock_frame_transformer, &channel_queue); delegate->Reset(); diff --git a/audio/mock_voe_channel_proxy.h b/audio/mock_voe_channel_proxy.h index 52e5b2fc83..ea2a2ac3f0 100644 --- a/audio/mock_voe_channel_proxy.h +++ b/audio/mock_voe_channel_proxy.h @@ -17,6 +17,7 @@ #include #include +#include "api/crypto/frame_decryptor_interface.h" #include "api/test/mock_frame_encryptor.h" #include "audio/channel_receive.h" #include "audio/channel_send.h" @@ -59,6 +60,7 @@ class MockChannelReceive : public voe::ChannelReceiveInterface { (int sample_rate_hz, AudioFrame*), (override)); MOCK_METHOD(int, PreferredSampleRate, (), (const, override)); + MOCK_METHOD(void, SetSourceTracker, (SourceTracker*), (override)); MOCK_METHOD(void, SetAssociatedSendChannel, (const voe::ChannelSendInterface*), @@ -97,6 +99,13 @@ class MockChannelReceive : public voe::ChannelReceiveInterface { SetDepacketizerToDecoderFrameTransformer, (rtc::scoped_refptr frame_transformer), (override)); + MOCK_METHOD( + void, + SetFrameDecryptor, + (rtc::scoped_refptr frame_decryptor), + (override)); + MOCK_METHOD(void, OnLocalSsrcChange, (uint32_t local_ssrc), (override)); + MOCK_METHOD(uint32_t, GetLocalSsrc, (), (const, override)); }; class MockChannelSend : public voe::ChannelSendInterface { diff --git a/audio/null_audio_poller.h b/audio/null_audio_poller.h index 97cd2c7e6c..47e67a91da 100644 --- a/audio/null_audio_poller.h +++ b/audio/null_audio_poller.h @@ -13,9 +13,9 @@ #include +#include "api/sequence_checker.h" #include "modules/audio_device/include/audio_device_defines.h" #include "rtc_base/message_handler.h" -#include "rtc_base/thread_checker.h" namespace webrtc { namespace internal { @@ -29,7 +29,7 @@ class NullAudioPoller final : public rtc::MessageHandler { void OnMessage(rtc::Message* msg) override; private: - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; AudioTransport* const audio_transport_; int64_t reschedule_at_; }; diff --git a/audio/test/audio_end_to_end_test.cc b/audio/test/audio_end_to_end_test.cc index 896b0f2dae..0d8529a913 100644 --- a/audio/test/audio_end_to_end_test.cc +++ b/audio/test/audio_end_to_end_test.cc @@ -92,6 +92,8 @@ void AudioEndToEndTest::ModifyAudioConfigs( {{"stereo", "1"}}); send_config->send_codec_spec = AudioSendStream::Config::SendCodecSpec( test::CallTest::kAudioSendPayloadType, kDefaultFormat); + send_config->min_bitrate_bps = 32000; + send_config->max_bitrate_bps = 32000; } void AudioEndToEndTest::OnAudioStreamsCreated( diff --git a/audio/test/nack_test.cc b/audio/test/nack_test.cc new file mode 100644 index 0000000000..13cfe74a28 --- /dev/null +++ b/audio/test/nack_test.cc @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "audio/test/audio_end_to_end_test.h" +#include "system_wrappers/include/sleep.h" +#include "test/gtest.h" + +namespace webrtc { +namespace test { + +using NackTest = CallTest; + +TEST_F(NackTest, ShouldNackInLossyNetwork) { + class NackTest : public AudioEndToEndTest { + public: + const int kTestDurationMs = 2000; + const int64_t kRttMs = 30; + const int64_t kLossPercent = 30; + const int kNackHistoryMs = 1000; + + BuiltInNetworkBehaviorConfig GetNetworkPipeConfig() const override { + BuiltInNetworkBehaviorConfig pipe_config; + pipe_config.queue_delay_ms = kRttMs / 2; + pipe_config.loss_percent = kLossPercent; + return pipe_config; + } + + void ModifyAudioConfigs( + AudioSendStream::Config* send_config, + std::vector* receive_configs) override { + ASSERT_EQ(receive_configs->size(), 1U); + (*receive_configs)[0].rtp.nack.rtp_history_ms = kNackHistoryMs; + AudioEndToEndTest::ModifyAudioConfigs(send_config, receive_configs); + } + + void PerformTest() override { SleepMs(kTestDurationMs); } + + void OnStreamsStopped() override { + AudioReceiveStream::Stats recv_stats = + receive_stream()->GetStats(/*get_and_clear_legacy_stats=*/true); + EXPECT_GT(recv_stats.nacks_sent, 0U); + AudioSendStream::Stats send_stats = send_stream()->GetStats(); + EXPECT_GT(send_stats.retransmitted_packets_sent, 0U); + EXPECT_GT(send_stats.nacks_rcvd, 0U); + } + } test; + + RunBaseTest(&test); +} + +} // namespace test +} // namespace webrtc diff --git a/audio/utility/BUILD.gn b/audio/utility/BUILD.gn index 54ca04698d..933553d81b 100644 --- a/audio/utility/BUILD.gn +++ b/audio/utility/BUILD.gn @@ -26,10 +26,10 @@ rtc_library("audio_frame_operations") { "../../api/audio:audio_frame_api", "../../common_audio", "../../rtc_base:checks", - "../../rtc_base:deprecation", "../../rtc_base:rtc_base_approved", "../../system_wrappers:field_trial", ] + absl_deps = [ "//third_party/abseil-cpp/absl/base:core_headers" ] } if (rtc_include_tests) { diff --git a/audio/utility/audio_frame_operations.cc b/audio/utility/audio_frame_operations.cc index a9d2cf1632..e13a09bace 100644 --- a/audio/utility/audio_frame_operations.cc +++ b/audio/utility/audio_frame_operations.cc @@ -169,10 +169,10 @@ void AudioFrameOperations::UpmixChannels(size_t target_number_of_channels, if (!frame->muted()) { // Up-mixing done in place. Going backwards through the frame ensure nothing // is irrevocably overwritten. + int16_t* frame_data = frame->mutable_data(); for (int i = frame->samples_per_channel_ - 1; i >= 0; i--) { for (size_t j = 0; j < target_number_of_channels; ++j) { - frame->mutable_data()[target_number_of_channels * i + j] = - frame->data()[i]; + frame_data[target_number_of_channels * i + j] = frame_data[i]; } } } diff --git a/audio/utility/audio_frame_operations.h b/audio/utility/audio_frame_operations.h index 65c310c489..2f1540bcf5 100644 --- a/audio/utility/audio_frame_operations.h +++ b/audio/utility/audio_frame_operations.h @@ -14,8 +14,8 @@ #include #include +#include "absl/base/attributes.h" #include "api/audio/audio_frame.h" -#include "rtc_base/deprecation.h" namespace webrtc { @@ -36,12 +36,14 @@ class AudioFrameOperations { // |frame.num_channels_| will be updated. This version checks for sufficient // buffer size and that |num_channels_| is mono. Use UpmixChannels // instead. TODO(bugs.webrtc.org/8649): remove. - RTC_DEPRECATED static int MonoToStereo(AudioFrame* frame); + ABSL_DEPRECATED("bugs.webrtc.org/8649") + static int MonoToStereo(AudioFrame* frame); // |frame.num_channels_| will be updated. This version checks that // |num_channels_| is stereo. Use DownmixChannels // instead. TODO(bugs.webrtc.org/8649): remove. - RTC_DEPRECATED static int StereoToMono(AudioFrame* frame); + ABSL_DEPRECATED("bugs.webrtc.org/8649") + static int StereoToMono(AudioFrame* frame); // Downmixes 4 channels |src_audio| to stereo |dst_audio|. This is an in-place // operation, meaning |src_audio| and |dst_audio| may point to the same diff --git a/audio/voip/BUILD.gn b/audio/voip/BUILD.gn index ed0508ff1e..5311d7242b 100644 --- a/audio/voip/BUILD.gn +++ b/audio/voip/BUILD.gn @@ -89,6 +89,7 @@ rtc_library("audio_egress") { ] deps = [ "..:audio", + "../../api:sequence_checker", "../../api/audio_codecs:audio_codecs_api", "../../api/task_queue", "../../call:audio_sender_interface", @@ -97,7 +98,6 @@ rtc_library("audio_egress") { "../../modules/rtp_rtcp:rtp_rtcp_format", "../../rtc_base:logging", "../../rtc_base:rtc_task_queue", - "../../rtc_base:thread_checker", "../../rtc_base:timeutils", "../../rtc_base/synchronization:mutex", "../../rtc_base/system:no_unique_address", diff --git a/audio/voip/audio_channel.cc b/audio/voip/audio_channel.cc index d11e6d79f9..b4a50eec12 100644 --- a/audio/voip/audio_channel.cc +++ b/audio/voip/audio_channel.cc @@ -32,12 +32,10 @@ AudioChannel::AudioChannel( Transport* transport, uint32_t local_ssrc, TaskQueueFactory* task_queue_factory, - ProcessThread* process_thread, AudioMixer* audio_mixer, rtc::scoped_refptr decoder_factory) - : audio_mixer_(audio_mixer), process_thread_(process_thread) { + : audio_mixer_(audio_mixer) { RTC_DCHECK(task_queue_factory); - RTC_DCHECK(process_thread); RTC_DCHECK(audio_mixer); Clock* clock = Clock::GetRealTimeClock(); @@ -56,9 +54,6 @@ AudioChannel::AudioChannel( rtp_rtcp_->SetSendingMediaStatus(false); rtp_rtcp_->SetRTCPStatus(RtcpMode::kCompound); - // ProcessThread periodically services RTP stack for RTCP. - process_thread_->RegisterModule(rtp_rtcp_.get(), RTC_FROM_HERE); - ingress_ = std::make_unique(rtp_rtcp_.get(), clock, receive_statistics_.get(), std::move(decoder_factory)); @@ -80,12 +75,10 @@ AudioChannel::~AudioChannel() { audio_mixer_->RemoveSource(ingress_.get()); - // AudioEgress could hold current global TaskQueueBase that we need to clear - // before ProcessThread::DeRegisterModule. + // TODO(bugs.webrtc.org/11581): unclear if we still need to clear |egress_| + // here. egress_.reset(); ingress_.reset(); - - process_thread_->DeRegisterModule(rtp_rtcp_.get()); } bool AudioChannel::StartSend() { diff --git a/audio/voip/audio_channel.h b/audio/voip/audio_channel.h index 7b9fa6f74e..7338d9faab 100644 --- a/audio/voip/audio_channel.h +++ b/audio/voip/audio_channel.h @@ -22,7 +22,6 @@ #include "audio/voip/audio_egress.h" #include "audio/voip/audio_ingress.h" #include "modules/rtp_rtcp/source/rtp_rtcp_impl2.h" -#include "modules/utility/include/process_thread.h" #include "rtc_base/ref_count.h" namespace webrtc { @@ -35,7 +34,6 @@ class AudioChannel : public rtc::RefCountInterface { AudioChannel(Transport* transport, uint32_t local_ssrc, TaskQueueFactory* task_queue_factory, - ProcessThread* process_thread, AudioMixer* audio_mixer, rtc::scoped_refptr decoder_factory); ~AudioChannel() override; @@ -120,9 +118,6 @@ class AudioChannel : public rtc::RefCountInterface { // Synchronization is handled internally by AudioMixer. AudioMixer* audio_mixer_; - // Synchronization is handled internally by ProcessThread. - ProcessThread* process_thread_; - // Listed in order for safe destruction of AudioChannel object. // Synchronization for these are handled internally. std::unique_ptr receive_statistics_; diff --git a/audio/voip/audio_egress.h b/audio/voip/audio_egress.h index fcd9ed0f20..a39c7e225a 100644 --- a/audio/voip/audio_egress.h +++ b/audio/voip/audio_egress.h @@ -15,6 +15,7 @@ #include #include "api/audio_codecs/audio_format.h" +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_factory.h" #include "audio/audio_level.h" #include "audio/utility/audio_frame_operations.h" @@ -25,7 +26,6 @@ #include "modules/rtp_rtcp/source/rtp_sender_audio.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/task_queue.h" -#include "rtc_base/thread_checker.h" #include "rtc_base/time_utils.h" namespace webrtc { diff --git a/audio/voip/test/BUILD.gn b/audio/voip/test/BUILD.gn index ab074f7a47..132f448307 100644 --- a/audio/voip/test/BUILD.gn +++ b/audio/voip/test/BUILD.gn @@ -19,21 +19,23 @@ if (rtc_include_tests) { ] } - rtc_library("voip_core_unittests") { - testonly = true - sources = [ "voip_core_unittest.cc" ] - deps = [ - "..:voip_core", - "../../../api/audio_codecs:builtin_audio_decoder_factory", - "../../../api/audio_codecs:builtin_audio_encoder_factory", - "../../../api/task_queue:default_task_queue_factory", - "../../../modules/audio_device:mock_audio_device", - "../../../modules/audio_processing:mocks", - "../../../modules/utility:mock_process_thread", - "../../../test:audio_codec_mocks", - "../../../test:mock_transport", - "../../../test:test_support", - ] + if (!build_with_chromium) { + rtc_library("voip_core_unittests") { + testonly = true + sources = [ "voip_core_unittest.cc" ] + deps = [ + "..:voip_core", + "../../../api/audio_codecs:builtin_audio_decoder_factory", + "../../../api/audio_codecs:builtin_audio_encoder_factory", + "../../../api/task_queue:default_task_queue_factory", + "../../../modules/audio_device:mock_audio_device", + "../../../modules/audio_processing:mocks", + "../../../modules/utility:mock_process_thread", + "../../../test:audio_codec_mocks", + "../../../test:mock_transport", + "../../../test:test_support", + ] + } } rtc_library("audio_channel_unittests") { diff --git a/audio/voip/test/audio_channel_unittest.cc b/audio/voip/test/audio_channel_unittest.cc index 1a79d847b1..a4f518c5bd 100644 --- a/audio/voip/test/audio_channel_unittest.cc +++ b/audio/voip/test/audio_channel_unittest.cc @@ -17,7 +17,6 @@ #include "modules/audio_mixer/audio_mixer_impl.h" #include "modules/audio_mixer/sine_wave_generator.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" -#include "modules/utility/include/process_thread.h" #include "rtc_base/logging.h" #include "test/gmock.h" #include "test/gtest.h" @@ -28,6 +27,7 @@ namespace { using ::testing::Invoke; using ::testing::NiceMock; +using ::testing::Return; using ::testing::Unused; constexpr uint64_t kStartTime = 123456789; @@ -42,7 +42,6 @@ class AudioChannelTest : public ::testing::Test { AudioChannelTest() : fake_clock_(kStartTime), wave_generator_(1000.0, kAudioLevel) { task_queue_factory_ = std::make_unique(&task_queue_); - process_thread_ = ProcessThread::Create("ModuleProcessThread"); audio_mixer_ = AudioMixerImpl::Create(); encoder_factory_ = CreateBuiltinAudioEncoderFactory(); decoder_factory_ = CreateBuiltinAudioDecoderFactory(); @@ -53,23 +52,27 @@ class AudioChannelTest : public ::testing::Test { Invoke([&](std::unique_ptr task) { task->Run(); })); } - void SetUp() override { - audio_channel_ = new rtc::RefCountedObject( - &transport_, kLocalSsrc, task_queue_factory_.get(), - process_thread_.get(), audio_mixer_.get(), decoder_factory_); - - audio_channel_->SetEncoder(kPcmuPayload, kPcmuFormat, - encoder_factory_->MakeAudioEncoder( - kPcmuPayload, kPcmuFormat, absl::nullopt)); - audio_channel_->SetReceiveCodecs({{kPcmuPayload, kPcmuFormat}}); - audio_channel_->StartSend(); - audio_channel_->StartPlay(); - } - - void TearDown() override { - audio_channel_->StopSend(); - audio_channel_->StopPlay(); - audio_channel_ = nullptr; + void SetUp() override { audio_channel_ = CreateAudioChannel(kLocalSsrc); } + + void TearDown() override { audio_channel_ = nullptr; } + + rtc::scoped_refptr CreateAudioChannel(uint32_t ssrc) { + // Use same audio mixer here for simplicity sake as we are not checking + // audio activity of RTP in our testcases. If we need to do test on audio + // signal activity then we need to assign audio mixer for each channel. + // Also this uses the same transport object for different audio channel to + // simplify network routing logic. + rtc::scoped_refptr audio_channel = + rtc::make_ref_counted( + &transport_, ssrc, task_queue_factory_.get(), audio_mixer_.get(), + decoder_factory_); + audio_channel->SetEncoder(kPcmuPayload, kPcmuFormat, + encoder_factory_->MakeAudioEncoder( + kPcmuPayload, kPcmuFormat, absl::nullopt)); + audio_channel->SetReceiveCodecs({{kPcmuPayload, kPcmuFormat}}); + audio_channel->StartSend(); + audio_channel->StartPlay(); + return audio_channel; } std::unique_ptr GetAudioFrame(int order) { @@ -90,7 +93,6 @@ class AudioChannelTest : public ::testing::Test { rtc::scoped_refptr audio_mixer_; rtc::scoped_refptr decoder_factory_; rtc::scoped_refptr encoder_factory_; - std::unique_ptr process_thread_; rtc::scoped_refptr audio_channel_; }; @@ -269,5 +271,85 @@ TEST_F(AudioChannelTest, TestChannelStatistics) { EXPECT_FALSE(channel_stats->remote_rtcp->round_trip_time.has_value()); } +// Check ChannelStatistics RTT metric after processing RTP and RTCP packets +// using three audio channels where each represents media endpoint. +// +// 1) AC1 <- RTP/RTCP -> AC2 +// 2) AC1 <- RTP/RTCP -> AC3 +// +// During step 1), AC1 should be able to check RTT from AC2's SSRC. +// During step 2), AC1 should be able to check RTT from AC3's SSRC. +TEST_F(AudioChannelTest, RttIsAvailableAfterChangeOfRemoteSsrc) { + // Create AC2 and AC3. + constexpr uint32_t kAc2Ssrc = 0xdeadbeef; + constexpr uint32_t kAc3Ssrc = 0xdeafbeef; + + auto ac_2 = CreateAudioChannel(kAc2Ssrc); + auto ac_3 = CreateAudioChannel(kAc3Ssrc); + + auto send_recv_rtp = [&](rtc::scoped_refptr rtp_sender, + rtc::scoped_refptr rtp_receiver) { + // Setup routing logic via transport_. + auto route_rtp = [&](const uint8_t* packet, size_t length, Unused) { + rtp_receiver->ReceivedRTPPacket(rtc::MakeArrayView(packet, length)); + return true; + }; + ON_CALL(transport_, SendRtp).WillByDefault(route_rtp); + + // This will trigger route_rtp callback via transport_. + rtp_sender->GetAudioSender()->SendAudioData(GetAudioFrame(0)); + rtp_sender->GetAudioSender()->SendAudioData(GetAudioFrame(1)); + + // Process received RTP in receiver. + AudioFrame audio_frame; + audio_mixer_->Mix(/*number_of_channels=*/1, &audio_frame); + audio_mixer_->Mix(/*number_of_channels=*/1, &audio_frame); + + // Revert to default to avoid using reference in route_rtp lambda. + ON_CALL(transport_, SendRtp).WillByDefault(Return(true)); + }; + + auto send_recv_rtcp = [&](rtc::scoped_refptr rtcp_sender, + rtc::scoped_refptr rtcp_receiver) { + // Setup routing logic via transport_. + auto route_rtcp = [&](const uint8_t* packet, size_t length) { + rtcp_receiver->ReceivedRTCPPacket(rtc::MakeArrayView(packet, length)); + return true; + }; + ON_CALL(transport_, SendRtcp).WillByDefault(route_rtcp); + + // This will trigger route_rtcp callback via transport_. + rtcp_sender->SendRTCPReportForTesting(kRtcpSr); + + // Revert to default to avoid using reference in route_rtcp lambda. + ON_CALL(transport_, SendRtcp).WillByDefault(Return(true)); + }; + + // AC1 <-- RTP/RTCP --> AC2 + send_recv_rtp(audio_channel_, ac_2); + send_recv_rtp(ac_2, audio_channel_); + send_recv_rtcp(audio_channel_, ac_2); + send_recv_rtcp(ac_2, audio_channel_); + + absl::optional channel_stats = + audio_channel_->GetChannelStatistics(); + ASSERT_TRUE(channel_stats); + EXPECT_EQ(channel_stats->remote_ssrc, kAc2Ssrc); + ASSERT_TRUE(channel_stats->remote_rtcp); + EXPECT_GT(channel_stats->remote_rtcp->round_trip_time, 0.0); + + // AC1 <-- RTP/RTCP --> AC3 + send_recv_rtp(audio_channel_, ac_3); + send_recv_rtp(ac_3, audio_channel_); + send_recv_rtcp(audio_channel_, ac_3); + send_recv_rtcp(ac_3, audio_channel_); + + channel_stats = audio_channel_->GetChannelStatistics(); + ASSERT_TRUE(channel_stats); + EXPECT_EQ(channel_stats->remote_ssrc, kAc3Ssrc); + ASSERT_TRUE(channel_stats->remote_rtcp); + EXPECT_GT(channel_stats->remote_rtcp->round_trip_time, 0.0); +} + } // namespace } // namespace webrtc diff --git a/audio/voip/test/voip_core_unittest.cc b/audio/voip/test/voip_core_unittest.cc index d290bd6ec3..896d0d98bb 100644 --- a/audio/voip/test/voip_core_unittest.cc +++ b/audio/voip/test/voip_core_unittest.cc @@ -14,7 +14,6 @@ #include "api/task_queue/default_task_queue_factory.h" #include "modules/audio_device/include/mock_audio_device.h" #include "modules/audio_processing/include/mock_audio_processing.h" -#include "modules/utility/include/mock/mock_process_thread.h" #include "test/gtest.h" #include "test/mock_transport.h" @@ -39,22 +38,17 @@ class VoipCoreTest : public ::testing::Test { auto encoder_factory = CreateBuiltinAudioEncoderFactory(); auto decoder_factory = CreateBuiltinAudioDecoderFactory(); rtc::scoped_refptr audio_processing = - new rtc::RefCountedObject>(); - - auto process_thread = std::make_unique>(); - // Hold the pointer to use for testing. - process_thread_ = process_thread.get(); + rtc::make_ref_counted>(); voip_core_ = std::make_unique( std::move(encoder_factory), std::move(decoder_factory), CreateDefaultTaskQueueFactory(), audio_device_, - std::move(audio_processing), std::move(process_thread)); + std::move(audio_processing)); } std::unique_ptr voip_core_; NiceMock transport_; rtc::scoped_refptr audio_device_; - NiceMock* process_thread_; }; // Validate expected API calls that involves with VoipCore. Some verification is @@ -192,31 +186,5 @@ TEST_F(VoipCoreTest, StopSendAndPlayoutWithoutStarting) { EXPECT_EQ(voip_core_->ReleaseChannel(channel), VoipResult::kOk); } -// This tests correctness on ProcessThread usage where we expect the first/last -// channel creation/release triggers its Start/Stop method once only. -TEST_F(VoipCoreTest, TestProcessThreadOperation) { - EXPECT_CALL(*process_thread_, Start); - EXPECT_CALL(*process_thread_, RegisterModule).Times(2); - - auto channel_one = voip_core_->CreateChannel(&transport_, 0xdeadc0de); - auto channel_two = voip_core_->CreateChannel(&transport_, 0xdeadbeef); - - EXPECT_CALL(*process_thread_, Stop); - EXPECT_CALL(*process_thread_, DeRegisterModule).Times(2); - - EXPECT_EQ(voip_core_->ReleaseChannel(channel_one), VoipResult::kOk); - EXPECT_EQ(voip_core_->ReleaseChannel(channel_two), VoipResult::kOk); - - EXPECT_CALL(*process_thread_, Start); - EXPECT_CALL(*process_thread_, RegisterModule); - - auto channel_three = voip_core_->CreateChannel(&transport_, absl::nullopt); - - EXPECT_CALL(*process_thread_, Stop); - EXPECT_CALL(*process_thread_, DeRegisterModule); - - EXPECT_EQ(voip_core_->ReleaseChannel(channel_three), VoipResult::kOk); -} - } // namespace } // namespace webrtc diff --git a/audio/voip/voip_core.cc b/audio/voip/voip_core.cc index 33dadbc9af..fd66379f4a 100644 --- a/audio/voip/voip_core.cc +++ b/audio/voip/voip_core.cc @@ -41,18 +41,12 @@ VoipCore::VoipCore(rtc::scoped_refptr encoder_factory, rtc::scoped_refptr decoder_factory, std::unique_ptr task_queue_factory, rtc::scoped_refptr audio_device_module, - rtc::scoped_refptr audio_processing, - std::unique_ptr process_thread) { + rtc::scoped_refptr audio_processing) { encoder_factory_ = std::move(encoder_factory); decoder_factory_ = std::move(decoder_factory); task_queue_factory_ = std::move(task_queue_factory); audio_device_module_ = std::move(audio_device_module); audio_processing_ = std::move(audio_processing); - process_thread_ = std::move(process_thread); - - if (!process_thread_) { - process_thread_ = ProcessThread::Create("ModuleProcessThread"); - } audio_mixer_ = AudioMixerImpl::Create(); // AudioTransportImpl depends on audio mixer and audio processing instances. @@ -138,19 +132,13 @@ ChannelId VoipCore::CreateChannel(Transport* transport, } rtc::scoped_refptr channel = - new rtc::RefCountedObject( - transport, local_ssrc.value(), task_queue_factory_.get(), - process_thread_.get(), audio_mixer_.get(), decoder_factory_); - - // Check if we need to start the process thread. - bool start_process_thread = false; + rtc::make_ref_counted(transport, local_ssrc.value(), + task_queue_factory_.get(), + audio_mixer_.get(), decoder_factory_); { MutexLock lock(&lock_); - // Start process thread if the channel is the first one. - start_process_thread = channels_.empty(); - channel_id = static_cast(next_channel_id_); channels_[channel_id] = channel; next_channel_id_++; @@ -162,10 +150,6 @@ ChannelId VoipCore::CreateChannel(Transport* transport, // Set ChannelId in audio channel for logging/debugging purpose. channel->SetId(channel_id); - if (start_process_thread) { - process_thread_->Start(); - } - return channel_id; } @@ -194,9 +178,9 @@ VoipResult VoipCore::ReleaseChannel(ChannelId channel_id) { } if (no_channels_after_release) { - // Release audio channel first to have it DeRegisterModule first. + // TODO(bugs.webrtc.org/11581): unclear if we still need to clear |channel| + // here. channel = nullptr; - process_thread_->Stop(); // Make sure to stop playout on ADM if it is playing. if (audio_device_module_->Playing()) { diff --git a/audio/voip/voip_core.h b/audio/voip/voip_core.h index b7c1f2947f..359e07272d 100644 --- a/audio/voip/voip_core.h +++ b/audio/voip/voip_core.h @@ -33,7 +33,6 @@ #include "modules/audio_device/include/audio_device.h" #include "modules/audio_mixer/audio_mixer_impl.h" #include "modules/audio_processing/include/audio_processing.h" -#include "modules/utility/include/process_thread.h" #include "rtc_base/synchronization/mutex.h" namespace webrtc { @@ -61,8 +60,7 @@ class VoipCore : public VoipEngine, rtc::scoped_refptr decoder_factory, std::unique_ptr task_queue_factory, rtc::scoped_refptr audio_device_module, - rtc::scoped_refptr audio_processing, - std::unique_ptr process_thread = nullptr); + rtc::scoped_refptr audio_processing); ~VoipCore() override = default; // Implements VoipEngine interfaces. @@ -160,10 +158,6 @@ class VoipCore : public VoipEngine, // Synchronization is handled internally by AudioDeviceModule. rtc::scoped_refptr audio_device_module_; - // Synchronization is handled internally by ProcessThread. - // Must be placed before |channels_| for proper destruction. - std::unique_ptr process_thread_; - Mutex lock_; // Member to track a next ChannelId for new AudioChannel. diff --git a/build_overrides/build.gni b/build_overrides/build.gni index 8facdeab8d..137b6a40b2 100644 --- a/build_overrides/build.gni +++ b/build_overrides/build.gni @@ -20,11 +20,11 @@ checkout_google_benchmark = true asan_suppressions_file = "//build/sanitizers/asan_suppressions.cc" lsan_suppressions_file = "//tools_webrtc/sanitizers/lsan_suppressions_webrtc.cc" tsan_suppressions_file = "//tools_webrtc/sanitizers/tsan_suppressions_webrtc.cc" -msan_blacklist_path = +msan_ignorelist_path = rebase_path("//tools_webrtc/msan/suppressions.txt", root_build_dir) -ubsan_blacklist_path = +ubsan_ignorelist_path = rebase_path("//tools_webrtc/ubsan/suppressions.txt", root_build_dir) -ubsan_vptr_blacklist_path = +ubsan_vptr_ignorelist_path = rebase_path("//tools_webrtc/ubsan/vptr_suppressions.txt", root_build_dir) # For Chromium, Android 32-bit non-component, non-clang builds hit a 4GiB size @@ -34,7 +34,8 @@ ignore_elf32_limitations = true # Use bundled hermetic Xcode installation maintainted by Chromium, # except for local iOS builds where it's unsupported. -if (host_os == "mac") { +# Allow for mac cross compile on linux machines. +if (host_os == "mac" || host_os == "linux") { _result = exec_script("//build/mac/should_use_hermetic_xcode.py", [ target_os ], "value") @@ -51,6 +52,13 @@ declare_args() { enable_base_tracing = false use_perfetto_client_library = false + # Limits the defined //third_party/android_deps targets to only "buildCompile" + # and "buildCompileNoDeps" targets. This is useful for third-party + # repositories which do not use JUnit tests. For instance, + # limit_android_deps == true removes "gn gen" requirement for + # //third_party/robolectric . + limit_android_deps = false + # If true, it assumes that //third_party/abseil-cpp is an available # dependency for googletest. gtest_enable_absl_printers = true diff --git a/call/BUILD.gn b/call/BUILD.gn index cef43f4c3d..638eb0b910 100644 --- a/call/BUILD.gn +++ b/call/BUILD.gn @@ -35,8 +35,10 @@ rtc_library("call_interfaces") { if (!build_with_mozilla) { sources += [ "audio_send_stream.cc" ] } + deps = [ ":audio_sender_interface", + ":receive_stream_interface", ":rtp_interfaces", ":video_stream_api", "../api:fec_controller_api", @@ -51,7 +53,6 @@ rtc_library("call_interfaces") { "../api/audio:audio_frame_processor", "../api/audio:audio_mixer_api", "../api/audio_codecs:audio_codecs_api", - "../api/crypto:frame_decryptor_interface", "../api/crypto:frame_encryptor_interface", "../api/crypto:options", "../api/neteq:neteq_api", @@ -59,7 +60,6 @@ rtc_library("call_interfaces") { "../api/transport:bitrate_settings", "../api/transport:network_control", "../api/transport:webrtc_key_value_config", - "../api/transport/rtp:rtp_source", "../modules/async_audio_processing", "../modules/audio_device", "../modules/audio_processing", @@ -73,7 +73,10 @@ rtc_library("call_interfaces") { "../rtc_base:rtc_base_approved", "../rtc_base/network:sent_packet", ] - absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/functional:bind_front", + "//third_party/abseil-cpp/absl/types:optional", + ] } rtc_source_set("audio_sender_interface") { @@ -95,22 +98,29 @@ rtc_library("rtp_interfaces") { "rtp_config.h", "rtp_packet_sink_interface.h", "rtp_stream_receiver_controller_interface.h", + "rtp_transport_config.h", + "rtp_transport_controller_send_factory_interface.h", "rtp_transport_controller_send_interface.h", ] deps = [ "../api:array_view", "../api:fec_controller_api", "../api:frame_transformer_interface", + "../api:network_state_predictor_api", "../api:rtp_headers", "../api:rtp_parameters", "../api/crypto:options", "../api/rtc_event_log", "../api/transport:bitrate_settings", + "../api/transport:network_control", + "../api/transport:webrtc_key_value_config", "../api/units:timestamp", "../common_video:frame_counts", "../modules/rtp_rtcp:rtp_rtcp_format", + "../modules/utility", "../rtc_base:checks", "../rtc_base:rtc_base_approved", + "../rtc_base:rtc_task_queue", ] absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container", @@ -132,10 +142,13 @@ rtc_library("rtp_receiver") { ":rtp_interfaces", "../api:array_view", "../api:rtp_headers", + "../api:sequence_checker", "../modules/rtp_rtcp", "../modules/rtp_rtcp:rtp_rtcp_format", "../rtc_base:checks", "../rtc_base:rtc_base_approved", + "../rtc_base/containers:flat_map", + "../rtc_base/containers:flat_set", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } @@ -146,6 +159,7 @@ rtc_library("rtp_sender") { "rtp_payload_params.h", "rtp_transport_controller_send.cc", "rtp_transport_controller_send.h", + "rtp_transport_controller_send_factory.h", "rtp_video_sender.cc", "rtp_video_sender.h", "rtp_video_sender_interface.h", @@ -158,6 +172,7 @@ rtc_library("rtp_sender") { "../api:fec_controller_api", "../api:network_state_predictor_api", "../api:rtp_parameters", + "../api:sequence_checker", "../api:transport_api", "../api/rtc_event_log", "../api/transport:field_trial_based_config", @@ -226,13 +241,13 @@ rtc_library("bitrate_allocator") { ] deps = [ "../api:bitrate_allocation", + "../api:sequence_checker", "../api/transport:network_control", "../api/units:data_rate", "../api/units:time_delta", "../rtc_base:checks", "../rtc_base:rtc_base_approved", "../rtc_base:safe_minmax", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/system:no_unique_address", "../system_wrappers", "../system_wrappers:field_trial", @@ -269,6 +284,7 @@ rtc_library("call") { "../api:fec_controller_api", "../api:rtp_headers", "../api:rtp_parameters", + "../api:sequence_checker", "../api:simulated_network_api", "../api:transport_api", "../api/rtc_event_log", @@ -293,7 +309,6 @@ rtc_library("call") { "../rtc_base:safe_minmax", "../rtc_base/experiments:field_trial_parser", "../rtc_base/network:sent_packet", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/system:no_unique_address", "../rtc_base/task_utils:pending_task_safety_flag", "../system_wrappers", @@ -302,7 +317,21 @@ rtc_library("call") { "../video", "adaptation:resource_adaptation", ] - absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/functional:bind_front", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_source_set("receive_stream_interface") { + sources = [ "receive_stream.h" ] + deps = [ + "../api:frame_transformer_interface", + "../api:rtp_parameters", + "../api:scoped_refptr", + "../api/crypto:frame_decryptor_interface", + "../api/transport/rtp:rtp_source", + ] } rtc_library("video_stream_api") { @@ -313,6 +342,7 @@ rtc_library("video_stream_api") { "video_send_stream.h", ] deps = [ + ":receive_stream_interface", ":rtp_interfaces", "../api:frame_transformer_interface", "../api:rtp_headers", @@ -320,10 +350,8 @@ rtc_library("video_stream_api") { "../api:scoped_refptr", "../api:transport_api", "../api/adaptation:resource_adaptation_api", - "../api/crypto:frame_decryptor_interface", "../api/crypto:frame_encryptor_interface", "../api/crypto:options", - "../api/transport/rtp:rtp_source", "../api/video:recordable_encoded_frame", "../api/video:video_frame", "../api/video:video_rtp_headers", @@ -344,6 +372,7 @@ rtc_library("simulated_network") { "simulated_network.h", ] deps = [ + "../api:sequence_checker", "../api:simulated_network_api", "../api/units:data_rate", "../api/units:data_size", @@ -352,7 +381,6 @@ rtc_library("simulated_network") { "../rtc_base:checks", "../rtc_base:rtc_base_approved", "../rtc_base/synchronization:mutex", - "../rtc_base/synchronization:sequence_checker", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } @@ -375,154 +403,160 @@ rtc_library("fake_network") { ":simulated_network", ":simulated_packet_receiver", "../api:rtp_parameters", + "../api:sequence_checker", "../api:simulated_network_api", "../api:transport_api", "../modules/utility", "../rtc_base:checks", "../rtc_base:rtc_base_approved", "../rtc_base/synchronization:mutex", - "../rtc_base/synchronization:sequence_checker", "../system_wrappers", ] } if (rtc_include_tests) { - rtc_library("call_tests") { - testonly = true + if (!build_with_chromium) { + rtc_library("call_tests") { + testonly = true - sources = [ - "bitrate_allocator_unittest.cc", - "bitrate_estimator_tests.cc", - "call_unittest.cc", - "flexfec_receive_stream_unittest.cc", - "receive_time_calculator_unittest.cc", - "rtp_bitrate_configurator_unittest.cc", - "rtp_demuxer_unittest.cc", - "rtp_payload_params_unittest.cc", - "rtp_video_sender_unittest.cc", - "rtx_receive_stream_unittest.cc", - ] - deps = [ - ":bitrate_allocator", - ":bitrate_configurator", - ":call", - ":call_interfaces", - ":mock_rtp_interfaces", - ":rtp_interfaces", - ":rtp_receiver", - ":rtp_sender", - ":simulated_network", - "../api:array_view", - "../api:create_frame_generator", - "../api:mock_audio_mixer", - "../api:rtp_headers", - "../api:rtp_parameters", - "../api:transport_api", - "../api/audio_codecs:builtin_audio_decoder_factory", - "../api/rtc_event_log", - "../api/task_queue:default_task_queue_factory", - "../api/test/video:function_video_factory", - "../api/transport:field_trial_based_config", - "../api/video:builtin_video_bitrate_allocator_factory", - "../api/video:video_frame", - "../api/video:video_rtp_headers", - "../audio", - "../modules/audio_device:mock_audio_device", - "../modules/audio_mixer", - "../modules/audio_mixer:audio_mixer_impl", - "../modules/audio_processing:mocks", - "../modules/congestion_controller", - "../modules/pacing", - "../modules/rtp_rtcp", - "../modules/rtp_rtcp:mock_rtp_rtcp", - "../modules/rtp_rtcp:rtp_rtcp_format", - "../modules/utility:mock_process_thread", - "../modules/video_coding", - "../modules/video_coding:codec_globals_headers", - "../modules/video_coding:video_codec_interface", - "../rtc_base:checks", - "../rtc_base:rate_limiter", - "../rtc_base:rtc_base_approved", - "../rtc_base:task_queue_for_test", - "../rtc_base/synchronization:mutex", - "../system_wrappers", - "../test:audio_codec_mocks", - "../test:direct_transport", - "../test:encoder_settings", - "../test:fake_video_codecs", - "../test:field_trial", - "../test:mock_frame_transformer", - "../test:mock_transport", - "../test:test_common", - "../test:test_support", - "../test:video_test_common", - "../test/time_controller:time_controller", - "../video", - "adaptation:resource_adaptation_test_utilities", - "//test/scenario:scenario", - "//testing/gmock", - "//testing/gtest", - ] - absl_deps = [ - "//third_party/abseil-cpp/absl/container:inlined_vector", - "//third_party/abseil-cpp/absl/memory", - "//third_party/abseil-cpp/absl/types:optional", - "//third_party/abseil-cpp/absl/types:variant", - ] - } + sources = [ + "bitrate_allocator_unittest.cc", + "bitrate_estimator_tests.cc", + "call_unittest.cc", + "flexfec_receive_stream_unittest.cc", + "receive_time_calculator_unittest.cc", + "rtp_bitrate_configurator_unittest.cc", + "rtp_demuxer_unittest.cc", + "rtp_payload_params_unittest.cc", + "rtp_video_sender_unittest.cc", + "rtx_receive_stream_unittest.cc", + ] + deps = [ + ":bitrate_allocator", + ":bitrate_configurator", + ":call", + ":call_interfaces", + ":mock_rtp_interfaces", + ":rtp_interfaces", + ":rtp_receiver", + ":rtp_sender", + ":simulated_network", + "../api:array_view", + "../api:create_frame_generator", + "../api:mock_audio_mixer", + "../api:rtp_headers", + "../api:rtp_parameters", + "../api:transport_api", + "../api/audio_codecs:builtin_audio_decoder_factory", + "../api/rtc_event_log", + "../api/task_queue:default_task_queue_factory", + "../api/test/video:function_video_factory", + "../api/transport:field_trial_based_config", + "../api/video:builtin_video_bitrate_allocator_factory", + "../api/video:video_frame", + "../api/video:video_rtp_headers", + "../audio", + "../modules:module_api", + "../modules/audio_device:mock_audio_device", + "../modules/audio_mixer", + "../modules/audio_mixer:audio_mixer_impl", + "../modules/audio_processing:mocks", + "../modules/congestion_controller", + "../modules/pacing", + "../modules/rtp_rtcp", + "../modules/rtp_rtcp:mock_rtp_rtcp", + "../modules/rtp_rtcp:rtp_rtcp_format", + "../modules/utility:mock_process_thread", + "../modules/video_coding", + "../modules/video_coding:codec_globals_headers", + "../modules/video_coding:video_codec_interface", + "../rtc_base:checks", + "../rtc_base:logging", + "../rtc_base:rate_limiter", + "../rtc_base:rtc_base_approved", + "../rtc_base:task_queue_for_test", + "../rtc_base/synchronization:mutex", + "../system_wrappers", + "../test:audio_codec_mocks", + "../test:direct_transport", + "../test:encoder_settings", + "../test:explicit_key_value_config", + "../test:fake_video_codecs", + "../test:field_trial", + "../test:mock_frame_transformer", + "../test:mock_transport", + "../test:test_common", + "../test:test_support", + "../test:video_test_common", + "../test/scenario", + "../test/time_controller:time_controller", + "../video", + "adaptation:resource_adaptation_test_utilities", + "//testing/gmock", + "//testing/gtest", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/container:inlined_vector", + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/types:optional", + "//third_party/abseil-cpp/absl/types:variant", + ] + } - rtc_library("call_perf_tests") { - testonly = true + rtc_library("call_perf_tests") { + testonly = true - sources = [ - "call_perf_tests.cc", - "rampup_tests.cc", - "rampup_tests.h", - ] - deps = [ - ":call_interfaces", - ":simulated_network", - ":video_stream_api", - "../api:rtc_event_log_output_file", - "../api:simulated_network_api", - "../api/audio_codecs:builtin_audio_encoder_factory", - "../api/rtc_event_log", - "../api/rtc_event_log:rtc_event_log_factory", - "../api/task_queue", - "../api/task_queue:default_task_queue_factory", - "../api/video:builtin_video_bitrate_allocator_factory", - "../api/video:video_bitrate_allocation", - "../api/video_codecs:video_codecs_api", - "../modules/audio_coding", - "../modules/audio_device", - "../modules/audio_device:audio_device_impl", - "../modules/audio_mixer:audio_mixer_impl", - "../modules/rtp_rtcp", - "../modules/rtp_rtcp:rtp_rtcp_format", - "../rtc_base", - "../rtc_base:checks", - "../rtc_base:rtc_base_approved", - "../rtc_base:task_queue_for_test", - "../rtc_base:task_queue_for_test", - "../rtc_base/synchronization:mutex", - "../rtc_base/task_utils:repeating_task", - "../system_wrappers", - "../system_wrappers:metrics", - "../test:direct_transport", - "../test:encoder_settings", - "../test:fake_video_codecs", - "../test:field_trial", - "../test:fileutils", - "../test:null_transport", - "../test:perf_test", - "../test:rtp_test_utils", - "../test:test_common", - "../test:test_support", - "../test:video_test_common", - "../video", - "//testing/gtest", - ] - absl_deps = [ "//third_party/abseil-cpp/absl/flags:flag" ] + sources = [ + "call_perf_tests.cc", + "rampup_tests.cc", + "rampup_tests.h", + ] + deps = [ + ":call_interfaces", + ":simulated_network", + ":video_stream_api", + "../api:rtc_event_log_output_file", + "../api:simulated_network_api", + "../api/audio_codecs:builtin_audio_encoder_factory", + "../api/rtc_event_log", + "../api/rtc_event_log:rtc_event_log_factory", + "../api/task_queue", + "../api/task_queue:default_task_queue_factory", + "../api/video:builtin_video_bitrate_allocator_factory", + "../api/video:video_bitrate_allocation", + "../api/video_codecs:video_codecs_api", + "../modules/audio_coding", + "../modules/audio_device", + "../modules/audio_device:audio_device_impl", + "../modules/audio_mixer:audio_mixer_impl", + "../modules/rtp_rtcp", + "../modules/rtp_rtcp:rtp_rtcp_format", + "../rtc_base", + "../rtc_base:checks", + "../rtc_base:rtc_base_approved", + "../rtc_base:task_queue_for_test", + "../rtc_base:task_queue_for_test", + "../rtc_base:threading", + "../rtc_base/synchronization:mutex", + "../rtc_base/task_utils:repeating_task", + "../system_wrappers", + "../system_wrappers:metrics", + "../test:direct_transport", + "../test:encoder_settings", + "../test:fake_video_codecs", + "../test:field_trial", + "../test:fileutils", + "../test:null_transport", + "../test:perf_test", + "../test:rtp_test_utils", + "../test:test_common", + "../test:test_support", + "../test:video_test_common", + "../video", + "//testing/gtest", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/flags:flag" ] + } } # TODO(eladalon): This should be moved, as with the TODO for |rtp_interfaces|. diff --git a/call/adaptation/BUILD.gn b/call/adaptation/BUILD.gn index f782a8d5bc..10a46a3d43 100644 --- a/call/adaptation/BUILD.gn +++ b/call/adaptation/BUILD.gn @@ -34,6 +34,7 @@ rtc_library("resource_adaptation") { deps = [ "../../api:rtp_parameters", "../../api:scoped_refptr", + "../../api:sequence_checker", "../../api/adaptation:resource_adaptation_api", "../../api/task_queue:task_queue", "../../api/video:video_adaptation", @@ -46,7 +47,6 @@ rtc_library("resource_adaptation") { "../../rtc_base:rtc_task_queue", "../../rtc_base/experiments:balanced_degradation_settings", "../../rtc_base/synchronization:mutex", - "../../rtc_base/synchronization:sequence_checker", "../../rtc_base/system:no_unique_address", "../../rtc_base/task_utils:to_queued_task", ] @@ -108,11 +108,11 @@ if (rtc_include_tests) { deps = [ ":resource_adaptation", "../../api:scoped_refptr", + "../../api:sequence_checker", "../../api/adaptation:resource_adaptation_api", "../../api/task_queue:task_queue", "../../api/video:video_stream_encoder", "../../rtc_base:rtc_base_approved", - "../../rtc_base/synchronization:sequence_checker", "../../rtc_base/task_utils:to_queued_task", "../../test:test_support", ] diff --git a/call/adaptation/broadcast_resource_listener.cc b/call/adaptation/broadcast_resource_listener.cc index 59bd1e0c7f..876d4c0bf6 100644 --- a/call/adaptation/broadcast_resource_listener.cc +++ b/call/adaptation/broadcast_resource_listener.cc @@ -83,8 +83,8 @@ BroadcastResourceListener::CreateAdapterResource() { MutexLock lock(&lock_); RTC_DCHECK(is_listening_); rtc::scoped_refptr adapter = - new rtc::RefCountedObject(source_resource_->Name() + - "Adapter"); + rtc::make_ref_counted(source_resource_->Name() + + "Adapter"); adapters_.push_back(adapter); return adapter; } diff --git a/call/adaptation/resource_adaptation_processor.cc b/call/adaptation/resource_adaptation_processor.cc index ac1b1db174..741575ae38 100644 --- a/call/adaptation/resource_adaptation_processor.cc +++ b/call/adaptation/resource_adaptation_processor.cc @@ -15,12 +15,12 @@ #include #include "absl/algorithm/container.h" +#include "api/sequence_checker.h" #include "api/video/video_adaptation_counters.h" #include "call/adaptation/video_stream_adapter.h" #include "rtc_base/logging.h" #include "rtc_base/ref_counted_object.h" #include "rtc_base/strings/string_builder.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/task_utils/to_queued_task.h" namespace webrtc { @@ -72,7 +72,7 @@ ResourceAdaptationProcessor::ResourceAdaptationProcessor( VideoStreamAdapter* stream_adapter) : task_queue_(nullptr), resource_listener_delegate_( - new rtc::RefCountedObject(this)), + rtc::make_ref_counted(this)), resources_(), stream_adapter_(stream_adapter), last_reported_source_restrictions_(), diff --git a/call/adaptation/test/fake_resource.cc b/call/adaptation/test/fake_resource.cc index fa69e886bf..d125468cb6 100644 --- a/call/adaptation/test/fake_resource.cc +++ b/call/adaptation/test/fake_resource.cc @@ -19,7 +19,7 @@ namespace webrtc { // static rtc::scoped_refptr FakeResource::Create(std::string name) { - return new rtc::RefCountedObject(name); + return rtc::make_ref_counted(name); } FakeResource::FakeResource(std::string name) diff --git a/call/adaptation/video_stream_adapter.cc b/call/adaptation/video_stream_adapter.cc index 13eb0349a3..64e1a77786 100644 --- a/call/adaptation/video_stream_adapter.cc +++ b/call/adaptation/video_stream_adapter.cc @@ -16,6 +16,7 @@ #include "absl/types/optional.h" #include "absl/types/variant.h" +#include "api/sequence_checker.h" #include "api/video/video_adaptation_counters.h" #include "api/video/video_adaptation_reason.h" #include "api/video_codecs/video_encoder.h" @@ -25,7 +26,6 @@ #include "rtc_base/constructor_magic.h" #include "rtc_base/logging.h" #include "rtc_base/numerics/safe_conversions.h" -#include "rtc_base/synchronization/sequence_checker.h" namespace webrtc { @@ -62,13 +62,14 @@ int GetIncreasedMaxPixelsWanted(int target_pixels) { } bool CanDecreaseResolutionTo(int target_pixels, + int target_pixels_min, const VideoStreamInputState& input_state, const VideoSourceRestrictions& restrictions) { int max_pixels_per_frame = rtc::dchecked_cast(restrictions.max_pixels_per_frame().value_or( std::numeric_limits::max())); return target_pixels < max_pixels_per_frame && - target_pixels >= input_state.min_pixels_per_frame(); + target_pixels_min >= input_state.min_pixels_per_frame(); } bool CanIncreaseResolutionTo(int target_pixels, @@ -96,6 +97,11 @@ bool CanIncreaseFrameRateTo(int max_frame_rate, } bool MinPixelLimitReached(const VideoStreamInputState& input_state) { + if (input_state.single_active_stream_pixels().has_value()) { + return GetLowerResolutionThan( + input_state.single_active_stream_pixels().value()) < + input_state.min_pixels_per_frame(); + } return input_state.frame_size_pixels().has_value() && GetLowerResolutionThan(input_state.frame_size_pixels().value()) < input_state.min_pixels_per_frame(); @@ -410,8 +416,10 @@ VideoStreamAdapter::AdaptIfFpsDiffInsufficient( const VideoStreamInputState& input_state, const RestrictionsWithCounters& restrictions) const { RTC_DCHECK_EQ(degradation_preference_, DegradationPreference::BALANCED); + int frame_size_pixels = input_state.single_active_stream_pixels().value_or( + input_state.frame_size_pixels().value()); absl::optional min_fps_diff = - balanced_settings_.MinFpsDiff(input_state.frame_size_pixels().value()); + balanced_settings_.MinFpsDiff(frame_size_pixels); if (current_restrictions_.counters.fps_adaptations < restrictions.counters.fps_adaptations && min_fps_diff && input_state.frames_per_second() > 0) { @@ -470,7 +478,11 @@ VideoStreamAdapter::RestrictionsOrState VideoStreamAdapter::DecreaseResolution( const RestrictionsWithCounters& current_restrictions) { int target_pixels = GetLowerResolutionThan(input_state.frame_size_pixels().value()); - if (!CanDecreaseResolutionTo(target_pixels, input_state, + // Use single active stream if set, this stream could be lower than the input. + int target_pixels_min = + GetLowerResolutionThan(input_state.single_active_stream_pixels().value_or( + input_state.frame_size_pixels().value())); + if (!CanDecreaseResolutionTo(target_pixels, target_pixels_min, input_state, current_restrictions.restrictions)) { return Adaptation::Status::kLimitReached; } @@ -492,9 +504,10 @@ VideoStreamAdapter::RestrictionsOrState VideoStreamAdapter::DecreaseFramerate( if (degradation_preference_ == DegradationPreference::MAINTAIN_RESOLUTION) { max_frame_rate = GetLowerFrameRateThan(input_state.frames_per_second()); } else if (degradation_preference_ == DegradationPreference::BALANCED) { - max_frame_rate = - balanced_settings_.MinFps(input_state.video_codec_type(), - input_state.frame_size_pixels().value()); + int frame_size_pixels = input_state.single_active_stream_pixels().value_or( + input_state.frame_size_pixels().value()); + max_frame_rate = balanced_settings_.MinFps(input_state.video_codec_type(), + frame_size_pixels); } else { RTC_NOTREACHED(); max_frame_rate = GetLowerFrameRateThan(input_state.frames_per_second()); @@ -551,12 +564,21 @@ VideoStreamAdapter::RestrictionsOrState VideoStreamAdapter::IncreaseFramerate( if (degradation_preference_ == DegradationPreference::MAINTAIN_RESOLUTION) { max_frame_rate = GetHigherFrameRateThan(input_state.frames_per_second()); } else if (degradation_preference_ == DegradationPreference::BALANCED) { - max_frame_rate = - balanced_settings_.MaxFps(input_state.video_codec_type(), - input_state.frame_size_pixels().value()); + int frame_size_pixels = input_state.single_active_stream_pixels().value_or( + input_state.frame_size_pixels().value()); + max_frame_rate = balanced_settings_.MaxFps(input_state.video_codec_type(), + frame_size_pixels); + // Temporary fix for cases when there are fewer framerate adaptation steps + // up than down. Make number of down/up steps equal. + if (max_frame_rate == std::numeric_limits::max() && + current_restrictions.counters.fps_adaptations > 1) { + // Do not unrestrict framerate to allow additional adaptation up steps. + RTC_LOG(LS_INFO) << "Modifying framerate due to remaining fps count."; + max_frame_rate -= current_restrictions.counters.fps_adaptations; + } // In BALANCED, the max_frame_rate must be checked before proceeding. This // is because the MaxFps might be the current Fps and so the balanced - // settings may want to scale up the resolution.= + // settings may want to scale up the resolution. if (!CanIncreaseFrameRateTo(max_frame_rate, current_restrictions.restrictions)) { return Adaptation::Status::kLimitReached; @@ -693,4 +715,27 @@ VideoStreamAdapter::AwaitingFrameSizeChange::AwaitingFrameSizeChange( : pixels_increased(pixels_increased), frame_size_pixels(frame_size_pixels) {} +absl::optional VideoStreamAdapter::GetSingleActiveLayerPixels( + const VideoCodec& codec) { + int num_active = 0; + absl::optional pixels; + if (codec.codecType == VideoCodecType::kVideoCodecVP9) { + for (int i = 0; i < codec.VP9().numberOfSpatialLayers; ++i) { + if (codec.spatialLayers[i].active) { + ++num_active; + pixels = codec.spatialLayers[i].width * codec.spatialLayers[i].height; + } + } + } else { + for (int i = 0; i < codec.numberOfSimulcastStreams; ++i) { + if (codec.simulcastStream[i].active) { + ++num_active; + pixels = + codec.simulcastStream[i].width * codec.simulcastStream[i].height; + } + } + } + return (num_active > 1) ? absl::nullopt : pixels; +} + } // namespace webrtc diff --git a/call/adaptation/video_stream_adapter.h b/call/adaptation/video_stream_adapter.h index 2b55c3d49c..3c876b8970 100644 --- a/call/adaptation/video_stream_adapter.h +++ b/call/adaptation/video_stream_adapter.h @@ -163,6 +163,9 @@ class VideoStreamAdapter { VideoAdaptationCounters counters; }; + static absl::optional GetSingleActiveLayerPixels( + const VideoCodec& codec); + private: void BroadcastVideoRestrictionsUpdate( const VideoStreamInputState& input_state, diff --git a/call/adaptation/video_stream_input_state.cc b/call/adaptation/video_stream_input_state.cc index dc3315e6d0..9c0d475902 100644 --- a/call/adaptation/video_stream_input_state.cc +++ b/call/adaptation/video_stream_input_state.cc @@ -19,7 +19,8 @@ VideoStreamInputState::VideoStreamInputState() frame_size_pixels_(absl::nullopt), frames_per_second_(0), video_codec_type_(VideoCodecType::kVideoCodecGeneric), - min_pixels_per_frame_(kDefaultMinPixelsPerFrame) {} + min_pixels_per_frame_(kDefaultMinPixelsPerFrame), + single_active_stream_pixels_(absl::nullopt) {} void VideoStreamInputState::set_has_input(bool has_input) { has_input_ = has_input; @@ -43,6 +44,11 @@ void VideoStreamInputState::set_min_pixels_per_frame(int min_pixels_per_frame) { min_pixels_per_frame_ = min_pixels_per_frame; } +void VideoStreamInputState::set_single_active_stream_pixels( + absl::optional single_active_stream_pixels) { + single_active_stream_pixels_ = single_active_stream_pixels; +} + bool VideoStreamInputState::has_input() const { return has_input_; } @@ -63,6 +69,10 @@ int VideoStreamInputState::min_pixels_per_frame() const { return min_pixels_per_frame_; } +absl::optional VideoStreamInputState::single_active_stream_pixels() const { + return single_active_stream_pixels_; +} + bool VideoStreamInputState::HasInputFrameSizeAndFramesPerSecond() const { return has_input_ && frame_size_pixels_.has_value(); } diff --git a/call/adaptation/video_stream_input_state.h b/call/adaptation/video_stream_input_state.h index af0d7c78e9..191e22386a 100644 --- a/call/adaptation/video_stream_input_state.h +++ b/call/adaptation/video_stream_input_state.h @@ -27,12 +27,15 @@ class VideoStreamInputState { void set_frames_per_second(int frames_per_second); void set_video_codec_type(VideoCodecType video_codec_type); void set_min_pixels_per_frame(int min_pixels_per_frame); + void set_single_active_stream_pixels( + absl::optional single_active_stream_pixels); bool has_input() const; absl::optional frame_size_pixels() const; int frames_per_second() const; VideoCodecType video_codec_type() const; int min_pixels_per_frame() const; + absl::optional single_active_stream_pixels() const; bool HasInputFrameSizeAndFramesPerSecond() const; @@ -42,6 +45,7 @@ class VideoStreamInputState { int frames_per_second_; VideoCodecType video_codec_type_; int min_pixels_per_frame_; + absl::optional single_active_stream_pixels_; }; } // namespace webrtc diff --git a/call/adaptation/video_stream_input_state_provider.cc b/call/adaptation/video_stream_input_state_provider.cc index 3c0a7e3fa2..3261af39ea 100644 --- a/call/adaptation/video_stream_input_state_provider.cc +++ b/call/adaptation/video_stream_input_state_provider.cc @@ -10,6 +10,8 @@ #include "call/adaptation/video_stream_input_state_provider.h" +#include "call/adaptation/video_stream_adapter.h" + namespace webrtc { VideoStreamInputStateProvider::VideoStreamInputStateProvider( @@ -36,6 +38,9 @@ void VideoStreamInputStateProvider::OnEncoderSettingsChanged( encoder_settings.encoder_config().codec_type); input_state_.set_min_pixels_per_frame( encoder_settings.encoder_info().scaling_settings.min_pixels_per_frame); + input_state_.set_single_active_stream_pixels( + VideoStreamAdapter::GetSingleActiveLayerPixels( + encoder_settings.video_codec())); } VideoStreamInputState VideoStreamInputStateProvider::InputState() { diff --git a/call/adaptation/video_stream_input_state_provider_unittest.cc b/call/adaptation/video_stream_input_state_provider_unittest.cc index 49c662c581..5da2ef21cd 100644 --- a/call/adaptation/video_stream_input_state_provider_unittest.cc +++ b/call/adaptation/video_stream_input_state_provider_unittest.cc @@ -28,6 +28,7 @@ TEST(VideoStreamInputStateProviderTest, DefaultValues) { EXPECT_EQ(0, input_state.frames_per_second()); EXPECT_EQ(VideoCodecType::kVideoCodecGeneric, input_state.video_codec_type()); EXPECT_EQ(kDefaultMinPixelsPerFrame, input_state.min_pixels_per_frame()); + EXPECT_EQ(absl::nullopt, input_state.single_active_stream_pixels()); } TEST(VideoStreamInputStateProviderTest, ValuesSet) { @@ -40,14 +41,22 @@ TEST(VideoStreamInputStateProviderTest, ValuesSet) { encoder_info.scaling_settings.min_pixels_per_frame = 1337; VideoEncoderConfig encoder_config; encoder_config.codec_type = VideoCodecType::kVideoCodecVP9; + VideoCodec video_codec; + video_codec.codecType = VideoCodecType::kVideoCodecVP8; + video_codec.numberOfSimulcastStreams = 2; + video_codec.simulcastStream[0].active = false; + video_codec.simulcastStream[1].active = true; + video_codec.simulcastStream[1].width = 111; + video_codec.simulcastStream[1].height = 222; input_state_provider.OnEncoderSettingsChanged(EncoderSettings( - std::move(encoder_info), std::move(encoder_config), VideoCodec())); + std::move(encoder_info), std::move(encoder_config), video_codec)); VideoStreamInputState input_state = input_state_provider.InputState(); EXPECT_EQ(true, input_state.has_input()); EXPECT_EQ(42, input_state.frame_size_pixels()); EXPECT_EQ(123, input_state.frames_per_second()); EXPECT_EQ(VideoCodecType::kVideoCodecVP9, input_state.video_codec_type()); EXPECT_EQ(1337, input_state.min_pixels_per_frame()); + EXPECT_EQ(111 * 222, input_state.single_active_stream_pixels()); } } // namespace webrtc diff --git a/call/audio_receive_stream.h b/call/audio_receive_stream.h index eee62e9a8a..8403e6bea0 100644 --- a/call/audio_receive_stream.h +++ b/call/audio_receive_stream.h @@ -20,17 +20,14 @@ #include "api/audio_codecs/audio_decoder_factory.h" #include "api/call/transport.h" #include "api/crypto/crypto_options.h" -#include "api/crypto/frame_decryptor_interface.h" -#include "api/frame_transformer_interface.h" #include "api/rtp_parameters.h" -#include "api/scoped_refptr.h" -#include "api/transport/rtp/rtp_source.h" +#include "call/receive_stream.h" #include "call/rtp_config.h" namespace webrtc { class AudioSinkInterface; -class AudioReceiveStream { +class AudioReceiveStream : public MediaReceiveStream { public: struct Stats { Stats(); @@ -42,6 +39,7 @@ class AudioReceiveStream { uint64_t fec_packets_received = 0; uint64_t fec_packets_discarded = 0; uint32_t packets_lost = 0; + uint32_t nacks_sent = 0; std::string codec_name; absl::optional codec_payload_type; uint32_t jitter_ms = 0; @@ -90,6 +88,13 @@ class AudioReceiveStream { int32_t total_interruption_duration_ms = 0; // https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-estimatedplayouttimestamp absl::optional estimated_playout_ntp_timestamp_ms; + // Remote outbound stats derived by the received RTCP sender reports. + // https://w3c.github.io/webrtc-stats/#remoteoutboundrtpstats-dict* + absl::optional last_sender_report_timestamp_ms; + absl::optional last_sender_report_remote_timestamp_ms; + uint32_t sender_reports_packets_sent = 0; + uint64_t sender_reports_bytes_sent = 0; + uint64_t sender_reports_reports_count = 0; }; struct Config { @@ -99,29 +104,14 @@ class AudioReceiveStream { std::string ToString() const; // Receive-stream specific RTP settings. - struct Rtp { + struct Rtp : public RtpConfig { Rtp(); ~Rtp(); std::string ToString() const; - // Synchronization source (stream identifier) to be received. - uint32_t remote_ssrc = 0; - - // Sender SSRC used for sending RTCP (such as receiver reports). - uint32_t local_ssrc = 0; - - // Enable feedback for send side bandwidth estimation. - // See - // https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions - // for details. - bool transport_cc = false; - // See NackConfig for description. NackConfig nack; - - // RTP header extensions used for the received stream. - std::vector extensions; } rtp; Transport* rtcp_send_transport = nullptr; @@ -150,22 +140,29 @@ class AudioReceiveStream { // An optional custom frame decryptor that allows the entire frame to be // decrypted in whatever way the caller choses. This is not required by // default. + // TODO(tommi): Remove this member variable from the struct. It's not + // a part of the AudioReceiveStream state but rather a pass through + // variable. rtc::scoped_refptr frame_decryptor; // An optional frame transformer used by insertable streams to transform // encoded frames. + // TODO(tommi): Remove this member variable from the struct. It's not + // a part of the AudioReceiveStream state but rather a pass through + // variable. rtc::scoped_refptr frame_transformer; }; - // Reconfigure the stream according to the Configuration. - virtual void Reconfigure(const Config& config) = 0; + // Methods that support reconfiguring the stream post initialization. + virtual void SetDecoderMap(std::map decoder_map) = 0; + virtual void SetUseTransportCcAndNackHistory(bool use_transport_cc, + int history_ms) = 0; + // Set/change the rtp header extensions. Must be called on the packet + // delivery thread. + virtual void SetRtpExtensions(std::vector extensions) = 0; - // Starts stream activity. - // When a stream is active, it can receive, process and deliver packets. - virtual void Start() = 0; - // Stops stream activity. - // When a stream is stopped, it can't receive, process or deliver packets. - virtual void Stop() = 0; + // Returns true if the stream has been started. + virtual bool IsRunning() const = 0; virtual Stats GetStats(bool get_and_clear_legacy_stats) const = 0; Stats GetStats() { return GetStats(/*get_and_clear_legacy_stats=*/true); } @@ -192,8 +189,6 @@ class AudioReceiveStream { // Returns current value of base minimum delay in milliseconds. virtual int GetBaseMinimumPlayoutDelayMs() const = 0; - virtual std::vector GetSources() const = 0; - protected: virtual ~AudioReceiveStream() {} }; diff --git a/call/audio_send_stream.cc b/call/audio_send_stream.cc index 76480f2362..916336b929 100644 --- a/call/audio_send_stream.cc +++ b/call/audio_send_stream.cc @@ -12,7 +12,6 @@ #include -#include "rtc_base/string_encode.h" #include "rtc_base/strings/audio_format_to_string.h" #include "rtc_base/strings/string_builder.h" @@ -27,8 +26,7 @@ AudioSendStream::Config::Config(Transport* send_transport) AudioSendStream::Config::~Config() = default; std::string AudioSendStream::Config::ToString() const { - char buf[1024]; - rtc::SimpleStringBuilder ss(buf); + rtc::StringBuilder ss; ss << "{rtp: " << rtp.ToString(); ss << ", rtcp_report_interval_ms: " << rtcp_report_interval_ms; ss << ", send_transport: " << (send_transport ? "(Transport)" : "null"); @@ -39,8 +37,8 @@ std::string AudioSendStream::Config::ToString() const { ss << ", has_dscp: " << (has_dscp ? "true" : "false"); ss << ", send_codec_spec: " << (send_codec_spec ? send_codec_spec->ToString() : ""); - ss << '}'; - return ss.str(); + ss << "}"; + return ss.Release(); } AudioSendStream::Config::Rtp::Rtp() = default; diff --git a/call/audio_send_stream.h b/call/audio_send_stream.h index d21dff4889..e084d4219d 100644 --- a/call/audio_send_stream.h +++ b/call/audio_send_stream.h @@ -70,6 +70,7 @@ class AudioSendStream : public AudioSender { // per-pair the ReportBlockData represents the latest Report Block that was // received for that pair. std::vector report_block_datas; + uint32_t nacks_rcvd = 0; }; struct Config { diff --git a/call/bitrate_allocator.h b/call/bitrate_allocator.h index 481d91b23c..c0d664b6f0 100644 --- a/call/bitrate_allocator.h +++ b/call/bitrate_allocator.h @@ -20,8 +20,8 @@ #include #include "api/call/bitrate_allocation.h" +#include "api/sequence_checker.h" #include "api/transport/network_types.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" namespace webrtc { diff --git a/call/call.cc b/call/call.cc index 6f407fc0f0..fb1d7cd3bc 100644 --- a/call/call.cc +++ b/call/call.cc @@ -13,14 +13,17 @@ #include #include +#include #include #include #include #include #include +#include "absl/functional/bind_front.h" #include "absl/types/optional.h" #include "api/rtc_event_log/rtc_event_log.h" +#include "api/sequence_checker.h" #include "api/transport/network_control.h" #include "audio/audio_receive_stream.h" #include "audio/audio_send_stream.h" @@ -31,6 +34,7 @@ #include "call/receive_time_calculator.h" #include "call/rtp_stream_receiver_controller.h" #include "call/rtp_transport_controller_send.h" +#include "call/rtp_transport_controller_send_factory.h" #include "call/version.h" #include "logging/rtc_event_log/events/rtc_event_audio_receive_stream_config.h" #include "logging/rtc_event_log/events/rtc_event_rtcp_packet_incoming.h" @@ -43,7 +47,7 @@ #include "modules/rtp_rtcp/include/rtp_header_extension_map.h" #include "modules/rtp_rtcp/source/byte_io.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" -#include "modules/rtp_rtcp/source/rtp_utility.h" +#include "modules/rtp_rtcp/source/rtp_util.h" #include "modules/utility/include/process_thread.h" #include "modules/video_coding/fec_controller_default.h" #include "rtc_base/checks.h" @@ -51,7 +55,6 @@ #include "rtc_base/location.h" #include "rtc_base/logging.h" #include "rtc_base/strings/string_builder.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/thread_annotations.h" @@ -78,12 +81,10 @@ bool SendPeriodicFeedback(const std::vector& extensions) { return true; } -// TODO(nisse): This really begs for a shared context struct. -bool UseSendSideBwe(const std::vector& extensions, - bool transport_cc) { - if (!transport_cc) +bool UseSendSideBwe(const ReceiveStream::RtpConfig& rtp) { + if (!rtp.transport_cc) return false; - for (const auto& extension : extensions) { + for (const auto& extension : rtp.extensions) { if (extension.uri == RtpExtension::kTransportSequenceNumberUri || extension.uri == RtpExtension::kTransportSequenceNumberV2Uri) return true; @@ -91,18 +92,6 @@ bool UseSendSideBwe(const std::vector& extensions, return false; } -bool UseSendSideBwe(const VideoReceiveStream::Config& config) { - return UseSendSideBwe(config.rtp.extensions, config.rtp.transport_cc); -} - -bool UseSendSideBwe(const AudioReceiveStream::Config& config) { - return UseSendSideBwe(config.rtp.extensions, config.rtp.transport_cc); -} - -bool UseSendSideBwe(const FlexfecReceiveStream::Config& config) { - return UseSendSideBwe(config.rtp_header_extensions, config.transport_cc); -} - const int* FindKeyByValue(const std::map& m, int v) { for (const auto& kv : m) { if (kv.second == v) @@ -155,11 +144,6 @@ std::unique_ptr CreateRtcLogStreamConfig( return rtclog_config; } -bool IsRtcp(const uint8_t* packet, size_t length) { - RtpUtility::RtpHeaderParser rtp_parser(packet, length); - return rtp_parser.RTCP(); -} - TaskQueueBase* GetCurrentTaskQueueOrThread() { TaskQueueBase* current = TaskQueueBase::Current(); if (!current) @@ -264,6 +248,9 @@ class Call final : public webrtc::Call, const WebRtcKeyValueConfig& trials() const override; + TaskQueueBase* network_thread() const override; + TaskQueueBase* worker_thread() const override; + // Implements PacketReceiver. DeliveryStatus DeliverPacket(MediaType media_type, rtc::CopyOnWriteBuffer packet, @@ -277,6 +264,12 @@ class Call final : public webrtc::Call, void OnAudioTransportOverheadChanged( int transport_overhead_per_packet) override; + void OnLocalSsrcUpdated(webrtc::AudioReceiveStream& stream, + uint32_t local_ssrc) override; + + void OnUpdateSyncGroup(webrtc::AudioReceiveStream& stream, + const std::string& sync_group) override; + void OnSentPacket(const rtc::SentPacket& sent_packet) override; // Implements TargetTransferRateObserver, @@ -289,91 +282,124 @@ class Call final : public webrtc::Call, void SetClientBitratePreferences(const BitrateSettings& preferences) override; private: - DeliveryStatus DeliverRtcp(MediaType media_type, - const uint8_t* packet, - size_t length) - RTC_EXCLUSIVE_LOCKS_REQUIRED(worker_thread_); + // Thread-compatible class that collects received packet stats and exposes + // them as UMA histograms on destruction. + class ReceiveStats { + public: + explicit ReceiveStats(Clock* clock); + ~ReceiveStats(); + + void AddReceivedRtcpBytes(int bytes); + void AddReceivedAudioBytes(int bytes, webrtc::Timestamp arrival_time); + void AddReceivedVideoBytes(int bytes, webrtc::Timestamp arrival_time); + + private: + RTC_NO_UNIQUE_ADDRESS SequenceChecker sequence_checker_; + RateCounter received_bytes_per_second_counter_ + RTC_GUARDED_BY(sequence_checker_); + RateCounter received_audio_bytes_per_second_counter_ + RTC_GUARDED_BY(sequence_checker_); + RateCounter received_video_bytes_per_second_counter_ + RTC_GUARDED_BY(sequence_checker_); + RateCounter received_rtcp_bytes_per_second_counter_ + RTC_GUARDED_BY(sequence_checker_); + absl::optional first_received_rtp_audio_timestamp_ + RTC_GUARDED_BY(sequence_checker_); + absl::optional last_received_rtp_audio_timestamp_ + RTC_GUARDED_BY(sequence_checker_); + absl::optional first_received_rtp_video_timestamp_ + RTC_GUARDED_BY(sequence_checker_); + absl::optional last_received_rtp_video_timestamp_ + RTC_GUARDED_BY(sequence_checker_); + }; + + // Thread-compatible class that collects sent packet stats and exposes + // them as UMA histograms on destruction, provided SetFirstPacketTime was + // called with a non-empty packet timestamp before the destructor. + class SendStats { + public: + explicit SendStats(Clock* clock); + ~SendStats(); + + void SetFirstPacketTime(absl::optional first_sent_packet_time); + void PauseSendAndPacerBitrateCounters(); + void AddTargetBitrateSample(uint32_t target_bitrate_bps); + void SetMinAllocatableRate(BitrateAllocationLimits limits); + + private: + RTC_NO_UNIQUE_ADDRESS SequenceChecker destructor_sequence_checker_; + RTC_NO_UNIQUE_ADDRESS SequenceChecker sequence_checker_; + Clock* const clock_ RTC_GUARDED_BY(destructor_sequence_checker_); + AvgCounter estimated_send_bitrate_kbps_counter_ + RTC_GUARDED_BY(sequence_checker_); + AvgCounter pacer_bitrate_kbps_counter_ RTC_GUARDED_BY(sequence_checker_); + uint32_t min_allocated_send_bitrate_bps_ RTC_GUARDED_BY(sequence_checker_){ + 0}; + absl::optional first_sent_packet_time_ + RTC_GUARDED_BY(destructor_sequence_checker_); + }; + + void DeliverRtcp(MediaType media_type, rtc::CopyOnWriteBuffer packet) + RTC_RUN_ON(network_thread_); DeliveryStatus DeliverRtp(MediaType media_type, rtc::CopyOnWriteBuffer packet, - int64_t packet_time_us) - RTC_EXCLUSIVE_LOCKS_REQUIRED(worker_thread_); - void ConfigureSync(const std::string& sync_group) - RTC_EXCLUSIVE_LOCKS_REQUIRED(worker_thread_); + int64_t packet_time_us) RTC_RUN_ON(worker_thread_); + void ConfigureSync(const std::string& sync_group) RTC_RUN_ON(worker_thread_); void NotifyBweOfReceivedPacket(const RtpPacketReceived& packet, MediaType media_type) - RTC_SHARED_LOCKS_REQUIRED(worker_thread_); + RTC_RUN_ON(worker_thread_); - void UpdateSendHistograms(Timestamp first_sent_packet) - RTC_EXCLUSIVE_LOCKS_REQUIRED(worker_thread_); - void UpdateReceiveHistograms(); - void UpdateHistograms(); void UpdateAggregateNetworkState(); // Ensure that necessary process threads are started, and any required // callbacks have been registered. - void EnsureStarted() RTC_EXCLUSIVE_LOCKS_REQUIRED(worker_thread_); - - rtc::TaskQueue* send_transport_queue() const { - return transport_send_ptr_->GetWorkerQueue(); - } + void EnsureStarted() RTC_RUN_ON(worker_thread_); Clock* const clock_; TaskQueueFactory* const task_queue_factory_; TaskQueueBase* const worker_thread_; + TaskQueueBase* const network_thread_; + RTC_NO_UNIQUE_ADDRESS SequenceChecker send_transport_sequence_checker_; const int num_cpu_cores_; const rtc::scoped_refptr module_process_thread_; const std::unique_ptr call_stats_; const std::unique_ptr bitrate_allocator_; - Call::Config config_; - - NetworkState audio_network_state_; - NetworkState video_network_state_; + const Call::Config config_ RTC_GUARDED_BY(worker_thread_); + // Maps to config_.trials, can be used from any thread via `trials()`. + const WebRtcKeyValueConfig& trials_; + + NetworkState audio_network_state_ RTC_GUARDED_BY(worker_thread_); + NetworkState video_network_state_ RTC_GUARDED_BY(worker_thread_); + // TODO(bugs.webrtc.org/11993): Move aggregate_network_up_ over to the + // network thread. bool aggregate_network_up_ RTC_GUARDED_BY(worker_thread_); // Audio, Video, and FlexFEC receive streams are owned by the client that // creates them. + // TODO(bugs.webrtc.org/11993): Move audio_receive_streams_, + // video_receive_streams_ and sync_stream_mapping_ over to the network thread. std::set audio_receive_streams_ RTC_GUARDED_BY(worker_thread_); std::set video_receive_streams_ RTC_GUARDED_BY(worker_thread_); - std::map sync_stream_mapping_ RTC_GUARDED_BY(worker_thread_); // TODO(nisse): Should eventually be injected at creation, // with a single object in the bundled case. - RtpStreamReceiverController audio_receiver_controller_; - RtpStreamReceiverController video_receiver_controller_; + RtpStreamReceiverController audio_receiver_controller_ + RTC_GUARDED_BY(worker_thread_); + RtpStreamReceiverController video_receiver_controller_ + RTC_GUARDED_BY(worker_thread_); // This extra map is used for receive processing which is // independent of media type. - // TODO(nisse): In the RTP transport refactoring, we should have a - // single mapping from ssrc to a more abstract receive stream, with - // accessor methods for all configuration we need at this level. - struct ReceiveRtpConfig { - explicit ReceiveRtpConfig(const webrtc::AudioReceiveStream::Config& config) - : extensions(config.rtp.extensions), - use_send_side_bwe(UseSendSideBwe(config)) {} - explicit ReceiveRtpConfig(const webrtc::VideoReceiveStream::Config& config) - : extensions(config.rtp.extensions), - use_send_side_bwe(UseSendSideBwe(config)) {} - explicit ReceiveRtpConfig(const FlexfecReceiveStream::Config& config) - : extensions(config.rtp_header_extensions), - use_send_side_bwe(UseSendSideBwe(config)) {} - - // Registered RTP header extensions for each stream. Note that RTP header - // extensions are negotiated per track ("m= line") in the SDP, but we have - // no notion of tracks at the Call level. We therefore store the RTP header - // extensions per SSRC instead, which leads to some storage overhead. - const RtpHeaderExtensionMap extensions; - // Set if both RTP extension the RTCP feedback message needed for - // send side BWE are negotiated. - const bool use_send_side_bwe; - }; - std::map receive_rtp_config_ + // TODO(bugs.webrtc.org/11993): Move receive_rtp_config_ over to the + // network thread. + std::map receive_rtp_config_ RTC_GUARDED_BY(worker_thread_); // Audio and Video send streams are owned by the client that creates them. @@ -382,6 +408,10 @@ class Call final : public webrtc::Call, std::map video_send_ssrcs_ RTC_GUARDED_BY(worker_thread_); std::set video_send_streams_ RTC_GUARDED_BY(worker_thread_); + // True if |video_send_streams_| is empty, false if not. The atomic variable + // is used to decide UMA send statistics behavior and enables avoiding a + // PostTask(). + std::atomic video_send_streams_empty_{true}; // Each forwarder wraps an adaptation resource that was added to the call. std::vector> @@ -395,49 +425,41 @@ class Call final : public webrtc::Call, RtpPayloadStateMap suspended_video_payload_states_ RTC_GUARDED_BY(worker_thread_); - webrtc::RtcEventLog* event_log_; - - // The following members are only accessed (exclusively) from one thread and - // from the destructor, and therefore doesn't need any explicit - // synchronization. - RateCounter received_bytes_per_second_counter_; - RateCounter received_audio_bytes_per_second_counter_; - RateCounter received_video_bytes_per_second_counter_; - RateCounter received_rtcp_bytes_per_second_counter_; - absl::optional first_received_rtp_audio_ms_; - absl::optional last_received_rtp_audio_ms_; - absl::optional first_received_rtp_video_ms_; - absl::optional last_received_rtp_video_ms_; - - uint32_t last_bandwidth_bps_ RTC_GUARDED_BY(worker_thread_); - // TODO(holmer): Remove this lock once BitrateController no longer calls - // OnNetworkChanged from multiple threads. - uint32_t min_allocated_send_bitrate_bps_ RTC_GUARDED_BY(worker_thread_); - uint32_t configured_max_padding_bitrate_bps_ RTC_GUARDED_BY(worker_thread_); - AvgCounter estimated_send_bitrate_kbps_counter_ - RTC_GUARDED_BY(worker_thread_); - AvgCounter pacer_bitrate_kbps_counter_ RTC_GUARDED_BY(worker_thread_); + webrtc::RtcEventLog* const event_log_; + + // TODO(bugs.webrtc.org/11993) ready to move stats access to the network + // thread. + ReceiveStats receive_stats_ RTC_GUARDED_BY(worker_thread_); + SendStats send_stats_ RTC_GUARDED_BY(send_transport_sequence_checker_); + // |last_bandwidth_bps_| and |configured_max_padding_bitrate_bps_| being + // atomic avoids a PostTask. The variables are used for stats gathering. + std::atomic last_bandwidth_bps_{0}; + std::atomic configured_max_padding_bitrate_bps_{0}; ReceiveSideCongestionController receive_side_cc_; const std::unique_ptr receive_time_calculator_; const std::unique_ptr video_send_delay_stats_; - const int64_t start_ms_; + const Timestamp start_of_call_; // Note that |task_safety_| needs to be at a greater scope than the task queue // owned by |transport_send_| since calls might arrive on the network thread // while Call is being deleted and the task queue is being torn down. - ScopedTaskSafety task_safety_; + const ScopedTaskSafety task_safety_; // Caches transport_send_.get(), to avoid racing with destructor. // Note that this is declared before transport_send_ to ensure that it is not // invalidated until no more tasks can be running on the transport_send_ task // queue. - RtpTransportControllerSendInterface* const transport_send_ptr_; + // For more details on the background of this member variable, see: + // https://webrtc-review.googlesource.com/c/src/+/63023/9/call/call.cc + // https://bugs.chromium.org/p/chromium/issues/detail?id=992640 + RtpTransportControllerSendInterface* const transport_send_ptr_ + RTC_GUARDED_BY(send_transport_sequence_checker_); // Declared last since it will issue callbacks from a task queue. Declaring it // last ensures that it is destroyed first and any running tasks are finished. - std::unique_ptr transport_send_; + const std::unique_ptr transport_send_; bool is_started_ RTC_GUARDED_BY(worker_thread_) = false; @@ -462,11 +484,6 @@ Call* Call::Create(const Call::Config& config) { rtc::scoped_refptr call_thread = SharedModuleThread::Create(ProcessThread::Create("ModuleProcessThread"), nullptr); - return Create(config, std::move(call_thread)); -} - -Call* Call::Create(const Call::Config& config, - rtc::scoped_refptr call_thread) { return Create(config, Clock::GetRealTimeClock(), std::move(call_thread), ProcessThread::Create("PacerThread")); } @@ -476,15 +493,28 @@ Call* Call::Create(const Call::Config& config, rtc::scoped_refptr call_thread, std::unique_ptr pacer_thread) { RTC_DCHECK(config.task_queue_factory); + + RtpTransportControllerSendFactory transport_controller_factory_; + + RtpTransportConfig transportConfig = config.ExtractTransportConfig(); + return new internal::Call( clock, config, - std::make_unique( - clock, config.event_log, config.network_state_predictor_factory, - config.network_controller_factory, config.bitrate_config, - std::move(pacer_thread), config.task_queue_factory, config.trials), + transport_controller_factory_.Create(transportConfig, clock, + std::move(pacer_thread)), std::move(call_thread), config.task_queue_factory); } +Call* Call::Create(const Call::Config& config, + Clock* clock, + rtc::scoped_refptr call_thread, + std::unique_ptr + transportControllerSend) { + RTC_DCHECK(config.task_queue_factory); + return new internal::Call(clock, config, std::move(transportControllerSend), + std::move(call_thread), config.task_queue_factory); +} + class SharedModuleThread::Impl { public: Impl(std::unique_ptr process_thread, @@ -589,6 +619,157 @@ VideoSendStream* Call::CreateVideoSendStream( namespace internal { +Call::ReceiveStats::ReceiveStats(Clock* clock) + : received_bytes_per_second_counter_(clock, nullptr, false), + received_audio_bytes_per_second_counter_(clock, nullptr, false), + received_video_bytes_per_second_counter_(clock, nullptr, false), + received_rtcp_bytes_per_second_counter_(clock, nullptr, false) { + sequence_checker_.Detach(); +} + +void Call::ReceiveStats::AddReceivedRtcpBytes(int bytes) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + if (received_bytes_per_second_counter_.HasSample()) { + // First RTP packet has been received. + received_bytes_per_second_counter_.Add(static_cast(bytes)); + received_rtcp_bytes_per_second_counter_.Add(static_cast(bytes)); + } +} + +void Call::ReceiveStats::AddReceivedAudioBytes(int bytes, + webrtc::Timestamp arrival_time) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + received_bytes_per_second_counter_.Add(bytes); + received_audio_bytes_per_second_counter_.Add(bytes); + if (!first_received_rtp_audio_timestamp_) + first_received_rtp_audio_timestamp_ = arrival_time; + last_received_rtp_audio_timestamp_ = arrival_time; +} + +void Call::ReceiveStats::AddReceivedVideoBytes(int bytes, + webrtc::Timestamp arrival_time) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + received_bytes_per_second_counter_.Add(bytes); + received_video_bytes_per_second_counter_.Add(bytes); + if (!first_received_rtp_video_timestamp_) + first_received_rtp_video_timestamp_ = arrival_time; + last_received_rtp_video_timestamp_ = arrival_time; +} + +Call::ReceiveStats::~ReceiveStats() { + RTC_DCHECK_RUN_ON(&sequence_checker_); + if (first_received_rtp_audio_timestamp_) { + RTC_HISTOGRAM_COUNTS_100000( + "WebRTC.Call.TimeReceivingAudioRtpPacketsInSeconds", + (*last_received_rtp_audio_timestamp_ - + *first_received_rtp_audio_timestamp_) + .seconds()); + } + if (first_received_rtp_video_timestamp_) { + RTC_HISTOGRAM_COUNTS_100000( + "WebRTC.Call.TimeReceivingVideoRtpPacketsInSeconds", + (*last_received_rtp_video_timestamp_ - + *first_received_rtp_video_timestamp_) + .seconds()); + } + const int kMinRequiredPeriodicSamples = 5; + AggregatedStats video_bytes_per_sec = + received_video_bytes_per_second_counter_.GetStats(); + if (video_bytes_per_sec.num_samples > kMinRequiredPeriodicSamples) { + RTC_HISTOGRAM_COUNTS_100000("WebRTC.Call.VideoBitrateReceivedInKbps", + video_bytes_per_sec.average * 8 / 1000); + RTC_LOG(LS_INFO) << "WebRTC.Call.VideoBitrateReceivedInBps, " + << video_bytes_per_sec.ToStringWithMultiplier(8); + } + AggregatedStats audio_bytes_per_sec = + received_audio_bytes_per_second_counter_.GetStats(); + if (audio_bytes_per_sec.num_samples > kMinRequiredPeriodicSamples) { + RTC_HISTOGRAM_COUNTS_100000("WebRTC.Call.AudioBitrateReceivedInKbps", + audio_bytes_per_sec.average * 8 / 1000); + RTC_LOG(LS_INFO) << "WebRTC.Call.AudioBitrateReceivedInBps, " + << audio_bytes_per_sec.ToStringWithMultiplier(8); + } + AggregatedStats rtcp_bytes_per_sec = + received_rtcp_bytes_per_second_counter_.GetStats(); + if (rtcp_bytes_per_sec.num_samples > kMinRequiredPeriodicSamples) { + RTC_HISTOGRAM_COUNTS_100000("WebRTC.Call.RtcpBitrateReceivedInBps", + rtcp_bytes_per_sec.average * 8); + RTC_LOG(LS_INFO) << "WebRTC.Call.RtcpBitrateReceivedInBps, " + << rtcp_bytes_per_sec.ToStringWithMultiplier(8); + } + AggregatedStats recv_bytes_per_sec = + received_bytes_per_second_counter_.GetStats(); + if (recv_bytes_per_sec.num_samples > kMinRequiredPeriodicSamples) { + RTC_HISTOGRAM_COUNTS_100000("WebRTC.Call.BitrateReceivedInKbps", + recv_bytes_per_sec.average * 8 / 1000); + RTC_LOG(LS_INFO) << "WebRTC.Call.BitrateReceivedInBps, " + << recv_bytes_per_sec.ToStringWithMultiplier(8); + } +} + +Call::SendStats::SendStats(Clock* clock) + : clock_(clock), + estimated_send_bitrate_kbps_counter_(clock, nullptr, true), + pacer_bitrate_kbps_counter_(clock, nullptr, true) { + destructor_sequence_checker_.Detach(); + sequence_checker_.Detach(); +} + +Call::SendStats::~SendStats() { + RTC_DCHECK_RUN_ON(&destructor_sequence_checker_); + if (!first_sent_packet_time_) + return; + + TimeDelta elapsed = clock_->CurrentTime() - *first_sent_packet_time_; + if (elapsed.seconds() < metrics::kMinRunTimeInSeconds) + return; + + const int kMinRequiredPeriodicSamples = 5; + AggregatedStats send_bitrate_stats = + estimated_send_bitrate_kbps_counter_.ProcessAndGetStats(); + if (send_bitrate_stats.num_samples > kMinRequiredPeriodicSamples) { + RTC_HISTOGRAM_COUNTS_100000("WebRTC.Call.EstimatedSendBitrateInKbps", + send_bitrate_stats.average); + RTC_LOG(LS_INFO) << "WebRTC.Call.EstimatedSendBitrateInKbps, " + << send_bitrate_stats.ToString(); + } + AggregatedStats pacer_bitrate_stats = + pacer_bitrate_kbps_counter_.ProcessAndGetStats(); + if (pacer_bitrate_stats.num_samples > kMinRequiredPeriodicSamples) { + RTC_HISTOGRAM_COUNTS_100000("WebRTC.Call.PacerBitrateInKbps", + pacer_bitrate_stats.average); + RTC_LOG(LS_INFO) << "WebRTC.Call.PacerBitrateInKbps, " + << pacer_bitrate_stats.ToString(); + } +} + +void Call::SendStats::SetFirstPacketTime( + absl::optional first_sent_packet_time) { + RTC_DCHECK_RUN_ON(&destructor_sequence_checker_); + first_sent_packet_time_ = first_sent_packet_time; +} + +void Call::SendStats::PauseSendAndPacerBitrateCounters() { + RTC_DCHECK_RUN_ON(&sequence_checker_); + estimated_send_bitrate_kbps_counter_.ProcessAndPause(); + pacer_bitrate_kbps_counter_.ProcessAndPause(); +} + +void Call::SendStats::AddTargetBitrateSample(uint32_t target_bitrate_bps) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + estimated_send_bitrate_kbps_counter_.Add(target_bitrate_bps / 1000); + // Pacer bitrate may be higher than bitrate estimate if enforcing min + // bitrate. + uint32_t pacer_bitrate_bps = + std::max(target_bitrate_bps, min_allocated_send_bitrate_bps_); + pacer_bitrate_kbps_counter_.Add(pacer_bitrate_bps / 1000); +} + +void Call::SendStats::SetMinAllocatableRate(BitrateAllocationLimits limits) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + min_allocated_send_bitrate_bps_ = limits.min_allocatable_rate.bps(); +} + Call::Call(Clock* clock, const Call::Config& config, std::unique_ptr transport_send, @@ -597,34 +778,40 @@ Call::Call(Clock* clock, : clock_(clock), task_queue_factory_(task_queue_factory), worker_thread_(GetCurrentTaskQueueOrThread()), + // If |network_task_queue_| was set to nullptr, network related calls + // must be made on |worker_thread_| (i.e. they're one and the same). + network_thread_(config.network_task_queue_ ? config.network_task_queue_ + : worker_thread_), num_cpu_cores_(CpuInfo::DetectNumberOfCores()), module_process_thread_(std::move(module_process_thread)), call_stats_(new CallStats(clock_, worker_thread_)), bitrate_allocator_(new BitrateAllocator(this)), config_(config), + trials_(*config.trials), audio_network_state_(kNetworkDown), video_network_state_(kNetworkDown), aggregate_network_up_(false), event_log_(config.event_log), - received_bytes_per_second_counter_(clock_, nullptr, true), - received_audio_bytes_per_second_counter_(clock_, nullptr, true), - received_video_bytes_per_second_counter_(clock_, nullptr, true), - received_rtcp_bytes_per_second_counter_(clock_, nullptr, true), - last_bandwidth_bps_(0), - min_allocated_send_bitrate_bps_(0), - configured_max_padding_bitrate_bps_(0), - estimated_send_bitrate_kbps_counter_(clock_, nullptr, true), - pacer_bitrate_kbps_counter_(clock_, nullptr, true), - receive_side_cc_(clock_, transport_send->packet_router()), + receive_stats_(clock_), + send_stats_(clock_), + receive_side_cc_(clock, + absl::bind_front(&PacketRouter::SendCombinedRtcpPacket, + transport_send->packet_router()), + absl::bind_front(&PacketRouter::SendRemb, + transport_send->packet_router()), + /*network_state_estimator=*/nullptr), receive_time_calculator_(ReceiveTimeCalculator::CreateFromFieldTrial()), video_send_delay_stats_(new SendDelayStats(clock_)), - start_ms_(clock_->TimeInMilliseconds()), + start_of_call_(clock_->CurrentTime()), transport_send_ptr_(transport_send.get()), transport_send_(std::move(transport_send)) { RTC_DCHECK(config.event_log != nullptr); RTC_DCHECK(config.trials != nullptr); + RTC_DCHECK(network_thread_); RTC_DCHECK(worker_thread_->IsCurrent()); + send_transport_sequence_checker_.Detach(); + // Do not remove this call; it is here to convince the compiler that the // WebRTC source timestamp string needs to be in the final binary. LoadWebRTCVersionInRegister(); @@ -650,18 +837,11 @@ Call::~Call() { receive_side_cc_.GetRemoteBitrateEstimator(true)); module_process_thread_->process_thread()->DeRegisterModule(&receive_side_cc_); call_stats_->DeregisterStatsObserver(&receive_side_cc_); + send_stats_.SetFirstPacketTime(transport_send_->GetFirstPacketTime()); - absl::optional first_sent_packet_ms = - transport_send_->GetFirstPacketTime(); - - // Only update histograms after process threads have been shut down, so that - // they won't try to concurrently update stats. - if (first_sent_packet_ms) { - UpdateSendHistograms(*first_sent_packet_ms); - } - - UpdateReceiveHistograms(); - UpdateHistograms(); + RTC_HISTOGRAM_COUNTS_100000( + "WebRTC.Call.LifetimeInSeconds", + (clock_->CurrentTime() - start_of_call_).seconds()); } void Call::EnsureStarted() { @@ -670,12 +850,14 @@ void Call::EnsureStarted() { } is_started_ = true; + call_stats_->EnsureStarted(); + // This call seems to kick off a number of things, so probably better left // off being kicked off on request rather than in the ctor. - transport_send_ptr_->RegisterTargetTransferRateObserver(this); + transport_send_->RegisterTargetTransferRateObserver(this); module_process_thread_->EnsureStarted(); - transport_send_ptr_->EnsureStarted(); + transport_send_->EnsureStarted(); } void Call::SetClientBitratePreferences(const BitrateSettings& preferences) { @@ -683,85 +865,7 @@ void Call::SetClientBitratePreferences(const BitrateSettings& preferences) { GetTransportControllerSend()->SetClientBitratePreferences(preferences); } -void Call::UpdateHistograms() { - RTC_HISTOGRAM_COUNTS_100000( - "WebRTC.Call.LifetimeInSeconds", - (clock_->TimeInMilliseconds() - start_ms_) / 1000); -} - -// Called from the dtor. -void Call::UpdateSendHistograms(Timestamp first_sent_packet) { - int64_t elapsed_sec = - (clock_->TimeInMilliseconds() - first_sent_packet.ms()) / 1000; - if (elapsed_sec < metrics::kMinRunTimeInSeconds) - return; - const int kMinRequiredPeriodicSamples = 5; - AggregatedStats send_bitrate_stats = - estimated_send_bitrate_kbps_counter_.ProcessAndGetStats(); - if (send_bitrate_stats.num_samples > kMinRequiredPeriodicSamples) { - RTC_HISTOGRAM_COUNTS_100000("WebRTC.Call.EstimatedSendBitrateInKbps", - send_bitrate_stats.average); - RTC_LOG(LS_INFO) << "WebRTC.Call.EstimatedSendBitrateInKbps, " - << send_bitrate_stats.ToString(); - } - AggregatedStats pacer_bitrate_stats = - pacer_bitrate_kbps_counter_.ProcessAndGetStats(); - if (pacer_bitrate_stats.num_samples > kMinRequiredPeriodicSamples) { - RTC_HISTOGRAM_COUNTS_100000("WebRTC.Call.PacerBitrateInKbps", - pacer_bitrate_stats.average); - RTC_LOG(LS_INFO) << "WebRTC.Call.PacerBitrateInKbps, " - << pacer_bitrate_stats.ToString(); - } -} - -void Call::UpdateReceiveHistograms() { - if (first_received_rtp_audio_ms_) { - RTC_HISTOGRAM_COUNTS_100000( - "WebRTC.Call.TimeReceivingAudioRtpPacketsInSeconds", - (*last_received_rtp_audio_ms_ - *first_received_rtp_audio_ms_) / 1000); - } - if (first_received_rtp_video_ms_) { - RTC_HISTOGRAM_COUNTS_100000( - "WebRTC.Call.TimeReceivingVideoRtpPacketsInSeconds", - (*last_received_rtp_video_ms_ - *first_received_rtp_video_ms_) / 1000); - } - const int kMinRequiredPeriodicSamples = 5; - AggregatedStats video_bytes_per_sec = - received_video_bytes_per_second_counter_.GetStats(); - if (video_bytes_per_sec.num_samples > kMinRequiredPeriodicSamples) { - RTC_HISTOGRAM_COUNTS_100000("WebRTC.Call.VideoBitrateReceivedInKbps", - video_bytes_per_sec.average * 8 / 1000); - RTC_LOG(LS_INFO) << "WebRTC.Call.VideoBitrateReceivedInBps, " - << video_bytes_per_sec.ToStringWithMultiplier(8); - } - AggregatedStats audio_bytes_per_sec = - received_audio_bytes_per_second_counter_.GetStats(); - if (audio_bytes_per_sec.num_samples > kMinRequiredPeriodicSamples) { - RTC_HISTOGRAM_COUNTS_100000("WebRTC.Call.AudioBitrateReceivedInKbps", - audio_bytes_per_sec.average * 8 / 1000); - RTC_LOG(LS_INFO) << "WebRTC.Call.AudioBitrateReceivedInBps, " - << audio_bytes_per_sec.ToStringWithMultiplier(8); - } - AggregatedStats rtcp_bytes_per_sec = - received_rtcp_bytes_per_second_counter_.GetStats(); - if (rtcp_bytes_per_sec.num_samples > kMinRequiredPeriodicSamples) { - RTC_HISTOGRAM_COUNTS_100000("WebRTC.Call.RtcpBitrateReceivedInBps", - rtcp_bytes_per_sec.average * 8); - RTC_LOG(LS_INFO) << "WebRTC.Call.RtcpBitrateReceivedInBps, " - << rtcp_bytes_per_sec.ToStringWithMultiplier(8); - } - AggregatedStats recv_bytes_per_sec = - received_bytes_per_second_counter_.GetStats(); - if (recv_bytes_per_sec.num_samples > kMinRequiredPeriodicSamples) { - RTC_HISTOGRAM_COUNTS_100000("WebRTC.Call.BitrateReceivedInKbps", - recv_bytes_per_sec.average * 8 / 1000); - RTC_LOG(LS_INFO) << "WebRTC.Call.BitrateReceivedInBps, " - << recv_bytes_per_sec.ToStringWithMultiplier(8); - } -} - PacketReceiver* Call::Receiver() { - RTC_DCHECK_RUN_ON(worker_thread_); return this; } @@ -784,20 +888,22 @@ webrtc::AudioSendStream* Call::CreateAudioSendStream( AudioSendStream* send_stream = new AudioSendStream( clock_, config, config_.audio_state, task_queue_factory_, - module_process_thread_->process_thread(), transport_send_ptr_, - bitrate_allocator_.get(), event_log_, call_stats_->AsRtcpRttStats(), - suspended_rtp_state); + transport_send_.get(), bitrate_allocator_.get(), event_log_, + call_stats_->AsRtcpRttStats(), suspended_rtp_state); RTC_DCHECK(audio_send_ssrcs_.find(config.rtp.ssrc) == audio_send_ssrcs_.end()); audio_send_ssrcs_[config.rtp.ssrc] = send_stream; + // TODO(bugs.webrtc.org/11993): call AssociateSendStream and + // UpdateAggregateNetworkState asynchronously on the network thread. for (AudioReceiveStream* stream : audio_receive_streams_) { - if (stream->config().rtp.local_ssrc == config.rtp.ssrc) { + if (stream->local_ssrc() == config.rtp.ssrc) { stream->AssociateSendStream(send_stream); } } UpdateAggregateNetworkState(); + return send_stream; } @@ -816,13 +922,16 @@ void Call::DestroyAudioSendStream(webrtc::AudioSendStream* send_stream) { size_t num_deleted = audio_send_ssrcs_.erase(ssrc); RTC_DCHECK_EQ(1, num_deleted); + // TODO(bugs.webrtc.org/11993): call AssociateSendStream and + // UpdateAggregateNetworkState asynchronously on the network thread. for (AudioReceiveStream* stream : audio_receive_streams_) { - if (stream->config().rtp.local_ssrc == ssrc) { + if (stream->local_ssrc() == ssrc) { stream->AssociateSendStream(nullptr); } } UpdateAggregateNetworkState(); + delete send_stream; } @@ -833,14 +942,22 @@ webrtc::AudioReceiveStream* Call::CreateAudioReceiveStream( EnsureStarted(); event_log_->Log(std::make_unique( CreateRtcLogStreamConfig(config))); + AudioReceiveStream* receive_stream = new AudioReceiveStream( - clock_, &audio_receiver_controller_, transport_send_ptr_->packet_router(), - module_process_thread_->process_thread(), config_.neteq_factory, config, + clock_, transport_send_->packet_router(), config_.neteq_factory, config, config_.audio_state, event_log_); - - receive_rtp_config_.emplace(config.rtp.remote_ssrc, ReceiveRtpConfig(config)); audio_receive_streams_.insert(receive_stream); + // TODO(bugs.webrtc.org/11993): Make the registration on the network thread + // (asynchronously). The registration and `audio_receiver_controller_` need + // to live on the network thread. + receive_stream->RegisterWithTransport(&audio_receiver_controller_); + + // TODO(bugs.webrtc.org/11993): Update the below on the network thread. + // We could possibly set up the audio_receiver_controller_ association up + // as part of the async setup. + receive_rtp_config_.emplace(config.rtp.remote_ssrc, receive_stream); + ConfigureSync(config.sync_group); auto it = audio_send_ssrcs_.find(config.rtp.local_ssrc); @@ -860,20 +977,29 @@ void Call::DestroyAudioReceiveStream( webrtc::internal::AudioReceiveStream* audio_receive_stream = static_cast(receive_stream); + // TODO(bugs.webrtc.org/11993): Access the map, rtp config, call ConfigureSync + // and UpdateAggregateNetworkState on the network thread. The call to + // `UnregisterFromTransport` should also happen on the network thread. + audio_receive_stream->UnregisterFromTransport(); + + uint32_t ssrc = audio_receive_stream->remote_ssrc(); const AudioReceiveStream::Config& config = audio_receive_stream->config(); - uint32_t ssrc = config.rtp.remote_ssrc; - receive_side_cc_.GetRemoteBitrateEstimator(UseSendSideBwe(config)) + receive_side_cc_.GetRemoteBitrateEstimator(UseSendSideBwe(config.rtp)) ->RemoveStream(ssrc); + audio_receive_streams_.erase(audio_receive_stream); - const std::string& sync_group = audio_receive_stream->config().sync_group; - const auto it = sync_stream_mapping_.find(sync_group); + + const auto it = sync_stream_mapping_.find(config.sync_group); if (it != sync_stream_mapping_.end() && it->second == audio_receive_stream) { sync_stream_mapping_.erase(it); - ConfigureSync(sync_group); + ConfigureSync(config.sync_group); } receive_rtp_config_.erase(ssrc); UpdateAggregateNetworkState(); + // TODO(bugs.webrtc.org/11993): Consider if deleting |audio_receive_stream| + // on the network thread would be better or if we'd need to tear down the + // state in two phases. delete audio_receive_stream; } @@ -900,8 +1026,8 @@ webrtc::VideoSendStream* Call::CreateVideoSendStream( std::vector ssrcs = config.rtp.ssrcs; VideoSendStream* send_stream = new VideoSendStream( - clock_, num_cpu_cores_, module_process_thread_->process_thread(), - task_queue_factory_, call_stats_->AsRtcpRttStats(), transport_send_ptr_, + clock_, num_cpu_cores_, task_queue_factory_, + call_stats_->AsRtcpRttStats(), transport_send_.get(), bitrate_allocator_.get(), video_send_delay_stats_.get(), event_log_, std::move(config), std::move(encoder_config), suspended_video_send_ssrcs_, suspended_video_payload_states_, std::move(fec_controller)); @@ -911,6 +1037,8 @@ webrtc::VideoSendStream* Call::CreateVideoSendStream( video_send_ssrcs_[ssrc] = send_stream; } video_send_streams_.insert(send_stream); + video_send_streams_empty_.store(false, std::memory_order_relaxed); + // Forward resources that were previously added to the call to the new stream. for (const auto& resource_forwarder : adaptation_resource_forwarders_) { resource_forwarder->OnCreateVideoSendStream(send_stream); @@ -924,6 +1052,7 @@ webrtc::VideoSendStream* Call::CreateVideoSendStream( webrtc::VideoSendStream* Call::CreateVideoSendStream( webrtc::VideoSendStream::Config config, VideoEncoderConfig encoder_config) { + RTC_DCHECK_RUN_ON(worker_thread_); if (config_.fec_controller_factory) { RTC_LOG(LS_INFO) << "External FEC Controller will be used."; } @@ -940,9 +1069,12 @@ void Call::DestroyVideoSendStream(webrtc::VideoSendStream* send_stream) { RTC_DCHECK(send_stream != nullptr); RTC_DCHECK_RUN_ON(worker_thread_); - send_stream->Stop(); - - VideoSendStream* send_stream_impl = nullptr; + VideoSendStream* send_stream_impl = + static_cast(send_stream); + VideoSendStream::RtpStateMap rtp_states; + VideoSendStream::RtpPayloadStateMap rtp_payload_states; + send_stream_impl->StopPermanentlyAndGetRtpStates(&rtp_states, + &rtp_payload_states); auto it = video_send_ssrcs_.begin(); while (it != video_send_ssrcs_.end()) { @@ -953,18 +1085,15 @@ void Call::DestroyVideoSendStream(webrtc::VideoSendStream* send_stream) { ++it; } } + // Stop forwarding resources to the stream being destroyed. for (const auto& resource_forwarder : adaptation_resource_forwarders_) { resource_forwarder->OnDestroyVideoSendStream(send_stream_impl); } video_send_streams_.erase(send_stream_impl); + if (video_send_streams_.empty()) + video_send_streams_empty_.store(true, std::memory_order_relaxed); - RTC_CHECK(send_stream_impl != nullptr); - - VideoSendStream::RtpStateMap rtp_states; - VideoSendStream::RtpPayloadStateMap rtp_payload_states; - send_stream_impl->StopPermanentlyAndGetRtpStates(&rtp_states, - &rtp_payload_states); for (const auto& kv : rtp_states) { suspended_video_send_ssrcs_[kv.first] = kv.second; } @@ -973,6 +1102,8 @@ void Call::DestroyVideoSendStream(webrtc::VideoSendStream* send_stream) { } UpdateAggregateNetworkState(); + // TODO(tommi): consider deleting on the same thread as runs + // StopPermanentlyAndGetRtpStates. delete send_stream_impl; } @@ -986,13 +1117,17 @@ webrtc::VideoReceiveStream* Call::CreateVideoReceiveStream( EnsureStarted(); - TaskQueueBase* current = GetCurrentTaskQueueOrThread(); - RTC_CHECK(current); + // TODO(bugs.webrtc.org/11993): Move the registration between |receive_stream| + // and |video_receiver_controller_| out of VideoReceiveStream2 construction + // and set it up asynchronously on the network thread (the registration and + // |video_receiver_controller_| need to live on the network thread). VideoReceiveStream2* receive_stream = new VideoReceiveStream2( - task_queue_factory_, current, &video_receiver_controller_, num_cpu_cores_, - transport_send_ptr_->packet_router(), std::move(configuration), - module_process_thread_->process_thread(), call_stats_.get(), clock_, - new VCMTiming(clock_)); + task_queue_factory_, this, num_cpu_cores_, + transport_send_->packet_router(), std::move(configuration), + call_stats_.get(), clock_, new VCMTiming(clock_)); + // TODO(bugs.webrtc.org/11993): Set this up asynchronously on the network + // thread. + receive_stream->RegisterWithTransport(&video_receiver_controller_); const webrtc::VideoReceiveStream::Config& config = receive_stream->config(); if (config.rtp.rtx_ssrc) { @@ -1000,9 +1135,9 @@ webrtc::VideoReceiveStream* Call::CreateVideoReceiveStream( // stream. Since the transport_send_cc negotiation is per payload // type, we may get an incorrect value for the rtx stream, but // that is unlikely to matter in practice. - receive_rtp_config_.emplace(config.rtp.rtx_ssrc, ReceiveRtpConfig(config)); + receive_rtp_config_.emplace(config.rtp.rtx_ssrc, receive_stream); } - receive_rtp_config_.emplace(config.rtp.remote_ssrc, ReceiveRtpConfig(config)); + receive_rtp_config_.emplace(config.rtp.remote_ssrc, receive_stream); video_receive_streams_.insert(receive_stream); ConfigureSync(config.sync_group); @@ -1020,6 +1155,9 @@ void Call::DestroyVideoReceiveStream( RTC_DCHECK(receive_stream != nullptr); VideoReceiveStream2* receive_stream_impl = static_cast(receive_stream); + // TODO(bugs.webrtc.org/11993): Unregister on the network thread. + receive_stream_impl->UnregisterFromTransport(); + const VideoReceiveStream::Config& config = receive_stream_impl->config(); // Remove all ssrcs pointing to a receive stream. As RTX retransmits on a @@ -1031,7 +1169,7 @@ void Call::DestroyVideoReceiveStream( video_receive_streams_.erase(receive_stream_impl); ConfigureSync(config.sync_group); - receive_side_cc_.GetRemoteBitrateEstimator(UseSendSideBwe(config)) + receive_side_cc_.GetRemoteBitrateEstimator(UseSendSideBwe(config.rtp)) ->RemoveStream(config.rtp.remote_ssrc); UpdateAggregateNetworkState(); @@ -1054,12 +1192,15 @@ FlexfecReceiveStream* Call::CreateFlexfecReceiveStream( // OnRtpPacket until the constructor is finished and the object is // in a valid state, since OnRtpPacket runs on the same thread. receive_stream = new FlexfecReceiveStreamImpl( - clock_, &video_receiver_controller_, config, recovered_packet_receiver, - call_stats_->AsRtcpRttStats(), module_process_thread_->process_thread()); + clock_, config, recovered_packet_receiver, call_stats_->AsRtcpRttStats()); - RTC_DCHECK(receive_rtp_config_.find(config.remote_ssrc) == + // TODO(bugs.webrtc.org/11993): Set this up asynchronously on the network + // thread. + receive_stream->RegisterWithTransport(&video_receiver_controller_); + + RTC_DCHECK(receive_rtp_config_.find(config.rtp.remote_ssrc) == receive_rtp_config_.end()); - receive_rtp_config_.emplace(config.remote_ssrc, ReceiveRtpConfig(config)); + receive_rtp_config_.emplace(config.rtp.remote_ssrc, receive_stream); // TODO(brandtr): Store config in RtcEventLog here. @@ -1070,15 +1211,19 @@ void Call::DestroyFlexfecReceiveStream(FlexfecReceiveStream* receive_stream) { TRACE_EVENT0("webrtc", "Call::DestroyFlexfecReceiveStream"); RTC_DCHECK_RUN_ON(worker_thread_); + FlexfecReceiveStreamImpl* receive_stream_impl = + static_cast(receive_stream); + // TODO(bugs.webrtc.org/11993): Unregister on the network thread. + receive_stream_impl->UnregisterFromTransport(); + RTC_DCHECK(receive_stream != nullptr); - const FlexfecReceiveStream::Config& config = receive_stream->GetConfig(); - uint32_t ssrc = config.remote_ssrc; - receive_rtp_config_.erase(ssrc); + const FlexfecReceiveStream::RtpConfig& rtp = receive_stream->rtp_config(); + receive_rtp_config_.erase(rtp.remote_ssrc); // Remove all SSRCs pointing to the FlexfecReceiveStreamImpl to be // destroyed. - receive_side_cc_.GetRemoteBitrateEstimator(UseSendSideBwe(config)) - ->RemoveStream(ssrc); + receive_side_cc_.GetRemoteBitrateEstimator(UseSendSideBwe(rtp)) + ->RemoveStream(rtp.remote_ssrc); delete receive_stream; } @@ -1094,7 +1239,7 @@ void Call::AddAdaptationResource(rtc::scoped_refptr resource) { } RtpTransportControllerSendInterface* Call::GetTransportControllerSend() { - return transport_send_ptr_; + return transport_send_.get(); } Call::Stats Call::GetStats() const { @@ -1104,7 +1249,7 @@ Call::Stats Call::GetStats() const { // TODO(srte): It is unclear if we only want to report queues if network is // available. stats.pacer_delay_ms = - aggregate_network_up_ ? transport_send_ptr_->GetPacerQueuingDelayMs() : 0; + aggregate_network_up_ ? transport_send_->GetPacerQueuingDelayMs() : 0; stats.rtt_ms = call_stats_->LastProcessedRtt(); @@ -1114,45 +1259,75 @@ Call::Stats Call::GetStats() const { receive_side_cc_.GetRemoteBitrateEstimator(false)->LatestEstimate( &ssrcs, &recv_bandwidth); stats.recv_bandwidth_bps = recv_bandwidth; - stats.send_bandwidth_bps = last_bandwidth_bps_; - stats.max_padding_bitrate_bps = configured_max_padding_bitrate_bps_; + stats.send_bandwidth_bps = + last_bandwidth_bps_.load(std::memory_order_relaxed); + stats.max_padding_bitrate_bps = + configured_max_padding_bitrate_bps_.load(std::memory_order_relaxed); return stats; } const WebRtcKeyValueConfig& Call::trials() const { - return *config_.trials; + return trials_; +} + +TaskQueueBase* Call::network_thread() const { + return network_thread_; +} + +TaskQueueBase* Call::worker_thread() const { + return worker_thread_; } void Call::SignalChannelNetworkState(MediaType media, NetworkState state) { - RTC_DCHECK_RUN_ON(worker_thread_); - switch (media) { - case MediaType::AUDIO: + RTC_DCHECK_RUN_ON(network_thread_); + RTC_DCHECK(media == MediaType::AUDIO || media == MediaType::VIDEO); + + auto closure = [this, media, state]() { + // TODO(bugs.webrtc.org/11993): Move this over to the network thread. + RTC_DCHECK_RUN_ON(worker_thread_); + if (media == MediaType::AUDIO) { audio_network_state_ = state; - break; - case MediaType::VIDEO: + } else { + RTC_DCHECK_EQ(media, MediaType::VIDEO); video_network_state_ = state; - break; - case MediaType::ANY: - case MediaType::DATA: - RTC_NOTREACHED(); - break; - } + } - UpdateAggregateNetworkState(); - for (VideoReceiveStream2* video_receive_stream : video_receive_streams_) { - video_receive_stream->SignalNetworkState(video_network_state_); + // TODO(tommi): Is it necessary to always do this, including if there + // was no change in state? + UpdateAggregateNetworkState(); + + // TODO(tommi): Is it right to do this if media == AUDIO? + for (VideoReceiveStream2* video_receive_stream : video_receive_streams_) { + video_receive_stream->SignalNetworkState(video_network_state_); + } + }; + + if (network_thread_ == worker_thread_) { + closure(); + } else { + // TODO(bugs.webrtc.org/11993): Remove workaround when we no longer need to + // post to the worker thread. + worker_thread_->PostTask(ToQueuedTask(task_safety_, std::move(closure))); } } void Call::OnAudioTransportOverheadChanged(int transport_overhead_per_packet) { - RTC_DCHECK_RUN_ON(worker_thread_); - for (auto& kv : audio_send_ssrcs_) { - kv.second->SetTransportOverhead(transport_overhead_per_packet); - } + RTC_DCHECK_RUN_ON(network_thread_); + worker_thread_->PostTask( + ToQueuedTask(task_safety_, [this, transport_overhead_per_packet]() { + // TODO(bugs.webrtc.org/11993): Move this over to the network thread. + RTC_DCHECK_RUN_ON(worker_thread_); + for (auto& kv : audio_send_ssrcs_) { + kv.second->SetTransportOverhead(transport_overhead_per_packet); + } + })); } void Call::UpdateAggregateNetworkState() { + // TODO(bugs.webrtc.org/11993): Move this over to the network thread. + // RTC_DCHECK_RUN_ON(network_thread_); + RTC_DCHECK_RUN_ON(worker_thread_); bool have_audio = @@ -1175,63 +1350,82 @@ void Call::UpdateAggregateNetworkState() { } aggregate_network_up_ = aggregate_network_up; - transport_send_ptr_->OnNetworkAvailability(aggregate_network_up); + transport_send_->OnNetworkAvailability(aggregate_network_up); +} + +void Call::OnLocalSsrcUpdated(webrtc::AudioReceiveStream& stream, + uint32_t local_ssrc) { + RTC_DCHECK_RUN_ON(worker_thread_); + webrtc::internal::AudioReceiveStream& receive_stream = + static_cast(stream); + + receive_stream.SetLocalSsrc(local_ssrc); + auto it = audio_send_ssrcs_.find(local_ssrc); + receive_stream.AssociateSendStream(it != audio_send_ssrcs_.end() ? it->second + : nullptr); +} + +void Call::OnUpdateSyncGroup(webrtc::AudioReceiveStream& stream, + const std::string& sync_group) { + RTC_DCHECK_RUN_ON(worker_thread_); + webrtc::internal::AudioReceiveStream& receive_stream = + static_cast(stream); + receive_stream.SetSyncGroup(sync_group); + ConfigureSync(sync_group); } void Call::OnSentPacket(const rtc::SentPacket& sent_packet) { + // In production and with most tests, this method will be called on the + // network thread. However some test classes such as DirectTransport don't + // incorporate a network thread. This means that tests for RtpSenderEgress + // and ModuleRtpRtcpImpl2 that use DirectTransport, will call this method + // on a ProcessThread. This is alright as is since we forward the call to + // implementations that either just do a PostTask or use locking. video_send_delay_stats_->OnSentPacket(sent_packet.packet_id, clock_->TimeInMilliseconds()); - transport_send_ptr_->OnSentPacket(sent_packet); + transport_send_->OnSentPacket(sent_packet); } void Call::OnStartRateUpdate(DataRate start_rate) { - RTC_DCHECK_RUN_ON(send_transport_queue()); + RTC_DCHECK_RUN_ON(&send_transport_sequence_checker_); bitrate_allocator_->UpdateStartRate(start_rate.bps()); } void Call::OnTargetTransferRate(TargetTransferRate msg) { - RTC_DCHECK_RUN_ON(send_transport_queue()); + RTC_DCHECK_RUN_ON(&send_transport_sequence_checker_); uint32_t target_bitrate_bps = msg.target_rate.bps(); // For controlling the rate of feedback messages. receive_side_cc_.OnBitrateChanged(target_bitrate_bps); bitrate_allocator_->OnNetworkEstimateChanged(msg); - worker_thread_->PostTask( - ToQueuedTask(task_safety_, [this, target_bitrate_bps]() { - RTC_DCHECK_RUN_ON(worker_thread_); - last_bandwidth_bps_ = target_bitrate_bps; - - // Ignore updates if bitrate is zero (the aggregate network state is - // down) or if we're not sending video. - if (target_bitrate_bps == 0 || video_send_streams_.empty()) { - estimated_send_bitrate_kbps_counter_.ProcessAndPause(); - pacer_bitrate_kbps_counter_.ProcessAndPause(); - return; - } + last_bandwidth_bps_.store(target_bitrate_bps, std::memory_order_relaxed); - estimated_send_bitrate_kbps_counter_.Add(target_bitrate_bps / 1000); - // Pacer bitrate may be higher than bitrate estimate if enforcing min - // bitrate. - uint32_t pacer_bitrate_bps = - std::max(target_bitrate_bps, min_allocated_send_bitrate_bps_); - pacer_bitrate_kbps_counter_.Add(pacer_bitrate_bps / 1000); - })); + // Ignore updates if bitrate is zero (the aggregate network state is + // down) or if we're not sending video. + // Using |video_send_streams_empty_| is racy but as the caller can't + // reasonably expect synchronize with changes in |video_send_streams_| (being + // on |send_transport_sequence_checker|), we can avoid a PostTask this way. + if (target_bitrate_bps == 0 || + video_send_streams_empty_.load(std::memory_order_relaxed)) { + send_stats_.PauseSendAndPacerBitrateCounters(); + } else { + send_stats_.AddTargetBitrateSample(target_bitrate_bps); + } } void Call::OnAllocationLimitsChanged(BitrateAllocationLimits limits) { - RTC_DCHECK_RUN_ON(send_transport_queue()); + RTC_DCHECK_RUN_ON(&send_transport_sequence_checker_); transport_send_ptr_->SetAllocatedSendBitrateLimits(limits); - - worker_thread_->PostTask(ToQueuedTask(task_safety_, [this, limits]() { - RTC_DCHECK_RUN_ON(worker_thread_); - min_allocated_send_bitrate_bps_ = limits.min_allocatable_rate.bps(); - configured_max_padding_bitrate_bps_ = limits.max_padding_rate.bps(); - })); + send_stats_.SetMinAllocatableRate(limits); + configured_max_padding_bitrate_bps_.store(limits.max_padding_rate.bps(), + std::memory_order_relaxed); } +// RTC_RUN_ON(worker_thread_) void Call::ConfigureSync(const std::string& sync_group) { + // TODO(bugs.webrtc.org/11993): Expect to be called on the network thread. // Set sync only if there was no previous one. if (sync_group.empty()) return; @@ -1281,56 +1475,62 @@ void Call::ConfigureSync(const std::string& sync_group) { } } -PacketReceiver::DeliveryStatus Call::DeliverRtcp(MediaType media_type, - const uint8_t* packet, - size_t length) { +// RTC_RUN_ON(network_thread_) +void Call::DeliverRtcp(MediaType media_type, rtc::CopyOnWriteBuffer packet) { TRACE_EVENT0("webrtc", "Call::DeliverRtcp"); - // TODO(pbos): Make sure it's a valid packet. - // Return DELIVERY_UNKNOWN_SSRC if it can be determined that - // there's no receiver of the packet. - if (received_bytes_per_second_counter_.HasSample()) { - // First RTP packet has been received. - received_bytes_per_second_counter_.Add(static_cast(length)); - received_rtcp_bytes_per_second_counter_.Add(static_cast(length)); - } - bool rtcp_delivered = false; - if (media_type == MediaType::ANY || media_type == MediaType::VIDEO) { - for (VideoReceiveStream2* stream : video_receive_streams_) { - if (stream->DeliverRtcp(packet, length)) - rtcp_delivered = true; - } - } - if (media_type == MediaType::ANY || media_type == MediaType::AUDIO) { - for (AudioReceiveStream* stream : audio_receive_streams_) { - stream->DeliverRtcp(packet, length); - rtcp_delivered = true; - } - } - if (media_type == MediaType::ANY || media_type == MediaType::VIDEO) { - for (VideoSendStream* stream : video_send_streams_) { - stream->DeliverRtcp(packet, length); - rtcp_delivered = true; - } - } - if (media_type == MediaType::ANY || media_type == MediaType::AUDIO) { - for (auto& kv : audio_send_ssrcs_) { - kv.second->DeliverRtcp(packet, length); - rtcp_delivered = true; - } - } - if (rtcp_delivered) { - event_log_->Log(std::make_unique( - rtc::MakeArrayView(packet, length))); - } + // TODO(bugs.webrtc.org/11993): This DCHECK is here just to maintain the + // invariant that currently the only call path to this function is via + // `PeerConnection::InitializeRtcpCallback()`. DeliverRtp on the other hand + // gets called via the channel classes and + // WebRtc[Audio|Video]Channel's `OnPacketReceived`. We'll remove the + // PeerConnection involvement as well as + // `JsepTransportController::OnRtcpPacketReceived_n` and `rtcp_handler` + // and make sure that the flow of packets is consistent from the + // `RtpTransport` class, via the *Channel and *Engine classes and into Call. + // This way we'll also know more about the context of the packet. + RTC_DCHECK_EQ(media_type, MediaType::ANY); + + // TODO(bugs.webrtc.org/11993): This should execute directly on the network + // thread. + worker_thread_->PostTask( + ToQueuedTask(task_safety_, [this, packet = std::move(packet)]() { + RTC_DCHECK_RUN_ON(worker_thread_); + + receive_stats_.AddReceivedRtcpBytes(static_cast(packet.size())); + bool rtcp_delivered = false; + for (VideoReceiveStream2* stream : video_receive_streams_) { + if (stream->DeliverRtcp(packet.cdata(), packet.size())) + rtcp_delivered = true; + } + + for (AudioReceiveStream* stream : audio_receive_streams_) { + stream->DeliverRtcp(packet.cdata(), packet.size()); + rtcp_delivered = true; + } + + for (VideoSendStream* stream : video_send_streams_) { + stream->DeliverRtcp(packet.cdata(), packet.size()); + rtcp_delivered = true; + } + + for (auto& kv : audio_send_ssrcs_) { + kv.second->DeliverRtcp(packet.cdata(), packet.size()); + rtcp_delivered = true; + } - return rtcp_delivered ? DELIVERY_OK : DELIVERY_PACKET_ERROR; + if (rtcp_delivered) { + event_log_->Log(std::make_unique( + rtc::MakeArrayView(packet.cdata(), packet.size()))); + } + })); } PacketReceiver::DeliveryStatus Call::DeliverRtp(MediaType media_type, rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) { TRACE_EVENT0("webrtc", "Call::DeliverRtp"); + RTC_DCHECK_NE(media_type, MediaType::ANY); RtpPacketReceived parsed_packet; if (!parsed_packet.Parse(std::move(packet))) @@ -1343,9 +1543,9 @@ PacketReceiver::DeliveryStatus Call::DeliverRtp(MediaType media_type, packet_time_us = receive_time_calculator_->ReconcileReceiveTimes( packet_time_us, rtc::TimeUTCMicros(), clock_->TimeInMicroseconds()); } - parsed_packet.set_arrival_time_ms((packet_time_us + 500) / 1000); + parsed_packet.set_arrival_time(Timestamp::Micros(packet_time_us)); } else { - parsed_packet.set_arrival_time_ms(clock_->TimeInMilliseconds()); + parsed_packet.set_arrival_time(clock_->CurrentTime()); } // We might get RTP keep-alive packets in accordance with RFC6263 section 4.6. @@ -1368,7 +1568,8 @@ PacketReceiver::DeliveryStatus Call::DeliverRtp(MediaType media_type, return DELIVERY_UNKNOWN_SSRC; } - parsed_packet.IdentifyExtensions(it->second.extensions); + parsed_packet.IdentifyExtensions( + RtpHeaderExtensionMap(it->second->rtp_config().extensions)); NotifyBweOfReceivedPacket(parsed_packet, media_type); @@ -1377,29 +1578,19 @@ PacketReceiver::DeliveryStatus Call::DeliverRtp(MediaType media_type, int length = static_cast(parsed_packet.size()); if (media_type == MediaType::AUDIO) { if (audio_receiver_controller_.OnRtpPacket(parsed_packet)) { - received_bytes_per_second_counter_.Add(length); - received_audio_bytes_per_second_counter_.Add(length); + receive_stats_.AddReceivedAudioBytes(length, + parsed_packet.arrival_time()); event_log_->Log( std::make_unique(parsed_packet)); - const int64_t arrival_time_ms = parsed_packet.arrival_time_ms(); - if (!first_received_rtp_audio_ms_) { - first_received_rtp_audio_ms_.emplace(arrival_time_ms); - } - last_received_rtp_audio_ms_.emplace(arrival_time_ms); return DELIVERY_OK; } } else if (media_type == MediaType::VIDEO) { parsed_packet.set_payload_type_frequency(kVideoPayloadTypeFrequency); if (video_receiver_controller_.OnRtpPacket(parsed_packet)) { - received_bytes_per_second_counter_.Add(length); - received_video_bytes_per_second_counter_.Add(length); + receive_stats_.AddReceivedVideoBytes(length, + parsed_packet.arrival_time()); event_log_->Log( std::make_unique(parsed_packet)); - const int64_t arrival_time_ms = parsed_packet.arrival_time_ms(); - if (!first_received_rtp_video_ms_) { - first_received_rtp_video_ms_.emplace(arrival_time_ms); - } - last_received_rtp_video_ms_.emplace(arrival_time_ms); return DELIVERY_OK; } } @@ -1410,15 +1601,20 @@ PacketReceiver::DeliveryStatus Call::DeliverPacket( MediaType media_type, rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) { - RTC_DCHECK_RUN_ON(worker_thread_); - - if (IsRtcp(packet.cdata(), packet.size())) - return DeliverRtcp(media_type, packet.cdata(), packet.size()); + if (IsRtcpPacket(packet)) { + RTC_DCHECK_RUN_ON(network_thread_); + DeliverRtcp(media_type, std::move(packet)); + return DELIVERY_OK; + } + RTC_DCHECK_RUN_ON(worker_thread_); return DeliverRtp(media_type, std::move(packet), packet_time_us); } void Call::OnRecoveredPacket(const uint8_t* packet, size_t length) { + // TODO(bugs.webrtc.org/11993): Expect to be called on the network thread. + // This method is called synchronously via |OnRtpPacket()| (see DeliverRtp) + // on the same thread. RTC_DCHECK_RUN_ON(worker_thread_); RtpPacketReceived parsed_packet; if (!parsed_packet.Parse(packet, length)) @@ -1438,29 +1634,31 @@ void Call::OnRecoveredPacket(const uint8_t* packet, size_t length) { // which is being torn down. return; } - parsed_packet.IdentifyExtensions(it->second.extensions); + parsed_packet.IdentifyExtensions( + RtpHeaderExtensionMap(it->second->rtp_config().extensions)); // TODO(brandtr): Update here when we support protecting audio packets too. parsed_packet.set_payload_type_frequency(kVideoPayloadTypeFrequency); video_receiver_controller_.OnRtpPacket(parsed_packet); } +// RTC_RUN_ON(worker_thread_) void Call::NotifyBweOfReceivedPacket(const RtpPacketReceived& packet, MediaType media_type) { auto it = receive_rtp_config_.find(packet.Ssrc()); - bool use_send_side_bwe = - (it != receive_rtp_config_.end()) && it->second.use_send_side_bwe; + bool use_send_side_bwe = (it != receive_rtp_config_.end()) && + UseSendSideBwe(it->second->rtp_config()); RTPHeader header; packet.GetHeader(&header); ReceivedPacket packet_msg; packet_msg.size = DataSize::Bytes(packet.payload_size()); - packet_msg.receive_time = Timestamp::Millis(packet.arrival_time_ms()); + packet_msg.receive_time = packet.arrival_time(); if (header.extension.hasAbsoluteSendTime) { packet_msg.send_time = header.extension.GetAbsoluteSendTimestamp(); } - transport_send_ptr_->OnReceivedPacket(packet_msg); + transport_send_->OnReceivedPacket(packet_msg); if (!use_send_side_bwe && header.extension.hasTransportSequenceNumber) { // Inconsistent configuration of send side BWE. Do nothing. @@ -1476,8 +1674,8 @@ void Call::NotifyBweOfReceivedPacket(const RtpPacketReceived& packet, if (media_type == MediaType::VIDEO || (use_send_side_bwe && header.extension.hasTransportSequenceNumber)) { receive_side_cc_.OnReceivedPacket( - packet.arrival_time_ms(), packet.payload_size() + packet.padding_size(), - header); + packet.arrival_time().ms(), + packet.payload_size() + packet.padding_size(), header); } } diff --git a/call/call.h b/call/call.h index a2b3b89598..f6388c3c78 100644 --- a/call/call.h +++ b/call/call.h @@ -17,6 +17,7 @@ #include "api/adaptation/resource.h" #include "api/media_types.h" +#include "api/task_queue/task_queue_base.h" #include "call/audio_receive_stream.h" #include "call/audio_send_stream.h" #include "call/call_config.h" @@ -82,12 +83,15 @@ class Call { }; static Call* Create(const Call::Config& config); - static Call* Create(const Call::Config& config, - rtc::scoped_refptr call_thread); static Call* Create(const Call::Config& config, Clock* clock, rtc::scoped_refptr call_thread, std::unique_ptr pacer_thread); + static Call* Create(const Call::Config& config, + Clock* clock, + rtc::scoped_refptr call_thread, + std::unique_ptr + transportControllerSend); virtual AudioSendStream* CreateAudioSendStream( const AudioSendStream::Config& config) = 0; @@ -151,6 +155,14 @@ class Call { virtual void OnAudioTransportOverheadChanged( int transport_overhead_per_packet) = 0; + // Called when a receive stream's local ssrc has changed and association with + // send streams needs to be updated. + virtual void OnLocalSsrcUpdated(AudioReceiveStream& stream, + uint32_t local_ssrc) = 0; + + virtual void OnUpdateSyncGroup(AudioReceiveStream& stream, + const std::string& sync_group) = 0; + virtual void OnSentPacket(const rtc::SentPacket& sent_packet) = 0; virtual void SetClientBitratePreferences( @@ -158,6 +170,9 @@ class Call { virtual const WebRtcKeyValueConfig& trials() const = 0; + virtual TaskQueueBase* network_thread() const = 0; + virtual TaskQueueBase* worker_thread() const = 0; + virtual ~Call() {} }; diff --git a/call/call_config.cc b/call/call_config.cc index b149c889ea..23b60ce436 100644 --- a/call/call_config.cc +++ b/call/call_config.cc @@ -14,12 +14,27 @@ namespace webrtc { -CallConfig::CallConfig(RtcEventLog* event_log) : event_log(event_log) { +CallConfig::CallConfig(RtcEventLog* event_log, + TaskQueueBase* network_task_queue /* = nullptr*/) + : event_log(event_log), network_task_queue_(network_task_queue) { RTC_DCHECK(event_log); } CallConfig::CallConfig(const CallConfig& config) = default; +RtpTransportConfig CallConfig::ExtractTransportConfig() const { + RtpTransportConfig transportConfig; + transportConfig.bitrate_config = bitrate_config; + transportConfig.event_log = event_log; + transportConfig.network_controller_factory = network_controller_factory; + transportConfig.network_state_predictor_factory = + network_state_predictor_factory; + transportConfig.task_queue_factory = task_queue_factory; + transportConfig.trials = trials; + + return transportConfig; +} + CallConfig::~CallConfig() = default; } // namespace webrtc diff --git a/call/call_config.h b/call/call_config.h index 205f7a48bb..ba6dec3ad6 100644 --- a/call/call_config.h +++ b/call/call_config.h @@ -19,6 +19,8 @@ #include "api/transport/network_control.h" #include "api/transport/webrtc_key_value_config.h" #include "call/audio_state.h" +#include "call/rtp_transport_config.h" +#include "call/rtp_transport_controller_send_factory_interface.h" namespace webrtc { @@ -26,8 +28,13 @@ class AudioProcessing; class RtcEventLog; struct CallConfig { - explicit CallConfig(RtcEventLog* event_log); + // If |network_task_queue| is set to nullptr, Call will assume that network + // related callbacks will be made on the same TQ as the Call instance was + // constructed on. + explicit CallConfig(RtcEventLog* event_log, + TaskQueueBase* network_task_queue = nullptr); CallConfig(const CallConfig&); + RtpTransportConfig ExtractTransportConfig() const; ~CallConfig(); // Bitrate config used until valid bitrate estimates are calculated. Also @@ -42,7 +49,7 @@ struct CallConfig { // RtcEventLog to use for this call. Required. // Use webrtc::RtcEventLog::CreateNull() for a null implementation. - RtcEventLog* event_log = nullptr; + RtcEventLog* const event_log = nullptr; // FecController to use for this call. FecControllerFactoryInterface* fec_controller_factory = nullptr; @@ -63,6 +70,11 @@ struct CallConfig { // Key-value mapping of internal configurations to apply, // e.g. field trials. const WebRtcKeyValueConfig* trials = nullptr; + + TaskQueueBase* const network_task_queue_ = nullptr; + // RtpTransportControllerSend to use for this call. + RtpTransportControllerSendFactoryInterface* + rtp_transport_controller_send_factory = nullptr; }; } // namespace webrtc diff --git a/call/call_factory.cc b/call/call_factory.cc index cc02c02835..aeb3cbdaa7 100644 --- a/call/call_factory.cc +++ b/call/call_factory.cc @@ -14,11 +14,13 @@ #include #include +#include #include "absl/types/optional.h" #include "api/test/simulated_network.h" #include "call/call.h" #include "call/degraded_call.h" +#include "call/rtp_transport_config.h" #include "rtc_base/checks.h" #include "system_wrappers/include/field_trial.h" @@ -81,10 +83,19 @@ Call* CallFactory::CreateCall(const Call::Config& config) { absl::optional receive_degradation_config = ParseDegradationConfig(false); + RtpTransportConfig transportConfig = config.ExtractTransportConfig(); + if (send_degradation_config || receive_degradation_config) { - return new DegradedCall(std::unique_ptr(Call::Create(config)), - send_degradation_config, receive_degradation_config, - config.task_queue_factory); + return new DegradedCall( + std::unique_ptr(Call::Create( + config, Clock::GetRealTimeClock(), + SharedModuleThread::Create( + ProcessThread::Create("ModuleProcessThread"), nullptr), + config.rtp_transport_controller_send_factory->Create( + transportConfig, Clock::GetRealTimeClock(), + ProcessThread::Create("PacerThread")))), + send_degradation_config, receive_degradation_config, + config.task_queue_factory); } if (!module_thread_) { @@ -95,7 +106,10 @@ Call* CallFactory::CreateCall(const Call::Config& config) { }); } - return Call::Create(config, module_thread_); + return Call::Create(config, Clock::GetRealTimeClock(), module_thread_, + config.rtp_transport_controller_send_factory->Create( + transportConfig, Clock::GetRealTimeClock(), + ProcessThread::Create("PacerThread"))); } std::unique_ptr CreateCallFactory() { diff --git a/call/call_factory.h b/call/call_factory.h index 2426caae47..469bec39e1 100644 --- a/call/call_factory.h +++ b/call/call_factory.h @@ -12,9 +12,9 @@ #define CALL_CALL_FACTORY_H_ #include "api/call/call_factory_interface.h" +#include "api/sequence_checker.h" #include "call/call.h" #include "call/call_config.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" namespace webrtc { diff --git a/call/call_perf_tests.cc b/call/call_perf_tests.cc index 6591ab596d..c163ab2fe7 100644 --- a/call/call_perf_tests.cc +++ b/call/call_perf_tests.cc @@ -561,6 +561,18 @@ TEST_F(CallPerfTest, ReceivesCpuOveruseAndUnderuse) { // TODO(sprang): Add integration test for maintain-framerate mode? void OnSinkWantsChanged(rtc::VideoSinkInterface* sink, const rtc::VideoSinkWants& wants) override { + // The sink wants can change either because an adaptation happened (i.e. + // the pixels or frame rate changed) or for other reasons, such as encoded + // resolutions being communicated (happens whenever we capture a new frame + // size). In this test, we only care about adaptations. + bool did_adapt = + last_wants_.max_pixel_count != wants.max_pixel_count || + last_wants_.target_pixel_count != wants.target_pixel_count || + last_wants_.max_framerate_fps != wants.max_framerate_fps; + last_wants_ = wants; + if (!did_adapt) { + return; + } // At kStart expect CPU overuse. Then expect CPU underuse when the encoder // delay has been decreased. switch (test_phase_) { @@ -625,6 +637,9 @@ TEST_F(CallPerfTest, ReceivesCpuOveruseAndUnderuse) { kAdaptedDown, kAdaptedUp } test_phase_; + + private: + rtc::VideoSinkWants last_wants_; } test; RunBaseTest(&test); @@ -639,7 +654,8 @@ void CallPerfTest::TestMinTransmitBitrate(bool pad_to_min_bitrate) { static const int kAcceptableBitrateErrorMargin = 15; // +- 7 class BitrateObserver : public test::EndToEndTest { public: - explicit BitrateObserver(bool using_min_transmit_bitrate) + explicit BitrateObserver(bool using_min_transmit_bitrate, + TaskQueueBase* task_queue) : EndToEndTest(kLongTimeoutMs), send_stream_(nullptr), converged_(false), @@ -652,27 +668,31 @@ void CallPerfTest::TestMinTransmitBitrate(bool pad_to_min_bitrate) { ? kMaxAcceptableTransmitBitrate : (kMaxEncodeBitrateKbps + kAcceptableBitrateErrorMargin / 2)), - num_bitrate_observations_in_range_(0) {} + num_bitrate_observations_in_range_(0), + task_queue_(task_queue) {} private: // TODO(holmer): Run this with a timer instead of once per packet. Action OnSendRtp(const uint8_t* packet, size_t length) override { - VideoSendStream::Stats stats = send_stream_->GetStats(); - if (!stats.substreams.empty()) { - RTC_DCHECK_EQ(1, stats.substreams.size()); - int bitrate_kbps = - stats.substreams.begin()->second.total_bitrate_bps / 1000; - if (bitrate_kbps > min_acceptable_bitrate_ && - bitrate_kbps < max_acceptable_bitrate_) { - converged_ = true; - ++num_bitrate_observations_in_range_; - if (num_bitrate_observations_in_range_ == - kNumBitrateObservationsInRange) - observation_complete_.Set(); + task_queue_->PostTask(ToQueuedTask([this]() { + VideoSendStream::Stats stats = send_stream_->GetStats(); + + if (!stats.substreams.empty()) { + RTC_DCHECK_EQ(1, stats.substreams.size()); + int bitrate_kbps = + stats.substreams.begin()->second.total_bitrate_bps / 1000; + if (bitrate_kbps > min_acceptable_bitrate_ && + bitrate_kbps < max_acceptable_bitrate_) { + converged_ = true; + ++num_bitrate_observations_in_range_; + if (num_bitrate_observations_in_range_ == + kNumBitrateObservationsInRange) + observation_complete_.Set(); + } + if (converged_) + bitrate_kbps_list_.push_back(bitrate_kbps); } - if (converged_) - bitrate_kbps_list_.push_back(bitrate_kbps); - } + })); return SEND_PACKET; } @@ -709,7 +729,8 @@ void CallPerfTest::TestMinTransmitBitrate(bool pad_to_min_bitrate) { const int max_acceptable_bitrate_; int num_bitrate_observations_in_range_; std::vector bitrate_kbps_list_; - } test(pad_to_min_bitrate); + TaskQueueBase* task_queue_; + } test(pad_to_min_bitrate, task_queue()); fake_encoder_max_bitrate_ = kMaxEncodeBitrateKbps; RunBaseTest(&test); @@ -760,7 +781,7 @@ TEST_F(CallPerfTest, MAYBE_KeepsHighBitrateWhenReconfiguringSender) { class BitrateObserver : public test::EndToEndTest, public test::FakeEncoder { public: - BitrateObserver() + explicit BitrateObserver(TaskQueueBase* task_queue) : EndToEndTest(kDefaultTimeoutMs), FakeEncoder(Clock::GetRealTimeClock()), encoder_inits_(0), @@ -769,7 +790,8 @@ TEST_F(CallPerfTest, MAYBE_KeepsHighBitrateWhenReconfiguringSender) { frame_generator_(nullptr), encoder_factory_(this), bitrate_allocator_factory_( - CreateBuiltinVideoBitrateAllocatorFactory()) {} + CreateBuiltinVideoBitrateAllocatorFactory()), + task_queue_(task_queue) {} int32_t InitEncode(const VideoCodec* config, const VideoEncoder::Settings& settings) override { @@ -819,7 +841,7 @@ TEST_F(CallPerfTest, MAYBE_KeepsHighBitrateWhenReconfiguringSender) { bitrate_allocator_factory_.get(); encoder_config->max_bitrate_bps = 2 * kReconfigureThresholdKbps * 1000; encoder_config->video_stream_factory = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); encoder_config_ = encoder_config->Copy(); } @@ -839,7 +861,9 @@ TEST_F(CallPerfTest, MAYBE_KeepsHighBitrateWhenReconfiguringSender) { ASSERT_TRUE(time_to_reconfigure_.Wait(kDefaultTimeoutMs)) << "Timed out before receiving an initial high bitrate."; frame_generator_->ChangeResolution(kDefaultWidth * 2, kDefaultHeight * 2); - send_stream_->ReconfigureVideoEncoder(encoder_config_.Copy()); + SendTask(RTC_FROM_HERE, task_queue_, [&]() { + send_stream_->ReconfigureVideoEncoder(encoder_config_.Copy()); + }); EXPECT_TRUE(Wait()) << "Timed out while waiting for a couple of high bitrate estimates " "after reconfiguring the send stream."; @@ -854,7 +878,8 @@ TEST_F(CallPerfTest, MAYBE_KeepsHighBitrateWhenReconfiguringSender) { test::VideoEncoderProxyFactory encoder_factory_; std::unique_ptr bitrate_allocator_factory_; VideoEncoderConfig encoder_config_; - } test; + TaskQueueBase* task_queue_; + } test(task_queue()); RunBaseTest(&test); } diff --git a/call/call_unittest.cc b/call/call_unittest.cc index b6be941e53..92a037f157 100644 --- a/call/call_unittest.cc +++ b/call/call_unittest.cc @@ -30,6 +30,7 @@ #include "call/audio_state.h" #include "modules/audio_device/include/mock_audio_device.h" #include "modules/audio_processing/include/mock_audio_processing.h" +#include "modules/include/module.h" #include "modules/rtp_rtcp/source/rtp_rtcp_interface.h" #include "test/fake_encoder.h" #include "test/gtest.h" @@ -49,14 +50,14 @@ struct CallHelper { task_queue_factory_ = webrtc::CreateDefaultTaskQueueFactory(); webrtc::AudioState::Config audio_state_config; audio_state_config.audio_mixer = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); audio_state_config.audio_processing = use_null_audio_processing ? nullptr - : new rtc::RefCountedObject< + : rtc::make_ref_counted< NiceMock>(); audio_state_config.audio_device_module = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); webrtc::Call::Config config(&event_log_); config.audio_state = webrtc::AudioState::Create(audio_state_config); config.task_queue_factory = task_queue_factory_.get(); @@ -117,7 +118,7 @@ TEST(CallTest, CreateDestroy_AudioReceiveStream) { config.rtp.remote_ssrc = 42; config.rtcp_send_transport = &rtcp_send_transport; config.decoder_factory = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); AudioReceiveStream* stream = call->CreateAudioReceiveStream(config); EXPECT_NE(stream, nullptr); call->DestroyAudioReceiveStream(stream); @@ -156,7 +157,7 @@ TEST(CallTest, CreateDestroy_AudioReceiveStreams) { MockTransport rtcp_send_transport; config.rtcp_send_transport = &rtcp_send_transport; config.decoder_factory = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); std::list streams; for (int i = 0; i < 2; ++i) { for (uint32_t ssrc = 0; ssrc < 1234567; ssrc += 34567) { @@ -186,7 +187,7 @@ TEST(CallTest, CreateDestroy_AssociateAudioSendReceiveStreams_RecvFirst) { recv_config.rtp.local_ssrc = 777; recv_config.rtcp_send_transport = &rtcp_send_transport; recv_config.decoder_factory = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); AudioReceiveStream* recv_stream = call->CreateAudioReceiveStream(recv_config); EXPECT_NE(recv_stream, nullptr); @@ -225,7 +226,7 @@ TEST(CallTest, CreateDestroy_AssociateAudioSendReceiveStreams_SendFirst) { recv_config.rtp.local_ssrc = 777; recv_config.rtcp_send_transport = &rtcp_send_transport; recv_config.decoder_factory = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); AudioReceiveStream* recv_stream = call->CreateAudioReceiveStream(recv_config); EXPECT_NE(recv_stream, nullptr); @@ -247,7 +248,7 @@ TEST(CallTest, CreateDestroy_FlexfecReceiveStream) { MockTransport rtcp_send_transport; FlexfecReceiveStream::Config config(&rtcp_send_transport); config.payload_type = 118; - config.remote_ssrc = 38837212; + config.rtp.remote_ssrc = 38837212; config.protected_media_ssrcs = {27273}; FlexfecReceiveStream* stream = call->CreateFlexfecReceiveStream(config); @@ -266,7 +267,7 @@ TEST(CallTest, CreateDestroy_FlexfecReceiveStreams) { for (int i = 0; i < 2; ++i) { for (uint32_t ssrc = 0; ssrc < 1234567; ssrc += 34567) { - config.remote_ssrc = ssrc; + config.rtp.remote_ssrc = ssrc; config.protected_media_ssrcs = {ssrc + 1}; FlexfecReceiveStream* stream = call->CreateFlexfecReceiveStream(config); EXPECT_NE(stream, nullptr); @@ -294,22 +295,22 @@ TEST(CallTest, MultipleFlexfecReceiveStreamsProtectingSingleVideoStream) { FlexfecReceiveStream* stream; std::list streams; - config.remote_ssrc = 838383; + config.rtp.remote_ssrc = 838383; stream = call->CreateFlexfecReceiveStream(config); EXPECT_NE(stream, nullptr); streams.push_back(stream); - config.remote_ssrc = 424993; + config.rtp.remote_ssrc = 424993; stream = call->CreateFlexfecReceiveStream(config); EXPECT_NE(stream, nullptr); streams.push_back(stream); - config.remote_ssrc = 99383; + config.rtp.remote_ssrc = 99383; stream = call->CreateFlexfecReceiveStream(config); EXPECT_NE(stream, nullptr); streams.push_back(stream); - config.remote_ssrc = 5548; + config.rtp.remote_ssrc = 5548; stream = call->CreateFlexfecReceiveStream(config); EXPECT_NE(stream, nullptr); streams.push_back(stream); diff --git a/call/degraded_call.cc b/call/degraded_call.cc index 0cd43018ac..5462085490 100644 --- a/call/degraded_call.cc +++ b/call/degraded_call.cc @@ -270,6 +270,14 @@ const WebRtcKeyValueConfig& DegradedCall::trials() const { return call_->trials(); } +TaskQueueBase* DegradedCall::network_thread() const { + return call_->network_thread(); +} + +TaskQueueBase* DegradedCall::worker_thread() const { + return call_->worker_thread(); +} + void DegradedCall::SignalChannelNetworkState(MediaType media, NetworkState state) { call_->SignalChannelNetworkState(media, state); @@ -280,6 +288,16 @@ void DegradedCall::OnAudioTransportOverheadChanged( call_->OnAudioTransportOverheadChanged(transport_overhead_per_packet); } +void DegradedCall::OnLocalSsrcUpdated(AudioReceiveStream& stream, + uint32_t local_ssrc) { + call_->OnLocalSsrcUpdated(stream, local_ssrc); +} + +void DegradedCall::OnUpdateSyncGroup(AudioReceiveStream& stream, + const std::string& sync_group) { + call_->OnUpdateSyncGroup(stream, sync_group); +} + void DegradedCall::OnSentPacket(const rtc::SentPacket& sent_packet) { if (send_config_) { // If we have a degraded send-transport, we have already notified call diff --git a/call/degraded_call.h b/call/degraded_call.h index d81c65c570..70dc126807 100644 --- a/call/degraded_call.h +++ b/call/degraded_call.h @@ -16,6 +16,7 @@ #include #include +#include #include "absl/types/optional.h" #include "api/call/transport.h" @@ -87,9 +88,16 @@ class DegradedCall : public Call, private PacketReceiver { const WebRtcKeyValueConfig& trials() const override; + TaskQueueBase* network_thread() const override; + TaskQueueBase* worker_thread() const override; + void SignalChannelNetworkState(MediaType media, NetworkState state) override; void OnAudioTransportOverheadChanged( int transport_overhead_per_packet) override; + void OnLocalSsrcUpdated(AudioReceiveStream& stream, + uint32_t local_ssrc) override; + void OnUpdateSyncGroup(AudioReceiveStream& stream, + const std::string& sync_group) override; void OnSentPacket(const rtc::SentPacket& sent_packet) override; protected: diff --git a/call/flexfec_receive_stream.h b/call/flexfec_receive_stream.h index 2f7438f9a4..72e544e7ec 100644 --- a/call/flexfec_receive_stream.h +++ b/call/flexfec_receive_stream.h @@ -19,11 +19,13 @@ #include "api/call/transport.h" #include "api/rtp_headers.h" #include "api/rtp_parameters.h" +#include "call/receive_stream.h" #include "call/rtp_packet_sink_interface.h" namespace webrtc { -class FlexfecReceiveStream : public RtpPacketSinkInterface { +class FlexfecReceiveStream : public RtpPacketSinkInterface, + public ReceiveStream { public: ~FlexfecReceiveStream() override = default; @@ -48,8 +50,7 @@ class FlexfecReceiveStream : public RtpPacketSinkInterface { // Payload type for FlexFEC. int payload_type = -1; - // SSRC for FlexFEC stream to be received. - uint32_t remote_ssrc = 0; + RtpConfig rtp; // Vector containing a single element, corresponding to the SSRC of the // media stream being protected by this FlexFEC stream. The vector MUST have @@ -59,26 +60,14 @@ class FlexfecReceiveStream : public RtpPacketSinkInterface { // protection. std::vector protected_media_ssrcs; - // SSRC for RTCP reports to be sent. - uint32_t local_ssrc = 0; - // What RTCP mode to use in the reports. RtcpMode rtcp_mode = RtcpMode::kCompound; // Transport for outgoing RTCP packets. Transport* rtcp_send_transport = nullptr; - - // |transport_cc| is true whenever the send-side BWE RTCP feedback message - // has been negotiated. This is a prerequisite for enabling send-side BWE. - bool transport_cc = false; - - // RTP header extensions that have been negotiated for this track. - std::vector rtp_header_extensions; }; virtual Stats GetStats() const = 0; - - virtual const Config& GetConfig() const = 0; }; } // namespace webrtc diff --git a/call/flexfec_receive_stream_impl.cc b/call/flexfec_receive_stream_impl.cc index e629bca347..688efb7b5e 100644 --- a/call/flexfec_receive_stream_impl.cc +++ b/call/flexfec_receive_stream_impl.cc @@ -44,21 +44,21 @@ std::string FlexfecReceiveStream::Config::ToString() const { char buf[1024]; rtc::SimpleStringBuilder ss(buf); ss << "{payload_type: " << payload_type; - ss << ", remote_ssrc: " << remote_ssrc; - ss << ", local_ssrc: " << local_ssrc; + ss << ", remote_ssrc: " << rtp.remote_ssrc; + ss << ", local_ssrc: " << rtp.local_ssrc; ss << ", protected_media_ssrcs: ["; size_t i = 0; for (; i + 1 < protected_media_ssrcs.size(); ++i) ss << protected_media_ssrcs[i] << ", "; if (!protected_media_ssrcs.empty()) ss << protected_media_ssrcs[i]; - ss << "], transport_cc: " << (transport_cc ? "on" : "off"); - ss << ", rtp_header_extensions: ["; + ss << "], transport_cc: " << (rtp.transport_cc ? "on" : "off"); + ss << ", rtp.extensions: ["; i = 0; - for (; i + 1 < rtp_header_extensions.size(); ++i) - ss << rtp_header_extensions[i].ToString() << ", "; - if (!rtp_header_extensions.empty()) - ss << rtp_header_extensions[i].ToString(); + for (; i + 1 < rtp.extensions.size(); ++i) + ss << rtp.extensions[i].ToString() << ", "; + if (!rtp.extensions.empty()) + ss << rtp.extensions[i].ToString(); ss << "]}"; return ss.str(); } @@ -68,7 +68,7 @@ bool FlexfecReceiveStream::Config::IsCompleteAndEnabled() const { if (payload_type < 0) return false; // Do we have the necessary SSRC information? - if (remote_ssrc == 0) + if (rtp.remote_ssrc == 0) return false; // TODO(brandtr): Update this check when we support multistream protection. if (protected_media_ssrcs.size() != 1u) @@ -91,7 +91,7 @@ std::unique_ptr MaybeCreateFlexfecReceiver( } RTC_DCHECK_GE(config.payload_type, 0); RTC_DCHECK_LE(config.payload_type, 127); - if (config.remote_ssrc == 0) { + if (config.rtp.remote_ssrc == 0) { RTC_LOG(LS_WARNING) << "Invalid FlexFEC SSRC given. " "This FlexfecReceiveStream will therefore be useless."; @@ -114,7 +114,7 @@ std::unique_ptr MaybeCreateFlexfecReceiver( } RTC_DCHECK_EQ(1U, config.protected_media_ssrcs.size()); return std::unique_ptr(new FlexfecReceiver( - clock, config.remote_ssrc, config.protected_media_ssrcs[0], + clock, config.rtp.remote_ssrc, config.protected_media_ssrcs[0], recovered_packet_receiver)); } @@ -130,7 +130,7 @@ std::unique_ptr CreateRtpRtcpModule( configuration.receive_statistics = receive_statistics; configuration.outgoing_transport = config.rtcp_send_transport; configuration.rtt_stats = rtt_stats; - configuration.local_media_ssrc = config.local_ssrc; + configuration.local_media_ssrc = config.rtp.local_ssrc; return ModuleRtpRtcpImpl2::Create(configuration); } @@ -138,11 +138,9 @@ std::unique_ptr CreateRtpRtcpModule( FlexfecReceiveStreamImpl::FlexfecReceiveStreamImpl( Clock* clock, - RtpStreamReceiverControllerInterface* receiver_controller, const Config& config, RecoveredPacketReceiver* recovered_packet_receiver, - RtcpRttStats* rtt_stats, - ProcessThread* process_thread) + RtcpRttStats* rtt_stats) : config_(config), receiver_(MaybeCreateFlexfecReceiver(clock, config_, @@ -151,32 +149,38 @@ FlexfecReceiveStreamImpl::FlexfecReceiveStreamImpl( rtp_rtcp_(CreateRtpRtcpModule(clock, rtp_receive_statistics_.get(), config_, - rtt_stats)), - process_thread_(process_thread) { + rtt_stats)) { RTC_LOG(LS_INFO) << "FlexfecReceiveStreamImpl: " << config_.ToString(); + packet_sequence_checker_.Detach(); + // RTCP reporting. rtp_rtcp_->SetRTCPStatus(config_.rtcp_mode); - process_thread_->RegisterModule(rtp_rtcp_.get(), RTC_FROM_HERE); +} + +FlexfecReceiveStreamImpl::~FlexfecReceiveStreamImpl() { + RTC_LOG(LS_INFO) << "~FlexfecReceiveStreamImpl: " << config_.ToString(); +} + +void FlexfecReceiveStreamImpl::RegisterWithTransport( + RtpStreamReceiverControllerInterface* receiver_controller) { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + RTC_DCHECK(!rtp_stream_receiver_); + + if (!receiver_) + return; - // Register with transport. // TODO(nisse): OnRtpPacket in this class delegates all real work to - // |receiver_|. So maybe we don't need to implement RtpPacketSinkInterface + // `receiver_`. So maybe we don't need to implement RtpPacketSinkInterface // here at all, we'd then delete the OnRtpPacket method and instead register - // |receiver_| as the RtpPacketSinkInterface for this stream. - // TODO(nisse): Passing |this| from the constructor to the RtpDemuxer, before - // the object is fully initialized, is risky. But it works in this case - // because locking in our caller, Call::CreateFlexfecReceiveStream, ensures - // that the demuxer doesn't call OnRtpPacket before this object is fully - // constructed. Registering |receiver_| instead of |this| would solve this - // problem too. + // `receiver_` as the RtpPacketSinkInterface for this stream. rtp_stream_receiver_ = - receiver_controller->CreateReceiver(config_.remote_ssrc, this); + receiver_controller->CreateReceiver(config_.rtp.remote_ssrc, this); } -FlexfecReceiveStreamImpl::~FlexfecReceiveStreamImpl() { - RTC_LOG(LS_INFO) << "~FlexfecReceiveStreamImpl: " << config_.ToString(); - process_thread_->DeRegisterModule(rtp_rtcp_.get()); +void FlexfecReceiveStreamImpl::UnregisterFromTransport() { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + rtp_stream_receiver_.reset(); } void FlexfecReceiveStreamImpl::OnRtpPacket(const RtpPacketReceived& packet) { @@ -186,7 +190,7 @@ void FlexfecReceiveStreamImpl::OnRtpPacket(const RtpPacketReceived& packet) { receiver_->OnRtpPacket(packet); // Do not report media packets in the RTCP RRs generated by |rtp_rtcp_|. - if (packet.Ssrc() == config_.remote_ssrc) { + if (packet.Ssrc() == config_.rtp.remote_ssrc) { rtp_receive_statistics_->OnRtpPacket(packet); } } @@ -197,9 +201,4 @@ FlexfecReceiveStreamImpl::Stats FlexfecReceiveStreamImpl::GetStats() const { return FlexfecReceiveStream::Stats(); } -const FlexfecReceiveStream::Config& FlexfecReceiveStreamImpl::GetConfig() - const { - return config_; -} - } // namespace webrtc diff --git a/call/flexfec_receive_stream_impl.h b/call/flexfec_receive_stream_impl.h index 888dae9ebd..285a33f7bb 100644 --- a/call/flexfec_receive_stream_impl.h +++ b/call/flexfec_receive_stream_impl.h @@ -16,12 +16,12 @@ #include "call/flexfec_receive_stream.h" #include "call/rtp_packet_sink_interface.h" #include "modules/rtp_rtcp/source/rtp_rtcp_impl2.h" +#include "rtc_base/system/no_unique_address.h" #include "system_wrappers/include/clock.h" namespace webrtc { class FlexfecReceiver; -class ProcessThread; class ReceiveStatistics; class RecoveredPacketReceiver; class RtcpRttStats; @@ -32,22 +32,37 @@ class RtpStreamReceiverInterface; class FlexfecReceiveStreamImpl : public FlexfecReceiveStream { public: - FlexfecReceiveStreamImpl( - Clock* clock, - RtpStreamReceiverControllerInterface* receiver_controller, - const Config& config, - RecoveredPacketReceiver* recovered_packet_receiver, - RtcpRttStats* rtt_stats, - ProcessThread* process_thread); + FlexfecReceiveStreamImpl(Clock* clock, + const Config& config, + RecoveredPacketReceiver* recovered_packet_receiver, + RtcpRttStats* rtt_stats); + // Destruction happens on the worker thread. Prior to destruction the caller + // must ensure that a registration with the transport has been cleared. See + // `RegisterWithTransport` for details. + // TODO(tommi): As a further improvement to this, performing the full + // destruction on the network thread could be made the default. ~FlexfecReceiveStreamImpl() override; + // Called on the network thread to register/unregister with the network + // transport. + void RegisterWithTransport( + RtpStreamReceiverControllerInterface* receiver_controller); + // If registration has previously been done (via `RegisterWithTransport`) then + // `UnregisterFromTransport` must be called prior to destruction, on the + // network thread. + void UnregisterFromTransport(); + // RtpPacketSinkInterface. void OnRtpPacket(const RtpPacketReceived& packet) override; Stats GetStats() const override; - const Config& GetConfig() const override; + + // ReceiveStream impl. + const RtpConfig& rtp_config() const override { return config_.rtp; } private: + RTC_NO_UNIQUE_ADDRESS SequenceChecker packet_sequence_checker_; + // Config. const Config config_; @@ -57,9 +72,9 @@ class FlexfecReceiveStreamImpl : public FlexfecReceiveStream { // RTCP reporting. const std::unique_ptr rtp_receive_statistics_; const std::unique_ptr rtp_rtcp_; - ProcessThread* process_thread_; - std::unique_ptr rtp_stream_receiver_; + std::unique_ptr rtp_stream_receiver_ + RTC_GUARDED_BY(packet_sequence_checker_); }; } // namespace webrtc diff --git a/call/flexfec_receive_stream_unittest.cc b/call/flexfec_receive_stream_unittest.cc index 5e8ee47433..312fe0c907 100644 --- a/call/flexfec_receive_stream_unittest.cc +++ b/call/flexfec_receive_stream_unittest.cc @@ -26,7 +26,6 @@ #include "modules/rtp_rtcp/source/byte_io.h" #include "modules/rtp_rtcp/source/rtp_header_extensions.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" -#include "modules/utility/include/mock/mock_process_thread.h" #include "test/gmock.h" #include "test/gtest.h" #include "test/mock_transport.h" @@ -45,7 +44,7 @@ FlexfecReceiveStream::Config CreateDefaultConfig( Transport* rtcp_send_transport) { FlexfecReceiveStream::Config config(rtcp_send_transport); config.payload_type = kFlexfecPlType; - config.remote_ssrc = ByteReader::ReadBigEndian(kFlexfecSsrc); + config.rtp.remote_ssrc = ByteReader::ReadBigEndian(kFlexfecSsrc); config.protected_media_ssrcs = { ByteReader::ReadBigEndian(kMediaSsrc)}; EXPECT_TRUE(config.IsCompleteAndEnabled()); @@ -64,16 +63,16 @@ TEST(FlexfecReceiveStreamConfigTest, IsCompleteAndEnabled) { MockTransport rtcp_send_transport; FlexfecReceiveStream::Config config(&rtcp_send_transport); - config.local_ssrc = 18374743; + config.rtp.local_ssrc = 18374743; config.rtcp_mode = RtcpMode::kCompound; - config.transport_cc = true; - config.rtp_header_extensions.emplace_back(TransportSequenceNumber::kUri, 7); + config.rtp.transport_cc = true; + config.rtp.extensions.emplace_back(TransportSequenceNumber::kUri, 7); EXPECT_FALSE(config.IsCompleteAndEnabled()); config.payload_type = 123; EXPECT_FALSE(config.IsCompleteAndEnabled()); - config.remote_ssrc = 238423838; + config.rtp.remote_ssrc = 238423838; EXPECT_FALSE(config.IsCompleteAndEnabled()); config.protected_media_ssrcs.push_back(138989393); @@ -87,21 +86,20 @@ class FlexfecReceiveStreamTest : public ::testing::Test { protected: FlexfecReceiveStreamTest() : config_(CreateDefaultConfig(&rtcp_send_transport_)) { - EXPECT_CALL(process_thread_, RegisterModule(_, _)).Times(1); receive_stream_ = std::make_unique( - Clock::GetRealTimeClock(), &rtp_stream_receiver_controller_, config_, - &recovered_packet_receiver_, &rtt_stats_, &process_thread_); + Clock::GetRealTimeClock(), config_, &recovered_packet_receiver_, + &rtt_stats_); + receive_stream_->RegisterWithTransport(&rtp_stream_receiver_controller_); } ~FlexfecReceiveStreamTest() { - EXPECT_CALL(process_thread_, DeRegisterModule(_)).Times(1); + receive_stream_->UnregisterFromTransport(); } MockTransport rtcp_send_transport_; FlexfecReceiveStream::Config config_; MockRecoveredPacketReceiver recovered_packet_receiver_; MockRtcpRttStats rtt_stats_; - MockProcessThread process_thread_; RtpStreamReceiverController rtp_stream_receiver_controller_; std::unique_ptr receive_stream_; }; @@ -144,10 +142,10 @@ TEST_F(FlexfecReceiveStreamTest, RecoversPacket) { // clang-format on ::testing::StrictMock recovered_packet_receiver; - EXPECT_CALL(process_thread_, RegisterModule(_, _)).Times(1); - FlexfecReceiveStreamImpl receive_stream( - Clock::GetRealTimeClock(), &rtp_stream_receiver_controller_, config_, - &recovered_packet_receiver, &rtt_stats_, &process_thread_); + FlexfecReceiveStreamImpl receive_stream(Clock::GetRealTimeClock(), config_, + &recovered_packet_receiver, + &rtt_stats_); + receive_stream.RegisterWithTransport(&rtp_stream_receiver_controller_); EXPECT_CALL(recovered_packet_receiver, OnRecoveredPacket(_, kRtpHeaderSize + kPayloadLength[1])); @@ -155,7 +153,7 @@ TEST_F(FlexfecReceiveStreamTest, RecoversPacket) { receive_stream.OnRtpPacket(ParsePacket(kFlexfecPacket)); // Tear-down - EXPECT_CALL(process_thread_, DeRegisterModule(_)).Times(1); + receive_stream.UnregisterFromTransport(); } } // namespace webrtc diff --git a/call/packet_receiver.h b/call/packet_receiver.h index df57d8f4f4..13d3b84c90 100644 --- a/call/packet_receiver.h +++ b/call/packet_receiver.h @@ -10,11 +10,6 @@ #ifndef CALL_PACKET_RECEIVER_H_ #define CALL_PACKET_RECEIVER_H_ -#include -#include -#include -#include - #include "api/media_types.h" #include "rtc_base/copy_on_write_buffer.h" diff --git a/call/rampup_tests.cc b/call/rampup_tests.cc index 379f9dcf84..bf136a5df9 100644 --- a/call/rampup_tests.cc +++ b/call/rampup_tests.cc @@ -160,7 +160,7 @@ void RampUpTester::ModifyVideoConfigs( encoder_config->number_of_streams = num_video_streams_; encoder_config->max_bitrate_bps = 2000000; encoder_config->video_stream_factory = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); if (num_video_streams_ == 1) { // For single stream rampup until 1mbps expected_bitrate_bps_ = kSingleStreamTargetBps; @@ -295,16 +295,16 @@ void RampUpTester::ModifyFlexfecConfigs( return; RTC_DCHECK_EQ(1, num_flexfec_streams_); (*receive_configs)[0].payload_type = test::CallTest::kFlexfecPayloadType; - (*receive_configs)[0].remote_ssrc = test::CallTest::kFlexfecSendSsrc; + (*receive_configs)[0].rtp.remote_ssrc = test::CallTest::kFlexfecSendSsrc; (*receive_configs)[0].protected_media_ssrcs = {video_ssrcs_[0]}; - (*receive_configs)[0].local_ssrc = video_ssrcs_[0]; + (*receive_configs)[0].rtp.local_ssrc = video_ssrcs_[0]; if (extension_type_ == RtpExtension::kAbsSendTimeUri) { - (*receive_configs)[0].transport_cc = false; - (*receive_configs)[0].rtp_header_extensions.push_back( + (*receive_configs)[0].rtp.transport_cc = false; + (*receive_configs)[0].rtp.extensions.push_back( RtpExtension(extension_type_.c_str(), kAbsSendTimeExtensionId)); } else if (extension_type_ == RtpExtension::kTransportSequenceNumberUri) { - (*receive_configs)[0].transport_cc = true; - (*receive_configs)[0].rtp_header_extensions.push_back(RtpExtension( + (*receive_configs)[0].rtp.transport_cc = true; + (*receive_configs)[0].rtp.extensions.push_back(RtpExtension( extension_type_.c_str(), kTransportSequenceNumberExtensionId)); } } @@ -370,7 +370,10 @@ void RampUpTester::TriggerTestDone() { if (!send_stream_) return; - VideoSendStream::Stats send_stats = send_stream_->GetStats(); + VideoSendStream::Stats send_stats; + SendTask(RTC_FROM_HERE, task_queue_, + [&] { send_stats = send_stream_->GetStats(); }); + send_stream_ = nullptr; // To avoid dereferencing a bad pointer. size_t total_packets_sent = 0; diff --git a/call/receive_stream.h b/call/receive_stream.h new file mode 100644 index 0000000000..0f59b37ae3 --- /dev/null +++ b/call/receive_stream.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef CALL_RECEIVE_STREAM_H_ +#define CALL_RECEIVE_STREAM_H_ + +#include + +#include "api/crypto/frame_decryptor_interface.h" +#include "api/frame_transformer_interface.h" +#include "api/media_types.h" +#include "api/scoped_refptr.h" +#include "api/transport/rtp/rtp_source.h" + +namespace webrtc { + +// Common base interface for MediaReceiveStream based classes and +// FlexfecReceiveStream. +class ReceiveStream { + public: + // Receive-stream specific RTP settings. + struct RtpConfig { + // Synchronization source (stream identifier) to be received. + // This member will not change mid-stream and can be assumed to be const + // post initialization. + uint32_t remote_ssrc = 0; + + // Sender SSRC used for sending RTCP (such as receiver reports). + // This value may change mid-stream and must be done on the same thread + // that the value is read on (i.e. packet delivery). + uint32_t local_ssrc = 0; + + // Enable feedback for send side bandwidth estimation. + // See + // https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions + // for details. + // This value may change mid-stream and must be done on the same thread + // that the value is read on (i.e. packet delivery). + bool transport_cc = false; + + // RTP header extensions used for the received stream. + // This value may change mid-stream and must be done on the same thread + // that the value is read on (i.e. packet delivery). + std::vector extensions; + }; + + // Called on the packet delivery thread since some members of the config may + // change mid-stream (e.g. the local ssrc). All mutation must also happen on + // the packet delivery thread. Return value can be assumed to + // only be used in the calling context (on the stack basically). + virtual const RtpConfig& rtp_config() const = 0; + + protected: + virtual ~ReceiveStream() {} +}; + +// Either an audio or video receive stream. +class MediaReceiveStream : public ReceiveStream { + public: + // Starts stream activity. + // When a stream is active, it can receive, process and deliver packets. + virtual void Start() = 0; + + // Stops stream activity. Must be called to match with a previous call to + // `Start()`. When a stream has been stopped, it won't receive, decode, + // process or deliver packets to downstream objects such as callback pointers + // set in the config struct. + virtual void Stop() = 0; + + virtual void SetDepacketizerToDecoderFrameTransformer( + rtc::scoped_refptr + frame_transformer) = 0; + + virtual void SetFrameDecryptor( + rtc::scoped_refptr frame_decryptor) = 0; + + virtual std::vector GetSources() const = 0; +}; + +} // namespace webrtc + +#endif // CALL_RECEIVE_STREAM_H_ diff --git a/call/rtp_demuxer.cc b/call/rtp_demuxer.cc index 9fc4ba1c16..28962fd2eb 100644 --- a/call/rtp_demuxer.cc +++ b/call/rtp_demuxer.cc @@ -36,16 +36,7 @@ size_t RemoveFromMultimapByValue(Container* multimap, const Value& value) { template size_t RemoveFromMapByValue(Map* map, const Value& value) { - size_t count = 0; - for (auto it = map->begin(); it != map->end();) { - if (it->second == value) { - it = map->erase(it); - ++count; - } else { - ++it; - } - } - return count; + return EraseIf(*map, [&](const auto& elem) { return elem.second == value; }); } } // namespace @@ -53,6 +44,16 @@ size_t RemoveFromMapByValue(Map* map, const Value& value) { RtpDemuxerCriteria::RtpDemuxerCriteria() = default; RtpDemuxerCriteria::~RtpDemuxerCriteria() = default; +bool RtpDemuxerCriteria::operator==(const RtpDemuxerCriteria& other) const { + return this->mid == other.mid && this->rsid == other.rsid && + this->ssrcs == other.ssrcs && + this->payload_types == other.payload_types; +} + +bool RtpDemuxerCriteria::operator!=(const RtpDemuxerCriteria& other) const { + return !(*this == other); +} + std::string RtpDemuxerCriteria::ToString() const { rtc::StringBuilder sb; sb << "{mid: " << (mid.empty() ? "" : mid) @@ -91,7 +92,7 @@ std::string RtpDemuxer::DescribePacket(const RtpPacketReceived& packet) { return sb.Release(); } -RtpDemuxer::RtpDemuxer() = default; +RtpDemuxer::RtpDemuxer(bool use_mid /* = true*/) : use_mid_(use_mid) {} RtpDemuxer::~RtpDemuxer() { RTC_DCHECK(sink_by_mid_.empty()); diff --git a/call/rtp_demuxer.h b/call/rtp_demuxer.h index 3aa7e9df26..fb65fce368 100644 --- a/call/rtp_demuxer.h +++ b/call/rtp_demuxer.h @@ -12,11 +12,13 @@ #define CALL_RTP_DEMUXER_H_ #include -#include #include #include #include +#include "rtc_base/containers/flat_map.h" +#include "rtc_base/containers/flat_set.h" + namespace webrtc { class RtpPacketReceived; @@ -28,6 +30,9 @@ struct RtpDemuxerCriteria { RtpDemuxerCriteria(); ~RtpDemuxerCriteria(); + bool operator==(const RtpDemuxerCriteria& other) const; + bool operator!=(const RtpDemuxerCriteria& other) const; + // If not the empty string, will match packets with this MID. std::string mid; @@ -39,10 +44,10 @@ struct RtpDemuxerCriteria { std::string rsid; // Will match packets with any of these SSRCs. - std::set ssrcs; + flat_set ssrcs; // Will match packets with any of these payload types. - std::set payload_types; + flat_set payload_types; // Return string representation of demux criteria to facilitate logging std::string ToString() const; @@ -94,7 +99,7 @@ class RtpDemuxer { // relevant for demuxing. static std::string DescribePacket(const RtpPacketReceived& packet); - RtpDemuxer(); + explicit RtpDemuxer(bool use_mid = true); ~RtpDemuxer(); RtpDemuxer(const RtpDemuxer&) = delete; @@ -132,10 +137,6 @@ class RtpDemuxer { // if the packet was forwarded and false if the packet was dropped. bool OnRtpPacket(const RtpPacketReceived& packet); - // Configure whether to look at the MID header extension when demuxing - // incoming RTP packets. By default this is enabled. - void set_use_mid(bool use_mid) { use_mid_ = use_mid; } - private: // Returns true if adding a sink with the given criteria would cause conflicts // with the existing criteria and should be rejected. @@ -169,29 +170,29 @@ class RtpDemuxer { // Note: Mappings are only modified by AddSink/RemoveSink (except for // SSRC mapping which receives all MID, payload type, or RSID to SSRC bindings // discovered when demuxing packets). - std::map sink_by_mid_; - std::map sink_by_ssrc_; + flat_map sink_by_mid_; + flat_map sink_by_ssrc_; std::multimap sinks_by_pt_; - std::map, RtpPacketSinkInterface*> + flat_map, RtpPacketSinkInterface*> sink_by_mid_and_rsid_; - std::map sink_by_rsid_; + flat_map sink_by_rsid_; // Tracks all the MIDs that have been identified in added criteria. Used to // determine if a packet should be dropped right away because the MID is // unknown. - std::set known_mids_; + flat_set known_mids_; // Records learned mappings of MID --> SSRC and RSID --> SSRC as packets are // received. // This is stored separately from the sink mappings because if a sink is // removed we want to still remember these associations. - std::map mid_by_ssrc_; - std::map rsid_by_ssrc_; + flat_map mid_by_ssrc_; + flat_map rsid_by_ssrc_; // Adds a binding from the SSRC to the given sink. void AddSsrcSinkBinding(uint32_t ssrc, RtpPacketSinkInterface* sink); - bool use_mid_ = true; + const bool use_mid_; }; } // namespace webrtc diff --git a/call/rtp_payload_params.cc b/call/rtp_payload_params.cc index 5c9c6c0c05..c6a56a389e 100644 --- a/call/rtp_payload_params.cc +++ b/call/rtp_payload_params.cc @@ -131,6 +131,9 @@ RtpPayloadParams::RtpPayloadParams(const uint32_t ssrc, : ssrc_(ssrc), generic_picture_id_experiment_( absl::StartsWith(trials.Lookup("WebRTC-GenericPictureId"), + "Enabled")), + simulate_generic_vp9_( + absl::StartsWith(trials.Lookup("WebRTC-Vp9DependencyDescriptor"), "Enabled")) { for (auto& spatial_layer : last_shared_frame_id_) spatial_layer.fill(-1); @@ -156,7 +159,7 @@ RTPVideoHeader RtpPayloadParams::GetRtpVideoHeader( PopulateRtpWithCodecSpecifics(*codec_specific_info, image.SpatialIndex(), &rtp_video_header); } - rtp_video_header.frame_type = image._frameType, + rtp_video_header.frame_type = image._frameType; rtp_video_header.rotation = image.rotation_; rtp_video_header.content_type = image.content_type_; rtp_video_header.playout_delay = image.playout_delay_; @@ -165,6 +168,7 @@ RTPVideoHeader RtpPayloadParams::GetRtpVideoHeader( rtp_video_header.color_space = image.ColorSpace() ? absl::make_optional(*image.ColorSpace()) : absl::nullopt; + rtp_video_header.video_frame_tracking_id = image.VideoFrameTrackingId(); SetVideoTiming(image, &rtp_video_header.video_timing); const bool is_keyframe = image._frameType == VideoFrameType::kVideoFrameKey; @@ -276,8 +280,13 @@ void RtpPayloadParams::SetGeneric(const CodecSpecificInfo* codec_specific_info, } return; case VideoCodecType::kVideoCodecVP9: + if (simulate_generic_vp9_ && codec_specific_info != nullptr) { + Vp9ToGeneric(codec_specific_info->codecSpecific.VP9, frame_id, + *rtp_video_header); + } + return; case VideoCodecType::kVideoCodecAV1: - // TODO(philipel): Implement VP9 and AV1 to generic descriptor. + // TODO(philipel): Implement AV1 to generic descriptor. return; case VideoCodecType::kVideoCodecH264: if (codec_specific_info) { @@ -398,6 +407,150 @@ void RtpPayloadParams::Vp8ToGeneric(const CodecSpecificInfoVP8& vp8_info, } } +FrameDependencyStructure RtpPayloadParams::MinimalisticVp9Structure( + const CodecSpecificInfoVP9& vp9) { + const int num_spatial_layers = vp9.num_spatial_layers; + const int num_temporal_layers = kMaxTemporalStreams; + FrameDependencyStructure structure; + structure.num_decode_targets = num_spatial_layers * num_temporal_layers; + structure.num_chains = num_spatial_layers; + structure.templates.reserve(num_spatial_layers * num_temporal_layers); + for (int sid = 0; sid < num_spatial_layers; ++sid) { + for (int tid = 0; tid < num_temporal_layers; ++tid) { + FrameDependencyTemplate a_template; + a_template.spatial_id = sid; + a_template.temporal_id = tid; + for (int s = 0; s < num_spatial_layers; ++s) { + for (int t = 0; t < num_temporal_layers; ++t) { + // Prefer kSwitch for indication frame is part of the decode target + // because RtpPayloadParams::Vp9ToGeneric uses that indication more + // often that kRequired, increasing chance custom dti need not to + // use more bits in dependency descriptor on the wire. + a_template.decode_target_indications.push_back( + sid <= s && tid <= t ? DecodeTargetIndication::kSwitch + : DecodeTargetIndication::kNotPresent); + } + } + a_template.frame_diffs.push_back(tid == 0 ? num_spatial_layers * + num_temporal_layers + : num_spatial_layers); + a_template.chain_diffs.assign(structure.num_chains, 1); + structure.templates.push_back(a_template); + + structure.decode_target_protected_by_chain.push_back(sid); + } + if (vp9.ss_data_available && vp9.spatial_layer_resolution_present) { + structure.resolutions.emplace_back(vp9.width[sid], vp9.height[sid]); + } + } + return structure; +} + +void RtpPayloadParams::Vp9ToGeneric(const CodecSpecificInfoVP9& vp9_info, + int64_t shared_frame_id, + RTPVideoHeader& rtp_video_header) { + const auto& vp9_header = + absl::get(rtp_video_header.video_type_header); + const int num_spatial_layers = vp9_header.num_spatial_layers; + const int num_temporal_layers = kMaxTemporalStreams; + + int spatial_index = + vp9_header.spatial_idx != kNoSpatialIdx ? vp9_header.spatial_idx : 0; + int temporal_index = + vp9_header.temporal_idx != kNoTemporalIdx ? vp9_header.temporal_idx : 0; + + if (spatial_index >= num_spatial_layers || + temporal_index >= num_temporal_layers || + num_spatial_layers > RtpGenericFrameDescriptor::kMaxSpatialLayers) { + // Prefer to generate no generic layering than an inconsistent one. + return; + } + + RTPVideoHeader::GenericDescriptorInfo& result = + rtp_video_header.generic.emplace(); + + result.frame_id = shared_frame_id; + result.spatial_index = spatial_index; + result.temporal_index = temporal_index; + + result.decode_target_indications.reserve(num_spatial_layers * + num_temporal_layers); + for (int sid = 0; sid < num_spatial_layers; ++sid) { + for (int tid = 0; tid < num_temporal_layers; ++tid) { + DecodeTargetIndication dti; + if (sid < spatial_index || tid < temporal_index) { + dti = DecodeTargetIndication::kNotPresent; + } else if (spatial_index != sid && + vp9_header.non_ref_for_inter_layer_pred) { + dti = DecodeTargetIndication::kNotPresent; + } else if (sid == spatial_index && tid == temporal_index) { + // Assume that if frame is decodable, all of its own layer is decodable. + dti = DecodeTargetIndication::kSwitch; + } else if (sid == spatial_index && vp9_header.temporal_up_switch) { + dti = DecodeTargetIndication::kSwitch; + } else if (!vp9_header.inter_pic_predicted) { + // Key frame or spatial upswitch + dti = DecodeTargetIndication::kSwitch; + } else { + // Make no other assumptions. That should be safe, though suboptimal. + // To provide more accurate dti, encoder wrapper should fill in + // CodecSpecificInfo::generic_frame_info + dti = DecodeTargetIndication::kRequired; + } + result.decode_target_indications.push_back(dti); + } + } + + // Calculate frame dependencies. + static constexpr int kPictureDiffLimit = 128; + if (last_vp9_frame_id_.empty()) { + // Create the array only if it is ever used. + last_vp9_frame_id_.resize(kPictureDiffLimit); + } + if (vp9_header.inter_layer_predicted && spatial_index > 0) { + result.dependencies.push_back( + last_vp9_frame_id_[vp9_header.picture_id % kPictureDiffLimit] + [spatial_index - 1]); + } + if (vp9_header.inter_pic_predicted) { + for (size_t i = 0; i < vp9_header.num_ref_pics; ++i) { + // picture_id is 15 bit number that wraps around. Though undeflow may + // produce picture that exceeds 2^15, it is ok because in this + // code block only last 7 bits of the picture_id are used. + uint16_t depend_on = vp9_header.picture_id - vp9_header.pid_diff[i]; + result.dependencies.push_back( + last_vp9_frame_id_[depend_on % kPictureDiffLimit][spatial_index]); + } + } + last_vp9_frame_id_[vp9_header.picture_id % kPictureDiffLimit][spatial_index] = + shared_frame_id; + + // Calculate chains, asuming chain includes all frames with temporal_id = 0 + if (!vp9_header.inter_pic_predicted && !vp9_header.inter_layer_predicted) { + // Assume frames without dependencies also reset chains. + for (int sid = spatial_index; sid < num_spatial_layers; ++sid) { + chain_last_frame_id_[sid] = -1; + } + } + result.chain_diffs.resize(num_spatial_layers); + for (int sid = 0; sid < num_spatial_layers; ++sid) { + if (chain_last_frame_id_[sid] == -1) { + result.chain_diffs[sid] = 0; + continue; + } + result.chain_diffs[sid] = shared_frame_id - chain_last_frame_id_[sid]; + } + + if (temporal_index == 0) { + chain_last_frame_id_[spatial_index] = shared_frame_id; + if (!vp9_header.non_ref_for_inter_layer_pred) { + for (int sid = spatial_index + 1; sid < num_spatial_layers; ++sid) { + chain_last_frame_id_[sid] = shared_frame_id; + } + } + } +} + void RtpPayloadParams::SetDependenciesVp8Deprecated( const CodecSpecificInfoVP8& vp8_info, int64_t shared_frame_id, diff --git a/call/rtp_payload_params.h b/call/rtp_payload_params.h index ebfdd4605a..da53cbc5c4 100644 --- a/call/rtp_payload_params.h +++ b/call/rtp_payload_params.h @@ -12,6 +12,7 @@ #define CALL_RTP_PAYLOAD_PARAMS_H_ #include +#include #include "absl/types/optional.h" #include "api/transport/webrtc_key_value_config.h" @@ -41,6 +42,14 @@ class RtpPayloadParams final { const CodecSpecificInfo* codec_specific_info, int64_t shared_frame_id); + // Returns structure that aligns with simulated generic info for VP9. + // The templates allow to produce valid dependency descriptor for any vp9 + // stream with up to 4 temporal layers. The set of the templates is not tuned + // for any paricular structure thus dependency descriptor would use more bytes + // on the wire than with tuned templates. + static FrameDependencyStructure MinimalisticVp9Structure( + const CodecSpecificInfoVP9& vp9); + uint32_t ssrc() const; RtpPayloadState state() const; @@ -61,6 +70,10 @@ class RtpPayloadParams final { bool is_keyframe, RTPVideoHeader* rtp_video_header); + void Vp9ToGeneric(const CodecSpecificInfoVP9& vp9_info, + int64_t shared_frame_id, + RTPVideoHeader& rtp_video_header); + void H264ToGeneric(const CodecSpecificInfoH264& h264_info, int64_t shared_frame_id, bool is_keyframe, @@ -94,6 +107,13 @@ class RtpPayloadParams final { std::array, RtpGenericFrameDescriptor::kMaxSpatialLayers> last_shared_frame_id_; + // circular buffer of frame ids for the last 128 vp9 pictures. + // ids for the `picture_id` are stored at the index `picture_id % 128`. + std::vector> + last_vp9_frame_id_; + // Last frame id for each chain + std::array + chain_last_frame_id_; // TODO(eladalon): When additional codecs are supported, // set kMaxCodecBuffersCount to the max() of these codecs' buffer count. @@ -113,6 +133,7 @@ class RtpPayloadParams final { RtpPayloadState state_; const bool generic_picture_id_experiment_; + const bool simulate_generic_vp9_; }; } // namespace webrtc #endif // CALL_RTP_PAYLOAD_PARAMS_H_ diff --git a/call/rtp_payload_params_unittest.cc b/call/rtp_payload_params_unittest.cc index 56ed2cdea6..7db38dbcb8 100644 --- a/call/rtp_payload_params_unittest.cc +++ b/call/rtp_payload_params_unittest.cc @@ -26,10 +26,12 @@ #include "modules/video_coding/codecs/vp8/include/vp8_globals.h" #include "modules/video_coding/codecs/vp9/include/vp9_globals.h" #include "modules/video_coding/include/video_codec_interface.h" +#include "test/explicit_key_value_config.h" #include "test/field_trial.h" #include "test/gmock.h" #include "test/gtest.h" +using ::testing::Each; using ::testing::ElementsAre; using ::testing::IsEmpty; using ::testing::SizeIs; @@ -461,6 +463,410 @@ TEST_F(RtpPayloadParamsVp8ToGenericTest, FrameIdGaps) { ConvertAndCheck(1, 20, VideoFrameType::kVideoFrameDelta, kNoSync, {10, 15}); } +class RtpPayloadParamsVp9ToGenericTest : public ::testing::Test { + protected: + RtpPayloadParamsVp9ToGenericTest() + : field_trials_("WebRTC-Vp9DependencyDescriptor/Enabled/") {} + + test::ExplicitKeyValueConfig field_trials_; + RtpPayloadState state_; +}; + +TEST_F(RtpPayloadParamsVp9ToGenericTest, NoScalability) { + RtpPayloadParams params(/*ssrc=*/123, &state_, field_trials_); + + EncodedImage encoded_image; + CodecSpecificInfo codec_info; + codec_info.codecType = kVideoCodecVP9; + codec_info.codecSpecific.VP9.num_spatial_layers = 1; + codec_info.codecSpecific.VP9.temporal_idx = kNoTemporalIdx; + codec_info.codecSpecific.VP9.first_frame_in_picture = true; + codec_info.end_of_picture = true; + + // Key frame. + encoded_image._frameType = VideoFrameType::kVideoFrameKey; + codec_info.codecSpecific.VP9.inter_pic_predicted = false; + codec_info.codecSpecific.VP9.num_ref_pics = 0; + RTPVideoHeader header = params.GetRtpVideoHeader(encoded_image, &codec_info, + /*shared_frame_id=*/1); + + ASSERT_TRUE(header.generic); + EXPECT_EQ(header.generic->spatial_index, 0); + EXPECT_EQ(header.generic->temporal_index, 0); + EXPECT_EQ(header.generic->frame_id, 1); + ASSERT_THAT(header.generic->decode_target_indications, Not(IsEmpty())); + EXPECT_EQ(header.generic->decode_target_indications[0], + DecodeTargetIndication::kSwitch); + EXPECT_THAT(header.generic->dependencies, IsEmpty()); + EXPECT_THAT(header.generic->chain_diffs, ElementsAre(0)); + + // Delta frame. + encoded_image._frameType = VideoFrameType::kVideoFrameDelta; + codec_info.codecSpecific.VP9.inter_pic_predicted = true; + codec_info.codecSpecific.VP9.num_ref_pics = 1; + codec_info.codecSpecific.VP9.p_diff[0] = 1; + header = params.GetRtpVideoHeader(encoded_image, &codec_info, + /*shared_frame_id=*/3); + + ASSERT_TRUE(header.generic); + EXPECT_EQ(header.generic->spatial_index, 0); + EXPECT_EQ(header.generic->temporal_index, 0); + EXPECT_EQ(header.generic->frame_id, 3); + ASSERT_THAT(header.generic->decode_target_indications, Not(IsEmpty())); + EXPECT_EQ(header.generic->decode_target_indications[0], + DecodeTargetIndication::kSwitch); + EXPECT_THAT(header.generic->dependencies, ElementsAre(1)); + // previous frame in the chain was frame#1, + EXPECT_THAT(header.generic->chain_diffs, ElementsAre(3 - 1)); +} + +TEST_F(RtpPayloadParamsVp9ToGenericTest, TemporalScalabilityWith2Layers) { + // Test with 2 temporal layers structure that is not used by webrtc: + // 1---3 5 + // / / / ... + // 0---2---4--- + RtpPayloadParams params(/*ssrc=*/123, &state_, field_trials_); + + EncodedImage image; + CodecSpecificInfo info; + info.codecType = kVideoCodecVP9; + info.codecSpecific.VP9.num_spatial_layers = 1; + info.codecSpecific.VP9.first_frame_in_picture = true; + info.end_of_picture = true; + + RTPVideoHeader headers[6]; + // Key frame. + image._frameType = VideoFrameType::kVideoFrameKey; + info.codecSpecific.VP9.inter_pic_predicted = false; + info.codecSpecific.VP9.num_ref_pics = 0; + info.codecSpecific.VP9.temporal_up_switch = true; + info.codecSpecific.VP9.temporal_idx = 0; + headers[0] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/1); + + // Delta frames. + info.codecSpecific.VP9.inter_pic_predicted = true; + image._frameType = VideoFrameType::kVideoFrameDelta; + + info.codecSpecific.VP9.temporal_up_switch = true; + info.codecSpecific.VP9.temporal_idx = 1; + info.codecSpecific.VP9.num_ref_pics = 1; + info.codecSpecific.VP9.p_diff[0] = 1; + headers[1] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/3); + + info.codecSpecific.VP9.temporal_up_switch = false; + info.codecSpecific.VP9.temporal_idx = 0; + info.codecSpecific.VP9.num_ref_pics = 1; + info.codecSpecific.VP9.p_diff[0] = 2; + headers[2] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/5); + + info.codecSpecific.VP9.temporal_up_switch = false; + info.codecSpecific.VP9.temporal_idx = 1; + info.codecSpecific.VP9.num_ref_pics = 2; + info.codecSpecific.VP9.p_diff[0] = 1; + info.codecSpecific.VP9.p_diff[1] = 2; + headers[3] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/7); + + info.codecSpecific.VP9.temporal_up_switch = true; + info.codecSpecific.VP9.temporal_idx = 0; + info.codecSpecific.VP9.num_ref_pics = 1; + info.codecSpecific.VP9.p_diff[0] = 2; + headers[4] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/9); + + info.codecSpecific.VP9.temporal_up_switch = true; + info.codecSpecific.VP9.temporal_idx = 1; + info.codecSpecific.VP9.num_ref_pics = 1; + info.codecSpecific.VP9.p_diff[0] = 1; + headers[5] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/11); + + ASSERT_TRUE(headers[0].generic); + int num_decode_targets = headers[0].generic->decode_target_indications.size(); + ASSERT_GE(num_decode_targets, 2); + + for (int frame_idx = 0; frame_idx < 6; ++frame_idx) { + const RTPVideoHeader& header = headers[frame_idx]; + ASSERT_TRUE(header.generic); + EXPECT_EQ(header.generic->spatial_index, 0); + EXPECT_EQ(header.generic->temporal_index, frame_idx % 2); + EXPECT_EQ(header.generic->frame_id, 1 + 2 * frame_idx); + ASSERT_THAT(header.generic->decode_target_indications, + SizeIs(num_decode_targets)); + // Expect only T0 frames are needed for the 1st decode target. + if (header.generic->temporal_index == 0) { + EXPECT_NE(header.generic->decode_target_indications[0], + DecodeTargetIndication::kNotPresent); + } else { + EXPECT_EQ(header.generic->decode_target_indications[0], + DecodeTargetIndication::kNotPresent); + } + // Expect all frames are needed for the 2nd decode target. + EXPECT_NE(header.generic->decode_target_indications[1], + DecodeTargetIndication::kNotPresent); + } + + // Expect switch at every beginning of the pattern. + EXPECT_THAT(headers[0].generic->decode_target_indications, + Each(DecodeTargetIndication::kSwitch)); + EXPECT_THAT(headers[4].generic->decode_target_indications, + Each(DecodeTargetIndication::kSwitch)); + + EXPECT_THAT(headers[0].generic->dependencies, IsEmpty()); // T0, 1 + EXPECT_THAT(headers[1].generic->dependencies, ElementsAre(1)); // T1, 3 + EXPECT_THAT(headers[2].generic->dependencies, ElementsAre(1)); // T0, 5 + EXPECT_THAT(headers[3].generic->dependencies, ElementsAre(5, 3)); // T1, 7 + EXPECT_THAT(headers[4].generic->dependencies, ElementsAre(5)); // T0, 9 + EXPECT_THAT(headers[5].generic->dependencies, ElementsAre(9)); // T1, 11 + + EXPECT_THAT(headers[0].generic->chain_diffs, ElementsAre(0)); + EXPECT_THAT(headers[1].generic->chain_diffs, ElementsAre(2)); + EXPECT_THAT(headers[2].generic->chain_diffs, ElementsAre(4)); + EXPECT_THAT(headers[3].generic->chain_diffs, ElementsAre(2)); + EXPECT_THAT(headers[4].generic->chain_diffs, ElementsAre(4)); + EXPECT_THAT(headers[5].generic->chain_diffs, ElementsAre(2)); +} + +TEST_F(RtpPayloadParamsVp9ToGenericTest, TemporalScalabilityWith3Layers) { + // Test with 3 temporal layers structure that is not used by webrtc, but used + // by chromium: https://imgur.com/pURAGvp + RtpPayloadParams params(/*ssrc=*/123, &state_, field_trials_); + + EncodedImage image; + CodecSpecificInfo info; + info.codecType = kVideoCodecVP9; + info.codecSpecific.VP9.num_spatial_layers = 1; + info.codecSpecific.VP9.first_frame_in_picture = true; + info.end_of_picture = true; + + RTPVideoHeader headers[9]; + // Key frame. + image._frameType = VideoFrameType::kVideoFrameKey; + info.codecSpecific.VP9.inter_pic_predicted = false; + info.codecSpecific.VP9.num_ref_pics = 0; + info.codecSpecific.VP9.temporal_up_switch = true; + info.codecSpecific.VP9.temporal_idx = 0; + headers[0] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/1); + + // Delta frames. + info.codecSpecific.VP9.inter_pic_predicted = true; + image._frameType = VideoFrameType::kVideoFrameDelta; + + info.codecSpecific.VP9.temporal_up_switch = true; + info.codecSpecific.VP9.temporal_idx = 2; + info.codecSpecific.VP9.num_ref_pics = 1; + info.codecSpecific.VP9.p_diff[0] = 1; + headers[1] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/3); + + info.codecSpecific.VP9.temporal_up_switch = true; + info.codecSpecific.VP9.temporal_idx = 1; + info.codecSpecific.VP9.num_ref_pics = 1; + info.codecSpecific.VP9.p_diff[0] = 2; + headers[2] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/5); + + info.codecSpecific.VP9.temporal_up_switch = true; + info.codecSpecific.VP9.temporal_idx = 2; + info.codecSpecific.VP9.num_ref_pics = 1; + info.codecSpecific.VP9.p_diff[0] = 1; + headers[3] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/7); + + info.codecSpecific.VP9.temporal_up_switch = false; + info.codecSpecific.VP9.temporal_idx = 0; + info.codecSpecific.VP9.num_ref_pics = 1; + info.codecSpecific.VP9.p_diff[0] = 4; + headers[4] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/9); + + info.codecSpecific.VP9.temporal_up_switch = true; + info.codecSpecific.VP9.temporal_idx = 2; + info.codecSpecific.VP9.num_ref_pics = 2; + info.codecSpecific.VP9.p_diff[0] = 1; + info.codecSpecific.VP9.p_diff[1] = 3; + headers[5] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/11); + + info.codecSpecific.VP9.temporal_up_switch = false; + info.codecSpecific.VP9.temporal_idx = 1; + info.codecSpecific.VP9.num_ref_pics = 2; + info.codecSpecific.VP9.p_diff[0] = 2; + info.codecSpecific.VP9.p_diff[1] = 4; + headers[6] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/13); + + info.codecSpecific.VP9.temporal_up_switch = true; + info.codecSpecific.VP9.temporal_idx = 2; + info.codecSpecific.VP9.num_ref_pics = 1; + info.codecSpecific.VP9.p_diff[0] = 1; + headers[7] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/15); + + info.codecSpecific.VP9.temporal_up_switch = true; + info.codecSpecific.VP9.temporal_idx = 0; + info.codecSpecific.VP9.num_ref_pics = 1; + info.codecSpecific.VP9.p_diff[0] = 4; + headers[8] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/17); + + ASSERT_TRUE(headers[0].generic); + int num_decode_targets = headers[0].generic->decode_target_indications.size(); + ASSERT_GE(num_decode_targets, 3); + + for (int frame_idx = 0; frame_idx < 9; ++frame_idx) { + const RTPVideoHeader& header = headers[frame_idx]; + ASSERT_TRUE(header.generic); + EXPECT_EQ(header.generic->spatial_index, 0); + EXPECT_EQ(header.generic->frame_id, 1 + 2 * frame_idx); + ASSERT_THAT(header.generic->decode_target_indications, + SizeIs(num_decode_targets)); + // Expect only T0 frames are needed for the 1st decode target. + if (header.generic->temporal_index == 0) { + EXPECT_NE(header.generic->decode_target_indications[0], + DecodeTargetIndication::kNotPresent); + } else { + EXPECT_EQ(header.generic->decode_target_indications[0], + DecodeTargetIndication::kNotPresent); + } + // Expect only T0 and T1 frames are needed for the 2nd decode target. + if (header.generic->temporal_index <= 1) { + EXPECT_NE(header.generic->decode_target_indications[1], + DecodeTargetIndication::kNotPresent); + } else { + EXPECT_EQ(header.generic->decode_target_indications[1], + DecodeTargetIndication::kNotPresent); + } + // Expect all frames are needed for the 3rd decode target. + EXPECT_NE(header.generic->decode_target_indications[2], + DecodeTargetIndication::kNotPresent); + } + + EXPECT_EQ(headers[0].generic->temporal_index, 0); + EXPECT_EQ(headers[1].generic->temporal_index, 2); + EXPECT_EQ(headers[2].generic->temporal_index, 1); + EXPECT_EQ(headers[3].generic->temporal_index, 2); + EXPECT_EQ(headers[4].generic->temporal_index, 0); + EXPECT_EQ(headers[5].generic->temporal_index, 2); + EXPECT_EQ(headers[6].generic->temporal_index, 1); + EXPECT_EQ(headers[7].generic->temporal_index, 2); + EXPECT_EQ(headers[8].generic->temporal_index, 0); + + // Expect switch at every beginning of the pattern. + EXPECT_THAT(headers[0].generic->decode_target_indications, + Each(DecodeTargetIndication::kSwitch)); + EXPECT_THAT(headers[8].generic->decode_target_indications, + Each(DecodeTargetIndication::kSwitch)); + + EXPECT_THAT(headers[0].generic->dependencies, IsEmpty()); // T0, 1 + EXPECT_THAT(headers[1].generic->dependencies, ElementsAre(1)); // T2, 3 + EXPECT_THAT(headers[2].generic->dependencies, ElementsAre(1)); // T1, 5 + EXPECT_THAT(headers[3].generic->dependencies, ElementsAre(5)); // T2, 7 + EXPECT_THAT(headers[4].generic->dependencies, ElementsAre(1)); // T0, 9 + EXPECT_THAT(headers[5].generic->dependencies, ElementsAre(9, 5)); // T2, 11 + EXPECT_THAT(headers[6].generic->dependencies, ElementsAre(9, 5)); // T1, 13 + EXPECT_THAT(headers[7].generic->dependencies, ElementsAre(13)); // T2, 15 + EXPECT_THAT(headers[8].generic->dependencies, ElementsAre(9)); // T0, 17 + + EXPECT_THAT(headers[0].generic->chain_diffs, ElementsAre(0)); + EXPECT_THAT(headers[1].generic->chain_diffs, ElementsAre(2)); + EXPECT_THAT(headers[2].generic->chain_diffs, ElementsAre(4)); + EXPECT_THAT(headers[3].generic->chain_diffs, ElementsAre(6)); + EXPECT_THAT(headers[4].generic->chain_diffs, ElementsAre(8)); + EXPECT_THAT(headers[5].generic->chain_diffs, ElementsAre(2)); + EXPECT_THAT(headers[6].generic->chain_diffs, ElementsAre(4)); + EXPECT_THAT(headers[7].generic->chain_diffs, ElementsAre(6)); + EXPECT_THAT(headers[8].generic->chain_diffs, ElementsAre(8)); +} + +TEST_F(RtpPayloadParamsVp9ToGenericTest, SpatialScalabilityKSvc) { + // 1---3-- + // | ... + // 0---2-- + RtpPayloadParams params(/*ssrc=*/123, &state_, field_trials_); + + EncodedImage image; + CodecSpecificInfo info; + info.codecType = kVideoCodecVP9; + info.codecSpecific.VP9.num_spatial_layers = 2; + info.codecSpecific.VP9.first_frame_in_picture = true; + + RTPVideoHeader headers[4]; + // Key frame. + image._frameType = VideoFrameType::kVideoFrameKey; + image.SetSpatialIndex(0); + info.codecSpecific.VP9.inter_pic_predicted = false; + info.codecSpecific.VP9.inter_layer_predicted = false; + info.codecSpecific.VP9.non_ref_for_inter_layer_pred = false; + info.codecSpecific.VP9.num_ref_pics = 0; + info.codecSpecific.VP9.first_frame_in_picture = true; + info.end_of_picture = false; + headers[0] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/1); + + image.SetSpatialIndex(1); + info.codecSpecific.VP9.inter_layer_predicted = true; + info.codecSpecific.VP9.non_ref_for_inter_layer_pred = true; + info.codecSpecific.VP9.first_frame_in_picture = false; + info.end_of_picture = true; + headers[1] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/3); + + // Delta frames. + info.codecSpecific.VP9.inter_pic_predicted = true; + image._frameType = VideoFrameType::kVideoFrameDelta; + info.codecSpecific.VP9.num_ref_pics = 1; + info.codecSpecific.VP9.p_diff[0] = 1; + + image.SetSpatialIndex(0); + info.codecSpecific.VP9.inter_layer_predicted = false; + info.codecSpecific.VP9.non_ref_for_inter_layer_pred = true; + info.codecSpecific.VP9.first_frame_in_picture = true; + info.end_of_picture = false; + headers[2] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/5); + + image.SetSpatialIndex(1); + info.codecSpecific.VP9.inter_layer_predicted = false; + info.codecSpecific.VP9.non_ref_for_inter_layer_pred = true; + info.codecSpecific.VP9.first_frame_in_picture = false; + info.end_of_picture = true; + headers[3] = params.GetRtpVideoHeader(image, &info, /*shared_frame_id=*/7); + + ASSERT_TRUE(headers[0].generic); + int num_decode_targets = headers[0].generic->decode_target_indications.size(); + // Rely on implementation detail there are always kMaxTemporalStreams temporal + // layers assumed, in particular assume Decode Target#0 matches layer S0T0, + // and Decode Target#kMaxTemporalStreams matches layer S1T0. + ASSERT_EQ(num_decode_targets, kMaxTemporalStreams * 2); + + for (int frame_idx = 0; frame_idx < 4; ++frame_idx) { + const RTPVideoHeader& header = headers[frame_idx]; + ASSERT_TRUE(header.generic); + EXPECT_EQ(header.generic->spatial_index, frame_idx % 2); + EXPECT_EQ(header.generic->temporal_index, 0); + EXPECT_EQ(header.generic->frame_id, 1 + 2 * frame_idx); + ASSERT_THAT(header.generic->decode_target_indications, + SizeIs(num_decode_targets)); + } + + // Expect S0 key frame is switch for both Decode Targets. + EXPECT_EQ(headers[0].generic->decode_target_indications[0], + DecodeTargetIndication::kSwitch); + EXPECT_EQ(headers[0].generic->decode_target_indications[kMaxTemporalStreams], + DecodeTargetIndication::kSwitch); + // S1 key frame is only needed for the 2nd Decode Targets. + EXPECT_EQ(headers[1].generic->decode_target_indications[0], + DecodeTargetIndication::kNotPresent); + EXPECT_NE(headers[1].generic->decode_target_indications[kMaxTemporalStreams], + DecodeTargetIndication::kNotPresent); + // Delta frames are only needed for their own Decode Targets. + EXPECT_NE(headers[2].generic->decode_target_indications[0], + DecodeTargetIndication::kNotPresent); + EXPECT_EQ(headers[2].generic->decode_target_indications[kMaxTemporalStreams], + DecodeTargetIndication::kNotPresent); + EXPECT_EQ(headers[3].generic->decode_target_indications[0], + DecodeTargetIndication::kNotPresent); + EXPECT_NE(headers[3].generic->decode_target_indications[kMaxTemporalStreams], + DecodeTargetIndication::kNotPresent); + + EXPECT_THAT(headers[0].generic->dependencies, IsEmpty()); // S0, 1 + EXPECT_THAT(headers[1].generic->dependencies, ElementsAre(1)); // S1, 3 + EXPECT_THAT(headers[2].generic->dependencies, ElementsAre(1)); // S0, 5 + EXPECT_THAT(headers[3].generic->dependencies, ElementsAre(3)); // S1, 7 + + EXPECT_THAT(headers[0].generic->chain_diffs, ElementsAre(0, 0)); + EXPECT_THAT(headers[1].generic->chain_diffs, ElementsAre(2, 2)); + EXPECT_THAT(headers[2].generic->chain_diffs, ElementsAre(4, 2)); + EXPECT_THAT(headers[3].generic->chain_diffs, ElementsAre(2, 4)); +} + class RtpPayloadParamsH264ToGenericTest : public ::testing::Test { public: enum LayerSync { kNoSync, kSync }; diff --git a/call/rtp_stream_receiver_controller.cc b/call/rtp_stream_receiver_controller.cc index f440b426d6..7150b34bdb 100644 --- a/call/rtp_stream_receiver_controller.cc +++ b/call/rtp_stream_receiver_controller.cc @@ -37,11 +37,7 @@ RtpStreamReceiverController::Receiver::~Receiver() { controller_->RemoveSink(sink_); } -RtpStreamReceiverController::RtpStreamReceiverController() { - // At this level the demuxer is only configured to demux by SSRC, so don't - // worry about MIDs (MIDs are handled by upper layers). - demuxer_.set_use_mid(false); -} +RtpStreamReceiverController::RtpStreamReceiverController() {} RtpStreamReceiverController::~RtpStreamReceiverController() = default; @@ -52,19 +48,19 @@ RtpStreamReceiverController::CreateReceiver(uint32_t ssrc, } bool RtpStreamReceiverController::OnRtpPacket(const RtpPacketReceived& packet) { - rtc::CritScope cs(&lock_); + RTC_DCHECK_RUN_ON(&demuxer_sequence_); return demuxer_.OnRtpPacket(packet); } bool RtpStreamReceiverController::AddSink(uint32_t ssrc, RtpPacketSinkInterface* sink) { - rtc::CritScope cs(&lock_); + RTC_DCHECK_RUN_ON(&demuxer_sequence_); return demuxer_.AddSink(ssrc, sink); } size_t RtpStreamReceiverController::RemoveSink( const RtpPacketSinkInterface* sink) { - rtc::CritScope cs(&lock_); + RTC_DCHECK_RUN_ON(&demuxer_sequence_); return demuxer_.RemoveSink(sink); } diff --git a/call/rtp_stream_receiver_controller.h b/call/rtp_stream_receiver_controller.h index 62447aa521..284c9fa12f 100644 --- a/call/rtp_stream_receiver_controller.h +++ b/call/rtp_stream_receiver_controller.h @@ -12,9 +12,9 @@ #include +#include "api/sequence_checker.h" #include "call/rtp_demuxer.h" #include "call/rtp_stream_receiver_controller_interface.h" -#include "rtc_base/deprecated/recursive_critical_section.h" namespace webrtc { @@ -58,13 +58,18 @@ class RtpStreamReceiverController RtpPacketSinkInterface* const sink_; }; - // TODO(nisse): Move to a TaskQueue for synchronization. When used - // by Call, we expect construction and all methods but OnRtpPacket - // to be called on the same thread, and OnRtpPacket to be called - // by a single, but possibly distinct, thread. But applications not - // using Call may have use threads differently. - rtc::RecursiveCriticalSection lock_; - RtpDemuxer demuxer_ RTC_GUARDED_BY(&lock_); + // TODO(bugs.webrtc.org/11993): We expect construction and all methods to be + // called on the same thread/tq. Currently this is the worker thread + // (including OnRtpPacket) but a more natural fit would be the network thread. + // Using a sequence checker to ensure that usage is correct but at the same + // time not require a specific thread/tq, an instance of this class + the + // associated functionality should be easily moved from one execution context + // to another (i.e. when network packets don't hop to the worker thread inside + // of Call). + SequenceChecker demuxer_sequence_; + // At this level the demuxer is only configured to demux by SSRC, so don't + // worry about MIDs (MIDs are handled by upper layers). + RtpDemuxer demuxer_ RTC_GUARDED_BY(&demuxer_sequence_){false /*use_mid*/}; }; } // namespace webrtc diff --git a/call/rtp_transport_config.h b/call/rtp_transport_config.h new file mode 100644 index 0000000000..9aa9f14c16 --- /dev/null +++ b/call/rtp_transport_config.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef CALL_RTP_TRANSPORT_CONFIG_H_ +#define CALL_RTP_TRANSPORT_CONFIG_H_ + +#include + +#include "api/network_state_predictor.h" +#include "api/rtc_event_log/rtc_event_log.h" +#include "api/transport/bitrate_settings.h" +#include "api/transport/network_control.h" +#include "api/transport/webrtc_key_value_config.h" +#include "modules/utility/include/process_thread.h" +#include "rtc_base/task_queue.h" + +namespace webrtc { + +struct RtpTransportConfig { + // Bitrate config used until valid bitrate estimates are calculated. Also + // used to cap total bitrate used. This comes from the remote connection. + BitrateConstraints bitrate_config; + + // RtcEventLog to use for this call. Required. + // Use webrtc::RtcEventLog::CreateNull() for a null implementation. + RtcEventLog* event_log = nullptr; + + // Task Queue Factory to be used in this call. Required. + TaskQueueFactory* task_queue_factory = nullptr; + + // NetworkStatePredictor to use for this call. + NetworkStatePredictorFactoryInterface* network_state_predictor_factory = + nullptr; + + // Network controller factory to use for this call. + NetworkControllerFactoryInterface* network_controller_factory = nullptr; + + // Key-value mapping of internal configurations to apply, + // e.g. field trials. + const WebRtcKeyValueConfig* trials = nullptr; +}; +} // namespace webrtc + +#endif // CALL_RTP_TRANSPORT_CONFIG_H_ diff --git a/call/rtp_transport_controller_send.cc b/call/rtp_transport_controller_send.cc index f5adae68ae..f7b6b11fd7 100644 --- a/call/rtp_transport_controller_send.cc +++ b/call/rtp_transport_controller_send.cc @@ -87,7 +87,7 @@ RtpTransportControllerSend::RtpTransportControllerSend( : clock_(clock), event_log_(event_log), bitrate_configurator_(bitrate_config), - process_thread_started_(false), + pacer_started_(false), process_thread_(std::move(process_thread)), use_task_queue_pacer_(IsEnabled(trials, "WebRTC-TaskQueuePacer")), process_thread_pacer_(use_task_queue_pacer_ @@ -142,6 +142,7 @@ RtpTransportControllerSend::RtpTransportControllerSend( } RtpTransportControllerSend::~RtpTransportControllerSend() { + RTC_DCHECK(video_rtp_senders_.empty()); process_thread_->Stop(); } @@ -156,6 +157,7 @@ RtpVideoSenderInterface* RtpTransportControllerSend::CreateRtpVideoSender( std::unique_ptr fec_controller, const RtpSenderFrameEncryptionConfig& frame_encryption_config, rtc::scoped_refptr frame_transformer) { + RTC_DCHECK_RUN_ON(&main_thread_); video_rtp_senders_.push_back(std::make_unique( clock_, suspended_ssrcs, states, rtp_config, rtcp_report_interval_ms, send_transport, observers, @@ -169,6 +171,7 @@ RtpVideoSenderInterface* RtpTransportControllerSend::CreateRtpVideoSender( void RtpTransportControllerSend::DestroyRtpVideoSender( RtpVideoSenderInterface* rtp_video_sender) { + RTC_DCHECK_RUN_ON(&main_thread_); std::vector>::iterator it = video_rtp_senders_.end(); for (it = video_rtp_senders_.begin(); it != video_rtp_senders_.end(); ++it) { @@ -354,6 +357,7 @@ void RtpTransportControllerSend::OnNetworkRouteChanged( } } void RtpTransportControllerSend::OnNetworkAvailability(bool network_available) { + RTC_DCHECK_RUN_ON(&main_thread_); RTC_LOG(LS_VERBOSE) << "SignalNetworkState " << (network_available ? "Up" : "Down"); NetworkAvailability msg; @@ -470,6 +474,7 @@ RtpTransportControllerSend::ApplyOrLiftRelayCap(bool is_relayed) { void RtpTransportControllerSend::OnTransportOverheadChanged( size_t transport_overhead_bytes_per_packet) { + RTC_DCHECK_RUN_ON(&main_thread_); if (transport_overhead_bytes_per_packet >= kMaxOverheadBytes) { RTC_LOG(LS_ERROR) << "Transport overhead exceeds " << kMaxOverheadBytes; return; @@ -496,9 +501,13 @@ void RtpTransportControllerSend::IncludeOverheadInPacedSender() { } void RtpTransportControllerSend::EnsureStarted() { - if (!use_task_queue_pacer_ && !process_thread_started_) { - process_thread_started_ = true; - process_thread_->Start(); + if (!pacer_started_) { + pacer_started_ = true; + if (use_task_queue_pacer_) { + task_queue_pacer_->EnsureStarted(); + } else { + process_thread_->Start(); + } } } diff --git a/call/rtp_transport_controller_send.h b/call/rtp_transport_controller_send.h index 7025b03312..7455060945 100644 --- a/call/rtp_transport_controller_send.h +++ b/call/rtp_transport_controller_send.h @@ -18,6 +18,7 @@ #include #include "api/network_state_predictor.h" +#include "api/sequence_checker.h" #include "api/transport/network_control.h" #include "api/units/data_rate.h" #include "call/rtp_bitrate_configurator.h" @@ -62,6 +63,7 @@ class RtpTransportControllerSend final const WebRtcKeyValueConfig* trials); ~RtpTransportControllerSend() override; + // TODO(tommi): Change to std::unique_ptr<>. RtpVideoSenderInterface* CreateRtpVideoSender( std::map suspended_ssrcs, const std::map& @@ -148,11 +150,13 @@ class RtpTransportControllerSend final Clock* const clock_; RtcEventLog* const event_log_; + SequenceChecker main_thread_; PacketRouter packet_router_; - std::vector> video_rtp_senders_; + std::vector> video_rtp_senders_ + RTC_GUARDED_BY(&main_thread_); RtpBitrateConfigurator bitrate_configurator_; std::map network_routes_; - bool process_thread_started_; + bool pacer_started_; const std::unique_ptr process_thread_; const bool use_task_queue_pacer_; std::unique_ptr process_thread_pacer_; diff --git a/call/rtp_transport_controller_send_factory.h b/call/rtp_transport_controller_send_factory.h new file mode 100644 index 0000000000..a857ca7e6f --- /dev/null +++ b/call/rtp_transport_controller_send_factory.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef CALL_RTP_TRANSPORT_CONTROLLER_SEND_FACTORY_H_ +#define CALL_RTP_TRANSPORT_CONTROLLER_SEND_FACTORY_H_ + +#include +#include + +#include "call/rtp_transport_controller_send.h" +#include "call/rtp_transport_controller_send_factory_interface.h" + +namespace webrtc { +class RtpTransportControllerSendFactory + : public RtpTransportControllerSendFactoryInterface { + public: + std::unique_ptr Create( + const RtpTransportConfig& config, + Clock* clock, + std::unique_ptr process_thread) override { + return std::make_unique( + clock, config.event_log, config.network_state_predictor_factory, + config.network_controller_factory, config.bitrate_config, + std::move(process_thread), config.task_queue_factory, config.trials); + } + + virtual ~RtpTransportControllerSendFactory() {} +}; +} // namespace webrtc +#endif // CALL_RTP_TRANSPORT_CONTROLLER_SEND_FACTORY_H_ diff --git a/call/rtp_transport_controller_send_factory_interface.h b/call/rtp_transport_controller_send_factory_interface.h new file mode 100644 index 0000000000..a0218532a1 --- /dev/null +++ b/call/rtp_transport_controller_send_factory_interface.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef CALL_RTP_TRANSPORT_CONTROLLER_SEND_FACTORY_INTERFACE_H_ +#define CALL_RTP_TRANSPORT_CONTROLLER_SEND_FACTORY_INTERFACE_H_ + +#include + +#include "call/rtp_transport_config.h" +#include "call/rtp_transport_controller_send_interface.h" +#include "modules/utility/include/process_thread.h" + +namespace webrtc { +// A factory used for dependency injection on the send side of the transport +// controller. +class RtpTransportControllerSendFactoryInterface { + public: + virtual std::unique_ptr Create( + const RtpTransportConfig& config, + Clock* clock, + std::unique_ptr process_thread) = 0; + + virtual ~RtpTransportControllerSendFactoryInterface() {} +}; +} // namespace webrtc +#endif // CALL_RTP_TRANSPORT_CONTROLLER_SEND_FACTORY_INTERFACE_H_ diff --git a/call/rtp_transport_controller_send_interface.h b/call/rtp_transport_controller_send_interface.h index 605ebfbd3e..2aa6d739da 100644 --- a/call/rtp_transport_controller_send_interface.h +++ b/call/rtp_transport_controller_send_interface.h @@ -52,7 +52,6 @@ struct RtpSenderObservers { RtcpRttStats* rtcp_rtt_stats; RtcpIntraFrameObserver* intra_frame_callback; RtcpLossNotificationObserver* rtcp_loss_notification_observer; - RtcpStatisticsCallback* rtcp_stats; ReportBlockDataObserver* report_block_data_observer; StreamDataCountersCallback* rtp_stats; BitrateStatisticsObserver* bitrate_observer; @@ -135,7 +134,13 @@ class RtpTransportControllerSendInterface { virtual int64_t GetPacerQueuingDelayMs() const = 0; virtual absl::optional GetFirstPacketTime() const = 0; virtual void EnablePeriodicAlrProbing(bool enable) = 0; + + // Called when a packet has been sent. + // The call should arrive on the network thread, but may not in all cases + // (some tests don't adhere to this). Implementations today should not block + // the calling thread or make assumptions about the thread context. virtual void OnSentPacket(const rtc::SentPacket& sent_packet) = 0; + virtual void OnReceivedPacket(const ReceivedPacket& received_packet) = 0; virtual void SetSdpBitrateParameters( diff --git a/call/rtp_video_sender.cc b/call/rtp_video_sender.cc index e8d5db9e46..7fad89b20b 100644 --- a/call/rtp_video_sender.cc +++ b/call/rtp_video_sender.cc @@ -31,6 +31,7 @@ #include "rtc_base/location.h" #include "rtc_base/logging.h" #include "rtc_base/task_queue.h" +#include "rtc_base/trace_event.h" namespace webrtc { @@ -216,7 +217,6 @@ std::vector CreateRtpStreamSenders( configuration.rtt_stats = observers.rtcp_rtt_stats; configuration.rtcp_packet_type_counter_observer = observers.rtcp_type_observer; - configuration.rtcp_statistics_callback = observers.rtcp_stats; configuration.report_block_data_observer = observers.report_block_data_observer; configuration.paced_sender = transport->packet_sender(); @@ -314,14 +314,24 @@ bool IsFirstFrameOfACodedVideoSequence( return false; } - if (codec_specific_info != nullptr && - codec_specific_info->generic_frame_info.has_value()) { - // This function is used before - // `codec_specific_info->generic_frame_info->frame_diffs` are calculated, so - // need to use more complicated way to check for presence of dependencies. - return absl::c_none_of( - codec_specific_info->generic_frame_info->encoder_buffers, - [](const CodecBufferUsage& buffer) { return buffer.referenced; }); + if (codec_specific_info != nullptr) { + if (codec_specific_info->generic_frame_info.has_value()) { + // This function is used before + // `codec_specific_info->generic_frame_info->frame_diffs` are calculated, + // so need to use a more complicated way to check for presence of the + // dependencies. + return absl::c_none_of( + codec_specific_info->generic_frame_info->encoder_buffers, + [](const CodecBufferUsage& buffer) { return buffer.referenced; }); + } + + if (codec_specific_info->codecType == VideoCodecType::kVideoCodecVP8 || + codec_specific_info->codecType == VideoCodecType::kVideoCodecH264 || + codec_specific_info->codecType == VideoCodecType::kVideoCodecGeneric) { + // These codecs do not support intra picture dependencies, so a frame + // marked as a key frame should be a key frame. + return true; + } } // Without depenedencies described in generic format do an educated guess. @@ -357,8 +367,10 @@ RtpVideoSender::RtpVideoSender( field_trials_.Lookup("WebRTC-Video-UseFrameRateForOverhead"), "Enabled")), has_packet_feedback_(TransportSeqNumExtensionConfigured(rtp_config)), + simulate_vp9_structure_(absl::StartsWith( + field_trials_.Lookup("WebRTC-Vp9DependencyDescriptor"), + "Enabled")), active_(false), - module_process_thread_(nullptr), suspended_ssrcs_(std::move(suspended_ssrcs)), fec_controller_(std::move(fec_controller)), fec_allowed_(true), @@ -386,7 +398,6 @@ RtpVideoSender::RtpVideoSender( RTC_DCHECK_EQ(rtp_config_.ssrcs.size(), rtp_streams_.size()); if (send_side_bwe_with_overhead_ && has_packet_feedback_) transport_->IncludeOverheadInPacedSender(); - module_process_thread_checker_.Detach(); // SSRCs are assumed to be sorted in the same order as |rtp_modules|. for (uint32_t ssrc : rtp_config_.ssrcs) { // Restore state if it previously existed. @@ -401,18 +412,6 @@ RtpVideoSender::RtpVideoSender( // RTP/RTCP initialization. - // We add the highest spatial layer first to ensure it'll be prioritized - // when sending padding, with the hope that the packet rate will be smaller, - // and that it's more important to protect than the lower layers. - - // TODO(nisse): Consider moving registration with PacketRouter last, after the - // modules are fully configured. - for (const RtpStreamSender& stream : rtp_streams_) { - constexpr bool remb_candidate = true; - transport->packet_router()->AddSendRtpModule(stream.rtp_rtcp.get(), - remb_candidate); - } - for (size_t i = 0; i < rtp_config_.extensions.size(); ++i) { const std::string& extension = rtp_config_.extensions[i].uri; int id = rtp_config_.extensions[i].id; @@ -453,31 +452,12 @@ RtpVideoSender::RtpVideoSender( } RtpVideoSender::~RtpVideoSender() { - for (const RtpStreamSender& stream : rtp_streams_) { - transport_->packet_router()->RemoveSendRtpModule(stream.rtp_rtcp.get()); - } + SetActiveModulesLocked( + std::vector(rtp_streams_.size(), /*active=*/false)); transport_->GetStreamFeedbackProvider()->DeRegisterStreamFeedbackObserver( this); } -void RtpVideoSender::RegisterProcessThread( - ProcessThread* module_process_thread) { - RTC_DCHECK_RUN_ON(&module_process_thread_checker_); - RTC_DCHECK(!module_process_thread_); - module_process_thread_ = module_process_thread; - - for (const RtpStreamSender& stream : rtp_streams_) { - module_process_thread_->RegisterModule(stream.rtp_rtcp.get(), - RTC_FROM_HERE); - } -} - -void RtpVideoSender::DeRegisterProcessThread() { - RTC_DCHECK_RUN_ON(&module_process_thread_checker_); - for (const RtpStreamSender& stream : rtp_streams_) - module_process_thread_->DeRegisterModule(stream.rtp_rtcp.get()); -} - void RtpVideoSender::SetActive(bool active) { MutexLock lock(&mutex_); if (active_ == active) @@ -499,10 +479,29 @@ void RtpVideoSender::SetActiveModulesLocked( if (active_modules[i]) { active_ = true; } + + RtpRtcpInterface& rtp_module = *rtp_streams_[i].rtp_rtcp; + const bool was_active = rtp_module.SendingMedia(); + const bool should_be_active = active_modules[i]; + // Sends a kRtcpByeCode when going from true to false. - rtp_streams_[i].rtp_rtcp->SetSendingStatus(active_modules[i]); + rtp_module.SetSendingStatus(active_modules[i]); + + if (was_active && !should_be_active) { + // Disabling media, remove from packet router map to reduce size and + // prevent any stray packets in the pacer from asynchronously arriving + // to a disabled module. + transport_->packet_router()->RemoveSendRtpModule(&rtp_module); + } + // If set to false this module won't send media. - rtp_streams_[i].rtp_rtcp->SetSendingMediaStatus(active_modules[i]); + rtp_module.SetSendingMediaStatus(active_modules[i]); + + if (!was_active && should_be_active) { + // Turning on media, register with packet router. + transport_->packet_router()->AddSendRtpModule(&rtp_module, + /*remb_candidate=*/true); + } } } @@ -562,10 +561,18 @@ EncodedImageCallback::Result RtpVideoSender::OnEncodedImage( // If encoder adapter produce FrameDependencyStructure, pass it so that // dependency descriptor rtp header extension can be used. // If not supported, disable using dependency descriptor by passing nullptr. - rtp_streams_[stream_index].sender_video->SetVideoStructure( - (codec_specific_info && codec_specific_info->template_structure) - ? &*codec_specific_info->template_structure - : nullptr); + RTPSenderVideo& sender_video = *rtp_streams_[stream_index].sender_video; + if (codec_specific_info && codec_specific_info->template_structure) { + sender_video.SetVideoStructure(&*codec_specific_info->template_structure); + } else if (simulate_vp9_structure_ && codec_specific_info && + codec_specific_info->codecType == kVideoCodecVP9) { + FrameDependencyStructure structure = + RtpPayloadParams::MinimalisticVp9Structure( + codec_specific_info->codecSpecific.VP9); + sender_video.SetVideoStructure(&structure); + } else { + sender_video.SetVideoStructure(nullptr); + } } bool send_result = rtp_streams_[stream_index].sender_video->SendEncodedImage( @@ -904,43 +911,45 @@ void RtpVideoSender::OnPacketFeedbackVector( // Map from SSRC to all acked packets for that RTP module. std::map> acked_packets_per_ssrc; for (const StreamPacketInfo& packet : packet_feedback_vector) { - if (packet.received) { - acked_packets_per_ssrc[packet.ssrc].push_back(packet.rtp_sequence_number); + if (packet.received && packet.ssrc) { + acked_packets_per_ssrc[*packet.ssrc].push_back( + packet.rtp_sequence_number); } } - // Map from SSRC to vector of RTP sequence numbers that are indicated as - // lost by feedback, without being trailed by any received packets. - std::map> early_loss_detected_per_ssrc; + // Map from SSRC to vector of RTP sequence numbers that are indicated as + // lost by feedback, without being trailed by any received packets. + std::map> early_loss_detected_per_ssrc; - for (const StreamPacketInfo& packet : packet_feedback_vector) { - if (!packet.received) { - // Last known lost packet, might not be detectable as lost by remote - // jitter buffer. - early_loss_detected_per_ssrc[packet.ssrc].push_back( - packet.rtp_sequence_number); - } else { - // Packet received, so any loss prior to this is already detectable. - early_loss_detected_per_ssrc.erase(packet.ssrc); - } + for (const StreamPacketInfo& packet : packet_feedback_vector) { + // Only include new media packets, not retransmissions/padding/fec. + if (!packet.received && packet.ssrc && !packet.is_retransmission) { + // Last known lost packet, might not be detectable as lost by remote + // jitter buffer. + early_loss_detected_per_ssrc[*packet.ssrc].push_back( + packet.rtp_sequence_number); + } else { + // Packet received, so any loss prior to this is already detectable. + early_loss_detected_per_ssrc.erase(*packet.ssrc); } + } - for (const auto& kv : early_loss_detected_per_ssrc) { - const uint32_t ssrc = kv.first; - auto it = ssrc_to_rtp_module_.find(ssrc); - RTC_DCHECK(it != ssrc_to_rtp_module_.end()); - RTPSender* rtp_sender = it->second->RtpSender(); - for (uint16_t sequence_number : kv.second) { - rtp_sender->ReSendPacket(sequence_number); - } + for (const auto& kv : early_loss_detected_per_ssrc) { + const uint32_t ssrc = kv.first; + auto it = ssrc_to_rtp_module_.find(ssrc); + RTC_CHECK(it != ssrc_to_rtp_module_.end()); + RTPSender* rtp_sender = it->second->RtpSender(); + for (uint16_t sequence_number : kv.second) { + rtp_sender->ReSendPacket(sequence_number); } + } for (const auto& kv : acked_packets_per_ssrc) { const uint32_t ssrc = kv.first; auto it = ssrc_to_rtp_module_.find(ssrc); if (it == ssrc_to_rtp_module_.end()) { - // Packets not for a media SSRC, so likely RTX or FEC. If so, ignore - // since there's no RTP history to clean up anyway. + // No media, likely FEC or padding. Ignore since there's no RTP history to + // clean up anyway. continue; } rtc::ArrayView rtp_sequence_numbers(kv.second); diff --git a/call/rtp_video_sender.h b/call/rtp_video_sender.h index a8fb0ab59c..991276fe79 100644 --- a/call/rtp_video_sender.h +++ b/call/rtp_video_sender.h @@ -22,6 +22,7 @@ #include "api/fec_controller.h" #include "api/fec_controller_override.h" #include "api/rtc_event_log/rtc_event_log.h" +#include "api/sequence_checker.h" #include "api/transport/field_trial_based_config.h" #include "api/video_codecs/video_encoder.h" #include "call/rtp_config.h" @@ -34,12 +35,10 @@ #include "modules/rtp_rtcp/source/rtp_sender_video.h" #include "modules/rtp_rtcp/source/rtp_sequence_number_map.h" #include "modules/rtp_rtcp/source/rtp_video_header.h" -#include "modules/utility/include/process_thread.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/rate_limiter.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -90,15 +89,6 @@ class RtpVideoSender : public RtpVideoSenderInterface, rtc::scoped_refptr frame_transformer); ~RtpVideoSender() override; - // RegisterProcessThread register |module_process_thread| with those objects - // that use it. Registration has to happen on the thread were - // |module_process_thread| was created (libjingle's worker thread). - // TODO(perkj): Replace the use of |module_process_thread| with a TaskQueue, - // maybe |worker_queue|. - void RegisterProcessThread(ProcessThread* module_process_thread) - RTC_LOCKS_EXCLUDED(mutex_) override; - void DeRegisterProcessThread() RTC_LOCKS_EXCLUDED(mutex_) override; - // RtpVideoSender will only route packets if being active, all packets will be // dropped otherwise. void SetActive(bool active) RTC_LOCKS_EXCLUDED(mutex_) override; @@ -178,14 +168,13 @@ class RtpVideoSender : public RtpVideoSenderInterface, const bool send_side_bwe_with_overhead_; const bool use_frame_rate_for_overhead_; const bool has_packet_feedback_; + const bool simulate_vp9_structure_; // TODO(holmer): Remove mutex_ once RtpVideoSender runs on the // transport task queue. mutable Mutex mutex_; bool active_ RTC_GUARDED_BY(mutex_); - ProcessThread* module_process_thread_; - rtc::ThreadChecker module_process_thread_checker_; std::map suspended_ssrcs_; const std::unique_ptr fec_controller_; diff --git a/call/rtp_video_sender_interface.h b/call/rtp_video_sender_interface.h index 632c9e835a..a0b4baccb4 100644 --- a/call/rtp_video_sender_interface.h +++ b/call/rtp_video_sender_interface.h @@ -22,7 +22,6 @@ #include "call/rtp_config.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" #include "modules/rtp_rtcp/source/rtp_sequence_number_map.h" -#include "modules/utility/include/process_thread.h" #include "modules/video_coding/include/video_codec_interface.h" namespace webrtc { @@ -32,9 +31,6 @@ struct FecProtectionParams; class RtpVideoSenderInterface : public EncodedImageCallback, public FecControllerOverride { public: - virtual void RegisterProcessThread(ProcessThread* module_process_thread) = 0; - virtual void DeRegisterProcessThread() = 0; - // RtpVideoSender will only route packets if being active, all // packets will be dropped otherwise. virtual void SetActive(bool active) = 0; diff --git a/call/rtp_video_sender_unittest.cc b/call/rtp_video_sender_unittest.cc index b738c21447..334d97ccfa 100644 --- a/call/rtp_video_sender_unittest.cc +++ b/call/rtp_video_sender_unittest.cc @@ -61,7 +61,6 @@ class MockRtcpIntraFrameObserver : public RtcpIntraFrameObserver { RtpSenderObservers CreateObservers( RtcpRttStats* rtcp_rtt_stats, RtcpIntraFrameObserver* intra_frame_callback, - RtcpStatisticsCallback* rtcp_stats, ReportBlockDataObserver* report_block_data_observer, StreamDataCountersCallback* rtp_stats, BitrateStatisticsObserver* bitrate_observer, @@ -73,7 +72,6 @@ RtpSenderObservers CreateObservers( observers.rtcp_rtt_stats = rtcp_rtt_stats; observers.intra_frame_callback = intra_frame_callback; observers.rtcp_loss_notification_observer = nullptr; - observers.rtcp_stats = rtcp_stats; observers.report_block_data_observer = report_block_data_observer; observers.rtp_stats = rtp_stats; observers.bitrate_observer = bitrate_observer; @@ -107,6 +105,7 @@ VideoSendStream::Config CreateVideoSendStreamConfig( kTransportsSequenceExtensionId); config.rtp.extensions.emplace_back(RtpDependencyDescriptorExtension::kUri, kDependencyDescriptorExtensionId); + config.rtp.extmap_allow_mixed = true; return config; } @@ -146,9 +145,8 @@ class RtpVideoSenderTestFixture { time_controller_.GetClock(), suspended_ssrcs, suspended_payload_states, config_.rtp, config_.rtcp_report_interval_ms, &transport_, CreateObservers(nullptr, &encoder_feedback_, &stats_proxy_, - &stats_proxy_, &stats_proxy_, &stats_proxy_, - frame_count_observer, &stats_proxy_, &stats_proxy_, - &send_delay_stats_), + &stats_proxy_, &stats_proxy_, frame_count_observer, + &stats_proxy_, &stats_proxy_, &send_delay_stats_), &transport_controller_, &event_log_, &retransmission_rate_limiter_, std::make_unique(time_controller_.GetClock()), nullptr, CryptoOptions{}, frame_transformer); @@ -464,11 +462,13 @@ TEST(RtpVideoSenderTest, DoesNotRetrasmitAckedPackets) { lost_packet_feedback.rtp_sequence_number = rtp_sequence_numbers[0]; lost_packet_feedback.ssrc = kSsrc1; lost_packet_feedback.received = false; + lost_packet_feedback.is_retransmission = false; StreamFeedbackObserver::StreamPacketInfo received_packet_feedback; received_packet_feedback.rtp_sequence_number = rtp_sequence_numbers[1]; received_packet_feedback.ssrc = kSsrc1; received_packet_feedback.received = true; + lost_packet_feedback.is_retransmission = false; test.router()->OnPacketFeedbackVector( {lost_packet_feedback, received_packet_feedback}); @@ -640,11 +640,13 @@ TEST(RtpVideoSenderTest, EarlyRetransmits) { first_packet_feedback.rtp_sequence_number = frame1_rtp_sequence_number; first_packet_feedback.ssrc = kSsrc1; first_packet_feedback.received = false; + first_packet_feedback.is_retransmission = false; StreamFeedbackObserver::StreamPacketInfo second_packet_feedback; second_packet_feedback.rtp_sequence_number = frame2_rtp_sequence_number; second_packet_feedback.ssrc = kSsrc2; second_packet_feedback.received = true; + first_packet_feedback.is_retransmission = false; test.router()->OnPacketFeedbackVector( {first_packet_feedback, second_packet_feedback}); @@ -768,6 +770,62 @@ TEST(RtpVideoSenderTest, SupportsDependencyDescriptorForVp9) { EXPECT_TRUE(sent_packets[1].HasExtension()); } +TEST(RtpVideoSenderTest, + SupportsDependencyDescriptorForVp9NotProvidedByEncoder) { + test::ScopedFieldTrials field_trials( + "WebRTC-Vp9DependencyDescriptor/Enabled/"); + RtpVideoSenderTestFixture test({kSsrc1}, {}, kPayloadType, {}); + test.router()->SetActive(true); + + RtpHeaderExtensionMap extensions; + extensions.Register( + kDependencyDescriptorExtensionId); + std::vector sent_packets; + ON_CALL(test.transport(), SendRtp) + .WillByDefault([&](const uint8_t* packet, size_t length, + const PacketOptions& options) { + sent_packets.emplace_back(&extensions); + EXPECT_TRUE(sent_packets.back().Parse(packet, length)); + return true; + }); + + const uint8_t kPayload[1] = {'a'}; + EncodedImage encoded_image; + encoded_image.SetTimestamp(1); + encoded_image.capture_time_ms_ = 2; + encoded_image._frameType = VideoFrameType::kVideoFrameKey; + encoded_image._encodedWidth = 320; + encoded_image._encodedHeight = 180; + encoded_image.SetEncodedData( + EncodedImageBuffer::Create(kPayload, sizeof(kPayload))); + + CodecSpecificInfo codec_specific; + codec_specific.codecType = VideoCodecType::kVideoCodecVP9; + codec_specific.codecSpecific.VP9.num_spatial_layers = 1; + codec_specific.codecSpecific.VP9.temporal_idx = kNoTemporalIdx; + codec_specific.codecSpecific.VP9.first_frame_in_picture = true; + codec_specific.end_of_picture = true; + codec_specific.codecSpecific.VP9.inter_pic_predicted = false; + + // Send two tiny images, each mapping to single RTP packet. + EXPECT_EQ(test.router()->OnEncodedImage(encoded_image, &codec_specific).error, + EncodedImageCallback::Result::OK); + + // Send in 2nd picture. + encoded_image._frameType = VideoFrameType::kVideoFrameDelta; + encoded_image.SetTimestamp(3000); + codec_specific.codecSpecific.VP9.inter_pic_predicted = true; + codec_specific.codecSpecific.VP9.num_ref_pics = 1; + codec_specific.codecSpecific.VP9.p_diff[0] = 1; + EXPECT_EQ(test.router()->OnEncodedImage(encoded_image, &codec_specific).error, + EncodedImageCallback::Result::OK); + + test.AdvanceTime(TimeDelta::Millis(33)); + ASSERT_THAT(sent_packets, SizeIs(2)); + EXPECT_TRUE(sent_packets[0].HasExtension()); + EXPECT_TRUE(sent_packets[1].HasExtension()); +} + TEST(RtpVideoSenderTest, SupportsStoppingUsingDependencyDescriptor) { RtpVideoSenderTestFixture test({kSsrc1}, {}, kPayloadType, {}); test.router()->SetActive(true); @@ -825,6 +883,64 @@ TEST(RtpVideoSenderTest, SupportsStoppingUsingDependencyDescriptor) { sent_packets.back().HasExtension()); } +TEST(RtpVideoSenderTest, + SupportsStoppingUsingDependencyDescriptorForVp8Simulcast) { + RtpVideoSenderTestFixture test({kSsrc1, kSsrc2}, {}, kPayloadType, {}); + test.router()->SetActive(true); + + RtpHeaderExtensionMap extensions; + extensions.Register( + kDependencyDescriptorExtensionId); + std::vector sent_packets; + ON_CALL(test.transport(), SendRtp) + .WillByDefault([&](const uint8_t* packet, size_t length, + const PacketOptions& options) { + sent_packets.emplace_back(&extensions); + EXPECT_TRUE(sent_packets.back().Parse(packet, length)); + return true; + }); + + const uint8_t kPayload[1] = {'a'}; + EncodedImage encoded_image; + encoded_image.SetTimestamp(1); + encoded_image.capture_time_ms_ = 2; + encoded_image.SetEncodedData( + EncodedImageBuffer::Create(kPayload, sizeof(kPayload))); + // VP8 simulcast uses spatial index to communicate simulcast stream. + encoded_image.SetSpatialIndex(1); + + CodecSpecificInfo codec_specific; + codec_specific.codecType = VideoCodecType::kVideoCodecVP8; + codec_specific.template_structure.emplace(); + codec_specific.template_structure->num_decode_targets = 1; + codec_specific.template_structure->templates = { + FrameDependencyTemplate().T(0).Dtis("S")}; + + // Send two tiny images, mapping to single RTP packets. + // Send in a key frame. + encoded_image._frameType = VideoFrameType::kVideoFrameKey; + codec_specific.generic_frame_info = + GenericFrameInfo::Builder().T(0).Dtis("S").Build(); + codec_specific.generic_frame_info->encoder_buffers = {{0, false, true}}; + EXPECT_EQ(test.router()->OnEncodedImage(encoded_image, &codec_specific).error, + EncodedImageCallback::Result::OK); + test.AdvanceTime(TimeDelta::Millis(33)); + ASSERT_THAT(sent_packets, SizeIs(1)); + EXPECT_TRUE( + sent_packets.back().HasExtension()); + + // Send in a new key frame without the support for the dependency descriptor. + encoded_image._frameType = VideoFrameType::kVideoFrameKey; + codec_specific.template_structure = absl::nullopt; + codec_specific.generic_frame_info = absl::nullopt; + EXPECT_EQ(test.router()->OnEncodedImage(encoded_image, &codec_specific).error, + EncodedImageCallback::Result::OK); + test.AdvanceTime(TimeDelta::Millis(33)); + ASSERT_THAT(sent_packets, SizeIs(2)); + EXPECT_FALSE( + sent_packets.back().HasExtension()); +} + TEST(RtpVideoSenderTest, CanSetZeroBitrate) { RtpVideoSenderTestFixture test({kSsrc1}, {kRtxSsrc1}, kPayloadType, {}); test.router()->OnBitrateUpdated(CreateBitrateAllocationUpdate(0), @@ -833,7 +949,7 @@ TEST(RtpVideoSenderTest, CanSetZeroBitrate) { TEST(RtpVideoSenderTest, SimulcastSenderRegistersFrameTransformers) { rtc::scoped_refptr transformer = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); EXPECT_CALL(*transformer, RegisterTransformedFrameSinkCallback(_, kSsrc1)); EXPECT_CALL(*transformer, RegisterTransformedFrameSinkCallback(_, kSsrc2)); @@ -853,7 +969,7 @@ TEST(RtpVideoSenderTest, OverheadIsSubtractedFromTargetBitrate) { constexpr uint32_t kTransportPacketOverheadBytes = 40; constexpr uint32_t kOverheadPerPacketBytes = kRtpHeaderSizeBytes + kTransportPacketOverheadBytes; - RtpVideoSenderTestFixture test({kSsrc1}, {kRtxSsrc1}, kPayloadType, {}); + RtpVideoSenderTestFixture test({kSsrc1}, {}, kPayloadType, {}); test.router()->OnTransportOverheadChanged(kTransportPacketOverheadBytes); test.router()->SetActive(true); diff --git a/call/rtx_receive_stream.cc b/call/rtx_receive_stream.cc index 9e4a41bc8f..c0b138b416 100644 --- a/call/rtx_receive_stream.cc +++ b/call/rtx_receive_stream.cc @@ -64,7 +64,7 @@ void RtxReceiveStream::OnRtpPacket(const RtpPacketReceived& rtx_packet) { media_packet.SetSequenceNumber((payload[0] << 8) + payload[1]); media_packet.SetPayloadType(it->second); media_packet.set_recovered(true); - media_packet.set_arrival_time_ms(rtx_packet.arrival_time_ms()); + media_packet.set_arrival_time(rtx_packet.arrival_time()); // Skip the RTX header. rtc::ArrayView rtx_payload = payload.subview(kRtxHeaderSize); diff --git a/call/rtx_receive_stream_unittest.cc b/call/rtx_receive_stream_unittest.cc index 75086fef9c..b06990820f 100644 --- a/call/rtx_receive_stream_unittest.cc +++ b/call/rtx_receive_stream_unittest.cc @@ -194,9 +194,9 @@ TEST(RtxReceiveStreamTest, PropagatesArrivalTime) { RtxReceiveStream rtx_sink(&media_sink, PayloadTypeMapping(), kMediaSSRC); RtpPacketReceived rtx_packet(nullptr); EXPECT_TRUE(rtx_packet.Parse(rtc::ArrayView(kRtxPacket))); - rtx_packet.set_arrival_time_ms(123); - EXPECT_CALL(media_sink, - OnRtpPacket(Property(&RtpPacketReceived::arrival_time_ms, 123))); + rtx_packet.set_arrival_time(Timestamp::Millis(123)); + EXPECT_CALL(media_sink, OnRtpPacket(Property(&RtpPacketReceived::arrival_time, + Timestamp::Millis(123)))); rtx_sink.OnRtpPacket(rtx_packet); } diff --git a/call/simulated_network.h b/call/simulated_network.h index b53ecc0ddb..68d066cb82 100644 --- a/call/simulated_network.h +++ b/call/simulated_network.h @@ -17,6 +17,7 @@ #include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "api/test/simulated_network.h" #include "api/units/data_size.h" #include "api/units/timestamp.h" @@ -24,7 +25,6 @@ #include "rtc_base/random.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" namespace webrtc { // Implementation of the CoDel active queue management algorithm. Loosely based diff --git a/call/version.cc b/call/version.cc index 9956fa2c1c..a76af47b41 100644 --- a/call/version.cc +++ b/call/version.cc @@ -13,7 +13,7 @@ namespace webrtc { // The timestamp is always in UTC. -const char* const kSourceTimestamp = "WebRTC source stamp 2021-01-13T04:03:21"; +const char* const kSourceTimestamp = "WebRTC source stamp 2021-07-13T04:01:55"; void LoadWebRTCVersionInRegister() { // Using volatile to instruct the compiler to not optimize `p` away even diff --git a/call/video_receive_stream.cc b/call/video_receive_stream.cc index e0f3de366b..d0518b6e0d 100644 --- a/call/video_receive_stream.cc +++ b/call/video_receive_stream.cc @@ -14,10 +14,18 @@ namespace webrtc { +VideoReceiveStream::Decoder::Decoder(SdpVideoFormat video_format, + int payload_type) + : video_format(std::move(video_format)), payload_type(payload_type) {} VideoReceiveStream::Decoder::Decoder() : video_format("Unset") {} VideoReceiveStream::Decoder::Decoder(const Decoder&) = default; VideoReceiveStream::Decoder::~Decoder() = default; +bool VideoReceiveStream::Decoder::operator==(const Decoder& other) const { + return payload_type == other.payload_type && + video_format == other.video_format; +} + std::string VideoReceiveStream::Decoder::ToString() const { char buf[1024]; rtc::SimpleStringBuilder ss(buf); @@ -74,8 +82,10 @@ std::string VideoReceiveStream::Stats::ToString(int64_t time_ms) const { VideoReceiveStream::Config::Config(const Config&) = default; VideoReceiveStream::Config::Config(Config&&) = default; -VideoReceiveStream::Config::Config(Transport* rtcp_send_transport) - : rtcp_send_transport(rtcp_send_transport) {} +VideoReceiveStream::Config::Config(Transport* rtcp_send_transport, + VideoDecoderFactory* decoder_factory) + : decoder_factory(decoder_factory), + rtcp_send_transport(rtcp_send_transport) {} VideoReceiveStream::Config& VideoReceiveStream::Config::operator=(Config&&) = default; diff --git a/call/video_receive_stream.h b/call/video_receive_stream.h index 7a6803d9e2..86e5052151 100644 --- a/call/video_receive_stream.h +++ b/call/video_receive_stream.h @@ -20,17 +20,15 @@ #include "api/call/transport.h" #include "api/crypto/crypto_options.h" -#include "api/crypto/frame_decryptor_interface.h" -#include "api/frame_transformer_interface.h" #include "api/rtp_headers.h" #include "api/rtp_parameters.h" -#include "api/transport/rtp/rtp_source.h" #include "api/video/recordable_encoded_frame.h" #include "api/video/video_content_type.h" #include "api/video/video_frame.h" #include "api/video/video_sink_interface.h" #include "api/video/video_timing.h" #include "api/video_codecs/sdp_video_format.h" +#include "call/receive_stream.h" #include "call/rtp_config.h" #include "common_video/frame_counts.h" #include "modules/rtp_rtcp/include/rtcp_statistics.h" @@ -41,7 +39,7 @@ namespace webrtc { class RtpPacketSinkInterface; class VideoDecoderFactory; -class VideoReceiveStream { +class VideoReceiveStream : public MediaReceiveStream { public: // Class for handling moving in/out recording state. struct RecordingState { @@ -53,11 +51,6 @@ class VideoReceiveStream { // Callback stored from the VideoReceiveStream. The VideoReceiveStream // client should not interpret the attribute. std::function callback; - // Memento of internal state in VideoReceiveStream, recording wether - // we're currently causing generation of a keyframe from the sender. Needed - // to avoid sending double keyframe requests. The VideoReceiveStream client - // should not interpret the attribute. - bool keyframe_needed = false; // Memento of when a keyframe request was last sent. The VideoReceiveStream // client should not interpret the attribute. absl::optional last_keyframe_request_ms; @@ -66,9 +59,13 @@ class VideoReceiveStream { // TODO(mflodman) Move all these settings to VideoDecoder and move the // declaration to common_types.h. struct Decoder { + Decoder(SdpVideoFormat video_format, int payload_type); Decoder(); Decoder(const Decoder&); ~Decoder(); + + bool operator==(const Decoder& other) const; + std::string ToString() const; SdpVideoFormat video_format; @@ -157,7 +154,8 @@ class VideoReceiveStream { public: Config() = delete; Config(Config&&); - explicit Config(Transport* rtcp_send_transport); + Config(Transport* rtcp_send_transport, + VideoDecoderFactory* decoder_factory = nullptr); Config& operator=(Config&&); Config& operator=(const Config&) = delete; ~Config(); @@ -174,17 +172,14 @@ class VideoReceiveStream { VideoDecoderFactory* decoder_factory = nullptr; // Receive-stream specific RTP settings. - struct Rtp { + struct Rtp : public RtpConfig { Rtp(); Rtp(const Rtp&); ~Rtp(); std::string ToString() const; - // Synchronization source (stream identifier) to be received. - uint32_t remote_ssrc = 0; - - // Sender SSRC used for sending RTCP (such as receiver reports). - uint32_t local_ssrc = 0; + // See NackConfig for description. + NackConfig nack; // See RtcpMode for description. RtcpMode rtcp_mode = RtcpMode::kCompound; @@ -196,15 +191,9 @@ class VideoReceiveStream { bool receiver_reference_time_report = false; } rtcp_xr; - // See draft-holmer-rmcat-transport-wide-cc-extensions for details. - bool transport_cc = false; - // See LntfConfig for description. LntfConfig lntf; - // See NackConfig for description. - NackConfig nack; - // Payload types for ULPFEC and RED, respectively. int ulpfec_payload_type = -1; int red_payload_type = -1; @@ -215,6 +204,10 @@ class VideoReceiveStream { // Set if the stream is protected using FlexFEC. bool protected_by_flexfec = false; + // Optional callback sink to support additional packet handlsers such as + // FlexFec. + RtpPacketSinkInterface* packet_sink_ = nullptr; + // Map from rtx payload type -> media payload type. // For RTX to be enabled, both an SSRC and this mapping are needed. std::map rtx_associated_payload_types; @@ -224,9 +217,6 @@ class VideoReceiveStream { // meta data is expected to be present in generic frame descriptor // RTP header extension). std::set raw_payload_types; - - // RTP header extensions used for the received stream. - std::vector extensions; } rtp; // Transport for outgoing packets (RTCP). @@ -252,10 +242,6 @@ class VideoReceiveStream { // used for streaming instead of a real-time call. int target_delay_ms = 0; - // TODO(nisse): Used with VideoDecoderFactory::LegacyCreateVideoDecoder. - // Delete when that method is retired. - std::string stream_id; - // An optional custom frame decryptor that allows the entire frame to be // decrypted in whatever way the caller choses. This is not required by // default. @@ -267,25 +253,9 @@ class VideoReceiveStream { rtc::scoped_refptr frame_transformer; }; - // Starts stream activity. - // When a stream is active, it can receive, process and deliver packets. - virtual void Start() = 0; - // Stops stream activity. - // When a stream is stopped, it can't receive, process or deliver packets. - virtual void Stop() = 0; - // TODO(pbos): Add info on currently-received codec to Stats. virtual Stats GetStats() const = 0; - // RtpDemuxer only forwards a given RTP packet to one sink. However, some - // sinks, such as FlexFEC, might wish to be informed of all of the packets - // a given sink receives (or any set of sinks). They may do so by registering - // themselves as secondary sinks. - virtual void AddSecondarySink(RtpPacketSinkInterface* sink) = 0; - virtual void RemoveSecondarySink(const RtpPacketSinkInterface* sink) = 0; - - virtual std::vector GetSources() const = 0; - // Sets a base minimum for the playout delay. Base minimum delay sets lower // bound on minimum delay value determining lower bound on playout delay. // @@ -295,16 +265,6 @@ class VideoReceiveStream { // Returns current value of base minimum delay in milliseconds. virtual int GetBaseMinimumPlayoutDelayMs() const = 0; - // Allows a FrameDecryptor to be attached to a VideoReceiveStream after - // creation without resetting the decoder state. - virtual void SetFrameDecryptor( - rtc::scoped_refptr frame_decryptor) = 0; - - // Allows a frame transformer to be attached to a VideoReceiveStream after - // creation without resetting the decoder state. - virtual void SetDepacketizerToDecoderFrameTransformer( - rtc::scoped_refptr frame_transformer) = 0; - // Sets and returns recording state. The old state is moved out // of the video receive stream and returned to the caller, and |state| // is moved in. If the state's callback is set, it will be called with @@ -324,6 +284,16 @@ class VideoReceiveStream { virtual ~VideoReceiveStream() {} }; +class DEPRECATED_VideoReceiveStream : public VideoReceiveStream { + public: + // RtpDemuxer only forwards a given RTP packet to one sink. However, some + // sinks, such as FlexFEC, might wish to be informed of all of the packets + // a given sink receives (or any set of sinks). They may do so by registering + // themselves as secondary sinks. + virtual void AddSecondarySink(RtpPacketSinkInterface* sink) = 0; + virtual void RemoveSecondarySink(const RtpPacketSinkInterface* sink) = 0; +}; + } // namespace webrtc #endif // CALL_VIDEO_RECEIVE_STREAM_H_ diff --git a/call/video_send_stream.cc b/call/video_send_stream.cc index 244d78089c..25513e4e4c 100644 --- a/call/video_send_stream.cc +++ b/call/video_send_stream.cc @@ -51,8 +51,13 @@ std::string VideoSendStream::StreamStats::ToString() const { ss << "retransmit_bps: " << retransmit_bitrate_bps << ", "; ss << "avg_delay_ms: " << avg_delay_ms << ", "; ss << "max_delay_ms: " << max_delay_ms << ", "; - ss << "cum_loss: " << rtcp_stats.packets_lost << ", "; - ss << "max_ext_seq: " << rtcp_stats.extended_highest_sequence_number << ", "; + if (report_block_data) { + ss << "cum_loss: " << report_block_data->report_block().packets_lost + << ", "; + ss << "max_ext_seq: " + << report_block_data->report_block().extended_highest_sequence_number + << ", "; + } ss << "nack: " << rtcp_packet_type_counts.nack_packets << ", "; ss << "fir: " << rtcp_packet_type_counts.fir_packets << ", "; ss << "pli: " << rtcp_packet_type_counts.pli_packets; diff --git a/call/video_send_stream.h b/call/video_send_stream.h index 0df9e6ce05..42e6249fcd 100644 --- a/call/video_send_stream.h +++ b/call/video_send_stream.h @@ -82,7 +82,6 @@ class VideoSendStream { uint64_t total_packet_send_delay_ms = 0; StreamDataCounters rtp_stats; RtcpPacketTypeCounter rtcp_packet_type_counts; - RtcpStatistics rtcp_stats; // A snapshot of the most recent Report Block with additional data of // interest to statistics. Used to implement RTCRemoteInboundRtpStreamStats. absl::optional report_block_data; @@ -108,6 +107,7 @@ class VideoSendStream { uint64_t total_encode_time_ms = 0; // https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-totalencodedbytestarget uint64_t total_encoded_bytes_target = 0; + uint32_t frames = 0; uint32_t frames_dropped_by_capturer = 0; uint32_t frames_dropped_by_encoder_queue = 0; uint32_t frames_dropped_by_rate_limiter = 0; @@ -218,6 +218,15 @@ class VideoSendStream { // When a stream is stopped, it can't receive, process or deliver packets. virtual void Stop() = 0; + // Accessor for determining if the stream is active. This is an inexpensive + // call that must be made on the same thread as `Start()` and `Stop()` methods + // are called on and will return `true` iff activity has been started either + // via `Start()` or `UpdateActiveSimulcastLayers()`. If activity is either + // stopped or is in the process of being stopped as a result of a call to + // either `Stop()` or `UpdateActiveSimulcastLayers()` where all layers were + // deactivated, the return value will be `false`. + virtual bool started() = 0; + // If the resource is overusing, the VideoSendStream will try to reduce // resolution or frame rate until no resource is overusing. // TODO(https://crbug.com/webrtc/11565): When the ResourceAdaptationProcessor diff --git a/common_audio/BUILD.gn b/common_audio/BUILD.gn index a03e9ab659..5b1e581410 100644 --- a/common_audio/BUILD.gn +++ b/common_audio/BUILD.gn @@ -335,7 +335,7 @@ if (rtc_build_with_neon) { } } -if (rtc_include_tests) { +if (rtc_include_tests && !build_with_chromium) { rtc_test("common_audio_unittests") { visibility += webrtc_default_visibility testonly = true diff --git a/common_audio/resampler/resampler.cc b/common_audio/resampler/resampler.cc index ccfed5a014..0fdb249052 100644 --- a/common_audio/resampler/resampler.cc +++ b/common_audio/resampler/resampler.cc @@ -916,7 +916,6 @@ int Resampler::Push(const int16_t* samplesIn, outLen = (lengthIn * 8) / 11; free(tmp_mem); return 0; - break; } return 0; } diff --git a/common_audio/signal_processing/division_operations.c b/common_audio/signal_processing/division_operations.c index c6195e7999..4764ddfccd 100644 --- a/common_audio/signal_processing/division_operations.c +++ b/common_audio/signal_processing/division_operations.c @@ -98,8 +98,7 @@ int32_t WebRtcSpl_DivResultInQ31(int32_t num, int32_t den) return div; } -int32_t RTC_NO_SANITIZE("signed-integer-overflow") // bugs.webrtc.org/5486 -WebRtcSpl_DivW32HiLow(int32_t num, int16_t den_hi, int16_t den_low) +int32_t WebRtcSpl_DivW32HiLow(int32_t num, int16_t den_hi, int16_t den_low) { int16_t approx, tmp_hi, tmp_low, num_hi, num_low; int32_t tmpW32; @@ -111,8 +110,8 @@ WebRtcSpl_DivW32HiLow(int32_t num, int16_t den_hi, int16_t den_low) tmpW32 = (den_hi * approx << 1) + ((den_low * approx >> 15) << 1); // tmpW32 = den * approx - tmpW32 = (int32_t)0x7fffffffL - tmpW32; // result in Q30 (tmpW32 = 2.0-(den*approx)) - // UBSan: 2147483647 - -2 cannot be represented in type 'int' + // result in Q30 (tmpW32 = 2.0-(den*approx)) + tmpW32 = (int32_t)((int64_t)0x7fffffffL - tmpW32); // Store tmpW32 in hi and low format tmp_hi = (int16_t)(tmpW32 >> 16); diff --git a/common_audio/signal_processing/include/signal_processing_library.h b/common_audio/signal_processing/include/signal_processing_library.h index 4ad92c4c2b..0c13071a27 100644 --- a/common_audio/signal_processing/include/signal_processing_library.h +++ b/common_audio/signal_processing/include/signal_processing_library.h @@ -228,6 +228,25 @@ int32_t WebRtcSpl_MinValueW32Neon(const int32_t* vector, size_t length); int32_t WebRtcSpl_MinValueW32_mips(const int32_t* vector, size_t length); #endif +// Returns both the minimum and maximum values of a 16-bit vector. +// +// Input: +// - vector : 16-bit input vector. +// - length : Number of samples in vector. +// Ouput: +// - max_val : Maximum sample value in |vector|. +// - min_val : Minimum sample value in |vector|. +void WebRtcSpl_MinMaxW16(const int16_t* vector, + size_t length, + int16_t* min_val, + int16_t* max_val); +#if defined(WEBRTC_HAS_NEON) +void WebRtcSpl_MinMaxW16Neon(const int16_t* vector, + size_t length, + int16_t* min_val, + int16_t* max_val); +#endif + // Returns the vector index to the largest absolute value of a 16-bit vector. // // Input: @@ -240,6 +259,17 @@ int32_t WebRtcSpl_MinValueW32_mips(const int32_t* vector, size_t length); // -32768 presenting an int16 absolute value of 32767). size_t WebRtcSpl_MaxAbsIndexW16(const int16_t* vector, size_t length); +// Returns the element with the largest absolute value of a 16-bit vector. Note +// that this function can return a negative value. +// +// Input: +// - vector : 16-bit input vector. +// - length : Number of samples in vector. +// +// Return value : The element with the largest absolute value. Note that this +// may be a negative value. +int16_t WebRtcSpl_MaxAbsElementW16(const int16_t* vector, size_t length); + // Returns the vector index to the maximum sample value of a 16-bit vector. // // Input: diff --git a/common_audio/signal_processing/min_max_operations.c b/common_audio/signal_processing/min_max_operations.c index d249a02d40..1b9542e7ef 100644 --- a/common_audio/signal_processing/min_max_operations.c +++ b/common_audio/signal_processing/min_max_operations.c @@ -155,6 +155,15 @@ size_t WebRtcSpl_MaxAbsIndexW16(const int16_t* vector, size_t length) { return index; } +int16_t WebRtcSpl_MaxAbsElementW16(const int16_t* vector, size_t length) { + int16_t min_val, max_val; + WebRtcSpl_MinMaxW16(vector, length, &min_val, &max_val); + if (min_val == max_val || min_val < -max_val) { + return min_val; + } + return max_val; +} + // Index of maximum value in a word16 vector. size_t WebRtcSpl_MaxIndexW16(const int16_t* vector, size_t length) { size_t i = 0, index = 0; @@ -222,3 +231,26 @@ size_t WebRtcSpl_MinIndexW32(const int32_t* vector, size_t length) { return index; } + +// Finds both the minimum and maximum elements in an array of 16-bit integers. +void WebRtcSpl_MinMaxW16(const int16_t* vector, size_t length, + int16_t* min_val, int16_t* max_val) { +#if defined(WEBRTC_HAS_NEON) + return WebRtcSpl_MinMaxW16Neon(vector, length, min_val, max_val); +#else + int16_t minimum = WEBRTC_SPL_WORD16_MAX; + int16_t maximum = WEBRTC_SPL_WORD16_MIN; + size_t i = 0; + + RTC_DCHECK_GT(length, 0); + + for (i = 0; i < length; i++) { + if (vector[i] < minimum) + minimum = vector[i]; + if (vector[i] > maximum) + maximum = vector[i]; + } + *min_val = minimum; + *max_val = maximum; +#endif +} diff --git a/common_audio/signal_processing/min_max_operations_neon.c b/common_audio/signal_processing/min_max_operations_neon.c index 53217df7be..e5b4b7c71b 100644 --- a/common_audio/signal_processing/min_max_operations_neon.c +++ b/common_audio/signal_processing/min_max_operations_neon.c @@ -281,3 +281,53 @@ int32_t WebRtcSpl_MinValueW32Neon(const int32_t* vector, size_t length) { return minimum; } +// Finds both the minimum and maximum elements in an array of 16-bit integers. +void WebRtcSpl_MinMaxW16Neon(const int16_t* vector, size_t length, + int16_t* min_val, int16_t* max_val) { + int16_t minimum = WEBRTC_SPL_WORD16_MAX; + int16_t maximum = WEBRTC_SPL_WORD16_MIN; + size_t i = 0; + size_t residual = length & 0x7; + + RTC_DCHECK_GT(length, 0); + + const int16_t* p_start = vector; + int16x8_t min16x8 = vdupq_n_s16(WEBRTC_SPL_WORD16_MAX); + int16x8_t max16x8 = vdupq_n_s16(WEBRTC_SPL_WORD16_MIN); + + // First part, unroll the loop 8 times. + for (i = 0; i < length - residual; i += 8) { + int16x8_t in16x8 = vld1q_s16(p_start); + min16x8 = vminq_s16(min16x8, in16x8); + max16x8 = vmaxq_s16(max16x8, in16x8); + p_start += 8; + } + +#if defined(WEBRTC_ARCH_ARM64) + minimum = vminvq_s16(min16x8); + maximum = vmaxvq_s16(max16x8); +#else + int16x4_t min16x4 = vmin_s16(vget_low_s16(min16x8), vget_high_s16(min16x8)); + min16x4 = vpmin_s16(min16x4, min16x4); + min16x4 = vpmin_s16(min16x4, min16x4); + + minimum = vget_lane_s16(min16x4, 0); + + int16x4_t max16x4 = vmax_s16(vget_low_s16(max16x8), vget_high_s16(max16x8)); + max16x4 = vpmax_s16(max16x4, max16x4); + max16x4 = vpmax_s16(max16x4, max16x4); + + maximum = vget_lane_s16(max16x4, 0); +#endif + + // Second part, do the remaining iterations (if any). + for (i = residual; i > 0; i--) { + if (*p_start < minimum) + minimum = *p_start; + if (*p_start > maximum) + maximum = *p_start; + p_start++; + } + *min_val = minimum; + *max_val = maximum; +} diff --git a/common_audio/signal_processing/signal_processing_unittest.cc b/common_audio/signal_processing/signal_processing_unittest.cc index 3106c47d2d..9ec8590d6c 100644 --- a/common_audio/signal_processing/signal_processing_unittest.cc +++ b/common_audio/signal_processing/signal_processing_unittest.cc @@ -289,6 +289,12 @@ TEST(SplTest, MinMaxOperationsTest) { WebRtcSpl_MinValueW32(vector32, kVectorSize)); EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MinIndexW16(vector16, kVectorSize)); EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MinIndexW32(vector32, kVectorSize)); + EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, + WebRtcSpl_MaxAbsElementW16(vector16, kVectorSize)); + int16_t min_value, max_value; + WebRtcSpl_MinMaxW16(vector16, kVectorSize, &min_value, &max_value); + EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, min_value); + EXPECT_EQ(12334, max_value); // Test the cases where maximum values have to be caught // outside of the unrolled loops in ARM-Neon. @@ -306,6 +312,11 @@ TEST(SplTest, MinMaxOperationsTest) { EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MaxAbsIndexW16(vector16, kVectorSize)); EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MaxIndexW16(vector16, kVectorSize)); EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MaxIndexW32(vector32, kVectorSize)); + EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, + WebRtcSpl_MaxAbsElementW16(vector16, kVectorSize)); + WebRtcSpl_MinMaxW16(vector16, kVectorSize, &min_value, &max_value); + EXPECT_EQ(-29871, min_value); + EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, max_value); // Test the cases where multiple maximum and minimum values are present. vector16[1] = WEBRTC_SPL_WORD16_MAX; @@ -332,6 +343,43 @@ TEST(SplTest, MinMaxOperationsTest) { EXPECT_EQ(1u, WebRtcSpl_MaxIndexW32(vector32, kVectorSize)); EXPECT_EQ(6u, WebRtcSpl_MinIndexW16(vector16, kVectorSize)); EXPECT_EQ(6u, WebRtcSpl_MinIndexW32(vector32, kVectorSize)); + EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, + WebRtcSpl_MaxAbsElementW16(vector16, kVectorSize)); + WebRtcSpl_MinMaxW16(vector16, kVectorSize, &min_value, &max_value); + EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, min_value); + EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, max_value); + + // Test a one-element vector. + int16_t single_element_vector = 0; + EXPECT_EQ(0, WebRtcSpl_MaxAbsValueW16(&single_element_vector, 1)); + EXPECT_EQ(0, WebRtcSpl_MaxValueW16(&single_element_vector, 1)); + EXPECT_EQ(0, WebRtcSpl_MinValueW16(&single_element_vector, 1)); + EXPECT_EQ(0u, WebRtcSpl_MaxAbsIndexW16(&single_element_vector, 1)); + EXPECT_EQ(0u, WebRtcSpl_MaxIndexW16(&single_element_vector, 1)); + EXPECT_EQ(0u, WebRtcSpl_MinIndexW16(&single_element_vector, 1)); + EXPECT_EQ(0, WebRtcSpl_MaxAbsElementW16(&single_element_vector, 1)); + WebRtcSpl_MinMaxW16(&single_element_vector, 1, &min_value, &max_value); + EXPECT_EQ(0, min_value); + EXPECT_EQ(0, max_value); + + // Test a two-element vector with the values WEBRTC_SPL_WORD16_MIN and + // WEBRTC_SPL_WORD16_MAX. + int16_t two_element_vector[2] = {WEBRTC_SPL_WORD16_MIN, + WEBRTC_SPL_WORD16_MAX}; + EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, + WebRtcSpl_MaxAbsValueW16(two_element_vector, 2)); + EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, + WebRtcSpl_MaxValueW16(two_element_vector, 2)); + EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, + WebRtcSpl_MinValueW16(two_element_vector, 2)); + EXPECT_EQ(0u, WebRtcSpl_MaxAbsIndexW16(two_element_vector, 2)); + EXPECT_EQ(1u, WebRtcSpl_MaxIndexW16(two_element_vector, 2)); + EXPECT_EQ(0u, WebRtcSpl_MinIndexW16(two_element_vector, 2)); + EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, + WebRtcSpl_MaxAbsElementW16(two_element_vector, 2)); + WebRtcSpl_MinMaxW16(two_element_vector, 2, &min_value, &max_value); + EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, min_value); + EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, max_value); } TEST(SplTest, VectorOperationsTest) { diff --git a/common_video/BUILD.gn b/common_video/BUILD.gn index bea8530860..8e5376725c 100644 --- a/common_video/BUILD.gn +++ b/common_video/BUILD.gn @@ -21,7 +21,6 @@ rtc_library("common_video") { "h264/h264_common.h", "h264/pps_parser.cc", "h264/pps_parser.h", - "h264/profile_level_id.h", "h264/sps_parser.cc", "h264/sps_parser.h", "h264/sps_vui_rewriter.cc", @@ -42,6 +41,7 @@ rtc_library("common_video") { deps = [ "../api:scoped_refptr", + "../api:sequence_checker", "../api/task_queue", "../api/units:time_delta", "../api/units:timestamp", @@ -49,10 +49,9 @@ rtc_library("common_video") { "../api/video:video_bitrate_allocation", "../api/video:video_bitrate_allocator", "../api/video:video_frame", - "../api/video:video_frame_nv12", "../api/video:video_rtp_headers", "../api/video_codecs:bitstream_parser_api", - "../media:rtc_h264_profile_id", + "../api/video_codecs:video_codecs_api", "../rtc_base", "../rtc_base:checks", "../rtc_base:rtc_task_queue", @@ -71,7 +70,7 @@ rtc_source_set("frame_counts") { sources = [ "frame_counts.h" ] } -if (rtc_include_tests) { +if (rtc_include_tests && !build_with_chromium) { common_video_resources = [ "../resources/foreman_cif.yuv" ] if (is_ios) { @@ -90,7 +89,6 @@ if (rtc_include_tests) { "frame_rate_estimator_unittest.cc", "h264/h264_bitstream_parser_unittest.cc", "h264/pps_parser_unittest.cc", - "h264/profile_level_id_unittest.cc", "h264/sps_parser_unittest.cc", "h264/sps_vui_rewriter_unittest.cc", "libyuv/libyuv_unittest.cc", @@ -104,9 +102,8 @@ if (rtc_include_tests) { "../api/units:time_delta", "../api/video:video_frame", "../api/video:video_frame_i010", - "../api/video:video_frame_nv12", "../api/video:video_rtp_headers", - "../media:rtc_h264_profile_id", + "../api/video_codecs:video_codecs_api", "../rtc_base", "../rtc_base:checks", "../rtc_base:rtc_base_approved", @@ -118,10 +115,11 @@ if (rtc_include_tests) { "../test:test_support", "../test:video_test_common", "//testing/gtest", - "//third_party/abseil-cpp/absl/types:optional", "//third_party/libyuv", ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + data = common_video_resources if (is_android) { deps += [ "//testing/android/native_test:native_test_support" ] diff --git a/common_video/h264/h264_bitstream_parser.cc b/common_video/h264/h264_bitstream_parser.cc index b0ada92d74..3b41599fa0 100644 --- a/common_video/h264/h264_bitstream_parser.cc +++ b/common_video/h264/h264_bitstream_parser.cc @@ -28,11 +28,13 @@ const int kMaxQpValue = 51; namespace webrtc { -#define RETURN_ON_FAIL(x, res) \ - if (!(x)) { \ - RTC_LOG_F(LS_ERROR) << "FAILED: " #x; \ - return res; \ - } +#define RETURN_ON_FAIL(x, res) \ + do { \ + if (!(x)) { \ + RTC_LOG_F(LS_ERROR) << "FAILED: " #x; \ + return res; \ + } \ + } while (0) #define RETURN_INV_ON_FAIL(x) RETURN_ON_FAIL(x, kInvalidStream) @@ -62,64 +64,63 @@ H264BitstreamParser::Result H264BitstreamParser::ParseNonParameterSetNalu( uint32_t bits_tmp; // first_mb_in_slice: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); // slice_type: ue(v) uint32_t slice_type; - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&slice_type)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(slice_type)); // slice_type's 5..9 range is used to indicate that all slices of a picture // have the same value of slice_type % 5, we don't care about that, so we map // to the corresponding 0..4 range. slice_type %= 5; // pic_parameter_set_id: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); if (sps_->separate_colour_plane_flag == 1) { // colour_plane_id - RETURN_INV_ON_FAIL(slice_reader.ReadBits(&bits_tmp, 2)); + RETURN_INV_ON_FAIL(slice_reader.ReadBits(2, bits_tmp)); } // frame_num: u(v) // Represented by log2_max_frame_num bits. - RETURN_INV_ON_FAIL( - slice_reader.ReadBits(&bits_tmp, sps_->log2_max_frame_num)); + RETURN_INV_ON_FAIL(slice_reader.ReadBits(sps_->log2_max_frame_num, bits_tmp)); uint32_t field_pic_flag = 0; if (sps_->frame_mbs_only_flag == 0) { // field_pic_flag: u(1) - RETURN_INV_ON_FAIL(slice_reader.ReadBits(&field_pic_flag, 1)); + RETURN_INV_ON_FAIL(slice_reader.ReadBits(1, field_pic_flag)); if (field_pic_flag != 0) { // bottom_field_flag: u(1) - RETURN_INV_ON_FAIL(slice_reader.ReadBits(&bits_tmp, 1)); + RETURN_INV_ON_FAIL(slice_reader.ReadBits(1, bits_tmp)); } } if (is_idr) { // idr_pic_id: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); } // pic_order_cnt_lsb: u(v) // Represented by sps_.log2_max_pic_order_cnt_lsb bits. if (sps_->pic_order_cnt_type == 0) { RETURN_INV_ON_FAIL( - slice_reader.ReadBits(&bits_tmp, sps_->log2_max_pic_order_cnt_lsb)); + slice_reader.ReadBits(sps_->log2_max_pic_order_cnt_lsb, bits_tmp)); if (pps_->bottom_field_pic_order_in_frame_present_flag && field_pic_flag == 0) { // delta_pic_order_cnt_bottom: se(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); } } if (sps_->pic_order_cnt_type == 1 && !sps_->delta_pic_order_always_zero_flag) { // delta_pic_order_cnt[0]: se(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); if (pps_->bottom_field_pic_order_in_frame_present_flag && !field_pic_flag) { // delta_pic_order_cnt[1]: se(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); } } if (pps_->redundant_pic_cnt_present_flag) { // redundant_pic_cnt: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); } if (slice_type == H264::SliceType::kB) { // direct_spatial_mv_pred_flag: u(1) - RETURN_INV_ON_FAIL(slice_reader.ReadBits(&bits_tmp, 1)); + RETURN_INV_ON_FAIL(slice_reader.ReadBits(1, bits_tmp)); } switch (slice_type) { case H264::SliceType::kP: @@ -128,13 +129,13 @@ H264BitstreamParser::Result H264BitstreamParser::ParseNonParameterSetNalu( uint32_t num_ref_idx_active_override_flag; // num_ref_idx_active_override_flag: u(1) RETURN_INV_ON_FAIL( - slice_reader.ReadBits(&num_ref_idx_active_override_flag, 1)); + slice_reader.ReadBits(1, num_ref_idx_active_override_flag)); if (num_ref_idx_active_override_flag != 0) { // num_ref_idx_l0_active_minus1: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); if (slice_type == H264::SliceType::kB) { // num_ref_idx_l1_active_minus1: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); } } break; @@ -158,20 +159,20 @@ H264BitstreamParser::Result H264BitstreamParser::ParseNonParameterSetNalu( // ref_pic_list_modification_flag_l0: u(1) uint32_t ref_pic_list_modification_flag_l0; RETURN_INV_ON_FAIL( - slice_reader.ReadBits(&ref_pic_list_modification_flag_l0, 1)); + slice_reader.ReadBits(1, ref_pic_list_modification_flag_l0)); if (ref_pic_list_modification_flag_l0) { uint32_t modification_of_pic_nums_idc; do { // modification_of_pic_nums_idc: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb( - &modification_of_pic_nums_idc)); + RETURN_INV_ON_FAIL( + slice_reader.ReadExponentialGolomb(modification_of_pic_nums_idc)); if (modification_of_pic_nums_idc == 0 || modification_of_pic_nums_idc == 1) { // abs_diff_pic_num_minus1: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); } else if (modification_of_pic_nums_idc == 2) { // long_term_pic_num: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); } } while (modification_of_pic_nums_idc != 3); } @@ -180,20 +181,20 @@ H264BitstreamParser::Result H264BitstreamParser::ParseNonParameterSetNalu( // ref_pic_list_modification_flag_l1: u(1) uint32_t ref_pic_list_modification_flag_l1; RETURN_INV_ON_FAIL( - slice_reader.ReadBits(&ref_pic_list_modification_flag_l1, 1)); + slice_reader.ReadBits(1, ref_pic_list_modification_flag_l1)); if (ref_pic_list_modification_flag_l1) { uint32_t modification_of_pic_nums_idc; do { // modification_of_pic_nums_idc: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb( - &modification_of_pic_nums_idc)); + RETURN_INV_ON_FAIL( + slice_reader.ReadExponentialGolomb(modification_of_pic_nums_idc)); if (modification_of_pic_nums_idc == 0 || modification_of_pic_nums_idc == 1) { // abs_diff_pic_num_minus1: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); } else if (modification_of_pic_nums_idc == 2) { // long_term_pic_num: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); } } while (modification_of_pic_nums_idc != 3); } @@ -215,35 +216,35 @@ H264BitstreamParser::Result H264BitstreamParser::ParseNonParameterSetNalu( if (is_idr) { // no_output_of_prior_pics_flag: u(1) // long_term_reference_flag: u(1) - RETURN_INV_ON_FAIL(slice_reader.ReadBits(&bits_tmp, 2)); + RETURN_INV_ON_FAIL(slice_reader.ReadBits(2, bits_tmp)); } else { // adaptive_ref_pic_marking_mode_flag: u(1) uint32_t adaptive_ref_pic_marking_mode_flag; RETURN_INV_ON_FAIL( - slice_reader.ReadBits(&adaptive_ref_pic_marking_mode_flag, 1)); + slice_reader.ReadBits(1, adaptive_ref_pic_marking_mode_flag)); if (adaptive_ref_pic_marking_mode_flag) { uint32_t memory_management_control_operation; do { // memory_management_control_operation: ue(v) RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb( - &memory_management_control_operation)); + memory_management_control_operation)); if (memory_management_control_operation == 1 || memory_management_control_operation == 3) { // difference_of_pic_nums_minus1: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); } if (memory_management_control_operation == 2) { // long_term_pic_num: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); } if (memory_management_control_operation == 3 || memory_management_control_operation == 6) { // long_term_frame_idx: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); } if (memory_management_control_operation == 4) { // max_long_term_frame_idx_plus1: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); } } while (memory_management_control_operation != 0); } @@ -252,12 +253,12 @@ H264BitstreamParser::Result H264BitstreamParser::ParseNonParameterSetNalu( if (pps_->entropy_coding_mode_flag && slice_type != H264::SliceType::kI && slice_type != H264::SliceType::kSi) { // cabac_init_idc: ue(v) - RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(&golomb_tmp)); + RETURN_INV_ON_FAIL(slice_reader.ReadExponentialGolomb(golomb_tmp)); } int32_t last_slice_qp_delta; RETURN_INV_ON_FAIL( - slice_reader.ReadSignedExponentialGolomb(&last_slice_qp_delta)); + slice_reader.ReadSignedExponentialGolomb(last_slice_qp_delta)); if (abs(last_slice_qp_delta) > kMaxAbsQpDeltaValue) { // Something has gone wrong, and the parsed value is invalid. RTC_LOG(LS_WARNING) << "Parsed QP value out of range."; @@ -275,14 +276,14 @@ void H264BitstreamParser::ParseSlice(const uint8_t* slice, size_t length) { sps_ = SpsParser::ParseSps(slice + H264::kNaluTypeSize, length - H264::kNaluTypeSize); if (!sps_) - RTC_LOG(LS_WARNING) << "Unable to parse SPS from H264 bitstream."; + RTC_DLOG(LS_WARNING) << "Unable to parse SPS from H264 bitstream."; break; } case H264::NaluType::kPps: { pps_ = PpsParser::ParsePps(slice + H264::kNaluTypeSize, length - H264::kNaluTypeSize); if (!pps_) - RTC_LOG(LS_WARNING) << "Unable to parse PPS from H264 bitstream."; + RTC_DLOG(LS_WARNING) << "Unable to parse PPS from H264 bitstream."; break; } case H264::NaluType::kAud: @@ -291,7 +292,7 @@ void H264BitstreamParser::ParseSlice(const uint8_t* slice, size_t length) { default: Result res = ParseNonParameterSetNalu(slice, length, nalu_type); if (res != kOk) - RTC_LOG(LS_INFO) << "Failed to parse bitstream. Error: " << res; + RTC_DLOG(LS_INFO) << "Failed to parse bitstream. Error: " << res; break; } } diff --git a/common_video/h264/pps_parser.cc b/common_video/h264/pps_parser.cc index ae01652189..3d3725f95a 100644 --- a/common_video/h264/pps_parser.cc +++ b/common_video/h264/pps_parser.cc @@ -18,9 +18,11 @@ #include "rtc_base/checks.h" #define RETURN_EMPTY_ON_FAIL(x) \ - if (!(x)) { \ - return absl::nullopt; \ - } + do { \ + if (!(x)) { \ + return absl::nullopt; \ + } \ + } while (0) namespace { const int kMaxPicInitQpDeltaValue = 25; @@ -64,14 +66,14 @@ absl::optional PpsParser::ParsePpsIdFromSlice(const uint8_t* data, uint32_t golomb_tmp; // first_mb_in_slice: ue(v) - if (!slice_reader.ReadExponentialGolomb(&golomb_tmp)) + if (!slice_reader.ReadExponentialGolomb(golomb_tmp)) return absl::nullopt; // slice_type: ue(v) - if (!slice_reader.ReadExponentialGolomb(&golomb_tmp)) + if (!slice_reader.ReadExponentialGolomb(golomb_tmp)) return absl::nullopt; // pic_parameter_set_id: ue(v) uint32_t slice_pps_id; - if (!slice_reader.ReadExponentialGolomb(&slice_pps_id)) + if (!slice_reader.ReadExponentialGolomb(slice_pps_id)) return absl::nullopt; return slice_pps_id; } @@ -86,30 +88,29 @@ absl::optional PpsParser::ParseInternal( uint32_t golomb_ignored; // entropy_coding_mode_flag: u(1) uint32_t entropy_coding_mode_flag; - RETURN_EMPTY_ON_FAIL(bit_buffer->ReadBits(&entropy_coding_mode_flag, 1)); + RETURN_EMPTY_ON_FAIL(bit_buffer->ReadBits(1, entropy_coding_mode_flag)); pps.entropy_coding_mode_flag = entropy_coding_mode_flag != 0; // bottom_field_pic_order_in_frame_present_flag: u(1) uint32_t bottom_field_pic_order_in_frame_present_flag; RETURN_EMPTY_ON_FAIL( - bit_buffer->ReadBits(&bottom_field_pic_order_in_frame_present_flag, 1)); + bit_buffer->ReadBits(1, bottom_field_pic_order_in_frame_present_flag)); pps.bottom_field_pic_order_in_frame_present_flag = bottom_field_pic_order_in_frame_present_flag != 0; // num_slice_groups_minus1: ue(v) uint32_t num_slice_groups_minus1; RETURN_EMPTY_ON_FAIL( - bit_buffer->ReadExponentialGolomb(&num_slice_groups_minus1)); + bit_buffer->ReadExponentialGolomb(num_slice_groups_minus1)); if (num_slice_groups_minus1 > 0) { uint32_t slice_group_map_type; // slice_group_map_type: ue(v) RETURN_EMPTY_ON_FAIL( - bit_buffer->ReadExponentialGolomb(&slice_group_map_type)); + bit_buffer->ReadExponentialGolomb(slice_group_map_type)); if (slice_group_map_type == 0) { for (uint32_t i_group = 0; i_group <= num_slice_groups_minus1; ++i_group) { // run_length_minus1[iGroup]: ue(v) - RETURN_EMPTY_ON_FAIL( - bit_buffer->ReadExponentialGolomb(&golomb_ignored)); + RETURN_EMPTY_ON_FAIL(bit_buffer->ReadExponentialGolomb(golomb_ignored)); } } else if (slice_group_map_type == 1) { // TODO(sprang): Implement support for dispersed slice group map type. @@ -118,23 +119,21 @@ absl::optional PpsParser::ParseInternal( for (uint32_t i_group = 0; i_group <= num_slice_groups_minus1; ++i_group) { // top_left[iGroup]: ue(v) - RETURN_EMPTY_ON_FAIL( - bit_buffer->ReadExponentialGolomb(&golomb_ignored)); + RETURN_EMPTY_ON_FAIL(bit_buffer->ReadExponentialGolomb(golomb_ignored)); // bottom_right[iGroup]: ue(v) - RETURN_EMPTY_ON_FAIL( - bit_buffer->ReadExponentialGolomb(&golomb_ignored)); + RETURN_EMPTY_ON_FAIL(bit_buffer->ReadExponentialGolomb(golomb_ignored)); } } else if (slice_group_map_type == 3 || slice_group_map_type == 4 || slice_group_map_type == 5) { // slice_group_change_direction_flag: u(1) - RETURN_EMPTY_ON_FAIL(bit_buffer->ReadBits(&bits_tmp, 1)); + RETURN_EMPTY_ON_FAIL(bit_buffer->ReadBits(1, bits_tmp)); // slice_group_change_rate_minus1: ue(v) - RETURN_EMPTY_ON_FAIL(bit_buffer->ReadExponentialGolomb(&golomb_ignored)); + RETURN_EMPTY_ON_FAIL(bit_buffer->ReadExponentialGolomb(golomb_ignored)); } else if (slice_group_map_type == 6) { // pic_size_in_map_units_minus1: ue(v) uint32_t pic_size_in_map_units_minus1; RETURN_EMPTY_ON_FAIL( - bit_buffer->ReadExponentialGolomb(&pic_size_in_map_units_minus1)); + bit_buffer->ReadExponentialGolomb(pic_size_in_map_units_minus1)); uint32_t slice_group_id_bits = 0; uint32_t num_slice_groups = num_slice_groups_minus1 + 1; // If num_slice_groups is not a power of two an additional bit is required @@ -149,39 +148,39 @@ absl::optional PpsParser::ParseInternal( // slice_group_id[i]: u(v) // Represented by ceil(log2(num_slice_groups_minus1 + 1)) bits. RETURN_EMPTY_ON_FAIL( - bit_buffer->ReadBits(&bits_tmp, slice_group_id_bits)); + bit_buffer->ReadBits(slice_group_id_bits, bits_tmp)); } } } // num_ref_idx_l0_default_active_minus1: ue(v) - RETURN_EMPTY_ON_FAIL(bit_buffer->ReadExponentialGolomb(&golomb_ignored)); + RETURN_EMPTY_ON_FAIL(bit_buffer->ReadExponentialGolomb(golomb_ignored)); // num_ref_idx_l1_default_active_minus1: ue(v) - RETURN_EMPTY_ON_FAIL(bit_buffer->ReadExponentialGolomb(&golomb_ignored)); + RETURN_EMPTY_ON_FAIL(bit_buffer->ReadExponentialGolomb(golomb_ignored)); // weighted_pred_flag: u(1) uint32_t weighted_pred_flag; - RETURN_EMPTY_ON_FAIL(bit_buffer->ReadBits(&weighted_pred_flag, 1)); + RETURN_EMPTY_ON_FAIL(bit_buffer->ReadBits(1, weighted_pred_flag)); pps.weighted_pred_flag = weighted_pred_flag != 0; // weighted_bipred_idc: u(2) - RETURN_EMPTY_ON_FAIL(bit_buffer->ReadBits(&pps.weighted_bipred_idc, 2)); + RETURN_EMPTY_ON_FAIL(bit_buffer->ReadBits(2, pps.weighted_bipred_idc)); // pic_init_qp_minus26: se(v) RETURN_EMPTY_ON_FAIL( - bit_buffer->ReadSignedExponentialGolomb(&pps.pic_init_qp_minus26)); + bit_buffer->ReadSignedExponentialGolomb(pps.pic_init_qp_minus26)); // Sanity-check parsed value if (pps.pic_init_qp_minus26 > kMaxPicInitQpDeltaValue || pps.pic_init_qp_minus26 < kMinPicInitQpDeltaValue) { RETURN_EMPTY_ON_FAIL(false); } // pic_init_qs_minus26: se(v) - RETURN_EMPTY_ON_FAIL(bit_buffer->ReadExponentialGolomb(&golomb_ignored)); + RETURN_EMPTY_ON_FAIL(bit_buffer->ReadExponentialGolomb(golomb_ignored)); // chroma_qp_index_offset: se(v) - RETURN_EMPTY_ON_FAIL(bit_buffer->ReadExponentialGolomb(&golomb_ignored)); + RETURN_EMPTY_ON_FAIL(bit_buffer->ReadExponentialGolomb(golomb_ignored)); // deblocking_filter_control_present_flag: u(1) // constrained_intra_pred_flag: u(1) - RETURN_EMPTY_ON_FAIL(bit_buffer->ReadBits(&bits_tmp, 2)); + RETURN_EMPTY_ON_FAIL(bit_buffer->ReadBits(2, bits_tmp)); // redundant_pic_cnt_present_flag: u(1) RETURN_EMPTY_ON_FAIL( - bit_buffer->ReadBits(&pps.redundant_pic_cnt_present_flag, 1)); + bit_buffer->ReadBits(1, pps.redundant_pic_cnt_present_flag)); return pps; } @@ -189,11 +188,15 @@ absl::optional PpsParser::ParseInternal( bool PpsParser::ParsePpsIdsInternal(rtc::BitBuffer* bit_buffer, uint32_t* pps_id, uint32_t* sps_id) { + if (pps_id == nullptr) + return false; // pic_parameter_set_id: ue(v) - if (!bit_buffer->ReadExponentialGolomb(pps_id)) + if (!bit_buffer->ReadExponentialGolomb(*pps_id)) + return false; + if (sps_id == nullptr) return false; // seq_parameter_set_id: ue(v) - if (!bit_buffer->ReadExponentialGolomb(sps_id)) + if (!bit_buffer->ReadExponentialGolomb(*sps_id)) return false; return true; } diff --git a/common_video/h264/profile_level_id.h b/common_video/h264/profile_level_id.h deleted file mode 100644 index 07b49e57c7..0000000000 --- a/common_video/h264/profile_level_id.h +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef COMMON_VIDEO_H264_PROFILE_LEVEL_ID_H_ -#define COMMON_VIDEO_H264_PROFILE_LEVEL_ID_H_ - -#include "media/base/h264_profile_level_id.h" - -// TODO(zhihuang): Delete this file once dependent applications switch to -// including "webrtc/media/base/h264_profile_level_id.h" directly. - -#endif // COMMON_VIDEO_H264_PROFILE_LEVEL_ID_H_ diff --git a/common_video/h264/profile_level_id_unittest.cc b/common_video/h264/profile_level_id_unittest.cc deleted file mode 100644 index 957b434a3c..0000000000 --- a/common_video/h264/profile_level_id_unittest.cc +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "common_video/h264/profile_level_id.h" - -#include -#include - -#include "absl/types/optional.h" -#include "media/base/h264_profile_level_id.h" -#include "test/gtest.h" - -namespace webrtc { -namespace H264 { - -TEST(H264ProfileLevelId, TestParsingInvalid) { - // Malformed strings. - EXPECT_FALSE(ParseProfileLevelId("")); - EXPECT_FALSE(ParseProfileLevelId(" 42e01f")); - EXPECT_FALSE(ParseProfileLevelId("4242e01f")); - EXPECT_FALSE(ParseProfileLevelId("e01f")); - EXPECT_FALSE(ParseProfileLevelId("gggggg")); - - // Invalid level. - EXPECT_FALSE(ParseProfileLevelId("42e000")); - EXPECT_FALSE(ParseProfileLevelId("42e00f")); - EXPECT_FALSE(ParseProfileLevelId("42e0ff")); - - // Invalid profile. - EXPECT_FALSE(ParseProfileLevelId("42e11f")); - EXPECT_FALSE(ParseProfileLevelId("58601f")); - EXPECT_FALSE(ParseProfileLevelId("64e01f")); -} - -TEST(H264ProfileLevelId, TestParsingLevel) { - EXPECT_EQ(kLevel3_1, ParseProfileLevelId("42e01f")->level); - EXPECT_EQ(kLevel1_1, ParseProfileLevelId("42e00b")->level); - EXPECT_EQ(kLevel1_b, ParseProfileLevelId("42f00b")->level); - EXPECT_EQ(kLevel4_2, ParseProfileLevelId("42C02A")->level); - EXPECT_EQ(kLevel5_2, ParseProfileLevelId("640c34")->level); -} - -TEST(H264ProfileLevelId, TestParsingConstrainedBaseline) { - EXPECT_EQ(kProfileConstrainedBaseline, - ParseProfileLevelId("42e01f")->profile); - EXPECT_EQ(kProfileConstrainedBaseline, - ParseProfileLevelId("42C02A")->profile); - EXPECT_EQ(kProfileConstrainedBaseline, - ParseProfileLevelId("4de01f")->profile); - EXPECT_EQ(kProfileConstrainedBaseline, - ParseProfileLevelId("58f01f")->profile); -} - -TEST(H264ProfileLevelId, TestParsingBaseline) { - EXPECT_EQ(kProfileBaseline, ParseProfileLevelId("42a01f")->profile); - EXPECT_EQ(kProfileBaseline, ParseProfileLevelId("58A01F")->profile); -} - -TEST(H264ProfileLevelId, TestParsingMain) { - EXPECT_EQ(kProfileMain, ParseProfileLevelId("4D401f")->profile); -} - -TEST(H264ProfileLevelId, TestParsingHigh) { - EXPECT_EQ(kProfileHigh, ParseProfileLevelId("64001f")->profile); -} - -TEST(H264ProfileLevelId, TestParsingConstrainedHigh) { - EXPECT_EQ(kProfileConstrainedHigh, ParseProfileLevelId("640c1f")->profile); -} - -TEST(H264ProfileLevelId, TestSupportedLevel) { - EXPECT_EQ(kLevel2_1, *SupportedLevel(640 * 480, 25)); - EXPECT_EQ(kLevel3_1, *SupportedLevel(1280 * 720, 30)); - EXPECT_EQ(kLevel4_2, *SupportedLevel(1920 * 1280, 60)); -} - -// Test supported level below level 1 requirements. -TEST(H264ProfileLevelId, TestSupportedLevelInvalid) { - EXPECT_FALSE(SupportedLevel(0, 0)); - // All levels support fps > 5. - EXPECT_FALSE(SupportedLevel(1280 * 720, 5)); - // All levels support frame sizes > 183 * 137. - EXPECT_FALSE(SupportedLevel(183 * 137, 30)); -} - -TEST(H264ProfileLevelId, TestToString) { - EXPECT_EQ("42e01f", *ProfileLevelIdToString(ProfileLevelId( - kProfileConstrainedBaseline, kLevel3_1))); - EXPECT_EQ("42000a", - *ProfileLevelIdToString(ProfileLevelId(kProfileBaseline, kLevel1))); - EXPECT_EQ("4d001f", - ProfileLevelIdToString(ProfileLevelId(kProfileMain, kLevel3_1))); - EXPECT_EQ("640c2a", *ProfileLevelIdToString( - ProfileLevelId(kProfileConstrainedHigh, kLevel4_2))); - EXPECT_EQ("64002a", - *ProfileLevelIdToString(ProfileLevelId(kProfileHigh, kLevel4_2))); -} - -TEST(H264ProfileLevelId, TestToStringLevel1b) { - EXPECT_EQ("42f00b", *ProfileLevelIdToString(ProfileLevelId( - kProfileConstrainedBaseline, kLevel1_b))); - EXPECT_EQ("42100b", *ProfileLevelIdToString( - ProfileLevelId(kProfileBaseline, kLevel1_b))); - EXPECT_EQ("4d100b", - *ProfileLevelIdToString(ProfileLevelId(kProfileMain, kLevel1_b))); -} - -TEST(H264ProfileLevelId, TestToStringRoundTrip) { - EXPECT_EQ("42e01f", *ProfileLevelIdToString(*ParseProfileLevelId("42e01f"))); - EXPECT_EQ("42e01f", *ProfileLevelIdToString(*ParseProfileLevelId("42E01F"))); - EXPECT_EQ("4d100b", *ProfileLevelIdToString(*ParseProfileLevelId("4d100b"))); - EXPECT_EQ("4d100b", *ProfileLevelIdToString(*ParseProfileLevelId("4D100B"))); - EXPECT_EQ("640c2a", *ProfileLevelIdToString(*ParseProfileLevelId("640c2a"))); - EXPECT_EQ("640c2a", *ProfileLevelIdToString(*ParseProfileLevelId("640C2A"))); -} - -TEST(H264ProfileLevelId, TestToStringInvalid) { - EXPECT_FALSE(ProfileLevelIdToString(ProfileLevelId(kProfileHigh, kLevel1_b))); - EXPECT_FALSE(ProfileLevelIdToString( - ProfileLevelId(kProfileConstrainedHigh, kLevel1_b))); - EXPECT_FALSE(ProfileLevelIdToString( - ProfileLevelId(static_cast(255), kLevel3_1))); -} - -TEST(H264ProfileLevelId, TestParseSdpProfileLevelIdEmpty) { - const absl::optional profile_level_id = - ParseSdpProfileLevelId(CodecParameterMap()); - EXPECT_TRUE(profile_level_id); - EXPECT_EQ(kProfileConstrainedBaseline, profile_level_id->profile); - EXPECT_EQ(kLevel3_1, profile_level_id->level); -} - -TEST(H264ProfileLevelId, TestParseSdpProfileLevelIdConstrainedHigh) { - CodecParameterMap params; - params["profile-level-id"] = "640c2a"; - const absl::optional profile_level_id = - ParseSdpProfileLevelId(params); - EXPECT_TRUE(profile_level_id); - EXPECT_EQ(kProfileConstrainedHigh, profile_level_id->profile); - EXPECT_EQ(kLevel4_2, profile_level_id->level); -} - -TEST(H264ProfileLevelId, TestParseSdpProfileLevelIdInvalid) { - CodecParameterMap params; - params["profile-level-id"] = "foobar"; - EXPECT_FALSE(ParseSdpProfileLevelId(params)); -} - -TEST(H264ProfileLevelId, TestGenerateProfileLevelIdForAnswerEmpty) { - CodecParameterMap answer_params; - GenerateProfileLevelIdForAnswer(CodecParameterMap(), CodecParameterMap(), - &answer_params); - EXPECT_TRUE(answer_params.empty()); -} - -TEST(H264ProfileLevelId, - TestGenerateProfileLevelIdForAnswerLevelSymmetryCapped) { - CodecParameterMap low_level; - low_level["profile-level-id"] = "42e015"; - CodecParameterMap high_level; - high_level["profile-level-id"] = "42e01f"; - - // Level asymmetry is not allowed; test that answer level is the lower of the - // local and remote levels. - CodecParameterMap answer_params; - GenerateProfileLevelIdForAnswer(low_level /* local_supported */, - high_level /* remote_offered */, - &answer_params); - EXPECT_EQ("42e015", answer_params["profile-level-id"]); - - CodecParameterMap answer_params2; - GenerateProfileLevelIdForAnswer(high_level /* local_supported */, - low_level /* remote_offered */, - &answer_params2); - EXPECT_EQ("42e015", answer_params2["profile-level-id"]); -} - -TEST(H264ProfileLevelId, - TestGenerateProfileLevelIdForAnswerConstrainedBaselineLevelAsymmetry) { - CodecParameterMap local_params; - local_params["profile-level-id"] = "42e01f"; - local_params["level-asymmetry-allowed"] = "1"; - CodecParameterMap remote_params; - remote_params["profile-level-id"] = "42e015"; - remote_params["level-asymmetry-allowed"] = "1"; - CodecParameterMap answer_params; - GenerateProfileLevelIdForAnswer(local_params, remote_params, &answer_params); - // When level asymmetry is allowed, we can answer a higher level than what was - // offered. - EXPECT_EQ("42e01f", answer_params["profile-level-id"]); -} - -} // namespace H264 -} // namespace webrtc diff --git a/common_video/h264/sps_parser.cc b/common_video/h264/sps_parser.cc index 3d78184e7a..f505928f29 100644 --- a/common_video/h264/sps_parser.cc +++ b/common_video/h264/sps_parser.cc @@ -71,14 +71,14 @@ absl::optional SpsParser::ParseSpsUpToVui( // profile_idc: u(8). We need it to determine if we need to read/skip chroma // formats. uint8_t profile_idc; - RETURN_EMPTY_ON_FAIL(buffer->ReadUInt8(&profile_idc)); + RETURN_EMPTY_ON_FAIL(buffer->ReadUInt8(profile_idc)); // constraint_set0_flag through constraint_set5_flag + reserved_zero_2bits // 1 bit each for the flags + 2 bits = 8 bits = 1 byte. RETURN_EMPTY_ON_FAIL(buffer->ConsumeBytes(1)); // level_idc: u(8) RETURN_EMPTY_ON_FAIL(buffer->ConsumeBytes(1)); // seq_parameter_set_id: ue(v) - RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(&sps.id)); + RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(sps.id)); sps.separate_colour_plane_flag = 0; // See if profile_idc has chroma format information. if (profile_idc == 100 || profile_idc == 110 || profile_idc == 122 || @@ -86,21 +86,20 @@ absl::optional SpsParser::ParseSpsUpToVui( profile_idc == 86 || profile_idc == 118 || profile_idc == 128 || profile_idc == 138 || profile_idc == 139 || profile_idc == 134) { // chroma_format_idc: ue(v) - RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(&chroma_format_idc)); + RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(chroma_format_idc)); if (chroma_format_idc == 3) { // separate_colour_plane_flag: u(1) - RETURN_EMPTY_ON_FAIL( - buffer->ReadBits(&sps.separate_colour_plane_flag, 1)); + RETURN_EMPTY_ON_FAIL(buffer->ReadBits(1, sps.separate_colour_plane_flag)); } // bit_depth_luma_minus8: ue(v) - RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(&golomb_ignored)); + RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(golomb_ignored)); // bit_depth_chroma_minus8: ue(v) - RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(&golomb_ignored)); + RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(golomb_ignored)); // qpprime_y_zero_transform_bypass_flag: u(1) RETURN_EMPTY_ON_FAIL(buffer->ConsumeBits(1)); // seq_scaling_matrix_present_flag: u(1) uint32_t seq_scaling_matrix_present_flag; - RETURN_EMPTY_ON_FAIL(buffer->ReadBits(&seq_scaling_matrix_present_flag, 1)); + RETURN_EMPTY_ON_FAIL(buffer->ReadBits(1, seq_scaling_matrix_present_flag)); if (seq_scaling_matrix_present_flag) { // Process the scaling lists just enough to be able to properly // skip over them, so we can still read the resolution on streams @@ -110,7 +109,7 @@ absl::optional SpsParser::ParseSpsUpToVui( // seq_scaling_list_present_flag[i] : u(1) uint32_t seq_scaling_list_present_flags; RETURN_EMPTY_ON_FAIL( - buffer->ReadBits(&seq_scaling_list_present_flags, 1)); + buffer->ReadBits(1, seq_scaling_list_present_flags)); if (seq_scaling_list_present_flags != 0) { int last_scale = 8; int next_scale = 8; @@ -120,7 +119,7 @@ absl::optional SpsParser::ParseSpsUpToVui( int32_t delta_scale; // delta_scale: se(v) RETURN_EMPTY_ON_FAIL( - buffer->ReadSignedExponentialGolomb(&delta_scale)); + buffer->ReadSignedExponentialGolomb(delta_scale)); RETURN_EMPTY_ON_FAIL(delta_scale >= kScalingDeltaMin && delta_scale <= kScaldingDeltaMax); next_scale = (last_scale + delta_scale + 256) % 256; @@ -140,18 +139,18 @@ absl::optional SpsParser::ParseSpsUpToVui( // log2_max_frame_num_minus4: ue(v) uint32_t log2_max_frame_num_minus4; - if (!buffer->ReadExponentialGolomb(&log2_max_frame_num_minus4) || + if (!buffer->ReadExponentialGolomb(log2_max_frame_num_minus4) || log2_max_frame_num_minus4 > kMaxLog2Minus4) { return OptionalSps(); } sps.log2_max_frame_num = log2_max_frame_num_minus4 + 4; // pic_order_cnt_type: ue(v) - RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(&sps.pic_order_cnt_type)); + RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(sps.pic_order_cnt_type)); if (sps.pic_order_cnt_type == 0) { // log2_max_pic_order_cnt_lsb_minus4: ue(v) uint32_t log2_max_pic_order_cnt_lsb_minus4; - if (!buffer->ReadExponentialGolomb(&log2_max_pic_order_cnt_lsb_minus4) || + if (!buffer->ReadExponentialGolomb(log2_max_pic_order_cnt_lsb_minus4) || log2_max_pic_order_cnt_lsb_minus4 > kMaxLog2Minus4) { return OptionalSps(); } @@ -159,22 +158,22 @@ absl::optional SpsParser::ParseSpsUpToVui( } else if (sps.pic_order_cnt_type == 1) { // delta_pic_order_always_zero_flag: u(1) RETURN_EMPTY_ON_FAIL( - buffer->ReadBits(&sps.delta_pic_order_always_zero_flag, 1)); + buffer->ReadBits(1, sps.delta_pic_order_always_zero_flag)); // offset_for_non_ref_pic: se(v) - RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(&golomb_ignored)); + RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(golomb_ignored)); // offset_for_top_to_bottom_field: se(v) - RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(&golomb_ignored)); + RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(golomb_ignored)); // num_ref_frames_in_pic_order_cnt_cycle: ue(v) uint32_t num_ref_frames_in_pic_order_cnt_cycle; RETURN_EMPTY_ON_FAIL( - buffer->ReadExponentialGolomb(&num_ref_frames_in_pic_order_cnt_cycle)); + buffer->ReadExponentialGolomb(num_ref_frames_in_pic_order_cnt_cycle)); for (size_t i = 0; i < num_ref_frames_in_pic_order_cnt_cycle; ++i) { // offset_for_ref_frame[i]: se(v) - RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(&golomb_ignored)); + RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(golomb_ignored)); } } // max_num_ref_frames: ue(v) - RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(&sps.max_num_ref_frames)); + RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(sps.max_num_ref_frames)); // gaps_in_frame_num_value_allowed_flag: u(1) RETURN_EMPTY_ON_FAIL(buffer->ConsumeBits(1)); // @@ -185,13 +184,13 @@ absl::optional SpsParser::ParseSpsUpToVui( // // pic_width_in_mbs_minus1: ue(v) uint32_t pic_width_in_mbs_minus1; - RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(&pic_width_in_mbs_minus1)); + RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(pic_width_in_mbs_minus1)); // pic_height_in_map_units_minus1: ue(v) uint32_t pic_height_in_map_units_minus1; RETURN_EMPTY_ON_FAIL( - buffer->ReadExponentialGolomb(&pic_height_in_map_units_minus1)); + buffer->ReadExponentialGolomb(pic_height_in_map_units_minus1)); // frame_mbs_only_flag: u(1) - RETURN_EMPTY_ON_FAIL(buffer->ReadBits(&sps.frame_mbs_only_flag, 1)); + RETURN_EMPTY_ON_FAIL(buffer->ReadBits(1, sps.frame_mbs_only_flag)); if (!sps.frame_mbs_only_flag) { // mb_adaptive_frame_field_flag: u(1) RETURN_EMPTY_ON_FAIL(buffer->ConsumeBits(1)); @@ -207,19 +206,18 @@ absl::optional SpsParser::ParseSpsUpToVui( uint32_t frame_crop_right_offset = 0; uint32_t frame_crop_top_offset = 0; uint32_t frame_crop_bottom_offset = 0; - RETURN_EMPTY_ON_FAIL(buffer->ReadBits(&frame_cropping_flag, 1)); + RETURN_EMPTY_ON_FAIL(buffer->ReadBits(1, frame_cropping_flag)); if (frame_cropping_flag) { // frame_crop_{left, right, top, bottom}_offset: ue(v) + RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(frame_crop_left_offset)); RETURN_EMPTY_ON_FAIL( - buffer->ReadExponentialGolomb(&frame_crop_left_offset)); + buffer->ReadExponentialGolomb(frame_crop_right_offset)); + RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(frame_crop_top_offset)); RETURN_EMPTY_ON_FAIL( - buffer->ReadExponentialGolomb(&frame_crop_right_offset)); - RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(&frame_crop_top_offset)); - RETURN_EMPTY_ON_FAIL( - buffer->ReadExponentialGolomb(&frame_crop_bottom_offset)); + buffer->ReadExponentialGolomb(frame_crop_bottom_offset)); } // vui_parameters_present_flag: u(1) - RETURN_EMPTY_ON_FAIL(buffer->ReadBits(&sps.vui_params_present, 1)); + RETURN_EMPTY_ON_FAIL(buffer->ReadBits(1, sps.vui_params_present)); // Far enough! We don't use the rest of the SPS. diff --git a/common_video/h264/sps_vui_rewriter.cc b/common_video/h264/sps_vui_rewriter.cc index 0d16be8254..856b012b32 100644 --- a/common_video/h264/sps_vui_rewriter.cc +++ b/common_video/h264/sps_vui_rewriter.cc @@ -45,29 +45,31 @@ enum SpsValidEvent { kSpsRewrittenMax = 8 }; -#define RETURN_FALSE_ON_FAIL(x) \ - if (!(x)) { \ - RTC_LOG_F(LS_ERROR) << " (line:" << __LINE__ << ") FAILED: " #x; \ - return false; \ - } +#define RETURN_FALSE_ON_FAIL(x) \ + do { \ + if (!(x)) { \ + RTC_LOG_F(LS_ERROR) << " (line:" << __LINE__ << ") FAILED: " #x; \ + return false; \ + } \ + } while (0) #define COPY_UINT8(src, dest, tmp) \ do { \ - RETURN_FALSE_ON_FAIL((src)->ReadUInt8(&tmp)); \ + RETURN_FALSE_ON_FAIL((src)->ReadUInt8(tmp)); \ if (dest) \ RETURN_FALSE_ON_FAIL((dest)->WriteUInt8(tmp)); \ } while (0) #define COPY_EXP_GOLOMB(src, dest, tmp) \ do { \ - RETURN_FALSE_ON_FAIL((src)->ReadExponentialGolomb(&tmp)); \ + RETURN_FALSE_ON_FAIL((src)->ReadExponentialGolomb(tmp)); \ if (dest) \ RETURN_FALSE_ON_FAIL((dest)->WriteExponentialGolomb(tmp)); \ } while (0) #define COPY_BITS(src, dest, tmp, bits) \ do { \ - RETURN_FALSE_ON_FAIL((src)->ReadBits(&tmp, bits)); \ + RETURN_FALSE_ON_FAIL((src)->ReadBits(bits, tmp)); \ if (dest) \ RETURN_FALSE_ON_FAIL((dest)->WriteBits(tmp, bits)); \ } while (0) @@ -369,7 +371,7 @@ bool CopyAndRewriteVui(const SpsParser::SpsState& sps, // bitstream_restriction_flag: u(1) uint32_t bitstream_restriction_flag; - RETURN_FALSE_ON_FAIL(source->ReadBits(&bitstream_restriction_flag, 1)); + RETURN_FALSE_ON_FAIL(source->ReadBits(1, bitstream_restriction_flag)); RETURN_FALSE_ON_FAIL(destination->WriteBits(1, 1)); if (bitstream_restriction_flag == 0) { // We're adding one from scratch. @@ -396,9 +398,9 @@ bool CopyAndRewriteVui(const SpsParser::SpsState& sps, // want, then we don't need to be rewriting. uint32_t max_num_reorder_frames, max_dec_frame_buffering; RETURN_FALSE_ON_FAIL( - source->ReadExponentialGolomb(&max_num_reorder_frames)); + source->ReadExponentialGolomb(max_num_reorder_frames)); RETURN_FALSE_ON_FAIL( - source->ReadExponentialGolomb(&max_dec_frame_buffering)); + source->ReadExponentialGolomb(max_dec_frame_buffering)); RETURN_FALSE_ON_FAIL(destination->WriteExponentialGolomb(0)); RETURN_FALSE_ON_FAIL( destination->WriteExponentialGolomb(sps.max_num_ref_frames)); @@ -511,15 +513,15 @@ bool CopyOrRewriteVideoSignalTypeInfo( uint8_t colour_primaries = 3; // H264 default: unspecified uint8_t transfer_characteristics = 3; // H264 default: unspecified uint8_t matrix_coefficients = 3; // H264 default: unspecified - RETURN_FALSE_ON_FAIL(source->ReadBits(&video_signal_type_present_flag, 1)); + RETURN_FALSE_ON_FAIL(source->ReadBits(1, video_signal_type_present_flag)); if (video_signal_type_present_flag) { - RETURN_FALSE_ON_FAIL(source->ReadBits(&video_format, 3)); - RETURN_FALSE_ON_FAIL(source->ReadBits(&video_full_range_flag, 1)); - RETURN_FALSE_ON_FAIL(source->ReadBits(&colour_description_present_flag, 1)); + RETURN_FALSE_ON_FAIL(source->ReadBits(3, video_format)); + RETURN_FALSE_ON_FAIL(source->ReadBits(1, video_full_range_flag)); + RETURN_FALSE_ON_FAIL(source->ReadBits(1, colour_description_present_flag)); if (colour_description_present_flag) { - RETURN_FALSE_ON_FAIL(source->ReadUInt8(&colour_primaries)); - RETURN_FALSE_ON_FAIL(source->ReadUInt8(&transfer_characteristics)); - RETURN_FALSE_ON_FAIL(source->ReadUInt8(&matrix_coefficients)); + RETURN_FALSE_ON_FAIL(source->ReadUInt8(colour_primaries)); + RETURN_FALSE_ON_FAIL(source->ReadUInt8(transfer_characteristics)); + RETURN_FALSE_ON_FAIL(source->ReadUInt8(matrix_coefficients)); } } diff --git a/common_video/include/incoming_video_stream.h b/common_video/include/incoming_video_stream.h index a779f2c622..d616c5a2ec 100644 --- a/common_video/include/incoming_video_stream.h +++ b/common_video/include/incoming_video_stream.h @@ -13,6 +13,7 @@ #include +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_factory.h" #include "api/video/video_frame.h" #include "api/video/video_sink_interface.h" @@ -20,7 +21,6 @@ #include "rtc_base/race_checker.h" #include "rtc_base/task_queue.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -35,7 +35,7 @@ class IncomingVideoStream : public rtc::VideoSinkInterface { void OnFrame(const VideoFrame& video_frame) override; void Dequeue(); - rtc::ThreadChecker main_thread_checker_; + SequenceChecker main_thread_checker_; rtc::RaceChecker decoder_race_checker_; VideoRenderFrames render_buffers_ RTC_GUARDED_BY(&incoming_render_queue_); diff --git a/common_video/include/video_frame_buffer.h b/common_video/include/video_frame_buffer.h index bc70f34ec8..593464abe4 100644 --- a/common_video/include/video_frame_buffer.h +++ b/common_video/include/video_frame_buffer.h @@ -12,10 +12,10 @@ #define COMMON_VIDEO_INCLUDE_VIDEO_FRAME_BUFFER_H_ #include +#include #include "api/scoped_refptr.h" #include "api/video/video_frame_buffer.h" -#include "rtc_base/callback.h" #include "rtc_base/ref_counted_object.h" namespace webrtc { @@ -29,7 +29,7 @@ rtc::scoped_refptr WrapI420Buffer( int u_stride, const uint8_t* v_plane, int v_stride, - const rtc::Callback0& no_longer_used); + std::function no_longer_used); rtc::scoped_refptr WrapI444Buffer( int width, @@ -40,7 +40,7 @@ rtc::scoped_refptr WrapI444Buffer( int u_stride, const uint8_t* v_plane, int v_stride, - const rtc::Callback0& no_longer_used); + std::function no_longer_used); rtc::scoped_refptr WrapI420ABuffer( int width, @@ -53,7 +53,7 @@ rtc::scoped_refptr WrapI420ABuffer( int v_stride, const uint8_t* a_plane, int a_stride, - const rtc::Callback0& no_longer_used); + std::function no_longer_used); rtc::scoped_refptr WrapYuvBuffer( VideoFrameBuffer::Type type, @@ -65,7 +65,7 @@ rtc::scoped_refptr WrapYuvBuffer( int u_stride, const uint8_t* v_plane, int v_stride, - const rtc::Callback0& no_longer_used); + std::function no_longer_used); rtc::scoped_refptr WrapI010Buffer( int width, @@ -76,7 +76,7 @@ rtc::scoped_refptr WrapI010Buffer( int u_stride, const uint16_t* v_plane, int v_stride, - const rtc::Callback0& no_longer_used); + std::function no_longer_used); } // namespace webrtc diff --git a/common_video/video_frame_buffer.cc b/common_video/video_frame_buffer.cc index 8bbe7c8728..78a126419a 100644 --- a/common_video/video_frame_buffer.cc +++ b/common_video/video_frame_buffer.cc @@ -30,7 +30,7 @@ class WrappedYuvBuffer : public Base { int u_stride, const uint8_t* v_plane, int v_stride, - const rtc::Callback0& no_longer_used) + std::function no_longer_used) : width_(width), height_(height), y_plane_(y_plane), @@ -70,7 +70,7 @@ class WrappedYuvBuffer : public Base { const int y_stride_; const int u_stride_; const int v_stride_; - rtc::Callback0 no_longer_used_cb_; + std::function no_longer_used_cb_; }; // Template to implement a wrapped buffer for a I4??BufferInterface. @@ -87,7 +87,7 @@ class WrappedYuvaBuffer : public WrappedYuvBuffer { int v_stride, const uint8_t* a_plane, int a_stride, - const rtc::Callback0& no_longer_used) + std::function no_longer_used) : WrappedYuvBuffer(width, height, y_plane, @@ -136,7 +136,7 @@ class WrappedYuv16BBuffer : public Base { int u_stride, const uint16_t* v_plane, int v_stride, - const rtc::Callback0& no_longer_used) + std::function no_longer_used) : width_(width), height_(height), y_plane_(y_plane), @@ -176,7 +176,7 @@ class WrappedYuv16BBuffer : public Base { const int y_stride_; const int u_stride_; const int v_stride_; - rtc::Callback0 no_longer_used_cb_; + std::function no_longer_used_cb_; }; class I010BufferBase : public I010BufferInterface { @@ -206,9 +206,9 @@ rtc::scoped_refptr WrapI420Buffer( int u_stride, const uint8_t* v_plane, int v_stride, - const rtc::Callback0& no_longer_used) { + std::function no_longer_used) { return rtc::scoped_refptr( - new rtc::RefCountedObject>( + rtc::make_ref_counted>( width, height, y_plane, y_stride, u_plane, u_stride, v_plane, v_stride, no_longer_used)); } @@ -224,9 +224,9 @@ rtc::scoped_refptr WrapI420ABuffer( int v_stride, const uint8_t* a_plane, int a_stride, - const rtc::Callback0& no_longer_used) { + std::function no_longer_used) { return rtc::scoped_refptr( - new rtc::RefCountedObject>( + rtc::make_ref_counted>( width, height, y_plane, y_stride, u_plane, u_stride, v_plane, v_stride, a_plane, a_stride, no_longer_used)); } @@ -240,9 +240,9 @@ rtc::scoped_refptr WrapI444Buffer( int u_stride, const uint8_t* v_plane, int v_stride, - const rtc::Callback0& no_longer_used) { + std::function no_longer_used) { return rtc::scoped_refptr( - new rtc::RefCountedObject>( + rtc::make_ref_counted>( width, height, y_plane, y_stride, u_plane, u_stride, v_plane, v_stride, no_longer_used)); } @@ -257,7 +257,7 @@ rtc::scoped_refptr WrapYuvBuffer( int u_stride, const uint8_t* v_plane, int v_stride, - const rtc::Callback0& no_longer_used) { + std::function no_longer_used) { switch (type) { case VideoFrameBuffer::Type::kI420: return WrapI420Buffer(width, height, y_plane, y_stride, u_plane, u_stride, @@ -279,9 +279,9 @@ rtc::scoped_refptr WrapI010Buffer( int u_stride, const uint16_t* v_plane, int v_stride, - const rtc::Callback0& no_longer_used) { + std::function no_longer_used) { return rtc::scoped_refptr( - new rtc::RefCountedObject>( + rtc::make_ref_counted>( width, height, y_plane, y_stride, u_plane, u_stride, v_plane, v_stride, no_longer_used)); } diff --git a/common_video/video_frame_buffer_pool.cc b/common_video/video_frame_buffer_pool.cc index 6df240d9fe..d225370a4d 100644 --- a/common_video/video_frame_buffer_pool.cc +++ b/common_video/video_frame_buffer_pool.cc @@ -107,7 +107,7 @@ rtc::scoped_refptr VideoFrameBufferPool::CreateI420Buffer( return nullptr; // Allocate new buffer. rtc::scoped_refptr buffer = - new rtc::RefCountedObject(width, height); + rtc::make_ref_counted(width, height); if (zero_initialize_) buffer->InitializeData(); @@ -138,7 +138,7 @@ rtc::scoped_refptr VideoFrameBufferPool::CreateNV12Buffer( return nullptr; // Allocate new buffer. rtc::scoped_refptr buffer = - new rtc::RefCountedObject(width, height); + rtc::make_ref_counted(width, height); if (zero_initialize_) buffer->InitializeData(); diff --git a/common_video/video_frame_unittest.cc b/common_video/video_frame_unittest.cc index 9a7a5e2b7c..b82c14716c 100644 --- a/common_video/video_frame_unittest.cc +++ b/common_video/video_frame_unittest.cc @@ -16,7 +16,6 @@ #include "api/video/i010_buffer.h" #include "api/video/i420_buffer.h" #include "api/video/nv12_buffer.h" -#include "rtc_base/bind.h" #include "rtc_base/time_utils.h" #include "test/fake_texture_frame.h" #include "test/frame_utils.h" diff --git a/docs/native-code/development/index.md b/docs/native-code/development/index.md index 3c7a5342da..b19a15ca5e 100644 --- a/docs/native-code/development/index.md +++ b/docs/native-code/development/index.md @@ -98,6 +98,12 @@ configuration untouched (stored in the args.gn file), do: $ gn clean out/Default ``` +To build the fuzzers residing in the [test/fuzzers][fuzzers] directory, use +``` +$ gn gen out/fuzzers --args='use_libfuzzer=true optimize_for_fuzzing=true' +``` +Depending on the fuzzer additional arguments like `is_asan`, `is_msan` or `is_ubsan_security` might be required. + See the [GN][gn-doc] documentation for all available options. There are also more platform specific tips on the [Android][webrtc-android-development] and [iOS][webrtc-ios-development] instructions. @@ -113,6 +119,14 @@ For [Ninja][ninja] project files generated in `out/Default`: $ ninja -C out/Default ``` +To build everything in the generated folder (`out/Default`): + +``` +$ ninja all -C out/Default +``` + +See [Ninja build rules][ninja-build-rules] to read more about difference between `ninja` and `ninja all`. + ## Using Another Build System @@ -256,6 +270,7 @@ Target name `turnserver`. Used for unit tests. [ninja]: https://ninja-build.org/ +[ninja-build-rules]: https://gn.googlesource.com/gn/+/master/docs/reference.md#the-all-and-default-rules [gn]: https://gn.googlesource.com/gn/+/master/README.md [gn-doc]: https://gn.googlesource.com/gn/+/master/docs/reference.md#IDE-options [webrtc-android-development]: https://webrtc.googlesource.com/src/+/refs/heads/master/docs/native-code/android/index.md @@ -268,3 +283,4 @@ Target name `turnserver`. Used for unit tests. [rfc-5766]: https://tools.ietf.org/html/rfc5766 [m80-log]: https://webrtc.googlesource.com/src/+log/branch-heads/3987 [m80]: https://webrtc.googlesource.com/src/+/branch-heads/3987 +[fuzzers]: https://chromium.googlesource.com/external/webrtc/+/refs/heads/master/test/fuzzers/ diff --git a/docs/native-code/rtp-hdrext/video-frame-tracking-id/README.md b/docs/native-code/rtp-hdrext/video-frame-tracking-id/README.md new file mode 100644 index 0000000000..d1c609744e --- /dev/null +++ b/docs/native-code/rtp-hdrext/video-frame-tracking-id/README.md @@ -0,0 +1,27 @@ +# Video Frame Tracking Id + +The Video Frame Tracking Id extension is meant for media quality testing +purpose and shouldn't be used in production. It tracks webrtc::VideoFrame id +field from the sender to the receiver to gather referenced base media quality +metrics such as PSNR or SSIM. +Contact for more info. + +**Name:** "Video Frame Tracking Id" + +**Formal name:** + + +**Status:** This extension is defined to allow for media quality testing. It is +enabled by using a field trial and should only be used in a testing environment. + +### Data layout overview + 1-byte header + 2 bytes of data: + + 0              1 2 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | ID   | L=1 | video-frame-tracking-id | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +Notes: The extension shoud be present only in the first packet of each frame. +If attached to other packets it can be ignored. \ No newline at end of file diff --git a/docs/native-code/rtp-hdrext/video-layers-allocation00/README.md b/docs/native-code/rtp-hdrext/video-layers-allocation00/README.md index 5c98610fcf..f367adab4c 100644 --- a/docs/native-code/rtp-hdrext/video-layers-allocation00/README.md +++ b/docs/native-code/rtp-hdrext/video-layers-allocation00/README.md @@ -80,3 +80,5 @@ extension size. Encoded (width - 1), 16-bit, (height - 1), 16-bit, max frame rate 8-bit per spatial layer per RTP stream. Values are stored in (RTP stream id, spatial id) ascending order. +An empty layer allocation (i.e nothing sent on ssrc) is encoded as +special case with a single 0 byte. diff --git a/docs/release-notes.md b/docs/release-notes.md index 5bb501b781..5f77b9eb6e 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -8,6 +8,8 @@ To find out the current release and schedule, refer to the [chromium dashboard](https://chromiumdash.appspot.com/schedule) ## List of releases + * [M89 Release Notes](https://groups.google.com/g/discuss-webrtc/c/Zrsn2hi8FV0/m/KIbn0EZPBQAJ) + * [M88 Release Notes](https://groups.google.com/g/discuss-webrtc/c/A0FjOcTW2c0/m/UAv-veyPCAAJ) * [M87 Release Notes](https://groups.google.com/g/discuss-webrtc/c/6VmKkCjRK0k/m/YyOTQyQ5AAAJ) * [M86 Release Notes](https://groups.google.com/g/discuss-webrtc/c/pKCOpi9Llyc/m/QhZjyE02BgAJ) * [M85 Release Notes](https://groups.google.com/d/msg/discuss-webrtc/Qq3nsR2w2HU/7WGLPscPBwAJ) diff --git a/examples/BUILD.gn b/examples/BUILD.gn index c2678962d7..b109d903e5 100644 --- a/examples/BUILD.gn +++ b/examples/BUILD.gn @@ -253,8 +253,6 @@ if (is_ios || (is_mac && target_cpu != "x86")) { "objc/AppRTCMobile/ARDAppClient.m", "objc/AppRTCMobile/ARDAppEngineClient.h", "objc/AppRTCMobile/ARDAppEngineClient.m", - "objc/AppRTCMobile/ARDBitrateTracker.h", - "objc/AppRTCMobile/ARDBitrateTracker.m", "objc/AppRTCMobile/ARDCaptureController.h", "objc/AppRTCMobile/ARDCaptureController.m", "objc/AppRTCMobile/ARDExternalSampleCapturer.h", @@ -344,14 +342,14 @@ if (is_ios || (is_mac && target_cpu != "x86")) { "../sdk:base_objc", "../sdk:helpers_objc", "../sdk:mediaconstraints_objc", + "../sdk:metal_objc", "../sdk:peerconnectionfactory_base_objc", "../sdk:peerconnectionfactory_base_objc", - "../sdk:ui_objc", "../sdk:videocapture_objc", "../sdk:videocodec_objc", ] - if (rtc_use_metal_rendering) { - deps += [ "../sdk:metal_objc" ] + if (rtc_ios_macos_use_opengl_rendering) { + deps += [ "../sdk:opengl_ui_objc" ] } frameworks = [ "AVFoundation.framework" ] @@ -501,14 +499,14 @@ if (is_ios || (is_mac && target_cpu != "x86")) { "../sdk:base_objc", "../sdk:default_codec_factory_objc", "../sdk:helpers_objc", + "../sdk:metal_objc", "../sdk:native_api", - "../sdk:ui_objc", "../sdk:videocapture_objc", "../sdk:videotoolbox_objc", ] - if (current_cpu == "arm64") { - deps += [ "../sdk:metal_objc" ] + if (rtc_ios_macos_use_opengl_rendering) { + deps += [ "../sdk:opengl_ui_objc" ] } } @@ -546,9 +544,9 @@ if (is_ios || (is_mac && target_cpu != "x86")) { "../sdk:helpers_objc", "../sdk:mediaconstraints_objc", "../sdk:metal_objc", + "../sdk:opengl_ui_objc", "../sdk:peerconnectionfactory_base_objc", "../sdk:peerconnectionfactory_base_objc", - "../sdk:ui_objc", "../sdk:videocapture_objc", "../sdk:videocodec_objc", ] @@ -686,6 +684,8 @@ if (is_linux || is_chromeos || is_win) { "../p2p:rtc_p2p", "../pc:video_track_source", "../rtc_base:checks", + "../rtc_base:net_helpers", + "../rtc_base:threading", "../rtc_base/third_party/sigslot", "../system_wrappers:field_trial", "../test:field_trial", @@ -760,6 +760,7 @@ if (is_linux || is_chromeos || is_win) { "peerconnection/server/utils.h", ] deps = [ + "../rtc_base:checks", "../rtc_base:rtc_base_approved", "../system_wrappers:field_trial", "../test:field_trial", @@ -777,7 +778,11 @@ if (is_linux || is_chromeos || is_win) { "../p2p:rtc_p2p", "../pc:rtc_pc", "../rtc_base", + "../rtc_base:ip_address", "../rtc_base:rtc_base_approved", + "../rtc_base:socket_address", + "../rtc_base:socket_server", + "../rtc_base:threading", ] } rtc_executable("stunserver") { @@ -789,6 +794,9 @@ if (is_linux || is_chromeos || is_win) { "../pc:rtc_pc", "../rtc_base", "../rtc_base:rtc_base_approved", + "../rtc_base:socket_address", + "../rtc_base:socket_server", + "../rtc_base:threading", ] } } @@ -891,6 +899,7 @@ if (is_android) { ":AppRTCMobile_javalib", "../sdk/android:peerconnection_java", "//base:base_java_test_support", + "//third_party/androidx:androidx_test_core_java", "//third_party/google-truth:google_truth_java", ] @@ -912,6 +921,8 @@ if (!build_with_chromium) { "../rtc_base", "../rtc_base:checks", "../rtc_base:rtc_base_approved", + "../rtc_base:socket_address", + "../rtc_base:threading", "//third_party/abseil-cpp/absl/flags:flag", "//third_party/abseil-cpp/absl/flags:parse", ] diff --git a/examples/aarproject/OWNERS b/examples/aarproject/OWNERS index 3c4e54174e..cf092a316a 100644 --- a/examples/aarproject/OWNERS +++ b/examples/aarproject/OWNERS @@ -1 +1 @@ -sakal@webrtc.org +xalep@webrtc.org diff --git a/examples/androidapp/OWNERS b/examples/androidapp/OWNERS index 299e8b20ec..109bea2725 100644 --- a/examples/androidapp/OWNERS +++ b/examples/androidapp/OWNERS @@ -1,2 +1,2 @@ magjed@webrtc.org -sakal@webrtc.org +xalep@webrtc.org diff --git a/examples/androidapp/res/values/arrays.xml b/examples/androidapp/res/values/arrays.xml index e0e6ccbdc2..4a2948c875 100644 --- a/examples/androidapp/res/values/arrays.xml +++ b/examples/androidapp/res/values/arrays.xml @@ -34,6 +34,7 @@ VP9 H264 Baseline H264 High + AV1 diff --git a/examples/androidapp/src/org/appspot/apprtc/PeerConnectionClient.java b/examples/androidapp/src/org/appspot/apprtc/PeerConnectionClient.java index 8cc487e7b8..b3282a6955 100644 --- a/examples/androidapp/src/org/appspot/apprtc/PeerConnectionClient.java +++ b/examples/androidapp/src/org/appspot/apprtc/PeerConnectionClient.java @@ -36,6 +36,7 @@ import java.util.regex.Pattern; import org.appspot.apprtc.AppRTCClient.SignalingParameters; import org.appspot.apprtc.RecordedAudioToFileController; +import org.webrtc.AddIceObserver; import org.webrtc.AudioSource; import org.webrtc.AudioTrack; import org.webrtc.CameraVideoCapturer; @@ -94,6 +95,8 @@ public class PeerConnectionClient { private static final String VIDEO_CODEC_H264 = "H264"; private static final String VIDEO_CODEC_H264_BASELINE = "H264 Baseline"; private static final String VIDEO_CODEC_H264_HIGH = "H264 High"; + private static final String VIDEO_CODEC_AV1 = "AV1"; + private static final String VIDEO_CODEC_AV1_SDP_CODEC_NAME = "AV1X"; private static final String AUDIO_CODEC_OPUS = "opus"; private static final String AUDIO_CODEC_ISAC = "ISAC"; private static final String VIDEO_CODEC_PARAM_START_BITRATE = "x-google-start-bitrate"; @@ -824,7 +827,16 @@ public void addRemoteIceCandidate(final IceCandidate candidate) { if (queuedRemoteCandidates != null) { queuedRemoteCandidates.add(candidate); } else { - peerConnection.addIceCandidate(candidate); + peerConnection.addIceCandidate(candidate, new AddIceObserver() { + @Override + public void onAddSuccess() { + Log.d(TAG, "Candidate " + candidate + " successfully added."); + } + @Override + public void onAddFailure(String error) { + Log.d(TAG, "Candidate " + candidate + " addition failed: " + error); + } + }); } } }); @@ -976,6 +988,8 @@ private static String getSdpVideoCodecName(PeerConnectionParameters parameters) return VIDEO_CODEC_VP8; case VIDEO_CODEC_VP9: return VIDEO_CODEC_VP9; + case VIDEO_CODEC_AV1: + return VIDEO_CODEC_AV1_SDP_CODEC_NAME; case VIDEO_CODEC_H264_HIGH: case VIDEO_CODEC_H264_BASELINE: return VIDEO_CODEC_H264; @@ -1146,7 +1160,16 @@ private void drainCandidates() { if (queuedRemoteCandidates != null) { Log.d(TAG, "Add " + queuedRemoteCandidates.size() + " remote candidates"); for (IceCandidate candidate : queuedRemoteCandidates) { - peerConnection.addIceCandidate(candidate); + peerConnection.addIceCandidate(candidate, new AddIceObserver() { + @Override + public void onAddSuccess() { + Log.d(TAG, "Candidate " + candidate + " successfully added."); + } + @Override + public void onAddFailure(String error) { + Log.d(TAG, "Candidate " + candidate + " addition failed: " + error); + } + }); } queuedRemoteCandidates = null; } @@ -1293,6 +1316,9 @@ public void onRenegotiationNeeded() { @Override public void onAddTrack(final RtpReceiver receiver, final MediaStream[] mediaStreams) {} + + @Override + public void onRemoveTrack(final RtpReceiver receiver) {} } // Implementation detail: handle offer creation/signaling and answer setting, diff --git a/examples/androidjunit/OWNERS b/examples/androidjunit/OWNERS index 3c4e54174e..cf092a316a 100644 --- a/examples/androidjunit/OWNERS +++ b/examples/androidjunit/OWNERS @@ -1 +1 @@ -sakal@webrtc.org +xalep@webrtc.org diff --git a/examples/androidjunit/src/org/appspot/apprtc/BluetoothManagerTest.java b/examples/androidjunit/src/org/appspot/apprtc/BluetoothManagerTest.java index b97f1f0bf6..3060bd7a56 100644 --- a/examples/androidjunit/src/org/appspot/apprtc/BluetoothManagerTest.java +++ b/examples/androidjunit/src/org/appspot/apprtc/BluetoothManagerTest.java @@ -29,6 +29,7 @@ import android.content.IntentFilter; import android.media.AudioManager; import android.util.Log; +import androidx.test.core.app.ApplicationProvider; import java.util.ArrayList; import java.util.List; import org.appspot.apprtc.AppRTCBluetoothManager.State; @@ -36,7 +37,6 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.robolectric.RuntimeEnvironment; import org.robolectric.annotation.Config; import org.robolectric.shadows.ShadowLog; @@ -65,7 +65,7 @@ public class BluetoothManagerTest { @Before public void setUp() { ShadowLog.stream = System.out; - context = RuntimeEnvironment.application; + context = ApplicationProvider.getApplicationContext(); mockedAppRtcAudioManager = mock(AppRTCAudioManager.class); mockedAudioManager = mock(AudioManager.class); mockedBluetoothHeadset = mock(BluetoothHeadset.class); diff --git a/examples/androidnativeapi/OWNERS b/examples/androidnativeapi/OWNERS index 3c4e54174e..cf092a316a 100644 --- a/examples/androidnativeapi/OWNERS +++ b/examples/androidnativeapi/OWNERS @@ -1 +1 @@ -sakal@webrtc.org +xalep@webrtc.org diff --git a/examples/androidnativeapi/jni/android_call_client.cc b/examples/androidnativeapi/jni/android_call_client.cc index f0b060632d..f38de24a3f 100644 --- a/examples/androidnativeapi/jni/android_call_client.cc +++ b/examples/androidnativeapi/jni/android_call_client.cc @@ -179,9 +179,9 @@ void AndroidCallClient::CreatePeerConnection() { config.sdp_semantics = webrtc::SdpSemantics::kUnifiedPlan; // DTLS SRTP has to be disabled for loopback to work. config.enable_dtls_srtp = false; - pc_ = pcf_->CreatePeerConnection(config, nullptr /* port_allocator */, - nullptr /* cert_generator */, - pc_observer_.get()); + webrtc::PeerConnectionDependencies deps(pc_observer_.get()); + pc_ = pcf_->CreatePeerConnectionOrError(config, std::move(deps)).MoveValue(); + RTC_LOG(LS_INFO) << "PeerConnection created: " << pc_; rtc::scoped_refptr local_video_track = diff --git a/examples/androidnativeapi/jni/android_call_client.h b/examples/androidnativeapi/jni/android_call_client.h index f3f61a4695..c9153d09bd 100644 --- a/examples/androidnativeapi/jni/android_call_client.h +++ b/examples/androidnativeapi/jni/android_call_client.h @@ -18,8 +18,8 @@ #include "api/peer_connection_interface.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/thread_checker.h" #include "sdk/android/native_api/jni/scoped_java_ref.h" #include "sdk/android/native_api/video/video_source.h" @@ -46,7 +46,7 @@ class AndroidCallClient { void CreatePeerConnection() RTC_RUN_ON(thread_checker_); void Connect() RTC_RUN_ON(thread_checker_); - rtc::ThreadChecker thread_checker_; + webrtc::SequenceChecker thread_checker_; bool call_started_ RTC_GUARDED_BY(thread_checker_); diff --git a/examples/androidtests/OWNERS b/examples/androidtests/OWNERS index 3c4e54174e..cf092a316a 100644 --- a/examples/androidtests/OWNERS +++ b/examples/androidtests/OWNERS @@ -1 +1 @@ -sakal@webrtc.org +xalep@webrtc.org diff --git a/examples/androidvoip/BUILD.gn b/examples/androidvoip/BUILD.gn index 649e601519..66dde947ac 100644 --- a/examples/androidvoip/BUILD.gn +++ b/examples/androidvoip/BUILD.gn @@ -29,8 +29,8 @@ if (is_android) { "//sdk/android:base_java", "//sdk/android:java_audio_device_module_java", "//sdk/android:video_java", - "//third_party/android_deps:androidx_core_core_java", - "//third_party/android_deps:androidx_legacy_legacy_support_v4_java", + "//third_party/androidx:androidx_core_core_java", + "//third_party/androidx:androidx_legacy_legacy_support_v4_java", ] shared_libraries = [ ":examples_androidvoip_jni" ] @@ -56,6 +56,9 @@ if (is_android) { deps = [ ":generated_jni", + "../../rtc_base:socket_address", + "../../rtc_base:socket_server", + "../../rtc_base:threading", "//api:transport_api", "//api/audio_codecs:audio_codecs_api", "//api/audio_codecs:builtin_audio_decoder_factory", diff --git a/examples/androidvoip/OWNERS b/examples/androidvoip/OWNERS index 0fe5182450..e7d3200562 100644 --- a/examples/androidvoip/OWNERS +++ b/examples/androidvoip/OWNERS @@ -1,2 +1,2 @@ natim@webrtc.org -sakal@webrtc.org +xalep@webrtc.org diff --git a/examples/objc/AppRTCMobile/ARDAppClient.h b/examples/objc/AppRTCMobile/ARDAppClient.h index 1fed247060..8e124ed925 100644 --- a/examples/objc/AppRTCMobile/ARDAppClient.h +++ b/examples/objc/AppRTCMobile/ARDAppClient.h @@ -48,7 +48,7 @@ typedef NS_ENUM(NSInteger, ARDAppClientState) { - (void)appClient:(ARDAppClient *)client didError:(NSError *)error; -- (void)appClient:(ARDAppClient *)client didGetStats:(NSArray *)stats; +- (void)appClient:(ARDAppClient *)client didGetStats:(RTC_OBJC_TYPE(RTCStatisticsReport) *)stats; @optional - (void)appClient:(ARDAppClient *)client diff --git a/examples/objc/AppRTCMobile/ARDAppClient.m b/examples/objc/AppRTCMobile/ARDAppClient.m index 8d12ff2627..fa6a960a54 100644 --- a/examples/objc/AppRTCMobile/ARDAppClient.m +++ b/examples/objc/AppRTCMobile/ARDAppClient.m @@ -191,9 +191,8 @@ - (void)setShouldGetStats:(BOOL)shouldGetStats { repeats:YES timerHandler:^{ ARDAppClient *strongSelf = weakSelf; - [strongSelf.peerConnection statsForTrack:nil - statsOutputLevel:RTCStatsOutputLevelDebug - completionHandler:^(NSArray *stats) { + [strongSelf.peerConnection statisticsWithCompletionHandler:^( + RTC_OBJC_TYPE(RTCStatisticsReport) * stats) { dispatch_async(dispatch_get_main_queue(), ^{ ARDAppClient *strongSelf = weakSelf; [strongSelf.delegate appClient:strongSelf didGetStats:stats]; @@ -634,7 +633,14 @@ - (void)processSignalingMessage:(ARDSignalingMessage *)message { case kARDSignalingMessageTypeCandidate: { ARDICECandidateMessage *candidateMessage = (ARDICECandidateMessage *)message; - [_peerConnection addIceCandidate:candidateMessage.candidate]; + __weak ARDAppClient *weakSelf = self; + [_peerConnection addIceCandidate:candidateMessage.candidate + completionHandler:^(NSError *error) { + ARDAppClient *strongSelf = weakSelf; + if (error) { + [strongSelf.delegate appClient:strongSelf didError:error]; + } + }]; break; } case kARDSignalingMessageTypeCandidateRemoval: { diff --git a/examples/objc/AppRTCMobile/ARDBitrateTracker.h b/examples/objc/AppRTCMobile/ARDBitrateTracker.h deleted file mode 100644 index 81ac4b4bd5..0000000000 --- a/examples/objc/AppRTCMobile/ARDBitrateTracker.h +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright 2015 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#import - -/** Class used to estimate bitrate based on byte count. It is expected that - * byte count is monotonocially increasing. This class tracks the times that - * byte count is updated, and measures the bitrate based on the byte difference - * over the interval between updates. - */ -@interface ARDBitrateTracker : NSObject - -/** The bitrate in bits per second. */ -@property(nonatomic, readonly) double bitrate; -/** The bitrate as a formatted string in bps, Kbps or Mbps. */ -@property(nonatomic, readonly) NSString *bitrateString; - -/** Converts the bitrate to a readable format in bps, Kbps or Mbps. */ -+ (NSString *)bitrateStringForBitrate:(double)bitrate; -/** Updates the tracked bitrate with the new byte count. */ -- (void)updateBitrateWithCurrentByteCount:(NSInteger)byteCount; - -@end diff --git a/examples/objc/AppRTCMobile/ARDBitrateTracker.m b/examples/objc/AppRTCMobile/ARDBitrateTracker.m deleted file mode 100644 index 8158229187..0000000000 --- a/examples/objc/AppRTCMobile/ARDBitrateTracker.m +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2015 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#import "ARDBitrateTracker.h" - -#import - -@implementation ARDBitrateTracker { - CFTimeInterval _prevTime; - NSInteger _prevByteCount; -} - -@synthesize bitrate = _bitrate; - -+ (NSString *)bitrateStringForBitrate:(double)bitrate { - if (bitrate > 1e6) { - return [NSString stringWithFormat:@"%.2fMbps", bitrate * 1e-6]; - } else if (bitrate > 1e3) { - return [NSString stringWithFormat:@"%.0fKbps", bitrate * 1e-3]; - } else { - return [NSString stringWithFormat:@"%.0fbps", bitrate]; - } -} - -- (NSString *)bitrateString { - return [[self class] bitrateStringForBitrate:_bitrate]; -} - -- (void)updateBitrateWithCurrentByteCount:(NSInteger)byteCount { - CFTimeInterval currentTime = CACurrentMediaTime(); - if (_prevTime && (byteCount > _prevByteCount)) { - _bitrate = (byteCount - _prevByteCount) * 8 / (currentTime - _prevTime); - } - _prevByteCount = byteCount; - _prevTime = currentTime; -} - -@end diff --git a/examples/objc/AppRTCMobile/ARDSettingsModel.m b/examples/objc/AppRTCMobile/ARDSettingsModel.m index 8b04c12f47..c628f0fde5 100644 --- a/examples/objc/AppRTCMobile/ARDSettingsModel.m +++ b/examples/objc/AppRTCMobile/ARDSettingsModel.m @@ -77,16 +77,30 @@ - (BOOL)storeVideoResolutionSetting:(NSString *)resolution { - (RTC_OBJC_TYPE(RTCVideoCodecInfo) *)currentVideoCodecSettingFromStore { [self registerStoreDefaults]; NSData *codecData = [[self settingsStore] videoCodec]; - return [NSKeyedUnarchiver unarchiveObjectWithData:codecData]; + Class expectedClass = [RTC_OBJC_TYPE(RTCVideoCodecInfo) class]; + NSError *error; + RTC_OBJC_TYPE(RTCVideoCodecInfo) *videoCodecSetting = + [NSKeyedUnarchiver unarchivedObjectOfClass:expectedClass fromData:codecData error:&error]; + if (!error) { + return videoCodecSetting; + } + return nil; } - (BOOL)storeVideoCodecSetting:(RTC_OBJC_TYPE(RTCVideoCodecInfo) *)videoCodec { if (![[self availableVideoCodecs] containsObject:videoCodec]) { return NO; } - NSData *codecData = [NSKeyedArchiver archivedDataWithRootObject:videoCodec]; - [[self settingsStore] setVideoCodec:codecData]; - return YES; + + NSError *error; + NSData *codecData = [NSKeyedArchiver archivedDataWithRootObject:videoCodec + requiringSecureCoding:NO + error:&error]; + if (!error) { + [[self settingsStore] setVideoCodec:codecData]; + return YES; + } + return NO; } - (nullable NSNumber *)currentMaxBitrateSettingFromStore { @@ -165,14 +179,18 @@ - (int)videoResolutionComponentAtIndex:(int)index inString:(NSString *)resolutio } - (void)registerStoreDefaults { - NSData *codecData = [NSKeyedArchiver archivedDataWithRootObject:[self defaultVideoCodecSetting]]; - [ARDSettingsStore setDefaultsForVideoResolution:[self defaultVideoResolutionSetting] - videoCodec:codecData - bitrate:nil - audioOnly:NO - createAecDump:NO - useManualAudioConfig:YES]; + NSError *error; + NSData *codecData = [NSKeyedArchiver archivedDataWithRootObject:[self defaultVideoCodecSetting] + requiringSecureCoding:NO + error:&error]; + if (!error) { + [ARDSettingsStore setDefaultsForVideoResolution:[self defaultVideoResolutionSetting] + videoCodec:codecData + bitrate:nil + audioOnly:NO + createAecDump:NO + useManualAudioConfig:YES]; + } } - @end NS_ASSUME_NONNULL_END diff --git a/examples/objc/AppRTCMobile/ARDStatsBuilder.h b/examples/objc/AppRTCMobile/ARDStatsBuilder.h index e8224dd707..eaffa67049 100644 --- a/examples/objc/AppRTCMobile/ARDStatsBuilder.h +++ b/examples/objc/AppRTCMobile/ARDStatsBuilder.h @@ -10,10 +10,9 @@ #import +#import "sdk/objc/api/peerconnection/RTCStatisticsReport.h" #import "sdk/objc/base/RTCMacros.h" -@class RTC_OBJC_TYPE(RTCLegacyStatsReport); - /** Class used to accumulate stats information into a single displayable string. */ @interface ARDStatsBuilder : NSObject @@ -22,10 +21,6 @@ * class. */ @property(nonatomic, readonly) NSString *statsString; - -/** Parses the information in the stats report into an appropriate internal - * format used to generate the stats string. - */ -- (void)parseStatsReport:(RTC_OBJC_TYPE(RTCLegacyStatsReport) *)statsReport; +@property(nonatomic) RTC_OBJC_TYPE(RTCStatisticsReport) * stats; @end diff --git a/examples/objc/AppRTCMobile/ARDStatsBuilder.m b/examples/objc/AppRTCMobile/ARDStatsBuilder.m index a74e351d51..7ebf9fb1c7 100644 --- a/examples/objc/AppRTCMobile/ARDStatsBuilder.m +++ b/examples/objc/AppRTCMobile/ARDStatsBuilder.m @@ -13,333 +13,23 @@ #import "sdk/objc/api/peerconnection/RTCLegacyStatsReport.h" #import "sdk/objc/base/RTCMacros.h" -#import "ARDBitrateTracker.h" #import "ARDUtilities.h" -@implementation ARDStatsBuilder { - // Connection stats. - NSString *_connRecvBitrate; - NSString *_connRtt; - NSString *_connSendBitrate; - NSString *_localCandType; - NSString *_remoteCandType; - NSString *_transportType; +@implementation ARDStatsBuilder - // BWE stats. - NSString *_actualEncBitrate; - NSString *_availableRecvBw; - NSString *_availableSendBw; - NSString *_targetEncBitrate; - - // Video send stats. - NSString *_videoEncodeMs; - NSString *_videoInputFps; - NSString *_videoInputHeight; - NSString *_videoInputWidth; - NSString *_videoSendCodec; - NSString *_videoSendBitrate; - NSString *_videoSendFps; - NSString *_videoSendHeight; - NSString *_videoSendWidth; - - // QP stats. - int _videoQPSum; - int _framesEncoded; - int _oldVideoQPSum; - int _oldFramesEncoded; - - // Video receive stats. - NSString *_videoDecodeMs; - NSString *_videoDecodedFps; - NSString *_videoOutputFps; - NSString *_videoRecvBitrate; - NSString *_videoRecvFps; - NSString *_videoRecvHeight; - NSString *_videoRecvWidth; - - // Audio send stats. - NSString *_audioSendBitrate; - NSString *_audioSendCodec; - - // Audio receive stats. - NSString *_audioCurrentDelay; - NSString *_audioExpandRate; - NSString *_audioRecvBitrate; - NSString *_audioRecvCodec; - - // Bitrate trackers. - ARDBitrateTracker *_audioRecvBitrateTracker; - ARDBitrateTracker *_audioSendBitrateTracker; - ARDBitrateTracker *_connRecvBitrateTracker; - ARDBitrateTracker *_connSendBitrateTracker; - ARDBitrateTracker *_videoRecvBitrateTracker; - ARDBitrateTracker *_videoSendBitrateTracker; -} - -- (instancetype)init { - if (self = [super init]) { - _audioSendBitrateTracker = [[ARDBitrateTracker alloc] init]; - _audioRecvBitrateTracker = [[ARDBitrateTracker alloc] init]; - _connSendBitrateTracker = [[ARDBitrateTracker alloc] init]; - _connRecvBitrateTracker = [[ARDBitrateTracker alloc] init]; - _videoSendBitrateTracker = [[ARDBitrateTracker alloc] init]; - _videoRecvBitrateTracker = [[ARDBitrateTracker alloc] init]; - _videoQPSum = 0; - _framesEncoded = 0; - } - return self; -} +@synthesize stats = _stats; - (NSString *)statsString { NSMutableString *result = [NSMutableString string]; - NSString *systemStatsFormat = @"(cpu)%ld%%\n"; - [result appendString:[NSString stringWithFormat:systemStatsFormat, - (long)ARDGetCpuUsagePercentage()]]; - - // Connection stats. - NSString *connStatsFormat = @"CN %@ms | %@->%@/%@ | (s)%@ | (r)%@\n"; - [result appendString:[NSString stringWithFormat:connStatsFormat, - _connRtt, - _localCandType, _remoteCandType, _transportType, - _connSendBitrate, _connRecvBitrate]]; - - // Video send stats. - NSString *videoSendFormat = @"VS (input) %@x%@@%@fps | (sent) %@x%@@%@fps\n" - "VS (enc) %@/%@ | (sent) %@/%@ | %@ms | %@\n" - "AvgQP (past %d encoded frames) = %d\n "; - int avgqp = [self calculateAvgQP]; - - [result appendString:[NSString stringWithFormat:videoSendFormat, - _videoInputWidth, _videoInputHeight, _videoInputFps, - _videoSendWidth, _videoSendHeight, _videoSendFps, - _actualEncBitrate, _targetEncBitrate, - _videoSendBitrate, _availableSendBw, - _videoEncodeMs, - _videoSendCodec, - _framesEncoded - _oldFramesEncoded, avgqp]]; - - // Video receive stats. - NSString *videoReceiveFormat = - @"VR (recv) %@x%@@%@fps | (decoded)%@ | (output)%@fps | %@/%@ | %@ms\n"; - [result appendString:[NSString stringWithFormat:videoReceiveFormat, - _videoRecvWidth, _videoRecvHeight, _videoRecvFps, - _videoDecodedFps, - _videoOutputFps, - _videoRecvBitrate, _availableRecvBw, - _videoDecodeMs]]; - - // Audio send stats. - NSString *audioSendFormat = @"AS %@ | %@\n"; - [result appendString:[NSString stringWithFormat:audioSendFormat, - _audioSendBitrate, _audioSendCodec]]; - // Audio receive stats. - NSString *audioReceiveFormat = @"AR %@ | %@ | %@ms | (expandrate)%@"; - [result appendString:[NSString stringWithFormat:audioReceiveFormat, - _audioRecvBitrate, _audioRecvCodec, _audioCurrentDelay, - _audioExpandRate]]; + [result appendFormat:@"(cpu)%ld%%\n", (long)ARDGetCpuUsagePercentage()]; - return result; -} - -- (void)parseStatsReport:(RTC_OBJC_TYPE(RTCLegacyStatsReport) *)statsReport { - NSString *reportType = statsReport.type; - if ([reportType isEqualToString:@"ssrc"] && - [statsReport.reportId rangeOfString:@"ssrc"].location != NSNotFound) { - if ([statsReport.reportId rangeOfString:@"send"].location != NSNotFound) { - [self parseSendSsrcStatsReport:statsReport]; - } - if ([statsReport.reportId rangeOfString:@"recv"].location != NSNotFound) { - [self parseRecvSsrcStatsReport:statsReport]; - } - } else if ([reportType isEqualToString:@"VideoBwe"]) { - [self parseBweStatsReport:statsReport]; - } else if ([reportType isEqualToString:@"googCandidatePair"]) { - [self parseConnectionStatsReport:statsReport]; + for (NSString *key in _stats.statistics) { + RTC_OBJC_TYPE(RTCStatistics) *stat = _stats.statistics[key]; + [result appendFormat:@"%@\n", stat.description]; } -} - -#pragma mark - Private - -- (int)calculateAvgQP { - int deltaFramesEncoded = _framesEncoded - _oldFramesEncoded; - int deltaQPSum = _videoQPSum - _oldVideoQPSum; - - return deltaFramesEncoded != 0 ? deltaQPSum / deltaFramesEncoded : 0; -} -- (void)updateBweStatOfKey:(NSString *)key value:(NSString *)value { - if ([key isEqualToString:@"googAvailableSendBandwidth"]) { - _availableSendBw = [ARDBitrateTracker bitrateStringForBitrate:value.doubleValue]; - } else if ([key isEqualToString:@"googAvailableReceiveBandwidth"]) { - _availableRecvBw = [ARDBitrateTracker bitrateStringForBitrate:value.doubleValue]; - } else if ([key isEqualToString:@"googActualEncBitrate"]) { - _actualEncBitrate = [ARDBitrateTracker bitrateStringForBitrate:value.doubleValue]; - } else if ([key isEqualToString:@"googTargetEncBitrate"]) { - _targetEncBitrate = [ARDBitrateTracker bitrateStringForBitrate:value.doubleValue]; - } -} - -- (void)parseBweStatsReport:(RTC_OBJC_TYPE(RTCLegacyStatsReport) *)statsReport { - [statsReport.values - enumerateKeysAndObjectsUsingBlock:^(NSString *key, NSString *value, BOOL *stop) { - [self updateBweStatOfKey:key value:value]; - }]; -} - -- (void)updateConnectionStatOfKey:(NSString *)key value:(NSString *)value { - if ([key isEqualToString:@"googRtt"]) { - _connRtt = value; - } else if ([key isEqualToString:@"googLocalCandidateType"]) { - _localCandType = value; - } else if ([key isEqualToString:@"googRemoteCandidateType"]) { - _remoteCandType = value; - } else if ([key isEqualToString:@"googTransportType"]) { - _transportType = value; - } else if ([key isEqualToString:@"bytesReceived"]) { - NSInteger byteCount = value.integerValue; - [_connRecvBitrateTracker updateBitrateWithCurrentByteCount:byteCount]; - _connRecvBitrate = _connRecvBitrateTracker.bitrateString; - } else if ([key isEqualToString:@"bytesSent"]) { - NSInteger byteCount = value.integerValue; - [_connSendBitrateTracker updateBitrateWithCurrentByteCount:byteCount]; - _connSendBitrate = _connSendBitrateTracker.bitrateString; - } -} - -- (void)parseConnectionStatsReport:(RTC_OBJC_TYPE(RTCLegacyStatsReport) *)statsReport { - NSString *activeConnection = statsReport.values[@"googActiveConnection"]; - if (![activeConnection isEqualToString:@"true"]) { - return; - } - [statsReport.values - enumerateKeysAndObjectsUsingBlock:^(NSString *key, NSString *value, BOOL *stop) { - [self updateConnectionStatOfKey:key value:value]; - }]; -} - -- (void)parseSendSsrcStatsReport:(RTC_OBJC_TYPE(RTCLegacyStatsReport) *)statsReport { - NSDictionary *values = statsReport.values; - if ([values objectForKey:@"googFrameRateSent"]) { - // Video track. - [self parseVideoSendStatsReport:statsReport]; - } else if ([values objectForKey:@"audioInputLevel"]) { - // Audio track. - [self parseAudioSendStatsReport:statsReport]; - } -} - -- (void)updateAudioSendStatOfKey:(NSString *)key value:(NSString *)value { - if ([key isEqualToString:@"googCodecName"]) { - _audioSendCodec = value; - } else if ([key isEqualToString:@"bytesSent"]) { - NSInteger byteCount = value.integerValue; - [_audioSendBitrateTracker updateBitrateWithCurrentByteCount:byteCount]; - _audioSendBitrate = _audioSendBitrateTracker.bitrateString; - } -} - -- (void)parseAudioSendStatsReport:(RTC_OBJC_TYPE(RTCLegacyStatsReport) *)statsReport { - [statsReport.values - enumerateKeysAndObjectsUsingBlock:^(NSString *key, NSString *value, BOOL *stop) { - [self updateAudioSendStatOfKey:key value:value]; - }]; -} - -- (void)updateVideoSendStatOfKey:(NSString *)key value:(NSString *)value { - if ([key isEqualToString:@"googCodecName"]) { - _videoSendCodec = value; - } else if ([key isEqualToString:@"googFrameHeightInput"]) { - _videoInputHeight = value; - } else if ([key isEqualToString:@"googFrameWidthInput"]) { - _videoInputWidth = value; - } else if ([key isEqualToString:@"googFrameRateInput"]) { - _videoInputFps = value; - } else if ([key isEqualToString:@"googFrameHeightSent"]) { - _videoSendHeight = value; - } else if ([key isEqualToString:@"googFrameWidthSent"]) { - _videoSendWidth = value; - } else if ([key isEqualToString:@"googFrameRateSent"]) { - _videoSendFps = value; - } else if ([key isEqualToString:@"googAvgEncodeMs"]) { - _videoEncodeMs = value; - } else if ([key isEqualToString:@"bytesSent"]) { - NSInteger byteCount = value.integerValue; - [_videoSendBitrateTracker updateBitrateWithCurrentByteCount:byteCount]; - _videoSendBitrate = _videoSendBitrateTracker.bitrateString; - } else if ([key isEqualToString:@"qpSum"]) { - _oldVideoQPSum = _videoQPSum; - _videoQPSum = value.integerValue; - } else if ([key isEqualToString:@"framesEncoded"]) { - _oldFramesEncoded = _framesEncoded; - _framesEncoded = value.integerValue; - } -} - -- (void)parseVideoSendStatsReport:(RTC_OBJC_TYPE(RTCLegacyStatsReport) *)statsReport { - [statsReport.values - enumerateKeysAndObjectsUsingBlock:^(NSString *key, NSString *value, BOOL *stop) { - [self updateVideoSendStatOfKey:key value:value]; - }]; -} - -- (void)parseRecvSsrcStatsReport:(RTC_OBJC_TYPE(RTCLegacyStatsReport) *)statsReport { - NSDictionary *values = statsReport.values; - if ([values objectForKey:@"googFrameWidthReceived"]) { - // Video track. - [self parseVideoRecvStatsReport:statsReport]; - } else if ([values objectForKey:@"audioOutputLevel"]) { - // Audio track. - [self parseAudioRecvStatsReport:statsReport]; - } -} - -- (void)updateAudioRecvStatOfKey:(NSString *)key value:(NSString *)value { - if ([key isEqualToString:@"googCodecName"]) { - _audioRecvCodec = value; - } else if ([key isEqualToString:@"bytesReceived"]) { - NSInteger byteCount = value.integerValue; - [_audioRecvBitrateTracker updateBitrateWithCurrentByteCount:byteCount]; - _audioRecvBitrate = _audioRecvBitrateTracker.bitrateString; - } else if ([key isEqualToString:@"googSpeechExpandRate"]) { - _audioExpandRate = value; - } else if ([key isEqualToString:@"googCurrentDelayMs"]) { - _audioCurrentDelay = value; - } -} - -- (void)parseAudioRecvStatsReport:(RTC_OBJC_TYPE(RTCLegacyStatsReport) *)statsReport { - [statsReport.values - enumerateKeysAndObjectsUsingBlock:^(NSString *key, NSString *value, BOOL *stop) { - [self updateAudioRecvStatOfKey:key value:value]; - }]; -} - -- (void)updateVideoRecvStatOfKey:(NSString *)key value:(NSString *)value { - if ([key isEqualToString:@"googFrameHeightReceived"]) { - _videoRecvHeight = value; - } else if ([key isEqualToString:@"googFrameWidthReceived"]) { - _videoRecvWidth = value; - } else if ([key isEqualToString:@"googFrameRateReceived"]) { - _videoRecvFps = value; - } else if ([key isEqualToString:@"googFrameRateDecoded"]) { - _videoDecodedFps = value; - } else if ([key isEqualToString:@"googFrameRateOutput"]) { - _videoOutputFps = value; - } else if ([key isEqualToString:@"googDecodeMs"]) { - _videoDecodeMs = value; - } else if ([key isEqualToString:@"bytesReceived"]) { - NSInteger byteCount = value.integerValue; - [_videoRecvBitrateTracker updateBitrateWithCurrentByteCount:byteCount]; - _videoRecvBitrate = _videoRecvBitrateTracker.bitrateString; - } -} - -- (void)parseVideoRecvStatsReport:(RTC_OBJC_TYPE(RTCLegacyStatsReport) *)statsReport { - [statsReport.values - enumerateKeysAndObjectsUsingBlock:^(NSString *key, NSString *value, BOOL *stop) { - [self updateVideoRecvStatOfKey:key value:value]; - }]; + return result; } @end diff --git a/examples/objc/AppRTCMobile/ios/ARDStatsView.h b/examples/objc/AppRTCMobile/ios/ARDStatsView.h index 9c8636476c..72207de64e 100644 --- a/examples/objc/AppRTCMobile/ios/ARDStatsView.h +++ b/examples/objc/AppRTCMobile/ios/ARDStatsView.h @@ -10,8 +10,12 @@ #import +#import "sdk/objc/base/RTCMacros.h" + +@class RTC_OBJC_TYPE(RTCStatisticsReport); + @interface ARDStatsView : UIView -- (void)setStats:(NSArray *)stats; +- (void)setStats:(RTC_OBJC_TYPE(RTCStatisticsReport) *)stats; @end diff --git a/examples/objc/AppRTCMobile/ios/ARDStatsView.m b/examples/objc/AppRTCMobile/ios/ARDStatsView.m index bd97d30fbe..867ba5b09e 100644 --- a/examples/objc/AppRTCMobile/ios/ARDStatsView.m +++ b/examples/objc/AppRTCMobile/ios/ARDStatsView.m @@ -34,10 +34,8 @@ - (instancetype)initWithFrame:(CGRect)frame { return self; } -- (void)setStats:(NSArray *)stats { - for (RTC_OBJC_TYPE(RTCLegacyStatsReport) * report in stats) { - [_statsBuilder parseStatsReport:report]; - } +- (void)setStats:(RTC_OBJC_TYPE(RTCStatisticsReport) *)stats { + _statsBuilder.stats = stats; _statsLabel.text = _statsBuilder.statsString; } diff --git a/examples/objc/AppRTCMobile/ios/ARDVideoCallView.m b/examples/objc/AppRTCMobile/ios/ARDVideoCallView.m index 4301b7ede9..437aea8d56 100644 --- a/examples/objc/AppRTCMobile/ios/ARDVideoCallView.m +++ b/examples/objc/AppRTCMobile/ios/ARDVideoCallView.m @@ -12,10 +12,7 @@ #import -#import "sdk/objc/components/renderer/opengl/RTCEAGLVideoView.h" -#if defined(RTC_SUPPORTS_METAL) -#import "sdk/objc/components/renderer/metal/RTCMTLVideoView.h" // nogncheck -#endif +#import "sdk/objc/components/renderer/metal/RTCMTLVideoView.h" #import "UIImage+ARDUtilities.h" @@ -44,14 +41,7 @@ @implementation ARDVideoCallView { - (instancetype)initWithFrame:(CGRect)frame { if (self = [super initWithFrame:frame]) { -#if defined(RTC_SUPPORTS_METAL) _remoteVideoView = [[RTC_OBJC_TYPE(RTCMTLVideoView) alloc] initWithFrame:CGRectZero]; -#else - RTC_OBJC_TYPE(RTCEAGLVideoView) *remoteView = - [[RTC_OBJC_TYPE(RTCEAGLVideoView) alloc] initWithFrame:CGRectZero]; - remoteView.delegate = self; - _remoteVideoView = remoteView; -#endif [self addSubview:_remoteVideoView]; diff --git a/examples/objc/AppRTCMobile/ios/ARDVideoCallViewController.m b/examples/objc/AppRTCMobile/ios/ARDVideoCallViewController.m index cd26829713..a82d90b290 100644 --- a/examples/objc/AppRTCMobile/ios/ARDVideoCallViewController.m +++ b/examples/objc/AppRTCMobile/ios/ARDVideoCallViewController.m @@ -132,8 +132,7 @@ - (void)appClient:(ARDAppClient *)client }); } -- (void)appClient:(ARDAppClient *)client - didGetStats:(NSArray *)stats { +- (void)appClient:(ARDAppClient *)client didGetStats:(RTC_OBJC_TYPE(RTCStatisticsReport) *)stats { _videoCallView.statsView.stats = stats; [_videoCallView setNeedsLayout]; } diff --git a/examples/objc/AppRTCMobile/ios/broadcast_extension/ARDBroadcastSampleHandler.m b/examples/objc/AppRTCMobile/ios/broadcast_extension/ARDBroadcastSampleHandler.m index d9c816d573..1c276d965f 100644 --- a/examples/objc/AppRTCMobile/ios/broadcast_extension/ARDBroadcastSampleHandler.m +++ b/examples/objc/AppRTCMobile/ios/broadcast_extension/ARDBroadcastSampleHandler.m @@ -120,7 +120,7 @@ - (void)appClient:(ARDAppClient *)client didReceiveRemoteVideoTrack:(RTC_OBJC_TYPE(RTCVideoTrack) *)remoteVideoTrack { } -- (void)appClient:(ARDAppClient *)client didGetStats:(NSArray *)stats { +- (void)appClient:(ARDAppClient *)client didGetStats:(RTC_OBJC_TYPE(RTCStatisticsReport) *)stats { } - (void)appClient:(ARDAppClient *)client didError:(NSError *)error { diff --git a/examples/objc/AppRTCMobile/third_party/SocketRocket/SRWebSocket.m b/examples/objc/AppRTCMobile/third_party/SocketRocket/SRWebSocket.m index 45f783feb3..ad7b99a4b2 100644 --- a/examples/objc/AppRTCMobile/third_party/SocketRocket/SRWebSocket.m +++ b/examples/objc/AppRTCMobile/third_party/SocketRocket/SRWebSocket.m @@ -482,16 +482,17 @@ - (void)_readHTTPHeader; if (_receivedHTTPHeaders == NULL) { _receivedHTTPHeaders = CFHTTPMessageCreateEmpty(NULL, NO); } - - [self _readUntilHeaderCompleteWithCallback:^(SRWebSocket *self, NSData *data) { - CFHTTPMessageAppendBytes(_receivedHTTPHeaders, (const UInt8 *)data.bytes, data.length); - - if (CFHTTPMessageIsHeaderComplete(_receivedHTTPHeaders)) { - SRFastLog(@"Finished reading headers %@", CFBridgingRelease(CFHTTPMessageCopyAllHeaderFields(_receivedHTTPHeaders))); - [self _HTTPHeadersDidFinish]; - } else { - [self _readHTTPHeader]; - } + + [self _readUntilHeaderCompleteWithCallback:^(SRWebSocket *self, NSData *data) { + CFHTTPMessageAppendBytes(self->_receivedHTTPHeaders, (const UInt8 *)data.bytes, data.length); + + if (CFHTTPMessageIsHeaderComplete(self->_receivedHTTPHeaders)) { + SRFastLog(@"Finished reading headers %@", + CFBridgingRelease(CFHTTPMessageCopyAllHeaderFields(self->_receivedHTTPHeaders))); + [self _HTTPHeadersDidFinish]; + } else { + [self _readHTTPHeader]; + } }]; } @@ -665,8 +666,8 @@ - (void)_closeWithProtocolError:(NSString *)message; // Need to shunt this on the _callbackQueue first to see if they received any messages [self _performDelegateBlock:^{ [self closeWithCode:SRStatusCodeProtocolError reason:message]; - dispatch_async(_workQueue, ^{ - [self _disconnect]; + dispatch_async(self->_workQueue, ^{ + [self _disconnect]; }); }]; } @@ -675,19 +676,19 @@ - (void)_failWithError:(NSError *)error; { dispatch_async(_workQueue, ^{ if (self.readyState != SR_CLOSED) { - _failed = YES; - [self _performDelegateBlock:^{ - if ([self.delegate respondsToSelector:@selector(webSocket:didFailWithError:)]) { - [self.delegate webSocket:self didFailWithError:error]; - } - }]; + self->_failed = YES; + [self _performDelegateBlock:^{ + if ([self.delegate respondsToSelector:@selector(webSocket:didFailWithError:)]) { + [self.delegate webSocket:self didFailWithError:error]; + } + }]; - self.readyState = SR_CLOSED; - _selfRetain = nil; + self.readyState = SR_CLOSED; + self->_selfRetain = nil; - SRFastLog(@"Failing with error %@", error.localizedDescription); - - [self _disconnect]; + SRFastLog(@"Failing with error %@", error.localizedDescription); + + [self _disconnect]; } }); } @@ -735,9 +736,9 @@ - (void)handlePing:(NSData *)pingData; { // Need to pingpong this off _callbackQueue first to make sure messages happen in order [self _performDelegateBlock:^{ - dispatch_async(_workQueue, ^{ - [self _sendFrameWithOpcode:SROpCodePong data:pingData]; - }); + dispatch_async(self->_workQueue, ^{ + [self _sendFrameWithOpcode:SROpCodePong data:pingData]; + }); }]; } @@ -1013,9 +1014,9 @@ - (void)_readFrameContinue; if (header.masked) { [self _closeWithProtocolError:@"Client must receive unmasked data"]; } - - size_t extra_bytes_needed = header.masked ? sizeof(_currentReadMaskKey) : 0; - + + size_t extra_bytes_needed = header.masked ? sizeof(self->_currentReadMaskKey) : 0; + if (header.payload_length == 126) { extra_bytes_needed += sizeof(uint16_t); } else if (header.payload_length == 127) { @@ -1045,8 +1046,10 @@ - (void)_readFrameContinue; if (header.masked) { - assert(mapped_size >= sizeof(_currentReadMaskOffset) + offset); - memcpy(self->_currentReadMaskKey, ((uint8_t *)mapped_buffer) + offset, sizeof(self->_currentReadMaskKey)); + assert(mapped_size >= sizeof(self->_currentReadMaskOffset) + offset); + memcpy(self->_currentReadMaskKey, + ((uint8_t *)mapped_buffer) + offset, + sizeof(self->_currentReadMaskKey)); } [self _handleFrameHeader:header curData:self->_currentFrameData]; @@ -1057,16 +1060,16 @@ - (void)_readFrameContinue; - (void)_readFrameNew; { - dispatch_async(_workQueue, ^{ - [_currentFrameData setLength:0]; - - _currentFrameOpcode = 0; - _currentFrameCount = 0; - _readOpCount = 0; - _currentStringScanPosition = 0; - - [self _readFrameContinue]; - }); + dispatch_async(_workQueue, ^{ + [self->_currentFrameData setLength:0]; + + self->_currentFrameOpcode = 0; + self->_currentFrameCount = 0; + self->_readOpCount = 0; + self->_currentStringScanPosition = 0; + + [self _readFrameContinue]; + }); } - (void)_pumpWriting; @@ -1107,7 +1110,10 @@ - (void)_pumpWriting; if (!_failed) { [self _performDelegateBlock:^{ if ([self.delegate respondsToSelector:@selector(webSocket:didCloseWithCode:reason:wasClean:)]) { - [self.delegate webSocket:self didCloseWithCode:_closeCode reason:_closeReason wasClean:YES]; + [self.delegate webSocket:self + didCloseWithCode:self->_closeCode + reason:self->_closeReason + wasClean:YES]; } }]; } @@ -1420,10 +1426,10 @@ - (void)stream:(NSStream *)aStream handleEvent:(NSStreamEvent)eventCode; if (self.readyState >= SR_CLOSING) { return; } - assert(_readBuffer); - - if (self.readyState == SR_CONNECTING && aStream == _inputStream) { - [self didConnect]; + assert(self->_readBuffer); + + if (self.readyState == SR_CONNECTING && aStream == self->_inputStream) { + [self didConnect]; } [self _pumpWriting]; [self _pumpScanner]; @@ -1434,8 +1440,8 @@ - (void)stream:(NSStream *)aStream handleEvent:(NSStreamEvent)eventCode; SRFastLog(@"NSStreamEventErrorOccurred %@ %@", aStream, [[aStream streamError] copy]); /// TODO specify error better! [self _failWithError:aStream.streamError]; - _readBufferOffset = 0; - [_readBuffer setLength:0]; + self->_readBufferOffset = 0; + [self->_readBuffer setLength:0]; break; } @@ -1448,17 +1454,22 @@ - (void)stream:(NSStream *)aStream handleEvent:(NSStreamEvent)eventCode; } else { if (self.readyState != SR_CLOSED) { self.readyState = SR_CLOSED; - _selfRetain = nil; + self->_selfRetain = nil; } - if (!_sentClose && !_failed) { - _sentClose = YES; - // If we get closed in this state it's probably not clean because we should be sending this when we send messages - [self _performDelegateBlock:^{ + if (!self->_sentClose && !self->_failed) { + self->_sentClose = YES; + // If we get closed in this state it's probably not clean because we should be + // sending this when we send messages + [self + _performDelegateBlock:^{ if ([self.delegate respondsToSelector:@selector(webSocket:didCloseWithCode:reason:wasClean:)]) { - [self.delegate webSocket:self didCloseWithCode:SRStatusCodeGoingAway reason:@"Stream end encountered" wasClean:NO]; + [self.delegate webSocket:self + didCloseWithCode:SRStatusCodeGoingAway + reason:@"Stream end encountered" + wasClean:NO]; } - }]; + }]; } } @@ -1469,19 +1480,19 @@ - (void)stream:(NSStream *)aStream handleEvent:(NSStreamEvent)eventCode; SRFastLog(@"NSStreamEventHasBytesAvailable %@", aStream); const int bufferSize = 2048; uint8_t buffer[bufferSize]; - - while (_inputStream.hasBytesAvailable) { - NSInteger bytes_read = [_inputStream read:buffer maxLength:bufferSize]; - - if (bytes_read > 0) { - [_readBuffer appendBytes:buffer length:bytes_read]; - } else if (bytes_read < 0) { - [self _failWithError:_inputStream.streamError]; - } - - if (bytes_read != bufferSize) { - break; - } + + while (self->_inputStream.hasBytesAvailable) { + NSInteger bytes_read = [self->_inputStream read:buffer maxLength:bufferSize]; + + if (bytes_read > 0) { + [self->_readBuffer appendBytes:buffer length:bytes_read]; + } else if (bytes_read < 0) { + [self _failWithError:self->_inputStream.streamError]; + } + + if (bytes_read != bufferSize) { + break; + } }; [self _pumpScanner]; break; diff --git a/examples/objcnativeapi/objc/NADViewController.mm b/examples/objcnativeapi/objc/NADViewController.mm index 7f6ffbb7e5..fd244799f8 100644 --- a/examples/objcnativeapi/objc/NADViewController.mm +++ b/examples/objcnativeapi/objc/NADViewController.mm @@ -12,10 +12,7 @@ #import "sdk/objc/base/RTCVideoRenderer.h" #import "sdk/objc/components/capturer/RTCCameraVideoCapturer.h" -#if defined(RTC_SUPPORTS_METAL) -#import "sdk/objc/components/renderer/metal/RTCMTLVideoView.h" // nogncheck -#endif -#import "sdk/objc/components/renderer/opengl/RTCEAGLVideoView.h" +#import "sdk/objc/components/renderer/metal/RTCMTLVideoView.h" #import "sdk/objc/helpers/RTCCameraPreviewView.h" #include @@ -49,11 +46,7 @@ @implementation NADViewController { - (void)loadView { _view = [[UIView alloc] initWithFrame:CGRectZero]; -#if defined(RTC_SUPPORTS_METAL) _remoteVideoView = [[RTC_OBJC_TYPE(RTCMTLVideoView) alloc] initWithFrame:CGRectZero]; -#else - _remoteVideoView = [[RTC_OBJC_TYPE(RTCEAGLVideoView) alloc] initWithFrame:CGRectZero]; -#endif _remoteVideoView.translatesAutoresizingMaskIntoConstraints = NO; [_view addSubview:_remoteVideoView]; diff --git a/examples/objcnativeapi/objc/objc_call_client.h b/examples/objcnativeapi/objc/objc_call_client.h index b952402bc0..cb8501d9ce 100644 --- a/examples/objcnativeapi/objc/objc_call_client.h +++ b/examples/objcnativeapi/objc/objc_call_client.h @@ -18,8 +18,8 @@ #include "api/peer_connection_interface.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/thread_checker.h" @class RTC_OBJC_TYPE(RTCVideoCapturer); @protocol RTC_OBJC_TYPE @@ -57,7 +57,7 @@ class ObjCCallClient { void CreatePeerConnection() RTC_RUN_ON(thread_checker_); void Connect() RTC_RUN_ON(thread_checker_); - rtc::ThreadChecker thread_checker_; + webrtc::SequenceChecker thread_checker_; bool call_started_ RTC_GUARDED_BY(thread_checker_); diff --git a/examples/objcnativeapi/objc/objc_call_client.mm b/examples/objcnativeapi/objc/objc_call_client.mm index 5ce7eb7804..419203eb62 100644 --- a/examples/objcnativeapi/objc/objc_call_client.mm +++ b/examples/objcnativeapi/objc/objc_call_client.mm @@ -144,7 +144,7 @@ // DTLS SRTP has to be disabled for loopback to work. config.enable_dtls_srtp = false; webrtc::PeerConnectionDependencies pc_dependencies(pc_observer_.get()); - pc_ = pcf_->CreatePeerConnection(config, std::move(pc_dependencies)); + pc_ = pcf_->CreatePeerConnectionOrError(config, std::move(pc_dependencies)).MoveValue(); RTC_LOG(LS_INFO) << "PeerConnection created: " << pc_; rtc::scoped_refptr local_video_track = diff --git a/examples/peerconnection/client/main.cc b/examples/peerconnection/client/main.cc index cc8bdfbd76..e209171116 100644 --- a/examples/peerconnection/client/main.cc +++ b/examples/peerconnection/client/main.cc @@ -27,7 +27,6 @@ #include "rtc_base/ssl_adapter.h" #include "rtc_base/string_utils.h" // For ToUtf8 #include "rtc_base/win32_socket_init.h" -#include "rtc_base/win32_socket_server.h" #include "system_wrappers/include/field_trial.h" #include "test/field_trial.h" @@ -76,9 +75,8 @@ int PASCAL wWinMain(HINSTANCE instance, wchar_t* cmd_line, int cmd_show) { rtc::WinsockInitializer winsock_init; - rtc::Win32SocketServer w32_ss; - rtc::Win32Thread w32_thread(&w32_ss); - rtc::ThreadManager::Instance()->SetCurrentThread(&w32_thread); + rtc::PhysicalSocketServer ss; + rtc::AutoSocketServerThread main_thread(&ss); WindowsCommandLineArguments win_args; int argc = win_args.argc(); diff --git a/examples/peerconnection/client/peer_connection_client.h b/examples/peerconnection/client/peer_connection_client.h index 56c235a82a..d7ae91343d 100644 --- a/examples/peerconnection/client/peer_connection_client.h +++ b/examples/peerconnection/client/peer_connection_client.h @@ -17,7 +17,6 @@ #include "rtc_base/net_helpers.h" #include "rtc_base/physical_socket_server.h" -#include "rtc_base/signal_thread.h" #include "rtc_base/third_party/sigslot/sigslot.h" typedef std::map Peers; diff --git a/examples/peerconnection/server/data_socket.cc b/examples/peerconnection/server/data_socket.cc index ced0fd1bae..2d595a0e86 100644 --- a/examples/peerconnection/server/data_socket.cc +++ b/examples/peerconnection/server/data_socket.cc @@ -10,7 +10,6 @@ #include "examples/peerconnection/server/data_socket.h" -#include #include #include #include @@ -20,6 +19,7 @@ #endif #include "examples/peerconnection/server/utils.h" +#include "rtc_base/checks.h" static const char kHeaderTerminator[] = "\r\n\r\n"; static const int kHeaderTerminatorLength = sizeof(kHeaderTerminator) - 1; @@ -53,7 +53,7 @@ WinsockInitializer WinsockInitializer::singleton; // bool SocketBase::Create() { - assert(!valid()); + RTC_DCHECK(!valid()); socket_ = ::socket(AF_INET, SOCK_STREAM, 0); return valid(); } @@ -77,7 +77,7 @@ std::string DataSocket::request_arguments() const { } bool DataSocket::PathEquals(const char* path) const { - assert(path); + RTC_DCHECK(path); size_t args = request_path_.find('?'); if (args != std::string::npos) return request_path_.substr(0, args).compare(path) == 0; @@ -85,7 +85,7 @@ bool DataSocket::PathEquals(const char* path) const { } bool DataSocket::OnDataAvailable(bool* close_socket) { - assert(valid()); + RTC_DCHECK(valid()); char buffer[0xfff] = {0}; int bytes = recv(socket_, buffer, sizeof(buffer), 0); if (bytes == SOCKET_ERROR || bytes == 0) { @@ -125,8 +125,8 @@ bool DataSocket::Send(const std::string& status, const std::string& content_type, const std::string& extra_headers, const std::string& data) const { - assert(valid()); - assert(!status.empty()); + RTC_DCHECK(valid()); + RTC_DCHECK(!status.empty()); std::string buffer("HTTP/1.1 " + status + "\r\n"); buffer += @@ -165,8 +165,8 @@ void DataSocket::Clear() { } bool DataSocket::ParseHeaders() { - assert(!request_headers_.empty()); - assert(method_ == INVALID); + RTC_DCHECK(!request_headers_.empty()); + RTC_DCHECK_EQ(method_, INVALID); size_t i = request_headers_.find("\r\n"); if (i == std::string::npos) return false; @@ -174,8 +174,8 @@ bool DataSocket::ParseHeaders() { if (!ParseMethodAndPath(request_headers_.data(), i)) return false; - assert(method_ != INVALID); - assert(!request_path_.empty()); + RTC_DCHECK_NE(method_, INVALID); + RTC_DCHECK(!request_path_.empty()); if (method_ == POST) { const char* headers = request_headers_.data() + i + 2; @@ -225,8 +225,8 @@ bool DataSocket::ParseMethodAndPath(const char* begin, size_t len) { } bool DataSocket::ParseContentLengthAndType(const char* headers, size_t length) { - assert(content_length_ == 0); - assert(content_type_.empty()); + RTC_DCHECK_EQ(content_length_, 0); + RTC_DCHECK(content_type_.empty()); const char* end = headers + length; while (headers && headers < end) { @@ -267,7 +267,7 @@ bool DataSocket::ParseContentLengthAndType(const char* headers, size_t length) { // bool ListeningSocket::Listen(unsigned short port) { - assert(valid()); + RTC_DCHECK(valid()); int enabled = 1; setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&enabled), sizeof(enabled)); @@ -284,7 +284,7 @@ bool ListeningSocket::Listen(unsigned short port) { } DataSocket* ListeningSocket::Accept() const { - assert(valid()); + RTC_DCHECK(valid()); struct sockaddr_in addr = {0}; socklen_t size = sizeof(addr); NativeSocket client = diff --git a/examples/peerconnection/server/main.cc b/examples/peerconnection/server/main.cc index b80e4d8247..50b8c23401 100644 --- a/examples/peerconnection/server/main.cc +++ b/examples/peerconnection/server/main.cc @@ -8,7 +8,6 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include #include #include #if defined(WEBRTC_POSIX) @@ -24,6 +23,7 @@ #include "absl/flags/usage.h" #include "examples/peerconnection/server/data_socket.h" #include "examples/peerconnection/server/peer_channel.h" +#include "rtc_base/checks.h" #include "system_wrappers/include/field_trial.h" #include "test/field_trial.h" @@ -41,8 +41,8 @@ ABSL_FLAG(int, port, 8888, "default: 8888"); static const size_t kMaxConnections = (FD_SETSIZE - 2); void HandleBrowserRequest(DataSocket* ds, bool* quit) { - assert(ds && ds->valid()); - assert(quit); + RTC_DCHECK(ds && ds->valid()); + RTC_DCHECK(quit); const std::string& path = ds->request_path(); @@ -162,7 +162,7 @@ int main(int argc, char* argv[]) { if (socket_done) { printf("Disconnecting socket\n"); clients.OnClosing(s); - assert(s->valid()); // Close must not have been called yet. + RTC_DCHECK(s->valid()); // Close must not have been called yet. FD_CLR(s->socket(), &socket_set); delete (*i); i = sockets.erase(i); diff --git a/examples/peerconnection/server/peer_channel.cc b/examples/peerconnection/server/peer_channel.cc index be0f282abd..f53820cc60 100644 --- a/examples/peerconnection/server/peer_channel.cc +++ b/examples/peerconnection/server/peer_channel.cc @@ -10,7 +10,6 @@ #include "examples/peerconnection/server/peer_channel.h" -#include #include #include @@ -18,6 +17,7 @@ #include "examples/peerconnection/server/data_socket.h" #include "examples/peerconnection/server/utils.h" +#include "rtc_base/checks.h" // Set to the peer id of the originator when messages are being // exchanged between peers, but set to the id of the receiving peer @@ -57,9 +57,9 @@ ChannelMember::ChannelMember(DataSocket* socket) id_(++s_member_id_), connected_(true), timestamp_(time(NULL)) { - assert(socket); - assert(socket->method() == DataSocket::GET); - assert(socket->PathEquals("/sign_in")); + RTC_DCHECK(socket); + RTC_DCHECK_EQ(socket->method(), DataSocket::GET); + RTC_DCHECK(socket->PathEquals("/sign_in")); name_ = socket->request_arguments(); if (name_.empty()) name_ = "peer_" + int2str(id_); @@ -85,14 +85,14 @@ std::string ChannelMember::GetPeerIdHeader() const { } bool ChannelMember::NotifyOfOtherMember(const ChannelMember& other) { - assert(&other != this); + RTC_DCHECK_NE(&other, this); QueueResponse("200 OK", "text/plain", GetPeerIdHeader(), other.GetEntry()); return true; } // Returns a string in the form "name,id,connected\n". std::string ChannelMember::GetEntry() const { - assert(name_.length() <= kMaxNameLength); + RTC_DCHECK(name_.length() <= kMaxNameLength); // name, 11-digit int, 1-digit bool, newline, null char entry[kMaxNameLength + 15]; @@ -102,8 +102,8 @@ std::string ChannelMember::GetEntry() const { } void ChannelMember::ForwardRequestToPeer(DataSocket* ds, ChannelMember* peer) { - assert(peer); - assert(ds); + RTC_DCHECK(peer); + RTC_DCHECK(ds); std::string extra_headers(GetPeerIdHeader()); @@ -129,8 +129,8 @@ void ChannelMember::QueueResponse(const std::string& status, const std::string& extra_headers, const std::string& data) { if (waiting_socket_) { - assert(queue_.empty()); - assert(waiting_socket_->method() == DataSocket::GET); + RTC_DCHECK(queue_.empty()); + RTC_DCHECK_EQ(waiting_socket_->method(), DataSocket::GET); bool ok = waiting_socket_->Send(status, true, content_type, extra_headers, data); if (!ok) { @@ -149,9 +149,9 @@ void ChannelMember::QueueResponse(const std::string& status, } void ChannelMember::SetWaitingSocket(DataSocket* ds) { - assert(ds->method() == DataSocket::GET); + RTC_DCHECK_EQ(ds->method(), DataSocket::GET); if (ds && !queue_.empty()) { - assert(waiting_socket_ == NULL); + RTC_DCHECK(!waiting_socket_); const QueuedResponse& response = queue_.front(); ds->Send(response.status, true, response.content_type, response.extra_headers, response.data); @@ -167,13 +167,13 @@ void ChannelMember::SetWaitingSocket(DataSocket* ds) { // static bool PeerChannel::IsPeerConnection(const DataSocket* ds) { - assert(ds); + RTC_DCHECK(ds); return (ds->method() == DataSocket::POST && ds->content_length() > 0) || (ds->method() == DataSocket::GET && ds->PathEquals("/sign_in")); } ChannelMember* PeerChannel::Lookup(DataSocket* ds) const { - assert(ds); + RTC_DCHECK(ds); if (ds->method() != DataSocket::GET && ds->method() != DataSocket::POST) return NULL; @@ -209,7 +209,7 @@ ChannelMember* PeerChannel::Lookup(DataSocket* ds) const { } ChannelMember* PeerChannel::IsTargetedRequest(const DataSocket* ds) const { - assert(ds); + RTC_DCHECK(ds); // Regardless of GET or POST, we look for the peer_id parameter // only in the request_path. const std::string& path = ds->request_path(); @@ -239,7 +239,7 @@ ChannelMember* PeerChannel::IsTargetedRequest(const DataSocket* ds) const { } bool PeerChannel::AddMember(DataSocket* ds) { - assert(IsPeerConnection(ds)); + RTC_DCHECK(IsPeerConnection(ds)); ChannelMember* new_guy = new ChannelMember(ds); Members failures; BroadcastChangedState(*new_guy, &failures); @@ -308,7 +308,7 @@ void PeerChannel::DeleteAll() { void PeerChannel::BroadcastChangedState(const ChannelMember& member, Members* delivery_failures) { // This function should be called prior to DataSocket::Close(). - assert(delivery_failures); + RTC_DCHECK(delivery_failures); if (!member.connected()) { printf("Member disconnected: %s\n", member.name().c_str()); @@ -329,12 +329,12 @@ void PeerChannel::BroadcastChangedState(const ChannelMember& member, } void PeerChannel::HandleDeliveryFailures(Members* failures) { - assert(failures); + RTC_DCHECK(failures); while (!failures->empty()) { Members::iterator i = failures->begin(); ChannelMember* member = *i; - assert(!member->connected()); + RTC_DCHECK(!member->connected()); failures->erase(i); BroadcastChangedState(*member, failures); delete member; @@ -344,14 +344,14 @@ void PeerChannel::HandleDeliveryFailures(Members* failures) { // Builds a simple list of "name,id\n" entries for each member. std::string PeerChannel::BuildResponseForNewMember(const ChannelMember& member, std::string* content_type) { - assert(content_type); + RTC_DCHECK(content_type); *content_type = "text/plain"; // The peer itself will always be the first entry. std::string response(member.GetEntry()); for (Members::iterator i = members_.begin(); i != members_.end(); ++i) { if (member.id() != (*i)->id()) { - assert((*i)->connected()); + RTC_DCHECK((*i)->connected()); response += (*i)->GetEntry(); } } diff --git a/examples/unityplugin/simple_peer_connection.cc b/examples/unityplugin/simple_peer_connection.cc index 4fd2fc359d..c7e5185bdc 100644 --- a/examples/unityplugin/simple_peer_connection.cc +++ b/examples/unityplugin/simple_peer_connection.cc @@ -190,13 +190,16 @@ bool SimplePeerConnection::CreatePeerConnection(const char** turn_urls, webrtc::PeerConnectionInterface::IceServer stun_server; stun_server.uri = GetPeerConnectionString(); config_.servers.push_back(stun_server); - config_.enable_rtp_data_channel = true; config_.enable_dtls_srtp = false; - peer_connection_ = g_peer_connection_factory->CreatePeerConnection( - config_, nullptr, nullptr, this); - - return peer_connection_.get() != nullptr; + auto result = g_peer_connection_factory->CreatePeerConnectionOrError( + config_, webrtc::PeerConnectionDependencies(this)); + if (!result.ok()) { + peer_connection_ = nullptr; + return false; + } + peer_connection_ = result.MoveValue(); + return true; } void SimplePeerConnection::DeletePeerConnection() { @@ -494,8 +497,9 @@ bool SimplePeerConnection::CreateDataChannel() { struct webrtc::DataChannelInit init; init.ordered = true; init.reliable = true; - data_channel_ = peer_connection_->CreateDataChannel("Hello", &init); - if (data_channel_.get()) { + auto result = peer_connection_->CreateDataChannelOrError("Hello", &init); + if (result.ok()) { + data_channel_ = result.MoveValue(); data_channel_->RegisterObserver(this); RTC_LOG(LS_INFO) << "Succeeds to create data channel"; return true; diff --git a/g3doc.lua b/g3doc.lua new file mode 100644 index 0000000000..85d8474a12 --- /dev/null +++ b/g3doc.lua @@ -0,0 +1 @@ +return require(this.dirname..'g3doc/g3doc.lua') diff --git a/g3doc/OWNERS b/g3doc/OWNERS new file mode 100644 index 0000000000..9ece35c39b --- /dev/null +++ b/g3doc/OWNERS @@ -0,0 +1,5 @@ +titovartem@webrtc.org + +per-file abseil-in-webrtc.md=danilchap@webrtc.org +per-file abseil-in-webrtc.md=mbonadei@webrtc.org +per-file style-guide.md=danilchap@webrtc.org diff --git a/abseil-in-webrtc.md b/g3doc/abseil-in-webrtc.md similarity index 95% rename from abseil-in-webrtc.md rename to g3doc/abseil-in-webrtc.md index 79b1031ffd..692ebe2b0b 100644 --- a/abseil-in-webrtc.md +++ b/g3doc/abseil-in-webrtc.md @@ -1,5 +1,8 @@ # Using Abseil in WebRTC + + + You may use a subset of the utilities provided by the [Abseil][abseil] library when writing WebRTC C++ code. Below, we list the explicitly *allowed* and the explicitly *disallowed* subsets of Abseil; if you @@ -22,6 +25,7 @@ will generate a shared library. ## **Allowed** +* `absl::bind_front` * `absl::InlinedVector` * `absl::WrapUnique` * `absl::optional` and related stuff from `absl/types/optional.h`. diff --git a/g3doc/g3doc.lua b/g3doc/g3doc.lua new file mode 100644 index 0000000000..e97289ff81 --- /dev/null +++ b/g3doc/g3doc.lua @@ -0,0 +1,20 @@ +return { + theme = { + '@builtins/theme/ng.md', + -- We don't want to have more than h3 headings in the Table Of Content. + toc_level = 3, + }, + + site = { + name = 'WebRTC C++ library', + home = this.dirname..'index.md', + logo = this.dirname..'logo.svg', + map = this.dirname..'sitemap.md', + -- Ensure absolute links are rewritten correctly. + root = this.dirname..'..' + }, + + visibility = { '/...' }, + + freshness = {} +} diff --git a/g3doc/how_to_write_documentation.md b/g3doc/how_to_write_documentation.md new file mode 100644 index 0000000000..6fbca116a5 --- /dev/null +++ b/g3doc/how_to_write_documentation.md @@ -0,0 +1,72 @@ +# How to write WebRTC documentation + + + + +## Audience + +Engineers and tech writers who wants to contribute to WebRTC documentation + +## Conceptual documentation + +Conceptual documentation provides overview of APIs or systems. Examples can +be threading model of a particular module or data life cycle. Conceptual +documentation can skip some edge cases in favor of clarity. The main point +is to impart understanding. + +Conceptual documentation often cannot be embedded directly within the source +code because it usually describes multiple APIs and entites, so the only +logical place to document such complex behavior is through a separate +conceptual document. + +A concept document needs to be useful to both experts and novices. Moreover, +it needs to emphasize clarity, so it often needs to sacrifice completeness +and sometimes strict accuracy. That's not to say a conceptual document should +intentionally be inaccurate. It just means that is should focus more on common +usage and leave rare ones or side effects for class/function level comments. + +In the WebRTC repo, conceptual documentation is located in `g3doc` subfolders +of related components. To add a new document for the component `Foo` find a +`g3doc` subfolder for this component and create a `.md` file there with +desired documentation. If there is no `g3doc` subfolder, create a new one; + +When you want to specify a link from one page to another - use the absolute +path: + +``` +[My document](/module/g3doc/my_document.md) +``` + +If you are a Googler also please specify an owner, who will be responsible for +keeping this documentation updated, by adding the next lines at the beginning +of your `.md` file immediately after page title: + +```markdown +' %?> +' %?> +``` + +If you want to configure the owner for all pages under a directory, create a +`g3doc.lua` file in that directory with the content: + +```lua +config = super() +config.freshness.owner = '' +return config +``` + +After the document is ready you should add it into `/g3doc/sitemap.md`, so it +will be discoverable by others. + +### Documentation format + +The documentation is written in GitHub Markdown +([spec](https://github.github.com/gfm/#:~:text=GitHub%20Flavored%20Markdown%2C%20often%20shortened,a%20strict%20superset%20of%20CommonMark.)). + +## Class/function level comments + +Documentation of specific classes and function APIs and their usage, including +their purpose, is embedded in the .h files defining that API. See +[C++ style guide](https://chromium.googlesource.com/chromium/src/+/master/styleguide/c++/c++.md) +for pointers on how to write API documentatin in .h files. + diff --git a/g3doc/implementation_basics.md b/g3doc/implementation_basics.md new file mode 100644 index 0000000000..933941a0d1 --- /dev/null +++ b/g3doc/implementation_basics.md @@ -0,0 +1,92 @@ + + + +# Basic concepts and primitives + +## Time + +Internally, time is represent using the [webrtc::Timestamp][1] class. This +represents +time with a resolution of one microsecond, using a 64-bit integer, and provides +converters to milliseconds or seconds as needed. + +All timestamps need to be measured from the system monotonic time. + +The epoch is not specified (because we can't always know if the system clock is +correct), but whenever an absolute epoch is needed, the Unix time +epoch (Jan 1, 1970 at 0:00 GMT) is used. + +Conversion from/to other formats (for example milliseconds, NTP times, +timestamp strings) should happen as close to the interface requiring that +format as possible. + +NOTE: There are parts of the codebase that don't use Timestamp, parts of the +codebase that use the NTP epoch, and parts of the codebase that don't use the +monotonic clock. They need to +be updated. + +## Threads + +All execution happens on a TaskQueue instance. How a TaskQueue is implemented +varies by platform, but they all have the [webrtc::TaskQueueBase][3] API. + +This API offers primitives for posting tasks, with or without delay. + +Some core parts use the [rtc::Thread][2], which is a subclass of TaskQueueBase. +This may contain a SocketServer for processing I/O, and is used for policing +certain calling pattern between a few core threads (the NetworkThread cannot +do Invoke on the Worker thread, for instance). + +## Synchronization primitives + +### PostTask and thread-guarded variables + +The preferred method for synchronization is to post tasks between threads, +and to let each thread take care of its own variables (lock-free programming). +All variables in +classes intended to be used with multiple threads should therefore be +annotated with RTC_GUARDED_BY(thread). + +For classes used with only one thread, the recommended pattern is to let +them own a webrtc::SequenceChecker (conventionally named sequence_checker_) +and let all variables be RTC_GUARDED_BY(sequence_checker_). + +Member variables marked const do not need to be guarded, since they never +change. (But note that they may point to objects that can change!) + +When posting tasks with callbacks, it is the duty of the caller to check +that the object one is calling back into still exists when the callback +is made. A helper for this task is the [webrtc::ScopedTaskSafety][5] +flag, which can automatically drop callbacks in this situation, and +associated classes. + +### Synchronization primitives to be used when needed + +When it is absolutely necessary to let one thread wait for another thread +to do something, Thread::Invoke can be used. This function is DISCOURAGED, +since it leads to performance issues, but is currently still widespread. + +When it is absolutely necessary to access one variable from multiple threads, +the webrtc::Mutex can be used. Such variables MUST be marked up with +RTC_GUARDED_BY(mutex), to allow static analysis that lessens the chance of +deadlocks or unintended consequences. + +### Synchronization primitives that are being removed +The following non-exhaustive list of synchronization primitives are +in the (slow) process of being removed from the codebase. + +* sigslot. Use [webrtc::CallbackList][4] instead, or, when there's only one + signal consumer, a single std::function. + +* AsyncInvoker. + +* RecursiveCriticalSection. Try to use [webrtc::Mutex][6] instead, and don't recurse. + + + +[1]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/api/units/timestamp.h;drc=b95d90b78a3491ef8e8aa0640dd521515ec881ca;l=29 +[2]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/rtc_base/thread.h;drc=1107751b6f11c35259a1c5c8a0f716e227b7e3b4;l=194 +[3]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/api/task_queue/task_queue_base.h;drc=1107751b6f11c35259a1c5c8a0f716e227b7e3b4;l=25 +[4]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/rtc_base/callback_list.h;drc=54b91412de3f579a2d5ccdead6e04cc2cc5ca3a1;l=162 +[5]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/rtc_base/task_utils/pending_task_safety_flag.h;drc=86ee89f73e4f4799b3ebcc0b5c65837c9601fe6d;l=117 +[6]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/rtc_base/synchronization/mutex.h;drc=0d3c09a8fe5f12dfbc9f1bcd5790fda8830624ec;l=40 diff --git a/g3doc/index.md b/g3doc/index.md new file mode 100644 index 0000000000..50a3934a4e --- /dev/null +++ b/g3doc/index.md @@ -0,0 +1,6 @@ +# WebRTC C++ library + + + + +This is a home page for WebRTC C++ library documentation diff --git a/g3doc/logo.svg b/g3doc/logo.svg new file mode 100644 index 0000000000..634b8cb116 --- /dev/null +++ b/g3doc/logo.svg @@ -0,0 +1,675 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/g3doc/sitemap.md b/g3doc/sitemap.md new file mode 100644 index 0000000000..c893d6ce3a --- /dev/null +++ b/g3doc/sitemap.md @@ -0,0 +1,47 @@ +* [Home](/g3doc/index.md) +* How to contribute + * Code + * [Style guide](/g3doc/style-guide.md) + * [Documentation](/g3doc/how_to_write_documentation.md) +* [Public C++ API](/api/g3doc/index.md) + * [Threading](/api/g3doc/threading_design.md) +* Implementation + * [Basic concepts](/g3doc/implementation_basics.md) + * [Supported Platforms and Compilers](/g3doc/supported-platforms-and-compilers.md) + * Network + * [ICE](/p2p/g3doc/ice.md) + * STUN + * TURN + * [DTLS](/pc/g3doc/dtls_transport.md) + * [RTP](/pc/g3doc/rtp.md) + * [SRTP](/pc/g3doc/srtp.md) + * [SCTP](/pc/g3doc/sctp_transport.md) + * [Pacing buffer](/modules/pacing/g3doc/index.md) + * Congestion control and bandwidth estimation + * Audio + * [NetEq](/modules/audio_coding/neteq/g3doc/index.md) + * AudioEngine + * [ADM](/modules/audio_device/g3doc/audio_device_module.md) + * [Audio Coding](/modules/audio_coding/g3doc/index.md) + * [Audio Mixer](/modules/audio_mixer/g3doc/index.md) + * AudioProcessingModule + * [APM](/modules/audio_processing/g3doc/audio_processing_module.md) + * Video + * [Adaptation](/video/g3doc/adaptation.md) + * [Video coding](/modules/video_coding/g3doc/index.md) + * [Stats](/video/g3doc/stats.md) + * DataChannel + * [PeerConnection](/pc/g3doc/peer_connection.md) + * Desktop capture + * Stats + * [Logging](/logging/g3doc/rtc_event_log.md) +* Testing + * Media Quality and performance + * [PeerConnection Framework](/test/pc/e2e/g3doc/index.md) + * [Architecture](/test/pc/e2e/g3doc/architecture.md) + * [Video analyzer](/test/pc/e2e/g3doc/default_video_quality_analyzer.md) + * Call framework + * Video codecs test framework + * Network emulation + * [Implementation](/test/network/g3doc/index.md) + * Performance stats collection diff --git a/g3doc/style-guide.md b/g3doc/style-guide.md new file mode 100644 index 0000000000..f3b0e8869d --- /dev/null +++ b/g3doc/style-guide.md @@ -0,0 +1,279 @@ +# WebRTC coding style guide + + + + +## General advice + +Some older parts of the code violate the style guide in various ways. + +* If making small changes to such code, follow the style guide when it's + reasonable to do so, but in matters of formatting etc., it is often better to + be consistent with the surrounding code. +* If making large changes to such code, consider first cleaning it up in a + separate CL. + +## C++ + +WebRTC follows the [Chromium C++ style guide][chr-style] and the +[Google C++ style guide][goog-style]. In cases where they conflict, the Chromium +style guide trumps the Google style guide, and the rules in this file trump them +both. + +[chr-style]: https://chromium.googlesource.com/chromium/src/+/HEAD/styleguide/c++/c++.md +[goog-style]: https://google.github.io/styleguide/cppguide.html + +### C++ version + +WebRTC is written in C++14, but with some restrictions: + +* We only allow the subset of C++14 (language and library) that is not banned by + Chromium; see the [list of banned C++ features in Chromium][chromium-cpp]. +* We only allow the subset of C++14 that is also valid C++17; otherwise, users + would not be able to compile WebRTC in C++17 mode. + +[chromium-cpp]: https://chromium-cpp.appspot.com/ + +Unlike the Chromium and Google C++ style guides, we do not allow C++20-style +designated initializers, because we want to stay compatible with compilers that +do not yet support them. + +### Abseil + +You may use a subset of the utilities provided by the [Abseil][abseil] library +when writing WebRTC C++ code; see the +[instructions on how to use Abseil in WebRTC](abseil-in-webrtc.md). + +[abseil]: https://abseil.io/about/ + +### `.h` and `.cc` files come in pairs + +`.h` and `.cc` files should come in pairs, with the same name (except for the +file type suffix), in the same directory, in the same build target. + +* If a declaration in `path/to/foo.h` has a definition in some `.cc` file, it + should be in `path/to/foo.cc`. +* If a definition in `path/to/foo.cc` file has a declaration in some `.h` file, + it should be in `path/to/foo.h`. +* Omit the `.cc` file if it would have been empty, but still list the `.h` file + in a build target. +* Omit the `.h` file if it would have been empty. (This can happen with unit + test `.cc` files, and with `.cc` files that define `main`.) + +See also the +[examples and exceptions on how to treat `.h` and `.cpp` files](style-guide/h-cc-pairs.md). + +This makes the source code easier to navigate and organize, and precludes some +questionable build system practices such as having build targets that don't pull +in definitions for everything they declare. + +### `TODO` comments + +Follow the [Google styleguide for `TODO` comments][goog-style-todo]. When +referencing a WebRTC bug, prefer the url form, e.g. + +```cpp +// TODO(bugs.webrtc.org/12345): Delete the hack when blocking bugs are resolved. +``` + +[goog-style-todo]: https://google.github.io/styleguide/cppguide.html#TODO_Comments + +### Deprecation + +Annotate the declarations of deprecated functions and classes with the +[`ABSL_DEPRECATED` macro][ABSL_DEPRECATED] to cause an error when they're used +inside WebRTC and a compiler warning when they're used by dependant projects. +Like so: + +```cpp +ABSL_DEPRECATED("bugs.webrtc.org/12345") +std::pony PonyPlz(const std::pony_spec& ps); +``` + +NOTE 1: The annotation goes on the declaration in the `.h` file, not the +definition in the `.cc` file! + +NOTE 2: In order to have unit tests that use the deprecated function without +getting errors, do something like this: + +```cpp +std::pony DEPRECATED_PonyPlz(const std::pony_spec& ps); +ABSL_DEPRECATED("bugs.webrtc.org/12345") +inline std::pony PonyPlz(const std::pony_spec& ps) { + return DEPRECATED_PonyPlz(ps); +} +``` + +In other words, rename the existing function, and provide an inline wrapper +using the original name that calls it. That way, callers who are willing to +call it using the `DEPRECATED_`-prefixed name don't get the warning. + +[ABSL_DEPRECATED]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/abseil-cpp/absl/base/attributes.h?q=ABSL_DEPRECATED + +### ArrayView + +When passing an array of values to a function, use `rtc::ArrayView` +whenever possible—that is, whenever you're not passing ownership of +the array, and don't allow the callee to change the array size. + +For example, + +| instead of | use | +|-------------------------------------|----------------------| +| `const std::vector&` | `ArrayView` | +| `const T* ptr, size_t num_elements` | `ArrayView` | +| `T* ptr, size_t num_elements` | `ArrayView` | + +See the [source code for `rtc::ArrayView`](api/array_view.h) for more detailed +docs. + +### sigslot + +SIGSLOT IS DEPRECATED. + +Prefer `webrtc::CallbackList`, and manage thread safety yourself. + +### Smart pointers + +The following smart pointer types are recommended: + + * `std::unique_ptr` for all singly-owned objects + * `rtc::scoped_refptr` for all objects with shared ownership + +Use of `std::shared_ptr` is *not permitted*. It is banned in the Chromium style +guide (overriding the Google style guide), and offers no compelling advantage +over `rtc::scoped_refptr` (which is cloned from the corresponding Chromium +type). See the +[list of banned C++ library features in Chromium][chr-std-shared-ptr] for more +information. + +In most cases, one will want to explicitly control lifetimes, and therefore use +`std::unique_ptr`, but in some cases, for instance where references have to +exist both from the API users and internally, with no way to invalidate pointers +held by the API user, `rtc::scoped_refptr` can be appropriate. + +[chr-std-shared-ptr]: https://chromium-cpp.appspot.com/#library-blocklist + +### `std::bind` + +Don't use `std::bind`—there are pitfalls, and lambdas are almost as succinct and +already familiar to modern C++ programmers. + +### `std::function` + +`std::function` is allowed, but remember that it's not the right tool for every +occasion. Prefer to use interfaces when that makes sense, and consider +`rtc::FunctionView` for cases where the callee will not save the function +object. + +### Forward declarations + +WebRTC follows the +[Google C++ style guide on forward declarations][goog-forward-declarations]. +In summary: avoid using forward declarations where possible; just `#include` the +headers you need. + +[goog-forward-declarations]: https://google.github.io/styleguide/cppguide.html#Forward_Declarations + +## C + +There's a substantial chunk of legacy C code in WebRTC, and a lot of it is old +enough that it violates the parts of the C++ style guide that also applies to C +(naming etc.) for the simple reason that it pre-dates the use of the current C++ +style guide for this code base. + +* If making small changes to C code, mimic the style of the surrounding code. +* If making large changes to C code, consider converting the whole thing to C++ + first. + +## Java + +WebRTC follows the [Google Java style guide][goog-java-style]. + +[goog-java-style]: https://google.github.io/styleguide/javaguide.html + +## Objective-C and Objective-C++ + +WebRTC follows the +[Chromium Objective-C and Objective-C++ style guide][chr-objc-style]. + +[chr-objc-style]: https://chromium.googlesource.com/chromium/src/+/HEAD/styleguide/objective-c/objective-c.md + +## Python + +WebRTC follows [Chromium's Python style][chr-py-style]. + +[chr-py-style]: https://chromium.googlesource.com/chromium/src/+/HEAD/styleguide/python/python.md + +## Build files + +The WebRTC build files are written in [GN][gn], and we follow the +[GN style guide][gn-style]. Additionally, there are some +WebRTC-specific rules below; in case of conflict, they trump the Chromium style +guide. + +[gn]: https://gn.googlesource.com/gn/ +[gn-style]: https://gn.googlesource.com/gn/+/HEAD/docs/style_guide.md + +### WebRTC-specific GN templates + +Use the following [GN templates][gn-templ] to ensure that all our +[GN targets][gn-target] are built with the same configuration: + +| instead of | use | +|------------------|----------------------| +| `executable` | `rtc_executable` | +| `shared_library` | `rtc_shared_library` | +| `source_set` | `rtc_source_set` | +| `static_library` | `rtc_static_library` | +| `test` | `rtc_test` | + + +[gn-templ]: https://gn.googlesource.com/gn/+/HEAD/docs/language.md#Templates +[gn-target]: https://gn.googlesource.com/gn/+/HEAD/docs/language.md#Targets + +### Target visibility and the native API + +The [WebRTC-specific GN templates](#webrtc-gn-templates) declare build targets +whose default `visibility` allows all other targets in the WebRTC tree (and no +targets outside the tree) to depend on them. + +Prefer to restrict the `visibility` if possible: + +* If a target is used by only one or a tiny number of other targets, prefer to + list them explicitly: `visibility = [ ":foo", ":bar" ]` +* If a target is used only by targets in the same `BUILD.gn` file: + `visibility = [ ":*" ]`. + +Setting `visibility = [ "*" ]` means that targets outside the WebRTC tree can +depend on this target; use this only for build targets whose headers are part of +the [native WebRTC API](native-api.md). + +### Conditional compilation with the C preprocessor + +Avoid using the C preprocessor to conditionally enable or disable pieces of +code. But if you can't avoid it, introduce a GN variable, and then set a +preprocessor constant to either 0 or 1 in the build targets that need it: + +```gn +if (apm_debug_dump) { + defines = [ "WEBRTC_APM_DEBUG_DUMP=1" ] +} else { + defines = [ "WEBRTC_APM_DEBUG_DUMP=0" ] +} +``` + +In the C, C++, or Objective-C files, use `#if` when testing the flag, +not `#ifdef` or `#if defined()`: + +```c +#if WEBRTC_APM_DEBUG_DUMP +// One way. +#else +// Or another. +#endif +``` + +When combined with the `-Wundef` compiler option, this produces compile time +warnings if preprocessor symbols are misspelled, or used without corresponding +build rules to set them. diff --git a/style-guide/OWNERS b/g3doc/style-guide/OWNERS similarity index 100% rename from style-guide/OWNERS rename to g3doc/style-guide/OWNERS diff --git a/style-guide/h-cc-pairs.md b/g3doc/style-guide/h-cc-pairs.md similarity index 92% rename from style-guide/h-cc-pairs.md rename to g3doc/style-guide/h-cc-pairs.md index 1a24e49d09..bb85871260 100644 --- a/style-guide/h-cc-pairs.md +++ b/g3doc/style-guide/h-cc-pairs.md @@ -1,5 +1,8 @@ # `.h` and `.cc` files come in pairs + + + This is an overflow page for [this](../style-guide.md#h-cc-pairs) style rule. diff --git a/g3doc/supported-platforms-and-compilers.md b/g3doc/supported-platforms-and-compilers.md new file mode 100644 index 0000000000..9e51a29ab7 --- /dev/null +++ b/g3doc/supported-platforms-and-compilers.md @@ -0,0 +1,36 @@ +# WebRTC supported plaftorms and compilers + + + + +## Operating systems and CPUs + +The list of officially supported operating systems and CPUs is: + +* Android: armeabi-v7a, arm64-v8a, x86, x86_64. +* iOS: arm64, x86_64. +* Linux: armeabi-v7a, arm64-v8a, x86, x86_64. +* macOS: x86_64, arm64 (M1). +* Windows: x86_64. + +Other platforms are not officially supported (which means there is no CI +coverage for them) but patches to keep WebRTC working with them are welcomed by +the WebRTC Team. + +## Compilers + +WebRTC officially supports clang on all the supported platforms. The clang +version officially supported is the one used by Chromium (hence the version is +really close to Tip of Tree and can be checked +[here](https://source.chromium.org/chromium/chromium/src/+/main:tools/clang/scripts/update.py) +by looking at the value of `CLANG_REVISION`). + +See also +[here](https://source.chromium.org/chromium/chromium/src/+/main:docs/clang.md) +for some clang related documentation from Chromium. + +MSVC is also supported at version VS 2019 16.61. + +Other compilers are not officially supported (which means there is no CI +coverage for them) but patches to keep WebRTC working with them are welcomed by +the WebRTC Team. diff --git a/logging/BUILD.gn b/logging/BUILD.gn index 8eb5631919..90a05f7c49 100644 --- a/logging/BUILD.gn +++ b/logging/BUILD.gn @@ -53,6 +53,7 @@ rtc_library("rtc_event_pacing") { deps = [ "../api:scoped_refptr", "../api/rtc_event_log", + "../api/units:timestamp", ] absl_deps = [ "//third_party/abseil-cpp/absl/memory" ] } @@ -73,6 +74,7 @@ rtc_library("rtc_event_audio") { ":rtc_stream_config", "../api:scoped_refptr", "../api/rtc_event_log", + "../api/units:timestamp", "../modules/audio_coding:audio_network_adaptor_config", "../rtc_base:checks", ] @@ -101,6 +103,7 @@ rtc_library("rtc_event_bwe") { "../api:scoped_refptr", "../api/rtc_event_log", "../api/units:data_rate", + "../api/units:timestamp", ] absl_deps = [ "//third_party/abseil-cpp/absl/memory", @@ -115,6 +118,7 @@ rtc_library("rtc_event_frame_events") { ] deps = [ "../api/rtc_event_log", + "../api/units:timestamp", "../api/video:video_frame", "../rtc_base:timeutils", ] @@ -136,6 +140,7 @@ rtc_library("rtc_event_generic_packet_events") { ] deps = [ "../api/rtc_event_log", + "../api/units:timestamp", "../rtc_base:timeutils", ] absl_deps = [ @@ -179,6 +184,7 @@ rtc_library("rtc_event_video") { ":rtc_stream_config", "../api:scoped_refptr", "../api/rtc_event_log", + "../api/units:timestamp", "../rtc_base:checks", ] absl_deps = [ "//third_party/abseil-cpp/absl/memory" ] @@ -261,13 +267,13 @@ if (rtc_enable_protobuf) { ":rtc_event_log_api", ":rtc_event_log_impl_encoder", "../api:libjingle_logging_api", + "../api:sequence_checker", "../api/rtc_event_log", "../api/task_queue", "../rtc_base:checks", "../rtc_base:rtc_base_approved", "../rtc_base:rtc_task_queue", "../rtc_base:safe_minmax", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/system:no_unique_address", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] @@ -284,10 +290,9 @@ rtc_library("fake_rtc_event_log") { ] deps = [ - ":ice_log", "../api/rtc_event_log", "../rtc_base", - "../rtc_base:checks", + "../rtc_base/synchronization:mutex", ] } @@ -343,11 +348,11 @@ if (rtc_enable_protobuf) { "../modules/rtp_rtcp", "../modules/rtp_rtcp:rtp_rtcp_format", "../rtc_base:checks", - "../rtc_base:deprecation", "../rtc_base:ignore_wundef", "../rtc_base:protobuf_utils", "../rtc_base:rtc_base_approved", "../rtc_base:rtc_numerics", + "../rtc_base/system:file_wrapper", ] absl_deps = [ "//third_party/abseil-cpp/absl/memory", @@ -409,26 +414,28 @@ if (rtc_enable_protobuf) { ] } - rtc_executable("rtc_event_log_rtp_dump") { - testonly = true - sources = [ "rtc_event_log/rtc_event_log2rtp_dump.cc" ] - deps = [ - ":rtc_event_log_parser", - "../api:array_view", - "../api:rtp_headers", - "../api/rtc_event_log", - "../modules/rtp_rtcp", - "../modules/rtp_rtcp:rtp_rtcp_format", - "../rtc_base:checks", - "../rtc_base:protobuf_utils", - "../rtc_base:rtc_base_approved", - "../test:rtp_test_utils", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - "//third_party/abseil-cpp/absl/flags:usage", - "//third_party/abseil-cpp/absl/memory", - "//third_party/abseil-cpp/absl/types:optional", - ] + if (!build_with_chromium) { + rtc_executable("rtc_event_log_rtp_dump") { + testonly = true + sources = [ "rtc_event_log/rtc_event_log2rtp_dump.cc" ] + deps = [ + ":rtc_event_log_parser", + "../api:array_view", + "../api:rtp_headers", + "../api/rtc_event_log", + "../modules/rtp_rtcp", + "../modules/rtp_rtcp:rtp_rtcp_format", + "../rtc_base:checks", + "../rtc_base:protobuf_utils", + "../rtc_base:rtc_base_approved", + "../test:rtp_test_utils", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + "//third_party/abseil-cpp/absl/flags:usage", + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/types:optional", + ] + } } } } @@ -451,6 +458,7 @@ rtc_library("ice_log") { "../api:libjingle_logging_api", "../api:libjingle_peerconnection_api", # For api/dtls_transport_interface.h "../api/rtc_event_log", + "../api/units:timestamp", "../rtc_base:rtc_base_approved", ] absl_deps = [ "//third_party/abseil-cpp/absl/memory" ] diff --git a/logging/g3doc/rtc_event_log.md b/logging/g3doc/rtc_event_log.md new file mode 100644 index 0000000000..c7996e0b42 --- /dev/null +++ b/logging/g3doc/rtc_event_log.md @@ -0,0 +1,85 @@ +# RTC event log + + + + +## Overview + +RTC event logs can be enabled to capture in-depth inpformation about sent and +received packets and the internal state of some WebRTC components. The logs are +useful to understand network behavior and to debug issues around connectivity, +bandwidth estimation and audio jitter buffers. + +The contents include: + +* Sent and received RTP headers +* Full RTCP feedback +* ICE candidates, pings and responses +* Bandwidth estimator events, including loss-based estimate, delay-based + estimate, probe results and ALR state +* Audio network adaptation settings +* Audio playout events + +## Binary wire format + +No guarantees are made on the wire format, and the format may change without +prior notice. To maintain compatibility with past and future formats, analysis +tools should be built on top of the provided +[rtc_event_log_parser.h](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/logging/rtc_event_log/rtc_event_log_parser.h) + +In particular, an analysis tool should *not* read the log as a protobuf. + +## Visualization + +Since the logs contain a substantial amount of data, it is usually convenient to +get an overview by visualizing them as a set of plots. Use the command: + +``` +out/Default/event_log_visualizer /path/to/log_file | python +``` + +This visualization requires matplotlib to be installed. The tool is capable of +producing a substantial number of plots, of which only a handful are generated +by default. You can select which plots are generated though the `--plot=` +command line argument. For example, the command + +``` +out/Default/event_log_visualizer \ + --plot=incoming_packet_sizes,incoming_stream_bitrate \ + /path/to/log_file | python +``` + +plots the sizes of incoming packets and the bitrate per incoming stream. + +You can get a full list of options for the `--plot` argument through + +``` +out/Default/event_log_visualizer --list_plots /path/to/log_file +``` + +You can also synchronize the x-axis between all plots (so zooming or +panning in one plot affects all of them), by adding the command line +argument `--shared_xaxis`. + + +## Viewing the raw log contents as text + +If you know which format version the log file uses, you can view the raw +contents as text. For version 1, you can use the command + +``` +out/Default/protoc --decode webrtc.rtclog.EventStream \ + ./logging/rtc_event_log/rtc_event_log.proto < /path/to/log_file +``` + +Similarly, you can use + +``` +out/Default/protoc --decode webrtc.rtclog2.EventStream \ + ./logging/rtc_event_log/rtc_event_log2.proto < /path/to/log_file +``` + +for logs that use version 2. However, note that not all of the contents will be +human readable. Some fields are based on the raw RTP format or may be encoded as +deltas relative to previous fields. Such fields will be printed as a list of +bytes. diff --git a/logging/rtc_event_log/encoder/blob_encoding.cc b/logging/rtc_event_log/encoder/blob_encoding.cc index 48316b052b..96699dc96a 100644 --- a/logging/rtc_event_log/encoder/blob_encoding.cc +++ b/logging/rtc_event_log/encoder/blob_encoding.cc @@ -58,49 +58,30 @@ std::vector DecodeBlobs(absl::string_view encoded_blobs, return std::vector(); } - size_t read_idx = 0; - // Read the lengths of all blobs. std::vector lengths(num_of_blobs); for (size_t i = 0; i < num_of_blobs; ++i) { - if (read_idx >= encoded_blobs.length()) { - RTC_DCHECK_EQ(read_idx, encoded_blobs.length()); - RTC_LOG(LS_WARNING) << "Corrupt input; excessive number of blobs."; - return std::vector(); - } - - const size_t read_bytes = - DecodeVarInt(encoded_blobs.substr(read_idx), &lengths[i]); - if (read_bytes == 0) { + bool success = false; + std::tie(success, encoded_blobs) = DecodeVarInt(encoded_blobs, &lengths[i]); + if (!success) { RTC_LOG(LS_WARNING) << "Corrupt input; varint decoding failed."; return std::vector(); } - - read_idx += read_bytes; - - // Note: It might be that read_idx == encoded_blobs.length(), if this - // is the last iteration, and all of the blobs are the empty string. - RTC_DCHECK_LE(read_idx, encoded_blobs.length()); } // Read the blobs themselves. std::vector blobs(num_of_blobs); for (size_t i = 0; i < num_of_blobs; ++i) { - if (read_idx + lengths[i] < read_idx) { // Wrap-around detection. - RTC_LOG(LS_WARNING) << "Corrupt input; unreasonably large blob sequence."; - return std::vector(); - } - - if (read_idx + lengths[i] > encoded_blobs.length()) { + if (lengths[i] > encoded_blobs.length()) { RTC_LOG(LS_WARNING) << "Corrupt input; blob sizes exceed input size."; return std::vector(); } - blobs[i] = encoded_blobs.substr(read_idx, lengths[i]); - read_idx += lengths[i]; + blobs[i] = encoded_blobs.substr(0, lengths[i]); + encoded_blobs = encoded_blobs.substr(lengths[i]); } - if (read_idx != encoded_blobs.length()) { + if (!encoded_blobs.empty()) { RTC_LOG(LS_WARNING) << "Corrupt input; unrecognized trailer."; return std::vector(); } diff --git a/logging/rtc_event_log/encoder/delta_encoding.cc b/logging/rtc_event_log/encoder/delta_encoding.cc index 022fb9c163..7bccdabdc8 100644 --- a/logging/rtc_event_log/encoder/delta_encoding.cc +++ b/logging/rtc_event_log/encoder/delta_encoding.cc @@ -693,7 +693,7 @@ bool FixedLengthDeltaDecoder::IsSuitableDecoderFor(const std::string& input) { uint32_t encoding_type_bits; const bool result = - reader.ReadBits(&encoding_type_bits, kBitsInHeaderForEncodingType); + reader.ReadBits(kBitsInHeaderForEncodingType, encoding_type_bits); RTC_DCHECK(result); const auto encoding_type = static_cast(encoding_type_bits); @@ -729,7 +729,7 @@ std::unique_ptr FixedLengthDeltaDecoder::Create( // Encoding type uint32_t encoding_type_bits; const bool result = - reader->ReadBits(&encoding_type_bits, kBitsInHeaderForEncodingType); + reader->ReadBits(kBitsInHeaderForEncodingType, encoding_type_bits); RTC_DCHECK(result); const EncodingType encoding = static_cast(encoding_type_bits); if (encoding != EncodingType::kFixedSizeUnsignedDeltasNoEarlyWrapNoOpt && @@ -742,7 +742,7 @@ std::unique_ptr FixedLengthDeltaDecoder::Create( uint32_t read_buffer; // delta_width_bits - if (!reader->ReadBits(&read_buffer, kBitsInHeaderForDeltaWidthBits)) { + if (!reader->ReadBits(kBitsInHeaderForDeltaWidthBits, read_buffer)) { return nullptr; } RTC_DCHECK_LE(read_buffer, 64 - 1); // See encoding for -1's rationale. @@ -759,20 +759,20 @@ std::unique_ptr FixedLengthDeltaDecoder::Create( value_width_bits = kDefaultValueWidthBits; } else { // signed_deltas - if (!reader->ReadBits(&read_buffer, kBitsInHeaderForSignedDeltas)) { + if (!reader->ReadBits(kBitsInHeaderForSignedDeltas, read_buffer)) { return nullptr; } signed_deltas = rtc::dchecked_cast(read_buffer); // values_optional - if (!reader->ReadBits(&read_buffer, kBitsInHeaderForValuesOptional)) { + if (!reader->ReadBits(kBitsInHeaderForValuesOptional, read_buffer)) { return nullptr; } RTC_DCHECK_LE(read_buffer, 1); values_optional = rtc::dchecked_cast(read_buffer); // value_width_bits - if (!reader->ReadBits(&read_buffer, kBitsInHeaderForValueWidthBits)) { + if (!reader->ReadBits(kBitsInHeaderForValueWidthBits, read_buffer)) { return nullptr; } RTC_DCHECK_LE(read_buffer, 64 - 1); // See encoding for -1's rationale. @@ -813,7 +813,7 @@ std::vector> FixedLengthDeltaDecoder::Decode() { if (params_.values_optional()) { for (size_t i = 0; i < num_of_deltas_; ++i) { uint32_t exists; - if (!reader_->ReadBits(&exists, 1u)) { + if (!reader_->ReadBits(1u, exists)) { RTC_LOG(LS_WARNING) << "Failed to read existence-indicating bit."; return std::vector>(); } @@ -877,7 +877,7 @@ bool FixedLengthDeltaDecoder::ParseDelta(uint64_t* delta) { uint32_t higher_bits; if (higher_bit_count > 0) { - if (!reader_->ReadBits(&higher_bits, higher_bit_count)) { + if (!reader_->ReadBits(higher_bit_count, higher_bits)) { RTC_LOG(LS_WARNING) << "Failed to read higher half of delta."; return false; } @@ -885,7 +885,7 @@ bool FixedLengthDeltaDecoder::ParseDelta(uint64_t* delta) { higher_bits = 0; } - if (!reader_->ReadBits(&lower_bits, lower_bit_count)) { + if (!reader_->ReadBits(lower_bit_count, lower_bits)) { RTC_LOG(LS_WARNING) << "Failed to read lower half of delta."; return false; } diff --git a/logging/rtc_event_log/encoder/rtc_event_log_encoder_legacy.cc b/logging/rtc_event_log/encoder/rtc_event_log_encoder_legacy.cc index dfbad7669a..2bd7507853 100644 --- a/logging/rtc_event_log/encoder/rtc_event_log_encoder_legacy.cc +++ b/logging/rtc_event_log/encoder/rtc_event_log_encoder_legacy.cc @@ -15,6 +15,7 @@ #include #include "absl/types/optional.h" +#include "api/array_view.h" #include "api/network_state_predictor.h" #include "api/rtp_headers.h" #include "api/rtp_parameters.h" @@ -593,14 +594,14 @@ std::string RtcEventLogEncoderLegacy::EncodeRtcpPacketOutgoing( std::string RtcEventLogEncoderLegacy::EncodeRtpPacketIncoming( const RtcEventRtpPacketIncoming& event) { - return EncodeRtpPacket(event.timestamp_us(), event.header(), + return EncodeRtpPacket(event.timestamp_us(), event.RawHeader(), event.packet_length(), PacedPacketInfo::kNotAProbe, true); } std::string RtcEventLogEncoderLegacy::EncodeRtpPacketOutgoing( const RtcEventRtpPacketOutgoing& event) { - return EncodeRtpPacket(event.timestamp_us(), event.header(), + return EncodeRtpPacket(event.timestamp_us(), event.RawHeader(), event.packet_length(), event.probe_cluster_id(), false); } @@ -736,7 +737,7 @@ std::string RtcEventLogEncoderLegacy::EncodeRtcpPacket( std::string RtcEventLogEncoderLegacy::EncodeRtpPacket( int64_t timestamp_us, - const webrtc::RtpPacket& header, + rtc::ArrayView header, size_t packet_length, int probe_cluster_id, bool is_incoming) { diff --git a/logging/rtc_event_log/encoder/rtc_event_log_encoder_legacy.h b/logging/rtc_event_log/encoder/rtc_event_log_encoder_legacy.h index 3105dc1e68..37296e797f 100644 --- a/logging/rtc_event_log/encoder/rtc_event_log_encoder_legacy.h +++ b/logging/rtc_event_log/encoder/rtc_event_log_encoder_legacy.h @@ -15,6 +15,7 @@ #include #include +#include "api/array_view.h" #include "logging/rtc_event_log/encoder/rtc_event_log_encoder.h" #include "rtc_base/buffer.h" @@ -94,7 +95,7 @@ class RtcEventLogEncoderLegacy final : public RtcEventLogEncoder { const rtc::Buffer& packet, bool is_incoming); std::string EncodeRtpPacket(int64_t timestamp_us, - const RtpPacket& header, + rtc::ArrayView header, size_t packet_length, int probe_cluster_id, bool is_incoming); diff --git a/logging/rtc_event_log/encoder/rtc_event_log_encoder_new_format.cc b/logging/rtc_event_log/encoder/rtc_event_log_encoder_new_format.cc index f0a307973e..9400c864bf 100644 --- a/logging/rtc_event_log/encoder/rtc_event_log_encoder_new_format.cc +++ b/logging/rtc_event_log/encoder/rtc_event_log_encoder_new_format.cc @@ -394,12 +394,12 @@ void EncodeRtpPacket(const std::vector& batch, // Base event const EventType* const base_event = batch[0]; proto_batch->set_timestamp_ms(base_event->timestamp_ms()); - proto_batch->set_marker(base_event->header().Marker()); + proto_batch->set_marker(base_event->Marker()); // TODO(terelius): Is payload type needed? - proto_batch->set_payload_type(base_event->header().PayloadType()); - proto_batch->set_sequence_number(base_event->header().SequenceNumber()); - proto_batch->set_rtp_timestamp(base_event->header().Timestamp()); - proto_batch->set_ssrc(base_event->header().Ssrc()); + proto_batch->set_payload_type(base_event->PayloadType()); + proto_batch->set_sequence_number(base_event->SequenceNumber()); + proto_batch->set_rtp_timestamp(base_event->Timestamp()); + proto_batch->set_ssrc(base_event->Ssrc()); proto_batch->set_payload_size(base_event->payload_length()); proto_batch->set_header_size(base_event->header_length()); proto_batch->set_padding_size(base_event->padding_length()); @@ -408,8 +408,7 @@ void EncodeRtpPacket(const std::vector& batch, absl::optional base_transport_sequence_number; { uint16_t seqnum; - if (base_event->header().template GetExtension( - &seqnum)) { + if (base_event->template GetExtension(&seqnum)) { proto_batch->set_transport_sequence_number(seqnum); base_transport_sequence_number = seqnum; } @@ -418,8 +417,7 @@ void EncodeRtpPacket(const std::vector& batch, absl::optional unsigned_base_transmission_time_offset; { int32_t offset; - if (base_event->header().template GetExtension( - &offset)) { + if (base_event->template GetExtension(&offset)) { proto_batch->set_transmission_time_offset(offset); unsigned_base_transmission_time_offset = ToUnsigned(offset); } @@ -428,8 +426,7 @@ void EncodeRtpPacket(const std::vector& batch, absl::optional base_absolute_send_time; { uint32_t sendtime; - if (base_event->header().template GetExtension( - &sendtime)) { + if (base_event->template GetExtension(&sendtime)) { proto_batch->set_absolute_send_time(sendtime); base_absolute_send_time = sendtime; } @@ -438,8 +435,7 @@ void EncodeRtpPacket(const std::vector& batch, absl::optional base_video_rotation; { VideoRotation video_rotation; - if (base_event->header().template GetExtension( - &video_rotation)) { + if (base_event->template GetExtension(&video_rotation)) { proto_batch->set_video_rotation( ConvertVideoRotationToCVOByte(video_rotation)); base_video_rotation = ConvertVideoRotationToCVOByte(video_rotation); @@ -451,8 +447,8 @@ void EncodeRtpPacket(const std::vector& batch, { bool voice_activity; uint8_t audio_level; - if (base_event->header().template GetExtension(&voice_activity, - &audio_level)) { + if (base_event->template GetExtension(&voice_activity, + &audio_level)) { RTC_DCHECK_LE(audio_level, 0x7Fu); base_audio_level = audio_level; proto_batch->set_audio_level(audio_level); @@ -484,9 +480,9 @@ void EncodeRtpPacket(const std::vector& batch, // marker (RTP base) for (size_t i = 0; i < values.size(); ++i) { const EventType* event = batch[i + 1]; - values[i] = event->header().Marker(); + values[i] = event->Marker(); } - encoded_deltas = EncodeDeltas(base_event->header().Marker(), values); + encoded_deltas = EncodeDeltas(base_event->Marker(), values); if (!encoded_deltas.empty()) { proto_batch->set_marker_deltas(encoded_deltas); } @@ -494,9 +490,9 @@ void EncodeRtpPacket(const std::vector& batch, // payload_type (RTP base) for (size_t i = 0; i < values.size(); ++i) { const EventType* event = batch[i + 1]; - values[i] = event->header().PayloadType(); + values[i] = event->PayloadType(); } - encoded_deltas = EncodeDeltas(base_event->header().PayloadType(), values); + encoded_deltas = EncodeDeltas(base_event->PayloadType(), values); if (!encoded_deltas.empty()) { proto_batch->set_payload_type_deltas(encoded_deltas); } @@ -504,9 +500,9 @@ void EncodeRtpPacket(const std::vector& batch, // sequence_number (RTP base) for (size_t i = 0; i < values.size(); ++i) { const EventType* event = batch[i + 1]; - values[i] = event->header().SequenceNumber(); + values[i] = event->SequenceNumber(); } - encoded_deltas = EncodeDeltas(base_event->header().SequenceNumber(), values); + encoded_deltas = EncodeDeltas(base_event->SequenceNumber(), values); if (!encoded_deltas.empty()) { proto_batch->set_sequence_number_deltas(encoded_deltas); } @@ -514,9 +510,9 @@ void EncodeRtpPacket(const std::vector& batch, // rtp_timestamp (RTP base) for (size_t i = 0; i < values.size(); ++i) { const EventType* event = batch[i + 1]; - values[i] = event->header().Timestamp(); + values[i] = event->Timestamp(); } - encoded_deltas = EncodeDeltas(base_event->header().Timestamp(), values); + encoded_deltas = EncodeDeltas(base_event->Timestamp(), values); if (!encoded_deltas.empty()) { proto_batch->set_rtp_timestamp_deltas(encoded_deltas); } @@ -524,9 +520,9 @@ void EncodeRtpPacket(const std::vector& batch, // ssrc (RTP base) for (size_t i = 0; i < values.size(); ++i) { const EventType* event = batch[i + 1]; - values[i] = event->header().Ssrc(); + values[i] = event->Ssrc(); } - encoded_deltas = EncodeDeltas(base_event->header().Ssrc(), values); + encoded_deltas = EncodeDeltas(base_event->Ssrc(), values); if (!encoded_deltas.empty()) { proto_batch->set_ssrc_deltas(encoded_deltas); } @@ -565,8 +561,7 @@ void EncodeRtpPacket(const std::vector& batch, for (size_t i = 0; i < values.size(); ++i) { const EventType* event = batch[i + 1]; uint16_t seqnum; - if (event->header().template GetExtension( - &seqnum)) { + if (event->template GetExtension(&seqnum)) { values[i] = seqnum; } else { values[i].reset(); @@ -581,7 +576,7 @@ void EncodeRtpPacket(const std::vector& batch, for (size_t i = 0; i < values.size(); ++i) { const EventType* event = batch[i + 1]; int32_t offset; - if (event->header().template GetExtension(&offset)) { + if (event->template GetExtension(&offset)) { values[i] = ToUnsigned(offset); } else { values[i].reset(); @@ -596,7 +591,7 @@ void EncodeRtpPacket(const std::vector& batch, for (size_t i = 0; i < values.size(); ++i) { const EventType* event = batch[i + 1]; uint32_t sendtime; - if (event->header().template GetExtension(&sendtime)) { + if (event->template GetExtension(&sendtime)) { values[i] = sendtime; } else { values[i].reset(); @@ -611,8 +606,7 @@ void EncodeRtpPacket(const std::vector& batch, for (size_t i = 0; i < values.size(); ++i) { const EventType* event = batch[i + 1]; VideoRotation video_rotation; - if (event->header().template GetExtension( - &video_rotation)) { + if (event->template GetExtension(&video_rotation)) { values[i] = ConvertVideoRotationToCVOByte(video_rotation); } else { values[i].reset(); @@ -628,8 +622,8 @@ void EncodeRtpPacket(const std::vector& batch, const EventType* event = batch[i + 1]; bool voice_activity; uint8_t audio_level; - if (event->header().template GetExtension(&voice_activity, - &audio_level)) { + if (event->template GetExtension(&voice_activity, + &audio_level)) { RTC_DCHECK_LE(audio_level, 0x7Fu); values[i] = audio_level; } else { @@ -646,8 +640,8 @@ void EncodeRtpPacket(const std::vector& batch, const EventType* event = batch[i + 1]; bool voice_activity; uint8_t audio_level; - if (event->header().template GetExtension(&voice_activity, - &audio_level)) { + if (event->template GetExtension(&voice_activity, + &audio_level)) { RTC_DCHECK_LE(audio_level, 0x7Fu); values[i] = voice_activity; } else { @@ -823,14 +817,14 @@ std::string RtcEventLogEncoderNewFormat::EncodeBatch( case RtcEvent::Type::RtpPacketIncoming: { auto* rtc_event = static_cast(it->get()); - auto& v = incoming_rtp_packets[rtc_event->header().Ssrc()]; + auto& v = incoming_rtp_packets[rtc_event->Ssrc()]; v.emplace_back(rtc_event); break; } case RtcEvent::Type::RtpPacketOutgoing: { auto* rtc_event = static_cast(it->get()); - auto& v = outgoing_rtp_packets[rtc_event->header().Ssrc()]; + auto& v = outgoing_rtp_packets[rtc_event->Ssrc()]; v.emplace_back(rtc_event); break; } diff --git a/logging/rtc_event_log/encoder/rtc_event_log_encoder_unittest.cc b/logging/rtc_event_log/encoder/rtc_event_log_encoder_unittest.cc index 6fae2d9cd6..063d425af5 100644 --- a/logging/rtc_event_log/encoder/rtc_event_log_encoder_unittest.cc +++ b/logging/rtc_event_log/encoder/rtc_event_log_encoder_unittest.cc @@ -49,12 +49,12 @@ class RtcEventLogEncoderTest RtcEventLogEncoderTest() : seed_(std::get<0>(GetParam())), prng_(seed_), - encoding_(std::get<1>(GetParam())), + encoding_type_(std::get<1>(GetParam())), event_count_(std::get<2>(GetParam())), force_repeated_fields_(std::get<3>(GetParam())), gen_(seed_ * 880001UL), - verifier_(encoding_) { - switch (encoding_) { + verifier_(encoding_type_) { + switch (encoding_type_) { case RtcEventLog::EncodingType::Legacy: encoder_ = std::make_unique(); break; @@ -62,6 +62,8 @@ class RtcEventLogEncoderTest encoder_ = std::make_unique(); break; } + encoded_ = + encoder_->EncodeLogStart(rtc::TimeMillis(), rtc::TimeUTCMillis()); } ~RtcEventLogEncoderTest() override = default; @@ -89,11 +91,12 @@ class RtcEventLogEncoderTest ParsedRtcEventLog parsed_log_; const uint64_t seed_; Random prng_; - const RtcEventLog::EncodingType encoding_; + const RtcEventLog::EncodingType encoding_type_; const size_t event_count_; const bool force_repeated_fields_; test::EventGenerator gen_; test::EventVerifier verifier_; + std::string encoded_; }; void RtcEventLogEncoderTest::TestRtcEventAudioNetworkAdaptation( @@ -105,8 +108,8 @@ void RtcEventLogEncoderTest::TestRtcEventAudioNetworkAdaptation( history_.push_back(event->Copy()); } - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& ana_configs = parsed_log_.audio_network_adaptation_events(); ASSERT_EQ(ana_configs.size(), events.size()); @@ -167,7 +170,7 @@ void RtcEventLogEncoderTest::TestRtpPackets() { // TODO(terelius): Test extensions for legacy encoding, too. RtpHeaderExtensionMap extension_map; - if (encoding_ != RtcEventLog::EncodingType::Legacy) { + if (encoding_type_ != RtcEventLog::EncodingType::Legacy) { extension_map = gen_.NewRtpHeaderExtensionMap(true); } @@ -185,8 +188,8 @@ void RtcEventLogEncoderTest::TestRtpPackets() { } // Encode and parse. - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); // For each SSRC, make sure the RTP packets associated with it to have been // correctly encoded and parsed. @@ -212,8 +215,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventAlrState) { history_.push_back(events[i]->Copy()); } - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& alr_state_events = parsed_log_.alr_state_events(); ASSERT_EQ(alr_state_events.size(), event_count_); @@ -223,7 +226,7 @@ TEST_P(RtcEventLogEncoderTest, RtcEventAlrState) { } TEST_P(RtcEventLogEncoderTest, RtcEventRouteChange) { - if (encoding_ == RtcEventLog::EncodingType::Legacy) { + if (encoding_type_ == RtcEventLog::EncodingType::Legacy) { return; } std::vector> events(event_count_); @@ -233,8 +236,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRouteChange) { history_.push_back(events[i]->Copy()); } - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& route_change_events = parsed_log_.route_change_events(); ASSERT_EQ(route_change_events.size(), event_count_); @@ -244,7 +247,7 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRouteChange) { } TEST_P(RtcEventLogEncoderTest, RtcEventRemoteEstimate) { - if (encoding_ == RtcEventLog::EncodingType::Legacy) { + if (encoding_type_ == RtcEventLog::EncodingType::Legacy) { return; } std::vector> events(event_count_); @@ -255,8 +258,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRemoteEstimate) { history_.push_back(std::make_unique(*events[i])); } - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& parsed_events = parsed_log_.remote_estimate_events(); ASSERT_EQ(parsed_events.size(), event_count_); @@ -409,8 +412,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventAudioPlayout) { original_events_by_ssrc[ssrc].push_back(std::move(event)); } - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& parsed_playout_events_by_ssrc = parsed_log_.audio_playout_events(); @@ -445,8 +448,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventAudioReceiveStreamConfig) { gen_.NewAudioReceiveStreamConfig(ssrc, extensions); history_.push_back(event->Copy()); - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& audio_recv_configs = parsed_log_.audio_recv_configs(); ASSERT_EQ(audio_recv_configs.size(), 1u); @@ -461,8 +464,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventAudioSendStreamConfig) { gen_.NewAudioSendStreamConfig(ssrc, extensions); history_.push_back(event->Copy()); - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& audio_send_configs = parsed_log_.audio_send_configs(); ASSERT_EQ(audio_send_configs.size(), 1u); @@ -479,8 +482,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventBweUpdateDelayBased) { history_.push_back(events[i]->Copy()); } - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& bwe_delay_updates = parsed_log_.bwe_delay_updates(); ASSERT_EQ(bwe_delay_updates.size(), event_count_); @@ -499,8 +502,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventBweUpdateLossBased) { history_.push_back(events[i]->Copy()); } - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& bwe_loss_updates = parsed_log_.bwe_loss_updates(); ASSERT_EQ(bwe_loss_updates.size(), event_count_); @@ -511,7 +514,7 @@ TEST_P(RtcEventLogEncoderTest, RtcEventBweUpdateLossBased) { } TEST_P(RtcEventLogEncoderTest, RtcEventGenericPacketReceived) { - if (encoding_ == RtcEventLog::EncodingType::Legacy) { + if (encoding_type_ == RtcEventLog::EncodingType::Legacy) { return; } std::vector> events( @@ -523,8 +526,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventGenericPacketReceived) { history_.push_back(events[i]->Copy()); } - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& packets_received = parsed_log_.generic_packets_received(); ASSERT_EQ(packets_received.size(), event_count_); @@ -536,7 +539,7 @@ TEST_P(RtcEventLogEncoderTest, RtcEventGenericPacketReceived) { } TEST_P(RtcEventLogEncoderTest, RtcEventGenericPacketSent) { - if (encoding_ == RtcEventLog::EncodingType::Legacy) { + if (encoding_type_ == RtcEventLog::EncodingType::Legacy) { return; } std::vector> events(event_count_); @@ -547,8 +550,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventGenericPacketSent) { history_.push_back(events[i]->Copy()); } - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& packets_sent = parsed_log_.generic_packets_sent(); ASSERT_EQ(packets_sent.size(), event_count_); @@ -559,7 +562,7 @@ TEST_P(RtcEventLogEncoderTest, RtcEventGenericPacketSent) { } TEST_P(RtcEventLogEncoderTest, RtcEventGenericAcksReceived) { - if (encoding_ == RtcEventLog::EncodingType::Legacy) { + if (encoding_type_ == RtcEventLog::EncodingType::Legacy) { return; } std::vector> events(event_count_); @@ -570,8 +573,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventGenericAcksReceived) { history_.push_back(events[i]->Copy()); } - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& decoded_events = parsed_log_.generic_acks_received(); ASSERT_EQ(decoded_events.size(), event_count_); @@ -590,12 +593,11 @@ TEST_P(RtcEventLogEncoderTest, RtcEventDtlsTransportState) { history_.push_back(events[i]->Copy()); } - const std::string encoded = - encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& dtls_transport_states = parsed_log_.dtls_transport_states(); - if (encoding_ == RtcEventLog::EncodingType::Legacy) { + if (encoding_type_ == RtcEventLog::EncodingType::Legacy) { ASSERT_EQ(dtls_transport_states.size(), 0u); return; } @@ -616,12 +618,11 @@ TEST_P(RtcEventLogEncoderTest, RtcEventDtlsWritableState) { history_.push_back(events[i]->Copy()); } - const std::string encoded = - encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& dtls_writable_states = parsed_log_.dtls_writable_states(); - if (encoding_ == RtcEventLog::EncodingType::Legacy) { + if (encoding_type_ == RtcEventLog::EncodingType::Legacy) { ASSERT_EQ(dtls_writable_states.size(), 0u); return; } @@ -654,15 +655,14 @@ TEST_P(RtcEventLogEncoderTest, RtcEventFrameDecoded) { original_events_by_ssrc[ssrc].push_back(std::move(event)); } - const std::string encoded = - encoder_->EncodeBatch(history_.begin(), history_.end()); - auto status = parsed_log_.ParseString(encoded); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + auto status = parsed_log_.ParseString(encoded_); if (!status.ok()) RTC_LOG(LS_ERROR) << status.message(); ASSERT_TRUE(status.ok()); const auto& decoded_frames_by_ssrc = parsed_log_.decoded_frames(); - if (encoding_ == RtcEventLog::EncodingType::Legacy) { + if (encoding_type_ == RtcEventLog::EncodingType::Legacy) { ASSERT_EQ(decoded_frames_by_ssrc.size(), 0u); return; } @@ -695,8 +695,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventIceCandidatePairConfig) { gen_.NewIceCandidatePairConfig(); history_.push_back(event->Copy()); - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& ice_candidate_pair_configs = parsed_log_.ice_candidate_pair_configs(); @@ -710,8 +710,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventIceCandidatePair) { std::unique_ptr event = gen_.NewIceCandidatePair(); history_.push_back(event->Copy()); - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& ice_candidate_pair_events = parsed_log_.ice_candidate_pair_events(); @@ -721,31 +721,35 @@ TEST_P(RtcEventLogEncoderTest, RtcEventIceCandidatePair) { } TEST_P(RtcEventLogEncoderTest, RtcEventLoggingStarted) { - const int64_t timestamp_us = rtc::TimeMicros(); - const int64_t utc_time_us = rtc::TimeUTCMicros(); + const int64_t timestamp_ms = prng_.Rand(1'000'000'000); + const int64_t utc_time_ms = prng_.Rand(1'000'000'000); - std::string encoded = encoder_->EncodeLogStart(timestamp_us, utc_time_us); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + // Overwrite the previously encoded LogStart event. + encoded_ = encoder_->EncodeLogStart(timestamp_ms * 1000, utc_time_ms * 1000); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& start_log_events = parsed_log_.start_log_events(); ASSERT_EQ(start_log_events.size(), 1u); - verifier_.VerifyLoggedStartEvent(timestamp_us, utc_time_us, + verifier_.VerifyLoggedStartEvent(timestamp_ms * 1000, utc_time_ms * 1000, start_log_events[0]); } TEST_P(RtcEventLogEncoderTest, RtcEventLoggingStopped) { - const int64_t start_timestamp_us = rtc::TimeMicros(); - const int64_t start_utc_time_us = rtc::TimeUTCMicros(); - std::string encoded = - encoder_->EncodeLogStart(start_timestamp_us, start_utc_time_us); - - const int64_t stop_timestamp_us = rtc::TimeMicros(); - encoded += encoder_->EncodeLogEnd(stop_timestamp_us); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + const int64_t start_timestamp_ms = prng_.Rand(1'000'000'000); + const int64_t start_utc_time_ms = prng_.Rand(1'000'000'000); + + // Overwrite the previously encoded LogStart event. + encoded_ = encoder_->EncodeLogStart(start_timestamp_ms * 1000, + start_utc_time_ms * 1000); + + const int64_t stop_timestamp_ms = + prng_.Rand(start_timestamp_ms, 2'000'000'000); + encoded_ += encoder_->EncodeLogEnd(stop_timestamp_ms * 1000); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& stop_log_events = parsed_log_.stop_log_events(); ASSERT_EQ(stop_log_events.size(), 1u); - verifier_.VerifyLoggedStopEvent(stop_timestamp_us, stop_log_events[0]); + verifier_.VerifyLoggedStopEvent(stop_timestamp_ms * 1000, stop_log_events[0]); } // TODO(eladalon/terelius): Test with multiple events in the batch. @@ -754,8 +758,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventProbeClusterCreated) { gen_.NewProbeClusterCreated(); history_.push_back(event->Copy()); - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& bwe_probe_cluster_created_events = parsed_log_.bwe_probe_cluster_created_events(); @@ -770,8 +774,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventProbeResultFailure) { gen_.NewProbeResultFailure(); history_.push_back(event->Copy()); - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& bwe_probe_failure_events = parsed_log_.bwe_probe_failure_events(); ASSERT_EQ(bwe_probe_failure_events.size(), 1u); @@ -785,8 +789,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventProbeResultSuccess) { gen_.NewProbeResultSuccess(); history_.push_back(event->Copy()); - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& bwe_probe_success_events = parsed_log_.bwe_probe_success_events(); ASSERT_EQ(bwe_probe_success_events.size(), 1u); @@ -809,8 +813,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpPacketIncoming) { history_.push_back(events[i]->Copy()); } - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& incoming_rtcp_packets = parsed_log_.incoming_rtcp_packets(); ASSERT_EQ(incoming_rtcp_packets.size(), event_count_); @@ -830,8 +834,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpPacketOutgoing) { history_.push_back(events[i]->Copy()); } - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& outgoing_rtcp_packets = parsed_log_.outgoing_rtcp_packets(); ASSERT_EQ(outgoing_rtcp_packets.size(), event_count_); @@ -852,9 +856,9 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpReceiverReport) { for (auto direction : {kIncomingPacket, kOutgoingPacket}) { std::vector events(event_count_); - std::vector timestamps_us(event_count_); + std::vector timestamps_ms(event_count_); for (size_t i = 0; i < event_count_; ++i) { - timestamps_us[i] = rtc::TimeMicros(); + timestamps_ms[i] = rtc::TimeMillis(); events[i] = gen_.NewReceiverReport(); rtc::Buffer buffer = events[i].Build(); if (direction == kIncomingPacket) { @@ -867,15 +871,14 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpReceiverReport) { fake_clock.AdvanceTime(TimeDelta::Millis(prng_.Rand(0, 1000))); } - std::string encoded = - encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& receiver_reports = parsed_log_.receiver_reports(direction); ASSERT_EQ(receiver_reports.size(), event_count_); for (size_t i = 0; i < event_count_; ++i) { - verifier_.VerifyLoggedReceiverReport(timestamps_us[i], events[i], + verifier_.VerifyLoggedReceiverReport(timestamps_ms[i], events[i], receiver_reports[i]); } } @@ -891,9 +894,9 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpSenderReport) { for (auto direction : {kIncomingPacket, kOutgoingPacket}) { std::vector events(event_count_); - std::vector timestamps_us(event_count_); + std::vector timestamps_ms(event_count_); for (size_t i = 0; i < event_count_; ++i) { - timestamps_us[i] = rtc::TimeMicros(); + timestamps_ms[i] = rtc::TimeMillis(); events[i] = gen_.NewSenderReport(); rtc::Buffer buffer = events[i].Build(); if (direction == kIncomingPacket) { @@ -906,15 +909,14 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpSenderReport) { fake_clock.AdvanceTime(TimeDelta::Millis(prng_.Rand(0, 1000))); } - std::string encoded = - encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& sender_reports = parsed_log_.sender_reports(direction); ASSERT_EQ(sender_reports.size(), event_count_); for (size_t i = 0; i < event_count_; ++i) { - verifier_.VerifyLoggedSenderReport(timestamps_us[i], events[i], + verifier_.VerifyLoggedSenderReport(timestamps_ms[i], events[i], sender_reports[i]); } } @@ -930,9 +932,9 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpExtendedReports) { for (auto direction : {kIncomingPacket, kOutgoingPacket}) { std::vector events(event_count_); - std::vector timestamps_us(event_count_); + std::vector timestamps_ms(event_count_); for (size_t i = 0; i < event_count_; ++i) { - timestamps_us[i] = rtc::TimeMicros(); + timestamps_ms[i] = rtc::TimeMillis(); events[i] = gen_.NewExtendedReports(); rtc::Buffer buffer = events[i].Build(); if (direction == kIncomingPacket) { @@ -945,15 +947,14 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpExtendedReports) { fake_clock.AdvanceTime(TimeDelta::Millis(prng_.Rand(0, 1000))); } - std::string encoded = - encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& extended_reports = parsed_log_.extended_reports(direction); ASSERT_EQ(extended_reports.size(), event_count_); for (size_t i = 0; i < event_count_; ++i) { - verifier_.VerifyLoggedExtendedReports(timestamps_us[i], events[i], + verifier_.VerifyLoggedExtendedReports(timestamps_ms[i], events[i], extended_reports[i]); } } @@ -969,9 +970,9 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpFir) { for (auto direction : {kIncomingPacket, kOutgoingPacket}) { std::vector events(event_count_); - std::vector timestamps_us(event_count_); + std::vector timestamps_ms(event_count_); for (size_t i = 0; i < event_count_; ++i) { - timestamps_us[i] = rtc::TimeMicros(); + timestamps_ms[i] = rtc::TimeMillis(); events[i] = gen_.NewFir(); rtc::Buffer buffer = events[i].Build(); if (direction == kIncomingPacket) { @@ -984,15 +985,14 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpFir) { fake_clock.AdvanceTime(TimeDelta::Millis(prng_.Rand(0, 1000))); } - std::string encoded = - encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& firs = parsed_log_.firs(direction); ASSERT_EQ(firs.size(), event_count_); for (size_t i = 0; i < event_count_; ++i) { - verifier_.VerifyLoggedFir(timestamps_us[i], events[i], firs[i]); + verifier_.VerifyLoggedFir(timestamps_ms[i], events[i], firs[i]); } } } @@ -1007,9 +1007,9 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpPli) { for (auto direction : {kIncomingPacket, kOutgoingPacket}) { std::vector events(event_count_); - std::vector timestamps_us(event_count_); + std::vector timestamps_ms(event_count_); for (size_t i = 0; i < event_count_; ++i) { - timestamps_us[i] = rtc::TimeMicros(); + timestamps_ms[i] = rtc::TimeMillis(); events[i] = gen_.NewPli(); rtc::Buffer buffer = events[i].Build(); if (direction == kIncomingPacket) { @@ -1022,15 +1022,51 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpPli) { fake_clock.AdvanceTime(TimeDelta::Millis(prng_.Rand(0, 1000))); } - std::string encoded = - encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& plis = parsed_log_.plis(direction); ASSERT_EQ(plis.size(), event_count_); for (size_t i = 0; i < event_count_; ++i) { - verifier_.VerifyLoggedPli(timestamps_us[i], events[i], plis[i]); + verifier_.VerifyLoggedPli(timestamps_ms[i], events[i], plis[i]); + } + } +} + +TEST_P(RtcEventLogEncoderTest, RtcEventRtcpBye) { + if (force_repeated_fields_) { + return; + } + + rtc::ScopedFakeClock fake_clock; + fake_clock.SetTime(Timestamp::Millis(prng_.Rand())); + + for (auto direction : {kIncomingPacket, kOutgoingPacket}) { + std::vector events(event_count_); + std::vector timestamps_ms(event_count_); + for (size_t i = 0; i < event_count_; ++i) { + timestamps_ms[i] = rtc::TimeMillis(); + events[i] = gen_.NewBye(); + rtc::Buffer buffer = events[i].Build(); + if (direction == kIncomingPacket) { + history_.push_back( + std::make_unique(buffer)); + } else { + history_.push_back( + std::make_unique(buffer)); + } + fake_clock.AdvanceTime(TimeDelta::Millis(prng_.Rand(0, 1000))); + } + + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); + + const auto& byes = parsed_log_.byes(direction); + ASSERT_EQ(byes.size(), event_count_); + + for (size_t i = 0; i < event_count_; ++i) { + verifier_.VerifyLoggedBye(timestamps_ms[i], events[i], byes[i]); } } } @@ -1045,9 +1081,9 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpNack) { for (auto direction : {kIncomingPacket, kOutgoingPacket}) { std::vector events(event_count_); - std::vector timestamps_us(event_count_); + std::vector timestamps_ms(event_count_); for (size_t i = 0; i < event_count_; ++i) { - timestamps_us[i] = rtc::TimeMicros(); + timestamps_ms[i] = rtc::TimeMillis(); events[i] = gen_.NewNack(); rtc::Buffer buffer = events[i].Build(); if (direction == kIncomingPacket) { @@ -1060,15 +1096,14 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpNack) { fake_clock.AdvanceTime(TimeDelta::Millis(prng_.Rand(0, 1000))); } - std::string encoded = - encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& nacks = parsed_log_.nacks(direction); ASSERT_EQ(nacks.size(), event_count_); for (size_t i = 0; i < event_count_; ++i) { - verifier_.VerifyLoggedNack(timestamps_us[i], events[i], nacks[i]); + verifier_.VerifyLoggedNack(timestamps_ms[i], events[i], nacks[i]); } } } @@ -1083,9 +1118,9 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpRemb) { for (auto direction : {kIncomingPacket, kOutgoingPacket}) { std::vector events(event_count_); - std::vector timestamps_us(event_count_); + std::vector timestamps_ms(event_count_); for (size_t i = 0; i < event_count_; ++i) { - timestamps_us[i] = rtc::TimeMicros(); + timestamps_ms[i] = rtc::TimeMillis(); events[i] = gen_.NewRemb(); rtc::Buffer buffer = events[i].Build(); if (direction == kIncomingPacket) { @@ -1098,15 +1133,14 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpRemb) { fake_clock.AdvanceTime(TimeDelta::Millis(prng_.Rand(0, 1000))); } - std::string encoded = - encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& rembs = parsed_log_.rembs(direction); ASSERT_EQ(rembs.size(), event_count_); for (size_t i = 0; i < event_count_; ++i) { - verifier_.VerifyLoggedRemb(timestamps_us[i], events[i], rembs[i]); + verifier_.VerifyLoggedRemb(timestamps_ms[i], events[i], rembs[i]); } } } @@ -1122,9 +1156,9 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpTransportFeedback) { for (auto direction : {kIncomingPacket, kOutgoingPacket}) { std::vector events; events.reserve(event_count_); - std::vector timestamps_us(event_count_); + std::vector timestamps_ms(event_count_); for (size_t i = 0; i < event_count_; ++i) { - timestamps_us[i] = rtc::TimeMicros(); + timestamps_ms[i] = rtc::TimeMillis(); events.emplace_back(gen_.NewTransportFeedback()); rtc::Buffer buffer = events[i].Build(); if (direction == kIncomingPacket) { @@ -1137,16 +1171,15 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpTransportFeedback) { fake_clock.AdvanceTime(TimeDelta::Millis(prng_.Rand(0, 1000))); } - std::string encoded = - encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& transport_feedbacks = parsed_log_.transport_feedbacks(direction); ASSERT_EQ(transport_feedbacks.size(), event_count_); for (size_t i = 0; i < event_count_; ++i) { - verifier_.VerifyLoggedTransportFeedback(timestamps_us[i], events[i], + verifier_.VerifyLoggedTransportFeedback(timestamps_ms[i], events[i], transport_feedbacks[i]); } } @@ -1163,9 +1196,9 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpLossNotification) { for (auto direction : {kIncomingPacket, kOutgoingPacket}) { std::vector events; events.reserve(event_count_); - std::vector timestamps_us(event_count_); + std::vector timestamps_ms(event_count_); for (size_t i = 0; i < event_count_; ++i) { - timestamps_us[i] = rtc::TimeMicros(); + timestamps_ms[i] = rtc::TimeMillis(); events.emplace_back(gen_.NewLossNotification()); rtc::Buffer buffer = events[i].Build(); if (direction == kIncomingPacket) { @@ -1178,15 +1211,14 @@ TEST_P(RtcEventLogEncoderTest, RtcEventRtcpLossNotification) { fake_clock.AdvanceTime(TimeDelta::Millis(prng_.Rand(0, 1000))); } - std::string encoded = - encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& loss_notifications = parsed_log_.loss_notifications(direction); ASSERT_EQ(loss_notifications.size(), event_count_); for (size_t i = 0; i < event_count_; ++i) { - verifier_.VerifyLoggedLossNotification(timestamps_us[i], events[i], + verifier_.VerifyLoggedLossNotification(timestamps_ms[i], events[i], loss_notifications[i]); } } @@ -1208,8 +1240,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventVideoReceiveStreamConfig) { gen_.NewVideoReceiveStreamConfig(ssrc, extensions); history_.push_back(event->Copy()); - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& video_recv_configs = parsed_log_.video_recv_configs(); ASSERT_EQ(video_recv_configs.size(), 1u); @@ -1224,8 +1256,8 @@ TEST_P(RtcEventLogEncoderTest, RtcEventVideoSendStreamConfig) { gen_.NewVideoSendStreamConfig(ssrc, extensions); history_.push_back(event->Copy()); - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); - ASSERT_TRUE(parsed_log_.ParseString(encoded).ok()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); + ASSERT_TRUE(parsed_log_.ParseString(encoded_).ok()); const auto& video_send_configs = parsed_log_.video_send_configs(); ASSERT_EQ(video_send_configs.size(), 1u); @@ -1245,8 +1277,8 @@ INSTANTIATE_TEST_SUITE_P( class RtcEventLogEncoderSimpleTest : public ::testing::TestWithParam { protected: - RtcEventLogEncoderSimpleTest() : encoding_(GetParam()) { - switch (encoding_) { + RtcEventLogEncoderSimpleTest() : encoding_type_(GetParam()) { + switch (encoding_type_) { case RtcEventLog::EncodingType::Legacy: encoder_ = std::make_unique(); break; @@ -1254,13 +1286,16 @@ class RtcEventLogEncoderSimpleTest encoder_ = std::make_unique(); break; } + encoded_ = + encoder_->EncodeLogStart(rtc::TimeMillis(), rtc::TimeUTCMillis()); } ~RtcEventLogEncoderSimpleTest() override = default; std::deque> history_; std::unique_ptr encoder_; ParsedRtcEventLog parsed_log_; - const RtcEventLog::EncodingType encoding_; + const RtcEventLog::EncodingType encoding_type_; + std::string encoded_; }; TEST_P(RtcEventLogEncoderSimpleTest, RtcEventLargeCompoundRtcpPacketIncoming) { @@ -1282,9 +1317,9 @@ TEST_P(RtcEventLogEncoderSimpleTest, RtcEventLargeCompoundRtcpPacketIncoming) { EXPECT_GT(packet.size(), static_cast(IP_PACKET_SIZE)); auto event = std::make_unique(packet); history_.push_back(event->Copy()); - std::string encoded = encoder_->EncodeBatch(history_.begin(), history_.end()); + encoded_ += encoder_->EncodeBatch(history_.begin(), history_.end()); - ParsedRtcEventLog::ParseStatus status = parsed_log_.ParseString(encoded); + ParsedRtcEventLog::ParseStatus status = parsed_log_.ParseString(encoded_); ASSERT_TRUE(status.ok()) << status.message(); const auto& incoming_rtcp_packets = parsed_log_.incoming_rtcp_packets(); diff --git a/logging/rtc_event_log/encoder/var_int.cc b/logging/rtc_event_log/encoder/var_int.cc index b2c695ee78..f2819c0c73 100644 --- a/logging/rtc_event_log/encoder/var_int.cc +++ b/logging/rtc_event_log/encoder/var_int.cc @@ -39,7 +39,8 @@ std::string EncodeVarInt(uint64_t input) { // There is some code duplication between the flavors of this function. // For performance's sake, it's best to just keep it. -size_t DecodeVarInt(absl::string_view input, uint64_t* output) { +std::pair DecodeVarInt(absl::string_view input, + uint64_t* output) { RTC_DCHECK(output); uint64_t decoded = 0; @@ -48,11 +49,11 @@ size_t DecodeVarInt(absl::string_view input, uint64_t* output) { << static_cast(7 * i)); if (!(input[i] & 0x80)) { *output = decoded; - return i + 1; + return {true, input.substr(i + 1)}; } } - return 0; + return {false, input}; } // There is some code duplication between the flavors of this function. @@ -63,7 +64,7 @@ size_t DecodeVarInt(rtc::BitBuffer* input, uint64_t* output) { uint64_t decoded = 0; for (size_t i = 0; i < kMaxVarIntLengthBytes; ++i) { uint8_t byte; - if (!input->ReadUInt8(&byte)) { + if (!input->ReadUInt8(byte)) { return 0; } decoded += diff --git a/logging/rtc_event_log/encoder/var_int.h b/logging/rtc_event_log/encoder/var_int.h index 178c9cec18..dbe1f1103f 100644 --- a/logging/rtc_event_log/encoder/var_int.h +++ b/logging/rtc_event_log/encoder/var_int.h @@ -15,6 +15,7 @@ #include #include +#include #include "absl/strings/string_view.h" #include "rtc_base/bit_buffer.h" @@ -26,20 +27,23 @@ extern const size_t kMaxVarIntLengthBytes; // Encode a given uint64_t as a varint. From least to most significant, // each batch of seven bits are put into the lower bits of a byte, and the last // remaining bit in that byte (the highest one) marks whether additional bytes -// follow (which happens if and only if there are other bits in |input| which +// follow (which happens if and only if there are other bits in `input` which // are non-zero). // Notes: If input == 0, one byte is used. If input is uint64_t::max, exactly // kMaxVarIntLengthBytes are used. std::string EncodeVarInt(uint64_t input); // Inverse of EncodeVarInt(). -// If decoding is successful, a non-zero number is returned, indicating the -// number of bytes read from |input|, and the decoded varint is written -// into |output|. -// If not successful, 0 is returned, and |output| is not modified. -size_t DecodeVarInt(absl::string_view input, uint64_t* output); +// Returns true and the remaining (unread) slice of the input if decoding +// succeeds. Returns false otherwise and `output` is not modified. +std::pair DecodeVarInt(absl::string_view input, + uint64_t* output); // Same as other version, but uses a rtc::BitBuffer for input. +// If decoding is successful, a non-zero number is returned, indicating the +// number of bytes read from `input`, and the decoded varint is written +// into `output`. +// If not successful, 0 is returned, and `output` is not modified. // Some bits may be consumed even if a varint fails to be read. size_t DecodeVarInt(rtc::BitBuffer* input, uint64_t* output); diff --git a/logging/rtc_event_log/events/rtc_event_alr_state.h b/logging/rtc_event_log/events/rtc_event_alr_state.h index 3ad0f005fb..74d66015ef 100644 --- a/logging/rtc_event_log/events/rtc_event_alr_state.h +++ b/logging/rtc_event_log/events/rtc_event_alr_state.h @@ -14,6 +14,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -39,13 +40,13 @@ class RtcEventAlrState final : public RtcEvent { struct LoggedAlrStateEvent { LoggedAlrStateEvent() = default; - LoggedAlrStateEvent(int64_t timestamp_us, bool in_alr) - : timestamp_us(timestamp_us), in_alr(in_alr) {} + LoggedAlrStateEvent(Timestamp timestamp, bool in_alr) + : timestamp(timestamp), in_alr(in_alr) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); bool in_alr; }; diff --git a/logging/rtc_event_log/events/rtc_event_audio_network_adaptation.h b/logging/rtc_event_log/events/rtc_event_audio_network_adaptation.h index 2b183bb307..aeeb28e218 100644 --- a/logging/rtc_event_log/events/rtc_event_audio_network_adaptation.h +++ b/logging/rtc_event_log/events/rtc_event_audio_network_adaptation.h @@ -14,6 +14,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" #include "modules/audio_coding/audio_network_adaptor/include/audio_network_adaptor_config.h" namespace webrtc { @@ -43,14 +44,14 @@ class RtcEventAudioNetworkAdaptation final : public RtcEvent { struct LoggedAudioNetworkAdaptationEvent { LoggedAudioNetworkAdaptationEvent() = default; - LoggedAudioNetworkAdaptationEvent(int64_t timestamp_us, + LoggedAudioNetworkAdaptationEvent(Timestamp timestamp, const AudioEncoderRuntimeConfig& config) - : timestamp_us(timestamp_us), config(config) {} + : timestamp(timestamp), config(config) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); AudioEncoderRuntimeConfig config; }; diff --git a/logging/rtc_event_log/events/rtc_event_audio_playout.h b/logging/rtc_event_log/events/rtc_event_audio_playout.h index 83825217a1..00d07a65bf 100644 --- a/logging/rtc_event_log/events/rtc_event_audio_playout.h +++ b/logging/rtc_event_log/events/rtc_event_audio_playout.h @@ -16,6 +16,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -41,13 +42,13 @@ class RtcEventAudioPlayout final : public RtcEvent { struct LoggedAudioPlayoutEvent { LoggedAudioPlayoutEvent() = default; - LoggedAudioPlayoutEvent(int64_t timestamp_us, uint32_t ssrc) - : timestamp_us(timestamp_us), ssrc(ssrc) {} + LoggedAudioPlayoutEvent(Timestamp timestamp, uint32_t ssrc) + : timestamp(timestamp), ssrc(ssrc) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); uint32_t ssrc; }; diff --git a/logging/rtc_event_log/events/rtc_event_audio_receive_stream_config.h b/logging/rtc_event_log/events/rtc_event_audio_receive_stream_config.h index 1edd8e1e46..ccf76025e6 100644 --- a/logging/rtc_event_log/events/rtc_event_audio_receive_stream_config.h +++ b/logging/rtc_event_log/events/rtc_event_audio_receive_stream_config.h @@ -14,6 +14,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" #include "logging/rtc_event_log/rtc_stream_config.h" namespace webrtc { @@ -42,13 +43,13 @@ class RtcEventAudioReceiveStreamConfig final : public RtcEvent { struct LoggedAudioRecvConfig { LoggedAudioRecvConfig() = default; - LoggedAudioRecvConfig(int64_t timestamp_us, const rtclog::StreamConfig config) - : timestamp_us(timestamp_us), config(config) {} + LoggedAudioRecvConfig(Timestamp timestamp, const rtclog::StreamConfig config) + : timestamp(timestamp), config(config) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); rtclog::StreamConfig config; }; diff --git a/logging/rtc_event_log/events/rtc_event_audio_send_stream_config.h b/logging/rtc_event_log/events/rtc_event_audio_send_stream_config.h index d3c60683b4..4e93871ae8 100644 --- a/logging/rtc_event_log/events/rtc_event_audio_send_stream_config.h +++ b/logging/rtc_event_log/events/rtc_event_audio_send_stream_config.h @@ -41,13 +41,13 @@ class RtcEventAudioSendStreamConfig final : public RtcEvent { struct LoggedAudioSendConfig { LoggedAudioSendConfig() = default; - LoggedAudioSendConfig(int64_t timestamp_us, const rtclog::StreamConfig config) - : timestamp_us(timestamp_us), config(config) {} + LoggedAudioSendConfig(Timestamp timestamp, const rtclog::StreamConfig config) + : timestamp(timestamp), config(config) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); rtclog::StreamConfig config; }; } // namespace webrtc diff --git a/logging/rtc_event_log/events/rtc_event_bwe_update_delay_based.h b/logging/rtc_event_log/events/rtc_event_bwe_update_delay_based.h index a83ea8b693..522f98fd8d 100644 --- a/logging/rtc_event_log/events/rtc_event_bwe_update_delay_based.h +++ b/logging/rtc_event_log/events/rtc_event_bwe_update_delay_based.h @@ -17,6 +17,7 @@ #include "api/network_state_predictor.h" #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -45,17 +46,17 @@ class RtcEventBweUpdateDelayBased final : public RtcEvent { struct LoggedBweDelayBasedUpdate { LoggedBweDelayBasedUpdate() = default; - LoggedBweDelayBasedUpdate(int64_t timestamp_us, + LoggedBweDelayBasedUpdate(Timestamp timestamp, int32_t bitrate_bps, BandwidthUsage detector_state) - : timestamp_us(timestamp_us), + : timestamp(timestamp), bitrate_bps(bitrate_bps), detector_state(detector_state) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); int32_t bitrate_bps; BandwidthUsage detector_state; }; diff --git a/logging/rtc_event_log/events/rtc_event_bwe_update_loss_based.h b/logging/rtc_event_log/events/rtc_event_bwe_update_loss_based.h index b638f1ac16..b031658ea2 100644 --- a/logging/rtc_event_log/events/rtc_event_bwe_update_loss_based.h +++ b/logging/rtc_event_log/events/rtc_event_bwe_update_loss_based.h @@ -16,6 +16,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -47,19 +48,19 @@ class RtcEventBweUpdateLossBased final : public RtcEvent { struct LoggedBweLossBasedUpdate { LoggedBweLossBasedUpdate() = default; - LoggedBweLossBasedUpdate(int64_t timestamp_us, + LoggedBweLossBasedUpdate(Timestamp timestamp, int32_t bitrate_bps, uint8_t fraction_lost, int32_t expected_packets) - : timestamp_us(timestamp_us), + : timestamp(timestamp), bitrate_bps(bitrate_bps), fraction_lost(fraction_lost), expected_packets(expected_packets) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); int32_t bitrate_bps; uint8_t fraction_lost; int32_t expected_packets; diff --git a/logging/rtc_event_log/events/rtc_event_dtls_transport_state.h b/logging/rtc_event_log/events/rtc_event_dtls_transport_state.h index af35a3f3bc..9a3eecb3d3 100644 --- a/logging/rtc_event_log/events/rtc_event_dtls_transport_state.h +++ b/logging/rtc_event_log/events/rtc_event_dtls_transport_state.h @@ -15,6 +15,7 @@ #include "api/dtls_transport_interface.h" #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -41,10 +42,10 @@ class RtcEventDtlsTransportState : public RtcEvent { }; struct LoggedDtlsTransportState { - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); DtlsTransportState dtls_transport_state; }; diff --git a/logging/rtc_event_log/events/rtc_event_dtls_writable_state.h b/logging/rtc_event_log/events/rtc_event_dtls_writable_state.h index c3ecce00ef..c0cc5b87ef 100644 --- a/logging/rtc_event_log/events/rtc_event_dtls_writable_state.h +++ b/logging/rtc_event_log/events/rtc_event_dtls_writable_state.h @@ -14,6 +14,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -41,10 +42,10 @@ struct LoggedDtlsWritableState { LoggedDtlsWritableState() = default; explicit LoggedDtlsWritableState(bool writable) : writable(writable) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); bool writable; }; diff --git a/logging/rtc_event_log/events/rtc_event_frame_decoded.h b/logging/rtc_event_log/events/rtc_event_frame_decoded.h index c549aa8831..4a6bb90d02 100644 --- a/logging/rtc_event_log/events/rtc_event_frame_decoded.h +++ b/logging/rtc_event_log/events/rtc_event_frame_decoded.h @@ -16,6 +16,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" #include "api/video/video_codec_type.h" namespace webrtc { @@ -56,10 +57,10 @@ class RtcEventFrameDecoded final : public RtcEvent { }; struct LoggedFrameDecoded { - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); int64_t render_time_ms; uint32_t ssrc; int width; diff --git a/logging/rtc_event_log/events/rtc_event_generic_ack_received.cc b/logging/rtc_event_log/events/rtc_event_generic_ack_received.cc index 2da2de6145..ba18d50ab6 100644 --- a/logging/rtc_event_log/events/rtc_event_generic_ack_received.cc +++ b/logging/rtc_event_log/events/rtc_event_generic_ack_received.cc @@ -23,6 +23,7 @@ RtcEventGenericAckReceived::CreateLogs( const std::vector& acked_packets) { std::vector> result; int64_t time_us = rtc::TimeMicros(); + result.reserve(acked_packets.size()); for (const AckedPacket& packet : acked_packets) { result.emplace_back(new RtcEventGenericAckReceived( time_us, packet_number, packet.packet_number, diff --git a/logging/rtc_event_log/events/rtc_event_generic_ack_received.h b/logging/rtc_event_log/events/rtc_event_generic_ack_received.h index 76e3cc24c4..75fc83c8b8 100644 --- a/logging/rtc_event_log/events/rtc_event_generic_ack_received.h +++ b/logging/rtc_event_log/events/rtc_event_generic_ack_received.h @@ -16,6 +16,7 @@ #include "absl/types/optional.h" #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -77,19 +78,19 @@ class RtcEventGenericAckReceived final : public RtcEvent { struct LoggedGenericAckReceived { LoggedGenericAckReceived() = default; - LoggedGenericAckReceived(int64_t timestamp_us, + LoggedGenericAckReceived(Timestamp timestamp, int64_t packet_number, int64_t acked_packet_number, absl::optional receive_acked_packet_time_ms) - : timestamp_us(timestamp_us), + : timestamp(timestamp), packet_number(packet_number), acked_packet_number(acked_packet_number), receive_acked_packet_time_ms(receive_acked_packet_time_ms) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); int64_t packet_number; int64_t acked_packet_number; absl::optional receive_acked_packet_time_ms; diff --git a/logging/rtc_event_log/events/rtc_event_generic_packet_received.h b/logging/rtc_event_log/events/rtc_event_generic_packet_received.h index 45e5e4cc44..428e7b3806 100644 --- a/logging/rtc_event_log/events/rtc_event_generic_packet_received.h +++ b/logging/rtc_event_log/events/rtc_event_generic_packet_received.h @@ -14,6 +14,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -45,17 +46,17 @@ class RtcEventGenericPacketReceived final : public RtcEvent { struct LoggedGenericPacketReceived { LoggedGenericPacketReceived() = default; - LoggedGenericPacketReceived(int64_t timestamp_us, + LoggedGenericPacketReceived(Timestamp timestamp, int64_t packet_number, int packet_length) - : timestamp_us(timestamp_us), + : timestamp(timestamp), packet_number(packet_number), packet_length(packet_length) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); int64_t packet_number; int packet_length; }; diff --git a/logging/rtc_event_log/events/rtc_event_generic_packet_sent.h b/logging/rtc_event_log/events/rtc_event_generic_packet_sent.h index 9ebafbe2ec..6e626e63a1 100644 --- a/logging/rtc_event_log/events/rtc_event_generic_packet_sent.h +++ b/logging/rtc_event_log/events/rtc_event_generic_packet_sent.h @@ -14,6 +14,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -62,24 +63,24 @@ class RtcEventGenericPacketSent final : public RtcEvent { struct LoggedGenericPacketSent { LoggedGenericPacketSent() = default; - LoggedGenericPacketSent(int64_t timestamp_us, + LoggedGenericPacketSent(Timestamp timestamp, int64_t packet_number, size_t overhead_length, size_t payload_length, size_t padding_length) - : timestamp_us(timestamp_us), + : timestamp(timestamp), packet_number(packet_number), overhead_length(overhead_length), payload_length(payload_length), padding_length(padding_length) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } size_t packet_length() const { return payload_length + padding_length + overhead_length; } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); int64_t packet_number; size_t overhead_length; size_t payload_length; diff --git a/logging/rtc_event_log/events/rtc_event_ice_candidate_pair.h b/logging/rtc_event_log/events/rtc_event_ice_candidate_pair.h index 717ddf360d..1f4d825a99 100644 --- a/logging/rtc_event_log/events/rtc_event_ice_candidate_pair.h +++ b/logging/rtc_event_log/events/rtc_event_ice_candidate_pair.h @@ -16,6 +16,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -56,19 +57,19 @@ class RtcEventIceCandidatePair final : public RtcEvent { struct LoggedIceCandidatePairEvent { LoggedIceCandidatePairEvent() = default; - LoggedIceCandidatePairEvent(int64_t timestamp_us, + LoggedIceCandidatePairEvent(Timestamp timestamp, IceCandidatePairEventType type, uint32_t candidate_pair_id, uint32_t transaction_id) - : timestamp_us(timestamp_us), + : timestamp(timestamp), type(type), candidate_pair_id(candidate_pair_id), transaction_id(transaction_id) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); IceCandidatePairEventType type; uint32_t candidate_pair_id; uint32_t transaction_id; diff --git a/logging/rtc_event_log/events/rtc_event_ice_candidate_pair_config.h b/logging/rtc_event_log/events/rtc_event_ice_candidate_pair_config.h index ab2eaf2422..465a799780 100644 --- a/logging/rtc_event_log/events/rtc_event_ice_candidate_pair_config.h +++ b/logging/rtc_event_log/events/rtc_event_ice_candidate_pair_config.h @@ -16,6 +16,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -113,10 +114,10 @@ class RtcEventIceCandidatePairConfig final : public RtcEvent { }; struct LoggedIceCandidatePairConfig { - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); IceCandidatePairConfigType type; uint32_t candidate_pair_id; IceCandidateType local_candidate_type; diff --git a/logging/rtc_event_log/events/rtc_event_probe_cluster_created.h b/logging/rtc_event_log/events/rtc_event_probe_cluster_created.h index f3221b91fd..974a0c9a5c 100644 --- a/logging/rtc_event_log/events/rtc_event_probe_cluster_created.h +++ b/logging/rtc_event_log/events/rtc_event_probe_cluster_created.h @@ -16,6 +16,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -50,21 +51,21 @@ class RtcEventProbeClusterCreated final : public RtcEvent { struct LoggedBweProbeClusterCreatedEvent { LoggedBweProbeClusterCreatedEvent() = default; - LoggedBweProbeClusterCreatedEvent(int64_t timestamp_us, + LoggedBweProbeClusterCreatedEvent(Timestamp timestamp, int32_t id, int32_t bitrate_bps, uint32_t min_packets, uint32_t min_bytes) - : timestamp_us(timestamp_us), + : timestamp(timestamp), id(id), bitrate_bps(bitrate_bps), min_packets(min_packets), min_bytes(min_bytes) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); int32_t id; int32_t bitrate_bps; uint32_t min_packets; diff --git a/logging/rtc_event_log/events/rtc_event_probe_result_failure.h b/logging/rtc_event_log/events/rtc_event_probe_result_failure.h index 868c30b61c..fa61b314b4 100644 --- a/logging/rtc_event_log/events/rtc_event_probe_result_failure.h +++ b/logging/rtc_event_log/events/rtc_event_probe_result_failure.h @@ -16,6 +16,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -50,15 +51,15 @@ class RtcEventProbeResultFailure final : public RtcEvent { struct LoggedBweProbeFailureEvent { LoggedBweProbeFailureEvent() = default; - LoggedBweProbeFailureEvent(int64_t timestamp_us, + LoggedBweProbeFailureEvent(Timestamp timestamp, int32_t id, ProbeFailureReason failure_reason) - : timestamp_us(timestamp_us), id(id), failure_reason(failure_reason) {} + : timestamp(timestamp), id(id), failure_reason(failure_reason) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); int32_t id; ProbeFailureReason failure_reason; }; diff --git a/logging/rtc_event_log/events/rtc_event_probe_result_success.h b/logging/rtc_event_log/events/rtc_event_probe_result_success.h index e3746681f6..d00cfa81d6 100644 --- a/logging/rtc_event_log/events/rtc_event_probe_result_success.h +++ b/logging/rtc_event_log/events/rtc_event_probe_result_success.h @@ -16,6 +16,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -43,15 +44,15 @@ class RtcEventProbeResultSuccess final : public RtcEvent { struct LoggedBweProbeSuccessEvent { LoggedBweProbeSuccessEvent() = default; - LoggedBweProbeSuccessEvent(int64_t timestamp_us, + LoggedBweProbeSuccessEvent(Timestamp timestamp, int32_t id, int32_t bitrate_bps) - : timestamp_us(timestamp_us), id(id), bitrate_bps(bitrate_bps) {} + : timestamp(timestamp), id(id), bitrate_bps(bitrate_bps) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); int32_t id; int32_t bitrate_bps; }; diff --git a/logging/rtc_event_log/events/rtc_event_remote_estimate.h b/logging/rtc_event_log/events/rtc_event_remote_estimate.h index 29b0c47195..956e05f682 100644 --- a/logging/rtc_event_log/events/rtc_event_remote_estimate.h +++ b/logging/rtc_event_log/events/rtc_event_remote_estimate.h @@ -15,6 +15,7 @@ #include "absl/types/optional.h" #include "api/rtc_event_log/rtc_event.h" #include "api/units/data_rate.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -37,10 +38,10 @@ class RtcEventRemoteEstimate final : public RtcEvent { struct LoggedRemoteEstimateEvent { LoggedRemoteEstimateEvent() = default; - int64_t log_time_us() const { return timestamp_ms * 1000; } - int64_t log_time_ms() const { return timestamp_ms; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_ms; + Timestamp timestamp = Timestamp::MinusInfinity(); absl::optional link_capacity_lower; absl::optional link_capacity_upper; }; diff --git a/logging/rtc_event_log/events/rtc_event_route_change.h b/logging/rtc_event_log/events/rtc_event_route_change.h index 455a832141..4a4e9aef80 100644 --- a/logging/rtc_event_log/events/rtc_event_route_change.h +++ b/logging/rtc_event_log/events/rtc_event_route_change.h @@ -14,6 +14,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" namespace webrtc { @@ -41,15 +42,13 @@ class RtcEventRouteChange final : public RtcEvent { struct LoggedRouteChangeEvent { LoggedRouteChangeEvent() = default; - LoggedRouteChangeEvent(int64_t timestamp_ms, - bool connected, - uint32_t overhead) - : timestamp_ms(timestamp_ms), connected(connected), overhead(overhead) {} + LoggedRouteChangeEvent(Timestamp timestamp, bool connected, uint32_t overhead) + : timestamp(timestamp), connected(connected), overhead(overhead) {} - int64_t log_time_us() const { return timestamp_ms * 1000; } - int64_t log_time_ms() const { return timestamp_ms; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_ms; + Timestamp timestamp = Timestamp::MinusInfinity(); bool connected; uint32_t overhead; }; diff --git a/logging/rtc_event_log/events/rtc_event_rtp_packet_incoming.cc b/logging/rtc_event_log/events/rtc_event_rtp_packet_incoming.cc index 4e505bdbf1..4cf33a238f 100644 --- a/logging/rtc_event_log/events/rtc_event_rtp_packet_incoming.cc +++ b/logging/rtc_event_log/events/rtc_event_rtp_packet_incoming.cc @@ -18,22 +18,11 @@ namespace webrtc { RtcEventRtpPacketIncoming::RtcEventRtpPacketIncoming( const RtpPacketReceived& packet) - : payload_length_(packet.payload_size()), - header_length_(packet.headers_size()), - padding_length_(packet.padding_size()) { - header_.CopyHeaderFrom(packet); - RTC_DCHECK_EQ(packet.size(), - payload_length_ + header_length_ + padding_length_); -} + : packet_(packet) {} RtcEventRtpPacketIncoming::RtcEventRtpPacketIncoming( const RtcEventRtpPacketIncoming& other) - : RtcEvent(other.timestamp_us_), - payload_length_(other.payload_length_), - header_length_(other.header_length_), - padding_length_(other.padding_length_) { - header_.CopyHeaderFrom(other.header_); -} + : RtcEvent(other.timestamp_us_), packet_(other.packet_) {} RtcEventRtpPacketIncoming::~RtcEventRtpPacketIncoming() = default; diff --git a/logging/rtc_event_log/events/rtc_event_rtp_packet_incoming.h b/logging/rtc_event_log/events/rtc_event_rtp_packet_incoming.h index 8d13dc6e87..ee48fa360b 100644 --- a/logging/rtc_event_log/events/rtc_event_rtp_packet_incoming.h +++ b/logging/rtc_event_log/events/rtc_event_rtp_packet_incoming.h @@ -11,8 +11,12 @@ #ifndef LOGGING_RTC_EVENT_LOG_EVENTS_RTC_EVENT_RTP_PACKET_INCOMING_H_ #define LOGGING_RTC_EVENT_LOG_EVENTS_RTC_EVENT_RTP_PACKET_INCOMING_H_ +#include +#include #include +#include +#include "api/array_view.h" #include "api/rtc_event_log/rtc_event.h" #include "modules/rtp_rtcp/source/rtp_packet.h" @@ -32,22 +36,33 @@ class RtcEventRtpPacketIncoming final : public RtcEvent { std::unique_ptr Copy() const; - size_t packet_length() const { - return payload_length_ + header_length_ + padding_length_; + size_t packet_length() const { return packet_.size(); } + + rtc::ArrayView RawHeader() const { + return rtc::MakeArrayView(packet_.data(), header_length()); + } + uint32_t Ssrc() const { return packet_.Ssrc(); } + uint32_t Timestamp() const { return packet_.Timestamp(); } + uint16_t SequenceNumber() const { return packet_.SequenceNumber(); } + uint8_t PayloadType() const { return packet_.PayloadType(); } + bool Marker() const { return packet_.Marker(); } + template + bool GetExtension(Args&&... args) const { + return packet_.GetExtension(std::forward(args)...); + } + template + bool HasExtension() const { + return packet_.HasExtension(); } - const RtpPacket& header() const { return header_; } - size_t payload_length() const { return payload_length_; } - size_t header_length() const { return header_length_; } - size_t padding_length() const { return padding_length_; } + size_t payload_length() const { return packet_.payload_size(); } + size_t header_length() const { return packet_.headers_size(); } + size_t padding_length() const { return packet_.padding_size(); } private: RtcEventRtpPacketIncoming(const RtcEventRtpPacketIncoming& other); - RtpPacket header_; // Only the packet's header will be stored here. - const size_t payload_length_; // Media payload, excluding header and padding. - const size_t header_length_; // RTP header. - const size_t padding_length_; // RTP padding. + const RtpPacket packet_; }; } // namespace webrtc diff --git a/logging/rtc_event_log/events/rtc_event_rtp_packet_outgoing.cc b/logging/rtc_event_log/events/rtc_event_rtp_packet_outgoing.cc index e5324bf1a3..a6a4d99702 100644 --- a/logging/rtc_event_log/events/rtc_event_rtp_packet_outgoing.cc +++ b/logging/rtc_event_log/events/rtc_event_rtp_packet_outgoing.cc @@ -19,24 +19,13 @@ namespace webrtc { RtcEventRtpPacketOutgoing::RtcEventRtpPacketOutgoing( const RtpPacketToSend& packet, int probe_cluster_id) - : payload_length_(packet.payload_size()), - header_length_(packet.headers_size()), - padding_length_(packet.padding_size()), - probe_cluster_id_(probe_cluster_id) { - header_.CopyHeaderFrom(packet); - RTC_DCHECK_EQ(packet.size(), - payload_length_ + header_length_ + padding_length_); -} + : packet_(packet), probe_cluster_id_(probe_cluster_id) {} RtcEventRtpPacketOutgoing::RtcEventRtpPacketOutgoing( const RtcEventRtpPacketOutgoing& other) : RtcEvent(other.timestamp_us_), - payload_length_(other.payload_length_), - header_length_(other.header_length_), - padding_length_(other.padding_length_), - probe_cluster_id_(other.probe_cluster_id_) { - header_.CopyHeaderFrom(other.header_); -} + packet_(other.packet_), + probe_cluster_id_(other.probe_cluster_id_) {} RtcEventRtpPacketOutgoing::~RtcEventRtpPacketOutgoing() = default; diff --git a/logging/rtc_event_log/events/rtc_event_rtp_packet_outgoing.h b/logging/rtc_event_log/events/rtc_event_rtp_packet_outgoing.h index de4abcc904..9ef5b1afdd 100644 --- a/logging/rtc_event_log/events/rtc_event_rtp_packet_outgoing.h +++ b/logging/rtc_event_log/events/rtc_event_rtp_packet_outgoing.h @@ -11,8 +11,12 @@ #ifndef LOGGING_RTC_EVENT_LOG_EVENTS_RTC_EVENT_RTP_PACKET_OUTGOING_H_ #define LOGGING_RTC_EVENT_LOG_EVENTS_RTC_EVENT_RTP_PACKET_OUTGOING_H_ +#include +#include #include +#include +#include "api/array_view.h" #include "api/rtc_event_log/rtc_event.h" #include "modules/rtp_rtcp/source/rtp_packet.h" @@ -33,23 +37,34 @@ class RtcEventRtpPacketOutgoing final : public RtcEvent { std::unique_ptr Copy() const; - size_t packet_length() const { - return payload_length_ + header_length_ + padding_length_; + size_t packet_length() const { return packet_.size(); } + + rtc::ArrayView RawHeader() const { + return rtc::MakeArrayView(packet_.data(), header_length()); + } + uint32_t Ssrc() const { return packet_.Ssrc(); } + uint32_t Timestamp() const { return packet_.Timestamp(); } + uint16_t SequenceNumber() const { return packet_.SequenceNumber(); } + uint8_t PayloadType() const { return packet_.PayloadType(); } + bool Marker() const { return packet_.Marker(); } + template + bool GetExtension(Args&&... args) const { + return packet_.GetExtension(std::forward(args)...); + } + template + bool HasExtension() const { + return packet_.HasExtension(); } - const RtpPacket& header() const { return header_; } - size_t payload_length() const { return payload_length_; } - size_t header_length() const { return header_length_; } - size_t padding_length() const { return padding_length_; } + size_t payload_length() const { return packet_.payload_size(); } + size_t header_length() const { return packet_.headers_size(); } + size_t padding_length() const { return packet_.padding_size(); } int probe_cluster_id() const { return probe_cluster_id_; } private: RtcEventRtpPacketOutgoing(const RtcEventRtpPacketOutgoing& other); - RtpPacket header_; // Only the packet's header will be stored here. - const size_t payload_length_; // Media payload, excluding header and padding. - const size_t header_length_; // RTP header. - const size_t padding_length_; // RTP padding. + const RtpPacket packet_; // TODO(eladalon): Delete |probe_cluster_id_| along with legacy encoding. const int probe_cluster_id_; }; diff --git a/logging/rtc_event_log/events/rtc_event_video_receive_stream_config.h b/logging/rtc_event_log/events/rtc_event_video_receive_stream_config.h index 2bf52476a1..e7b9061872 100644 --- a/logging/rtc_event_log/events/rtc_event_video_receive_stream_config.h +++ b/logging/rtc_event_log/events/rtc_event_video_receive_stream_config.h @@ -14,6 +14,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" #include "logging/rtc_event_log/rtc_stream_config.h" namespace webrtc { @@ -42,13 +43,13 @@ class RtcEventVideoReceiveStreamConfig final : public RtcEvent { struct LoggedVideoRecvConfig { LoggedVideoRecvConfig() = default; - LoggedVideoRecvConfig(int64_t timestamp_us, const rtclog::StreamConfig config) - : timestamp_us(timestamp_us), config(config) {} + LoggedVideoRecvConfig(Timestamp timestamp, const rtclog::StreamConfig config) + : timestamp(timestamp), config(config) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); rtclog::StreamConfig config; }; diff --git a/logging/rtc_event_log/events/rtc_event_video_send_stream_config.h b/logging/rtc_event_log/events/rtc_event_video_send_stream_config.h index cf95afc4d8..e72e75e49d 100644 --- a/logging/rtc_event_log/events/rtc_event_video_send_stream_config.h +++ b/logging/rtc_event_log/events/rtc_event_video_send_stream_config.h @@ -14,6 +14,7 @@ #include #include "api/rtc_event_log/rtc_event.h" +#include "api/units/timestamp.h" #include "logging/rtc_event_log/rtc_stream_config.h" namespace webrtc { @@ -41,13 +42,13 @@ class RtcEventVideoSendStreamConfig final : public RtcEvent { struct LoggedVideoSendConfig { LoggedVideoSendConfig() = default; - LoggedVideoSendConfig(int64_t timestamp_us, const rtclog::StreamConfig config) - : timestamp_us(timestamp_us), config(config) {} + LoggedVideoSendConfig(Timestamp timestamp, const rtclog::StreamConfig config) + : timestamp(timestamp), config(config) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); rtclog::StreamConfig config; }; } // namespace webrtc diff --git a/logging/rtc_event_log/fake_rtc_event_log.cc b/logging/rtc_event_log/fake_rtc_event_log.cc index 55f4b582c7..5a44b00694 100644 --- a/logging/rtc_event_log/fake_rtc_event_log.cc +++ b/logging/rtc_event_log/fake_rtc_event_log.cc @@ -10,32 +10,29 @@ #include "logging/rtc_event_log/fake_rtc_event_log.h" -#include "logging/rtc_event_log/events/rtc_event_ice_candidate_pair.h" -#include "rtc_base/bind.h" -#include "rtc_base/checks.h" -#include "rtc_base/logging.h" +#include +#include -namespace webrtc { +#include "api/rtc_event_log/rtc_event_log.h" +#include "rtc_base/synchronization/mutex.h" -FakeRtcEventLog::FakeRtcEventLog(rtc::Thread* thread) : thread_(thread) { - RTC_DCHECK(thread_); -} -FakeRtcEventLog::~FakeRtcEventLog() = default; +namespace webrtc { bool FakeRtcEventLog::StartLogging(std::unique_ptr output, int64_t output_period_ms) { return true; } -void FakeRtcEventLog::StopLogging() { - invoker_.Flush(thread_); -} +void FakeRtcEventLog::StopLogging() {} void FakeRtcEventLog::Log(std::unique_ptr event) { - RtcEvent::Type rtc_event_type = event->GetType(); - invoker_.AsyncInvoke( - RTC_FROM_HERE, thread_, - rtc::Bind(&FakeRtcEventLog::IncrementEventCount, this, rtc_event_type)); + MutexLock lock(&mu_); + ++count_[event->GetType()]; +} + +int FakeRtcEventLog::GetEventCount(RtcEvent::Type event_type) { + MutexLock lock(&mu_); + return count_[event_type]; } } // namespace webrtc diff --git a/logging/rtc_event_log/fake_rtc_event_log.h b/logging/rtc_event_log/fake_rtc_event_log.h index fb0e6ff4dc..effa7507f1 100644 --- a/logging/rtc_event_log/fake_rtc_event_log.h +++ b/logging/rtc_event_log/fake_rtc_event_log.h @@ -16,26 +16,25 @@ #include "api/rtc_event_log/rtc_event.h" #include "api/rtc_event_log/rtc_event_log.h" -#include "rtc_base/async_invoker.h" -#include "rtc_base/thread.h" +#include "rtc_base/synchronization/mutex.h" +#include "rtc_base/thread_annotations.h" namespace webrtc { class FakeRtcEventLog : public RtcEventLog { public: - explicit FakeRtcEventLog(rtc::Thread* thread); - ~FakeRtcEventLog() override; + FakeRtcEventLog() = default; + ~FakeRtcEventLog() override = default; + bool StartLogging(std::unique_ptr output, int64_t output_period_ms) override; void StopLogging() override; void Log(std::unique_ptr event) override; - int GetEventCount(RtcEvent::Type event_type) { return count_[event_type]; } + int GetEventCount(RtcEvent::Type event_type); private: - void IncrementEventCount(RtcEvent::Type event_type) { ++count_[event_type]; } - std::map count_; - rtc::Thread* thread_; - rtc::AsyncInvoker invoker_; + Mutex mu_; + std::map count_ RTC_GUARDED_BY(mu_); }; } // namespace webrtc diff --git a/logging/rtc_event_log/fake_rtc_event_log_factory.cc b/logging/rtc_event_log/fake_rtc_event_log_factory.cc index f84f74fdb6..f663ec5abe 100644 --- a/logging/rtc_event_log/fake_rtc_event_log_factory.cc +++ b/logging/rtc_event_log/fake_rtc_event_log_factory.cc @@ -10,14 +10,16 @@ #include "logging/rtc_event_log/fake_rtc_event_log_factory.h" +#include + #include "api/rtc_event_log/rtc_event_log.h" #include "logging/rtc_event_log/fake_rtc_event_log.h" namespace webrtc { std::unique_ptr FakeRtcEventLogFactory::CreateRtcEventLog( - RtcEventLog::EncodingType encoding_type) { - std::unique_ptr fake_event_log(new FakeRtcEventLog(thread())); + RtcEventLog::EncodingType /*encoding_type*/) { + auto fake_event_log = std::make_unique(); last_log_created_ = fake_event_log.get(); return fake_event_log; } diff --git a/logging/rtc_event_log/fake_rtc_event_log_factory.h b/logging/rtc_event_log/fake_rtc_event_log_factory.h index 873e50efdc..114c3e6323 100644 --- a/logging/rtc_event_log/fake_rtc_event_log_factory.h +++ b/logging/rtc_event_log/fake_rtc_event_log_factory.h @@ -15,24 +15,21 @@ #include "api/rtc_event_log/rtc_event_log_factory_interface.h" #include "logging/rtc_event_log/fake_rtc_event_log.h" -#include "rtc_base/thread.h" namespace webrtc { class FakeRtcEventLogFactory : public RtcEventLogFactoryInterface { public: - explicit FakeRtcEventLogFactory(rtc::Thread* thread) : thread_(thread) {} - ~FakeRtcEventLogFactory() override {} + FakeRtcEventLogFactory() = default; + ~FakeRtcEventLogFactory() override = default; std::unique_ptr CreateRtcEventLog( RtcEventLog::EncodingType encoding_type) override; - webrtc::RtcEventLog* last_log_created() { return last_log_created_; } - rtc::Thread* thread() { return thread_; } + webrtc::FakeRtcEventLog* last_log_created() { return last_log_created_; } private: - webrtc::RtcEventLog* last_log_created_; - rtc::Thread* thread_; + webrtc::FakeRtcEventLog* last_log_created_; }; } // namespace webrtc diff --git a/logging/rtc_event_log/logged_events.cc b/logging/rtc_event_log/logged_events.cc index dd0a8aae2a..5ef3de11c0 100644 --- a/logging/rtc_event_log/logged_events.cc +++ b/logging/rtc_event_log/logged_events.cc @@ -40,13 +40,13 @@ LoggedPacketInfo::LoggedPacketInfo(const LoggedPacketInfo&) = default; LoggedPacketInfo::~LoggedPacketInfo() {} -LoggedRtcpPacket::LoggedRtcpPacket(int64_t timestamp_us, +LoggedRtcpPacket::LoggedRtcpPacket(Timestamp timestamp, const std::vector& packet) - : timestamp_us(timestamp_us), raw_data(packet) {} + : timestamp(timestamp), raw_data(packet) {} -LoggedRtcpPacket::LoggedRtcpPacket(int64_t timestamp_us, +LoggedRtcpPacket::LoggedRtcpPacket(Timestamp timestamp, const std::string& packet) - : timestamp_us(timestamp_us), raw_data(packet.size()) { + : timestamp(timestamp), raw_data(packet.size()) { memcpy(raw_data.data(), packet.data(), packet.size()); } diff --git a/logging/rtc_event_log/logged_events.h b/logging/rtc_event_log/logged_events.h index 1ed21befe0..5bce658c30 100644 --- a/logging/rtc_event_log/logged_events.h +++ b/logging/rtc_event_log/logged_events.h @@ -17,6 +17,7 @@ #include "api/rtp_headers.h" #include "api/units/time_delta.h" #include "api/units/timestamp.h" +#include "modules/rtp_rtcp/source/rtcp_packet/bye.h" #include "modules/rtp_rtcp/source/rtcp_packet/extended_reports.h" #include "modules/rtp_rtcp/source/rtcp_packet/fir.h" #include "modules/rtp_rtcp/source/rtcp_packet/loss_notification.h" @@ -36,19 +37,19 @@ namespace webrtc { // adding a vptr. struct LoggedRtpPacket { - LoggedRtpPacket(int64_t timestamp_us, + LoggedRtpPacket(Timestamp timestamp, RTPHeader header, size_t header_length, size_t total_length) - : timestamp_us(timestamp_us), + : timestamp(timestamp), header(header), header_length(header_length), total_length(total_length) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp; // TODO(terelius): This allocates space for 15 CSRCs even if none are used. RTPHeader header; size_t header_length; @@ -56,145 +57,145 @@ struct LoggedRtpPacket { }; struct LoggedRtpPacketIncoming { - LoggedRtpPacketIncoming(int64_t timestamp_us, + LoggedRtpPacketIncoming(Timestamp timestamp, RTPHeader header, size_t header_length, size_t total_length) - : rtp(timestamp_us, header, header_length, total_length) {} - int64_t log_time_us() const { return rtp.timestamp_us; } - int64_t log_time_ms() const { return rtp.timestamp_us / 1000; } + : rtp(timestamp, header, header_length, total_length) {} + int64_t log_time_us() const { return rtp.timestamp.us(); } + int64_t log_time_ms() const { return rtp.timestamp.ms(); } LoggedRtpPacket rtp; }; struct LoggedRtpPacketOutgoing { - LoggedRtpPacketOutgoing(int64_t timestamp_us, + LoggedRtpPacketOutgoing(Timestamp timestamp, RTPHeader header, size_t header_length, size_t total_length) - : rtp(timestamp_us, header, header_length, total_length) {} - int64_t log_time_us() const { return rtp.timestamp_us; } - int64_t log_time_ms() const { return rtp.timestamp_us / 1000; } + : rtp(timestamp, header, header_length, total_length) {} + int64_t log_time_us() const { return rtp.timestamp.us(); } + int64_t log_time_ms() const { return rtp.timestamp.ms(); } LoggedRtpPacket rtp; }; struct LoggedRtcpPacket { - LoggedRtcpPacket(int64_t timestamp_us, const std::vector& packet); - LoggedRtcpPacket(int64_t timestamp_us, const std::string& packet); + LoggedRtcpPacket(Timestamp timestamp, const std::vector& packet); + LoggedRtcpPacket(Timestamp timestamp, const std::string& packet); LoggedRtcpPacket(const LoggedRtcpPacket&); ~LoggedRtcpPacket(); - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp; std::vector raw_data; }; struct LoggedRtcpPacketIncoming { - LoggedRtcpPacketIncoming(int64_t timestamp_us, + LoggedRtcpPacketIncoming(Timestamp timestamp, const std::vector& packet) - : rtcp(timestamp_us, packet) {} - LoggedRtcpPacketIncoming(uint64_t timestamp_us, const std::string& packet) - : rtcp(timestamp_us, packet) {} + : rtcp(timestamp, packet) {} + LoggedRtcpPacketIncoming(Timestamp timestamp, const std::string& packet) + : rtcp(timestamp, packet) {} - int64_t log_time_us() const { return rtcp.timestamp_us; } - int64_t log_time_ms() const { return rtcp.timestamp_us / 1000; } + int64_t log_time_us() const { return rtcp.timestamp.us(); } + int64_t log_time_ms() const { return rtcp.timestamp.ms(); } LoggedRtcpPacket rtcp; }; struct LoggedRtcpPacketOutgoing { - LoggedRtcpPacketOutgoing(int64_t timestamp_us, + LoggedRtcpPacketOutgoing(Timestamp timestamp, const std::vector& packet) - : rtcp(timestamp_us, packet) {} - LoggedRtcpPacketOutgoing(uint64_t timestamp_us, const std::string& packet) - : rtcp(timestamp_us, packet) {} + : rtcp(timestamp, packet) {} + LoggedRtcpPacketOutgoing(Timestamp timestamp, const std::string& packet) + : rtcp(timestamp, packet) {} - int64_t log_time_us() const { return rtcp.timestamp_us; } - int64_t log_time_ms() const { return rtcp.timestamp_us / 1000; } + int64_t log_time_us() const { return rtcp.timestamp.us(); } + int64_t log_time_ms() const { return rtcp.timestamp.ms(); } LoggedRtcpPacket rtcp; }; struct LoggedRtcpPacketReceiverReport { LoggedRtcpPacketReceiverReport() = default; - LoggedRtcpPacketReceiverReport(int64_t timestamp_us, + LoggedRtcpPacketReceiverReport(Timestamp timestamp, const rtcp::ReceiverReport& rr) - : timestamp_us(timestamp_us), rr(rr) {} + : timestamp(timestamp), rr(rr) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); rtcp::ReceiverReport rr; }; struct LoggedRtcpPacketSenderReport { LoggedRtcpPacketSenderReport() = default; - LoggedRtcpPacketSenderReport(int64_t timestamp_us, + LoggedRtcpPacketSenderReport(Timestamp timestamp, const rtcp::SenderReport& sr) - : timestamp_us(timestamp_us), sr(sr) {} + : timestamp(timestamp), sr(sr) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); rtcp::SenderReport sr; }; struct LoggedRtcpPacketExtendedReports { LoggedRtcpPacketExtendedReports() = default; - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); rtcp::ExtendedReports xr; }; struct LoggedRtcpPacketRemb { LoggedRtcpPacketRemb() = default; - LoggedRtcpPacketRemb(int64_t timestamp_us, const rtcp::Remb& remb) - : timestamp_us(timestamp_us), remb(remb) {} + LoggedRtcpPacketRemb(Timestamp timestamp, const rtcp::Remb& remb) + : timestamp(timestamp), remb(remb) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); rtcp::Remb remb; }; struct LoggedRtcpPacketNack { LoggedRtcpPacketNack() = default; - LoggedRtcpPacketNack(int64_t timestamp_us, const rtcp::Nack& nack) - : timestamp_us(timestamp_us), nack(nack) {} + LoggedRtcpPacketNack(Timestamp timestamp, const rtcp::Nack& nack) + : timestamp(timestamp), nack(nack) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); rtcp::Nack nack; }; struct LoggedRtcpPacketFir { LoggedRtcpPacketFir() = default; - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); rtcp::Fir fir; }; struct LoggedRtcpPacketPli { LoggedRtcpPacketPli() = default; - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); rtcp::Pli pli; }; @@ -203,52 +204,64 @@ struct LoggedRtcpPacketTransportFeedback { : transport_feedback(/*include_timestamps=*/true, /*include_lost*/ true) { } LoggedRtcpPacketTransportFeedback( - int64_t timestamp_us, + Timestamp timestamp, const rtcp::TransportFeedback& transport_feedback) - : timestamp_us(timestamp_us), transport_feedback(transport_feedback) {} + : timestamp(timestamp), transport_feedback(transport_feedback) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); rtcp::TransportFeedback transport_feedback; }; struct LoggedRtcpPacketLossNotification { LoggedRtcpPacketLossNotification() = default; LoggedRtcpPacketLossNotification( - int64_t timestamp_us, + Timestamp timestamp, const rtcp::LossNotification& loss_notification) - : timestamp_us(timestamp_us), loss_notification(loss_notification) {} + : timestamp(timestamp), loss_notification(loss_notification) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp = Timestamp::MinusInfinity(); rtcp::LossNotification loss_notification; }; +struct LoggedRtcpPacketBye { + LoggedRtcpPacketBye() = default; + + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } + + Timestamp timestamp = Timestamp::MinusInfinity(); + rtcp::Bye bye; +}; + struct LoggedStartEvent { - explicit LoggedStartEvent(int64_t timestamp_us) - : LoggedStartEvent(timestamp_us, timestamp_us / 1000) {} + explicit LoggedStartEvent(Timestamp timestamp) + : LoggedStartEvent(timestamp, timestamp) {} - LoggedStartEvent(int64_t timestamp_us, int64_t utc_start_time_ms) - : timestamp_us(timestamp_us), utc_start_time_ms(utc_start_time_ms) {} + LoggedStartEvent(Timestamp timestamp, Timestamp utc_start_time) + : timestamp(timestamp), utc_start_time(utc_start_time) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; - int64_t utc_start_time_ms; + Timestamp utc_time() const { return utc_start_time; } + + Timestamp timestamp; + Timestamp utc_start_time; }; struct LoggedStopEvent { - explicit LoggedStopEvent(int64_t timestamp_us) : timestamp_us(timestamp_us) {} + explicit LoggedStopEvent(Timestamp timestamp) : timestamp(timestamp) {} - int64_t log_time_us() const { return timestamp_us; } - int64_t log_time_ms() const { return timestamp_us / 1000; } + int64_t log_time_us() const { return timestamp.us(); } + int64_t log_time_ms() const { return timestamp.ms(); } - int64_t timestamp_us; + Timestamp timestamp; }; struct InferredRouteChangeEvent { @@ -326,8 +339,5 @@ struct LoggedIceEvent { }; - - - } // namespace webrtc #endif // LOGGING_RTC_EVENT_LOG_LOGGED_EVENTS_H_ diff --git a/logging/rtc_event_log/rtc_event_log_impl.cc b/logging/rtc_event_log/rtc_event_log_impl.cc index 4a272f08cf..700f639311 100644 --- a/logging/rtc_event_log/rtc_event_log_impl.cc +++ b/logging/rtc_event_log/rtc_event_log_impl.cc @@ -90,8 +90,8 @@ bool RtcEventLogImpl::StartLogging(std::unique_ptr output, return false; } - const int64_t timestamp_us = rtc::TimeMicros(); - const int64_t utc_time_us = rtc::TimeUTCMicros(); + const int64_t timestamp_us = rtc::TimeMillis() * 1000; + const int64_t utc_time_us = rtc::TimeUTCMillis() * 1000; RTC_LOG(LS_INFO) << "Starting WebRTC event log. (Timestamp, UTC) = " "(" << timestamp_us << ", " << utc_time_us << ")."; @@ -253,7 +253,7 @@ void RtcEventLogImpl::StopOutput() { void RtcEventLogImpl::StopLoggingInternal() { if (event_output_) { RTC_DCHECK(event_output_->IsActive()); - const int64_t timestamp_us = rtc::TimeMicros(); + const int64_t timestamp_us = rtc::TimeMillis() * 1000; event_output_->Write(event_encoder_->EncodeLogEnd(timestamp_us)); } StopOutput(); diff --git a/logging/rtc_event_log/rtc_event_log_impl.h b/logging/rtc_event_log/rtc_event_log_impl.h index bdbde612eb..0b6a71b24b 100644 --- a/logging/rtc_event_log/rtc_event_log_impl.h +++ b/logging/rtc_event_log/rtc_event_log_impl.h @@ -21,9 +21,9 @@ #include "api/rtc_event_log/rtc_event.h" #include "api/rtc_event_log/rtc_event_log.h" #include "api/rtc_event_log_output.h" +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_factory.h" #include "logging/rtc_event_log/encoder/rtc_event_log_encoder.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_queue.h" #include "rtc_base/thread_annotations.h" diff --git a/logging/rtc_event_log/rtc_event_log_parser.cc b/logging/rtc_event_log/rtc_event_log_parser.cc index 59a722c7f3..08fb9408c1 100644 --- a/logging/rtc_event_log/rtc_event_log_parser.cc +++ b/logging/rtc_event_log/rtc_event_log_parser.cc @@ -14,8 +14,6 @@ #include #include -#include -#include // no-presubmit-check TODO(webrtc:8982) #include #include #include @@ -29,6 +27,7 @@ #include "logging/rtc_event_log/encoder/blob_encoding.h" #include "logging/rtc_event_log/encoder/delta_encoding.h" #include "logging/rtc_event_log/encoder/rtc_event_log_encoder_common.h" +#include "logging/rtc_event_log/encoder/var_int.h" #include "logging/rtc_event_log/rtc_event_processor.h" #include "modules/audio_coding/audio_network_adaptor/include/audio_network_adaptor.h" #include "modules/include/module_common_types_public.h" @@ -42,6 +41,7 @@ #include "rtc_base/numerics/safe_conversions.h" #include "rtc_base/numerics/sequence_number_util.h" #include "rtc_base/protobuf_utils.h" +#include "rtc_base/system/file_wrapper.h" // These macros were added to convert existing code using RTC_CHECKs // to returning a Status object instead. Macros are necessary (over @@ -98,6 +98,8 @@ using webrtc_event_logging::ToUnsigned; namespace webrtc { namespace { +constexpr int64_t kMaxLogSize = 250000000; + constexpr size_t kIpv4Overhead = 20; constexpr size_t kIpv6Overhead = 40; constexpr size_t kUdpOverhead = 8; @@ -313,33 +315,6 @@ VideoCodecType GetRuntimeCodecType(rtclog2::FrameDecodedEvents::Codec codec) { return VideoCodecType::kVideoCodecMultiplex; } -// Reads a VarInt from |stream| and returns it. Also writes the read bytes to -// |buffer| starting |bytes_written| bytes into the buffer. |bytes_written| is -// incremented for each written byte. -ParsedRtcEventLog::ParseStatusOr ParseVarInt( - std::istream& stream, // no-presubmit-check TODO(webrtc:8982) - char* buffer, - size_t* bytes_written) { - uint64_t varint = 0; - for (size_t bytes_read = 0; bytes_read < 10; ++bytes_read) { - // The most significant bit of each byte is 0 if it is the last byte in - // the varint and 1 otherwise. Thus, we take the 7 least significant bits - // of each byte and shift them 7 bits for each byte read previously to get - // the (unsigned) integer. - int byte = stream.get(); - RTC_PARSE_CHECK_OR_RETURN(!stream.eof()); - RTC_DCHECK_GE(byte, 0); - RTC_DCHECK_LE(byte, 255); - varint |= static_cast(byte & 0x7F) << (7 * bytes_read); - buffer[*bytes_written] = byte; - *bytes_written += 1; - if ((byte & 0x80) == 0) { - return varint; - } - } - RTC_PARSE_CHECK_OR_RETURN(false); -} - ParsedRtcEventLog::ParseStatus GetHeaderExtensions( std::vector* header_extensions, const RepeatedPtrField& @@ -415,7 +390,7 @@ ParsedRtcEventLog::ParseStatus StoreRtpPackets( RTC_PARSE_CHECK_OR_RETURN(!proto.has_voice_activity()); } (*rtp_packets_map)[header.ssrc].emplace_back( - proto.timestamp_ms() * 1000, header, proto.header_size(), + Timestamp::Millis(proto.timestamp_ms()), header, proto.header_size(), proto.payload_size() + header.headerLength + header.paddingLength); } @@ -617,7 +592,7 @@ ParsedRtcEventLog::ParseStatus StoreRtpPackets( !voice_activity_values[i].has_value()); } (*rtp_packets_map)[header.ssrc].emplace_back( - 1000 * timestamp_ms, header, header.headerLength, + Timestamp::Millis(timestamp_ms), header, header.headerLength, payload_size_values[i].value() + header.headerLength + header.paddingLength); } @@ -640,7 +615,8 @@ ParsedRtcEventLog::ParseStatus StoreRtcpPackets( !IdenticalRtcpContents(rtcp_packets->back().rtcp.raw_data, proto.raw_packet())) { // Base event - rtcp_packets->emplace_back(proto.timestamp_ms() * 1000, proto.raw_packet()); + rtcp_packets->emplace_back(Timestamp::Millis(proto.timestamp_ms()), + proto.raw_packet()); } const size_t number_of_deltas = @@ -678,7 +654,7 @@ ParsedRtcEventLog::ParseStatus StoreRtcpPackets( continue; } std::string data(raw_packet_values[i]); - rtcp_packets->emplace_back(1000 * timestamp_ms, data); + rtcp_packets->emplace_back(Timestamp::Millis(timestamp_ms), data); } return ParsedRtcEventLog::ParseStatus::Success(); } @@ -694,8 +670,10 @@ ParsedRtcEventLog::ParseStatus StoreRtcpBlocks( std::vector* nack_list, std::vector* fir_list, std::vector* pli_list, + std::vector* bye_list, std::vector* transport_feedback_list, std::vector* loss_notification_list) { + Timestamp timestamp = Timestamp::Micros(timestamp_us); rtcp::CommonHeader header; for (const uint8_t* block = packet_begin; block < packet_end; block = header.NextPacket()) { @@ -703,47 +681,53 @@ ParsedRtcEventLog::ParseStatus StoreRtcpBlocks( if (header.type() == rtcp::TransportFeedback::kPacketType && header.fmt() == rtcp::TransportFeedback::kFeedbackMessageType) { LoggedRtcpPacketTransportFeedback parsed_block; - parsed_block.timestamp_us = timestamp_us; + parsed_block.timestamp = timestamp; if (parsed_block.transport_feedback.Parse(header)) transport_feedback_list->push_back(std::move(parsed_block)); } else if (header.type() == rtcp::SenderReport::kPacketType) { LoggedRtcpPacketSenderReport parsed_block; - parsed_block.timestamp_us = timestamp_us; + parsed_block.timestamp = timestamp; if (parsed_block.sr.Parse(header)) { sr_list->push_back(std::move(parsed_block)); } } else if (header.type() == rtcp::ReceiverReport::kPacketType) { LoggedRtcpPacketReceiverReport parsed_block; - parsed_block.timestamp_us = timestamp_us; + parsed_block.timestamp = timestamp; if (parsed_block.rr.Parse(header)) { rr_list->push_back(std::move(parsed_block)); } } else if (header.type() == rtcp::ExtendedReports::kPacketType) { LoggedRtcpPacketExtendedReports parsed_block; - parsed_block.timestamp_us = timestamp_us; + parsed_block.timestamp = timestamp; if (parsed_block.xr.Parse(header)) { xr_list->push_back(std::move(parsed_block)); } } else if (header.type() == rtcp::Fir::kPacketType && header.fmt() == rtcp::Fir::kFeedbackMessageType) { LoggedRtcpPacketFir parsed_block; - parsed_block.timestamp_us = timestamp_us; + parsed_block.timestamp = timestamp; if (parsed_block.fir.Parse(header)) { fir_list->push_back(std::move(parsed_block)); } } else if (header.type() == rtcp::Pli::kPacketType && header.fmt() == rtcp::Pli::kFeedbackMessageType) { LoggedRtcpPacketPli parsed_block; - parsed_block.timestamp_us = timestamp_us; + parsed_block.timestamp = timestamp; if (parsed_block.pli.Parse(header)) { pli_list->push_back(std::move(parsed_block)); } - } else if (header.type() == rtcp::Remb::kPacketType && + } else if (header.type() == rtcp::Bye::kPacketType) { + LoggedRtcpPacketBye parsed_block; + parsed_block.timestamp = timestamp; + if (parsed_block.bye.Parse(header)) { + bye_list->push_back(std::move(parsed_block)); + } + } else if (header.type() == rtcp::Psfb::kPacketType && header.fmt() == rtcp::Psfb::kAfbMessageType) { bool type_found = false; if (!type_found) { LoggedRtcpPacketRemb parsed_block; - parsed_block.timestamp_us = timestamp_us; + parsed_block.timestamp = timestamp; if (parsed_block.remb.Parse(header)) { remb_list->push_back(std::move(parsed_block)); type_found = true; @@ -751,7 +735,7 @@ ParsedRtcEventLog::ParseStatus StoreRtcpBlocks( } if (!type_found) { LoggedRtcpPacketLossNotification parsed_block; - parsed_block.timestamp_us = timestamp_us; + parsed_block.timestamp = timestamp; if (parsed_block.loss_notification.Parse(header)) { loss_notification_list->push_back(std::move(parsed_block)); type_found = true; @@ -760,7 +744,7 @@ ParsedRtcEventLog::ParseStatus StoreRtcpBlocks( } else if (header.type() == rtcp::Nack::kPacketType && header.fmt() == rtcp::Nack::kFeedbackMessageType) { LoggedRtcpPacketNack parsed_block; - parsed_block.timestamp_us = timestamp_us; + parsed_block.timestamp = timestamp; if (parsed_block.nack.Parse(header)) { nack_list->push_back(std::move(parsed_block)); } @@ -977,23 +961,21 @@ ParsedRtcEventLog::LoggedRtpStreamOutgoing::~LoggedRtpStreamOutgoing() = ParsedRtcEventLog::LoggedRtpStreamView::LoggedRtpStreamView( uint32_t ssrc, - const LoggedRtpPacketIncoming* ptr, - size_t num_elements) - : ssrc(ssrc), - packet_view(PacketView::Create( - ptr, - num_elements, - offsetof(LoggedRtpPacketIncoming, rtp))) {} + const std::vector& packets) + : ssrc(ssrc), packet_view() { + for (const LoggedRtpPacketIncoming& packet : packets) { + packet_view.push_back(&(packet.rtp)); + } +} ParsedRtcEventLog::LoggedRtpStreamView::LoggedRtpStreamView( uint32_t ssrc, - const LoggedRtpPacketOutgoing* ptr, - size_t num_elements) - : ssrc(ssrc), - packet_view(PacketView::Create( - ptr, - num_elements, - offsetof(LoggedRtpPacketOutgoing, rtp))) {} + const std::vector& packets) + : ssrc(ssrc), packet_view() { + for (const LoggedRtpPacketOutgoing& packet : packets) { + packet_view.push_back(&(packet.rtp)); + } +} ParsedRtcEventLog::LoggedRtpStreamView::LoggedRtpStreamView( const LoggedRtpStreamView&) = default; @@ -1102,27 +1084,39 @@ void ParsedRtcEventLog::Clear() { ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::ParseFile( const std::string& filename) { - std::ifstream file( // no-presubmit-check TODO(webrtc:8982) - filename, std::ios_base::in | std::ios_base::binary); - if (!file.good() || !file.is_open()) { - RTC_LOG(LS_WARNING) << "Could not open file for reading."; - RTC_PARSE_CHECK_OR_RETURN(file.good() && file.is_open()); + FileWrapper file = FileWrapper::OpenReadOnly(filename); + if (!file.is_open()) { + RTC_LOG(LS_WARNING) << "Could not open file " << filename + << " for reading."; + RTC_PARSE_CHECK_OR_RETURN(file.is_open()); } - return ParseStream(file); + // Compute file size. + long signed_filesize = file.FileSize(); // NOLINT(runtime/int) + RTC_PARSE_CHECK_OR_RETURN_GE(signed_filesize, 0); + RTC_PARSE_CHECK_OR_RETURN_LE(signed_filesize, kMaxLogSize); + size_t filesize = rtc::checked_cast(signed_filesize); + + // Read file into memory. + std::string buffer(filesize, '\0'); + size_t bytes_read = file.Read(&buffer[0], buffer.size()); + if (bytes_read != filesize) { + RTC_LOG(LS_WARNING) << "Failed to read file " << filename; + RTC_PARSE_CHECK_OR_RETURN_EQ(bytes_read, filesize); + } + + return ParseStream(buffer); } ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::ParseString( const std::string& s) { - std::istringstream stream( // no-presubmit-check TODO(webrtc:8982) - s, std::ios_base::in | std::ios_base::binary); - return ParseStream(stream); + return ParseStream(s); } ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::ParseStream( - std::istream& stream) { // no-presubmit-check TODO(webrtc:8982) + const std::string& s) { Clear(); - ParseStatus status = ParseStreamInternal(stream); + ParseStatus status = ParseStreamInternal(s); // Cache the configured SSRCs. for (const auto& video_recv_config : video_recv_configs()) { @@ -1165,36 +1159,34 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::ParseStream( // Build PacketViews for easier iteration over RTP packets. for (const auto& stream : incoming_rtp_packets_by_ssrc_) { incoming_rtp_packet_views_by_ssrc_.emplace_back( - LoggedRtpStreamView(stream.ssrc, stream.incoming_packets.data(), - stream.incoming_packets.size())); + LoggedRtpStreamView(stream.ssrc, stream.incoming_packets)); } for (const auto& stream : outgoing_rtp_packets_by_ssrc_) { outgoing_rtp_packet_views_by_ssrc_.emplace_back( - LoggedRtpStreamView(stream.ssrc, stream.outgoing_packets.data(), - stream.outgoing_packets.size())); + LoggedRtpStreamView(stream.ssrc, stream.outgoing_packets)); } // Set up convenience wrappers around the most commonly used RTCP types. for (const auto& incoming : incoming_rtcp_packets_) { - const int64_t timestamp_us = incoming.rtcp.timestamp_us; + const int64_t timestamp_us = incoming.rtcp.timestamp.us(); const uint8_t* packet_begin = incoming.rtcp.raw_data.data(); const uint8_t* packet_end = packet_begin + incoming.rtcp.raw_data.size(); auto status = StoreRtcpBlocks( timestamp_us, packet_begin, packet_end, &incoming_sr_, &incoming_rr_, &incoming_xr_, &incoming_remb_, &incoming_nack_, &incoming_fir_, - &incoming_pli_, &incoming_transport_feedback_, + &incoming_pli_, &incoming_bye_, &incoming_transport_feedback_, &incoming_loss_notification_); RTC_RETURN_IF_ERROR(status); } for (const auto& outgoing : outgoing_rtcp_packets_) { - const int64_t timestamp_us = outgoing.rtcp.timestamp_us; + const int64_t timestamp_us = outgoing.rtcp.timestamp.us(); const uint8_t* packet_begin = outgoing.rtcp.raw_data.data(); const uint8_t* packet_end = packet_begin + outgoing.rtcp.raw_data.size(); auto status = StoreRtcpBlocks( timestamp_us, packet_begin, packet_end, &outgoing_sr_, &outgoing_rr_, &outgoing_xr_, &outgoing_remb_, &outgoing_nack_, &outgoing_fir_, - &outgoing_pli_, &outgoing_transport_feedback_, + &outgoing_pli_, &outgoing_bye_, &outgoing_transport_feedback_, &outgoing_loss_notification_); RTC_RETURN_IF_ERROR(status); } @@ -1273,17 +1265,12 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::ParseStream( } ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::ParseStreamInternal( - std::istream& stream) { // no-presubmit-check TODO(webrtc:8982) + absl::string_view s) { constexpr uint64_t kMaxEventSize = 10000000; // Sanity check. - std::vector buffer(0xFFFF); - RTC_DCHECK(stream.good()); - while (1) { - // Check whether we have reached end of file. - stream.peek(); - if (stream.eof()) { - break; - } + while (!s.empty()) { + absl::string_view event_start = s; + bool success = false; // Read the next message tag. Protobuf defines the message tag as // (field_number << 3) | wire_type. In the legacy encoding, the field number @@ -1291,18 +1278,18 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::ParseStreamInternal( // In the new encoding we still expect the wire type to be 2, but the field // number will be greater than 1. constexpr uint64_t kExpectedV1Tag = (1 << 3) | 2; - size_t bytes_written = 0; - ParsedRtcEventLog::ParseStatusOr tag = - ParseVarInt(stream, buffer.data(), &bytes_written); - if (!tag.ok()) { + uint64_t tag = 0; + std::tie(success, s) = DecodeVarInt(s, &tag); + if (!success) { RTC_LOG(LS_WARNING) - << "Missing field tag from beginning of protobuf event."; + << "Failed to read field tag from beginning of protobuf event."; RTC_PARSE_WARN_AND_RETURN_SUCCESS_IF(allow_incomplete_logs_, kIncompleteLogError); - return tag.status(); + return ParseStatus::Error("Failed to read field tag varint", __FILE__, + __LINE__); } constexpr uint64_t kWireTypeMask = 0x07; - const uint64_t wire_type = tag.value() & kWireTypeMask; + const uint64_t wire_type = tag & kWireTypeMask; if (wire_type != 2) { RTC_LOG(LS_WARNING) << "Expected field tag with wire type 2 (length " "delimited message). Found wire type " @@ -1313,36 +1300,32 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::ParseStreamInternal( } // Read the length field. - ParsedRtcEventLog::ParseStatusOr message_length = - ParseVarInt(stream, buffer.data(), &bytes_written); - if (!message_length.ok()) { + uint64_t message_length = 0; + std::tie(success, s) = DecodeVarInt(s, &message_length); + if (!success) { RTC_LOG(LS_WARNING) << "Missing message length after protobuf field tag."; RTC_PARSE_WARN_AND_RETURN_SUCCESS_IF(allow_incomplete_logs_, kIncompleteLogError); - return message_length.status(); - } else if (message_length.value() > kMaxEventSize) { - RTC_LOG(LS_WARNING) << "Protobuf message length is too large."; - RTC_PARSE_WARN_AND_RETURN_SUCCESS_IF(allow_incomplete_logs_, - kIncompleteLogError); - RTC_PARSE_CHECK_OR_RETURN_LE(message_length.value(), kMaxEventSize); + return ParseStatus::Error("Failed to read message length varint", + __FILE__, __LINE__); } - // Read the next protobuf event to a temporary char buffer. - if (buffer.size() < bytes_written + message_length.value()) - buffer.resize(bytes_written + message_length.value()); - stream.read(buffer.data() + bytes_written, message_length.value()); - if (stream.gcount() != static_cast(message_length.value())) { - RTC_LOG(LS_WARNING) << "Failed to read protobuf message."; + if (message_length > s.size()) { + RTC_LOG(LS_WARNING) << "Protobuf message length is too large."; RTC_PARSE_WARN_AND_RETURN_SUCCESS_IF(allow_incomplete_logs_, kIncompleteLogError); - RTC_PARSE_CHECK_OR_RETURN(false); + RTC_PARSE_CHECK_OR_RETURN_LE(message_length, kMaxEventSize); } - size_t buffer_size = bytes_written + message_length.value(); - if (tag.value() == kExpectedV1Tag) { + // Skip forward to the start of the next event. + s = s.substr(message_length); + size_t total_event_size = event_start.size() - s.size(); + RTC_CHECK_LE(total_event_size, event_start.size()); + + if (tag == kExpectedV1Tag) { // Parse the protobuf event from the buffer. rtclog::EventStream event_stream; - if (!event_stream.ParseFromArray(buffer.data(), buffer_size)) { + if (!event_stream.ParseFromArray(event_start.data(), total_event_size)) { RTC_LOG(LS_WARNING) << "Failed to parse legacy-format protobuf message."; RTC_PARSE_WARN_AND_RETURN_SUCCESS_IF(allow_incomplete_logs_, @@ -1356,7 +1339,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::ParseStreamInternal( } else { // Parse the protobuf event from the buffer. rtclog2::EventStream event_stream; - if (!event_stream.ParseFromArray(buffer.data(), buffer_size)) { + if (!event_stream.ParseFromArray(event_start.data(), total_event_size)) { RTC_LOG(LS_WARNING) << "Failed to parse new-format protobuf message."; RTC_PARSE_WARN_AND_RETURN_SUCCESS_IF(allow_incomplete_logs_, kIncompleteLogError); @@ -1389,7 +1372,8 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreParsedLegacyEvent( RTC_PARSE_CHECK_OR_RETURN(event.has_timestamp_us()); int64_t timestamp_us = event.timestamp_us(); - video_recv_configs_.emplace_back(timestamp_us, config.value()); + video_recv_configs_.emplace_back(Timestamp::Micros(timestamp_us), + config.value()); incoming_rtp_extensions_maps_[config.value().remote_ssrc] = RtpHeaderExtensionMap(config.value().rtp_extensions); incoming_rtp_extensions_maps_[config.value().rtx_ssrc] = @@ -1403,7 +1387,8 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreParsedLegacyEvent( RTC_PARSE_CHECK_OR_RETURN(event.has_timestamp_us()); int64_t timestamp_us = event.timestamp_us(); - video_send_configs_.emplace_back(timestamp_us, config.value()); + video_send_configs_.emplace_back(Timestamp::Micros(timestamp_us), + config.value()); outgoing_rtp_extensions_maps_[config.value().local_ssrc] = RtpHeaderExtensionMap(config.value().rtp_extensions); outgoing_rtp_extensions_maps_[config.value().rtx_ssrc] = @@ -1417,7 +1402,8 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreParsedLegacyEvent( RTC_PARSE_CHECK_OR_RETURN(event.has_timestamp_us()); int64_t timestamp_us = event.timestamp_us(); - audio_recv_configs_.emplace_back(timestamp_us, config.value()); + audio_recv_configs_.emplace_back(Timestamp::Micros(timestamp_us), + config.value()); incoming_rtp_extensions_maps_[config.value().remote_ssrc] = RtpHeaderExtensionMap(config.value().rtp_extensions); break; @@ -1428,7 +1414,8 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreParsedLegacyEvent( return config.status(); RTC_PARSE_CHECK_OR_RETURN(event.has_timestamp_us()); int64_t timestamp_us = event.timestamp_us(); - audio_send_configs_.emplace_back(timestamp_us, config.value()); + audio_send_configs_.emplace_back(Timestamp::Micros(timestamp_us), + config.value()); outgoing_rtp_extensions_maps_[config.value().local_ssrc] = RtpHeaderExtensionMap(config.value().rtp_extensions); break; @@ -1461,11 +1448,13 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreParsedLegacyEvent( int64_t timestamp_us = event.timestamp_us(); if (direction == kIncomingPacket) { incoming_rtp_packets_map_[parsed_header.ssrc].push_back( - LoggedRtpPacketIncoming(timestamp_us, parsed_header, header_length, + LoggedRtpPacketIncoming(Timestamp::Micros(timestamp_us), + parsed_header, header_length, total_length)); } else { outgoing_rtp_packets_map_[parsed_header.ssrc].push_back( - LoggedRtpPacketOutgoing(timestamp_us, parsed_header, header_length, + LoggedRtpPacketOutgoing(Timestamp::Micros(timestamp_us), + parsed_header, header_length, total_length)); } break; @@ -1484,24 +1473,26 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreParsedLegacyEvent( if (packet == last_incoming_rtcp_packet_) break; incoming_rtcp_packets_.push_back( - LoggedRtcpPacketIncoming(timestamp_us, packet)); + LoggedRtcpPacketIncoming(Timestamp::Micros(timestamp_us), packet)); last_incoming_rtcp_packet_ = packet; } else { outgoing_rtcp_packets_.push_back( - LoggedRtcpPacketOutgoing(timestamp_us, packet)); + LoggedRtcpPacketOutgoing(Timestamp::Micros(timestamp_us), packet)); } break; } case rtclog::Event::LOG_START: { RTC_PARSE_CHECK_OR_RETURN(event.has_timestamp_us()); int64_t timestamp_us = event.timestamp_us(); - start_log_events_.push_back(LoggedStartEvent(timestamp_us)); + start_log_events_.push_back( + LoggedStartEvent(Timestamp::Micros(timestamp_us))); break; } case rtclog::Event::LOG_END: { RTC_PARSE_CHECK_OR_RETURN(event.has_timestamp_us()); int64_t timestamp_us = event.timestamp_us(); - stop_log_events_.push_back(LoggedStopEvent(timestamp_us)); + stop_log_events_.push_back( + LoggedStopEvent(Timestamp::Micros(timestamp_us))); break; } case rtclog::Event::AUDIO_PLAYOUT_EVENT: { @@ -1820,7 +1811,7 @@ ParsedRtcEventLog::GetAudioPlayout(const rtclog::Event& event) const { const rtclog::AudioPlayoutEvent& playout_event = event.audio_playout_event(); LoggedAudioPlayoutEvent res; RTC_PARSE_CHECK_OR_RETURN(event.has_timestamp_us()); - res.timestamp_us = event.timestamp_us(); + res.timestamp = Timestamp::Micros(event.timestamp_us()); RTC_PARSE_CHECK_OR_RETURN(playout_event.has_local_ssrc()); res.ssrc = playout_event.local_ssrc(); return res; @@ -1836,7 +1827,7 @@ ParsedRtcEventLog::GetLossBasedBweUpdate(const rtclog::Event& event) const { LoggedBweLossBasedUpdate bwe_update; RTC_CHECK(event.has_timestamp_us()); - bwe_update.timestamp_us = event.timestamp_us(); + bwe_update.timestamp = Timestamp::Micros(event.timestamp_us()); RTC_PARSE_CHECK_OR_RETURN(loss_event.has_bitrate_bps()); bwe_update.bitrate_bps = loss_event.bitrate_bps(); RTC_PARSE_CHECK_OR_RETURN(loss_event.has_fraction_loss()); @@ -1857,7 +1848,7 @@ ParsedRtcEventLog::GetDelayBasedBweUpdate(const rtclog::Event& event) const { LoggedBweDelayBasedUpdate res; RTC_PARSE_CHECK_OR_RETURN(event.has_timestamp_us()); - res.timestamp_us = event.timestamp_us(); + res.timestamp = Timestamp::Micros(event.timestamp_us()); RTC_PARSE_CHECK_OR_RETURN(delay_event.has_bitrate_bps()); res.bitrate_bps = delay_event.bitrate_bps(); RTC_PARSE_CHECK_OR_RETURN(delay_event.has_detector_state()); @@ -1876,7 +1867,7 @@ ParsedRtcEventLog::GetAudioNetworkAdaptation(const rtclog::Event& event) const { LoggedAudioNetworkAdaptationEvent res; RTC_PARSE_CHECK_OR_RETURN(event.has_timestamp_us()); - res.timestamp_us = event.timestamp_us(); + res.timestamp = Timestamp::Micros(event.timestamp_us()); if (ana_event.has_bitrate_bps()) res.config.bitrate_bps = ana_event.bitrate_bps(); if (ana_event.has_enable_fec()) @@ -1902,7 +1893,7 @@ ParsedRtcEventLog::GetBweProbeClusterCreated(const rtclog::Event& event) const { const rtclog::BweProbeCluster& pcc_event = event.probe_cluster(); LoggedBweProbeClusterCreatedEvent res; RTC_PARSE_CHECK_OR_RETURN(event.has_timestamp_us()); - res.timestamp_us = event.timestamp_us(); + res.timestamp = Timestamp::Micros(event.timestamp_us()); RTC_PARSE_CHECK_OR_RETURN(pcc_event.has_id()); res.id = pcc_event.id(); RTC_PARSE_CHECK_OR_RETURN(pcc_event.has_bitrate_bps()); @@ -1927,7 +1918,7 @@ ParsedRtcEventLog::GetBweProbeFailure(const rtclog::Event& event) const { LoggedBweProbeFailureEvent res; RTC_PARSE_CHECK_OR_RETURN(event.has_timestamp_us()); - res.timestamp_us = event.timestamp_us(); + res.timestamp = Timestamp::Micros(event.timestamp_us()); RTC_PARSE_CHECK_OR_RETURN(pr_event.has_id()); res.id = pr_event.id(); RTC_PARSE_CHECK_OR_RETURN(pr_event.has_result()); @@ -1960,7 +1951,7 @@ ParsedRtcEventLog::GetBweProbeSuccess(const rtclog::Event& event) const { LoggedBweProbeSuccessEvent res; RTC_PARSE_CHECK_OR_RETURN(event.has_timestamp_us()); - res.timestamp_us = event.timestamp_us(); + res.timestamp = Timestamp::Micros(event.timestamp_us()); RTC_PARSE_CHECK_OR_RETURN(pr_event.has_id()); res.id = pr_event.id(); RTC_PARSE_CHECK_OR_RETURN(pr_event.has_bitrate_bps()); @@ -1977,7 +1968,7 @@ ParsedRtcEventLog::GetAlrState(const rtclog::Event& event) const { const rtclog::AlrState& alr_event = event.alr_state(); LoggedAlrStateEvent res; RTC_PARSE_CHECK_OR_RETURN(event.has_timestamp_us()); - res.timestamp_us = event.timestamp_us(); + res.timestamp = Timestamp::Micros(event.timestamp_us()); RTC_PARSE_CHECK_OR_RETURN(alr_event.has_in_alr()); res.in_alr = alr_event.in_alr(); @@ -1994,7 +1985,7 @@ ParsedRtcEventLog::GetIceCandidatePairConfig( const rtclog::IceCandidatePairConfig& config = rtc_event.ice_candidate_pair_config(); RTC_CHECK(rtc_event.has_timestamp_us()); - res.timestamp_us = rtc_event.timestamp_us(); + res.timestamp = Timestamp::Micros(rtc_event.timestamp_us()); RTC_PARSE_CHECK_OR_RETURN(config.has_config_type()); res.type = GetRuntimeIceCandidatePairConfigType(config.config_type()); RTC_PARSE_CHECK_OR_RETURN(config.has_candidate_pair_id()); @@ -2033,7 +2024,7 @@ ParsedRtcEventLog::GetIceCandidatePairEvent( const rtclog::IceCandidatePairEvent& event = rtc_event.ice_candidate_pair_event(); RTC_CHECK(rtc_event.has_timestamp_us()); - res.timestamp_us = rtc_event.timestamp_us(); + res.timestamp = Timestamp::Micros(rtc_event.timestamp_us()); RTC_PARSE_CHECK_OR_RETURN(event.has_event_type()); res.type = GetRuntimeIceCandidatePairEventType(event.event_type()); RTC_PARSE_CHECK_OR_RETURN(event.has_candidate_pair_id()); @@ -2419,7 +2410,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreAlrStateEvent( RTC_PARSE_CHECK_OR_RETURN(proto.has_timestamp_ms()); RTC_PARSE_CHECK_OR_RETURN(proto.has_in_alr()); LoggedAlrStateEvent alr_event; - alr_event.timestamp_us = proto.timestamp_ms() * 1000; + alr_event.timestamp = Timestamp::Millis(proto.timestamp_ms()); alr_event.in_alr = proto.in_alr(); alr_state_events_.push_back(alr_event); @@ -2433,7 +2424,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreRouteChangeEvent( RTC_PARSE_CHECK_OR_RETURN(proto.has_connected()); RTC_PARSE_CHECK_OR_RETURN(proto.has_overhead()); LoggedRouteChangeEvent route_event; - route_event.timestamp_ms = proto.timestamp_ms(); + route_event.timestamp = Timestamp::Millis(proto.timestamp_ms()); route_event.connected = proto.connected(); route_event.overhead = proto.overhead(); @@ -2447,7 +2438,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreRemoteEstimateEvent( RTC_PARSE_CHECK_OR_RETURN(proto.has_timestamp_ms()); // Base event LoggedRemoteEstimateEvent base_event; - base_event.timestamp_ms = proto.timestamp_ms(); + base_event.timestamp = Timestamp::Millis(proto.timestamp_ms()); absl::optional base_link_capacity_lower_kbps; if (proto.has_link_capacity_lower_kbps()) { @@ -2495,7 +2486,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreRemoteEstimateEvent( for (size_t i = 0; i < number_of_deltas; ++i) { LoggedRemoteEstimateEvent event; RTC_PARSE_CHECK_OR_RETURN(timestamp_ms_values[i].has_value()); - event.timestamp_ms = *timestamp_ms_values[i]; + event.timestamp = Timestamp::Millis(*timestamp_ms_values[i]); if (link_capacity_lower_kbps_values[i]) event.link_capacity_lower = DataRate::KilobitsPerSec(*link_capacity_lower_kbps_values[i]); @@ -2514,7 +2505,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreAudioPlayoutEvent( // Base event audio_playout_events_[proto.local_ssrc()].emplace_back( - 1000 * proto.timestamp_ms(), proto.local_ssrc()); + Timestamp::Millis(proto.timestamp_ms()), proto.local_ssrc()); const size_t number_of_deltas = proto.has_number_of_deltas() ? proto.number_of_deltas() : 0u; @@ -2546,8 +2537,8 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreAudioPlayoutEvent( const uint32_t local_ssrc = static_cast(local_ssrc_values[i].value()); - audio_playout_events_[local_ssrc].emplace_back(1000 * timestamp_ms, - local_ssrc); + audio_playout_events_[local_ssrc].emplace_back( + Timestamp::Millis(timestamp_ms), local_ssrc); } return ParseStatus::Success(); } @@ -2580,8 +2571,8 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreStartEvent( RTC_PARSE_CHECK_OR_RETURN(proto.has_version()); RTC_PARSE_CHECK_OR_RETURN(proto.has_utc_time_ms()); RTC_PARSE_CHECK_OR_RETURN_EQ(proto.version(), 2); - LoggedStartEvent start_event(proto.timestamp_ms() * 1000, - proto.utc_time_ms()); + LoggedStartEvent start_event(Timestamp::Millis(proto.timestamp_ms()), + Timestamp::Millis(proto.utc_time_ms())); start_log_events_.push_back(start_event); return ParseStatus::Success(); @@ -2590,7 +2581,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreStartEvent( ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreStopEvent( const rtclog2::EndLogEvent& proto) { RTC_PARSE_CHECK_OR_RETURN(proto.has_timestamp_ms()); - LoggedStopEvent stop_event(proto.timestamp_ms() * 1000); + LoggedStopEvent stop_event(Timestamp::Millis(proto.timestamp_ms())); stop_log_events_.push_back(stop_event); return ParseStatus::Success(); @@ -2604,7 +2595,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreBweLossBasedUpdate( RTC_PARSE_CHECK_OR_RETURN(proto.has_total_packets()); // Base event - bwe_loss_updates_.emplace_back(1000 * proto.timestamp_ms(), + bwe_loss_updates_.emplace_back(Timestamp::Millis(proto.timestamp_ms()), proto.bitrate_bps(), proto.fraction_loss(), proto.total_packets()); @@ -2660,7 +2651,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreBweLossBasedUpdate( const uint32_t total_packets = static_cast(total_packets_values[i].value()); - bwe_loss_updates_.emplace_back(1000 * timestamp_ms, bitrate_bps, + bwe_loss_updates_.emplace_back(Timestamp::Millis(timestamp_ms), bitrate_bps, fraction_loss, total_packets); } return ParseStatus::Success(); @@ -2675,7 +2666,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreBweDelayBasedUpdate( // Base event const BandwidthUsage base_detector_state = GetRuntimeDetectorState(proto.detector_state()); - bwe_delay_updates_.emplace_back(1000 * proto.timestamp_ms(), + bwe_delay_updates_.emplace_back(Timestamp::Millis(proto.timestamp_ms()), proto.bitrate_bps(), base_detector_state); const size_t number_of_deltas = @@ -2719,7 +2710,8 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreBweDelayBasedUpdate( static_cast( detector_state_values[i].value()); - bwe_delay_updates_.emplace_back(1000 * timestamp_ms, bitrate_bps, + bwe_delay_updates_.emplace_back(Timestamp::Millis(timestamp_ms), + bitrate_bps, GetRuntimeDetectorState(detector_state)); } return ParseStatus::Success(); @@ -2729,7 +2721,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreBweProbeClusterCreated( const rtclog2::BweProbeCluster& proto) { LoggedBweProbeClusterCreatedEvent probe_cluster; RTC_PARSE_CHECK_OR_RETURN(proto.has_timestamp_ms()); - probe_cluster.timestamp_us = proto.timestamp_ms() * 1000; + probe_cluster.timestamp = Timestamp::Millis(proto.timestamp_ms()); RTC_PARSE_CHECK_OR_RETURN(proto.has_id()); probe_cluster.id = proto.id(); RTC_PARSE_CHECK_OR_RETURN(proto.has_bitrate_bps()); @@ -2749,7 +2741,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreBweProbeSuccessEvent( const rtclog2::BweProbeResultSuccess& proto) { LoggedBweProbeSuccessEvent probe_result; RTC_PARSE_CHECK_OR_RETURN(proto.has_timestamp_ms()); - probe_result.timestamp_us = proto.timestamp_ms() * 1000; + probe_result.timestamp = Timestamp::Millis(proto.timestamp_ms()); RTC_PARSE_CHECK_OR_RETURN(proto.has_id()); probe_result.id = proto.id(); RTC_PARSE_CHECK_OR_RETURN(proto.has_bitrate_bps()); @@ -2765,7 +2757,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreBweProbeFailureEvent( const rtclog2::BweProbeResultFailure& proto) { LoggedBweProbeFailureEvent probe_result; RTC_PARSE_CHECK_OR_RETURN(proto.has_timestamp_ms()); - probe_result.timestamp_us = proto.timestamp_ms() * 1000; + probe_result.timestamp = Timestamp::Millis(proto.timestamp_ms()); RTC_PARSE_CHECK_OR_RETURN(proto.has_id()); probe_result.id = proto.id(); RTC_PARSE_CHECK_OR_RETURN(proto.has_failure()); @@ -2788,7 +2780,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreFrameDecodedEvents( RTC_PARSE_CHECK_OR_RETURN(proto.has_qp()); LoggedFrameDecoded base_frame; - base_frame.timestamp_us = 1000 * proto.timestamp_ms(); + base_frame.timestamp = Timestamp::Millis(proto.timestamp_ms()); base_frame.ssrc = proto.ssrc(); base_frame.render_time_ms = proto.render_time_ms(); base_frame.width = proto.width(); @@ -2851,7 +2843,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreFrameDecodedEvents( RTC_PARSE_CHECK_OR_RETURN(timestamp_ms_values[i].has_value()); RTC_PARSE_CHECK_OR_RETURN( ToSigned(timestamp_ms_values[i].value(), ×tamp_ms)); - frame.timestamp_us = 1000 * timestamp_ms; + frame.timestamp = Timestamp::Millis(timestamp_ms); RTC_PARSE_CHECK_OR_RETURN(ssrc_values[i].has_value()); RTC_PARSE_CHECK_OR_RETURN_LE(ssrc_values[i].value(), @@ -2896,7 +2888,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreGenericAckReceivedEvent( base_receive_acked_packet_time_ms = proto.receive_acked_packet_time_ms(); } generic_acks_received_.push_back( - {proto.timestamp_ms() * 1000, proto.packet_number(), + {Timestamp::Millis(proto.timestamp_ms()), proto.packet_number(), proto.acked_packet_number(), base_receive_acked_packet_time_ms}); const size_t number_of_deltas = @@ -2955,8 +2947,8 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreGenericAckReceivedEvent( ToSigned(receive_acked_packet_time_ms_values[i].value(), &value)); receive_acked_packet_time_ms = value; } - generic_acks_received_.push_back({timestamp_ms * 1000, packet_number, - acked_packet_number, + generic_acks_received_.push_back({Timestamp::Millis(timestamp_ms), + packet_number, acked_packet_number, receive_acked_packet_time_ms}); } return ParseStatus::Success(); @@ -2973,7 +2965,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreGenericPacketSentEvent( RTC_PARSE_CHECK_OR_RETURN(proto.has_padding_length()); generic_packets_sent_.push_back( - {proto.timestamp_ms() * 1000, proto.packet_number(), + {Timestamp::Millis(proto.timestamp_ms()), proto.packet_number(), static_cast(proto.overhead_length()), static_cast(proto.payload_length()), static_cast(proto.padding_length())}); @@ -3020,7 +3012,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreGenericPacketSentEvent( RTC_PARSE_CHECK_OR_RETURN(payload_length_values[i].has_value()); RTC_PARSE_CHECK_OR_RETURN(padding_length_values[i].has_value()); generic_packets_sent_.push_back( - {timestamp_ms * 1000, packet_number, + {Timestamp::Millis(timestamp_ms), packet_number, static_cast(overhead_length_values[i].value()), static_cast(payload_length_values[i].value()), static_cast(padding_length_values[i].value())}); @@ -3037,7 +3029,7 @@ ParsedRtcEventLog::StoreGenericPacketReceivedEvent( RTC_PARSE_CHECK_OR_RETURN(proto.has_packet_number()); RTC_PARSE_CHECK_OR_RETURN(proto.has_packet_length()); - generic_packets_received_.push_back({proto.timestamp_ms() * 1000, + generic_packets_received_.push_back({Timestamp::Millis(proto.timestamp_ms()), proto.packet_number(), proto.packet_length()}); @@ -3075,7 +3067,7 @@ ParsedRtcEventLog::StoreGenericPacketReceivedEvent( int32_t packet_length = static_cast(packet_length_values[i].value()); generic_packets_received_.push_back( - {timestamp_ms * 1000, packet_number, packet_length}); + {Timestamp::Millis(timestamp_ms), packet_number, packet_length}); } return ParseStatus::Success(); } @@ -3110,8 +3102,8 @@ ParsedRtcEventLog::StoreAudioNetworkAdaptationEvent( // Note: Encoding N as N-1 only done for |num_channels_deltas|. runtime_config.num_channels = proto.num_channels(); } - audio_network_adaptation_events_.emplace_back(1000 * proto.timestamp_ms(), - runtime_config); + audio_network_adaptation_events_.emplace_back( + Timestamp::Millis(proto.timestamp_ms()), runtime_config); } const size_t number_of_deltas = @@ -3232,8 +3224,8 @@ ParsedRtcEventLog::StoreAudioNetworkAdaptationEvent( runtime_config.num_channels = rtc::checked_cast(num_channels_values[i].value()); } - audio_network_adaptation_events_.emplace_back(1000 * timestamp_ms, - runtime_config); + audio_network_adaptation_events_.emplace_back( + Timestamp::Millis(timestamp_ms), runtime_config); } return ParseStatus::Success(); } @@ -3242,7 +3234,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreDtlsTransportState( const rtclog2::DtlsTransportStateEvent& proto) { LoggedDtlsTransportState dtls_state; RTC_PARSE_CHECK_OR_RETURN(proto.has_timestamp_ms()); - dtls_state.timestamp_us = proto.timestamp_ms() * 1000; + dtls_state.timestamp = Timestamp::Millis(proto.timestamp_ms()); RTC_PARSE_CHECK_OR_RETURN(proto.has_dtls_transport_state()); dtls_state.dtls_transport_state = @@ -3256,7 +3248,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreDtlsWritableState( const rtclog2::DtlsWritableState& proto) { LoggedDtlsWritableState dtls_writable_state; RTC_PARSE_CHECK_OR_RETURN(proto.has_timestamp_ms()); - dtls_writable_state.timestamp_us = proto.timestamp_ms() * 1000; + dtls_writable_state.timestamp = Timestamp::Millis(proto.timestamp_ms()); RTC_PARSE_CHECK_OR_RETURN(proto.has_writable()); dtls_writable_state.writable = proto.writable(); @@ -3268,7 +3260,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreIceCandidatePairConfig( const rtclog2::IceCandidatePairConfig& proto) { LoggedIceCandidatePairConfig ice_config; RTC_PARSE_CHECK_OR_RETURN(proto.has_timestamp_ms()); - ice_config.timestamp_us = proto.timestamp_ms() * 1000; + ice_config.timestamp = Timestamp::Millis(proto.timestamp_ms()); RTC_PARSE_CHECK_OR_RETURN(proto.has_config_type()); ice_config.type = GetRuntimeIceCandidatePairConfigType(proto.config_type()); @@ -3306,7 +3298,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreIceCandidateEvent( const rtclog2::IceCandidatePairEvent& proto) { LoggedIceCandidatePairEvent ice_event; RTC_PARSE_CHECK_OR_RETURN(proto.has_timestamp_ms()); - ice_event.timestamp_us = proto.timestamp_ms() * 1000; + ice_event.timestamp = Timestamp::Millis(proto.timestamp_ms()); RTC_PARSE_CHECK_OR_RETURN(proto.has_event_type()); ice_event.type = GetRuntimeIceCandidatePairEventType(proto.event_type()); RTC_PARSE_CHECK_OR_RETURN(proto.has_candidate_pair_id()); @@ -3326,7 +3318,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreVideoRecvConfig( const rtclog2::VideoRecvStreamConfig& proto) { LoggedVideoRecvConfig stream; RTC_PARSE_CHECK_OR_RETURN(proto.has_timestamp_ms()); - stream.timestamp_us = proto.timestamp_ms() * 1000; + stream.timestamp = Timestamp::Millis(proto.timestamp_ms()); RTC_PARSE_CHECK_OR_RETURN(proto.has_remote_ssrc()); stream.config.remote_ssrc = proto.remote_ssrc(); RTC_PARSE_CHECK_OR_RETURN(proto.has_local_ssrc()); @@ -3346,7 +3338,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreVideoSendConfig( const rtclog2::VideoSendStreamConfig& proto) { LoggedVideoSendConfig stream; RTC_PARSE_CHECK_OR_RETURN(proto.has_timestamp_ms()); - stream.timestamp_us = proto.timestamp_ms() * 1000; + stream.timestamp = Timestamp::Millis(proto.timestamp_ms()); RTC_PARSE_CHECK_OR_RETURN(proto.has_ssrc()); stream.config.local_ssrc = proto.ssrc(); if (proto.has_rtx_ssrc()) { @@ -3364,7 +3356,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreAudioRecvConfig( const rtclog2::AudioRecvStreamConfig& proto) { LoggedAudioRecvConfig stream; RTC_PARSE_CHECK_OR_RETURN(proto.has_timestamp_ms()); - stream.timestamp_us = proto.timestamp_ms() * 1000; + stream.timestamp = Timestamp::Millis(proto.timestamp_ms()); RTC_PARSE_CHECK_OR_RETURN(proto.has_remote_ssrc()); stream.config.remote_ssrc = proto.remote_ssrc(); RTC_PARSE_CHECK_OR_RETURN(proto.has_local_ssrc()); @@ -3381,7 +3373,7 @@ ParsedRtcEventLog::ParseStatus ParsedRtcEventLog::StoreAudioSendConfig( const rtclog2::AudioSendStreamConfig& proto) { LoggedAudioSendConfig stream; RTC_PARSE_CHECK_OR_RETURN(proto.has_timestamp_ms()); - stream.timestamp_us = proto.timestamp_ms() * 1000; + stream.timestamp = Timestamp::Millis(proto.timestamp_ms()); RTC_PARSE_CHECK_OR_RETURN(proto.has_ssrc()); stream.config.local_ssrc = proto.ssrc(); if (proto.has_header_extensions()) { diff --git a/logging/rtc_event_log/rtc_event_log_parser.h b/logging/rtc_event_log/rtc_event_log_parser.h index d890792a39..4898022fae 100644 --- a/logging/rtc_event_log/rtc_event_log_parser.h +++ b/logging/rtc_event_log/rtc_event_log_parser.h @@ -14,11 +14,11 @@ #include #include #include -#include // no-presubmit-check TODO(webrtc:8982) #include #include // pair #include +#include "absl/base/attributes.h" #include "api/rtc_event_log/rtc_event_log.h" #include "call/video_receive_stream.h" #include "call/video_send_stream.h" @@ -64,144 +64,108 @@ namespace webrtc { enum PacketDirection { kIncomingPacket = 0, kOutgoingPacket }; +// This class is used to process lists of LoggedRtpPacketIncoming +// and LoggedRtpPacketOutgoing without duplicating the code. +// TODO(terelius): Remove this class. Instead use e.g. a vector of pointers +// to LoggedRtpPacket or templatize the surrounding code. template -class PacketView; - -template -class PacketIterator { - friend class PacketView; - +class DereferencingVector { public: - // Standard iterator traits. - using difference_type = std::ptrdiff_t; - using value_type = T; - using pointer = T*; - using reference = T&; - using iterator_category = std::bidirectional_iterator_tag; - - // The default-contructed iterator is meaningless, but is required by the - // ForwardIterator concept. - PacketIterator() : ptr_(nullptr), element_size_(0) {} - PacketIterator(const PacketIterator& other) - : ptr_(other.ptr_), element_size_(other.element_size_) {} - PacketIterator(const PacketIterator&& other) - : ptr_(other.ptr_), element_size_(other.element_size_) {} - ~PacketIterator() = default; + template + class DereferencingIterator { + public: + // Standard iterator traits. + using difference_type = std::ptrdiff_t; + using value_type = T; + using pointer = typename std::conditional_t; + using reference = typename std::conditional_t; + using iterator_category = std::bidirectional_iterator_tag; + + using representation = + typename std::conditional_t; + + explicit DereferencingIterator(representation ptr) : ptr_(ptr) {} + + DereferencingIterator(const DereferencingIterator& other) + : ptr_(other.ptr_) {} + DereferencingIterator(const DereferencingIterator&& other) + : ptr_(other.ptr_) {} + ~DereferencingIterator() = default; + + DereferencingIterator& operator=(const DereferencingIterator& other) { + ptr_ = other.ptr_; + return *this; + } + DereferencingIterator& operator=(const DereferencingIterator&& other) { + ptr_ = other.ptr_; + return *this; + } - PacketIterator& operator=(const PacketIterator& other) { - ptr_ = other.ptr_; - element_size_ = other.element_size_; - return *this; - } - PacketIterator& operator=(const PacketIterator&& other) { - ptr_ = other.ptr_; - element_size_ = other.element_size_; - return *this; - } + bool operator==(const DereferencingIterator& other) const { + return ptr_ == other.ptr_; + } + bool operator!=(const DereferencingIterator& other) const { + return ptr_ != other.ptr_; + } - bool operator==(const PacketIterator& other) const { - RTC_DCHECK_EQ(element_size_, other.element_size_); - return ptr_ == other.ptr_; - } - bool operator!=(const PacketIterator& other) const { - RTC_DCHECK_EQ(element_size_, other.element_size_); - return ptr_ != other.ptr_; - } + DereferencingIterator& operator++() { + ++ptr_; + return *this; + } + DereferencingIterator& operator--() { + --ptr_; + return *this; + } + DereferencingIterator operator++(int) { + DereferencingIterator iter_copy(ptr_); + ++ptr_; + return iter_copy; + } + DereferencingIterator operator--(int) { + DereferencingIterator iter_copy(ptr_); + --ptr_; + return iter_copy; + } - PacketIterator& operator++() { - ptr_ += element_size_; - return *this; - } - PacketIterator& operator--() { - ptr_ -= element_size_; - return *this; - } - PacketIterator operator++(int) { - PacketIterator iter_copy(ptr_, element_size_); - ptr_ += element_size_; - return iter_copy; - } - PacketIterator operator--(int) { - PacketIterator iter_copy(ptr_, element_size_); - ptr_ -= element_size_; - return iter_copy; - } + template + std::enable_if_t operator*() { + return **ptr_; + } - T& operator*() { return *reinterpret_cast(ptr_); } - const T& operator*() const { return *reinterpret_cast(ptr_); } + template + std::enable_if_t<_IsConst, reference> operator*() const { + return **ptr_; + } - T* operator->() { return reinterpret_cast(ptr_); } - const T* operator->() const { return reinterpret_cast(ptr_); } + template + std::enable_if_t operator->() { + return *ptr_; + } - private: - PacketIterator(typename std::conditional::value, - const void*, - void*>::type p, - size_t s) - : ptr_(reinterpret_cast(p)), element_size_(s) {} - - typename std::conditional::value, const char*, char*>::type - ptr_; - size_t element_size_; -}; + template + std::enable_if_t<_IsConst, pointer> operator->() const { + return *ptr_; + } -// Suppose that we have a struct S where we are only interested in a specific -// member M. Given an array of S, PacketView can be used to treat the array -// as an array of M, without exposing the type S to surrounding code and without -// accessing the member through a virtual function. In this case, we want to -// have a common view for incoming and outgoing RtpPackets, hence the PacketView -// name. -// Note that constructing a PacketView bypasses the typesystem, so the caller -// has to take extra care when constructing these objects. The implementation -// also requires that the containing struct is standard-layout (e.g. POD). -// -// Usage example: -// struct A {...}; -// struct B { A a; ...}; -// struct C { A a; ...}; -// size_t len = 10; -// B* array1 = new B[len]; -// C* array2 = new C[len]; -// -// PacketView view1 = PacketView::Create(array1, len, offsetof(B, a)); -// PacketView view2 = PacketView::Create(array2, len, offsetof(C, a)); -// -// The following code works with either view1 or view2. -// void f(PacketView view) -// for (A& a : view) { -// DoSomething(a); -// } -template -class PacketView { - public: - template - static PacketView Create(U* ptr, size_t num_elements, size_t offset) { - static_assert(std::is_standard_layout::value, - "PacketView can only be created for standard layout types."); - static_assert(std::is_standard_layout::value, - "PacketView can only be created for standard layout types."); - return PacketView(ptr, num_elements, offset, sizeof(U)); - } + private: + representation ptr_; + }; using value_type = T; using reference = value_type&; using const_reference = const value_type&; - using iterator = PacketIterator; - using const_iterator = PacketIterator; + using iterator = DereferencingIterator; + using const_iterator = DereferencingIterator; using reverse_iterator = std::reverse_iterator; using const_reverse_iterator = std::reverse_iterator; - iterator begin() { return iterator(data_, element_size_); } - iterator end() { - auto end_ptr = data_ + num_elements_ * element_size_; - return iterator(end_ptr, element_size_); - } + iterator begin() { return iterator(elems_.data()); } + iterator end() { return iterator(elems_.data() + elems_.size()); } - const_iterator begin() const { return const_iterator(data_, element_size_); } + const_iterator begin() const { return const_iterator(elems_.data()); } const_iterator end() const { - auto end_ptr = data_ + num_elements_ * element_size_; - return const_iterator(end_ptr, element_size_); + return const_iterator(elems_.data() + elems_.size()); } reverse_iterator rbegin() { return reverse_iterator(end()); } @@ -214,35 +178,27 @@ class PacketView { return const_reverse_iterator(begin()); } - size_t size() const { return num_elements_; } + size_t size() const { return elems_.size(); } - bool empty() const { return num_elements_ == 0; } + bool empty() const { return elems_.empty(); } T& operator[](size_t i) { - auto elem_ptr = data_ + i * element_size_; - return *reinterpret_cast(elem_ptr); + RTC_DCHECK_LT(i, elems_.size()); + return *elems_[i]; } const T& operator[](size_t i) const { - auto elem_ptr = data_ + i * element_size_; - return *reinterpret_cast(elem_ptr); + RTC_DCHECK_LT(i, elems_.size()); + return *elems_[i]; + } + + void push_back(T* elem) { + RTC_DCHECK(elem != nullptr); + elems_.push_back(elem); } private: - PacketView(typename std::conditional::value, - const void*, - void*>::type data, - size_t num_elements, - size_t offset, - size_t element_size) - : data_(reinterpret_cast(data) + offset), - num_elements_(num_elements), - element_size_(element_size) {} - - typename std::conditional::value, const char*, char*>::type - data_; - size_t num_elements_; - size_t element_size_; + std::vector elems_; }; // Conversion functions for version 2 of the wire format. @@ -296,7 +252,7 @@ class ParsedRtcEventLog { return error_ + " failed at " + file_ + " line " + std::to_string(line_); } - RTC_DEPRECATED operator bool() const { return ok(); } + ABSL_DEPRECATED("Use ok() instead") operator bool() const { return ok(); } private: ParseStatus() : error_(), file_(), line_(0) {} @@ -345,14 +301,12 @@ class ParsedRtcEventLog { struct LoggedRtpStreamView { LoggedRtpStreamView(uint32_t ssrc, - const LoggedRtpPacketIncoming* ptr, - size_t num_elements); + const std::vector& packets); LoggedRtpStreamView(uint32_t ssrc, - const LoggedRtpPacketOutgoing* ptr, - size_t num_elements); + const std::vector& packets); LoggedRtpStreamView(const LoggedRtpStreamView&); uint32_t ssrc; - PacketView packet_view; + DereferencingVector packet_view; }; class LogSegment { @@ -388,9 +342,8 @@ class ParsedRtcEventLog { // Reads an RtcEventLog from a string and returns success if successful. ParseStatus ParseString(const std::string& s); - // Reads an RtcEventLog from an istream and returns success if successful. - ParseStatus ParseStream( - std::istream& stream); // no-presubmit-check TODO(webrtc:8982) + // Reads an RtcEventLog from an string and returns success if successful. + ParseStatus ParseStream(const std::string& s); MediaType GetMediaType(uint32_t ssrc, PacketDirection direction) const; @@ -603,6 +556,15 @@ class ParsedRtcEventLog { } } + const std::vector& byes( + PacketDirection direction) const { + if (direction == kIncomingPacket) { + return incoming_bye_; + } else { + return outgoing_bye_; + } + } + const std::vector& transport_feedbacks( PacketDirection direction) const { if (direction == kIncomingPacket) { @@ -657,8 +619,7 @@ class ParsedRtcEventLog { std::vector GetRouteChanges() const; private: - ABSL_MUST_USE_RESULT ParseStatus ParseStreamInternal( - std::istream& stream); // no-presubmit-check TODO(webrtc:8982) + ABSL_MUST_USE_RESULT ParseStatus ParseStreamInternal(absl::string_view s); ABSL_MUST_USE_RESULT ParseStatus StoreParsedLegacyEvent(const rtclog::Event& event); @@ -849,6 +810,8 @@ class ParsedRtcEventLog { std::vector outgoing_fir_; std::vector incoming_pli_; std::vector outgoing_pli_; + std::vector incoming_bye_; + std::vector outgoing_bye_; std::vector incoming_transport_feedback_; std::vector outgoing_transport_feedback_; std::vector incoming_loss_notification_; diff --git a/logging/rtc_event_log/rtc_event_log_unittest.cc b/logging/rtc_event_log/rtc_event_log_unittest.cc index b6fa1db539..323e4fe009 100644 --- a/logging/rtc_event_log/rtc_event_log_unittest.cc +++ b/logging/rtc_event_log/rtc_event_log_unittest.cc @@ -899,9 +899,9 @@ TEST_P(RtcEventLogCircularBufferTest, KeepsMostRecentEvents) { auto task_queue_factory = CreateDefaultTaskQueueFactory(); RtcEventLogFactory rtc_event_log_factory(task_queue_factory.get()); - // When log_dumper goes out of scope, it causes the log file to be flushed + // When `log` goes out of scope, it causes the log file to be flushed // to disk. - std::unique_ptr log_dumper = + std::unique_ptr log = rtc_event_log_factory.CreateRtcEventLog(encoding_type_); for (size_t i = 0; i < kNumEvents; i++) { @@ -911,18 +911,18 @@ TEST_P(RtcEventLogCircularBufferTest, KeepsMostRecentEvents) { // simplicity. // We base the various values on the index. We use this for some basic // consistency checks when we read back. - log_dumper->Log(std::make_unique( + log->Log(std::make_unique( i, kStartBitrate + i * 1000)); fake_clock->AdvanceTime(TimeDelta::Millis(10)); } int64_t start_time_us = rtc::TimeMicros(); int64_t utc_start_time_us = rtc::TimeUTCMicros(); - log_dumper->StartLogging( + log->StartLogging( std::make_unique(temp_filename, 10000000), RtcEventLog::kImmediateOutput); fake_clock->AdvanceTime(TimeDelta::Millis(10)); int64_t stop_time_us = rtc::TimeMicros(); - log_dumper->StopLogging(); + log->StopLogging(); // Read the generated file from disk. ParsedRtcEventLog parsed_log; @@ -944,7 +944,7 @@ TEST_P(RtcEventLogCircularBufferTest, KeepsMostRecentEvents) { EXPECT_LT(probe_success_events.size(), kNumEvents); ASSERT_GT(probe_success_events.size(), 1u); - int64_t first_timestamp_us = probe_success_events[0].timestamp_us; + int64_t first_timestamp_ms = probe_success_events[0].timestamp.ms(); uint32_t first_id = probe_success_events[0].id; int32_t first_bitrate_bps = probe_success_events[0].bitrate_bps; // We want to reset the time to what we used when generating the events, but @@ -953,13 +953,16 @@ TEST_P(RtcEventLogCircularBufferTest, KeepsMostRecentEvents) { // destroyed before the new one is created, so we have to reset() first. fake_clock.reset(); fake_clock = std::make_unique(); - fake_clock->SetTime(Timestamp::Micros(first_timestamp_us)); + fake_clock->SetTime(Timestamp::Millis(first_timestamp_ms)); for (size_t i = 1; i < probe_success_events.size(); i++) { fake_clock->AdvanceTime(TimeDelta::Millis(10)); verifier_.VerifyLoggedBweProbeSuccessEvent( RtcEventProbeResultSuccess(first_id + i, first_bitrate_bps + i * 1000), probe_success_events[i]); } + + // Clean up temporary file - can be pretty slow. + remove(temp_filename.c_str()); } INSTANTIATE_TEST_SUITE_P( @@ -971,4 +974,64 @@ INSTANTIATE_TEST_SUITE_P( // TODO(terelius): Verify parser behavior if the timestamps are not // monotonically increasing in the log. +TEST(DereferencingVectorTest, NonConstVector) { + std::vector v{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + DereferencingVector even; + EXPECT_TRUE(even.empty()); + EXPECT_EQ(even.size(), 0u); + EXPECT_EQ(even.begin(), even.end()); + for (size_t i = 0; i < v.size(); i += 2) { + even.push_back(&v[i]); + } + EXPECT_FALSE(even.empty()); + EXPECT_EQ(even.size(), 5u); + EXPECT_NE(even.begin(), even.end()); + + // Test direct access. + for (size_t i = 0; i < even.size(); i++) { + EXPECT_EQ(even[i], 2 * static_cast(i)); + } + + // Test iterator. + for (int val : even) { + EXPECT_EQ(val % 2, 0); + } + + // Test modification through iterator. + for (int& val : even) { + val = val * 2; + EXPECT_EQ(val % 2, 0); + } + + // Backing vector should have been modified. + std::vector expected{0, 1, 4, 3, 8, 5, 12, 7, 16, 9}; + for (size_t i = 0; i < v.size(); i++) { + EXPECT_EQ(v[i], expected[i]); + } +} + +TEST(DereferencingVectorTest, ConstVector) { + std::vector v{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + DereferencingVector odd; + EXPECT_TRUE(odd.empty()); + EXPECT_EQ(odd.size(), 0u); + EXPECT_EQ(odd.begin(), odd.end()); + for (size_t i = 1; i < v.size(); i += 2) { + odd.push_back(&v[i]); + } + EXPECT_FALSE(odd.empty()); + EXPECT_EQ(odd.size(), 5u); + EXPECT_NE(odd.begin(), odd.end()); + + // Test direct access. + for (size_t i = 0; i < odd.size(); i++) { + EXPECT_EQ(odd[i], 2 * static_cast(i) + 1); + } + + // Test iterator. + for (int val : odd) { + EXPECT_EQ(val % 2, 1); + } +} + } // namespace webrtc diff --git a/logging/rtc_event_log/rtc_event_log_unittest_helper.cc b/logging/rtc_event_log/rtc_event_log_unittest_helper.cc index 2896c130f2..0960c98502 100644 --- a/logging/rtc_event_log/rtc_event_log_unittest_helper.cc +++ b/logging/rtc_event_log/rtc_event_log_unittest_helper.cc @@ -338,6 +338,19 @@ rtcp::Pli EventGenerator::NewPli() { return pli; } +rtcp::Bye EventGenerator::NewBye() { + rtcp::Bye bye; + bye.SetSenderSsrc(prng_.Rand()); + std::vector csrcs{prng_.Rand(), prng_.Rand()}; + bye.SetCsrcs(csrcs); + if (prng_.Rand(0, 2)) { + bye.SetReason("foo"); + } else { + bye.SetReason("bar"); + } + return bye; +} + rtcp::TransportFeedback EventGenerator::NewTransportFeedback() { rtcp::TransportFeedback transport_feedback; uint16_t base_seq_no = prng_.Rand(); @@ -396,6 +409,7 @@ EventGenerator::NewRtcpPacketIncoming() { kPli, kNack, kRemb, + kBye, kTransportFeedback, kNumValues }; @@ -437,6 +451,11 @@ EventGenerator::NewRtcpPacketIncoming() { rtc::Buffer buffer = remb.Build(); return std::make_unique(buffer); } + case SupportedRtcpTypes::kBye: { + rtcp::Bye bye = NewBye(); + rtc::Buffer buffer = bye.Build(); + return std::make_unique(buffer); + } case SupportedRtcpTypes::kTransportFeedback: { rtcp::TransportFeedback transport_feedback = NewTransportFeedback(); rtc::Buffer buffer = transport_feedback.Build(); @@ -459,6 +478,7 @@ EventGenerator::NewRtcpPacketOutgoing() { kPli, kNack, kRemb, + kBye, kTransportFeedback, kNumValues }; @@ -500,6 +520,11 @@ EventGenerator::NewRtcpPacketOutgoing() { rtc::Buffer buffer = remb.Build(); return std::make_unique(buffer); } + case SupportedRtcpTypes::kBye: { + rtcp::Bye bye = NewBye(); + rtc::Buffer buffer = bye.Build(); + return std::make_unique(buffer); + } case SupportedRtcpTypes::kTransportFeedback: { rtcp::TransportFeedback transport_feedback = NewTransportFeedback(); rtc::Buffer buffer = transport_feedback.Build(); @@ -906,7 +931,8 @@ void EventVerifier::VerifyLoggedIceCandidatePairEvent( } } -void VerifyLoggedRtpHeader(const RtpPacket& original_header, +template +void VerifyLoggedRtpHeader(const Event& original_header, const RTPHeader& logged_header) { // Standard RTP header. EXPECT_EQ(original_header.Marker(), logged_header.markerBit); @@ -915,53 +941,57 @@ void VerifyLoggedRtpHeader(const RtpPacket& original_header, EXPECT_EQ(original_header.Timestamp(), logged_header.timestamp); EXPECT_EQ(original_header.Ssrc(), logged_header.ssrc); - EXPECT_EQ(original_header.headers_size(), logged_header.headerLength); + EXPECT_EQ(original_header.header_length(), logged_header.headerLength); // TransmissionOffset header extension. - ASSERT_EQ(original_header.HasExtension(), + ASSERT_EQ(original_header.template HasExtension(), logged_header.extension.hasTransmissionTimeOffset); if (logged_header.extension.hasTransmissionTimeOffset) { int32_t offset; - ASSERT_TRUE(original_header.GetExtension(&offset)); + ASSERT_TRUE( + original_header.template GetExtension(&offset)); EXPECT_EQ(offset, logged_header.extension.transmissionTimeOffset); } // AbsoluteSendTime header extension. - ASSERT_EQ(original_header.HasExtension(), + ASSERT_EQ(original_header.template HasExtension(), logged_header.extension.hasAbsoluteSendTime); if (logged_header.extension.hasAbsoluteSendTime) { uint32_t sendtime; - ASSERT_TRUE(original_header.GetExtension(&sendtime)); + ASSERT_TRUE( + original_header.template GetExtension(&sendtime)); EXPECT_EQ(sendtime, logged_header.extension.absoluteSendTime); } // TransportSequenceNumber header extension. - ASSERT_EQ(original_header.HasExtension(), + ASSERT_EQ(original_header.template HasExtension(), logged_header.extension.hasTransportSequenceNumber); if (logged_header.extension.hasTransportSequenceNumber) { uint16_t seqnum; - ASSERT_TRUE(original_header.GetExtension(&seqnum)); + ASSERT_TRUE(original_header.template GetExtension( + &seqnum)); EXPECT_EQ(seqnum, logged_header.extension.transportSequenceNumber); } // AudioLevel header extension. - ASSERT_EQ(original_header.HasExtension(), + ASSERT_EQ(original_header.template HasExtension(), logged_header.extension.hasAudioLevel); if (logged_header.extension.hasAudioLevel) { bool voice_activity; uint8_t audio_level; - ASSERT_TRUE(original_header.GetExtension(&voice_activity, - &audio_level)); + ASSERT_TRUE(original_header.template GetExtension( + &voice_activity, &audio_level)); EXPECT_EQ(voice_activity, logged_header.extension.voiceActivity); EXPECT_EQ(audio_level, logged_header.extension.audioLevel); } // VideoOrientation header extension. - ASSERT_EQ(original_header.HasExtension(), + ASSERT_EQ(original_header.template HasExtension(), logged_header.extension.hasVideoRotation); if (logged_header.extension.hasVideoRotation) { uint8_t rotation; - ASSERT_TRUE(original_header.GetExtension(&rotation)); + ASSERT_TRUE( + original_header.template GetExtension(&rotation)); EXPECT_EQ(ConvertCVOByteToVideoRotation(rotation), logged_header.extension.videoRotation); } @@ -990,8 +1020,7 @@ void EventVerifier::VerifyLoggedRtpPacketIncoming( const LoggedRtpPacketIncoming& logged_event) const { EXPECT_EQ(original_event.timestamp_ms(), logged_event.log_time_ms()); - EXPECT_EQ(original_event.header().headers_size(), - logged_event.rtp.header_length); + EXPECT_EQ(original_event.header_length(), logged_event.rtp.header_length); EXPECT_EQ(original_event.packet_length(), logged_event.rtp.total_length); @@ -1000,7 +1029,7 @@ void EventVerifier::VerifyLoggedRtpPacketIncoming( EXPECT_EQ(original_event.padding_length(), logged_event.rtp.header.paddingLength); - VerifyLoggedRtpHeader(original_event.header(), logged_event.rtp.header); + VerifyLoggedRtpHeader(original_event, logged_event.rtp.header); } void EventVerifier::VerifyLoggedRtpPacketOutgoing( @@ -1008,8 +1037,7 @@ void EventVerifier::VerifyLoggedRtpPacketOutgoing( const LoggedRtpPacketOutgoing& logged_event) const { EXPECT_EQ(original_event.timestamp_ms(), logged_event.log_time_ms()); - EXPECT_EQ(original_event.header().headers_size(), - logged_event.rtp.header_length); + EXPECT_EQ(original_event.header_length(), logged_event.rtp.header_length); EXPECT_EQ(original_event.packet_length(), logged_event.rtp.total_length); @@ -1021,7 +1049,7 @@ void EventVerifier::VerifyLoggedRtpPacketOutgoing( // TODO(terelius): Probe cluster ID isn't parsed, used or tested. Unless // someone has a strong reason to keep it, it'll be removed. - VerifyLoggedRtpHeader(original_event.header(), logged_event.rtp.header); + VerifyLoggedRtpHeader(original_event, logged_event.rtp.header); } void EventVerifier::VerifyLoggedGenericPacketSent( @@ -1096,10 +1124,10 @@ void EventVerifier::VerifyReportBlock( } void EventVerifier::VerifyLoggedSenderReport( - int64_t log_time_us, + int64_t log_time_ms, const rtcp::SenderReport& original_sr, const LoggedRtcpPacketSenderReport& logged_sr) { - EXPECT_EQ(log_time_us, logged_sr.log_time_us()); + EXPECT_EQ(log_time_ms, logged_sr.log_time_ms()); EXPECT_EQ(original_sr.sender_ssrc(), logged_sr.sr.sender_ssrc()); EXPECT_EQ(original_sr.ntp(), logged_sr.sr.ntp()); EXPECT_EQ(original_sr.rtp_timestamp(), logged_sr.sr.rtp_timestamp()); @@ -1116,10 +1144,10 @@ void EventVerifier::VerifyLoggedSenderReport( } void EventVerifier::VerifyLoggedReceiverReport( - int64_t log_time_us, + int64_t log_time_ms, const rtcp::ReceiverReport& original_rr, const LoggedRtcpPacketReceiverReport& logged_rr) { - EXPECT_EQ(log_time_us, logged_rr.log_time_us()); + EXPECT_EQ(log_time_ms, logged_rr.log_time_ms()); EXPECT_EQ(original_rr.sender_ssrc(), logged_rr.rr.sender_ssrc()); ASSERT_EQ(original_rr.report_blocks().size(), logged_rr.rr.report_blocks().size()); @@ -1130,9 +1158,10 @@ void EventVerifier::VerifyLoggedReceiverReport( } void EventVerifier::VerifyLoggedExtendedReports( - int64_t log_time_us, + int64_t log_time_ms, const rtcp::ExtendedReports& original_xr, const LoggedRtcpPacketExtendedReports& logged_xr) { + EXPECT_EQ(log_time_ms, logged_xr.log_time_ms()); EXPECT_EQ(original_xr.sender_ssrc(), logged_xr.xr.sender_ssrc()); EXPECT_EQ(original_xr.rrtr().has_value(), logged_xr.xr.rrtr().has_value()); @@ -1170,11 +1199,11 @@ void EventVerifier::VerifyLoggedExtendedReports( } } -void EventVerifier::VerifyLoggedFir(int64_t log_time_us, +void EventVerifier::VerifyLoggedFir(int64_t log_time_ms, const rtcp::Fir& original_fir, const LoggedRtcpPacketFir& logged_fir) { + EXPECT_EQ(log_time_ms, logged_fir.log_time_ms()); EXPECT_EQ(original_fir.sender_ssrc(), logged_fir.fir.sender_ssrc()); - const auto& original_requests = original_fir.requests(); const auto& logged_requests = logged_fir.fir.requests(); ASSERT_EQ(original_requests.size(), logged_requests.size()); @@ -1184,25 +1213,35 @@ void EventVerifier::VerifyLoggedFir(int64_t log_time_us, } } -void EventVerifier::VerifyLoggedPli(int64_t log_time_us, +void EventVerifier::VerifyLoggedPli(int64_t log_time_ms, const rtcp::Pli& original_pli, const LoggedRtcpPacketPli& logged_pli) { + EXPECT_EQ(log_time_ms, logged_pli.log_time_ms()); EXPECT_EQ(original_pli.sender_ssrc(), logged_pli.pli.sender_ssrc()); EXPECT_EQ(original_pli.media_ssrc(), logged_pli.pli.media_ssrc()); } -void EventVerifier::VerifyLoggedNack(int64_t log_time_us, +void EventVerifier::VerifyLoggedBye(int64_t log_time_ms, + const rtcp::Bye& original_bye, + const LoggedRtcpPacketBye& logged_bye) { + EXPECT_EQ(log_time_ms, logged_bye.log_time_ms()); + EXPECT_EQ(original_bye.sender_ssrc(), logged_bye.bye.sender_ssrc()); + EXPECT_EQ(original_bye.csrcs(), logged_bye.bye.csrcs()); + EXPECT_EQ(original_bye.reason(), logged_bye.bye.reason()); +} + +void EventVerifier::VerifyLoggedNack(int64_t log_time_ms, const rtcp::Nack& original_nack, const LoggedRtcpPacketNack& logged_nack) { - EXPECT_EQ(log_time_us, logged_nack.log_time_us()); + EXPECT_EQ(log_time_ms, logged_nack.log_time_ms()); EXPECT_EQ(original_nack.packet_ids(), logged_nack.nack.packet_ids()); } void EventVerifier::VerifyLoggedTransportFeedback( - int64_t log_time_us, + int64_t log_time_ms, const rtcp::TransportFeedback& original_transport_feedback, const LoggedRtcpPacketTransportFeedback& logged_transport_feedback) { - EXPECT_EQ(log_time_us, logged_transport_feedback.log_time_us()); + EXPECT_EQ(log_time_ms, logged_transport_feedback.log_time_ms()); ASSERT_EQ( original_transport_feedback.GetReceivedPackets().size(), logged_transport_feedback.transport_feedback.GetReceivedPackets().size()); @@ -1219,19 +1258,19 @@ void EventVerifier::VerifyLoggedTransportFeedback( } } -void EventVerifier::VerifyLoggedRemb(int64_t log_time_us, +void EventVerifier::VerifyLoggedRemb(int64_t log_time_ms, const rtcp::Remb& original_remb, const LoggedRtcpPacketRemb& logged_remb) { - EXPECT_EQ(log_time_us, logged_remb.log_time_us()); + EXPECT_EQ(log_time_ms, logged_remb.log_time_ms()); EXPECT_EQ(original_remb.ssrcs(), logged_remb.remb.ssrcs()); EXPECT_EQ(original_remb.bitrate_bps(), logged_remb.remb.bitrate_bps()); } void EventVerifier::VerifyLoggedLossNotification( - int64_t log_time_us, + int64_t log_time_ms, const rtcp::LossNotification& original_loss_notification, const LoggedRtcpPacketLossNotification& logged_loss_notification) { - EXPECT_EQ(log_time_us, logged_loss_notification.log_time_us()); + EXPECT_EQ(log_time_ms, logged_loss_notification.log_time_ms()); EXPECT_EQ(original_loss_notification.last_decoded(), logged_loss_notification.loss_notification.last_decoded()); EXPECT_EQ(original_loss_notification.last_received(), @@ -1246,7 +1285,7 @@ void EventVerifier::VerifyLoggedStartEvent( const LoggedStartEvent& logged_event) const { EXPECT_EQ(start_time_us / 1000, logged_event.log_time_ms()); if (encoding_type_ == RtcEventLog::EncodingType::NewFormat) { - EXPECT_EQ(utc_start_time_us / 1000, logged_event.utc_start_time_ms); + EXPECT_EQ(utc_start_time_us / 1000, logged_event.utc_start_time.ms()); } } diff --git a/logging/rtc_event_log/rtc_event_log_unittest_helper.h b/logging/rtc_event_log/rtc_event_log_unittest_helper.h index bf9fb573c1..eb16592271 100644 --- a/logging/rtc_event_log/rtc_event_log_unittest_helper.h +++ b/logging/rtc_event_log/rtc_event_log_unittest_helper.h @@ -45,6 +45,7 @@ #include "logging/rtc_event_log/rtc_event_log_parser.h" #include "logging/rtc_event_log/rtc_stream_config.h" #include "modules/rtp_rtcp/include/rtp_header_extension_map.h" +#include "modules/rtp_rtcp/source/rtcp_packet/bye.h" #include "modules/rtp_rtcp/source/rtcp_packet/extended_reports.h" #include "modules/rtp_rtcp/source/rtcp_packet/fir.h" #include "modules/rtp_rtcp/source/rtcp_packet/loss_notification.h" @@ -93,6 +94,7 @@ class EventGenerator { rtcp::Remb NewRemb(); rtcp::Fir NewFir(); rtcp::Pli NewPli(); + rtcp::Bye NewBye(); rtcp::TransportFeedback NewTransportFeedback(); rtcp::LossNotification NewLossNotification(); @@ -258,35 +260,38 @@ class EventVerifier { const RtcEventRtcpPacketOutgoing& original_event, const LoggedRtcpPacketOutgoing& logged_event) const; - void VerifyLoggedSenderReport(int64_t log_time_us, + void VerifyLoggedSenderReport(int64_t log_time_ms, const rtcp::SenderReport& original_sr, const LoggedRtcpPacketSenderReport& logged_sr); void VerifyLoggedReceiverReport( - int64_t log_time_us, + int64_t log_time_ms, const rtcp::ReceiverReport& original_rr, const LoggedRtcpPacketReceiverReport& logged_rr); void VerifyLoggedExtendedReports( - int64_t log_time_us, + int64_t log_time_ms, const rtcp::ExtendedReports& original_xr, const LoggedRtcpPacketExtendedReports& logged_xr); - void VerifyLoggedFir(int64_t log_time_us, + void VerifyLoggedFir(int64_t log_time_ms, const rtcp::Fir& original_fir, const LoggedRtcpPacketFir& logged_fir); - void VerifyLoggedPli(int64_t log_time_us, + void VerifyLoggedPli(int64_t log_time_ms, const rtcp::Pli& original_pli, const LoggedRtcpPacketPli& logged_pli); - void VerifyLoggedNack(int64_t log_time_us, + void VerifyLoggedBye(int64_t log_time_ms, + const rtcp::Bye& original_bye, + const LoggedRtcpPacketBye& logged_bye); + void VerifyLoggedNack(int64_t log_time_ms, const rtcp::Nack& original_nack, const LoggedRtcpPacketNack& logged_nack); void VerifyLoggedTransportFeedback( - int64_t log_time_us, + int64_t log_time_ms, const rtcp::TransportFeedback& original_transport_feedback, const LoggedRtcpPacketTransportFeedback& logged_transport_feedback); - void VerifyLoggedRemb(int64_t log_time_us, + void VerifyLoggedRemb(int64_t log_time_ms, const rtcp::Remb& original_remb, const LoggedRtcpPacketRemb& logged_remb); void VerifyLoggedLossNotification( - int64_t log_time_us, + int64_t log_time_ms, const rtcp::LossNotification& original_loss_notification, const LoggedRtcpPacketLossNotification& logged_loss_notification); diff --git a/logging/rtc_event_log/rtc_event_processor_unittest.cc b/logging/rtc_event_log/rtc_event_processor_unittest.cc index 4ec5abee5e..b0cec25f1f 100644 --- a/logging/rtc_event_log/rtc_event_processor_unittest.cc +++ b/logging/rtc_event_log/rtc_event_processor_unittest.cc @@ -29,7 +29,7 @@ std::vector CreateEventList( std::initializer_list timestamp_list) { std::vector v; for (int64_t timestamp_ms : timestamp_list) { - v.emplace_back(timestamp_ms * 1000); // Convert ms to us. + v.emplace_back(Timestamp::Millis(timestamp_ms)); } return v; } @@ -41,7 +41,7 @@ CreateRandomEventLists(size_t num_lists, size_t num_elements, uint64_t seed) { for (size_t elem = 0; elem < num_elements; elem++) { uint32_t i = prng.Rand(0u, num_lists - 1); int64_t timestamp_ms = elem; - lists[i].emplace_back(timestamp_ms * 1000); + lists[i].emplace_back(Timestamp::Millis(timestamp_ms)); } return lists; } @@ -146,8 +146,8 @@ TEST(RtcEventProcessor, DifferentTypes) { result.push_back(elem.log_time_ms()); }; - std::vector events1{LoggedStartEvent(2000)}; - std::vector events2{LoggedStopEvent(1000)}; + std::vector events1{LoggedStartEvent(Timestamp::Millis(2))}; + std::vector events2{LoggedStopEvent(Timestamp::Millis(1))}; RtcEventProcessor processor; processor.AddEvents(events1, f1); processor.AddEvents(events2, f2); diff --git a/media/BUILD.gn b/media/BUILD.gn index f653af7a61..5f0f527b8f 100644 --- a/media/BUILD.gn +++ b/media/BUILD.gn @@ -23,20 +23,15 @@ config("rtc_media_defines_config") { defines = [ "HAVE_WEBRTC_VIDEO" ] } -rtc_library("rtc_h264_profile_id") { +# Remove once downstream projects stop depend on this. +rtc_source_set("rtc_h264_profile_id") { visibility = [ "*" ] sources = [ "base/h264_profile_level_id.cc", "base/h264_profile_level_id.h", ] - - deps = [ - "../rtc_base", - "../rtc_base:checks", - "../rtc_base:rtc_base_approved", - "../rtc_base/system:rtc_export", - ] - absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + public_deps = # no-presubmit-check TODO(webrtc:8603) + [ "../api/video_codecs:video_codecs_api" ] } rtc_source_set("rtc_media_config") { @@ -44,30 +39,24 @@ rtc_source_set("rtc_media_config") { sources = [ "base/media_config.h" ] } -rtc_library("rtc_vp9_profile") { +# Remove once downstream projects stop depend on this. +rtc_source_set("rtc_vp9_profile") { visibility = [ "*" ] - sources = [ - "base/vp9_profile.cc", - "base/vp9_profile.h", - ] - - deps = [ - "../api/video_codecs:video_codecs_api", - "../rtc_base:rtc_base_approved", - "../rtc_base/system:rtc_export", - ] - absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + sources = [ "base/vp9_profile.h" ] + public_deps = # no-presubmit-check TODO(webrtc:8603) + [ "../api/video_codecs:video_codecs_api" ] } -rtc_library("rtc_sdp_fmtp_utils") { +rtc_library("rtc_sdp_video_format_utils") { visibility = [ "*" ] sources = [ - "base/sdp_fmtp_utils.cc", - "base/sdp_fmtp_utils.h", + "base/sdp_video_format_utils.cc", + "base/sdp_video_format_utils.h", ] deps = [ "../api/video_codecs:video_codecs_api", + "../rtc_base:checks", "../rtc_base:stringutils", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] @@ -78,9 +67,7 @@ rtc_library("rtc_media_base") { defines = [] libs = [] deps = [ - ":rtc_h264_profile_id", ":rtc_media_config", - ":rtc_vp9_profile", "../api:array_view", "../api:audio_options_api", "../api:frame_transformer_interface", @@ -88,11 +75,13 @@ rtc_library("rtc_media_base") { "../api:rtc_error", "../api:rtp_parameters", "../api:scoped_refptr", + "../api:sequence_checker", "../api/audio:audio_frame_processor", "../api/audio_codecs:audio_codecs_api", "../api/crypto:frame_decryptor_interface", "../api/crypto:frame_encryptor_interface", "../api/crypto:options", + "../api/transport:datagram_transport_interface", "../api/transport:stun_types", "../api/transport:webrtc_key_value_config", "../api/transport/rtp:rtp_source", @@ -113,11 +102,14 @@ rtc_library("rtc_media_base") { "../rtc_base:rtc_base_approved", "../rtc_base:rtc_task_queue", "../rtc_base:sanitizer", + "../rtc_base:socket", "../rtc_base:stringutils", "../rtc_base/synchronization:mutex", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/system:file_wrapper", + "../rtc_base/system:no_unique_address", "../rtc_base/system:rtc_export", + "../rtc_base/task_utils:pending_task_safety_flag", + "../rtc_base/task_utils:to_queued_task", "../rtc_base/third_party/sigslot", "../system_wrappers:field_trial", ] @@ -141,8 +133,6 @@ rtc_library("rtc_media_base") { "base/media_engine.h", "base/rid_description.cc", "base/rid_description.h", - "base/rtp_data_engine.cc", - "base/rtp_data_engine.h", "base/rtp_utils.cc", "base/rtp_utils.h", "base/stream_params.cc", @@ -160,16 +150,6 @@ rtc_library("rtc_media_base") { ] } -rtc_library("rtc_constants") { - defines = [] - libs = [] - deps = [] - sources = [ - "engine/constants.cc", - "engine/constants.h", - ] -} - rtc_library("rtc_simulcast_encoder_adapter") { visibility = [ "*" ] defines = [] @@ -182,6 +162,7 @@ rtc_library("rtc_simulcast_encoder_adapter") { ":rtc_media_base", "../api:fec_controller_api", "../api:scoped_refptr", + "../api:sequence_checker", "../api/video:video_codec_constants", "../api/video:video_frame", "../api/video:video_rtp_headers", @@ -192,14 +173,17 @@ rtc_library("rtc_simulcast_encoder_adapter") { "../modules/video_coding:video_coding_utility", "../rtc_base:checks", "../rtc_base:rtc_base_approved", + "../rtc_base/experiments:encoder_info_settings", "../rtc_base/experiments:rate_control_settings", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/system:no_unique_address", "../rtc_base/system:rtc_export", "../system_wrappers", "../system_wrappers:field_trial", ] - absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/types:optional", + ] } rtc_library("rtc_encoder_simulcast_proxy") { @@ -227,9 +211,7 @@ rtc_library("rtc_internal_video_codecs") { defines = [] libs = [] deps = [ - ":rtc_constants", ":rtc_encoder_simulcast_proxy", - ":rtc_h264_profile_id", ":rtc_media_base", ":rtc_simulcast_encoder_adapter", "../api/video:encoded_image", @@ -248,7 +230,6 @@ rtc_library("rtc_internal_video_codecs") { "../modules/video_coding/codecs/av1:libaom_av1_decoder", "../modules/video_coding/codecs/av1:libaom_av1_encoder", "../rtc_base:checks", - "../rtc_base:deprecation", "../rtc_base:rtc_base_approved", "../rtc_base/system:rtc_export", "../test:fake_video_codecs", @@ -276,13 +257,13 @@ rtc_library("rtc_audio_video") { defines = [] libs = [] deps = [ - ":rtc_constants", ":rtc_media_base", "../api:call_api", "../api:libjingle_peerconnection_api", "../api:media_stream_interface", "../api:rtp_parameters", "../api:scoped_refptr", + "../api:sequence_checker", "../api:transport_api", "../api/audio:audio_frame_processor", "../api/audio:audio_mixer_api", @@ -320,12 +301,15 @@ rtc_library("rtc_audio_video") { "../rtc_base:ignore_wundef", "../rtc_base:rtc_task_queue", "../rtc_base:stringutils", + "../rtc_base:threading", "../rtc_base/experiments:field_trial_parser", "../rtc_base/experiments:min_video_bitrate_experiment", "../rtc_base/experiments:normalize_simulcast_size_experiment", "../rtc_base/experiments:rate_control_settings", "../rtc_base/synchronization:mutex", "../rtc_base/system:rtc_export", + "../rtc_base/task_utils:pending_task_safety_flag", + "../rtc_base/task_utils:to_queued_task", "../rtc_base/third_party/base64", "../system_wrappers", "../system_wrappers:metrics", @@ -396,55 +380,116 @@ rtc_library("rtc_media_engine_defaults") { ] } -rtc_library("rtc_data") { - defines = [ - # "SCTP_DEBUG" # Uncomment for SCTP debugging. - ] +rtc_source_set("rtc_data_sctp_transport_internal") { + sources = [ "sctp/sctp_transport_internal.h" ] deps = [ - ":rtc_media_base", - "../api:call_api", - "../api:transport_api", + "../api/transport:datagram_transport_interface", + "../media:rtc_media_base", "../p2p:rtc_p2p", - "../rtc_base", "../rtc_base:rtc_base_approved", - "../rtc_base/synchronization:mutex", - "../rtc_base/task_utils:pending_task_safety_flag", - "../rtc_base/task_utils:to_queued_task", + "../rtc_base:threading", "../rtc_base/third_party/sigslot", - "../system_wrappers", - ] - absl_deps = [ - "//third_party/abseil-cpp/absl/algorithm:container", - "//third_party/abseil-cpp/absl/base:core_headers", - "//third_party/abseil-cpp/absl/types:optional", ] +} - if (rtc_enable_sctp) { +if (rtc_build_dcsctp) { + rtc_library("rtc_data_dcsctp_transport") { sources = [ - "sctp/sctp_transport.cc", - "sctp/sctp_transport.h", - "sctp/sctp_transport_internal.h", + "sctp/dcsctp_transport.cc", + "sctp/dcsctp_transport.h", + ] + deps = [ + ":rtc_data_sctp_transport_internal", + "../api:array_view", + "../media:rtc_media_base", + "../net/dcsctp/public:factory", + "../net/dcsctp/public:socket", + "../net/dcsctp/public:types", + "../net/dcsctp/public:utils", + "../net/dcsctp/timer:task_queue_timeout", + "../p2p:rtc_p2p", + "../rtc_base:checks", + "../rtc_base:rtc_base_approved", + "../rtc_base:threading", + "../rtc_base/task_utils:pending_task_safety_flag", + "../rtc_base/task_utils:to_queued_task", + "../rtc_base/third_party/sigslot:sigslot", + "../system_wrappers", + ] + absl_deps += [ + "//third_party/abseil-cpp/absl/strings:strings", + "//third_party/abseil-cpp/absl/types:optional", ] - } else { - # libtool on mac does not like empty targets. - sources = [ "sctp/noop.cc" ] } +} - if (rtc_enable_sctp && rtc_build_usrsctp) { - deps += [ - "../api/transport:sctp_transport_factory_interface", +if (rtc_build_usrsctp) { + rtc_library("rtc_data_usrsctp_transport") { + defines = [ + # "SCTP_DEBUG" # Uncomment for SCTP debugging. + ] + sources = [ + "sctp/usrsctp_transport.cc", + "sctp/usrsctp_transport.h", + ] + deps = [ + ":rtc_data_sctp_transport_internal", + "../media:rtc_media_base", + "../p2p:rtc_p2p", + "../rtc_base", + "../rtc_base:rtc_base_approved", + "../rtc_base:threading", + "../rtc_base/synchronization:mutex", + "../rtc_base/task_utils:pending_task_safety_flag", + "../rtc_base/task_utils:to_queued_task", + "../rtc_base/third_party/sigslot:sigslot", "//third_party/usrsctp", ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/types:optional", + ] + } +} + +rtc_library("rtc_data_sctp_transport_factory") { + defines = [] + sources = [ + "sctp/sctp_transport_factory.cc", + "sctp/sctp_transport_factory.h", + ] + deps = [ + ":rtc_data_sctp_transport_internal", + "../api/transport:sctp_transport_factory_interface", + "../rtc_base:threading", + "../rtc_base/experiments:field_trial_parser", + "../rtc_base/system:unused", + ] + + if (rtc_enable_sctp) { + assert(rtc_build_dcsctp || rtc_build_usrsctp, + "An SCTP backend is required to enable SCTP") + } + + if (rtc_build_dcsctp) { + defines += [ "WEBRTC_HAVE_DCSCTP" ] + deps += [ + ":rtc_data_dcsctp_transport", + "../system_wrappers", + "../system_wrappers:field_trial", + ] + } + + if (rtc_build_usrsctp) { + defines += [ "WEBRTC_HAVE_USRSCTP" ] + deps += [ ":rtc_data_usrsctp_transport" ] } } rtc_source_set("rtc_media") { visibility = [ "*" ] allow_poison = [ "audio_codecs" ] # TODO(bugs.webrtc.org/8396): Remove. - deps = [ - ":rtc_audio_video", - ":rtc_data", - ] + deps = [ ":rtc_audio_video" ] } if (rtc_include_tests) { @@ -483,6 +528,7 @@ if (rtc_include_tests) { "../rtc_base:rtc_base_approved", "../rtc_base:rtc_task_queue", "../rtc_base:stringutils", + "../rtc_base:threading", "../rtc_base/synchronization:mutex", "../rtc_base/third_party/sigslot", "../test:test_support", @@ -511,157 +557,164 @@ if (rtc_include_tests) { ] } - rtc_media_unittests_resources = [ - "../resources/media/captured-320x240-2s-48.frames", - "../resources/media/faces.1280x720_P420.yuv", - "../resources/media/faces_I400.jpg", - "../resources/media/faces_I411.jpg", - "../resources/media/faces_I420.jpg", - "../resources/media/faces_I422.jpg", - "../resources/media/faces_I444.jpg", - ] - - if (is_ios) { - bundle_data("rtc_media_unittests_bundle_data") { - testonly = true - sources = rtc_media_unittests_resources - outputs = [ "{{bundle_resources_dir}}/{{source_file_part}}" ] - } - } - - rtc_test("rtc_media_unittests") { - testonly = true - - defines = [] - deps = [ - ":rtc_audio_video", - ":rtc_constants", - ":rtc_data", - ":rtc_encoder_simulcast_proxy", - ":rtc_internal_video_codecs", - ":rtc_media", - ":rtc_media_base", - ":rtc_media_engine_defaults", - ":rtc_media_tests_utils", - ":rtc_sdp_fmtp_utils", - ":rtc_simulcast_encoder_adapter", - ":rtc_vp9_profile", - "../api:create_simulcast_test_fixture_api", - "../api:libjingle_peerconnection_api", - "../api:mock_video_bitrate_allocator", - "../api:mock_video_bitrate_allocator_factory", - "../api:mock_video_codec_factory", - "../api:mock_video_encoder", - "../api:rtp_parameters", - "../api:scoped_refptr", - "../api:simulcast_test_fixture_api", - "../api/audio_codecs:builtin_audio_decoder_factory", - "../api/audio_codecs:builtin_audio_encoder_factory", - "../api/rtc_event_log", - "../api/task_queue", - "../api/task_queue:default_task_queue_factory", - "../api/test/video:function_video_factory", - "../api/transport:field_trial_based_config", - "../api/units:time_delta", - "../api/video:builtin_video_bitrate_allocator_factory", - "../api/video:video_bitrate_allocation", - "../api/video:video_frame", - "../api/video:video_rtp_headers", - "../api/video_codecs:builtin_video_decoder_factory", - "../api/video_codecs:builtin_video_encoder_factory", - "../api/video_codecs:video_codecs_api", - "../audio", - "../call:call_interfaces", - "../common_video", - "../media:rtc_h264_profile_id", - "../modules/audio_device:mock_audio_device", - "../modules/audio_processing", - "../modules/audio_processing:api", - "../modules/audio_processing:mocks", - "../modules/rtp_rtcp", - "../modules/video_coding:simulcast_test_fixture_impl", - "../modules/video_coding:video_codec_interface", - "../modules/video_coding:webrtc_h264", - "../modules/video_coding:webrtc_vp8", - "../modules/video_coding/codecs/av1:libaom_av1_decoder", - "../p2p:p2p_test_utils", - "../rtc_base", - "../rtc_base:checks", - "../rtc_base:gunit_helpers", - "../rtc_base:rtc_base_approved", - "../rtc_base:rtc_base_tests_utils", - "../rtc_base:rtc_task_queue", - "../rtc_base:stringutils", - "../rtc_base/experiments:min_video_bitrate_experiment", - "../rtc_base/synchronization:mutex", - "../rtc_base/third_party/sigslot", - "../test:audio_codec_mocks", - "../test:fake_video_codecs", - "../test:field_trial", - "../test:rtp_test_utils", - "../test:test_main", - "../test:test_support", - "../test:video_test_common", - "//third_party/abseil-cpp/absl/algorithm:container", - "//third_party/abseil-cpp/absl/memory", - "//third_party/abseil-cpp/absl/strings", - "//third_party/abseil-cpp/absl/types:optional", - ] - sources = [ - "base/codec_unittest.cc", - "base/media_engine_unittest.cc", - "base/rtp_data_engine_unittest.cc", - "base/rtp_utils_unittest.cc", - "base/sdp_fmtp_utils_unittest.cc", - "base/stream_params_unittest.cc", - "base/turn_utils_unittest.cc", - "base/video_adapter_unittest.cc", - "base/video_broadcaster_unittest.cc", - "base/video_common_unittest.cc", - "engine/encoder_simulcast_proxy_unittest.cc", - "engine/internal_decoder_factory_unittest.cc", - "engine/multiplex_codec_factory_unittest.cc", - "engine/null_webrtc_video_engine_unittest.cc", - "engine/payload_type_mapper_unittest.cc", - "engine/simulcast_encoder_adapter_unittest.cc", - "engine/simulcast_unittest.cc", - "engine/unhandled_packets_buffer_unittest.cc", - "engine/webrtc_media_engine_unittest.cc", - "engine/webrtc_video_engine_unittest.cc", + if (!build_with_chromium) { + rtc_media_unittests_resources = [ + "../resources/media/captured-320x240-2s-48.frames", + "../resources/media/faces.1280x720_P420.yuv", + "../resources/media/faces_I400.jpg", + "../resources/media/faces_I411.jpg", + "../resources/media/faces_I420.jpg", + "../resources/media/faces_I422.jpg", + "../resources/media/faces_I444.jpg", ] - # TODO(kthelgason): Reenable this test on iOS. - # See bugs.webrtc.org/5569 - if (!is_ios) { - sources += [ "engine/webrtc_voice_engine_unittest.cc" ] - } - - if (rtc_enable_sctp) { - sources += [ - "sctp/sctp_transport_reliability_unittest.cc", - "sctp/sctp_transport_unittest.cc", - ] - } - - if (rtc_opus_support_120ms_ptime) { - defines += [ "WEBRTC_OPUS_SUPPORT_120MS_PTIME=1" ] - } else { - defines += [ "WEBRTC_OPUS_SUPPORT_120MS_PTIME=0" ] + if (is_ios) { + bundle_data("rtc_media_unittests_bundle_data") { + testonly = true + sources = rtc_media_unittests_resources + outputs = [ "{{bundle_resources_dir}}/{{source_file_part}}" ] + } } - data = rtc_media_unittests_resources - - if (is_android) { - deps += [ "//testing/android/native_test:native_test_support" ] - shard_timeout = 900 - } + rtc_test("rtc_media_unittests") { + testonly = true - if (is_ios) { - deps += [ ":rtc_media_unittests_bundle_data" ] - } + defines = [] + deps = [ + ":rtc_audio_video", + ":rtc_encoder_simulcast_proxy", + ":rtc_internal_video_codecs", + ":rtc_media", + ":rtc_media_base", + ":rtc_media_engine_defaults", + ":rtc_media_tests_utils", + ":rtc_sdp_video_format_utils", + ":rtc_simulcast_encoder_adapter", + "../api:create_simulcast_test_fixture_api", + "../api:libjingle_peerconnection_api", + "../api:mock_video_bitrate_allocator", + "../api:mock_video_bitrate_allocator_factory", + "../api:mock_video_codec_factory", + "../api:mock_video_encoder", + "../api:rtp_parameters", + "../api:scoped_refptr", + "../api:simulcast_test_fixture_api", + "../api/audio_codecs:builtin_audio_decoder_factory", + "../api/audio_codecs:builtin_audio_encoder_factory", + "../api/rtc_event_log", + "../api/task_queue", + "../api/task_queue:default_task_queue_factory", + "../api/test/video:function_video_factory", + "../api/transport:field_trial_based_config", + "../api/units:time_delta", + "../api/video:builtin_video_bitrate_allocator_factory", + "../api/video:video_bitrate_allocation", + "../api/video:video_codec_constants", + "../api/video:video_frame", + "../api/video:video_rtp_headers", + "../api/video_codecs:builtin_video_decoder_factory", + "../api/video_codecs:builtin_video_encoder_factory", + "../api/video_codecs:video_codecs_api", + "../audio", + "../call:call_interfaces", + "../common_video", + "../modules/audio_device:mock_audio_device", + "../modules/audio_processing", + "../modules/audio_processing:api", + "../modules/audio_processing:mocks", + "../modules/rtp_rtcp", + "../modules/rtp_rtcp:rtp_rtcp_format", + "../modules/video_coding:simulcast_test_fixture_impl", + "../modules/video_coding:video_codec_interface", + "../modules/video_coding:webrtc_h264", + "../modules/video_coding:webrtc_vp8", + "../modules/video_coding/codecs/av1:libaom_av1_decoder", + "../p2p:p2p_test_utils", + "../rtc_base", + "../rtc_base:checks", + "../rtc_base:gunit_helpers", + "../rtc_base:rtc_base_approved", + "../rtc_base:rtc_base_tests_utils", + "../rtc_base:rtc_task_queue", + "../rtc_base:stringutils", + "../rtc_base:threading", + "../rtc_base/experiments:min_video_bitrate_experiment", + "../rtc_base/synchronization:mutex", + "../rtc_base/third_party/sigslot", + "../system_wrappers:field_trial", + "../test:audio_codec_mocks", + "../test:fake_video_codecs", + "../test:field_trial", + "../test:rtp_test_utils", + "../test:test_main", + "../test:test_support", + "../test:video_test_common", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + sources = [ + "base/codec_unittest.cc", + "base/media_engine_unittest.cc", + "base/rtp_utils_unittest.cc", + "base/sdp_video_format_utils_unittest.cc", + "base/stream_params_unittest.cc", + "base/turn_utils_unittest.cc", + "base/video_adapter_unittest.cc", + "base/video_broadcaster_unittest.cc", + "base/video_common_unittest.cc", + "engine/encoder_simulcast_proxy_unittest.cc", + "engine/internal_decoder_factory_unittest.cc", + "engine/multiplex_codec_factory_unittest.cc", + "engine/null_webrtc_video_engine_unittest.cc", + "engine/payload_type_mapper_unittest.cc", + "engine/simulcast_encoder_adapter_unittest.cc", + "engine/simulcast_unittest.cc", + "engine/unhandled_packets_buffer_unittest.cc", + "engine/webrtc_media_engine_unittest.cc", + "engine/webrtc_video_engine_unittest.cc", + ] - if (rtc_enable_sctp && rtc_build_usrsctp) { - deps += [ "//third_party/usrsctp" ] + # TODO(kthelgason): Reenable this test on iOS. + # See bugs.webrtc.org/5569 + if (!is_ios) { + sources += [ "engine/webrtc_voice_engine_unittest.cc" ] + } + + if (rtc_build_usrsctp) { + sources += [ + "sctp/usrsctp_transport_reliability_unittest.cc", + "sctp/usrsctp_transport_unittest.cc", + ] + deps += [ + ":rtc_data_sctp_transport_internal", + ":rtc_data_usrsctp_transport", + "../rtc_base:rtc_event", + "../rtc_base/task_utils:pending_task_safety_flag", + "../rtc_base/task_utils:to_queued_task", + "//third_party/usrsctp", + ] + } + + if (rtc_opus_support_120ms_ptime) { + defines += [ "WEBRTC_OPUS_SUPPORT_120MS_PTIME=1" ] + } else { + defines += [ "WEBRTC_OPUS_SUPPORT_120MS_PTIME=0" ] + } + + data = rtc_media_unittests_resources + + if (is_android) { + deps += [ "//testing/android/native_test:native_test_support" ] + shard_timeout = 900 + } + + if (is_ios) { + deps += [ ":rtc_media_unittests_bundle_data" ] + } } } } diff --git a/media/DEPS b/media/DEPS index 5b4d9f93b5..127e3ef081 100644 --- a/media/DEPS +++ b/media/DEPS @@ -11,6 +11,7 @@ include_rules = [ "+modules/video_capture", "+modules/video_coding", "+modules/video_coding/utility", + "+net/dcsctp", "+p2p", "+sound", "+system_wrappers", diff --git a/media/base/codec.cc b/media/base/codec.cc index ab39592e24..cb6913e76a 100644 --- a/media/base/codec.cc +++ b/media/base/codec.cc @@ -12,8 +12,8 @@ #include "absl/algorithm/container.h" #include "absl/strings/match.h" -#include "media/base/h264_profile_level_id.h" -#include "media/base/vp9_profile.h" +#include "api/video_codecs/h264_profile_level_id.h" +#include "api/video_codecs/vp9_profile.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/string_encode.h" @@ -51,25 +51,13 @@ bool IsSameCodecSpecific(const std::string& name1, absl::EqualsIgnoreCase(name, name2); }; if (either_name_matches(kH264CodecName)) - return webrtc::H264::IsSameH264Profile(params1, params2) && + return webrtc::H264IsSameProfile(params1, params2) && IsSameH264PacketizationMode(params1, params2); if (either_name_matches(kVp9CodecName)) - return webrtc::IsSameVP9Profile(params1, params2); + return webrtc::VP9IsSameProfile(params1, params2); return true; } -bool IsCodecInList( - const webrtc::SdpVideoFormat& format, - const std::vector& existing_formats) { - for (auto existing_format : existing_formats) { - if (IsSameCodec(format.name, format.parameters, existing_format.name, - existing_format.parameters)) { - return true; - } - } - return false; -} - } // namespace FeedbackParams::FeedbackParams() = default; @@ -396,25 +384,6 @@ bool VideoCodec::ValidateCodecFormat() const { return true; } -RtpDataCodec::RtpDataCodec(int id, const std::string& name) - : Codec(id, name, kDataCodecClockrate) {} - -RtpDataCodec::RtpDataCodec() : Codec() { - clockrate = kDataCodecClockrate; -} - -RtpDataCodec::RtpDataCodec(const RtpDataCodec& c) = default; -RtpDataCodec::RtpDataCodec(RtpDataCodec&& c) = default; -RtpDataCodec& RtpDataCodec::operator=(const RtpDataCodec& c) = default; -RtpDataCodec& RtpDataCodec::operator=(RtpDataCodec&& c) = default; - -std::string RtpDataCodec::ToString() const { - char buf[256]; - rtc::SimpleStringBuilder sb(buf); - sb << "RtpDataCodec[" << id << ":" << name << "]"; - return sb.str(); -} - bool HasLntf(const Codec& codec) { return codec.HasFeedbackParam( FeedbackParam(kRtcpFbParamLntf, kParamValueEmpty)); @@ -452,6 +421,8 @@ const VideoCodec* FindMatchingCodec( return nullptr; } +// TODO(crbug.com/1187565): Remove once downstream projects stopped using this +// method in favor of SdpVideoFormat::IsSameCodec(). bool IsSameCodec(const std::string& name1, const CodecParameterMap& params1, const std::string& name2, @@ -473,15 +444,16 @@ void AddH264ConstrainedBaselineProfileToSupportedFormats( for (auto it = supported_formats->cbegin(); it != supported_formats->cend(); ++it) { if (it->name == cricket::kH264CodecName) { - const absl::optional profile_level_id = - webrtc::H264::ParseSdpProfileLevelId(it->parameters); - if (profile_level_id && profile_level_id->profile != - webrtc::H264::kProfileConstrainedBaseline) { + const absl::optional profile_level_id = + webrtc::ParseSdpForH264ProfileLevelId(it->parameters); + if (profile_level_id && + profile_level_id->profile != + webrtc::H264Profile::kProfileConstrainedBaseline) { webrtc::SdpVideoFormat cbp_format = *it; - webrtc::H264::ProfileLevelId cbp_profile = *profile_level_id; - cbp_profile.profile = webrtc::H264::kProfileConstrainedBaseline; + webrtc::H264ProfileLevelId cbp_profile = *profile_level_id; + cbp_profile.profile = webrtc::H264Profile::kProfileConstrainedBaseline; cbp_format.parameters[cricket::kH264FmtpProfileLevelId] = - *webrtc::H264::ProfileLevelIdToString(cbp_profile); + *webrtc::H264ProfileLevelIdToString(cbp_profile); cbr_supported_formats.push_back(cbp_format); } } @@ -492,7 +464,7 @@ void AddH264ConstrainedBaselineProfileToSupportedFormats( std::copy_if(cbr_supported_formats.begin(), cbr_supported_formats.end(), std::back_inserter(*supported_formats), [supported_formats](const webrtc::SdpVideoFormat& format) { - return !IsCodecInList(format, *supported_formats); + return !format.IsCodecInList(*supported_formats); }); if (supported_formats->size() > original_size) { diff --git a/media/base/codec.h b/media/base/codec.h index c3be2334ce..c7c99bf732 100644 --- a/media/base/codec.h +++ b/media/base/codec.h @@ -202,23 +202,6 @@ struct RTC_EXPORT VideoCodec : public Codec { void SetDefaultParameters(); }; -struct RtpDataCodec : public Codec { - RtpDataCodec(int id, const std::string& name); - RtpDataCodec(); - RtpDataCodec(const RtpDataCodec& c); - RtpDataCodec(RtpDataCodec&& c); - ~RtpDataCodec() override = default; - - RtpDataCodec& operator=(const RtpDataCodec& c); - RtpDataCodec& operator=(RtpDataCodec&& c); - - std::string ToString() const; -}; - -// For backwards compatibility -// TODO(bugs.webrtc.org/10597): Remove when no longer needed. -typedef RtpDataCodec DataCodec; - // Get the codec setting associated with |payload_type|. If there // is no codec associated with that payload type it returns nullptr. template diff --git a/media/base/codec_unittest.cc b/media/base/codec_unittest.cc index d41ed9bdf9..23bae7b7fe 100644 --- a/media/base/codec_unittest.cc +++ b/media/base/codec_unittest.cc @@ -12,14 +12,13 @@ #include -#include "media/base/h264_profile_level_id.h" -#include "media/base/vp9_profile.h" +#include "api/video_codecs/h264_profile_level_id.h" +#include "api/video_codecs/vp9_profile.h" #include "modules/video_coding/codecs/h264/include/h264.h" #include "rtc_base/gunit.h" using cricket::AudioCodec; using cricket::Codec; -using cricket::DataCodec; using cricket::FeedbackParam; using cricket::kCodecParamAssociatedPayloadType; using cricket::kCodecParamMaxBitrate; @@ -31,7 +30,8 @@ class TestCodec : public Codec { TestCodec(int id, const std::string& name, int clockrate) : Codec(id, name, clockrate) {} TestCodec() : Codec() {} - TestCodec(const TestCodec& c) : Codec(c) {} + TestCodec(const TestCodec& c) = default; + TestCodec& operator=(const TestCodec& c) = default; }; TEST(CodecTest, TestCodecOperators) { @@ -303,27 +303,6 @@ TEST(CodecTest, TestH264CodecMatches) { } } -TEST(CodecTest, TestDataCodecMatches) { - // Test a codec with a static payload type. - DataCodec c0(34, "D"); - EXPECT_TRUE(c0.Matches(DataCodec(34, ""))); - EXPECT_FALSE(c0.Matches(DataCodec(96, "D"))); - EXPECT_FALSE(c0.Matches(DataCodec(96, ""))); - - // Test a codec with a dynamic payload type. - DataCodec c1(96, "D"); - EXPECT_TRUE(c1.Matches(DataCodec(96, "D"))); - EXPECT_TRUE(c1.Matches(DataCodec(97, "D"))); - EXPECT_TRUE(c1.Matches(DataCodec(96, "d"))); - EXPECT_TRUE(c1.Matches(DataCodec(97, "d"))); - EXPECT_TRUE(c1.Matches(DataCodec(35, "d"))); - EXPECT_TRUE(c1.Matches(DataCodec(42, "d"))); - EXPECT_TRUE(c1.Matches(DataCodec(63, "d"))); - EXPECT_FALSE(c1.Matches(DataCodec(96, ""))); - EXPECT_FALSE(c1.Matches(DataCodec(95, "D"))); - EXPECT_FALSE(c1.Matches(DataCodec(34, "D"))); -} - TEST(CodecTest, TestSetParamGetParamAndRemoveParam) { AudioCodec codec; codec.SetParam("a", "1"); @@ -457,10 +436,10 @@ TEST(CodecTest, TestToCodecParameters) { TEST(CodecTest, H264CostrainedBaselineIsAddedIfH264IsSupported) { const std::vector kExplicitlySupportedFormats = { - webrtc::CreateH264Format(webrtc::H264::kProfileBaseline, - webrtc::H264::kLevel3_1, "1"), - webrtc::CreateH264Format(webrtc::H264::kProfileBaseline, - webrtc::H264::kLevel3_1, "0")}; + webrtc::CreateH264Format(webrtc::H264Profile::kProfileBaseline, + webrtc::H264Level::kLevel3_1, "1"), + webrtc::CreateH264Format(webrtc::H264Profile::kProfileBaseline, + webrtc::H264Level::kLevel3_1, "0")}; std::vector supported_formats = kExplicitlySupportedFormats; @@ -468,11 +447,11 @@ TEST(CodecTest, H264CostrainedBaselineIsAddedIfH264IsSupported) { &supported_formats); const webrtc::SdpVideoFormat kH264ConstrainedBasedlinePacketization1 = - webrtc::CreateH264Format(webrtc::H264::kProfileConstrainedBaseline, - webrtc::H264::kLevel3_1, "1"); + webrtc::CreateH264Format(webrtc::H264Profile::kProfileConstrainedBaseline, + webrtc::H264Level::kLevel3_1, "1"); const webrtc::SdpVideoFormat kH264ConstrainedBasedlinePacketization0 = - webrtc::CreateH264Format(webrtc::H264::kProfileConstrainedBaseline, - webrtc::H264::kLevel3_1, "0"); + webrtc::CreateH264Format(webrtc::H264Profile::kProfileConstrainedBaseline, + webrtc::H264Level::kLevel3_1, "0"); EXPECT_EQ(supported_formats[0], kExplicitlySupportedFormats[0]); EXPECT_EQ(supported_formats[1], kExplicitlySupportedFormats[1]); @@ -497,14 +476,14 @@ TEST(CodecTest, H264CostrainedBaselineIsNotAddedIfH264IsUnsupported) { TEST(CodecTest, H264CostrainedBaselineNotAddedIfAlreadySpecified) { const std::vector kExplicitlySupportedFormats = { - webrtc::CreateH264Format(webrtc::H264::kProfileBaseline, - webrtc::H264::kLevel3_1, "1"), - webrtc::CreateH264Format(webrtc::H264::kProfileBaseline, - webrtc::H264::kLevel3_1, "0"), - webrtc::CreateH264Format(webrtc::H264::kProfileConstrainedBaseline, - webrtc::H264::kLevel3_1, "1"), - webrtc::CreateH264Format(webrtc::H264::kProfileConstrainedBaseline, - webrtc::H264::kLevel3_1, "0")}; + webrtc::CreateH264Format(webrtc::H264Profile::kProfileBaseline, + webrtc::H264Level::kLevel3_1, "1"), + webrtc::CreateH264Format(webrtc::H264Profile::kProfileBaseline, + webrtc::H264Level::kLevel3_1, "0"), + webrtc::CreateH264Format(webrtc::H264Profile::kProfileConstrainedBaseline, + webrtc::H264Level::kLevel3_1, "1"), + webrtc::CreateH264Format(webrtc::H264Profile::kProfileConstrainedBaseline, + webrtc::H264Level::kLevel3_1, "0")}; std::vector supported_formats = kExplicitlySupportedFormats; diff --git a/media/base/fake_media_engine.cc b/media/base/fake_media_engine.cc index 734a30be75..aa8e2325b6 100644 --- a/media/base/fake_media_engine.cc +++ b/media/base/fake_media_engine.cc @@ -18,6 +18,7 @@ #include "rtc_base/checks.h" namespace cricket { +using webrtc::TaskQueueBase; FakeVoiceMediaChannel::DtmfInfo::DtmfInfo(uint32_t ssrc, int event_code, @@ -49,8 +50,11 @@ AudioSource* FakeVoiceMediaChannel::VoiceChannelAudioSink::source() const { } FakeVoiceMediaChannel::FakeVoiceMediaChannel(FakeVoiceEngine* engine, - const AudioOptions& options) - : engine_(engine), max_bps_(-1) { + const AudioOptions& options, + TaskQueueBase* network_thread) + : RtpHelper(network_thread), + engine_(engine), + max_bps_(-1) { output_scalings_[0] = 1.0; // For default channel. SetOptions(options); } @@ -253,8 +257,11 @@ bool CompareDtmfInfo(const FakeVoiceMediaChannel::DtmfInfo& info, } FakeVideoMediaChannel::FakeVideoMediaChannel(FakeVideoEngine* engine, - const VideoOptions& options) - : engine_(engine), max_bps_(-1) { + const VideoOptions& options, + TaskQueueBase* network_thread) + : RtpHelper(network_thread), + engine_(engine), + max_bps_(-1) { SetOptions(options); } FakeVideoMediaChannel::~FakeVideoMediaChannel() { @@ -422,93 +429,6 @@ void FakeVideoMediaChannel::ClearRecordableEncodedFrameCallback(uint32_t ssrc) { void FakeVideoMediaChannel::GenerateKeyFrame(uint32_t ssrc) {} -FakeDataMediaChannel::FakeDataMediaChannel(void* unused, - const DataOptions& options) - : send_blocked_(false), max_bps_(-1) {} -FakeDataMediaChannel::~FakeDataMediaChannel() {} -const std::vector& FakeDataMediaChannel::recv_codecs() const { - return recv_codecs_; -} -const std::vector& FakeDataMediaChannel::send_codecs() const { - return send_codecs_; -} -const std::vector& FakeDataMediaChannel::codecs() const { - return send_codecs(); -} -int FakeDataMediaChannel::max_bps() const { - return max_bps_; -} -bool FakeDataMediaChannel::SetSendParameters(const DataSendParameters& params) { - set_send_rtcp_parameters(params.rtcp); - return (SetSendCodecs(params.codecs) && - SetMaxSendBandwidth(params.max_bandwidth_bps)); -} -bool FakeDataMediaChannel::SetRecvParameters(const DataRecvParameters& params) { - set_recv_rtcp_parameters(params.rtcp); - return SetRecvCodecs(params.codecs); -} -bool FakeDataMediaChannel::SetSend(bool send) { - return set_sending(send); -} -bool FakeDataMediaChannel::SetReceive(bool receive) { - set_playout(receive); - return true; -} -bool FakeDataMediaChannel::AddRecvStream(const StreamParams& sp) { - if (!RtpHelper::AddRecvStream(sp)) - return false; - return true; -} -bool FakeDataMediaChannel::RemoveRecvStream(uint32_t ssrc) { - if (!RtpHelper::RemoveRecvStream(ssrc)) - return false; - return true; -} -bool FakeDataMediaChannel::SendData(const SendDataParams& params, - const rtc::CopyOnWriteBuffer& payload, - SendDataResult* result) { - if (send_blocked_) { - *result = SDR_BLOCK; - return false; - } else { - last_sent_data_params_ = params; - last_sent_data_ = std::string(payload.data(), payload.size()); - return true; - } -} -SendDataParams FakeDataMediaChannel::last_sent_data_params() { - return last_sent_data_params_; -} -std::string FakeDataMediaChannel::last_sent_data() { - return last_sent_data_; -} -bool FakeDataMediaChannel::is_send_blocked() { - return send_blocked_; -} -void FakeDataMediaChannel::set_send_blocked(bool blocked) { - send_blocked_ = blocked; -} -bool FakeDataMediaChannel::SetRecvCodecs(const std::vector& codecs) { - if (fail_set_recv_codecs()) { - // Fake the failure in SetRecvCodecs. - return false; - } - recv_codecs_ = codecs; - return true; -} -bool FakeDataMediaChannel::SetSendCodecs(const std::vector& codecs) { - if (fail_set_send_codecs()) { - // Fake the failure in SetSendCodecs. - return false; - } - send_codecs_ = codecs; - return true; -} -bool FakeDataMediaChannel::SetMaxSendBandwidth(int bps) { - max_bps_ = bps; - return true; -} - FakeVoiceEngine::FakeVoiceEngine() : fail_create_channel_(false) { // Add a fake audio codec. Note that the name must not be "" as there are // sanity checks against that. @@ -527,7 +447,8 @@ VoiceMediaChannel* FakeVoiceEngine::CreateMediaChannel( return nullptr; } - FakeVoiceMediaChannel* ch = new FakeVoiceMediaChannel(this, options); + FakeVoiceMediaChannel* ch = + new FakeVoiceMediaChannel(this, options, call->network_thread()); channels_.push_back(ch); return ch; } @@ -593,7 +514,8 @@ VideoMediaChannel* FakeVideoEngine::CreateMediaChannel( return nullptr; } - FakeVideoMediaChannel* ch = new FakeVideoMediaChannel(this, options); + FakeVideoMediaChannel* ch = + new FakeVideoMediaChannel(this, options, call->network_thread()); channels_.emplace_back(ch); return ch; } @@ -668,22 +590,4 @@ void FakeMediaEngine::set_fail_create_channel(bool fail) { video_->fail_create_channel_ = fail; } -DataMediaChannel* FakeDataEngine::CreateChannel(const MediaConfig& config) { - FakeDataMediaChannel* ch = new FakeDataMediaChannel(this, DataOptions()); - channels_.push_back(ch); - return ch; -} -FakeDataMediaChannel* FakeDataEngine::GetChannel(size_t index) { - return (channels_.size() > index) ? channels_[index] : NULL; -} -void FakeDataEngine::UnregisterChannel(DataMediaChannel* channel) { - channels_.erase(absl::c_find(channels_, channel)); -} -void FakeDataEngine::SetDataCodecs(const std::vector& data_codecs) { - data_codecs_ = data_codecs; -} -const std::vector& FakeDataEngine::data_codecs() { - return data_codecs_; -} - } // namespace cricket diff --git a/media/base/fake_media_engine.h b/media/base/fake_media_engine.h index 42940bf1b4..e4f7b6659f 100644 --- a/media/base/fake_media_engine.h +++ b/media/base/fake_media_engine.h @@ -11,6 +11,7 @@ #ifndef MEDIA_BASE_FAKE_MEDIA_ENGINE_H_ #define MEDIA_BASE_FAKE_MEDIA_ENGINE_H_ +#include #include #include #include @@ -42,8 +43,9 @@ class FakeVoiceEngine; template class RtpHelper : public Base { public: - RtpHelper() - : sending_(false), + explicit RtpHelper(webrtc::TaskQueueBase* network_thread) + : Base(network_thread), + sending_(false), playout_(false), fail_set_send_codecs_(false), fail_set_recv_codecs_(false), @@ -118,6 +120,8 @@ class RtpHelper : public Base { return RemoveStreamBySsrc(&send_streams_, ssrc); } virtual void ResetUnsignaledRecvStream() {} + virtual void OnDemuxerCriteriaUpdatePending() {} + virtual void OnDemuxerCriteriaUpdateComplete() {} virtual bool AddRecvStream(const StreamParams& sp) { if (absl::c_linear_search(receive_streams_, sp)) { @@ -265,14 +269,14 @@ class RtpHelper : public Base { void set_recv_rtcp_parameters(const RtcpParameters& params) { recv_rtcp_parameters_ = params; } - virtual void OnPacketReceived(rtc::CopyOnWriteBuffer packet, - int64_t packet_time_us) { + void OnPacketReceived(rtc::CopyOnWriteBuffer packet, + int64_t packet_time_us) override { rtp_packets_.push_back(std::string(packet.cdata(), packet.size())); } - virtual void OnReadyToSend(bool ready) { ready_to_send_ = ready; } - - virtual void OnNetworkRouteChanged(const std::string& transport_name, - const rtc::NetworkRoute& network_route) { + void OnPacketSent(const rtc::SentPacket& sent_packet) override {} + void OnReadyToSend(bool ready) override { ready_to_send_ = ready; } + void OnNetworkRouteChanged(const std::string& transport_name, + const rtc::NetworkRoute& network_route) override { last_network_route_ = network_route; ++num_network_route_changes_; transport_overhead_per_packet_ = network_route.packet_overhead; @@ -281,7 +285,10 @@ class RtpHelper : public Base { bool fail_set_recv_codecs() const { return fail_set_recv_codecs_; } private: - bool sending_; + // TODO(bugs.webrtc.org/12783): This flag is used from more than one thread. + // As a workaround for tsan, it's currently std::atomic but that might not + // be the appropriate fix. + std::atomic sending_; bool playout_; std::vector recv_extensions_; std::vector send_extensions_; @@ -312,8 +319,9 @@ class FakeVoiceMediaChannel : public RtpHelper { int event_code; int duration; }; - explicit FakeVoiceMediaChannel(FakeVoiceEngine* engine, - const AudioOptions& options); + FakeVoiceMediaChannel(FakeVoiceEngine* engine, + const AudioOptions& options, + webrtc::TaskQueueBase* network_thread); ~FakeVoiceMediaChannel(); const std::vector& recv_codecs() const; const std::vector& send_codecs() const; @@ -404,7 +412,9 @@ bool CompareDtmfInfo(const FakeVoiceMediaChannel::DtmfInfo& info, class FakeVideoMediaChannel : public RtpHelper { public: - FakeVideoMediaChannel(FakeVideoEngine* engine, const VideoOptions& options); + FakeVideoMediaChannel(FakeVideoEngine* engine, + const VideoOptions& options, + webrtc::TaskQueueBase* network_thread); ~FakeVideoMediaChannel(); @@ -470,48 +480,6 @@ class FakeVideoMediaChannel : public RtpHelper { int max_bps_; }; -// Dummy option class, needed for the DataTraits abstraction in -// channel_unittest.c. -class DataOptions {}; - -class FakeDataMediaChannel : public RtpHelper { - public: - explicit FakeDataMediaChannel(void* unused, const DataOptions& options); - ~FakeDataMediaChannel(); - const std::vector& recv_codecs() const; - const std::vector& send_codecs() const; - const std::vector& codecs() const; - int max_bps() const; - - bool SetSendParameters(const DataSendParameters& params) override; - bool SetRecvParameters(const DataRecvParameters& params) override; - bool SetSend(bool send) override; - bool SetReceive(bool receive) override; - bool AddRecvStream(const StreamParams& sp) override; - bool RemoveRecvStream(uint32_t ssrc) override; - - bool SendData(const SendDataParams& params, - const rtc::CopyOnWriteBuffer& payload, - SendDataResult* result) override; - - SendDataParams last_sent_data_params(); - std::string last_sent_data(); - bool is_send_blocked(); - void set_send_blocked(bool blocked); - - private: - bool SetRecvCodecs(const std::vector& codecs); - bool SetSendCodecs(const std::vector& codecs); - bool SetMaxSendBandwidth(int bps); - - std::vector recv_codecs_; - std::vector send_codecs_; - SendDataParams last_sent_data_params_; - std::string last_sent_data_; - bool send_blocked_; - int max_bps_; -}; - class FakeVoiceEngine : public VoiceEngineInterface { public: FakeVoiceEngine(); @@ -607,25 +575,6 @@ class FakeMediaEngine : public CompositeMediaEngine { FakeVideoEngine* const video_; }; -// Have to come afterwards due to declaration order - -class FakeDataEngine : public DataEngineInterface { - public: - DataMediaChannel* CreateChannel(const MediaConfig& config) override; - - FakeDataMediaChannel* GetChannel(size_t index); - - void UnregisterChannel(DataMediaChannel* channel); - - void SetDataCodecs(const std::vector& data_codecs); - - const std::vector& data_codecs() override; - - private: - std::vector channels_; - std::vector data_codecs_; -}; - } // namespace cricket #endif // MEDIA_BASE_FAKE_MEDIA_ENGINE_H_ diff --git a/media/base/fake_network_interface.h b/media/base/fake_network_interface.h index 02d53f6781..45b7aa0fc0 100644 --- a/media/base/fake_network_interface.h +++ b/media/base/fake_network_interface.h @@ -18,6 +18,7 @@ #include "media/base/media_channel.h" #include "media/base/rtp_utils.h" #include "rtc_base/byte_order.h" +#include "rtc_base/checks.h" #include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/dscp.h" #include "rtc_base/message_handler.h" @@ -83,14 +84,12 @@ class FakeNetworkInterface : public MediaChannel::NetworkInterface, return static_cast(sent_ssrcs_.size()); } - // Note: callers are responsible for deleting the returned buffer. - const rtc::CopyOnWriteBuffer* GetRtpPacket(int index) - RTC_LOCKS_EXCLUDED(mutex_) { + rtc::CopyOnWriteBuffer GetRtpPacket(int index) RTC_LOCKS_EXCLUDED(mutex_) { webrtc::MutexLock lock(&mutex_); if (index >= static_cast(rtp_packets_.size())) { - return NULL; + return {}; } - return new rtc::CopyOnWriteBuffer(rtp_packets_[index]); + return rtp_packets_[index]; } int NumRtcpPackets() RTC_LOCKS_EXCLUDED(mutex_) { @@ -129,10 +128,7 @@ class FakeNetworkInterface : public MediaChannel::NetworkInterface, rtp_packets_.push_back(*packet); if (conf_) { for (size_t i = 0; i < conf_sent_ssrcs_.size(); ++i) { - if (!SetRtpSsrc(packet->MutableData(), packet->size(), - conf_sent_ssrcs_[i])) { - return false; - } + SetRtpSsrc(conf_sent_ssrcs_[i], *packet); PostMessage(ST_RTP, *packet); } } else { @@ -184,6 +180,11 @@ class FakeNetworkInterface : public MediaChannel::NetworkInterface, } private: + void SetRtpSsrc(uint32_t ssrc, rtc::CopyOnWriteBuffer& buffer) { + RTC_CHECK_GE(buffer.size(), 12); + rtc::SetBE32(buffer.MutableData() + 8, ssrc); + } + void GetNumRtpBytesAndPackets(uint32_t ssrc, int* bytes, int* packets) { if (bytes) { *bytes = 0; diff --git a/media/base/h264_profile_level_id.cc b/media/base/h264_profile_level_id.cc index 32fa02c143..6f9fa46694 100644 --- a/media/base/h264_profile_level_id.cc +++ b/media/base/h264_profile_level_id.cc @@ -10,301 +10,33 @@ #include "media/base/h264_profile_level_id.h" -#include -#include -#include - -#include "rtc_base/arraysize.h" -#include "rtc_base/checks.h" +// TODO(crbug.com/1187565): Remove this file once downstream projects stop +// depend on it. namespace webrtc { namespace H264 { -namespace { - -const char kProfileLevelId[] = "profile-level-id"; -const char kLevelAsymmetryAllowed[] = "level-asymmetry-allowed"; - -// For level_idc=11 and profile_idc=0x42, 0x4D, or 0x58, the constraint set3 -// flag specifies if level 1b or level 1.1 is used. -const uint8_t kConstraintSet3Flag = 0x10; - -// Convert a string of 8 characters into a byte where the positions containing -// character c will have their bit set. For example, c = 'x', str = "x1xx0000" -// will return 0b10110000. constexpr is used so that the pattern table in -// kProfilePatterns is statically initialized. -constexpr uint8_t ByteMaskString(char c, const char (&str)[9]) { - return (str[0] == c) << 7 | (str[1] == c) << 6 | (str[2] == c) << 5 | - (str[3] == c) << 4 | (str[4] == c) << 3 | (str[5] == c) << 2 | - (str[6] == c) << 1 | (str[7] == c) << 0; -} - -// Class for matching bit patterns such as "x1xx0000" where 'x' is allowed to be -// either 0 or 1. -class BitPattern { - public: - explicit constexpr BitPattern(const char (&str)[9]) - : mask_(~ByteMaskString('x', str)), - masked_value_(ByteMaskString('1', str)) {} - - bool IsMatch(uint8_t value) const { return masked_value_ == (value & mask_); } - - private: - const uint8_t mask_; - const uint8_t masked_value_; -}; - -// Table for converting between profile_idc/profile_iop to H264::Profile. -struct ProfilePattern { - const uint8_t profile_idc; - const BitPattern profile_iop; - const Profile profile; -}; - -// This is from https://tools.ietf.org/html/rfc6184#section-8.1. -constexpr ProfilePattern kProfilePatterns[] = { - {0x42, BitPattern("x1xx0000"), kProfileConstrainedBaseline}, - {0x4D, BitPattern("1xxx0000"), kProfileConstrainedBaseline}, - {0x58, BitPattern("11xx0000"), kProfileConstrainedBaseline}, - {0x42, BitPattern("x0xx0000"), kProfileBaseline}, - {0x58, BitPattern("10xx0000"), kProfileBaseline}, - {0x4D, BitPattern("0x0x0000"), kProfileMain}, - {0x64, BitPattern("00000000"), kProfileHigh}, - {0x64, BitPattern("00001100"), kProfileConstrainedHigh}}; - -// Compare H264 levels and handle the level 1b case. -bool IsLess(Level a, Level b) { - if (a == kLevel1_b) - return b != kLevel1 && b != kLevel1_b; - if (b == kLevel1_b) - return a == kLevel1; - return a < b; -} - -Level Min(Level a, Level b) { - return IsLess(a, b) ? a : b; -} - -bool IsLevelAsymmetryAllowed(const CodecParameterMap& params) { - const auto it = params.find(kLevelAsymmetryAllowed); - return it != params.end() && strcmp(it->second.c_str(), "1") == 0; -} - -struct LevelConstraint { - const int max_macroblocks_per_second; - const int max_macroblock_frame_size; - const webrtc::H264::Level level; -}; - -// This is from ITU-T H.264 (02/2016) Table A-1 – Level limits. -static constexpr LevelConstraint kLevelConstraints[] = { - {1485, 99, webrtc::H264::kLevel1}, - {1485, 99, webrtc::H264::kLevel1_b}, - {3000, 396, webrtc::H264::kLevel1_1}, - {6000, 396, webrtc::H264::kLevel1_2}, - {11880, 396, webrtc::H264::kLevel1_3}, - {11880, 396, webrtc::H264::kLevel2}, - {19800, 792, webrtc::H264::kLevel2_1}, - {20250, 1620, webrtc::H264::kLevel2_2}, - {40500, 1620, webrtc::H264::kLevel3}, - {108000, 3600, webrtc::H264::kLevel3_1}, - {216000, 5120, webrtc::H264::kLevel3_2}, - {245760, 8192, webrtc::H264::kLevel4}, - {245760, 8192, webrtc::H264::kLevel4_1}, - {522240, 8704, webrtc::H264::kLevel4_2}, - {589824, 22080, webrtc::H264::kLevel5}, - {983040, 36864, webrtc::H264::kLevel5_1}, - {2073600, 36864, webrtc::H264::kLevel5_2}, -}; - -} // anonymous namespace - absl::optional ParseProfileLevelId(const char* str) { - // The string should consist of 3 bytes in hexadecimal format. - if (strlen(str) != 6u) - return absl::nullopt; - const uint32_t profile_level_id_numeric = strtol(str, nullptr, 16); - if (profile_level_id_numeric == 0) - return absl::nullopt; - - // Separate into three bytes. - const uint8_t level_idc = - static_cast(profile_level_id_numeric & 0xFF); - const uint8_t profile_iop = - static_cast((profile_level_id_numeric >> 8) & 0xFF); - const uint8_t profile_idc = - static_cast((profile_level_id_numeric >> 16) & 0xFF); - - // Parse level based on level_idc and constraint set 3 flag. - Level level; - switch (level_idc) { - case kLevel1_1: - level = (profile_iop & kConstraintSet3Flag) != 0 ? kLevel1_b : kLevel1_1; - break; - case kLevel1: - case kLevel1_2: - case kLevel1_3: - case kLevel2: - case kLevel2_1: - case kLevel2_2: - case kLevel3: - case kLevel3_1: - case kLevel3_2: - case kLevel4: - case kLevel4_1: - case kLevel4_2: - case kLevel5: - case kLevel5_1: - case kLevel5_2: - level = static_cast(level_idc); - break; - default: - // Unrecognized level_idc. - return absl::nullopt; - } - - // Parse profile_idc/profile_iop into a Profile enum. - for (const ProfilePattern& pattern : kProfilePatterns) { - if (profile_idc == pattern.profile_idc && - pattern.profile_iop.IsMatch(profile_iop)) { - return ProfileLevelId(pattern.profile, level); - } - } - - // Unrecognized profile_idc/profile_iop combination. - return absl::nullopt; -} - -absl::optional SupportedLevel(int max_frame_pixel_count, float max_fps) { - static const int kPixelsPerMacroblock = 16 * 16; - - for (int i = arraysize(kLevelConstraints) - 1; i >= 0; --i) { - const LevelConstraint& level_constraint = kLevelConstraints[i]; - if (level_constraint.max_macroblock_frame_size * kPixelsPerMacroblock <= - max_frame_pixel_count && - level_constraint.max_macroblocks_per_second <= - max_fps * level_constraint.max_macroblock_frame_size) { - return level_constraint.level; - } - } - - // No level supported. - return absl::nullopt; + return webrtc::ParseH264ProfileLevelId(str); } absl::optional ParseSdpProfileLevelId( - const CodecParameterMap& params) { - // TODO(magjed): The default should really be kProfileBaseline and kLevel1 - // according to the spec: https://tools.ietf.org/html/rfc6184#section-8.1. In - // order to not break backwards compatibility with older versions of WebRTC - // where external codecs don't have any parameters, use - // kProfileConstrainedBaseline kLevel3_1 instead. This workaround will only be - // done in an interim period to allow external clients to update their code. - // http://crbug/webrtc/6337. - static const ProfileLevelId kDefaultProfileLevelId( - kProfileConstrainedBaseline, kLevel3_1); + const SdpVideoFormat::Parameters& params) { + return webrtc::ParseSdpForH264ProfileLevelId(params); +} - const auto profile_level_id_it = params.find(kProfileLevelId); - return (profile_level_id_it == params.end()) - ? kDefaultProfileLevelId - : ParseProfileLevelId(profile_level_id_it->second.c_str()); +absl::optional SupportedLevel(int max_frame_pixel_count, float max_fps) { + return webrtc::H264SupportedLevel(max_frame_pixel_count, max_fps); } absl::optional ProfileLevelIdToString( const ProfileLevelId& profile_level_id) { - // Handle special case level == 1b. - if (profile_level_id.level == kLevel1_b) { - switch (profile_level_id.profile) { - case kProfileConstrainedBaseline: - return {"42f00b"}; - case kProfileBaseline: - return {"42100b"}; - case kProfileMain: - return {"4d100b"}; - // Level 1b is not allowed for other profiles. - default: - return absl::nullopt; - } - } - - const char* profile_idc_iop_string; - switch (profile_level_id.profile) { - case kProfileConstrainedBaseline: - profile_idc_iop_string = "42e0"; - break; - case kProfileBaseline: - profile_idc_iop_string = "4200"; - break; - case kProfileMain: - profile_idc_iop_string = "4d00"; - break; - case kProfileConstrainedHigh: - profile_idc_iop_string = "640c"; - break; - case kProfileHigh: - profile_idc_iop_string = "6400"; - break; - // Unrecognized profile. - default: - return absl::nullopt; - } - - char str[7]; - snprintf(str, 7u, "%s%02x", profile_idc_iop_string, profile_level_id.level); - return {str}; -} - -// Set level according to https://tools.ietf.org/html/rfc6184#section-8.2.2. -void GenerateProfileLevelIdForAnswer( - const CodecParameterMap& local_supported_params, - const CodecParameterMap& remote_offered_params, - CodecParameterMap* answer_params) { - // If both local and remote haven't set profile-level-id, they are both using - // the default profile. In this case, don't set profile-level-id in answer - // either. - if (!local_supported_params.count(kProfileLevelId) && - !remote_offered_params.count(kProfileLevelId)) { - return; - } - - // Parse profile-level-ids. - const absl::optional local_profile_level_id = - ParseSdpProfileLevelId(local_supported_params); - const absl::optional remote_profile_level_id = - ParseSdpProfileLevelId(remote_offered_params); - // The local and remote codec must have valid and equal H264 Profiles. - RTC_DCHECK(local_profile_level_id); - RTC_DCHECK(remote_profile_level_id); - RTC_DCHECK_EQ(local_profile_level_id->profile, - remote_profile_level_id->profile); - - // Parse level information. - const bool level_asymmetry_allowed = - IsLevelAsymmetryAllowed(local_supported_params) && - IsLevelAsymmetryAllowed(remote_offered_params); - const Level local_level = local_profile_level_id->level; - const Level remote_level = remote_profile_level_id->level; - const Level min_level = Min(local_level, remote_level); - - // Determine answer level. When level asymmetry is not allowed, level upgrade - // is not allowed, i.e., the level in the answer must be equal to or lower - // than the level in the offer. - const Level answer_level = level_asymmetry_allowed ? local_level : min_level; - - // Set the resulting profile-level-id in the answer parameters. - (*answer_params)[kProfileLevelId] = *ProfileLevelIdToString( - ProfileLevelId(local_profile_level_id->profile, answer_level)); + return webrtc::H264ProfileLevelIdToString(profile_level_id); } -bool IsSameH264Profile(const CodecParameterMap& params1, - const CodecParameterMap& params2) { - const absl::optional profile_level_id = - webrtc::H264::ParseSdpProfileLevelId(params1); - const absl::optional other_profile_level_id = - webrtc::H264::ParseSdpProfileLevelId(params2); - // Compare H264 profiles, but not levels. - return profile_level_id && other_profile_level_id && - profile_level_id->profile == other_profile_level_id->profile; +bool IsSameH264Profile(const SdpVideoFormat::Parameters& params1, + const SdpVideoFormat::Parameters& params2) { + return webrtc::H264IsSameProfile(params1, params2); } } // namespace H264 diff --git a/media/base/h264_profile_level_id.h b/media/base/h264_profile_level_id.h index f0f7928a3a..c85709faa9 100644 --- a/media/base/h264_profile_level_id.h +++ b/media/base/h264_profile_level_id.h @@ -11,54 +11,45 @@ #ifndef MEDIA_BASE_H264_PROFILE_LEVEL_ID_H_ #define MEDIA_BASE_H264_PROFILE_LEVEL_ID_H_ -#include #include -#include "absl/types/optional.h" -#include "rtc_base/system/rtc_export.h" +#include "api/video_codecs/h264_profile_level_id.h" + +// TODO(crbug.com/1187565): Remove this file once downstream projects stop +// depend on it. namespace webrtc { namespace H264 { -enum Profile { - kProfileConstrainedBaseline, - kProfileBaseline, - kProfileMain, - kProfileConstrainedHigh, - kProfileHigh, -}; - -// Map containting SDP codec parameters. -typedef std::map CodecParameterMap; - -// All values are equal to ten times the level number, except level 1b which is -// special. -enum Level { - kLevel1_b = 0, - kLevel1 = 10, - kLevel1_1 = 11, - kLevel1_2 = 12, - kLevel1_3 = 13, - kLevel2 = 20, - kLevel2_1 = 21, - kLevel2_2 = 22, - kLevel3 = 30, - kLevel3_1 = 31, - kLevel3_2 = 32, - kLevel4 = 40, - kLevel4_1 = 41, - kLevel4_2 = 42, - kLevel5 = 50, - kLevel5_1 = 51, - kLevel5_2 = 52 -}; - -struct ProfileLevelId { - constexpr ProfileLevelId(Profile profile, Level level) - : profile(profile), level(level) {} - Profile profile; - Level level; -}; +typedef H264Profile Profile; +typedef H264Level Level; +typedef H264ProfileLevelId ProfileLevelId; + +constexpr H264Profile kProfileConstrainedBaseline = + H264Profile::kProfileConstrainedBaseline; +constexpr H264Profile kProfileBaseline = H264Profile::kProfileBaseline; +constexpr H264Profile kProfileMain = H264Profile::kProfileMain; +constexpr H264Profile kProfileConstrainedHigh = + H264Profile::kProfileConstrainedHigh; +constexpr H264Profile kProfileHigh = H264Profile::kProfileHigh; + +constexpr H264Level kLevel1_b = H264Level::kLevel1_b; +constexpr H264Level kLevel1 = H264Level::kLevel1; +constexpr H264Level kLevel1_1 = H264Level::kLevel1_1; +constexpr H264Level kLevel1_2 = H264Level::kLevel1_2; +constexpr H264Level kLevel1_3 = H264Level::kLevel1_3; +constexpr H264Level kLevel2 = H264Level::kLevel2; +constexpr H264Level kLevel2_1 = H264Level::kLevel2_1; +constexpr H264Level kLevel2_2 = H264Level::kLevel2_2; +constexpr H264Level kLevel3 = H264Level::kLevel3; +constexpr H264Level kLevel3_1 = H264Level::kLevel3_1; +constexpr H264Level kLevel3_2 = H264Level::kLevel3_2; +constexpr H264Level kLevel4 = H264Level::kLevel4; +constexpr H264Level kLevel4_1 = H264Level::kLevel4_1; +constexpr H264Level kLevel4_2 = H264Level::kLevel4_2; +constexpr H264Level kLevel5 = H264Level::kLevel5; +constexpr H264Level kLevel5_1 = H264Level::kLevel5_1; +constexpr H264Level kLevel5_2 = H264Level::kLevel5_2; // Parse profile level id that is represented as a string of 3 hex bytes. // Nothing will be returned if the string is not a recognized H264 @@ -70,7 +61,7 @@ absl::optional ParseProfileLevelId(const char* str); // returned if the profile-level-id key is missing. Nothing will be returned if // the key is present but the string is invalid. RTC_EXPORT absl::optional ParseSdpProfileLevelId( - const CodecParameterMap& params); + const SdpVideoFormat::Parameters& params); // Given that a decoder supports up to a given frame size (in pixels) at up to a // given number of frames per second, return the highest H.264 level where it @@ -84,33 +75,11 @@ RTC_EXPORT absl::optional SupportedLevel(int max_frame_pixel_count, RTC_EXPORT absl::optional ProfileLevelIdToString( const ProfileLevelId& profile_level_id); -// Generate codec parameters that will be used as answer in an SDP negotiation -// based on local supported parameters and remote offered parameters. Both -// |local_supported_params|, |remote_offered_params|, and |answer_params| -// represent sendrecv media descriptions, i.e they are a mix of both encode and -// decode capabilities. In theory, when the profile in |local_supported_params| -// represent a strict superset of the profile in |remote_offered_params|, we -// could limit the profile in |answer_params| to the profile in -// |remote_offered_params|. However, to simplify the code, each supported H264 -// profile should be listed explicitly in the list of local supported codecs, -// even if they are redundant. Then each local codec in the list should be -// tested one at a time against the remote codec, and only when the profiles are -// equal should this function be called. Therefore, this function does not need -// to handle profile intersection, and the profile of |local_supported_params| -// and |remote_offered_params| must be equal before calling this function. The -// parameters that are used when negotiating are the level part of -// profile-level-id and level-asymmetry-allowed. -void GenerateProfileLevelIdForAnswer( - const CodecParameterMap& local_supported_params, - const CodecParameterMap& remote_offered_params, - CodecParameterMap* answer_params); - // Returns true if the parameters have the same H264 profile, i.e. the same // H264::Profile (Baseline, High, etc). -bool IsSameH264Profile(const CodecParameterMap& params1, - const CodecParameterMap& params2); +RTC_EXPORT bool IsSameH264Profile(const SdpVideoFormat::Parameters& params1, + const SdpVideoFormat::Parameters& params2); } // namespace H264 } // namespace webrtc - #endif // MEDIA_BASE_H264_PROFILE_LEVEL_ID_H_ diff --git a/media/base/media_channel.cc b/media/base/media_channel.cc index 0cef36e2b9..01b043b828 100644 --- a/media/base/media_channel.cc +++ b/media/base/media_channel.cc @@ -10,21 +10,40 @@ #include "media/base/media_channel.h" +#include "media/base/rtp_utils.h" +#include "rtc_base/task_utils/to_queued_task.h" + namespace cricket { +using webrtc::FrameDecryptorInterface; +using webrtc::FrameEncryptorInterface; +using webrtc::FrameTransformerInterface; +using webrtc::PendingTaskSafetyFlag; +using webrtc::TaskQueueBase; +using webrtc::ToQueuedTask; +using webrtc::VideoTrackInterface; VideoOptions::VideoOptions() - : content_hint(webrtc::VideoTrackInterface::ContentHint::kNone) {} + : content_hint(VideoTrackInterface::ContentHint::kNone) {} VideoOptions::~VideoOptions() = default; -MediaChannel::MediaChannel(const MediaConfig& config) - : enable_dscp_(config.enable_dscp) {} +MediaChannel::MediaChannel(const MediaConfig& config, + TaskQueueBase* network_thread) + : enable_dscp_(config.enable_dscp), + network_safety_(PendingTaskSafetyFlag::CreateDetachedInactive()), + network_thread_(network_thread) {} -MediaChannel::MediaChannel() : enable_dscp_(false) {} +MediaChannel::MediaChannel(TaskQueueBase* network_thread) + : enable_dscp_(false), + network_safety_(PendingTaskSafetyFlag::CreateDetachedInactive()), + network_thread_(network_thread) {} -MediaChannel::~MediaChannel() {} +MediaChannel::~MediaChannel() { + RTC_DCHECK(!network_interface_); +} void MediaChannel::SetInterface(NetworkInterface* iface) { - webrtc::MutexLock lock(&network_interface_mutex_); + RTC_DCHECK_RUN_ON(network_thread_); + iface ? network_safety_->SetAlive() : network_safety_->SetNotAlive(); network_interface_ = iface; UpdateDscp(); } @@ -35,24 +54,163 @@ int MediaChannel::GetRtpSendTimeExtnId() const { void MediaChannel::SetFrameEncryptor( uint32_t ssrc, - rtc::scoped_refptr frame_encryptor) { + rtc::scoped_refptr frame_encryptor) { // Placeholder should be pure virtual once internal supports it. } void MediaChannel::SetFrameDecryptor( uint32_t ssrc, - rtc::scoped_refptr frame_decryptor) { + rtc::scoped_refptr frame_decryptor) { // Placeholder should be pure virtual once internal supports it. } void MediaChannel::SetVideoCodecSwitchingEnabled(bool enabled) {} +bool MediaChannel::SendPacket(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options) { + return DoSendPacket(packet, false, options); +} + +bool MediaChannel::SendRtcp(rtc::CopyOnWriteBuffer* packet, + const rtc::PacketOptions& options) { + return DoSendPacket(packet, true, options); +} + +int MediaChannel::SetOption(NetworkInterface::SocketType type, + rtc::Socket::Option opt, + int option) { + RTC_DCHECK_RUN_ON(network_thread_); + return SetOptionLocked(type, opt, option); +} + +// Corresponds to the SDP attribute extmap-allow-mixed, see RFC8285. +// Set to true if it's allowed to mix one- and two-byte RTP header extensions +// in the same stream. The setter and getter must only be called from +// worker_thread. +void MediaChannel::SetExtmapAllowMixed(bool extmap_allow_mixed) { + extmap_allow_mixed_ = extmap_allow_mixed; +} + +bool MediaChannel::ExtmapAllowMixed() const { + return extmap_allow_mixed_; +} + void MediaChannel::SetEncoderToPacketizerFrameTransformer( uint32_t ssrc, - rtc::scoped_refptr frame_transformer) {} + rtc::scoped_refptr frame_transformer) {} + void MediaChannel::SetDepacketizerToDecoderFrameTransformer( uint32_t ssrc, - rtc::scoped_refptr frame_transformer) {} + rtc::scoped_refptr frame_transformer) {} + +int MediaChannel::SetOptionLocked(NetworkInterface::SocketType type, + rtc::Socket::Option opt, + int option) { + if (!network_interface_) + return -1; + return network_interface_->SetOption(type, opt, option); +} + +bool MediaChannel::DscpEnabled() const { + return enable_dscp_; +} + +// This is the DSCP value used for both RTP and RTCP channels if DSCP is +// enabled. It can be changed at any time via |SetPreferredDscp|. +rtc::DiffServCodePoint MediaChannel::PreferredDscp() const { + RTC_DCHECK_RUN_ON(network_thread_); + return preferred_dscp_; +} + +void MediaChannel::SetPreferredDscp(rtc::DiffServCodePoint new_dscp) { + if (!network_thread_->IsCurrent()) { + // This is currently the common path as the derived channel classes + // get called on the worker thread. There are still some tests though + // that call directly on the network thread. + network_thread_->PostTask(ToQueuedTask( + network_safety_, [this, new_dscp]() { SetPreferredDscp(new_dscp); })); + return; + } + + RTC_DCHECK_RUN_ON(network_thread_); + if (new_dscp == preferred_dscp_) + return; + + preferred_dscp_ = new_dscp; + UpdateDscp(); +} + +rtc::scoped_refptr MediaChannel::network_safety() { + return network_safety_; +} + +void MediaChannel::UpdateDscp() { + rtc::DiffServCodePoint value = + enable_dscp_ ? preferred_dscp_ : rtc::DSCP_DEFAULT; + int ret = + SetOptionLocked(NetworkInterface::ST_RTP, rtc::Socket::OPT_DSCP, value); + if (ret == 0) + SetOptionLocked(NetworkInterface::ST_RTCP, rtc::Socket::OPT_DSCP, value); +} + +bool MediaChannel::DoSendPacket(rtc::CopyOnWriteBuffer* packet, + bool rtcp, + const rtc::PacketOptions& options) { + RTC_DCHECK_RUN_ON(network_thread_); + if (!network_interface_) + return false; + + return (!rtcp) ? network_interface_->SendPacket(packet, options) + : network_interface_->SendRtcp(packet, options); +} + +void MediaChannel::SendRtp(const uint8_t* data, + size_t len, + const webrtc::PacketOptions& options) { + auto send = + [this, packet_id = options.packet_id, + included_in_feedback = options.included_in_feedback, + included_in_allocation = options.included_in_allocation, + packet = rtc::CopyOnWriteBuffer(data, len, kMaxRtpPacketLen)]() mutable { + rtc::PacketOptions rtc_options; + rtc_options.packet_id = packet_id; + if (DscpEnabled()) { + rtc_options.dscp = PreferredDscp(); + } + rtc_options.info_signaled_after_sent.included_in_feedback = + included_in_feedback; + rtc_options.info_signaled_after_sent.included_in_allocation = + included_in_allocation; + SendPacket(&packet, rtc_options); + }; + + // TODO(bugs.webrtc.org/11993): ModuleRtpRtcpImpl2 and related classes (e.g. + // RTCPSender) aren't aware of the network thread and may trigger calls to + // this function from different threads. Update those classes to keep + // network traffic on the network thread. + if (network_thread_->IsCurrent()) { + send(); + } else { + network_thread_->PostTask(ToQueuedTask(network_safety_, std::move(send))); + } +} + +void MediaChannel::SendRtcp(const uint8_t* data, size_t len) { + auto send = [this, packet = rtc::CopyOnWriteBuffer( + data, len, kMaxRtpPacketLen)]() mutable { + rtc::PacketOptions rtc_options; + if (DscpEnabled()) { + rtc_options.dscp = PreferredDscp(); + } + SendRtcp(&packet, rtc_options); + }; + + if (network_thread_->IsCurrent()) { + send(); + } else { + network_thread_->PostTask(ToQueuedTask(network_safety_, std::move(send))); + } +} MediaSenderInfo::MediaSenderInfo() = default; MediaSenderInfo::~MediaSenderInfo() = default; @@ -78,9 +236,6 @@ VoiceMediaInfo::~VoiceMediaInfo() = default; VideoMediaInfo::VideoMediaInfo() = default; VideoMediaInfo::~VideoMediaInfo() = default; -DataMediaInfo::DataMediaInfo() = default; -DataMediaInfo::~DataMediaInfo() = default; - AudioSendParameters::AudioSendParameters() = default; AudioSendParameters::~AudioSendParameters() = default; @@ -107,31 +262,4 @@ cricket::MediaType VideoMediaChannel::media_type() const { return cricket::MediaType::MEDIA_TYPE_VIDEO; } -DataMediaChannel::DataMediaChannel() = default; -DataMediaChannel::DataMediaChannel(const MediaConfig& config) - : MediaChannel(config) {} -DataMediaChannel::~DataMediaChannel() = default; - -webrtc::RtpParameters DataMediaChannel::GetRtpSendParameters( - uint32_t ssrc) const { - // GetRtpSendParameters is not supported for DataMediaChannel. - RTC_NOTREACHED(); - return webrtc::RtpParameters(); -} -webrtc::RTCError DataMediaChannel::SetRtpSendParameters( - uint32_t ssrc, - const webrtc::RtpParameters& parameters) { - // SetRtpSendParameters is not supported for DataMediaChannel. - RTC_NOTREACHED(); - return webrtc::RTCError(webrtc::RTCErrorType::UNSUPPORTED_OPERATION); -} - -cricket::MediaType DataMediaChannel::media_type() const { - return cricket::MediaType::MEDIA_TYPE_DATA; -} - -bool DataMediaChannel::GetStats(DataMediaInfo* info) { - return true; -} - } // namespace cricket diff --git a/media/base/media_channel.h b/media/base/media_channel.h index a947b47998..7b9a6f138c 100644 --- a/media/base/media_channel.h +++ b/media/base/media_channel.h @@ -26,6 +26,7 @@ #include "api/media_stream_interface.h" #include "api/rtc_error.h" #include "api/rtp_parameters.h" +#include "api/transport/data_channel_transport_interface.h" #include "api/transport/rtp/rtp_source.h" #include "api/video/video_content_type.h" #include "api/video/video_sink_interface.h" @@ -43,7 +44,6 @@ #include "modules/rtp_rtcp/include/report_block_data.h" #include "rtc_base/async_packet_socket.h" #include "rtc_base/buffer.h" -#include "rtc_base/callback.h" #include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/dscp.h" #include "rtc_base/logging.h" @@ -51,8 +51,7 @@ #include "rtc_base/socket.h" #include "rtc_base/string_encode.h" #include "rtc_base/strings/string_builder.h" -#include "rtc_base/synchronization/mutex.h" -#include "rtc_base/third_party/sigslot/sigslot.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" namespace rtc { class Timing; @@ -155,7 +154,7 @@ struct VideoOptions { } }; -class MediaChannel : public sigslot::has_slots<> { +class MediaChannel { public: class NetworkInterface { public: @@ -170,18 +169,21 @@ class MediaChannel : public sigslot::has_slots<> { virtual ~NetworkInterface() {} }; - explicit MediaChannel(const MediaConfig& config); - MediaChannel(); - ~MediaChannel() override; + MediaChannel(const MediaConfig& config, + webrtc::TaskQueueBase* network_thread); + explicit MediaChannel(webrtc::TaskQueueBase* network_thread); + virtual ~MediaChannel(); virtual cricket::MediaType media_type() const = 0; // Sets the abstract interface class for sending RTP/RTCP data. - virtual void SetInterface(NetworkInterface* iface) - RTC_LOCKS_EXCLUDED(network_interface_mutex_); - // Called when a RTP packet is received. + virtual void SetInterface(NetworkInterface* iface); + // Called on the network when an RTP packet is received. virtual void OnPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) = 0; + // Called on the network thread after a transport has finished sending a + // packet. + virtual void OnPacketSent(const rtc::SentPacket& sent_packet) = 0; // Called when the socket's ability to send has changed. virtual void OnReadyToSend(bool ready) = 0; // Called when the network route used for sending packets changed. @@ -207,6 +209,17 @@ class MediaChannel : public sigslot::has_slots<> { // Resets any cached StreamParams for an unsignaled RecvStream, and removes // any existing unsignaled streams. virtual void ResetUnsignaledRecvStream() = 0; + // Informs the media channel when the transport's demuxer criteria is updated. + // * OnDemuxerCriteriaUpdatePending() happens on the same thread that the + // channel's streams are added and removed (worker thread). + // * OnDemuxerCriteriaUpdateComplete() happens on the thread where the demuxer + // lives (network thread). + // Because the demuxer is updated asynchronously, there is a window of time + // where packets are arriving to the channel for streams that have already + // been removed on the worker thread. It is important NOT to treat these as + // new unsignalled ssrcs. + virtual void OnDemuxerCriteriaUpdatePending() = 0; + virtual void OnDemuxerCriteriaUpdateComplete() = 0; // Returns the absoulte sendtime extension id value from media channel. virtual int GetRtpSendTimeExtnId() const; // Set the frame encryptor to use on all outgoing frames. This is optional. @@ -229,30 +242,21 @@ class MediaChannel : public sigslot::has_slots<> { // Base method to send packet using NetworkInterface. bool SendPacket(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options) { - return DoSendPacket(packet, false, options); - } + const rtc::PacketOptions& options); bool SendRtcp(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options) { - return DoSendPacket(packet, true, options); - } + const rtc::PacketOptions& options); int SetOption(NetworkInterface::SocketType type, rtc::Socket::Option opt, - int option) RTC_LOCKS_EXCLUDED(network_interface_mutex_) { - webrtc::MutexLock lock(&network_interface_mutex_); - return SetOptionLocked(type, opt, option); - } + int option); // Corresponds to the SDP attribute extmap-allow-mixed, see RFC8285. // Set to true if it's allowed to mix one- and two-byte RTP header extensions // in the same stream. The setter and getter must only be called from // worker_thread. - void SetExtmapAllowMixed(bool extmap_allow_mixed) { - extmap_allow_mixed_ = extmap_allow_mixed; - } - bool ExtmapAllowMixed() const { return extmap_allow_mixed_; } + void SetExtmapAllowMixed(bool extmap_allow_mixed); + bool ExtmapAllowMixed() const; virtual webrtc::RtpParameters GetRtpSendParameters(uint32_t ssrc) const = 0; virtual webrtc::RTCError SetRtpSendParameters( @@ -269,69 +273,42 @@ class MediaChannel : public sigslot::has_slots<> { protected: int SetOptionLocked(NetworkInterface::SocketType type, rtc::Socket::Option opt, - int option) - RTC_EXCLUSIVE_LOCKS_REQUIRED(network_interface_mutex_) { - if (!network_interface_) - return -1; - return network_interface_->SetOption(type, opt, option); - } + int option) RTC_RUN_ON(network_thread_); - bool DscpEnabled() const { return enable_dscp_; } + bool DscpEnabled() const; // This is the DSCP value used for both RTP and RTCP channels if DSCP is // enabled. It can be changed at any time via |SetPreferredDscp|. - rtc::DiffServCodePoint PreferredDscp() const - RTC_LOCKS_EXCLUDED(network_interface_mutex_) { - webrtc::MutexLock lock(&network_interface_mutex_); - return preferred_dscp_; - } + rtc::DiffServCodePoint PreferredDscp() const; + void SetPreferredDscp(rtc::DiffServCodePoint new_dscp); - int SetPreferredDscp(rtc::DiffServCodePoint preferred_dscp) - RTC_LOCKS_EXCLUDED(network_interface_mutex_) { - webrtc::MutexLock lock(&network_interface_mutex_); - if (preferred_dscp == preferred_dscp_) { - return 0; - } - preferred_dscp_ = preferred_dscp; - return UpdateDscp(); - } + rtc::scoped_refptr network_safety(); + + // Utility implementation for derived classes (video/voice) that applies + // the packet options and passes the data onwards to `SendPacket`. + void SendRtp(const uint8_t* data, + size_t len, + const webrtc::PacketOptions& options); + + void SendRtcp(const uint8_t* data, size_t len); private: // Apply the preferred DSCP setting to the underlying network interface RTP // and RTCP channels. If DSCP is disabled, then apply the default DSCP value. - int UpdateDscp() RTC_EXCLUSIVE_LOCKS_REQUIRED(network_interface_mutex_) { - rtc::DiffServCodePoint value = - enable_dscp_ ? preferred_dscp_ : rtc::DSCP_DEFAULT; - int ret = - SetOptionLocked(NetworkInterface::ST_RTP, rtc::Socket::OPT_DSCP, value); - if (ret == 0) { - ret = SetOptionLocked(NetworkInterface::ST_RTCP, rtc::Socket::OPT_DSCP, - value); - } - return ret; - } + void UpdateDscp() RTC_RUN_ON(network_thread_); bool DoSendPacket(rtc::CopyOnWriteBuffer* packet, bool rtcp, - const rtc::PacketOptions& options) - RTC_LOCKS_EXCLUDED(network_interface_mutex_) { - webrtc::MutexLock lock(&network_interface_mutex_); - if (!network_interface_) - return false; - - return (!rtcp) ? network_interface_->SendPacket(packet, options) - : network_interface_->SendRtcp(packet, options); - } + const rtc::PacketOptions& options); const bool enable_dscp_; - // |network_interface_| can be accessed from the worker_thread and - // from any MediaEngine threads. This critical section is to protect accessing - // of network_interface_ object. - mutable webrtc::Mutex network_interface_mutex_; - NetworkInterface* network_interface_ - RTC_GUARDED_BY(network_interface_mutex_) = nullptr; - rtc::DiffServCodePoint preferred_dscp_ - RTC_GUARDED_BY(network_interface_mutex_) = rtc::DSCP_DEFAULT; + const rtc::scoped_refptr network_safety_ + RTC_PT_GUARDED_BY(network_thread_); + webrtc::TaskQueueBase* const network_thread_; + NetworkInterface* network_interface_ RTC_GUARDED_BY(network_thread_) = + nullptr; + rtc::DiffServCodePoint preferred_dscp_ RTC_GUARDED_BY(network_thread_) = + rtc::DSCP_DEFAULT; bool extmap_allow_mixed_ = false; }; @@ -395,6 +372,8 @@ struct MediaSenderInfo { int packets_sent = 0; // https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-retransmittedpacketssent uint64_t retransmitted_packets_sent = 0; + // https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-nackcount + uint32_t nacks_rcvd = 0; int packets_lost = 0; float fraction_lost = 0.0f; int64_t rtt_ms = 0; @@ -449,6 +428,13 @@ struct MediaReceiverInfo { int64_t header_and_padding_bytes_rcvd = 0; int packets_rcvd = 0; int packets_lost = 0; + absl::optional nacks_sent; + // Jitter (network-related) latency (cumulative). + // https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-jitterbufferdelay + double jitter_buffer_delay_seconds = 0.0; + // Number of observations for cumulative jitter latency. + // https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-jitterbufferemittedcount + uint64_t jitter_buffer_emitted_count = 0; // The timestamp at which the last packet was received, i.e. the time of the // local clock when it was received - not the RTP timestamp of that packet. // https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-lastpacketreceivedtimestamp @@ -492,8 +478,6 @@ struct VoiceReceiverInfo : public MediaReceiverInfo { uint64_t concealed_samples = 0; uint64_t silent_concealed_samples = 0; uint64_t concealment_events = 0; - double jitter_buffer_delay_seconds = 0.0; - uint64_t jitter_buffer_emitted_count = 0; double jitter_buffer_target_delay_seconds = 0.0; uint64_t inserted_samples_for_deceleration = 0; uint64_t removed_samples_for_acceleration = 0; @@ -537,6 +521,13 @@ struct VoiceReceiverInfo : public MediaReceiverInfo { // longer than 150 ms). int32_t interruption_count = 0; int32_t total_interruption_duration_ms = 0; + // Remote outbound stats derived by the received RTCP sender reports. + // https://w3c.github.io/webrtc-stats/#remoteoutboundrtpstats-dict* + absl::optional last_sender_report_timestamp_ms; + absl::optional last_sender_report_remote_timestamp_ms; + uint32_t sender_reports_packets_sent = 0; + uint64_t sender_reports_bytes_sent = 0; + uint64_t sender_reports_reports_count = 0; }; struct VideoSenderInfo : public MediaSenderInfo { @@ -546,9 +537,9 @@ struct VideoSenderInfo : public MediaSenderInfo { std::string encoder_implementation_name; int firs_rcvd = 0; int plis_rcvd = 0; - int nacks_rcvd = 0; int send_frame_width = 0; int send_frame_height = 0; + int frames = 0; int framerate_input = 0; int framerate_sent = 0; int aggregated_framerate_sent = 0; @@ -590,7 +581,6 @@ struct VideoReceiverInfo : public MediaReceiverInfo { int packets_concealed = 0; int firs_sent = 0; int plis_sent = 0; - int nacks_sent = 0; int frame_width = 0; int frame_height = 0; int framerate_rcvd = 0; @@ -617,6 +607,7 @@ struct VideoReceiverInfo : public MediaReceiverInfo { uint32_t total_pauses_duration_ms = 0; uint32_t total_frames_duration_ms = 0; double sum_squared_frame_durations = 0.0; + uint32_t jitter_ms = 0; webrtc::VideoContentType content_type = webrtc::VideoContentType::UNSPECIFIED; @@ -630,12 +621,6 @@ struct VideoReceiverInfo : public MediaReceiverInfo { int max_decode_ms = 0; // Jitter (network-related) latency. int jitter_buffer_ms = 0; - // Jitter (network-related) latency (cumulative). - // https://w3c.github.io/webrtc-stats/#dom-rtcvideoreceiverstats-jitterbufferdelay - double jitter_buffer_delay_seconds = 0; - // Number of observations for cumulative jitter latency. - // https://w3c.github.io/webrtc-stats/#dom-rtcvideoreceiverstats-jitterbufferemittedcount - uint64_t jitter_buffer_emitted_count = 0; // Requested minimum playout latency. int min_playout_delay_ms = 0; // Requested latency to account for rendering delay. @@ -657,14 +642,6 @@ struct VideoReceiverInfo : public MediaReceiverInfo { absl::optional timing_frame_info; }; -struct DataSenderInfo : public MediaSenderInfo { - uint32_t ssrc = 0; -}; - -struct DataReceiverInfo : public MediaReceiverInfo { - uint32_t ssrc = 0; -}; - struct BandwidthEstimationInfo { int available_send_bandwidth = 0; int available_recv_bandwidth = 0; @@ -718,17 +695,6 @@ struct VideoMediaInfo { RtpCodecParametersMap receive_codecs; }; -struct DataMediaInfo { - DataMediaInfo(); - ~DataMediaInfo(); - void Clear() { - senders.clear(); - receivers.clear(); - } - std::vector senders; - std::vector receivers; -}; - struct RtcpParameters { bool reduced_size = false; bool remote_estimate = false; @@ -799,9 +765,11 @@ struct AudioRecvParameters : RtpParameters {}; class VoiceMediaChannel : public MediaChannel, public Delayable { public: - VoiceMediaChannel() {} - explicit VoiceMediaChannel(const MediaConfig& config) - : MediaChannel(config) {} + explicit VoiceMediaChannel(webrtc::TaskQueueBase* network_thread) + : MediaChannel(network_thread) {} + VoiceMediaChannel(const MediaConfig& config, + webrtc::TaskQueueBase* network_thread) + : MediaChannel(config, network_thread) {} ~VoiceMediaChannel() override {} cricket::MediaType media_type() const override; @@ -869,9 +837,11 @@ struct VideoRecvParameters : RtpParameters {}; class VideoMediaChannel : public MediaChannel, public Delayable { public: - VideoMediaChannel() {} - explicit VideoMediaChannel(const MediaConfig& config) - : MediaChannel(config) {} + explicit VideoMediaChannel(webrtc::TaskQueueBase* network_thread) + : MediaChannel(network_thread) {} + VideoMediaChannel(const MediaConfig& config, + webrtc::TaskQueueBase* network_thread) + : MediaChannel(config, network_thread) {} ~VideoMediaChannel() override {} cricket::MediaType media_type() const override; @@ -922,102 +892,21 @@ class VideoMediaChannel : public MediaChannel, public Delayable { virtual std::vector GetSources(uint32_t ssrc) const = 0; }; -enum DataMessageType { - // Chrome-Internal use only. See SctpDataMediaChannel for the actual PPID - // values. - DMT_NONE = 0, - DMT_CONTROL = 1, - DMT_BINARY = 2, - DMT_TEXT = 3, -}; - // Info about data received in DataMediaChannel. For use in // DataMediaChannel::SignalDataReceived and in all of the signals that // signal fires, on up the chain. struct ReceiveDataParams { // The in-packet stream indentifier. - // RTP data channels use SSRCs, SCTP data channels use SIDs. - union { - uint32_t ssrc; - int sid = 0; - }; + // SCTP data channels use SIDs. + int sid = 0; // The type of message (binary, text, or control). - DataMessageType type = DMT_TEXT; + webrtc::DataMessageType type = webrtc::DataMessageType::kText; // A per-stream value incremented per packet in the stream. int seq_num = 0; - // A per-stream value monotonically increasing with time. - int timestamp = 0; -}; - -struct SendDataParams { - // The in-packet stream indentifier. - // RTP data channels use SSRCs, SCTP data channels use SIDs. - union { - uint32_t ssrc; - int sid = 0; - }; - // The type of message (binary, text, or control). - DataMessageType type = DMT_TEXT; - - // TODO(pthatcher): Make |ordered| and |reliable| true by default? - // For SCTP, whether to send messages flagged as ordered or not. - // If false, messages can be received out of order. - bool ordered = false; - // For SCTP, whether the messages are sent reliably or not. - // If false, messages may be lost. - bool reliable = false; - // For SCTP, if reliable == false, provide partial reliability by - // resending up to this many times. Either count or millis - // is supported, not both at the same time. - int max_rtx_count = 0; - // For SCTP, if reliable == false, provide partial reliability by - // resending for up to this many milliseconds. Either count or millis - // is supported, not both at the same time. - int max_rtx_ms = 0; }; enum SendDataResult { SDR_SUCCESS, SDR_ERROR, SDR_BLOCK }; -struct DataSendParameters : RtpSendParameters {}; - -struct DataRecvParameters : RtpParameters {}; - -class DataMediaChannel : public MediaChannel { - public: - DataMediaChannel(); - explicit DataMediaChannel(const MediaConfig& config); - ~DataMediaChannel() override; - - cricket::MediaType media_type() const override; - virtual bool SetSendParameters(const DataSendParameters& params) = 0; - virtual bool SetRecvParameters(const DataRecvParameters& params) = 0; - - // RtpParameter methods are not supported for Data channel. - webrtc::RtpParameters GetRtpSendParameters(uint32_t ssrc) const override; - webrtc::RTCError SetRtpSendParameters( - uint32_t ssrc, - const webrtc::RtpParameters& parameters) override; - - // TODO(pthatcher): Implement this. - virtual bool GetStats(DataMediaInfo* info); - - virtual bool SetSend(bool send) = 0; - virtual bool SetReceive(bool receive) = 0; - - void OnNetworkRouteChanged(const std::string& transport_name, - const rtc::NetworkRoute& network_route) override {} - - virtual bool SendData(const SendDataParams& params, - const rtc::CopyOnWriteBuffer& payload, - SendDataResult* result = NULL) = 0; - // Signals when data is received (params, data, len) - sigslot::signal3 - SignalDataReceived; - // Signal when the media channel is ready to send the stream. Arguments are: - // writable(bool) - sigslot::signal1 SignalReadyToSend; -}; - } // namespace cricket #endif // MEDIA_BASE_MEDIA_CHANNEL_H_ diff --git a/media/base/media_constants.cc b/media/base/media_constants.cc index fb34ea8851..17a8a83bd0 100644 --- a/media/base/media_constants.cc +++ b/media/base/media_constants.cc @@ -13,14 +13,15 @@ namespace cricket { const int kVideoCodecClockrate = 90000; -const int kDataCodecClockrate = 90000; -const int kRtpDataMaxBandwidth = 30720; // bps + +const int kVideoMtu = 1200; +const int kVideoRtpSendBufferSize = 65536; +const int kVideoRtpRecvBufferSize = 262144; const float kHighSystemCpuThreshold = 0.85f; const float kLowSystemCpuThreshold = 0.65f; const float kProcessCpuThreshold = 0.10f; -const char kRtxCodecName[] = "rtx"; const char kRedCodecName[] = "red"; const char kUlpfecCodecName[] = "ulpfec"; const char kMultiplexCodecName[] = "multiplex"; @@ -32,7 +33,11 @@ const char kFlexfecCodecName[] = "flexfec-03"; // draft-ietf-payload-flexible-fec-scheme-02.txt const char kFlexfecFmtpRepairWindow[] = "repair-window"; +// RFC 4588 RTP Retransmission Payload Format +const char kRtxCodecName[] = "rtx"; +const char kCodecParamRtxTime[] = "rtx-time"; const char kCodecParamAssociatedPayloadType[] = "apt"; + const char kCodecParamAssociatedCodecName[] = "acn"; const char kOpusCodecName[] = "opus"; @@ -90,9 +95,6 @@ const char kCodecParamMinBitrate[] = "x-google-min-bitrate"; const char kCodecParamStartBitrate[] = "x-google-start-bitrate"; const char kCodecParamMaxQuantization[] = "x-google-max-quantization"; -const int kGoogleRtpDataCodecPlType = 109; -const char kGoogleRtpDataCodecName[] = "google-data"; - const char kComfortNoiseCodecName[] = "CN"; const char kVp8CodecName[] = "VP8"; diff --git a/media/base/media_constants.h b/media/base/media_constants.h index 6907172df2..bf7f0c3047 100644 --- a/media/base/media_constants.h +++ b/media/base/media_constants.h @@ -20,15 +20,16 @@ namespace cricket { extern const int kVideoCodecClockrate; -extern const int kDataCodecClockrate; -extern const int kRtpDataMaxBandwidth; // bps + +extern const int kVideoMtu; +extern const int kVideoRtpSendBufferSize; +extern const int kVideoRtpRecvBufferSize; // Default CPU thresholds. extern const float kHighSystemCpuThreshold; extern const float kLowSystemCpuThreshold; extern const float kProcessCpuThreshold; -extern const char kRtxCodecName[]; extern const char kRedCodecName[]; extern const char kUlpfecCodecName[]; extern const char kFlexfecCodecName[]; @@ -36,8 +37,10 @@ extern const char kMultiplexCodecName[]; extern const char kFlexfecFmtpRepairWindow[]; -// Codec parameters +extern const char kRtxCodecName[]; +extern const char kCodecParamRtxTime[]; extern const char kCodecParamAssociatedPayloadType[]; + extern const char kCodecParamAssociatedCodecName[]; extern const char kOpusCodecName[]; @@ -114,12 +117,6 @@ extern const char kCodecParamMinBitrate[]; extern const char kCodecParamStartBitrate[]; extern const char kCodecParamMaxQuantization[]; -// We put the data codec names here so callers of DataEngine::CreateChannel -// don't have to import rtpdataengine.h to get the codec names they want to -// pass in. -extern const int kGoogleRtpDataCodecPlType; -extern const char kGoogleRtpDataCodecName[]; - extern const char kComfortNoiseCodecName[]; RTC_EXPORT extern const char kVp8CodecName[]; diff --git a/media/base/media_engine.h b/media/base/media_engine.h index 1d8917cfcb..6f47127f30 100644 --- a/media/base/media_engine.h +++ b/media/base/media_engine.h @@ -121,9 +121,9 @@ class MediaEngineInterface { public: virtual ~MediaEngineInterface() {} - // Initialization - // Starts the engine. + // Initialization. Needs to be called on the worker thread. virtual bool Init() = 0; + virtual VoiceEngineInterface& voice() = 0; virtual VideoEngineInterface& video() = 0; virtual const VoiceEngineInterface& voice() const = 0; @@ -141,6 +141,8 @@ class CompositeMediaEngine : public MediaEngineInterface { CompositeMediaEngine(std::unique_ptr audio_engine, std::unique_ptr video_engine); ~CompositeMediaEngine() override; + + // Always succeeds. bool Init() override; VoiceEngineInterface& voice() override; @@ -150,21 +152,8 @@ class CompositeMediaEngine : public MediaEngineInterface { private: const std::unique_ptr trials_; - std::unique_ptr voice_engine_; - std::unique_ptr video_engine_; -}; - -enum DataChannelType { - DCT_NONE = 0, - DCT_RTP = 1, - DCT_SCTP = 2, -}; - -class DataEngineInterface { - public: - virtual ~DataEngineInterface() {} - virtual DataMediaChannel* CreateChannel(const MediaConfig& config) = 0; - virtual const std::vector& data_codecs() = 0; + const std::unique_ptr voice_engine_; + const std::unique_ptr video_engine_; }; webrtc::RtpParameters CreateRtpParametersWithOneEncoding(); diff --git a/media/base/rtp_data_engine.cc b/media/base/rtp_data_engine.cc deleted file mode 100644 index 5fbb25f533..0000000000 --- a/media/base/rtp_data_engine.cc +++ /dev/null @@ -1,338 +0,0 @@ -/* - * Copyright (c) 2012 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "media/base/rtp_data_engine.h" - -#include - -#include "absl/strings/match.h" -#include "media/base/codec.h" -#include "media/base/media_constants.h" -#include "media/base/rtp_utils.h" -#include "media/base/stream_params.h" -#include "rtc_base/copy_on_write_buffer.h" -#include "rtc_base/data_rate_limiter.h" -#include "rtc_base/helpers.h" -#include "rtc_base/logging.h" -#include "rtc_base/sanitizer.h" - -namespace cricket { - -// We want to avoid IP fragmentation. -static const size_t kDataMaxRtpPacketLen = 1200U; -// We reserve space after the RTP header for future wiggle room. -static const unsigned char kReservedSpace[] = {0x00, 0x00, 0x00, 0x00}; - -// Amount of overhead SRTP may take. We need to leave room in the -// buffer for it, otherwise SRTP will fail later. If SRTP ever uses -// more than this, we need to increase this number. -static const size_t kMaxSrtpHmacOverhead = 16; - -RtpDataEngine::RtpDataEngine() { - data_codecs_.push_back( - DataCodec(kGoogleRtpDataCodecPlType, kGoogleRtpDataCodecName)); -} - -DataMediaChannel* RtpDataEngine::CreateChannel(const MediaConfig& config) { - return new RtpDataMediaChannel(config); -} - -static const DataCodec* FindCodecByName(const std::vector& codecs, - const std::string& name) { - for (const DataCodec& codec : codecs) { - if (absl::EqualsIgnoreCase(name, codec.name)) - return &codec; - } - return nullptr; -} - -RtpDataMediaChannel::RtpDataMediaChannel(const MediaConfig& config) - : DataMediaChannel(config) { - Construct(); - SetPreferredDscp(rtc::DSCP_AF41); -} - -void RtpDataMediaChannel::Construct() { - sending_ = false; - receiving_ = false; - send_limiter_.reset(new rtc::DataRateLimiter(kRtpDataMaxBandwidth / 8, 1.0)); -} - -RtpDataMediaChannel::~RtpDataMediaChannel() { - std::map::const_iterator iter; - for (iter = rtp_clock_by_send_ssrc_.begin(); - iter != rtp_clock_by_send_ssrc_.end(); ++iter) { - delete iter->second; - } -} - -void RTC_NO_SANITIZE("float-cast-overflow") // bugs.webrtc.org/8204 - RtpClock::Tick(double now, int* seq_num, uint32_t* timestamp) { - *seq_num = ++last_seq_num_; - *timestamp = timestamp_offset_ + static_cast(now * clockrate_); - // UBSan: 5.92374e+10 is outside the range of representable values of type - // 'unsigned int' -} - -const DataCodec* FindUnknownCodec(const std::vector& codecs) { - DataCodec data_codec(kGoogleRtpDataCodecPlType, kGoogleRtpDataCodecName); - std::vector::const_iterator iter; - for (iter = codecs.begin(); iter != codecs.end(); ++iter) { - if (!iter->Matches(data_codec)) { - return &(*iter); - } - } - return NULL; -} - -const DataCodec* FindKnownCodec(const std::vector& codecs) { - DataCodec data_codec(kGoogleRtpDataCodecPlType, kGoogleRtpDataCodecName); - std::vector::const_iterator iter; - for (iter = codecs.begin(); iter != codecs.end(); ++iter) { - if (iter->Matches(data_codec)) { - return &(*iter); - } - } - return NULL; -} - -bool RtpDataMediaChannel::SetRecvCodecs(const std::vector& codecs) { - const DataCodec* unknown_codec = FindUnknownCodec(codecs); - if (unknown_codec) { - RTC_LOG(LS_WARNING) << "Failed to SetRecvCodecs because of unknown codec: " - << unknown_codec->ToString(); - return false; - } - - recv_codecs_ = codecs; - return true; -} - -bool RtpDataMediaChannel::SetSendCodecs(const std::vector& codecs) { - const DataCodec* known_codec = FindKnownCodec(codecs); - if (!known_codec) { - RTC_LOG(LS_WARNING) - << "Failed to SetSendCodecs because there is no known codec."; - return false; - } - - send_codecs_ = codecs; - return true; -} - -bool RtpDataMediaChannel::SetSendParameters(const DataSendParameters& params) { - return (SetSendCodecs(params.codecs) && - SetMaxSendBandwidth(params.max_bandwidth_bps)); -} - -bool RtpDataMediaChannel::SetRecvParameters(const DataRecvParameters& params) { - return SetRecvCodecs(params.codecs); -} - -bool RtpDataMediaChannel::AddSendStream(const StreamParams& stream) { - if (!stream.has_ssrcs()) { - return false; - } - - if (GetStreamBySsrc(send_streams_, stream.first_ssrc())) { - RTC_LOG(LS_WARNING) << "Not adding data send stream '" << stream.id - << "' with ssrc=" << stream.first_ssrc() - << " because stream already exists."; - return false; - } - - send_streams_.push_back(stream); - // TODO(pthatcher): This should be per-stream, not per-ssrc. - // And we should probably allow more than one per stream. - rtp_clock_by_send_ssrc_[stream.first_ssrc()] = - new RtpClock(kDataCodecClockrate, rtc::CreateRandomNonZeroId(), - rtc::CreateRandomNonZeroId()); - - RTC_LOG(LS_INFO) << "Added data send stream '" << stream.id - << "' with ssrc=" << stream.first_ssrc(); - return true; -} - -bool RtpDataMediaChannel::RemoveSendStream(uint32_t ssrc) { - if (!GetStreamBySsrc(send_streams_, ssrc)) { - return false; - } - - RemoveStreamBySsrc(&send_streams_, ssrc); - delete rtp_clock_by_send_ssrc_[ssrc]; - rtp_clock_by_send_ssrc_.erase(ssrc); - return true; -} - -bool RtpDataMediaChannel::AddRecvStream(const StreamParams& stream) { - if (!stream.has_ssrcs()) { - return false; - } - - if (GetStreamBySsrc(recv_streams_, stream.first_ssrc())) { - RTC_LOG(LS_WARNING) << "Not adding data recv stream '" << stream.id - << "' with ssrc=" << stream.first_ssrc() - << " because stream already exists."; - return false; - } - - recv_streams_.push_back(stream); - RTC_LOG(LS_INFO) << "Added data recv stream '" << stream.id - << "' with ssrc=" << stream.first_ssrc(); - return true; -} - -bool RtpDataMediaChannel::RemoveRecvStream(uint32_t ssrc) { - RemoveStreamBySsrc(&recv_streams_, ssrc); - return true; -} - -// Not implemented. -void RtpDataMediaChannel::ResetUnsignaledRecvStream() {} - -void RtpDataMediaChannel::OnPacketReceived(rtc::CopyOnWriteBuffer packet, - int64_t /* packet_time_us */) { - RtpHeader header; - if (!GetRtpHeader(packet.cdata(), packet.size(), &header)) { - return; - } - - size_t header_length; - if (!GetRtpHeaderLen(packet.cdata(), packet.size(), &header_length)) { - return; - } - const char* data = - packet.cdata() + header_length + sizeof(kReservedSpace); - size_t data_len = packet.size() - header_length - sizeof(kReservedSpace); - - if (!receiving_) { - RTC_LOG(LS_WARNING) << "Not receiving packet " << header.ssrc << ":" - << header.seq_num << " before SetReceive(true) called."; - return; - } - - if (!FindCodecById(recv_codecs_, header.payload_type)) { - return; - } - - if (!GetStreamBySsrc(recv_streams_, header.ssrc)) { - RTC_LOG(LS_WARNING) << "Received packet for unknown ssrc: " << header.ssrc; - return; - } - - // Uncomment this for easy debugging. - // const auto* found_stream = GetStreamBySsrc(recv_streams_, header.ssrc); - // RTC_LOG(LS_INFO) << "Received packet" - // << " groupid=" << found_stream.groupid - // << ", ssrc=" << header.ssrc - // << ", seqnum=" << header.seq_num - // << ", timestamp=" << header.timestamp - // << ", len=" << data_len; - - ReceiveDataParams params; - params.ssrc = header.ssrc; - params.seq_num = header.seq_num; - params.timestamp = header.timestamp; - SignalDataReceived(params, data, data_len); -} - -bool RtpDataMediaChannel::SetMaxSendBandwidth(int bps) { - if (bps <= 0) { - bps = kRtpDataMaxBandwidth; - } - send_limiter_.reset(new rtc::DataRateLimiter(bps / 8, 1.0)); - RTC_LOG(LS_INFO) << "RtpDataMediaChannel::SetSendBandwidth to " << bps - << "bps."; - return true; -} - -bool RtpDataMediaChannel::SendData(const SendDataParams& params, - const rtc::CopyOnWriteBuffer& payload, - SendDataResult* result) { - if (result) { - // If we return true, we'll set this to SDR_SUCCESS. - *result = SDR_ERROR; - } - if (!sending_) { - RTC_LOG(LS_WARNING) << "Not sending packet with ssrc=" << params.ssrc - << " len=" << payload.size() - << " before SetSend(true)."; - return false; - } - - if (params.type != cricket::DMT_TEXT) { - RTC_LOG(LS_WARNING) - << "Not sending data because binary type is unsupported."; - return false; - } - - const StreamParams* found_stream = - GetStreamBySsrc(send_streams_, params.ssrc); - if (!found_stream) { - RTC_LOG(LS_WARNING) << "Not sending data because ssrc is unknown: " - << params.ssrc; - return false; - } - - const DataCodec* found_codec = - FindCodecByName(send_codecs_, kGoogleRtpDataCodecName); - if (!found_codec) { - RTC_LOG(LS_WARNING) << "Not sending data because codec is unknown: " - << kGoogleRtpDataCodecName; - return false; - } - - size_t packet_len = (kMinRtpPacketLen + sizeof(kReservedSpace) + - payload.size() + kMaxSrtpHmacOverhead); - if (packet_len > kDataMaxRtpPacketLen) { - return false; - } - - double now = - rtc::TimeMicros() / static_cast(rtc::kNumMicrosecsPerSec); - - if (!send_limiter_->CanUse(packet_len, now)) { - RTC_LOG(LS_VERBOSE) << "Dropped data packet of len=" << packet_len - << "; already sent " << send_limiter_->used_in_period() - << "/" << send_limiter_->max_per_period(); - return false; - } - - RtpHeader header; - header.payload_type = found_codec->id; - header.ssrc = params.ssrc; - rtp_clock_by_send_ssrc_[header.ssrc]->Tick(now, &header.seq_num, - &header.timestamp); - - rtc::CopyOnWriteBuffer packet(kMinRtpPacketLen, packet_len); - if (!SetRtpHeader(packet.MutableData(), packet.size(), header)) { - return false; - } - packet.AppendData(kReservedSpace); - packet.AppendData(payload); - - RTC_LOG(LS_VERBOSE) << "Sent RTP data packet: " - " stream=" - << found_stream->id << " ssrc=" << header.ssrc - << ", seqnum=" << header.seq_num - << ", timestamp=" << header.timestamp - << ", len=" << payload.size(); - - rtc::PacketOptions options; - options.info_signaled_after_sent.packet_type = rtc::PacketType::kData; - MediaChannel::SendPacket(&packet, options); - send_limiter_->Use(packet_len, now); - if (result) { - *result = SDR_SUCCESS; - } - return true; -} - -} // namespace cricket diff --git a/media/base/rtp_data_engine.h b/media/base/rtp_data_engine.h deleted file mode 100644 index e5f071d5a9..0000000000 --- a/media/base/rtp_data_engine.h +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright (c) 2012 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef MEDIA_BASE_RTP_DATA_ENGINE_H_ -#define MEDIA_BASE_RTP_DATA_ENGINE_H_ - -#include -#include -#include -#include - -#include "media/base/codec.h" -#include "media/base/media_channel.h" -#include "media/base/media_constants.h" -#include "media/base/media_engine.h" - -namespace rtc { -class DataRateLimiter; -} - -namespace cricket { - -class RtpDataEngine : public DataEngineInterface { - public: - RtpDataEngine(); - - virtual DataMediaChannel* CreateChannel(const MediaConfig& config); - - virtual const std::vector& data_codecs() { return data_codecs_; } - - private: - std::vector data_codecs_; -}; - -// Keep track of sequence number and timestamp of an RTP stream. The -// sequence number starts with a "random" value and increments. The -// timestamp starts with a "random" value and increases monotonically -// according to the clockrate. -class RtpClock { - public: - RtpClock(int clockrate, uint16_t first_seq_num, uint32_t timestamp_offset) - : clockrate_(clockrate), - last_seq_num_(first_seq_num), - timestamp_offset_(timestamp_offset) {} - - // Given the current time (in number of seconds which must be - // monotonically increasing), Return the next sequence number and - // timestamp. - void Tick(double now, int* seq_num, uint32_t* timestamp); - - private: - int clockrate_; - uint16_t last_seq_num_; - uint32_t timestamp_offset_; -}; - -class RtpDataMediaChannel : public DataMediaChannel { - public: - explicit RtpDataMediaChannel(const MediaConfig& config); - virtual ~RtpDataMediaChannel(); - - virtual bool SetSendParameters(const DataSendParameters& params); - virtual bool SetRecvParameters(const DataRecvParameters& params); - virtual bool AddSendStream(const StreamParams& sp); - virtual bool RemoveSendStream(uint32_t ssrc); - virtual bool AddRecvStream(const StreamParams& sp); - virtual bool RemoveRecvStream(uint32_t ssrc); - virtual void ResetUnsignaledRecvStream(); - virtual bool SetSend(bool send) { - sending_ = send; - return true; - } - virtual bool SetReceive(bool receive) { - receiving_ = receive; - return true; - } - virtual void OnPacketReceived(rtc::CopyOnWriteBuffer packet, - int64_t packet_time_us); - virtual void OnReadyToSend(bool ready) {} - virtual bool SendData(const SendDataParams& params, - const rtc::CopyOnWriteBuffer& payload, - SendDataResult* result); - - private: - void Construct(); - bool SetMaxSendBandwidth(int bps); - bool SetSendCodecs(const std::vector& codecs); - bool SetRecvCodecs(const std::vector& codecs); - - bool sending_; - bool receiving_; - std::vector send_codecs_; - std::vector recv_codecs_; - std::vector send_streams_; - std::vector recv_streams_; - std::map rtp_clock_by_send_ssrc_; - std::unique_ptr send_limiter_; -}; - -} // namespace cricket - -#endif // MEDIA_BASE_RTP_DATA_ENGINE_H_ diff --git a/media/base/rtp_data_engine_unittest.cc b/media/base/rtp_data_engine_unittest.cc deleted file mode 100644 index f01c7c60c7..0000000000 --- a/media/base/rtp_data_engine_unittest.cc +++ /dev/null @@ -1,362 +0,0 @@ -/* - * Copyright (c) 2012 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "media/base/rtp_data_engine.h" - -#include - -#include -#include - -#include "media/base/fake_network_interface.h" -#include "media/base/media_constants.h" -#include "media/base/rtp_utils.h" -#include "rtc_base/copy_on_write_buffer.h" -#include "rtc_base/fake_clock.h" -#include "rtc_base/third_party/sigslot/sigslot.h" -#include "rtc_base/time_utils.h" -#include "test/gtest.h" - -class FakeDataReceiver : public sigslot::has_slots<> { - public: - FakeDataReceiver() : has_received_data_(false) {} - - void OnDataReceived(const cricket::ReceiveDataParams& params, - const char* data, - size_t len) { - has_received_data_ = true; - last_received_data_ = std::string(data, len); - last_received_data_len_ = len; - last_received_data_params_ = params; - } - - bool has_received_data() const { return has_received_data_; } - std::string last_received_data() const { return last_received_data_; } - size_t last_received_data_len() const { return last_received_data_len_; } - cricket::ReceiveDataParams last_received_data_params() const { - return last_received_data_params_; - } - - private: - bool has_received_data_; - std::string last_received_data_; - size_t last_received_data_len_; - cricket::ReceiveDataParams last_received_data_params_; -}; - -class RtpDataMediaChannelTest : public ::testing::Test { - protected: - virtual void SetUp() { - // Seed needed for each test to satisfy expectations. - iface_.reset(new cricket::FakeNetworkInterface()); - dme_.reset(CreateEngine()); - receiver_.reset(new FakeDataReceiver()); - } - - void SetNow(double now) { clock_.SetTime(webrtc::Timestamp::Seconds(now)); } - - cricket::RtpDataEngine* CreateEngine() { - cricket::RtpDataEngine* dme = new cricket::RtpDataEngine(); - return dme; - } - - cricket::RtpDataMediaChannel* CreateChannel() { - return CreateChannel(dme_.get()); - } - - cricket::RtpDataMediaChannel* CreateChannel(cricket::RtpDataEngine* dme) { - cricket::MediaConfig config; - cricket::RtpDataMediaChannel* channel = - static_cast(dme->CreateChannel(config)); - channel->SetInterface(iface_.get()); - channel->SignalDataReceived.connect(receiver_.get(), - &FakeDataReceiver::OnDataReceived); - return channel; - } - - FakeDataReceiver* receiver() { return receiver_.get(); } - - bool HasReceivedData() { return receiver_->has_received_data(); } - - std::string GetReceivedData() { return receiver_->last_received_data(); } - - size_t GetReceivedDataLen() { return receiver_->last_received_data_len(); } - - cricket::ReceiveDataParams GetReceivedDataParams() { - return receiver_->last_received_data_params(); - } - - bool HasSentData(int count) { return (iface_->NumRtpPackets() > count); } - - std::string GetSentData(int index) { - // Assume RTP header of length 12 - std::unique_ptr packet( - iface_->GetRtpPacket(index)); - if (packet->size() > 12) { - return std::string(packet->data() + 12, packet->size() - 12); - } else { - return ""; - } - } - - cricket::RtpHeader GetSentDataHeader(int index) { - std::unique_ptr packet( - iface_->GetRtpPacket(index)); - cricket::RtpHeader header; - GetRtpHeader(packet->data(), packet->size(), &header); - return header; - } - - private: - std::unique_ptr dme_; - rtc::ScopedFakeClock clock_; - std::unique_ptr iface_; - std::unique_ptr receiver_; -}; - -TEST_F(RtpDataMediaChannelTest, SetUnknownCodecs) { - std::unique_ptr dmc(CreateChannel()); - - cricket::DataCodec known_codec; - known_codec.id = 103; - known_codec.name = "google-data"; - cricket::DataCodec unknown_codec; - unknown_codec.id = 104; - unknown_codec.name = "unknown-data"; - - cricket::DataSendParameters send_parameters_known; - send_parameters_known.codecs.push_back(known_codec); - cricket::DataRecvParameters recv_parameters_known; - recv_parameters_known.codecs.push_back(known_codec); - - cricket::DataSendParameters send_parameters_unknown; - send_parameters_unknown.codecs.push_back(unknown_codec); - cricket::DataRecvParameters recv_parameters_unknown; - recv_parameters_unknown.codecs.push_back(unknown_codec); - - cricket::DataSendParameters send_parameters_mixed; - send_parameters_mixed.codecs.push_back(known_codec); - send_parameters_mixed.codecs.push_back(unknown_codec); - cricket::DataRecvParameters recv_parameters_mixed; - recv_parameters_mixed.codecs.push_back(known_codec); - recv_parameters_mixed.codecs.push_back(unknown_codec); - - EXPECT_TRUE(dmc->SetSendParameters(send_parameters_known)); - EXPECT_FALSE(dmc->SetSendParameters(send_parameters_unknown)); - EXPECT_TRUE(dmc->SetSendParameters(send_parameters_mixed)); - EXPECT_TRUE(dmc->SetRecvParameters(recv_parameters_known)); - EXPECT_FALSE(dmc->SetRecvParameters(recv_parameters_unknown)); - EXPECT_FALSE(dmc->SetRecvParameters(recv_parameters_mixed)); -} - -TEST_F(RtpDataMediaChannelTest, AddRemoveSendStream) { - std::unique_ptr dmc(CreateChannel()); - - cricket::StreamParams stream1; - stream1.add_ssrc(41); - EXPECT_TRUE(dmc->AddSendStream(stream1)); - cricket::StreamParams stream2; - stream2.add_ssrc(42); - EXPECT_TRUE(dmc->AddSendStream(stream2)); - - EXPECT_TRUE(dmc->RemoveSendStream(41)); - EXPECT_TRUE(dmc->RemoveSendStream(42)); - EXPECT_FALSE(dmc->RemoveSendStream(43)); -} - -TEST_F(RtpDataMediaChannelTest, AddRemoveRecvStream) { - std::unique_ptr dmc(CreateChannel()); - - cricket::StreamParams stream1; - stream1.add_ssrc(41); - EXPECT_TRUE(dmc->AddRecvStream(stream1)); - cricket::StreamParams stream2; - stream2.add_ssrc(42); - EXPECT_TRUE(dmc->AddRecvStream(stream2)); - EXPECT_FALSE(dmc->AddRecvStream(stream2)); - - EXPECT_TRUE(dmc->RemoveRecvStream(41)); - EXPECT_TRUE(dmc->RemoveRecvStream(42)); -} - -TEST_F(RtpDataMediaChannelTest, SendData) { - std::unique_ptr dmc(CreateChannel()); - - cricket::SendDataParams params; - params.ssrc = 42; - unsigned char data[] = "food"; - rtc::CopyOnWriteBuffer payload(data, 4); - unsigned char padded_data[] = { - 0x00, 0x00, 0x00, 0x00, 'f', 'o', 'o', 'd', - }; - cricket::SendDataResult result; - - // Not sending - EXPECT_FALSE(dmc->SendData(params, payload, &result)); - EXPECT_EQ(cricket::SDR_ERROR, result); - EXPECT_FALSE(HasSentData(0)); - ASSERT_TRUE(dmc->SetSend(true)); - - // Unknown stream name. - EXPECT_FALSE(dmc->SendData(params, payload, &result)); - EXPECT_EQ(cricket::SDR_ERROR, result); - EXPECT_FALSE(HasSentData(0)); - - cricket::StreamParams stream; - stream.add_ssrc(42); - ASSERT_TRUE(dmc->AddSendStream(stream)); - - // Unknown codec; - EXPECT_FALSE(dmc->SendData(params, payload, &result)); - EXPECT_EQ(cricket::SDR_ERROR, result); - EXPECT_FALSE(HasSentData(0)); - - cricket::DataCodec codec; - codec.id = 103; - codec.name = cricket::kGoogleRtpDataCodecName; - cricket::DataSendParameters parameters; - parameters.codecs.push_back(codec); - ASSERT_TRUE(dmc->SetSendParameters(parameters)); - - // Length too large; - std::string x10000(10000, 'x'); - EXPECT_FALSE(dmc->SendData( - params, rtc::CopyOnWriteBuffer(x10000.data(), x10000.length()), &result)); - EXPECT_EQ(cricket::SDR_ERROR, result); - EXPECT_FALSE(HasSentData(0)); - - // Finally works! - EXPECT_TRUE(dmc->SendData(params, payload, &result)); - EXPECT_EQ(cricket::SDR_SUCCESS, result); - ASSERT_TRUE(HasSentData(0)); - EXPECT_EQ(sizeof(padded_data), GetSentData(0).length()); - EXPECT_EQ(0, memcmp(padded_data, GetSentData(0).data(), sizeof(padded_data))); - cricket::RtpHeader header0 = GetSentDataHeader(0); - EXPECT_NE(0, header0.seq_num); - EXPECT_NE(0U, header0.timestamp); - EXPECT_EQ(header0.ssrc, 42U); - EXPECT_EQ(header0.payload_type, 103); - - // Should bump timestamp by 180000 because the clock rate is 90khz. - SetNow(2); - - EXPECT_TRUE(dmc->SendData(params, payload, &result)); - ASSERT_TRUE(HasSentData(1)); - EXPECT_EQ(sizeof(padded_data), GetSentData(1).length()); - EXPECT_EQ(0, memcmp(padded_data, GetSentData(1).data(), sizeof(padded_data))); - cricket::RtpHeader header1 = GetSentDataHeader(1); - EXPECT_EQ(header1.ssrc, 42U); - EXPECT_EQ(header1.payload_type, 103); - EXPECT_EQ(static_cast(header0.seq_num + 1), - static_cast(header1.seq_num)); - EXPECT_EQ(header0.timestamp + 180000, header1.timestamp); -} - -TEST_F(RtpDataMediaChannelTest, SendDataRate) { - std::unique_ptr dmc(CreateChannel()); - - ASSERT_TRUE(dmc->SetSend(true)); - - cricket::DataCodec codec; - codec.id = 103; - codec.name = cricket::kGoogleRtpDataCodecName; - cricket::DataSendParameters parameters; - parameters.codecs.push_back(codec); - ASSERT_TRUE(dmc->SetSendParameters(parameters)); - - cricket::StreamParams stream; - stream.add_ssrc(42); - ASSERT_TRUE(dmc->AddSendStream(stream)); - - cricket::SendDataParams params; - params.ssrc = 42; - unsigned char data[] = "food"; - rtc::CopyOnWriteBuffer payload(data, 4); - cricket::SendDataResult result; - - // With rtp overhead of 32 bytes, each one of our packets is 36 - // bytes, or 288 bits. So, a limit of 872bps will allow 3 packets, - // but not four. - parameters.max_bandwidth_bps = 872; - ASSERT_TRUE(dmc->SetSendParameters(parameters)); - - EXPECT_TRUE(dmc->SendData(params, payload, &result)); - EXPECT_TRUE(dmc->SendData(params, payload, &result)); - EXPECT_TRUE(dmc->SendData(params, payload, &result)); - EXPECT_FALSE(dmc->SendData(params, payload, &result)); - EXPECT_FALSE(dmc->SendData(params, payload, &result)); - - SetNow(0.9); - EXPECT_FALSE(dmc->SendData(params, payload, &result)); - - SetNow(1.1); - EXPECT_TRUE(dmc->SendData(params, payload, &result)); - EXPECT_TRUE(dmc->SendData(params, payload, &result)); - SetNow(1.9); - EXPECT_TRUE(dmc->SendData(params, payload, &result)); - - SetNow(2.2); - EXPECT_TRUE(dmc->SendData(params, payload, &result)); - EXPECT_TRUE(dmc->SendData(params, payload, &result)); - EXPECT_TRUE(dmc->SendData(params, payload, &result)); - EXPECT_FALSE(dmc->SendData(params, payload, &result)); -} - -TEST_F(RtpDataMediaChannelTest, ReceiveData) { - // PT= 103, SN=2, TS=3, SSRC = 4, data = "abcde" - unsigned char data[] = {0x80, 0x67, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x00, 0x2A, 0x00, 0x00, - 0x00, 0x00, 'a', 'b', 'c', 'd', 'e'}; - rtc::CopyOnWriteBuffer packet(data, sizeof(data)); - - std::unique_ptr dmc(CreateChannel()); - - // SetReceived not called. - dmc->OnPacketReceived(packet, /* packet_time_us */ -1); - EXPECT_FALSE(HasReceivedData()); - - dmc->SetReceive(true); - - // Unknown payload id - dmc->OnPacketReceived(packet, /* packet_time_us */ -1); - EXPECT_FALSE(HasReceivedData()); - - cricket::DataCodec codec; - codec.id = 103; - codec.name = cricket::kGoogleRtpDataCodecName; - cricket::DataRecvParameters parameters; - parameters.codecs.push_back(codec); - ASSERT_TRUE(dmc->SetRecvParameters(parameters)); - - // Unknown stream - dmc->OnPacketReceived(packet, /* packet_time_us */ -1); - EXPECT_FALSE(HasReceivedData()); - - cricket::StreamParams stream; - stream.add_ssrc(42); - ASSERT_TRUE(dmc->AddRecvStream(stream)); - - // Finally works! - dmc->OnPacketReceived(packet, /* packet_time_us */ -1); - EXPECT_TRUE(HasReceivedData()); - EXPECT_EQ("abcde", GetReceivedData()); - EXPECT_EQ(5U, GetReceivedDataLen()); -} - -TEST_F(RtpDataMediaChannelTest, InvalidRtpPackets) { - unsigned char data[] = {0x80, 0x65, 0x00, 0x02}; - rtc::CopyOnWriteBuffer packet(data, sizeof(data)); - - std::unique_ptr dmc(CreateChannel()); - - // Too short - dmc->OnPacketReceived(packet, /* packet_time_us */ -1); - EXPECT_FALSE(HasReceivedData()); -} diff --git a/media/base/rtp_utils.cc b/media/base/rtp_utils.cc index 4714175226..9f90c468f7 100644 --- a/media/base/rtp_utils.cc +++ b/media/base/rtp_utils.cc @@ -17,6 +17,7 @@ // PacketTimeUpdateParams is defined in asyncpacketsocket.h. // TODO(sergeyu): Find more appropriate place for PacketTimeUpdateParams. #include "media/base/turn_utils.h" +#include "modules/rtp_rtcp/source/rtp_util.h" #include "rtc_base/async_packet_socket.h" #include "rtc_base/byte_order.h" #include "rtc_base/checks.h" @@ -24,8 +25,6 @@ namespace cricket { -static const uint8_t kRtpVersion = 2; -static const size_t kRtpFlagsOffset = 0; static const size_t kRtpPayloadTypeOffset = 1; static const size_t kRtpSeqNumOffset = 2; static const size_t kRtpTimestampOffset = 4; @@ -119,8 +118,6 @@ void UpdateRtpAuthTag(uint8_t* rtp, memcpy(auth_tag, output, tag_length); } -} // namespace - bool GetUint8(const void* data, size_t offset, int* value) { if (!data || !value) { return false; @@ -146,36 +143,7 @@ bool GetUint32(const void* data, size_t offset, uint32_t* value) { return true; } -bool SetUint8(void* data, size_t offset, uint8_t value) { - if (!data) { - return false; - } - rtc::Set8(data, offset, value); - return true; -} - -bool SetUint16(void* data, size_t offset, uint16_t value) { - if (!data) { - return false; - } - rtc::SetBE16(static_cast(data) + offset, value); - return true; -} - -bool SetUint32(void* data, size_t offset, uint32_t value) { - if (!data) { - return false; - } - rtc::SetBE32(static_cast(data) + offset, value); - return true; -} - -bool GetRtpFlags(const void* data, size_t len, int* value) { - if (len < kMinRtpPacketLen) { - return false; - } - return GetUint8(data, kRtpFlagsOffset, value); -} +} // namespace bool GetRtpPayloadType(const void* data, size_t len, int* value) { if (len < kMinRtpPacketLen) { @@ -209,34 +177,6 @@ bool GetRtpSsrc(const void* data, size_t len, uint32_t* value) { return GetUint32(data, kRtpSsrcOffset, value); } -bool GetRtpHeaderLen(const void* data, size_t len, size_t* value) { - if (!data || len < kMinRtpPacketLen || !value) - return false; - const uint8_t* header = static_cast(data); - // Get base header size + length of CSRCs (not counting extension yet). - size_t header_size = kMinRtpPacketLen + (header[0] & 0xF) * sizeof(uint32_t); - if (len < header_size) - return false; - // If there's an extension, read and add in the extension size. - if (header[0] & 0x10) { - if (len < header_size + sizeof(uint32_t)) - return false; - header_size += - ((rtc::GetBE16(header + header_size + 2) + 1) * sizeof(uint32_t)); - if (len < header_size) - return false; - } - *value = header_size; - return true; -} - -bool GetRtpHeader(const void* data, size_t len, RtpHeader* header) { - return (GetRtpPayloadType(data, len, &(header->payload_type)) && - GetRtpSeqNum(data, len, &(header->seq_num)) && - GetRtpTimestamp(data, len, &(header->timestamp)) && - GetRtpSsrc(data, len, &(header->ssrc))); -} - bool GetRtcpType(const void* data, size_t len, int* value) { if (len < kMinRtcpPacketLen) { return false; @@ -261,47 +201,6 @@ bool GetRtcpSsrc(const void* data, size_t len, uint32_t* value) { return true; } -bool SetRtpSsrc(void* data, size_t len, uint32_t value) { - return SetUint32(data, kRtpSsrcOffset, value); -} - -// Assumes version 2, no padding, no extensions, no csrcs. -bool SetRtpHeader(void* data, size_t len, const RtpHeader& header) { - if (!IsValidRtpPayloadType(header.payload_type) || header.seq_num < 0 || - header.seq_num > static_cast(UINT16_MAX)) { - return false; - } - return (SetUint8(data, kRtpFlagsOffset, kRtpVersion << 6) && - SetUint8(data, kRtpPayloadTypeOffset, header.payload_type & 0x7F) && - SetUint16(data, kRtpSeqNumOffset, - static_cast(header.seq_num)) && - SetUint32(data, kRtpTimestampOffset, header.timestamp) && - SetRtpSsrc(data, len, header.ssrc)); -} - -static bool HasCorrectRtpVersion(rtc::ArrayView packet) { - return packet.data()[0] >> 6 == kRtpVersion; -} - -bool IsRtpPacket(rtc::ArrayView packet) { - return packet.size() >= kMinRtpPacketLen && - HasCorrectRtpVersion( - rtc::reinterpret_array_view(packet)); -} - -// Check the RTP payload type. If 63 < payload type < 96, it's RTCP. -// For additional details, see http://tools.ietf.org/html/rfc5761. -bool IsRtcpPacket(rtc::ArrayView packet) { - if (packet.size() < kMinRtcpPacketLen || - !HasCorrectRtpVersion( - rtc::reinterpret_array_view(packet))) { - return false; - } - - char pt = packet[1] & 0x7F; - return (63 < pt) && (pt < 96); -} - bool IsValidRtpPayloadType(int payload_type) { return payload_type >= 0 && payload_type <= 127; } @@ -327,11 +226,11 @@ absl::string_view RtpPacketTypeToString(RtpPacketType packet_type) { } RtpPacketType InferRtpPacketType(rtc::ArrayView packet) { - // RTCP packets are RTP packets so must check that first. - if (IsRtcpPacket(packet)) { + if (webrtc::IsRtcpPacket( + rtc::reinterpret_array_view(packet))) { return RtpPacketType::kRtcp; } - if (IsRtpPacket(packet)) { + if (webrtc::IsRtpPacket(rtc::reinterpret_array_view(packet))) { return RtpPacketType::kRtp; } return RtpPacketType::kUnknown; @@ -532,7 +431,7 @@ bool ApplyPacketOptions(uint8_t* data, // Making sure we have a valid RTP packet at the end. auto packet = rtc::MakeArrayView(data + rtp_start_pos, rtp_length); - if (!IsRtpPacket(rtc::reinterpret_array_view(packet)) || + if (!webrtc::IsRtpPacket(packet) || !ValidateRtpHeader(data + rtp_start_pos, rtp_length, nullptr)) { RTC_NOTREACHED(); return false; diff --git a/media/base/rtp_utils.h b/media/base/rtp_utils.h index 9ef9f9c7ba..f6b5dbc9f0 100644 --- a/media/base/rtp_utils.h +++ b/media/base/rtp_utils.h @@ -26,13 +26,6 @@ const size_t kMinRtpPacketLen = 12; const size_t kMaxRtpPacketLen = 2048; const size_t kMinRtcpPacketLen = 4; -struct RtpHeader { - int payload_type; - int seq_num; - uint32_t timestamp; - uint32_t ssrc; -}; - enum RtcpTypes { kRtcpTypeSR = 200, // Sender report payload type. kRtcpTypeRR = 201, // Receiver report payload type. @@ -53,18 +46,10 @@ bool GetRtpPayloadType(const void* data, size_t len, int* value); bool GetRtpSeqNum(const void* data, size_t len, int* value); bool GetRtpTimestamp(const void* data, size_t len, uint32_t* value); bool GetRtpSsrc(const void* data, size_t len, uint32_t* value); -bool GetRtpHeaderLen(const void* data, size_t len, size_t* value); + bool GetRtcpType(const void* data, size_t len, int* value); bool GetRtcpSsrc(const void* data, size_t len, uint32_t* value); -bool GetRtpHeader(const void* data, size_t len, RtpHeader* header); - -bool SetRtpSsrc(void* data, size_t len, uint32_t value); -// Assumes version 2, no padding, no extensions, no csrcs. -bool SetRtpHeader(void* data, size_t len, const RtpHeader& header); - -bool IsRtpPacket(rtc::ArrayView packet); -bool IsRtcpPacket(rtc::ArrayView packet); // Checks the packet header to determine if it can be an RTP or RTCP packet. RtpPacketType InferRtpPacketType(rtc::ArrayView packet); // True if |payload type| is 0-127. diff --git a/media/base/rtp_utils_unittest.cc b/media/base/rtp_utils_unittest.cc index a5e8a810f4..14599abca2 100644 --- a/media/base/rtp_utils_unittest.cc +++ b/media/base/rtp_utils_unittest.cc @@ -23,24 +23,7 @@ namespace cricket { static const uint8_t kRtpPacketWithMarker[] = { 0x80, 0x80, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}; -// 3 CSRCs (0x01020304, 0x12345678, 0xAABBCCDD) -// Extension (0xBEDE, 0x1122334455667788) -static const uint8_t kRtpPacketWithMarkerAndCsrcAndExtension[] = { - 0x93, 0x80, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - 0x01, 0x02, 0x03, 0x04, 0x12, 0x34, 0x56, 0x78, 0xAA, 0xBB, 0xCC, 0xDD, - 0xBE, 0xDE, 0x00, 0x02, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}; static const uint8_t kInvalidPacket[] = {0x80, 0x00}; -static const uint8_t kInvalidPacketWithCsrc[] = { - 0x83, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - 0x01, 0x02, 0x03, 0x04, 0x12, 0x34, 0x56, 0x78, 0xAA, 0xBB, 0xCC}; -static const uint8_t kInvalidPacketWithCsrcAndExtension1[] = { - 0x93, 0x80, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x01, 0x01, 0x02, 0x03, 0x04, 0x12, 0x34, - 0x56, 0x78, 0xAA, 0xBB, 0xCC, 0xDD, 0xBE, 0xDE, 0x00}; -static const uint8_t kInvalidPacketWithCsrcAndExtension2[] = { - 0x93, 0x80, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - 0x01, 0x02, 0x03, 0x04, 0x12, 0x34, 0x56, 0x78, 0xAA, 0xBB, 0xCC, 0xDD, - 0xBE, 0xDE, 0x00, 0x02, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77}; // PT = 206, FMT = 1, Sender SSRC = 0x1111, Media SSRC = 0x1111 // No FCI information is needed for PLI. @@ -102,8 +85,6 @@ static const rtc::ArrayView kInvalidPacketArrayView = sizeof(kInvalidPacket)); TEST(RtpUtilsTest, GetRtp) { - EXPECT_TRUE(IsRtpPacket(kPcmuFrameArrayView)); - int pt; EXPECT_TRUE(GetRtpPayloadType(kPcmuFrame, sizeof(kPcmuFrame), &pt)); EXPECT_EQ(0, pt); @@ -123,59 +104,12 @@ TEST(RtpUtilsTest, GetRtp) { EXPECT_TRUE(GetRtpSsrc(kPcmuFrame, sizeof(kPcmuFrame), &ssrc)); EXPECT_EQ(1u, ssrc); - RtpHeader header; - EXPECT_TRUE(GetRtpHeader(kPcmuFrame, sizeof(kPcmuFrame), &header)); - EXPECT_EQ(0, header.payload_type); - EXPECT_EQ(1, header.seq_num); - EXPECT_EQ(0u, header.timestamp); - EXPECT_EQ(1u, header.ssrc); - EXPECT_FALSE(GetRtpPayloadType(kInvalidPacket, sizeof(kInvalidPacket), &pt)); EXPECT_FALSE(GetRtpSeqNum(kInvalidPacket, sizeof(kInvalidPacket), &seq_num)); EXPECT_FALSE(GetRtpTimestamp(kInvalidPacket, sizeof(kInvalidPacket), &ts)); EXPECT_FALSE(GetRtpSsrc(kInvalidPacket, sizeof(kInvalidPacket), &ssrc)); } -TEST(RtpUtilsTest, SetRtpHeader) { - uint8_t packet[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; - - RtpHeader header = {9, 1111, 2222u, 3333u}; - EXPECT_TRUE(SetRtpHeader(packet, sizeof(packet), header)); - - // Bits: 10 0 0 0000 - EXPECT_EQ(128u, packet[0]); - size_t len; - EXPECT_TRUE(GetRtpHeaderLen(packet, sizeof(packet), &len)); - EXPECT_EQ(12U, len); - EXPECT_TRUE(GetRtpHeader(packet, sizeof(packet), &header)); - EXPECT_EQ(9, header.payload_type); - EXPECT_EQ(1111, header.seq_num); - EXPECT_EQ(2222u, header.timestamp); - EXPECT_EQ(3333u, header.ssrc); -} - -TEST(RtpUtilsTest, GetRtpHeaderLen) { - size_t len; - EXPECT_TRUE(GetRtpHeaderLen(kPcmuFrame, sizeof(kPcmuFrame), &len)); - EXPECT_EQ(12U, len); - - EXPECT_TRUE(GetRtpHeaderLen(kRtpPacketWithMarkerAndCsrcAndExtension, - sizeof(kRtpPacketWithMarkerAndCsrcAndExtension), - &len)); - EXPECT_EQ(sizeof(kRtpPacketWithMarkerAndCsrcAndExtension), len); - - EXPECT_FALSE(GetRtpHeaderLen(kInvalidPacket, sizeof(kInvalidPacket), &len)); - EXPECT_FALSE(GetRtpHeaderLen(kInvalidPacketWithCsrc, - sizeof(kInvalidPacketWithCsrc), &len)); - EXPECT_FALSE(GetRtpHeaderLen(kInvalidPacketWithCsrcAndExtension1, - sizeof(kInvalidPacketWithCsrcAndExtension1), - &len)); - EXPECT_FALSE(GetRtpHeaderLen(kInvalidPacketWithCsrcAndExtension2, - sizeof(kInvalidPacketWithCsrcAndExtension2), - &len)); -} - TEST(RtpUtilsTest, GetRtcp) { int pt; EXPECT_TRUE(GetRtcpType(kRtcpReport, sizeof(kRtcpReport), &pt)); diff --git a/media/base/sdp_fmtp_utils.cc b/media/base/sdp_fmtp_utils.cc deleted file mode 100644 index 4ffc3b9696..0000000000 --- a/media/base/sdp_fmtp_utils.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "media/base/sdp_fmtp_utils.h" - -#include -#include - -#include "rtc_base/string_to_number.h" - -namespace webrtc { -namespace { -// Max frame rate for VP8 and VP9 video. -const char kVPxFmtpMaxFrameRate[] = "max-fr"; -// Max frame size for VP8 and VP9 video. -const char kVPxFmtpMaxFrameSize[] = "max-fs"; -const int kVPxFmtpFrameSizeSubBlockPixels = 256; - -absl::optional ParsePositiveNumberFromParams( - const SdpVideoFormat::Parameters& params, - const char* parameter_name) { - const auto max_frame_rate_it = params.find(parameter_name); - if (max_frame_rate_it == params.end()) - return absl::nullopt; - - const absl::optional i = - rtc::StringToNumber(max_frame_rate_it->second); - if (!i.has_value() || i.value() <= 0) - return absl::nullopt; - return i; -} - -} // namespace - -absl::optional ParseSdpForVPxMaxFrameRate( - const SdpVideoFormat::Parameters& params) { - return ParsePositiveNumberFromParams(params, kVPxFmtpMaxFrameRate); -} - -absl::optional ParseSdpForVPxMaxFrameSize( - const SdpVideoFormat::Parameters& params) { - const absl::optional i = - ParsePositiveNumberFromParams(params, kVPxFmtpMaxFrameSize); - return i ? absl::make_optional(i.value() * kVPxFmtpFrameSizeSubBlockPixels) - : absl::nullopt; -} - -} // namespace webrtc diff --git a/media/base/sdp_fmtp_utils.h b/media/base/sdp_fmtp_utils.h deleted file mode 100644 index 04e9183614..0000000000 --- a/media/base/sdp_fmtp_utils.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef MEDIA_BASE_SDP_FMTP_UTILS_H_ -#define MEDIA_BASE_SDP_FMTP_UTILS_H_ - -#include "absl/types/optional.h" -#include "api/video_codecs/sdp_video_format.h" - -namespace webrtc { - -// Parse max frame rate from SDP FMTP line. absl::nullopt is returned if the -// field is missing or not a number. -absl::optional ParseSdpForVPxMaxFrameRate( - const SdpVideoFormat::Parameters& params); - -// Parse max frame size from SDP FMTP line. absl::nullopt is returned if the -// field is missing or not a number. Please note that the value is stored in sub -// blocks but the returned value is in total number of pixels. -absl::optional ParseSdpForVPxMaxFrameSize( - const SdpVideoFormat::Parameters& params); - -} // namespace webrtc - -#endif // MEDIA_BASE_SDP_FMTP_UTILS_H__ diff --git a/media/base/sdp_fmtp_utils_unittest.cc b/media/base/sdp_fmtp_utils_unittest.cc deleted file mode 100644 index 0ff12ffbe1..0000000000 --- a/media/base/sdp_fmtp_utils_unittest.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "media/base/sdp_fmtp_utils.h" - -#include -#include -#include - -#include "rtc_base/string_to_number.h" -#include "test/gtest.h" - -namespace webrtc { -namespace { -// Max frame rate for VP8 and VP9 video. -const char kVPxFmtpMaxFrameRate[] = "max-fr"; -// Max frame size for VP8 and VP9 video. -const char kVPxFmtpMaxFrameSize[] = "max-fs"; -} // namespace - -TEST(SdpFmtpUtilsTest, MaxFrameRateIsMissingOrInvalid) { - SdpVideoFormat::Parameters params; - absl::optional empty = ParseSdpForVPxMaxFrameRate(params); - EXPECT_FALSE(empty); - params[kVPxFmtpMaxFrameRate] = "-1"; - EXPECT_FALSE(ParseSdpForVPxMaxFrameRate(params)); - params[kVPxFmtpMaxFrameRate] = "0"; - EXPECT_FALSE(ParseSdpForVPxMaxFrameRate(params)); - params[kVPxFmtpMaxFrameRate] = "abcde"; - EXPECT_FALSE(ParseSdpForVPxMaxFrameRate(params)); -} - -TEST(SdpFmtpUtilsTest, MaxFrameRateIsSpecified) { - SdpVideoFormat::Parameters params; - params[kVPxFmtpMaxFrameRate] = "30"; - EXPECT_EQ(ParseSdpForVPxMaxFrameRate(params), 30); - params[kVPxFmtpMaxFrameRate] = "60"; - EXPECT_EQ(ParseSdpForVPxMaxFrameRate(params), 60); -} - -TEST(SdpFmtpUtilsTest, MaxFrameSizeIsMissingOrInvalid) { - SdpVideoFormat::Parameters params; - absl::optional empty = ParseSdpForVPxMaxFrameSize(params); - EXPECT_FALSE(empty); - params[kVPxFmtpMaxFrameSize] = "-1"; - EXPECT_FALSE(ParseSdpForVPxMaxFrameSize(params)); - params[kVPxFmtpMaxFrameSize] = "0"; - EXPECT_FALSE(ParseSdpForVPxMaxFrameSize(params)); - params[kVPxFmtpMaxFrameSize] = "abcde"; - EXPECT_FALSE(ParseSdpForVPxMaxFrameSize(params)); -} - -TEST(SdpFmtpUtilsTest, MaxFrameSizeIsSpecified) { - SdpVideoFormat::Parameters params; - params[kVPxFmtpMaxFrameSize] = "8100"; // 1920 x 1080 / (16^2) - EXPECT_EQ(ParseSdpForVPxMaxFrameSize(params), 1920 * 1080); - params[kVPxFmtpMaxFrameSize] = "32400"; // 3840 x 2160 / (16^2) - EXPECT_EQ(ParseSdpForVPxMaxFrameSize(params), 3840 * 2160); -} - -} // namespace webrtc diff --git a/media/base/sdp_video_format_utils.cc b/media/base/sdp_video_format_utils.cc new file mode 100644 index 0000000000..a156afdc02 --- /dev/null +++ b/media/base/sdp_video_format_utils.cc @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "media/base/sdp_video_format_utils.h" + +#include +#include +#include + +#include "api/video_codecs/h264_profile_level_id.h" +#include "rtc_base/checks.h" +#include "rtc_base/string_to_number.h" + +namespace webrtc { +namespace { +const char kProfileLevelId[] = "profile-level-id"; +const char kH264LevelAsymmetryAllowed[] = "level-asymmetry-allowed"; +// Max frame rate for VP8 and VP9 video. +const char kVPxFmtpMaxFrameRate[] = "max-fr"; +// Max frame size for VP8 and VP9 video. +const char kVPxFmtpMaxFrameSize[] = "max-fs"; +const int kVPxFmtpFrameSizeSubBlockPixels = 256; + +bool IsH264LevelAsymmetryAllowed(const SdpVideoFormat::Parameters& params) { + const auto it = params.find(kH264LevelAsymmetryAllowed); + return it != params.end() && strcmp(it->second.c_str(), "1") == 0; +} + +// Compare H264 levels and handle the level 1b case. +bool H264LevelIsLess(H264Level a, H264Level b) { + if (a == H264Level::kLevel1_b) + return b != H264Level::kLevel1 && b != H264Level::kLevel1_b; + if (b == H264Level::kLevel1_b) + return a == H264Level::kLevel1; + return a < b; +} + +H264Level H264LevelMin(H264Level a, H264Level b) { + return H264LevelIsLess(a, b) ? a : b; +} + +absl::optional ParsePositiveNumberFromParams( + const SdpVideoFormat::Parameters& params, + const char* parameter_name) { + const auto max_frame_rate_it = params.find(parameter_name); + if (max_frame_rate_it == params.end()) + return absl::nullopt; + + const absl::optional i = + rtc::StringToNumber(max_frame_rate_it->second); + if (!i.has_value() || i.value() <= 0) + return absl::nullopt; + return i; +} + +} // namespace + +// Set level according to https://tools.ietf.org/html/rfc6184#section-8.2.2. +void H264GenerateProfileLevelIdForAnswer( + const SdpVideoFormat::Parameters& local_supported_params, + const SdpVideoFormat::Parameters& remote_offered_params, + SdpVideoFormat::Parameters* answer_params) { + // If both local and remote haven't set profile-level-id, they are both using + // the default profile. In this case, don't set profile-level-id in answer + // either. + if (!local_supported_params.count(kProfileLevelId) && + !remote_offered_params.count(kProfileLevelId)) { + return; + } + + // Parse profile-level-ids. + const absl::optional local_profile_level_id = + ParseSdpForH264ProfileLevelId(local_supported_params); + const absl::optional remote_profile_level_id = + ParseSdpForH264ProfileLevelId(remote_offered_params); + // The local and remote codec must have valid and equal H264 Profiles. + RTC_DCHECK(local_profile_level_id); + RTC_DCHECK(remote_profile_level_id); + RTC_DCHECK_EQ(local_profile_level_id->profile, + remote_profile_level_id->profile); + + // Parse level information. + const bool level_asymmetry_allowed = + IsH264LevelAsymmetryAllowed(local_supported_params) && + IsH264LevelAsymmetryAllowed(remote_offered_params); + const H264Level local_level = local_profile_level_id->level; + const H264Level remote_level = remote_profile_level_id->level; + const H264Level min_level = H264LevelMin(local_level, remote_level); + + // Determine answer level. When level asymmetry is not allowed, level upgrade + // is not allowed, i.e., the level in the answer must be equal to or lower + // than the level in the offer. + const H264Level answer_level = + level_asymmetry_allowed ? local_level : min_level; + + // Set the resulting profile-level-id in the answer parameters. + (*answer_params)[kProfileLevelId] = *H264ProfileLevelIdToString( + H264ProfileLevelId(local_profile_level_id->profile, answer_level)); +} + +absl::optional ParseSdpForVPxMaxFrameRate( + const SdpVideoFormat::Parameters& params) { + return ParsePositiveNumberFromParams(params, kVPxFmtpMaxFrameRate); +} + +absl::optional ParseSdpForVPxMaxFrameSize( + const SdpVideoFormat::Parameters& params) { + const absl::optional i = + ParsePositiveNumberFromParams(params, kVPxFmtpMaxFrameSize); + return i ? absl::make_optional(i.value() * kVPxFmtpFrameSizeSubBlockPixels) + : absl::nullopt; +} + +} // namespace webrtc diff --git a/media/base/sdp_video_format_utils.h b/media/base/sdp_video_format_utils.h new file mode 100644 index 0000000000..6671c182ac --- /dev/null +++ b/media/base/sdp_video_format_utils.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MEDIA_BASE_SDP_VIDEO_FORMAT_UTILS_H_ +#define MEDIA_BASE_SDP_VIDEO_FORMAT_UTILS_H_ + +#include "absl/types/optional.h" +#include "api/video_codecs/sdp_video_format.h" + +namespace webrtc { +// Generate codec parameters that will be used as answer in an SDP negotiation +// based on local supported parameters and remote offered parameters. Both +// |local_supported_params|, |remote_offered_params|, and |answer_params| +// represent sendrecv media descriptions, i.e they are a mix of both encode and +// decode capabilities. In theory, when the profile in |local_supported_params| +// represent a strict superset of the profile in |remote_offered_params|, we +// could limit the profile in |answer_params| to the profile in +// |remote_offered_params|. However, to simplify the code, each supported H264 +// profile should be listed explicitly in the list of local supported codecs, +// even if they are redundant. Then each local codec in the list should be +// tested one at a time against the remote codec, and only when the profiles are +// equal should this function be called. Therefore, this function does not need +// to handle profile intersection, and the profile of |local_supported_params| +// and |remote_offered_params| must be equal before calling this function. The +// parameters that are used when negotiating are the level part of +// profile-level-id and level-asymmetry-allowed. +void H264GenerateProfileLevelIdForAnswer( + const SdpVideoFormat::Parameters& local_supported_params, + const SdpVideoFormat::Parameters& remote_offered_params, + SdpVideoFormat::Parameters* answer_params); + +// Parse max frame rate from SDP FMTP line. absl::nullopt is returned if the +// field is missing or not a number. +absl::optional ParseSdpForVPxMaxFrameRate( + const SdpVideoFormat::Parameters& params); + +// Parse max frame size from SDP FMTP line. absl::nullopt is returned if the +// field is missing or not a number. Please note that the value is stored in sub +// blocks but the returned value is in total number of pixels. +absl::optional ParseSdpForVPxMaxFrameSize( + const SdpVideoFormat::Parameters& params); + +} // namespace webrtc + +#endif // MEDIA_BASE_SDP_VIDEO_FORMAT_UTILS_H_ diff --git a/media/base/sdp_video_format_utils_unittest.cc b/media/base/sdp_video_format_utils_unittest.cc new file mode 100644 index 0000000000..d8ef9ab827 --- /dev/null +++ b/media/base/sdp_video_format_utils_unittest.cc @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "media/base/sdp_video_format_utils.h" + +#include + +#include +#include + +#include "rtc_base/string_to_number.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { +// Max frame rate for VP8 and VP9 video. +const char kVPxFmtpMaxFrameRate[] = "max-fr"; +// Max frame size for VP8 and VP9 video. +const char kVPxFmtpMaxFrameSize[] = "max-fs"; +} // namespace + +TEST(SdpVideoFormatUtilsTest, TestH264GenerateProfileLevelIdForAnswerEmpty) { + SdpVideoFormat::Parameters answer_params; + H264GenerateProfileLevelIdForAnswer(SdpVideoFormat::Parameters(), + SdpVideoFormat::Parameters(), + &answer_params); + EXPECT_TRUE(answer_params.empty()); +} + +TEST(SdpVideoFormatUtilsTest, + TestH264GenerateProfileLevelIdForAnswerLevelSymmetryCapped) { + SdpVideoFormat::Parameters low_level; + low_level["profile-level-id"] = "42e015"; + SdpVideoFormat::Parameters high_level; + high_level["profile-level-id"] = "42e01f"; + + // Level asymmetry is not allowed; test that answer level is the lower of the + // local and remote levels. + SdpVideoFormat::Parameters answer_params; + H264GenerateProfileLevelIdForAnswer(low_level /* local_supported */, + high_level /* remote_offered */, + &answer_params); + EXPECT_EQ("42e015", answer_params["profile-level-id"]); + + SdpVideoFormat::Parameters answer_params2; + H264GenerateProfileLevelIdForAnswer(high_level /* local_supported */, + low_level /* remote_offered */, + &answer_params2); + EXPECT_EQ("42e015", answer_params2["profile-level-id"]); +} + +TEST(SdpVideoFormatUtilsTest, + TestH264GenerateProfileLevelIdForAnswerConstrainedBaselineLevelAsymmetry) { + SdpVideoFormat::Parameters local_params; + local_params["profile-level-id"] = "42e01f"; + local_params["level-asymmetry-allowed"] = "1"; + SdpVideoFormat::Parameters remote_params; + remote_params["profile-level-id"] = "42e015"; + remote_params["level-asymmetry-allowed"] = "1"; + SdpVideoFormat::Parameters answer_params; + H264GenerateProfileLevelIdForAnswer(local_params, remote_params, + &answer_params); + // When level asymmetry is allowed, we can answer a higher level than what was + // offered. + EXPECT_EQ("42e01f", answer_params["profile-level-id"]); +} + +TEST(SdpVideoFormatUtilsTest, MaxFrameRateIsMissingOrInvalid) { + SdpVideoFormat::Parameters params; + absl::optional empty = ParseSdpForVPxMaxFrameRate(params); + EXPECT_FALSE(empty); + params[kVPxFmtpMaxFrameRate] = "-1"; + EXPECT_FALSE(ParseSdpForVPxMaxFrameRate(params)); + params[kVPxFmtpMaxFrameRate] = "0"; + EXPECT_FALSE(ParseSdpForVPxMaxFrameRate(params)); + params[kVPxFmtpMaxFrameRate] = "abcde"; + EXPECT_FALSE(ParseSdpForVPxMaxFrameRate(params)); +} + +TEST(SdpVideoFormatUtilsTest, MaxFrameRateIsSpecified) { + SdpVideoFormat::Parameters params; + params[kVPxFmtpMaxFrameRate] = "30"; + EXPECT_EQ(ParseSdpForVPxMaxFrameRate(params), 30); + params[kVPxFmtpMaxFrameRate] = "60"; + EXPECT_EQ(ParseSdpForVPxMaxFrameRate(params), 60); +} + +TEST(SdpVideoFormatUtilsTest, MaxFrameSizeIsMissingOrInvalid) { + SdpVideoFormat::Parameters params; + absl::optional empty = ParseSdpForVPxMaxFrameSize(params); + EXPECT_FALSE(empty); + params[kVPxFmtpMaxFrameSize] = "-1"; + EXPECT_FALSE(ParseSdpForVPxMaxFrameSize(params)); + params[kVPxFmtpMaxFrameSize] = "0"; + EXPECT_FALSE(ParseSdpForVPxMaxFrameSize(params)); + params[kVPxFmtpMaxFrameSize] = "abcde"; + EXPECT_FALSE(ParseSdpForVPxMaxFrameSize(params)); +} + +TEST(SdpVideoFormatUtilsTest, MaxFrameSizeIsSpecified) { + SdpVideoFormat::Parameters params; + params[kVPxFmtpMaxFrameSize] = "8100"; // 1920 x 1080 / (16^2) + EXPECT_EQ(ParseSdpForVPxMaxFrameSize(params), 1920 * 1080); + params[kVPxFmtpMaxFrameSize] = "32400"; // 3840 x 2160 / (16^2) + EXPECT_EQ(ParseSdpForVPxMaxFrameSize(params), 3840 * 2160); +} + +} // namespace webrtc diff --git a/media/base/turn_utils.h b/media/base/turn_utils.h index ed8e282ba7..82e492c028 100644 --- a/media/base/turn_utils.h +++ b/media/base/turn_utils.h @@ -18,8 +18,6 @@ namespace cricket { -struct PacketOptions; - // Finds data location within a TURN Channel Message or TURN Send Indication // message. bool RTC_EXPORT UnwrapTurnPacket(const uint8_t* packet, diff --git a/media/base/video_broadcaster.cc b/media/base/video_broadcaster.cc index e6a91368fc..3c20eca963 100644 --- a/media/base/video_broadcaster.cc +++ b/media/base/video_broadcaster.cc @@ -94,6 +94,7 @@ void VideoBroadcaster::OnFrame(const webrtc::VideoFrame& frame) { } void VideoBroadcaster::OnDiscardedFrame() { + webrtc::MutexLock lock(&sinks_and_wants_lock_); for (auto& sink_pair : sink_pairs()) { sink_pair.sink->OnDiscardedFrame(); } diff --git a/media/base/video_broadcaster.h b/media/base/video_broadcaster.h index 0703862c4f..2f4e578224 100644 --- a/media/base/video_broadcaster.h +++ b/media/base/video_broadcaster.h @@ -12,12 +12,12 @@ #define MEDIA_BASE_VIDEO_BROADCASTER_H_ #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "api/video/video_frame_buffer.h" #include "api/video/video_source_interface.h" #include "media/base/video_source_base.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" namespace rtc { diff --git a/media/base/video_source_base.cc b/media/base/video_source_base.cc index d057a24ad8..2454902069 100644 --- a/media/base/video_source_base.cc +++ b/media/base/video_source_base.cc @@ -10,6 +10,8 @@ #include "media/base/video_source_base.h" +#include + #include "absl/algorithm/container.h" #include "rtc_base/checks.h" @@ -52,4 +54,51 @@ VideoSourceBase::SinkPair* VideoSourceBase::FindSinkPair( return nullptr; } +VideoSourceBaseGuarded::VideoSourceBaseGuarded() = default; +VideoSourceBaseGuarded::~VideoSourceBaseGuarded() = default; + +void VideoSourceBaseGuarded::AddOrUpdateSink( + VideoSinkInterface* sink, + const VideoSinkWants& wants) { + RTC_DCHECK_RUN_ON(&source_sequence_); + RTC_DCHECK(sink != nullptr); + + SinkPair* sink_pair = FindSinkPair(sink); + if (!sink_pair) { + sinks_.push_back(SinkPair(sink, wants)); + } else { + sink_pair->wants = wants; + } +} + +void VideoSourceBaseGuarded::RemoveSink( + VideoSinkInterface* sink) { + RTC_DCHECK_RUN_ON(&source_sequence_); + RTC_DCHECK(sink != nullptr); + RTC_DCHECK(FindSinkPair(sink)); + sinks_.erase(std::remove_if(sinks_.begin(), sinks_.end(), + [sink](const SinkPair& sink_pair) { + return sink_pair.sink == sink; + }), + sinks_.end()); +} + +VideoSourceBaseGuarded::SinkPair* VideoSourceBaseGuarded::FindSinkPair( + const VideoSinkInterface* sink) { + RTC_DCHECK_RUN_ON(&source_sequence_); + auto sink_pair_it = absl::c_find_if( + sinks_, + [sink](const SinkPair& sink_pair) { return sink_pair.sink == sink; }); + if (sink_pair_it != sinks_.end()) { + return &*sink_pair_it; + } + return nullptr; +} + +const std::vector& +VideoSourceBaseGuarded::sink_pairs() const { + RTC_DCHECK_RUN_ON(&source_sequence_); + return sinks_; +} + } // namespace rtc diff --git a/media/base/video_source_base.h b/media/base/video_source_base.h index 507fa10645..2644723aa7 100644 --- a/media/base/video_source_base.h +++ b/media/base/video_source_base.h @@ -13,14 +13,18 @@ #include +#include "api/sequence_checker.h" #include "api/video/video_frame.h" #include "api/video/video_sink_interface.h" #include "api/video/video_source_interface.h" -#include "rtc_base/thread_checker.h" +#include "rtc_base/system/no_unique_address.h" namespace rtc { -// VideoSourceBase is not thread safe. +// VideoSourceBase is not thread safe. Before using this class, consider using +// VideoSourceBaseGuarded below instead, which is an identical implementation +// but applies a sequence checker to help protect internal state. +// TODO(bugs.webrtc.org/12780): Delete this class. class VideoSourceBase : public VideoSourceInterface { public: VideoSourceBase(); @@ -44,6 +48,36 @@ class VideoSourceBase : public VideoSourceInterface { std::vector sinks_; }; +// VideoSourceBaseGuarded assumes that operations related to sinks, occur on the +// same TQ/thread that the object was constructed on. +class VideoSourceBaseGuarded : public VideoSourceInterface { + public: + VideoSourceBaseGuarded(); + ~VideoSourceBaseGuarded() override; + + void AddOrUpdateSink(VideoSinkInterface* sink, + const VideoSinkWants& wants) override; + void RemoveSink(VideoSinkInterface* sink) override; + + protected: + struct SinkPair { + SinkPair(VideoSinkInterface* sink, VideoSinkWants wants) + : sink(sink), wants(wants) {} + VideoSinkInterface* sink; + VideoSinkWants wants; + }; + + SinkPair* FindSinkPair(const VideoSinkInterface* sink); + const std::vector& sink_pairs() const; + + // Keep the `source_sequence_` checker protected to allow sub classes the + // ability to call Detach() if/when appropriate. + RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker source_sequence_; + + private: + std::vector sinks_ RTC_GUARDED_BY(&source_sequence_); +}; + } // namespace rtc #endif // MEDIA_BASE_VIDEO_SOURCE_BASE_H_ diff --git a/media/base/vp9_profile.h b/media/base/vp9_profile.h index e47204fede..d44a7998d2 100644 --- a/media/base/vp9_profile.h +++ b/media/base/vp9_profile.h @@ -11,43 +11,9 @@ #ifndef MEDIA_BASE_VP9_PROFILE_H_ #define MEDIA_BASE_VP9_PROFILE_H_ -#include +#include "api/video_codecs/vp9_profile.h" -#include "absl/types/optional.h" -#include "api/video_codecs/sdp_video_format.h" -#include "rtc_base/system/rtc_export.h" - -namespace webrtc { - -// Profile information for VP9 video. -extern RTC_EXPORT const char kVP9FmtpProfileId[]; - -enum class VP9Profile { - kProfile0, - kProfile1, - kProfile2, -}; - -// Helper functions to convert VP9Profile to std::string. Returns "0" by -// default. -RTC_EXPORT std::string VP9ProfileToString(VP9Profile profile); - -// Helper functions to convert std::string to VP9Profile. Returns null if given -// an invalid profile string. -absl::optional StringToVP9Profile(const std::string& str); - -// Parse profile that is represented as a string of single digit contained in an -// SDP key-value map. A default profile(kProfile0) will be returned if the -// profile key is missing. Nothing will be returned if the key is present but -// the string is invalid. -RTC_EXPORT absl::optional ParseSdpForVP9Profile( - const SdpVideoFormat::Parameters& params); - -// Returns true if the parameters have the same VP9 profile, or neither contains -// VP9 profile. -bool IsSameVP9Profile(const SdpVideoFormat::Parameters& params1, - const SdpVideoFormat::Parameters& params2); - -} // namespace webrtc +// TODO(crbug.com/1187565): Remove this file once downstream projects stop +// depend on it. #endif // MEDIA_BASE_VP9_PROFILE_H_ diff --git a/media/engine/encoder_simulcast_proxy_unittest.cc b/media/engine/encoder_simulcast_proxy_unittest.cc index ebbadb00a4..e5eb7a3703 100644 --- a/media/engine/encoder_simulcast_proxy_unittest.cc +++ b/media/engine/encoder_simulcast_proxy_unittest.cc @@ -49,7 +49,8 @@ TEST(EncoderSimulcastProxy, ChoosesCorrectImplementation) { 2000, 1000, 1000, - 56}; + 56, + true}; codec_settings.simulcastStream[1] = {test::kTestWidth, test::kTestHeight, test::kTestFrameRate, @@ -57,7 +58,8 @@ TEST(EncoderSimulcastProxy, ChoosesCorrectImplementation) { 3000, 1000, 1000, - 56}; + 56, + true}; codec_settings.simulcastStream[2] = {test::kTestWidth, test::kTestHeight, test::kTestFrameRate, @@ -65,7 +67,8 @@ TEST(EncoderSimulcastProxy, ChoosesCorrectImplementation) { 5000, 1000, 1000, - 56}; + 56, + true}; codec_settings.numberOfSimulcastStreams = 3; auto mock_encoder = std::make_unique>(); diff --git a/media/engine/fake_webrtc_call.cc b/media/engine/fake_webrtc_call.cc index e320880b2d..e8c7f6e0c9 100644 --- a/media/engine/fake_webrtc_call.cc +++ b/media/engine/fake_webrtc_call.cc @@ -17,6 +17,7 @@ #include "media/base/rtp_utils.h" #include "rtc_base/checks.h" #include "rtc_base/gunit.h" +#include "rtc_base/thread.h" namespace cricket { FakeAudioSendStream::FakeAudioSendStream( @@ -95,9 +96,31 @@ bool FakeAudioReceiveStream::DeliverRtp(const uint8_t* packet, return true; } -void FakeAudioReceiveStream::Reconfigure( - const webrtc::AudioReceiveStream::Config& config) { - config_ = config; +void FakeAudioReceiveStream::SetDepacketizerToDecoderFrameTransformer( + rtc::scoped_refptr frame_transformer) { + config_.frame_transformer = std::move(frame_transformer); +} + +void FakeAudioReceiveStream::SetDecoderMap( + std::map decoder_map) { + config_.decoder_map = std::move(decoder_map); +} + +void FakeAudioReceiveStream::SetUseTransportCcAndNackHistory( + bool use_transport_cc, + int history_ms) { + config_.rtp.transport_cc = use_transport_cc; + config_.rtp.nack.rtp_history_ms = history_ms; +} + +void FakeAudioReceiveStream::SetFrameDecryptor( + rtc::scoped_refptr frame_decryptor) { + config_.frame_decryptor = std::move(frame_decryptor); +} + +void FakeAudioReceiveStream::SetRtpExtensions( + std::vector extensions) { + config_.rtp.extensions = std::move(extensions); } webrtc::AudioReceiveStream::Stats FakeAudioReceiveStream::GetStats( @@ -326,10 +349,7 @@ void FakeVideoSendStream::InjectVideoSinkWants( FakeVideoReceiveStream::FakeVideoReceiveStream( webrtc::VideoReceiveStream::Config config) - : config_(std::move(config)), - receiving_(false), - num_added_secondary_sinks_(0), - num_removed_secondary_sinks_(0) {} + : config_(std::move(config)), receiving_(false) {} const webrtc::VideoReceiveStream::Config& FakeVideoReceiveStream::GetConfig() const { @@ -361,24 +381,6 @@ void FakeVideoReceiveStream::SetStats( stats_ = stats; } -void FakeVideoReceiveStream::AddSecondarySink( - webrtc::RtpPacketSinkInterface* sink) { - ++num_added_secondary_sinks_; -} - -void FakeVideoReceiveStream::RemoveSecondarySink( - const webrtc::RtpPacketSinkInterface* sink) { - ++num_removed_secondary_sinks_; -} - -int FakeVideoReceiveStream::GetNumAddedSecondarySinks() const { - return num_added_secondary_sinks_; -} - -int FakeVideoReceiveStream::GetNumRemovedSecondarySinks() const { - return num_removed_secondary_sinks_; -} - FakeFlexfecReceiveStream::FakeFlexfecReceiveStream( const webrtc::FlexfecReceiveStream::Config& config) : config_(config) {} @@ -398,7 +400,13 @@ void FakeFlexfecReceiveStream::OnRtpPacket(const webrtc::RtpPacketReceived&) { } FakeCall::FakeCall() - : audio_network_state_(webrtc::kNetworkUp), + : FakeCall(rtc::Thread::Current(), rtc::Thread::Current()) {} + +FakeCall::FakeCall(webrtc::TaskQueueBase* worker_thread, + webrtc::TaskQueueBase* network_thread) + : network_thread_(network_thread), + worker_thread_(worker_thread), + audio_network_state_(webrtc::kNetworkUp), video_network_state_(webrtc::kNetworkUp), num_created_send_streams_(0), num_created_receive_streams_(0) {} @@ -599,14 +607,17 @@ FakeCall::DeliveryStatus FakeCall::DeliverPacket(webrtc::MediaType media_type, if (media_type == webrtc::MediaType::VIDEO) { for (auto receiver : video_receive_streams_) { - if (receiver->GetConfig().rtp.remote_ssrc == ssrc) + if (receiver->GetConfig().rtp.remote_ssrc == ssrc) { + ++delivered_packets_by_ssrc_[ssrc]; return DELIVERY_OK; + } } } if (media_type == webrtc::MediaType::AUDIO) { for (auto receiver : audio_receive_streams_) { if (receiver->GetConfig().rtp.remote_ssrc == ssrc) { receiver->DeliverRtp(packet.cdata(), packet.size(), packet_time_us); + ++delivered_packets_by_ssrc_[ssrc]; return DELIVERY_OK; } } @@ -630,6 +641,14 @@ webrtc::Call::Stats FakeCall::GetStats() const { return stats_; } +webrtc::TaskQueueBase* FakeCall::network_thread() const { + return network_thread_; +} + +webrtc::TaskQueueBase* FakeCall::worker_thread() const { + return worker_thread_; +} + void FakeCall::SignalChannelNetworkState(webrtc::MediaType media, webrtc::NetworkState state) { switch (media) { @@ -649,6 +668,18 @@ void FakeCall::SignalChannelNetworkState(webrtc::MediaType media, void FakeCall::OnAudioTransportOverheadChanged( int transport_overhead_per_packet) {} +void FakeCall::OnLocalSsrcUpdated(webrtc::AudioReceiveStream& stream, + uint32_t local_ssrc) { + auto& fake_stream = static_cast(stream); + fake_stream.SetLocalSsrc(local_ssrc); +} + +void FakeCall::OnUpdateSyncGroup(webrtc::AudioReceiveStream& stream, + const std::string& sync_group) { + auto& fake_stream = static_cast(stream); + fake_stream.SetSyncGroup(sync_group); +} + void FakeCall::OnSentPacket(const rtc::SentPacket& sent_packet) { last_sent_packet_ = sent_packet; if (sent_packet.packet_id >= 0) { diff --git a/media/engine/fake_webrtc_call.h b/media/engine/fake_webrtc_call.h index 385bbcd76d..aeef95477e 100644 --- a/media/engine/fake_webrtc_call.h +++ b/media/engine/fake_webrtc_call.h @@ -20,6 +20,7 @@ #ifndef MEDIA_ENGINE_FAKE_WEBRTC_CALL_H_ #define MEDIA_ENGINE_FAKE_WEBRTC_CALL_H_ +#include #include #include #include @@ -99,11 +100,31 @@ class FakeAudioReceiveStream final : public webrtc::AudioReceiveStream { return base_mininum_playout_delay_ms_; } + void SetLocalSsrc(uint32_t local_ssrc) { + config_.rtp.local_ssrc = local_ssrc; + } + + void SetSyncGroup(const std::string& sync_group) { + config_.sync_group = sync_group; + } + private: - // webrtc::AudioReceiveStream implementation. - void Reconfigure(const webrtc::AudioReceiveStream::Config& config) override; + const webrtc::ReceiveStream::RtpConfig& rtp_config() const override { + return config_.rtp; + } void Start() override { started_ = true; } void Stop() override { started_ = false; } + bool IsRunning() const override { return started_; } + void SetDepacketizerToDecoderFrameTransformer( + rtc::scoped_refptr frame_transformer) + override; + void SetDecoderMap( + std::map decoder_map) override; + void SetUseTransportCcAndNackHistory(bool use_transport_cc, + int history_ms) override; + void SetFrameDecryptor(rtc::scoped_refptr + frame_decryptor) override; + void SetRtpExtensions(std::vector extensions) override; webrtc::AudioReceiveStream::Stats GetStats( bool get_and_clear_legacy_stats) const override; @@ -175,6 +196,7 @@ class FakeVideoSendStream final const std::vector active_layers) override; void Start() override; void Stop() override; + bool started() override { return IsSending(); } void AddAdaptationResource( rtc::scoped_refptr resource) override; std::vector> GetAdaptationResources() @@ -218,12 +240,6 @@ class FakeVideoReceiveStream final : public webrtc::VideoReceiveStream { void SetStats(const webrtc::VideoReceiveStream::Stats& stats); - void AddSecondarySink(webrtc::RtpPacketSinkInterface* sink) override; - void RemoveSecondarySink(const webrtc::RtpPacketSinkInterface* sink) override; - - int GetNumAddedSecondarySinks() const; - int GetNumRemovedSecondarySinks() const; - std::vector GetSources() const override { return std::vector(); } @@ -247,6 +263,9 @@ class FakeVideoReceiveStream final : public webrtc::VideoReceiveStream { private: // webrtc::VideoReceiveStream implementation. + const webrtc::ReceiveStream::RtpConfig& rtp_config() const override { + return config_.rtp; + } void Start() override; void Stop() override; @@ -266,9 +285,6 @@ class FakeVideoReceiveStream final : public webrtc::VideoReceiveStream { webrtc::VideoReceiveStream::Stats stats_; int base_mininum_playout_delay_ms_ = 0; - - int num_added_secondary_sinks_; - int num_removed_secondary_sinks_; }; class FakeFlexfecReceiveStream final : public webrtc::FlexfecReceiveStream { @@ -276,7 +292,11 @@ class FakeFlexfecReceiveStream final : public webrtc::FlexfecReceiveStream { explicit FakeFlexfecReceiveStream( const webrtc::FlexfecReceiveStream::Config& config); - const webrtc::FlexfecReceiveStream::Config& GetConfig() const override; + const webrtc::ReceiveStream::RtpConfig& rtp_config() const override { + return config_.rtp; + } + + const webrtc::FlexfecReceiveStream::Config& GetConfig() const; private: webrtc::FlexfecReceiveStream::Stats GetStats() const override; @@ -289,6 +309,8 @@ class FakeFlexfecReceiveStream final : public webrtc::FlexfecReceiveStream { class FakeCall final : public webrtc::Call, public webrtc::PacketReceiver { public: FakeCall(); + FakeCall(webrtc::TaskQueueBase* worker_thread, + webrtc::TaskQueueBase* network_thread); ~FakeCall() override; webrtc::MockRtpTransportControllerSend* GetMockTransportControllerSend() { @@ -307,6 +329,10 @@ class FakeCall final : public webrtc::Call, public webrtc::PacketReceiver { const std::vector& GetFlexfecReceiveStreams(); rtc::SentPacket last_sent_packet() const { return last_sent_packet_; } + size_t GetDeliveredPacketsForSsrc(uint32_t ssrc) const { + auto it = delivered_packets_by_ssrc_.find(ssrc); + return it != delivered_packets_by_ssrc_.end() ? it->second : 0u; + } // This is useful if we care about the last media packet (with id populated) // but not the last ICE packet (with -1 ID). @@ -367,12 +393,22 @@ class FakeCall final : public webrtc::Call, public webrtc::PacketReceiver { return trials_; } + webrtc::TaskQueueBase* network_thread() const override; + webrtc::TaskQueueBase* worker_thread() const override; + void SignalChannelNetworkState(webrtc::MediaType media, webrtc::NetworkState state) override; void OnAudioTransportOverheadChanged( int transport_overhead_per_packet) override; + void OnLocalSsrcUpdated(webrtc::AudioReceiveStream& stream, + uint32_t local_ssrc) override; + void OnUpdateSyncGroup(webrtc::AudioReceiveStream& stream, + const std::string& sync_group) override; void OnSentPacket(const rtc::SentPacket& sent_packet) override; + webrtc::TaskQueueBase* const network_thread_; + webrtc::TaskQueueBase* const worker_thread_; + ::testing::NiceMock transport_controller_send_; @@ -387,6 +423,7 @@ class FakeCall final : public webrtc::Call, public webrtc::PacketReceiver { std::vector video_receive_streams_; std::vector audio_receive_streams_; std::vector flexfec_receive_streams_; + std::map delivered_packets_by_ssrc_; int num_created_send_streams_; int num_created_receive_streams_; diff --git a/media/engine/internal_decoder_factory.cc b/media/engine/internal_decoder_factory.cc index 1c084846a2..a8d1f00009 100644 --- a/media/engine/internal_decoder_factory.cc +++ b/media/engine/internal_decoder_factory.cc @@ -23,23 +23,6 @@ namespace webrtc { -namespace { - -bool IsFormatSupported( - const std::vector& supported_formats, - const webrtc::SdpVideoFormat& format) { - for (const webrtc::SdpVideoFormat& supported_format : supported_formats) { - if (cricket::IsSameCodec(format.name, format.parameters, - supported_format.name, - supported_format.parameters)) { - return true; - } - } - return false; -} - -} // namespace - std::vector InternalDecoderFactory::GetSupportedFormats() const { std::vector formats; @@ -55,7 +38,7 @@ std::vector InternalDecoderFactory::GetSupportedFormats() std::unique_ptr InternalDecoderFactory::CreateVideoDecoder( const SdpVideoFormat& format) { - if (!IsFormatSupported(GetSupportedFormats(), format)) { + if (!format.IsCodecInList(GetSupportedFormats())) { RTC_LOG(LS_WARNING) << "Trying to create decoder for unsupported format. " << format.ToString(); return nullptr; diff --git a/media/engine/internal_decoder_factory_unittest.cc b/media/engine/internal_decoder_factory_unittest.cc index 61be5e72df..a2a69211b9 100644 --- a/media/engine/internal_decoder_factory_unittest.cc +++ b/media/engine/internal_decoder_factory_unittest.cc @@ -12,8 +12,8 @@ #include "api/video_codecs/sdp_video_format.h" #include "api/video_codecs/video_decoder.h" +#include "api/video_codecs/vp9_profile.h" #include "media/base/media_constants.h" -#include "media/base/vp9_profile.h" #include "modules/video_coding/codecs/av1/libaom_av1_decoder.h" #include "test/gmock.h" #include "test/gtest.h" diff --git a/media/engine/null_webrtc_video_engine_unittest.cc b/media/engine/null_webrtc_video_engine_unittest.cc index 47b9ab22dd..a23a3b6cdf 100644 --- a/media/engine/null_webrtc_video_engine_unittest.cc +++ b/media/engine/null_webrtc_video_engine_unittest.cc @@ -41,8 +41,7 @@ TEST(NullWebRtcVideoEngineTest, CheckInterface) { CompositeMediaEngine engine(std::move(audio_engine), std::make_unique()); - - EXPECT_TRUE(engine.Init()); + engine.Init(); } } // namespace cricket diff --git a/media/engine/payload_type_mapper.cc b/media/engine/payload_type_mapper.cc index e9f863ca63..cbc0a5340d 100644 --- a/media/engine/payload_type_mapper.cc +++ b/media/engine/payload_type_mapper.cc @@ -32,18 +32,18 @@ PayloadTypeMapper::PayloadTypeMapper() max_payload_type_(127), mappings_( {// Static payload type assignments according to RFC 3551. - {{"PCMU", 8000, 1}, 0}, + {{kPcmuCodecName, 8000, 1}, 0}, {{"GSM", 8000, 1}, 3}, {{"G723", 8000, 1}, 4}, {{"DVI4", 8000, 1}, 5}, {{"DVI4", 16000, 1}, 6}, {{"LPC", 8000, 1}, 7}, - {{"PCMA", 8000, 1}, 8}, - {{"G722", 8000, 1}, 9}, - {{"L16", 44100, 2}, 10}, - {{"L16", 44100, 1}, 11}, + {{kPcmaCodecName, 8000, 1}, 8}, + {{kG722CodecName, 8000, 1}, 9}, + {{kL16CodecName, 44100, 2}, 10}, + {{kL16CodecName, 44100, 1}, 11}, {{"QCELP", 8000, 1}, 12}, - {{"CN", 8000, 1}, 13}, + {{kCnCodecName, 8000, 1}, 13}, // RFC 4566 is a bit ambiguous on the contents of the "encoding // parameters" field, which, for audio, encodes the number of // channels. It is "optional and may be omitted if the number of @@ -61,7 +61,6 @@ PayloadTypeMapper::PayloadTypeMapper() // Payload type assignments currently used by WebRTC. // Includes data to reduce collisions (and thus reassignments) - {{kGoogleRtpDataCodecName, 0, 0}, kGoogleRtpDataCodecPlType}, {{kIlbcCodecName, 8000, 1}, 102}, {{kIsacCodecName, 16000, 1}, 103}, {{kIsacCodecName, 32000, 1}, 104}, @@ -70,8 +69,11 @@ PayloadTypeMapper::PayloadTypeMapper() {{kOpusCodecName, 48000, 2, - {{"minptime", "10"}, {"useinbandfec", "1"}}}, + {{kCodecParamMinPTime, "10"}, + {kCodecParamUseInbandFec, kParamValueTrue}}}, 111}, + // RED for opus is assigned in the lower range, starting at the top. + {{kRedCodecName, 48000, 2}, 63}, // TODO(solenberg): Remove the hard coded 16k,32k,48k DTMF once we // assign payload types dynamically for send side as well. {{kDtmfCodecName, 48000, 1}, 110}, diff --git a/media/engine/payload_type_mapper_unittest.cc b/media/engine/payload_type_mapper_unittest.cc index fa6864b48a..9c29827fa9 100644 --- a/media/engine/payload_type_mapper_unittest.cc +++ b/media/engine/payload_type_mapper_unittest.cc @@ -46,13 +46,8 @@ TEST_F(PayloadTypeMapperTest, StaticPayloadTypes) { } TEST_F(PayloadTypeMapperTest, WebRTCPayloadTypes) { - // Tests that the payload mapper knows about the audio and data formats we've + // Tests that the payload mapper knows about the audio formats we've // been using in WebRTC, with their hard coded values. - auto data_mapping = [this](const char* name) { - return mapper_.FindMappingFor({name, 0, 0}); - }; - EXPECT_EQ(kGoogleRtpDataCodecPlType, data_mapping(kGoogleRtpDataCodecName)); - EXPECT_EQ(102, mapper_.FindMappingFor({kIlbcCodecName, 8000, 1})); EXPECT_EQ(103, mapper_.FindMappingFor({kIsacCodecName, 16000, 1})); EXPECT_EQ(104, mapper_.FindMappingFor({kIsacCodecName, 32000, 1})); @@ -63,6 +58,7 @@ TEST_F(PayloadTypeMapperTest, WebRTCPayloadTypes) { 48000, 2, {{"minptime", "10"}, {"useinbandfec", "1"}}})); + EXPECT_EQ(63, mapper_.FindMappingFor({kRedCodecName, 48000, 2})); // TODO(solenberg): Remove 16k, 32k, 48k DTMF checks once these payload types // are dynamically assigned. EXPECT_EQ(110, mapper_.FindMappingFor({kDtmfCodecName, 48000, 1})); diff --git a/media/engine/simulcast.cc b/media/engine/simulcast.cc index f74d4adfbe..ebc6a240fe 100644 --- a/media/engine/simulcast.cc +++ b/media/engine/simulcast.cc @@ -15,14 +15,15 @@ #include #include +#include #include "absl/strings/match.h" #include "absl/types/optional.h" #include "api/video/video_codec_constants.h" #include "media/base/media_constants.h" #include "modules/video_coding/utility/simulcast_rate_allocator.h" -#include "rtc_base/arraysize.h" #include "rtc_base/checks.h" +#include "rtc_base/experiments/field_trial_parser.h" #include "rtc_base/experiments/min_video_bitrate_experiment.h" #include "rtc_base/experiments/normalize_simulcast_size_experiment.h" #include "rtc_base/experiments/rate_control_settings.h" @@ -41,6 +42,15 @@ constexpr webrtc::DataRate Interpolate(const webrtc::DataRate& a, constexpr char kUseLegacySimulcastLayerLimitFieldTrial[] = "WebRTC-LegacySimulcastLayerLimit"; +constexpr double kDefaultMaxRoundupRate = 0.1; + +// TODO(webrtc:12415): Flip this to a kill switch when this feature launches. +bool EnableLowresBitrateInterpolation( + const webrtc::WebRtcKeyValueConfig& trials) { + return absl::StartsWith( + trials.Lookup("WebRTC-LowresSimulcastBitrateInterpolation"), "Enabled"); +} + // Limits for legacy conference screensharing mode. Currently used for the // lower of the two simulcast streams. constexpr webrtc::DataRate kScreenshareDefaultTl0Bitrate = @@ -61,7 +71,7 @@ struct SimulcastFormat { int width; int height; // The maximum number of simulcast layers can be used for - // resolutions at |widthxheigh| for legacy applications. + // resolutions at |widthxheight| for legacy applications. size_t max_layers; // The maximum bitrate for encoding stream at |widthxheight|, when we are // not sending the next higher spatial stream. @@ -96,10 +106,29 @@ constexpr const SimulcastFormat kSimulcastFormats[] = { {320, 180, 1, webrtc::DataRate::KilobitsPerSec(200), webrtc::DataRate::KilobitsPerSec(150), webrtc::DataRate::KilobitsPerSec(30)}, - {0, 0, 1, webrtc::DataRate::KilobitsPerSec(200), - webrtc::DataRate::KilobitsPerSec(150), + // As the resolution goes down, interpolate the target and max bitrates down + // towards zero. The min bitrate is still limited at 30 kbps and the target + // and the max will be capped from below accordingly. + {0, 0, 1, webrtc::DataRate::KilobitsPerSec(0), + webrtc::DataRate::KilobitsPerSec(0), webrtc::DataRate::KilobitsPerSec(30)}}; +std::vector GetSimulcastFormats( + bool enable_lowres_bitrate_interpolation) { + std::vector formats; + formats.insert(formats.begin(), std::begin(kSimulcastFormats), + std::end(kSimulcastFormats)); + if (!enable_lowres_bitrate_interpolation) { + RTC_CHECK_GE(formats.size(), 2u); + SimulcastFormat& format0x0 = formats[formats.size() - 1]; + const SimulcastFormat& format_prev = formats[formats.size() - 2]; + format0x0.max_bitrate = format_prev.max_bitrate; + format0x0.target_bitrate = format_prev.target_bitrate; + format0x0.min_bitrate = format_prev.min_bitrate; + } + return formats; +} + const int kMaxScreenshareSimulcastLayers = 2; // Multiway: Number of temporal layers for each simulcast stream. @@ -135,12 +164,14 @@ int DefaultNumberOfTemporalLayers(int simulcast_id, return default_num_temporal_layers; } -int FindSimulcastFormatIndex(int width, int height) { +int FindSimulcastFormatIndex(int width, + int height, + bool enable_lowres_bitrate_interpolation) { RTC_DCHECK_GE(width, 0); RTC_DCHECK_GE(height, 0); - for (uint32_t i = 0; i < arraysize(kSimulcastFormats); ++i) { - if (width * height >= - kSimulcastFormats[i].width * kSimulcastFormats[i].height) { + const auto formats = GetSimulcastFormats(enable_lowres_bitrate_interpolation); + for (uint32_t i = 0; i < formats.size(); ++i) { + if (width * height >= formats[i].width * formats[i].height) { return i; } } @@ -162,42 +193,70 @@ int NormalizeSimulcastSize(int size, size_t simulcast_layers) { return ((size >> base2_exponent) << base2_exponent); } -SimulcastFormat InterpolateSimulcastFormat(int width, int height) { - const int index = FindSimulcastFormatIndex(width, height); +SimulcastFormat InterpolateSimulcastFormat( + int width, + int height, + absl::optional max_roundup_rate, + bool enable_lowres_bitrate_interpolation) { + const auto formats = GetSimulcastFormats(enable_lowres_bitrate_interpolation); + const int index = FindSimulcastFormatIndex( + width, height, enable_lowres_bitrate_interpolation); if (index == 0) - return kSimulcastFormats[index]; + return formats[index]; const int total_pixels_up = - kSimulcastFormats[index - 1].width * kSimulcastFormats[index - 1].height; - const int total_pixels_down = - kSimulcastFormats[index].width * kSimulcastFormats[index].height; + formats[index - 1].width * formats[index - 1].height; + const int total_pixels_down = formats[index].width * formats[index].height; const int total_pixels = width * height; const float rate = (total_pixels_up - total_pixels) / static_cast(total_pixels_up - total_pixels_down); - size_t max_layers = kSimulcastFormats[index].max_layers; - webrtc::DataRate max_bitrate = - Interpolate(kSimulcastFormats[index - 1].max_bitrate, - kSimulcastFormats[index].max_bitrate, rate); - webrtc::DataRate target_bitrate = - Interpolate(kSimulcastFormats[index - 1].target_bitrate, - kSimulcastFormats[index].target_bitrate, rate); - webrtc::DataRate min_bitrate = - Interpolate(kSimulcastFormats[index - 1].min_bitrate, - kSimulcastFormats[index].min_bitrate, rate); + // Use upper resolution if |rate| is below the configured threshold. + size_t max_layers = (rate < max_roundup_rate.value_or(kDefaultMaxRoundupRate)) + ? formats[index - 1].max_layers + : formats[index].max_layers; + webrtc::DataRate max_bitrate = Interpolate(formats[index - 1].max_bitrate, + formats[index].max_bitrate, rate); + webrtc::DataRate target_bitrate = Interpolate( + formats[index - 1].target_bitrate, formats[index].target_bitrate, rate); + webrtc::DataRate min_bitrate = Interpolate(formats[index - 1].min_bitrate, + formats[index].min_bitrate, rate); return {width, height, max_layers, max_bitrate, target_bitrate, min_bitrate}; } -webrtc::DataRate FindSimulcastMaxBitrate(int width, int height) { - return InterpolateSimulcastFormat(width, height).max_bitrate; +SimulcastFormat InterpolateSimulcastFormat( + int width, + int height, + bool enable_lowres_bitrate_interpolation) { + return InterpolateSimulcastFormat(width, height, absl::nullopt, + enable_lowres_bitrate_interpolation); } -webrtc::DataRate FindSimulcastTargetBitrate(int width, int height) { - return InterpolateSimulcastFormat(width, height).target_bitrate; +webrtc::DataRate FindSimulcastMaxBitrate( + int width, + int height, + bool enable_lowres_bitrate_interpolation) { + return InterpolateSimulcastFormat(width, height, + enable_lowres_bitrate_interpolation) + .max_bitrate; } -webrtc::DataRate FindSimulcastMinBitrate(int width, int height) { - return InterpolateSimulcastFormat(width, height).min_bitrate; +webrtc::DataRate FindSimulcastTargetBitrate( + int width, + int height, + bool enable_lowres_bitrate_interpolation) { + return InterpolateSimulcastFormat(width, height, + enable_lowres_bitrate_interpolation) + .target_bitrate; +} + +webrtc::DataRate FindSimulcastMinBitrate( + int width, + int height, + bool enable_lowres_bitrate_interpolation) { + return InterpolateSimulcastFormat(width, height, + enable_lowres_bitrate_interpolation) + .min_bitrate; } void BoostMaxSimulcastLayer(webrtc::DataRate max_bitrate, @@ -235,9 +294,21 @@ size_t LimitSimulcastLayerCount(int width, const webrtc::WebRtcKeyValueConfig& trials) { if (!absl::StartsWith(trials.Lookup(kUseLegacySimulcastLayerLimitFieldTrial), "Disabled")) { + // Max layers from one higher resolution in kSimulcastFormats will be used + // if the ratio (pixels_up - pixels) / (pixels_up - pixels_down) is less + // than configured |max_ratio|. pixels_down is the selected index in + // kSimulcastFormats based on pixels. + webrtc::FieldTrialOptional max_ratio("max_ratio"); + webrtc::ParseFieldTrial({&max_ratio}, + trials.Lookup("WebRTC-SimulcastLayerLimitRoundUp")); + + const bool enable_lowres_bitrate_interpolation = + EnableLowresBitrateInterpolation(trials); size_t adaptive_layer_count = std::max( need_layers, - kSimulcastFormats[FindSimulcastFormatIndex(width, height)].max_layers); + InterpolateSimulcastFormat(width, height, max_ratio.GetOptional(), + enable_lowres_bitrate_interpolation) + .max_layers); if (layer_count > adaptive_layer_count) { RTC_LOG(LS_WARNING) << "Reducing simulcast layer count from " << layer_count << " to " << adaptive_layer_count; @@ -291,6 +362,9 @@ std::vector GetNormalSimulcastLayers( const webrtc::WebRtcKeyValueConfig& trials) { std::vector layers(layer_count); + const bool enable_lowres_bitrate_interpolation = + EnableLowresBitrateInterpolation(trials); + // Format width and height has to be divisible by |2 ^ num_simulcast_layers - // 1|. width = NormalizeSimulcastSize(width, layer_count); @@ -306,9 +380,14 @@ std::vector GetNormalSimulcastLayers( temporal_layers_supported ? DefaultNumberOfTemporalLayers(s, false, trials) : 1; - layers[s].max_bitrate_bps = FindSimulcastMaxBitrate(width, height).bps(); + layers[s].max_bitrate_bps = + FindSimulcastMaxBitrate(width, height, + enable_lowres_bitrate_interpolation) + .bps(); layers[s].target_bitrate_bps = - FindSimulcastTargetBitrate(width, height).bps(); + FindSimulcastTargetBitrate(width, height, + enable_lowres_bitrate_interpolation) + .bps(); int num_temporal_layers = DefaultNumberOfTemporalLayers(s, false, trials); if (s == 0) { // If alternative temporal rate allocation is selected, adjust the @@ -335,7 +414,17 @@ std::vector GetNormalSimulcastLayers( layers[s].target_bitrate_bps = static_cast(layers[s].target_bitrate_bps * rate_factor); } - layers[s].min_bitrate_bps = FindSimulcastMinBitrate(width, height).bps(); + layers[s].min_bitrate_bps = + FindSimulcastMinBitrate(width, height, + enable_lowres_bitrate_interpolation) + .bps(); + + // Ensure consistency. + layers[s].max_bitrate_bps = + std::max(layers[s].min_bitrate_bps, layers[s].max_bitrate_bps); + layers[s].target_bitrate_bps = + std::max(layers[s].min_bitrate_bps, layers[s].target_bitrate_bps); + layers[s].max_framerate = kDefaultVideoMaxFramerate; width /= 2; diff --git a/media/engine/simulcast_encoder_adapter.cc b/media/engine/simulcast_encoder_adapter.cc index e0c0ff7bc6..116f987aa4 100644 --- a/media/engine/simulcast_encoder_adapter.cc +++ b/media/engine/simulcast_encoder_adapter.cc @@ -18,6 +18,7 @@ #include #include +#include "absl/algorithm/container.h" #include "api/scoped_refptr.h" #include "api/video/i420_buffer.h" #include "api/video/video_codec_constants.h" @@ -61,25 +62,29 @@ uint32_t SumStreamMaxBitrate(int streams, const webrtc::VideoCodec& codec) { return bitrate_sum; } -int NumberOfStreams(const webrtc::VideoCodec& codec) { - int streams = +int CountAllStreams(const webrtc::VideoCodec& codec) { + int total_streams_count = codec.numberOfSimulcastStreams < 1 ? 1 : codec.numberOfSimulcastStreams; - uint32_t simulcast_max_bitrate = SumStreamMaxBitrate(streams, codec); + uint32_t simulcast_max_bitrate = + SumStreamMaxBitrate(total_streams_count, codec); if (simulcast_max_bitrate == 0) { - streams = 1; + total_streams_count = 1; } - return streams; + return total_streams_count; } -int NumActiveStreams(const webrtc::VideoCodec& codec) { - int num_configured_streams = NumberOfStreams(codec); - int num_active_streams = 0; - for (int i = 0; i < num_configured_streams; ++i) { +int CountActiveStreams(const webrtc::VideoCodec& codec) { + if (codec.numberOfSimulcastStreams < 1) { + return 1; + } + int total_streams_count = CountAllStreams(codec); + int active_streams_count = 0; + for (int i = 0; i < total_streams_count; ++i) { if (codec.simulcastStream[i].active) { - ++num_active_streams; + ++active_streams_count; } } - return num_active_streams; + return active_streams_count; } int VerifyCodec(const webrtc::VideoCodec* inst) { @@ -97,42 +102,150 @@ int VerifyCodec(const webrtc::VideoCodec* inst) { return WEBRTC_VIDEO_CODEC_ERR_PARAMETER; } if (inst->codecType == webrtc::kVideoCodecVP8 && - inst->VP8().automaticResizeOn && NumActiveStreams(*inst) > 1) { + inst->VP8().automaticResizeOn && CountActiveStreams(*inst) > 1) { return WEBRTC_VIDEO_CODEC_ERR_PARAMETER; } return WEBRTC_VIDEO_CODEC_OK; } -bool StreamResolutionCompare(const webrtc::SpatialLayer& a, - const webrtc::SpatialLayer& b) { +bool StreamQualityCompare(const webrtc::SpatialLayer& a, + const webrtc::SpatialLayer& b) { return std::tie(a.height, a.width, a.maxBitrate, a.maxFramerate) < std::tie(b.height, b.width, b.maxBitrate, b.maxFramerate); } -// An EncodedImageCallback implementation that forwards on calls to a -// SimulcastEncoderAdapter, but with the stream index it's registered with as -// the first parameter to Encoded. -class AdapterEncodedImageCallback : public webrtc::EncodedImageCallback { - public: - AdapterEncodedImageCallback(webrtc::SimulcastEncoderAdapter* adapter, - size_t stream_idx) - : adapter_(adapter), stream_idx_(stream_idx) {} +void GetLowestAndHighestQualityStreamIndixes( + rtc::ArrayView streams, + int* lowest_quality_stream_idx, + int* highest_quality_stream_idx) { + const auto lowest_highest_quality_streams = + absl::c_minmax_element(streams, StreamQualityCompare); + *lowest_quality_stream_idx = + std::distance(streams.begin(), lowest_highest_quality_streams.first); + *highest_quality_stream_idx = + std::distance(streams.begin(), lowest_highest_quality_streams.second); +} - EncodedImageCallback::Result OnEncodedImage( - const webrtc::EncodedImage& encoded_image, - const webrtc::CodecSpecificInfo* codec_specific_info) override { - return adapter_->OnEncodedImage(stream_idx_, encoded_image, - codec_specific_info); +std::vector GetStreamStartBitratesKbps( + const webrtc::VideoCodec& codec) { + std::vector start_bitrates; + std::unique_ptr rate_allocator = + std::make_unique(codec); + webrtc::VideoBitrateAllocation allocation = + rate_allocator->Allocate(webrtc::VideoBitrateAllocationParameters( + codec.startBitrate * 1000, codec.maxFramerate)); + + int total_streams_count = CountAllStreams(codec); + for (int i = 0; i < total_streams_count; ++i) { + uint32_t stream_bitrate = allocation.GetSpatialLayerSum(i) / 1000; + start_bitrates.push_back(stream_bitrate); } + return start_bitrates; +} - private: - webrtc::SimulcastEncoderAdapter* const adapter_; - const size_t stream_idx_; -}; } // namespace namespace webrtc { +SimulcastEncoderAdapter::EncoderContext::EncoderContext( + std::unique_ptr encoder, + bool prefer_temporal_support, + VideoEncoder::EncoderInfo primary_info, + VideoEncoder::EncoderInfo fallback_info) + : encoder_(std::move(encoder)), + prefer_temporal_support_(prefer_temporal_support), + primary_info_(std::move(primary_info)), + fallback_info_(std::move(fallback_info)) {} + +void SimulcastEncoderAdapter::EncoderContext::Release() { + if (encoder_) { + encoder_->RegisterEncodeCompleteCallback(nullptr); + encoder_->Release(); + } +} + +SimulcastEncoderAdapter::StreamContext::StreamContext( + SimulcastEncoderAdapter* parent, + std::unique_ptr encoder_context, + std::unique_ptr framerate_controller, + int stream_idx, + uint16_t width, + uint16_t height, + bool is_paused) + : parent_(parent), + encoder_context_(std::move(encoder_context)), + framerate_controller_(std::move(framerate_controller)), + stream_idx_(stream_idx), + width_(width), + height_(height), + is_keyframe_needed_(false), + is_paused_(is_paused) { + if (parent_) { + encoder_context_->encoder().RegisterEncodeCompleteCallback(this); + } +} + +SimulcastEncoderAdapter::StreamContext::StreamContext(StreamContext&& rhs) + : parent_(rhs.parent_), + encoder_context_(std::move(rhs.encoder_context_)), + framerate_controller_(std::move(rhs.framerate_controller_)), + stream_idx_(rhs.stream_idx_), + width_(rhs.width_), + height_(rhs.height_), + is_keyframe_needed_(rhs.is_keyframe_needed_), + is_paused_(rhs.is_paused_) { + if (parent_) { + encoder_context_->encoder().RegisterEncodeCompleteCallback(this); + } +} + +SimulcastEncoderAdapter::StreamContext::~StreamContext() { + if (encoder_context_) { + encoder_context_->Release(); + } +} + +std::unique_ptr +SimulcastEncoderAdapter::StreamContext::ReleaseEncoderContext() && { + encoder_context_->Release(); + return std::move(encoder_context_); +} + +void SimulcastEncoderAdapter::StreamContext::OnKeyframe(Timestamp timestamp) { + is_keyframe_needed_ = false; + if (framerate_controller_) { + framerate_controller_->AddFrame(timestamp.ms()); + } +} + +bool SimulcastEncoderAdapter::StreamContext::ShouldDropFrame( + Timestamp timestamp) { + if (!framerate_controller_) { + return false; + } + + if (framerate_controller_->DropFrame(timestamp.ms())) { + return true; + } + framerate_controller_->AddFrame(timestamp.ms()); + return false; +} + +EncodedImageCallback::Result +SimulcastEncoderAdapter::StreamContext::OnEncodedImage( + const EncodedImage& encoded_image, + const CodecSpecificInfo* codec_specific_info) { + RTC_CHECK(parent_); // If null, this method should never be called. + return parent_->OnEncodedImage(stream_idx_, encoded_image, + codec_specific_info); +} + +void SimulcastEncoderAdapter::StreamContext::OnDroppedFrame( + DropReason /*reason*/) { + RTC_CHECK(parent_); // If null, this method should never be called. + parent_->OnDroppedFrame(stream_idx_); +} + SimulcastEncoderAdapter::SimulcastEncoderAdapter(VideoEncoderFactory* factory, const SdpVideoFormat& format) : SimulcastEncoderAdapter(factory, nullptr, format) {} @@ -145,6 +258,8 @@ SimulcastEncoderAdapter::SimulcastEncoderAdapter( primary_encoder_factory_(primary_factory), fallback_encoder_factory_(fallback_factory), video_format_(format), + total_streams_count_(0), + bypass_mode_(false), encoded_complete_callback_(nullptr), experimental_boosted_screenshare_qp_(GetScreenshareBoostedQpValue()), boost_base_layer_quality_(RateControlSettings::ParseFromFieldTrials() @@ -164,25 +279,23 @@ SimulcastEncoderAdapter::~SimulcastEncoderAdapter() { } void SimulcastEncoderAdapter::SetFecControllerOverride( - FecControllerOverride* fec_controller_override) { + FecControllerOverride* /*fec_controller_override*/) { // Ignored. } int SimulcastEncoderAdapter::Release() { RTC_DCHECK_RUN_ON(&encoder_queue_); - while (!streaminfos_.empty()) { - std::unique_ptr encoder = - std::move(streaminfos_.back().encoder); - // Even though it seems very unlikely, there are no guarantees that the - // encoder will not call back after being Release()'d. Therefore, we first - // disable the callbacks here. - encoder->RegisterEncodeCompleteCallback(nullptr); - encoder->Release(); - streaminfos_.pop_back(); // Deletes callback adapter. - stored_encoders_.push(std::move(encoder)); + while (!stream_contexts_.empty()) { + // Move the encoder instances and put it on the |cached_encoder_contexts_| + // where it may possibly be reused from (ordering does not matter). + cached_encoder_contexts_.push_front( + std::move(stream_contexts_.back()).ReleaseEncoderContext()); + stream_contexts_.pop_back(); } + bypass_mode_ = false; + // It's legal to move the encoder to another queue now. encoder_queue_.Detach(); @@ -191,7 +304,6 @@ int SimulcastEncoderAdapter::Release() { return WEBRTC_VIDEO_CODEC_OK; } -// TODO(eladalon): s/inst/codec_settings/g. int SimulcastEncoderAdapter::InitEncode( const VideoCodec* inst, const VideoEncoder::Settings& settings) { @@ -206,136 +318,118 @@ int SimulcastEncoderAdapter::InitEncode( return ret; } - ret = Release(); - if (ret < 0) { - return ret; - } - - int number_of_streams = NumberOfStreams(*inst); - RTC_DCHECK_LE(number_of_streams, kMaxSimulcastStreams); - bool doing_simulcast_using_adapter = (number_of_streams > 1); - int num_active_streams = NumActiveStreams(*inst); + Release(); codec_ = *inst; - SimulcastRateAllocator rate_allocator(codec_); - VideoBitrateAllocation allocation = - rate_allocator.Allocate(VideoBitrateAllocationParameters( - codec_.startBitrate * 1000, codec_.maxFramerate)); - std::vector start_bitrates; - for (int i = 0; i < kMaxSimulcastStreams; ++i) { - uint32_t stream_bitrate = allocation.GetSpatialLayerSum(i) / 1000; - start_bitrates.push_back(stream_bitrate); - } - - // Create |number_of_streams| of encoder instances and init them. - const auto minmax = std::minmax_element( - std::begin(codec_.simulcastStream), - std::begin(codec_.simulcastStream) + number_of_streams, - StreamResolutionCompare); - const auto lowest_resolution_stream_index = - std::distance(std::begin(codec_.simulcastStream), minmax.first); - const auto highest_resolution_stream_index = - std::distance(std::begin(codec_.simulcastStream), minmax.second); - - RTC_DCHECK_LT(lowest_resolution_stream_index, number_of_streams); - RTC_DCHECK_LT(highest_resolution_stream_index, number_of_streams); - - for (int i = 0; i < number_of_streams; ++i) { - // If an existing encoder instance exists, reuse it. - // TODO(brandtr): Set initial RTP state (e.g., picture_id/tl0_pic_idx) here, - // when we start storing that state outside the encoder wrappers. - std::unique_ptr encoder; - if (!stored_encoders_.empty()) { - encoder = std::move(stored_encoders_.top()); - stored_encoders_.pop(); - } else { - encoder = primary_encoder_factory_->CreateVideoEncoder(video_format_); - if (fallback_encoder_factory_ != nullptr) { - encoder = CreateVideoEncoderSoftwareFallbackWrapper( - fallback_encoder_factory_->CreateVideoEncoder(video_format_), - std::move(encoder), - i == lowest_resolution_stream_index && - prefer_temporal_support_on_base_layer_); - } + total_streams_count_ = CountAllStreams(*inst); + + // TODO(ronghuawu): Remove once this is handled in LibvpxVp8Encoder. + if (codec_.qpMax < kDefaultMinQp) { + codec_.qpMax = kDefaultMaxQp; + } + + bool is_legacy_singlecast = codec_.numberOfSimulcastStreams == 0; + int lowest_quality_stream_idx = 0; + int highest_quality_stream_idx = 0; + if (!is_legacy_singlecast) { + GetLowestAndHighestQualityStreamIndixes( + rtc::ArrayView(codec_.simulcastStream, + total_streams_count_), + &lowest_quality_stream_idx, &highest_quality_stream_idx); + } + + std::unique_ptr encoder_context = FetchOrCreateEncoderContext( + /*is_lowest_quality_stream=*/( + is_legacy_singlecast || + codec_.simulcastStream[lowest_quality_stream_idx].active)); + if (encoder_context == nullptr) { + return WEBRTC_VIDEO_CODEC_MEMORY; + } + + // Two distinct scenarios: + // * Singlecast (total_streams_count == 1) or simulcast with simulcast-capable + // underlaying encoder implementation if active_streams_count > 1. SEA + // operates in bypass mode: original settings are passed to the underlaying + // encoder, frame encode complete callback is not intercepted. + // * Multi-encoder simulcast or singlecast if layers are deactivated + // (active_streams_count >= 1). SEA creates N=active_streams_count encoders + // and configures each to produce a single stream. + + int active_streams_count = CountActiveStreams(*inst); + // If we only have a single active layer it is better to create an encoder + // with only one configured layer than creating it with all-but-one disabled + // layers because that way we control scaling. + bool separate_encoders_needed = + !encoder_context->encoder().GetEncoderInfo().supports_simulcast || + active_streams_count == 1; + // Singlecast or simulcast with simulcast-capable underlaying encoder. + if (total_streams_count_ == 1 || !separate_encoders_needed) { + int ret = encoder_context->encoder().InitEncode(&codec_, settings); + if (ret >= 0) { + stream_contexts_.emplace_back( + /*parent=*/nullptr, std::move(encoder_context), + /*framerate_controller=*/nullptr, /*stream_idx=*/0, codec_.width, + codec_.height, /*is_paused=*/active_streams_count == 0); + bypass_mode_ = true; + + DestroyStoredEncoders(); + rtc::AtomicOps::ReleaseStore(&inited_, 1); + return WEBRTC_VIDEO_CODEC_OK; } - bool encoder_initialized = false; - if (doing_simulcast_using_adapter && i == 0 && - encoder->GetEncoderInfo().supports_simulcast) { - ret = encoder->InitEncode(&codec_, settings); - if (ret < 0) { - encoder->Release(); - } else { - doing_simulcast_using_adapter = false; - number_of_streams = 1; - encoder_initialized = true; - } + encoder_context->Release(); + if (total_streams_count_ == 1) { + // Failed to initialize singlecast encoder. + return ret; } + } - VideoCodec stream_codec; - uint32_t start_bitrate_kbps = start_bitrates[i]; - const bool send_stream = doing_simulcast_using_adapter - ? start_bitrate_kbps > 0 - : num_active_streams > 0; - if (!doing_simulcast_using_adapter) { - stream_codec = codec_; - stream_codec.numberOfSimulcastStreams = - std::max(1, stream_codec.numberOfSimulcastStreams); - } else { - // Cap start bitrate to the min bitrate in order to avoid strange codec - // behavior. Since sending will be false, this should not matter. - StreamResolution stream_resolution = - i == highest_resolution_stream_index - ? StreamResolution::HIGHEST - : i == lowest_resolution_stream_index ? StreamResolution::LOWEST - : StreamResolution::OTHER; - - start_bitrate_kbps = - std::max(codec_.simulcastStream[i].minBitrate, start_bitrate_kbps); - PopulateStreamCodec(codec_, i, start_bitrate_kbps, stream_resolution, - &stream_codec); - } + // Multi-encoder simulcast or singlecast (deactivated layers). + std::vector stream_start_bitrate_kbps = + GetStreamStartBitratesKbps(codec_); - // TODO(ronghuawu): Remove once this is handled in LibvpxVp8Encoder. - if (stream_codec.qpMax < kDefaultMinQp) { - stream_codec.qpMax = kDefaultMaxQp; + for (int stream_idx = 0; stream_idx < total_streams_count_; ++stream_idx) { + if (!is_legacy_singlecast && !codec_.simulcastStream[stream_idx].active) { + continue; } - if (!encoder_initialized) { - ret = encoder->InitEncode(&stream_codec, settings); - if (ret < 0) { - // Explicitly destroy the current encoder; because we haven't registered - // a StreamInfo for it yet, Release won't do anything about it. - encoder.reset(); - Release(); - return ret; - } + if (encoder_context == nullptr) { + encoder_context = FetchOrCreateEncoderContext( + /*is_lowest_quality_stream=*/stream_idx == lowest_quality_stream_idx); + } + if (encoder_context == nullptr) { + Release(); + return WEBRTC_VIDEO_CODEC_MEMORY; } - if (!doing_simulcast_using_adapter) { - // Without simulcast, just pass through the encoder info from the one - // active encoder. - encoder->RegisterEncodeCompleteCallback(encoded_complete_callback_); - streaminfos_.emplace_back( - std::move(encoder), nullptr, - std::make_unique(stream_codec.maxFramerate), - stream_codec.width, stream_codec.height, send_stream); - } else { - std::unique_ptr callback( - new AdapterEncodedImageCallback(this, i)); - encoder->RegisterEncodeCompleteCallback(callback.get()); - streaminfos_.emplace_back( - std::move(encoder), std::move(callback), - std::make_unique(stream_codec.maxFramerate), - stream_codec.width, stream_codec.height, send_stream); + VideoCodec stream_codec = MakeStreamCodec( + codec_, stream_idx, stream_start_bitrate_kbps[stream_idx], + /*is_lowest_quality_stream=*/stream_idx == lowest_quality_stream_idx, + /*is_highest_quality_stream=*/stream_idx == highest_quality_stream_idx); + + int ret = encoder_context->encoder().InitEncode(&stream_codec, settings); + if (ret < 0) { + encoder_context.reset(); + Release(); + return ret; } + + // Intercept frame encode complete callback only for upper streams, where + // we need to set a correct stream index. Set |parent| to nullptr for the + // lowest stream to bypass the callback. + SimulcastEncoderAdapter* parent = stream_idx > 0 ? this : nullptr; + + bool is_paused = stream_start_bitrate_kbps[stream_idx] == 0; + stream_contexts_.emplace_back( + parent, std::move(encoder_context), + std::make_unique(stream_codec.maxFramerate), + stream_idx, stream_codec.width, stream_codec.height, is_paused); } // To save memory, don't store encoders that we don't use. DestroyStoredEncoders(); rtc::AtomicOps::ReleaseStore(&inited_, 1); - return WEBRTC_VIDEO_CODEC_OK; } @@ -351,22 +445,46 @@ int SimulcastEncoderAdapter::Encode( return WEBRTC_VIDEO_CODEC_UNINITIALIZED; } + if (encoder_info_override_.requested_resolution_alignment()) { + const int alignment = + *encoder_info_override_.requested_resolution_alignment(); + if (input_image.width() % alignment != 0 || + input_image.height() % alignment != 0) { + RTC_LOG(LS_WARNING) << "Frame " << input_image.width() << "x" + << input_image.height() << " not divisible by " + << alignment; + return WEBRTC_VIDEO_CODEC_ERROR; + } + if (encoder_info_override_.apply_alignment_to_all_simulcast_layers()) { + for (const auto& layer : stream_contexts_) { + if (layer.width() % alignment != 0 || layer.height() % alignment != 0) { + RTC_LOG(LS_WARNING) + << "Codec " << layer.width() << "x" << layer.height() + << " not divisible by " << alignment; + return WEBRTC_VIDEO_CODEC_ERROR; + } + } + } + } + // All active streams should generate a key frame if // a key frame is requested by any stream. - bool send_key_frame = false; + bool is_keyframe_needed = false; if (frame_types) { - for (size_t i = 0; i < frame_types->size(); ++i) { - if (frame_types->at(i) == VideoFrameType::kVideoFrameKey) { - send_key_frame = true; + for (const auto& frame_type : *frame_types) { + if (frame_type == VideoFrameType::kVideoFrameKey) { + is_keyframe_needed = true; break; } } } - for (size_t stream_idx = 0; stream_idx < streaminfos_.size(); ++stream_idx) { - if (streaminfos_[stream_idx].key_frame_request && - streaminfos_[stream_idx].send_stream) { - send_key_frame = true; - break; + + if (!is_keyframe_needed) { + for (const auto& layer : stream_contexts_) { + if (layer.is_keyframe_needed()) { + is_keyframe_needed = true; + break; + } } } @@ -374,36 +492,34 @@ int SimulcastEncoderAdapter::Encode( rtc::scoped_refptr src_buffer; int src_width = input_image.width(); int src_height = input_image.height(); - for (size_t stream_idx = 0; stream_idx < streaminfos_.size(); ++stream_idx) { + + for (auto& layer : stream_contexts_) { // Don't encode frames in resolutions that we don't intend to send. - if (!streaminfos_[stream_idx].send_stream) { + if (layer.is_paused()) { continue; } - const uint32_t frame_timestamp_ms = - 1000 * input_image.timestamp() / 90000; // kVideoPayloadTypeFrequency; + // Convert timestamp from RTP 90kHz clock. + const Timestamp frame_timestamp = + Timestamp::Micros((1000 * input_image.timestamp()) / 90); // If adapter is passed through and only one sw encoder does simulcast, // frame types for all streams should be passed to the encoder unchanged. // Otherwise a single per-encoder frame type is passed. std::vector stream_frame_types( - streaminfos_.size() == 1 ? NumberOfStreams(codec_) : 1); - if (send_key_frame) { + bypass_mode_ ? total_streams_count_ : 1); + if (is_keyframe_needed) { std::fill(stream_frame_types.begin(), stream_frame_types.end(), VideoFrameType::kVideoFrameKey); - streaminfos_[stream_idx].key_frame_request = false; + layer.OnKeyframe(frame_timestamp); } else { - if (streaminfos_[stream_idx].framerate_controller->DropFrame( - frame_timestamp_ms)) { + if (layer.ShouldDropFrame(frame_timestamp)) { continue; } std::fill(stream_frame_types.begin(), stream_frame_types.end(), VideoFrameType::kVideoFrameDelta); } - streaminfos_[stream_idx].framerate_controller->AddFrame(frame_timestamp_ms); - int dst_width = streaminfos_[stream_idx].width; - int dst_height = streaminfos_[stream_idx].height; // If scaling isn't required, because the input resolution // matches the destination or the input image is empty (e.g. // a keyframe request for encoders with internal camera @@ -414,14 +530,11 @@ int SimulcastEncoderAdapter::Encode( // correctly sample/scale the source texture. // TODO(perkj): ensure that works going forward, and figure out how this // affects webrtc:5683. - if ((dst_width == src_width && dst_height == src_height) || + if ((layer.width() == src_width && layer.height() == src_height) || (input_image.video_frame_buffer()->type() == VideoFrameBuffer::Type::kNative && - streaminfos_[stream_idx] - .encoder->GetEncoderInfo() - .supports_native_handle)) { - int ret = streaminfos_[stream_idx].encoder->Encode(input_image, - &stream_frame_types); + layer.encoder().GetEncoderInfo().supports_native_handle)) { + int ret = layer.encoder().Encode(input_image, &stream_frame_types); if (ret != WEBRTC_VIDEO_CODEC_OK) { return ret; } @@ -430,7 +543,7 @@ int SimulcastEncoderAdapter::Encode( src_buffer = input_image.video_frame_buffer(); } rtc::scoped_refptr dst_buffer = - src_buffer->Scale(dst_width, dst_height); + src_buffer->Scale(layer.width(), layer.height()); if (!dst_buffer) { RTC_LOG(LS_ERROR) << "Failed to scale video frame"; return WEBRTC_VIDEO_CODEC_ENCODER_FAILURE; @@ -443,8 +556,7 @@ int SimulcastEncoderAdapter::Encode( frame.set_rotation(webrtc::kVideoRotation_0); frame.set_update_rect( VideoFrame::UpdateRect{0, 0, frame.width(), frame.height()}); - int ret = - streaminfos_[stream_idx].encoder->Encode(frame, &stream_frame_types); + int ret = layer.encoder().Encode(frame, &stream_frame_types); if (ret != WEBRTC_VIDEO_CODEC_OK) { return ret; } @@ -458,8 +570,10 @@ int SimulcastEncoderAdapter::RegisterEncodeCompleteCallback( EncodedImageCallback* callback) { RTC_DCHECK_RUN_ON(&encoder_queue_); encoded_complete_callback_ = callback; - if (streaminfos_.size() == 1) { - streaminfos_[0].encoder->RegisterEncodeCompleteCallback(callback); + if (!stream_contexts_.empty() && stream_contexts_.front().stream_idx() == 0) { + // Bypass frame encode complete callback for the lowest layer since there is + // no need to override frame's spatial index. + stream_contexts_.front().encoder().RegisterEncodeCompleteCallback(callback); } return WEBRTC_VIDEO_CODEC_OK; } @@ -480,21 +594,21 @@ void SimulcastEncoderAdapter::SetRates( codec_.maxFramerate = static_cast(parameters.framerate_fps + 0.5); - if (streaminfos_.size() == 1) { - // Not doing simulcast. - streaminfos_[0].encoder->SetRates(parameters); + if (bypass_mode_) { + stream_contexts_.front().encoder().SetRates(parameters); return; } - for (size_t stream_idx = 0; stream_idx < streaminfos_.size(); ++stream_idx) { + for (StreamContext& layer_context : stream_contexts_) { + int stream_idx = layer_context.stream_idx(); uint32_t stream_bitrate_kbps = parameters.bitrate.GetSpatialLayerSum(stream_idx) / 1000; // Need a key frame if we have not sent this stream before. - if (stream_bitrate_kbps > 0 && !streaminfos_[stream_idx].send_stream) { - streaminfos_[stream_idx].key_frame_request = true; + if (stream_bitrate_kbps > 0 && layer_context.is_paused()) { + layer_context.set_is_keyframe_needed(); } - streaminfos_[stream_idx].send_stream = stream_bitrate_kbps > 0; + layer_context.set_is_paused(stream_bitrate_kbps == 0); // Slice the temporal layers out of the full allocation and pass it on to // the encoder handling the current simulcast stream. @@ -524,28 +638,28 @@ void SimulcastEncoderAdapter::SetRates( stream_parameters.framerate_fps = std::min( parameters.framerate_fps, - streaminfos_[stream_idx].framerate_controller->GetTargetRate()); + layer_context.target_fps().value_or(parameters.framerate_fps)); - streaminfos_[stream_idx].encoder->SetRates(stream_parameters); + layer_context.encoder().SetRates(stream_parameters); } } void SimulcastEncoderAdapter::OnPacketLossRateUpdate(float packet_loss_rate) { - for (StreamInfo& info : streaminfos_) { - info.encoder->OnPacketLossRateUpdate(packet_loss_rate); + for (auto& c : stream_contexts_) { + c.encoder().OnPacketLossRateUpdate(packet_loss_rate); } } void SimulcastEncoderAdapter::OnRttUpdate(int64_t rtt_ms) { - for (StreamInfo& info : streaminfos_) { - info.encoder->OnRttUpdate(rtt_ms); + for (auto& c : stream_contexts_) { + c.encoder().OnRttUpdate(rtt_ms); } } void SimulcastEncoderAdapter::OnLossNotification( const LossNotification& loss_notification) { - for (StreamInfo& info : streaminfos_) { - info.encoder->OnLossNotification(loss_notification); + for (auto& c : stream_contexts_) { + c.encoder().OnLossNotification(loss_notification); } } @@ -564,75 +678,147 @@ EncodedImageCallback::Result SimulcastEncoderAdapter::OnEncodedImage( &stream_codec_specific); } -void SimulcastEncoderAdapter::PopulateStreamCodec( - const webrtc::VideoCodec& inst, - int stream_index, +void SimulcastEncoderAdapter::OnDroppedFrame(size_t stream_idx) { + // Not yet implemented. +} + +bool SimulcastEncoderAdapter::Initialized() const { + return rtc::AtomicOps::AcquireLoad(&inited_) == 1; +} + +void SimulcastEncoderAdapter::DestroyStoredEncoders() { + while (!cached_encoder_contexts_.empty()) { + cached_encoder_contexts_.pop_back(); + } +} + +std::unique_ptr +SimulcastEncoderAdapter::FetchOrCreateEncoderContext( + bool is_lowest_quality_stream) const { + bool prefer_temporal_support = fallback_encoder_factory_ != nullptr && + is_lowest_quality_stream && + prefer_temporal_support_on_base_layer_; + + // Toggling of |prefer_temporal_support| requires encoder recreation. Find + // and reuse encoder with desired |prefer_temporal_support|. Otherwise, if + // there is no such encoder in the cache, create a new instance. + auto encoder_context_iter = + std::find_if(cached_encoder_contexts_.begin(), + cached_encoder_contexts_.end(), [&](auto& encoder_context) { + return encoder_context->prefer_temporal_support() == + prefer_temporal_support; + }); + + std::unique_ptr encoder_context; + if (encoder_context_iter != cached_encoder_contexts_.end()) { + encoder_context = std::move(*encoder_context_iter); + cached_encoder_contexts_.erase(encoder_context_iter); + } else { + std::unique_ptr encoder = + primary_encoder_factory_->CreateVideoEncoder(video_format_); + VideoEncoder::EncoderInfo primary_info = encoder->GetEncoderInfo(); + VideoEncoder::EncoderInfo fallback_info = primary_info; + if (fallback_encoder_factory_ != nullptr) { + std::unique_ptr fallback_encoder = + fallback_encoder_factory_->CreateVideoEncoder(video_format_); + fallback_info = fallback_encoder->GetEncoderInfo(); + encoder = CreateVideoEncoderSoftwareFallbackWrapper( + std::move(fallback_encoder), std::move(encoder), + prefer_temporal_support); + } + + encoder_context = std::make_unique( + std::move(encoder), prefer_temporal_support, primary_info, + fallback_info); + } + + encoder_context->encoder().RegisterEncodeCompleteCallback( + encoded_complete_callback_); + return encoder_context; +} + +webrtc::VideoCodec SimulcastEncoderAdapter::MakeStreamCodec( + const webrtc::VideoCodec& codec, + int stream_idx, uint32_t start_bitrate_kbps, - StreamResolution stream_resolution, - webrtc::VideoCodec* stream_codec) { - *stream_codec = inst; - - // Stream specific settings. - stream_codec->numberOfSimulcastStreams = 0; - stream_codec->width = inst.simulcastStream[stream_index].width; - stream_codec->height = inst.simulcastStream[stream_index].height; - stream_codec->maxBitrate = inst.simulcastStream[stream_index].maxBitrate; - stream_codec->minBitrate = inst.simulcastStream[stream_index].minBitrate; - stream_codec->maxFramerate = inst.simulcastStream[stream_index].maxFramerate; - stream_codec->qpMax = inst.simulcastStream[stream_index].qpMax; - stream_codec->active = inst.simulcastStream[stream_index].active; + bool is_lowest_quality_stream, + bool is_highest_quality_stream) { + webrtc::VideoCodec codec_params = codec; + const SpatialLayer& stream_params = codec.simulcastStream[stream_idx]; + + codec_params.numberOfSimulcastStreams = 0; + codec_params.width = stream_params.width; + codec_params.height = stream_params.height; + codec_params.maxBitrate = stream_params.maxBitrate; + codec_params.minBitrate = stream_params.minBitrate; + codec_params.maxFramerate = stream_params.maxFramerate; + codec_params.qpMax = stream_params.qpMax; + codec_params.active = stream_params.active; // Settings that are based on stream/resolution. - if (stream_resolution == StreamResolution::LOWEST) { + if (is_lowest_quality_stream) { // Settings for lowest spatial resolutions. - if (inst.mode == VideoCodecMode::kScreensharing) { + if (codec.mode == VideoCodecMode::kScreensharing) { if (experimental_boosted_screenshare_qp_) { - stream_codec->qpMax = *experimental_boosted_screenshare_qp_; + codec_params.qpMax = *experimental_boosted_screenshare_qp_; } } else if (boost_base_layer_quality_) { - stream_codec->qpMax = kLowestResMaxQp; + codec_params.qpMax = kLowestResMaxQp; } } - if (inst.codecType == webrtc::kVideoCodecVP8) { - stream_codec->VP8()->numberOfTemporalLayers = - inst.simulcastStream[stream_index].numberOfTemporalLayers; - if (stream_resolution != StreamResolution::HIGHEST) { + if (codec.codecType == webrtc::kVideoCodecVP8) { + codec_params.VP8()->numberOfTemporalLayers = + stream_params.numberOfTemporalLayers; + if (!is_highest_quality_stream) { // For resolutions below CIF, set the codec |complexity| parameter to // kComplexityHigher, which maps to cpu_used = -4. - int pixels_per_frame = stream_codec->width * stream_codec->height; + int pixels_per_frame = codec_params.width * codec_params.height; if (pixels_per_frame < 352 * 288) { - stream_codec->VP8()->complexity = + codec_params.VP8()->complexity = webrtc::VideoCodecComplexity::kComplexityHigher; } // Turn off denoising for all streams but the highest resolution. - stream_codec->VP8()->denoisingOn = false; + codec_params.VP8()->denoisingOn = false; } - } else if (inst.codecType == webrtc::kVideoCodecH264) { - stream_codec->H264()->numberOfTemporalLayers = - inst.simulcastStream[stream_index].numberOfTemporalLayers; + } else if (codec.codecType == webrtc::kVideoCodecH264) { + codec_params.H264()->numberOfTemporalLayers = + stream_params.numberOfTemporalLayers; } - // TODO(ronghuawu): what to do with targetBitrate. - stream_codec->startBitrate = start_bitrate_kbps; + // Cap start bitrate to the min bitrate in order to avoid strange codec + // behavior. + codec_params.startBitrate = + std::max(stream_params.minBitrate, start_bitrate_kbps); // Legacy screenshare mode is only enabled for the first simulcast layer - stream_codec->legacy_conference_mode = - inst.legacy_conference_mode && stream_index == 0; -} + codec_params.legacy_conference_mode = + codec.legacy_conference_mode && stream_idx == 0; -bool SimulcastEncoderAdapter::Initialized() const { - return rtc::AtomicOps::AcquireLoad(&inited_) == 1; + return codec_params; } -void SimulcastEncoderAdapter::DestroyStoredEncoders() { - while (!stored_encoders_.empty()) { - stored_encoders_.pop(); +void SimulcastEncoderAdapter::OverrideFromFieldTrial( + VideoEncoder::EncoderInfo* info) const { + if (encoder_info_override_.requested_resolution_alignment()) { + info->requested_resolution_alignment = cricket::LeastCommonMultiple( + info->requested_resolution_alignment, + *encoder_info_override_.requested_resolution_alignment()); + info->apply_alignment_to_all_simulcast_layers = + info->apply_alignment_to_all_simulcast_layers || + encoder_info_override_.apply_alignment_to_all_simulcast_layers(); + } + if (!encoder_info_override_.resolution_bitrate_limits().empty()) { + info->resolution_bitrate_limits = + encoder_info_override_.resolution_bitrate_limits(); } } VideoEncoder::EncoderInfo SimulcastEncoderAdapter::GetEncoderInfo() const { - if (streaminfos_.size() == 1) { + if (stream_contexts_.size() == 1) { // Not using simulcast adapting functionality, just pass through. - return streaminfos_[0].encoder->GetEncoderInfo(); + VideoEncoder::EncoderInfo info = + stream_contexts_.front().encoder().GetEncoderInfo(); + OverrideFromFieldTrial(&info); + return info; } VideoEncoder::EncoderInfo encoder_info; @@ -641,17 +827,43 @@ VideoEncoder::EncoderInfo SimulcastEncoderAdapter::GetEncoderInfo() const { encoder_info.apply_alignment_to_all_simulcast_layers = false; encoder_info.supports_native_handle = true; encoder_info.scaling_settings.thresholds = absl::nullopt; - if (streaminfos_.empty()) { + + if (stream_contexts_.empty()) { + // GetEncoderInfo queried before InitEncode. Only alignment info is needed + // to be filled. + // Create one encoder and query it. + + std::unique_ptr encoder_context = + FetchOrCreateEncoderContext(true); + + const VideoEncoder::EncoderInfo& primary_info = + encoder_context->PrimaryInfo(); + const VideoEncoder::EncoderInfo& fallback_info = + encoder_context->FallbackInfo(); + + encoder_info.requested_resolution_alignment = cricket::LeastCommonMultiple( + primary_info.requested_resolution_alignment, + fallback_info.requested_resolution_alignment); + + encoder_info.apply_alignment_to_all_simulcast_layers = + primary_info.apply_alignment_to_all_simulcast_layers || + fallback_info.apply_alignment_to_all_simulcast_layers; + + if (!primary_info.supports_simulcast || !fallback_info.supports_simulcast) { + encoder_info.apply_alignment_to_all_simulcast_layers = true; + } + + cached_encoder_contexts_.emplace_back(std::move(encoder_context)); + + OverrideFromFieldTrial(&encoder_info); return encoder_info; } encoder_info.scaling_settings = VideoEncoder::ScalingSettings::kOff; - int num_active_streams = NumActiveStreams(codec_); - for (size_t i = 0; i < streaminfos_.size(); ++i) { + for (size_t i = 0; i < stream_contexts_.size(); ++i) { VideoEncoder::EncoderInfo encoder_impl_info = - streaminfos_[i].encoder->GetEncoderInfo(); - + stream_contexts_[i].encoder().GetEncoderInfo(); if (i == 0) { // Encoder name indicates names of all sub-encoders. encoder_info.implementation_name += " ("; @@ -690,15 +902,19 @@ VideoEncoder::EncoderInfo SimulcastEncoderAdapter::GetEncoderInfo() const { encoder_info.requested_resolution_alignment = cricket::LeastCommonMultiple( encoder_info.requested_resolution_alignment, encoder_impl_info.requested_resolution_alignment); - if (encoder_impl_info.apply_alignment_to_all_simulcast_layers) { + // request alignment on all layers if any of the encoders may need it, or + // if any non-top layer encoder requests a non-trivial alignment. + if (encoder_impl_info.apply_alignment_to_all_simulcast_layers || + (encoder_impl_info.requested_resolution_alignment > 1 && + (codec_.simulcastStream[i].height < codec_.height || + codec_.simulcastStream[i].width < codec_.width))) { encoder_info.apply_alignment_to_all_simulcast_layers = true; } - if (num_active_streams == 1 && codec_.simulcastStream[i].active) { - encoder_info.scaling_settings = encoder_impl_info.scaling_settings; - } } encoder_info.implementation_name += ")"; + OverrideFromFieldTrial(&encoder_info); + return encoder_info; } diff --git a/media/engine/simulcast_encoder_adapter.h b/media/engine/simulcast_encoder_adapter.h index 1067df8ed1..07e3ccd024 100644 --- a/media/engine/simulcast_encoder_adapter.h +++ b/media/engine/simulcast_encoder_adapter.h @@ -12,6 +12,7 @@ #ifndef MEDIA_ENGINE_SIMULCAST_ENCODER_ADAPTER_H_ #define MEDIA_ENGINE_SIMULCAST_ENCODER_ADAPTER_H_ +#include #include #include #include @@ -20,20 +21,19 @@ #include "absl/types/optional.h" #include "api/fec_controller_override.h" +#include "api/sequence_checker.h" #include "api/video_codecs/sdp_video_format.h" #include "api/video_codecs/video_encoder.h" +#include "api/video_codecs/video_encoder_factory.h" #include "modules/video_coding/include/video_codec_interface.h" #include "modules/video_coding/utility/framerate_controller.h" #include "rtc_base/atomic_ops.h" -#include "rtc_base/synchronization/sequence_checker.h" +#include "rtc_base/experiments/encoder_info_settings.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/system/rtc_export.h" namespace webrtc { -class SimulcastRateAllocator; -class VideoEncoderFactory; - // SimulcastEncoderAdapter implements simulcast support by creating multiple // webrtc::VideoEncoder instances with the given VideoEncoderFactory. // The object is created and destroyed on the worker thread, but all public @@ -65,75 +65,132 @@ class RTC_EXPORT SimulcastEncoderAdapter : public VideoEncoder { void OnRttUpdate(int64_t rtt_ms) override; void OnLossNotification(const LossNotification& loss_notification) override; - // Eventual handler for the contained encoders' EncodedImageCallbacks, but - // called from an internal helper that also knows the correct stream - // index. - EncodedImageCallback::Result OnEncodedImage( - size_t stream_idx, - const EncodedImage& encoded_image, - const CodecSpecificInfo* codec_specific_info); - EncoderInfo GetEncoderInfo() const override; private: - struct StreamInfo { - StreamInfo(std::unique_ptr encoder, - std::unique_ptr callback, - std::unique_ptr framerate_controller, - uint16_t width, - uint16_t height, - bool send_stream) - : encoder(std::move(encoder)), - callback(std::move(callback)), - framerate_controller(std::move(framerate_controller)), - width(width), - height(height), - key_frame_request(false), - send_stream(send_stream) {} - std::unique_ptr encoder; - std::unique_ptr callback; - std::unique_ptr framerate_controller; - uint16_t width; - uint16_t height; - bool key_frame_request; - bool send_stream; + class EncoderContext { + public: + EncoderContext(std::unique_ptr encoder, + bool prefer_temporal_support, + VideoEncoder::EncoderInfo primary_info, + VideoEncoder::EncoderInfo fallback_info); + EncoderContext& operator=(EncoderContext&&) = delete; + + VideoEncoder& encoder() { return *encoder_; } + bool prefer_temporal_support() { return prefer_temporal_support_; } + void Release(); + + const VideoEncoder::EncoderInfo& PrimaryInfo() { return primary_info_; } + + const VideoEncoder::EncoderInfo& FallbackInfo() { return fallback_info_; } + + private: + std::unique_ptr encoder_; + bool prefer_temporal_support_; + const VideoEncoder::EncoderInfo primary_info_; + const VideoEncoder::EncoderInfo fallback_info_; }; - enum class StreamResolution { - OTHER, - HIGHEST, - LOWEST, + class StreamContext : public EncodedImageCallback { + public: + StreamContext(SimulcastEncoderAdapter* parent, + std::unique_ptr encoder_context, + std::unique_ptr framerate_controller, + int stream_idx, + uint16_t width, + uint16_t height, + bool send_stream); + StreamContext(StreamContext&& rhs); + StreamContext& operator=(StreamContext&&) = delete; + ~StreamContext() override; + + Result OnEncodedImage( + const EncodedImage& encoded_image, + const CodecSpecificInfo* codec_specific_info) override; + void OnDroppedFrame(DropReason reason) override; + + VideoEncoder& encoder() { return encoder_context_->encoder(); } + const VideoEncoder& encoder() const { return encoder_context_->encoder(); } + int stream_idx() const { return stream_idx_; } + uint16_t width() const { return width_; } + uint16_t height() const { return height_; } + bool is_keyframe_needed() const { + return !is_paused_ && is_keyframe_needed_; + } + void set_is_keyframe_needed() { is_keyframe_needed_ = true; } + bool is_paused() const { return is_paused_; } + void set_is_paused(bool is_paused) { is_paused_ = is_paused; } + absl::optional target_fps() const { + return framerate_controller_ == nullptr + ? absl::nullopt + : absl::optional( + framerate_controller_->GetTargetRate()); + } + + std::unique_ptr ReleaseEncoderContext() &&; + void OnKeyframe(Timestamp timestamp); + bool ShouldDropFrame(Timestamp timestamp); + + private: + SimulcastEncoderAdapter* const parent_; + std::unique_ptr encoder_context_; + std::unique_ptr framerate_controller_; + const int stream_idx_; + const uint16_t width_; + const uint16_t height_; + bool is_keyframe_needed_; + bool is_paused_; }; - // Populate the codec settings for each simulcast stream. - void PopulateStreamCodec(const webrtc::VideoCodec& inst, - int stream_index, - uint32_t start_bitrate_kbps, - StreamResolution stream_resolution, - webrtc::VideoCodec* stream_codec); - bool Initialized() const; void DestroyStoredEncoders(); + // This method creates encoder. May reuse previously created encoders from + // |cached_encoder_contexts_|. It's const because it's used from + // const GetEncoderInfo(). + std::unique_ptr FetchOrCreateEncoderContext( + bool is_lowest_quality_stream) const; + + webrtc::VideoCodec MakeStreamCodec(const webrtc::VideoCodec& codec, + int stream_idx, + uint32_t start_bitrate_kbps, + bool is_lowest_quality_stream, + bool is_highest_quality_stream); + + EncodedImageCallback::Result OnEncodedImage( + size_t stream_idx, + const EncodedImage& encoded_image, + const CodecSpecificInfo* codec_specific_info); + + void OnDroppedFrame(size_t stream_idx); + + void OverrideFromFieldTrial(VideoEncoder::EncoderInfo* info) const; + volatile int inited_; // Accessed atomically. VideoEncoderFactory* const primary_encoder_factory_; VideoEncoderFactory* const fallback_encoder_factory_; const SdpVideoFormat video_format_; VideoCodec codec_; - std::vector streaminfos_; + int total_streams_count_; + bool bypass_mode_; + std::vector stream_contexts_; EncodedImageCallback* encoded_complete_callback_; // Used for checking the single-threaded access of the encoder interface. RTC_NO_UNIQUE_ADDRESS SequenceChecker encoder_queue_; - // Store encoders in between calls to Release and InitEncode, so they don't - // have to be recreated. Remaining encoders are destroyed by the destructor. - std::stack> stored_encoders_; + // Store previously created and released encoders , so they don't have to be + // recreated. Remaining encoders are destroyed by the destructor. + // Marked as |mutable| becuase we may need to temporarily create encoder in + // GetEncoderInfo(), which is const. + mutable std::list> cached_encoder_contexts_; const absl::optional experimental_boosted_screenshare_qp_; const bool boost_base_layer_quality_; const bool prefer_temporal_support_on_base_layer_; + + const SimulcastEncoderAdapterEncoderInfoSettings encoder_info_override_; }; } // namespace webrtc diff --git a/media/engine/simulcast_encoder_adapter_unittest.cc b/media/engine/simulcast_encoder_adapter_unittest.cc index 24686e813e..48e005f1c2 100644 --- a/media/engine/simulcast_encoder_adapter_unittest.cc +++ b/media/engine/simulcast_encoder_adapter_unittest.cc @@ -18,6 +18,7 @@ #include "api/test/simulcast_test_fixture.h" #include "api/test/video/function_video_decoder_factory.h" #include "api/test/video/function_video_encoder_factory.h" +#include "api/video/video_codec_constants.h" #include "api/video_codecs/sdp_video_format.h" #include "api/video_codecs/video_encoder.h" #include "api/video_codecs/video_encoder_factory.h" @@ -28,6 +29,7 @@ #include "modules/video_coding/include/video_codec_interface.h" #include "modules/video_coding/utility/simulcast_test_fixture_impl.h" #include "rtc_base/checks.h" +#include "test/field_trial.h" #include "test/gmock.h" #include "test/gtest.h" @@ -420,14 +422,24 @@ class TestSimulcastEncoderAdapterFake : public ::testing::Test, } void SetUp() override { - helper_ = std::make_unique( - use_fallback_factory_, SdpVideoFormat("VP8", sdp_video_parameters_)); + helper_.reset(new TestSimulcastEncoderAdapterFakeHelper( + use_fallback_factory_, SdpVideoFormat("VP8", sdp_video_parameters_))); adapter_.reset(helper_->CreateMockEncoderAdapter()); last_encoded_image_width_ = -1; last_encoded_image_height_ = -1; last_encoded_image_simulcast_index_ = -1; } + void ReSetUp() { + if (adapter_) { + adapter_->Release(); + // |helper_| owns factories which |adapter_| needs to destroy encoders. + // Release |adapter_| before |helper_| (released in SetUp()). + adapter_.reset(); + } + SetUp(); + } + Result OnEncodedImage(const EncodedImage& encoded_image, const CodecSpecificInfo* codec_specific_info) override { last_encoded_image_width_ = encoded_image._encodedWidth; @@ -450,10 +462,23 @@ class TestSimulcastEncoderAdapterFake : public ::testing::Test, return true; } - void SetupCodec() { + void SetupCodec() { SetupCodec(/*active_streams=*/{true, true, true}); } + + void SetupCodec(std::vector active_streams) { SimulcastTestFixtureImpl::DefaultSettings( &codec_, static_cast(kTestTemporalLayerProfile), kVideoCodecVP8); + ASSERT_LE(active_streams.size(), codec_.numberOfSimulcastStreams); + codec_.numberOfSimulcastStreams = active_streams.size(); + for (size_t stream_idx = 0; stream_idx < kMaxSimulcastStreams; + ++stream_idx) { + if (stream_idx >= codec_.numberOfSimulcastStreams) { + // Reset parameters of unspecified stream. + codec_.simulcastStream[stream_idx] = {0}; + } else { + codec_.simulcastStream[stream_idx].active = active_streams[stream_idx]; + } + } rate_allocator_.reset(new SimulcastRateAllocator(codec_)); EXPECT_EQ(0, adapter_->InitEncode(&codec_, kSettings)); adapter_->RegisterEncodeCompleteCallback(this); @@ -578,7 +603,8 @@ TEST_F(TestSimulcastEncoderAdapterFake, EncodedCallbackForDifferentEncoders) { EXPECT_TRUE(GetLastEncodedImageInfo(&width, &height, &simulcast_index)); EXPECT_EQ(1152, width); EXPECT_EQ(704, height); - EXPECT_EQ(0, simulcast_index); + // SEA doesn't intercept frame encode complete callback for the lowest stream. + EXPECT_EQ(-1, simulcast_index); encoders[1]->SendEncodedImage(300, 620); EXPECT_TRUE(GetLastEncodedImageInfo(&width, &height, &simulcast_index)); @@ -794,7 +820,8 @@ TEST_F(TestSimulcastEncoderAdapterFake, ReinitDoesNotReorderFrameSimulcastIdx) { int height; int simulcast_index; EXPECT_TRUE(GetLastEncodedImageInfo(&width, &height, &simulcast_index)); - EXPECT_EQ(0, simulcast_index); + // SEA doesn't intercept frame encode complete callback for the lowest stream. + EXPECT_EQ(-1, simulcast_index); encoders[1]->SendEncodedImage(300, 620); EXPECT_TRUE(GetLastEncodedImageInfo(&width, &height, &simulcast_index)); @@ -814,7 +841,7 @@ TEST_F(TestSimulcastEncoderAdapterFake, ReinitDoesNotReorderFrameSimulcastIdx) { // Verify that the same encoder sends out frames on the same simulcast index. encoders[0]->SendEncodedImage(1152, 704); EXPECT_TRUE(GetLastEncodedImageInfo(&width, &height, &simulcast_index)); - EXPECT_EQ(0, simulcast_index); + EXPECT_EQ(-1, simulcast_index); encoders[1]->SendEncodedImage(300, 620); EXPECT_TRUE(GetLastEncodedImageInfo(&width, &height, &simulcast_index)); @@ -873,8 +900,6 @@ TEST_F(TestSimulcastEncoderAdapterFake, SetRatesUnderMinBitrate) { } TEST_F(TestSimulcastEncoderAdapterFake, SupportsImplementationName) { - EXPECT_EQ("SimulcastEncoderAdapter", - adapter_->GetEncoderInfo().implementation_name); SimulcastTestFixtureImpl::DefaultSettings( &codec_, static_cast(kTestTemporalLayerProfile), kVideoCodecVP8); @@ -883,6 +908,8 @@ TEST_F(TestSimulcastEncoderAdapterFake, SupportsImplementationName) { encoder_names.push_back("codec2"); encoder_names.push_back("codec3"); helper_->factory()->SetEncoderNames(encoder_names); + EXPECT_EQ("SimulcastEncoderAdapter", + adapter_->GetEncoderInfo().implementation_name); EXPECT_EQ(0, adapter_->InitEncode(&codec_, kSettings)); EXPECT_EQ("SimulcastEncoderAdapter (codec1, codec2, codec3)", adapter_->GetEncoderInfo().implementation_name); @@ -979,8 +1006,8 @@ TEST_F(TestSimulcastEncoderAdapterFake, EXPECT_TRUE(adapter_->GetEncoderInfo().supports_native_handle); rtc::scoped_refptr buffer( - new rtc::RefCountedObject(1280, 720, - /*allow_to_i420=*/false)); + rtc::make_ref_counted(1280, 720, + /*allow_to_i420=*/false)); VideoFrame input_frame = VideoFrame::Builder() .set_video_frame_buffer(buffer) .set_timestamp_rtp(100) @@ -1016,8 +1043,8 @@ TEST_F(TestSimulcastEncoderAdapterFake, NativeHandleForwardingOnlyIfSupported) { EXPECT_TRUE(adapter_->GetEncoderInfo().supports_native_handle); rtc::scoped_refptr buffer( - new rtc::RefCountedObject(1280, 720, - /*allow_to_i420=*/true)); + rtc::make_ref_counted(1280, 720, + /*allow_to_i420=*/true)); VideoFrame input_frame = VideoFrame::Builder() .set_video_frame_buffer(buffer) .set_timestamp_rtp(100) @@ -1291,6 +1318,53 @@ TEST_F(TestSimulcastEncoderAdapterFake, adapter_->GetEncoderInfo().apply_alignment_to_all_simulcast_layers); } +TEST_F(TestSimulcastEncoderAdapterFake, EncoderInfoFromFieldTrial) { + test::ScopedFieldTrials field_trials( + "WebRTC-SimulcastEncoderAdapter-GetEncoderInfoOverride/" + "requested_resolution_alignment:8," + "apply_alignment_to_all_simulcast_layers/"); + SetUp(); + SimulcastTestFixtureImpl::DefaultSettings( + &codec_, static_cast(kTestTemporalLayerProfile), + kVideoCodecVP8); + codec_.numberOfSimulcastStreams = 3; + EXPECT_EQ(0, adapter_->InitEncode(&codec_, kSettings)); + ASSERT_EQ(3u, helper_->factory()->encoders().size()); + + EXPECT_EQ(8, adapter_->GetEncoderInfo().requested_resolution_alignment); + EXPECT_TRUE( + adapter_->GetEncoderInfo().apply_alignment_to_all_simulcast_layers); + EXPECT_TRUE(adapter_->GetEncoderInfo().resolution_bitrate_limits.empty()); +} + +TEST_F(TestSimulcastEncoderAdapterFake, + EncoderInfoFromFieldTrialForSingleStream) { + test::ScopedFieldTrials field_trials( + "WebRTC-SimulcastEncoderAdapter-GetEncoderInfoOverride/" + "requested_resolution_alignment:9," + "frame_size_pixels:123|456|789," + "min_start_bitrate_bps:11000|22000|33000," + "min_bitrate_bps:44000|55000|66000," + "max_bitrate_bps:77000|88000|99000/"); + SetUp(); + SimulcastTestFixtureImpl::DefaultSettings( + &codec_, static_cast(kTestTemporalLayerProfile), + kVideoCodecVP8); + codec_.numberOfSimulcastStreams = 1; + EXPECT_EQ(0, adapter_->InitEncode(&codec_, kSettings)); + ASSERT_EQ(1u, helper_->factory()->encoders().size()); + + EXPECT_EQ(9, adapter_->GetEncoderInfo().requested_resolution_alignment); + EXPECT_FALSE( + adapter_->GetEncoderInfo().apply_alignment_to_all_simulcast_layers); + EXPECT_THAT( + adapter_->GetEncoderInfo().resolution_bitrate_limits, + ::testing::ElementsAre( + VideoEncoder::ResolutionBitrateLimits{123, 11000, 44000, 77000}, + VideoEncoder::ResolutionBitrateLimits{456, 22000, 55000, 88000}, + VideoEncoder::ResolutionBitrateLimits{789, 33000, 66000, 99000})); +} + TEST_F(TestSimulcastEncoderAdapterFake, ReportsInternalSource) { SimulcastTestFixtureImpl::DefaultSettings( &codec_, static_cast(kTestTemporalLayerProfile), @@ -1545,5 +1619,69 @@ TEST_F(TestSimulcastEncoderAdapterFake, SupportsPerSimulcastLayerMaxFramerate) { EXPECT_EQ(10u, helper_->factory()->encoders()[2]->codec().maxFramerate); } +TEST_F(TestSimulcastEncoderAdapterFake, CreatesEncoderOnlyIfStreamIsActive) { + // Legacy singlecast + SetupCodec(/*active_streams=*/{}); + EXPECT_EQ(1u, helper_->factory()->encoders().size()); + + // Simulcast-capable underlaying encoder + ReSetUp(); + helper_->factory()->set_supports_simulcast(true); + SetupCodec(/*active_streams=*/{true, true}); + EXPECT_EQ(1u, helper_->factory()->encoders().size()); + + // Muti-encoder simulcast + ReSetUp(); + helper_->factory()->set_supports_simulcast(false); + SetupCodec(/*active_streams=*/{true, true}); + EXPECT_EQ(2u, helper_->factory()->encoders().size()); + + // Singlecast via layers deactivation. Lowest layer is active. + ReSetUp(); + helper_->factory()->set_supports_simulcast(false); + SetupCodec(/*active_streams=*/{true, false}); + EXPECT_EQ(1u, helper_->factory()->encoders().size()); + + // Singlecast via layers deactivation. Highest layer is active. + ReSetUp(); + helper_->factory()->set_supports_simulcast(false); + SetupCodec(/*active_streams=*/{false, true}); + EXPECT_EQ(1u, helper_->factory()->encoders().size()); +} + +TEST_F(TestSimulcastEncoderAdapterFake, + RecreateEncoderIfPreferTemporalSupportIsEnabled) { + // Normally SEA reuses encoders. But, when TL-based SW fallback is enabled, + // the encoder which served the lowest stream should be recreated before it + // can be used to process an upper layer and vice-versa. + test::ScopedFieldTrials field_trials( + "WebRTC-Video-PreferTemporalSupportOnBaseLayer/Enabled/"); + use_fallback_factory_ = true; + ReSetUp(); + + // Legacy singlecast + SetupCodec(/*active_streams=*/{}); + ASSERT_EQ(1u, helper_->factory()->encoders().size()); + + // Singlecast, the lowest stream is active. Encoder should be reused. + MockVideoEncoder* prev_encoder = helper_->factory()->encoders()[0]; + SetupCodec(/*active_streams=*/{true, false}); + ASSERT_EQ(1u, helper_->factory()->encoders().size()); + EXPECT_EQ(helper_->factory()->encoders()[0], prev_encoder); + + // Singlecast, an upper stream is active. Encoder should be recreated. + EXPECT_CALL(*prev_encoder, Release()).Times(1); + SetupCodec(/*active_streams=*/{false, true}); + ASSERT_EQ(1u, helper_->factory()->encoders().size()); + EXPECT_NE(helper_->factory()->encoders()[0], prev_encoder); + + // Singlecast, the lowest stream is active. Encoder should be recreated. + prev_encoder = helper_->factory()->encoders()[0]; + EXPECT_CALL(*prev_encoder, Release()).Times(1); + SetupCodec(/*active_streams=*/{true, false}); + ASSERT_EQ(1u, helper_->factory()->encoders().size()); + EXPECT_NE(helper_->factory()->encoders()[0], prev_encoder); +} + } // namespace test } // namespace webrtc diff --git a/media/engine/simulcast_unittest.cc b/media/engine/simulcast_unittest.cc index 27b1574456..47a9db75a1 100644 --- a/media/engine/simulcast_unittest.cc +++ b/media/engine/simulcast_unittest.cc @@ -12,7 +12,6 @@ #include "api/transport/field_trial_based_config.h" #include "media/base/media_constants.h" -#include "media/engine/constants.h" #include "test/field_trial.h" #include "test/gtest.h" @@ -378,4 +377,149 @@ TEST(SimulcastTest, BitratesForCloseToStandardResolution) { } } +TEST(SimulcastTest, MaxLayersWithRoundUpDisabled) { + test::ScopedFieldTrials field_trials( + "WebRTC-SimulcastLayerLimitRoundUp/max_ratio:0.0/"); + FieldTrialBasedConfig trials; + const size_t kMinLayers = 1; + const int kMaxLayers = 3; + + std::vector streams; + streams = cricket::GetSimulcastConfig(kMinLayers, kMaxLayers, 960, 540, + kBitratePriority, kQpMax, !kScreenshare, + true, trials); + EXPECT_EQ(3u, streams.size()); + // <960x540: 2 layers + streams = cricket::GetSimulcastConfig(kMinLayers, kMaxLayers, 960, 539, + kBitratePriority, kQpMax, !kScreenshare, + true, trials); + EXPECT_EQ(2u, streams.size()); + streams = cricket::GetSimulcastConfig(kMinLayers, kMaxLayers, 480, 270, + kBitratePriority, kQpMax, !kScreenshare, + true, trials); + EXPECT_EQ(2u, streams.size()); + // <480x270: 1 layer + streams = cricket::GetSimulcastConfig(kMinLayers, kMaxLayers, 480, 269, + kBitratePriority, kQpMax, !kScreenshare, + true, trials); + EXPECT_EQ(1u, streams.size()); +} + +TEST(SimulcastTest, MaxLayersWithDefaultRoundUpRatio) { + // Default: "WebRTC-SimulcastLayerLimitRoundUp/max_ratio:0.1/" + FieldTrialBasedConfig trials; + const size_t kMinLayers = 1; + const int kMaxLayers = 3; + + std::vector streams; + streams = cricket::GetSimulcastConfig(kMinLayers, kMaxLayers, 960, 540, + kBitratePriority, kQpMax, !kScreenshare, + true, trials); + EXPECT_EQ(3u, streams.size()); + // Lowest cropped height where max layers from higher resolution is used. + streams = cricket::GetSimulcastConfig(kMinLayers, kMaxLayers, 960, 512, + kBitratePriority, kQpMax, !kScreenshare, + true, trials); + EXPECT_EQ(3u, streams.size()); + streams = cricket::GetSimulcastConfig(kMinLayers, kMaxLayers, 960, 508, + kBitratePriority, kQpMax, !kScreenshare, + true, trials); + EXPECT_EQ(2u, streams.size()); + streams = cricket::GetSimulcastConfig(kMinLayers, kMaxLayers, 480, 270, + kBitratePriority, kQpMax, !kScreenshare, + true, trials); + EXPECT_EQ(2u, streams.size()); + // Lowest cropped height where max layers from higher resolution is used. + streams = cricket::GetSimulcastConfig(kMinLayers, kMaxLayers, 480, 256, + kBitratePriority, kQpMax, !kScreenshare, + true, trials); + EXPECT_EQ(2u, streams.size()); + streams = cricket::GetSimulcastConfig(kMinLayers, kMaxLayers, 480, 254, + kBitratePriority, kQpMax, !kScreenshare, + true, trials); + EXPECT_EQ(1u, streams.size()); +} + +TEST(SimulcastTest, MaxLayersWithRoundUpRatio) { + test::ScopedFieldTrials field_trials( + "WebRTC-SimulcastLayerLimitRoundUp/max_ratio:0.13/"); + FieldTrialBasedConfig trials; + const size_t kMinLayers = 1; + const int kMaxLayers = 3; + + std::vector streams; + streams = cricket::GetSimulcastConfig(kMinLayers, kMaxLayers, 480, 270, + kBitratePriority, kQpMax, !kScreenshare, + true, trials); + EXPECT_EQ(2u, streams.size()); + // Lowest cropped height where max layers from higher resolution is used. + streams = cricket::GetSimulcastConfig(kMinLayers, kMaxLayers, 480, 252, + kBitratePriority, kQpMax, !kScreenshare, + true, trials); + EXPECT_EQ(2u, streams.size()); + streams = cricket::GetSimulcastConfig(kMinLayers, kMaxLayers, 480, 250, + kBitratePriority, kQpMax, !kScreenshare, + true, trials); + EXPECT_EQ(1u, streams.size()); +} + +TEST(SimulcastTest, BitratesInterpolatedForResBelow180p) { + // TODO(webrtc:12415): Remove when feature launches. + test::ScopedFieldTrials field_trials( + "WebRTC-LowresSimulcastBitrateInterpolation/Enabled/"); + + const size_t kMaxLayers = 3; + FieldTrialBasedConfig trials; + + std::vector streams = cricket::GetSimulcastConfig( + /* min_layers = */ 1, kMaxLayers, /* width = */ 960, /* height = */ 540, + kBitratePriority, kQpMax, !kScreenshare, true, trials); + + ASSERT_EQ(streams.size(), kMaxLayers); + EXPECT_EQ(240u, streams[0].width); + EXPECT_EQ(135u, streams[0].height); + EXPECT_EQ(streams[0].max_bitrate_bps, 112500); + EXPECT_EQ(streams[0].target_bitrate_bps, 84375); + EXPECT_EQ(streams[0].min_bitrate_bps, 30000); +} + +TEST(SimulcastTest, BitratesConsistentForVerySmallRes) { + // TODO(webrtc:12415): Remove when feature launches. + test::ScopedFieldTrials field_trials( + "WebRTC-LowresSimulcastBitrateInterpolation/Enabled/"); + + FieldTrialBasedConfig trials; + + std::vector streams = cricket::GetSimulcastConfig( + /* min_layers = */ 1, /* max_layers = */ 3, /* width = */ 1, + /* height = */ 1, kBitratePriority, kQpMax, !kScreenshare, true, trials); + + ASSERT_TRUE(!streams.empty()); + EXPECT_EQ(1u, streams[0].width); + EXPECT_EQ(1u, streams[0].height); + EXPECT_EQ(streams[0].max_bitrate_bps, 30000); + EXPECT_EQ(streams[0].target_bitrate_bps, 30000); + EXPECT_EQ(streams[0].min_bitrate_bps, 30000); +} + +TEST(SimulcastTest, + BitratesNotInterpolatedForResBelow180pWhenDisabledTrialSet) { + test::ScopedFieldTrials field_trials( + "WebRTC-LowresSimulcastBitrateInterpolation/Disabled/"); + + const size_t kMaxLayers = 3; + FieldTrialBasedConfig trials; + + std::vector streams = cricket::GetSimulcastConfig( + /* min_layers = */ 1, kMaxLayers, /* width = */ 960, /* height = */ 540, + kBitratePriority, kQpMax, !kScreenshare, true, trials); + + ASSERT_EQ(streams.size(), kMaxLayers); + EXPECT_EQ(240u, streams[0].width); + EXPECT_EQ(135u, streams[0].height); + EXPECT_EQ(streams[0].max_bitrate_bps, 200000); + EXPECT_EQ(streams[0].target_bitrate_bps, 150000); + EXPECT_EQ(streams[0].min_bitrate_bps, 30000); +} + } // namespace webrtc diff --git a/media/engine/webrtc_video_engine.cc b/media/engine/webrtc_video_engine.cc index 758e495a9f..38a210ee7d 100644 --- a/media/engine/webrtc_video_engine.cc +++ b/media/engine/webrtc_video_engine.cc @@ -39,7 +39,6 @@ #include "rtc_base/logging.h" #include "rtc_base/numerics/safe_conversions.h" #include "rtc_base/strings/string_builder.h" -#include "rtc_base/thread.h" #include "rtc_base/time_utils.h" #include "rtc_base/trace_event.h" @@ -48,6 +47,7 @@ namespace cricket { namespace { const int kMinLayerSize = 16; +constexpr int64_t kUnsignaledSsrcCooldownMs = rtc::kNumMillisecsPerSec / 2; const char* StreamTypeToString( webrtc::VideoSendStream::StreamStats::StreamType type) { @@ -106,10 +106,10 @@ void AddDefaultFeedbackParams(VideoCodec* codec, } } -// This function will assign dynamic payload types (in the range [96, 127]) to -// the input codecs, and also add ULPFEC, RED, FlexFEC, and associated RTX -// codecs for recognized codecs (VP8, VP9, H264, and RED). It will also add -// default feedback params to the codecs. +// This function will assign dynamic payload types (in the range [96, 127] +// and then [35, 63]) to the input codecs, and also add ULPFEC, RED, FlexFEC, +// and associated RTX codecs for recognized codecs (VP8, VP9, H264, and RED). +// It will also add default feedback params to the codecs. // is_decoder_factory is needed to keep track of the implict assumption that any // H264 decoder also supports constrained base line profile. // Also, is_decoder_factory is used to decide whether FlexFEC video format @@ -134,16 +134,6 @@ std::vector GetPayloadTypesAndDefaultCodecs( if (supported_formats.empty()) return std::vector(); - // Due to interoperability issues with old Chrome/WebRTC versions only use - // the lower range for new codecs. - static const int kFirstDynamicPayloadTypeLowerRange = 35; - static const int kLastDynamicPayloadTypeLowerRange = 65; - - static const int kFirstDynamicPayloadTypeUpperRange = 96; - static const int kLastDynamicPayloadTypeUpperRange = 127; - int payload_type_upper = kFirstDynamicPayloadTypeUpperRange; - int payload_type_lower = kFirstDynamicPayloadTypeLowerRange; - supported_formats.push_back(webrtc::SdpVideoFormat(kRedCodecName)); supported_formats.push_back(webrtc::SdpVideoFormat(kUlpfecCodecName)); @@ -163,60 +153,65 @@ std::vector GetPayloadTypesAndDefaultCodecs( supported_formats.push_back(flexfec_format); } + // Due to interoperability issues with old Chrome/WebRTC versions that + // ignore the [35, 63] range prefer the lower range for new codecs. + static const int kFirstDynamicPayloadTypeLowerRange = 35; + static const int kLastDynamicPayloadTypeLowerRange = 63; + + static const int kFirstDynamicPayloadTypeUpperRange = 96; + static const int kLastDynamicPayloadTypeUpperRange = 127; + int payload_type_upper = kFirstDynamicPayloadTypeUpperRange; + int payload_type_lower = kFirstDynamicPayloadTypeLowerRange; + std::vector output_codecs; for (const webrtc::SdpVideoFormat& format : supported_formats) { VideoCodec codec(format); bool isCodecValidForLowerRange = absl::EqualsIgnoreCase(codec.name, kFlexfecCodecName) || absl::EqualsIgnoreCase(codec.name, kAv1CodecName); - if (!isCodecValidForLowerRange) { - codec.id = payload_type_upper++; - } else { - codec.id = payload_type_lower++; - } - AddDefaultFeedbackParams(&codec, trials); - output_codecs.push_back(codec); + bool isFecCodec = absl::EqualsIgnoreCase(codec.name, kUlpfecCodecName) || + absl::EqualsIgnoreCase(codec.name, kFlexfecCodecName); - if (payload_type_upper > kLastDynamicPayloadTypeUpperRange) { - RTC_LOG(LS_ERROR) - << "Out of dynamic payload types [96,127], skipping the rest."; - // TODO(https://bugs.chromium.org/p/webrtc/issues/detail?id=12194): - // continue in lower range. - break; - } + // Check if we ran out of payload types. if (payload_type_lower > kLastDynamicPayloadTypeLowerRange) { // TODO(https://bugs.chromium.org/p/webrtc/issues/detail?id=12248): // return an error. - RTC_LOG(LS_ERROR) - << "Out of dynamic payload types [35,65], skipping the rest."; + RTC_LOG(LS_ERROR) << "Out of dynamic payload types [35,63] after " + "fallback from [96, 127], skipping the rest."; + RTC_DCHECK_EQ(payload_type_upper, kLastDynamicPayloadTypeUpperRange); break; } - // Add associated RTX codec for non-FEC codecs. - if (!absl::EqualsIgnoreCase(codec.name, kUlpfecCodecName) && - !absl::EqualsIgnoreCase(codec.name, kFlexfecCodecName)) { - if (!isCodecValidForLowerRange) { - output_codecs.push_back( - VideoCodec::CreateRtxCodec(payload_type_upper++, codec.id)); - } else { - output_codecs.push_back( - VideoCodec::CreateRtxCodec(payload_type_lower++, codec.id)); - } + // Lower range gets used for "new" codecs or when running out of payload + // types in the upper range. + if (isCodecValidForLowerRange || + payload_type_upper >= kLastDynamicPayloadTypeUpperRange) { + codec.id = payload_type_lower++; + } else { + codec.id = payload_type_upper++; + } + AddDefaultFeedbackParams(&codec, trials); + output_codecs.push_back(codec); - if (payload_type_upper > kLastDynamicPayloadTypeUpperRange) { - RTC_LOG(LS_ERROR) - << "Out of dynamic payload types [96,127], skipping rtx."; - // TODO(https://bugs.chromium.org/p/webrtc/issues/detail?id=12194): - // continue in lower range. - break; - } + // Add associated RTX codec for non-FEC codecs. + if (!isFecCodec) { + // Check if we ran out of payload types. if (payload_type_lower > kLastDynamicPayloadTypeLowerRange) { // TODO(https://bugs.chromium.org/p/webrtc/issues/detail?id=12248): // return an error. - RTC_LOG(LS_ERROR) - << "Out of dynamic payload types [35,65], skipping rtx."; + RTC_LOG(LS_ERROR) << "Out of dynamic payload types [35,63] after " + "fallback from [96, 127], skipping the rest."; + RTC_DCHECK_EQ(payload_type_upper, kLastDynamicPayloadTypeUpperRange); break; } + if (isCodecValidForLowerRange || + payload_type_upper >= kLastDynamicPayloadTypeUpperRange) { + output_codecs.push_back( + VideoCodec::CreateRtxCodec(payload_type_lower++, codec.id)); + } else { + output_codecs.push_back( + VideoCodec::CreateRtxCodec(payload_type_upper++, codec.id)); + } } } return output_codecs; @@ -502,7 +497,7 @@ WebRtcVideoChannel::WebRtcVideoSendStream::ConfigureVideoEncoderSettings( webrtc::VideoCodecH264 h264_settings = webrtc::VideoEncoder::GetDefaultH264Settings(); h264_settings.frameDroppingOn = frame_dropping; - return new rtc::RefCountedObject< + return rtc::make_ref_counted< webrtc::VideoEncoderConfig::H264EncoderSpecificSettings>(h264_settings); } if (absl::EqualsIgnoreCase(codec.name, kVp8CodecName)) { @@ -512,7 +507,7 @@ WebRtcVideoChannel::WebRtcVideoSendStream::ConfigureVideoEncoderSettings( // VP8 denoising is enabled by default. vp8_settings.denoisingOn = codec_default_denoising ? true : denoising; vp8_settings.frameDroppingOn = frame_dropping; - return new rtc::RefCountedObject< + return rtc::make_ref_counted< webrtc::VideoEncoderConfig::Vp8EncoderSpecificSettings>(vp8_settings); } if (absl::EqualsIgnoreCase(codec.name, kVp9CodecName)) { @@ -562,7 +557,7 @@ WebRtcVideoChannel::WebRtcVideoSendStream::ConfigureVideoEncoderSettings( vp9_settings.flexibleMode = vp9_settings.numberOfSpatialLayers > 1; vp9_settings.interLayerPred = webrtc::InterLayerPredMode::kOn; } - return new rtc::RefCountedObject< + return rtc::make_ref_counted< webrtc::VideoEncoderConfig::Vp9EncoderSpecificSettings>(vp9_settings); } return nullptr; @@ -626,11 +621,11 @@ WebRtcVideoEngine::WebRtcVideoEngine( : decoder_factory_(std::move(video_decoder_factory)), encoder_factory_(std::move(video_encoder_factory)), trials_(trials) { - RTC_LOG(LS_INFO) << "WebRtcVideoEngine::WebRtcVideoEngine()"; + RTC_DLOG(LS_INFO) << "WebRtcVideoEngine::WebRtcVideoEngine()"; } WebRtcVideoEngine::~WebRtcVideoEngine() { - RTC_LOG(LS_INFO) << "WebRtcVideoEngine::~WebRtcVideoEngine"; + RTC_DLOG(LS_INFO) << "WebRtcVideoEngine::~WebRtcVideoEngine"; } VideoMediaChannel* WebRtcVideoEngine::CreateMediaChannel( @@ -686,6 +681,12 @@ WebRtcVideoEngine::GetRtpHeaderExtensions() const { ? webrtc::RtpTransceiverDirection::kSendRecv : webrtc::RtpTransceiverDirection::kStopped); + result.emplace_back( + webrtc::RtpExtension::kVideoFrameTrackingIdUri, id++, + IsEnabled(trials_, "WebRTC-VideoFrameTrackingIdAdvertised") + ? webrtc::RtpTransceiverDirection::kSendRecv + : webrtc::RtpTransceiverDirection::kStopped); + return result; } @@ -697,8 +698,8 @@ WebRtcVideoChannel::WebRtcVideoChannel( webrtc::VideoEncoderFactory* encoder_factory, webrtc::VideoDecoderFactory* decoder_factory, webrtc::VideoBitrateAllocatorFactory* bitrate_allocator_factory) - : VideoMediaChannel(config), - worker_thread_(rtc::Thread::Current()), + : VideoMediaChannel(config, call->network_thread()), + worker_thread_(call->worker_thread()), call_(call), unsignalled_ssrc_handler_(&default_unsignalled_ssrc_handler_), video_config_(config.video), @@ -716,7 +717,8 @@ WebRtcVideoChannel::WebRtcVideoChannel( "WebRTC-Video-BufferPacketsWithUnknownSsrc") ? new UnhandledPacketsBuffer() : nullptr) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&thread_checker_); + network_thread_checker_.Detach(); rtcp_receiver_report_ssrc_ = kDefaultRtcpReceiverReportSsrc; sending_ = false; @@ -752,8 +754,8 @@ WebRtcVideoChannel::SelectSendVideoCodecs( // following the spec in https://tools.ietf.org/html/rfc6184#section-8.2.2 // since we should limit the encode level to the lower of local and remote // level when level asymmetry is not allowed. - if (IsSameCodec(format_it->name, format_it->parameters, - remote_codec.codec.name, remote_codec.codec.params)) { + if (format_it->IsSameCodec( + {remote_codec.codec.name, remote_codec.codec.params})) { encoders.push_back(remote_codec); // To allow the VideoEncoderFactory to keep information about which @@ -816,7 +818,6 @@ bool WebRtcVideoChannel::GetChangedSendParameters( // Never enable sending FlexFEC, unless we are in the experiment. if (!IsEnabled(call_->trials(), "WebRTC-FlexFEC-03")) { - RTC_LOG(LS_INFO) << "WebRTC-FlexFEC-03 field trial is not enabled."; for (VideoCodecSettings& codec : negotiated_codecs) codec.flexfec_payload_type = -1; } @@ -948,8 +949,8 @@ void WebRtcVideoChannel::RequestEncoderSwitch( RTC_DCHECK_RUN_ON(&thread_checker_); for (const VideoCodecSettings& codec_setting : negotiated_codecs_) { - if (IsSameCodec(format.name, format.parameters, codec_setting.codec.name, - codec_setting.codec.params)) { + if (format.IsSameCodec( + {codec_setting.codec.name, codec_setting.codec.params})) { VideoCodecSettings new_codec_setting = codec_setting; for (const auto& kv : format.parameters) { new_codec_setting.codec.params[kv.first] = kv.second; @@ -1032,7 +1033,7 @@ bool WebRtcVideoChannel::ApplyChangedParams( if (changed_params.send_codec || changed_params.rtcp_mode) { // Update receive feedback parameters from new codec or RTCP mode. RTC_LOG(LS_INFO) - << "SetFeedbackOptions on all the receive streams because the send " + << "SetFeedbackParameters on all the receive streams because the send " "codec or RTCP mode has changed."; for (auto& kv : receive_streams_) { RTC_DCHECK(kv.second != nullptr); @@ -1040,7 +1041,8 @@ bool WebRtcVideoChannel::ApplyChangedParams( HasLntf(send_codec_->codec), HasNack(send_codec_->codec), HasTransportCc(send_codec_->codec), send_params_.rtcp.reduced_size ? webrtc::RtcpMode::kReducedSize - : webrtc::RtcpMode::kCompound); + : webrtc::RtcpMode::kCompound, + send_codec_->rtx_time); } } return true; @@ -1214,25 +1216,25 @@ bool WebRtcVideoChannel::GetChangedRecvParameters( bool WebRtcVideoChannel::SetRecvParameters(const VideoRecvParameters& params) { RTC_DCHECK_RUN_ON(&thread_checker_); TRACE_EVENT0("webrtc", "WebRtcVideoChannel::SetRecvParameters"); - RTC_LOG(LS_INFO) << "SetRecvParameters: " << params.ToString(); + RTC_DLOG(LS_INFO) << "SetRecvParameters: " << params.ToString(); ChangedRecvParameters changed_params; if (!GetChangedRecvParameters(params, &changed_params)) { return false; } if (changed_params.flexfec_payload_type) { - RTC_LOG(LS_INFO) << "Changing FlexFEC payload type (recv) from " - << recv_flexfec_payload_type_ << " to " - << *changed_params.flexfec_payload_type; + RTC_DLOG(LS_INFO) << "Changing FlexFEC payload type (recv) from " + << recv_flexfec_payload_type_ << " to " + << *changed_params.flexfec_payload_type; recv_flexfec_payload_type_ = *changed_params.flexfec_payload_type; } if (changed_params.rtp_header_extensions) { recv_rtp_extensions_ = *changed_params.rtp_header_extensions; } if (changed_params.codec_settings) { - RTC_LOG(LS_INFO) << "Changing recv codecs from " - << CodecSettingsVectorToString(recv_codecs_) << " to " - << CodecSettingsVectorToString( - *changed_params.codec_settings); + RTC_DLOG(LS_INFO) << "Changing recv codecs from " + << CodecSettingsVectorToString(recv_codecs_) << " to " + << CodecSettingsVectorToString( + *changed_params.codec_settings); recv_codecs_ = *changed_params.codec_settings; } @@ -1466,7 +1468,7 @@ bool WebRtcVideoChannel::AddRecvStream(const StreamParams& sp, for (uint32_t used_ssrc : sp.ssrcs) receive_ssrcs_.insert(used_ssrc); - webrtc::VideoReceiveStream::Config config(this); + webrtc::VideoReceiveStream::Config config(this, decoder_factory_); webrtc::FlexfecReceiveStream::Config flexfec_config(this); ConfigureReceiverRtp(&config, &flexfec_config, sp); @@ -1481,8 +1483,8 @@ bool WebRtcVideoChannel::AddRecvStream(const StreamParams& sp, config.frame_transformer = unsignaled_frame_transformer_; receive_streams_[ssrc] = new WebRtcVideoReceiveStream( - this, call_, sp, std::move(config), decoder_factory_, default_stream, - recv_codecs_, flexfec_config); + this, call_, sp, std::move(config), default_stream, recv_codecs_, + flexfec_config); return true; } @@ -1517,6 +1519,12 @@ void WebRtcVideoChannel::ConfigureReceiverRtp( ? webrtc::RtcpMode::kReducedSize : webrtc::RtcpMode::kCompound; + // rtx-time (RFC 4588) is a declarative attribute similar to rtcp-rsize and + // determined by the sender / send codec. + if (send_codec_ && send_codec_->rtx_time != -1) { + config->rtp.nack.rtp_history_ms = send_codec_->rtx_time; + } + config->rtp.transport_cc = send_codec_ ? HasTransportCc(send_codec_->codec) : false; @@ -1527,14 +1535,14 @@ void WebRtcVideoChannel::ConfigureReceiverRtp( // TODO(brandtr): Generalize when we add support for multistream protection. flexfec_config->payload_type = recv_flexfec_payload_type_; if (!IsDisabled(call_->trials(), "WebRTC-FlexFEC-03-Advertised") && - sp.GetFecFrSsrc(ssrc, &flexfec_config->remote_ssrc)) { + sp.GetFecFrSsrc(ssrc, &flexfec_config->rtp.remote_ssrc)) { flexfec_config->protected_media_ssrcs = {ssrc}; - flexfec_config->local_ssrc = config->rtp.local_ssrc; + flexfec_config->rtp.local_ssrc = config->rtp.local_ssrc; flexfec_config->rtcp_mode = config->rtp.rtcp_mode; // TODO(brandtr): We should be spec-compliant and set |transport_cc| here // based on the rtcp-fb for the FlexFEC codec, not the media codec. - flexfec_config->transport_cc = config->rtp.transport_cc; - flexfec_config->rtp_header_extensions = config->rtp.extensions; + flexfec_config->rtp.transport_cc = config->rtp.transport_cc; + flexfec_config->rtp.extensions = config->rtp.extensions; } } @@ -1558,6 +1566,7 @@ void WebRtcVideoChannel::ResetUnsignaledRecvStream() { RTC_DCHECK_RUN_ON(&thread_checker_); RTC_LOG(LS_INFO) << "ResetUnsignaledRecvStream."; unsignaled_stream_params_ = StreamParams(); + last_unsignalled_ssrc_creation_time_ms_ = absl::nullopt; // Delete any created default streams. This is needed to avoid SSRC collisions // in Call's RtpDemuxer, in the case that |this| has created a default video @@ -1574,6 +1583,19 @@ void WebRtcVideoChannel::ResetUnsignaledRecvStream() { } } +void WebRtcVideoChannel::OnDemuxerCriteriaUpdatePending() { + RTC_DCHECK_RUN_ON(&thread_checker_); + ++demuxer_criteria_id_; +} + +void WebRtcVideoChannel::OnDemuxerCriteriaUpdateComplete() { + RTC_DCHECK_RUN_ON(&network_thread_checker_); + worker_thread_->PostTask(ToQueuedTask(task_safety_, [this] { + RTC_DCHECK_RUN_ON(&thread_checker_); + ++demuxer_criteria_completed_id_; + })); +} + bool WebRtcVideoChannel::SetSink( uint32_t ssrc, rtc::VideoSinkInterface* sink) { @@ -1684,67 +1706,112 @@ void WebRtcVideoChannel::FillSendAndReceiveCodecStats( void WebRtcVideoChannel::OnPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) { - RTC_DCHECK_RUN_ON(&thread_checker_); - const webrtc::PacketReceiver::DeliveryStatus delivery_result = - call_->Receiver()->DeliverPacket(webrtc::MediaType::VIDEO, packet, - packet_time_us); - switch (delivery_result) { - case webrtc::PacketReceiver::DELIVERY_OK: - return; - case webrtc::PacketReceiver::DELIVERY_PACKET_ERROR: - return; - case webrtc::PacketReceiver::DELIVERY_UNKNOWN_SSRC: - break; - } + RTC_DCHECK_RUN_ON(&network_thread_checker_); + // TODO(bugs.webrtc.org/11993): This code is very similar to what + // WebRtcVoiceMediaChannel::OnPacketReceived does. For maintainability and + // consistency it would be good to move the interaction with call_->Receiver() + // to a common implementation and provide a callback on the worker thread + // for the exception case (DELIVERY_UNKNOWN_SSRC) and how retry is attempted. + worker_thread_->PostTask( + ToQueuedTask(task_safety_, [this, packet, packet_time_us] { + RTC_DCHECK_RUN_ON(&thread_checker_); + const webrtc::PacketReceiver::DeliveryStatus delivery_result = + call_->Receiver()->DeliverPacket(webrtc::MediaType::VIDEO, packet, + packet_time_us); + switch (delivery_result) { + case webrtc::PacketReceiver::DELIVERY_OK: + return; + case webrtc::PacketReceiver::DELIVERY_PACKET_ERROR: + return; + case webrtc::PacketReceiver::DELIVERY_UNKNOWN_SSRC: + break; + } - uint32_t ssrc = 0; - if (!GetRtpSsrc(packet.cdata(), packet.size(), &ssrc)) { - return; - } + uint32_t ssrc = 0; + if (!GetRtpSsrc(packet.cdata(), packet.size(), &ssrc)) { + return; + } - if (unknown_ssrc_packet_buffer_) { - unknown_ssrc_packet_buffer_->AddPacket(ssrc, packet_time_us, packet); - return; - } + if (unknown_ssrc_packet_buffer_) { + unknown_ssrc_packet_buffer_->AddPacket(ssrc, packet_time_us, packet); + return; + } - if (discard_unknown_ssrc_packets_) { - return; - } + if (discard_unknown_ssrc_packets_) { + return; + } - int payload_type = 0; - if (!GetRtpPayloadType(packet.cdata(), packet.size(), &payload_type)) { - return; - } + int payload_type = 0; + if (!GetRtpPayloadType(packet.cdata(), packet.size(), &payload_type)) { + return; + } - // See if this payload_type is registered as one that usually gets its own - // SSRC (RTX) or at least is safe to drop either way (FEC). If it is, and - // it wasn't handled above by DeliverPacket, that means we don't know what - // stream it associates with, and we shouldn't ever create an implicit channel - // for these. - for (auto& codec : recv_codecs_) { - if (payload_type == codec.rtx_payload_type || - payload_type == codec.ulpfec.red_rtx_payload_type || - payload_type == codec.ulpfec.ulpfec_payload_type) { - return; - } - } - if (payload_type == recv_flexfec_payload_type_) { - return; - } + // See if this payload_type is registered as one that usually gets its + // own SSRC (RTX) or at least is safe to drop either way (FEC). If it + // is, and it wasn't handled above by DeliverPacket, that means we don't + // know what stream it associates with, and we shouldn't ever create an + // implicit channel for these. + for (auto& codec : recv_codecs_) { + if (payload_type == codec.rtx_payload_type || + payload_type == codec.ulpfec.red_rtx_payload_type || + payload_type == codec.ulpfec.ulpfec_payload_type) { + return; + } + } + if (payload_type == recv_flexfec_payload_type_) { + return; + } - switch (unsignalled_ssrc_handler_->OnUnsignalledSsrc(this, ssrc)) { - case UnsignalledSsrcHandler::kDropPacket: - return; - case UnsignalledSsrcHandler::kDeliverPacket: - break; - } + // Ignore unknown ssrcs if there is a demuxer criteria update pending. + // During a demuxer update we may receive ssrcs that were recently + // removed or we may receve ssrcs that were recently configured for a + // different video channel. + if (demuxer_criteria_id_ != demuxer_criteria_completed_id_) { + return; + } + // Ignore unknown ssrcs if we recently created an unsignalled receive + // stream since this shouldn't happen frequently. Getting into a state + // of creating decoders on every packet eats up processing time (e.g. + // https://crbug.com/1069603) and this cooldown prevents that. + if (last_unsignalled_ssrc_creation_time_ms_.has_value()) { + int64_t now_ms = rtc::TimeMillis(); + if (now_ms - last_unsignalled_ssrc_creation_time_ms_.value() < + kUnsignaledSsrcCooldownMs) { + // We've already created an unsignalled ssrc stream within the last + // 0.5 s, ignore with a warning. + RTC_LOG(LS_WARNING) + << "Another unsignalled ssrc packet arrived shortly after the " + << "creation of an unsignalled ssrc stream. Dropping packet."; + return; + } + } + // Let the unsignalled ssrc handler decide whether to drop or deliver. + switch (unsignalled_ssrc_handler_->OnUnsignalledSsrc(this, ssrc)) { + case UnsignalledSsrcHandler::kDropPacket: + return; + case UnsignalledSsrcHandler::kDeliverPacket: + break; + } - if (call_->Receiver()->DeliverPacket(webrtc::MediaType::VIDEO, packet, - packet_time_us) != - webrtc::PacketReceiver::DELIVERY_OK) { - RTC_LOG(LS_WARNING) << "Failed to deliver RTP packet on re-delivery."; - return; - } + if (call_->Receiver()->DeliverPacket(webrtc::MediaType::VIDEO, packet, + packet_time_us) != + webrtc::PacketReceiver::DELIVERY_OK) { + RTC_LOG(LS_WARNING) << "Failed to deliver RTP packet on re-delivery."; + } + last_unsignalled_ssrc_creation_time_ms_ = rtc::TimeMillis(); + })); +} + +void WebRtcVideoChannel::OnPacketSent(const rtc::SentPacket& sent_packet) { + RTC_DCHECK_RUN_ON(&network_thread_checker_); + // TODO(tommi): We shouldn't need to go through call_ to deliver this + // notification. We should already have direct access to + // video_send_delay_stats_ and transport_send_ptr_ via `stream_`. + // So we should be able to remove OnSentPacket from Call and handle this per + // channel instead. At the moment Call::OnSentPacket calls OnSentPacket for + // the video stats, for all sent packets, including audio, which causes + // unnecessary lookups. + call_->OnSentPacket(sent_packet); } void WebRtcVideoChannel::BackfillBufferedPackets( @@ -1794,7 +1861,7 @@ void WebRtcVideoChannel::BackfillBufferedPackets( } void WebRtcVideoChannel::OnReadyToSend(bool ready) { - RTC_DCHECK_RUN_ON(&thread_checker_); + RTC_DCHECK_RUN_ON(&network_thread_checker_); RTC_LOG(LS_VERBOSE) << "OnReadyToSend: " << (ready ? "Ready." : "Not ready."); call_->SignalChannelNetworkState( webrtc::MediaType::VIDEO, @@ -1804,15 +1871,19 @@ void WebRtcVideoChannel::OnReadyToSend(bool ready) { void WebRtcVideoChannel::OnNetworkRouteChanged( const std::string& transport_name, const rtc::NetworkRoute& network_route) { - RTC_DCHECK_RUN_ON(&thread_checker_); - call_->GetTransportControllerSend()->OnNetworkRouteChanged(transport_name, - network_route); - call_->GetTransportControllerSend()->OnTransportOverheadChanged( - network_route.packet_overhead); + RTC_DCHECK_RUN_ON(&network_thread_checker_); + worker_thread_->PostTask(ToQueuedTask( + task_safety_, [this, name = transport_name, route = network_route] { + RTC_DCHECK_RUN_ON(&thread_checker_); + webrtc::RtpTransportControllerSendInterface* transport = + call_->GetTransportControllerSend(); + transport->OnNetworkRouteChanged(name, route); + transport->OnTransportOverheadChanged(route.packet_overhead); + })); } void WebRtcVideoChannel::SetInterface(NetworkInterface* iface) { - RTC_DCHECK_RUN_ON(&thread_checker_); + RTC_DCHECK_RUN_ON(&network_thread_checker_); MediaChannel::SetInterface(iface); // Set the RTP recv/send buffer to a bigger size. @@ -1961,27 +2032,13 @@ std::vector WebRtcVideoChannel::GetSources( bool WebRtcVideoChannel::SendRtp(const uint8_t* data, size_t len, const webrtc::PacketOptions& options) { - rtc::CopyOnWriteBuffer packet(data, len, kMaxRtpPacketLen); - rtc::PacketOptions rtc_options; - rtc_options.packet_id = options.packet_id; - if (DscpEnabled()) { - rtc_options.dscp = PreferredDscp(); - } - rtc_options.info_signaled_after_sent.included_in_feedback = - options.included_in_feedback; - rtc_options.info_signaled_after_sent.included_in_allocation = - options.included_in_allocation; - return MediaChannel::SendPacket(&packet, rtc_options); + MediaChannel::SendRtp(data, len, options); + return true; } bool WebRtcVideoChannel::SendRtcp(const uint8_t* data, size_t len) { - rtc::CopyOnWriteBuffer packet(data, len, kMaxRtpPacketLen); - rtc::PacketOptions rtc_options; - if (DscpEnabled()) { - rtc_options.dscp = PreferredDscp(); - } - - return MediaChannel::SendRtcp(&packet, rtc_options); + MediaChannel::SendRtcp(data, len); + return true; } WebRtcVideoChannel::WebRtcVideoSendStream::VideoSendStreamParameters:: @@ -2008,7 +2065,7 @@ WebRtcVideoChannel::WebRtcVideoSendStream::WebRtcVideoSendStream( // TODO(deadbeef): Don't duplicate information between send_params, // rtp_extensions, options, etc. const VideoSendParameters& send_params) - : worker_thread_(rtc::Thread::Current()), + : worker_thread_(call->worker_thread()), ssrcs_(sp.ssrcs), ssrc_groups_(sp.ssrc_groups), call_(call), @@ -2297,6 +2354,9 @@ webrtc::RTCError WebRtcVideoChannel::WebRtcVideoSendStream::SetRtpParameters( // TODO(bugs.webrtc.org/8807): The active field as well should not require // a full encoder reconfiguration, but it needs to update both the bitrate // allocator and the video bitrate allocator. + // + // Note that the simulcast encoder adapter relies on the fact that layers + // de/activation triggers encoder reinitialization. bool new_send_state = false; for (size_t i = 0; i < rtp_parameters_.encodings.size(); ++i) { bool new_active = IsLayerActive(new_parameters.encodings[i]); @@ -2477,11 +2537,17 @@ WebRtcVideoChannel::WebRtcVideoSendStream::CreateVideoEncoderConfig( encoder_config.legacy_conference_mode = parameters_.conference_mode; + encoder_config.is_quality_scaling_allowed = + !disable_automatic_resize_ && !is_screencast && + (parameters_.config.rtp.ssrcs.size() == 1 || + NumActiveStreams(rtp_parameters_) == 1); + int max_qp = kDefaultQpMax; codec.GetParam(kCodecParamMaxQuantization, &max_qp); encoder_config.video_stream_factory = - new rtc::RefCountedObject( + rtc::make_ref_counted( codec.name, max_qp, is_screencast, parameters_.conference_mode); + return encoder_config; } @@ -2559,6 +2625,7 @@ WebRtcVideoChannel::WebRtcVideoSendStream::GetPerLayerVideoSenderInfos( stats.quality_limitation_resolution_changes; common_info.encoder_implementation_name = stats.encoder_implementation_name; common_info.ssrc_groups = ssrc_groups_; + common_info.frames = stats.frames; common_info.framerate_input = stats.input_frame_rate; common_info.avg_encode_ms = stats.avg_encode_time_ms; common_info.encode_usage_percent = stats.encode_usage_percent; @@ -2608,15 +2675,18 @@ WebRtcVideoChannel::WebRtcVideoSendStream::GetPerLayerVideoSenderInfos( stream_stats.rtp_stats.retransmitted.payload_bytes; info.retransmitted_packets_sent = stream_stats.rtp_stats.retransmitted.packets; - info.packets_lost = stream_stats.rtcp_stats.packets_lost; info.firs_rcvd = stream_stats.rtcp_packet_type_counts.fir_packets; info.nacks_rcvd = stream_stats.rtcp_packet_type_counts.nack_packets; info.plis_rcvd = stream_stats.rtcp_packet_type_counts.pli_packets; if (stream_stats.report_block_data.has_value()) { - info.report_block_datas.push_back(stream_stats.report_block_data.value()); + info.packets_lost = + stream_stats.report_block_data->report_block().packets_lost; + info.fraction_lost = + static_cast( + stream_stats.report_block_data->report_block().fraction_lost) / + (1 << 8); + info.report_block_datas.push_back(*stream_stats.report_block_data); } - info.fraction_lost = - static_cast(stream_stats.rtcp_stats.fraction_lost) / (1 << 8); info.qp_sum = stream_stats.qp_sum; info.total_encode_time_ms = stream_stats.total_encode_time_ms; info.total_encoded_bytes_target = stream_stats.total_encoded_bytes_target; @@ -2749,7 +2819,6 @@ WebRtcVideoChannel::WebRtcVideoReceiveStream::WebRtcVideoReceiveStream( webrtc::Call* call, const StreamParams& sp, webrtc::VideoReceiveStream::Config config, - webrtc::VideoDecoderFactory* decoder_factory, bool default_stream, const std::vector& recv_codecs, const webrtc::FlexfecReceiveStream::Config& flexfec_config) @@ -2761,23 +2830,20 @@ WebRtcVideoChannel::WebRtcVideoReceiveStream::WebRtcVideoReceiveStream( config_(std::move(config)), flexfec_config_(flexfec_config), flexfec_stream_(nullptr), - decoder_factory_(decoder_factory), sink_(NULL), first_frame_timestamp_(-1), estimated_remote_start_ntp_time_ms_(0) { + RTC_DCHECK(config_.decoder_factory); config_.renderer = this; ConfigureCodecs(recv_codecs); - ConfigureFlexfecCodec(flexfec_config.payload_type); - MaybeRecreateWebRtcFlexfecStream(); + flexfec_config_.payload_type = flexfec_config.payload_type; RecreateWebRtcVideoStream(); } WebRtcVideoChannel::WebRtcVideoReceiveStream::~WebRtcVideoReceiveStream() { - if (flexfec_stream_) { - MaybeDissociateFlexfecFromVideo(); - call_->DestroyFlexfecReceiveStream(flexfec_stream_); - } call_->DestroyVideoReceiveStream(stream_); + if (flexfec_stream_) + call_->DestroyFlexfecReceiveStream(flexfec_stream_); } const std::vector& @@ -2809,47 +2875,84 @@ WebRtcVideoChannel::WebRtcVideoReceiveStream::GetRtpParameters() const { return rtp_parameters; } -void WebRtcVideoChannel::WebRtcVideoReceiveStream::ConfigureCodecs( +bool WebRtcVideoChannel::WebRtcVideoReceiveStream::ConfigureCodecs( const std::vector& recv_codecs) { RTC_DCHECK(!recv_codecs.empty()); - config_.decoders.clear(); - config_.rtp.rtx_associated_payload_types.clear(); - config_.rtp.raw_payload_types.clear(); - config_.decoder_factory = decoder_factory_; + + std::map rtx_associated_payload_types; + std::set raw_payload_types; + std::vector decoders; for (const auto& recv_codec : recv_codecs) { - webrtc::SdpVideoFormat video_format(recv_codec.codec.name, - recv_codec.codec.params); - - webrtc::VideoReceiveStream::Decoder decoder; - decoder.video_format = video_format; - decoder.payload_type = recv_codec.codec.id; - decoder.video_format = - webrtc::SdpVideoFormat(recv_codec.codec.name, recv_codec.codec.params); - config_.decoders.push_back(decoder); - config_.rtp.rtx_associated_payload_types[recv_codec.rtx_payload_type] = - recv_codec.codec.id; + decoders.emplace_back( + webrtc::SdpVideoFormat(recv_codec.codec.name, recv_codec.codec.params), + recv_codec.codec.id); + rtx_associated_payload_types.insert( + {recv_codec.rtx_payload_type, recv_codec.codec.id}); if (recv_codec.codec.packetization == kPacketizationParamRaw) { - config_.rtp.raw_payload_types.insert(recv_codec.codec.id); + raw_payload_types.insert(recv_codec.codec.id); } } + bool recreate_needed = (stream_ == nullptr); + const auto& codec = recv_codecs.front(); - config_.rtp.ulpfec_payload_type = codec.ulpfec.ulpfec_payload_type; - config_.rtp.red_payload_type = codec.ulpfec.red_payload_type; + if (config_.rtp.ulpfec_payload_type != codec.ulpfec.ulpfec_payload_type) { + config_.rtp.ulpfec_payload_type = codec.ulpfec.ulpfec_payload_type; + recreate_needed = true; + } + + if (config_.rtp.red_payload_type != codec.ulpfec.red_payload_type) { + config_.rtp.red_payload_type = codec.ulpfec.red_payload_type; + recreate_needed = true; + } + + const bool has_lntf = HasLntf(codec.codec); + if (config_.rtp.lntf.enabled != has_lntf) { + config_.rtp.lntf.enabled = has_lntf; + recreate_needed = true; + } + + const int rtp_history_ms = HasNack(codec.codec) ? kNackHistoryMs : 0; + if (rtp_history_ms != config_.rtp.nack.rtp_history_ms) { + config_.rtp.nack.rtp_history_ms = rtp_history_ms; + recreate_needed = true; + } + + // The rtx-time parameter can be used to override the hardcoded default for + // the NACK buffer length. + if (codec.rtx_time != -1 && config_.rtp.nack.rtp_history_ms != 0) { + config_.rtp.nack.rtp_history_ms = codec.rtx_time; + recreate_needed = true; + } + + const bool has_rtr = HasRrtr(codec.codec); + if (has_rtr != config_.rtp.rtcp_xr.receiver_reference_time_report) { + config_.rtp.rtcp_xr.receiver_reference_time_report = has_rtr; + recreate_needed = true; + } - config_.rtp.lntf.enabled = HasLntf(codec.codec); - config_.rtp.nack.rtp_history_ms = HasNack(codec.codec) ? kNackHistoryMs : 0; - config_.rtp.rtcp_xr.receiver_reference_time_report = HasRrtr(codec.codec); if (codec.ulpfec.red_rtx_payload_type != -1) { - config_.rtp - .rtx_associated_payload_types[codec.ulpfec.red_rtx_payload_type] = + rtx_associated_payload_types[codec.ulpfec.red_rtx_payload_type] = codec.ulpfec.red_payload_type; } -} -void WebRtcVideoChannel::WebRtcVideoReceiveStream::ConfigureFlexfecCodec( - int flexfec_payload_type) { - flexfec_config_.payload_type = flexfec_payload_type; + if (config_.rtp.rtx_associated_payload_types != + rtx_associated_payload_types) { + rtx_associated_payload_types.swap(config_.rtp.rtx_associated_payload_types); + recreate_needed = true; + } + + if (raw_payload_types != config_.rtp.raw_payload_types) { + raw_payload_types.swap(config_.rtp.raw_payload_types); + recreate_needed = true; + } + + if (decoders != config_.decoders) { + decoders.swap(config_.decoders); + recreate_needed = true; + } + + return recreate_needed; } void WebRtcVideoChannel::WebRtcVideoReceiveStream::SetLocalSsrc( @@ -2866,11 +2969,10 @@ void WebRtcVideoChannel::WebRtcVideoReceiveStream::SetLocalSsrc( } config_.rtp.local_ssrc = local_ssrc; - flexfec_config_.local_ssrc = local_ssrc; + flexfec_config_.rtp.local_ssrc = local_ssrc; RTC_LOG(LS_INFO) << "RecreateWebRtcVideoStream (recv) because of SetLocalSsrc; local_ssrc=" << local_ssrc; - MaybeRecreateWebRtcFlexfecStream(); RecreateWebRtcVideoStream(); } @@ -2878,8 +2980,10 @@ void WebRtcVideoChannel::WebRtcVideoReceiveStream::SetFeedbackParameters( bool lntf_enabled, bool nack_enabled, bool transport_cc_enabled, - webrtc::RtcpMode rtcp_mode) { - int nack_history_ms = nack_enabled ? kNackHistoryMs : 0; + webrtc::RtcpMode rtcp_mode, + int rtx_time) { + int nack_history_ms = + nack_enabled ? rtx_time != -1 ? rtx_time : kNackHistoryMs : 0; if (config_.rtp.lntf.enabled == lntf_enabled && config_.rtp.nack.rtp_history_ms == nack_history_ms && config_.rtp.transport_cc == transport_cc_enabled && @@ -2888,7 +2992,8 @@ void WebRtcVideoChannel::WebRtcVideoReceiveStream::SetFeedbackParameters( << "Ignoring call to SetFeedbackParameters because parameters are " "unchanged; lntf=" << lntf_enabled << ", nack=" << nack_enabled - << ", transport_cc=" << transport_cc_enabled; + << ", transport_cc=" << transport_cc_enabled + << ", rtx_time=" << rtx_time; return; } config_.rtp.lntf.enabled = lntf_enabled; @@ -2897,41 +3002,43 @@ void WebRtcVideoChannel::WebRtcVideoReceiveStream::SetFeedbackParameters( config_.rtp.rtcp_mode = rtcp_mode; // TODO(brandtr): We should be spec-compliant and set |transport_cc| here // based on the rtcp-fb for the FlexFEC codec, not the media codec. - flexfec_config_.transport_cc = config_.rtp.transport_cc; + flexfec_config_.rtp.transport_cc = config_.rtp.transport_cc; flexfec_config_.rtcp_mode = config_.rtp.rtcp_mode; RTC_LOG(LS_INFO) << "RecreateWebRtcVideoStream (recv) because of " "SetFeedbackParameters; nack=" << nack_enabled << ", transport_cc=" << transport_cc_enabled; - MaybeRecreateWebRtcFlexfecStream(); RecreateWebRtcVideoStream(); } void WebRtcVideoChannel::WebRtcVideoReceiveStream::SetRecvParameters( const ChangedRecvParameters& params) { bool video_needs_recreation = false; - bool flexfec_needs_recreation = false; if (params.codec_settings) { - ConfigureCodecs(*params.codec_settings); - video_needs_recreation = true; + video_needs_recreation = ConfigureCodecs(*params.codec_settings); } + if (params.rtp_header_extensions) { - config_.rtp.extensions = *params.rtp_header_extensions; - flexfec_config_.rtp_header_extensions = *params.rtp_header_extensions; - video_needs_recreation = true; - flexfec_needs_recreation = true; + if (config_.rtp.extensions != *params.rtp_header_extensions) { + config_.rtp.extensions = *params.rtp_header_extensions; + video_needs_recreation = true; + } + + if (flexfec_config_.rtp.extensions != *params.rtp_header_extensions) { + flexfec_config_.rtp.extensions = *params.rtp_header_extensions; + if (flexfec_stream_ || flexfec_config_.IsCompleteAndEnabled()) + video_needs_recreation = true; + } } if (params.flexfec_payload_type) { - ConfigureFlexfecCodec(*params.flexfec_payload_type); - flexfec_needs_recreation = true; - } - if (flexfec_needs_recreation) { - RTC_LOG(LS_INFO) << "MaybeRecreateWebRtcFlexfecStream (recv) because of " - "SetRecvParameters"; - MaybeRecreateWebRtcFlexfecStream(); + flexfec_config_.payload_type = *params.flexfec_payload_type; + // TODO(tommi): See if it is better to always have a flexfec stream object + // configured and instead of recreating the video stream, reconfigure the + // flexfec object from within the rtp callback (soon to be on the network + // thread). + if (flexfec_stream_ || flexfec_config_.IsCompleteAndEnabled()) + video_needs_recreation = true; } if (video_needs_recreation) { - RTC_LOG(LS_INFO) - << "RecreateWebRtcVideoStream (recv) because of SetRecvParameters"; RecreateWebRtcVideoStream(); } } @@ -2944,13 +3051,22 @@ void WebRtcVideoChannel::WebRtcVideoReceiveStream::RecreateWebRtcVideoStream() { recording_state = stream_->SetAndGetRecordingState( webrtc::VideoReceiveStream::RecordingState(), /*generate_key_frame=*/false); - MaybeDissociateFlexfecFromVideo(); call_->DestroyVideoReceiveStream(stream_); stream_ = nullptr; } + + if (flexfec_stream_) { + call_->DestroyFlexfecReceiveStream(flexfec_stream_); + flexfec_stream_ = nullptr; + } + + if (flexfec_config_.IsCompleteAndEnabled()) { + flexfec_stream_ = call_->CreateFlexfecReceiveStream(flexfec_config_); + } + webrtc::VideoReceiveStream::Config config = config_.Copy(); config.rtp.protected_by_flexfec = (flexfec_stream_ != nullptr); - config.stream_id = stream_params_.id; + config.rtp.packet_sink_ = flexfec_stream_; stream_ = call_->CreateVideoReceiveStream(std::move(config)); if (base_minimum_playout_delay_ms) { stream_->SetBaseMinimumPlayoutDelayMs( @@ -2960,7 +3076,7 @@ void WebRtcVideoChannel::WebRtcVideoReceiveStream::RecreateWebRtcVideoStream() { stream_->SetAndGetRecordingState(std::move(*recording_state), /*generate_key_frame=*/false); } - MaybeAssociateFlexfecWithVideo(); + stream_->Start(); if (IsEnabled(call_->trials(), "WebRTC-Video-BufferPacketsWithUnknownSsrc")) { @@ -2968,33 +3084,6 @@ void WebRtcVideoChannel::WebRtcVideoReceiveStream::RecreateWebRtcVideoStream() { } } -void WebRtcVideoChannel::WebRtcVideoReceiveStream:: - MaybeRecreateWebRtcFlexfecStream() { - if (flexfec_stream_) { - MaybeDissociateFlexfecFromVideo(); - call_->DestroyFlexfecReceiveStream(flexfec_stream_); - flexfec_stream_ = nullptr; - } - if (flexfec_config_.IsCompleteAndEnabled()) { - flexfec_stream_ = call_->CreateFlexfecReceiveStream(flexfec_config_); - MaybeAssociateFlexfecWithVideo(); - } -} - -void WebRtcVideoChannel::WebRtcVideoReceiveStream:: - MaybeAssociateFlexfecWithVideo() { - if (stream_ && flexfec_stream_) { - stream_->AddSecondarySink(flexfec_stream_); - } -} - -void WebRtcVideoChannel::WebRtcVideoReceiveStream:: - MaybeDissociateFlexfecFromVideo() { - if (stream_ && flexfec_stream_) { - stream_->RemoveSecondarySink(flexfec_stream_); - } -} - void WebRtcVideoChannel::WebRtcVideoReceiveStream::OnFrame( const webrtc::VideoFrame& frame) { webrtc::MutexLock lock(&sink_lock_); @@ -3074,6 +3163,7 @@ WebRtcVideoChannel::WebRtcVideoReceiveStream::GetVideoReceiverInfo( stats.rtp_stats.packet_counter.padding_bytes; info.packets_rcvd = stats.rtp_stats.packet_counter.packets; info.packets_lost = stats.rtp_stats.packets_lost; + info.jitter_ms = stats.rtp_stats.jitter; info.framerate_rcvd = stats.network_frame_rate; info.framerate_decoded = stats.decode_frame_rate; @@ -3180,20 +3270,21 @@ void WebRtcVideoChannel::WebRtcVideoReceiveStream:: } WebRtcVideoChannel::VideoCodecSettings::VideoCodecSettings() - : flexfec_payload_type(-1), rtx_payload_type(-1) {} + : flexfec_payload_type(-1), rtx_payload_type(-1), rtx_time(-1) {} bool WebRtcVideoChannel::VideoCodecSettings::operator==( const WebRtcVideoChannel::VideoCodecSettings& other) const { return codec == other.codec && ulpfec == other.ulpfec && flexfec_payload_type == other.flexfec_payload_type && - rtx_payload_type == other.rtx_payload_type; + rtx_payload_type == other.rtx_payload_type && + rtx_time == other.rtx_time; } bool WebRtcVideoChannel::VideoCodecSettings::EqualsDisregardingFlexfec( const WebRtcVideoChannel::VideoCodecSettings& a, const WebRtcVideoChannel::VideoCodecSettings& b) { return a.codec == b.codec && a.ulpfec == b.ulpfec && - a.rtx_payload_type == b.rtx_payload_type; + a.rtx_payload_type == b.rtx_payload_type && a.rtx_time == b.rtx_time; } bool WebRtcVideoChannel::VideoCodecSettings::operator!=( @@ -3211,6 +3302,7 @@ WebRtcVideoChannel::MapCodecs(const std::vector& codecs) { std::map payload_codec_type; // |rtx_mapping| maps video payload type to rtx payload type. std::map rtx_mapping; + std::map rtx_time_mapping; webrtc::UlpfecConfig ulpfec_config; absl::optional flexfec_payload_type; @@ -3272,6 +3364,10 @@ WebRtcVideoChannel::MapCodecs(const std::vector& codecs) { << in_codec.ToString(); return {}; } + int rtx_time; + if (in_codec.GetParam(kCodecParamRtxTime, &rtx_time) && rtx_time > 0) { + rtx_time_mapping[associated_payload_type] = rtx_time; + } rtx_mapping[associated_payload_type] = payload_type; break; } @@ -3321,6 +3417,16 @@ WebRtcVideoChannel::MapCodecs(const std::vector& codecs) { if (it != rtx_mapping.end()) { const int rtx_payload_type = it->second; codec_settings.rtx_payload_type = rtx_payload_type; + + auto rtx_time_it = rtx_time_mapping.find(payload_type); + if (rtx_time_it != rtx_time_mapping.end()) { + const int rtx_time = rtx_time_it->second; + if (rtx_time < kNackHistoryMs) { + codec_settings.rtx_time = rtx_time; + } else { + codec_settings.rtx_time = kNackHistoryMs; + } + } } } diff --git a/media/engine/webrtc_video_engine.h b/media/engine/webrtc_video_engine.h index 321a5a8c2e..a67a010ed7 100644 --- a/media/engine/webrtc_video_engine.h +++ b/media/engine/webrtc_video_engine.h @@ -19,6 +19,7 @@ #include "absl/types/optional.h" #include "api/call/transport.h" +#include "api/sequence_checker.h" #include "api/transport/field_trial_based_config.h" #include "api/video/video_bitrate_allocator_factory.h" #include "api/video/video_frame.h" @@ -30,23 +31,17 @@ #include "call/video_receive_stream.h" #include "call/video_send_stream.h" #include "media/base/media_engine.h" -#include "media/engine/constants.h" #include "media/engine/unhandled_packets_buffer.h" #include "rtc_base/network_route.h" #include "rtc_base/synchronization/mutex.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" namespace webrtc { class VideoDecoderFactory; class VideoEncoderFactory; -struct MediaConfig; } // namespace webrtc -namespace rtc { -class Thread; -} // namespace rtc - namespace cricket { class WebRtcVideoChannel; @@ -159,6 +154,8 @@ class WebRtcVideoChannel : public VideoMediaChannel, bool AddRecvStream(const StreamParams& sp, bool default_stream); bool RemoveRecvStream(uint32_t ssrc) override; void ResetUnsignaledRecvStream() override; + void OnDemuxerCriteriaUpdatePending() override; + void OnDemuxerCriteriaUpdateComplete() override; bool SetSink(uint32_t ssrc, rtc::VideoSinkInterface* sink) override; void SetDefaultSink( @@ -168,6 +165,7 @@ class WebRtcVideoChannel : public VideoMediaChannel, void OnPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) override; + void OnPacketSent(const rtc::SentPacket& sent_packet) override; void OnReadyToSend(bool ready) override; void OnNetworkRouteChanged(const std::string& transport_name, const rtc::NetworkRoute& network_route) override; @@ -273,6 +271,7 @@ class WebRtcVideoChannel : public VideoMediaChannel, webrtc::UlpfecConfig ulpfec; int flexfec_payload_type; // -1 if absent. int rtx_payload_type; // -1 if absent. + int rtx_time; // -1 if absent. }; struct ChangedSendParameters { @@ -397,8 +396,8 @@ class WebRtcVideoChannel : public VideoMediaChannel, webrtc::DegradationPreference GetDegradationPreference() const RTC_EXCLUSIVE_LOCKS_REQUIRED(&thread_checker_); - rtc::ThreadChecker thread_checker_; - rtc::Thread* worker_thread_; + webrtc::SequenceChecker thread_checker_; + webrtc::TaskQueueBase* const worker_thread_; const std::vector ssrcs_ RTC_GUARDED_BY(&thread_checker_); const std::vector ssrc_groups_ RTC_GUARDED_BY(&thread_checker_); webrtc::Call* const call_; @@ -437,7 +436,6 @@ class WebRtcVideoChannel : public VideoMediaChannel, webrtc::Call* call, const StreamParams& sp, webrtc::VideoReceiveStream::Config config, - webrtc::VideoDecoderFactory* decoder_factory, bool default_stream, const std::vector& recv_codecs, const webrtc::FlexfecReceiveStream::Config& flexfec_config); @@ -455,7 +453,8 @@ class WebRtcVideoChannel : public VideoMediaChannel, void SetFeedbackParameters(bool lntf_enabled, bool nack_enabled, bool transport_cc_enabled, - webrtc::RtcpMode rtcp_mode); + webrtc::RtcpMode rtcp_mode, + int rtx_time); void SetRecvParameters(const ChangedRecvParameters& recv_params); void OnFrame(const webrtc::VideoFrame& frame) override; @@ -483,13 +482,11 @@ class WebRtcVideoChannel : public VideoMediaChannel, private: void RecreateWebRtcVideoStream(); - void MaybeRecreateWebRtcFlexfecStream(); - - void MaybeAssociateFlexfecWithVideo(); - void MaybeDissociateFlexfecFromVideo(); - void ConfigureCodecs(const std::vector& recv_codecs); - void ConfigureFlexfecCodec(int flexfec_payload_type); + // Applies a new receive codecs configration to `config_`. Returns true + // if the internal stream needs to be reconstructed, or false if no changes + // were applied. + bool ConfigureCodecs(const std::vector& recv_codecs); std::string GetCodecNameFromPayloadType(int payload_type); @@ -506,8 +503,6 @@ class WebRtcVideoChannel : public VideoMediaChannel, webrtc::FlexfecReceiveStream::Config flexfec_config_; webrtc::FlexfecReceiveStream* flexfec_stream_; - webrtc::VideoDecoderFactory* const decoder_factory_; - webrtc::Mutex sink_lock_; rtc::VideoSinkInterface* sink_ RTC_GUARDED_BY(sink_lock_); @@ -553,12 +548,14 @@ class WebRtcVideoChannel : public VideoMediaChannel, void FillSendAndReceiveCodecStats(VideoMediaInfo* video_media_info) RTC_EXCLUSIVE_LOCKS_REQUIRED(thread_checker_); - rtc::Thread* const worker_thread_; - rtc::ThreadChecker thread_checker_; + webrtc::TaskQueueBase* const worker_thread_; + webrtc::ScopedTaskSafety task_safety_; + webrtc::SequenceChecker network_thread_checker_; + webrtc::SequenceChecker thread_checker_; uint32_t rtcp_receiver_report_ssrc_ RTC_GUARDED_BY(thread_checker_); bool sending_ RTC_GUARDED_BY(thread_checker_); - webrtc::Call* const call_ RTC_GUARDED_BY(thread_checker_); + webrtc::Call* const call_; DefaultUnsignalledSsrcHandler default_unsignalled_ssrc_handler_ RTC_GUARDED_BY(thread_checker_); @@ -575,6 +572,24 @@ class WebRtcVideoChannel : public VideoMediaChannel, RTC_GUARDED_BY(thread_checker_); std::map receive_streams_ RTC_GUARDED_BY(thread_checker_); + // When the channel and demuxer get reconfigured, there is a window of time + // where we have to be prepared for packets arriving based on the old demuxer + // criteria because the streams live on the worker thread and the demuxer + // lives on the network thread. Because packets are posted from the network + // thread to the worker thread, they can still be in-flight when streams are + // reconfgured. This can happen when |demuxer_criteria_id_| and + // |demuxer_criteria_completed_id_| don't match. During this time, we do not + // want to create unsignalled receive streams and should instead drop the + // packets. E.g: + // * If RemoveRecvStream(old_ssrc) was recently called, there may be packets + // in-flight for that ssrc. This happens when a receiver becomes inactive. + // * If we go from one to many m= sections, the demuxer may change from + // forwarding all packets to only forwarding the configured ssrcs, so there + // is a risk of receiving ssrcs for other, recently added m= sections. + uint32_t demuxer_criteria_id_ RTC_GUARDED_BY(thread_checker_) = 0; + uint32_t demuxer_criteria_completed_id_ RTC_GUARDED_BY(thread_checker_) = 0; + absl::optional last_unsignalled_ssrc_creation_time_ms_ + RTC_GUARDED_BY(thread_checker_); std::set send_ssrcs_ RTC_GUARDED_BY(thread_checker_); std::set receive_ssrcs_ RTC_GUARDED_BY(thread_checker_); diff --git a/media/engine/webrtc_video_engine_unittest.cc b/media/engine/webrtc_video_engine_unittest.cc index 72fbc56885..d0745e35f5 100644 --- a/media/engine/webrtc_video_engine_unittest.cc +++ b/media/engine/webrtc_video_engine_unittest.cc @@ -35,34 +35,35 @@ #include "api/video/video_bitrate_allocation.h" #include "api/video_codecs/builtin_video_decoder_factory.h" #include "api/video_codecs/builtin_video_encoder_factory.h" +#include "api/video_codecs/h264_profile_level_id.h" #include "api/video_codecs/sdp_video_format.h" #include "api/video_codecs/video_decoder_factory.h" #include "api/video_codecs/video_encoder.h" #include "api/video_codecs/video_encoder_factory.h" #include "call/flexfec_receive_stream.h" -#include "common_video/h264/profile_level_id.h" #include "media/base/fake_frame_source.h" #include "media/base/fake_network_interface.h" #include "media/base/fake_video_renderer.h" #include "media/base/media_constants.h" #include "media/base/rtp_utils.h" #include "media/base/test_utils.h" -#include "media/engine/constants.h" #include "media/engine/fake_webrtc_call.h" #include "media/engine/fake_webrtc_video_engine.h" #include "media/engine/simulcast.h" #include "media/engine/webrtc_voice_engine.h" +#include "modules/rtp_rtcp/source/rtp_packet.h" #include "rtc_base/arraysize.h" +#include "rtc_base/event.h" #include "rtc_base/experiments/min_video_bitrate_experiment.h" #include "rtc_base/fake_clock.h" #include "rtc_base/gunit.h" #include "rtc_base/numerics/safe_conversions.h" #include "rtc_base/time_utils.h" +#include "system_wrappers/include/field_trial.h" #include "test/fake_decoder.h" #include "test/field_trial.h" #include "test/frame_forwarder.h" #include "test/gmock.h" -#include "test/rtp_header_parser.h" using ::testing::_; using ::testing::Contains; @@ -96,6 +97,7 @@ static const uint32_t kSsrcs3[] = {1, 2, 3}; static const uint32_t kRtxSsrcs1[] = {4}; static const uint32_t kFlexfecSsrc = 5; static const uint32_t kIncomingUnsignalledSsrc = 0xC0FFEE; +static const int64_t kUnsignalledReceiveStreamCooldownMs = 500; constexpr uint32_t kRtpHeaderSize = 12; @@ -398,6 +400,17 @@ TEST_F(WebRtcVideoEngineTestWithVideoLayersAllocation, ExpectRtpCapabilitySupport(RtpExtension::kVideoLayersAllocationUri, true); } +class WebRtcVideoFrameTrackingId : public WebRtcVideoEngineTest { + public: + WebRtcVideoFrameTrackingId() + : WebRtcVideoEngineTest( + "WebRTC-VideoFrameTrackingIdAdvertised/Enabled/") {} +}; + +TEST_F(WebRtcVideoFrameTrackingId, AdvertiseVideoFrameTrackingId) { + ExpectRtpCapabilitySupport(RtpExtension::kVideoFrameTrackingIdUri, true); +} + TEST_F(WebRtcVideoEngineTest, CVOSetHeaderExtensionBeforeCapturer) { // Allocate the source first to prevent early destruction before channel's // dtor is called. @@ -571,20 +584,21 @@ TEST_F(WebRtcVideoEngineTest, UseFactoryForVp8WhenSupported) { // TODO(deadbeef): This test should be updated if/when we start // adding RTX codecs for unrecognized codec names. TEST_F(WebRtcVideoEngineTest, RtxCodecAddedForH264Codec) { - using webrtc::H264::kLevel1; - using webrtc::H264::ProfileLevelId; - using webrtc::H264::ProfileLevelIdToString; + using webrtc::H264Level; + using webrtc::H264Profile; + using webrtc::H264ProfileLevelId; + using webrtc::H264ProfileLevelIdToString; webrtc::SdpVideoFormat h264_constrained_baseline("H264"); h264_constrained_baseline.parameters[kH264FmtpProfileLevelId] = - *ProfileLevelIdToString( - ProfileLevelId(webrtc::H264::kProfileConstrainedBaseline, kLevel1)); + *H264ProfileLevelIdToString(H264ProfileLevelId( + H264Profile::kProfileConstrainedBaseline, H264Level::kLevel1)); webrtc::SdpVideoFormat h264_constrained_high("H264"); h264_constrained_high.parameters[kH264FmtpProfileLevelId] = - *ProfileLevelIdToString( - ProfileLevelId(webrtc::H264::kProfileConstrainedHigh, kLevel1)); + *H264ProfileLevelIdToString(H264ProfileLevelId( + H264Profile::kProfileConstrainedHigh, H264Level::kLevel1)); webrtc::SdpVideoFormat h264_high("H264"); - h264_high.parameters[kH264FmtpProfileLevelId] = *ProfileLevelIdToString( - ProfileLevelId(webrtc::H264::kProfileHigh, kLevel1)); + h264_high.parameters[kH264FmtpProfileLevelId] = *H264ProfileLevelIdToString( + H264ProfileLevelId(H264Profile::kProfileHigh, H264Level::kLevel1)); encoder_factory_->AddSupportedVideoCodec(h264_constrained_baseline); encoder_factory_->AddSupportedVideoCodec(h264_constrained_high); @@ -711,10 +725,10 @@ size_t WebRtcVideoEngineTest::GetEngineCodecIndex( // The tests only use H264 Constrained Baseline. Make sure we don't return // an internal H264 codec from the engine with a different H264 profile. if (absl::EqualsIgnoreCase(name.c_str(), kH264CodecName)) { - const absl::optional profile_level_id = - webrtc::H264::ParseSdpProfileLevelId(engine_codec.params); + const absl::optional profile_level_id = + webrtc::ParseSdpForH264ProfileLevelId(engine_codec.params); if (profile_level_id->profile != - webrtc::H264::kProfileConstrainedBaseline) { + webrtc::H264Profile::kProfileConstrainedBaseline) { continue; } } @@ -1406,6 +1420,10 @@ class WebRtcVideoChannelEncodedFrameCallbackTest : public ::testing::Test { channel_->SetRecvParameters(parameters); } + ~WebRtcVideoChannelEncodedFrameCallbackTest() override { + channel_->SetInterface(nullptr); + } + void DeliverKeyFrame(uint32_t ssrc) { webrtc::RtpPacket packet; packet.SetMarker(true); @@ -1416,6 +1434,13 @@ class WebRtcVideoChannelEncodedFrameCallbackTest : public ::testing::Test { uint8_t* buf_ptr = packet.AllocatePayload(11); memset(buf_ptr, 0, 11); // Pass MSAN (don't care about bytes 1-9) buf_ptr[0] = 0x10; // Partition ID 0 + beginning of partition. + constexpr unsigned width = 1080; + constexpr unsigned height = 720; + buf_ptr[6] = width & 255; + buf_ptr[7] = width >> 8; + buf_ptr[8] = height & 255; + buf_ptr[9] = height >> 8; + call_->Receiver()->DeliverPacket(webrtc::MediaType::VIDEO, packet.Buffer(), /*packet_time_us=*/0); } @@ -1526,7 +1551,7 @@ class WebRtcVideoChannelBaseTest : public ::testing::Test { webrtc::CreateBuiltinVideoDecoderFactory(), field_trials_) {} - virtual void SetUp() { + void SetUp() override { // One testcase calls SetUp in a loop, only create call_ once. if (!call_) { webrtc::Call::Config call_config(&event_log_); @@ -1568,6 +1593,7 @@ class WebRtcVideoChannelBaseTest : public ::testing::Test { // Make the second renderer available for use by a new stream. EXPECT_TRUE(channel_->SetSink(kSsrc + 2, &renderer2_)); } + // Setup an additional stream just to send video. Defer add recv stream. // This is required if you want to test unsignalled recv of video rtp packets. void SetUpSecondStreamWithNoRecv() { @@ -1586,7 +1612,17 @@ class WebRtcVideoChannelBaseTest : public ::testing::Test { EXPECT_TRUE( channel_->SetVideoSend(kSsrc + 2, nullptr, frame_forwarder_2_.get())); } - virtual void TearDown() { channel_.reset(); } + + void TearDown() override { + channel_->SetInterface(nullptr); + channel_.reset(); + } + + void ResetTest() { + TearDown(); + SetUp(); + } + bool SetDefaultCodec() { return SetOneCodec(DefaultCodec()); } bool SetOneCodec(const cricket::VideoCodec& codec) { @@ -1626,20 +1662,13 @@ class WebRtcVideoChannelBaseTest : public ::testing::Test { return network_interface_.NumRtpPackets(ssrc); } int NumSentSsrcs() { return network_interface_.NumSentSsrcs(); } - const rtc::CopyOnWriteBuffer* GetRtpPacket(int index) { + rtc::CopyOnWriteBuffer GetRtpPacket(int index) { return network_interface_.GetRtpPacket(index); } - static int GetPayloadType(const rtc::CopyOnWriteBuffer* p) { - webrtc::RTPHeader header; - EXPECT_TRUE(ParseRtpPacket(p, &header)); - return header.payloadType; - } - - static bool ParseRtpPacket(const rtc::CopyOnWriteBuffer* p, - webrtc::RTPHeader* header) { - std::unique_ptr parser( - webrtc::RtpHeaderParser::CreateForTest()); - return parser->Parse(p->cdata(), p->size(), header); + static int GetPayloadType(rtc::CopyOnWriteBuffer p) { + webrtc::RtpPacket header; + EXPECT_TRUE(header.Parse(std::move(p))); + return header.PayloadType(); } // Tests that we can send and receive frames. @@ -1650,8 +1679,7 @@ class WebRtcVideoChannelBaseTest : public ::testing::Test { EXPECT_EQ(0, renderer_.num_rendered_frames()); SendFrame(); EXPECT_FRAME_WAIT(1, kVideoWidth, kVideoHeight, kTimeout); - std::unique_ptr p(GetRtpPacket(0)); - EXPECT_EQ(codec.id, GetPayloadType(p.get())); + EXPECT_EQ(codec.id, GetPayloadType(GetRtpPacket(0))); } void SendReceiveManyAndGetStats(const cricket::VideoCodec& codec, @@ -1667,8 +1695,7 @@ class WebRtcVideoChannelBaseTest : public ::testing::Test { EXPECT_FRAME_WAIT(frame + i * fps, kVideoWidth, kVideoHeight, kTimeout); } } - std::unique_ptr p(GetRtpPacket(0)); - EXPECT_EQ(codec.id, GetPayloadType(p.get())); + EXPECT_EQ(codec.id, GetPayloadType(GetRtpPacket(0))); } cricket::VideoSenderInfo GetSenderStats(size_t i) { @@ -1714,6 +1741,7 @@ class WebRtcVideoChannelBaseTest : public ::testing::Test { webrtc::RtcEventLogNull event_log_; webrtc::FieldTrialBasedConfig field_trials_; + std::unique_ptr override_field_trials_; std::unique_ptr task_queue_factory_; std::unique_ptr call_; std::unique_ptr @@ -1768,9 +1796,11 @@ TEST_F(WebRtcVideoChannelBaseTest, OverridesRecvBufferSize) { // Set field trial to override the default recv buffer size, and then re-run // setup where the interface is created and configured. const int kCustomRecvBufferSize = 123456; - webrtc::test::ScopedFieldTrials field_trial( + RTC_DCHECK(!override_field_trials_); + override_field_trials_ = std::make_unique( "WebRTC-IncreasedReceivebuffers/123456/"); - SetUp(); + + ResetTest(); EXPECT_TRUE(SetOneCodec(DefaultCodec())); EXPECT_TRUE(SetSend(true)); @@ -1784,9 +1814,10 @@ TEST_F(WebRtcVideoChannelBaseTest, OverridesRecvBufferSizeWithSuffix) { // Set field trial to override the default recv buffer size, and then re-run // setup where the interface is created and configured. const int kCustomRecvBufferSize = 123456; - webrtc::test::ScopedFieldTrials field_trial( + RTC_DCHECK(!override_field_trials_); + override_field_trials_ = std::make_unique( "WebRTC-IncreasedReceivebuffers/123456_Dogfood/"); - SetUp(); + ResetTest(); EXPECT_TRUE(SetOneCodec(DefaultCodec())); EXPECT_TRUE(SetSend(true)); @@ -1801,24 +1832,50 @@ TEST_F(WebRtcVideoChannelBaseTest, InvalidRecvBufferSize) { // then re-run setup where the interface is created and configured. The // default value should still be used. + const char* prev_field_trials = webrtc::field_trial::GetFieldTrialString(); + + std::string field_trial_string; for (std::string group : {" ", "NotANumber", "-1", "0"}) { - std::string field_trial_string = "WebRTC-IncreasedReceivebuffers/"; - field_trial_string += group; - field_trial_string += "/"; - webrtc::test::ScopedFieldTrials field_trial(field_trial_string); + std::string trial_string = "WebRTC-IncreasedReceivebuffers/"; + trial_string += group; + trial_string += "/"; + + // Dear reader. Sorry for this... it's a bit of a mess. + // TODO(bugs.webrtc.org/12854): This test needs to be rewritten to not use + // ResetTest and changing global field trials in a loop. + TearDown(); + // This is a hack to appease tsan. Because of the way the test is written + // active state within Call, including running task queues may race with + // the test changing the global field trial variable. + // This particular hack, pauses the transport controller TQ while we + // change the field trial. + rtc::TaskQueue* tq = call_->GetTransportControllerSend()->GetWorkerQueue(); + rtc::Event waiting, resume; + tq->PostTask([&waiting, &resume]() { + waiting.Set(); + resume.Wait(rtc::Event::kForever); + }); + + waiting.Wait(rtc::Event::kForever); + field_trial_string = std::move(trial_string); + webrtc::field_trial::InitFieldTrialsFromString(field_trial_string.c_str()); + SetUp(); + resume.Set(); + + // OK, now the test can carry on. EXPECT_TRUE(SetOneCodec(DefaultCodec())); EXPECT_TRUE(SetSend(true)); EXPECT_EQ(64 * 1024, network_interface_.sendbuf_size()); EXPECT_EQ(256 * 1024, network_interface_.recvbuf_size()); } + + webrtc::field_trial::InitFieldTrialsFromString(prev_field_trials); } // Test that stats work properly for a 1-1 call. TEST_F(WebRtcVideoChannelBaseTest, GetStats) { - SetUp(); - const int kDurationSec = 3; const int kFps = 10; SendReceiveManyAndGetStats(DefaultCodec(), kDurationSec, kFps); @@ -1837,7 +1894,7 @@ TEST_F(WebRtcVideoChannelBaseTest, GetStats) { EXPECT_EQ(DefaultCodec().id, *info.senders[0].codec_payload_type); EXPECT_EQ(0, info.senders[0].firs_rcvd); EXPECT_EQ(0, info.senders[0].plis_rcvd); - EXPECT_EQ(0, info.senders[0].nacks_rcvd); + EXPECT_EQ(0u, info.senders[0].nacks_rcvd); EXPECT_EQ(kVideoWidth, info.senders[0].send_frame_width); EXPECT_EQ(kVideoHeight, info.senders[0].send_frame_height); EXPECT_GT(info.senders[0].framerate_input, 0); @@ -1861,7 +1918,7 @@ TEST_F(WebRtcVideoChannelBaseTest, GetStats) { // EXPECT_EQ(0, info.receivers[0].packets_concealed); EXPECT_EQ(0, info.receivers[0].firs_sent); EXPECT_EQ(0, info.receivers[0].plis_sent); - EXPECT_EQ(0, info.receivers[0].nacks_sent); + EXPECT_EQ(0U, info.receivers[0].nacks_sent); EXPECT_EQ(kVideoWidth, info.receivers[0].frame_width); EXPECT_EQ(kVideoHeight, info.receivers[0].frame_height); EXPECT_GT(info.receivers[0].framerate_rcvd, 0); @@ -1875,8 +1932,6 @@ TEST_F(WebRtcVideoChannelBaseTest, GetStats) { // Test that stats work properly for a conf call with multiple recv streams. TEST_F(WebRtcVideoChannelBaseTest, GetStatsMultipleRecvStreams) { - SetUp(); - cricket::FakeVideoRenderer renderer1, renderer2; EXPECT_TRUE(SetOneCodec(DefaultCodec())); cricket::VideoSendParameters parameters; @@ -2005,15 +2060,14 @@ TEST_F(WebRtcVideoChannelBaseTest, SetSendSsrc) { EXPECT_TRUE(SetSend(true)); SendFrame(); EXPECT_TRUE_WAIT(NumRtpPackets() > 0, kTimeout); - webrtc::RTPHeader header; - std::unique_ptr p(GetRtpPacket(0)); - EXPECT_TRUE(ParseRtpPacket(p.get(), &header)); - EXPECT_EQ(kSsrc, header.ssrc); + webrtc::RtpPacket header; + EXPECT_TRUE(header.Parse(GetRtpPacket(0))); + EXPECT_EQ(kSsrc, header.Ssrc()); // Packets are being paced out, so these can mismatch between the first and // second call to NumRtpPackets until pending packets are paced out. - EXPECT_EQ_WAIT(NumRtpPackets(), NumRtpPackets(header.ssrc), kTimeout); - EXPECT_EQ_WAIT(NumRtpBytes(), NumRtpBytes(header.ssrc), kTimeout); + EXPECT_EQ_WAIT(NumRtpPackets(), NumRtpPackets(header.Ssrc()), kTimeout); + EXPECT_EQ_WAIT(NumRtpBytes(), NumRtpBytes(header.Ssrc()), kTimeout); EXPECT_EQ(1, NumSentSsrcs()); EXPECT_EQ(0, NumRtpPackets(kSsrc - 1)); EXPECT_EQ(0, NumRtpBytes(kSsrc - 1)); @@ -2030,14 +2084,13 @@ TEST_F(WebRtcVideoChannelBaseTest, SetSendSsrcAfterSetCodecs) { EXPECT_TRUE(SetSend(true)); EXPECT_TRUE(WaitAndSendFrame(0)); EXPECT_TRUE_WAIT(NumRtpPackets() > 0, kTimeout); - webrtc::RTPHeader header; - std::unique_ptr p(GetRtpPacket(0)); - EXPECT_TRUE(ParseRtpPacket(p.get(), &header)); - EXPECT_EQ(999u, header.ssrc); + webrtc::RtpPacket header; + EXPECT_TRUE(header.Parse(GetRtpPacket(0))); + EXPECT_EQ(999u, header.Ssrc()); // Packets are being paced out, so these can mismatch between the first and // second call to NumRtpPackets until pending packets are paced out. - EXPECT_EQ_WAIT(NumRtpPackets(), NumRtpPackets(header.ssrc), kTimeout); - EXPECT_EQ_WAIT(NumRtpBytes(), NumRtpBytes(header.ssrc), kTimeout); + EXPECT_EQ_WAIT(NumRtpPackets(), NumRtpPackets(header.Ssrc()), kTimeout); + EXPECT_EQ_WAIT(NumRtpBytes(), NumRtpBytes(header.Ssrc()), kTimeout); EXPECT_EQ(1, NumSentSsrcs()); EXPECT_EQ(0, NumRtpPackets(kSsrc)); EXPECT_EQ(0, NumRtpBytes(kSsrc)); @@ -2069,12 +2122,10 @@ TEST_F(WebRtcVideoChannelBaseTest, AddRemoveSendStreams) { SendFrame(); EXPECT_FRAME_WAIT(1, kVideoWidth, kVideoHeight, kTimeout); EXPECT_GT(NumRtpPackets(), 0); - webrtc::RTPHeader header; + webrtc::RtpPacket header; size_t last_packet = NumRtpPackets() - 1; - std::unique_ptr p( - GetRtpPacket(static_cast(last_packet))); - EXPECT_TRUE(ParseRtpPacket(p.get(), &header)); - EXPECT_EQ(kSsrc, header.ssrc); + EXPECT_TRUE(header.Parse(GetRtpPacket(static_cast(last_packet)))); + EXPECT_EQ(kSsrc, header.Ssrc()); // Remove the send stream that was added during Setup. EXPECT_TRUE(channel_->RemoveSendStream(kSsrc)); @@ -2089,9 +2140,8 @@ TEST_F(WebRtcVideoChannelBaseTest, AddRemoveSendStreams) { EXPECT_TRUE_WAIT(NumRtpPackets() > rtp_packets, kTimeout); last_packet = NumRtpPackets() - 1; - p.reset(GetRtpPacket(static_cast(last_packet))); - EXPECT_TRUE(ParseRtpPacket(p.get(), &header)); - EXPECT_EQ(789u, header.ssrc); + EXPECT_TRUE(header.Parse(GetRtpPacket(static_cast(last_packet)))); + EXPECT_EQ(789u, header.Ssrc()); } // Tests the behavior of incoming streams in a conference scenario. @@ -2119,8 +2169,7 @@ TEST_F(WebRtcVideoChannelBaseTest, SimulateConference) { EXPECT_FRAME_ON_RENDERER_WAIT(renderer2, 1, kVideoWidth, kVideoHeight, kTimeout); - std::unique_ptr p(GetRtpPacket(0)); - EXPECT_EQ(DefaultCodec().id, GetPayloadType(p.get())); + EXPECT_EQ(DefaultCodec().id, GetPayloadType(GetRtpPacket(0))); EXPECT_EQ(kVideoWidth, renderer1.width()); EXPECT_EQ(kVideoHeight, renderer1.height()); EXPECT_EQ(kVideoWidth, renderer2.width()); @@ -2495,6 +2544,16 @@ class WebRtcVideoChannelTest : public WebRtcVideoEngineTest { ASSERT_TRUE(channel_->SetSendParameters(send_parameters_)); } + void TearDown() override { + channel_->SetInterface(nullptr); + channel_ = nullptr; + } + + void ResetTest() { + TearDown(); + SetUp(); + } + cricket::VideoCodec GetEngineCodec(const std::string& name) { for (const cricket::VideoCodec& engine_codec : engine_.send_codecs()) { if (absl::EqualsIgnoreCase(name, engine_codec.name)) @@ -2507,6 +2566,16 @@ class WebRtcVideoChannelTest : public WebRtcVideoEngineTest { cricket::VideoCodec DefaultCodec() { return GetEngineCodec("VP8"); } + // After receciving and processing the packet, enough time is advanced that + // the unsignalled receive stream cooldown is no longer in effect. + void ReceivePacketAndAdvanceTime(rtc::CopyOnWriteBuffer packet, + int64_t packet_time_us) { + channel_->OnPacketReceived(packet, packet_time_us); + rtc::Thread::Current()->ProcessMessages(0); + fake_clock_.AdvanceTime( + webrtc::TimeDelta::Millis(kUnsignalledReceiveStreamCooldownMs)); + } + protected: FakeVideoSendStream* AddSendStream() { return AddSendStream(StreamParams::CreateLegacy(++last_ssrc_)); @@ -2916,7 +2985,7 @@ TEST_F(WebRtcVideoChannelTest, RecvAbsoluteSendTimeHeaderExtensions) { } TEST_F(WebRtcVideoChannelTest, FiltersExtensionsPicksTransportSeqNum) { - webrtc::test::ScopedFieldTrials override_field_trials_( + webrtc::test::ScopedFieldTrials override_field_trials( "WebRTC-FilterAbsSendTimeExtension/Enabled/"); // Enable three redundant extensions. std::vector extensions; @@ -3147,7 +3216,7 @@ TEST_F(WebRtcVideoChannelTest, LossNotificationIsEnabledByFieldTrial) { RTC_DCHECK(!override_field_trials_); override_field_trials_ = std::make_unique( "WebRTC-RtcpLossNotification/Enabled/"); - SetUp(); + ResetTest(); TestLossNotificationState(true); } @@ -3155,7 +3224,7 @@ TEST_F(WebRtcVideoChannelTest, LossNotificationCanBeEnabledAndDisabled) { RTC_DCHECK(!override_field_trials_); override_field_trials_ = std::make_unique( "WebRTC-RtcpLossNotification/Enabled/"); - SetUp(); + ResetTest(); AssignDefaultCodec(); VerifyCodecHasDefaultFeedbackParams(default_codec_, true); @@ -4082,8 +4151,10 @@ TEST_F(WebRtcVideoChannelFlexfecRecvTest, SetDefaultRecvCodecsWithoutSsrc) { fake_call_->GetVideoReceiveStreams(); ASSERT_EQ(1U, video_streams.size()); const FakeVideoReceiveStream& video_stream = *video_streams.front(); - EXPECT_EQ(0, video_stream.GetNumAddedSecondarySinks()); - EXPECT_EQ(0, video_stream.GetNumRemovedSecondarySinks()); + const webrtc::VideoReceiveStream::Config& video_config = + video_stream.GetConfig(); + EXPECT_FALSE(video_config.rtp.protected_by_flexfec); + EXPECT_EQ(video_config.rtp.packet_sink_, nullptr); } TEST_F(WebRtcVideoChannelFlexfecRecvTest, SetDefaultRecvCodecsWithSsrc) { @@ -4093,10 +4164,10 @@ TEST_F(WebRtcVideoChannelFlexfecRecvTest, SetDefaultRecvCodecsWithSsrc) { const std::vector& streams = fake_call_->GetFlexfecReceiveStreams(); ASSERT_EQ(1U, streams.size()); - const FakeFlexfecReceiveStream* stream = streams.front(); + const auto* stream = streams.front(); const webrtc::FlexfecReceiveStream::Config& config = stream->GetConfig(); EXPECT_EQ(GetEngineCodec("flexfec-03").id, config.payload_type); - EXPECT_EQ(kFlexfecSsrc, config.remote_ssrc); + EXPECT_EQ(kFlexfecSsrc, config.rtp.remote_ssrc); ASSERT_EQ(1U, config.protected_media_ssrcs.size()); EXPECT_EQ(kSsrcs1[0], config.protected_media_ssrcs[0]); @@ -4104,14 +4175,17 @@ TEST_F(WebRtcVideoChannelFlexfecRecvTest, SetDefaultRecvCodecsWithSsrc) { fake_call_->GetVideoReceiveStreams(); ASSERT_EQ(1U, video_streams.size()); const FakeVideoReceiveStream& video_stream = *video_streams.front(); - EXPECT_EQ(1, video_stream.GetNumAddedSecondarySinks()); const webrtc::VideoReceiveStream::Config& video_config = video_stream.GetConfig(); EXPECT_TRUE(video_config.rtp.protected_by_flexfec); + EXPECT_NE(video_config.rtp.packet_sink_, nullptr); } +// Test changing the configuration after a video stream has been created and +// turn on flexfec. This will result in the video stream being recreated because +// the flexfec stream pointer is injected to the video stream at construction. TEST_F(WebRtcVideoChannelFlexfecRecvTest, - EnablingFlexfecDoesNotRecreateVideoReceiveStream) { + EnablingFlexfecRecreatesVideoReceiveStream) { cricket::VideoRecvParameters recv_parameters; recv_parameters.codecs.push_back(GetEngineCodec("VP8")); ASSERT_TRUE(channel_->SetRecvParameters(recv_parameters)); @@ -4122,25 +4196,37 @@ TEST_F(WebRtcVideoChannelFlexfecRecvTest, const std::vector& video_streams = fake_call_->GetVideoReceiveStreams(); ASSERT_EQ(1U, video_streams.size()); - const FakeVideoReceiveStream& video_stream = *video_streams.front(); - EXPECT_EQ(0, video_stream.GetNumAddedSecondarySinks()); - EXPECT_EQ(0, video_stream.GetNumRemovedSecondarySinks()); + const FakeVideoReceiveStream* video_stream = video_streams.front(); + const webrtc::VideoReceiveStream::Config* video_config = + &video_stream->GetConfig(); + EXPECT_FALSE(video_config->rtp.protected_by_flexfec); + EXPECT_EQ(video_config->rtp.packet_sink_, nullptr); // Enable FlexFEC. recv_parameters.codecs.push_back(GetEngineCodec("flexfec-03")); ASSERT_TRUE(channel_->SetRecvParameters(recv_parameters)); - EXPECT_EQ(2, fake_call_->GetNumCreatedReceiveStreams()) + + // Now the count of created streams will be 3 since the video stream was + // recreated and a flexfec stream was created. + EXPECT_EQ(3, fake_call_->GetNumCreatedReceiveStreams()) << "Enabling FlexFEC should create FlexfecReceiveStream."; + EXPECT_EQ(1U, fake_call_->GetVideoReceiveStreams().size()) << "Enabling FlexFEC should not create VideoReceiveStream."; EXPECT_EQ(1U, fake_call_->GetFlexfecReceiveStreams().size()) << "Enabling FlexFEC should create a single FlexfecReceiveStream."; - EXPECT_EQ(1, video_stream.GetNumAddedSecondarySinks()); - EXPECT_EQ(0, video_stream.GetNumRemovedSecondarySinks()); + video_stream = video_streams.front(); + video_config = &video_stream->GetConfig(); + EXPECT_TRUE(video_config->rtp.protected_by_flexfec); + EXPECT_NE(video_config->rtp.packet_sink_, nullptr); } +// Test changing the configuration after a video stream has been created with +// flexfec enabled and then turn off flexfec. This will result in the video +// stream being recreated because the flexfec stream pointer is injected to the +// video stream at construction and that config needs to be torn down. TEST_F(WebRtcVideoChannelFlexfecRecvTest, - DisablingFlexfecDoesNotRecreateVideoReceiveStream) { + DisablingFlexfecRecreatesVideoReceiveStream) { cricket::VideoRecvParameters recv_parameters; recv_parameters.codecs.push_back(GetEngineCodec("VP8")); recv_parameters.codecs.push_back(GetEngineCodec("flexfec-03")); @@ -4153,22 +4239,28 @@ TEST_F(WebRtcVideoChannelFlexfecRecvTest, const std::vector& video_streams = fake_call_->GetVideoReceiveStreams(); ASSERT_EQ(1U, video_streams.size()); - const FakeVideoReceiveStream& video_stream = *video_streams.front(); - EXPECT_EQ(1, video_stream.GetNumAddedSecondarySinks()); - EXPECT_EQ(0, video_stream.GetNumRemovedSecondarySinks()); + const FakeVideoReceiveStream* video_stream = video_streams.front(); + const webrtc::VideoReceiveStream::Config* video_config = + &video_stream->GetConfig(); + EXPECT_TRUE(video_config->rtp.protected_by_flexfec); + EXPECT_NE(video_config->rtp.packet_sink_, nullptr); // Disable FlexFEC. recv_parameters.codecs.clear(); recv_parameters.codecs.push_back(GetEngineCodec("VP8")); ASSERT_TRUE(channel_->SetRecvParameters(recv_parameters)); - EXPECT_EQ(2, fake_call_->GetNumCreatedReceiveStreams()) + // Now the count of created streams will be 3 since the video stream had to + // be recreated on account of the flexfec stream being deleted. + EXPECT_EQ(3, fake_call_->GetNumCreatedReceiveStreams()) << "Disabling FlexFEC should not recreate VideoReceiveStream."; EXPECT_EQ(1U, fake_call_->GetVideoReceiveStreams().size()) << "Disabling FlexFEC should not destroy VideoReceiveStream."; EXPECT_TRUE(fake_call_->GetFlexfecReceiveStreams().empty()) << "Disabling FlexFEC should destroy FlexfecReceiveStream."; - EXPECT_EQ(1, video_stream.GetNumAddedSecondarySinks()); - EXPECT_EQ(1, video_stream.GetNumRemovedSecondarySinks()); + video_stream = video_streams.front(); + video_config = &video_stream->GetConfig(); + EXPECT_FALSE(video_config->rtp.protected_by_flexfec); + EXPECT_EQ(video_config->rtp.packet_sink_, nullptr); } TEST_F(WebRtcVideoChannelFlexfecRecvTest, DuplicateFlexfecCodecIsDropped) { @@ -4188,7 +4280,7 @@ TEST_F(WebRtcVideoChannelFlexfecRecvTest, DuplicateFlexfecCodecIsDropped) { const std::vector& streams = fake_call_->GetFlexfecReceiveStreams(); ASSERT_EQ(1U, streams.size()); - const FakeFlexfecReceiveStream* stream = streams.front(); + const auto* stream = streams.front(); const webrtc::FlexfecReceiveStream::Config& config = stream->GetConfig(); EXPECT_EQ(GetEngineCodec("flexfec-03").id, config.payload_type); } @@ -4264,7 +4356,7 @@ TEST_F(WebRtcVideoChannelFlexfecRecvTest, SetRecvCodecsWithFec) { flexfec_stream->GetConfig(); EXPECT_EQ(GetEngineCodec("flexfec-03").id, flexfec_stream_config.payload_type); - EXPECT_EQ(kFlexfecSsrc, flexfec_stream_config.remote_ssrc); + EXPECT_EQ(kFlexfecSsrc, flexfec_stream_config.rtp.remote_ssrc); ASSERT_EQ(1U, flexfec_stream_config.protected_media_ssrcs.size()); EXPECT_EQ(kSsrcs1[0], flexfec_stream_config.protected_media_ssrcs[0]); const std::vector& video_streams = @@ -4273,17 +4365,17 @@ TEST_F(WebRtcVideoChannelFlexfecRecvTest, SetRecvCodecsWithFec) { const webrtc::VideoReceiveStream::Config& video_stream_config = video_stream->GetConfig(); EXPECT_EQ(video_stream_config.rtp.local_ssrc, - flexfec_stream_config.local_ssrc); + flexfec_stream_config.rtp.local_ssrc); EXPECT_EQ(video_stream_config.rtp.rtcp_mode, flexfec_stream_config.rtcp_mode); EXPECT_EQ(video_stream_config.rtcp_send_transport, flexfec_stream_config.rtcp_send_transport); // TODO(brandtr): Update this EXPECT when we set |transport_cc| in a // spec-compliant way. EXPECT_EQ(video_stream_config.rtp.transport_cc, - flexfec_stream_config.transport_cc); + flexfec_stream_config.rtp.transport_cc); EXPECT_EQ(video_stream_config.rtp.rtcp_mode, flexfec_stream_config.rtcp_mode); EXPECT_EQ(video_stream_config.rtp.extensions, - flexfec_stream_config.rtp_header_extensions); + flexfec_stream_config.rtp.extensions); } // We should not send FlexFEC, even if we advertise it, unless the right @@ -4901,6 +4993,76 @@ TEST_F(WebRtcVideoChannelTest, SetRecvCodecsWithChangedRtxPayloadType) { EXPECT_EQ(kRtxSsrcs1[0], config_after.rtp.rtx_ssrc); } +TEST_F(WebRtcVideoChannelTest, SetRecvCodecsRtxWithRtxTime) { + const int kUnusedPayloadType1 = 126; + const int kUnusedPayloadType2 = 127; + EXPECT_FALSE(FindCodecById(engine_.recv_codecs(), kUnusedPayloadType1)); + EXPECT_FALSE(FindCodecById(engine_.recv_codecs(), kUnusedPayloadType2)); + + // SSRCs for RTX. + cricket::StreamParams params = + cricket::StreamParams::CreateLegacy(kSsrcs1[0]); + params.AddFidSsrc(kSsrcs1[0], kRtxSsrcs1[0]); + AddRecvStream(params); + + // Payload type for RTX. + cricket::VideoRecvParameters parameters; + parameters.codecs.push_back(GetEngineCodec("VP8")); + cricket::VideoCodec rtx_codec(kUnusedPayloadType1, "rtx"); + rtx_codec.SetParam("apt", GetEngineCodec("VP8").id); + parameters.codecs.push_back(rtx_codec); + EXPECT_TRUE(channel_->SetRecvParameters(parameters)); + ASSERT_EQ(1U, fake_call_->GetVideoReceiveStreams().size()); + const webrtc::VideoReceiveStream::Config& config = + fake_call_->GetVideoReceiveStreams()[0]->GetConfig(); + + const int kRtxTime = 343; + // Assert that the default value is different from the ones we test + // and store the default value. + EXPECT_NE(config.rtp.nack.rtp_history_ms, kRtxTime); + int default_history_ms = config.rtp.nack.rtp_history_ms; + + // Set rtx-time. + parameters.codecs[1].SetParam(kCodecParamRtxTime, kRtxTime); + EXPECT_TRUE(channel_->SetRecvParameters(parameters)); + EXPECT_EQ(fake_call_->GetVideoReceiveStreams()[0] + ->GetConfig() + .rtp.nack.rtp_history_ms, + kRtxTime); + + // Negative values are ignored so the default value applies. + parameters.codecs[1].SetParam(kCodecParamRtxTime, -1); + EXPECT_TRUE(channel_->SetRecvParameters(parameters)); + EXPECT_NE(fake_call_->GetVideoReceiveStreams()[0] + ->GetConfig() + .rtp.nack.rtp_history_ms, + -1); + EXPECT_EQ(fake_call_->GetVideoReceiveStreams()[0] + ->GetConfig() + .rtp.nack.rtp_history_ms, + default_history_ms); + + // 0 is ignored so the default applies. + parameters.codecs[1].SetParam(kCodecParamRtxTime, 0); + EXPECT_TRUE(channel_->SetRecvParameters(parameters)); + EXPECT_NE(fake_call_->GetVideoReceiveStreams()[0] + ->GetConfig() + .rtp.nack.rtp_history_ms, + 0); + EXPECT_EQ(fake_call_->GetVideoReceiveStreams()[0] + ->GetConfig() + .rtp.nack.rtp_history_ms, + default_history_ms); + + // Values larger than the default are clamped to the default. + parameters.codecs[1].SetParam(kCodecParamRtxTime, default_history_ms + 100); + EXPECT_TRUE(channel_->SetRecvParameters(parameters)); + EXPECT_EQ(fake_call_->GetVideoReceiveStreams()[0] + ->GetConfig() + .rtp.nack.rtp_history_ms, + default_history_ms); +} + TEST_F(WebRtcVideoChannelTest, SetRecvCodecsDifferentPayloadType) { cricket::VideoRecvParameters parameters; parameters.codecs.push_back(GetEngineCodec("VP8")); @@ -4964,7 +5126,7 @@ TEST_F(WebRtcVideoChannelFlexfecRecvTest, SetRecvParamsWithoutFecDisablesFec) { ASSERT_EQ(1U, streams.size()); const FakeFlexfecReceiveStream* stream = streams.front(); EXPECT_EQ(GetEngineCodec("flexfec-03").id, stream->GetConfig().payload_type); - EXPECT_EQ(kFlexfecSsrc, stream->GetConfig().remote_ssrc); + EXPECT_EQ(kFlexfecSsrc, stream->rtp_config().remote_ssrc); ASSERT_EQ(1U, stream->GetConfig().protected_media_ssrcs.size()); EXPECT_EQ(kSsrcs1[0], stream->GetConfig().protected_media_ssrcs[0]); @@ -5017,7 +5179,7 @@ TEST_F(WebRtcVideoChannelFlexfecSendRecvTest, const FakeFlexfecReceiveStream* stream_with_recv_params = streams.front(); EXPECT_EQ(GetEngineCodec("flexfec-03").id, stream_with_recv_params->GetConfig().payload_type); - EXPECT_EQ(kFlexfecSsrc, stream_with_recv_params->GetConfig().remote_ssrc); + EXPECT_EQ(kFlexfecSsrc, stream_with_recv_params->GetConfig().rtp.remote_ssrc); EXPECT_EQ(1U, stream_with_recv_params->GetConfig().protected_media_ssrcs.size()); EXPECT_EQ(kSsrcs1[0], @@ -5031,7 +5193,7 @@ TEST_F(WebRtcVideoChannelFlexfecSendRecvTest, const FakeFlexfecReceiveStream* stream_with_send_params = streams.front(); EXPECT_EQ(GetEngineCodec("flexfec-03").id, stream_with_send_params->GetConfig().payload_type); - EXPECT_EQ(kFlexfecSsrc, stream_with_send_params->GetConfig().remote_ssrc); + EXPECT_EQ(kFlexfecSsrc, stream_with_send_params->GetConfig().rtp.remote_ssrc); EXPECT_EQ(1U, stream_with_send_params->GetConfig().protected_media_ssrcs.size()); EXPECT_EQ(kSsrcs1[0], @@ -5137,6 +5299,7 @@ TEST_F(WebRtcVideoChannelTest, TestSetDscpOptions) { channel->SetInterface(network_interface.get()); // Default value when DSCP is disabled should be DSCP_DEFAULT. EXPECT_EQ(rtc::DSCP_DEFAULT, network_interface->dscp()); + channel->SetInterface(nullptr); // Default value when DSCP is enabled is also DSCP_DEFAULT, until it is set // through rtp parameters. @@ -5166,6 +5329,7 @@ TEST_F(WebRtcVideoChannelTest, TestSetDscpOptions) { EXPECT_TRUE(static_cast(channel.get()) ->SendRtcp(kData, sizeof(kData))); EXPECT_EQ(rtc::DSCP_CS1, network_interface->options().dscp); + channel->SetInterface(nullptr); // Verify that setting the option to false resets the // DiffServCodePoint. @@ -5176,6 +5340,7 @@ TEST_F(WebRtcVideoChannelTest, TestSetDscpOptions) { video_bitrate_allocator_factory_.get()))); channel->SetInterface(network_interface.get()); EXPECT_EQ(rtc::DSCP_DEFAULT, network_interface->dscp()); + channel->SetInterface(nullptr); } // This test verifies that the RTCP reduced size mode is properly applied to @@ -5390,7 +5555,7 @@ TEST_F(WebRtcVideoChannelTest, GetAggregatedStatsReportWithoutSubStreams) { // Comes from substream only. EXPECT_EQ(sender.firs_rcvd, 0); EXPECT_EQ(sender.plis_rcvd, 0); - EXPECT_EQ(sender.nacks_rcvd, 0); + EXPECT_EQ(sender.nacks_rcvd, 0u); EXPECT_EQ(sender.send_frame_width, 0); EXPECT_EQ(sender.send_frame_height, 0); @@ -5450,9 +5615,11 @@ TEST_F(WebRtcVideoChannelTest, GetAggregatedStatsReportForSubStreams) { substream.rtcp_packet_type_counts.fir_packets = 14; substream.rtcp_packet_type_counts.nack_packets = 15; substream.rtcp_packet_type_counts.pli_packets = 16; - substream.rtcp_stats.packets_lost = 17; - substream.rtcp_stats.fraction_lost = 18; + webrtc::RTCPReportBlock report_block; + report_block.packets_lost = 17; + report_block.fraction_lost = 18; webrtc::ReportBlockData report_block_data; + report_block_data.SetReportBlock(report_block, 0); report_block_data.AddRoundTripTimeSample(19); substream.report_block_data = report_block_data; substream.encode_frame_rate = 20.0; @@ -5486,9 +5653,12 @@ TEST_F(WebRtcVideoChannelTest, GetAggregatedStatsReportForSubStreams) { static_cast(2 * substream.rtp_stats.transmitted.packets)); EXPECT_EQ(sender.retransmitted_packets_sent, 2u * substream.rtp_stats.retransmitted.packets); - EXPECT_EQ(sender.packets_lost, 2 * substream.rtcp_stats.packets_lost); + EXPECT_EQ(sender.packets_lost, + 2 * substream.report_block_data->report_block().packets_lost); EXPECT_EQ(sender.fraction_lost, - static_cast(substream.rtcp_stats.fraction_lost) / (1 << 8)); + static_cast( + substream.report_block_data->report_block().fraction_lost) / + (1 << 8)); EXPECT_EQ(sender.rtt_ms, 0); EXPECT_EQ(sender.codec_name, DefaultCodec().name); EXPECT_EQ(sender.codec_payload_type, DefaultCodec().id); @@ -5509,9 +5679,8 @@ TEST_F(WebRtcVideoChannelTest, GetAggregatedStatsReportForSubStreams) { EXPECT_EQ( sender.plis_rcvd, static_cast(2 * substream.rtcp_packet_type_counts.pli_packets)); - EXPECT_EQ( - sender.nacks_rcvd, - static_cast(2 * substream.rtcp_packet_type_counts.nack_packets)); + EXPECT_EQ(sender.nacks_rcvd, + 2 * substream.rtcp_packet_type_counts.nack_packets); EXPECT_EQ(sender.send_frame_width, substream.width); EXPECT_EQ(sender.send_frame_height, substream.height); @@ -5568,9 +5737,11 @@ TEST_F(WebRtcVideoChannelTest, GetPerLayerStatsReportForSubStreams) { substream.rtcp_packet_type_counts.fir_packets = 14; substream.rtcp_packet_type_counts.nack_packets = 15; substream.rtcp_packet_type_counts.pli_packets = 16; - substream.rtcp_stats.packets_lost = 17; - substream.rtcp_stats.fraction_lost = 18; + webrtc::RTCPReportBlock report_block; + report_block.packets_lost = 17; + report_block.fraction_lost = 18; webrtc::ReportBlockData report_block_data; + report_block_data.SetReportBlock(report_block, 0); report_block_data.AddRoundTripTimeSample(19); substream.report_block_data = report_block_data; substream.encode_frame_rate = 20.0; @@ -5604,9 +5775,12 @@ TEST_F(WebRtcVideoChannelTest, GetPerLayerStatsReportForSubStreams) { static_cast(substream.rtp_stats.transmitted.packets)); EXPECT_EQ(sender.retransmitted_packets_sent, substream.rtp_stats.retransmitted.packets); - EXPECT_EQ(sender.packets_lost, substream.rtcp_stats.packets_lost); + EXPECT_EQ(sender.packets_lost, + substream.report_block_data->report_block().packets_lost); EXPECT_EQ(sender.fraction_lost, - static_cast(substream.rtcp_stats.fraction_lost) / (1 << 8)); + static_cast( + substream.report_block_data->report_block().fraction_lost) / + (1 << 8)); EXPECT_EQ(sender.rtt_ms, 0); EXPECT_EQ(sender.codec_name, DefaultCodec().name); EXPECT_EQ(sender.codec_payload_type, DefaultCodec().id); @@ -5625,8 +5799,7 @@ TEST_F(WebRtcVideoChannelTest, GetPerLayerStatsReportForSubStreams) { static_cast(substream.rtcp_packet_type_counts.fir_packets)); EXPECT_EQ(sender.plis_rcvd, static_cast(substream.rtcp_packet_type_counts.pli_packets)); - EXPECT_EQ(sender.nacks_rcvd, - static_cast(substream.rtcp_packet_type_counts.nack_packets)); + EXPECT_EQ(sender.nacks_rcvd, substream.rtcp_packet_type_counts.nack_packets); EXPECT_EQ(sender.send_frame_width, substream.width); EXPECT_EQ(sender.send_frame_height, substream.height); @@ -5947,15 +6120,15 @@ TEST_F(WebRtcVideoChannelTest, GetStatsTranslatesSendRtcpPacketTypesCorrectly) { cricket::VideoMediaInfo info; ASSERT_TRUE(channel_->GetStats(&info)); EXPECT_EQ(2, info.senders[0].firs_rcvd); - EXPECT_EQ(3, info.senders[0].nacks_rcvd); + EXPECT_EQ(3u, info.senders[0].nacks_rcvd); EXPECT_EQ(4, info.senders[0].plis_rcvd); EXPECT_EQ(5, info.senders[1].firs_rcvd); - EXPECT_EQ(7, info.senders[1].nacks_rcvd); + EXPECT_EQ(7u, info.senders[1].nacks_rcvd); EXPECT_EQ(9, info.senders[1].plis_rcvd); EXPECT_EQ(7, info.aggregated_senders[0].firs_rcvd); - EXPECT_EQ(10, info.aggregated_senders[0].nacks_rcvd); + EXPECT_EQ(10u, info.aggregated_senders[0].nacks_rcvd); EXPECT_EQ(13, info.aggregated_senders[0].plis_rcvd); } @@ -5973,7 +6146,7 @@ TEST_F(WebRtcVideoChannelTest, EXPECT_EQ(stats.rtcp_packet_type_counts.fir_packets, rtc::checked_cast(info.receivers[0].firs_sent)); EXPECT_EQ(stats.rtcp_packet_type_counts.nack_packets, - rtc::checked_cast(info.receivers[0].nacks_sent)); + info.receivers[0].nacks_sent); EXPECT_EQ(stats.rtcp_packet_type_counts.pli_packets, rtc::checked_cast(info.receivers[0].plis_sent)); } @@ -6133,7 +6306,7 @@ TEST_F(WebRtcVideoChannelTest, DefaultReceiveStreamReconfiguresToUseRtx) { memset(data, 0, sizeof(data)); rtc::SetBE32(&data[8], ssrcs[0]); rtc::CopyOnWriteBuffer packet(data, kDataLength); - channel_->OnPacketReceived(packet, /* packet_time_us */ -1); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); ASSERT_EQ(1u, fake_call_->GetVideoReceiveStreams().size()) << "No default receive stream created."; @@ -6281,6 +6454,9 @@ TEST_F(WebRtcVideoChannelTest, RecvUnsignaledSsrcWithSignaledStreamId) { cricket::StreamParams unsignaled_stream; unsignaled_stream.set_stream_ids({kSyncLabel}); ASSERT_TRUE(channel_->AddRecvStream(unsignaled_stream)); + channel_->OnDemuxerCriteriaUpdatePending(); + channel_->OnDemuxerCriteriaUpdateComplete(); + rtc::Thread::Current()->ProcessMessages(0); // The stream shouldn't have been created at this point because it doesn't // have any SSRCs. EXPECT_EQ(0u, fake_call_->GetVideoReceiveStreams().size()); @@ -6291,19 +6467,29 @@ TEST_F(WebRtcVideoChannelTest, RecvUnsignaledSsrcWithSignaledStreamId) { memset(data, 0, sizeof(data)); rtc::SetBE32(&data[8], kIncomingUnsignalledSsrc); rtc::CopyOnWriteBuffer packet(data, kDataLength); - channel_->OnPacketReceived(packet, /* packet_time_us */ -1); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); // The stream should now be created with the appropriate sync label. EXPECT_EQ(1u, fake_call_->GetVideoReceiveStreams().size()); EXPECT_EQ(kSyncLabel, fake_call_->GetVideoReceiveStreams()[0]->GetConfig().sync_group); - // Reset the unsignaled stream to clear the cache. This time when - // a default video receive stream is created it won't have a sync_group. + // Reset the unsignaled stream to clear the cache. This deletes the receive + // stream. channel_->ResetUnsignaledRecvStream(); + channel_->OnDemuxerCriteriaUpdatePending(); + EXPECT_EQ(0u, fake_call_->GetVideoReceiveStreams().size()); + + // Until the demuxer criteria has been updated, we ignore in-flight ssrcs of + // the recently removed unsignaled receive stream. + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); EXPECT_EQ(0u, fake_call_->GetVideoReceiveStreams().size()); - channel_->OnPacketReceived(packet, /* packet_time_us */ -1); + // After the demuxer criteria has been updated, we should proceed to create + // unsignalled receive streams. This time when a default video receive stream + // is created it won't have a sync_group. + channel_->OnDemuxerCriteriaUpdateComplete(); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); EXPECT_EQ(1u, fake_call_->GetVideoReceiveStreams().size()); EXPECT_TRUE( fake_call_->GetVideoReceiveStreams()[0]->GetConfig().sync_group.empty()); @@ -6320,7 +6506,7 @@ TEST_F(WebRtcVideoChannelTest, memset(data, 0, sizeof(data)); rtc::SetBE32(&data[8], kIncomingUnsignalledSsrc); rtc::CopyOnWriteBuffer packet(data, kDataLength); - channel_->OnPacketReceived(packet, /* packet_time_us */ -1); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); // Default receive stream created. const auto& receivers1 = fake_call_->GetVideoReceiveStreams(); @@ -6340,6 +6526,332 @@ TEST_F(WebRtcVideoChannelTest, EXPECT_EQ(receivers2[0]->GetConfig().rtp.remote_ssrc, kIncomingSignalledSsrc); } +TEST_F(WebRtcVideoChannelTest, + RecentlyAddedSsrcsDoNotCreateUnsignalledRecvStreams) { + const uint32_t kSsrc1 = 1; + const uint32_t kSsrc2 = 2; + + // Starting point: receiving kSsrc1. + EXPECT_TRUE(channel_->AddRecvStream(StreamParams::CreateLegacy(kSsrc1))); + channel_->OnDemuxerCriteriaUpdatePending(); + channel_->OnDemuxerCriteriaUpdateComplete(); + rtc::Thread::Current()->ProcessMessages(0); + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 1u); + + // If this is the only m= section the demuxer might be configure to forward + // all packets, regardless of ssrc, to this channel. When we go to multiple m= + // sections, there can thus be a window of time where packets that should + // never have belonged to this channel arrive anyway. + + // Emulate a second m= section being created by updating the demuxer criteria + // without adding any streams. + channel_->OnDemuxerCriteriaUpdatePending(); + + // Emulate there being in-flight packets for kSsrc1 and kSsrc2 arriving before + // the demuxer is updated. + { + // Receive a packet for kSsrc1. + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc1); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); + } + { + // Receive a packet for kSsrc2. + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc2); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); + } + + // No unsignaled ssrc for kSsrc2 should have been created, but kSsrc1 should + // arrive since it already has a stream. + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 1u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc1), 1u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc2), 0u); + + // Signal that the demuxer update is complete. Because there are no more + // pending demuxer updates, receiving unknown ssrcs (kSsrc2) should again + // result in unsignalled receive streams being created. + channel_->OnDemuxerCriteriaUpdateComplete(); + rtc::Thread::Current()->ProcessMessages(0); + + // Receive packets for kSsrc1 and kSsrc2 again. + { + // Receive a packet for kSsrc1. + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc1); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); + } + { + // Receive a packet for kSsrc2. + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc2); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); + } + + // An unsignalled ssrc for kSsrc2 should be created and the packet counter + // should increase for both ssrcs. + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 2u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc1), 2u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc2), 1u); +} + +TEST_F(WebRtcVideoChannelTest, + RecentlyRemovedSsrcsDoNotCreateUnsignalledRecvStreams) { + const uint32_t kSsrc1 = 1; + const uint32_t kSsrc2 = 2; + + // Starting point: receiving kSsrc1 and kSsrc2. + EXPECT_TRUE(channel_->AddRecvStream(StreamParams::CreateLegacy(kSsrc1))); + EXPECT_TRUE(channel_->AddRecvStream(StreamParams::CreateLegacy(kSsrc2))); + channel_->OnDemuxerCriteriaUpdatePending(); + channel_->OnDemuxerCriteriaUpdateComplete(); + rtc::Thread::Current()->ProcessMessages(0); + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 2u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc1), 0u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc2), 0u); + + // Remove kSsrc1, signal that a demuxer criteria update is pending, but not + // completed yet. + EXPECT_TRUE(channel_->RemoveRecvStream(kSsrc1)); + channel_->OnDemuxerCriteriaUpdatePending(); + + // We only have a receiver for kSsrc2 now. + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 1u); + + // Emulate there being in-flight packets for kSsrc1 and kSsrc2 arriving before + // the demuxer is updated. + { + // Receive a packet for kSsrc1. + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc1); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); + } + { + // Receive a packet for kSsrc2. + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc2); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); + } + + // No unsignaled ssrc for kSsrc1 should have been created, but the packet + // count for kSsrc2 should increase. + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 1u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc1), 0u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc2), 1u); + + // Signal that the demuxer update is complete. This means we should stop + // ignorning kSsrc1. + channel_->OnDemuxerCriteriaUpdateComplete(); + rtc::Thread::Current()->ProcessMessages(0); + + // Receive packets for kSsrc1 and kSsrc2 again. + { + // Receive a packet for kSsrc1. + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc1); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); + } + { + // Receive a packet for kSsrc2. + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc2); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); + } + + // An unsignalled ssrc for kSsrc1 should be created and the packet counter + // should increase for both ssrcs. + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 2u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc1), 1u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc2), 2u); +} + +TEST_F(WebRtcVideoChannelTest, MultiplePendingDemuxerCriteriaUpdates) { + const uint32_t kSsrc = 1; + + // Starting point: receiving kSsrc. + EXPECT_TRUE(channel_->AddRecvStream(StreamParams::CreateLegacy(kSsrc))); + channel_->OnDemuxerCriteriaUpdatePending(); + channel_->OnDemuxerCriteriaUpdateComplete(); + rtc::Thread::Current()->ProcessMessages(0); + ASSERT_EQ(fake_call_->GetVideoReceiveStreams().size(), 1u); + + // Remove kSsrc... + EXPECT_TRUE(channel_->RemoveRecvStream(kSsrc)); + channel_->OnDemuxerCriteriaUpdatePending(); + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 0u); + // And then add it back again, before the demuxer knows about the new + // criteria! + EXPECT_TRUE(channel_->AddRecvStream(StreamParams::CreateLegacy(kSsrc))); + channel_->OnDemuxerCriteriaUpdatePending(); + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 1u); + + // In-flight packets should arrive because the stream was recreated, even + // though demuxer criteria updates are pending... + { + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); + } + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc), 1u); + + // Signal that the demuxer knows about the first update: the removal. + channel_->OnDemuxerCriteriaUpdateComplete(); + rtc::Thread::Current()->ProcessMessages(0); + + // This still should not prevent in-flight packets from arriving because we + // have a receive stream for it. + { + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); + } + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc), 2u); + + // Remove the kSsrc again while previous demuxer updates are still pending. + EXPECT_TRUE(channel_->RemoveRecvStream(kSsrc)); + channel_->OnDemuxerCriteriaUpdatePending(); + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 0u); + + // Now the packet should be dropped and not create an unsignalled receive + // stream. + { + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); + } + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 0u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc), 2u); + + // Signal that the demuxer knows about the second update: adding it back. + channel_->OnDemuxerCriteriaUpdateComplete(); + rtc::Thread::Current()->ProcessMessages(0); + + // The packets should continue to be dropped because removal happened after + // the most recently completed demuxer update. + { + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); + } + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 0u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc), 2u); + + // Signal that the demuxer knows about the last update: the second removal. + channel_->OnDemuxerCriteriaUpdateComplete(); + rtc::Thread::Current()->ProcessMessages(0); + + // If packets still arrive after the demuxer knows about the latest removal we + // should finally create an unsignalled receive stream. + { + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); + } + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 1u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc), 3u); +} + +TEST_F(WebRtcVideoChannelTest, UnsignalledSsrcHasACooldown) { + const uint32_t kSsrc1 = 1; + const uint32_t kSsrc2 = 2; + + // Send packets for kSsrc1, creating an unsignalled receive stream. + { + // Receive a packet for kSsrc1. + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc1); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + channel_->OnPacketReceived(packet, /* packet_time_us */ -1); + } + rtc::Thread::Current()->ProcessMessages(0); + fake_clock_.AdvanceTime( + webrtc::TimeDelta::Millis(kUnsignalledReceiveStreamCooldownMs - 1)); + + // We now have an unsignalled receive stream for kSsrc1. + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 1u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc1), 1u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc2), 0u); + + { + // Receive a packet for kSsrc2. + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc2); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + channel_->OnPacketReceived(packet, /* packet_time_us */ -1); + } + rtc::Thread::Current()->ProcessMessages(0); + + // Not enough time has passed to replace the unsignalled receive stream, so + // the kSsrc2 should be ignored. + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 1u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc1), 1u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc2), 0u); + + // After 500 ms, kSsrc2 should trigger a new unsignalled receive stream that + // replaces the old one. + fake_clock_.AdvanceTime(webrtc::TimeDelta::Millis(1)); + { + // Receive a packet for kSsrc2. + const size_t kDataLength = 12; + uint8_t data[kDataLength]; + memset(data, 0, sizeof(data)); + rtc::SetBE32(&data[8], kSsrc2); + rtc::CopyOnWriteBuffer packet(data, kDataLength); + channel_->OnPacketReceived(packet, /* packet_time_us */ -1); + } + rtc::Thread::Current()->ProcessMessages(0); + + // The old unsignalled receive stream was destroyed and replaced, so we still + // only have one unsignalled receive stream. But tha packet counter for kSsrc2 + // has now increased. + EXPECT_EQ(fake_call_->GetVideoReceiveStreams().size(), 1u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc1), 1u); + EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(kSsrc2), 1u); +} + // Test BaseMinimumPlayoutDelayMs on receive streams. TEST_F(WebRtcVideoChannelTest, BaseMinimumPlayoutDelayMs) { // Test that set won't work for non-existing receive streams. @@ -6373,7 +6885,7 @@ TEST_F(WebRtcVideoChannelTest, BaseMinimumPlayoutDelayMsUnsignaledRecvStream) { memset(data, 0, sizeof(data)); rtc::SetBE32(&data[8], kIncomingUnsignalledSsrc); rtc::CopyOnWriteBuffer packet(data, kDataLength); - channel_->OnPacketReceived(packet, /* packet_time_us */ -1); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); recv_stream = fake_call_->GetVideoReceiveStream(kIncomingUnsignalledSsrc); EXPECT_EQ(recv_stream->base_mininum_playout_delay_ms(), 200); @@ -6410,7 +6922,7 @@ void WebRtcVideoChannelTest::TestReceiveUnsignaledSsrcPacket( rtc::Set8(data, 1, payload_type); rtc::SetBE32(&data[8], kIncomingUnsignalledSsrc); rtc::CopyOnWriteBuffer packet(data, kDataLength); - channel_->OnPacketReceived(packet, /* packet_time_us */ -1); + ReceivePacketAndAdvanceTime(packet, /* packet_time_us */ -1); if (expect_created_receive_stream) { EXPECT_EQ(1u, fake_call_->GetVideoReceiveStreams().size()) @@ -6490,18 +7002,14 @@ TEST_F(WebRtcVideoChannelTest, ReceiveDifferentUnsignaledSsrc) { channel_->SetDefaultSink(&renderer); // Receive VP8 packet on first SSRC. - uint8_t data[kMinRtpPacketLen]; - cricket::RtpHeader rtpHeader; - rtpHeader.payload_type = GetEngineCodec("VP8").id; - rtpHeader.seq_num = rtpHeader.timestamp = 0; - rtpHeader.ssrc = kIncomingUnsignalledSsrc + 1; - cricket::SetRtpHeader(data, sizeof(data), rtpHeader); - rtc::CopyOnWriteBuffer packet(data, sizeof(data)); - channel_->OnPacketReceived(packet, /* packet_time_us */ -1); + webrtc::RtpPacket rtp_packet; + rtp_packet.SetPayloadType(GetEngineCodec("VP8").id); + rtp_packet.SetSsrc(kIncomingUnsignalledSsrc + 1); + ReceivePacketAndAdvanceTime(rtp_packet.Buffer(), /* packet_time_us */ -1); // VP8 packet should create default receive stream. ASSERT_EQ(1u, fake_call_->GetVideoReceiveStreams().size()); FakeVideoReceiveStream* recv_stream = fake_call_->GetVideoReceiveStreams()[0]; - EXPECT_EQ(rtpHeader.ssrc, recv_stream->GetConfig().rtp.remote_ssrc); + EXPECT_EQ(rtp_packet.Ssrc(), recv_stream->GetConfig().rtp.remote_ssrc); // Verify that the receive stream sinks to a renderer. webrtc::VideoFrame video_frame = webrtc::VideoFrame::Builder() @@ -6514,15 +7022,13 @@ TEST_F(WebRtcVideoChannelTest, ReceiveDifferentUnsignaledSsrc) { EXPECT_EQ(1, renderer.num_rendered_frames()); // Receive VP9 packet on second SSRC. - rtpHeader.payload_type = GetEngineCodec("VP9").id; - rtpHeader.ssrc = kIncomingUnsignalledSsrc + 2; - cricket::SetRtpHeader(data, sizeof(data), rtpHeader); - rtc::CopyOnWriteBuffer packet2(data, sizeof(data)); - channel_->OnPacketReceived(packet2, /* packet_time_us */ -1); + rtp_packet.SetPayloadType(GetEngineCodec("VP9").id); + rtp_packet.SetSsrc(kIncomingUnsignalledSsrc + 2); + ReceivePacketAndAdvanceTime(rtp_packet.Buffer(), /* packet_time_us */ -1); // VP9 packet should replace the default receive SSRC. ASSERT_EQ(1u, fake_call_->GetVideoReceiveStreams().size()); recv_stream = fake_call_->GetVideoReceiveStreams()[0]; - EXPECT_EQ(rtpHeader.ssrc, recv_stream->GetConfig().rtp.remote_ssrc); + EXPECT_EQ(rtp_packet.Ssrc(), recv_stream->GetConfig().rtp.remote_ssrc); // Verify that the receive stream sinks to a renderer. webrtc::VideoFrame video_frame2 = webrtc::VideoFrame::Builder() @@ -6536,15 +7042,13 @@ TEST_F(WebRtcVideoChannelTest, ReceiveDifferentUnsignaledSsrc) { #if defined(WEBRTC_USE_H264) // Receive H264 packet on third SSRC. - rtpHeader.payload_type = 126; - rtpHeader.ssrc = kIncomingUnsignalledSsrc + 3; - cricket::SetRtpHeader(data, sizeof(data), rtpHeader); - rtc::CopyOnWriteBuffer packet3(data, sizeof(data)); - channel_->OnPacketReceived(packet3, /* packet_time_us */ -1); + rtp_packet.SetPayloadType(126); + rtp_packet.SetSsrc(kIncomingUnsignalledSsrc + 3); + ReceivePacketAndAdvanceTime(rtp_packet.Buffer(), /* packet_time_us */ -1); // H264 packet should replace the default receive SSRC. ASSERT_EQ(1u, fake_call_->GetVideoReceiveStreams().size()); recv_stream = fake_call_->GetVideoReceiveStreams()[0]; - EXPECT_EQ(rtpHeader.ssrc, recv_stream->GetConfig().rtp.remote_ssrc); + EXPECT_EQ(rtp_packet.Ssrc(), recv_stream->GetConfig().rtp.remote_ssrc); // Verify that the receive stream sinks to a renderer. webrtc::VideoFrame video_frame3 = webrtc::VideoFrame::Builder() @@ -6572,14 +7076,10 @@ TEST_F(WebRtcVideoChannelTest, EXPECT_EQ(0u, fake_call_->GetVideoReceiveStreams().size()); // Receive packet on an unsignaled SSRC. - uint8_t data[kMinRtpPacketLen]; - cricket::RtpHeader rtp_header; - rtp_header.payload_type = GetEngineCodec("VP8").id; - rtp_header.seq_num = rtp_header.timestamp = 0; - rtp_header.ssrc = kSsrcs3[0]; - cricket::SetRtpHeader(data, sizeof(data), rtp_header); - rtc::CopyOnWriteBuffer packet(data, sizeof(data)); - channel_->OnPacketReceived(packet, /* packet_time_us */ -1); + webrtc::RtpPacket rtp_packet; + rtp_packet.SetPayloadType(GetEngineCodec("VP8").id); + rtp_packet.SetSsrc(kSsrcs3[0]); + ReceivePacketAndAdvanceTime(rtp_packet.Buffer(), /* packet_time_us */ -1); // Default receive stream should be created. ASSERT_EQ(1u, fake_call_->GetVideoReceiveStreams().size()); FakeVideoReceiveStream* recv_stream0 = @@ -6594,10 +7094,8 @@ TEST_F(WebRtcVideoChannelTest, EXPECT_EQ(kSsrcs3[0], recv_stream0->GetConfig().rtp.remote_ssrc); // Receive packet on a different unsignaled SSRC. - rtp_header.ssrc = kSsrcs3[1]; - cricket::SetRtpHeader(data, sizeof(data), rtp_header); - packet.SetData(data, sizeof(data)); - channel_->OnPacketReceived(packet, /* packet_time_us */ -1); + rtp_packet.SetSsrc(kSsrcs3[1]); + ReceivePacketAndAdvanceTime(rtp_packet.Buffer(), /* packet_time_us */ -1); // New default receive stream should be created, but old stream should remain. ASSERT_EQ(2u, fake_call_->GetVideoReceiveStreams().size()); EXPECT_EQ(recv_stream0, fake_call_->GetVideoReceiveStreams()[0]); @@ -8203,14 +8701,10 @@ TEST_F(WebRtcVideoChannelTest, EXPECT_FALSE(rtp_parameters.encodings[0].ssrc); // Receive VP8 packet. - uint8_t data[kMinRtpPacketLen]; - cricket::RtpHeader rtpHeader; - rtpHeader.payload_type = GetEngineCodec("VP8").id; - rtpHeader.seq_num = rtpHeader.timestamp = 0; - rtpHeader.ssrc = kIncomingUnsignalledSsrc; - cricket::SetRtpHeader(data, sizeof(data), rtpHeader); - rtc::CopyOnWriteBuffer packet(data, sizeof(data)); - channel_->OnPacketReceived(packet, /* packet_time_us */ -1); + webrtc::RtpPacket rtp_packet; + rtp_packet.SetPayloadType(GetEngineCodec("VP8").id); + rtp_packet.SetSsrc(kIncomingUnsignalledSsrc); + ReceivePacketAndAdvanceTime(rtp_packet.Buffer(), /* packet_time_us */ -1); // The |ssrc| member should still be unset. rtp_parameters = channel_->GetDefaultRtpReceiveParameters(); @@ -8268,6 +8762,48 @@ TEST_F(WebRtcVideoChannelTest, ConfiguresLocalSsrcOnExistingReceivers) { TestReceiverLocalSsrcConfiguration(true); } +TEST_F(WebRtcVideoChannelTest, Simulcast_QualityScalingNotAllowed) { + FakeVideoSendStream* stream = SetUpSimulcast(true, true); + EXPECT_FALSE(stream->GetEncoderConfig().is_quality_scaling_allowed); +} + +TEST_F(WebRtcVideoChannelTest, Singlecast_QualityScalingAllowed) { + FakeVideoSendStream* stream = SetUpSimulcast(false, true); + EXPECT_TRUE(stream->GetEncoderConfig().is_quality_scaling_allowed); +} + +TEST_F(WebRtcVideoChannelTest, + SinglecastScreenSharing_QualityScalingNotAllowed) { + SetUpSimulcast(false, true); + + webrtc::test::FrameForwarder frame_forwarder; + VideoOptions options; + options.is_screencast = true; + EXPECT_TRUE(channel_->SetVideoSend(last_ssrc_, &options, &frame_forwarder)); + // Fetch the latest stream since SetVideoSend() may recreate it if the + // screen content setting is changed. + FakeVideoSendStream* stream = fake_call_->GetVideoSendStreams().front(); + + EXPECT_FALSE(stream->GetEncoderConfig().is_quality_scaling_allowed); + EXPECT_TRUE(channel_->SetVideoSend(last_ssrc_, nullptr, nullptr)); +} + +TEST_F(WebRtcVideoChannelTest, + SimulcastSingleActiveStream_QualityScalingAllowed) { + FakeVideoSendStream* stream = SetUpSimulcast(true, false); + + webrtc::RtpParameters rtp_parameters = + channel_->GetRtpSendParameters(last_ssrc_); + ASSERT_EQ(3u, rtp_parameters.encodings.size()); + ASSERT_TRUE(rtp_parameters.encodings[0].active); + ASSERT_TRUE(rtp_parameters.encodings[1].active); + ASSERT_TRUE(rtp_parameters.encodings[2].active); + rtp_parameters.encodings[0].active = false; + rtp_parameters.encodings[1].active = false; + EXPECT_TRUE(channel_->SetRtpSendParameters(last_ssrc_, rtp_parameters).ok()); + EXPECT_TRUE(stream->GetEncoderConfig().is_quality_scaling_allowed); +} + class WebRtcVideoChannelSimulcastTest : public ::testing::Test { public: WebRtcVideoChannelSimulcastTest() diff --git a/media/engine/webrtc_voice_engine.cc b/media/engine/webrtc_voice_engine.cc index 2ed78b429b..aa80c8724a 100644 --- a/media/engine/webrtc_voice_engine.cc +++ b/media/engine/webrtc_voice_engine.cc @@ -11,6 +11,7 @@ #include "media/engine/webrtc_voice_engine.h" #include +#include #include #include #include @@ -46,6 +47,8 @@ #include "rtc_base/strings/audio_format_to_string.h" #include "rtc_base/strings/string_builder.h" #include "rtc_base/strings/string_format.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/third_party/base64/base64.h" #include "rtc_base/trace_event.h" #include "system_wrappers/include/metrics.h" @@ -207,7 +210,9 @@ bool IsEnabled(const webrtc::WebRtcKeyValueConfig& config, struct AdaptivePtimeConfig { bool enabled = false; webrtc::DataRate min_payload_bitrate = webrtc::DataRate::KilobitsPerSec(16); - webrtc::DataRate min_encoder_bitrate = webrtc::DataRate::KilobitsPerSec(12); + // Value is chosen to ensure FEC can be encoded, see LBRR_WB_MIN_RATE_BPS in + // libopus. + webrtc::DataRate min_encoder_bitrate = webrtc::DataRate::KilobitsPerSec(16); bool use_slow_adaptation = true; absl::optional audio_network_adaptor_config; @@ -235,6 +240,49 @@ struct AdaptivePtimeConfig { } }; +// TODO(tommi): Constructing a receive stream could be made simpler. +// Move some of this boiler plate code into the config structs themselves. +webrtc::AudioReceiveStream::Config BuildReceiveStreamConfig( + uint32_t remote_ssrc, + uint32_t local_ssrc, + bool use_transport_cc, + bool use_nack, + const std::vector& stream_ids, + const std::vector& extensions, + webrtc::Transport* rtcp_send_transport, + const rtc::scoped_refptr& decoder_factory, + const std::map& decoder_map, + absl::optional codec_pair_id, + size_t jitter_buffer_max_packets, + bool jitter_buffer_fast_accelerate, + int jitter_buffer_min_delay_ms, + bool jitter_buffer_enable_rtx_handling, + rtc::scoped_refptr frame_decryptor, + const webrtc::CryptoOptions& crypto_options, + rtc::scoped_refptr frame_transformer) { + webrtc::AudioReceiveStream::Config config; + config.rtp.remote_ssrc = remote_ssrc; + config.rtp.local_ssrc = local_ssrc; + config.rtp.transport_cc = use_transport_cc; + config.rtp.nack.rtp_history_ms = use_nack ? kNackRtpHistoryMs : 0; + if (!stream_ids.empty()) { + config.sync_group = stream_ids[0]; + } + config.rtp.extensions = extensions; + config.rtcp_send_transport = rtcp_send_transport; + config.decoder_factory = decoder_factory; + config.decoder_map = decoder_map; + config.codec_pair_id = codec_pair_id; + config.jitter_buffer_max_packets = jitter_buffer_max_packets; + config.jitter_buffer_fast_accelerate = jitter_buffer_fast_accelerate; + config.jitter_buffer_min_delay_ms = jitter_buffer_min_delay_ms; + config.jitter_buffer_enable_rtx_handling = jitter_buffer_enable_rtx_handling; + config.frame_decryptor = std::move(frame_decryptor); + config.crypto_options = crypto_options; + config.frame_transformer = std::move(frame_transformer); + return config; +} + } // namespace WebRtcVoiceEngine::WebRtcVoiceEngine( @@ -267,7 +315,7 @@ WebRtcVoiceEngine::WebRtcVoiceEngine( } WebRtcVoiceEngine::~WebRtcVoiceEngine() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_LOG(LS_INFO) << "WebRtcVoiceEngine::~WebRtcVoiceEngine"; if (initialized_) { StopAecDump(); @@ -281,7 +329,7 @@ WebRtcVoiceEngine::~WebRtcVoiceEngine() { } void WebRtcVoiceEngine::Init() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_LOG(LS_INFO) << "WebRtcVoiceEngine::Init"; // TaskQueue expects to be created/destroyed on the same thread. @@ -324,7 +372,7 @@ void WebRtcVoiceEngine::Init() { config.audio_device_module = adm_; if (audio_frame_processor_) config.async_audio_processing_factory = - new rtc::RefCountedObject( + rtc::make_ref_counted( *audio_frame_processor_, *task_queue_factory_); audio_state_ = webrtc::AudioState::Create(config); } @@ -362,7 +410,7 @@ void WebRtcVoiceEngine::Init() { rtc::scoped_refptr WebRtcVoiceEngine::GetAudioState() const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); return audio_state_; } @@ -371,13 +419,13 @@ VoiceMediaChannel* WebRtcVoiceEngine::CreateMediaChannel( const MediaConfig& config, const AudioOptions& options, const webrtc::CryptoOptions& crypto_options) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(call->worker_thread()); return new WebRtcVoiceMediaChannel(this, config, options, crypto_options, call); } bool WebRtcVoiceEngine::ApplyOptions(const AudioOptions& options_in) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_LOG(LS_INFO) << "WebRtcVoiceEngine::ApplyOptions: " << options_in.ToString(); AudioOptions options = options_in; // The options are modified below. @@ -620,22 +668,9 @@ WebRtcVoiceEngine::GetRtpHeaderExtensions() const { return result; } -void WebRtcVoiceEngine::RegisterChannel(WebRtcVoiceMediaChannel* channel) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - RTC_DCHECK(channel); - channels_.push_back(channel); -} - -void WebRtcVoiceEngine::UnregisterChannel(WebRtcVoiceMediaChannel* channel) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - auto it = absl::c_find(channels_, channel); - RTC_DCHECK(it != channels_.end()); - channels_.erase(it); -} - bool WebRtcVoiceEngine::StartAecDump(webrtc::FileWrapper file, int64_t max_size_bytes) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); webrtc::AudioProcessing* ap = apm(); if (!ap) { @@ -650,7 +685,7 @@ bool WebRtcVoiceEngine::StartAecDump(webrtc::FileWrapper file, } void WebRtcVoiceEngine::StopAecDump() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); webrtc::AudioProcessing* ap = apm(); if (ap) { ap->DetachAecDump(); @@ -661,18 +696,18 @@ void WebRtcVoiceEngine::StopAecDump() { } webrtc::AudioDeviceModule* WebRtcVoiceEngine::adm() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_DCHECK(adm_); return adm_.get(); } webrtc::AudioProcessing* WebRtcVoiceEngine::apm() const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); return apm_.get(); } webrtc::AudioState* WebRtcVoiceEngine::audio_state() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_DCHECK(audio_state_); return audio_state_.get(); } @@ -814,7 +849,7 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream WebRtcAudioSendStream& operator=(const WebRtcAudioSendStream&) = delete; ~WebRtcAudioSendStream() override { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); ClearSource(); call_->DestroyAudioSendStream(stream_); } @@ -826,7 +861,7 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream } void SetRtpExtensions(const std::vector& extensions) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); config_.rtp.extensions = extensions; rtp_parameters_.header_extensions = extensions; ReconfigureAudioSendStream(); @@ -838,7 +873,7 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream } void SetMid(const std::string& mid) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); if (config_.rtp.mid == mid) { return; } @@ -848,14 +883,14 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream void SetFrameEncryptor( rtc::scoped_refptr frame_encryptor) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); config_.frame_encryptor = frame_encryptor; ReconfigureAudioSendStream(); } void SetAudioNetworkAdaptorConfig( const absl::optional& audio_network_adaptor_config) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); if (audio_network_adaptor_config_from_options_ == audio_network_adaptor_config) { return; @@ -867,7 +902,7 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream } bool SetMaxSendBitrate(int bps) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_DCHECK(config_.send_codec_spec); RTC_DCHECK(audio_codec_spec_); auto send_rate = ComputeSendBitrate( @@ -890,32 +925,32 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream int payload_freq, int event, int duration_ms) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_DCHECK(stream_); return stream_->SendTelephoneEvent(payload_type, payload_freq, event, duration_ms); } void SetSend(bool send) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); send_ = send; UpdateSendState(); } void SetMuted(bool muted) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_DCHECK(stream_); stream_->SetMuted(muted); muted_ = muted; } bool muted() const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); return muted_; } webrtc::AudioSendStream::Stats GetStats(bool has_remote_tracks) const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_DCHECK(stream_); return stream_->GetStats(has_remote_tracks); } @@ -925,7 +960,7 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream // This method is called on the libjingle worker thread. // TODO(xians): Make sure Start() is called only once. void SetSource(AudioSource* source) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_DCHECK(source); if (source_) { RTC_DCHECK(source_ == source); @@ -940,7 +975,7 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream // callback will be received after this method. // This method is called on the libjingle worker thread. void ClearSource() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); if (source_) { source_->SetSink(nullptr); source_ = nullptr; @@ -976,7 +1011,7 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream // Callback from the |source_| when it is going away. In case Start() has // never been called, this callback won't be triggered. void OnClose() override { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); // Set |source_| to nullptr to make sure no more callback will get into // the source. source_ = nullptr; @@ -1043,14 +1078,14 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream void SetEncoderToPacketizerFrameTransformer( rtc::scoped_refptr frame_transformer) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); config_.frame_transformer = std::move(frame_transformer); ReconfigureAudioSendStream(); } private: void UpdateSendState() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_DCHECK(stream_); RTC_DCHECK_EQ(1UL, rtp_parameters_.encodings.size()); if (send_ && source_ != nullptr && rtp_parameters_.encodings[0].active) { @@ -1061,7 +1096,7 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream } void UpdateAllowedBitrateRange() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); // The order of precedence, from lowest to highest is: // - a reasonable default of 32kbps min/max // - fixed target bitrate from codec spec @@ -1093,7 +1128,7 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream void UpdateSendCodecSpec( const webrtc::AudioSendStream::Config::SendCodecSpec& send_codec_spec) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); config_.send_codec_spec = send_codec_spec; auto info = config_.encoder_factory->QueryAudioEncoder(send_codec_spec.format); @@ -1136,7 +1171,7 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream } void ReconfigureAudioSendStream() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); RTC_DCHECK(stream_); stream_->Reconfigure(config_); } @@ -1144,7 +1179,7 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream int NumPreferredChannels() const override { return num_encoded_channels_; } const AdaptivePtimeConfig adaptive_ptime_config_; - rtc::ThreadChecker worker_thread_checker_; + webrtc::SequenceChecker worker_thread_checker_; rtc::RaceChecker audio_capture_race_checker_; webrtc::Call* call_ = nullptr; webrtc::AudioSendStream::Config config_; @@ -1164,53 +1199,16 @@ class WebRtcVoiceMediaChannel::WebRtcAudioSendStream // TODO(webrtc:11717): Remove this once audio_network_adaptor in AudioOptions // has been removed. absl::optional audio_network_adaptor_config_from_options_; - int num_encoded_channels_ = -1; + std::atomic num_encoded_channels_{-1}; }; class WebRtcVoiceMediaChannel::WebRtcAudioReceiveStream { public: - WebRtcAudioReceiveStream( - uint32_t remote_ssrc, - uint32_t local_ssrc, - bool use_transport_cc, - bool use_nack, - const std::vector& stream_ids, - const std::vector& extensions, - webrtc::Call* call, - webrtc::Transport* rtcp_send_transport, - const rtc::scoped_refptr& decoder_factory, - const std::map& decoder_map, - absl::optional codec_pair_id, - size_t jitter_buffer_max_packets, - bool jitter_buffer_fast_accelerate, - int jitter_buffer_min_delay_ms, - bool jitter_buffer_enable_rtx_handling, - rtc::scoped_refptr frame_decryptor, - const webrtc::CryptoOptions& crypto_options, - rtc::scoped_refptr frame_transformer) - : call_(call), config_() { + WebRtcAudioReceiveStream(webrtc::AudioReceiveStream::Config config, + webrtc::Call* call) + : call_(call), stream_(call_->CreateAudioReceiveStream(config)) { RTC_DCHECK(call); - config_.rtp.remote_ssrc = remote_ssrc; - config_.rtp.local_ssrc = local_ssrc; - config_.rtp.transport_cc = use_transport_cc; - config_.rtp.nack.rtp_history_ms = use_nack ? kNackRtpHistoryMs : 0; - config_.rtp.extensions = extensions; - config_.rtcp_send_transport = rtcp_send_transport; - config_.jitter_buffer_max_packets = jitter_buffer_max_packets; - config_.jitter_buffer_fast_accelerate = jitter_buffer_fast_accelerate; - config_.jitter_buffer_min_delay_ms = jitter_buffer_min_delay_ms; - config_.jitter_buffer_enable_rtx_handling = - jitter_buffer_enable_rtx_handling; - if (!stream_ids.empty()) { - config_.sync_group = stream_ids[0]; - } - config_.decoder_factory = decoder_factory; - config_.decoder_map = decoder_map; - config_.codec_pair_id = codec_pair_id; - config_.frame_decryptor = frame_decryptor; - config_.crypto_options = crypto_options; - config_.frame_transformer = std::move(frame_transformer); - RecreateAudioReceiveStream(); + RTC_DCHECK(stream_); } WebRtcAudioReceiveStream() = delete; @@ -1218,72 +1216,46 @@ class WebRtcVoiceMediaChannel::WebRtcAudioReceiveStream { WebRtcAudioReceiveStream& operator=(const WebRtcAudioReceiveStream&) = delete; ~WebRtcAudioReceiveStream() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); call_->DestroyAudioReceiveStream(stream_); } - void SetFrameDecryptor( - rtc::scoped_refptr frame_decryptor) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - config_.frame_decryptor = frame_decryptor; - RecreateAudioReceiveStream(); + webrtc::AudioReceiveStream& stream() { + RTC_DCHECK(stream_); + return *stream_; } - void SetLocalSsrc(uint32_t local_ssrc) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - if (local_ssrc != config_.rtp.local_ssrc) { - config_.rtp.local_ssrc = local_ssrc; - RecreateAudioReceiveStream(); - } + void SetFrameDecryptor( + rtc::scoped_refptr frame_decryptor) { + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + stream_->SetFrameDecryptor(std::move(frame_decryptor)); } - void SetUseTransportCcAndRecreateStream(bool use_transport_cc, - bool use_nack) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - config_.rtp.transport_cc = use_transport_cc; - config_.rtp.nack.rtp_history_ms = use_nack ? kNackRtpHistoryMs : 0; - ReconfigureAudioReceiveStream(); + void SetUseTransportCc(bool use_transport_cc, bool use_nack) { + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + stream_->SetUseTransportCcAndNackHistory(use_transport_cc, + use_nack ? kNackRtpHistoryMs : 0); } - void SetRtpExtensionsAndRecreateStream( - const std::vector& extensions) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - config_.rtp.extensions = extensions; - RecreateAudioReceiveStream(); + void SetRtpExtensions(const std::vector& extensions) { + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + stream_->SetRtpExtensions(extensions); } // Set a new payload type -> decoder map. void SetDecoderMap(const std::map& decoder_map) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - config_.decoder_map = decoder_map; - ReconfigureAudioReceiveStream(); - } - - void MaybeRecreateAudioReceiveStream( - const std::vector& stream_ids) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - std::string sync_group; - if (!stream_ids.empty()) { - sync_group = stream_ids[0]; - } - if (config_.sync_group != sync_group) { - RTC_LOG(LS_INFO) << "Recreating AudioReceiveStream for SSRC=" - << config_.rtp.remote_ssrc - << " because of sync group change."; - config_.sync_group = sync_group; - RecreateAudioReceiveStream(); - } + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + stream_->SetDecoderMap(decoder_map); } webrtc::AudioReceiveStream::Stats GetStats( bool get_and_clear_legacy_stats) const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - RTC_DCHECK(stream_); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); return stream_->GetStats(get_and_clear_legacy_stats); } void SetRawAudioSink(std::unique_ptr sink) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); // Need to update the stream's sink first; once raw_audio_sink_ is // reassigned, whatever was in there before is destroyed. stream_->SetSink(sink.get()); @@ -1291,95 +1263,62 @@ class WebRtcVoiceMediaChannel::WebRtcAudioReceiveStream { } void SetOutputVolume(double volume) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - output_volume_ = volume; + RTC_DCHECK_RUN_ON(&worker_thread_checker_); stream_->SetGain(volume); } void SetPlayout(bool playout) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - RTC_DCHECK(stream_); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); if (playout) { stream_->Start(); } else { stream_->Stop(); } - playout_ = playout; } bool SetBaseMinimumPlayoutDelayMs(int delay_ms) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - RTC_DCHECK(stream_); - if (stream_->SetBaseMinimumPlayoutDelayMs(delay_ms)) { - // Memorize only valid delay because during stream recreation it will be - // passed to the constructor and it must be valid value. - config_.jitter_buffer_min_delay_ms = delay_ms; + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + if (stream_->SetBaseMinimumPlayoutDelayMs(delay_ms)) return true; - } else { - RTC_LOG(LS_ERROR) << "Failed to SetBaseMinimumPlayoutDelayMs" - " on AudioReceiveStream on SSRC=" - << config_.rtp.remote_ssrc - << " with delay_ms=" << delay_ms; - return false; - } + + RTC_LOG(LS_ERROR) << "Failed to SetBaseMinimumPlayoutDelayMs" + " on AudioReceiveStream on SSRC=" + << stream_->rtp_config().remote_ssrc + << " with delay_ms=" << delay_ms; + return false; } int GetBaseMinimumPlayoutDelayMs() const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - RTC_DCHECK(stream_); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); return stream_->GetBaseMinimumPlayoutDelayMs(); } std::vector GetSources() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - RTC_DCHECK(stream_); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); return stream_->GetSources(); } webrtc::RtpParameters GetRtpParameters() const { webrtc::RtpParameters rtp_parameters; rtp_parameters.encodings.emplace_back(); - rtp_parameters.encodings[0].ssrc = config_.rtp.remote_ssrc; - rtp_parameters.header_extensions = config_.rtp.extensions; - + const auto& config = stream_->rtp_config(); + rtp_parameters.encodings[0].ssrc = config.remote_ssrc; + rtp_parameters.header_extensions = config.extensions; return rtp_parameters; } void SetDepacketizerToDecoderFrameTransformer( rtc::scoped_refptr frame_transformer) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - config_.frame_transformer = std::move(frame_transformer); - ReconfigureAudioReceiveStream(); + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + stream_->SetDepacketizerToDecoderFrameTransformer(frame_transformer); } private: - void RecreateAudioReceiveStream() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - if (stream_) { - call_->DestroyAudioReceiveStream(stream_); - } - stream_ = call_->CreateAudioReceiveStream(config_); - RTC_CHECK(stream_); - stream_->SetGain(output_volume_); - SetPlayout(playout_); - stream_->SetSink(raw_audio_sink_.get()); - } - - void ReconfigureAudioReceiveStream() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - RTC_DCHECK(stream_); - stream_->Reconfigure(config_); - } - - rtc::ThreadChecker worker_thread_checker_; + webrtc::SequenceChecker worker_thread_checker_; webrtc::Call* call_ = nullptr; - webrtc::AudioReceiveStream::Config config_; - // The stream is owned by WebRtcAudioReceiveStream and may be reallocated if - // configuration changes. - webrtc::AudioReceiveStream* stream_ = nullptr; - bool playout_ = false; - float output_volume_ = 1.0; - std::unique_ptr raw_audio_sink_; + webrtc::AudioReceiveStream* const stream_ = nullptr; + std::unique_ptr raw_audio_sink_ + RTC_GUARDED_BY(worker_thread_checker_); }; WebRtcVoiceMediaChannel::WebRtcVoiceMediaChannel( @@ -1388,21 +1327,22 @@ WebRtcVoiceMediaChannel::WebRtcVoiceMediaChannel( const AudioOptions& options, const webrtc::CryptoOptions& crypto_options, webrtc::Call* call) - : VoiceMediaChannel(config), + : VoiceMediaChannel(config, call->network_thread()), + worker_thread_(call->worker_thread()), engine_(engine), call_(call), audio_config_(config.audio), crypto_options_(crypto_options), audio_red_for_opus_trial_enabled_( IsEnabled(call->trials(), "WebRTC-Audio-Red-For-Opus")) { + network_thread_checker_.Detach(); RTC_LOG(LS_VERBOSE) << "WebRtcVoiceMediaChannel::WebRtcVoiceMediaChannel"; RTC_DCHECK(call); - engine->RegisterChannel(this); SetOptions(options); } WebRtcVoiceMediaChannel::~WebRtcVoiceMediaChannel() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_LOG(LS_VERBOSE) << "WebRtcVoiceMediaChannel::~WebRtcVoiceMediaChannel"; // TODO(solenberg): Should be able to delete the streams directly, without // going through RemoveNnStream(), once stream objects handle @@ -1413,13 +1353,12 @@ WebRtcVoiceMediaChannel::~WebRtcVoiceMediaChannel() { while (!recv_streams_.empty()) { RemoveRecvStream(recv_streams_.begin()->first); } - engine()->UnregisterChannel(this); } bool WebRtcVoiceMediaChannel::SetSendParameters( const AudioSendParameters& params) { TRACE_EVENT0("webrtc", "WebRtcVoiceMediaChannel::SetSendParameters"); - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_LOG(LS_INFO) << "WebRtcVoiceMediaChannel::SetSendParameters: " << params.ToString(); // TODO(pthatcher): Refactor this to be more clean now that we have @@ -1465,7 +1404,7 @@ bool WebRtcVoiceMediaChannel::SetSendParameters( bool WebRtcVoiceMediaChannel::SetRecvParameters( const AudioRecvParameters& params) { TRACE_EVENT0("webrtc", "WebRtcVoiceMediaChannel::SetRecvParameters"); - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_LOG(LS_INFO) << "WebRtcVoiceMediaChannel::SetRecvParameters: " << params.ToString(); // TODO(pthatcher): Refactor this to be more clean now that we have @@ -1484,7 +1423,7 @@ bool WebRtcVoiceMediaChannel::SetRecvParameters( if (recv_rtp_extensions_ != filtered_extensions) { recv_rtp_extensions_.swap(filtered_extensions); for (auto& it : recv_streams_) { - it.second->SetRtpExtensionsAndRecreateStream(recv_rtp_extensions_); + it.second->SetRtpExtensions(recv_rtp_extensions_); } } return true; @@ -1492,7 +1431,7 @@ bool WebRtcVoiceMediaChannel::SetRecvParameters( webrtc::RtpParameters WebRtcVoiceMediaChannel::GetRtpSendParameters( uint32_t ssrc) const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); auto it = send_streams_.find(ssrc); if (it == send_streams_.end()) { RTC_LOG(LS_WARNING) << "Attempting to get RTP send parameters for stream " @@ -1513,7 +1452,7 @@ webrtc::RtpParameters WebRtcVoiceMediaChannel::GetRtpSendParameters( webrtc::RTCError WebRtcVoiceMediaChannel::SetRtpSendParameters( uint32_t ssrc, const webrtc::RtpParameters& parameters) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); auto it = send_streams_.find(ssrc); if (it == send_streams_.end()) { RTC_LOG(LS_WARNING) << "Attempting to set RTP send parameters for stream " @@ -1568,7 +1507,7 @@ webrtc::RTCError WebRtcVoiceMediaChannel::SetRtpSendParameters( webrtc::RtpParameters WebRtcVoiceMediaChannel::GetRtpReceiveParameters( uint32_t ssrc) const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); webrtc::RtpParameters rtp_params; auto it = recv_streams_.find(ssrc); if (it == recv_streams_.end()) { @@ -1588,7 +1527,7 @@ webrtc::RtpParameters WebRtcVoiceMediaChannel::GetRtpReceiveParameters( webrtc::RtpParameters WebRtcVoiceMediaChannel::GetDefaultRtpReceiveParameters() const { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); webrtc::RtpParameters rtp_params; if (!default_sink_) { RTC_LOG(LS_WARNING) << "Attempting to get RTP parameters for the default, " @@ -1605,7 +1544,7 @@ webrtc::RtpParameters WebRtcVoiceMediaChannel::GetDefaultRtpReceiveParameters() } bool WebRtcVoiceMediaChannel::SetOptions(const AudioOptions& options) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_LOG(LS_INFO) << "Setting voice channel options: " << options.ToString(); // We retain all of the existing options, and apply the given ones @@ -1631,7 +1570,7 @@ bool WebRtcVoiceMediaChannel::SetOptions(const AudioOptions& options) { bool WebRtcVoiceMediaChannel::SetRecvCodecs( const std::vector& codecs) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); // Set the payload types to be used for incoming media. RTC_LOG(LS_INFO) << "Setting receive voice codecs."; @@ -1689,21 +1628,22 @@ bool WebRtcVoiceMediaChannel::SetRecvCodecs( return true; } - if (playout_) { - // Receive codecs can not be changed while playing. So we temporarily - // pause playout. - ChangePlayout(false); - } + bool playout_enabled = playout_; + // Receive codecs can not be changed while playing. So we temporarily + // pause playout. + SetPlayout(false); + RTC_DCHECK(!playout_); decoder_map_ = std::move(decoder_map); for (auto& kv : recv_streams_) { kv.second->SetDecoderMap(decoder_map_); } + recv_codecs_ = codecs; - if (desired_playout_ && !playout_) { - ChangePlayout(desired_playout_); - } + SetPlayout(playout_enabled); + RTC_DCHECK_EQ(playout_, playout_enabled); + return true; } @@ -1712,7 +1652,7 @@ bool WebRtcVoiceMediaChannel::SetRecvCodecs( // and receive streams may be reconfigured based on the new settings. bool WebRtcVoiceMediaChannel::SetSendCodecs( const std::vector& codecs) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); dtmf_payload_type_ = absl::nullopt; dtmf_payload_freq_ = -1; @@ -1848,8 +1788,8 @@ bool WebRtcVoiceMediaChannel::SetSendCodecs( recv_transport_cc_enabled_ = send_codec_spec_->transport_cc_enabled; recv_nack_enabled_ = send_codec_spec_->nack_enabled; for (auto& kv : recv_streams_) { - kv.second->SetUseTransportCcAndRecreateStream(recv_transport_cc_enabled_, - recv_nack_enabled_); + kv.second->SetUseTransportCc(recv_transport_cc_enabled_, + recv_nack_enabled_); } } @@ -1858,13 +1798,8 @@ bool WebRtcVoiceMediaChannel::SetSendCodecs( } void WebRtcVoiceMediaChannel::SetPlayout(bool playout) { - desired_playout_ = playout; - return ChangePlayout(desired_playout_); -} - -void WebRtcVoiceMediaChannel::ChangePlayout(bool playout) { - TRACE_EVENT0("webrtc", "WebRtcVoiceMediaChannel::ChangePlayout"); - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + TRACE_EVENT0("webrtc", "WebRtcVoiceMediaChannel::SetPlayout"); + RTC_DCHECK_RUN_ON(worker_thread_); if (playout_ == playout) { return; } @@ -1907,7 +1842,7 @@ bool WebRtcVoiceMediaChannel::SetAudioSend(uint32_t ssrc, bool enable, const AudioOptions* options, AudioSource* source) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); // TODO(solenberg): The state change should be fully rolled back if any one of // these calls fail. if (!SetLocalSource(ssrc, source)) { @@ -1924,7 +1859,7 @@ bool WebRtcVoiceMediaChannel::SetAudioSend(uint32_t ssrc, bool WebRtcVoiceMediaChannel::AddSendStream(const StreamParams& sp) { TRACE_EVENT0("webrtc", "WebRtcVoiceMediaChannel::AddSendStream"); - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_LOG(LS_INFO) << "AddSendStream: " << sp.ToString(); uint32_t ssrc = sp.first_ssrc(); @@ -1950,10 +1885,8 @@ bool WebRtcVoiceMediaChannel::AddSendStream(const StreamParams& sp) { // same SSRC in order to send receiver reports. if (send_streams_.size() == 1) { receiver_reports_ssrc_ = ssrc; - for (const auto& kv : recv_streams_) { - // TODO(solenberg): Allow applications to set the RTCP SSRC of receive - // streams instead, so we can avoid reconfiguring the streams here. - kv.second->SetLocalSsrc(ssrc); + for (auto& kv : recv_streams_) { + call_->OnLocalSsrcUpdated(kv.second->stream(), ssrc); } } @@ -1963,7 +1896,7 @@ bool WebRtcVoiceMediaChannel::AddSendStream(const StreamParams& sp) { bool WebRtcVoiceMediaChannel::RemoveSendStream(uint32_t ssrc) { TRACE_EVENT0("webrtc", "WebRtcVoiceMediaChannel::RemoveSendStream"); - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_LOG(LS_INFO) << "RemoveSendStream: " << ssrc; auto it = send_streams_.find(ssrc); @@ -1989,7 +1922,7 @@ bool WebRtcVoiceMediaChannel::RemoveSendStream(uint32_t ssrc) { bool WebRtcVoiceMediaChannel::AddRecvStream(const StreamParams& sp) { TRACE_EVENT0("webrtc", "WebRtcVoiceMediaChannel::AddRecvStream"); - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_LOG(LS_INFO) << "AddRecvStream: " << sp.ToString(); if (!sp.has_ssrcs()) { @@ -2006,9 +1939,12 @@ bool WebRtcVoiceMediaChannel::AddRecvStream(const StreamParams& sp) { const uint32_t ssrc = sp.first_ssrc(); // If this stream was previously received unsignaled, we promote it, possibly - // recreating the AudioReceiveStream, if stream ids have changed. + // updating the sync group if stream ids have changed. if (MaybeDeregisterUnsignaledRecvStream(ssrc)) { - recv_streams_[ssrc]->MaybeRecreateAudioReceiveStream(sp.stream_ids()); + auto stream_ids = sp.stream_ids(); + std::string sync_group = stream_ids.empty() ? std::string() : stream_ids[0]; + call_->OnUpdateSyncGroup(recv_streams_[ssrc]->stream(), + std::move(sync_group)); return true; } @@ -2018,16 +1954,18 @@ bool WebRtcVoiceMediaChannel::AddRecvStream(const StreamParams& sp) { } // Create a new channel for receiving audio data. + auto config = BuildReceiveStreamConfig( + ssrc, receiver_reports_ssrc_, recv_transport_cc_enabled_, + recv_nack_enabled_, sp.stream_ids(), recv_rtp_extensions_, this, + engine()->decoder_factory_, decoder_map_, codec_pair_id_, + engine()->audio_jitter_buffer_max_packets_, + engine()->audio_jitter_buffer_fast_accelerate_, + engine()->audio_jitter_buffer_min_delay_ms_, + engine()->audio_jitter_buffer_enable_rtx_handling_, + unsignaled_frame_decryptor_, crypto_options_, nullptr); + recv_streams_.insert(std::make_pair( - ssrc, new WebRtcAudioReceiveStream( - ssrc, receiver_reports_ssrc_, recv_transport_cc_enabled_, - recv_nack_enabled_, sp.stream_ids(), recv_rtp_extensions_, - call_, this, engine()->decoder_factory_, decoder_map_, - codec_pair_id_, engine()->audio_jitter_buffer_max_packets_, - engine()->audio_jitter_buffer_fast_accelerate_, - engine()->audio_jitter_buffer_min_delay_ms_, - engine()->audio_jitter_buffer_enable_rtx_handling_, - unsignaled_frame_decryptor_, crypto_options_, nullptr))); + ssrc, new WebRtcAudioReceiveStream(std::move(config), call_))); recv_streams_[ssrc]->SetPlayout(playout_); return true; @@ -2035,7 +1973,7 @@ bool WebRtcVoiceMediaChannel::AddRecvStream(const StreamParams& sp) { bool WebRtcVoiceMediaChannel::RemoveRecvStream(uint32_t ssrc) { TRACE_EVENT0("webrtc", "WebRtcVoiceMediaChannel::RemoveRecvStream"); - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_LOG(LS_INFO) << "RemoveRecvStream: " << ssrc; const auto it = recv_streams_.find(ssrc); @@ -2054,7 +1992,7 @@ bool WebRtcVoiceMediaChannel::RemoveRecvStream(uint32_t ssrc) { } void WebRtcVoiceMediaChannel::ResetUnsignaledRecvStream() { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_LOG(LS_INFO) << "ResetUnsignaledRecvStream."; unsignaled_stream_params_ = StreamParams(); // Create a copy since RemoveRecvStream will modify |unsignaled_recv_ssrcs_|. @@ -2064,6 +2002,13 @@ void WebRtcVoiceMediaChannel::ResetUnsignaledRecvStream() { } } +// Not implemented. +// TODO(https://crbug.com/webrtc/12676): Implement a fix for the unsignalled +// SSRC race that can happen when an m= section goes from receiving to not +// receiving. +void WebRtcVoiceMediaChannel::OnDemuxerCriteriaUpdatePending() {} +void WebRtcVoiceMediaChannel::OnDemuxerCriteriaUpdateComplete() {} + bool WebRtcVoiceMediaChannel::SetLocalSource(uint32_t ssrc, AudioSource* source) { auto it = send_streams_.find(ssrc); @@ -2088,7 +2033,7 @@ bool WebRtcVoiceMediaChannel::SetLocalSource(uint32_t ssrc, } bool WebRtcVoiceMediaChannel::SetOutputVolume(uint32_t ssrc, double volume) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_LOG(LS_INFO) << rtc::StringFormat("WRVMC::%s({ssrc=%u}, {volume=%.2f})", __func__, ssrc, volume); const auto it = recv_streams_.find(ssrc); @@ -2106,7 +2051,7 @@ bool WebRtcVoiceMediaChannel::SetOutputVolume(uint32_t ssrc, double volume) { } bool WebRtcVoiceMediaChannel::SetDefaultOutputVolume(double volume) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); default_recv_volume_ = volume; for (uint32_t ssrc : unsignaled_recv_ssrcs_) { const auto it = recv_streams_.find(ssrc); @@ -2123,7 +2068,7 @@ bool WebRtcVoiceMediaChannel::SetDefaultOutputVolume(double volume) { bool WebRtcVoiceMediaChannel::SetBaseMinimumPlayoutDelayMs(uint32_t ssrc, int delay_ms) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); std::vector ssrcs(1, ssrc); // SSRC of 0 represents the default receive stream. if (ssrc == 0) { @@ -2166,7 +2111,7 @@ bool WebRtcVoiceMediaChannel::CanInsertDtmf() { void WebRtcVoiceMediaChannel::SetFrameDecryptor( uint32_t ssrc, rtc::scoped_refptr frame_decryptor) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); auto matching_stream = recv_streams_.find(ssrc); if (matching_stream != recv_streams_.end()) { matching_stream->second->SetFrameDecryptor(frame_decryptor); @@ -2180,7 +2125,7 @@ void WebRtcVoiceMediaChannel::SetFrameDecryptor( void WebRtcVoiceMediaChannel::SetFrameEncryptor( uint32_t ssrc, rtc::scoped_refptr frame_encryptor) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); auto matching_stream = send_streams_.find(ssrc); if (matching_stream != send_streams_.end()) { matching_stream->second->SetFrameEncryptor(frame_encryptor); @@ -2190,7 +2135,7 @@ void WebRtcVoiceMediaChannel::SetFrameEncryptor( bool WebRtcVoiceMediaChannel::InsertDtmf(uint32_t ssrc, int event, int duration) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_LOG(LS_INFO) << "WebRtcVoiceMediaChannel::InsertDtmf"; if (!CanInsertDtmf()) { return false; @@ -2213,78 +2158,104 @@ bool WebRtcVoiceMediaChannel::InsertDtmf(uint32_t ssrc, void WebRtcVoiceMediaChannel::OnPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - - webrtc::PacketReceiver::DeliveryStatus delivery_result = - call_->Receiver()->DeliverPacket(webrtc::MediaType::AUDIO, packet, - packet_time_us); - - if (delivery_result != webrtc::PacketReceiver::DELIVERY_UNKNOWN_SSRC) { - return; - } - - // Create an unsignaled receive stream for this previously not received ssrc. - // If there already is N unsignaled receive streams, delete the oldest. - // See: https://bugs.chromium.org/p/webrtc/issues/detail?id=5208 - uint32_t ssrc = 0; - if (!GetRtpSsrc(packet.cdata(), packet.size(), &ssrc)) { - return; - } - RTC_DCHECK(!absl::c_linear_search(unsignaled_recv_ssrcs_, ssrc)); + RTC_DCHECK_RUN_ON(&network_thread_checker_); + // TODO(bugs.webrtc.org/11993): This code is very similar to what + // WebRtcVideoChannel::OnPacketReceived does. For maintainability and + // consistency it would be good to move the interaction with call_->Receiver() + // to a common implementation and provide a callback on the worker thread + // for the exception case (DELIVERY_UNKNOWN_SSRC) and how retry is attempted. + worker_thread_->PostTask(ToQueuedTask(task_safety_, [this, packet, + packet_time_us] { + RTC_DCHECK_RUN_ON(worker_thread_); + + webrtc::PacketReceiver::DeliveryStatus delivery_result = + call_->Receiver()->DeliverPacket(webrtc::MediaType::AUDIO, packet, + packet_time_us); + + if (delivery_result != webrtc::PacketReceiver::DELIVERY_UNKNOWN_SSRC) { + return; + } - // Add new stream. - StreamParams sp = unsignaled_stream_params_; - sp.ssrcs.push_back(ssrc); - RTC_LOG(LS_INFO) << "Creating unsignaled receive stream for SSRC=" << ssrc; - if (!AddRecvStream(sp)) { - RTC_LOG(LS_WARNING) << "Could not create unsignaled receive stream."; - return; - } - unsignaled_recv_ssrcs_.push_back(ssrc); - RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.NumOfUnsignaledStreams", - unsignaled_recv_ssrcs_.size(), 1, 100, 101); - - // Remove oldest unsignaled stream, if we have too many. - if (unsignaled_recv_ssrcs_.size() > kMaxUnsignaledRecvStreams) { - uint32_t remove_ssrc = unsignaled_recv_ssrcs_.front(); - RTC_DLOG(LS_INFO) << "Removing unsignaled receive stream with SSRC=" - << remove_ssrc; - RemoveRecvStream(remove_ssrc); - } - RTC_DCHECK_GE(kMaxUnsignaledRecvStreams, unsignaled_recv_ssrcs_.size()); - - SetOutputVolume(ssrc, default_recv_volume_); - SetBaseMinimumPlayoutDelayMs(ssrc, default_recv_base_minimum_delay_ms_); - - // The default sink can only be attached to one stream at a time, so we hook - // it up to the *latest* unsignaled stream we've seen, in order to support the - // case where the SSRC of one unsignaled stream changes. - if (default_sink_) { - for (uint32_t drop_ssrc : unsignaled_recv_ssrcs_) { - auto it = recv_streams_.find(drop_ssrc); - it->second->SetRawAudioSink(nullptr); + // Create an unsignaled receive stream for this previously not received + // ssrc. If there already is N unsignaled receive streams, delete the + // oldest. See: https://bugs.chromium.org/p/webrtc/issues/detail?id=5208 + uint32_t ssrc = 0; + if (!GetRtpSsrc(packet.cdata(), packet.size(), &ssrc)) { + return; + } + RTC_DCHECK(!absl::c_linear_search(unsignaled_recv_ssrcs_, ssrc)); + + // Add new stream. + StreamParams sp = unsignaled_stream_params_; + sp.ssrcs.push_back(ssrc); + RTC_LOG(LS_INFO) << "Creating unsignaled receive stream for SSRC=" << ssrc; + if (!AddRecvStream(sp)) { + RTC_LOG(LS_WARNING) << "Could not create unsignaled receive stream."; + return; + } + unsignaled_recv_ssrcs_.push_back(ssrc); + RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.NumOfUnsignaledStreams", + unsignaled_recv_ssrcs_.size(), 1, 100, 101); + + // Remove oldest unsignaled stream, if we have too many. + if (unsignaled_recv_ssrcs_.size() > kMaxUnsignaledRecvStreams) { + uint32_t remove_ssrc = unsignaled_recv_ssrcs_.front(); + RTC_DLOG(LS_INFO) << "Removing unsignaled receive stream with SSRC=" + << remove_ssrc; + RemoveRecvStream(remove_ssrc); + } + RTC_DCHECK_GE(kMaxUnsignaledRecvStreams, unsignaled_recv_ssrcs_.size()); + + SetOutputVolume(ssrc, default_recv_volume_); + SetBaseMinimumPlayoutDelayMs(ssrc, default_recv_base_minimum_delay_ms_); + + // The default sink can only be attached to one stream at a time, so we hook + // it up to the *latest* unsignaled stream we've seen, in order to support + // the case where the SSRC of one unsignaled stream changes. + if (default_sink_) { + for (uint32_t drop_ssrc : unsignaled_recv_ssrcs_) { + auto it = recv_streams_.find(drop_ssrc); + it->second->SetRawAudioSink(nullptr); + } + std::unique_ptr proxy_sink( + new ProxySink(default_sink_.get())); + SetRawAudioSink(ssrc, std::move(proxy_sink)); } - std::unique_ptr proxy_sink( - new ProxySink(default_sink_.get())); - SetRawAudioSink(ssrc, std::move(proxy_sink)); - } - delivery_result = call_->Receiver()->DeliverPacket(webrtc::MediaType::AUDIO, - packet, packet_time_us); - RTC_DCHECK_NE(webrtc::PacketReceiver::DELIVERY_UNKNOWN_SSRC, delivery_result); + delivery_result = call_->Receiver()->DeliverPacket(webrtc::MediaType::AUDIO, + packet, packet_time_us); + RTC_DCHECK_NE(webrtc::PacketReceiver::DELIVERY_UNKNOWN_SSRC, + delivery_result); + })); +} + +void WebRtcVoiceMediaChannel::OnPacketSent(const rtc::SentPacket& sent_packet) { + RTC_DCHECK_RUN_ON(&network_thread_checker_); + // TODO(tommi): We shouldn't need to go through call_ to deliver this + // notification. We should already have direct access to + // video_send_delay_stats_ and transport_send_ptr_ via `stream_`. + // So we should be able to remove OnSentPacket from Call and handle this per + // channel instead. At the moment Call::OnSentPacket calls OnSentPacket for + // the video stats, which we should be able to skip. + call_->OnSentPacket(sent_packet); } void WebRtcVoiceMediaChannel::OnNetworkRouteChanged( const std::string& transport_name, const rtc::NetworkRoute& network_route) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); - call_->GetTransportControllerSend()->OnNetworkRouteChanged(transport_name, - network_route); + RTC_DCHECK_RUN_ON(&network_thread_checker_); + call_->OnAudioTransportOverheadChanged(network_route.packet_overhead); + + worker_thread_->PostTask(ToQueuedTask( + task_safety_, [this, name = transport_name, route = network_route] { + RTC_DCHECK_RUN_ON(worker_thread_); + call_->GetTransportControllerSend()->OnNetworkRouteChanged(name, route); + })); } bool WebRtcVoiceMediaChannel::MuteStream(uint32_t ssrc, bool muted) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); const auto it = send_streams_.find(ssrc); if (it == send_streams_.end()) { RTC_LOG(LS_WARNING) << "The specified ssrc " << ssrc << " is not in use."; @@ -2322,7 +2293,7 @@ bool WebRtcVoiceMediaChannel::SetMaxSendBitrate(int bps) { } void WebRtcVoiceMediaChannel::OnReadyToSend(bool ready) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&network_thread_checker_); RTC_LOG(LS_VERBOSE) << "OnReadyToSend: " << (ready ? "Ready." : "Not ready."); call_->SignalChannelNetworkState( webrtc::MediaType::AUDIO, @@ -2332,7 +2303,7 @@ void WebRtcVoiceMediaChannel::OnReadyToSend(bool ready) { bool WebRtcVoiceMediaChannel::GetStats(VoiceMediaInfo* info, bool get_and_clear_legacy_stats) { TRACE_EVENT0("webrtc", "WebRtcVoiceMediaChannel::GetStats"); - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_DCHECK(info); // Get SSRC and stats for each sender. @@ -2349,6 +2320,7 @@ bool WebRtcVoiceMediaChannel::GetStats(VoiceMediaInfo* info, sinfo.retransmitted_packets_sent = stats.retransmitted_packets_sent; sinfo.packets_lost = stats.packets_lost; sinfo.fraction_lost = stats.fraction_lost; + sinfo.nacks_rcvd = stats.nacks_rcvd; sinfo.codec_name = stats.codec_name; sinfo.codec_payload_type = stats.codec_payload_type; sinfo.jitter_ms = stats.jitter_ms; @@ -2440,6 +2412,17 @@ bool WebRtcVoiceMediaChannel::GetStats(VoiceMediaInfo* info, stats.relative_packet_arrival_delay_seconds; rinfo.interruption_count = stats.interruption_count; rinfo.total_interruption_duration_ms = stats.total_interruption_duration_ms; + rinfo.last_sender_report_timestamp_ms = + stats.last_sender_report_timestamp_ms; + rinfo.last_sender_report_remote_timestamp_ms = + stats.last_sender_report_remote_timestamp_ms; + rinfo.sender_reports_packets_sent = stats.sender_reports_packets_sent; + rinfo.sender_reports_bytes_sent = stats.sender_reports_bytes_sent; + rinfo.sender_reports_reports_count = stats.sender_reports_reports_count; + + if (recv_nack_enabled_) { + rinfo.nacks_sent = stats.nacks_sent; + } info->receivers.push_back(rinfo); } @@ -2463,7 +2446,7 @@ bool WebRtcVoiceMediaChannel::GetStats(VoiceMediaInfo* info, void WebRtcVoiceMediaChannel::SetRawAudioSink( uint32_t ssrc, std::unique_ptr sink) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_LOG(LS_VERBOSE) << "WebRtcVoiceMediaChannel::SetRawAudioSink: ssrc:" << ssrc << " " << (sink ? "(ptr)" : "NULL"); const auto it = recv_streams_.find(ssrc); @@ -2476,7 +2459,7 @@ void WebRtcVoiceMediaChannel::SetRawAudioSink( void WebRtcVoiceMediaChannel::SetDefaultRawAudioSink( std::unique_ptr sink) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_LOG(LS_VERBOSE) << "WebRtcVoiceMediaChannel::SetDefaultRawAudioSink:"; if (!unsignaled_recv_ssrcs_.empty()) { std::unique_ptr proxy_sink( @@ -2500,7 +2483,7 @@ std::vector WebRtcVoiceMediaChannel::GetSources( void WebRtcVoiceMediaChannel::SetEncoderToPacketizerFrameTransformer( uint32_t ssrc, rtc::scoped_refptr frame_transformer) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); auto matching_stream = send_streams_.find(ssrc); if (matching_stream == send_streams_.end()) { RTC_LOG(LS_INFO) << "Attempting to set frame transformer for SSRC:" << ssrc @@ -2514,7 +2497,7 @@ void WebRtcVoiceMediaChannel::SetEncoderToPacketizerFrameTransformer( void WebRtcVoiceMediaChannel::SetDepacketizerToDecoderFrameTransformer( uint32_t ssrc, rtc::scoped_refptr frame_transformer) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); auto matching_stream = recv_streams_.find(ssrc); if (matching_stream == recv_streams_.end()) { RTC_LOG(LS_INFO) << "Attempting to set frame transformer for SSRC:" << ssrc @@ -2525,9 +2508,21 @@ void WebRtcVoiceMediaChannel::SetDepacketizerToDecoderFrameTransformer( std::move(frame_transformer)); } +bool WebRtcVoiceMediaChannel::SendRtp(const uint8_t* data, + size_t len, + const webrtc::PacketOptions& options) { + MediaChannel::SendRtp(data, len, options); + return true; +} + +bool WebRtcVoiceMediaChannel::SendRtcp(const uint8_t* data, size_t len) { + MediaChannel::SendRtcp(data, len); + return true; +} + bool WebRtcVoiceMediaChannel::MaybeDeregisterUnsignaledRecvStream( uint32_t ssrc) { - RTC_DCHECK(worker_thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread_); auto it = absl::c_find(unsignaled_recv_ssrcs_, ssrc); if (it != unsignaled_recv_ssrcs_.end()) { unsignaled_recv_ssrcs_.erase(it); diff --git a/media/engine/webrtc_voice_engine.h b/media/engine/webrtc_voice_engine.h index c2da3b9df0..147688b0e0 100644 --- a/media/engine/webrtc_voice_engine.h +++ b/media/engine/webrtc_voice_engine.h @@ -18,6 +18,7 @@ #include "api/audio_codecs/audio_encoder_factory.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_factory.h" #include "api/transport/rtp/rtp_source.h" #include "api/transport/webrtc_key_value_config.h" @@ -29,7 +30,7 @@ #include "rtc_base/buffer.h" #include "rtc_base/network_route.h" #include "rtc_base/task_queue.h" -#include "rtc_base/thread_checker.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" namespace webrtc { class AudioFrameProcessor; @@ -37,8 +38,6 @@ class AudioFrameProcessor; namespace cricket { -class AudioDeviceModule; -class AudioMixer; class AudioSource; class WebRtcVoiceMediaChannel; @@ -79,12 +78,6 @@ class WebRtcVoiceEngine final : public VoiceEngineInterface { std::vector GetRtpHeaderExtensions() const override; - // For tracking WebRtc channels. Needed because we have to pause them - // all when switching devices. - // May only be called by WebRtcVoiceMediaChannel. - void RegisterChannel(WebRtcVoiceMediaChannel* channel); - void UnregisterChannel(WebRtcVoiceMediaChannel* channel); - // Starts AEC dump using an existing file. A maximum file size in bytes can be // specified. When the maximum file size is reached, logging is stopped and // the file is closed. If max_size_bytes is set to <= 0, no limit will be @@ -112,8 +105,8 @@ class WebRtcVoiceEngine final : public VoiceEngineInterface { std::vector CollectCodecs( const std::vector& specs) const; - rtc::ThreadChecker signal_thread_checker_; - rtc::ThreadChecker worker_thread_checker_; + webrtc::SequenceChecker signal_thread_checker_; + webrtc::SequenceChecker worker_thread_checker_; // The audio device module. rtc::scoped_refptr adm_; @@ -128,7 +121,6 @@ class WebRtcVoiceEngine final : public VoiceEngineInterface { rtc::scoped_refptr audio_state_; std::vector send_codecs_; std::vector recv_codecs_; - std::vector channels_; bool is_dumping_aec_ = false; bool initialized_ = false; @@ -186,6 +178,8 @@ class WebRtcVoiceMediaChannel final : public VoiceMediaChannel, bool AddRecvStream(const StreamParams& sp) override; bool RemoveRecvStream(uint32_t ssrc) override; void ResetUnsignaledRecvStream() override; + void OnDemuxerCriteriaUpdatePending() override; + void OnDemuxerCriteriaUpdateComplete() override; // E2EE Frame API // Set a frame decryptor to a particular ssrc that will intercept all @@ -214,6 +208,7 @@ class WebRtcVoiceMediaChannel final : public VoiceMediaChannel, void OnPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) override; + void OnPacketSent(const rtc::SentPacket& sent_packet) override; void OnNetworkRouteChanged(const std::string& transport_name, const rtc::NetworkRoute& network_route) override; void OnReadyToSend(bool ready) override; @@ -244,29 +239,9 @@ class WebRtcVoiceMediaChannel final : public VoiceMediaChannel, // implements Transport interface bool SendRtp(const uint8_t* data, size_t len, - const webrtc::PacketOptions& options) override { - rtc::CopyOnWriteBuffer packet(data, len, kMaxRtpPacketLen); - rtc::PacketOptions rtc_options; - rtc_options.packet_id = options.packet_id; - if (DscpEnabled()) { - rtc_options.dscp = PreferredDscp(); - } - rtc_options.info_signaled_after_sent.included_in_feedback = - options.included_in_feedback; - rtc_options.info_signaled_after_sent.included_in_allocation = - options.included_in_allocation; - return VoiceMediaChannel::SendPacket(&packet, rtc_options); - } - - bool SendRtcp(const uint8_t* data, size_t len) override { - rtc::CopyOnWriteBuffer packet(data, len, kMaxRtpPacketLen); - rtc::PacketOptions rtc_options; - if (DscpEnabled()) { - rtc_options.dscp = PreferredDscp(); - } - - return VoiceMediaChannel::SendRtcp(&packet, rtc_options); - } + const webrtc::PacketOptions& options) override; + + bool SendRtcp(const uint8_t* data, size_t len) override; private: bool SetOptions(const AudioOptions& options); @@ -276,7 +251,6 @@ class WebRtcVoiceMediaChannel final : public VoiceMediaChannel, bool MuteStream(uint32_t ssrc, bool mute); WebRtcVoiceEngine* engine() { return engine_; } - void ChangePlayout(bool playout); int CreateVoEChannel(); bool DeleteVoEChannel(int channel); bool SetMaxSendBitrate(int bps); @@ -285,7 +259,9 @@ class WebRtcVoiceMediaChannel final : public VoiceMediaChannel, // unsignaled anymore (i.e. it is now removed, or signaled), and return true. bool MaybeDeregisterUnsignaledRecvStream(uint32_t ssrc); - rtc::ThreadChecker worker_thread_checker_; + webrtc::TaskQueueBase* const worker_thread_; + webrtc::ScopedTaskSafety task_safety_; + webrtc::SequenceChecker network_thread_checker_; WebRtcVoiceEngine* const engine_ = nullptr; std::vector send_codecs_; @@ -301,7 +277,6 @@ class WebRtcVoiceMediaChannel final : public VoiceMediaChannel, int dtmf_payload_freq_ = -1; bool recv_transport_cc_enabled_ = false; bool recv_nack_enabled_ = false; - bool desired_playout_ = false; bool playout_ = false; bool send_ = false; webrtc::Call* const call_ = nullptr; diff --git a/media/engine/webrtc_voice_engine_unittest.cc b/media/engine/webrtc_voice_engine_unittest.cc index 87678be087..c570b1a03a 100644 --- a/media/engine/webrtc_voice_engine_unittest.cc +++ b/media/engine/webrtc_voice_engine_unittest.cc @@ -277,6 +277,7 @@ class WebRtcVoiceEngineTestFake : public ::testing::TestWithParam { void DeliverPacket(const void* data, int len) { rtc::CopyOnWriteBuffer packet(reinterpret_cast(data), len); channel_->OnPacketReceived(packet, /* packet_time_us */ -1); + rtc::Thread::Current()->ProcessMessages(0); } void TearDown() override { delete channel_; } @@ -1217,7 +1218,7 @@ TEST_P(WebRtcVoiceEngineTestFake, SetRtpParametersAdaptivePtime) { parameters.encodings[0].adaptive_ptime = true; EXPECT_TRUE(channel_->SetRtpSendParameters(kSsrcX, parameters).ok()); EXPECT_TRUE(GetAudioNetworkAdaptorConfig(kSsrcX)); - EXPECT_EQ(12000, GetSendStreamConfig(kSsrcX).min_bitrate_bps); + EXPECT_EQ(16000, GetSendStreamConfig(kSsrcX).min_bitrate_bps); parameters.encodings[0].adaptive_ptime = false; EXPECT_TRUE(channel_->SetRtpSendParameters(kSsrcX, parameters).ok()); @@ -2844,7 +2845,7 @@ TEST_P(WebRtcVoiceEngineTestFake, AddRecvStreamAfterUnsignaled_NoRecreate) { EXPECT_EQ(audio_receive_stream_id, streams.front()->id()); } -TEST_P(WebRtcVoiceEngineTestFake, AddRecvStreamAfterUnsignaled_Recreate) { +TEST_P(WebRtcVoiceEngineTestFake, AddRecvStreamAfterUnsignaled_Updates) { EXPECT_TRUE(SetupChannel()); // Spawn unsignaled stream with SSRC=1. @@ -2853,17 +2854,26 @@ TEST_P(WebRtcVoiceEngineTestFake, AddRecvStreamAfterUnsignaled_Recreate) { EXPECT_TRUE( GetRecvStream(1).VerifyLastPacket(kPcmuFrame, sizeof(kPcmuFrame))); - // Verify that the underlying stream object in Call *is* recreated when a + // Verify that the underlying stream object in Call gets updated when a // stream with SSRC=1 is added, and which has changed stream parameters. const auto& streams = call_.GetAudioReceiveStreams(); EXPECT_EQ(1u, streams.size()); + // The sync_group id should be empty. + EXPECT_TRUE(streams.front()->GetConfig().sync_group.empty()); + + const std::string new_stream_id("stream_id"); int audio_receive_stream_id = streams.front()->id(); cricket::StreamParams stream_params; stream_params.ssrcs.push_back(1); - stream_params.set_stream_ids({"stream_id"}); + stream_params.set_stream_ids({new_stream_id}); + EXPECT_TRUE(channel_->AddRecvStream(stream_params)); EXPECT_EQ(1u, streams.size()); - EXPECT_NE(audio_receive_stream_id, streams.front()->id()); + // The audio receive stream should not have been recreated. + EXPECT_EQ(audio_receive_stream_id, streams.front()->id()); + + // The sync_group id should now match with the new stream params. + EXPECT_EQ(new_stream_id, streams.front()->GetConfig().sync_group); } // Test that AddRecvStream creates new stream. @@ -3202,6 +3212,7 @@ TEST_P(WebRtcVoiceEngineTestFake, TestSetDscpOptions) { channel->SetInterface(&network_interface); // Default value when DSCP is disabled should be DSCP_DEFAULT. EXPECT_EQ(rtc::DSCP_DEFAULT, network_interface.dscp()); + channel->SetInterface(nullptr); config.enable_dscp = true; channel.reset(static_cast( @@ -3228,6 +3239,7 @@ TEST_P(WebRtcVoiceEngineTestFake, TestSetDscpOptions) { const uint8_t kData[10] = {0}; EXPECT_TRUE(channel->SendRtcp(kData, sizeof(kData))); EXPECT_EQ(rtc::DSCP_CS1, network_interface.options().dscp); + channel->SetInterface(nullptr); // Verify that setting the option to false resets the // DiffServCodePoint. @@ -3443,6 +3455,8 @@ TEST_P(WebRtcVoiceEngineTestFake, DeliverAudioPacket_Call) { call_.GetAudioReceiveStream(kAudioSsrc); EXPECT_EQ(0, s->received_packets()); channel_->OnPacketReceived(kPcmuPacket, /* packet_time_us */ -1); + rtc::Thread::Current()->ProcessMessages(0); + EXPECT_EQ(1, s->received_packets()); } diff --git a/media/sctp/OWNERS b/media/sctp/OWNERS index a32f041ac8..da2f0178a8 100644 --- a/media/sctp/OWNERS +++ b/media/sctp/OWNERS @@ -1 +1,3 @@ +boivie@webrtc.org deadbeef@webrtc.org +orphis@webrtc.org diff --git a/media/sctp/dcsctp_transport.cc b/media/sctp/dcsctp_transport.cc new file mode 100644 index 0000000000..3b89af1ec2 --- /dev/null +++ b/media/sctp/dcsctp_transport.cc @@ -0,0 +1,524 @@ +/* + * Copyright 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "media/sctp/dcsctp_transport.h" + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "media/base/media_channel.h" +#include "net/dcsctp/public/dcsctp_socket_factory.h" +#include "net/dcsctp/public/packet_observer.h" +#include "net/dcsctp/public/text_pcap_packet_observer.h" +#include "net/dcsctp/public/types.h" +#include "p2p/base/packet_transport_internal.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" +#include "rtc_base/thread.h" +#include "rtc_base/trace_event.h" +#include "system_wrappers/include/clock.h" + +namespace webrtc { + +namespace { + +enum class WebrtcPPID : dcsctp::PPID::UnderlyingType { + // https://www.rfc-editor.org/rfc/rfc8832.html#section-8.1 + kDCEP = 50, + // https://www.rfc-editor.org/rfc/rfc8831.html#section-8 + kString = 51, + kBinaryPartial = 52, // Deprecated + kBinary = 53, + kStringPartial = 54, // Deprecated + kStringEmpty = 56, + kBinaryEmpty = 57, +}; + +WebrtcPPID ToPPID(DataMessageType message_type, size_t size) { + switch (message_type) { + case webrtc::DataMessageType::kControl: + return WebrtcPPID::kDCEP; + case webrtc::DataMessageType::kText: + return size > 0 ? WebrtcPPID::kString : WebrtcPPID::kStringEmpty; + case webrtc::DataMessageType::kBinary: + return size > 0 ? WebrtcPPID::kBinary : WebrtcPPID::kBinaryEmpty; + } +} + +absl::optional ToDataMessageType(dcsctp::PPID ppid) { + switch (static_cast(ppid.value())) { + case WebrtcPPID::kDCEP: + return webrtc::DataMessageType::kControl; + case WebrtcPPID::kString: + case WebrtcPPID::kStringPartial: + case WebrtcPPID::kStringEmpty: + return webrtc::DataMessageType::kText; + case WebrtcPPID::kBinary: + case WebrtcPPID::kBinaryPartial: + case WebrtcPPID::kBinaryEmpty: + return webrtc::DataMessageType::kBinary; + } + return absl::nullopt; +} + +absl::optional ToErrorCauseCode( + dcsctp::ErrorKind error) { + switch (error) { + case dcsctp::ErrorKind::kParseFailed: + return cricket::SctpErrorCauseCode::kUnrecognizedParameters; + case dcsctp::ErrorKind::kPeerReported: + return cricket::SctpErrorCauseCode::kUserInitiatedAbort; + case dcsctp::ErrorKind::kWrongSequence: + case dcsctp::ErrorKind::kProtocolViolation: + return cricket::SctpErrorCauseCode::kProtocolViolation; + case dcsctp::ErrorKind::kResourceExhaustion: + return cricket::SctpErrorCauseCode::kOutOfResource; + case dcsctp::ErrorKind::kTooManyRetries: + case dcsctp::ErrorKind::kUnsupportedOperation: + case dcsctp::ErrorKind::kNoError: + case dcsctp::ErrorKind::kNotConnected: + // No SCTP error cause code matches those + break; + } + return absl::nullopt; +} + +bool IsEmptyPPID(dcsctp::PPID ppid) { + WebrtcPPID webrtc_ppid = static_cast(ppid.value()); + return webrtc_ppid == WebrtcPPID::kStringEmpty || + webrtc_ppid == WebrtcPPID::kBinaryEmpty; +} +} // namespace + +DcSctpTransport::DcSctpTransport(rtc::Thread* network_thread, + rtc::PacketTransportInternal* transport, + Clock* clock) + : network_thread_(network_thread), + transport_(transport), + clock_(clock), + random_(clock_->TimeInMicroseconds()), + task_queue_timeout_factory_( + *network_thread, + [this]() { return TimeMillis(); }, + [this](dcsctp::TimeoutID timeout_id) { + socket_->HandleTimeout(timeout_id); + }) { + RTC_DCHECK_RUN_ON(network_thread_); + static int instance_count = 0; + rtc::StringBuilder sb; + sb << debug_name_ << instance_count++; + debug_name_ = sb.Release(); + ConnectTransportSignals(); +} + +DcSctpTransport::~DcSctpTransport() { + if (socket_) { + socket_->Close(); + } +} + +void DcSctpTransport::SetDtlsTransport( + rtc::PacketTransportInternal* transport) { + RTC_DCHECK_RUN_ON(network_thread_); + DisconnectTransportSignals(); + transport_ = transport; + ConnectTransportSignals(); + MaybeConnectSocket(); +} + +bool DcSctpTransport::Start(int local_sctp_port, + int remote_sctp_port, + int max_message_size) { + RTC_DCHECK_RUN_ON(network_thread_); + RTC_DCHECK(max_message_size > 0); + + RTC_LOG(LS_INFO) << debug_name_ << "->Start(local=" << local_sctp_port + << ", remote=" << remote_sctp_port + << ", max_message_size=" << max_message_size << ")"; + + if (!socket_) { + dcsctp::DcSctpOptions options; + options.local_port = local_sctp_port; + options.remote_port = remote_sctp_port; + options.max_message_size = max_message_size; + + std::unique_ptr packet_observer; + if (RTC_LOG_CHECK_LEVEL(LS_VERBOSE)) { + packet_observer = + std::make_unique(debug_name_); + } + + dcsctp::DcSctpSocketFactory factory; + socket_ = + factory.Create(debug_name_, *this, std::move(packet_observer), options); + } else { + if (local_sctp_port != socket_->options().local_port || + remote_sctp_port != socket_->options().remote_port) { + RTC_LOG(LS_ERROR) + << debug_name_ << "->Start(local=" << local_sctp_port + << ", remote=" << remote_sctp_port + << "): Can't change ports on already started transport."; + return false; + } + socket_->SetMaxMessageSize(max_message_size); + } + + MaybeConnectSocket(); + + return true; +} + +bool DcSctpTransport::OpenStream(int sid) { + RTC_LOG(LS_INFO) << debug_name_ << "->OpenStream(" << sid << ")."; + if (!socket_) { + RTC_LOG(LS_ERROR) << debug_name_ << "->OpenStream(sid=" << sid + << "): Transport is not started."; + return false; + } + return true; +} + +bool DcSctpTransport::ResetStream(int sid) { + RTC_LOG(LS_INFO) << debug_name_ << "->ResetStream(" << sid << ")."; + if (!socket_) { + RTC_LOG(LS_ERROR) << debug_name_ << "->OpenStream(sid=" << sid + << "): Transport is not started."; + return false; + } + dcsctp::StreamID streams[1] = {dcsctp::StreamID(static_cast(sid))}; + socket_->ResetStreams(streams); + return true; +} + +bool DcSctpTransport::SendData(int sid, + const SendDataParams& params, + const rtc::CopyOnWriteBuffer& payload, + cricket::SendDataResult* result) { + RTC_DCHECK_RUN_ON(network_thread_); + + RTC_LOG(LS_VERBOSE) << debug_name_ << "->SendData(sid=" << sid + << ", type=" << static_cast(params.type) + << ", length=" << payload.size() << ")."; + + if (!socket_) { + RTC_LOG(LS_ERROR) << debug_name_ + << "->SendData(...): Transport is not started."; + *result = cricket::SDR_ERROR; + return false; + } + + auto max_message_size = socket_->options().max_message_size; + if (max_message_size > 0 && payload.size() > max_message_size) { + RTC_LOG(LS_WARNING) << debug_name_ + << "->SendData(...): " + "Trying to send packet bigger " + "than the max message size: " + << payload.size() << " vs max of " << max_message_size; + *result = cricket::SDR_ERROR; + return false; + } + + std::vector message_payload(payload.cdata(), + payload.cdata() + payload.size()); + if (message_payload.empty()) { + // https://www.rfc-editor.org/rfc/rfc8831.html#section-6.6 + // SCTP does not support the sending of empty user messages. Therefore, if + // an empty message has to be sent, the appropriate PPID (WebRTC String + // Empty or WebRTC Binary Empty) is used, and the SCTP user message of one + // zero byte is sent. + message_payload.push_back('\0'); + } + + dcsctp::DcSctpMessage message( + dcsctp::StreamID(static_cast(sid)), + dcsctp::PPID(static_cast(ToPPID(params.type, payload.size()))), + std::move(message_payload)); + + dcsctp::SendOptions send_options; + send_options.unordered = dcsctp::IsUnordered(!params.ordered); + if (params.max_rtx_ms.has_value()) { + RTC_DCHECK(*params.max_rtx_ms >= 0 && + *params.max_rtx_ms <= std::numeric_limits::max()); + send_options.lifetime = dcsctp::DurationMs(*params.max_rtx_ms); + } + if (params.max_rtx_count.has_value()) { + RTC_DCHECK(*params.max_rtx_count >= 0 && + *params.max_rtx_count <= std::numeric_limits::max()); + send_options.max_retransmissions = *params.max_rtx_count; + } + + auto error = socket_->Send(std::move(message), send_options); + switch (error) { + case dcsctp::SendStatus::kSuccess: + *result = cricket::SDR_SUCCESS; + break; + case dcsctp::SendStatus::kErrorResourceExhaustion: + *result = cricket::SDR_BLOCK; + ready_to_send_data_ = false; + break; + default: + RTC_LOG(LS_ERROR) << debug_name_ + << "->SendData(...): send() failed with error " + << dcsctp::ToString(error) << "."; + *result = cricket::SDR_ERROR; + } + + return *result == cricket::SDR_SUCCESS; +} + +bool DcSctpTransport::ReadyToSendData() { + return ready_to_send_data_; +} + +int DcSctpTransport::max_message_size() const { + if (!socket_) { + RTC_LOG(LS_ERROR) << debug_name_ + << "->max_message_size(...): Transport is not started."; + return 0; + } + return socket_->options().max_message_size; +} + +absl::optional DcSctpTransport::max_outbound_streams() const { + if (!socket_) + return absl::nullopt; + return socket_->options().announced_maximum_outgoing_streams; +} + +absl::optional DcSctpTransport::max_inbound_streams() const { + if (!socket_) + return absl::nullopt; + return socket_->options().announced_maximum_incoming_streams; +} + +void DcSctpTransport::set_debug_name_for_testing(const char* debug_name) { + debug_name_ = debug_name; +} + +void DcSctpTransport::SendPacket(rtc::ArrayView data) { + RTC_DCHECK_RUN_ON(network_thread_); + RTC_DCHECK(socket_); + + if (data.size() > (socket_->options().mtu)) { + RTC_LOG(LS_ERROR) << debug_name_ + << "->SendPacket(...): " + "SCTP seems to have made a packet that is bigger " + "than its official MTU: " + << data.size() << " vs max of " << socket_->options().mtu; + return; + } + TRACE_EVENT0("webrtc", "DcSctpTransport::SendPacket"); + + if (!transport_ || !transport_->writable()) + return; + + RTC_LOG(LS_VERBOSE) << debug_name_ << "->SendPacket(length=" << data.size() + << ")"; + + auto result = + transport_->SendPacket(reinterpret_cast(data.data()), + data.size(), rtc::PacketOptions(), 0); + + if (result < 0) { + RTC_LOG(LS_WARNING) << debug_name_ << "->SendPacket(length=" << data.size() + << ") failed with error: " << transport_->GetError() + << "."; + } +} + +std::unique_ptr DcSctpTransport::CreateTimeout() { + return task_queue_timeout_factory_.CreateTimeout(); +} + +dcsctp::TimeMs DcSctpTransport::TimeMillis() { + return dcsctp::TimeMs(clock_->TimeInMilliseconds()); +} + +uint32_t DcSctpTransport::GetRandomInt(uint32_t low, uint32_t high) { + return random_.Rand(low, high); +} + +void DcSctpTransport::OnTotalBufferedAmountLow() { + if (!ready_to_send_data_) { + ready_to_send_data_ = true; + SignalReadyToSendData(); + } +} + +void DcSctpTransport::OnMessageReceived(dcsctp::DcSctpMessage message) { + RTC_DCHECK_RUN_ON(network_thread_); + RTC_LOG(LS_VERBOSE) << debug_name_ << "->OnMessageReceived(sid=" + << message.stream_id().value() + << ", ppid=" << message.ppid().value() + << ", length=" << message.payload().size() << ")."; + cricket::ReceiveDataParams receive_data_params; + receive_data_params.sid = message.stream_id().value(); + auto type = ToDataMessageType(message.ppid()); + if (!type.has_value()) { + RTC_LOG(LS_VERBOSE) << debug_name_ + << "->OnMessageReceived(): Received an unknown PPID " + << message.ppid().value() + << " on an SCTP packet. Dropping."; + } + receive_data_params.type = *type; + // No seq_num available from dcSCTP + receive_data_params.seq_num = 0; + receive_buffer_.Clear(); + if (!IsEmptyPPID(message.ppid())) + receive_buffer_.AppendData(message.payload().data(), + message.payload().size()); + + SignalDataReceived(receive_data_params, receive_buffer_); +} + +void DcSctpTransport::OnError(dcsctp::ErrorKind error, + absl::string_view message) { + RTC_LOG(LS_ERROR) << debug_name_ + << "->OnError(error=" << dcsctp::ToString(error) + << ", message=" << message << ")."; +} + +void DcSctpTransport::OnAborted(dcsctp::ErrorKind error, + absl::string_view message) { + RTC_LOG(LS_ERROR) << debug_name_ + << "->OnAborted(error=" << dcsctp::ToString(error) + << ", message=" << message << ")."; + ready_to_send_data_ = false; + RTCError rtc_error(RTCErrorType::OPERATION_ERROR_WITH_DATA, + std::string(message)); + rtc_error.set_error_detail(RTCErrorDetailType::SCTP_FAILURE); + auto code = ToErrorCauseCode(error); + if (code.has_value()) { + rtc_error.set_sctp_cause_code(static_cast(*code)); + } + SignalClosedAbruptly(rtc_error); +} + +void DcSctpTransport::OnConnected() { + RTC_LOG(LS_INFO) << debug_name_ << "->OnConnected()."; + ready_to_send_data_ = true; + SignalReadyToSendData(); + SignalAssociationChangeCommunicationUp(); +} + +void DcSctpTransport::OnClosed() { + RTC_LOG(LS_INFO) << debug_name_ << "->OnClosed()."; + ready_to_send_data_ = false; +} + +void DcSctpTransport::OnConnectionRestarted() { + RTC_LOG(LS_INFO) << debug_name_ << "->OnConnectionRestarted()."; +} + +void DcSctpTransport::OnStreamsResetFailed( + rtc::ArrayView outgoing_streams, + absl::string_view reason) { + // TODO(orphis): Need a test to check for correct behavior + for (auto& stream_id : outgoing_streams) { + RTC_LOG(LS_WARNING) + << debug_name_ + << "->OnStreamsResetFailed(...): Outgoing stream reset failed" + << ", sid=" << stream_id.value() << ", reason: " << reason << "."; + } +} + +void DcSctpTransport::OnStreamsResetPerformed( + rtc::ArrayView outgoing_streams) { + for (auto& stream_id : outgoing_streams) { + RTC_LOG(LS_INFO) << debug_name_ + << "->OnStreamsResetPerformed(...): Outgoing stream reset" + << ", sid=" << stream_id.value(); + SignalClosingProcedureComplete(stream_id.value()); + } +} + +void DcSctpTransport::OnIncomingStreamsReset( + rtc::ArrayView incoming_streams) { + for (auto& stream_id : incoming_streams) { + RTC_LOG(LS_INFO) << debug_name_ + << "->OnIncomingStreamsReset(...): Incoming stream reset" + << ", sid=" << stream_id.value(); + SignalClosingProcedureStartedRemotely(stream_id.value()); + SignalClosingProcedureComplete(stream_id.value()); + } +} + +void DcSctpTransport::ConnectTransportSignals() { + RTC_DCHECK_RUN_ON(network_thread_); + if (!transport_) { + return; + } + transport_->SignalWritableState.connect( + this, &DcSctpTransport::OnTransportWritableState); + transport_->SignalReadPacket.connect(this, + &DcSctpTransport::OnTransportReadPacket); + transport_->SignalClosed.connect(this, &DcSctpTransport::OnTransportClosed); +} + +void DcSctpTransport::DisconnectTransportSignals() { + RTC_DCHECK_RUN_ON(network_thread_); + if (!transport_) { + return; + } + transport_->SignalWritableState.disconnect(this); + transport_->SignalReadPacket.disconnect(this); + transport_->SignalClosed.disconnect(this); +} + +void DcSctpTransport::OnTransportWritableState( + rtc::PacketTransportInternal* transport) { + RTC_DCHECK_RUN_ON(network_thread_); + RTC_DCHECK_EQ(transport_, transport); + + RTC_LOG(LS_VERBOSE) << debug_name_ + << "->OnTransportWritableState(), writable=" + << transport->writable(); + + MaybeConnectSocket(); +} + +void DcSctpTransport::OnTransportReadPacket( + rtc::PacketTransportInternal* transport, + const char* data, + size_t length, + const int64_t& /* packet_time_us */, + int flags) { + if (flags) { + // We are only interested in SCTP packets. + return; + } + + RTC_LOG(LS_VERBOSE) << debug_name_ + << "->OnTransportReadPacket(), length=" << length; + if (socket_) { + socket_->ReceivePacket(rtc::ArrayView( + reinterpret_cast(data), length)); + } +} + +void DcSctpTransport::OnTransportClosed( + rtc::PacketTransportInternal* transport) { + RTC_LOG(LS_VERBOSE) << debug_name_ << "->OnTransportClosed()."; + SignalClosedAbruptly({}); +} + +void DcSctpTransport::MaybeConnectSocket() { + if (transport_ && transport_->writable() && socket_ && + socket_->state() == dcsctp::SocketState::kClosed) { + socket_->Connect(); + } +} +} // namespace webrtc diff --git a/media/sctp/dcsctp_transport.h b/media/sctp/dcsctp_transport.h new file mode 100644 index 0000000000..15933383b5 --- /dev/null +++ b/media/sctp/dcsctp_transport.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MEDIA_SCTP_DCSCTP_TRANSPORT_H_ +#define MEDIA_SCTP_DCSCTP_TRANSPORT_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "media/sctp/sctp_transport_internal.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/timer/task_queue_timeout.h" +#include "p2p/base/packet_transport_internal.h" +#include "rtc_base/copy_on_write_buffer.h" +#include "rtc_base/random.h" +#include "rtc_base/third_party/sigslot/sigslot.h" +#include "rtc_base/thread.h" +#include "system_wrappers/include/clock.h" + +namespace webrtc { + +class DcSctpTransport : public cricket::SctpTransportInternal, + public dcsctp::DcSctpSocketCallbacks, + public sigslot::has_slots<> { + public: + DcSctpTransport(rtc::Thread* network_thread, + rtc::PacketTransportInternal* transport, + Clock* clock); + ~DcSctpTransport() override; + + // cricket::SctpTransportInternal + void SetDtlsTransport(rtc::PacketTransportInternal* transport) override; + bool Start(int local_sctp_port, + int remote_sctp_port, + int max_message_size) override; + bool OpenStream(int sid) override; + bool ResetStream(int sid) override; + bool SendData(int sid, + const SendDataParams& params, + const rtc::CopyOnWriteBuffer& payload, + cricket::SendDataResult* result = nullptr) override; + bool ReadyToSendData() override; + int max_message_size() const override; + absl::optional max_outbound_streams() const override; + absl::optional max_inbound_streams() const override; + void set_debug_name_for_testing(const char* debug_name) override; + + private: + // dcsctp::DcSctpSocketCallbacks + void SendPacket(rtc::ArrayView data) override; + std::unique_ptr CreateTimeout() override; + dcsctp::TimeMs TimeMillis() override; + uint32_t GetRandomInt(uint32_t low, uint32_t high) override; + void OnTotalBufferedAmountLow() override; + void OnMessageReceived(dcsctp::DcSctpMessage message) override; + void OnError(dcsctp::ErrorKind error, absl::string_view message) override; + void OnAborted(dcsctp::ErrorKind error, absl::string_view message) override; + void OnConnected() override; + void OnClosed() override; + void OnConnectionRestarted() override; + void OnStreamsResetFailed( + rtc::ArrayView outgoing_streams, + absl::string_view reason) override; + void OnStreamsResetPerformed( + rtc::ArrayView outgoing_streams) override; + void OnIncomingStreamsReset( + rtc::ArrayView incoming_streams) override; + + // Transport callbacks + void ConnectTransportSignals(); + void DisconnectTransportSignals(); + void OnTransportWritableState(rtc::PacketTransportInternal* transport); + void OnTransportReadPacket(rtc::PacketTransportInternal* transport, + const char* data, + size_t length, + const int64_t& /* packet_time_us */, + int flags); + void OnTransportClosed(rtc::PacketTransportInternal* transport); + + void MaybeConnectSocket(); + + rtc::Thread* network_thread_; + rtc::PacketTransportInternal* transport_; + Clock* clock_; + Random random_; + + dcsctp::TaskQueueTimeoutFactory task_queue_timeout_factory_; + std::unique_ptr socket_; + std::string debug_name_ = "DcSctpTransport"; + rtc::CopyOnWriteBuffer receive_buffer_; + + bool ready_to_send_data_ = false; +}; + +} // namespace webrtc + +#endif // MEDIA_SCTP_DCSCTP_TRANSPORT_H_ diff --git a/media/sctp/noop.cc b/media/sctp/noop.cc deleted file mode 100644 index a3523b18b2..0000000000 --- a/media/sctp/noop.cc +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Copyright 2017 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -// This file is only needed to make ninja happy on some platforms. -// On some platforms it is not possible to link an rtc_static_library -// without any source file listed in the GN target. diff --git a/media/sctp/sctp_transport_factory.cc b/media/sctp/sctp_transport_factory.cc new file mode 100644 index 0000000000..5097d423d9 --- /dev/null +++ b/media/sctp/sctp_transport_factory.cc @@ -0,0 +1,55 @@ +/* + * Copyright 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "media/sctp/sctp_transport_factory.h" + +#include "rtc_base/system/unused.h" + +#ifdef WEBRTC_HAVE_DCSCTP +#include "media/sctp/dcsctp_transport.h" // nogncheck +#include "system_wrappers/include/clock.h" // nogncheck +#include "system_wrappers/include/field_trial.h" // nogncheck +#endif + +#ifdef WEBRTC_HAVE_USRSCTP +#include "media/sctp/usrsctp_transport.h" // nogncheck +#endif + +namespace cricket { + +SctpTransportFactory::SctpTransportFactory(rtc::Thread* network_thread) + : network_thread_(network_thread), use_dcsctp_("Enabled", false) { + RTC_UNUSED(network_thread_); +#ifdef WEBRTC_HAVE_DCSCTP + webrtc::ParseFieldTrial({&use_dcsctp_}, webrtc::field_trial::FindFullName( + "WebRTC-DataChannel-Dcsctp")); +#endif +} + +std::unique_ptr +SctpTransportFactory::CreateSctpTransport( + rtc::PacketTransportInternal* transport) { + std::unique_ptr result; +#ifdef WEBRTC_HAVE_DCSCTP + if (use_dcsctp_.Get()) { + result = std::unique_ptr(new webrtc::DcSctpTransport( + network_thread_, transport, webrtc::Clock::GetRealTimeClock())); + } +#endif +#ifdef WEBRTC_HAVE_USRSCTP + if (!result) { + result = std::unique_ptr( + new UsrsctpTransport(network_thread_, transport)); + } +#endif + return result; +} + +} // namespace cricket diff --git a/media/sctp/sctp_transport_factory.h b/media/sctp/sctp_transport_factory.h new file mode 100644 index 0000000000..ed7c2163d7 --- /dev/null +++ b/media/sctp/sctp_transport_factory.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MEDIA_SCTP_SCTP_TRANSPORT_FACTORY_H_ +#define MEDIA_SCTP_SCTP_TRANSPORT_FACTORY_H_ + +#include + +#include "api/transport/sctp_transport_factory_interface.h" +#include "media/sctp/sctp_transport_internal.h" +#include "rtc_base/experiments/field_trial_parser.h" +#include "rtc_base/thread.h" + +namespace cricket { + +class SctpTransportFactory : public webrtc::SctpTransportFactoryInterface { + public: + explicit SctpTransportFactory(rtc::Thread* network_thread); + + std::unique_ptr CreateSctpTransport( + rtc::PacketTransportInternal* transport) override; + + private: + rtc::Thread* network_thread_; + webrtc::FieldTrialFlag use_dcsctp_; +}; + +} // namespace cricket + +#endif // MEDIA_SCTP_SCTP_TRANSPORT_FACTORY_H__ diff --git a/media/sctp/sctp_transport_internal.h b/media/sctp/sctp_transport_internal.h index dc8ac4558d..b1327165b6 100644 --- a/media/sctp/sctp_transport_internal.h +++ b/media/sctp/sctp_transport_internal.h @@ -18,6 +18,7 @@ #include #include +#include "api/transport/data_channel_transport_interface.h" #include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/thread.h" // For SendDataParams/ReceiveDataParams. @@ -52,6 +53,24 @@ constexpr uint16_t kMinSctpSid = 0; // usrsctp.h) const int kSctpDefaultPort = 5000; +// Error cause codes defined at +// https://www.iana.org/assignments/sctp-parameters/sctp-parameters.xhtml#sctp-parameters-24 +enum class SctpErrorCauseCode : uint16_t { + kInvalidStreamIdentifier = 1, + kMissingMandatoryParameter = 2, + kStaleCookieError = 3, + kOutOfResource = 4, + kUnresolvableAddress = 5, + kUnrecognizedChunkType = 6, + kInvalidMandatoryParameter = 7, + kUnrecognizedParameters = 8, + kNoUserData = 9, + kCookieReceivedWhileShuttingDown = 10, + kRestartWithNewAddresses = 11, + kUserInitiatedAbort = 12, + kProtocolViolation = 13, +}; + // Abstract SctpTransport interface for use internally (by PeerConnection etc.). // Exists to allow mock/fake SctpTransports to be created. class SctpTransportInternal { @@ -101,7 +120,8 @@ class SctpTransportInternal { // usrsctp that will then post the network interface). // Returns true iff successful data somewhere on the send-queue/network. // Uses |params.ssrc| as the SCTP sid. - virtual bool SendData(const SendDataParams& params, + virtual bool SendData(int sid, + const webrtc::SendDataParams& params, const rtc::CopyOnWriteBuffer& payload, SendDataResult* result = nullptr) = 0; @@ -135,8 +155,8 @@ class SctpTransportInternal { // and outgoing streams reset). sigslot::signal1 SignalClosingProcedureComplete; // Fired when the underlying DTLS transport has closed due to an error - // or an incoming DTLS disconnect. - sigslot::signal0<> SignalClosedAbruptly; + // or an incoming DTLS disconnect or SCTP transport errors. + sigslot::signal1 SignalClosedAbruptly; // Helper for debugging. virtual void set_debug_name_for_testing(const char* debug_name) = 0; diff --git a/media/sctp/sctp_transport.cc b/media/sctp/usrsctp_transport.cc similarity index 74% rename from media/sctp/sctp_transport.cc rename to media/sctp/usrsctp_transport.cc index 6bb4a8fdf2..7824a72934 100644 --- a/media/sctp/sctp_transport.cc +++ b/media/sctp/usrsctp_transport.cc @@ -20,6 +20,7 @@ enum PreservedErrno { // Successful return value from usrsctp callbacks. Is not actually used by // usrsctp, but all example programs for usrsctp use 1 as their return value. constexpr int kSctpSuccessReturn = 1; +constexpr int kSctpErrorReturn = 0; } // namespace @@ -29,15 +30,17 @@ constexpr int kSctpSuccessReturn = 1; #include #include +#include #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "media/base/codec.h" #include "media/base/media_channel.h" #include "media/base/media_constants.h" #include "media/base/stream_params.h" -#include "media/sctp/sctp_transport.h" +#include "media/sctp/usrsctp_transport.h" #include "p2p/base/dtls_transport_internal.h" // For PF_NORMAL #include "rtc_base/arraysize.h" #include "rtc_base/copy_on_write_buffer.h" @@ -48,92 +51,69 @@ constexpr int kSctpSuccessReturn = 1; #include "rtc_base/synchronization/mutex.h" #include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" #include "rtc_base/trace_event.h" +namespace cricket { namespace { // The biggest SCTP packet. Starting from a 'safe' wire MTU value of 1280, -// take off 80 bytes for DTLS/TURN/TCP/IP overhead. -static constexpr size_t kSctpMtu = 1200; +// take off 85 bytes for DTLS/TURN/TCP/IP and ciphertext overhead. +// +// Additionally, it's possible that TURN adds an additional 4 bytes of overhead +// after a channel has been established, so we subtract an additional 4 bytes. +// +// 1280 IPV6 MTU +// -40 IPV6 header +// -8 UDP +// -24 GCM Cipher +// -13 DTLS record header +// -4 TURN ChannelData +// = 1191 bytes. +static constexpr size_t kSctpMtu = 1191; // Set the initial value of the static SCTP Data Engines reference count. ABSL_CONST_INIT int g_usrsctp_usage_count = 0; ABSL_CONST_INIT bool g_usrsctp_initialized_ = false; ABSL_CONST_INIT webrtc::GlobalMutex g_usrsctp_lock_(absl::kConstInit); +ABSL_CONST_INIT char kZero[] = {'\0'}; // DataMessageType is used for the SCTP "Payload Protocol Identifier", as // defined in http://tools.ietf.org/html/rfc4960#section-14.4 // // For the list of IANA approved values see: +// https://tools.ietf.org/html/rfc8831 Sec. 8 // http://www.iana.org/assignments/sctp-parameters/sctp-parameters.xml // The value is not used by SCTP itself. It indicates the protocol running // on top of SCTP. enum { PPID_NONE = 0, // No protocol is specified. - // Matches the PPIDs in mozilla source and - // https://datatracker.ietf.org/doc/draft-ietf-rtcweb-data-protocol Sec. 9 - // They're not yet assigned by IANA. PPID_CONTROL = 50, - PPID_BINARY_PARTIAL = 52, + PPID_TEXT_LAST = 51, + PPID_BINARY_PARTIAL = 52, // Deprecated PPID_BINARY_LAST = 53, - PPID_TEXT_PARTIAL = 54, - PPID_TEXT_LAST = 51 + PPID_TEXT_PARTIAL = 54, // Deprecated + PPID_TEXT_EMPTY = 56, + PPID_BINARY_EMPTY = 57, }; -// Maps SCTP transport ID to SctpTransport object, necessary in send threshold -// callback and outgoing packet callback. -// TODO(crbug.com/1076703): Remove once the underlying problem is fixed or -// workaround is provided in usrsctp. -class SctpTransportMap { - public: - SctpTransportMap() = default; - - // Assigns a new unused ID to the following transport. - uintptr_t Register(cricket::SctpTransport* transport) { - webrtc::MutexLock lock(&lock_); - // usrsctp_connect fails with a value of 0... - if (next_id_ == 0) { - ++next_id_; - } - // In case we've wrapped around and need to find an empty spot from a - // removed transport. Assumes we'll never be full. - while (map_.find(next_id_) != map_.end()) { - ++next_id_; - if (next_id_ == 0) { - ++next_id_; - } - }; - map_[next_id_] = transport; - return next_id_++; - } +// Should only be modified by UsrSctpWrapper. +ABSL_CONST_INIT cricket::UsrsctpTransportMap* g_transport_map_ = nullptr; - // Returns true if found. - bool Deregister(uintptr_t id) { - webrtc::MutexLock lock(&lock_); - return map_.erase(id) > 0; - } +// Helper that will call C's free automatically. +// TODO(b/181900299): Figure out why unique_ptr with a custom deleter is causing +// issues in a certain build environment. +class AutoFreedPointer { + public: + explicit AutoFreedPointer(void* ptr) : ptr_(ptr) {} + AutoFreedPointer(AutoFreedPointer&& o) : ptr_(o.ptr_) { o.ptr_ = nullptr; } + ~AutoFreedPointer() { free(ptr_); } - cricket::SctpTransport* Retrieve(uintptr_t id) const { - webrtc::MutexLock lock(&lock_); - auto it = map_.find(id); - if (it == map_.end()) { - return nullptr; - } - return it->second; - } + void* get() const { return ptr_; } private: - mutable webrtc::Mutex lock_; - - uintptr_t next_id_ RTC_GUARDED_BY(lock_) = 0; - std::unordered_map map_ - RTC_GUARDED_BY(lock_); + void* ptr_; }; -// Should only be modified by UsrSctpWrapper. -ABSL_CONST_INIT SctpTransportMap* g_transport_map_ = nullptr; - // Helper for logging SCTP messages. #if defined(__GNUC__) __attribute__((__format__(__printf__, 1, 2))) @@ -150,44 +130,41 @@ void DebugSctpPrintf(const char* format, ...) { } // Get the PPID to use for the terminating fragment of this type. -uint32_t GetPpid(cricket::DataMessageType type) { +uint32_t GetPpid(webrtc::DataMessageType type, size_t size) { switch (type) { - default: - case cricket::DMT_NONE: - return PPID_NONE; - case cricket::DMT_CONTROL: + case webrtc::DataMessageType::kControl: return PPID_CONTROL; - case cricket::DMT_BINARY: - return PPID_BINARY_LAST; - case cricket::DMT_TEXT: - return PPID_TEXT_LAST; + case webrtc::DataMessageType::kBinary: + return size > 0 ? PPID_BINARY_LAST : PPID_BINARY_EMPTY; + case webrtc::DataMessageType::kText: + return size > 0 ? PPID_TEXT_LAST : PPID_TEXT_EMPTY; } } -bool GetDataMediaType(uint32_t ppid, cricket::DataMessageType* dest) { +bool GetDataMediaType(uint32_t ppid, webrtc::DataMessageType* dest) { RTC_DCHECK(dest != NULL); switch (ppid) { case PPID_BINARY_PARTIAL: case PPID_BINARY_LAST: - *dest = cricket::DMT_BINARY; + case PPID_BINARY_EMPTY: + *dest = webrtc::DataMessageType::kBinary; return true; case PPID_TEXT_PARTIAL: case PPID_TEXT_LAST: - *dest = cricket::DMT_TEXT; + case PPID_TEXT_EMPTY: + *dest = webrtc::DataMessageType::kText; return true; case PPID_CONTROL: - *dest = cricket::DMT_CONTROL; - return true; - - case PPID_NONE: - *dest = cricket::DMT_NONE; + *dest = webrtc::DataMessageType::kControl; return true; - - default: - return false; } + return false; +} + +bool IsEmptyPPID(uint32_t ppid) { + return ppid == PPID_BINARY_EMPTY || ppid == PPID_TEXT_EMPTY; } // Log the packet in text2pcap format, if log level is at LS_VERBOSE. @@ -227,11 +204,13 @@ void VerboseLogPacket(const void* data, size_t length, int direction) { // Creates the sctp_sendv_spa struct used for setting flags in the // sctp_sendv() call. -sctp_sendv_spa CreateSctpSendParams(const cricket::SendDataParams& params) { +sctp_sendv_spa CreateSctpSendParams(int sid, + const webrtc::SendDataParams& params, + size_t size) { struct sctp_sendv_spa spa = {0}; spa.sendv_flags |= SCTP_SEND_SNDINFO_VALID; - spa.sendv_sndinfo.snd_sid = params.sid; - spa.sendv_sndinfo.snd_ppid = rtc::HostToNetwork32(GetPpid(params.type)); + spa.sendv_sndinfo.snd_sid = sid; + spa.sendv_sndinfo.snd_ppid = rtc::HostToNetwork32(GetPpid(params.type, size)); // Explicitly marking the EOR flag turns the usrsctp_sendv call below into a // non atomic operation. This means that the sctp lib might only accept the // message partially. This is done in order to improve throughput, so that we @@ -239,28 +218,128 @@ sctp_sendv_spa CreateSctpSendParams(const cricket::SendDataParams& params) { // example. spa.sendv_sndinfo.snd_flags |= SCTP_EOR; - // Ordered implies reliable. if (!params.ordered) { spa.sendv_sndinfo.snd_flags |= SCTP_UNORDERED; - if (params.max_rtx_count >= 0 || params.max_rtx_ms == 0) { - spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; - spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX; - spa.sendv_prinfo.pr_value = params.max_rtx_count; - } else { - spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; - spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL; - spa.sendv_prinfo.pr_value = params.max_rtx_ms; - } + } + if (params.max_rtx_count.has_value()) { + RTC_DCHECK(*params.max_rtx_count >= 0 && + *params.max_rtx_count <= std::numeric_limits::max()); + spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; + spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX; + spa.sendv_prinfo.pr_value = *params.max_rtx_count; + } + if (params.max_rtx_ms.has_value()) { + RTC_DCHECK(*params.max_rtx_ms >= 0 && + *params.max_rtx_ms <= std::numeric_limits::max()); + spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; + spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL; + spa.sendv_prinfo.pr_value = *params.max_rtx_ms; } return spa; } + +std::string SctpErrorCauseCodeToString(SctpErrorCauseCode code) { + switch (code) { + case SctpErrorCauseCode::kInvalidStreamIdentifier: + return "Invalid Stream Identifier"; + case SctpErrorCauseCode::kMissingMandatoryParameter: + return "Missing Mandatory Parameter"; + case SctpErrorCauseCode::kStaleCookieError: + return "Stale Cookie Error"; + case SctpErrorCauseCode::kOutOfResource: + return "Out of Resource"; + case SctpErrorCauseCode::kUnresolvableAddress: + return "Unresolvable Address"; + case SctpErrorCauseCode::kUnrecognizedChunkType: + return "Unrecognized Chunk Type"; + case SctpErrorCauseCode::kInvalidMandatoryParameter: + return "Invalid Mandatory Parameter"; + case SctpErrorCauseCode::kUnrecognizedParameters: + return "Unrecognized Parameters"; + case SctpErrorCauseCode::kNoUserData: + return "No User Data"; + case SctpErrorCauseCode::kCookieReceivedWhileShuttingDown: + return "Cookie Received Whilte Shutting Down"; + case SctpErrorCauseCode::kRestartWithNewAddresses: + return "Restart With New Addresses"; + case SctpErrorCauseCode::kUserInitiatedAbort: + return "User Initiated Abort"; + case SctpErrorCauseCode::kProtocolViolation: + return "Protocol Violation"; + } + return "Unknown error"; +} } // namespace -namespace cricket { +// Maps SCTP transport ID to UsrsctpTransport object, necessary in send +// threshold callback and outgoing packet callback. It also provides a facility +// to safely post a task to an UsrsctpTransport's network thread from another +// thread. +class UsrsctpTransportMap { + public: + UsrsctpTransportMap() = default; + + // Assigns a new unused ID to the following transport. + uintptr_t Register(cricket::UsrsctpTransport* transport) { + webrtc::MutexLock lock(&lock_); + // usrsctp_connect fails with a value of 0... + if (next_id_ == 0) { + ++next_id_; + } + // In case we've wrapped around and need to find an empty spot from a + // removed transport. Assumes we'll never be full. + while (map_.find(next_id_) != map_.end()) { + ++next_id_; + if (next_id_ == 0) { + ++next_id_; + } + } + map_[next_id_] = transport; + return next_id_++; + } + + // Returns true if found. + bool Deregister(uintptr_t id) { + webrtc::MutexLock lock(&lock_); + return map_.erase(id) > 0; + } + + // Posts |action| to the network thread of the transport identified by |id| + // and returns true if found, all while holding a lock to protect against the + // transport being simultaneously deleted/deregistered, or returns false if + // not found. + template + bool PostToTransportThread(uintptr_t id, F action) const { + webrtc::MutexLock lock(&lock_); + UsrsctpTransport* transport = RetrieveWhileHoldingLock(id); + if (!transport) { + return false; + } + transport->network_thread_->PostTask(ToQueuedTask( + transport->task_safety_, + [transport, action{std::move(action)}]() { action(transport); })); + return true; + } + + private: + UsrsctpTransport* RetrieveWhileHoldingLock(uintptr_t id) const + RTC_EXCLUSIVE_LOCKS_REQUIRED(lock_) { + auto it = map_.find(id); + if (it == map_.end()) { + return nullptr; + } + return it->second; + } + + mutable webrtc::Mutex lock_; + + uintptr_t next_id_ RTC_GUARDED_BY(lock_) = 0; + std::unordered_map map_ RTC_GUARDED_BY(lock_); +}; // Handles global init/deinit, and mapping from usrsctp callbacks to -// SctpTransport calls. -class SctpTransport::UsrSctpWrapper { +// UsrsctpTransport calls. +class UsrsctpTransport::UsrSctpWrapper { public: static void InitializeUsrSctp() { RTC_LOG(LS_INFO) << __FUNCTION__; @@ -319,7 +398,7 @@ class SctpTransport::UsrSctpWrapper { // send in the SCTP INIT message. usrsctp_sysctl_set_sctp_nr_outgoing_streams_default(kMaxSctpStreams); - g_transport_map_ = new SctpTransportMap(); + g_transport_map_ = new UsrsctpTransportMap(); } static void UninitializeUsrSctp() { @@ -370,14 +449,6 @@ class SctpTransport::UsrSctpWrapper { << "OnSctpOutboundPacket called after usrsctp uninitialized?"; return EINVAL; } - SctpTransport* transport = - g_transport_map_->Retrieve(reinterpret_cast(addr)); - if (!transport) { - RTC_LOG(LS_ERROR) - << "OnSctpOutboundPacket: Failed to get transport for socket ID " - << addr; - return EINVAL; - } RTC_LOG(LS_VERBOSE) << "global OnSctpOutboundPacket():" "addr: " << addr << "; length: " << length @@ -385,13 +456,23 @@ class SctpTransport::UsrSctpWrapper { << "; set_df: " << rtc::ToHex(set_df); VerboseLogPacket(data, length, SCTP_DUMP_OUTBOUND); + // Note: We have to copy the data; the caller will delete it. rtc::CopyOnWriteBuffer buf(reinterpret_cast(data), length); - transport->network_thread_->PostTask(ToQueuedTask( - transport->task_safety_, [transport, buf = std::move(buf)]() { + // PostsToTransportThread protects against the transport being + // simultaneously deregistered/deleted, since this callback may come from + // the SCTP timer thread and thus race with the network thread. + bool found = g_transport_map_->PostToTransportThread( + reinterpret_cast(addr), [buf](UsrsctpTransport* transport) { transport->OnPacketFromSctpToNetwork(buf); - })); + }); + if (!found) { + RTC_LOG(LS_ERROR) + << "OnSctpOutboundPacket: Failed to get transport for socket ID " + << addr << "; possibly was already destroyed."; + return EINVAL; + } return 0; } @@ -407,89 +488,123 @@ class SctpTransport::UsrSctpWrapper { struct sctp_rcvinfo rcv, int flags, void* ulp_info) { - SctpTransport* transport = GetTransportFromSocket(sock); - if (!transport) { + AutoFreedPointer owned_data(data); + + absl::optional id = GetTransportIdFromSocket(sock); + if (!id) { RTC_LOG(LS_ERROR) - << "OnSctpInboundPacket: Failed to get transport for socket " << sock - << "; possibly was already destroyed."; - free(data); - return 0; + << "OnSctpInboundPacket: Failed to get transport ID from socket " + << sock; + return kSctpErrorReturn; } - // Sanity check that both methods of getting the SctpTransport pointer - // yield the same result. - RTC_CHECK_EQ(transport, static_cast(ulp_info)); - int result = - transport->OnDataOrNotificationFromSctp(data, length, rcv, flags); - free(data); - return result; + + if (!g_transport_map_) { + RTC_LOG(LS_ERROR) + << "OnSctpInboundPacket called after usrsctp uninitialized?"; + return kSctpErrorReturn; + } + // PostsToTransportThread protects against the transport being + // simultaneously deregistered/deleted, since this callback may come from + // the SCTP timer thread and thus race with the network thread. + bool found = g_transport_map_->PostToTransportThread( + *id, [owned_data{std::move(owned_data)}, length, rcv, + flags](UsrsctpTransport* transport) { + transport->OnDataOrNotificationFromSctp(owned_data.get(), length, rcv, + flags); + }); + if (!found) { + RTC_LOG(LS_ERROR) + << "OnSctpInboundPacket: Failed to get transport for socket ID " + << *id << "; possibly was already destroyed."; + return kSctpErrorReturn; + } + return kSctpSuccessReturn; } - static SctpTransport* GetTransportFromSocket(struct socket* sock) { + static absl::optional GetTransportIdFromSocket( + struct socket* sock) { + absl::optional ret; struct sockaddr* addrs = nullptr; int naddrs = usrsctp_getladdrs(sock, 0, &addrs); if (naddrs <= 0 || addrs[0].sa_family != AF_CONN) { - return nullptr; + return ret; } // usrsctp_getladdrs() returns the addresses bound to this socket, which - // contains the SctpTransport id as sconn_addr. Read the id, + // contains the UsrsctpTransport id as sconn_addr. Read the id, // then free the list of addresses once we have the pointer. We only open // AF_CONN sockets, and they should all have the sconn_addr set to the // id of the transport that created them, so [0] is as good as any other. struct sockaddr_conn* sconn = reinterpret_cast(&addrs[0]); - if (!g_transport_map_) { - RTC_LOG(LS_ERROR) - << "GetTransportFromSocket called after usrsctp uninitialized?"; - usrsctp_freeladdrs(addrs); - return nullptr; - } - SctpTransport* transport = g_transport_map_->Retrieve( - reinterpret_cast(sconn->sconn_addr)); + ret = reinterpret_cast(sconn->sconn_addr); usrsctp_freeladdrs(addrs); - return transport; + return ret; } // TODO(crbug.com/webrtc/11899): This is a legacy callback signature, remove // when usrsctp is updated. static int SendThresholdCallback(struct socket* sock, uint32_t sb_free) { - // Fired on our I/O thread. SctpTransport::OnPacketReceived() gets + // Fired on our I/O thread. UsrsctpTransport::OnPacketReceived() gets // a packet containing acknowledgments, which goes into usrsctp_conninput, // and then back here. - SctpTransport* transport = GetTransportFromSocket(sock); - if (!transport) { + absl::optional id = GetTransportIdFromSocket(sock); + if (!id) { + RTC_LOG(LS_ERROR) + << "SendThresholdCallback: Failed to get transport ID from socket " + << sock; + return 0; + } + if (!g_transport_map_) { RTC_LOG(LS_ERROR) - << "SendThresholdCallback: Failed to get transport for socket " - << sock << "; possibly was already destroyed."; + << "SendThresholdCallback called after usrsctp uninitialized?"; return 0; } - transport->OnSendThresholdCallback(); + bool found = g_transport_map_->PostToTransportThread( + *id, [](UsrsctpTransport* transport) { + transport->OnSendThresholdCallback(); + }); + if (!found) { + RTC_LOG(LS_ERROR) + << "SendThresholdCallback: Failed to get transport for socket ID " + << *id << "; possibly was already destroyed."; + } return 0; } static int SendThresholdCallback(struct socket* sock, uint32_t sb_free, void* ulp_info) { - // Fired on our I/O thread. SctpTransport::OnPacketReceived() gets + // Fired on our I/O thread. UsrsctpTransport::OnPacketReceived() gets // a packet containing acknowledgments, which goes into usrsctp_conninput, // and then back here. - SctpTransport* transport = GetTransportFromSocket(sock); - if (!transport) { + absl::optional id = GetTransportIdFromSocket(sock); + if (!id) { RTC_LOG(LS_ERROR) - << "SendThresholdCallback: Failed to get transport for socket " - << sock << "; possibly was already destroyed."; + << "SendThresholdCallback: Failed to get transport ID from socket " + << sock; return 0; } - // Sanity check that both methods of getting the SctpTransport pointer - // yield the same result. - RTC_CHECK_EQ(transport, static_cast(ulp_info)); - transport->OnSendThresholdCallback(); + if (!g_transport_map_) { + RTC_LOG(LS_ERROR) + << "SendThresholdCallback called after usrsctp uninitialized?"; + return 0; + } + bool found = g_transport_map_->PostToTransportThread( + *id, [](UsrsctpTransport* transport) { + transport->OnSendThresholdCallback(); + }); + if (!found) { + RTC_LOG(LS_ERROR) + << "SendThresholdCallback: Failed to get transport for socket ID " + << *id << "; possibly was already destroyed."; + } return 0; } }; -SctpTransport::SctpTransport(rtc::Thread* network_thread, - rtc::PacketTransportInternal* transport) +UsrsctpTransport::UsrsctpTransport(rtc::Thread* network_thread, + rtc::PacketTransportInternal* transport) : network_thread_(network_thread), transport_(transport), was_ever_writable_(transport ? transport->writable() : false) { @@ -498,17 +613,17 @@ SctpTransport::SctpTransport(rtc::Thread* network_thread, ConnectTransportSignals(); } -SctpTransport::~SctpTransport() { +UsrsctpTransport::~UsrsctpTransport() { RTC_DCHECK_RUN_ON(network_thread_); // Close abruptly; no reset procedure. CloseSctpSocket(); // It's not strictly necessary to reset these fields to nullptr, // but having these fields set to nullptr is a clear indication that // object was destructed. There was a bug in usrsctp when it - // invoked OnSctpOutboundPacket callback for destructed SctpTransport, + // invoked OnSctpOutboundPacket callback for destructed UsrsctpTransport, // which caused obscure SIGSEGV on access to these fields, // having this fields set to nullptr will make it easier to understand - // that SctpTransport was destructed and "use-after-free" bug happen. + // that UsrsctpTransport was destructed and "use-after-free" bug happen. // SIGSEGV error triggered on dereference these pointers will also // be easier to understand due to 0x0 address. All of this assumes // that ASAN is not enabled to detect "use-after-free", which is @@ -517,7 +632,8 @@ SctpTransport::~SctpTransport() { transport_ = nullptr; } -void SctpTransport::SetDtlsTransport(rtc::PacketTransportInternal* transport) { +void UsrsctpTransport::SetDtlsTransport( + rtc::PacketTransportInternal* transport) { RTC_DCHECK_RUN_ON(network_thread_); DisconnectTransportSignals(); transport_ = transport; @@ -533,9 +649,9 @@ void SctpTransport::SetDtlsTransport(rtc::PacketTransportInternal* transport) { } } -bool SctpTransport::Start(int local_sctp_port, - int remote_sctp_port, - int max_message_size) { +bool UsrsctpTransport::Start(int local_sctp_port, + int remote_sctp_port, + int max_message_size) { RTC_DCHECK_RUN_ON(network_thread_); if (local_sctp_port == -1) { local_sctp_port = kSctpDefaultPort; @@ -577,7 +693,7 @@ bool SctpTransport::Start(int local_sctp_port, return true; } -bool SctpTransport::OpenStream(int sid) { +bool UsrsctpTransport::OpenStream(int sid) { RTC_DCHECK_RUN_ON(network_thread_); if (sid > kMaxSctpSid) { RTC_LOG(LS_WARNING) << debug_name_ @@ -609,7 +725,7 @@ bool SctpTransport::OpenStream(int sid) { } } -bool SctpTransport::ResetStream(int sid) { +bool UsrsctpTransport::ResetStream(int sid) { RTC_DCHECK_RUN_ON(network_thread_); auto it = stream_status_by_sid_.find(sid); @@ -631,9 +747,10 @@ bool SctpTransport::ResetStream(int sid) { return true; } -bool SctpTransport::SendData(const SendDataParams& params, - const rtc::CopyOnWriteBuffer& payload, - SendDataResult* result) { +bool UsrsctpTransport::SendData(int sid, + const webrtc::SendDataParams& params, + const rtc::CopyOnWriteBuffer& payload, + SendDataResult* result) { RTC_DCHECK_RUN_ON(network_thread_); if (partial_outgoing_message_.has_value()) { @@ -644,8 +761,23 @@ bool SctpTransport::SendData(const SendDataParams& params, ready_to_send_data_ = false; return false; } + + // Do not queue data to send on a closing stream. + auto it = stream_status_by_sid_.find(sid); + if (it == stream_status_by_sid_.end() || !it->second.is_open()) { + RTC_LOG(LS_WARNING) + << debug_name_ + << "->SendData(...): " + "Not sending data because sid is unknown or closing: " + << sid; + if (result) { + *result = SDR_ERROR; + } + return false; + } + size_t payload_size = payload.size(); - OutgoingMessage message(payload, params); + OutgoingMessage message(payload, sid, params); SendDataResult send_message_result = SendMessageInternal(&message); if (result) { *result = send_message_result; @@ -668,24 +800,23 @@ bool SctpTransport::SendData(const SendDataParams& params, return true; } -SendDataResult SctpTransport::SendMessageInternal(OutgoingMessage* message) { +SendDataResult UsrsctpTransport::SendMessageInternal(OutgoingMessage* message) { RTC_DCHECK_RUN_ON(network_thread_); if (!sock_) { RTC_LOG(LS_WARNING) << debug_name_ << "->SendMessageInternal(...): " "Not sending packet with sid=" - << message->send_params().sid - << " len=" << message->size() << " before Start()."; + << message->sid() << " len=" << message->size() + << " before Start()."; return SDR_ERROR; } - if (message->send_params().type != DMT_CONTROL) { - auto it = stream_status_by_sid_.find(message->send_params().sid); - if (it == stream_status_by_sid_.end() || !it->second.is_open()) { - RTC_LOG(LS_WARNING) - << debug_name_ - << "->SendMessageInternal(...): " - "Not sending data because sid is unknown or closing: " - << message->send_params().sid; + if (message->send_params().type != webrtc::DataMessageType::kControl) { + auto it = stream_status_by_sid_.find(message->sid()); + if (it == stream_status_by_sid_.end()) { + RTC_LOG(LS_WARNING) << debug_name_ + << "->SendMessageInternal(...): " + "Not sending data because sid is unknown: " + << message->sid(); return SDR_ERROR; } } @@ -697,13 +828,23 @@ SendDataResult SctpTransport::SendMessageInternal(OutgoingMessage* message) { } // Send data using SCTP. - sctp_sendv_spa spa = CreateSctpSendParams(message->send_params()); + sctp_sendv_spa spa = CreateSctpSendParams( + message->sid(), message->send_params(), message->size()); + const void* data = message->data(); + size_t data_length = message->size(); + if (message->size() == 0) { + // Empty messages are replaced by a single NUL byte on the wire as SCTP + // doesn't support empty messages. + // The PPID carries the information that the payload needs to be ignored. + data = kZero; + data_length = 1; + } // Note: this send call is not atomic because the EOR bit is set. This means // that usrsctp can partially accept this message and it is our duty to buffer // the rest. - ssize_t send_res = usrsctp_sendv( - sock_, message->data(), message->size(), NULL, 0, &spa, - rtc::checked_cast(sizeof(spa)), SCTP_SENDV_SPA, 0); + ssize_t send_res = usrsctp_sendv(sock_, data, data_length, NULL, 0, &spa, + rtc::checked_cast(sizeof(spa)), + SCTP_SENDV_SPA, 0); if (send_res < 0) { if (errno == SCTP_EWOULDBLOCK) { ready_to_send_data_ = false; @@ -719,29 +860,30 @@ SendDataResult SctpTransport::SendMessageInternal(OutgoingMessage* message) { } size_t amount_sent = static_cast(send_res); - RTC_DCHECK_LE(amount_sent, message->size()); - message->Advance(amount_sent); + RTC_DCHECK_LE(amount_sent, data_length); + if (message->size() != 0) + message->Advance(amount_sent); // Only way out now is success. return SDR_SUCCESS; } -bool SctpTransport::ReadyToSendData() { +bool UsrsctpTransport::ReadyToSendData() { RTC_DCHECK_RUN_ON(network_thread_); return ready_to_send_data_; } -void SctpTransport::ConnectTransportSignals() { +void UsrsctpTransport::ConnectTransportSignals() { RTC_DCHECK_RUN_ON(network_thread_); if (!transport_) { return; } transport_->SignalWritableState.connect(this, - &SctpTransport::OnWritableState); - transport_->SignalReadPacket.connect(this, &SctpTransport::OnPacketRead); - transport_->SignalClosed.connect(this, &SctpTransport::OnClosed); + &UsrsctpTransport::OnWritableState); + transport_->SignalReadPacket.connect(this, &UsrsctpTransport::OnPacketRead); + transport_->SignalClosed.connect(this, &UsrsctpTransport::OnClosed); } -void SctpTransport::DisconnectTransportSignals() { +void UsrsctpTransport::DisconnectTransportSignals() { RTC_DCHECK_RUN_ON(network_thread_); if (!transport_) { return; @@ -751,7 +893,7 @@ void SctpTransport::DisconnectTransportSignals() { transport_->SignalClosed.disconnect(this); } -bool SctpTransport::Connect() { +bool UsrsctpTransport::Connect() { RTC_DCHECK_RUN_ON(network_thread_); RTC_LOG(LS_VERBOSE) << debug_name_ << "->Connect()."; @@ -814,7 +956,7 @@ bool SctpTransport::Connect() { return true; } -bool SctpTransport::OpenSctpSocket() { +bool UsrsctpTransport::OpenSctpSocket() { RTC_DCHECK_RUN_ON(network_thread_); if (sock_) { RTC_LOG(LS_WARNING) << debug_name_ @@ -857,7 +999,7 @@ bool SctpTransport::OpenSctpSocket() { return true; } -bool SctpTransport::ConfigureSctpSocket() { +bool UsrsctpTransport::ConfigureSctpSocket() { RTC_DCHECK_RUN_ON(network_thread_); RTC_DCHECK(sock_); // Make the socket non-blocking. Connect, close, shutdown etc will not block @@ -938,7 +1080,7 @@ bool SctpTransport::ConfigureSctpSocket() { return true; } -void SctpTransport::CloseSctpSocket() { +void UsrsctpTransport::CloseSctpSocket() { RTC_DCHECK_RUN_ON(network_thread_); if (sock_) { // We assume that SO_LINGER option is set to close the association when @@ -953,16 +1095,22 @@ void SctpTransport::CloseSctpSocket() { } } -bool SctpTransport::SendQueuedStreamResets() { +bool UsrsctpTransport::SendQueuedStreamResets() { RTC_DCHECK_RUN_ON(network_thread_); + auto needs_reset = + [this](const std::map::value_type& stream) { + // Ignore streams with partial outgoing messages as they are required to + // be fully sent by the WebRTC spec + // https://w3c.github.io/webrtc-pc/#closing-procedure + return stream.second.need_outgoing_reset() && + (!partial_outgoing_message_.has_value() || + partial_outgoing_message_.value().sid() != + static_cast(stream.first)); + }; // Figure out how many streams need to be reset. We need to do this so we can // allocate the right amount of memory for the sctp_reset_streams structure. - size_t num_streams = absl::c_count_if( - stream_status_by_sid_, - [](const std::map::value_type& stream) { - return stream.second.need_outgoing_reset(); - }); + size_t num_streams = absl::c_count_if(stream_status_by_sid_, needs_reset); if (num_streams == 0) { // Nothing to reset. return true; @@ -981,12 +1129,10 @@ bool SctpTransport::SendQueuedStreamResets() { resetp->srs_number_streams = rtc::checked_cast(num_streams); int result_idx = 0; - for (const std::map::value_type& stream : - stream_status_by_sid_) { - if (!stream.second.need_outgoing_reset()) { - continue; + for (const auto& stream : stream_status_by_sid_) { + if (needs_reset(stream)) { + resetp->srs_stream_list[result_idx++] = stream.first; } - resetp->srs_stream_list[result_idx++] = stream.first; } int ret = @@ -1015,7 +1161,7 @@ bool SctpTransport::SendQueuedStreamResets() { return true; } -void SctpTransport::SetReadyToSendData() { +void UsrsctpTransport::SetReadyToSendData() { RTC_DCHECK_RUN_ON(network_thread_); if (!ready_to_send_data_) { ready_to_send_data_ = true; @@ -1023,7 +1169,7 @@ void SctpTransport::SetReadyToSendData() { } } -bool SctpTransport::SendBufferedMessage() { +bool UsrsctpTransport::SendBufferedMessage() { RTC_DCHECK_RUN_ON(network_thread_); RTC_DCHECK(partial_outgoing_message_.has_value()); RTC_DLOG(LS_VERBOSE) << "Sending partially buffered message of size " @@ -1035,11 +1181,21 @@ bool SctpTransport::SendBufferedMessage() { return false; } RTC_DCHECK_EQ(0u, partial_outgoing_message_->size()); + + int sid = partial_outgoing_message_->sid(); partial_outgoing_message_.reset(); + + // Send the queued stream reset if it was pending for this stream. + auto it = stream_status_by_sid_.find(sid); + if (it->second.need_outgoing_reset()) { + SendQueuedStreamResets(); + } + return true; } -void SctpTransport::OnWritableState(rtc::PacketTransportInternal* transport) { +void UsrsctpTransport::OnWritableState( + rtc::PacketTransportInternal* transport) { RTC_DCHECK_RUN_ON(network_thread_); RTC_DCHECK_EQ(transport_, transport); if (!was_ever_writable_ && transport->writable()) { @@ -1051,14 +1207,14 @@ void SctpTransport::OnWritableState(rtc::PacketTransportInternal* transport) { } // Called by network interface when a packet has been received. -void SctpTransport::OnPacketRead(rtc::PacketTransportInternal* transport, - const char* data, - size_t len, - const int64_t& /* packet_time_us */, - int flags) { +void UsrsctpTransport::OnPacketRead(rtc::PacketTransportInternal* transport, + const char* data, + size_t len, + const int64_t& /* packet_time_us */, + int flags) { RTC_DCHECK_RUN_ON(network_thread_); RTC_DCHECK_EQ(transport_, transport); - TRACE_EVENT0("webrtc", "SctpTransport::OnPacketRead"); + TRACE_EVENT0("webrtc", "UsrsctpTransport::OnPacketRead"); if (flags & PF_SRTP_BYPASS) { // We are only interested in SCTP packets. @@ -1085,11 +1241,15 @@ void SctpTransport::OnPacketRead(rtc::PacketTransportInternal* transport, } } -void SctpTransport::OnClosed(rtc::PacketTransportInternal* transport) { - SignalClosedAbruptly(); +void UsrsctpTransport::OnClosed(rtc::PacketTransportInternal* transport) { + webrtc::RTCError error = + webrtc::RTCError(webrtc::RTCErrorType::OPERATION_ERROR_WITH_DATA, + "Transport channel closed"); + error.set_error_detail(webrtc::RTCErrorDetailType::SCTP_FAILURE); + SignalClosedAbruptly(error); } -void SctpTransport::OnSendThresholdCallback() { +void UsrsctpTransport::OnSendThresholdCallback() { RTC_DCHECK_RUN_ON(network_thread_); if (partial_outgoing_message_.has_value()) { if (!SendBufferedMessage()) { @@ -1100,7 +1260,7 @@ void SctpTransport::OnSendThresholdCallback() { SetReadyToSendData(); } -sockaddr_conn SctpTransport::GetSctpSockAddr(int port) { +sockaddr_conn UsrsctpTransport::GetSctpSockAddr(int port) { sockaddr_conn sconn = {0}; sconn.sconn_family = AF_CONN; #ifdef HAVE_SCONN_LEN @@ -1112,7 +1272,7 @@ sockaddr_conn SctpTransport::GetSctpSockAddr(int port) { return sconn; } -void SctpTransport::OnPacketFromSctpToNetwork( +void UsrsctpTransport::OnPacketFromSctpToNetwork( const rtc::CopyOnWriteBuffer& buffer) { RTC_DCHECK_RUN_ON(network_thread_); if (buffer.size() > (kSctpMtu)) { @@ -1122,7 +1282,7 @@ void SctpTransport::OnPacketFromSctpToNetwork( "than its official MTU: " << buffer.size() << " vs max of " << kSctpMtu; } - TRACE_EVENT0("webrtc", "SctpTransport::OnPacketFromSctpToNetwork"); + TRACE_EVENT0("webrtc", "UsrsctpTransport::OnPacketFromSctpToNetwork"); // Don't create noise by trying to send a packet when the DTLS transport isn't // even writable. @@ -1135,24 +1295,25 @@ void SctpTransport::OnPacketFromSctpToNetwork( rtc::PacketOptions(), PF_NORMAL); } -int SctpTransport::InjectDataOrNotificationFromSctpForTesting( +void UsrsctpTransport::InjectDataOrNotificationFromSctpForTesting( const void* data, size_t length, struct sctp_rcvinfo rcv, int flags) { - return OnDataOrNotificationFromSctp(data, length, rcv, flags); + OnDataOrNotificationFromSctp(data, length, rcv, flags); } -int SctpTransport::OnDataOrNotificationFromSctp(const void* data, - size_t length, - struct sctp_rcvinfo rcv, - int flags) { +void UsrsctpTransport::OnDataOrNotificationFromSctp(const void* data, + size_t length, + struct sctp_rcvinfo rcv, + int flags) { + RTC_DCHECK_RUN_ON(network_thread_); // If data is NULL, the SCTP association has been closed. if (!data) { RTC_LOG(LS_INFO) << debug_name_ << "->OnDataOrNotificationFromSctp(...): " "No data; association closed."; - return kSctpSuccessReturn; + return; } // Handle notifications early. @@ -1165,14 +1326,10 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data, << "->OnDataOrNotificationFromSctp(...): SCTP notification" << " length=" << length; - // Copy and dispatch asynchronously rtc::CopyOnWriteBuffer notification(reinterpret_cast(data), length); - network_thread_->PostTask(ToQueuedTask( - task_safety_, [this, notification = std::move(notification)]() { - OnNotificationFromSctp(notification); - })); - return kSctpSuccessReturn; + OnNotificationFromSctp(notification); + return; } // Log data chunk @@ -1185,12 +1342,12 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data, << ", eor=" << ((flags & MSG_EOR) ? "y" : "n"); // Validate payload protocol identifier - DataMessageType type = DMT_NONE; + webrtc::DataMessageType type; if (!GetDataMediaType(ppid, &type)) { // Unexpected PPID, dropping RTC_LOG(LS_ERROR) << "Received an unknown PPID " << ppid << " on an SCTP packet. Dropping."; - return kSctpSuccessReturn; + return; } // Expect only continuation messages belonging to the same SID. The SCTP @@ -1212,12 +1369,13 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data, // Furthermore, it is increased per stream and not on the whole // association. params.seq_num = rcv.rcv_ssn; - // There is no timestamp field in the SCTP API - params.timestamp = 0; - // Append the chunk's data to the message buffer - partial_incoming_message_.AppendData(reinterpret_cast(data), - length); + // Append the chunk's data to the message buffer unless we have a chunk with a + // PPID marking an empty message. + // See: https://tools.ietf.org/html/rfc8831#section-6.6 + if (!IsEmptyPPID(ppid)) + partial_incoming_message_.AppendData(reinterpret_cast(data), + length); partial_params_ = params; partial_flags_ = flags; @@ -1226,7 +1384,7 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data, if (partial_incoming_message_.size() < kSctpSendBufferSize) { // We still have space in the buffer. Continue buffering chunks until // the message is complete before handing it out. - return kSctpSuccessReturn; + return; } else { // The sender is exceeding the maximum message size that we announced. // Spit out a warning but still hand out the partial message. Note that @@ -1240,21 +1398,12 @@ int SctpTransport::OnDataOrNotificationFromSctp(const void* data, } } - // Dispatch the complete message. - // The ownership of the packet transfers to |invoker_|. Using - // CopyOnWriteBuffer is the most convenient way to do this. - network_thread_->PostTask(webrtc::ToQueuedTask( - task_safety_, [this, params = std::move(params), - message = partial_incoming_message_]() { - OnDataFromSctpToTransport(params, message); - })); - - // Reset the message buffer + // Dispatch the complete message and reset the message buffer. + OnDataFromSctpToTransport(params, partial_incoming_message_); partial_incoming_message_.Clear(); - return kSctpSuccessReturn; } -void SctpTransport::OnDataFromSctpToTransport( +void UsrsctpTransport::OnDataFromSctpToTransport( const ReceiveDataParams& params, const rtc::CopyOnWriteBuffer& buffer) { RTC_DCHECK_RUN_ON(network_thread_); @@ -1267,7 +1416,7 @@ void SctpTransport::OnDataFromSctpToTransport( SignalDataReceived(params, buffer); } -void SctpTransport::OnNotificationFromSctp( +void UsrsctpTransport::OnNotificationFromSctp( const rtc::CopyOnWriteBuffer& buffer) { RTC_DCHECK_RUN_ON(network_thread_); if (buffer.size() < sizeof(sctp_notification::sn_header)) { @@ -1368,7 +1517,8 @@ void SctpTransport::OnNotificationFromSctp( } } -void SctpTransport::OnNotificationAssocChange(const sctp_assoc_change& change) { +void UsrsctpTransport::OnNotificationAssocChange( + const sctp_assoc_change& change) { RTC_DCHECK_RUN_ON(network_thread_); switch (change.sac_state) { case SCTP_COMM_UP: @@ -1382,9 +1532,17 @@ void SctpTransport::OnNotificationAssocChange(const sctp_assoc_change& change) { // came up, send any queued resets. SendQueuedStreamResets(); break; - case SCTP_COMM_LOST: + case SCTP_COMM_LOST: { RTC_LOG(LS_INFO) << "Association change SCTP_COMM_LOST"; + webrtc::RTCError error = webrtc::RTCError( + webrtc::RTCErrorType::OPERATION_ERROR_WITH_DATA, + SctpErrorCauseCodeToString( + static_cast(change.sac_error))); + error.set_error_detail(webrtc::RTCErrorDetailType::SCTP_FAILURE); + error.set_sctp_cause_code(change.sac_error); + SignalClosedAbruptly(error); break; + } case SCTP_RESTART: RTC_LOG(LS_INFO) << "Association change SCTP_RESTART"; break; @@ -1400,7 +1558,7 @@ void SctpTransport::OnNotificationAssocChange(const sctp_assoc_change& change) { } } -void SctpTransport::OnStreamResetEvent( +void UsrsctpTransport::OnStreamResetEvent( const struct sctp_stream_reset_event* evt) { RTC_DCHECK_RUN_ON(network_thread_); diff --git a/media/sctp/sctp_transport.h b/media/sctp/usrsctp_transport.h similarity index 85% rename from media/sctp/sctp_transport.h rename to media/sctp/usrsctp_transport.h index 38a89fcb61..5dcf57b243 100644 --- a/media/sctp/sctp_transport.h +++ b/media/sctp/usrsctp_transport.h @@ -8,8 +8,8 @@ * be found in the AUTHORS file in the root of the source tree. */ -#ifndef MEDIA_SCTP_SCTP_TRANSPORT_H_ -#define MEDIA_SCTP_SCTP_TRANSPORT_H_ +#ifndef MEDIA_SCTP_USRSCTP_TRANSPORT_H_ +#define MEDIA_SCTP_USRSCTP_TRANSPORT_H_ #include @@ -21,7 +21,6 @@ #include #include "absl/types/optional.h" -#include "api/transport/sctp_transport_factory_interface.h" #include "rtc_base/buffer.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/copy_on_write_buffer.h" @@ -66,23 +65,24 @@ struct SctpInboundPacket; // 12. SctpTransport::SignalDataReceived(data) // [from the same thread, methods registered/connected to // SctpTransport are called with the recieved data] -class SctpTransport : public SctpTransportInternal, - public sigslot::has_slots<> { +class UsrsctpTransport : public SctpTransportInternal, + public sigslot::has_slots<> { public: // |network_thread| is where packets will be processed and callbacks from // this transport will be posted, and is the only thread on which public // methods can be called. // |transport| is not required (can be null). - SctpTransport(rtc::Thread* network_thread, - rtc::PacketTransportInternal* transport); - ~SctpTransport() override; + UsrsctpTransport(rtc::Thread* network_thread, + rtc::PacketTransportInternal* transport); + ~UsrsctpTransport() override; // SctpTransportInternal overrides (see sctptransportinternal.h for comments). void SetDtlsTransport(rtc::PacketTransportInternal* transport) override; bool Start(int local_port, int remote_port, int max_message_size) override; bool OpenStream(int sid) override; bool ResetStream(int sid) override; - bool SendData(const SendDataParams& params, + bool SendData(int sid, + const webrtc::SendDataParams& params, const rtc::CopyOnWriteBuffer& payload, SendDataResult* result = nullptr) override; bool ReadyToSendData() override; @@ -96,10 +96,10 @@ class SctpTransport : public SctpTransportInternal, void set_debug_name_for_testing(const char* debug_name) override { debug_name_ = debug_name; } - int InjectDataOrNotificationFromSctpForTesting(const void* data, - size_t length, - struct sctp_rcvinfo rcv, - int flags); + void InjectDataOrNotificationFromSctpForTesting(const void* data, + size_t length, + struct sctp_rcvinfo rcv, + int flags); // Exposed to allow Post call from c-callbacks. // TODO(deadbeef): Remove this or at least make it return a const pointer. @@ -114,8 +114,9 @@ class SctpTransport : public SctpTransportInternal, class OutgoingMessage { public: OutgoingMessage(const rtc::CopyOnWriteBuffer& buffer, - const SendDataParams& send_params) - : buffer_(buffer), send_params_(send_params) {} + int sid, + const webrtc::SendDataParams& send_params) + : buffer_(buffer), sid_(sid), send_params_(send_params) {} // Advances the buffer by the incremented amount. Must not advance further // than the current data size. @@ -128,11 +129,13 @@ class SctpTransport : public SctpTransportInternal, const void* data() const { return buffer_.data() + offset_; } - SendDataParams send_params() const { return send_params_; } + int sid() const { return sid_; } + webrtc::SendDataParams send_params() const { return send_params_; } private: const rtc::CopyOnWriteBuffer buffer_; - const SendDataParams send_params_; + int sid_; + const webrtc::SendDataParams send_params_; size_t offset_ = 0; }; @@ -180,12 +183,12 @@ class SctpTransport : public SctpTransportInternal, // Called using |invoker_| to send packet on the network. void OnPacketFromSctpToNetwork(const rtc::CopyOnWriteBuffer& buffer); - // Called on the SCTP thread. + // Called on the network thread. // Flags are standard socket API flags (RFC 6458). - int OnDataOrNotificationFromSctp(const void* data, - size_t length, - struct sctp_rcvinfo rcv, - int flags); + void OnDataOrNotificationFromSctp(const void* data, + size_t length, + struct sctp_rcvinfo rcv, + int flags); // Called using |invoker_| to decide what to do with the data. void OnDataFromSctpToTransport(const ReceiveDataParams& params, const rtc::CopyOnWriteBuffer& buffer); @@ -270,7 +273,7 @@ class SctpTransport : public SctpTransportInternal, std::map stream_status_by_sid_; // A static human-readable name for debugging messages. - const char* debug_name_ = "SctpTransport"; + const char* debug_name_ = "UsrsctpTransport"; // Hides usrsctp interactions from this header file. class UsrSctpWrapper; // Number of channels negotiated. Not set before negotiation completes. @@ -281,24 +284,13 @@ class SctpTransport : public SctpTransportInternal, // various callbacks. uintptr_t id_ = 0; - RTC_DISALLOW_COPY_AND_ASSIGN(SctpTransport); -}; - -class SctpTransportFactory : public webrtc::SctpTransportFactoryInterface { - public: - explicit SctpTransportFactory(rtc::Thread* network_thread) - : network_thread_(network_thread) {} - - std::unique_ptr CreateSctpTransport( - rtc::PacketTransportInternal* transport) override { - return std::unique_ptr( - new SctpTransport(network_thread_, transport)); - } + friend class UsrsctpTransportMap; - private: - rtc::Thread* network_thread_; + RTC_DISALLOW_COPY_AND_ASSIGN(UsrsctpTransport); }; +class UsrsctpTransportMap; + } // namespace cricket -#endif // MEDIA_SCTP_SCTP_TRANSPORT_H_ +#endif // MEDIA_SCTP_USRSCTP_TRANSPORT_H_ diff --git a/media/sctp/sctp_transport_reliability_unittest.cc b/media/sctp/usrsctp_transport_reliability_unittest.cc similarity index 88% rename from media/sctp/sctp_transport_reliability_unittest.cc rename to media/sctp/usrsctp_transport_reliability_unittest.cc index 80b7d61215..104e320398 100644 --- a/media/sctp/sctp_transport_reliability_unittest.cc +++ b/media/sctp/usrsctp_transport_reliability_unittest.cc @@ -7,19 +7,20 @@ * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ -#include "media/sctp/sctp_transport.h" - #include #include #include #include "media/sctp/sctp_transport_internal.h" -#include "rtc_base/async_invoker.h" +#include "media/sctp/usrsctp_transport.h" #include "rtc_base/copy_on_write_buffer.h" +#include "rtc_base/event.h" #include "rtc_base/gunit.h" #include "rtc_base/logging.h" #include "rtc_base/random.h" #include "rtc_base/synchronization/mutex.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/thread.h" #include "test/gtest.h" @@ -53,11 +54,6 @@ class SimulatedPacketTransport final : public rtc::PacketTransportInternal { ~SimulatedPacketTransport() override { RTC_DCHECK_RUN_ON(transport_thread_); - auto destination = destination_.load(); - if (destination != nullptr) { - invoker_.Flush(destination->transport_thread_); - } - invoker_.Flush(transport_thread_); destination_ = nullptr; SignalWritableState(this); } @@ -82,15 +78,13 @@ class SimulatedPacketTransport final : public rtc::PacketTransportInternal { return 0; } rtc::CopyOnWriteBuffer buffer(data, len); - auto send_job = [this, flags, buffer = std::move(buffer)] { - auto destination = destination_.load(); - if (destination == nullptr) { - return; - } - destination->SignalReadPacket( - destination, reinterpret_cast(buffer.data()), - buffer.size(), rtc::Time(), flags); - }; + auto send_task = ToQueuedTask( + destination->task_safety_.flag(), + [destination, flags, buffer = std::move(buffer)] { + destination->SignalReadPacket( + destination, reinterpret_cast(buffer.data()), + buffer.size(), rtc::Time(), flags); + }); // Introduce random send delay in range [0 .. 2 * avg_send_delay_millis_] // millis, which will also work as random packet reordering mechanism. uint16_t actual_send_delay = avg_send_delay_millis_; @@ -100,12 +94,10 @@ class SimulatedPacketTransport final : public rtc::PacketTransportInternal { actual_send_delay += reorder_delay; if (actual_send_delay > 0) { - invoker_.AsyncInvokeDelayed(RTC_FROM_HERE, - destination->transport_thread_, - std::move(send_job), actual_send_delay); + destination->transport_thread_->PostDelayedTask(std::move(send_task), + actual_send_delay); } else { - invoker_.AsyncInvoke(RTC_FROM_HERE, destination->transport_thread_, - std::move(send_job)); + destination->transport_thread_->PostTask(std::move(send_task)); } return 0; } @@ -135,29 +127,25 @@ class SimulatedPacketTransport final : public rtc::PacketTransportInternal { const uint8_t packet_loss_percents_; const uint16_t avg_send_delay_millis_; std::atomic destination_ ATOMIC_VAR_INIT(nullptr); - rtc::AsyncInvoker invoker_; webrtc::Random random_; + webrtc::ScopedTaskSafety task_safety_; RTC_DISALLOW_COPY_AND_ASSIGN(SimulatedPacketTransport); }; /** - * A helper class to send specified number of messages - * over SctpTransport with SCTP reliability settings - * provided by user. The reliability settings are specified - * by passing a template instance of SendDataParams. - * When .sid field inside SendDataParams is specified to - * negative value it means that actual .sid will be - * assigned by sender itself, .sid will be assigned from - * range [cricket::kMinSctpSid; cricket::kMaxSctpSid]. - * The wide range of sids are used to possibly trigger - * more execution paths inside usrsctp. + * A helper class to send specified number of messages over UsrsctpTransport + * with SCTP reliability settings provided by user. The reliability settings are + * specified by passing a template instance of SendDataParams. The sid will be + * assigned by sender itself and will be assigned from range + * [cricket::kMinSctpSid; cricket::kMaxSctpSid]. The wide range of sids are used + * to possibly trigger more execution paths inside usrsctp. */ class SctpDataSender final { public: SctpDataSender(rtc::Thread* thread, - cricket::SctpTransport* transport, + cricket::UsrsctpTransport* transport, uint64_t target_messages_count, - cricket::SendDataParams send_params, + webrtc::SendDataParams send_params, uint32_t sender_id) : thread_(thread), transport_(transport), @@ -169,14 +157,14 @@ class SctpDataSender final { } void Start() { - invoker_.AsyncInvoke(RTC_FROM_HERE, thread_, [this] { + thread_->PostTask(ToQueuedTask(task_safety_.flag(), [this] { if (started_) { RTC_LOG(LS_INFO) << sender_id_ << " sender is already started"; return; } started_ = true; SendNextMessage(); - }); + })); } uint64_t BytesSentCount() const { return num_bytes_sent_; } @@ -208,55 +196,52 @@ class SctpDataSender final { << target_messages_count_; } - cricket::SendDataParams params(send_params_); - if (params.sid < 0) { - params.sid = cricket::kMinSctpSid + - (num_messages_sent_ % cricket::kMaxSctpStreams); - } + webrtc::SendDataParams params(send_params_); + int sid = + cricket::kMinSctpSid + (num_messages_sent_ % cricket::kMaxSctpStreams); cricket::SendDataResult result; - transport_->SendData(params, payload_, &result); + transport_->SendData(sid, params, payload_, &result); switch (result) { case cricket::SDR_BLOCK: // retry after timeout - invoker_.AsyncInvokeDelayed( - RTC_FROM_HERE, thread_, - rtc::Bind(&SctpDataSender::SendNextMessage, this), 500); + thread_->PostDelayedTask( + ToQueuedTask(task_safety_.flag(), [this] { SendNextMessage(); }), + 500); break; case cricket::SDR_SUCCESS: // send next num_bytes_sent_ += payload_.size(); ++num_messages_sent_; - invoker_.AsyncInvoke( - RTC_FROM_HERE, thread_, - rtc::Bind(&SctpDataSender::SendNextMessage, this)); + thread_->PostTask( + ToQueuedTask(task_safety_.flag(), [this] { SendNextMessage(); })); break; case cricket::SDR_ERROR: // give up - last_error_ = "SctpTransport::SendData error returned"; + last_error_ = "UsrsctpTransport::SendData error returned"; sent_target_messages_count_.Set(); break; } } rtc::Thread* const thread_; - cricket::SctpTransport* const transport_; + cricket::UsrsctpTransport* const transport_; const uint64_t target_messages_count_; - const cricket::SendDataParams send_params_; + const webrtc::SendDataParams send_params_; const uint32_t sender_id_; rtc::CopyOnWriteBuffer payload_{std::string(1400, '.').c_str(), 1400}; std::atomic started_ ATOMIC_VAR_INIT(false); - rtc::AsyncInvoker invoker_; std::atomic num_messages_sent_ ATOMIC_VAR_INIT(0); rtc::Event sent_target_messages_count_{true, false}; std::atomic num_bytes_sent_ ATOMIC_VAR_INIT(0); absl::optional last_error_; + webrtc::ScopedTaskSafetyDetached task_safety_; RTC_DISALLOW_COPY_AND_ASSIGN(SctpDataSender); }; /** * A helper class which counts number of received messages - * and bytes over SctpTransport. Also allow waiting until + * and bytes over UsrsctpTransport. Also allow waiting until * specified number of messages received. */ class SctpDataReceiver final : public sigslot::has_slots<> { @@ -323,7 +308,7 @@ class ThreadPool final { }; /** - * Represents single ping-pong test over SctpTransport. + * Represents single ping-pong test over UsrsctpTransport. * User can specify target number of message for bidirectional * send, underlying transport packets loss and average packet delay * and SCTP delivery settings. @@ -338,7 +323,7 @@ class SctpPingPong final { uint32_t messages_count, uint8_t packet_loss_percents, uint16_t avg_send_delay_millis, - cricket::SendDataParams send_params) + webrtc::SendDataParams send_params) : id_(id), port1_(port1), port2_(port2), @@ -505,7 +490,7 @@ class SctpPingPong final { "SctpPingPong id = " + rtc::ToString(id_) + ", packet transport 1", transport_thread1_, packet_loss_percents_, avg_send_delay_millis_)); data_receiver1_.reset(new SctpDataReceiver(id_, messages_count_)); - sctp_transport1_.reset(new cricket::SctpTransport( + sctp_transport1_.reset(new cricket::UsrsctpTransport( transport_thread1_, packet_transport1_.get())); sctp_transport1_->set_debug_name_for_testing("sctp transport 1"); @@ -527,7 +512,7 @@ class SctpPingPong final { "SctpPingPong id = " + rtc::ToString(id_) + "packet transport 2", transport_thread2_, packet_loss_percents_, avg_send_delay_millis_)); data_receiver2_.reset(new SctpDataReceiver(id_, messages_count_)); - sctp_transport2_.reset(new cricket::SctpTransport( + sctp_transport2_.reset(new cricket::UsrsctpTransport( transport_thread2_, packet_transport2_.get())); sctp_transport2_->set_debug_name_for_testing("sctp transport 2"); sctp_transport2_->SignalDataReceived.connect( @@ -576,8 +561,8 @@ class SctpPingPong final { std::unique_ptr packet_transport2_; std::unique_ptr data_receiver1_; std::unique_ptr data_receiver2_; - std::unique_ptr sctp_transport1_; - std::unique_ptr sctp_transport2_; + std::unique_ptr sctp_transport1_; + std::unique_ptr sctp_transport2_; std::unique_ptr data_sender1_; std::unique_ptr data_sender2_; mutable webrtc::Mutex lock_; @@ -591,7 +576,7 @@ class SctpPingPong final { const uint32_t messages_count_; const uint8_t packet_loss_percents_; const uint16_t avg_send_delay_millis_; - const cricket::SendDataParams send_params_; + const webrtc::SendDataParams send_params_; RTC_DISALLOW_COPY_AND_ASSIGN(SctpPingPong); }; @@ -652,12 +637,8 @@ TEST_F(UsrSctpReliabilityTest, static_assert(wait_timeout > 0, "Timeout computation must produce positive value"); - cricket::SendDataParams send_params; - send_params.sid = -1; + webrtc::SendDataParams send_params; send_params.ordered = true; - send_params.reliable = true; - send_params.max_rtx_count = 0; - send_params.max_rtx_ms = 0; SctpPingPong test(1, kTransport1Port, kTransport2Port, thread1.get(), thread2.get(), messages_count, packet_loss_percents, @@ -690,12 +671,8 @@ TEST_F(UsrSctpReliabilityTest, static_assert(wait_timeout > 0, "Timeout computation must produce positive value"); - cricket::SendDataParams send_params; - send_params.sid = -1; + webrtc::SendDataParams send_params; send_params.ordered = true; - send_params.reliable = true; - send_params.max_rtx_count = 0; - send_params.max_rtx_ms = 0; SctpPingPong test(1, kTransport1Port, kTransport2Port, thread1.get(), thread2.get(), messages_count, packet_loss_percents, @@ -729,12 +706,10 @@ TEST_F(UsrSctpReliabilityTest, static_assert(wait_timeout > 0, "Timeout computation must produce positive value"); - cricket::SendDataParams send_params; - send_params.sid = -1; + webrtc::SendDataParams send_params; send_params.ordered = false; - send_params.reliable = false; - send_params.max_rtx_count = INT_MAX; - send_params.max_rtx_ms = INT_MAX; + send_params.max_rtx_count = std::numeric_limits::max(); + send_params.max_rtx_ms = std::numeric_limits::max(); SctpPingPong test(1, kTransport1Port, kTransport2Port, thread1.get(), thread2.get(), messages_count, packet_loss_percents, @@ -766,12 +741,8 @@ TEST_F(UsrSctpReliabilityTest, DISABLED_AllMessagesAreDeliveredOverLossyConnectionConcurrentTests) { ThreadPool pool(16); - cricket::SendDataParams send_params; - send_params.sid = -1; + webrtc::SendDataParams send_params; send_params.ordered = true; - send_params.reliable = true; - send_params.max_rtx_count = 0; - send_params.max_rtx_ms = 0; constexpr uint32_t base_sctp_port = 5000; // The constants value below were experimentally chosen diff --git a/media/sctp/sctp_transport_unittest.cc b/media/sctp/usrsctp_transport_unittest.cc similarity index 86% rename from media/sctp/sctp_transport_unittest.cc rename to media/sctp/usrsctp_transport_unittest.cc index 98a91225b2..59e9c59b3d 100644 --- a/media/sctp/sctp_transport_unittest.cc +++ b/media/sctp/usrsctp_transport_unittest.cc @@ -8,7 +8,7 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include "media/sctp/sctp_transport.h" +#include "media/sctp/usrsctp_transport.h" #include #include @@ -74,7 +74,7 @@ class SctpFakeDataReceiver : public sigslot::has_slots<> { class SctpTransportObserver : public sigslot::has_slots<> { public: - explicit SctpTransportObserver(SctpTransport* transport) { + explicit SctpTransportObserver(UsrsctpTransport* transport) { transport->SignalClosingProcedureComplete.connect( this, &SctpTransportObserver::OnClosingProcedureComplete); transport->SignalReadyToSendData.connect( @@ -105,7 +105,8 @@ class SctpTransportObserver : public sigslot::has_slots<> { // been closed. class SignalTransportClosedReopener : public sigslot::has_slots<> { public: - SignalTransportClosedReopener(SctpTransport* transport, SctpTransport* peer) + SignalTransportClosedReopener(UsrsctpTransport* transport, + UsrsctpTransport* peer) : transport_(transport), peer_(peer) {} int StreamCloseCount(int stream) { return absl::c_count(streams_, stream); } @@ -117,8 +118,8 @@ class SignalTransportClosedReopener : public sigslot::has_slots<> { streams_.push_back(stream); } - SctpTransport* transport_; - SctpTransport* peer_; + UsrsctpTransport* transport_; + UsrsctpTransport* peer_; std::vector streams_; }; @@ -169,27 +170,26 @@ class SctpTransportTest : public ::testing::Test, public sigslot::has_slots<> { return ret; } - SctpTransport* CreateTransport(FakeDtlsTransport* fake_dtls, - SctpFakeDataReceiver* recv) { - SctpTransport* transport = - new SctpTransport(rtc::Thread::Current(), fake_dtls); + UsrsctpTransport* CreateTransport(FakeDtlsTransport* fake_dtls, + SctpFakeDataReceiver* recv) { + UsrsctpTransport* transport = + new UsrsctpTransport(rtc::Thread::Current(), fake_dtls); // When data is received, pass it to the SctpFakeDataReceiver. transport->SignalDataReceived.connect( recv, &SctpFakeDataReceiver::OnDataReceived); return transport; } - bool SendData(SctpTransport* chan, + bool SendData(UsrsctpTransport* chan, int sid, const std::string& msg, SendDataResult* result, bool ordered = false) { - SendDataParams params; - params.sid = sid; + webrtc::SendDataParams params; params.ordered = ordered; - return chan->SendData(params, rtc::CopyOnWriteBuffer(&msg[0], msg.length()), - result); + return chan->SendData( + sid, params, rtc::CopyOnWriteBuffer(&msg[0], msg.length()), result); } bool ReceivedData(const SctpFakeDataReceiver* recv, @@ -210,8 +210,8 @@ class SctpTransportTest : public ::testing::Test, public sigslot::has_slots<> { return !thread->IsQuitting(); } - SctpTransport* transport1() { return transport1_.get(); } - SctpTransport* transport2() { return transport2_.get(); } + UsrsctpTransport* transport1() { return transport1_.get(); } + UsrsctpTransport* transport2() { return transport2_.get(); } SctpFakeDataReceiver* receiver1() { return recv1_.get(); } SctpFakeDataReceiver* receiver2() { return recv2_.get(); } FakeDtlsTransport* fake_dtls1() { return fake_dtls1_.get(); } @@ -229,8 +229,8 @@ class SctpTransportTest : public ::testing::Test, public sigslot::has_slots<> { std::unique_ptr fake_dtls2_; std::unique_ptr recv1_; std::unique_ptr recv2_; - std::unique_ptr transport1_; - std::unique_ptr transport2_; + std::unique_ptr transport1_; + std::unique_ptr transport2_; int transport1_ready_to_send_count_ = 0; int transport2_ready_to_send_count_ = 0; @@ -244,9 +244,9 @@ TEST_F(SctpTransportTest, MessageInterleavedWithNotification) { FakeDtlsTransport fake_dtls2("fake dtls 2", 0); SctpFakeDataReceiver recv1; SctpFakeDataReceiver recv2; - std::unique_ptr transport1( + std::unique_ptr transport1( CreateTransport(&fake_dtls1, &recv1)); - std::unique_ptr transport2( + std::unique_ptr transport2( CreateTransport(&fake_dtls2, &recv2)); // Add a stream. @@ -282,8 +282,8 @@ TEST_F(SctpTransportTest, MessageInterleavedWithNotification) { meta.rcv_tsn = 42; meta.rcv_cumtsn = 42; chunk.SetData("meow?", 5); - EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting( - chunk.data(), chunk.size(), meta, 0)); + transport1->InjectDataOrNotificationFromSctpForTesting(chunk.data(), + chunk.size(), meta, 0); // Inject a notification in between chunks. union sctp_notification notification; @@ -292,15 +292,15 @@ TEST_F(SctpTransportTest, MessageInterleavedWithNotification) { notification.sn_header.sn_type = SCTP_PEER_ADDR_CHANGE; notification.sn_header.sn_flags = 0; notification.sn_header.sn_length = sizeof(notification); - EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting( - ¬ification, sizeof(notification), {0}, MSG_NOTIFICATION)); + transport1->InjectDataOrNotificationFromSctpForTesting( + ¬ification, sizeof(notification), {0}, MSG_NOTIFICATION); // Inject chunk 2/2 meta.rcv_tsn = 42; meta.rcv_cumtsn = 43; chunk.SetData(" rawr!", 6); - EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting( - chunk.data(), chunk.size(), meta, MSG_EOR)); + transport1->InjectDataOrNotificationFromSctpForTesting( + chunk.data(), chunk.size(), meta, MSG_EOR); // Expect the message to contain both chunks. EXPECT_TRUE_WAIT(ReceivedData(&recv1, 1, "meow? rawr!"), kDefaultTimeout); @@ -317,9 +317,9 @@ TEST_F(SctpTransportTest, SwitchDtlsTransport) { SctpFakeDataReceiver recv2; // Construct transport1 with the "black hole" transport. - std::unique_ptr transport1( + std::unique_ptr transport1( CreateTransport(&black_hole, &recv1)); - std::unique_ptr transport2( + std::unique_ptr transport2( CreateTransport(&fake_dtls2, &recv2)); // Add a stream. @@ -377,9 +377,9 @@ TEST_F(SctpTransportTest, NegativeOnePortTreatedAsDefault) { FakeDtlsTransport fake_dtls2("fake dtls 2", 0); SctpFakeDataReceiver recv1; SctpFakeDataReceiver recv2; - std::unique_ptr transport1( + std::unique_ptr transport1( CreateTransport(&fake_dtls1, &recv1)); - std::unique_ptr transport2( + std::unique_ptr transport2( CreateTransport(&fake_dtls2, &recv2)); // Add a stream. @@ -406,7 +406,8 @@ TEST_F(SctpTransportTest, NegativeOnePortTreatedAsDefault) { TEST_F(SctpTransportTest, OpenStreamWithAlreadyOpenedStreamFails) { FakeDtlsTransport fake_dtls("fake dtls", 0); SctpFakeDataReceiver recv; - std::unique_ptr transport(CreateTransport(&fake_dtls, &recv)); + std::unique_ptr transport( + CreateTransport(&fake_dtls, &recv)); EXPECT_TRUE(transport->OpenStream(1)); EXPECT_FALSE(transport->OpenStream(1)); } @@ -414,7 +415,8 @@ TEST_F(SctpTransportTest, OpenStreamWithAlreadyOpenedStreamFails) { TEST_F(SctpTransportTest, ResetStreamWithAlreadyResetStreamFails) { FakeDtlsTransport fake_dtls("fake dtls", 0); SctpFakeDataReceiver recv; - std::unique_ptr transport(CreateTransport(&fake_dtls, &recv)); + std::unique_ptr transport( + CreateTransport(&fake_dtls, &recv)); EXPECT_TRUE(transport->OpenStream(1)); EXPECT_TRUE(transport->ResetStream(1)); EXPECT_FALSE(transport->ResetStream(1)); @@ -425,7 +427,8 @@ TEST_F(SctpTransportTest, ResetStreamWithAlreadyResetStreamFails) { TEST_F(SctpTransportTest, SignalReadyToSendDataAfterDtlsWritable) { FakeDtlsTransport fake_dtls("fake dtls", 0); SctpFakeDataReceiver recv; - std::unique_ptr transport(CreateTransport(&fake_dtls, &recv)); + std::unique_ptr transport( + CreateTransport(&fake_dtls, &recv)); SctpTransportObserver observer(transport.get()); transport->Start(kSctpDefaultPort, kSctpDefaultPort, kSctpSendBufferSize); @@ -438,8 +441,8 @@ class SctpTransportTestWithOrdered : public SctpTransportTest, public ::testing::WithParamInterface {}; -// Tests that a small message gets buffered and later sent by the SctpTransport -// when the sctp library only accepts the message partially. +// Tests that a small message gets buffered and later sent by the +// UsrsctpTransport when the sctp library only accepts the message partially. TEST_P(SctpTransportTestWithOrdered, SendSmallBufferedOutgoingMessage) { bool ordered = GetParam(); SetupConnectedTransportsWithTwoStreams(); @@ -456,7 +459,7 @@ TEST_P(SctpTransportTestWithOrdered, SendSmallBufferedOutgoingMessage) { ordered)); std::string buffered_message("hello hello"); - // SctpTransport accepts this message by buffering part of it. + // UsrsctpTransport accepts this message by buffering part of it. ASSERT_TRUE( SendData(transport1(), /*sid=*/1, buffered_message, &result, ordered)); ASSERT_TRUE(transport1()->ReadyToSendData()); @@ -478,8 +481,8 @@ TEST_P(SctpTransportTestWithOrdered, SendSmallBufferedOutgoingMessage) { EXPECT_EQ(2u, receiver2()->num_messages_received()); } -// Tests that a large message gets buffered and later sent by the SctpTransport -// when the sctp library only accepts the message partially. +// Tests that a large message gets buffered and later sent by the +// UsrsctpTransport when the sctp library only accepts the message partially. TEST_P(SctpTransportTestWithOrdered, SendLargeBufferedOutgoingMessage) { bool ordered = GetParam(); SetupConnectedTransportsWithTwoStreams(); @@ -496,7 +499,7 @@ TEST_P(SctpTransportTestWithOrdered, SendLargeBufferedOutgoingMessage) { ordered)); std::string buffered_message(kSctpSendBufferSize, 'b'); - // SctpTransport accepts this message by buffering the second half. + // UsrsctpTransport accepts this message by buffering the second half. ASSERT_TRUE( SendData(transport1(), /*sid=*/1, buffered_message, &result, ordered)); ASSERT_TRUE(transport1()->ReadyToSendData()); @@ -518,6 +521,47 @@ TEST_P(SctpTransportTestWithOrdered, SendLargeBufferedOutgoingMessage) { EXPECT_EQ(2u, receiver2()->num_messages_received()); } +// Tests that a large message gets buffered and later sent by the +// UsrsctpTransport when the sctp library only accepts the message partially +// during a stream reset. +TEST_P(SctpTransportTestWithOrdered, + SendLargeBufferedOutgoingMessageDuringReset) { + bool ordered = GetParam(); + SetupConnectedTransportsWithTwoStreams(); + SctpTransportObserver transport2_observer(transport2()); + + // Wait for initial SCTP association to be formed. + EXPECT_EQ_WAIT(1, transport1_ready_to_send_count(), kDefaultTimeout); + // Make the fake transport unwritable so that messages pile up for the SCTP + // socket. + fake_dtls1()->SetWritable(false); + SendDataResult result; + + // Fill almost all of sctp library's send buffer. + ASSERT_TRUE(SendData(transport1(), /*sid=*/1, + std::string(kSctpSendBufferSize / 2, 'a'), &result, + ordered)); + + std::string buffered_message(kSctpSendBufferSize, 'b'); + // UsrsctpTransport accepts this message by buffering the second half. + ASSERT_TRUE( + SendData(transport1(), /*sid=*/1, buffered_message, &result, ordered)); + // Queue a stream reset + transport1()->ResetStream(/*sid=*/1); + + // Make the transport writable again and expect a "SignalReadyToSendData" at + // some point after sending the buffered message. + fake_dtls1()->SetWritable(true); + EXPECT_EQ_WAIT(2, transport1_ready_to_send_count(), kDefaultTimeout); + + // Queued message should be received by the receiver before receiving the + // reset + EXPECT_TRUE_WAIT(ReceivedData(receiver2(), 1, buffered_message), + kDefaultTimeout); + EXPECT_EQ(2u, receiver2()->num_messages_received()); + EXPECT_TRUE_WAIT(transport2_observer.WasStreamClosed(1), kDefaultTimeout); +} + TEST_P(SctpTransportTestWithOrdered, SendData) { bool ordered = GetParam(); SetupConnectedTransportsWithTwoStreams(); @@ -531,8 +575,6 @@ TEST_P(SctpTransportTestWithOrdered, SendData) { RTC_LOG(LS_VERBOSE) << "recv2.received=" << receiver2()->received() << ", recv2.last_params.sid=" << receiver2()->last_params().sid - << ", recv2.last_params.timestamp=" - << receiver2()->last_params().timestamp << ", recv2.last_params.seq_num=" << receiver2()->last_params().seq_num << ", recv2.last_data=" << receiver2()->last_data(); @@ -546,8 +588,6 @@ TEST_P(SctpTransportTestWithOrdered, SendData) { RTC_LOG(LS_VERBOSE) << "recv1.received=" << receiver1()->received() << ", recv1.last_params.sid=" << receiver1()->last_params().sid - << ", recv1.last_params.timestamp=" - << receiver1()->last_params().timestamp << ", recv1.last_params.seq_num=" << receiver1()->last_params().seq_num << ", recv1.last_data=" << receiver1()->last_data(); @@ -558,15 +598,14 @@ TEST_P(SctpTransportTestWithOrdered, SendDataBlocked) { SetupConnectedTransportsWithTwoStreams(); SendDataResult result; - SendDataParams params; - params.sid = 1; + webrtc::SendDataParams params; params.ordered = GetParam(); std::vector buffer(1024 * 64, 0); for (size_t i = 0; i < 100; ++i) { transport1()->SendData( - params, rtc::CopyOnWriteBuffer(&buffer[0], buffer.size()), &result); + 1, params, rtc::CopyOnWriteBuffer(&buffer[0], buffer.size()), &result); if (result == SDR_BLOCK) break; } @@ -585,15 +624,15 @@ TEST_P(SctpTransportTestWithOrdered, SignalReadyToSendDataAfterBlocked) { fake_dtls1()->SetWritable(false); // Send messages until we get EWOULDBLOCK. static const size_t kMaxMessages = 1024; - SendDataParams params; - params.sid = 1; + webrtc::SendDataParams params; params.ordered = GetParam(); rtc::CopyOnWriteBuffer buf(1024); memset(buf.MutableData(), 0, 1024); SendDataResult result; size_t message_count = 0; for (; message_count < kMaxMessages; ++message_count) { - if (!transport1()->SendData(params, buf, &result) && result == SDR_BLOCK) { + if (!transport1()->SendData(1, params, buf, &result) && + result == SDR_BLOCK) { break; } } @@ -767,7 +806,8 @@ TEST_F(SctpTransportTest, ReusesAStream) { TEST_F(SctpTransportTest, RejectsTooLargeMessageSize) { FakeDtlsTransport fake_dtls("fake dtls", 0); SctpFakeDataReceiver recv; - std::unique_ptr transport(CreateTransport(&fake_dtls, &recv)); + std::unique_ptr transport( + CreateTransport(&fake_dtls, &recv)); EXPECT_FALSE(transport->Start(kSctpDefaultPort, kSctpDefaultPort, kSctpSendBufferSize + 1)); @@ -776,7 +816,8 @@ TEST_F(SctpTransportTest, RejectsTooLargeMessageSize) { TEST_F(SctpTransportTest, RejectsTooSmallMessageSize) { FakeDtlsTransport fake_dtls("fake dtls", 0); SctpFakeDataReceiver recv; - std::unique_ptr transport(CreateTransport(&fake_dtls, &recv)); + std::unique_ptr transport( + CreateTransport(&fake_dtls, &recv)); EXPECT_FALSE(transport->Start(kSctpDefaultPort, kSctpDefaultPort, 0)); } @@ -803,11 +844,11 @@ TEST_F(SctpTransportTest, SctpRestartWithPendingDataDoesNotDeadlock) { SctpFakeDataReceiver recv2; SctpFakeDataReceiver recv3; - std::unique_ptr transport1( + std::unique_ptr transport1( CreateTransport(&fake_dtls1, &recv1)); - std::unique_ptr transport2( + std::unique_ptr transport2( CreateTransport(&fake_dtls2, &recv2)); - std::unique_ptr transport3( + std::unique_ptr transport3( CreateTransport(&fake_dtls3, &recv3)); SctpTransportObserver observer(transport1.get()); diff --git a/modules/BUILD.gn b/modules/BUILD.gn index bb6b7cc242..54dffe0a63 100644 --- a/modules/BUILD.gn +++ b/modules/BUILD.gn @@ -47,7 +47,7 @@ rtc_source_set("module_fec_api") { sources = [ "include/module_fec_types.h" ] } -if (rtc_include_tests) { +if (rtc_include_tests && !build_with_chromium) { modules_tests_resources = [ "../resources/audio_coding/testfile16kHz.pcm", "../resources/audio_coding/testfile32kHz.pcm", diff --git a/modules/async_audio_processing/BUILD.gn b/modules/async_audio_processing/BUILD.gn index 6a2a95ecf3..9330b67f92 100644 --- a/modules/async_audio_processing/BUILD.gn +++ b/modules/async_audio_processing/BUILD.gn @@ -18,13 +18,13 @@ rtc_library("async_audio_processing") { deps = [ "../../api:scoped_refptr", + "../../api:sequence_checker", "../../api/audio:audio_frame_api", "../../api/audio:audio_frame_processor", "../../api/task_queue:task_queue", "../../rtc_base:checks", "../../rtc_base:rtc_base_approved", "../../rtc_base:rtc_task_queue", - "../../rtc_base/synchronization:sequence_checker", ] } diff --git a/modules/audio_coding/BUILD.gn b/modules/audio_coding/BUILD.gn index e440b43da7..d1d17267e5 100644 --- a/modules/audio_coding/BUILD.gn +++ b/modules/audio_coding/BUILD.gn @@ -17,7 +17,6 @@ visibility = [ ":*" ] rtc_source_set("audio_coding_module_typedefs") { visibility += [ "*" ] sources = [ "include/audio_coding_module_typedefs.h" ] - deps = [ "../../rtc_base:deprecation" ] } rtc_library("audio_coding") { @@ -52,7 +51,6 @@ rtc_library("audio_coding") { "../../common_audio:common_audio_c", "../../rtc_base:audio_format_to_string", "../../rtc_base:checks", - "../../rtc_base:deprecation", "../../rtc_base:rtc_base_approved", "../../rtc_base/synchronization:mutex", "../../system_wrappers", @@ -125,6 +123,7 @@ rtc_library("red") { "../../common_audio", "../../rtc_base:checks", "../../rtc_base:rtc_base_approved", + "../../system_wrappers:field_trial", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } @@ -368,8 +367,8 @@ rtc_library("ilbc_c") { "../../rtc_base:rtc_base_approved", "../../rtc_base:sanitizer", "../../rtc_base/system:arch", - "../../rtc_base/system:unused", ] + absl_deps = [ "//third_party/abseil-cpp/absl/base:core_headers" ] } rtc_source_set("isac_common") { @@ -831,6 +830,7 @@ rtc_library("webrtc_opus_wrapper") { } deps = [ + "../../api:array_view", "../../rtc_base:checks", "../../rtc_base:ignore_wundef", "../../rtc_base:rtc_base_approved", @@ -1058,6 +1058,7 @@ rtc_library("neteq_tools_minimal") { deps = [ ":default_neteq_factory", ":neteq", + "../../api:array_view", "../../api:neteq_simulator_api", "../../api:rtp_headers", "../../api/audio:audio_frame_api", @@ -1068,7 +1069,6 @@ rtc_library("neteq_tools_minimal") { "../../rtc_base:checks", "../../rtc_base:rtc_base_approved", "../../system_wrappers", - "../rtp_rtcp", "../rtp_rtcp:rtp_rtcp_format", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] @@ -1248,7 +1248,6 @@ rtc_library("audio_coding_modules_tests_shared") { "../../system_wrappers", "../../test:fileutils", "../../test:test_support", - "../rtp_rtcp:rtp_rtcp_format", "//testing/gtest", ] absl_deps = [ @@ -1301,36 +1300,39 @@ if (rtc_include_tests) { ] } - group("audio_coding_tests") { - visibility += webrtc_default_visibility - testonly = true - public_deps = [ - ":acm_receive_test", - ":acm_send_test", - ":audio_codec_speed_tests", - ":audio_decoder_unittests", - ":audio_decoder_unittests", - ":g711_test", - ":g722_test", - ":ilbc_test", - ":isac_api_test", - ":isac_fix_test", - ":isac_switch_samprate_test", - ":isac_test", - ":neteq_ilbc_quality_test", - ":neteq_isac_quality_test", - ":neteq_opus_quality_test", - ":neteq_pcm16b_quality_test", - ":neteq_pcmu_quality_test", - ":neteq_speed_test", - ":rtp_analyze", - ":rtp_encode", - ":rtp_jitter", - ":rtpcat", - ":webrtc_opus_fec_test", - ] - if (rtc_enable_protobuf) { - public_deps += [ ":neteq_rtpplay" ] + if (!build_with_chromium) { + group("audio_coding_tests") { + visibility += webrtc_default_visibility + testonly = true + public_deps = [ # no-presubmit-check TODO(webrtc:8603) + ":acm_receive_test", + ":acm_send_test", + ":audio_codec_speed_tests", + ":audio_decoder_unittests", + ":audio_decoder_unittests", + ":g711_test", + ":g722_test", + ":ilbc_test", + ":isac_api_test", + ":isac_fix_test", + ":isac_switch_samprate_test", + ":isac_test", + ":neteq_ilbc_quality_test", + ":neteq_isac_quality_test", + ":neteq_opus_quality_test", + ":neteq_pcm16b_quality_test", + ":neteq_pcmu_quality_test", + ":neteq_speed_test", + ":rtp_analyze", + ":rtp_encode", + ":rtp_jitter", + ":rtpcat", + ":webrtc_opus_fec_test", + ] + if (rtc_enable_protobuf) { + public_deps += # no-presubmit-check TODO(webrtc:8603) + [ ":neteq_rtpplay" ] + } } } @@ -1454,7 +1456,6 @@ if (rtc_include_tests) { defines = audio_coding_defines deps = audio_coding_deps + [ - "//third_party/abseil-cpp/absl/strings", "../../api/audio:audio_frame_api", "../../rtc_base:checks", ":audio_coding", @@ -1466,49 +1467,53 @@ if (rtc_include_tests) { "../../test:test_support", "//testing/gtest", ] - } - audio_decoder_unittests_resources = - [ "../../resources/audio_coding/testfile32kHz.pcm" ] - - if (is_ios) { - bundle_data("audio_decoder_unittests_bundle_data") { - testonly = true - sources = audio_decoder_unittests_resources - outputs = [ "{{bundle_resources_dir}}/{{source_file_part}}" ] - } + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] } - rtc_test("audio_decoder_unittests") { - testonly = true - sources = [ "neteq/audio_decoder_unittest.cc" ] + if (!build_with_chromium) { + audio_decoder_unittests_resources = + [ "../../resources/audio_coding/testfile32kHz.pcm" ] - defines = neteq_defines + if (is_ios) { + bundle_data("audio_decoder_unittests_bundle_data") { + testonly = true + sources = audio_decoder_unittests_resources + outputs = [ "{{bundle_resources_dir}}/{{source_file_part}}" ] + } + } - deps = [ - ":ilbc", - ":isac", - ":isac_fix", - ":neteq", - ":neteq_tools", - "../../test:fileutils", - "../../api/audio_codecs:audio_codecs_api", - "../../api/audio_codecs/opus:audio_encoder_opus", - "../../common_audio", - "../../rtc_base/system:arch", - "../../test:test_main", - "//testing/gtest", - "../../test:test_support", - ] + audio_coding_deps + rtc_test("audio_decoder_unittests") { + testonly = true + sources = [ "neteq/audio_decoder_unittest.cc" ] - data = audio_decoder_unittests_resources + defines = neteq_defines - if (is_android) { - deps += [ "//testing/android/native_test:native_test_native_code" ] - shard_timeout = 900 - } - if (is_ios) { - deps += [ ":audio_decoder_unittests_bundle_data" ] + deps = [ + ":ilbc", + ":isac", + ":isac_fix", + ":neteq", + ":neteq_tools", + "../../test:fileutils", + "../../api/audio_codecs:audio_codecs_api", + "../../api/audio_codecs/opus:audio_encoder_opus", + "../../common_audio", + "../../rtc_base/system:arch", + "../../test:test_main", + "//testing/gtest", + "../../test:test_support", + ] + audio_coding_deps + + data = audio_decoder_unittests_resources + + if (is_android) { + deps += [ "//testing/android/native_test:native_test_native_code" ] + shard_timeout = 900 + } + if (is_ios) { + deps += [ ":audio_decoder_unittests_bundle_data" ] + } } } @@ -1538,7 +1543,9 @@ if (rtc_include_tests) { "../../test:test_support", ] } + } + if (rtc_enable_protobuf && !build_with_chromium) { rtc_executable("neteq_rtpplay") { testonly = true visibility += [ "*" ] @@ -1559,51 +1566,54 @@ if (rtc_include_tests) { } } - audio_codec_speed_tests_resources = [ - "//resources/audio_coding/music_stereo_48kHz.pcm", - "//resources/audio_coding/speech_mono_16kHz.pcm", - "//resources/audio_coding/speech_mono_32_48kHz.pcm", - ] + if (!build_with_chromium) { + audio_codec_speed_tests_resources = [ + "//resources/audio_coding/music_stereo_48kHz.pcm", + "//resources/audio_coding/speech_mono_16kHz.pcm", + "//resources/audio_coding/speech_mono_32_48kHz.pcm", + ] - if (is_ios) { - bundle_data("audio_codec_speed_tests_data") { - testonly = true - sources = audio_codec_speed_tests_resources - outputs = [ "{{bundle_resources_dir}}/{{source_file_part}}" ] + if (is_ios) { + bundle_data("audio_codec_speed_tests_data") { + testonly = true + sources = audio_codec_speed_tests_resources + outputs = [ "{{bundle_resources_dir}}/{{source_file_part}}" ] + } } - } - rtc_test("audio_codec_speed_tests") { - testonly = true - defines = [] - deps = [ "../../test:fileutils" ] - sources = [ - "codecs/isac/fix/test/isac_speed_test.cc", - "codecs/opus/opus_speed_test.cc", - "codecs/tools/audio_codec_speed_test.cc", - "codecs/tools/audio_codec_speed_test.h", - ] + rtc_test("audio_codec_speed_tests") { + testonly = true + defines = [] + deps = [ "../../test:fileutils" ] + sources = [ + "codecs/isac/fix/test/isac_speed_test.cc", + "codecs/opus/opus_speed_test.cc", + "codecs/tools/audio_codec_speed_test.cc", + "codecs/tools/audio_codec_speed_test.h", + ] - data = audio_codec_speed_tests_resources + data = audio_codec_speed_tests_resources - if (is_android) { - deps += [ "//testing/android/native_test:native_test_native_code" ] - shard_timeout = 900 - } + if (is_android) { + deps += [ "//testing/android/native_test:native_test_native_code" ] + shard_timeout = 900 + } - if (is_ios) { - deps += [ ":audio_codec_speed_tests_data" ] - } + if (is_ios) { + deps += [ ":audio_codec_speed_tests_data" ] + } - deps += [ - ":isac_fix", - ":webrtc_opus", - "../../rtc_base:rtc_base_approved", - "../../test:test_main", - "../../test:test_support", - "../audio_processing", - "//testing/gtest", - ] + deps += [ + ":isac_fix", + ":webrtc_opus", + "../../rtc_base:checks", + "../../rtc_base:rtc_base_approved", + "../../test:test_main", + "../../test:test_support", + "../audio_processing", + "//testing/gtest", + ] + } } rtc_library("neteq_test_support") { @@ -1631,210 +1641,212 @@ if (rtc_include_tests) { ] } - rtc_library("neteq_quality_test_support") { - testonly = true - sources = [ - "neteq/tools/neteq_quality_test.cc", - "neteq/tools/neteq_quality_test.h", - ] - - deps = [ - ":default_neteq_factory", - ":neteq", - ":neteq_test_tools", - "../../api/audio_codecs:builtin_audio_decoder_factory", - "../../api/neteq:neteq_api", - "../../rtc_base:checks", - "../../system_wrappers", - "../../test:fileutils", - "../../test:test_support", - "//testing/gtest", - ] - absl_deps = [ "//third_party/abseil-cpp/absl/flags:flag" ] - } - - rtc_executable("rtp_encode") { - testonly = true + if (!build_with_chromium) { + rtc_library("neteq_quality_test_support") { + testonly = true + sources = [ + "neteq/tools/neteq_quality_test.cc", + "neteq/tools/neteq_quality_test.h", + ] - deps = audio_coding_deps + [ - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - ":audio_coding", - ":audio_encoder_cng", - ":neteq_input_audio_tools", - "../../api/audio:audio_frame_api", - "../../api/audio_codecs/g711:audio_encoder_g711", - "../../api/audio_codecs/L16:audio_encoder_L16", - "../../api/audio_codecs/g722:audio_encoder_g722", - "../../api/audio_codecs/ilbc:audio_encoder_ilbc", - "../../api/audio_codecs/isac:audio_encoder_isac", - "../../api/audio_codecs/opus:audio_encoder_opus", - "../../rtc_base:safe_conversions", - "//third_party/abseil-cpp/absl/memory", - ] + deps = [ + ":default_neteq_factory", + ":neteq", + ":neteq_test_tools", + "../../api/audio_codecs:builtin_audio_decoder_factory", + "../../api/neteq:neteq_api", + "../../rtc_base:checks", + "../../system_wrappers", + "../../test:fileutils", + "../../test:test_support", + "//testing/gtest", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/flags:flag" ] + } - sources = [ "neteq/tools/rtp_encode.cc" ] + rtc_executable("rtp_encode") { + testonly = true - defines = audio_coding_defines - } + deps = audio_coding_deps + [ + ":audio_coding", + ":audio_encoder_cng", + ":neteq_input_audio_tools", + "../../api/audio:audio_frame_api", + "../../api/audio_codecs/g711:audio_encoder_g711", + "../../api/audio_codecs/L16:audio_encoder_L16", + "../../api/audio_codecs/g722:audio_encoder_g722", + "../../api/audio_codecs/ilbc:audio_encoder_ilbc", + "../../api/audio_codecs/isac:audio_encoder_isac", + "../../api/audio_codecs/opus:audio_encoder_opus", + "../../rtc_base:safe_conversions", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + "//third_party/abseil-cpp/absl/memory", + ] + + sources = [ "neteq/tools/rtp_encode.cc" ] + + defines = audio_coding_defines + } - rtc_executable("rtp_jitter") { - testonly = true + rtc_executable("rtp_jitter") { + testonly = true - deps = audio_coding_deps + [ - "../rtp_rtcp:rtp_rtcp_format", - "../../api:array_view", - "../../rtc_base:rtc_base_approved", - ] + deps = audio_coding_deps + [ + "../rtp_rtcp:rtp_rtcp_format", + "../../api:array_view", + "../../rtc_base:rtc_base_approved", + ] - sources = [ "neteq/tools/rtp_jitter.cc" ] + sources = [ "neteq/tools/rtp_jitter.cc" ] - defines = audio_coding_defines - } + defines = audio_coding_defines + } - rtc_executable("rtpcat") { - testonly = true + rtc_executable("rtpcat") { + testonly = true - sources = [ "neteq/tools/rtpcat.cc" ] + sources = [ "neteq/tools/rtpcat.cc" ] - deps = [ - "../../rtc_base:checks", - "../../rtc_base:rtc_base_approved", - "../../test:rtp_test_utils", - "//testing/gtest", - ] - } + deps = [ + "../../rtc_base:checks", + "../../rtc_base:rtc_base_approved", + "../../test:rtp_test_utils", + "//testing/gtest", + ] + } - rtc_executable("rtp_analyze") { - testonly = true + rtc_executable("rtp_analyze") { + testonly = true - sources = [ "neteq/tools/rtp_analyze.cc" ] + sources = [ "neteq/tools/rtp_analyze.cc" ] - deps = [ - ":neteq", - ":neteq_test_tools", - ":pcm16b", - "//testing/gtest", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - ] - } + deps = [ + ":neteq", + ":neteq_test_tools", + ":pcm16b", + "//testing/gtest", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] + } - rtc_executable("neteq_opus_quality_test") { - testonly = true + rtc_executable("neteq_opus_quality_test") { + testonly = true - sources = [ "neteq/test/neteq_opus_quality_test.cc" ] + sources = [ "neteq/test/neteq_opus_quality_test.cc" ] - deps = [ - ":neteq", - ":neteq_quality_test_support", - ":neteq_tools", - ":webrtc_opus", - "../../rtc_base:rtc_base_approved", - "../../test:test_main", - "//testing/gtest", - "//third_party/abseil-cpp/absl/flags:flag", - ] - } + deps = [ + ":neteq", + ":neteq_quality_test_support", + ":neteq_tools", + ":webrtc_opus", + "../../rtc_base:rtc_base_approved", + "../../test:test_main", + "//testing/gtest", + "//third_party/abseil-cpp/absl/flags:flag", + ] + } - rtc_executable("neteq_speed_test") { - testonly = true + rtc_executable("neteq_speed_test") { + testonly = true - sources = [ "neteq/test/neteq_speed_test.cc" ] + sources = [ "neteq/test/neteq_speed_test.cc" ] - deps = [ - ":neteq", - ":neteq_test_support", - "../../rtc_base:checks", - "../../test:test_support", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - ] - } + deps = [ + ":neteq", + ":neteq_test_support", + "../../rtc_base:checks", + "../../test:test_support", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] + } - rtc_executable("neteq_ilbc_quality_test") { - testonly = true + rtc_executable("neteq_ilbc_quality_test") { + testonly = true - sources = [ "neteq/test/neteq_ilbc_quality_test.cc" ] + sources = [ "neteq/test/neteq_ilbc_quality_test.cc" ] - deps = [ - ":ilbc", - ":neteq", - ":neteq_quality_test_support", - ":neteq_tools", - "../../rtc_base:checks", - "../../rtc_base:rtc_base_approved", - "../../test:fileutils", - "../../test:test_main", - "//testing/gtest", - "//third_party/abseil-cpp/absl/flags:flag", - ] - } + deps = [ + ":ilbc", + ":neteq", + ":neteq_quality_test_support", + ":neteq_tools", + "../../rtc_base:checks", + "../../rtc_base:rtc_base_approved", + "../../test:fileutils", + "../../test:test_main", + "//testing/gtest", + "//third_party/abseil-cpp/absl/flags:flag", + ] + } - rtc_executable("neteq_isac_quality_test") { - testonly = true + rtc_executable("neteq_isac_quality_test") { + testonly = true - sources = [ "neteq/test/neteq_isac_quality_test.cc" ] + sources = [ "neteq/test/neteq_isac_quality_test.cc" ] - deps = [ - ":isac_fix", - ":neteq", - ":neteq_quality_test_support", - "../../rtc_base:rtc_base_approved", - "../../test:test_main", - "//testing/gtest", - "//third_party/abseil-cpp/absl/flags:flag", - ] - } + deps = [ + ":isac_fix", + ":neteq", + ":neteq_quality_test_support", + "../../rtc_base:rtc_base_approved", + "../../test:test_main", + "//testing/gtest", + "//third_party/abseil-cpp/absl/flags:flag", + ] + } - rtc_executable("neteq_pcmu_quality_test") { - testonly = true + rtc_executable("neteq_pcmu_quality_test") { + testonly = true - sources = [ "neteq/test/neteq_pcmu_quality_test.cc" ] + sources = [ "neteq/test/neteq_pcmu_quality_test.cc" ] - deps = [ - ":g711", - ":neteq", - ":neteq_quality_test_support", - "../../rtc_base:checks", - "../../rtc_base:rtc_base_approved", - "../../test:fileutils", - "../../test:test_main", - "//testing/gtest", - "//third_party/abseil-cpp/absl/flags:flag", - ] - } + deps = [ + ":g711", + ":neteq", + ":neteq_quality_test_support", + "../../rtc_base:checks", + "../../rtc_base:rtc_base_approved", + "../../test:fileutils", + "../../test:test_main", + "//testing/gtest", + "//third_party/abseil-cpp/absl/flags:flag", + ] + } - rtc_executable("neteq_pcm16b_quality_test") { - testonly = true + rtc_executable("neteq_pcm16b_quality_test") { + testonly = true - sources = [ "neteq/test/neteq_pcm16b_quality_test.cc" ] + sources = [ "neteq/test/neteq_pcm16b_quality_test.cc" ] - deps = [ - ":neteq", - ":neteq_quality_test_support", - ":pcm16b", - "../../rtc_base:checks", - "../../rtc_base:rtc_base_approved", - "../../test:fileutils", - "../../test:test_main", - "//testing/gtest", - "//third_party/abseil-cpp/absl/flags:flag", - ] - } + deps = [ + ":neteq", + ":neteq_quality_test_support", + ":pcm16b", + "../../rtc_base:checks", + "../../rtc_base:rtc_base_approved", + "../../test:fileutils", + "../../test:test_main", + "//testing/gtest", + "//third_party/abseil-cpp/absl/flags:flag", + ] + } - rtc_executable("isac_fix_test") { - testonly = true + rtc_test("isac_fix_test") { + testonly = true - sources = [ "codecs/isac/fix/test/kenny.cc" ] + sources = [ "codecs/isac/fix/test/kenny.cc" ] - deps = [ - ":isac_fix", - "../../test:perf_test", - "../../test:test_support", - ] + deps = [ + ":isac_fix", + "../../test:perf_test", + "../../test:test_support", + ] - data = [ "../../resources/speech_and_misc_wb.pcm" ] + data = [ "../../resources/speech_and_misc_wb.pcm" ] + } } rtc_library("isac_test_util") { @@ -1845,16 +1857,18 @@ if (rtc_include_tests) { ] } - rtc_executable("isac_test") { - testonly = true + if (!build_with_chromium) { + rtc_executable("isac_test") { + testonly = true - sources = [ "codecs/isac/main/test/simpleKenny.c" ] + sources = [ "codecs/isac/main/test/simpleKenny.c" ] - deps = [ - ":isac", - ":isac_test_util", - "../../rtc_base:rtc_base_approved", - ] + deps = [ + ":isac", + ":isac_test_util", + "../../rtc_base:rtc_base_approved", + ] + } } rtc_executable("g711_test") { @@ -1873,225 +1887,228 @@ if (rtc_include_tests) { deps = [ ":g722" ] } - rtc_executable("isac_api_test") { - testonly = true + if (!build_with_chromium) { + rtc_executable("isac_api_test") { + testonly = true - sources = [ "codecs/isac/main/test/ReleaseTest-API/ReleaseTest-API.cc" ] + sources = [ "codecs/isac/main/test/ReleaseTest-API/ReleaseTest-API.cc" ] - deps = [ - ":isac", - ":isac_test_util", - "../../rtc_base:rtc_base_approved", - ] - } + deps = [ + ":isac", + ":isac_test_util", + "../../rtc_base:rtc_base_approved", + ] + } - rtc_executable("isac_switch_samprate_test") { - testonly = true + rtc_executable("isac_switch_samprate_test") { + testonly = true - sources = [ "codecs/isac/main/test/SwitchingSampRate/SwitchingSampRate.cc" ] + sources = + [ "codecs/isac/main/test/SwitchingSampRate/SwitchingSampRate.cc" ] - deps = [ - ":isac", - ":isac_test_util", - "../../common_audio", - "../../common_audio:common_audio_c", - ] - } + deps = [ + ":isac", + ":isac_test_util", + "../../common_audio", + "../../common_audio:common_audio_c", + ] + } - rtc_executable("ilbc_test") { - testonly = true + rtc_executable("ilbc_test") { + testonly = true - sources = [ "codecs/ilbc/test/iLBC_test.c" ] + sources = [ "codecs/ilbc/test/iLBC_test.c" ] - deps = [ ":ilbc" ] - } + deps = [ ":ilbc" ] + } - rtc_executable("webrtc_opus_fec_test") { - testonly = true + rtc_executable("webrtc_opus_fec_test") { + testonly = true - sources = [ "codecs/opus/opus_fec_test.cc" ] + sources = [ "codecs/opus/opus_fec_test.cc" ] - deps = [ - ":webrtc_opus", - "../../common_audio", - "../../rtc_base:rtc_base_approved", - "../../test:fileutils", - "../../test:test_main", - "../../test:test_support", - "//testing/gtest", - ] - } + deps = [ + ":webrtc_opus", + "../../common_audio", + "../../rtc_base:rtc_base_approved", + "../../test:fileutils", + "../../test:test_main", + "../../test:test_support", + "//testing/gtest", + ] + } - rtc_library("audio_coding_unittests") { - testonly = true - visibility += webrtc_default_visibility + rtc_library("audio_coding_unittests") { + testonly = true + visibility += webrtc_default_visibility - sources = [ - "acm2/acm_receiver_unittest.cc", - "acm2/acm_remixing_unittest.cc", - "acm2/audio_coding_module_unittest.cc", - "acm2/call_statistics_unittest.cc", - "audio_network_adaptor/audio_network_adaptor_impl_unittest.cc", - "audio_network_adaptor/bitrate_controller_unittest.cc", - "audio_network_adaptor/channel_controller_unittest.cc", - "audio_network_adaptor/controller_manager_unittest.cc", - "audio_network_adaptor/dtx_controller_unittest.cc", - "audio_network_adaptor/event_log_writer_unittest.cc", - "audio_network_adaptor/fec_controller_plr_based_unittest.cc", - "audio_network_adaptor/frame_length_controller_unittest.cc", - "audio_network_adaptor/frame_length_controller_v2_unittest.cc", - "audio_network_adaptor/util/threshold_curve_unittest.cc", - "codecs/builtin_audio_decoder_factory_unittest.cc", - "codecs/builtin_audio_encoder_factory_unittest.cc", - "codecs/cng/audio_encoder_cng_unittest.cc", - "codecs/cng/cng_unittest.cc", - "codecs/ilbc/ilbc_unittest.cc", - "codecs/isac/fix/source/filterbanks_unittest.cc", - "codecs/isac/fix/source/filters_unittest.cc", - "codecs/isac/fix/source/lpc_masking_model_unittest.cc", - "codecs/isac/fix/source/transform_unittest.cc", - "codecs/isac/isac_webrtc_api_test.cc", - "codecs/isac/main/source/audio_encoder_isac_unittest.cc", - "codecs/isac/main/source/isac_unittest.cc", - "codecs/legacy_encoded_audio_frame_unittest.cc", - "codecs/opus/audio_decoder_multi_channel_opus_unittest.cc", - "codecs/opus/audio_encoder_multi_channel_opus_unittest.cc", - "codecs/opus/audio_encoder_opus_unittest.cc", - "codecs/opus/opus_bandwidth_unittest.cc", - "codecs/opus/opus_unittest.cc", - "codecs/red/audio_encoder_copy_red_unittest.cc", - "neteq/audio_multi_vector_unittest.cc", - "neteq/audio_vector_unittest.cc", - "neteq/background_noise_unittest.cc", - "neteq/buffer_level_filter_unittest.cc", - "neteq/comfort_noise_unittest.cc", - "neteq/decision_logic_unittest.cc", - "neteq/decoder_database_unittest.cc", - "neteq/delay_manager_unittest.cc", - "neteq/dsp_helper_unittest.cc", - "neteq/dtmf_buffer_unittest.cc", - "neteq/dtmf_tone_generator_unittest.cc", - "neteq/expand_unittest.cc", - "neteq/histogram_unittest.cc", - "neteq/merge_unittest.cc", - "neteq/mock/mock_buffer_level_filter.h", - "neteq/mock/mock_decoder_database.h", - "neteq/mock/mock_delay_manager.h", - "neteq/mock/mock_dtmf_buffer.h", - "neteq/mock/mock_dtmf_tone_generator.h", - "neteq/mock/mock_expand.h", - "neteq/mock/mock_histogram.h", - "neteq/mock/mock_neteq_controller.h", - "neteq/mock/mock_packet_buffer.h", - "neteq/mock/mock_red_payload_splitter.h", - "neteq/mock/mock_statistics_calculator.h", - "neteq/nack_tracker_unittest.cc", - "neteq/neteq_decoder_plc_unittest.cc", - "neteq/neteq_impl_unittest.cc", - "neteq/neteq_network_stats_unittest.cc", - "neteq/neteq_stereo_unittest.cc", - "neteq/neteq_unittest.cc", - "neteq/normal_unittest.cc", - "neteq/packet_buffer_unittest.cc", - "neteq/post_decode_vad_unittest.cc", - "neteq/random_vector_unittest.cc", - "neteq/red_payload_splitter_unittest.cc", - "neteq/statistics_calculator_unittest.cc", - "neteq/sync_buffer_unittest.cc", - "neteq/time_stretch_unittest.cc", - "neteq/timestamp_scaler_unittest.cc", - "neteq/tools/input_audio_file_unittest.cc", - "neteq/tools/packet_unittest.cc", - ] + sources = [ + "acm2/acm_receiver_unittest.cc", + "acm2/acm_remixing_unittest.cc", + "acm2/audio_coding_module_unittest.cc", + "acm2/call_statistics_unittest.cc", + "audio_network_adaptor/audio_network_adaptor_impl_unittest.cc", + "audio_network_adaptor/bitrate_controller_unittest.cc", + "audio_network_adaptor/channel_controller_unittest.cc", + "audio_network_adaptor/controller_manager_unittest.cc", + "audio_network_adaptor/dtx_controller_unittest.cc", + "audio_network_adaptor/event_log_writer_unittest.cc", + "audio_network_adaptor/fec_controller_plr_based_unittest.cc", + "audio_network_adaptor/frame_length_controller_unittest.cc", + "audio_network_adaptor/frame_length_controller_v2_unittest.cc", + "audio_network_adaptor/util/threshold_curve_unittest.cc", + "codecs/builtin_audio_decoder_factory_unittest.cc", + "codecs/builtin_audio_encoder_factory_unittest.cc", + "codecs/cng/audio_encoder_cng_unittest.cc", + "codecs/cng/cng_unittest.cc", + "codecs/ilbc/ilbc_unittest.cc", + "codecs/isac/fix/source/filterbanks_unittest.cc", + "codecs/isac/fix/source/filters_unittest.cc", + "codecs/isac/fix/source/lpc_masking_model_unittest.cc", + "codecs/isac/fix/source/transform_unittest.cc", + "codecs/isac/isac_webrtc_api_test.cc", + "codecs/isac/main/source/audio_encoder_isac_unittest.cc", + "codecs/isac/main/source/isac_unittest.cc", + "codecs/legacy_encoded_audio_frame_unittest.cc", + "codecs/opus/audio_decoder_multi_channel_opus_unittest.cc", + "codecs/opus/audio_encoder_multi_channel_opus_unittest.cc", + "codecs/opus/audio_encoder_opus_unittest.cc", + "codecs/opus/opus_bandwidth_unittest.cc", + "codecs/opus/opus_unittest.cc", + "codecs/red/audio_encoder_copy_red_unittest.cc", + "neteq/audio_multi_vector_unittest.cc", + "neteq/audio_vector_unittest.cc", + "neteq/background_noise_unittest.cc", + "neteq/buffer_level_filter_unittest.cc", + "neteq/comfort_noise_unittest.cc", + "neteq/decision_logic_unittest.cc", + "neteq/decoder_database_unittest.cc", + "neteq/delay_manager_unittest.cc", + "neteq/dsp_helper_unittest.cc", + "neteq/dtmf_buffer_unittest.cc", + "neteq/dtmf_tone_generator_unittest.cc", + "neteq/expand_unittest.cc", + "neteq/histogram_unittest.cc", + "neteq/merge_unittest.cc", + "neteq/mock/mock_buffer_level_filter.h", + "neteq/mock/mock_decoder_database.h", + "neteq/mock/mock_delay_manager.h", + "neteq/mock/mock_dtmf_buffer.h", + "neteq/mock/mock_dtmf_tone_generator.h", + "neteq/mock/mock_expand.h", + "neteq/mock/mock_histogram.h", + "neteq/mock/mock_neteq_controller.h", + "neteq/mock/mock_packet_buffer.h", + "neteq/mock/mock_red_payload_splitter.h", + "neteq/mock/mock_statistics_calculator.h", + "neteq/nack_tracker_unittest.cc", + "neteq/neteq_decoder_plc_unittest.cc", + "neteq/neteq_impl_unittest.cc", + "neteq/neteq_network_stats_unittest.cc", + "neteq/neteq_stereo_unittest.cc", + "neteq/neteq_unittest.cc", + "neteq/normal_unittest.cc", + "neteq/packet_buffer_unittest.cc", + "neteq/post_decode_vad_unittest.cc", + "neteq/random_vector_unittest.cc", + "neteq/red_payload_splitter_unittest.cc", + "neteq/statistics_calculator_unittest.cc", + "neteq/sync_buffer_unittest.cc", + "neteq/time_stretch_unittest.cc", + "neteq/timestamp_scaler_unittest.cc", + "neteq/tools/input_audio_file_unittest.cc", + "neteq/tools/packet_unittest.cc", + ] - deps = [ - ":acm_receive_test", - ":acm_send_test", - ":audio_coding", - ":audio_coding_module_typedefs", - ":audio_coding_modules_tests_shared", - ":audio_coding_opus_common", - ":audio_encoder_cng", - ":audio_network_adaptor", - ":default_neteq_factory", - ":g711", - ":ilbc", - ":isac", - ":isac_c", - ":isac_common", - ":isac_fix", - ":legacy_encoded_audio_frame", - ":mocks", - ":neteq", - ":neteq_test_support", - ":neteq_test_tools", - ":pcm16b", - ":red", - ":webrtc_cng", - ":webrtc_opus", - "..:module_api", - "..:module_api_public", - "../../api:array_view", - "../../api/audio:audio_frame_api", - "../../api/audio_codecs:audio_codecs_api", - "../../api/audio_codecs:builtin_audio_decoder_factory", - "../../api/audio_codecs:builtin_audio_encoder_factory", - "../../api/audio_codecs/isac:audio_decoder_isac_fix", - "../../api/audio_codecs/isac:audio_decoder_isac_float", - "../../api/audio_codecs/isac:audio_encoder_isac_fix", - "../../api/audio_codecs/isac:audio_encoder_isac_float", - "../../api/audio_codecs/opus:audio_decoder_multiopus", - "../../api/audio_codecs/opus:audio_decoder_opus", - "../../api/audio_codecs/opus:audio_encoder_multiopus", - "../../api/audio_codecs/opus:audio_encoder_opus", - "../../api/neteq:default_neteq_controller_factory", - "../../api/neteq:neteq_api", - "../../api/neteq:neteq_controller_api", - "../../api/neteq:tick_timer", - "../../api/neteq:tick_timer_unittest", - "../../api/rtc_event_log", - "../../common_audio", - "../../common_audio:common_audio_c", - "../../common_audio:mock_common_audio", - "../../logging:mocks", - "../../logging:rtc_event_audio", - "../../modules/rtp_rtcp:rtp_rtcp_format", - "../../rtc_base", - "../../rtc_base:checks", - "../../rtc_base:ignore_wundef", - "../../rtc_base:rtc_base_approved", - "../../rtc_base:rtc_base_tests_utils", - "../../rtc_base:sanitizer", - "../../rtc_base:timeutils", - "../../rtc_base/synchronization:mutex", - "../../rtc_base/system:arch", - "../../system_wrappers", - "../../test:audio_codec_mocks", - "../../test:field_trial", - "../../test:fileutils", - "../../test:rtc_expect_death", - "../../test:rtp_test_utils", - "../../test:test_common", - "../../test:test_support", - "codecs/opus/test", - "codecs/opus/test:test_unittest", - "//testing/gtest", - ] - absl_deps = [ - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/memory", - "//third_party/abseil-cpp/absl/strings", - "//third_party/abseil-cpp/absl/types:optional", - ] + deps = [ + ":acm_receive_test", + ":acm_send_test", + ":audio_coding", + ":audio_coding_module_typedefs", + ":audio_coding_modules_tests_shared", + ":audio_coding_opus_common", + ":audio_encoder_cng", + ":audio_network_adaptor", + ":default_neteq_factory", + ":g711", + ":ilbc", + ":isac", + ":isac_c", + ":isac_common", + ":isac_fix", + ":legacy_encoded_audio_frame", + ":mocks", + ":neteq", + ":neteq_test_support", + ":neteq_test_tools", + ":pcm16b", + ":red", + ":webrtc_cng", + ":webrtc_opus", + "..:module_api", + "..:module_api_public", + "../../api:array_view", + "../../api/audio:audio_frame_api", + "../../api/audio_codecs:audio_codecs_api", + "../../api/audio_codecs:builtin_audio_decoder_factory", + "../../api/audio_codecs:builtin_audio_encoder_factory", + "../../api/audio_codecs/isac:audio_decoder_isac_fix", + "../../api/audio_codecs/isac:audio_decoder_isac_float", + "../../api/audio_codecs/isac:audio_encoder_isac_fix", + "../../api/audio_codecs/isac:audio_encoder_isac_float", + "../../api/audio_codecs/opus:audio_decoder_multiopus", + "../../api/audio_codecs/opus:audio_decoder_opus", + "../../api/audio_codecs/opus:audio_encoder_multiopus", + "../../api/audio_codecs/opus:audio_encoder_opus", + "../../api/neteq:default_neteq_controller_factory", + "../../api/neteq:neteq_api", + "../../api/neteq:neteq_controller_api", + "../../api/neteq:tick_timer", + "../../api/neteq:tick_timer_unittest", + "../../api/rtc_event_log", + "../../common_audio", + "../../common_audio:common_audio_c", + "../../common_audio:mock_common_audio", + "../../logging:mocks", + "../../logging:rtc_event_audio", + "../../modules/rtp_rtcp:rtp_rtcp_format", + "../../rtc_base", + "../../rtc_base:checks", + "../../rtc_base:ignore_wundef", + "../../rtc_base:rtc_base_approved", + "../../rtc_base:rtc_base_tests_utils", + "../../rtc_base:sanitizer", + "../../rtc_base:timeutils", + "../../rtc_base/synchronization:mutex", + "../../rtc_base/system:arch", + "../../system_wrappers", + "../../test:audio_codec_mocks", + "../../test:field_trial", + "../../test:fileutils", + "../../test:rtc_expect_death", + "../../test:rtp_test_utils", + "../../test:test_common", + "../../test:test_support", + "codecs/opus/test", + "codecs/opus/test:test_unittest", + "//testing/gtest", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] - defines = audio_coding_defines + defines = audio_coding_defines - if (rtc_enable_protobuf) { - defines += [ "WEBRTC_NETEQ_UNITTEST_BITEXACT" ] - deps += [ - ":ana_config_proto", - ":neteq_unittest_proto", - ] + if (rtc_enable_protobuf) { + defines += [ "WEBRTC_NETEQ_UNITTEST_BITEXACT" ] + deps += [ + ":ana_config_proto", + ":neteq_unittest_proto", + ] + } } } } diff --git a/modules/audio_coding/OWNERS b/modules/audio_coding/OWNERS index f7a0e4797e..c27c2a8d2d 100644 --- a/modules/audio_coding/OWNERS +++ b/modules/audio_coding/OWNERS @@ -1,3 +1,4 @@ henrik.lundin@webrtc.org minyue@webrtc.org ivoc@webrtc.org +jakobi@webrtc.org diff --git a/modules/audio_coding/acm2/acm_receiver.cc b/modules/audio_coding/acm2/acm_receiver.cc index 0e615cae82..3214ce6f7b 100644 --- a/modules/audio_coding/acm2/acm_receiver.cc +++ b/modules/audio_coding/acm2/acm_receiver.cc @@ -146,20 +146,22 @@ int AcmReceiver::GetAudio(int desired_freq_hz, AudioFrame* audio_frame, bool* muted) { RTC_DCHECK(muted); - // Accessing members, take the lock. - MutexLock lock(&mutex_); - if (neteq_->GetAudio(audio_frame, muted) != NetEq::kOK) { + int current_sample_rate_hz = 0; + if (neteq_->GetAudio(audio_frame, muted, ¤t_sample_rate_hz) != + NetEq::kOK) { RTC_LOG(LERROR) << "AcmReceiver::GetAudio - NetEq Failed."; return -1; } - const int current_sample_rate_hz = neteq_->last_output_sample_rate_hz(); + RTC_DCHECK_NE(current_sample_rate_hz, 0); // Update if resampling is required. const bool need_resampling = (desired_freq_hz != -1) && (current_sample_rate_hz != desired_freq_hz); + // Accessing members, take the lock. + MutexLock lock(&mutex_); if (need_resampling && !resampled_last_output_frame_) { // Prime the resampler with the last frame. int16_t temp_output[AudioFrame::kMaxDataSizeSamples]; @@ -174,8 +176,8 @@ int AcmReceiver::GetAudio(int desired_freq_hz, } } - // TODO(henrik.lundin) Glitches in the output may appear if the output rate - // from NetEq changes. See WebRTC issue 3923. + // TODO(bugs.webrtc.org/3923) Glitches in the output may appear if the output + // rate from NetEq changes. if (need_resampling) { // TODO(yujo): handle this more efficiently for muted frames. int samples_per_channel_int = resampler_.Resample10Msec( diff --git a/modules/audio_coding/acm2/acm_receiver_unittest.cc b/modules/audio_coding/acm2/acm_receiver_unittest.cc index a8da77e6b6..2338a53235 100644 --- a/modules/audio_coding/acm2/acm_receiver_unittest.cc +++ b/modules/audio_coding/acm2/acm_receiver_unittest.cc @@ -119,7 +119,7 @@ class AcmReceiverTestOldApi : public AudioPacketizationCallback, rtp_header_, rtc::ArrayView(payload_data, payload_len_bytes)); if (ret_val < 0) { - assert(false); + RTC_NOTREACHED(); return -1; } rtp_header_.sequenceNumber++; diff --git a/modules/audio_coding/acm2/acm_resampler.cc b/modules/audio_coding/acm2/acm_resampler.cc index ca3583e32c..367ec2b9cd 100644 --- a/modules/audio_coding/acm2/acm_resampler.cc +++ b/modules/audio_coding/acm2/acm_resampler.cc @@ -31,7 +31,7 @@ int ACMResampler::Resample10Msec(const int16_t* in_audio, size_t in_length = in_freq_hz * num_audio_channels / 100; if (in_freq_hz == out_freq_hz) { if (out_capacity_samples < in_length) { - assert(false); + RTC_NOTREACHED(); return -1; } memcpy(out_audio, in_audio, in_length * sizeof(int16_t)); diff --git a/modules/audio_coding/acm2/acm_send_test.cc b/modules/audio_coding/acm2/acm_send_test.cc index b3e1e1ecb2..cda668dab8 100644 --- a/modules/audio_coding/acm2/acm_send_test.cc +++ b/modules/audio_coding/acm2/acm_send_test.cc @@ -51,8 +51,8 @@ AcmSendTestOldApi::AcmSendTestOldApi(InputAudioFile* audio_source, input_frame_.sample_rate_hz_ = source_rate_hz_; input_frame_.num_channels_ = 1; input_frame_.samples_per_channel_ = input_block_size_samples_; - assert(input_block_size_samples_ * input_frame_.num_channels_ <= - AudioFrame::kMaxDataSizeSamples); + RTC_DCHECK_LE(input_block_size_samples_ * input_frame_.num_channels_, + AudioFrame::kMaxDataSizeSamples); acm_->RegisterTransportCallback(this); } @@ -81,8 +81,8 @@ bool AcmSendTestOldApi::RegisterCodec(const char* payload_name, factory->MakeAudioEncoder(payload_type, format, absl::nullopt)); codec_registered_ = true; input_frame_.num_channels_ = num_channels; - assert(input_block_size_samples_ * input_frame_.num_channels_ <= - AudioFrame::kMaxDataSizeSamples); + RTC_DCHECK_LE(input_block_size_samples_ * input_frame_.num_channels_, + AudioFrame::kMaxDataSizeSamples); return codec_registered_; } @@ -90,13 +90,13 @@ void AcmSendTestOldApi::RegisterExternalCodec( std::unique_ptr external_speech_encoder) { input_frame_.num_channels_ = external_speech_encoder->NumChannels(); acm_->SetEncoder(std::move(external_speech_encoder)); - assert(input_block_size_samples_ * input_frame_.num_channels_ <= - AudioFrame::kMaxDataSizeSamples); + RTC_DCHECK_LE(input_block_size_samples_ * input_frame_.num_channels_, + AudioFrame::kMaxDataSizeSamples); codec_registered_ = true; } std::unique_ptr AcmSendTestOldApi::NextPacket() { - assert(codec_registered_); + RTC_DCHECK(codec_registered_); if (filter_.test(static_cast(payload_type_))) { // This payload type should be filtered out. Since the payload type is the // same throughout the whole test run, no packet at all will be delivered. @@ -133,15 +133,16 @@ int32_t AcmSendTestOldApi::SendData(AudioFrameType frame_type, payload_type_ = payload_type; timestamp_ = timestamp; last_payload_vec_.assign(payload_data, payload_data + payload_len_bytes); - assert(last_payload_vec_.size() == payload_len_bytes); + RTC_DCHECK_EQ(last_payload_vec_.size(), payload_len_bytes); data_to_send_ = true; return 0; } std::unique_ptr AcmSendTestOldApi::CreatePacket() { const size_t kRtpHeaderSize = 12; - size_t allocated_bytes = last_payload_vec_.size() + kRtpHeaderSize; - uint8_t* packet_memory = new uint8_t[allocated_bytes]; + rtc::CopyOnWriteBuffer packet_buffer(last_payload_vec_.size() + + kRtpHeaderSize); + uint8_t* packet_memory = packet_buffer.MutableData(); // Populate the header bytes. packet_memory[0] = 0x80; packet_memory[1] = static_cast(payload_type_); @@ -162,8 +163,8 @@ std::unique_ptr AcmSendTestOldApi::CreatePacket() { // Copy the payload data. memcpy(packet_memory + kRtpHeaderSize, &last_payload_vec_[0], last_payload_vec_.size()); - std::unique_ptr packet( - new Packet(packet_memory, allocated_bytes, clock_.TimeInMilliseconds())); + auto packet = std::make_unique(std::move(packet_buffer), + clock_.TimeInMilliseconds()); RTC_DCHECK(packet); RTC_DCHECK(packet->valid_header()); return packet; diff --git a/modules/audio_coding/acm2/audio_coding_module.cc b/modules/audio_coding/acm2/audio_coding_module.cc index 648ae6e5ea..7d0f4d1e84 100644 --- a/modules/audio_coding/acm2/audio_coding_module.cc +++ b/modules/audio_coding/acm2/audio_coding_module.cc @@ -343,13 +343,13 @@ int AudioCodingModuleImpl::Add10MsData(const AudioFrame& audio_frame) { int AudioCodingModuleImpl::Add10MsDataInternal(const AudioFrame& audio_frame, InputData* input_data) { if (audio_frame.samples_per_channel_ == 0) { - assert(false); + RTC_NOTREACHED(); RTC_LOG(LS_ERROR) << "Cannot Add 10 ms audio, payload length is zero"; return -1; } if (audio_frame.sample_rate_hz_ > kMaxInputSampleRateHz) { - assert(false); + RTC_NOTREACHED(); RTC_LOG(LS_ERROR) << "Cannot Add 10 ms audio, input frequency not valid"; return -1; } diff --git a/modules/audio_coding/acm2/audio_coding_module_unittest.cc b/modules/audio_coding/acm2/audio_coding_module_unittest.cc index 590dc30f47..74654565e3 100644 --- a/modules/audio_coding/acm2/audio_coding_module_unittest.cc +++ b/modules/audio_coding/acm2/audio_coding_module_unittest.cc @@ -429,15 +429,6 @@ class AudioCodingModuleMtTestOldApi : public AudioCodingModuleTestOldApi { AudioCodingModuleMtTestOldApi() : AudioCodingModuleTestOldApi(), - send_thread_(CbSendThread, this, "send", rtc::kRealtimePriority), - insert_packet_thread_(CbInsertPacketThread, - this, - "insert_packet", - rtc::kRealtimePriority), - pull_audio_thread_(CbPullAudioThread, - this, - "pull_audio", - rtc::kRealtimePriority), send_count_(0), insert_packet_count_(0), pull_audio_count_(0), @@ -454,17 +445,38 @@ class AudioCodingModuleMtTestOldApi : public AudioCodingModuleTestOldApi { void StartThreads() { quit_.store(false); - send_thread_.Start(); - insert_packet_thread_.Start(); - pull_audio_thread_.Start(); + + const auto attributes = + rtc::ThreadAttributes().SetPriority(rtc::ThreadPriority::kRealtime); + send_thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { + while (!quit_.load()) { + CbSendImpl(); + } + }, + "send", attributes); + insert_packet_thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { + while (!quit_.load()) { + CbInsertPacketImpl(); + } + }, + "insert_packet", attributes); + pull_audio_thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { + while (!quit_.load()) { + CbPullAudioImpl(); + } + }, + "pull_audio", attributes); } void TearDown() { AudioCodingModuleTestOldApi::TearDown(); quit_.store(true); - pull_audio_thread_.Stop(); - send_thread_.Stop(); - insert_packet_thread_.Stop(); + pull_audio_thread_.Finalize(); + send_thread_.Finalize(); + insert_packet_thread_.Finalize(); } bool RunTest() { @@ -482,14 +494,6 @@ class AudioCodingModuleMtTestOldApi : public AudioCodingModuleTestOldApi { return false; } - static void CbSendThread(void* context) { - AudioCodingModuleMtTestOldApi* fixture = - reinterpret_cast(context); - while (!fixture->quit_.load()) { - fixture->CbSendImpl(); - } - } - // The send thread doesn't have to care about the current simulated time, // since only the AcmReceiver is using the clock. void CbSendImpl() { @@ -505,14 +509,6 @@ class AudioCodingModuleMtTestOldApi : public AudioCodingModuleTestOldApi { } } - static void CbInsertPacketThread(void* context) { - AudioCodingModuleMtTestOldApi* fixture = - reinterpret_cast(context); - while (!fixture->quit_.load()) { - fixture->CbInsertPacketImpl(); - } - } - void CbInsertPacketImpl() { SleepMs(1); { @@ -527,14 +523,6 @@ class AudioCodingModuleMtTestOldApi : public AudioCodingModuleTestOldApi { InsertPacket(); } - static void CbPullAudioThread(void* context) { - AudioCodingModuleMtTestOldApi* fixture = - reinterpret_cast(context); - while (!fixture->quit_.load()) { - fixture->CbPullAudioImpl(); - } - } - void CbPullAudioImpl() { SleepMs(1); { @@ -693,14 +681,6 @@ class AcmReRegisterIsacMtTestOldApi : public AudioCodingModuleTestOldApi { AcmReRegisterIsacMtTestOldApi() : AudioCodingModuleTestOldApi(), - receive_thread_(CbReceiveThread, - this, - "receive", - rtc::kRealtimePriority), - codec_registration_thread_(CbCodecRegistrationThread, - this, - "codec_registration", - rtc::kRealtimePriority), codec_registered_(false), receive_packet_count_(0), next_insert_packet_time_ms_(0), @@ -732,28 +712,34 @@ class AcmReRegisterIsacMtTestOldApi : public AudioCodingModuleTestOldApi { void StartThreads() { quit_.store(false); - receive_thread_.Start(); - codec_registration_thread_.Start(); + const auto attributes = + rtc::ThreadAttributes().SetPriority(rtc::ThreadPriority::kRealtime); + receive_thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { + while (!quit_.load() && CbReceiveImpl()) { + } + }, + "receive", attributes); + codec_registration_thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { + while (!quit_.load()) { + CbCodecRegistrationImpl(); + } + }, + "codec_registration", attributes); } void TearDown() override { AudioCodingModuleTestOldApi::TearDown(); quit_.store(true); - receive_thread_.Stop(); - codec_registration_thread_.Stop(); + receive_thread_.Finalize(); + codec_registration_thread_.Finalize(); } bool RunTest() { return test_complete_.Wait(10 * 60 * 1000); // 10 minutes' timeout. } - static void CbReceiveThread(void* context) { - AcmReRegisterIsacMtTestOldApi* fixture = - reinterpret_cast(context); - while (!fixture->quit_.load() && fixture->CbReceiveImpl()) { - } - } - bool CbReceiveImpl() { SleepMs(1); rtc::Buffer encoded; @@ -799,14 +785,6 @@ class AcmReRegisterIsacMtTestOldApi : public AudioCodingModuleTestOldApi { return true; } - static void CbCodecRegistrationThread(void* context) { - AcmReRegisterIsacMtTestOldApi* fixture = - reinterpret_cast(context); - while (!fixture->quit_.load()) { - fixture->CbCodecRegistrationImpl(); - } - } - void CbCodecRegistrationImpl() { SleepMs(1); if (HasFatalFailure()) { @@ -862,9 +840,12 @@ class AcmReceiverBitExactnessOldApi : public ::testing::Test { std::string win64, std::string android_arm32, std::string android_arm64, - std::string android_arm64_clang) { + std::string android_arm64_clang, + std::string mac_arm64) { #if defined(_WIN32) && defined(WEBRTC_ARCH_64_BITS) return win64; +#elif defined(WEBRTC_MAC) && defined(WEBRTC_ARCH_ARM64) + return mac_arm64; #elif defined(WEBRTC_ANDROID) && defined(WEBRTC_ARCH_ARM) return android_arm32; #elif defined(WEBRTC_ANDROID) && defined(WEBRTC_ARCH_ARM64) @@ -939,58 +920,87 @@ class AcmReceiverBitExactnessOldApi : public ::testing::Test { defined(WEBRTC_CODEC_ILBC) TEST_F(AcmReceiverBitExactnessOldApi, 8kHzOutput) { std::string others_checksum_reference = - GetCPUInfo(kAVX2) != 0 ? "1d7b784031599e2c01a3f575f8439f2f" - : "c119fda4ea2c119ff2a720fd0c289071"; + GetCPUInfo(kAVX2) != 0 ? "e0c966d7b8c36ff60167988fa35d33e0" +// TODO(bugs.webrtc.org/12941): Linux x86 optimized builds have a different +// checksum. +#if defined(WEBRTC_LINUX) && defined(NDEBUG) + : "5af28619e3a3c606b2242c9a12f4f64e"; +#else + : "7d8f6b84abd1e57ec010a53bc2130652"; +#endif std::string win64_checksum_reference = GetCPUInfo(kAVX2) != 0 ? "405a50f0bcb8827e20aa944299fc59f6" - : "38e70d4e186f8e1a56b929fafcb7c379"; + : "0ed5830930f5527a01bbec0ba11f8541"; Run(8000, PlatformChecksum(others_checksum_reference, win64_checksum_reference, - "3b03e41731e1cef5ae2b9f9618660b42", + "b892ed69c38b21b16c132ec2ce03aa7b", "4598140b5e4f7ee66c5adad609e65a3e", - "da7e76687c8c0a9509cd1d57ee1aba3b")); + "5fec8d770778ef7969ec98c56d9eb10f", + "636efe6d0a148f22c5383f356da3deac")); } TEST_F(AcmReceiverBitExactnessOldApi, 16kHzOutput) { std::string others_checksum_reference = - GetCPUInfo(kAVX2) != 0 ? "8884d910e443c244d8593c2e3cef5e63" - : "36dc8c0532ba0efa099e2b6a689cde40"; + GetCPUInfo(kAVX2) != 0 ? "a63c578e1195c8420f453962c6d8519c" + +// TODO(bugs.webrtc.org/12941): Linux x86 optimized builds have a different +// checksum. +#if defined(WEBRTC_LINUX) && defined(NDEBUG) + : "f788cc9200ac4a7d498d9081987808a3"; +#else + : "6bac83762c1306b932cd25a560155681"; +#endif std::string win64_checksum_reference = GetCPUInfo(kAVX2) != 0 ? "58fd62a5c49ee513f9fa6fe7dbf62c97" - : "07e4b388168e273fa19da0a167aff782"; + : "0509cf0672f543efb4b050e8cffefb1d"; Run(16000, PlatformChecksum(others_checksum_reference, win64_checksum_reference, - "06b08d14a72f6e7c72840b1cc9ad204d", + "3cea9abbeabbdea9a79719941b241af5", "f2aad418af974a3b1694d5ae5cc2c3c7", - "1d5f9a93f3975e7e491373b81eb5fd14")); + "9d4b92c31c00e321a4cff29ad002d6a2", + "1e2d1b482fdc924f79a838503ee7ead5")); } TEST_F(AcmReceiverBitExactnessOldApi, 32kHzOutput) { std::string others_checksum_reference = - GetCPUInfo(kAVX2) != 0 ? "73f4fe21996c0af495e2c47e3708e519" - : "c848ce9002d3825056a1eac2a067c0d3"; + GetCPUInfo(kAVX2) != 0 ? "8775ce387f44dc5ff4a26da295d5ee7c" +// TODO(bugs.webrtc.org/12941): Linux x86 optimized builds have a different +// checksum. +#if defined(WEBRTC_LINUX) && defined(NDEBUG) + : "5b84b2a179cb8533a8f9bcd19612e7f0"; +#else + : "e319222ca47733709f90fdf33c8574db"; +#endif std::string win64_checksum_reference = GetCPUInfo(kAVX2) != 0 ? "04ce6a1dac5ffdd8438d804623d0132f" - : "0e705f6844c75fd57a84734f7c30af87"; + : "39a4a7a1c455b35baeffb9fd193d7858"; Run(32000, PlatformChecksum(others_checksum_reference, win64_checksum_reference, - "c18e98e5701ec91bba5c026b720d1790", + "4df55b3b62bcbf4328786d474ae87f61", "100869c8dcde51346c2073e52a272d98", - "e35df943bfa3ca32e86b26bf1e37ed8f")); + "ff58d3153d2780a3df6bc2068844cb2d", + "51788e9784a10ae14a030f075a039205")); } TEST_F(AcmReceiverBitExactnessOldApi, 48kHzOutput) { std::string others_checksum_reference = - GetCPUInfo(kAVX2) != 0 ? "884243f7e1476931e93eda5de88d1326" - : "ba0f66d538487bba377e721cfac62d1e"; + GetCPUInfo(kAVX2) != 0 ? "7a55700b7ca9aa60237db58b33e55606" +// TODO(bugs.webrtc.org/12941): Linux x86 optimized builds have a different +// checksum. +#if defined(WEBRTC_LINUX) && defined(NDEBUG) + : "a2459749062f96297283cce4a8c7e6db"; +#else + : "57d1d316c88279f4f3da3511665069a9"; +#endif std::string win64_checksum_reference = GetCPUInfo(kAVX2) != 0 ? "f59833d9b0924f4b0704707dd3589f80" - : "6a480541fb86faa95c7563b9de08104d"; + : "74cbe7345e2b6b45c1e455a5d1e921ca"; Run(48000, PlatformChecksum(others_checksum_reference, win64_checksum_reference, - "30e617e4b3c9ba1979d1b2e8eba3519b", + "f52bc7bf0f499c9da25932fdf176c4ec", "bd44bf97e7899186532f91235cef444d", - "90158462a1853b1de50873eebd68dba7")); + "364d403dae55d73cd69e6dbd6b723a4d", + "71bc5c15a151400517c2119d1602ee9f")); } TEST_F(AcmReceiverBitExactnessOldApi, 48kHzOutputExternalDecoder) { @@ -1069,16 +1079,23 @@ TEST_F(AcmReceiverBitExactnessOldApi, 48kHzOutputExternalDecoder) { rtc::scoped_refptr> factory( new rtc::RefCountedObject); std::string others_checksum_reference = - GetCPUInfo(kAVX2) != 0 ? "884243f7e1476931e93eda5de88d1326" - : "ba0f66d538487bba377e721cfac62d1e"; + GetCPUInfo(kAVX2) != 0 ? "7a55700b7ca9aa60237db58b33e55606" +// TODO(bugs.webrtc.org/12941): Linux x86 optimized builds have a different +// checksum. +#if defined(WEBRTC_LINUX) && defined(NDEBUG) + : "a2459749062f96297283cce4a8c7e6db"; +#else + : "57d1d316c88279f4f3da3511665069a9"; +#endif std::string win64_checksum_reference = GetCPUInfo(kAVX2) != 0 ? "f59833d9b0924f4b0704707dd3589f80" - : "6a480541fb86faa95c7563b9de08104d"; + : "74cbe7345e2b6b45c1e455a5d1e921ca"; Run(48000, PlatformChecksum(others_checksum_reference, win64_checksum_reference, - "30e617e4b3c9ba1979d1b2e8eba3519b", + "f52bc7bf0f499c9da25932fdf176c4ec", "bd44bf97e7899186532f91235cef444d", - "90158462a1853b1de50873eebd68dba7"), + "364d403dae55d73cd69e6dbd6b723a4d", + "71bc5c15a151400517c2119d1602ee9f"), factory, [](AudioCodingModule* acm) { acm->SetReceiveCodecs({{0, {"MockPCMu", 8000, 1}}, {103, {"ISAC", 16000, 1}}, @@ -1299,12 +1316,14 @@ TEST_F(AcmSenderBitExactnessOldApi, IsacWb30ms) { "9336a9b993cbd8a751f0e8958e66c89c", "5c2eb46199994506236f68b2c8e51b0d", "343f1f42be0607c61e6516aece424609", + "2c9cb15d4ed55b5a0cadd04883bc73b0", "2c9cb15d4ed55b5a0cadd04883bc73b0"), AcmReceiverBitExactnessOldApi::PlatformChecksum( "3c79f16f34218271f3dca4e2b1dfe1bb", "d42cb5195463da26c8129bbfe73a22e6", "83de248aea9c3c2bd680b6952401b4ca", "3c79f16f34218271f3dca4e2b1dfe1bb", + "3c79f16f34218271f3dca4e2b1dfe1bb", "3c79f16f34218271f3dca4e2b1dfe1bb"), 33, test::AcmReceiveTestOldApi::kMonoOutput); } @@ -1316,12 +1335,14 @@ TEST_F(AcmSenderBitExactnessOldApi, IsacWb60ms) { "14d63c5f08127d280e722e3191b73bdd", "9a81e467eb1485f84aca796f8ea65011", "ef75e900e6f375e3061163c53fd09a63", + "1ad29139a04782a33daad8c2b9b35875", "1ad29139a04782a33daad8c2b9b35875"), AcmReceiverBitExactnessOldApi::PlatformChecksum( "9e0a0ab743ad987b55b8e14802769c56", "ebe04a819d3a9d83a83a17f271e1139a", "97aeef98553b5a4b5a68f8b716e8eaf0", "9e0a0ab743ad987b55b8e14802769c56", + "9e0a0ab743ad987b55b8e14802769c56", "9e0a0ab743ad987b55b8e14802769c56"), 16, test::AcmReceiveTestOldApi::kMonoOutput); } @@ -1336,13 +1357,21 @@ TEST_F(AcmSenderBitExactnessOldApi, IsacWb60ms) { TEST_F(AcmSenderBitExactnessOldApi, MAYBE_IsacSwb30ms) { ASSERT_NO_FATAL_FAILURE(SetUpTest("ISAC", 32000, 1, 104, 960, 960)); Run(AcmReceiverBitExactnessOldApi::PlatformChecksum( +// TODO(bugs.webrtc.org/12941): Linux x86 optimized builds have a different +// checksum. +#if defined(WEBRTC_LINUX) && defined(NDEBUG) && defined(WEBRTC_ARCH_X86) + "13d4d2a4c9e8e94a4b74a176e4bf7cc4", +#else "5683b58da0fbf2063c7adc2e6bfb3fb8", +#endif "2b3c387d06f00b7b7aad4c9be56fb83d", "android_arm32_audio", - "android_arm64_audio", "android_arm64_clang_audio"), + "android_arm64_audio", "android_arm64_clang_audio", + "5683b58da0fbf2063c7adc2e6bfb3fb8"), AcmReceiverBitExactnessOldApi::PlatformChecksum( "ce86106a93419aefb063097108ec94ab", "bcc2041e7744c7ebd9f701866856849c", "android_arm32_payload", - "android_arm64_payload", "android_arm64_clang_payload"), + "android_arm64_payload", "android_arm64_clang_payload", + "ce86106a93419aefb063097108ec94ab"), 33, test::AcmReceiveTestOldApi::kMonoOutput); } #endif @@ -1418,11 +1447,13 @@ TEST_F(AcmSenderBitExactnessOldApi, MAYBE_Ilbc_30ms) { Run(AcmReceiverBitExactnessOldApi::PlatformChecksum( "7b6ec10910debd9af08011d3ed5249f7", "7b6ec10910debd9af08011d3ed5249f7", "android_arm32_audio", - "android_arm64_audio", "android_arm64_clang_audio"), + "android_arm64_audio", "android_arm64_clang_audio", + "7b6ec10910debd9af08011d3ed5249f7"), AcmReceiverBitExactnessOldApi::PlatformChecksum( "cfae2e9f6aba96e145f2bcdd5050ce78", "cfae2e9f6aba96e145f2bcdd5050ce78", "android_arm32_payload", - "android_arm64_payload", "android_arm64_clang_payload"), + "android_arm64_payload", "android_arm64_clang_payload", + "cfae2e9f6aba96e145f2bcdd5050ce78"), 33, test::AcmReceiveTestOldApi::kMonoOutput); } #endif @@ -1437,11 +1468,13 @@ TEST_F(AcmSenderBitExactnessOldApi, MAYBE_G722_20ms) { Run(AcmReceiverBitExactnessOldApi::PlatformChecksum( "e99c89be49a46325d03c0d990c292d68", "e99c89be49a46325d03c0d990c292d68", "android_arm32_audio", - "android_arm64_audio", "android_arm64_clang_audio"), + "android_arm64_audio", "android_arm64_clang_audio", + "e99c89be49a46325d03c0d990c292d68"), AcmReceiverBitExactnessOldApi::PlatformChecksum( "fc68a87e1380614e658087cb35d5ca10", "fc68a87e1380614e658087cb35d5ca10", "android_arm32_payload", - "android_arm64_payload", "android_arm64_clang_payload"), + "android_arm64_payload", "android_arm64_clang_payload", + "fc68a87e1380614e658087cb35d5ca10"), 50, test::AcmReceiveTestOldApi::kMonoOutput); } @@ -1455,11 +1488,13 @@ TEST_F(AcmSenderBitExactnessOldApi, MAYBE_G722_stereo_20ms) { Run(AcmReceiverBitExactnessOldApi::PlatformChecksum( "e280aed283e499d37091b481ca094807", "e280aed283e499d37091b481ca094807", "android_arm32_audio", - "android_arm64_audio", "android_arm64_clang_audio"), + "android_arm64_audio", "android_arm64_clang_audio", + "e280aed283e499d37091b481ca094807"), AcmReceiverBitExactnessOldApi::PlatformChecksum( "66516152eeaa1e650ad94ff85f668dac", "66516152eeaa1e650ad94ff85f668dac", "android_arm32_payload", - "android_arm64_payload", "android_arm64_clang_payload"), + "android_arm64_payload", "android_arm64_clang_payload", + "66516152eeaa1e650ad94ff85f668dac"), 50, test::AcmReceiveTestOldApi::kStereoOutput); } @@ -1478,23 +1513,29 @@ const std::string audio_checksum = audio_maybe_sse, "6fcceb83acf427730570bc13eeac920c", "fd96f15d547c4e155daeeef4253b174e", - "fd96f15d547c4e155daeeef4253b174e"); + "fd96f15d547c4e155daeeef4253b174e", + "Mac_arm64_checksum_placeholder"); const std::string payload_checksum = AcmReceiverBitExactnessOldApi::PlatformChecksum( payload_maybe_sse, payload_maybe_sse, "4bd846d0aa5656ecd5dfd85701a1b78c", "7efbfc9f8e3b4b2933ae2d01ab919028", - "7efbfc9f8e3b4b2933ae2d01ab919028"); + "7efbfc9f8e3b4b2933ae2d01ab919028", + "Mac_arm64_checksum_placeholder"); } // namespace -TEST_F(AcmSenderBitExactnessOldApi, Opus_stereo_20ms) { +// TODO(http://bugs.webrtc.org/12518): Enable the test after Opus has been +// updated. +TEST_F(AcmSenderBitExactnessOldApi, DISABLED_Opus_stereo_20ms) { ASSERT_NO_FATAL_FAILURE(SetUpTest("opus", 48000, 2, 120, 960, 960)); Run(audio_checksum, payload_checksum, 50, test::AcmReceiveTestOldApi::kStereoOutput); } -TEST_F(AcmSenderBitExactnessNewApi, MAYBE_OpusFromFormat_stereo_20ms) { +// TODO(http://bugs.webrtc.org/12518): Enable the test after Opus has been +// updated. +TEST_F(AcmSenderBitExactnessNewApi, DISABLED_OpusFromFormat_stereo_20ms) { const auto config = AudioEncoderOpus::SdpToConfig( SdpAudioFormat("opus", 48000, 2, {{"stereo", "1"}})); ASSERT_TRUE(SetUpSender(kTestFileFakeStereo32kHz, 32000)); @@ -1541,17 +1582,19 @@ TEST_F(AcmSenderBitExactnessNewApi, DISABLED_OpusManyChannels) { "audio checksum check downstream|8051617907766bec5f4e4a4f7c6d5291", "8051617907766bec5f4e4a4f7c6d5291", "6183752a62dc1368f959eb3a8c93b846", "android arm64 audio checksum", - "48bf1f3ca0b72f3c9cdfbe79956122b1"), + "48bf1f3ca0b72f3c9cdfbe79956122b1", "Mac_arm64_checksum_placeholder"), // payload_checksum, AcmReceiverBitExactnessOldApi::PlatformChecksum( // payload checksum "payload checksum check downstream|b09c52e44b2bdd9a0809e3a5b1623a76", "b09c52e44b2bdd9a0809e3a5b1623a76", "2ea535ef60f7d0c9d89e3002d4c2124f", "android arm64 payload checksum", - "e87995a80f50a0a735a230ca8b04a67d"), + "e87995a80f50a0a735a230ca8b04a67d", "Mac_arm64_checksum_placeholder"), 50, test::AcmReceiveTestOldApi::kQuadOutput, decoder_factory); } -TEST_F(AcmSenderBitExactnessNewApi, OpusFromFormat_stereo_20ms_voip) { +// TODO(http://bugs.webrtc.org/12518): Enable the test after Opus has been +// updated. +TEST_F(AcmSenderBitExactnessNewApi, DISABLED_OpusFromFormat_stereo_20ms_voip) { auto config = AudioEncoderOpus::SdpToConfig( SdpAudioFormat("opus", 48000, 2, {{"stereo", "1"}})); // If not set, default will be kAudio in case of stereo. @@ -1568,12 +1611,12 @@ TEST_F(AcmSenderBitExactnessNewApi, OpusFromFormat_stereo_20ms_voip) { Run(AcmReceiverBitExactnessOldApi::PlatformChecksum( audio_maybe_sse, audio_maybe_sse, "f1cefe107ffdced7694d7f735342adf3", "3b1bfe5dd8ed16ee5b04b93a5b5e7e48", - "3b1bfe5dd8ed16ee5b04b93a5b5e7e48"), + "3b1bfe5dd8ed16ee5b04b93a5b5e7e48", "Mac_arm64_checksum_placeholder"), AcmReceiverBitExactnessOldApi::PlatformChecksum( payload_maybe_sse, payload_maybe_sse, "5e79a2f51c633fe145b6c10ae198d1aa", "e730050cb304d54d853fd285ab0424fa", - "e730050cb304d54d853fd285ab0424fa"), + "e730050cb304d54d853fd285ab0424fa", "Mac_arm64_checksum_placeholder"), 50, test::AcmReceiveTestOldApi::kStereoOutput); } diff --git a/modules/audio_coding/codecs/ilbc/cb_construct.h b/modules/audio_coding/codecs/ilbc/cb_construct.h index 0a4a47aa06..8f7c663164 100644 --- a/modules/audio_coding/codecs/ilbc/cb_construct.h +++ b/modules/audio_coding/codecs/ilbc/cb_construct.h @@ -23,14 +23,15 @@ #include #include +#include "absl/base/attributes.h" #include "modules/audio_coding/codecs/ilbc/defines.h" -#include "rtc_base/system/unused.h" /*----------------------------------------------------------------* * Construct decoded vector from codebook and gains. *---------------------------------------------------------------*/ // Returns true on success, false on failure. +ABSL_MUST_USE_RESULT bool WebRtcIlbcfix_CbConstruct( int16_t* decvector, /* (o) Decoded vector */ const int16_t* index, /* (i) Codebook indices */ @@ -38,6 +39,6 @@ bool WebRtcIlbcfix_CbConstruct( int16_t* mem, /* (i) Buffer for codevector construction */ size_t lMem, /* (i) Length of buffer */ size_t veclen /* (i) Length of vector */ - ) RTC_WARN_UNUSED_RESULT; +); #endif diff --git a/modules/audio_coding/codecs/ilbc/decode.h b/modules/audio_coding/codecs/ilbc/decode.h index d73f79880b..a7d2910115 100644 --- a/modules/audio_coding/codecs/ilbc/decode.h +++ b/modules/audio_coding/codecs/ilbc/decode.h @@ -21,21 +21,22 @@ #include +#include "absl/base/attributes.h" #include "modules/audio_coding/codecs/ilbc/defines.h" -#include "rtc_base/system/unused.h" /*----------------------------------------------------------------* * main decoder function *---------------------------------------------------------------*/ // Returns 0 on success, -1 on error. +ABSL_MUST_USE_RESULT int WebRtcIlbcfix_DecodeImpl( int16_t* decblock, /* (o) decoded signal block */ const uint16_t* bytes, /* (i) encoded signal bits */ IlbcDecoder* iLBCdec_inst, /* (i/o) the decoder state structure */ - int16_t mode /* (i) 0: bad packet, PLC, - 1: normal */ - ) RTC_WARN_UNUSED_RESULT; + int16_t mode /* (i) 0: bad packet, PLC, + 1: normal */ +); #endif diff --git a/modules/audio_coding/codecs/ilbc/decode_residual.h b/modules/audio_coding/codecs/ilbc/decode_residual.h index 30eb35f82b..d079577661 100644 --- a/modules/audio_coding/codecs/ilbc/decode_residual.h +++ b/modules/audio_coding/codecs/ilbc/decode_residual.h @@ -23,8 +23,8 @@ #include #include +#include "absl/base/attributes.h" #include "modules/audio_coding/codecs/ilbc/defines.h" -#include "rtc_base/system/unused.h" /*----------------------------------------------------------------* * frame residual decoder function (subrutine to iLBC_decode) @@ -32,6 +32,7 @@ // Returns true on success, false on failure. In case of failure, the decoder // state may be corrupted and needs resetting. +ABSL_MUST_USE_RESULT bool WebRtcIlbcfix_DecodeResidual( IlbcDecoder* iLBCdec_inst, /* (i/o) the decoder state structure */ iLBC_bits* iLBC_encbits, /* (i/o) Encoded bits, which are used @@ -39,6 +40,6 @@ bool WebRtcIlbcfix_DecodeResidual( int16_t* decresidual, /* (o) decoded residual frame */ int16_t* syntdenum /* (i) the decoded synthesis filter coefficients */ - ) RTC_WARN_UNUSED_RESULT; +); #endif diff --git a/modules/audio_coding/codecs/ilbc/enhancer_interface.c b/modules/audio_coding/codecs/ilbc/enhancer_interface.c index fb9740eb22..ca23e19ae3 100644 --- a/modules/audio_coding/codecs/ilbc/enhancer_interface.c +++ b/modules/audio_coding/codecs/ilbc/enhancer_interface.c @@ -18,6 +18,7 @@ #include "modules/audio_coding/codecs/ilbc/enhancer_interface.h" +#include #include #include "modules/audio_coding/codecs/ilbc/constants.h" @@ -203,11 +204,11 @@ size_t // (o) Estimated lag in end of in[] regressor=in+tlag-1; /* scaling */ - // Note that this is not abs-max, but it doesn't matter since we use only - // the square of it. - max16 = regressor[WebRtcSpl_MaxAbsIndexW16(regressor, plc_blockl + 3 - 1)]; - - const int64_t max_val = plc_blockl * max16 * max16; + // Note that this is not abs-max, so we will take the absolute value below. + max16 = WebRtcSpl_MaxAbsElementW16(regressor, plc_blockl + 3 - 1); + const int16_t max_target = + WebRtcSpl_MaxAbsElementW16(target, plc_blockl + 3 - 1); + const int64_t max_val = plc_blockl * abs(max16 * max_target); const int32_t factor = max_val >> 31; shifts = factor == 0 ? 0 : 31 - WebRtcSpl_NormW32(factor); diff --git a/modules/audio_coding/codecs/ilbc/get_cd_vec.h b/modules/audio_coding/codecs/ilbc/get_cd_vec.h index 647b0634a0..99537dd0f7 100644 --- a/modules/audio_coding/codecs/ilbc/get_cd_vec.h +++ b/modules/audio_coding/codecs/ilbc/get_cd_vec.h @@ -23,17 +23,18 @@ #include #include +#include "absl/base/attributes.h" #include "modules/audio_coding/codecs/ilbc/defines.h" -#include "rtc_base/system/unused.h" // Returns true on success, false on failure. In case of failure, the decoder // state may be corrupted and needs resetting. +ABSL_MUST_USE_RESULT bool WebRtcIlbcfix_GetCbVec( int16_t* cbvec, /* (o) Constructed codebook vector */ int16_t* mem, /* (i) Codebook buffer */ size_t index, /* (i) Codebook index */ size_t lMem, /* (i) Length of codebook buffer */ size_t cbveclen /* (i) Codebook vector length */ - ) RTC_WARN_UNUSED_RESULT; +); #endif diff --git a/modules/audio_coding/codecs/isac/audio_encoder_isac_t_impl.h b/modules/audio_coding/codecs/isac/audio_encoder_isac_t_impl.h index 0bde3f797f..fa84515204 100644 --- a/modules/audio_coding/codecs/isac/audio_encoder_isac_t_impl.h +++ b/modules/audio_coding/codecs/isac/audio_encoder_isac_t_impl.h @@ -140,6 +140,11 @@ AudioEncoder::EncodedInfo AudioEncoderIsacT::EncodeImpl( kSufficientEncodeBufferSizeBytes, [&](rtc::ArrayView encoded) { int r = T::Encode(isac_state_, audio.data(), encoded.data()); + if (T::GetErrorCode(isac_state_) == 6450) { + // Isac is not able to effectively compress all types of signals. This + // is a limitation of the codec that cannot be easily fixed. + r = 0; + } RTC_CHECK_GE(r, 0) << "Encode failed (error code " << T::GetErrorCode(isac_state_) << ")"; diff --git a/modules/audio_coding/codecs/isac/fix/test/isac_speed_test.cc b/modules/audio_coding/codecs/isac/fix/test/isac_speed_test.cc index 20752639fc..903ac64aff 100644 --- a/modules/audio_coding/codecs/isac/fix/test/isac_speed_test.cc +++ b/modules/audio_coding/codecs/isac/fix/test/isac_speed_test.cc @@ -11,6 +11,7 @@ #include "modules/audio_coding/codecs/isac/fix/include/isacfix.h" #include "modules/audio_coding/codecs/isac/fix/source/settings.h" #include "modules/audio_coding/codecs/tools/audio_codec_speed_test.h" +#include "rtc_base/checks.h" using std::string; @@ -83,7 +84,7 @@ float IsacSpeedTest::EncodeABlock(int16_t* in_data, } clocks = clock() - clocks; *encoded_bytes = static_cast(value); - assert(*encoded_bytes <= max_bytes); + RTC_DCHECK_LE(*encoded_bytes, max_bytes); return 1000.0 * clocks / CLOCKS_PER_SEC; } diff --git a/modules/audio_coding/codecs/legacy_encoded_audio_frame_unittest.cc b/modules/audio_coding/codecs/legacy_encoded_audio_frame_unittest.cc index 2ca1d4ca98..f081a5380f 100644 --- a/modules/audio_coding/codecs/legacy_encoded_audio_frame_unittest.cc +++ b/modules/audio_coding/codecs/legacy_encoded_audio_frame_unittest.cc @@ -88,7 +88,7 @@ class SplitBySamplesTest : public ::testing::TestWithParam { samples_per_ms_ = 8; break; default: - assert(false); + RTC_NOTREACHED(); break; } } diff --git a/modules/audio_coding/codecs/opus/audio_encoder_opus.cc b/modules/audio_coding/codecs/opus/audio_encoder_opus.cc index 203cb5aeb3..7c62e98c5b 100644 --- a/modules/audio_coding/codecs/opus/audio_encoder_opus.cc +++ b/modules/audio_coding/codecs/opus/audio_encoder_opus.cc @@ -704,6 +704,11 @@ bool AudioEncoderOpusImpl::RecreateEncoderInstance( } void AudioEncoderOpusImpl::SetFrameLength(int frame_length_ms) { + if (next_frame_length_ms_ != frame_length_ms) { + RTC_LOG(LS_VERBOSE) << "Update Opus frame length " + << "from " << next_frame_length_ms_ << " ms " + << "to " << frame_length_ms << " ms."; + } next_frame_length_ms_ = frame_length_ms; } diff --git a/modules/audio_coding/codecs/opus/audio_encoder_opus_unittest.cc b/modules/audio_coding/codecs/opus/audio_encoder_opus_unittest.cc index 0fe87bc31e..f1953eaacf 100644 --- a/modules/audio_coding/codecs/opus/audio_encoder_opus_unittest.cc +++ b/modules/audio_coding/codecs/opus/audio_encoder_opus_unittest.cc @@ -809,4 +809,97 @@ TEST_P(AudioEncoderOpusTest, OpusFlagDtxAsNonSpeech) { EXPECT_GT(max_nonspeech_frames, 15); } +TEST(AudioEncoderOpusTest, OpusDtxFilteringHighEnergyRefreshPackets) { + test::ScopedFieldTrials override_field_trials( + "WebRTC-Audio-OpusAvoidNoisePumpingDuringDtx/Enabled/"); + const std::string kInputFileName = + webrtc::test::ResourcePath("audio_coding/testfile16kHz", "pcm"); + constexpr int kSampleRateHz = 16000; + AudioEncoderOpusConfig config; + config.dtx_enabled = true; + config.sample_rate_hz = kSampleRateHz; + constexpr int payload_type = 17; + const auto encoder = AudioEncoderOpus::MakeAudioEncoder(config, payload_type); + test::AudioLoop audio_loop; + constexpr size_t kMaxLoopLengthSaples = kSampleRateHz * 11.6f; + constexpr size_t kInputBlockSizeSamples = kSampleRateHz / 100; + EXPECT_TRUE(audio_loop.Init(kInputFileName, kMaxLoopLengthSaples, + kInputBlockSizeSamples)); + AudioEncoder::EncodedInfo info; + rtc::Buffer encoded(500); + // Encode the audio file and store the last part that corresponds to silence. + constexpr size_t kSilenceDurationSamples = kSampleRateHz * 0.2f; + std::array silence; + uint32_t rtp_timestamp = 0; + bool last_packet_dtx_frame = false; + bool opus_entered_dtx = false; + bool silence_filled = false; + size_t timestamp_start_silence = 0; + while (!silence_filled && rtp_timestamp < kMaxLoopLengthSaples) { + encoded.Clear(); + // Every second call to the encoder will generate an Opus packet. + for (int j = 0; j < 2; j++) { + auto next_frame = audio_loop.GetNextBlock(); + info = encoder->Encode(rtp_timestamp, next_frame, &encoded); + if (opus_entered_dtx) { + size_t silence_frame_start = rtp_timestamp - timestamp_start_silence; + silence_filled = silence_frame_start >= kSilenceDurationSamples; + if (!silence_filled) { + std::copy(next_frame.begin(), next_frame.end(), + silence.begin() + silence_frame_start); + } + } + rtp_timestamp += kInputBlockSizeSamples; + } + EXPECT_TRUE(info.encoded_bytes > 0 || last_packet_dtx_frame); + last_packet_dtx_frame = info.encoded_bytes > 0 ? info.encoded_bytes <= 2 + : last_packet_dtx_frame; + if (info.encoded_bytes <= 2 && !opus_entered_dtx) { + timestamp_start_silence = rtp_timestamp; + } + opus_entered_dtx = info.encoded_bytes <= 2; + } + + EXPECT_TRUE(silence_filled); + // The copied 200 ms of silence is used for creating 6 bursts that are fed to + // the encoder, the first three ones with a larger energy and the last three + // with a lower energy. This test verifies that the encoder just sends refresh + // DTX packets during the last bursts. + int number_non_empty_packets_during_increase = 0; + int number_non_empty_packets_during_decrease = 0; + for (size_t burst = 0; burst < 6; ++burst) { + uint32_t rtp_timestamp_start = rtp_timestamp; + const bool increase_noise = burst < 3; + const float gain = increase_noise ? 1.4f : 0.0f; + while (rtp_timestamp < rtp_timestamp_start + kSilenceDurationSamples) { + encoded.Clear(); + // Every second call to the encoder will generate an Opus packet. + for (int j = 0; j < 2; j++) { + std::array silence_frame; + size_t silence_frame_start = rtp_timestamp - rtp_timestamp_start; + std::transform( + silence.begin() + silence_frame_start, + silence.begin() + silence_frame_start + kInputBlockSizeSamples, + silence_frame.begin(), [gain](float s) { return gain * s; }); + info = encoder->Encode(rtp_timestamp, silence_frame, &encoded); + rtp_timestamp += kInputBlockSizeSamples; + } + EXPECT_TRUE(info.encoded_bytes > 0 || last_packet_dtx_frame); + last_packet_dtx_frame = info.encoded_bytes > 0 ? info.encoded_bytes <= 2 + : last_packet_dtx_frame; + // Tracking the number of non empty packets. + if (increase_noise && info.encoded_bytes > 2) { + number_non_empty_packets_during_increase++; + } + if (!increase_noise && info.encoded_bytes > 2) { + number_non_empty_packets_during_decrease++; + } + } + } + // Check that the refresh DTX packets are just sent during the decrease energy + // region. + EXPECT_EQ(number_non_empty_packets_during_increase, 0); + EXPECT_GT(number_non_empty_packets_during_decrease, 0); +} + } // namespace webrtc diff --git a/modules/audio_coding/codecs/opus/opus_inst.h b/modules/audio_coding/codecs/opus/opus_inst.h index 148baa2806..2c25e43f25 100644 --- a/modules/audio_coding/codecs/opus/opus_inst.h +++ b/modules/audio_coding/codecs/opus/opus_inst.h @@ -25,6 +25,9 @@ struct WebRtcOpusEncInst { OpusMSEncoder* multistream_encoder; size_t channels; int in_dtx_mode; + bool avoid_noise_pumping_during_dtx; + int sample_rate_hz; + float smooth_energy_non_active_frames; }; struct WebRtcOpusDecInst { diff --git a/modules/audio_coding/codecs/opus/opus_interface.cc b/modules/audio_coding/codecs/opus/opus_interface.cc index ca39ed8235..f684452ad5 100644 --- a/modules/audio_coding/codecs/opus/opus_interface.cc +++ b/modules/audio_coding/codecs/opus/opus_interface.cc @@ -12,6 +12,9 @@ #include +#include + +#include "api/array_view.h" #include "rtc_base/checks.h" #include "system_wrappers/include/field_trial.h" @@ -36,6 +39,9 @@ enum { constexpr char kPlcUsePrevDecodedSamplesFieldTrial[] = "WebRTC-Audio-OpusPlcUsePrevDecodedSamples"; +constexpr char kAvoidNoisePumpingDuringDtxFieldTrial[] = + "WebRTC-Audio-OpusAvoidNoisePumpingDuringDtx"; + static int FrameSizePerChannel(int frame_size_ms, int sample_rate_hz) { RTC_DCHECK_GT(frame_size_ms, 0); RTC_DCHECK_EQ(frame_size_ms % 10, 0); @@ -54,6 +60,46 @@ static int DefaultFrameSizePerChannel(int sample_rate_hz) { return FrameSizePerChannel(20, sample_rate_hz); } +// Returns true if the `encoded` payload corresponds to a refresh DTX packet +// whose energy is larger than the expected for non activity packets. +static bool WebRtcOpus_IsHighEnergyRefreshDtxPacket( + OpusEncInst* inst, + rtc::ArrayView frame, + rtc::ArrayView encoded) { + if (encoded.size() <= 2) { + return false; + } + int number_frames = + frame.size() / DefaultFrameSizePerChannel(inst->sample_rate_hz); + if (number_frames > 0 && + WebRtcOpus_PacketHasVoiceActivity(encoded.data(), encoded.size()) == 0) { + const float average_frame_energy = + std::accumulate(frame.begin(), frame.end(), 0.0f, + [](float a, int32_t b) { return a + b * b; }) / + number_frames; + if (WebRtcOpus_GetInDtx(inst) == 1 && + average_frame_energy >= inst->smooth_energy_non_active_frames * 0.5f) { + // This is a refresh DTX packet as the encoder is in DTX and has + // produced a payload > 2 bytes. This refresh packet has a higher energy + // than the smooth energy of non activity frames (with a 3 dB negative + // margin) and, therefore, it is flagged as a high energy refresh DTX + // packet. + return true; + } + // The average energy is tracked in a similar way as the modeling of the + // comfort noise in the Silk decoder in Opus + // (third_party/opus/src/silk/CNG.c). + if (average_frame_energy < inst->smooth_energy_non_active_frames * 0.5f) { + inst->smooth_energy_non_active_frames = average_frame_energy; + } else { + inst->smooth_energy_non_active_frames += + (average_frame_energy - inst->smooth_energy_non_active_frames) * + 0.25f; + } + } + return false; +} + int16_t WebRtcOpus_EncoderCreate(OpusEncInst** inst, size_t channels, int32_t application, @@ -88,6 +134,10 @@ int16_t WebRtcOpus_EncoderCreate(OpusEncInst** inst, state->in_dtx_mode = 0; state->channels = channels; + state->sample_rate_hz = sample_rate_hz; + state->smooth_energy_non_active_frames = 0.0f; + state->avoid_noise_pumping_during_dtx = + webrtc::field_trial::IsEnabled(kAvoidNoisePumpingDuringDtxFieldTrial); *inst = state; return 0; @@ -120,9 +170,10 @@ int16_t WebRtcOpus_MultistreamEncoderCreate( RTC_DCHECK(state); int error; - state->multistream_encoder = - opus_multistream_encoder_create(48000, channels, streams, coupled_streams, - channel_mapping, opus_app, &error); + const int sample_rate_hz = 48000; + state->multistream_encoder = opus_multistream_encoder_create( + sample_rate_hz, channels, streams, coupled_streams, channel_mapping, + opus_app, &error); if (error != OPUS_OK || (!state->encoder && !state->multistream_encoder)) { WebRtcOpus_EncoderFree(state); @@ -131,6 +182,9 @@ int16_t WebRtcOpus_MultistreamEncoderCreate( state->in_dtx_mode = 0; state->channels = channels; + state->sample_rate_hz = sample_rate_hz; + state->smooth_energy_non_active_frames = 0.0f; + state->avoid_noise_pumping_during_dtx = false; *inst = state; return 0; @@ -188,6 +242,21 @@ int WebRtcOpus_Encode(OpusEncInst* inst, } } + if (inst->avoid_noise_pumping_during_dtx && WebRtcOpus_GetUseDtx(inst) == 1 && + WebRtcOpus_IsHighEnergyRefreshDtxPacket( + inst, rtc::MakeArrayView(audio_in, samples), + rtc::MakeArrayView(encoded, res))) { + // This packet is a high energy refresh DTX packet. For avoiding an increase + // of the energy in the DTX region at the decoder, this packet is + // substituted by a TOC byte with one empty frame. + // The number of frames described in the TOC byte + // (https://tools.ietf.org/html/rfc6716#section-3.1) are overwritten to + // always indicate one frame (last two bits equal to 0). + encoded[0] = encoded[0] & 0b11111100; + inst->in_dtx_mode = 1; + // The payload is just the TOC byte and has 1 byte as length. + return 1; + } inst->in_dtx_mode = 0; return res; } @@ -316,6 +385,16 @@ int16_t WebRtcOpus_DisableDtx(OpusEncInst* inst) { } } +int16_t WebRtcOpus_GetUseDtx(OpusEncInst* inst) { + if (inst) { + opus_int32 use_dtx; + if (ENCODER_CTL(inst, OPUS_GET_DTX(&use_dtx)) == 0) { + return use_dtx; + } + } + return -1; +} + int16_t WebRtcOpus_EnableCbr(OpusEncInst* inst) { if (inst) { return ENCODER_CTL(inst, OPUS_SET_VBR(0)); diff --git a/modules/audio_coding/codecs/opus/opus_interface.h b/modules/audio_coding/codecs/opus/opus_interface.h index 2a3ceaa7d3..89159ce1c0 100644 --- a/modules/audio_coding/codecs/opus/opus_interface.h +++ b/modules/audio_coding/codecs/opus/opus_interface.h @@ -231,6 +231,20 @@ int16_t WebRtcOpus_EnableDtx(OpusEncInst* inst); */ int16_t WebRtcOpus_DisableDtx(OpusEncInst* inst); +/**************************************************************************** + * WebRtcOpus_GetUseDtx() + * + * This function gets the DTX configuration used for encoding. + * + * Input: + * - inst : Encoder context + * + * Return value : 0 - Encoder does not use DTX. + * 1 - Encoder uses DTX. + * -1 - Error. + */ +int16_t WebRtcOpus_GetUseDtx(OpusEncInst* inst); + /**************************************************************************** * WebRtcOpus_EnableCbr() * diff --git a/modules/audio_coding/codecs/red/audio_encoder_copy_red.cc b/modules/audio_coding/codecs/red/audio_encoder_copy_red.cc index 1432e3182f..c72768e937 100644 --- a/modules/audio_coding/codecs/red/audio_encoder_copy_red.cc +++ b/modules/audio_coding/codecs/red/audio_encoder_copy_red.cc @@ -17,22 +17,51 @@ #include "rtc_base/byte_order.h" #include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "system_wrappers/include/field_trial.h" namespace webrtc { -// RED packets must be less than 1024 bytes to fit the 10 bit block length. -static constexpr const int kRedMaxPacketSize = 1 << 10; -// The typical MTU is 1200 bytes. -static constexpr const size_t kAudioMaxRtpPacketLen = 1200; +static constexpr const int kRedMaxPacketSize = + 1 << 10; // RED packets must be less than 1024 bytes to fit the 10 bit + // block length. +static constexpr const size_t kAudioMaxRtpPacketLen = + 1200; // The typical MTU is 1200 bytes. + +static constexpr size_t kRedHeaderLength = 4; // 4 bytes RED header. +static constexpr size_t kRedLastHeaderLength = + 1; // reduced size for last RED header. + +static constexpr size_t kRedNumberOfRedundantEncodings = + 2; // The level of redundancy we support. AudioEncoderCopyRed::Config::Config() = default; AudioEncoderCopyRed::Config::Config(Config&&) = default; AudioEncoderCopyRed::Config::~Config() = default; +size_t GetMaxRedundancyFromFieldTrial() { + const std::string red_trial = + webrtc::field_trial::FindFullName("WebRTC-Audio-Red-For-Opus"); + size_t redundancy = 0; + if (sscanf(red_trial.c_str(), "Enabled-%zu", &redundancy) != 1 || + redundancy < 1 || redundancy > 9) { + return kRedNumberOfRedundantEncodings; + } + return redundancy; +} + AudioEncoderCopyRed::AudioEncoderCopyRed(Config&& config) : speech_encoder_(std::move(config.speech_encoder)), + primary_encoded_(0, kAudioMaxRtpPacketLen), max_packet_length_(kAudioMaxRtpPacketLen), red_payload_type_(config.payload_type) { RTC_CHECK(speech_encoder_) << "Speech encoder not provided."; + + auto number_of_redundant_encodings = GetMaxRedundancyFromFieldTrial(); + for (size_t i = 0; i < number_of_redundant_encodings; i++) { + std::pair redundant; + redundant.second.EnsureCapacity(kAudioMaxRtpPacketLen); + redundant_encodings_.push_front(std::move(redundant)); + } } AudioEncoderCopyRed::~AudioEncoderCopyRed() = default; @@ -61,104 +90,86 @@ int AudioEncoderCopyRed::GetTargetBitrate() const { return speech_encoder_->GetTargetBitrate(); } -size_t AudioEncoderCopyRed::CalculateHeaderLength(size_t encoded_bytes) const { - size_t header_size = 1; - size_t bytes_available = max_packet_length_ - encoded_bytes; - if (secondary_info_.encoded_bytes > 0 && - secondary_info_.encoded_bytes < bytes_available) { - header_size += 4; - bytes_available -= secondary_info_.encoded_bytes; - } - if (tertiary_info_.encoded_bytes > 0 && - tertiary_info_.encoded_bytes < bytes_available) { - header_size += 4; - } - return header_size > 1 ? header_size : 0; -} - AudioEncoder::EncodedInfo AudioEncoderCopyRed::EncodeImpl( uint32_t rtp_timestamp, rtc::ArrayView audio, rtc::Buffer* encoded) { - rtc::Buffer primary_encoded; + primary_encoded_.Clear(); EncodedInfo info = - speech_encoder_->Encode(rtp_timestamp, audio, &primary_encoded); + speech_encoder_->Encode(rtp_timestamp, audio, &primary_encoded_); RTC_CHECK(info.redundant.empty()) << "Cannot use nested redundant encoders."; - RTC_DCHECK_EQ(primary_encoded.size(), info.encoded_bytes); + RTC_DCHECK_EQ(primary_encoded_.size(), info.encoded_bytes); if (info.encoded_bytes == 0 || info.encoded_bytes > kRedMaxPacketSize) { return info; } RTC_DCHECK_GT(max_packet_length_, info.encoded_bytes); + size_t header_length_bytes = kRedLastHeaderLength; + size_t bytes_available = max_packet_length_ - info.encoded_bytes; + auto it = redundant_encodings_.begin(); + + // Determine how much redundancy we can fit into our packet by + // iterating forward. + for (; it != redundant_encodings_.end(); it++) { + if (bytes_available < kRedHeaderLength + it->first.encoded_bytes) { + break; + } + if (it->first.encoded_bytes == 0) { + break; + } + bytes_available -= kRedHeaderLength + it->first.encoded_bytes; + header_length_bytes += kRedHeaderLength; + } + // Allocate room for RFC 2198 header if there is redundant data. // Otherwise this will send the primary payload type without // wrapping in RED. - const size_t header_length_bytes = CalculateHeaderLength(info.encoded_bytes); + if (header_length_bytes == kRedLastHeaderLength) { + header_length_bytes = 0; + } encoded->SetSize(header_length_bytes); + // Iterate backwards and append the data. size_t header_offset = 0; - size_t bytes_available = max_packet_length_ - info.encoded_bytes; - if (tertiary_info_.encoded_bytes > 0 && - tertiary_info_.encoded_bytes + secondary_info_.encoded_bytes < - bytes_available) { - encoded->AppendData(tertiary_encoded_); + while (it-- != redundant_encodings_.begin()) { + encoded->AppendData(it->second); const uint32_t timestamp_delta = - info.encoded_timestamp - tertiary_info_.encoded_timestamp; - - encoded->data()[header_offset] = tertiary_info_.payload_type | 0x80; + info.encoded_timestamp - it->first.encoded_timestamp; + encoded->data()[header_offset] = it->first.payload_type | 0x80; rtc::SetBE16(static_cast(encoded->data()) + header_offset + 1, - (timestamp_delta << 2) | (tertiary_info_.encoded_bytes >> 8)); - encoded->data()[header_offset + 3] = tertiary_info_.encoded_bytes & 0xff; - header_offset += 4; - bytes_available -= tertiary_info_.encoded_bytes; + (timestamp_delta << 2) | (it->first.encoded_bytes >> 8)); + encoded->data()[header_offset + 3] = it->first.encoded_bytes & 0xff; + header_offset += kRedHeaderLength; + info.redundant.push_back(it->first); } - if (secondary_info_.encoded_bytes > 0 && - secondary_info_.encoded_bytes < bytes_available) { - encoded->AppendData(secondary_encoded_); - - const uint32_t timestamp_delta = - info.encoded_timestamp - secondary_info_.encoded_timestamp; - - encoded->data()[header_offset] = secondary_info_.payload_type | 0x80; - rtc::SetBE16(static_cast(encoded->data()) + header_offset + 1, - (timestamp_delta << 2) | (secondary_info_.encoded_bytes >> 8)); - encoded->data()[header_offset + 3] = secondary_info_.encoded_bytes & 0xff; - header_offset += 4; - bytes_available -= secondary_info_.encoded_bytes; + // |info| will be implicitly cast to an EncodedInfoLeaf struct, effectively + // discarding the (empty) vector of redundant information. This is + // intentional. + if (header_length_bytes > 0) { + info.redundant.push_back(info); + RTC_DCHECK_EQ(info.speech, + info.redundant[info.redundant.size() - 1].speech); } - encoded->AppendData(primary_encoded); + encoded->AppendData(primary_encoded_); if (header_length_bytes > 0) { RTC_DCHECK_EQ(header_offset, header_length_bytes - 1); encoded->data()[header_offset] = info.payload_type; } - // |info| will be implicitly cast to an EncodedInfoLeaf struct, effectively - // discarding the (empty) vector of redundant information. This is - // intentional. - info.redundant.push_back(info); - RTC_DCHECK_EQ(info.redundant.size(), 1); - RTC_DCHECK_EQ(info.speech, info.redundant[0].speech); - if (secondary_info_.encoded_bytes > 0) { - info.redundant.push_back(secondary_info_); - RTC_DCHECK_EQ(info.redundant.size(), 2); + // Shift the redundant encodings. + it = redundant_encodings_.begin(); + for (auto next = std::next(it); next != redundant_encodings_.end(); + it++, next = std::next(it)) { + next->first = it->first; + next->second.SetData(it->second); } - if (tertiary_info_.encoded_bytes > 0) { - info.redundant.push_back(tertiary_info_); - RTC_DCHECK_EQ(info.redundant.size(), - 2 + (secondary_info_.encoded_bytes > 0 ? 1 : 0)); - } - - // Save secondary to tertiary. - tertiary_encoded_.SetData(secondary_encoded_); - tertiary_info_ = secondary_info_; - - // Save primary to secondary. - secondary_encoded_.SetData(primary_encoded); - secondary_info_ = info; + it = redundant_encodings_.begin(); + it->first = info; + it->second.SetData(primary_encoded_); // Update main EncodedInfo. if (header_length_bytes > 0) { @@ -170,8 +181,13 @@ AudioEncoder::EncodedInfo AudioEncoderCopyRed::EncodeImpl( void AudioEncoderCopyRed::Reset() { speech_encoder_->Reset(); - secondary_encoded_.Clear(); - secondary_info_.encoded_bytes = 0; + auto number_of_redundant_encodings = redundant_encodings_.size(); + redundant_encodings_.clear(); + for (size_t i = 0; i < number_of_redundant_encodings; i++) { + std::pair redundant; + redundant.second.EnsureCapacity(kAudioMaxRtpPacketLen); + redundant_encodings_.push_front(std::move(redundant)); + } } bool AudioEncoderCopyRed::SetFec(bool enable) { @@ -182,6 +198,10 @@ bool AudioEncoderCopyRed::SetDtx(bool enable) { return speech_encoder_->SetDtx(enable); } +bool AudioEncoderCopyRed::GetDtx() const { + return speech_encoder_->GetDtx(); +} + bool AudioEncoderCopyRed::SetApplication(Application application) { return speech_encoder_->SetApplication(application); } @@ -190,9 +210,14 @@ void AudioEncoderCopyRed::SetMaxPlaybackRate(int frequency_hz) { speech_encoder_->SetMaxPlaybackRate(frequency_hz); } -rtc::ArrayView> -AudioEncoderCopyRed::ReclaimContainedEncoders() { - return rtc::ArrayView>(&speech_encoder_, 1); +bool AudioEncoderCopyRed::EnableAudioNetworkAdaptor( + const std::string& config_string, + RtcEventLog* event_log) { + return speech_encoder_->EnableAudioNetworkAdaptor(config_string, event_log); +} + +void AudioEncoderCopyRed::DisableAudioNetworkAdaptor() { + speech_encoder_->DisableAudioNetworkAdaptor(); } void AudioEncoderCopyRed::OnReceivedUplinkPacketLossFraction( @@ -208,14 +233,38 @@ void AudioEncoderCopyRed::OnReceivedUplinkBandwidth( bwe_period_ms); } +void AudioEncoderCopyRed::OnReceivedUplinkAllocation( + BitrateAllocationUpdate update) { + speech_encoder_->OnReceivedUplinkAllocation(update); +} + absl::optional> AudioEncoderCopyRed::GetFrameLengthRange() const { return speech_encoder_->GetFrameLengthRange(); } +void AudioEncoderCopyRed::OnReceivedRtt(int rtt_ms) { + speech_encoder_->OnReceivedRtt(rtt_ms); +} + void AudioEncoderCopyRed::OnReceivedOverhead(size_t overhead_bytes_per_packet) { max_packet_length_ = kAudioMaxRtpPacketLen - overhead_bytes_per_packet; return speech_encoder_->OnReceivedOverhead(overhead_bytes_per_packet); } +void AudioEncoderCopyRed::SetReceiverFrameLengthRange(int min_frame_length_ms, + int max_frame_length_ms) { + return speech_encoder_->SetReceiverFrameLengthRange(min_frame_length_ms, + max_frame_length_ms); +} + +ANAStats AudioEncoderCopyRed::GetANAStats() const { + return speech_encoder_->GetANAStats(); +} + +rtc::ArrayView> +AudioEncoderCopyRed::ReclaimContainedEncoders() { + return rtc::ArrayView>(&speech_encoder_, 1); +} + } // namespace webrtc diff --git a/modules/audio_coding/codecs/red/audio_encoder_copy_red.h b/modules/audio_coding/codecs/red/audio_encoder_copy_red.h index 9806772ba4..d5b1bf6868 100644 --- a/modules/audio_coding/codecs/red/audio_encoder_copy_red.h +++ b/modules/audio_coding/codecs/red/audio_encoder_copy_red.h @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -26,10 +27,12 @@ namespace webrtc { -// This class implements redundant audio coding. The class object will have an -// underlying AudioEncoder object that performs the actual encodings. The -// current class will gather the two latest encodings from the underlying codec -// into one packet. +// This class implements redundant audio coding as described in +// https://tools.ietf.org/html/rfc2198 +// The class object will have an underlying AudioEncoder object that performs +// the actual encodings. The current class will gather the N latest encodings +// from the underlying codec into one packet. Currently N is hard-coded to 2. + class AudioEncoderCopyRed final : public AudioEncoder { public: struct Config { @@ -50,21 +53,33 @@ class AudioEncoderCopyRed final : public AudioEncoder { size_t Num10MsFramesInNextPacket() const override; size_t Max10MsFramesInAPacket() const override; int GetTargetBitrate() const override; + void Reset() override; bool SetFec(bool enable) override; + bool SetDtx(bool enable) override; + bool GetDtx() const override; + bool SetApplication(Application application) override; void SetMaxPlaybackRate(int frequency_hz) override; - rtc::ArrayView> ReclaimContainedEncoders() - override; + bool EnableAudioNetworkAdaptor(const std::string& config_string, + RtcEventLog* event_log) override; + void DisableAudioNetworkAdaptor() override; void OnReceivedUplinkPacketLossFraction( float uplink_packet_loss_fraction) override; void OnReceivedUplinkBandwidth( int target_audio_bitrate_bps, absl::optional bwe_period_ms) override; + void OnReceivedUplinkAllocation(BitrateAllocationUpdate update) override; + void OnReceivedRtt(int rtt_ms) override; void OnReceivedOverhead(size_t overhead_bytes_per_packet) override; + void SetReceiverFrameLengthRange(int min_frame_length_ms, + int max_frame_length_ms) override; + ANAStats GetANAStats() const override; absl::optional> GetFrameLengthRange() const override; + rtc::ArrayView> ReclaimContainedEncoders() + override; protected: EncodedInfo EncodeImpl(uint32_t rtp_timestamp, @@ -72,15 +87,11 @@ class AudioEncoderCopyRed final : public AudioEncoder { rtc::Buffer* encoded) override; private: - size_t CalculateHeaderLength(size_t encoded_bytes) const; - std::unique_ptr speech_encoder_; + rtc::Buffer primary_encoded_; size_t max_packet_length_; int red_payload_type_; - rtc::Buffer secondary_encoded_; - EncodedInfoLeaf secondary_info_; - rtc::Buffer tertiary_encoded_; - EncodedInfoLeaf tertiary_info_; + std::list> redundant_encodings_; RTC_DISALLOW_COPY_AND_ASSIGN(AudioEncoderCopyRed); }; diff --git a/modules/audio_coding/codecs/red/audio_encoder_copy_red_unittest.cc b/modules/audio_coding/codecs/red/audio_encoder_copy_red_unittest.cc index 33527997b5..ddd82441db 100644 --- a/modules/audio_coding/codecs/red/audio_encoder_copy_red_unittest.cc +++ b/modules/audio_coding/codecs/red/audio_encoder_copy_red_unittest.cc @@ -152,7 +152,7 @@ TEST_F(AudioEncoderCopyRedTest, CheckNoOutput) { Encode(); // First call is a special case, since it does not include a secondary // payload. - EXPECT_EQ(1u, encoded_info_.redundant.size()); + EXPECT_EQ(0u, encoded_info_.redundant.size()); EXPECT_EQ(kEncodedSize, encoded_info_.encoded_bytes); // Next call to the speech encoder will not produce any output. @@ -180,7 +180,7 @@ TEST_F(AudioEncoderCopyRedTest, CheckPayloadSizes) { // First call is a special case, since it does not include a secondary // payload. Encode(); - EXPECT_EQ(1u, encoded_info_.redundant.size()); + EXPECT_EQ(0u, encoded_info_.redundant.size()); EXPECT_EQ(1u, encoded_info_.encoded_bytes); // Second call is also special since it does not include a ternary @@ -192,9 +192,9 @@ TEST_F(AudioEncoderCopyRedTest, CheckPayloadSizes) { for (size_t i = 3; i <= kNumPackets; ++i) { Encode(); ASSERT_EQ(3u, encoded_info_.redundant.size()); - EXPECT_EQ(i, encoded_info_.redundant[0].encoded_bytes); + EXPECT_EQ(i, encoded_info_.redundant[2].encoded_bytes); EXPECT_EQ(i - 1, encoded_info_.redundant[1].encoded_bytes); - EXPECT_EQ(i - 2, encoded_info_.redundant[2].encoded_bytes); + EXPECT_EQ(i - 2, encoded_info_.redundant[0].encoded_bytes); EXPECT_EQ(9 + i + (i - 1) + (i - 2), encoded_info_.encoded_bytes); } } @@ -222,8 +222,8 @@ TEST_F(AudioEncoderCopyRedTest, CheckTimestamps) { Encode(); ASSERT_EQ(2u, encoded_info_.redundant.size()); - EXPECT_EQ(primary_timestamp, encoded_info_.redundant[0].encoded_timestamp); - EXPECT_EQ(secondary_timestamp, encoded_info_.redundant[1].encoded_timestamp); + EXPECT_EQ(primary_timestamp, encoded_info_.redundant[1].encoded_timestamp); + EXPECT_EQ(secondary_timestamp, encoded_info_.redundant[0].encoded_timestamp); EXPECT_EQ(primary_timestamp, encoded_info_.encoded_timestamp); } @@ -280,9 +280,7 @@ TEST_F(AudioEncoderCopyRedTest, CheckPayloadType) { // First call is a special case, since it does not include a secondary // payload. Encode(); - ASSERT_EQ(1u, encoded_info_.redundant.size()); - EXPECT_EQ(primary_payload_type, encoded_info_.redundant[0].payload_type); - EXPECT_EQ(primary_payload_type, encoded_info_.payload_type); + ASSERT_EQ(0u, encoded_info_.redundant.size()); const int secondary_payload_type = red_payload_type_ + 2; info.payload_type = secondary_payload_type; @@ -291,8 +289,8 @@ TEST_F(AudioEncoderCopyRedTest, CheckPayloadType) { Encode(); ASSERT_EQ(2u, encoded_info_.redundant.size()); - EXPECT_EQ(secondary_payload_type, encoded_info_.redundant[0].payload_type); - EXPECT_EQ(primary_payload_type, encoded_info_.redundant[1].payload_type); + EXPECT_EQ(secondary_payload_type, encoded_info_.redundant[1].payload_type); + EXPECT_EQ(primary_payload_type, encoded_info_.redundant[0].payload_type); EXPECT_EQ(red_payload_type_, encoded_info_.payload_type); } @@ -316,7 +314,7 @@ TEST_F(AudioEncoderCopyRedTest, CheckRFC2198Header) { EXPECT_EQ(encoded_[0], primary_payload_type | 0x80); uint32_t timestamp_delta = encoded_info_.encoded_timestamp - - encoded_info_.redundant[1].encoded_timestamp; + encoded_info_.redundant[0].encoded_timestamp; // Timestamp delta is encoded as a 14 bit value. EXPECT_EQ(encoded_[1], timestamp_delta >> 6); EXPECT_EQ(static_cast(encoded_[2] >> 2), timestamp_delta & 0x3f); @@ -335,13 +333,13 @@ TEST_F(AudioEncoderCopyRedTest, CheckRFC2198Header) { EXPECT_EQ(encoded_[0], primary_payload_type | 0x80); timestamp_delta = encoded_info_.encoded_timestamp - - encoded_info_.redundant[2].encoded_timestamp; + encoded_info_.redundant[0].encoded_timestamp; // Timestamp delta is encoded as a 14 bit value. EXPECT_EQ(encoded_[1], timestamp_delta >> 6); EXPECT_EQ(static_cast(encoded_[2] >> 2), timestamp_delta & 0x3f); // Redundant length is encoded as 10 bit value. - EXPECT_EQ(encoded_[2] & 0x3u, encoded_info_.redundant[2].encoded_bytes >> 8); - EXPECT_EQ(encoded_[3], encoded_info_.redundant[2].encoded_bytes & 0xff); + EXPECT_EQ(encoded_[2] & 0x3u, encoded_info_.redundant[1].encoded_bytes >> 8); + EXPECT_EQ(encoded_[3], encoded_info_.redundant[1].encoded_bytes & 0xff); EXPECT_EQ(encoded_[4], primary_payload_type | 0x80); timestamp_delta = encoded_info_.encoded_timestamp - @@ -350,8 +348,8 @@ TEST_F(AudioEncoderCopyRedTest, CheckRFC2198Header) { EXPECT_EQ(encoded_[5], timestamp_delta >> 6); EXPECT_EQ(static_cast(encoded_[6] >> 2), timestamp_delta & 0x3f); // Redundant length is encoded as 10 bit value. - EXPECT_EQ(encoded_[6] & 0x3u, encoded_info_.redundant[2].encoded_bytes >> 8); - EXPECT_EQ(encoded_[7], encoded_info_.redundant[2].encoded_bytes & 0xff); + EXPECT_EQ(encoded_[6] & 0x3u, encoded_info_.redundant[1].encoded_bytes >> 8); + EXPECT_EQ(encoded_[7], encoded_info_.redundant[1].encoded_bytes & 0xff); EXPECT_EQ(encoded_[8], primary_payload_type); } diff --git a/modules/audio_coding/codecs/tools/audio_codec_speed_test.cc b/modules/audio_coding/codecs/tools/audio_codec_speed_test.cc index 3d5ba0b7c8..f61aacc474 100644 --- a/modules/audio_coding/codecs/tools/audio_codec_speed_test.cc +++ b/modules/audio_coding/codecs/tools/audio_codec_speed_test.cc @@ -10,6 +10,7 @@ #include "modules/audio_coding/codecs/tools/audio_codec_speed_test.h" +#include "rtc_base/checks.h" #include "rtc_base/format_macros.h" #include "test/gtest.h" #include "test/testsupport/file_utils.h" @@ -43,7 +44,7 @@ void AudioCodecSpeedTest::SetUp() { save_out_data_ = get<4>(GetParam()); FILE* fp = fopen(in_filename_.c_str(), "rb"); - assert(fp != NULL); + RTC_DCHECK(fp); // Obtain file size. fseek(fp, 0, SEEK_END); @@ -83,7 +84,7 @@ void AudioCodecSpeedTest::SetUp() { out_filename = test::OutputPath() + out_filename + ".pcm"; out_file_ = fopen(out_filename.c_str(), "wb"); - assert(out_file_ != NULL); + RTC_DCHECK(out_file_); printf("Output to be saved in %s.\n", out_filename.c_str()); } diff --git a/modules/audio_coding/g3doc/index.md b/modules/audio_coding/g3doc/index.md new file mode 100644 index 0000000000..bf50c155fc --- /dev/null +++ b/modules/audio_coding/g3doc/index.md @@ -0,0 +1,32 @@ + + + +# The WebRTC Audio Coding Module + +WebRTC audio coding module can handle both audio sending and receiving. Folder +[`acm2`][acm2] contains implementations of the APIs. + +* Audio Sending Audio frames, each of which should always contain 10 ms worth + of data, are provided to the audio coding module through + [`Add10MsData()`][Add10MsData]. The audio coding module uses a provided + audio encoder to encoded audio frames and deliver the data to a + pre-registered audio packetization callback, which is supposed to wrap the + encoded audio into RTP packets and send them over a transport. Built-in + audio codecs are included the [`codecs`][codecs] folder. The + [audio network adaptor][ANA] provides an add-on functionality to an audio + encoder (currently limited to Opus) to make the audio encoder adaptive to + network conditions (bandwidth, packet loss rate, etc). + +* Audio Receiving Audio packets are provided to the audio coding module + through [`IncomingPacket()`][IncomingPacket], and are processed by an audio + jitter buffer ([NetEq][NetEq]), which includes decoding of the packets. + Audio decoders are provided by an audio decoder factory. Decoded audio + samples should be queried by calling [`PlayoutData10Ms()`][PlayoutData10Ms]. + +[acm2]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_coding/acm2/;drc=854d59f7501aac9e9bccfa7b4d1f7f4db7842719 +[Add10MsData]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_coding/include/audio_coding_module.h;l=136;drc=d82a02c837d33cdfd75121e40dcccd32515e42d6 +[codecs]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_coding/codecs/;drc=883fea1548d58e0080f98d66fab2e0c744dfb556 +[ANA]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_coding/audio_network_adaptor/;drc=1f99551775cd876c116d1d90cba94c8a4670d184 +[IncomingPacket]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_coding/include/audio_coding_module.h;l=192;drc=d82a02c837d33cdfd75121e40dcccd32515e42d6 +[NetEq]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_coding/neteq/;drc=213dc2cfc5f1b360b1c6fc51d393491f5de49d3d +[PlayoutData10Ms]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_coding/include/audio_coding_module.h;l=216;drc=d82a02c837d33cdfd75121e40dcccd32515e42d6 diff --git a/modules/audio_coding/include/audio_coding_module_typedefs.h b/modules/audio_coding/include/audio_coding_module_typedefs.h index 07aa8c956f..a7210dadcb 100644 --- a/modules/audio_coding/include/audio_coding_module_typedefs.h +++ b/modules/audio_coding/include/audio_coding_module_typedefs.h @@ -13,8 +13,6 @@ #include -#include "rtc_base/deprecation.h" - namespace webrtc { /////////////////////////////////////////////////////////////////////////// diff --git a/modules/audio_coding/neteq/accelerate.cc b/modules/audio_coding/neteq/accelerate.cc index 6161a8f91b..e97191d8d2 100644 --- a/modules/audio_coding/neteq/accelerate.cc +++ b/modules/audio_coding/neteq/accelerate.cc @@ -69,7 +69,7 @@ Accelerate::ReturnCodes Accelerate::CheckCriteriaAndStretch( peak_index = (fs_mult_120 / peak_index) * peak_index; } - assert(fs_mult_120 >= peak_index); // Should be handled in Process(). + RTC_DCHECK_GE(fs_mult_120, peak_index); // Should be handled in Process(). // Copy first part; 0 to 15 ms. output->PushBackInterleaved( rtc::ArrayView(input, fs_mult_120 * num_channels_)); diff --git a/modules/audio_coding/neteq/audio_decoder_unittest.cc b/modules/audio_coding/neteq/audio_decoder_unittest.cc index 56708eca2a..2277872ee4 100644 --- a/modules/audio_coding/neteq/audio_decoder_unittest.cc +++ b/modules/audio_coding/neteq/audio_decoder_unittest.cc @@ -77,9 +77,9 @@ double MseInputOutput(const std::vector& input, size_t num_samples, size_t channels, int delay) { - assert(delay < static_cast(num_samples)); - assert(num_samples <= input.size()); - assert(num_samples * channels <= output.size()); + RTC_DCHECK_LT(delay, static_cast(num_samples)); + RTC_DCHECK_LE(num_samples, input.size()); + RTC_DCHECK_LE(num_samples * channels, output.size()); if (num_samples == 0) return 0.0; double squared_sum = 0.0; @@ -303,7 +303,7 @@ class AudioDecoderPcm16BTest : public AudioDecoderTest { frame_size_ = 20 * codec_input_rate_hz_ / 1000; data_length_ = 10 * frame_size_; decoder_ = new AudioDecoderPcm16B(codec_input_rate_hz_, 1); - assert(decoder_); + RTC_DCHECK(decoder_); AudioEncoderPcm16B::Config config; config.sample_rate_hz = codec_input_rate_hz_; config.frame_size_ms = @@ -320,7 +320,7 @@ class AudioDecoderIlbcTest : public AudioDecoderTest { frame_size_ = 240; data_length_ = 10 * frame_size_; decoder_ = new AudioDecoderIlbcImpl; - assert(decoder_); + RTC_DCHECK(decoder_); AudioEncoderIlbcConfig config; config.frame_size_ms = 30; audio_encoder_.reset(new AudioEncoderIlbcImpl(config, payload_type_)); @@ -414,7 +414,7 @@ class AudioDecoderG722Test : public AudioDecoderTest { frame_size_ = 160; data_length_ = 10 * frame_size_; decoder_ = new AudioDecoderG722Impl; - assert(decoder_); + RTC_DCHECK(decoder_); AudioEncoderG722Config config; config.frame_size_ms = 10; config.num_channels = 1; @@ -430,7 +430,7 @@ class AudioDecoderG722StereoTest : public AudioDecoderTest { frame_size_ = 160; data_length_ = 10 * frame_size_; decoder_ = new AudioDecoderG722StereoImpl; - assert(decoder_); + RTC_DCHECK(decoder_); AudioEncoderG722Config config; config.frame_size_ms = 10; config.num_channels = 2; @@ -587,7 +587,9 @@ TEST_F(AudioDecoderIsacFixTest, EncodeDecode) { int delay = 54; // Delay from input to output. #if defined(WEBRTC_ANDROID) && defined(WEBRTC_ARCH_ARM) static const int kEncodedBytes = 685; -#elif defined(WEBRTC_ANDROID) && defined(WEBRTC_ARCH_ARM64) +#elif defined(WEBRTC_ARCH_ARM64) + static const int kEncodedBytes = 673; +#elif defined(WEBRTC_MAC) && defined(WEBRTC_ARCH_ARM64) // M1 Mac static const int kEncodedBytes = 673; #else static const int kEncodedBytes = 671; @@ -639,7 +641,9 @@ TEST_F(AudioDecoderG722StereoTest, SetTargetBitrate) { TestSetAndGetTargetBitratesWithFixedCodec(audio_encoder_.get(), 128000); } -TEST_P(AudioDecoderOpusTest, EncodeDecode) { +// TODO(http://bugs.webrtc.org/12518): Enable the test after Opus has been +// updated. +TEST_P(AudioDecoderOpusTest, DISABLED_EncodeDecode) { constexpr int tolerance = 6176; constexpr int channel_diff_tolerance = 6; constexpr double mse = 238630.0; diff --git a/modules/audio_coding/neteq/audio_multi_vector.cc b/modules/audio_coding/neteq/audio_multi_vector.cc index 349d75dcdc..290d7eae22 100644 --- a/modules/audio_coding/neteq/audio_multi_vector.cc +++ b/modules/audio_coding/neteq/audio_multi_vector.cc @@ -19,7 +19,7 @@ namespace webrtc { AudioMultiVector::AudioMultiVector(size_t N) { - assert(N > 0); + RTC_DCHECK_GT(N, 0); if (N < 1) N = 1; for (size_t n = 0; n < N; ++n) { @@ -29,7 +29,7 @@ AudioMultiVector::AudioMultiVector(size_t N) { } AudioMultiVector::AudioMultiVector(size_t N, size_t initial_size) { - assert(N > 0); + RTC_DCHECK_GT(N, 0); if (N < 1) N = 1; for (size_t n = 0; n < N; ++n) { @@ -91,7 +91,7 @@ void AudioMultiVector::PushBackInterleaved( } void AudioMultiVector::PushBack(const AudioMultiVector& append_this) { - assert(num_channels_ == append_this.num_channels_); + RTC_DCHECK_EQ(num_channels_, append_this.num_channels_); if (num_channels_ == append_this.num_channels_) { for (size_t i = 0; i < num_channels_; ++i) { channels_[i]->PushBack(append_this[i]); @@ -101,10 +101,10 @@ void AudioMultiVector::PushBack(const AudioMultiVector& append_this) { void AudioMultiVector::PushBackFromIndex(const AudioMultiVector& append_this, size_t index) { - assert(index < append_this.Size()); + RTC_DCHECK_LT(index, append_this.Size()); index = std::min(index, append_this.Size() - 1); size_t length = append_this.Size() - index; - assert(num_channels_ == append_this.num_channels_); + RTC_DCHECK_EQ(num_channels_, append_this.num_channels_); if (num_channels_ == append_this.num_channels_) { for (size_t i = 0; i < num_channels_; ++i) { channels_[i]->PushBack(append_this[i], length, index); @@ -162,9 +162,9 @@ size_t AudioMultiVector::ReadInterleavedFromEnd(size_t length, void AudioMultiVector::OverwriteAt(const AudioMultiVector& insert_this, size_t length, size_t position) { - assert(num_channels_ == insert_this.num_channels_); + RTC_DCHECK_EQ(num_channels_, insert_this.num_channels_); // Cap |length| at the length of |insert_this|. - assert(length <= insert_this.Size()); + RTC_DCHECK_LE(length, insert_this.Size()); length = std::min(length, insert_this.Size()); if (num_channels_ == insert_this.num_channels_) { for (size_t i = 0; i < num_channels_; ++i) { @@ -175,7 +175,7 @@ void AudioMultiVector::OverwriteAt(const AudioMultiVector& insert_this, void AudioMultiVector::CrossFade(const AudioMultiVector& append_this, size_t fade_length) { - assert(num_channels_ == append_this.num_channels_); + RTC_DCHECK_EQ(num_channels_, append_this.num_channels_); if (num_channels_ == append_this.num_channels_) { for (size_t i = 0; i < num_channels_; ++i) { channels_[i]->CrossFade(append_this[i], fade_length); @@ -188,7 +188,7 @@ size_t AudioMultiVector::Channels() const { } size_t AudioMultiVector::Size() const { - assert(channels_[0]); + RTC_DCHECK(channels_[0]); return channels_[0]->Size(); } @@ -202,13 +202,13 @@ void AudioMultiVector::AssertSize(size_t required_size) { } bool AudioMultiVector::Empty() const { - assert(channels_[0]); + RTC_DCHECK(channels_[0]); return channels_[0]->Empty(); } void AudioMultiVector::CopyChannel(size_t from_channel, size_t to_channel) { - assert(from_channel < num_channels_); - assert(to_channel < num_channels_); + RTC_DCHECK_LT(from_channel, num_channels_); + RTC_DCHECK_LT(to_channel, num_channels_); channels_[from_channel]->CopyTo(channels_[to_channel]); } diff --git a/modules/audio_coding/neteq/audio_vector.cc b/modules/audio_coding/neteq/audio_vector.cc index b3ad48fc43..5e435e944d 100644 --- a/modules/audio_coding/neteq/audio_vector.cc +++ b/modules/audio_coding/neteq/audio_vector.cc @@ -247,8 +247,8 @@ void AudioVector::OverwriteAt(const int16_t* insert_this, void AudioVector::CrossFade(const AudioVector& append_this, size_t fade_length) { // Fade length cannot be longer than the current vector or |append_this|. - assert(fade_length <= Size()); - assert(fade_length <= append_this.Size()); + RTC_DCHECK_LE(fade_length, Size()); + RTC_DCHECK_LE(fade_length, append_this.Size()); fade_length = std::min(fade_length, Size()); fade_length = std::min(fade_length, append_this.Size()); size_t position = Size() - fade_length + begin_index_; @@ -265,7 +265,7 @@ void AudioVector::CrossFade(const AudioVector& append_this, (16384 - alpha) * append_this[i] + 8192) >> 14; } - assert(alpha >= 0); // Verify that the slope was correct. + RTC_DCHECK_GE(alpha, 0); // Verify that the slope was correct. // Append what is left of |append_this|. size_t samples_to_push_back = append_this.Size() - fade_length; if (samples_to_push_back > 0) diff --git a/modules/audio_coding/neteq/background_noise.cc b/modules/audio_coding/neteq/background_noise.cc index c0dcc5e04d..ae4645c78e 100644 --- a/modules/audio_coding/neteq/background_noise.cc +++ b/modules/audio_coding/neteq/background_noise.cc @@ -136,7 +136,7 @@ void BackgroundNoise::GenerateBackgroundNoise( int16_t* buffer) { constexpr size_t kNoiseLpcOrder = kMaxLpcOrder; int16_t scaled_random_vector[kMaxSampleRate / 8000 * 125]; - assert(num_noise_samples <= (kMaxSampleRate / 8000 * 125)); + RTC_DCHECK_LE(num_noise_samples, (kMaxSampleRate / 8000 * 125)); RTC_DCHECK_GE(random_vector.size(), num_noise_samples); int16_t* noise_samples = &buffer[kNoiseLpcOrder]; if (initialized()) { @@ -178,44 +178,44 @@ void BackgroundNoise::GenerateBackgroundNoise( } int32_t BackgroundNoise::Energy(size_t channel) const { - assert(channel < num_channels_); + RTC_DCHECK_LT(channel, num_channels_); return channel_parameters_[channel].energy; } void BackgroundNoise::SetMuteFactor(size_t channel, int16_t value) { - assert(channel < num_channels_); + RTC_DCHECK_LT(channel, num_channels_); channel_parameters_[channel].mute_factor = value; } int16_t BackgroundNoise::MuteFactor(size_t channel) const { - assert(channel < num_channels_); + RTC_DCHECK_LT(channel, num_channels_); return channel_parameters_[channel].mute_factor; } const int16_t* BackgroundNoise::Filter(size_t channel) const { - assert(channel < num_channels_); + RTC_DCHECK_LT(channel, num_channels_); return channel_parameters_[channel].filter; } const int16_t* BackgroundNoise::FilterState(size_t channel) const { - assert(channel < num_channels_); + RTC_DCHECK_LT(channel, num_channels_); return channel_parameters_[channel].filter_state; } void BackgroundNoise::SetFilterState(size_t channel, rtc::ArrayView input) { - assert(channel < num_channels_); + RTC_DCHECK_LT(channel, num_channels_); size_t length = std::min(input.size(), kMaxLpcOrder); memcpy(channel_parameters_[channel].filter_state, input.data(), length * sizeof(int16_t)); } int16_t BackgroundNoise::Scale(size_t channel) const { - assert(channel < num_channels_); + RTC_DCHECK_LT(channel, num_channels_); return channel_parameters_[channel].scale; } int16_t BackgroundNoise::ScaleShift(size_t channel) const { - assert(channel < num_channels_); + RTC_DCHECK_LT(channel, num_channels_); return channel_parameters_[channel].scale_shift; } @@ -240,7 +240,7 @@ void BackgroundNoise::IncrementEnergyThreshold(size_t channel, // to the limited-width operations, it is not exactly the same. The // difference should be inaudible, but bit-exactness would not be // maintained. - assert(channel < num_channels_); + RTC_DCHECK_LT(channel, num_channels_); ChannelParameters& parameters = channel_parameters_[channel]; int32_t temp_energy = (kThresholdIncrement * parameters.low_energy_update_threshold) >> 16; @@ -278,7 +278,7 @@ void BackgroundNoise::SaveParameters(size_t channel, const int16_t* filter_state, int32_t sample_energy, int32_t residual_energy) { - assert(channel < num_channels_); + RTC_DCHECK_LT(channel, num_channels_); ChannelParameters& parameters = channel_parameters_[channel]; memcpy(parameters.filter, lpc_coefficients, (kMaxLpcOrder + 1) * sizeof(int16_t)); diff --git a/modules/audio_coding/neteq/buffer_level_filter.cc b/modules/audio_coding/neteq/buffer_level_filter.cc index 5d503e9918..8901c01f77 100644 --- a/modules/audio_coding/neteq/buffer_level_filter.cc +++ b/modules/audio_coding/neteq/buffer_level_filter.cc @@ -35,14 +35,13 @@ void BufferLevelFilter::Update(size_t buffer_size_samples, // |level_factor_| and |filtered_current_level_| are in Q8. // |buffer_size_samples| is in Q0. const int64_t filtered_current_level = - ((level_factor_ * int64_t{filtered_current_level_}) >> 8) + - ((256 - level_factor_) * rtc::dchecked_cast(buffer_size_samples)); + (level_factor_ * int64_t{filtered_current_level_} >> 8) + + (256 - level_factor_) * rtc::dchecked_cast(buffer_size_samples); // Account for time-scale operations (accelerate and pre-emptive expand) and // make sure that the filtered value remains non-negative. filtered_current_level_ = rtc::saturated_cast(std::max( - 0, - filtered_current_level - (int64_t{time_stretched_samples} * (1 << 8)))); + 0, filtered_current_level - int64_t{time_stretched_samples} * (1 << 8))); } void BufferLevelFilter::SetFilteredBufferLevel(int buffer_size_samples) { diff --git a/modules/audio_coding/neteq/buffer_level_filter.h b/modules/audio_coding/neteq/buffer_level_filter.h index 89fcaf4612..218a142648 100644 --- a/modules/audio_coding/neteq/buffer_level_filter.h +++ b/modules/audio_coding/neteq/buffer_level_filter.h @@ -12,6 +12,7 @@ #define MODULES_AUDIO_CODING_NETEQ_BUFFER_LEVEL_FILTER_H_ #include +#include #include "rtc_base/constructor_magic.h" @@ -39,7 +40,7 @@ class BufferLevelFilter { // Returns filtered current level in number of samples. virtual int filtered_current_level() const { // Round to nearest whole sample. - return (filtered_current_level_ + (1 << 7)) >> 8; + return (int64_t{filtered_current_level_} + (1 << 7)) >> 8; } private: diff --git a/modules/audio_coding/neteq/comfort_noise.cc b/modules/audio_coding/neteq/comfort_noise.cc index a21cddab4d..b02e3d747f 100644 --- a/modules/audio_coding/neteq/comfort_noise.cc +++ b/modules/audio_coding/neteq/comfort_noise.cc @@ -45,8 +45,8 @@ int ComfortNoise::UpdateParameters(const Packet& packet) { int ComfortNoise::Generate(size_t requested_length, AudioMultiVector* output) { // TODO(hlundin): Change to an enumerator and skip assert. - assert(fs_hz_ == 8000 || fs_hz_ == 16000 || fs_hz_ == 32000 || - fs_hz_ == 48000); + RTC_DCHECK(fs_hz_ == 8000 || fs_hz_ == 16000 || fs_hz_ == 32000 || + fs_hz_ == 48000); // Not adapted for multi-channel yet. if (output->Channels() != 1) { RTC_LOG(LS_ERROR) << "No multi-channel support"; diff --git a/modules/audio_coding/neteq/cross_correlation.cc b/modules/audio_coding/neteq/cross_correlation.cc index 7ee867aa9b..37ed9374f0 100644 --- a/modules/audio_coding/neteq/cross_correlation.cc +++ b/modules/audio_coding/neteq/cross_correlation.cc @@ -25,22 +25,23 @@ int CrossCorrelationWithAutoShift(const int16_t* sequence_1, size_t cross_correlation_length, int cross_correlation_step, int32_t* cross_correlation) { - // Find the maximum absolute value of sequence_1 and 2. - const int32_t max_1 = - abs(sequence_1[WebRtcSpl_MaxAbsIndexW16(sequence_1, sequence_1_length)]); + // Find the element that has the maximum absolute value of sequence_1 and 2. + // Note that these values may be negative. + const int16_t max_1 = + WebRtcSpl_MaxAbsElementW16(sequence_1, sequence_1_length); const int sequence_2_shift = cross_correlation_step * (static_cast(cross_correlation_length) - 1); const int16_t* sequence_2_start = sequence_2_shift >= 0 ? sequence_2 : sequence_2 + sequence_2_shift; const size_t sequence_2_length = sequence_1_length + std::abs(sequence_2_shift); - const int32_t max_2 = abs(sequence_2_start[WebRtcSpl_MaxAbsIndexW16( - sequence_2_start, sequence_2_length)]); + const int16_t max_2 = + WebRtcSpl_MaxAbsElementW16(sequence_2_start, sequence_2_length); // In order to avoid overflow when computing the sum we should scale the // samples so that (in_vector_length * max_1 * max_2) will not overflow. const int64_t max_value = - max_1 * max_2 * static_cast(sequence_1_length); + abs(max_1 * max_2) * static_cast(sequence_1_length); const int32_t factor = max_value >> 31; const int scaling = factor == 0 ? 0 : 31 - WebRtcSpl_NormW32(factor); diff --git a/modules/audio_coding/neteq/decision_logic.cc b/modules/audio_coding/neteq/decision_logic.cc index 266e675148..d702729881 100644 --- a/modules/audio_coding/neteq/decision_logic.cc +++ b/modules/audio_coding/neteq/decision_logic.cc @@ -50,8 +50,8 @@ DecisionLogic::DecisionLogic( disallow_time_stretching_(!config.allow_time_stretching), timescale_countdown_( tick_timer_->GetNewCountdown(kMinTimescaleInterval + 1)), - estimate_dtx_delay_("estimate_dtx_delay", false), - time_stretch_cn_("time_stretch_cn", false), + estimate_dtx_delay_("estimate_dtx_delay", true), + time_stretch_cn_("time_stretch_cn", true), target_level_window_ms_("target_level_window", kDefaultTargetLevelWindowMs, 0, @@ -96,7 +96,8 @@ void DecisionLogic::SoftReset() { void DecisionLogic::SetSampleRate(int fs_hz, size_t output_size_samples) { // TODO(hlundin): Change to an enumerator and skip assert. - assert(fs_hz == 8000 || fs_hz == 16000 || fs_hz == 32000 || fs_hz == 48000); + RTC_DCHECK(fs_hz == 8000 || fs_hz == 16000 || fs_hz == 32000 || + fs_hz == 48000); sample_rate_ = fs_hz; output_size_samples_ = output_size_samples; } diff --git a/modules/audio_coding/neteq/decoder_database_unittest.cc b/modules/audio_coding/neteq/decoder_database_unittest.cc index c1b92b5375..33bee8d6f5 100644 --- a/modules/audio_coding/neteq/decoder_database_unittest.cc +++ b/modules/audio_coding/neteq/decoder_database_unittest.cc @@ -27,15 +27,14 @@ using ::testing::Invoke; namespace webrtc { TEST(DecoderDatabase, CreateAndDestroy) { - DecoderDatabase db(new rtc::RefCountedObject, + DecoderDatabase db(rtc::make_ref_counted(), absl::nullopt); EXPECT_EQ(0, db.Size()); EXPECT_TRUE(db.Empty()); } TEST(DecoderDatabase, InsertAndRemove) { - rtc::scoped_refptr factory( - new rtc::RefCountedObject); + auto factory = rtc::make_ref_counted(); DecoderDatabase db(factory, absl::nullopt); const uint8_t kPayloadType = 0; const std::string kCodecName = "Robert\'); DROP TABLE Students;"; @@ -50,8 +49,7 @@ TEST(DecoderDatabase, InsertAndRemove) { } TEST(DecoderDatabase, InsertAndRemoveAll) { - rtc::scoped_refptr factory( - new rtc::RefCountedObject); + auto factory = rtc::make_ref_counted(); DecoderDatabase db(factory, absl::nullopt); const std::string kCodecName1 = "Robert\'); DROP TABLE Students;"; const std::string kCodecName2 = "https://xkcd.com/327/"; @@ -67,8 +65,7 @@ TEST(DecoderDatabase, InsertAndRemoveAll) { } TEST(DecoderDatabase, GetDecoderInfo) { - rtc::scoped_refptr factory( - new rtc::RefCountedObject); + auto factory = rtc::make_ref_counted(); auto* decoder = new MockAudioDecoder; EXPECT_CALL(*factory, MakeAudioDecoderMock(_, _, _)) .WillOnce(Invoke([decoder](const SdpAudioFormat& format, @@ -103,8 +100,7 @@ TEST(DecoderDatabase, GetDecoder) { } TEST(DecoderDatabase, TypeTests) { - rtc::scoped_refptr factory( - new rtc::RefCountedObject); + auto factory = rtc::make_ref_counted(); DecoderDatabase db(factory, absl::nullopt); const uint8_t kPayloadTypePcmU = 0; const uint8_t kPayloadTypeCng = 13; @@ -140,8 +136,7 @@ TEST(DecoderDatabase, TypeTests) { TEST(DecoderDatabase, CheckPayloadTypes) { constexpr int kNumPayloads = 10; - rtc::scoped_refptr factory( - new rtc::RefCountedObject); + auto factory = rtc::make_ref_counted(); DecoderDatabase db(factory, absl::nullopt); // Load a number of payloads into the database. Payload types are 0, 1, ..., // while the decoder type is the same for all payload types (this does not diff --git a/modules/audio_coding/neteq/dsp_helper.cc b/modules/audio_coding/neteq/dsp_helper.cc index 05b0f70bcf..91979f2d48 100644 --- a/modules/audio_coding/neteq/dsp_helper.cc +++ b/modules/audio_coding/neteq/dsp_helper.cc @@ -89,7 +89,7 @@ int DspHelper::RampSignal(AudioMultiVector* signal, size_t length, int factor, int increment) { - assert(start_index + length <= signal->Size()); + RTC_DCHECK_LE(start_index + length, signal->Size()); if (start_index + length > signal->Size()) { // Wrong parameters. Do nothing and return the scale factor unaltered. return factor; @@ -355,7 +355,7 @@ int DspHelper::DownsampleTo4kHz(const int16_t* input, break; } default: { - assert(false); + RTC_NOTREACHED(); return -1; } } diff --git a/modules/audio_coding/neteq/expand.cc b/modules/audio_coding/neteq/expand.cc index 8df2c7afde..ffaa4c74aa 100644 --- a/modules/audio_coding/neteq/expand.cc +++ b/modules/audio_coding/neteq/expand.cc @@ -48,9 +48,10 @@ Expand::Expand(BackgroundNoise* background_noise, stop_muting_(false), expand_duration_samples_(0), channel_parameters_(new ChannelParameters[num_channels_]) { - assert(fs == 8000 || fs == 16000 || fs == 32000 || fs == 48000); - assert(fs <= static_cast(kMaxSampleRate)); // Should not be possible. - assert(num_channels_ > 0); + RTC_DCHECK(fs == 8000 || fs == 16000 || fs == 32000 || fs == 48000); + RTC_DCHECK_LE(fs, + static_cast(kMaxSampleRate)); // Should not be possible. + RTC_DCHECK_GT(num_channels_, 0); memset(expand_lags_, 0, sizeof(expand_lags_)); Reset(); } @@ -91,7 +92,7 @@ int Expand::Process(AudioMultiVector* output) { // Extract a noise segment. size_t rand_length = max_lag_; // This only applies to SWB where length could be larger than 256. - assert(rand_length <= kMaxSampleRate / 8000 * 120 + 30); + RTC_DCHECK_LE(rand_length, kMaxSampleRate / 8000 * 120 + 30); GenerateRandomVector(2, rand_length, random_vector); } @@ -110,8 +111,8 @@ int Expand::Process(AudioMultiVector* output) { ChannelParameters& parameters = channel_parameters_[channel_ix]; if (current_lag_index_ == 0) { // Use only expand_vector0. - assert(expansion_vector_position + temp_length <= - parameters.expand_vector0.Size()); + RTC_DCHECK_LE(expansion_vector_position + temp_length, + parameters.expand_vector0.Size()); parameters.expand_vector0.CopyTo(temp_length, expansion_vector_position, voiced_vector_storage); } else if (current_lag_index_ == 1) { @@ -126,10 +127,10 @@ int Expand::Process(AudioMultiVector* output) { voiced_vector_storage, temp_length); } else if (current_lag_index_ == 2) { // Mix 1/2 of expand_vector0 with 1/2 of expand_vector1. - assert(expansion_vector_position + temp_length <= - parameters.expand_vector0.Size()); - assert(expansion_vector_position + temp_length <= - parameters.expand_vector1.Size()); + RTC_DCHECK_LE(expansion_vector_position + temp_length, + parameters.expand_vector0.Size()); + RTC_DCHECK_LE(expansion_vector_position + temp_length, + parameters.expand_vector1.Size()); std::unique_ptr temp_0(new int16_t[temp_length]); parameters.expand_vector0.CopyTo(temp_length, expansion_vector_position, @@ -303,7 +304,7 @@ int Expand::Process(AudioMultiVector* output) { if (channel_ix == 0) { output->AssertSize(current_lag); } else { - assert(output->Size() == current_lag); + RTC_DCHECK_EQ(output->Size(), current_lag); } (*output)[channel_ix].OverwriteAt(temp_data, current_lag, 0); } @@ -465,7 +466,7 @@ void Expand::AnalyzeSignal(int16_t* random_vector) { size_t start_index = std::min(distortion_lag, correlation_lag); size_t correlation_lags = static_cast( WEBRTC_SPL_ABS_W16((distortion_lag - correlation_lag)) + 1); - assert(correlation_lags <= static_cast(99 * fs_mult + 1)); + RTC_DCHECK_LE(correlation_lags, static_cast(99 * fs_mult + 1)); for (size_t channel_ix = 0; channel_ix < num_channels_; ++channel_ix) { ChannelParameters& parameters = channel_parameters_[channel_ix]; @@ -659,7 +660,7 @@ void Expand::AnalyzeSignal(int16_t* random_vector) { // |kRandomTableSize|. memcpy(random_vector, RandomVector::kRandomTable, sizeof(int16_t) * RandomVector::kRandomTableSize); - assert(noise_length <= kMaxSampleRate / 8000 * 120 + 30); + RTC_DCHECK_LE(noise_length, kMaxSampleRate / 8000 * 120 + 30); random_vector_->IncreaseSeedIncrement(2); random_vector_->Generate( noise_length - RandomVector::kRandomTableSize, diff --git a/modules/audio_coding/neteq/expand.h b/modules/audio_coding/neteq/expand.h index 45d78d0823..3b0cea3d93 100644 --- a/modules/audio_coding/neteq/expand.h +++ b/modules/audio_coding/neteq/expand.h @@ -59,7 +59,7 @@ class Expand { // Returns the mute factor for |channel|. int16_t MuteFactor(size_t channel) const { - assert(channel < num_channels_); + RTC_DCHECK_LT(channel, num_channels_); return channel_parameters_[channel].mute_factor; } diff --git a/modules/audio_coding/neteq/g3doc/index.md b/modules/audio_coding/neteq/g3doc/index.md new file mode 100644 index 0000000000..d0624f46ef --- /dev/null +++ b/modules/audio_coding/neteq/g3doc/index.md @@ -0,0 +1,102 @@ + + + +# NetEq + +NetEq is the audio jitter buffer and packet loss concealer. The jitter buffer is +an adaptive jitter buffer, meaning that the buffering delay is continuously +optimized based on the network conditions. Its main goal is to ensure a smooth +playout of incoming audio packets from the network with a low amount of audio +artifacts (alterations to the original content of the packets) while at the same +time keep the delay as low as possible. + +## API + +At a high level, the NetEq API has two main functions: +[`InsertPacket`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/neteq/neteq.h;l=198;drc=4461f059d180fe8c2886d422ebd1cb55b5c83e72) +and +[`GetAudio`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/neteq/neteq.h;l=219;drc=4461f059d180fe8c2886d422ebd1cb55b5c83e72). + +### InsertPacket + +[`InsertPacket`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/neteq/neteq.h;l=198;drc=4461f059d180fe8c2886d422ebd1cb55b5c83e72) +delivers an RTP packet from the network to NetEq where the following happens: + +1. The packet is discarded if it is too late for playout (for example if it was + reordered). Otherwize it is put into the packet buffer where it is stored + until it is time for playout. If the buffer is full, discard all the + existing packets (this should be rare). +2. The interarrival time between packets is analyzed and statistics is updated + which is used to derive a new target playout delay. The interarrival time is + measured in the number of GetAudio ‘ticks’ and thus clock drift between the + sender and receiver can be accounted for. + +### GetAudio + +[`GetAudio`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/neteq/neteq.h;l=219;drc=4461f059d180fe8c2886d422ebd1cb55b5c83e72) +pulls 10 ms of audio from NetEq for playout. A much simplified decision logic is +as follows: + +1. If there is 10 ms audio in the sync buffer then return that. +2. If the next packet is available (based on RTP timestamp) in the packet + buffer then decode it and append the result to the sync buffer. + 1. Compare the current delay estimate (filtered buffer level) with the + target delay and time stretch (accelerate or decelerate) the contents of + the sync buffer if the buffer level is too high or too low. + 2. Return 10 ms of audio from the sync buffer. +3. If the last decoded packet was a discontinuous transmission (DTX) packet + then generate comfort noise. +4. If there is no available packet for decoding due to the next packet having + not arrived or been lost then generate packet loss concealment by + extrapolating the remaining audio in the sync buffer or by asking the + decoder to produce it. + +In summary, the output is the result one of the following operations: + +* Normal: audio decoded from a packet. +* Acceleration: accelerated playout of a decoded packet. +* Preemptive expand: decelerated playout of a decoded packet. +* Expand: packet loss concealment generated by NetEq or the decoder. +* Merge: audio stitched together from packet loss concealment to decoded data + in case of a loss. +* Comfort noise (CNG): comfort noise generated by NetEq or the decoder between + talk spurts due to discontinuous transmission of packets (DTX). + +## Statistics + +There are a number of functions that can be used to query the internal state of +NetEq, statistics about the type of audio output and latency metrics such as how +long time packets have waited in the buffer. + +* [`NetworkStatistics`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/neteq/neteq.h;l=273;drc=4461f059d180fe8c2886d422ebd1cb55b5c83e72): + instantaneous values or stats averaged over the duration since last call to + this function. +* [`GetLifetimeStatistics`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/neteq/neteq.h;l=280;drc=4461f059d180fe8c2886d422ebd1cb55b5c83e72): + cumulative stats that persist over the lifetime of the class. +* [`GetOperationsAndState`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/neteq/neteq.h;l=284;drc=4461f059d180fe8c2886d422ebd1cb55b5c83e72): + information about the internal state of NetEq (is only inteded to be used + for testing and debugging). + +## Tests and tools + +* [`neteq_rtpplay`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_coding/neteq/tools/neteq_rtpplay.cc;drc=cee751abff598fc19506f77de08bea7c61b9dcca): + Simulate NetEq behavior based on either an RTP dump, a PCAP file or an RTC + event log. A replacement audio file can also be used instead of the original + payload. Outputs aggregated statistics and optionally an audio file to + listen to. +* [`neteq_speed_test`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_coding/neteq/test/neteq_speed_test.cc;drc=2ab97f6f8e27b47c0d9beeb8b6ca5387bda9f55c): + Measure performance of NetEq, used on perf bots. +* Unit tests including bit exactness tests where RTP file is used as an input + to NetEq, the output is concatenated and a checksum is calculated and + compared against a reference. + +## Other responsibilities + +* Dual-tone multi-frequency signaling (DTMF): receive telephone events and + produce dual tone waveforms. +* Forward error correction (RED or codec inband FEC): split inserted packets + and prioritize the payloads. +* NACK (negative acknowledgement): keep track of lost packets and generate a + list of packets to NACK. +* Audio/video sync: NetEq can be instructed to increase the latency in order + to keep audio and video in sync. diff --git a/modules/audio_coding/neteq/merge.cc b/modules/audio_coding/neteq/merge.cc index 5bf239bfc5..770e2e3590 100644 --- a/modules/audio_coding/neteq/merge.cc +++ b/modules/audio_coding/neteq/merge.cc @@ -38,7 +38,7 @@ Merge::Merge(int fs_hz, expand_(expand), sync_buffer_(sync_buffer), expanded_(num_channels_) { - assert(num_channels_ > 0); + RTC_DCHECK_GT(num_channels_, 0); } Merge::~Merge() = default; @@ -47,9 +47,9 @@ size_t Merge::Process(int16_t* input, size_t input_length, AudioMultiVector* output) { // TODO(hlundin): Change to an enumerator and skip assert. - assert(fs_hz_ == 8000 || fs_hz_ == 16000 || fs_hz_ == 32000 || - fs_hz_ == 48000); - assert(fs_hz_ <= kMaxSampleRate); // Should not be possible. + RTC_DCHECK(fs_hz_ == 8000 || fs_hz_ == 16000 || fs_hz_ == 32000 || + fs_hz_ == 48000); + RTC_DCHECK_LE(fs_hz_, kMaxSampleRate); // Should not be possible. if (input_length == 0) { return 0; } @@ -64,7 +64,7 @@ size_t Merge::Process(int16_t* input, input_vector.PushBackInterleaved( rtc::ArrayView(input, input_length)); size_t input_length_per_channel = input_vector.Size(); - assert(input_length_per_channel == input_length / num_channels_); + RTC_DCHECK_EQ(input_length_per_channel, input_length / num_channels_); size_t best_correlation_index = 0; size_t output_length = 0; @@ -142,10 +142,10 @@ size_t Merge::Process(int16_t* input, output_length = best_correlation_index + input_length_per_channel; if (channel == 0) { - assert(output->Empty()); // Output should be empty at this point. + RTC_DCHECK(output->Empty()); // Output should be empty at this point. output->AssertSize(output_length); } else { - assert(output->Size() == output_length); + RTC_DCHECK_EQ(output->Size(), output_length); } (*output)[channel].OverwriteAt(temp_data_.data(), output_length, 0); } @@ -165,7 +165,7 @@ size_t Merge::GetExpandedSignal(size_t* old_length, size_t* expand_period) { // Check how much data that is left since earlier. *old_length = sync_buffer_->FutureLength(); // Should never be less than overlap_length. - assert(*old_length >= expand_->overlap_length()); + RTC_DCHECK_GE(*old_length, expand_->overlap_length()); // Generate data to merge the overlap with using expand. expand_->SetParametersForMergeAfterExpand(); @@ -182,7 +182,7 @@ size_t Merge::GetExpandedSignal(size_t* old_length, size_t* expand_period) { // This is the truncated length. } // This assert should always be true thanks to the if statement above. - assert(210 * kMaxSampleRate / 8000 >= *old_length); + RTC_DCHECK_GE(210 * kMaxSampleRate / 8000, *old_length); AudioMultiVector expanded_temp(num_channels_); expand_->Process(&expanded_temp); @@ -191,8 +191,8 @@ size_t Merge::GetExpandedSignal(size_t* old_length, size_t* expand_period) { expanded_.Clear(); // Copy what is left since earlier into the expanded vector. expanded_.PushBackFromIndex(*sync_buffer_, sync_buffer_->next_index()); - assert(expanded_.Size() == *old_length); - assert(expanded_temp.Size() > 0); + RTC_DCHECK_EQ(expanded_.Size(), *old_length); + RTC_DCHECK_GT(expanded_temp.Size(), 0); // Do "ugly" copy and paste from the expanded in order to generate more data // to correlate (but not interpolate) with. const size_t required_length = static_cast((120 + 80 + 2) * fs_mult_); @@ -204,7 +204,7 @@ size_t Merge::GetExpandedSignal(size_t* old_length, size_t* expand_period) { // Trim the length to exactly |required_length|. expanded_.PopBack(expanded_.Size() - required_length); } - assert(expanded_.Size() >= required_length); + RTC_DCHECK_GE(expanded_.Size(), required_length); return required_length; } @@ -373,7 +373,7 @@ size_t Merge::CorrelateAndPeakSearch(size_t start_position, while (((best_correlation_index + input_length) < (timestamps_per_call_ + expand_->overlap_length())) || ((best_correlation_index + input_length) < start_position)) { - assert(false); // Should never happen. + RTC_NOTREACHED(); // Should never happen. best_correlation_index += expand_period; // Jump one lag ahead. } return best_correlation_index; diff --git a/modules/audio_coding/neteq/nack_tracker.cc b/modules/audio_coding/neteq/nack_tracker.cc index 8358769804..9a873eee07 100644 --- a/modules/audio_coding/neteq/nack_tracker.cc +++ b/modules/audio_coding/neteq/nack_tracker.cc @@ -44,7 +44,7 @@ NackTracker* NackTracker::Create(int nack_threshold_packets) { } void NackTracker::UpdateSampleRate(int sample_rate_hz) { - assert(sample_rate_hz > 0); + RTC_DCHECK_GT(sample_rate_hz, 0); sample_rate_khz_ = sample_rate_hz / 1000; } @@ -120,9 +120,9 @@ uint32_t NackTracker::EstimateTimestamp(uint16_t sequence_num) { } void NackTracker::AddToList(uint16_t sequence_number_current_received_rtp) { - assert(!any_rtp_decoded_ || - IsNewerSequenceNumber(sequence_number_current_received_rtp, - sequence_num_last_decoded_rtp_)); + RTC_DCHECK(!any_rtp_decoded_ || + IsNewerSequenceNumber(sequence_number_current_received_rtp, + sequence_num_last_decoded_rtp_)); // Packets with sequence numbers older than |upper_bound_missing| are // considered missing, and the rest are considered late. @@ -164,7 +164,7 @@ void NackTracker::UpdateLastDecodedPacket(uint16_t sequence_number, ++it) it->second.time_to_play_ms = TimeToPlay(it->second.estimated_timestamp); } else { - assert(sequence_number == sequence_num_last_decoded_rtp_); + RTC_DCHECK_EQ(sequence_number, sequence_num_last_decoded_rtp_); // Same sequence number as before. 10 ms is elapsed, update estimations for // time-to-play. diff --git a/modules/audio_coding/neteq/neteq_decoder_plc_unittest.cc b/modules/audio_coding/neteq/neteq_decoder_plc_unittest.cc index daf81f2a9c..2b4ae7e63e 100644 --- a/modules/audio_coding/neteq/neteq_decoder_plc_unittest.cc +++ b/modules/audio_coding/neteq/neteq_decoder_plc_unittest.cc @@ -10,7 +10,6 @@ // Test to verify correct operation when using the decoder-internal PLC. -#include #include #include #include @@ -33,6 +32,9 @@ namespace webrtc { namespace test { namespace { +constexpr int kSampleRateHz = 32000; +constexpr int kRunTimeMs = 10000; + // This class implements a fake decoder. The decoder will read audio from a file // and present as output, both for regular decoding and for PLC. class AudioDecoderPlc : public AudioDecoder { @@ -48,7 +50,8 @@ class AudioDecoderPlc : public AudioDecoder { int sample_rate_hz, int16_t* decoded, SpeechType* speech_type) override { - RTC_CHECK_EQ(encoded_len / 2, 20 * sample_rate_hz_ / 1000); + RTC_CHECK_GE(encoded_len / 2, 10 * sample_rate_hz_ / 1000); + RTC_CHECK_LE(encoded_len / 2, 2 * 10 * sample_rate_hz_ / 1000); RTC_CHECK_EQ(sample_rate_hz, sample_rate_hz_); RTC_CHECK(decoded); RTC_CHECK(speech_type); @@ -60,17 +63,21 @@ class AudioDecoderPlc : public AudioDecoder { void GeneratePlc(size_t requested_samples_per_channel, rtc::BufferT* concealment_audio) override { + // Instead of generating random data for GeneratePlc we use the same data as + // the input, so we can check that we produce the same result independently + // of the losses. + RTC_DCHECK_EQ(requested_samples_per_channel, 10 * sample_rate_hz_ / 1000); + // Must keep a local copy of this since DecodeInternal sets it to false. const bool last_was_plc = last_was_plc_; - SpeechType speech_type; + std::vector decoded(5760); - int dec_len = DecodeInternal(nullptr, 2 * 20 * sample_rate_hz_ / 1000, + SpeechType speech_type; + int dec_len = DecodeInternal(nullptr, 2 * 10 * sample_rate_hz_ / 1000, sample_rate_hz_, decoded.data(), &speech_type); - // This fake decoder can only generate 20 ms of PLC data each time. Make - // sure the caller didn't ask for more. - RTC_CHECK_GE(dec_len, requested_samples_per_channel); concealment_audio->AppendData(decoded.data(), dec_len); concealed_samples_ += rtc::checked_cast(dec_len); + if (!last_was_plc) { ++concealment_events_; } @@ -103,11 +110,15 @@ class ZeroSampleGenerator : public EncodeNetEqInput::Generator { }; // A NetEqInput which connects to another NetEqInput, but drops a number of -// packets on the way. +// consecutive packets on the way class LossyInput : public NetEqInput { public: - LossyInput(int loss_cadence, std::unique_ptr input) - : loss_cadence_(loss_cadence), input_(std::move(input)) {} + LossyInput(int loss_cadence, + int burst_length, + std::unique_ptr input) + : loss_cadence_(loss_cadence), + burst_length_(burst_length), + input_(std::move(input)) {} absl::optional NextPacketTime() const override { return input_->NextPacketTime(); @@ -119,8 +130,12 @@ class LossyInput : public NetEqInput { std::unique_ptr PopPacket() override { if (loss_cadence_ != 0 && (++count_ % loss_cadence_) == 0) { - // Pop one extra packet to create the loss. - input_->PopPacket(); + // Pop `burst_length_` packets to create the loss. + auto packet_to_return = input_->PopPacket(); + for (int i = 0; i < burst_length_; i++) { + input_->PopPacket(); + } + return packet_to_return; } return input_->PopPacket(); } @@ -135,6 +150,7 @@ class LossyInput : public NetEqInput { private: const int loss_cadence_; + const int burst_length_; int count_ = 0; const std::unique_ptr input_; }; @@ -149,7 +165,14 @@ class AudioChecksumWithOutput : public AudioChecksum { std::string& output_str_; }; -NetEqNetworkStatistics RunTest(int loss_cadence, std::string* checksum) { +struct TestStatistics { + NetEqNetworkStatistics network; + NetEqLifetimeStatistics lifetime; +}; + +TestStatistics RunTest(int loss_cadence, + int burst_length, + std::string* checksum) { NetEq::Config config; config.for_test_no_time_stretching = true; @@ -157,20 +180,18 @@ NetEqNetworkStatistics RunTest(int loss_cadence, std::string* checksum) { // but the actual encoded samples will never be used by the decoder in the // test. See below about the decoder. auto generator = std::make_unique(); - constexpr int kSampleRateHz = 32000; constexpr int kPayloadType = 100; AudioEncoderPcm16B::Config encoder_config; encoder_config.sample_rate_hz = kSampleRateHz; encoder_config.payload_type = kPayloadType; auto encoder = std::make_unique(encoder_config); - constexpr int kRunTimeMs = 10000; auto input = std::make_unique( std::move(generator), std::move(encoder), kRunTimeMs); // Wrap the input in a loss function. - auto lossy_input = - std::make_unique(loss_cadence, std::move(input)); + auto lossy_input = std::make_unique(loss_cadence, burst_length, + std::move(input)); - // Settinng up decoders. + // Setting up decoders. NetEqTest::DecoderMap decoders; // Using a fake decoder which simply reads the output audio from a file. auto input_file = std::make_unique( @@ -187,7 +208,7 @@ NetEqNetworkStatistics RunTest(int loss_cadence, std::string* checksum) { NetEqTest neteq_test( config, /*decoder_factory=*/ - new rtc::RefCountedObject(&dec), + rtc::make_ref_counted(&dec), /*codecs=*/decoders, /*text_log=*/nullptr, /*neteq_factory=*/nullptr, /*input=*/std::move(lossy_input), std::move(output), callbacks); EXPECT_LE(kRunTimeMs, neteq_test.Run()); @@ -195,24 +216,98 @@ NetEqNetworkStatistics RunTest(int loss_cadence, std::string* checksum) { auto lifetime_stats = neteq_test.LifetimeStats(); EXPECT_EQ(dec.concealed_samples(), lifetime_stats.concealed_samples); EXPECT_EQ(dec.concealment_events(), lifetime_stats.concealment_events); - - return neteq_test.SimulationStats(); + return {neteq_test.SimulationStats(), neteq_test.LifetimeStats()}; } } // namespace -TEST(NetEqDecoderPlc, Test) { +// Check that some basic metrics are produced in the right direction. In +// particular, expand_rate should only increase if there are losses present. Our +// dummy decoder is designed such as the checksum should always be the same +// regardless of the losses given that calls are executed in the right order. +TEST(NetEqDecoderPlc, BasicMetrics) { std::string checksum; - auto stats = RunTest(10, &checksum); + + // Drop 1 packet every 10 packets. + auto stats = RunTest(10, 1, &checksum); std::string checksum_no_loss; - auto stats_no_loss = RunTest(0, &checksum_no_loss); + auto stats_no_loss = RunTest(0, 0, &checksum_no_loss); EXPECT_EQ(checksum, checksum_no_loss); - EXPECT_EQ(stats.preemptive_rate, stats_no_loss.preemptive_rate); - EXPECT_EQ(stats.accelerate_rate, stats_no_loss.accelerate_rate); - EXPECT_EQ(0, stats_no_loss.expand_rate); - EXPECT_GT(stats.expand_rate, 0); + EXPECT_EQ(stats.network.preemptive_rate, + stats_no_loss.network.preemptive_rate); + EXPECT_EQ(stats.network.accelerate_rate, + stats_no_loss.network.accelerate_rate); + EXPECT_EQ(0, stats_no_loss.network.expand_rate); + EXPECT_GT(stats.network.expand_rate, 0); +} + +// Checks that interruptions are not counted in small losses but they are +// correctly counted in long interruptions. +TEST(NetEqDecoderPlc, CountInterruptions) { + std::string checksum; + std::string checksum_2; + std::string checksum_3; + + // Half of the packets lost but in short interruptions. + auto stats_no_interruptions = RunTest(1, 1, &checksum); + // One lost of 500 ms (250 packets). + auto stats_one_interruption = RunTest(200, 250, &checksum_2); + // Two losses of 250ms each (125 packets). + auto stats_two_interruptions = RunTest(125, 125, &checksum_3); + + EXPECT_EQ(checksum, checksum_2); + EXPECT_EQ(checksum, checksum_3); + EXPECT_GT(stats_no_interruptions.network.expand_rate, 0); + EXPECT_EQ(stats_no_interruptions.lifetime.total_interruption_duration_ms, 0); + EXPECT_EQ(stats_no_interruptions.lifetime.interruption_count, 0); + + EXPECT_GT(stats_one_interruption.network.expand_rate, 0); + EXPECT_EQ(stats_one_interruption.lifetime.total_interruption_duration_ms, + 5000); + EXPECT_EQ(stats_one_interruption.lifetime.interruption_count, 1); + + EXPECT_GT(stats_two_interruptions.network.expand_rate, 0); + EXPECT_EQ(stats_two_interruptions.lifetime.total_interruption_duration_ms, + 5000); + EXPECT_EQ(stats_two_interruptions.lifetime.interruption_count, 2); +} + +// Checks that small losses do not produce interruptions. +TEST(NetEqDecoderPlc, NoInterruptionsInSmallLosses) { + std::string checksum_1; + std::string checksum_4; + + auto stats_1 = RunTest(300, 1, &checksum_1); + auto stats_4 = RunTest(300, 4, &checksum_4); + + EXPECT_EQ(checksum_1, checksum_4); + + EXPECT_EQ(stats_1.lifetime.interruption_count, 0); + EXPECT_EQ(stats_1.lifetime.total_interruption_duration_ms, 0); + EXPECT_EQ(stats_1.lifetime.concealed_samples, 640u); // 20ms of concealment. + EXPECT_EQ(stats_1.lifetime.concealment_events, 1u); // in just one event. + + EXPECT_EQ(stats_4.lifetime.interruption_count, 0); + EXPECT_EQ(stats_4.lifetime.total_interruption_duration_ms, 0); + EXPECT_EQ(stats_4.lifetime.concealed_samples, 2560u); // 80ms of concealment. + EXPECT_EQ(stats_4.lifetime.concealment_events, 1u); // in just one event. +} + +// Checks that interruptions of different sizes report correct duration. +TEST(NetEqDecoderPlc, InterruptionsReportCorrectSize) { + std::string checksum; + + for (int burst_length = 5; burst_length < 10; burst_length++) { + auto stats = RunTest(300, burst_length, &checksum); + auto duration = stats.lifetime.total_interruption_duration_ms; + if (burst_length < 8) { + EXPECT_EQ(duration, 0); + } else { + EXPECT_EQ(duration, burst_length * 20); + } + } } } // namespace test diff --git a/modules/audio_coding/neteq/neteq_impl.cc b/modules/audio_coding/neteq/neteq_impl.cc index 9ec7bd5bca..8b07d7e47c 100644 --- a/modules/audio_coding/neteq/neteq_impl.cc +++ b/modules/audio_coding/neteq/neteq_impl.cc @@ -258,6 +258,7 @@ void SetAudioFrameActivityAndType(bool vad_enabled, int NetEqImpl::GetAudio(AudioFrame* audio_frame, bool* muted, + int* current_sample_rate_hz, absl::optional action_override) { TRACE_EVENT0("webrtc", "NetEqImpl::GetAudio"); MutexLock lock(&mutex_); @@ -296,6 +297,11 @@ int NetEqImpl::GetAudio(AudioFrame* audio_frame, } } + if (current_sample_rate_hz) { + *current_sample_rate_hz = delayed_last_output_sample_rate_hz_.value_or( + last_output_sample_rate_hz_); + } + return kOK; } @@ -337,7 +343,7 @@ void NetEqImpl::RemoveAllPayloadTypes() { bool NetEqImpl::SetMinimumDelay(int delay_ms) { MutexLock lock(&mutex_); if (delay_ms >= 0 && delay_ms <= 10000) { - assert(controller_.get()); + RTC_DCHECK(controller_.get()); return controller_->SetMinimumDelay( std::max(delay_ms - output_delay_chain_ms_, 0)); } @@ -347,7 +353,7 @@ bool NetEqImpl::SetMinimumDelay(int delay_ms) { bool NetEqImpl::SetMaximumDelay(int delay_ms) { MutexLock lock(&mutex_); if (delay_ms >= 0 && delay_ms <= 10000) { - assert(controller_.get()); + RTC_DCHECK(controller_.get()); return controller_->SetMaximumDelay( std::max(delay_ms - output_delay_chain_ms_, 0)); } @@ -386,7 +392,7 @@ int NetEqImpl::FilteredCurrentDelayMs() const { int NetEqImpl::NetworkStatistics(NetEqNetworkStatistics* stats) { MutexLock lock(&mutex_); - assert(decoder_database_.get()); + RTC_DCHECK(decoder_database_.get()); *stats = CurrentNetworkStatisticsInternal(); stats_->GetNetworkStatistics(decoder_frame_length_, stats); // Compensate for output delay chain. @@ -403,13 +409,13 @@ NetEqNetworkStatistics NetEqImpl::CurrentNetworkStatistics() const { } NetEqNetworkStatistics NetEqImpl::CurrentNetworkStatisticsInternal() const { - assert(decoder_database_.get()); + RTC_DCHECK(decoder_database_.get()); NetEqNetworkStatistics stats; const size_t total_samples_in_buffers = packet_buffer_->NumSamplesInBuffer(decoder_frame_length_) + sync_buffer_->FutureLength(); - assert(controller_.get()); + RTC_DCHECK(controller_.get()); stats.preferred_buffer_size_ms = controller_->TargetLevelMs(); stats.jitter_peaks_found = controller_->PeakFound(); RTC_DCHECK_GT(fs_hz_, 0); @@ -443,13 +449,13 @@ NetEqOperationsAndState NetEqImpl::GetOperationsAndState() const { void NetEqImpl::EnableVad() { MutexLock lock(&mutex_); - assert(vad_.get()); + RTC_DCHECK(vad_.get()); vad_->Enable(); } void NetEqImpl::DisableVad() { MutexLock lock(&mutex_); - assert(vad_.get()); + RTC_DCHECK(vad_.get()); vad_->Disable(); } @@ -500,8 +506,8 @@ void NetEqImpl::FlushBuffers() { MutexLock lock(&mutex_); RTC_LOG(LS_VERBOSE) << "FlushBuffers"; packet_buffer_->Flush(stats_.get()); - assert(sync_buffer_.get()); - assert(expand_.get()); + RTC_DCHECK(sync_buffer_.get()); + RTC_DCHECK(expand_.get()); sync_buffer_->Flush(); sync_buffer_->set_next_index(sync_buffer_->next_index() - expand_->overlap_length()); @@ -565,19 +571,19 @@ int NetEqImpl::InsertPacketInternal(const RTPHeader& rtp_header, return kInvalidPointer; } - int64_t receive_time_ms = clock_->TimeInMilliseconds(); + Timestamp receive_time = clock_->CurrentTime(); stats_->ReceivedPacket(); PacketList packet_list; // Insert packet in a packet list. - packet_list.push_back([&rtp_header, &payload, &receive_time_ms] { + packet_list.push_back([&rtp_header, &payload, &receive_time] { // Convert to Packet. Packet packet; packet.payload_type = rtp_header.payloadType; packet.sequence_number = rtp_header.sequenceNumber; packet.timestamp = rtp_header.timestamp; packet.payload.SetData(payload.data(), payload.size()); - packet.packet_info = RtpPacketInfo(rtp_header, receive_time_ms); + packet.packet_info = RtpPacketInfo(rtp_header, receive_time); // Waiting time will be set upon inserting the packet in the buffer. RTC_DCHECK(!packet.waiting_time); return packet; @@ -622,8 +628,7 @@ int NetEqImpl::InsertPacketInternal(const RTPHeader& rtp_header, if (update_sample_rate_and_channels) { nack_->Reset(); } - nack_->UpdateLastReceivedPacket(rtp_header.sequenceNumber, - rtp_header.timestamp); + nack_->UpdateLastReceivedPacket(main_sequence_number, main_timestamp); } // Check for RED payload type, and separate payloads into several packets. @@ -792,12 +797,12 @@ int NetEqImpl::InsertPacketInternal(const RTPHeader& rtp_header, size_t channels = 1; if (!decoder_database_->IsComfortNoise(payload_type)) { AudioDecoder* decoder = decoder_database_->GetDecoder(payload_type); - assert(decoder); // Payloads are already checked to be valid. + RTC_DCHECK(decoder); // Payloads are already checked to be valid. channels = decoder->Channels(); } const DecoderDatabase::DecoderInfo* decoder_info = decoder_database_->GetDecoderInfo(payload_type); - assert(decoder_info); + RTC_DCHECK(decoder_info); if (decoder_info->SampleRateHz() != fs_hz_ || channels != algorithm_buffer_->Channels()) { SetSampleRateAndChannels(decoder_info->SampleRateHz(), channels); @@ -811,7 +816,7 @@ int NetEqImpl::InsertPacketInternal(const RTPHeader& rtp_header, const DecoderDatabase::DecoderInfo* dec_info = decoder_database_->GetDecoderInfo(main_payload_type); - assert(dec_info); // Already checked that the payload type is known. + RTC_DCHECK(dec_info); // Already checked that the payload type is known. NetEqController::PacketArrivedInfo info; info.is_cng_or_dtmf = dec_info->IsComfortNoise() || dec_info->IsDtmf(); @@ -889,7 +894,7 @@ int NetEqImpl::GetAudioInternal(AudioFrame* audio_frame, int decode_return_value = Decode(&packet_list, &operation, &length, &speech_type); - assert(vad_.get()); + RTC_DCHECK(vad_.get()); bool sid_frame_available = (operation == Operation::kRfc3389Cng && !packet_list.empty()); vad_->Update(decoded_buffer_.get(), static_cast(length), speech_type, @@ -960,7 +965,7 @@ int NetEqImpl::GetAudioInternal(AudioFrame* audio_frame, } case Operation::kUndefined: { RTC_LOG(LS_ERROR) << "Invalid operation kUndefined."; - assert(false); // This should not happen. + RTC_NOTREACHED(); // This should not happen. last_mode_ = Mode::kError; return kInvalidOperation; } @@ -1096,7 +1101,7 @@ int NetEqImpl::GetDecision(Operation* operation, *play_dtmf = false; *operation = Operation::kUndefined; - assert(sync_buffer_.get()); + RTC_DCHECK(sync_buffer_.get()); uint32_t end_timestamp = sync_buffer_->end_timestamp(); if (!new_codec_) { const uint32_t five_seconds_samples = 5 * fs_hz_; @@ -1123,7 +1128,7 @@ int NetEqImpl::GetDecision(Operation* operation, // Don't use this packet, discard it. if (packet_buffer_->DiscardNextPacket(stats_.get()) != PacketBuffer::kOK) { - assert(false); // Must be ok by design. + RTC_NOTREACHED(); // Must be ok by design. } // Check buffer again. if (!new_codec_) { @@ -1134,7 +1139,7 @@ int NetEqImpl::GetDecision(Operation* operation, } } - assert(expand_.get()); + RTC_DCHECK(expand_.get()); const int samples_left = static_cast(sync_buffer_->FutureLength() - expand_->overlap_length()); if (last_mode_ == Mode::kAccelerateSuccess || @@ -1154,8 +1159,8 @@ int NetEqImpl::GetDecision(Operation* operation, } // Get instruction. - assert(sync_buffer_.get()); - assert(expand_.get()); + RTC_DCHECK(sync_buffer_.get()); + RTC_DCHECK(expand_.get()); generated_noise_samples = generated_noise_stopwatch_ ? generated_noise_stopwatch_->ElapsedTicks() * output_size_samples_ + @@ -1214,11 +1219,16 @@ int NetEqImpl::GetDecision(Operation* operation, } controller_->ExpandDecision(*operation); + if ((last_mode_ == Mode::kCodecPlc) && (*operation != Operation::kExpand)) { + // Getting out of the PLC expand mode, reporting interruptions. + // NetEq PLC reports this metrics in expand.cc + stats_->EndExpandEvent(fs_hz_); + } // Check conditions for reset. if (new_codec_ || *operation == Operation::kUndefined) { // The only valid reason to get kUndefined is that new_codec_ is set. - assert(new_codec_); + RTC_DCHECK(new_codec_); if (*play_dtmf && !packet) { timestamp_ = dtmf_event->timestamp; } else { @@ -1390,7 +1400,7 @@ int NetEqImpl::Decode(PacketList* packet_list, uint8_t payload_type = packet.payload_type; if (!decoder_database_->IsComfortNoise(payload_type)) { decoder = decoder_database_->GetDecoder(payload_type); - assert(decoder); + RTC_DCHECK(decoder); if (!decoder) { RTC_LOG(LS_WARNING) << "Unknown payload type " << static_cast(payload_type); @@ -1403,7 +1413,7 @@ int NetEqImpl::Decode(PacketList* packet_list, // We have a new decoder. Re-init some values. const DecoderDatabase::DecoderInfo* decoder_info = decoder_database_->GetDecoderInfo(payload_type); - assert(decoder_info); + RTC_DCHECK(decoder_info); if (!decoder_info) { RTC_LOG(LS_WARNING) << "Unknown payload type " << static_cast(payload_type); @@ -1475,8 +1485,8 @@ int NetEqImpl::Decode(PacketList* packet_list, // Don't increment timestamp if codec returned CNG speech type // since in this case, the we will increment the CNGplayedTS counter. // Increase with number of samples per channel. - assert(*decoded_length == 0 || - (decoder && decoder->Channels() == sync_buffer_->Channels())); + RTC_DCHECK(*decoded_length == 0 || + (decoder && decoder->Channels() == sync_buffer_->Channels())); sync_buffer_->IncreaseEndTimestamp( *decoded_length / static_cast(sync_buffer_->Channels())); } @@ -1525,16 +1535,16 @@ int NetEqImpl::DecodeLoop(PacketList* packet_list, // Do decoding. while (!packet_list->empty() && !decoder_database_->IsComfortNoise( packet_list->front().payload_type)) { - assert(decoder); // At this point, we must have a decoder object. + RTC_DCHECK(decoder); // At this point, we must have a decoder object. // The number of channels in the |sync_buffer_| should be the same as the // number decoder channels. - assert(sync_buffer_->Channels() == decoder->Channels()); - assert(decoded_buffer_length_ >= kMaxFrameSize * decoder->Channels()); - assert(operation == Operation::kNormal || - operation == Operation::kAccelerate || - operation == Operation::kFastAccelerate || - operation == Operation::kMerge || - operation == Operation::kPreemptiveExpand); + RTC_DCHECK_EQ(sync_buffer_->Channels(), decoder->Channels()); + RTC_DCHECK_GE(decoded_buffer_length_, kMaxFrameSize * decoder->Channels()); + RTC_DCHECK(operation == Operation::kNormal || + operation == Operation::kAccelerate || + operation == Operation::kFastAccelerate || + operation == Operation::kMerge || + operation == Operation::kPreemptiveExpand); auto opt_result = packet_list->front().frame->Decode( rtc::ArrayView(&decoded_buffer_[*decoded_length], @@ -1571,9 +1581,10 @@ int NetEqImpl::DecodeLoop(PacketList* packet_list, // If the list is not empty at this point, either a decoding error terminated // the while-loop, or list must hold exactly one CNG packet. - assert(packet_list->empty() || *decoded_length < 0 || - (packet_list->size() == 1 && decoder_database_->IsComfortNoise( - packet_list->front().payload_type))); + RTC_DCHECK( + packet_list->empty() || *decoded_length < 0 || + (packet_list->size() == 1 && + decoder_database_->IsComfortNoise(packet_list->front().payload_type))); return 0; } @@ -1581,7 +1592,7 @@ void NetEqImpl::DoNormal(const int16_t* decoded_buffer, size_t decoded_length, AudioDecoder::SpeechType speech_type, bool play_dtmf) { - assert(normal_.get()); + RTC_DCHECK(normal_.get()); normal_->Process(decoded_buffer, decoded_length, last_mode_, algorithm_buffer_.get()); if (decoded_length != 0) { @@ -1604,7 +1615,7 @@ void NetEqImpl::DoMerge(int16_t* decoded_buffer, size_t decoded_length, AudioDecoder::SpeechType speech_type, bool play_dtmf) { - assert(merge_.get()); + RTC_DCHECK(merge_.get()); size_t new_length = merge_->Process(decoded_buffer, decoded_length, algorithm_buffer_.get()); // Correction can be negative. @@ -1765,7 +1776,7 @@ int NetEqImpl::DoAccelerate(int16_t* decoded_buffer, sync_buffer_->Size() - borrowed_samples_per_channel); sync_buffer_->PushFrontZeros(borrowed_samples_per_channel - length); algorithm_buffer_->PopFront(length); - assert(algorithm_buffer_->Empty()); + RTC_DCHECK(algorithm_buffer_->Empty()); } else { sync_buffer_->ReplaceAtIndex( *algorithm_buffer_, borrowed_samples_per_channel, @@ -1854,7 +1865,7 @@ int NetEqImpl::DoPreemptiveExpand(int16_t* decoded_buffer, int NetEqImpl::DoRfc3389Cng(PacketList* packet_list, bool play_dtmf) { if (!packet_list->empty()) { // Must have exactly one SID frame at this point. - assert(packet_list->size() == 1); + RTC_DCHECK_EQ(packet_list->size(), 1); const Packet& packet = packet_list->front(); if (!decoder_database_->IsComfortNoise(packet.payload_type)) { RTC_LOG(LS_ERROR) << "Trying to decode non-CNG payload as CNG."; @@ -1937,14 +1948,14 @@ int NetEqImpl::DoDtmf(const DtmfEvent& dtmf_event, bool* play_dtmf) { // // it must be copied to the speech buffer. // // TODO(hlundin): This code seems incorrect. (Legacy.) Write test and // // verify correct operation. - // assert(false); + // RTC_NOTREACHED(); // // Must generate enough data to replace all of the |sync_buffer_| // // "future". // int required_length = sync_buffer_->FutureLength(); - // assert(dtmf_tone_generator_->initialized()); + // RTC_DCHECK(dtmf_tone_generator_->initialized()); // dtmf_return_value = dtmf_tone_generator_->Generate(required_length, // algorithm_buffer_); - // assert((size_t) required_length == algorithm_buffer_->Size()); + // RTC_DCHECK((size_t) required_length == algorithm_buffer_->Size()); // if (dtmf_return_value < 0) { // algorithm_buffer_->Zeros(output_size_samples_); // return dtmf_return_value; @@ -1954,7 +1965,7 @@ int NetEqImpl::DoDtmf(const DtmfEvent& dtmf_event, bool* play_dtmf) { // // data. // // TODO(hlundin): It seems that this overwriting has gone lost. // // Not adapted for multi-channel yet. - // assert(algorithm_buffer_->Channels() == 1); + // RTC_DCHECK(algorithm_buffer_->Channels() == 1); // if (algorithm_buffer_->Channels() != 1) { // RTC_LOG(LS_WARNING) << "DTMF not supported for more than one channel"; // return kStereoNotSupported; @@ -1996,7 +2007,7 @@ int NetEqImpl::DtmfOverdub(const DtmfEvent& dtmf_event, if (dtmf_return_value == 0) { dtmf_return_value = dtmf_tone_generator_->Generate(overdub_length, &dtmf_output); - assert(overdub_length == dtmf_output.Size()); + RTC_DCHECK_EQ(overdub_length, dtmf_output.Size()); } dtmf_output.ReadInterleaved(overdub_length, &output[out_index]); return dtmf_return_value < 0 ? dtmf_return_value : 0; @@ -2027,7 +2038,7 @@ int NetEqImpl::ExtractPackets(size_t required_samples, next_packet = nullptr; if (!packet) { RTC_LOG(LS_ERROR) << "Should always be able to extract a packet here"; - assert(false); // Should always be able to extract a packet here. + RTC_NOTREACHED(); // Should always be able to extract a packet here. return -1; } const uint64_t waiting_time_ms = packet->waiting_time->ElapsedMs(); @@ -2120,8 +2131,9 @@ void NetEqImpl::SetSampleRateAndChannels(int fs_hz, size_t channels) { RTC_LOG(LS_VERBOSE) << "SetSampleRateAndChannels " << fs_hz << " " << channels; // TODO(hlundin): Change to an enumerator and skip assert. - assert(fs_hz == 8000 || fs_hz == 16000 || fs_hz == 32000 || fs_hz == 48000); - assert(channels > 0); + RTC_DCHECK(fs_hz == 8000 || fs_hz == 16000 || fs_hz == 32000 || + fs_hz == 48000); + RTC_DCHECK_GT(channels, 0); // Before changing the sample rate, end and report any ongoing expand event. stats_->EndExpandEvent(fs_hz_); @@ -2137,7 +2149,7 @@ void NetEqImpl::SetSampleRateAndChannels(int fs_hz, size_t channels) { cng_decoder->Reset(); // Reinit post-decode VAD with new sample rate. - assert(vad_.get()); // Cannot be NULL here. + RTC_DCHECK(vad_.get()); // Cannot be NULL here. vad_->Init(); // Delete algorithm buffer and create a new one. @@ -2159,7 +2171,7 @@ void NetEqImpl::SetSampleRateAndChannels(int fs_hz, size_t channels) { expand_->overlap_length()); normal_.reset(new Normal(fs_hz, decoder_database_.get(), *background_noise_, - expand_.get())); + expand_.get(), stats_.get())); accelerate_.reset( accelerate_factory_->Create(fs_hz, channels, *background_noise_)); preemptive_expand_.reset(preemptive_expand_factory_->Create( @@ -2180,8 +2192,8 @@ void NetEqImpl::SetSampleRateAndChannels(int fs_hz, size_t channels) { } NetEqImpl::OutputType NetEqImpl::LastOutputType() { - assert(vad_.get()); - assert(expand_.get()); + RTC_DCHECK(vad_.get()); + RTC_DCHECK(expand_.get()); if (last_mode_ == Mode::kCodecInternalCng || last_mode_ == Mode::kRfc3389Cng) { return OutputType::kCNG; diff --git a/modules/audio_coding/neteq/neteq_impl.h b/modules/audio_coding/neteq/neteq_impl.h index e130422a30..88da6dcbd5 100644 --- a/modules/audio_coding/neteq/neteq_impl.h +++ b/modules/audio_coding/neteq/neteq_impl.h @@ -133,6 +133,7 @@ class NetEqImpl : public webrtc::NetEq { int GetAudio( AudioFrame* audio_frame, bool* muted, + int* current_sample_rate_hz = nullptr, absl::optional action_override = absl::nullopt) override; void SetCodecs(const std::map& codecs) override; diff --git a/modules/audio_coding/neteq/neteq_impl_unittest.cc b/modules/audio_coding/neteq/neteq_impl_unittest.cc index c66a0e25f9..53b4dae17d 100644 --- a/modules/audio_coding/neteq/neteq_impl_unittest.cc +++ b/modules/audio_coding/neteq/neteq_impl_unittest.cc @@ -303,8 +303,7 @@ TEST_F(NetEqImplTest, InsertPacket) { fake_packet.sequence_number = kFirstSequenceNumber; fake_packet.timestamp = kFirstTimestamp; - rtc::scoped_refptr mock_decoder_factory( - new rtc::RefCountedObject); + auto mock_decoder_factory = rtc::make_ref_counted(); EXPECT_CALL(*mock_decoder_factory, MakeAudioDecoderMock(_, _, _)) .WillOnce(Invoke([&](const SdpAudioFormat& format, absl::optional codec_pair_id, @@ -487,8 +486,8 @@ TEST_F(NetEqImplTest, VerifyTimestampPropagation) { int16_t next_value_; } decoder_; - rtc::scoped_refptr decoder_factory = - new rtc::RefCountedObject(&decoder_); + auto decoder_factory = + rtc::make_ref_counted(&decoder_); UseNoMocks(); CreateInstance(decoder_factory); @@ -498,7 +497,7 @@ TEST_F(NetEqImplTest, VerifyTimestampPropagation) { // Insert one packet. clock_.AdvanceTimeMilliseconds(123456); - int64_t expected_receive_time_ms = clock_.TimeInMilliseconds(); + Timestamp expected_receive_time = clock_.CurrentTime(); EXPECT_EQ(NetEq::kOK, neteq_->InsertPacket(rtp_header, payload)); // Pull audio once. @@ -519,7 +518,7 @@ TEST_F(NetEqImplTest, VerifyTimestampPropagation) { EXPECT_THAT(packet_info.csrcs(), ElementsAre(43, 65, 17)); EXPECT_EQ(packet_info.rtp_timestamp(), rtp_header.timestamp); EXPECT_FALSE(packet_info.audio_level().has_value()); - EXPECT_EQ(packet_info.receive_time_ms(), expected_receive_time_ms); + EXPECT_EQ(packet_info.receive_time(), expected_receive_time); } // Start with a simple check that the fake decoder is behaving as expected. @@ -555,7 +554,7 @@ TEST_F(NetEqImplTest, ReorderedPacket) { MockAudioDecoder mock_decoder; CreateInstance( - new rtc::RefCountedObject(&mock_decoder)); + rtc::make_ref_counted(&mock_decoder)); const uint8_t kPayloadType = 17; // Just an arbitrary number. const int kSampleRateHz = 8000; @@ -591,7 +590,7 @@ TEST_F(NetEqImplTest, ReorderedPacket) { // Insert one packet. clock_.AdvanceTimeMilliseconds(123456); - int64_t expected_receive_time_ms = clock_.TimeInMilliseconds(); + Timestamp expected_receive_time = clock_.CurrentTime(); EXPECT_EQ(NetEq::kOK, neteq_->InsertPacket(rtp_header, payload)); // Pull audio once. @@ -611,7 +610,7 @@ TEST_F(NetEqImplTest, ReorderedPacket) { EXPECT_THAT(packet_info.csrcs(), IsEmpty()); EXPECT_EQ(packet_info.rtp_timestamp(), rtp_header.timestamp); EXPECT_EQ(packet_info.audio_level(), rtp_header.extension.audioLevel); - EXPECT_EQ(packet_info.receive_time_ms(), expected_receive_time_ms); + EXPECT_EQ(packet_info.receive_time(), expected_receive_time); } // Insert two more packets. The first one is out of order, and is already too @@ -627,7 +626,7 @@ TEST_F(NetEqImplTest, ReorderedPacket) { rtp_header.extension.audioLevel = 2; payload[0] = 2; clock_.AdvanceTimeMilliseconds(2000); - expected_receive_time_ms = clock_.TimeInMilliseconds(); + expected_receive_time = clock_.CurrentTime(); EXPECT_EQ(NetEq::kOK, neteq_->InsertPacket(rtp_header, payload)); // Expect only the second packet to be decoded (the one with "2" as the first @@ -657,7 +656,7 @@ TEST_F(NetEqImplTest, ReorderedPacket) { EXPECT_THAT(packet_info.csrcs(), IsEmpty()); EXPECT_EQ(packet_info.rtp_timestamp(), rtp_header.timestamp); EXPECT_EQ(packet_info.audio_level(), rtp_header.extension.audioLevel); - EXPECT_EQ(packet_info.receive_time_ms(), expected_receive_time_ms); + EXPECT_EQ(packet_info.receive_time(), expected_receive_time); } EXPECT_CALL(mock_decoder, Die()); @@ -737,6 +736,15 @@ class NetEqImplTestSampleRateParameter const int initial_sample_rate_hz_; }; +class NetEqImplTestSdpFormatParameter + : public NetEqImplTest, + public testing::WithParamInterface { + protected: + NetEqImplTestSdpFormatParameter() + : NetEqImplTest(), sdp_format_(GetParam()) {} + const SdpAudioFormat sdp_format_; +}; + // This test does the following: // 0. Set up NetEq with initial sample rate given by test parameter, and a codec // sample rate of 16000. @@ -920,6 +928,67 @@ INSTANTIATE_TEST_SUITE_P(SampleRates, NetEqImplTestSampleRateParameter, testing::Values(8000, 16000, 32000, 48000)); +TEST_P(NetEqImplTestSdpFormatParameter, GetNackListScaledTimestamp) { + UseNoMocks(); + CreateInstance(); + + neteq_->EnableNack(128); + + const uint8_t kPayloadType = 17; // Just an arbitrary number. + const int kPayloadSampleRateHz = sdp_format_.clockrate_hz; + const size_t kPayloadLengthSamples = + static_cast(10 * kPayloadSampleRateHz / 1000); // 10 ms. + const size_t kPayloadLengthBytes = kPayloadLengthSamples * 2; + std::vector payload(kPayloadLengthBytes, 0); + RTPHeader rtp_header; + rtp_header.payloadType = kPayloadType; + rtp_header.sequenceNumber = 0x1234; + rtp_header.timestamp = 0x12345678; + rtp_header.ssrc = 0x87654321; + + EXPECT_TRUE(neteq_->RegisterPayloadType(kPayloadType, sdp_format_)); + + auto insert_packet = [&](bool lost = false) { + rtp_header.sequenceNumber++; + rtp_header.timestamp += kPayloadLengthSamples; + if (!lost) + EXPECT_EQ(NetEq::kOK, neteq_->InsertPacket(rtp_header, payload)); + }; + + // Insert and decode 10 packets. + for (size_t i = 0; i < 10; ++i) { + insert_packet(); + } + AudioFrame output; + size_t count_loops = 0; + do { + bool muted; + // Make sure we don't hang the test if we never go to PLC. + ASSERT_LT(++count_loops, 100u); + EXPECT_EQ(NetEq::kOK, neteq_->GetAudio(&output, &muted)); + } while (output.speech_type_ == AudioFrame::kNormalSpeech); + + insert_packet(); + + insert_packet(/*lost=*/true); + + // Ensure packet gets marked as missing. + for (int i = 0; i < 5; ++i) { + insert_packet(); + } + + // Missing packet recoverable with 5ms RTT. + EXPECT_THAT(neteq_->GetNackList(5), Not(IsEmpty())); + + // No packets should have TimeToPlay > 500ms. + EXPECT_THAT(neteq_->GetNackList(500), IsEmpty()); +} + +INSTANTIATE_TEST_SUITE_P(GetNackList, + NetEqImplTestSdpFormatParameter, + testing::Values(SdpAudioFormat("g722", 8000, 1), + SdpAudioFormat("opus", 48000, 2))); + // This test verifies that NetEq can handle comfort noise and enters/quits codec // internal CNG mode properly. TEST_F(NetEqImplTest, CodecInternalCng) { @@ -927,7 +996,7 @@ TEST_F(NetEqImplTest, CodecInternalCng) { // Create a mock decoder object. MockAudioDecoder mock_decoder; CreateInstance( - new rtc::RefCountedObject(&mock_decoder)); + rtc::make_ref_counted(&mock_decoder)); const uint8_t kPayloadType = 17; // Just an arbitrary number. const int kSampleRateKhz = 48; @@ -987,15 +1056,6 @@ TEST_F(NetEqImplTest, CodecInternalCng) { EXPECT_TRUE(neteq_->RegisterPayloadType(kPayloadType, SdpAudioFormat("opus", 48000, 2))); - // Insert one packet (decoder will return speech). - EXPECT_EQ(NetEq::kOK, neteq_->InsertPacket(rtp_header, payload)); - - // Insert second packet (decoder will return CNG). - payload[0] = 1; - rtp_header.sequenceNumber++; - rtp_header.timestamp += kPayloadLengthSamples; - EXPECT_EQ(NetEq::kOK, neteq_->InsertPacket(rtp_header, payload)); - const size_t kMaxOutputSize = static_cast(10 * kSampleRateKhz); AudioFrame output; AudioFrame::SpeechType expected_type[8] = { @@ -1012,11 +1072,20 @@ TEST_F(NetEqImplTest, CodecInternalCng) { 50 * kSampleRateKhz, 10 * kSampleRateKhz}; + // Insert one packet (decoder will return speech). + EXPECT_EQ(NetEq::kOK, neteq_->InsertPacket(rtp_header, payload)); + bool muted; EXPECT_EQ(NetEq::kOK, neteq_->GetAudio(&output, &muted)); absl::optional last_timestamp = neteq_->GetPlayoutTimestamp(); ASSERT_TRUE(last_timestamp); + // Insert second packet (decoder will return CNG). + payload[0] = 1; + rtp_header.sequenceNumber++; + rtp_header.timestamp += kPayloadLengthSamples; + EXPECT_EQ(NetEq::kOK, neteq_->InsertPacket(rtp_header, payload)); + // Lambda for verifying the timestamps. auto verify_timestamp = [&last_timestamp, &expected_timestamp_increment]( absl::optional ts, size_t i) { @@ -1066,7 +1135,7 @@ TEST_F(NetEqImplTest, UnsupportedDecoder) { ::testing::NiceMock decoder; CreateInstance( - new rtc::RefCountedObject(&decoder)); + rtc::make_ref_counted(&decoder)); static const size_t kNetEqMaxFrameSize = 5760; // 120 ms @ 48 kHz. static const size_t kChannels = 2; @@ -1193,7 +1262,7 @@ TEST_F(NetEqImplTest, DecodedPayloadTooShort) { MockAudioDecoder mock_decoder; CreateInstance( - new rtc::RefCountedObject(&mock_decoder)); + rtc::make_ref_counted(&mock_decoder)); const uint8_t kPayloadType = 17; // Just an arbitrary number. const int kSampleRateHz = 8000; @@ -1252,7 +1321,7 @@ TEST_F(NetEqImplTest, DecodingError) { MockAudioDecoder mock_decoder; CreateInstance( - new rtc::RefCountedObject(&mock_decoder)); + rtc::make_ref_counted(&mock_decoder)); const uint8_t kPayloadType = 17; // Just an arbitrary number. const int kSampleRateHz = 8000; @@ -1364,7 +1433,7 @@ TEST_F(NetEqImplTest, DecodingErrorDuringInternalCng) { // Create a mock decoder object. MockAudioDecoder mock_decoder; CreateInstance( - new rtc::RefCountedObject(&mock_decoder)); + rtc::make_ref_counted(&mock_decoder)); const uint8_t kPayloadType = 17; // Just an arbitrary number. const int kSampleRateHz = 8000; @@ -1658,14 +1727,13 @@ class NetEqImplTest120ms : public NetEqImplTest { void Register120msCodec(AudioDecoder::SpeechType speech_type) { const uint32_t sampling_freq = kSamplingFreq_; - decoder_factory_ = - new rtc::RefCountedObject( - [sampling_freq, speech_type]() { - std::unique_ptr decoder = - std::make_unique(sampling_freq, speech_type); - RTC_CHECK_EQ(2, decoder->Channels()); - return decoder; - }); + decoder_factory_ = rtc::make_ref_counted( + [sampling_freq, speech_type]() { + std::unique_ptr decoder = + std::make_unique(sampling_freq, speech_type); + RTC_CHECK_EQ(2, decoder->Channels()); + return decoder; + }); } rtc::scoped_refptr decoder_factory_; diff --git a/modules/audio_coding/neteq/neteq_network_stats_unittest.cc b/modules/audio_coding/neteq/neteq_network_stats_unittest.cc index 5f15babbe3..8f72734d23 100644 --- a/modules/audio_coding/neteq/neteq_network_stats_unittest.cc +++ b/modules/audio_coding/neteq/neteq_network_stats_unittest.cc @@ -162,7 +162,7 @@ class NetEqNetworkStatsTest { NetEqNetworkStatsTest(const SdpAudioFormat& format, MockAudioDecoder* decoder) : decoder_(decoder), decoder_factory_( - new rtc::RefCountedObject(decoder)), + rtc::make_ref_counted(decoder)), samples_per_ms_(format.clockrate_hz / 1000), frame_size_samples_(kFrameSizeMs * samples_per_ms_), rtp_generator_(new RtpGenerator(samples_per_ms_)), diff --git a/modules/audio_coding/neteq/neteq_unittest.cc b/modules/audio_coding/neteq/neteq_unittest.cc index c6d514d827..bdd90e96cc 100644 --- a/modules/audio_coding/neteq/neteq_unittest.cc +++ b/modules/audio_coding/neteq/neteq_unittest.cc @@ -34,7 +34,6 @@ #include "rtc_base/ignore_wundef.h" #include "rtc_base/message_digest.h" #include "rtc_base/numerics/safe_conversions.h" -#include "rtc_base/string_encode.h" #include "rtc_base/strings/string_builder.h" #include "rtc_base/system/arch.h" #include "test/field_trial.h" @@ -83,17 +82,29 @@ TEST_F(NetEqDecodingTest, MAYBE_TestBitExactness) { const std::string input_rtp_file = webrtc::test::ResourcePath("audio_coding/neteq_universal_new", "rtp"); - const std::string output_checksum = - PlatformChecksum("68ec266d2d152dfc0d938484e7936f6af4f803e3", - "1c243feb35e3e9ab37039eddf5b3c3ecfca3c60c", "not used", - "68ec266d2d152dfc0d938484e7936f6af4f803e3", - "f68c546a43bb25743297c9c0c9027e8424b8e10b"); - - const std::string network_stats_checksum = - PlatformChecksum("2a5516cdc1c6af9f1d9d3c2f95ed292f509311c7", - "e96a7f081ebc111f49c7373d3728274057012ae9", "not used", - "2a5516cdc1c6af9f1d9d3c2f95ed292f509311c7", - "2a5516cdc1c6af9f1d9d3c2f95ed292f509311c7"); + const std::string output_checksum = PlatformChecksum( +// TODO(bugs.webrtc.org/12941): Linux x86 optimized builds have a different +// checksum. +#if defined(WEBRTC_LINUX) && defined(NDEBUG) && defined(WEBRTC_ARCH_X86) + "8d9c177b7f2f9398c0944a851edffae214de2c56", +#else + "6c35140ce4d75874bdd60aa1872400b05fd05ca2", +#endif + "ab451bb8301d9a92fbf4de91556b56f1ea38b4ce", "not used", + "6c35140ce4d75874bdd60aa1872400b05fd05ca2", + "64b46bb3c1165537a880ae8404afce2efba456c0"); + + const std::string network_stats_checksum = PlatformChecksum( +// TODO(bugs.webrtc.org/12941): Linux x86 optimized builds have a different +// checksum. +#if defined(WEBRTC_LINUX) && defined(NDEBUG) && defined(WEBRTC_ARCH_X86) + "8cc08e3cd6801dcba4fcc15eb4036c19296a140d", +#else + "90594d85fa31d3d9584d79293bf7aa4ee55ed751", +#endif + "77b9c3640b81aff6a38d69d07dd782d39c15321d", "not used", + "90594d85fa31d3d9584d79293bf7aa4ee55ed751", + "90594d85fa31d3d9584d79293bf7aa4ee55ed751"); DecodeAndCompare(input_rtp_file, output_checksum, network_stats_checksum, absl::GetFlag(FLAGS_gen_ref)); @@ -105,35 +116,33 @@ TEST_F(NetEqDecodingTest, MAYBE_TestBitExactness) { #else #define MAYBE_TestOpusBitExactness DISABLED_TestOpusBitExactness #endif -TEST_F(NetEqDecodingTest, MAYBE_TestOpusBitExactness) { +// TODO(http://bugs.webrtc.org/12518): Enable the test after Opus has been +// updated. +TEST_F(NetEqDecodingTest, DISABLED_TestOpusBitExactness) { const std::string input_rtp_file = webrtc::test::ResourcePath("audio_coding/neteq_opus", "rtp"); const std::string maybe_sse = - "554ad4133934e3920f97575579a46f674683d77c" - "|de316e2bfb15192edb820fe5fb579d11ff5a524b"; + "c7887ff60eecf460332c6c7a28c81561f9e8a40f" + "|673dd422cfc174152536d3b13af64f9722520ab5"; const std::string output_checksum = PlatformChecksum( - maybe_sse, "b3fac4ad4f6ea384aff676ee1ea816bd70415490", - "373ccd99c147cd3fcef0e7dcad6f87d0f8e5a1c0", maybe_sse, maybe_sse); + maybe_sse, "e39283dd61a89cead3786ef8642d2637cc447296", + "53d8073eb848b70974cba9e26424f4946508fd19", maybe_sse, maybe_sse); const std::string network_stats_checksum = - PlatformChecksum("ec29e047b019a86ec06e2c40643143dc1975c69f", - "ce6f519bc1220b003944ac5d9db077665a06834e", - "abb686d3ac6fac0001ca8d45a6ca6f5aefb2f9d6", - "ec29e047b019a86ec06e2c40643143dc1975c69f", - "ec29e047b019a86ec06e2c40643143dc1975c69f"); + PlatformChecksum("c438bfa3b018f77691279eb9c63730569f54585c", + "8a474ed0992591e0c84f593824bb05979c3de157", + "9a05378dbf7e6edd56cdeb8ec45bcd6d8589623c", + "c438bfa3b018f77691279eb9c63730569f54585c", + "c438bfa3b018f77691279eb9c63730569f54585c"); DecodeAndCompare(input_rtp_file, output_checksum, network_stats_checksum, absl::GetFlag(FLAGS_gen_ref)); } -#if !defined(WEBRTC_IOS) && defined(WEBRTC_NETEQ_UNITTEST_BITEXACT) && \ - defined(WEBRTC_CODEC_OPUS) -#define MAYBE_TestOpusDtxBitExactness TestOpusDtxBitExactness -#else -#define MAYBE_TestOpusDtxBitExactness DISABLED_TestOpusDtxBitExactness -#endif -TEST_F(NetEqDecodingTest, MAYBE_TestOpusDtxBitExactness) { +// TODO(http://bugs.webrtc.org/12518): Enable the test after Opus has been +// updated. +TEST_F(NetEqDecodingTest, DISABLED_TestOpusDtxBitExactness) { const std::string input_rtp_file = webrtc::test::ResourcePath("audio_coding/neteq_opus_dtx", "rtp"); @@ -1068,7 +1077,7 @@ TEST_F(NetEqDecodingTestFaxMode, TestJitterBufferDelayWithAcceleration) { expected_target_delay += neteq_->TargetDelayMs() * 2 * kSamples; // We have two packets in the buffer and kAccelerate operation will // extract 20 ms of data. - neteq_->GetAudio(&out_frame_, &muted, NetEq::Operation::kAccelerate); + neteq_->GetAudio(&out_frame_, &muted, nullptr, NetEq::Operation::kAccelerate); // Check jitter buffer delay. NetEqLifetimeStatistics stats = neteq_->GetLifetimeStatistics(); diff --git a/modules/audio_coding/neteq/normal.cc b/modules/audio_coding/neteq/normal.cc index 967deea77a..3ed0e26a75 100644 --- a/modules/audio_coding/neteq/normal.cc +++ b/modules/audio_coding/neteq/normal.cc @@ -14,7 +14,6 @@ #include // min -#include "api/audio_codecs/audio_decoder.h" #include "common_audio/signal_processing/include/signal_processing_library.h" #include "modules/audio_coding/neteq/audio_multi_vector.h" #include "modules/audio_coding/neteq/background_noise.h" @@ -50,6 +49,13 @@ int Normal::Process(const int16_t* input, // TODO(hlundin): Investigate this further. const int fs_shift = 30 - WebRtcSpl_NormW32(fs_mult); + // If last call resulted in a CodedPlc we don't need to do cross-fading but we + // need to report the end of the interruption once we are back to normal + // operation. + if (last_mode == NetEq::Mode::kCodecPlc) { + statistics_->EndExpandEvent(fs_hz_); + } + // Check if last RecOut call resulted in an Expand. If so, we have to take // care of some cross-fading and unmuting. if (last_mode == NetEq::Mode::kExpand) { diff --git a/modules/audio_coding/neteq/normal.h b/modules/audio_coding/neteq/normal.h index d8c13e6190..d6dc84a2d6 100644 --- a/modules/audio_coding/neteq/normal.h +++ b/modules/audio_coding/neteq/normal.h @@ -15,6 +15,7 @@ #include // Access to size_t. #include "api/neteq/neteq.h" +#include "modules/audio_coding/neteq/statistics_calculator.h" #include "rtc_base/checks.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/numerics/safe_conversions.h" @@ -35,14 +36,16 @@ class Normal { Normal(int fs_hz, DecoderDatabase* decoder_database, const BackgroundNoise& background_noise, - Expand* expand) + Expand* expand, + StatisticsCalculator* statistics) : fs_hz_(fs_hz), decoder_database_(decoder_database), background_noise_(background_noise), expand_(expand), samples_per_ms_(rtc::CheckedDivExact(fs_hz_, 1000)), default_win_slope_Q14_( - rtc::dchecked_cast((1 << 14) / samples_per_ms_)) {} + rtc::dchecked_cast((1 << 14) / samples_per_ms_)), + statistics_(statistics) {} virtual ~Normal() {} @@ -64,6 +67,7 @@ class Normal { Expand* expand_; const size_t samples_per_ms_; const int16_t default_win_slope_Q14_; + StatisticsCalculator* const statistics_; RTC_DISALLOW_COPY_AND_ASSIGN(Normal); }; diff --git a/modules/audio_coding/neteq/normal_unittest.cc b/modules/audio_coding/neteq/normal_unittest.cc index 36751f8bcc..7e533bb1eb 100644 --- a/modules/audio_coding/neteq/normal_unittest.cc +++ b/modules/audio_coding/neteq/normal_unittest.cc @@ -50,7 +50,7 @@ TEST(Normal, CreateAndDestroy) { RandomVector random_vector; StatisticsCalculator statistics; Expand expand(&bgn, &sync_buffer, &random_vector, &statistics, fs, channels); - Normal normal(fs, &db, bgn, &expand); + Normal normal(fs, &db, bgn, &expand, &statistics); EXPECT_CALL(db, Die()); // Called when |db| goes out of scope. } @@ -64,7 +64,7 @@ TEST(Normal, AvoidDivideByZero) { StatisticsCalculator statistics; MockExpand expand(&bgn, &sync_buffer, &random_vector, &statistics, fs, channels); - Normal normal(fs, &db, bgn, &expand); + Normal normal(fs, &db, bgn, &expand, &statistics); int16_t input[1000] = {0}; AudioMultiVector output(channels); @@ -99,7 +99,7 @@ TEST(Normal, InputLengthAndChannelsDoNotMatch) { StatisticsCalculator statistics; MockExpand expand(&bgn, &sync_buffer, &random_vector, &statistics, fs, channels); - Normal normal(fs, &db, bgn, &expand); + Normal normal(fs, &db, bgn, &expand, &statistics); int16_t input[1000] = {0}; AudioMultiVector output(channels); @@ -124,7 +124,7 @@ TEST(Normal, LastModeExpand120msPacket) { StatisticsCalculator statistics; MockExpand expand(&bgn, &sync_buffer, &random_vector, &statistics, kFs, kChannels); - Normal normal(kFs, &db, bgn, &expand); + Normal normal(kFs, &db, bgn, &expand, &statistics); int16_t input[kPacketsizeBytes] = {0}; AudioMultiVector output(kChannels); diff --git a/modules/audio_coding/neteq/red_payload_splitter.cc b/modules/audio_coding/neteq/red_payload_splitter.cc index 5681464f4d..2f21a5ff6c 100644 --- a/modules/audio_coding/neteq/red_payload_splitter.cc +++ b/modules/audio_coding/neteq/red_payload_splitter.cc @@ -41,7 +41,7 @@ bool RedPayloadSplitter::SplitRed(PacketList* packet_list) { PacketList::iterator it = packet_list->begin(); while (it != packet_list->end()) { const Packet& red_packet = *it; - assert(!red_packet.payload.empty()); + RTC_DCHECK(!red_packet.payload.empty()); const uint8_t* payload_ptr = red_packet.payload.data(); size_t payload_length = red_packet.payload.size(); @@ -139,7 +139,7 @@ bool RedPayloadSplitter::SplitRed(PacketList* packet_list) { /*rtp_timestamp=*/new_packet.timestamp, /*audio_level=*/absl::nullopt, /*absolute_capture_time=*/absl::nullopt, - /*receive_time_ms=*/red_packet.packet_info.receive_time_ms()); + /*receive_time=*/red_packet.packet_info.receive_time()); new_packets.push_front(std::move(new_packet)); payload_ptr += payload_length; } diff --git a/modules/audio_coding/neteq/red_payload_splitter_unittest.cc b/modules/audio_coding/neteq/red_payload_splitter_unittest.cc index 5956971b33..7275232daa 100644 --- a/modules/audio_coding/neteq/red_payload_splitter_unittest.cc +++ b/modules/audio_coding/neteq/red_payload_splitter_unittest.cc @@ -103,7 +103,7 @@ Packet CreateRedPayload(size_t num_payloads, rtc::checked_cast((num_payloads - i - 1) * timestamp_offset); *payload_ptr = this_offset >> 6; ++payload_ptr; - assert(kPayloadLength <= 1023); // Max length described by 10 bits. + RTC_DCHECK_LE(kPayloadLength, 1023); // Max length described by 10 bits. *payload_ptr = ((this_offset & 0x3F) << 2) | (kPayloadLength >> 8); ++payload_ptr; *payload_ptr = kPayloadLength & 0xFF; @@ -298,7 +298,7 @@ TEST(RedPayloadSplitter, CheckRedPayloads) { // easier to just register the payload types and let the actual implementation // do its job. DecoderDatabase decoder_database( - new rtc::RefCountedObject, absl::nullopt); + rtc::make_ref_counted(), absl::nullopt); decoder_database.RegisterPayload(0, SdpAudioFormat("cn", 8000, 1)); decoder_database.RegisterPayload(1, SdpAudioFormat("pcmu", 8000, 1)); decoder_database.RegisterPayload(2, @@ -333,7 +333,7 @@ TEST(RedPayloadSplitter, CheckRedPayloadsRecursiveRed) { // easier to just register the payload types and let the actual implementation // do its job. DecoderDatabase decoder_database( - new rtc::RefCountedObject, absl::nullopt); + rtc::make_ref_counted(), absl::nullopt); decoder_database.RegisterPayload(kRedPayloadType, SdpAudioFormat("red", 8000, 1)); diff --git a/modules/audio_coding/neteq/statistics_calculator.cc b/modules/audio_coding/neteq/statistics_calculator.cc index 708780a8a8..12a0e3c9ec 100644 --- a/modules/audio_coding/neteq/statistics_calculator.cc +++ b/modules/audio_coding/neteq/statistics_calculator.cc @@ -375,7 +375,7 @@ uint16_t StatisticsCalculator::CalculateQ14Ratio(size_t numerator, return 0; } else if (numerator < denominator) { // Ratio must be smaller than 1 in Q14. - assert((numerator << 14) / denominator < (1 << 14)); + RTC_DCHECK_LT((numerator << 14) / denominator, (1 << 14)); return static_cast((numerator << 14) / denominator); } else { // Will not produce a ratio larger than 1, since this is probably an error. diff --git a/modules/audio_coding/neteq/sync_buffer.cc b/modules/audio_coding/neteq/sync_buffer.cc index 4949bb201f..73e0628ea6 100644 --- a/modules/audio_coding/neteq/sync_buffer.cc +++ b/modules/audio_coding/neteq/sync_buffer.cc @@ -28,7 +28,7 @@ void SyncBuffer::PushBack(const AudioMultiVector& append_this) { next_index_ -= samples_added; } else { // This means that we are pushing out future data that was never used. - // assert(false); + // RTC_NOTREACHED(); // TODO(hlundin): This assert must be disabled to support 60 ms frames. // This should not happen even for 60 ms frames, but it does. Investigate // why. diff --git a/modules/audio_coding/neteq/test/result_sink.cc b/modules/audio_coding/neteq/test/result_sink.cc index bb2a59bcfe..b70016180e 100644 --- a/modules/audio_coding/neteq/test/result_sink.cc +++ b/modules/audio_coding/neteq/test/result_sink.cc @@ -47,15 +47,6 @@ void Convert(const webrtc::NetEqNetworkStatistics& stats_raw, stats->set_max_waiting_time_ms(stats_raw.max_waiting_time_ms); } -void Convert(const webrtc::RtcpStatistics& stats_raw, - webrtc::neteq_unittest::RtcpStatistics* stats) { - stats->set_fraction_lost(stats_raw.fraction_lost); - stats->set_cumulative_lost(stats_raw.packets_lost); - stats->set_extended_max_sequence_number( - stats_raw.extended_highest_sequence_number); - stats->set_jitter(stats_raw.jitter); -} - void AddMessage(FILE* file, rtc::MessageDigest* digest, const std::string& message) { @@ -99,19 +90,6 @@ void ResultSink::AddResult(const NetEqNetworkStatistics& stats_raw) { #endif // WEBRTC_NETEQ_UNITTEST_BITEXACT } -void ResultSink::AddResult(const RtcpStatistics& stats_raw) { -#ifdef WEBRTC_NETEQ_UNITTEST_BITEXACT - neteq_unittest::RtcpStatistics stats; - Convert(stats_raw, &stats); - - std::string stats_string; - ASSERT_TRUE(stats.SerializeToString(&stats_string)); - AddMessage(output_fp_, digest_.get(), stats_string); -#else - FAIL() << "Writing to reference file requires Proto Buffer."; -#endif // WEBRTC_NETEQ_UNITTEST_BITEXACT -} - void ResultSink::VerifyChecksum(const std::string& checksum) { std::vector buffer; buffer.resize(digest_->Size()); diff --git a/modules/audio_coding/neteq/test/result_sink.h b/modules/audio_coding/neteq/test/result_sink.h index 357b635b08..dcde02d450 100644 --- a/modules/audio_coding/neteq/test/result_sink.h +++ b/modules/audio_coding/neteq/test/result_sink.h @@ -16,7 +16,6 @@ #include #include "api/neteq/neteq.h" -#include "modules/rtp_rtcp/include/rtcp_statistics.h" #include "rtc_base/message_digest.h" namespace webrtc { @@ -30,7 +29,6 @@ class ResultSink { void AddResult(const T* test_results, size_t length); void AddResult(const NetEqNetworkStatistics& stats); - void AddResult(const RtcpStatistics& stats); void VerifyChecksum(const std::string& ref_check_sum); diff --git a/modules/audio_coding/neteq/time_stretch.cc b/modules/audio_coding/neteq/time_stretch.cc index ba24e0bfc3..b7680292bd 100644 --- a/modules/audio_coding/neteq/time_stretch.cc +++ b/modules/audio_coding/neteq/time_stretch.cc @@ -66,7 +66,7 @@ TimeStretch::ReturnCodes TimeStretch::Process(const int16_t* input, DspHelper::PeakDetection(auto_correlation_, kCorrelationLen, kNumPeaks, fs_mult_, &peak_index, &peak_value); // Assert that |peak_index| stays within boundaries. - assert(peak_index <= (2 * kCorrelationLen - 1) * fs_mult_); + RTC_DCHECK_LE(peak_index, (2 * kCorrelationLen - 1) * fs_mult_); // Compensate peak_index for displaced starting position. The displacement // happens in AutoCorrelation(). Here, |kMinLag| is in the down-sampled 4 kHz @@ -74,8 +74,9 @@ TimeStretch::ReturnCodes TimeStretch::Process(const int16_t* input, // multiplication by fs_mult_ * 2. peak_index += kMinLag * fs_mult_ * 2; // Assert that |peak_index| stays within boundaries. - assert(peak_index >= static_cast(20 * fs_mult_)); - assert(peak_index <= 20 * fs_mult_ + (2 * kCorrelationLen - 1) * fs_mult_); + RTC_DCHECK_GE(peak_index, static_cast(20 * fs_mult_)); + RTC_DCHECK_LE(peak_index, + 20 * fs_mult_ + (2 * kCorrelationLen - 1) * fs_mult_); // Calculate scaling to ensure that |peak_index| samples can be square-summed // without overflowing. diff --git a/modules/audio_coding/neteq/time_stretch.h b/modules/audio_coding/neteq/time_stretch.h index aede9cadf3..26d295f669 100644 --- a/modules/audio_coding/neteq/time_stretch.h +++ b/modules/audio_coding/neteq/time_stretch.h @@ -42,9 +42,9 @@ class TimeStretch { num_channels_(num_channels), background_noise_(background_noise), max_input_value_(0) { - assert(sample_rate_hz_ == 8000 || sample_rate_hz_ == 16000 || - sample_rate_hz_ == 32000 || sample_rate_hz_ == 48000); - assert(num_channels_ > 0); + RTC_DCHECK(sample_rate_hz_ == 8000 || sample_rate_hz_ == 16000 || + sample_rate_hz_ == 32000 || sample_rate_hz_ == 48000); + RTC_DCHECK_GT(num_channels_, 0); memset(auto_correlation_, 0, sizeof(auto_correlation_)); } diff --git a/modules/audio_coding/neteq/tools/constant_pcm_packet_source.cc b/modules/audio_coding/neteq/tools/constant_pcm_packet_source.cc index 6b325b6c5c..6cbba20e5f 100644 --- a/modules/audio_coding/neteq/tools/constant_pcm_packet_source.cc +++ b/modules/audio_coding/neteq/tools/constant_pcm_packet_source.cc @@ -37,14 +37,15 @@ ConstantPcmPacketSource::ConstantPcmPacketSource(size_t payload_len_samples, std::unique_ptr ConstantPcmPacketSource::NextPacket() { RTC_CHECK_GT(packet_len_bytes_, kHeaderLenBytes); - uint8_t* packet_memory = new uint8_t[packet_len_bytes_]; + rtc::CopyOnWriteBuffer packet_buffer(packet_len_bytes_); + uint8_t* packet_memory = packet_buffer.MutableData(); // Fill the payload part of the packet memory with the pre-encoded value. for (unsigned i = 0; i < 2 * payload_len_samples_; ++i) packet_memory[kHeaderLenBytes + i] = encoded_sample_[i % 2]; WriteHeader(packet_memory); // |packet| assumes ownership of |packet_memory|. - std::unique_ptr packet( - new Packet(packet_memory, packet_len_bytes_, next_arrival_time_ms_)); + auto packet = + std::make_unique(std::move(packet_buffer), next_arrival_time_ms_); next_arrival_time_ms_ += payload_len_samples_ / samples_per_ms_; return packet; } diff --git a/modules/audio_coding/neteq/tools/neteq_test.cc b/modules/audio_coding/neteq/tools/neteq_test.cc index 0988d2c8e5..22f5ad6931 100644 --- a/modules/audio_coding/neteq/tools/neteq_test.cc +++ b/modules/audio_coding/neteq/tools/neteq_test.cc @@ -172,7 +172,7 @@ NetEqTest::SimulationStepResult NetEqTest::RunToNextGetAudio() { } AudioFrame out_frame; bool muted; - int error = neteq_->GetAudio(&out_frame, &muted, + int error = neteq_->GetAudio(&out_frame, &muted, nullptr, ActionToOperations(next_action_)); next_action_ = absl::nullopt; RTC_CHECK(!muted) << "The code does not handle enable_muted_state"; diff --git a/modules/audio_coding/neteq/tools/neteq_test_factory.cc b/modules/audio_coding/neteq/tools/neteq_test_factory.cc index f8ec36bd25..1a0ea156f1 100644 --- a/modules/audio_coding/neteq/tools/neteq_test_factory.cc +++ b/modules/audio_coding/neteq/tools/neteq_test_factory.cc @@ -285,7 +285,7 @@ std::unique_ptr NetEqTestFactory::InitializeTest( // Note that capture-by-copy implies that the lambda captures the value of // decoder_factory before it's reassigned on the left-hand side. - decoder_factory = new rtc::RefCountedObject( + decoder_factory = rtc::make_ref_counted( [decoder_factory, config]( const SdpAudioFormat& format, absl::optional codec_pair_id) { diff --git a/modules/audio_coding/neteq/tools/output_audio_file.h b/modules/audio_coding/neteq/tools/output_audio_file.h index d729c9cbeb..7220a36d69 100644 --- a/modules/audio_coding/neteq/tools/output_audio_file.h +++ b/modules/audio_coding/neteq/tools/output_audio_file.h @@ -36,7 +36,7 @@ class OutputAudioFile : public AudioSink { } bool WriteArray(const int16_t* audio, size_t num_samples) override { - assert(out_file_); + RTC_DCHECK(out_file_); return fwrite(audio, sizeof(*audio), num_samples, out_file_) == num_samples; } diff --git a/modules/audio_coding/neteq/tools/packet.cc b/modules/audio_coding/neteq/tools/packet.cc index 48959e4f62..e540173f43 100644 --- a/modules/audio_coding/neteq/tools/packet.cc +++ b/modules/audio_coding/neteq/tools/packet.cc @@ -10,30 +10,22 @@ #include "modules/audio_coding/neteq/tools/packet.h" -#include - -#include - -#include "modules/rtp_rtcp/source/rtp_utility.h" +#include "api/array_view.h" +#include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "rtc_base/checks.h" +#include "rtc_base/copy_on_write_buffer.h" namespace webrtc { namespace test { -using webrtc::RtpUtility::RtpHeaderParser; - -Packet::Packet(uint8_t* packet_memory, - size_t allocated_bytes, +Packet::Packet(rtc::CopyOnWriteBuffer packet, size_t virtual_packet_length_bytes, double time_ms, - const RtpUtility::RtpHeaderParser& parser, - const RtpHeaderExtensionMap* extension_map /*= nullptr*/) - : payload_memory_(packet_memory), - packet_length_bytes_(allocated_bytes), + const RtpHeaderExtensionMap* extension_map) + : packet_(std::move(packet)), virtual_packet_length_bytes_(virtual_packet_length_bytes), - virtual_payload_length_bytes_(0), time_ms_(time_ms), - valid_header_(ParseHeader(parser, extension_map)) {} + valid_header_(ParseHeader(extension_map)) {} Packet::Packet(const RTPHeader& header, size_t virtual_packet_length_bytes, @@ -45,23 +37,6 @@ Packet::Packet(const RTPHeader& header, time_ms_(time_ms), valid_header_(true) {} -Packet::Packet(uint8_t* packet_memory, size_t allocated_bytes, double time_ms) - : Packet(packet_memory, - allocated_bytes, - allocated_bytes, - time_ms, - RtpUtility::RtpHeaderParser(packet_memory, allocated_bytes)) {} - -Packet::Packet(uint8_t* packet_memory, - size_t allocated_bytes, - size_t virtual_packet_length_bytes, - double time_ms) - : Packet(packet_memory, - allocated_bytes, - virtual_packet_length_bytes, - time_ms, - RtpUtility::RtpHeaderParser(packet_memory, allocated_bytes)) {} - Packet::~Packet() = default; bool Packet::ExtractRedHeaders(std::list* headers) const { @@ -77,9 +52,8 @@ bool Packet::ExtractRedHeaders(std::list* headers) const { // +-+-+-+-+-+-+-+-+ // - RTC_DCHECK(payload_); - const uint8_t* payload_ptr = payload_; - const uint8_t* payload_end_ptr = payload_ptr + payload_length_bytes_; + const uint8_t* payload_ptr = payload(); + const uint8_t* payload_end_ptr = payload_ptr + payload_length_bytes(); // Find all RED headers with the extension bit set to 1. That is, all headers // but the last one. @@ -111,27 +85,43 @@ void Packet::DeleteRedHeaders(std::list* headers) { } } -bool Packet::ParseHeader(const RtpHeaderParser& parser, - const RtpHeaderExtensionMap* extension_map) { - bool valid_header = parser.Parse(&header_, extension_map); - - // Special case for dummy packets that have padding marked in the RTP header. - // This causes the RTP header parser to report failure, but is fine in this - // context. - const bool header_only_with_padding = - (header_.headerLength == packet_length_bytes_ && - header_.paddingLength > 0); - if (!valid_header && !header_only_with_padding) { - return false; +bool Packet::ParseHeader(const RtpHeaderExtensionMap* extension_map) { + // Use RtpPacketReceived instead of RtpPacket because former already has a + // converter into legacy RTPHeader. + webrtc::RtpPacketReceived rtp_packet(extension_map); + + // Because of the special case of dummy packets that have padding marked in + // the RTP header, but do not have rtp payload with the padding size, handle + // padding manually. Regular RTP packet parser reports failure, but it is fine + // in this context. + bool padding = (packet_[0] & 0b0010'0000); + size_t padding_size = 0; + if (padding) { + // Clear the padding bit to prevent failure when rtp payload is omited. + rtc::CopyOnWriteBuffer packet(packet_); + packet.MutableData()[0] &= ~0b0010'0000; + if (!rtp_packet.Parse(std::move(packet))) { + return false; + } + if (rtp_packet.payload_size() > 0) { + padding_size = rtp_packet.data()[rtp_packet.size() - 1]; + } + if (padding_size > rtp_packet.payload_size()) { + return false; + } + } else { + if (!rtp_packet.Parse(packet_)) { + return false; + } } - RTC_DCHECK_LE(header_.headerLength, packet_length_bytes_); - payload_ = &payload_memory_[header_.headerLength]; - RTC_DCHECK_GE(packet_length_bytes_, header_.headerLength); - payload_length_bytes_ = packet_length_bytes_ - header_.headerLength; - RTC_CHECK_GE(virtual_packet_length_bytes_, packet_length_bytes_); - RTC_DCHECK_GE(virtual_packet_length_bytes_, header_.headerLength); + rtp_payload_ = rtc::MakeArrayView(packet_.data() + rtp_packet.headers_size(), + rtp_packet.payload_size() - padding_size); + rtp_packet.GetHeader(&header_); + + RTC_CHECK_GE(virtual_packet_length_bytes_, rtp_packet.size()); + RTC_DCHECK_GE(virtual_packet_length_bytes_, rtp_packet.headers_size()); virtual_payload_length_bytes_ = - virtual_packet_length_bytes_ - header_.headerLength; + virtual_packet_length_bytes_ - rtp_packet.headers_size(); return true; } diff --git a/modules/audio_coding/neteq/tools/packet.h b/modules/audio_coding/neteq/tools/packet.h index f4189aae10..ef118d9f0b 100644 --- a/modules/audio_coding/neteq/tools/packet.h +++ b/modules/audio_coding/neteq/tools/packet.h @@ -12,62 +12,46 @@ #define MODULES_AUDIO_CODING_NETEQ_TOOLS_PACKET_H_ #include -#include -#include "api/rtp_headers.h" // NOLINT(build/include) +#include "api/array_view.h" +#include "api/rtp_headers.h" #include "modules/rtp_rtcp/include/rtp_header_extension_map.h" #include "rtc_base/constructor_magic.h" +#include "rtc_base/copy_on_write_buffer.h" namespace webrtc { - -namespace RtpUtility { -class RtpHeaderParser; -} // namespace RtpUtility - namespace test { // Class for handling RTP packets in test applications. class Packet { public: // Creates a packet, with the packet payload (including header bytes) in - // |packet_memory|. The length of |packet_memory| is |allocated_bytes|. - // The new object assumes ownership of |packet_memory| and will delete it - // when the Packet object is deleted. The |time_ms| is an extra time - // associated with this packet, typically used to denote arrival time. - // The first bytes in |packet_memory| will be parsed using |parser|. - // |virtual_packet_length_bytes| is typically used when reading RTP dump files + // `packet`. The `time_ms` is an extra time associated with this packet, + // typically used to denote arrival time. + // `virtual_packet_length_bytes` is typically used when reading RTP dump files // that only contain the RTP headers, and no payload (a.k.a RTP dummy files or - // RTP light). The |virtual_packet_length_bytes| tells what size the packet - // had on wire, including the now discarded payload, whereas |allocated_bytes| - // is the length of the remaining payload (typically only the RTP header). - Packet(uint8_t* packet_memory, - size_t allocated_bytes, + // RTP light). The `virtual_packet_length_bytes` tells what size the packet + // had on wire, including the now discarded payload. + Packet(rtc::CopyOnWriteBuffer packet, size_t virtual_packet_length_bytes, double time_ms, - const RtpUtility::RtpHeaderParser& parser, const RtpHeaderExtensionMap* extension_map = nullptr); + Packet(rtc::CopyOnWriteBuffer packet, + double time_ms, + const RtpHeaderExtensionMap* extension_map = nullptr) + : Packet(packet, packet.size(), time_ms, extension_map) {} + // Same as above, but creates the packet from an already parsed RTPHeader. // This is typically used when reading RTP dump files that only contain the - // RTP headers, and no payload. The |virtual_packet_length_bytes| tells what + // RTP headers, and no payload. The `virtual_packet_length_bytes` tells what // size the packet had on wire, including the now discarded payload, - // The |virtual_payload_length_bytes| tells the size of the payload. + // The `virtual_payload_length_bytes` tells the size of the payload. Packet(const RTPHeader& header, size_t virtual_packet_length_bytes, size_t virtual_payload_length_bytes, double time_ms); - // The following constructors are the same as the first two, but without a - // parser. Note that when the object is constructed using any of these - // methods, the header will be parsed using a default RtpHeaderParser object. - // In particular, RTP header extensions won't be parsed. - Packet(uint8_t* packet_memory, size_t allocated_bytes, double time_ms); - - Packet(uint8_t* packet_memory, - size_t allocated_bytes, - size_t virtual_packet_length_bytes, - double time_ms); - virtual ~Packet(); // Parses the first bytes of the RTP payload, interpreting them as RED headers @@ -80,11 +64,11 @@ class Packet { // itself. static void DeleteRedHeaders(std::list* headers); - const uint8_t* payload() const { return payload_; } + const uint8_t* payload() const { return rtp_payload_.data(); } - size_t packet_length_bytes() const { return packet_length_bytes_; } + size_t packet_length_bytes() const { return packet_.size(); } - size_t payload_length_bytes() const { return payload_length_bytes_; } + size_t payload_length_bytes() const { return rtp_payload_.size(); } size_t virtual_packet_length_bytes() const { return virtual_packet_length_bytes_; @@ -100,21 +84,17 @@ class Packet { bool valid_header() const { return valid_header_; } private: - bool ParseHeader(const webrtc::RtpUtility::RtpHeaderParser& parser, - const RtpHeaderExtensionMap* extension_map); + bool ParseHeader(const RtpHeaderExtensionMap* extension_map); void CopyToHeader(RTPHeader* destination) const; RTPHeader header_; - const std::unique_ptr payload_memory_; - const uint8_t* payload_ = nullptr; // First byte after header. - const size_t packet_length_bytes_ = 0; // Total length of packet. - size_t payload_length_bytes_ = 0; // Length of the payload, after RTP header. - // Zero for dummy RTP packets. + const rtc::CopyOnWriteBuffer packet_; + rtc::ArrayView rtp_payload_; // Empty for dummy RTP packets. // Virtual lengths are used when parsing RTP header files (dummy RTP files). const size_t virtual_packet_length_bytes_; size_t virtual_payload_length_bytes_ = 0; const double time_ms_; // Used to denote a packet's arrival time. - const bool valid_header_; // Set by the RtpHeaderParser. + const bool valid_header_; RTC_DISALLOW_COPY_AND_ASSIGN(Packet); }; diff --git a/modules/audio_coding/neteq/tools/packet_unittest.cc b/modules/audio_coding/neteq/tools/packet_unittest.cc index 7f3d6630c3..7cc9a48ee6 100644 --- a/modules/audio_coding/neteq/tools/packet_unittest.cc +++ b/modules/audio_coding/neteq/tools/packet_unittest.cc @@ -42,16 +42,15 @@ void MakeRtpHeader(int payload_type, TEST(TestPacket, RegularPacket) { const size_t kPacketLengthBytes = 100; - uint8_t* packet_memory = new uint8_t[kPacketLengthBytes]; + rtc::CopyOnWriteBuffer packet_memory(kPacketLengthBytes); const uint8_t kPayloadType = 17; const uint16_t kSequenceNumber = 4711; const uint32_t kTimestamp = 47114711; const uint32_t kSsrc = 0x12345678; MakeRtpHeader(kPayloadType, kSequenceNumber, kTimestamp, kSsrc, - packet_memory); + packet_memory.MutableData()); const double kPacketTime = 1.0; - // Hand over ownership of |packet_memory| to |packet|. - Packet packet(packet_memory, kPacketLengthBytes, kPacketTime); + Packet packet(std::move(packet_memory), kPacketTime); ASSERT_TRUE(packet.valid_header()); EXPECT_EQ(kPayloadType, packet.header().payloadType); EXPECT_EQ(kSequenceNumber, packet.header().sequenceNumber); @@ -70,16 +69,44 @@ TEST(TestPacket, RegularPacket) { TEST(TestPacket, DummyPacket) { const size_t kPacketLengthBytes = kHeaderLengthBytes; // Only RTP header. const size_t kVirtualPacketLengthBytes = 100; - uint8_t* packet_memory = new uint8_t[kPacketLengthBytes]; + rtc::CopyOnWriteBuffer packet_memory(kPacketLengthBytes); const uint8_t kPayloadType = 17; const uint16_t kSequenceNumber = 4711; const uint32_t kTimestamp = 47114711; const uint32_t kSsrc = 0x12345678; MakeRtpHeader(kPayloadType, kSequenceNumber, kTimestamp, kSsrc, - packet_memory); + packet_memory.MutableData()); const double kPacketTime = 1.0; - // Hand over ownership of |packet_memory| to |packet|. - Packet packet(packet_memory, kPacketLengthBytes, kVirtualPacketLengthBytes, + Packet packet(std::move(packet_memory), kVirtualPacketLengthBytes, + kPacketTime); + ASSERT_TRUE(packet.valid_header()); + EXPECT_EQ(kPayloadType, packet.header().payloadType); + EXPECT_EQ(kSequenceNumber, packet.header().sequenceNumber); + EXPECT_EQ(kTimestamp, packet.header().timestamp); + EXPECT_EQ(kSsrc, packet.header().ssrc); + EXPECT_EQ(0, packet.header().numCSRCs); + EXPECT_EQ(kPacketLengthBytes, packet.packet_length_bytes()); + EXPECT_EQ(kPacketLengthBytes - kHeaderLengthBytes, + packet.payload_length_bytes()); + EXPECT_EQ(kVirtualPacketLengthBytes, packet.virtual_packet_length_bytes()); + EXPECT_EQ(kVirtualPacketLengthBytes - kHeaderLengthBytes, + packet.virtual_payload_length_bytes()); + EXPECT_EQ(kPacketTime, packet.time_ms()); +} + +TEST(TestPacket, DummyPaddingPacket) { + const size_t kPacketLengthBytes = kHeaderLengthBytes; // Only RTP header. + const size_t kVirtualPacketLengthBytes = 100; + rtc::CopyOnWriteBuffer packet_memory(kPacketLengthBytes); + const uint8_t kPayloadType = 17; + const uint16_t kSequenceNumber = 4711; + const uint32_t kTimestamp = 47114711; + const uint32_t kSsrc = 0x12345678; + MakeRtpHeader(kPayloadType, kSequenceNumber, kTimestamp, kSsrc, + packet_memory.MutableData()); + packet_memory.MutableData()[0] |= 0b0010'0000; // Set the padding bit. + const double kPacketTime = 1.0; + Packet packet(std::move(packet_memory), kVirtualPacketLengthBytes, kPacketTime); ASSERT_TRUE(packet.valid_header()); EXPECT_EQ(kPayloadType, packet.header().payloadType); @@ -133,19 +160,19 @@ int MakeRedHeader(int payload_type, TEST(TestPacket, RED) { const size_t kPacketLengthBytes = 100; - uint8_t* packet_memory = new uint8_t[kPacketLengthBytes]; + rtc::CopyOnWriteBuffer packet_memory(kPacketLengthBytes); const uint8_t kRedPayloadType = 17; const uint16_t kSequenceNumber = 4711; const uint32_t kTimestamp = 47114711; const uint32_t kSsrc = 0x12345678; MakeRtpHeader(kRedPayloadType, kSequenceNumber, kTimestamp, kSsrc, - packet_memory); + packet_memory.MutableData()); // Create four RED headers. // Payload types are just the same as the block index the offset is 100 times // the block index. const int kRedBlocks = 4; - uint8_t* payload_ptr = - &packet_memory[kHeaderLengthBytes]; // First byte after header. + uint8_t* payload_ptr = packet_memory.MutableData() + + kHeaderLengthBytes; // First byte after header. for (int i = 0; i < kRedBlocks; ++i) { int payload_type = i; // Offset value is not used for the last block. diff --git a/modules/audio_coding/neteq/tools/rtp_analyze.cc b/modules/audio_coding/neteq/tools/rtp_analyze.cc index dad3750940..46fc2d744e 100644 --- a/modules/audio_coding/neteq/tools/rtp_analyze.cc +++ b/modules/audio_coding/neteq/tools/rtp_analyze.cc @@ -56,7 +56,7 @@ int main(int argc, char* argv[]) { printf("Input file: %s\n", args[1]); std::unique_ptr file_source( webrtc::test::RtpFileSource::Create(args[1])); - assert(file_source.get()); + RTC_DCHECK(file_source.get()); // Set RTP extension IDs. bool print_audio_level = false; if (absl::GetFlag(FLAGS_audio_level) != -1) { @@ -151,7 +151,7 @@ int main(int argc, char* argv[]) { packet->ExtractRedHeaders(&red_headers); while (!red_headers.empty()) { webrtc::RTPHeader* red = red_headers.front(); - assert(red); + RTC_DCHECK(red); fprintf(out_file, "* %5u %10u %10u %5i\n", red->sequenceNumber, red->timestamp, static_cast(packet->time_ms()), red->payloadType); diff --git a/modules/audio_coding/neteq/tools/rtp_file_source.cc b/modules/audio_coding/neteq/tools/rtp_file_source.cc index 78523308e3..16b225e5df 100644 --- a/modules/audio_coding/neteq/tools/rtp_file_source.cc +++ b/modules/audio_coding/neteq/tools/rtp_file_source.cc @@ -62,12 +62,9 @@ std::unique_ptr RtpFileSource::NextPacket() { // Read the next one. continue; } - std::unique_ptr packet_memory(new uint8_t[temp_packet.length]); - memcpy(packet_memory.get(), temp_packet.data, temp_packet.length); - RtpUtility::RtpHeaderParser parser(packet_memory.get(), temp_packet.length); auto packet = std::make_unique( - packet_memory.release(), temp_packet.length, - temp_packet.original_length, temp_packet.time_ms, parser, + rtc::CopyOnWriteBuffer(temp_packet.data, temp_packet.length), + temp_packet.original_length, temp_packet.time_ms, &rtp_header_extension_map_); if (!packet->valid_header()) { continue; diff --git a/modules/audio_coding/neteq/tools/rtp_generator.cc b/modules/audio_coding/neteq/tools/rtp_generator.cc index accd1635b5..a37edef20a 100644 --- a/modules/audio_coding/neteq/tools/rtp_generator.cc +++ b/modules/audio_coding/neteq/tools/rtp_generator.cc @@ -18,7 +18,7 @@ namespace test { uint32_t RtpGenerator::GetRtpHeader(uint8_t payload_type, size_t payload_length_samples, RTPHeader* rtp_header) { - assert(rtp_header); + RTC_DCHECK(rtp_header); if (!rtp_header) { return 0; } @@ -31,7 +31,7 @@ uint32_t RtpGenerator::GetRtpHeader(uint8_t payload_type, rtp_header->numCSRCs = 0; uint32_t this_send_time = next_send_time_ms_; - assert(samples_per_ms_ > 0); + RTC_DCHECK_GT(samples_per_ms_, 0); next_send_time_ms_ += ((1.0 + drift_factor_) * payload_length_samples) / samples_per_ms_; return this_send_time; diff --git a/modules/audio_coding/test/Channel.cc b/modules/audio_coding/test/Channel.cc index 9456145d8c..d7bd6a968b 100644 --- a/modules/audio_coding/test/Channel.cc +++ b/modules/audio_coding/test/Channel.cc @@ -125,7 +125,7 @@ void Channel::CalcStatistics(const RTPHeader& rtp_header, size_t payloadSize) { (uint32_t)((uint32_t)rtp_header.timestamp - (uint32_t)currentPayloadStr->lastTimestamp); } - assert(_lastFrameSizeSample > 0); + RTC_DCHECK_GT(_lastFrameSizeSample, 0); int k = 0; for (; k < MAX_NUM_FRAMESIZES; ++k) { if ((currentPayloadStr->frameSizeStats[k].frameSizeSample == diff --git a/modules/audio_device/BUILD.gn b/modules/audio_device/BUILD.gn index 4f701e4be8..5d6a1d82fc 100644 --- a/modules/audio_device/BUILD.gn +++ b/modules/audio_device/BUILD.gn @@ -51,7 +51,6 @@ rtc_source_set("audio_device_api") { "../../api:scoped_refptr", "../../api/task_queue", "../../rtc_base:checks", - "../../rtc_base:deprecation", "../../rtc_base:rtc_base_approved", "../../rtc_base:stringutils", ] @@ -68,6 +67,7 @@ rtc_library("audio_device_buffer") { deps = [ ":audio_device_api", "../../api:array_view", + "../../api:sequence_checker", "../../api/task_queue", "../../common_audio:common_audio_c", "../../rtc_base:checks", @@ -164,12 +164,12 @@ rtc_library("audio_device_impl") { "../../api:array_view", "../../api:refcountedbase", "../../api:scoped_refptr", + "../../api:sequence_checker", "../../api/task_queue", "../../common_audio", "../../common_audio:common_audio_c", "../../rtc_base", "../../rtc_base:checks", - "../../rtc_base:deprecation", "../../rtc_base:rtc_base_approved", "../../rtc_base:rtc_task_queue", "../../rtc_base/synchronization:mutex", @@ -181,6 +181,7 @@ rtc_library("audio_device_impl") { "../../system_wrappers:metrics", "../utility", ] + absl_deps = [ "//third_party/abseil-cpp/absl/base:core_headers" ] if (rtc_include_internal_audio_device && is_ios) { deps += [ "../../sdk:audio_device" ] } @@ -351,6 +352,7 @@ if (is_mac) { } rtc_source_set("mock_audio_device") { + visibility = [ "*" ] testonly = true sources = [ "include/mock_audio_device.h", @@ -366,7 +368,7 @@ rtc_source_set("mock_audio_device") { ] } -if (rtc_include_tests) { +if (rtc_include_tests && !build_with_chromium) { rtc_library("audio_device_unittests") { testonly = true @@ -381,6 +383,7 @@ if (rtc_include_tests) { ":mock_audio_device", "../../api:array_view", "../../api:scoped_refptr", + "../../api:sequence_checker", "../../api/task_queue", "../../api/task_queue:default_task_queue_factory", "../../common_audio", diff --git a/modules/audio_device/android/aaudio_player.h b/modules/audio_device/android/aaudio_player.h index 820d279d6e..9e9182aed8 100644 --- a/modules/audio_device/android/aaudio_player.h +++ b/modules/audio_device/android/aaudio_player.h @@ -15,12 +15,12 @@ #include +#include "api/sequence_checker.h" #include "modules/audio_device/android/aaudio_wrapper.h" #include "modules/audio_device/include/audio_device_defines.h" #include "rtc_base/message_handler.h" #include "rtc_base/thread.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -95,12 +95,12 @@ class AAudioPlayer final : public AAudioObserverInterface, // Ensures that methods are called from the same thread as this object is // created on. - rtc::ThreadChecker main_thread_checker_; + SequenceChecker main_thread_checker_; // Stores thread ID in first call to AAudioPlayer::OnDataCallback from a // real-time thread owned by AAudio. Detached during construction of this // object. - rtc::ThreadChecker thread_checker_aaudio_; + SequenceChecker thread_checker_aaudio_; // The thread on which this object is created on. rtc::Thread* main_thread_; diff --git a/modules/audio_device/android/aaudio_recorder.h b/modules/audio_device/android/aaudio_recorder.h index d9427e2aec..bbf2cacf9b 100644 --- a/modules/audio_device/android/aaudio_recorder.h +++ b/modules/audio_device/android/aaudio_recorder.h @@ -15,11 +15,11 @@ #include +#include "api/sequence_checker.h" #include "modules/audio_device/android/aaudio_wrapper.h" #include "modules/audio_device/include/audio_device_defines.h" #include "rtc_base/message_handler.h" #include "rtc_base/thread.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -88,12 +88,12 @@ class AAudioRecorder : public AAudioObserverInterface, // Ensures that methods are called from the same thread as this object is // created on. - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; // Stores thread ID in first call to AAudioPlayer::OnDataCallback from a // real-time thread owned by AAudio. Detached during construction of this // object. - rtc::ThreadChecker thread_checker_aaudio_; + SequenceChecker thread_checker_aaudio_; // The thread on which this object is created on. rtc::Thread* main_thread_; diff --git a/modules/audio_device/android/aaudio_wrapper.h b/modules/audio_device/android/aaudio_wrapper.h index 4915092149..1f925b96d3 100644 --- a/modules/audio_device/android/aaudio_wrapper.h +++ b/modules/audio_device/android/aaudio_wrapper.h @@ -13,8 +13,8 @@ #include +#include "api/sequence_checker.h" #include "modules/audio_device/include/audio_device_defines.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -113,8 +113,8 @@ class AAudioWrapper { bool VerifyStreamConfiguration(); bool OptimizeBuffers(); - rtc::ThreadChecker thread_checker_; - rtc::ThreadChecker aaudio_thread_checker_; + SequenceChecker thread_checker_; + SequenceChecker aaudio_thread_checker_; AudioParameters audio_parameters_; const aaudio_direction_t direction_; AAudioObserverInterface* observer_ = nullptr; diff --git a/modules/audio_device/android/audio_device_template.h b/modules/audio_device/android/audio_device_template.h index fb5bf6fa59..3ea248f79e 100644 --- a/modules/audio_device/android/audio_device_template.h +++ b/modules/audio_device/android/audio_device_template.h @@ -11,11 +11,11 @@ #ifndef MODULES_AUDIO_DEVICE_ANDROID_AUDIO_DEVICE_TEMPLATE_H_ #define MODULES_AUDIO_DEVICE_ANDROID_AUDIO_DEVICE_TEMPLATE_H_ +#include "api/sequence_checker.h" #include "modules/audio_device/android/audio_manager.h" #include "modules/audio_device/audio_device_generic.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -39,7 +39,7 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { output_(audio_manager_), input_(audio_manager_), initialized_(false) { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_CHECK(audio_manager); audio_manager_->SetActiveAudioLayer(audio_layer); } @@ -48,13 +48,13 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { int32_t ActiveAudioLayer( AudioDeviceModule::AudioLayer& audioLayer) const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; audioLayer = audio_layer_; return 0; } InitStatus Init() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK(thread_checker_.IsCurrent()); RTC_DCHECK(!initialized_); if (!audio_manager_->Init()) { @@ -74,7 +74,7 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { } int32_t Terminate() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK(thread_checker_.IsCurrent()); int32_t err = input_.Terminate(); err |= output_.Terminate(); @@ -85,18 +85,18 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { } bool Initialized() const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK(thread_checker_.IsCurrent()); return initialized_; } int16_t PlayoutDevices() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return 1; } int16_t RecordingDevices() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return 1; } @@ -115,7 +115,7 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { int32_t SetPlayoutDevice(uint16_t index) override { // OK to use but it has no effect currently since device selection is // done using Andoid APIs instead. - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return 0; } @@ -127,7 +127,7 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { int32_t SetRecordingDevice(uint16_t index) override { // OK to use but it has no effect currently since device selection is // done using Andoid APIs instead. - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return 0; } @@ -137,39 +137,39 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { } int32_t PlayoutIsAvailable(bool& available) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; available = true; return 0; } int32_t InitPlayout() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return output_.InitPlayout(); } bool PlayoutIsInitialized() const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return output_.PlayoutIsInitialized(); } int32_t RecordingIsAvailable(bool& available) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; available = true; return 0; } int32_t InitRecording() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return input_.InitRecording(); } bool RecordingIsInitialized() const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return input_.RecordingIsInitialized(); } int32_t StartPlayout() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; if (!audio_manager_->IsCommunicationModeEnabled()) { RTC_LOG(WARNING) << "The application should use MODE_IN_COMMUNICATION audio mode!"; @@ -181,7 +181,7 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { // Avoid using audio manger (JNI/Java cost) if playout was inactive. if (!Playing()) return 0; - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; int32_t err = output_.StopPlayout(); return err; } @@ -192,7 +192,7 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { } int32_t StartRecording() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; if (!audio_manager_->IsCommunicationModeEnabled()) { RTC_LOG(WARNING) << "The application should use MODE_IN_COMMUNICATION audio mode!"; @@ -202,7 +202,7 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { int32_t StopRecording() override { // Avoid using audio manger (JNI/Java cost) if recording was inactive. - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; if (!Recording()) return 0; int32_t err = input_.StopRecording(); @@ -212,47 +212,47 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { bool Recording() const override { return input_.Recording(); } int32_t InitSpeaker() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return 0; } bool SpeakerIsInitialized() const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return true; } int32_t InitMicrophone() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return 0; } bool MicrophoneIsInitialized() const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return true; } int32_t SpeakerVolumeIsAvailable(bool& available) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return output_.SpeakerVolumeIsAvailable(available); } int32_t SetSpeakerVolume(uint32_t volume) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return output_.SetSpeakerVolume(volume); } int32_t SpeakerVolume(uint32_t& volume) const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return output_.SpeakerVolume(volume); } int32_t MaxSpeakerVolume(uint32_t& maxVolume) const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return output_.MaxSpeakerVolume(maxVolume); } int32_t MinSpeakerVolume(uint32_t& minVolume) const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return output_.MinSpeakerVolume(minVolume); } @@ -299,13 +299,13 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { // Returns true if the audio manager has been configured to support stereo // and false otherwised. Default is mono. int32_t StereoPlayoutIsAvailable(bool& available) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; available = audio_manager_->IsStereoPlayoutSupported(); return 0; } int32_t SetStereoPlayout(bool enable) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; bool available = audio_manager_->IsStereoPlayoutSupported(); // Android does not support changes between mono and stero on the fly. // Instead, the native audio layer is configured via the audio manager @@ -320,13 +320,13 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { } int32_t StereoRecordingIsAvailable(bool& available) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; available = audio_manager_->IsStereoRecordSupported(); return 0; } int32_t SetStereoRecording(bool enable) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; bool available = audio_manager_->IsStereoRecordSupported(); // Android does not support changes between mono and stero on the fly. // Instead, the native audio layer is configured via the audio manager @@ -336,7 +336,7 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { } int32_t StereoRecording(bool& enabled) const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; enabled = audio_manager_->IsStereoRecordSupported(); return 0; } @@ -349,7 +349,7 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { } void AttachAudioBuffer(AudioDeviceBuffer* audioBuffer) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; output_.AttachAudioBuffer(audioBuffer); input_.AttachAudioBuffer(audioBuffer); } @@ -367,13 +367,13 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { // a "Not Implemented" log will be filed. This non-perfect state will remain // until I have added full support for audio effects based on OpenSL ES APIs. bool BuiltInAECIsAvailable() const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return audio_manager_->IsAcousticEchoCancelerSupported(); } // TODO(henrika): add implementation for OpenSL ES based audio as well. int32_t EnableBuiltInAEC(bool enable) override { - RTC_LOG(INFO) << __FUNCTION__ << "(" << enable << ")"; + RTC_DLOG(INFO) << __FUNCTION__ << "(" << enable << ")"; RTC_CHECK(BuiltInAECIsAvailable()) << "HW AEC is not available"; return input_.EnableBuiltInAEC(enable); } @@ -383,13 +383,13 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { // TODO(henrika): add implementation for OpenSL ES based audio as well. // In addition, see comments for BuiltInAECIsAvailable(). bool BuiltInAGCIsAvailable() const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return audio_manager_->IsAutomaticGainControlSupported(); } // TODO(henrika): add implementation for OpenSL ES based audio as well. int32_t EnableBuiltInAGC(bool enable) override { - RTC_LOG(INFO) << __FUNCTION__ << "(" << enable << ")"; + RTC_DLOG(INFO) << __FUNCTION__ << "(" << enable << ")"; RTC_CHECK(BuiltInAGCIsAvailable()) << "HW AGC is not available"; return input_.EnableBuiltInAGC(enable); } @@ -399,19 +399,19 @@ class AudioDeviceTemplate : public AudioDeviceGeneric { // TODO(henrika): add implementation for OpenSL ES based audio as well. // In addition, see comments for BuiltInAECIsAvailable(). bool BuiltInNSIsAvailable() const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return audio_manager_->IsNoiseSuppressorSupported(); } // TODO(henrika): add implementation for OpenSL ES based audio as well. int32_t EnableBuiltInNS(bool enable) override { - RTC_LOG(INFO) << __FUNCTION__ << "(" << enable << ")"; + RTC_DLOG(INFO) << __FUNCTION__ << "(" << enable << ")"; RTC_CHECK(BuiltInNSIsAvailable()) << "HW NS is not available"; return input_.EnableBuiltInNS(enable); } private: - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; // Local copy of the audio layer set during construction of the // AudioDeviceModuleImpl instance. Read only value. diff --git a/modules/audio_device/android/audio_manager.h b/modules/audio_device/android/audio_manager.h index d1debdb415..900fc78a68 100644 --- a/modules/audio_device/android/audio_manager.h +++ b/modules/audio_device/android/audio_manager.h @@ -16,6 +16,7 @@ #include +#include "api/sequence_checker.h" #include "modules/audio_device/android/audio_common.h" #include "modules/audio_device/android/opensles_common.h" #include "modules/audio_device/audio_device_config.h" @@ -23,7 +24,6 @@ #include "modules/audio_device/include/audio_device_defines.h" #include "modules/utility/include/helpers_android.h" #include "modules/utility/include/jvm_android.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -158,9 +158,9 @@ class AudioManager { jint input_buffer_size); // Stores thread ID in the constructor. - // We can then use ThreadChecker::IsCurrent() to ensure that + // We can then use RTC_DCHECK_RUN_ON(&thread_checker_) to ensure that // other methods are called from the same thread. - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; // Calls JavaVM::AttachCurrentThread() if this thread is not attached at // construction. diff --git a/modules/audio_device/android/audio_record_jni.h b/modules/audio_device/android/audio_record_jni.h index 102f29ab1a..c445360d6c 100644 --- a/modules/audio_device/android/audio_record_jni.h +++ b/modules/audio_device/android/audio_record_jni.h @@ -15,12 +15,12 @@ #include +#include "api/sequence_checker.h" #include "modules/audio_device/android/audio_manager.h" #include "modules/audio_device/audio_device_generic.h" #include "modules/audio_device/include/audio_device_defines.h" #include "modules/utility/include/helpers_android.h" #include "modules/utility/include/jvm_android.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -110,11 +110,11 @@ class AudioRecordJni { void OnDataIsRecorded(int length); // Stores thread ID in constructor. - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; // Stores thread ID in first call to OnDataIsRecorded() from high-priority // thread in Java. Detached during construction of this object. - rtc::ThreadChecker thread_checker_java_; + SequenceChecker thread_checker_java_; // Calls JavaVM::AttachCurrentThread() if this thread is not attached at // construction. diff --git a/modules/audio_device/android/audio_track_jni.h b/modules/audio_device/android/audio_track_jni.h index 529a9013e8..62bcba42b0 100644 --- a/modules/audio_device/android/audio_track_jni.h +++ b/modules/audio_device/android/audio_track_jni.h @@ -15,13 +15,13 @@ #include +#include "api/sequence_checker.h" #include "modules/audio_device/android/audio_common.h" #include "modules/audio_device/android/audio_manager.h" #include "modules/audio_device/audio_device_generic.h" #include "modules/audio_device/include/audio_device_defines.h" #include "modules/utility/include/helpers_android.h" #include "modules/utility/include/jvm_android.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -109,11 +109,11 @@ class AudioTrackJni { void OnGetPlayoutData(size_t length); // Stores thread ID in constructor. - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; // Stores thread ID in first call to OnGetPlayoutData() from high-priority // thread in Java. Detached during construction of this object. - rtc::ThreadChecker thread_checker_java_; + SequenceChecker thread_checker_java_; // Calls JavaVM::AttachCurrentThread() if this thread is not attached at // construction. diff --git a/modules/audio_device/android/opensles_player.h b/modules/audio_device/android/opensles_player.h index 20107585a6..78af29b6b6 100644 --- a/modules/audio_device/android/opensles_player.h +++ b/modules/audio_device/android/opensles_player.h @@ -15,13 +15,13 @@ #include #include +#include "api/sequence_checker.h" #include "modules/audio_device/android/audio_common.h" #include "modules/audio_device/android/audio_manager.h" #include "modules/audio_device/android/opensles_common.h" #include "modules/audio_device/audio_device_generic.h" #include "modules/audio_device/include/audio_device_defines.h" #include "modules/utility/include/helpers_android.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -113,12 +113,12 @@ class OpenSLESPlayer { // Ensures that methods are called from the same thread as this object is // created on. - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; // Stores thread ID in first call to SimpleBufferQueueCallback() from internal // non-application thread which is not attached to the Dalvik JVM. // Detached during construction of this object. - rtc::ThreadChecker thread_checker_opensles_; + SequenceChecker thread_checker_opensles_; // Raw pointer to the audio manager injected at construction. Used to cache // audio parameters and to access the global SL engine object needed by the diff --git a/modules/audio_device/android/opensles_recorder.h b/modules/audio_device/android/opensles_recorder.h index ee1ede51d5..5f975d7242 100644 --- a/modules/audio_device/android/opensles_recorder.h +++ b/modules/audio_device/android/opensles_recorder.h @@ -17,13 +17,13 @@ #include +#include "api/sequence_checker.h" #include "modules/audio_device/android/audio_common.h" #include "modules/audio_device/android/audio_manager.h" #include "modules/audio_device/android/opensles_common.h" #include "modules/audio_device/audio_device_generic.h" #include "modules/audio_device/include/audio_device_defines.h" #include "modules/utility/include/helpers_android.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -123,12 +123,12 @@ class OpenSLESRecorder { // Ensures that methods are called from the same thread as this object is // created on. - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; // Stores thread ID in first call to SimpleBufferQueueCallback() from internal // non-application thread which is not attached to the Dalvik JVM. // Detached during construction of this object. - rtc::ThreadChecker thread_checker_opensles_; + SequenceChecker thread_checker_opensles_; // Raw pointer to the audio manager injected at construction. Used to cache // audio parameters and to access the global SL engine object needed by the diff --git a/modules/audio_device/audio_device_buffer.cc b/modules/audio_device/audio_device_buffer.cc index 520976482c..977045419a 100644 --- a/modules/audio_device/audio_device_buffer.cc +++ b/modules/audio_device/audio_device_buffer.cc @@ -78,7 +78,7 @@ AudioDeviceBuffer::~AudioDeviceBuffer() { int32_t AudioDeviceBuffer::RegisterAudioCallback( AudioTransport* audio_callback) { RTC_DCHECK_RUN_ON(&main_thread_checker_); - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; if (playing_ || recording_) { RTC_LOG(LS_ERROR) << "Failed to set audio transport since media was active"; return -1; @@ -95,7 +95,7 @@ void AudioDeviceBuffer::StartPlayout() { if (playing_) { return; } - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; // Clear members tracking playout stats and do it on the task queue. task_queue_.PostTask([this] { ResetPlayStats(); }); // Start a periodic timer based on task queue if not already done by the @@ -114,7 +114,7 @@ void AudioDeviceBuffer::StartRecording() { if (recording_) { return; } - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; // Clear members tracking recording stats and do it on the task queue. task_queue_.PostTask([this] { ResetRecStats(); }); // Start a periodic timer based on task queue if not already done by the @@ -136,7 +136,7 @@ void AudioDeviceBuffer::StopPlayout() { if (!playing_) { return; } - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; playing_ = false; // Stop periodic logging if no more media is active. if (!recording_) { @@ -150,7 +150,7 @@ void AudioDeviceBuffer::StopRecording() { if (!recording_) { return; } - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; recording_ = false; // Stop periodic logging if no more media is active. if (!playing_) { diff --git a/modules/audio_device/audio_device_buffer.h b/modules/audio_device/audio_device_buffer.h index 37b8a2ec5e..a0b7953194 100644 --- a/modules/audio_device/audio_device_buffer.h +++ b/modules/audio_device/audio_device_buffer.h @@ -16,13 +16,13 @@ #include +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_factory.h" #include "modules/audio_device/include/audio_device_defines.h" #include "rtc_base/buffer.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/task_queue.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -140,7 +140,7 @@ class AudioDeviceBuffer { // TODO(henrika): see if it is possible to refactor and annotate all members. // Main thread on which this object is created. - rtc::ThreadChecker main_thread_checker_; + SequenceChecker main_thread_checker_; Mutex lock_; diff --git a/modules/audio_device/audio_device_data_observer.cc b/modules/audio_device/audio_device_data_observer.cc index 89265a288f..be78fd16d7 100644 --- a/modules/audio_device/audio_device_data_observer.cc +++ b/modules/audio_device/audio_device_data_observer.cc @@ -301,9 +301,8 @@ class ADMWrapper : public AudioDeviceModule, public AudioTransport { rtc::scoped_refptr CreateAudioDeviceWithDataObserver( rtc::scoped_refptr impl, std::unique_ptr observer) { - rtc::scoped_refptr audio_device( - new rtc::RefCountedObject(impl, observer.get(), - std::move(observer))); + auto audio_device = rtc::make_ref_counted(impl, observer.get(), + std::move(observer)); if (!audio_device->IsValid()) { return nullptr; @@ -315,8 +314,8 @@ rtc::scoped_refptr CreateAudioDeviceWithDataObserver( rtc::scoped_refptr CreateAudioDeviceWithDataObserver( rtc::scoped_refptr impl, AudioDeviceDataObserver* legacy_observer) { - rtc::scoped_refptr audio_device( - new rtc::RefCountedObject(impl, legacy_observer, nullptr)); + auto audio_device = + rtc::make_ref_counted(impl, legacy_observer, nullptr); if (!audio_device->IsValid()) { return nullptr; @@ -329,10 +328,8 @@ rtc::scoped_refptr CreateAudioDeviceWithDataObserver( AudioDeviceModule::AudioLayer audio_layer, TaskQueueFactory* task_queue_factory, std::unique_ptr observer) { - rtc::scoped_refptr audio_device( - new rtc::RefCountedObject(audio_layer, task_queue_factory, - observer.get(), - std::move(observer))); + auto audio_device = rtc::make_ref_counted( + audio_layer, task_queue_factory, observer.get(), std::move(observer)); if (!audio_device->IsValid()) { return nullptr; @@ -345,9 +342,8 @@ rtc::scoped_refptr CreateAudioDeviceWithDataObserver( AudioDeviceModule::AudioLayer audio_layer, TaskQueueFactory* task_queue_factory, AudioDeviceDataObserver* legacy_observer) { - rtc::scoped_refptr audio_device( - new rtc::RefCountedObject(audio_layer, task_queue_factory, - legacy_observer, nullptr)); + auto audio_device = rtc::make_ref_counted( + audio_layer, task_queue_factory, legacy_observer, nullptr); if (!audio_device->IsValid()) { return nullptr; diff --git a/modules/audio_device/audio_device_impl.cc b/modules/audio_device/audio_device_impl.cc index b410654a14..84460ff83f 100644 --- a/modules/audio_device/audio_device_impl.cc +++ b/modules/audio_device/audio_device_impl.cc @@ -73,7 +73,7 @@ namespace webrtc { rtc::scoped_refptr AudioDeviceModule::Create( AudioLayer audio_layer, TaskQueueFactory* task_queue_factory) { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; return AudioDeviceModule::CreateForTest(audio_layer, task_queue_factory); } @@ -81,7 +81,7 @@ rtc::scoped_refptr AudioDeviceModule::Create( rtc::scoped_refptr AudioDeviceModule::CreateForTest( AudioLayer audio_layer, TaskQueueFactory* task_queue_factory) { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; // The "AudioDeviceModule::kWindowsCoreAudio2" audio layer has its own // dedicated factory method which should be used instead. @@ -92,38 +92,37 @@ rtc::scoped_refptr AudioDeviceModule::CreateForTest( } // Create the generic reference counted (platform independent) implementation. - rtc::scoped_refptr audioDevice( - new rtc::RefCountedObject(audio_layer, - task_queue_factory)); + auto audio_device = rtc::make_ref_counted( + audio_layer, task_queue_factory); // Ensure that the current platform is supported. - if (audioDevice->CheckPlatform() == -1) { + if (audio_device->CheckPlatform() == -1) { return nullptr; } // Create the platform-dependent implementation. - if (audioDevice->CreatePlatformSpecificObjects() == -1) { + if (audio_device->CreatePlatformSpecificObjects() == -1) { return nullptr; } // Ensure that the generic audio buffer can communicate with the platform // specific parts. - if (audioDevice->AttachAudioBuffer() == -1) { + if (audio_device->AttachAudioBuffer() == -1) { return nullptr; } - return audioDevice; + return audio_device; } AudioDeviceModuleImpl::AudioDeviceModuleImpl( AudioLayer audio_layer, TaskQueueFactory* task_queue_factory) : audio_layer_(audio_layer), audio_device_buffer_(task_queue_factory) { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; } int32_t AudioDeviceModuleImpl::CheckPlatform() { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; // Ensure that the current platform is supported PlatformType platform(kPlatformNotSupported); #if defined(_WIN32) diff --git a/modules/audio_device/audio_device_unittest.cc b/modules/audio_device/audio_device_unittest.cc index 709b191b9f..b0af9521c6 100644 --- a/modules/audio_device/audio_device_unittest.cc +++ b/modules/audio_device/audio_device_unittest.cc @@ -19,6 +19,7 @@ #include "absl/types/optional.h" #include "api/array_view.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "api/task_queue/default_task_queue_factory.h" #include "api/task_queue/task_queue_factory.h" #include "modules/audio_device/audio_device_impl.h" @@ -31,7 +32,6 @@ #include "rtc_base/race_checker.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" #include "rtc_base/time_utils.h" #include "test/gmock.h" #include "test/gtest.h" @@ -317,8 +317,8 @@ class LatencyAudioStream : public AudioStream { Mutex lock_; rtc::RaceChecker race_checker_; - rtc::ThreadChecker read_thread_checker_; - rtc::ThreadChecker write_thread_checker_; + SequenceChecker read_thread_checker_; + SequenceChecker write_thread_checker_; absl::optional pulse_time_ RTC_GUARDED_BY(lock_); std::vector latencies_ RTC_GUARDED_BY(race_checker_); diff --git a/modules/audio_device/dummy/file_audio_device.cc b/modules/audio_device/dummy/file_audio_device.cc index c68e7bba1a..e345a16c44 100644 --- a/modules/audio_device/dummy/file_audio_device.cc +++ b/modules/audio_device/dummy/file_audio_device.cc @@ -216,10 +216,13 @@ int32_t FileAudioDevice::StartPlayout() { } } - _ptrThreadPlay.reset(new rtc::PlatformThread( - PlayThreadFunc, this, "webrtc_audio_module_play_thread", - rtc::kRealtimePriority)); - _ptrThreadPlay->Start(); + _ptrThreadPlay = rtc::PlatformThread::SpawnJoinable( + [this] { + while (PlayThreadProcess()) { + } + }, + "webrtc_audio_module_play_thread", + rtc::ThreadAttributes().SetPriority(rtc::ThreadPriority::kRealtime)); RTC_LOG(LS_INFO) << "Started playout capture to output file: " << _outputFilename; @@ -233,10 +236,8 @@ int32_t FileAudioDevice::StopPlayout() { } // stop playout thread first - if (_ptrThreadPlay) { - _ptrThreadPlay->Stop(); - _ptrThreadPlay.reset(); - } + if (!_ptrThreadPlay.empty()) + _ptrThreadPlay.Finalize(); MutexLock lock(&mutex_); @@ -276,11 +277,13 @@ int32_t FileAudioDevice::StartRecording() { } } - _ptrThreadRec.reset(new rtc::PlatformThread( - RecThreadFunc, this, "webrtc_audio_module_capture_thread", - rtc::kRealtimePriority)); - - _ptrThreadRec->Start(); + _ptrThreadRec = rtc::PlatformThread::SpawnJoinable( + [this] { + while (RecThreadProcess()) { + } + }, + "webrtc_audio_module_capture_thread", + rtc::ThreadAttributes().SetPriority(rtc::ThreadPriority::kRealtime)); RTC_LOG(LS_INFO) << "Started recording from input file: " << _inputFilename; @@ -293,10 +296,8 @@ int32_t FileAudioDevice::StopRecording() { _recording = false; } - if (_ptrThreadRec) { - _ptrThreadRec->Stop(); - _ptrThreadRec.reset(); - } + if (!_ptrThreadRec.empty()) + _ptrThreadRec.Finalize(); MutexLock lock(&mutex_); _recordingFramesLeft = 0; @@ -439,18 +440,6 @@ void FileAudioDevice::AttachAudioBuffer(AudioDeviceBuffer* audioBuffer) { _ptrAudioBuffer->SetPlayoutChannels(0); } -void FileAudioDevice::PlayThreadFunc(void* pThis) { - FileAudioDevice* device = static_cast(pThis); - while (device->PlayThreadProcess()) { - } -} - -void FileAudioDevice::RecThreadFunc(void* pThis) { - FileAudioDevice* device = static_cast(pThis); - while (device->RecThreadProcess()) { - } -} - bool FileAudioDevice::PlayThreadProcess() { if (!_playing) { return false; diff --git a/modules/audio_device/dummy/file_audio_device.h b/modules/audio_device/dummy/file_audio_device.h index ecb3f2f533..f4a6b76586 100644 --- a/modules/audio_device/dummy/file_audio_device.h +++ b/modules/audio_device/dummy/file_audio_device.h @@ -17,14 +17,11 @@ #include #include "modules/audio_device/audio_device_generic.h" +#include "rtc_base/platform_thread.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/system/file_wrapper.h" #include "rtc_base/time_utils.h" -namespace rtc { -class PlatformThread; -} // namespace rtc - namespace webrtc { // This is a fake audio device which plays audio from a file as its microphone @@ -145,9 +142,8 @@ class FileAudioDevice : public AudioDeviceGeneric { size_t _recordingFramesIn10MS; size_t _playoutFramesIn10MS; - // TODO(pbos): Make plain members instead of pointers and stop resetting them. - std::unique_ptr _ptrThreadRec; - std::unique_ptr _ptrThreadPlay; + rtc::PlatformThread _ptrThreadRec; + rtc::PlatformThread _ptrThreadPlay; bool _playing; bool _recording; diff --git a/modules/audio_device/g3doc/audio_device_module.md b/modules/audio_device/g3doc/audio_device_module.md new file mode 100644 index 0000000000..3aa1a59d08 --- /dev/null +++ b/modules/audio_device/g3doc/audio_device_module.md @@ -0,0 +1,171 @@ +# Audio Device Module (ADM) + + + + +## Overview + +The ADM is responsible for driving input (microphone) and output (speaker) audio +in WebRTC and the API is defined in [audio_device.h][19]. + +Main functions of the ADM are: + +* Initialization and termination of native audio libraries. +* Registration of an [AudioTransport object][16] which handles audio callbacks + for audio in both directions. +* Device enumeration and selection (only for Linux, Windows and Mac OSX). +* Start/Stop physical audio streams: + * Recording audio from the selected microphone, and + * playing out audio on the selected speaker. +* Level control of the active audio streams. +* Control of built-in audio effects (Audio Echo Cancelation (AEC), Audio Gain + Control (AGC) and Noise Suppression (NS)) for Android and iOS. + +ADM implementations reside at two different locations in the WebRTC repository: +`/modules/audio_device/` and `/sdk/`. The latest implementations for [iOS][20] +and [Android][21] can be found under `/sdk/`. `/modules/audio_device/` contains +older versions for mobile platforms and also implementations for desktop +platforms such as [Linux][22], [Windows][23] and [Mac OSX][24]. This document is +focusing on the parts in `/modules/audio_device/` but implementation specific +details such as threading models are omitted to keep the descriptions as simple +as possible. + +By default, the ADM in WebRTC is created in [`WebRtcVoiceEngine::Init`][1] but +an external implementation can also be injected using +[`rtc::CreatePeerConnectionFactory`][25]. An example of where an external ADM is +injected can be found in [PeerConnectionInterfaceTest][26] where a so-called +[fake ADM][29] is utilized to avoid hardware dependency in a gtest. Clients can +also inject their own ADMs in situations where functionality is needed that is +not provided by the default implementations. + +## Background + +This section contains a historical background of the ADM API. + +The ADM interface is old and has undergone many changes over the years. It used +to be much more granular but it still contains more than 50 methods and is +implemented on several different hardware platforms. + +Some APIs are not implemented on all platforms, and functionality can be spread +out differently between the methods. + +The most up-to-date implementations of the ADM interface are for [iOS][27] and +for [Android][28]. + +Desktop version are not updated to comply with the latest +[C++ style guide](https://chromium.googlesource.com/chromium/src/+/master/styleguide/c++/c++.md) +and more work is also needed to improve the performance and stability of these +versions. + +## WebRtcVoiceEngine + +[`WebRtcVoiceEngine`][2] does not utilize all methods of the ADM but it still +serves as the best example of its architecture and how to use it. For a more +detailed view of all methods in the ADM interface, see [ADM unit tests][3]. + +Assuming that an external ADM implementation is not injected, a default - or +internal - ADM is created in [`WebRtcVoiceEngine::Init`][1] using +[`AudioDeviceModule::Create`][4]. + +Basic initialization is done using a utility method called +[`adm_helpers::Init`][5] which calls fundamental ADM APIs like: + +* [`AudiDeviceModule::Init`][6] - initializes the native audio parts required + for each platform. +* [`AudiDeviceModule::SetPlayoutDevice`][7] - specifies which speaker to use + for playing out audio using an `index` retrieved by the corresponding + enumeration method [`AudiDeviceModule::PlayoutDeviceName`][8]. +* [`AudiDeviceModule::SetRecordingDevice`][9] - specifies which microphone to + use for recording audio using an `index` retrieved by the corresponding + enumeration method which is [`AudiDeviceModule::RecordingDeviceName`][10]. +* [`AudiDeviceModule::InitSpeaker`][11] - sets up the parts of the ADM needed + to use the selected output device. +* [`AudiDeviceModule::InitMicrophone`][12] - sets up the parts of the ADM + needed to use the selected input device. +* [`AudiDeviceModule::SetStereoPlayout`][13] - enables playout in stereo if + the selected audio device supports it. +* [`AudiDeviceModule::SetStereoRecording`][14] - enables recording in stereo + if the selected audio device supports it. + +[`WebRtcVoiceEngine::Init`][1] also calls +[`AudiDeviceModule::RegisterAudioTransport`][15] to register an existing +[AudioTransport][16] implementation which handles audio callbacks in both +directions and therefore serves as the bridge between the native ADM and the +upper WebRTC layers. + +Recorded audio samples are delivered from the ADM to the `WebRtcVoiceEngine` +(who owns the `AudioTransport` object) via +[`AudioTransport::RecordedDataIsAvailable`][17]: + +``` +int32_t RecordedDataIsAvailable(const void* audioSamples, size_t nSamples, size_t nBytesPerSample, + size_t nChannels, uint32_t samplesPerSec, uint32_t totalDelayMS, + int32_t clockDrift, uint32_t currentMicLevel, bool keyPressed, + uint32_t& newMicLevel) +``` + +Decoded audio samples ready to be played out are are delivered by the +`WebRtcVoiceEngine` to the ADM, via [`AudioTransport::NeedMorePlayoutData`][18]: + +``` +int32_t NeedMorePlayData(size_t nSamples, size_t nBytesPerSample, size_t nChannels, int32_t samplesPerSec, + void* audioSamples, size_t& nSamplesOut, + int64_t* elapsed_time_ms, int64_t* ntp_time_ms) +``` + +Audio samples are 16-bit [linear PCM](https://wiki.multimedia.cx/index.php/PCM) +using regular interleaving of channels within each sample. + +`WebRtcVoiceEngine` also owns an [`AudioState`][30] member and this class is +used has helper to start and stop audio to and from the ADM. To initialize and +start recording, it calls: + +* [`AudiDeviceModule::InitRecording`][31] +* [`AudiDeviceModule::StartRecording`][32] + +and to initialize and start playout: + +* [`AudiDeviceModule::InitPlayout`][33] +* [`AudiDeviceModule::StartPlayout`][34] + +Finally, the corresponding stop methods [`AudiDeviceModule::StopRecording`][35] +and [`AudiDeviceModule::StopPlayout`][36] are called followed by +[`AudiDeviceModule::Terminate`][37]. + +[1]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/media/engine/webrtc_voice_engine.cc;l=314;drc=f7b1b95f11c74cb5369fdd528b73c70a50f2e206 +[2]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/media/engine/webrtc_voice_engine.h;l=48;drc=d15a575ec3528c252419149d35977e55269d8a41 +[3]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/audio_device_unittest.cc;l=1;drc=d15a575ec3528c252419149d35977e55269d8a41 +[4]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=46;drc=eb8c4ca608486add9800f6bfb7a8ba3cf23e738e +[5]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/media/engine/adm_helpers.h;drc=2222a80e79ae1ef5cb9510ec51d3868be75f47a2 +[6]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=62;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[7]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=77;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[8]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=69;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[9]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=79;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[10]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=72;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[11]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=99;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[12]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=101;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[13]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=130;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[14]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=133;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[15]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=59;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[16]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device_defines.h;l=34;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[17]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device_defines.h;l=36;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[18]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device_defines.h;l=48;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[19]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;drc=eb8c4ca608486add9800f6bfb7a8ba3cf23e738es +[20]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/sdk/objc/native/api/audio_device_module.h;drc=76443eafa9375374d9f1d23da2b913f2acac6ac2 +[21]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/sdk/android/src/jni/audio_device/audio_device_module.h;drc=bbeb10925eb106eeed6143ccf571bc438ec22ce1 +[22]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/linux/;drc=d15a575ec3528c252419149d35977e55269d8a41 +[23]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/win/;drc=d15a575ec3528c252419149d35977e55269d8a41 +[24]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/mac/;drc=3b68aa346a5d3483c3448852d19d91723846825c +[25]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/create_peerconnection_factory.h;l=45;drc=09ceed2165137c4bea4e02e8d3db31970d0bf273 +[26]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/pc/peer_connection_interface_unittest.cc;l=692;drc=2efb8a5ec61b1b87475d046c03d20244f53b14b6 +[27]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/sdk/objc/native/api/audio_device_module.h;drc=76443eafa9375374d9f1d23da2b913f2acac6ac2 +[28]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/sdk/android/src/jni/audio_device/audio_device_module.h;drc=bbeb10925eb106eeed6143ccf571bc438ec22ce1 +[29]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/pc/test/fake_audio_capture_module.h;l=42;drc=d15a575ec3528c252419149d35977e55269d8a41 +[30]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/audio/audio_state.h;drc=d15a575ec3528c252419149d35977e55269d8a41 +[31]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=87;drc=eb8c4ca608486add9800f6bfb7a8ba3cf23e738e +[32]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=94;drc=eb8c4ca608486add9800f6bfb7a8ba3cf23e738e +[33]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=84;drc=eb8c4ca608486add9800f6bfb7a8ba3cf23e738e +[34]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=91;drc=eb8c4ca608486add9800f6bfb7a8ba3cf23e738e +[35]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=95;drc=eb8c4ca608486add9800f6bfb7a8ba3cf23e738e +[36]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=92;drc=eb8c4ca608486add9800f6bfb7a8ba3cf23e738e +[37]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_device/include/audio_device.h;l=63;drc=eb8c4ca608486add9800f6bfb7a8ba3cf23e738e diff --git a/modules/audio_device/include/audio_device_data_observer.h b/modules/audio_device/include/audio_device_data_observer.h index e1c2035d67..b59cafcb5d 100644 --- a/modules/audio_device/include/audio_device_data_observer.h +++ b/modules/audio_device/include/audio_device_data_observer.h @@ -14,6 +14,7 @@ #include #include +#include "absl/base/attributes.h" #include "api/scoped_refptr.h" #include "api/task_queue/task_queue_factory.h" #include "modules/audio_device/include/audio_device.h" @@ -48,7 +49,7 @@ rtc::scoped_refptr CreateAudioDeviceWithDataObserver( // Creates an ADMWrapper around an ADM instance that registers // the provided AudioDeviceDataObserver. -RTC_DEPRECATED +ABSL_DEPRECATED("") rtc::scoped_refptr CreateAudioDeviceWithDataObserver( rtc::scoped_refptr impl, AudioDeviceDataObserver* observer); @@ -60,7 +61,7 @@ rtc::scoped_refptr CreateAudioDeviceWithDataObserver( std::unique_ptr observer); // Creates an ADM instance with AudioDeviceDataObserver registered. -RTC_DEPRECATED +ABSL_DEPRECATED("") rtc::scoped_refptr CreateAudioDeviceWithDataObserver( const AudioDeviceModule::AudioLayer audio_layer, TaskQueueFactory* task_queue_factory, diff --git a/modules/audio_device/include/audio_device_defines.h b/modules/audio_device/include/audio_device_defines.h index d5d4d7372e..01129a47a9 100644 --- a/modules/audio_device/include/audio_device_defines.h +++ b/modules/audio_device/include/audio_device_defines.h @@ -16,7 +16,6 @@ #include #include "rtc_base/checks.h" -#include "rtc_base/deprecation.h" #include "rtc_base/strings/string_builder.h" namespace webrtc { diff --git a/modules/audio_device/include/mock_audio_device.h b/modules/audio_device/include/mock_audio_device.h index 0ca19de156..8483aa3da8 100644 --- a/modules/audio_device/include/mock_audio_device.h +++ b/modules/audio_device/include/mock_audio_device.h @@ -23,11 +23,10 @@ namespace test { class MockAudioDeviceModule : public AudioDeviceModule { public: static rtc::scoped_refptr CreateNice() { - return new rtc::RefCountedObject< - ::testing::NiceMock>(); + return rtc::make_ref_counted<::testing::NiceMock>(); } static rtc::scoped_refptr CreateStrict() { - return new rtc::RefCountedObject< + return rtc::make_ref_counted< ::testing::StrictMock>(); } diff --git a/modules/audio_device/include/test_audio_device.cc b/modules/audio_device/include/test_audio_device.cc index 46bf216540..8351e8a405 100644 --- a/modules/audio_device/include/test_audio_device.cc +++ b/modules/audio_device/include/test_audio_device.cc @@ -447,7 +447,7 @@ rtc::scoped_refptr TestAudioDeviceModule::Create( std::unique_ptr capturer, std::unique_ptr renderer, float speed) { - return new rtc::RefCountedObject( + return rtc::make_ref_counted( task_queue_factory, std::move(capturer), std::move(renderer), speed); } diff --git a/modules/audio_device/linux/audio_device_alsa_linux.cc b/modules/audio_device/linux/audio_device_alsa_linux.cc index 84d05e0f6c..60e01e1239 100644 --- a/modules/audio_device/linux/audio_device_alsa_linux.cc +++ b/modules/audio_device/linux/audio_device_alsa_linux.cc @@ -98,7 +98,7 @@ AudioDeviceLinuxALSA::AudioDeviceLinuxALSA() _recordingDelay(0), _playoutDelay(0) { memset(_oldKeyState, 0, sizeof(_oldKeyState)); - RTC_LOG(LS_INFO) << __FUNCTION__ << " created"; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " created"; } // ---------------------------------------------------------------------------- @@ -106,7 +106,7 @@ AudioDeviceLinuxALSA::AudioDeviceLinuxALSA() // ---------------------------------------------------------------------------- AudioDeviceLinuxALSA::~AudioDeviceLinuxALSA() { - RTC_LOG(LS_INFO) << __FUNCTION__ << " destroyed"; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " destroyed"; Terminate(); @@ -178,26 +178,13 @@ int32_t AudioDeviceLinuxALSA::Terminate() { _mixerManager.Close(); // RECORDING - if (_ptrThreadRec) { - rtc::PlatformThread* tmpThread = _ptrThreadRec.release(); - mutex_.Unlock(); - - tmpThread->Stop(); - delete tmpThread; - - mutex_.Lock(); - } + mutex_.Unlock(); + _ptrThreadRec.Finalize(); // PLAYOUT - if (_ptrThreadPlay) { - rtc::PlatformThread* tmpThread = _ptrThreadPlay.release(); - mutex_.Unlock(); - - tmpThread->Stop(); - delete tmpThread; + _ptrThreadPlay.Finalize(); + mutex_.Lock(); - mutex_.Lock(); - } #if defined(WEBRTC_USE_X11) if (_XDisplay) { XCloseDisplay(_XDisplay); @@ -1040,11 +1027,13 @@ int32_t AudioDeviceLinuxALSA::StartRecording() { return -1; } // RECORDING - _ptrThreadRec.reset(new rtc::PlatformThread( - RecThreadFunc, this, "webrtc_audio_module_capture_thread", - rtc::kRealtimePriority)); - - _ptrThreadRec->Start(); + _ptrThreadRec = rtc::PlatformThread::SpawnJoinable( + [this] { + while (RecThreadProcess()) { + } + }, + "webrtc_audio_module_capture_thread", + rtc::ThreadAttributes().SetPriority(rtc::ThreadPriority::kRealtime)); errVal = LATE(snd_pcm_prepare)(_handleRecord); if (errVal < 0) { @@ -1088,10 +1077,7 @@ int32_t AudioDeviceLinuxALSA::StopRecordingLocked() { _recIsInitialized = false; _recording = false; - if (_ptrThreadRec) { - _ptrThreadRec->Stop(); - _ptrThreadRec.reset(); - } + _ptrThreadRec.Finalize(); _recordingFramesLeft = 0; if (_recordingBuffer) { @@ -1158,10 +1144,13 @@ int32_t AudioDeviceLinuxALSA::StartPlayout() { } // PLAYOUT - _ptrThreadPlay.reset(new rtc::PlatformThread( - PlayThreadFunc, this, "webrtc_audio_module_play_thread", - rtc::kRealtimePriority)); - _ptrThreadPlay->Start(); + _ptrThreadPlay = rtc::PlatformThread::SpawnJoinable( + [this] { + while (PlayThreadProcess()) { + } + }, + "webrtc_audio_module_play_thread", + rtc::ThreadAttributes().SetPriority(rtc::ThreadPriority::kRealtime)); int errVal = LATE(snd_pcm_prepare)(_handlePlayout); if (errVal < 0) { @@ -1191,10 +1180,7 @@ int32_t AudioDeviceLinuxALSA::StopPlayoutLocked() { _playing = false; // stop playout thread first - if (_ptrThreadPlay) { - _ptrThreadPlay->Stop(); - _ptrThreadPlay.reset(); - } + _ptrThreadPlay.Finalize(); _playoutFramesLeft = 0; delete[] _playoutBuffer; @@ -1469,18 +1455,6 @@ int32_t AudioDeviceLinuxALSA::ErrorRecovery(int32_t error, // Thread Methods // ============================================================================ -void AudioDeviceLinuxALSA::PlayThreadFunc(void* pThis) { - AudioDeviceLinuxALSA* device = static_cast(pThis); - while (device->PlayThreadProcess()) { - } -} - -void AudioDeviceLinuxALSA::RecThreadFunc(void* pThis) { - AudioDeviceLinuxALSA* device = static_cast(pThis); - while (device->RecThreadProcess()) { - } -} - bool AudioDeviceLinuxALSA::PlayThreadProcess() { if (!_playing) return false; @@ -1516,7 +1490,7 @@ bool AudioDeviceLinuxALSA::PlayThreadProcess() { Lock(); _playoutFramesLeft = _ptrAudioBuffer->GetPlayoutData(_playoutBuffer); - assert(_playoutFramesLeft == _playoutFramesIn10MS); + RTC_DCHECK_EQ(_playoutFramesLeft, _playoutFramesIn10MS); } if (static_cast(avail_frames) > _playoutFramesLeft) @@ -1535,7 +1509,7 @@ bool AudioDeviceLinuxALSA::PlayThreadProcess() { UnLock(); return true; } else { - assert(frames == avail_frames); + RTC_DCHECK_EQ(frames, avail_frames); _playoutFramesLeft -= frames; } @@ -1585,7 +1559,7 @@ bool AudioDeviceLinuxALSA::RecThreadProcess() { UnLock(); return true; } else if (frames > 0) { - assert(frames == avail_frames); + RTC_DCHECK_EQ(frames, avail_frames); int left_size = LATE(snd_pcm_frames_to_bytes)(_handleRecord, _recordingFramesLeft); diff --git a/modules/audio_device/linux/audio_device_alsa_linux.h b/modules/audio_device/linux/audio_device_alsa_linux.h index 410afcf42c..1f4a231640 100644 --- a/modules/audio_device/linux/audio_device_alsa_linux.h +++ b/modules/audio_device/linux/audio_device_alsa_linux.h @@ -155,10 +155,8 @@ class AudioDeviceLinuxALSA : public AudioDeviceGeneric { Mutex mutex_; - // TODO(pbos): Make plain members and start/stop instead of resetting these - // pointers. A thread can be reused. - std::unique_ptr _ptrThreadRec; - std::unique_ptr _ptrThreadPlay; + rtc::PlatformThread _ptrThreadRec; + rtc::PlatformThread _ptrThreadPlay; AudioMixerManagerLinuxALSA _mixerManager; diff --git a/modules/audio_device/linux/audio_device_pulse_linux.cc b/modules/audio_device/linux/audio_device_pulse_linux.cc index 9a7d1d0ca3..7742420fc2 100644 --- a/modules/audio_device/linux/audio_device_pulse_linux.cc +++ b/modules/audio_device/linux/audio_device_pulse_linux.cc @@ -15,6 +15,7 @@ #include "modules/audio_device/linux/latebindingsymboltable_linux.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" +#include "rtc_base/platform_thread.h" WebRTCPulseSymbolTable* GetPulseSymbolTable() { static WebRTCPulseSymbolTable* pulse_symbol_table = @@ -78,7 +79,7 @@ AudioDeviceLinuxPulse::AudioDeviceLinuxPulse() _playStream(NULL), _recStreamFlags(0), _playStreamFlags(0) { - RTC_LOG(LS_INFO) << __FUNCTION__ << " created"; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " created"; memset(_paServerVersion, 0, sizeof(_paServerVersion)); memset(&_playBufferAttr, 0, sizeof(_playBufferAttr)); @@ -87,7 +88,7 @@ AudioDeviceLinuxPulse::AudioDeviceLinuxPulse() } AudioDeviceLinuxPulse::~AudioDeviceLinuxPulse() { - RTC_LOG(LS_INFO) << __FUNCTION__ << " destroyed"; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " destroyed"; RTC_DCHECK(thread_checker_.IsCurrent()); Terminate(); @@ -158,18 +159,22 @@ AudioDeviceGeneric::InitStatus AudioDeviceLinuxPulse::Init() { #endif // RECORDING - _ptrThreadRec.reset(new rtc::PlatformThread(RecThreadFunc, this, - "webrtc_audio_module_rec_thread", - rtc::kRealtimePriority)); - - _ptrThreadRec->Start(); + const auto attributes = + rtc::ThreadAttributes().SetPriority(rtc::ThreadPriority::kRealtime); + _ptrThreadRec = rtc::PlatformThread::SpawnJoinable( + [this] { + while (RecThreadProcess()) { + } + }, + "webrtc_audio_module_rec_thread", attributes); // PLAYOUT - _ptrThreadPlay.reset(new rtc::PlatformThread( - PlayThreadFunc, this, "webrtc_audio_module_play_thread", - rtc::kRealtimePriority)); - _ptrThreadPlay->Start(); - + _ptrThreadPlay = rtc::PlatformThread::SpawnJoinable( + [this] { + while (PlayThreadProcess()) { + } + }, + "webrtc_audio_module_play_thread", attributes); _initialized = true; return InitStatus::OK; @@ -187,22 +192,12 @@ int32_t AudioDeviceLinuxPulse::Terminate() { _mixerManager.Close(); // RECORDING - if (_ptrThreadRec) { - rtc::PlatformThread* tmpThread = _ptrThreadRec.release(); - - _timeEventRec.Set(); - tmpThread->Stop(); - delete tmpThread; - } + _timeEventRec.Set(); + _ptrThreadRec.Finalize(); // PLAYOUT - if (_ptrThreadPlay) { - rtc::PlatformThread* tmpThread = _ptrThreadPlay.release(); - - _timeEventPlay.Set(); - tmpThread->Stop(); - delete tmpThread; - } + _timeEventPlay.Set(); + _ptrThreadPlay.Finalize(); // Terminate PulseAudio if (TerminatePulseAudio() < 0) { @@ -1981,18 +1976,6 @@ int32_t AudioDeviceLinuxPulse::ProcessRecordedData(int8_t* bufferData, return 0; } -void AudioDeviceLinuxPulse::PlayThreadFunc(void* pThis) { - AudioDeviceLinuxPulse* device = static_cast(pThis); - while (device->PlayThreadProcess()) { - } -} - -void AudioDeviceLinuxPulse::RecThreadFunc(void* pThis) { - AudioDeviceLinuxPulse* device = static_cast(pThis); - while (device->RecThreadProcess()) { - } -} - bool AudioDeviceLinuxPulse::PlayThreadProcess() { if (!_timeEventPlay.Wait(1000)) { return true; diff --git a/modules/audio_device/linux/audio_device_pulse_linux.h b/modules/audio_device/linux/audio_device_pulse_linux.h index 03aa16bb85..0cf89ef011 100644 --- a/modules/audio_device/linux/audio_device_pulse_linux.h +++ b/modules/audio_device/linux/audio_device_pulse_linux.h @@ -13,6 +13,7 @@ #include +#include "api/sequence_checker.h" #include "modules/audio_device/audio_device_buffer.h" #include "modules/audio_device/audio_device_generic.h" #include "modules/audio_device/include/audio_device.h" @@ -23,7 +24,6 @@ #include "rtc_base/platform_thread.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" #if defined(WEBRTC_USE_X11) #include @@ -268,9 +268,8 @@ class AudioDeviceLinuxPulse : public AudioDeviceGeneric { rtc::Event _recStartEvent; rtc::Event _playStartEvent; - // TODO(pbos): Remove unique_ptr and use directly without resetting. - std::unique_ptr _ptrThreadPlay; - std::unique_ptr _ptrThreadRec; + rtc::PlatformThread _ptrThreadPlay; + rtc::PlatformThread _ptrThreadRec; AudioMixerManagerLinuxPulse _mixerManager; @@ -284,10 +283,10 @@ class AudioDeviceLinuxPulse : public AudioDeviceGeneric { uint8_t _playChannels; // Stores thread ID in constructor. - // We can then use ThreadChecker::IsCurrent() to ensure that + // We can then use RTC_DCHECK_RUN_ON(&worker_thread_checker_) to ensure that // other methods are called from the same thread. // Currently only does RTC_DCHECK(thread_checker_.IsCurrent()). - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; bool _initialized; bool _recording; diff --git a/modules/audio_device/linux/audio_mixer_manager_alsa_linux.cc b/modules/audio_device/linux/audio_mixer_manager_alsa_linux.cc index fb9d874ef3..e7e7033173 100644 --- a/modules/audio_device/linux/audio_mixer_manager_alsa_linux.cc +++ b/modules/audio_device/linux/audio_mixer_manager_alsa_linux.cc @@ -27,14 +27,14 @@ AudioMixerManagerLinuxALSA::AudioMixerManagerLinuxALSA() _inputMixerHandle(NULL), _outputMixerElement(NULL), _inputMixerElement(NULL) { - RTC_LOG(LS_INFO) << __FUNCTION__ << " created"; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " created"; memset(_outputMixerStr, 0, kAdmMaxDeviceNameSize); memset(_inputMixerStr, 0, kAdmMaxDeviceNameSize); } AudioMixerManagerLinuxALSA::~AudioMixerManagerLinuxALSA() { - RTC_LOG(LS_INFO) << __FUNCTION__ << " destroyed"; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " destroyed"; Close(); } @@ -43,7 +43,7 @@ AudioMixerManagerLinuxALSA::~AudioMixerManagerLinuxALSA() { // ============================================================================ int32_t AudioMixerManagerLinuxALSA::Close() { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; MutexLock lock(&mutex_); @@ -59,7 +59,7 @@ int32_t AudioMixerManagerLinuxALSA::CloseSpeaker() { } int32_t AudioMixerManagerLinuxALSA::CloseSpeakerLocked() { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; int errVal = 0; @@ -94,7 +94,7 @@ int32_t AudioMixerManagerLinuxALSA::CloseMicrophone() { } int32_t AudioMixerManagerLinuxALSA::CloseMicrophoneLocked() { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; int errVal = 0; @@ -289,13 +289,13 @@ int32_t AudioMixerManagerLinuxALSA::OpenMicrophone(char* deviceName) { } bool AudioMixerManagerLinuxALSA::SpeakerIsInitialized() const { - RTC_LOG(LS_INFO) << __FUNCTION__; + RTC_DLOG(LS_INFO) << __FUNCTION__; return (_outputMixerHandle != NULL); } bool AudioMixerManagerLinuxALSA::MicrophoneIsInitialized() const { - RTC_LOG(LS_INFO) << __FUNCTION__; + RTC_DLOG(LS_INFO) << __FUNCTION__; return (_inputMixerHandle != NULL); } diff --git a/modules/audio_device/linux/audio_mixer_manager_pulse_linux.cc b/modules/audio_device/linux/audio_mixer_manager_pulse_linux.cc index c507e623b3..91beee3c87 100644 --- a/modules/audio_device/linux/audio_mixer_manager_pulse_linux.cc +++ b/modules/audio_device/linux/audio_mixer_manager_pulse_linux.cc @@ -54,12 +54,12 @@ AudioMixerManagerLinuxPulse::AudioMixerManagerLinuxPulse() _paSpeakerVolume(PA_VOLUME_NORM), _paChannels(0), _paObjectsSet(false) { - RTC_LOG(LS_INFO) << __FUNCTION__ << " created"; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " created"; } AudioMixerManagerLinuxPulse::~AudioMixerManagerLinuxPulse() { RTC_DCHECK(thread_checker_.IsCurrent()); - RTC_LOG(LS_INFO) << __FUNCTION__ << " destroyed"; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " destroyed"; Close(); } @@ -72,7 +72,7 @@ int32_t AudioMixerManagerLinuxPulse::SetPulseAudioObjects( pa_threaded_mainloop* mainloop, pa_context* context) { RTC_DCHECK(thread_checker_.IsCurrent()); - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; if (!mainloop || !context) { RTC_LOG(LS_ERROR) << "could not set PulseAudio objects for mixer"; @@ -90,7 +90,7 @@ int32_t AudioMixerManagerLinuxPulse::SetPulseAudioObjects( int32_t AudioMixerManagerLinuxPulse::Close() { RTC_DCHECK(thread_checker_.IsCurrent()); - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; CloseSpeaker(); CloseMicrophone(); @@ -104,7 +104,7 @@ int32_t AudioMixerManagerLinuxPulse::Close() { int32_t AudioMixerManagerLinuxPulse::CloseSpeaker() { RTC_DCHECK(thread_checker_.IsCurrent()); - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; // Reset the index to -1 _paOutputDeviceIndex = -1; @@ -115,7 +115,7 @@ int32_t AudioMixerManagerLinuxPulse::CloseSpeaker() { int32_t AudioMixerManagerLinuxPulse::CloseMicrophone() { RTC_DCHECK(thread_checker_.IsCurrent()); - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; // Reset the index to -1 _paInputDeviceIndex = -1; @@ -186,14 +186,14 @@ int32_t AudioMixerManagerLinuxPulse::OpenMicrophone(uint16_t deviceIndex) { bool AudioMixerManagerLinuxPulse::SpeakerIsInitialized() const { RTC_DCHECK(thread_checker_.IsCurrent()); - RTC_LOG(LS_INFO) << __FUNCTION__; + RTC_DLOG(LS_INFO) << __FUNCTION__; return (_paOutputDeviceIndex != -1); } bool AudioMixerManagerLinuxPulse::MicrophoneIsInitialized() const { RTC_DCHECK(thread_checker_.IsCurrent()); - RTC_LOG(LS_INFO) << __FUNCTION__; + RTC_DLOG(LS_INFO) << __FUNCTION__; return (_paInputDeviceIndex != -1); } diff --git a/modules/audio_device/linux/audio_mixer_manager_pulse_linux.h b/modules/audio_device/linux/audio_mixer_manager_pulse_linux.h index f2f3e48c70..546440c4a6 100644 --- a/modules/audio_device/linux/audio_mixer_manager_pulse_linux.h +++ b/modules/audio_device/linux/audio_mixer_manager_pulse_linux.h @@ -14,7 +14,7 @@ #include #include -#include "rtc_base/thread_checker.h" +#include "api/sequence_checker.h" #ifndef UINT32_MAX #define UINT32_MAX ((uint32_t)-1) @@ -103,10 +103,10 @@ class AudioMixerManagerLinuxPulse { bool _paObjectsSet; // Stores thread ID in constructor. - // We can then use ThreadChecker::IsCurrent() to ensure that + // We can then use RTC_DCHECK_RUN_ON(&worker_thread_checker_) to ensure that // other methods are called from the same thread. // Currently only does RTC_DCHECK(thread_checker_.IsCurrent()). - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; }; } // namespace webrtc diff --git a/modules/audio_device/linux/latebindingsymboltable_linux.h b/modules/audio_device/linux/latebindingsymboltable_linux.h index edb62aede8..6cfb659749 100644 --- a/modules/audio_device/linux/latebindingsymboltable_linux.h +++ b/modules/audio_device/linux/latebindingsymboltable_linux.h @@ -11,10 +11,10 @@ #ifndef AUDIO_DEVICE_LATEBINDINGSYMBOLTABLE_LINUX_H_ #define AUDIO_DEVICE_LATEBINDINGSYMBOLTABLE_LINUX_H_ -#include #include // for NULL #include +#include "rtc_base/checks.h" #include "rtc_base/constructor_magic.h" // This file provides macros for creating "symbol table" classes to simplify the @@ -59,7 +59,7 @@ class LateBindingSymbolTable { // We do not use this, but we offer it for theoretical convenience. static const char* GetSymbolName(int index) { - assert(index < NumSymbols()); + RTC_DCHECK_LT(index, NumSymbols()); return kSymbolNames[index]; } @@ -100,8 +100,8 @@ class LateBindingSymbolTable { // Retrieves the given symbol. NOTE: Recommended to use LATESYM_GET below // instead of this. void* GetSymbol(int index) const { - assert(IsLoaded()); - assert(index < NumSymbols()); + RTC_DCHECK(IsLoaded()); + RTC_DCHECK_LT(index, NumSymbols()); return symbols_[index]; } diff --git a/modules/audio_device/mac/audio_device_mac.cc b/modules/audio_device/mac/audio_device_mac.cc index 0c6e9f5dec..2088b017a0 100644 --- a/modules/audio_device/mac/audio_device_mac.cc +++ b/modules/audio_device/mac/audio_device_mac.cc @@ -150,7 +150,7 @@ AudioDeviceMac::AudioDeviceMac() _captureBufSizeSamples(0), _renderBufSizeSamples(0), prev_key_state_() { - RTC_LOG(LS_INFO) << __FUNCTION__ << " created"; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " created"; memset(_renderConvertData, 0, sizeof(_renderConvertData)); memset(&_outStreamFormat, 0, sizeof(AudioStreamBasicDescription)); @@ -160,14 +160,14 @@ AudioDeviceMac::AudioDeviceMac() } AudioDeviceMac::~AudioDeviceMac() { - RTC_LOG(LS_INFO) << __FUNCTION__ << " destroyed"; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " destroyed"; if (!_isShutDown) { Terminate(); } - RTC_DCHECK(!capture_worker_thread_.get()); - RTC_DCHECK(!render_worker_thread_.get()); + RTC_DCHECK(capture_worker_thread_.empty()); + RTC_DCHECK(render_worker_thread_.empty()); if (_paRenderBuffer) { delete _paRenderBuffer; @@ -1308,11 +1308,14 @@ int32_t AudioDeviceMac::StartRecording() { return -1; } - RTC_DCHECK(!capture_worker_thread_.get()); - capture_worker_thread_.reset(new rtc::PlatformThread( - RunCapture, this, "CaptureWorkerThread", rtc::kRealtimePriority)); - RTC_DCHECK(capture_worker_thread_.get()); - capture_worker_thread_->Start(); + RTC_DCHECK(capture_worker_thread_.empty()); + capture_worker_thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { + while (CaptureWorkerThread()) { + } + }, + "CaptureWorkerThread", + rtc::ThreadAttributes().SetPriority(rtc::ThreadPriority::kRealtime)); OSStatus err = noErr; if (_twoDevices) { @@ -1394,10 +1397,9 @@ int32_t AudioDeviceMac::StopRecording() { // Setting this signal will allow the worker thread to be stopped. AtomicSet32(&_captureDeviceIsAlive, 0); - if (capture_worker_thread_.get()) { + if (!capture_worker_thread_.empty()) { mutex_.Unlock(); - capture_worker_thread_->Stop(); - capture_worker_thread_.reset(); + capture_worker_thread_.Finalize(); mutex_.Lock(); } @@ -1443,10 +1445,14 @@ int32_t AudioDeviceMac::StartPlayout() { return 0; } - RTC_DCHECK(!render_worker_thread_.get()); - render_worker_thread_.reset(new rtc::PlatformThread( - RunRender, this, "RenderWorkerThread", rtc::kRealtimePriority)); - render_worker_thread_->Start(); + RTC_DCHECK(render_worker_thread_.empty()); + render_worker_thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { + while (RenderWorkerThread()) { + } + }, + "RenderWorkerThread", + rtc::ThreadAttributes().SetPriority(rtc::ThreadPriority::kRealtime)); if (_twoDevices || !_recording) { OSStatus err = noErr; @@ -1504,10 +1510,9 @@ int32_t AudioDeviceMac::StopPlayout() { // Setting this signal will allow the worker thread to be stopped. AtomicSet32(&_renderDeviceIsAlive, 0); - if (render_worker_thread_.get()) { + if (!render_worker_thread_.empty()) { mutex_.Unlock(); - render_worker_thread_->Stop(); - render_worker_thread_.reset(); + render_worker_thread_.Finalize(); mutex_.Lock(); } @@ -2369,12 +2374,6 @@ OSStatus AudioDeviceMac::implInConverterProc(UInt32* numberDataPackets, return 0; } -void AudioDeviceMac::RunRender(void* ptrThis) { - AudioDeviceMac* device = static_cast(ptrThis); - while (device->RenderWorkerThread()) { - } -} - bool AudioDeviceMac::RenderWorkerThread() { PaRingBufferSize numSamples = ENGINE_PLAY_BUF_SIZE_IN_SAMPLES * _outDesiredFormat.mChannelsPerFrame; @@ -2440,12 +2439,6 @@ bool AudioDeviceMac::RenderWorkerThread() { return true; } -void AudioDeviceMac::RunCapture(void* ptrThis) { - AudioDeviceMac* device = static_cast(ptrThis); - while (device->CaptureWorkerThread()) { - } -} - bool AudioDeviceMac::CaptureWorkerThread() { OSStatus err = noErr; UInt32 noRecSamples = diff --git a/modules/audio_device/mac/audio_device_mac.h b/modules/audio_device/mac/audio_device_mac.h index 985db9da52..f9504b64b5 100644 --- a/modules/audio_device/mac/audio_device_mac.h +++ b/modules/audio_device/mac/audio_device_mac.h @@ -21,15 +21,12 @@ #include "modules/audio_device/mac/audio_mixer_manager_mac.h" #include "rtc_base/event.h" #include "rtc_base/logging.h" +#include "rtc_base/platform_thread.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread_annotations.h" struct PaUtilRingBuffer; -namespace rtc { -class PlatformThread; -} // namespace rtc - namespace webrtc { const uint32_t N_REC_SAMPLES_PER_SEC = 48000; @@ -271,13 +268,11 @@ class AudioDeviceMac : public AudioDeviceGeneric { rtc::Event _stopEventRec; rtc::Event _stopEvent; - // TODO(pbos): Replace with direct members, just start/stop, no need to - // recreate the thread. // Only valid/running between calls to StartRecording and StopRecording. - std::unique_ptr capture_worker_thread_; + rtc::PlatformThread capture_worker_thread_; // Only valid/running between calls to StartPlayout and StopPlayout. - std::unique_ptr render_worker_thread_; + rtc::PlatformThread render_worker_thread_; AudioMixerManagerMac _mixerManager; diff --git a/modules/audio_device/mac/audio_mixer_manager_mac.cc b/modules/audio_device/mac/audio_mixer_manager_mac.cc index 162f3f255d..942e7db3b3 100644 --- a/modules/audio_device/mac/audio_mixer_manager_mac.cc +++ b/modules/audio_device/mac/audio_mixer_manager_mac.cc @@ -46,11 +46,11 @@ AudioMixerManagerMac::AudioMixerManagerMac() _outputDeviceID(kAudioObjectUnknown), _noInputChannels(0), _noOutputChannels(0) { - RTC_LOG(LS_INFO) << __FUNCTION__ << " created"; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " created"; } AudioMixerManagerMac::~AudioMixerManagerMac() { - RTC_LOG(LS_INFO) << __FUNCTION__ << " destroyed"; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " destroyed"; Close(); } @@ -59,7 +59,7 @@ AudioMixerManagerMac::~AudioMixerManagerMac() { // ============================================================================ int32_t AudioMixerManagerMac::Close() { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; MutexLock lock(&mutex_); @@ -75,7 +75,7 @@ int32_t AudioMixerManagerMac::CloseSpeaker() { } int32_t AudioMixerManagerMac::CloseSpeakerLocked() { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; _outputDeviceID = kAudioObjectUnknown; _noOutputChannels = 0; @@ -89,7 +89,7 @@ int32_t AudioMixerManagerMac::CloseMicrophone() { } int32_t AudioMixerManagerMac::CloseMicrophoneLocked() { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; _inputDeviceID = kAudioObjectUnknown; _noInputChannels = 0; @@ -196,13 +196,13 @@ int32_t AudioMixerManagerMac::OpenMicrophone(AudioDeviceID deviceID) { } bool AudioMixerManagerMac::SpeakerIsInitialized() const { - RTC_LOG(LS_INFO) << __FUNCTION__; + RTC_DLOG(LS_INFO) << __FUNCTION__; return (_outputDeviceID != kAudioObjectUnknown); } bool AudioMixerManagerMac::MicrophoneIsInitialized() const { - RTC_LOG(LS_INFO) << __FUNCTION__; + RTC_DLOG(LS_INFO) << __FUNCTION__; return (_inputDeviceID != kAudioObjectUnknown); } @@ -225,7 +225,7 @@ int32_t AudioMixerManagerMac::SetSpeakerVolume(uint32_t volume) { // volume range is 0.0 - 1.0, convert from 0 -255 const Float32 vol = (Float32)(volume / 255.0); - assert(vol <= 1.0 && vol >= 0.0); + RTC_DCHECK(vol <= 1.0 && vol >= 0.0); // Does the capture device have a master volume control? // If so, use it exclusively. @@ -311,7 +311,7 @@ int32_t AudioMixerManagerMac::SpeakerVolume(uint32_t& volume) const { return -1; } - assert(channels > 0); + RTC_DCHECK_GT(channels, 0); // vol 0.0 to 1.0 -> convert to 0 - 255 volume = static_cast(255 * vol / channels + 0.5); } @@ -522,7 +522,7 @@ int32_t AudioMixerManagerMac::SpeakerMute(bool& enabled) const { return -1; } - assert(channels > 0); + RTC_DCHECK_GT(channels, 0); // 1 means muted enabled = static_cast(muted); } @@ -690,7 +690,7 @@ int32_t AudioMixerManagerMac::MicrophoneMute(bool& enabled) const { return -1; } - assert(channels > 0); + RTC_DCHECK_GT(channels, 0); // 1 means muted enabled = static_cast(muted); } @@ -757,7 +757,7 @@ int32_t AudioMixerManagerMac::SetMicrophoneVolume(uint32_t volume) { // volume range is 0.0 - 1.0, convert from 0 - 255 const Float32 vol = (Float32)(volume / 255.0); - assert(vol <= 1.0 && vol >= 0.0); + RTC_DCHECK(vol <= 1.0 && vol >= 0.0); // Does the capture device have a master volume control? // If so, use it exclusively. @@ -843,7 +843,7 @@ int32_t AudioMixerManagerMac::MicrophoneVolume(uint32_t& volume) const { return -1; } - assert(channels > 0); + RTC_DCHECK_GT(channels, 0); // vol 0.0 to 1.0 -> convert to 0 - 255 volume = static_cast(255 * volFloat32 / channels + 0.5); } diff --git a/modules/audio_device/win/audio_device_core_win.cc b/modules/audio_device/win/audio_device_core_win.cc index 776a16cda4..a3723edb56 100644 --- a/modules/audio_device/win/audio_device_core_win.cc +++ b/modules/audio_device/win/audio_device_core_win.cc @@ -174,7 +174,7 @@ class MediaBufferImpl final : public IMediaBuffer { // ---------------------------------------------------------------------------- bool AudioDeviceWindowsCore::CoreAudioIsSupported() { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; bool MMDeviceIsAvailable(false); bool coreAudioIsSupported(false); @@ -281,7 +281,7 @@ bool AudioDeviceWindowsCore::CoreAudioIsSupported() { DWORD messageLength = ::FormatMessageW(dwFlags, 0, hr, dwLangID, errorText, MAXERRORLENGTH, NULL); - assert(messageLength <= MAXERRORLENGTH); + RTC_DCHECK_LE(messageLength, MAXERRORLENGTH); // Trims tailing white space (FormatMessage() leaves a trailing cr-lf.). for (; messageLength && ::isspace(errorText[messageLength - 1]); @@ -395,7 +395,7 @@ AudioDeviceWindowsCore::AudioDeviceWindowsCore() _outputDevice(AudioDeviceModule::kDefaultCommunicationDevice), _inputDeviceIndex(0), _outputDeviceIndex(0) { - RTC_LOG(LS_INFO) << __FUNCTION__ << " created"; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " created"; RTC_DCHECK(_comInit.Succeeded()); // Try to load the Avrt DLL @@ -469,7 +469,7 @@ AudioDeviceWindowsCore::AudioDeviceWindowsCore() CoCreateInstance(__uuidof(MMDeviceEnumerator), NULL, CLSCTX_ALL, __uuidof(IMMDeviceEnumerator), reinterpret_cast(&_ptrEnumerator)); - assert(NULL != _ptrEnumerator); + RTC_DCHECK(_ptrEnumerator); // DMO initialization for built-in WASAPI AEC. { @@ -492,7 +492,7 @@ AudioDeviceWindowsCore::AudioDeviceWindowsCore() // ---------------------------------------------------------------------------- AudioDeviceWindowsCore::~AudioDeviceWindowsCore() { - RTC_LOG(LS_INFO) << __FUNCTION__ << " destroyed"; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " destroyed"; Terminate(); @@ -1347,7 +1347,7 @@ int32_t AudioDeviceWindowsCore::MicrophoneVolume(uint32_t& volume) const { // ---------------------------------------------------------------------------- int32_t AudioDeviceWindowsCore::MaxMicrophoneVolume(uint32_t& maxVolume) const { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; if (!_microphoneIsInitialized) { return -1; @@ -1411,7 +1411,7 @@ int32_t AudioDeviceWindowsCore::SetPlayoutDevice(uint16_t index) { HRESULT hr(S_OK); - assert(_ptrRenderCollection != NULL); + RTC_DCHECK(_ptrRenderCollection); // Select an endpoint rendering device given the specified index SAFE_RELEASE(_ptrDeviceOut); @@ -1461,7 +1461,7 @@ int32_t AudioDeviceWindowsCore::SetPlayoutDevice( HRESULT hr(S_OK); - assert(_ptrEnumerator != NULL); + RTC_DCHECK(_ptrEnumerator); // Select an endpoint rendering device given the specified role SAFE_RELEASE(_ptrDeviceOut); @@ -1677,7 +1677,7 @@ int32_t AudioDeviceWindowsCore::SetRecordingDevice(uint16_t index) { HRESULT hr(S_OK); - assert(_ptrCaptureCollection != NULL); + RTC_DCHECK(_ptrCaptureCollection); // Select an endpoint capture device given the specified index SAFE_RELEASE(_ptrDeviceIn); @@ -1727,7 +1727,7 @@ int32_t AudioDeviceWindowsCore::SetRecordingDevice( HRESULT hr(S_OK); - assert(_ptrEnumerator != NULL); + RTC_DCHECK(_ptrEnumerator); // Select an endpoint capture device given the specified role SAFE_RELEASE(_ptrDeviceIn); @@ -2036,8 +2036,8 @@ int32_t AudioDeviceWindowsCore::InitPlayout() { // handles device initialization itself. // Reference: http://msdn.microsoft.com/en-us/library/ff819492(v=vs.85).aspx int32_t AudioDeviceWindowsCore::InitRecordingDMO() { - assert(_builtInAecEnabled); - assert(_dmo != NULL); + RTC_DCHECK(_builtInAecEnabled); + RTC_DCHECK(_dmo); if (SetDMOProperties() == -1) { return -1; @@ -2356,7 +2356,7 @@ int32_t AudioDeviceWindowsCore::StartRecording() { } } - assert(_hRecThread == NULL); + RTC_DCHECK(_hRecThread); _hRecThread = CreateThread(NULL, 0, lpStartAddress, this, 0, NULL); if (_hRecThread == NULL) { RTC_LOG(LS_ERROR) << "failed to create the recording thread"; @@ -2421,8 +2421,8 @@ int32_t AudioDeviceWindowsCore::StopRecording() { ResetEvent(_hShutdownCaptureEvent); // Must be manually reset. // Ensure that the thread has released these interfaces properly. - assert(err == -1 || _ptrClientIn == NULL); - assert(err == -1 || _ptrCaptureClient == NULL); + RTC_DCHECK(err == -1 || _ptrClientIn == NULL); + RTC_DCHECK(err == -1 || _ptrCaptureClient == NULL); _recIsInitialized = false; _recording = false; @@ -2433,7 +2433,7 @@ int32_t AudioDeviceWindowsCore::StopRecording() { _hRecThread = NULL; if (_builtInAecEnabled) { - assert(_dmo != NULL); + RTC_DCHECK(_dmo); // This is necessary. Otherwise the DMO can generate garbage render // audio even after rendering has stopped. HRESULT hr = _dmo->FreeStreamingResources(); @@ -2493,7 +2493,7 @@ int32_t AudioDeviceWindowsCore::StartPlayout() { MutexLock lockScoped(&mutex_); // Create thread which will drive the rendering. - assert(_hPlayThread == NULL); + RTC_DCHECK(_hPlayThread); _hPlayThread = CreateThread(NULL, 0, WSAPIRenderThread, this, 0, NULL); if (_hPlayThread == NULL) { RTC_LOG(LS_ERROR) << "failed to create the playout thread"; @@ -2954,7 +2954,7 @@ void AudioDeviceWindowsCore::RevertCaptureThreadPriority() { } DWORD AudioDeviceWindowsCore::DoCaptureThreadPollDMO() { - assert(_mediaBuffer != NULL); + RTC_DCHECK(_mediaBuffer); bool keepRecording = true; // Initialize COM as MTA in this thread. @@ -3010,7 +3010,7 @@ DWORD AudioDeviceWindowsCore::DoCaptureThreadPollDMO() { if (FAILED(hr)) { _TraceCOMError(hr); keepRecording = false; - assert(false); + RTC_NOTREACHED(); break; } @@ -3022,7 +3022,7 @@ DWORD AudioDeviceWindowsCore::DoCaptureThreadPollDMO() { if (FAILED(hr)) { _TraceCOMError(hr); keepRecording = false; - assert(false); + RTC_NOTREACHED(); break; } @@ -3031,8 +3031,8 @@ DWORD AudioDeviceWindowsCore::DoCaptureThreadPollDMO() { // TODO(andrew): verify that this is always satisfied. It might // be that ProcessOutput will try to return more than 10 ms if // we fail to call it frequently enough. - assert(kSamplesProduced == static_cast(_recBlockSize)); - assert(sizeof(BYTE) == sizeof(int8_t)); + RTC_DCHECK_EQ(kSamplesProduced, static_cast(_recBlockSize)); + RTC_DCHECK_EQ(sizeof(BYTE), sizeof(int8_t)); _ptrAudioBuffer->SetRecordedBuffer(reinterpret_cast(data), kSamplesProduced); _ptrAudioBuffer->SetVQEData(0, 0); @@ -3047,7 +3047,7 @@ DWORD AudioDeviceWindowsCore::DoCaptureThreadPollDMO() { if (FAILED(hr)) { _TraceCOMError(hr); keepRecording = false; - assert(false); + RTC_NOTREACHED(); break; } @@ -3228,7 +3228,7 @@ DWORD AudioDeviceWindowsCore::DoCaptureThread() { pData = NULL; } - assert(framesAvailable != 0); + RTC_DCHECK_NE(framesAvailable, 0); if (pData) { CopyMemory(&syncBuffer[syncBufIndex * _recAudioFrameSize], pData, @@ -3237,8 +3237,8 @@ DWORD AudioDeviceWindowsCore::DoCaptureThread() { ZeroMemory(&syncBuffer[syncBufIndex * _recAudioFrameSize], framesAvailable * _recAudioFrameSize); } - assert(syncBufferSize >= (syncBufIndex * _recAudioFrameSize) + - framesAvailable * _recAudioFrameSize); + RTC_DCHECK_GE(syncBufferSize, (syncBufIndex * _recAudioFrameSize) + + framesAvailable * _recAudioFrameSize); // Release the capture buffer // @@ -3377,7 +3377,7 @@ void AudioDeviceWindowsCore::_UnLock() RTC_NO_THREAD_SAFETY_ANALYSIS { int AudioDeviceWindowsCore::SetDMOProperties() { HRESULT hr = S_OK; - assert(_dmo != NULL); + RTC_DCHECK(_dmo); rtc::scoped_refptr ps; { @@ -3512,13 +3512,13 @@ int AudioDeviceWindowsCore::SetVtI4Property(IPropertyStore* ptrPS, // ---------------------------------------------------------------------------- int32_t AudioDeviceWindowsCore::_RefreshDeviceList(EDataFlow dir) { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; HRESULT hr = S_OK; IMMDeviceCollection* pCollection = NULL; - assert(dir == eRender || dir == eCapture); - assert(_ptrEnumerator != NULL); + RTC_DCHECK(dir == eRender || dir == eCapture); + RTC_DCHECK(_ptrEnumerator); // Create a fresh list of devices using the specified direction hr = _ptrEnumerator->EnumAudioEndpoints(dir, DEVICE_STATE_ACTIVE, @@ -3548,12 +3548,12 @@ int32_t AudioDeviceWindowsCore::_RefreshDeviceList(EDataFlow dir) { // ---------------------------------------------------------------------------- int16_t AudioDeviceWindowsCore::_DeviceListCount(EDataFlow dir) { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; HRESULT hr = S_OK; UINT count = 0; - assert(eRender == dir || eCapture == dir); + RTC_DCHECK(eRender == dir || eCapture == dir); if (eRender == dir && NULL != _ptrRenderCollection) { hr = _ptrRenderCollection->GetCount(&count); @@ -3584,12 +3584,12 @@ int32_t AudioDeviceWindowsCore::_GetListDeviceName(EDataFlow dir, int index, LPWSTR szBuffer, int bufferLen) { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; HRESULT hr = S_OK; IMMDevice* pDevice = NULL; - assert(dir == eRender || dir == eCapture); + RTC_DCHECK(dir == eRender || dir == eCapture); if (eRender == dir && NULL != _ptrRenderCollection) { hr = _ptrRenderCollection->Item(index, &pDevice); @@ -3621,14 +3621,14 @@ int32_t AudioDeviceWindowsCore::_GetDefaultDeviceName(EDataFlow dir, ERole role, LPWSTR szBuffer, int bufferLen) { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; HRESULT hr = S_OK; IMMDevice* pDevice = NULL; - assert(dir == eRender || dir == eCapture); - assert(role == eConsole || role == eCommunications); - assert(_ptrEnumerator != NULL); + RTC_DCHECK(dir == eRender || dir == eCapture); + RTC_DCHECK(role == eConsole || role == eCommunications); + RTC_DCHECK(_ptrEnumerator); hr = _ptrEnumerator->GetDefaultAudioEndpoint(dir, role, &pDevice); @@ -3658,12 +3658,12 @@ int32_t AudioDeviceWindowsCore::_GetListDeviceID(EDataFlow dir, int index, LPWSTR szBuffer, int bufferLen) { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; HRESULT hr = S_OK; IMMDevice* pDevice = NULL; - assert(dir == eRender || dir == eCapture); + RTC_DCHECK(dir == eRender || dir == eCapture); if (eRender == dir && NULL != _ptrRenderCollection) { hr = _ptrRenderCollection->Item(index, &pDevice); @@ -3695,14 +3695,14 @@ int32_t AudioDeviceWindowsCore::_GetDefaultDeviceID(EDataFlow dir, ERole role, LPWSTR szBuffer, int bufferLen) { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; HRESULT hr = S_OK; IMMDevice* pDevice = NULL; - assert(dir == eRender || dir == eCapture); - assert(role == eConsole || role == eCommunications); - assert(_ptrEnumerator != NULL); + RTC_DCHECK(dir == eRender || dir == eCapture); + RTC_DCHECK(role == eConsole || role == eCommunications); + RTC_DCHECK(_ptrEnumerator); hr = _ptrEnumerator->GetDefaultAudioEndpoint(dir, role, &pDevice); @@ -3720,15 +3720,15 @@ int32_t AudioDeviceWindowsCore::_GetDefaultDeviceID(EDataFlow dir, int32_t AudioDeviceWindowsCore::_GetDefaultDeviceIndex(EDataFlow dir, ERole role, int* index) { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; HRESULT hr = S_OK; WCHAR szDefaultDeviceID[MAX_PATH] = {0}; WCHAR szDeviceID[MAX_PATH] = {0}; const size_t kDeviceIDLength = sizeof(szDeviceID) / sizeof(szDeviceID[0]); - assert(kDeviceIDLength == - sizeof(szDefaultDeviceID) / sizeof(szDefaultDeviceID[0])); + RTC_DCHECK_EQ(kDeviceIDLength, + sizeof(szDefaultDeviceID) / sizeof(szDefaultDeviceID[0])); if (_GetDefaultDeviceID(dir, role, szDefaultDeviceID, kDeviceIDLength) == -1) { @@ -3793,7 +3793,7 @@ int32_t AudioDeviceWindowsCore::_GetDefaultDeviceIndex(EDataFlow dir, int32_t AudioDeviceWindowsCore::_GetDeviceName(IMMDevice* pDevice, LPWSTR pszBuffer, int bufferLen) { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; static const WCHAR szDefault[] = L""; @@ -3801,8 +3801,8 @@ int32_t AudioDeviceWindowsCore::_GetDeviceName(IMMDevice* pDevice, IPropertyStore* pProps = NULL; PROPVARIANT varName; - assert(pszBuffer != NULL); - assert(bufferLen > 0); + RTC_DCHECK(pszBuffer); + RTC_DCHECK_GT(bufferLen, 0); if (pDevice != NULL) { hr = pDevice->OpenPropertyStore(STGM_READ, &pProps); @@ -3860,15 +3860,15 @@ int32_t AudioDeviceWindowsCore::_GetDeviceName(IMMDevice* pDevice, int32_t AudioDeviceWindowsCore::_GetDeviceID(IMMDevice* pDevice, LPWSTR pszBuffer, int bufferLen) { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; static const WCHAR szDefault[] = L""; HRESULT hr = E_FAIL; LPWSTR pwszID = NULL; - assert(pszBuffer != NULL); - assert(bufferLen > 0); + RTC_DCHECK(pszBuffer); + RTC_DCHECK_GT(bufferLen, 0); if (pDevice != NULL) { hr = pDevice->GetId(&pwszID); @@ -3893,11 +3893,11 @@ int32_t AudioDeviceWindowsCore::_GetDeviceID(IMMDevice* pDevice, int32_t AudioDeviceWindowsCore::_GetDefaultDevice(EDataFlow dir, ERole role, IMMDevice** ppDevice) { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; HRESULT hr(S_OK); - assert(_ptrEnumerator != NULL); + RTC_DCHECK(_ptrEnumerator); hr = _ptrEnumerator->GetDefaultAudioEndpoint(dir, role, ppDevice); if (FAILED(hr)) { @@ -3917,7 +3917,7 @@ int32_t AudioDeviceWindowsCore::_GetListDevice(EDataFlow dir, IMMDevice** ppDevice) { HRESULT hr(S_OK); - assert(_ptrEnumerator != NULL); + RTC_DCHECK(_ptrEnumerator); IMMDeviceCollection* pCollection = NULL; @@ -3938,6 +3938,8 @@ int32_t AudioDeviceWindowsCore::_GetListDevice(EDataFlow dir, return -1; } + SAFE_RELEASE(pCollection); + return 0; } @@ -3947,9 +3949,9 @@ int32_t AudioDeviceWindowsCore::_GetListDevice(EDataFlow dir, int32_t AudioDeviceWindowsCore::_EnumerateEndpointDevicesAll( EDataFlow dataFlow) const { - RTC_LOG(LS_VERBOSE) << __FUNCTION__; + RTC_DLOG(LS_VERBOSE) << __FUNCTION__; - assert(_ptrEnumerator != NULL); + RTC_DCHECK(_ptrEnumerator); HRESULT hr = S_OK; IMMDeviceCollection* pCollection = NULL; @@ -4141,7 +4143,7 @@ void AudioDeviceWindowsCore::_TraceCOMError(HRESULT hr) const { DWORD messageLength = ::FormatMessageW(dwFlags, 0, hr, dwLangID, errorText, MAXERRORLENGTH, NULL); - assert(messageLength <= MAXERRORLENGTH); + RTC_DCHECK_LE(messageLength, MAXERRORLENGTH); // Trims tailing white space (FormatMessage() leaves a trailing cr-lf.). for (; messageLength && ::isspace(errorText[messageLength - 1]); diff --git a/modules/audio_device/win/audio_device_module_win.cc b/modules/audio_device/win/audio_device_module_win.cc index b77a24aadb..8cc4b7fc36 100644 --- a/modules/audio_device/win/audio_device_module_win.cc +++ b/modules/audio_device/win/audio_device_module_win.cc @@ -13,13 +13,13 @@ #include #include +#include "api/sequence_checker.h" #include "modules/audio_device/audio_device_buffer.h" #include "modules/audio_device/include/audio_device.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/ref_counted_object.h" #include "rtc_base/string_utils.h" -#include "rtc_base/thread_checker.h" namespace webrtc { namespace webrtc_win { @@ -95,12 +95,12 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { task_queue_factory_(task_queue_factory) { RTC_CHECK(input_); RTC_CHECK(output_); - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); } ~WindowsAudioDeviceModule() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); Terminate(); } @@ -110,7 +110,7 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { int32_t ActiveAudioLayer( AudioDeviceModule::AudioLayer* audioLayer) const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); // TODO(henrika): it might be possible to remove this unique signature. *audioLayer = AudioDeviceModule::kWindowsCoreAudio2; @@ -118,14 +118,14 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { } int32_t RegisterAudioCallback(AudioTransport* audioCallback) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK(audio_device_buffer_); RTC_DCHECK_RUN_ON(&thread_checker_); return audio_device_buffer_->RegisterAudioCallback(audioCallback); } int32_t Init() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_OUTPUT_RESTARTS(0); RETURN_IF_INPUT_RESTARTS(0); @@ -153,7 +153,7 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { } int32_t Terminate() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_OUTPUT_RESTARTS(0); RETURN_IF_INPUT_RESTARTS(0); @@ -172,14 +172,14 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { } int16_t PlayoutDevices() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_OUTPUT_RESTARTS(0); return output_->NumDevices(); } int16_t RecordingDevices() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_INPUT_RESTARTS(0); return input_->NumDevices(); @@ -188,7 +188,7 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { int32_t PlayoutDeviceName(uint16_t index, char name[kAdmMaxDeviceNameSize], char guid[kAdmMaxGuidSize]) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_OUTPUT_RESTARTS(0); std::string name_str, guid_str; @@ -205,7 +205,7 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { int32_t RecordingDeviceName(uint16_t index, char name[kAdmMaxDeviceNameSize], char guid[kAdmMaxGuidSize]) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_INPUT_RESTARTS(0); std::string name_str, guid_str; @@ -221,7 +221,7 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { } int32_t SetPlayoutDevice(uint16_t index) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_OUTPUT_RESTARTS(0); return output_->SetDevice(index); @@ -229,33 +229,33 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { int32_t SetPlayoutDevice( AudioDeviceModule::WindowsDeviceType device) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_OUTPUT_RESTARTS(0); return output_->SetDevice(device); } int32_t SetRecordingDevice(uint16_t index) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); return input_->SetDevice(index); } int32_t SetRecordingDevice( AudioDeviceModule::WindowsDeviceType device) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); return input_->SetDevice(device); } int32_t PlayoutIsAvailable(bool* available) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); *available = true; return 0; } int32_t InitPlayout() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_OUTPUT_RESTARTS(0); RETURN_IF_OUTPUT_IS_INITIALIZED(0); @@ -263,21 +263,21 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { } bool PlayoutIsInitialized() const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_OUTPUT_RESTARTS(true); return output_->PlayoutIsInitialized(); } int32_t RecordingIsAvailable(bool* available) override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); *available = true; return 0; } int32_t InitRecording() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_INPUT_RESTARTS(0); RETURN_IF_INPUT_IS_INITIALIZED(0); @@ -285,14 +285,14 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { } bool RecordingIsInitialized() const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_INPUT_RESTARTS(true); return input_->RecordingIsInitialized(); } int32_t StartPlayout() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_OUTPUT_RESTARTS(0); RETURN_IF_OUTPUT_IS_ACTIVE(0); @@ -300,21 +300,21 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { } int32_t StopPlayout() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_OUTPUT_RESTARTS(-1); return output_->StopPlayout(); } bool Playing() const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_OUTPUT_RESTARTS(true); return output_->Playing(); } int32_t StartRecording() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_INPUT_RESTARTS(0); RETURN_IF_INPUT_IS_ACTIVE(0); @@ -322,41 +322,41 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { } int32_t StopRecording() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RETURN_IF_INPUT_RESTARTS(-1); return input_->StopRecording(); } bool Recording() const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RETURN_IF_INPUT_RESTARTS(true); return input_->Recording(); } int32_t InitSpeaker() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RTC_DLOG(LS_WARNING) << "This method has no effect"; return initialized_ ? 0 : -1; } bool SpeakerIsInitialized() const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RTC_DLOG(LS_WARNING) << "This method has no effect"; return initialized_; } int32_t InitMicrophone() override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RTC_DLOG(LS_WARNING) << "This method has no effect"; return initialized_ ? 0 : -1; } bool MicrophoneIsInitialized() const override { - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); RTC_DLOG(LS_WARNING) << "This method has no effect"; return initialized_; @@ -364,7 +364,7 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { int32_t SpeakerVolumeIsAvailable(bool* available) override { // TODO(henrika): improve support. - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); *available = false; return 0; @@ -377,7 +377,7 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { int32_t MicrophoneVolumeIsAvailable(bool* available) override { // TODO(henrika): improve support. - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); *available = false; return 0; @@ -398,7 +398,7 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { int32_t StereoPlayoutIsAvailable(bool* available) const override { // TODO(henrika): improve support. - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); *available = true; return 0; @@ -406,14 +406,14 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { int32_t SetStereoPlayout(bool enable) override { // TODO(henrika): improve support. - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); return 0; } int32_t StereoPlayout(bool* enabled) const override { // TODO(henrika): improve support. - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); *enabled = true; return 0; @@ -421,7 +421,7 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { int32_t StereoRecordingIsAvailable(bool* available) const override { // TODO(henrika): improve support. - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); *available = true; return 0; @@ -429,14 +429,14 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { int32_t SetStereoRecording(bool enable) override { // TODO(henrika): improve support. - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); return 0; } int32_t StereoRecording(bool* enabled) const override { // TODO(henrika): improve support. - RTC_LOG(INFO) << __FUNCTION__; + RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK_RUN_ON(&thread_checker_); *enabled = true; return 0; @@ -487,7 +487,7 @@ class WindowsAudioDeviceModule : public AudioDeviceModuleForTest { private: // Ensures that the class is used on the same thread as it is constructed // and destroyed on. - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; // Implements the AudioInput interface and deals with audio capturing parts. const std::unique_ptr input_; @@ -513,8 +513,8 @@ CreateWindowsCoreAudioAudioDeviceModuleFromInputAndOutput( std::unique_ptr audio_input, std::unique_ptr audio_output, TaskQueueFactory* task_queue_factory) { - RTC_LOG(INFO) << __FUNCTION__; - return new rtc::RefCountedObject( + RTC_DLOG(INFO) << __FUNCTION__; + return rtc::make_ref_counted( std::move(audio_input), std::move(audio_output), task_queue_factory); } diff --git a/modules/audio_device/win/core_audio_base_win.cc b/modules/audio_device/win/core_audio_base_win.cc index 672e482478..7d93fcb14a 100644 --- a/modules/audio_device/win/core_audio_base_win.cc +++ b/modules/audio_device/win/core_audio_base_win.cc @@ -9,15 +9,16 @@ */ #include "modules/audio_device/win/core_audio_base_win.h" -#include "modules/audio_device/audio_device_buffer.h" #include #include +#include "modules/audio_device/audio_device_buffer.h" #include "rtc_base/arraysize.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/numerics/safe_conversions.h" +#include "rtc_base/platform_thread.h" #include "rtc_base/time_utils.h" #include "rtc_base/win/scoped_com_initializer.h" #include "rtc_base/win/windows_version.h" @@ -118,11 +119,6 @@ const char* SessionDisconnectReasonToString( } } -void Run(void* obj) { - RTC_DCHECK(obj); - reinterpret_cast(obj)->ThreadRun(); -} - // Returns true if the selected audio device supports low latency, i.e, if it // is possible to initialize the engine using periods less than the default // period (10ms). @@ -552,24 +548,19 @@ bool CoreAudioBase::Start() { // Audio thread should be alive during internal restart since the restart // callback is triggered on that thread and it also makes the restart // sequence less complex. - RTC_DCHECK(audio_thread_); + RTC_DCHECK(!audio_thread_.empty()); } // Start an audio thread but only if one does not already exist (which is the // case during restart). - if (!audio_thread_) { - audio_thread_ = std::make_unique( - Run, this, IsInput() ? "wasapi_capture_thread" : "wasapi_render_thread", - rtc::kRealtimePriority); - RTC_DCHECK(audio_thread_); - audio_thread_->Start(); - if (!audio_thread_->IsRunning()) { - StopThread(); - RTC_LOG(LS_ERROR) << "Failed to start audio thread"; - return false; - } - RTC_DLOG(INFO) << "Started thread with name: " << audio_thread_->name() - << " and id: " << audio_thread_->GetThreadRef(); + if (audio_thread_.empty()) { + const absl::string_view name = + IsInput() ? "wasapi_capture_thread" : "wasapi_render_thread"; + audio_thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { ThreadRun(); }, name, + rtc::ThreadAttributes().SetPriority(rtc::ThreadPriority::kRealtime)); + RTC_DLOG(INFO) << "Started thread with name: " << name + << " and handle: " << *audio_thread_.GetHandle(); } // Start streaming data between the endpoint buffer and the audio engine. @@ -696,14 +687,11 @@ bool CoreAudioBase::Restart() { void CoreAudioBase::StopThread() { RTC_DLOG(INFO) << __FUNCTION__; RTC_DCHECK(!IsRestarting()); - if (audio_thread_) { - if (audio_thread_->IsRunning()) { - RTC_DLOG(INFO) << "Sets stop_event..."; - SetEvent(stop_event_.Get()); - RTC_DLOG(INFO) << "PlatformThread::Stop..."; - audio_thread_->Stop(); - } - audio_thread_.reset(); + if (!audio_thread_.empty()) { + RTC_DLOG(INFO) << "Sets stop_event..."; + SetEvent(stop_event_.Get()); + RTC_DLOG(INFO) << "PlatformThread::Finalize..."; + audio_thread_.Finalize(); // Ensure that we don't quit the main thread loop immediately next // time Start() is called. @@ -716,7 +704,7 @@ bool CoreAudioBase::HandleRestartEvent() { RTC_DLOG(INFO) << __FUNCTION__ << "[" << DirectionToString(direction()) << "]"; RTC_DCHECK_RUN_ON(&thread_checker_audio_); - RTC_DCHECK(audio_thread_); + RTC_DCHECK(!audio_thread_.empty()); RTC_DCHECK(IsRestarting()); // Let each client (input and/or output) take care of its own restart // sequence since each side might need unique actions. diff --git a/modules/audio_device/win/core_audio_base_win.h b/modules/audio_device/win/core_audio_base_win.h index 87f306f541..afcc6a684d 100644 --- a/modules/audio_device/win/core_audio_base_win.h +++ b/modules/audio_device/win/core_audio_base_win.h @@ -17,9 +17,9 @@ #include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "modules/audio_device/win/core_audio_utility_win.h" #include "rtc_base/platform_thread.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -128,8 +128,8 @@ class CoreAudioBase : public IAudioSessionEvents { // level here. In addition, calls to Init(), Start() and Stop() are not // included to allow for support of internal restart (where these methods are // called on the audio thread). - rtc::ThreadChecker thread_checker_; - rtc::ThreadChecker thread_checker_audio_; + SequenceChecker thread_checker_; + SequenceChecker thread_checker_audio_; AudioDeviceBuffer* audio_device_buffer_ = nullptr; bool initialized_ = false; WAVEFORMATEXTENSIBLE format_ = {}; @@ -158,7 +158,7 @@ class CoreAudioBase : public IAudioSessionEvents { // Set when restart process starts and cleared when restart stops // successfully. Accessed atomically. std::atomic is_restarting_; - std::unique_ptr audio_thread_; + rtc::PlatformThread audio_thread_; Microsoft::WRL::ComPtr audio_session_control_; void StopThread(); diff --git a/modules/audio_device/win/core_audio_output_win.cc b/modules/audio_device/win/core_audio_output_win.cc index 299eefe18c..36ec703c3a 100644 --- a/modules/audio_device/win/core_audio_output_win.cc +++ b/modules/audio_device/win/core_audio_output_win.cc @@ -14,7 +14,6 @@ #include "modules/audio_device/audio_device_buffer.h" #include "modules/audio_device/fine_audio_buffer.h" -#include "rtc_base/bind.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/time_utils.h" diff --git a/modules/audio_device/win/core_audio_utility_win.cc b/modules/audio_device/win/core_audio_utility_win.cc index f17ee99143..289abe9d26 100644 --- a/modules/audio_device/win/core_audio_utility_win.cc +++ b/modules/audio_device/win/core_audio_utility_win.cc @@ -323,7 +323,7 @@ ComPtr CreateDeviceInternal(const std::string& device_id, // Verify that the audio endpoint device is active, i.e., that the audio // adapter that connects to the endpoint device is present and enabled. - if (SUCCEEDED(error.Error()) && !audio_endpoint_device.Get() && + if (SUCCEEDED(error.Error()) && audio_endpoint_device.Get() && !IsDeviceActive(audio_endpoint_device.Get())) { RTC_LOG(LS_WARNING) << "Selected endpoint device is not active"; audio_endpoint_device.Reset(); diff --git a/modules/audio_mixer/BUILD.gn b/modules/audio_mixer/BUILD.gn index 7ce35ffeb3..d51be4af04 100644 --- a/modules/audio_mixer/BUILD.gn +++ b/modules/audio_mixer/BUILD.gn @@ -39,6 +39,7 @@ rtc_library("audio_mixer_impl") { deps = [ ":audio_frame_manipulator", "../../api:array_view", + "../../api:rtp_packet_info", "../../api:scoped_refptr", "../../api/audio:audio_frame_api", "../../api/audio:audio_mixer_api", @@ -46,6 +47,7 @@ rtc_library("audio_mixer_impl") { "../../common_audio", "../../rtc_base:checks", "../../rtc_base:rtc_base_approved", + "../../rtc_base:safe_conversions", "../../rtc_base/synchronization:mutex", "../../system_wrappers", "../../system_wrappers:metrics", @@ -104,13 +106,15 @@ if (rtc_include_tests) { "audio_mixer_impl_unittest.cc", "frame_combiner_unittest.cc", ] - + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] deps = [ ":audio_frame_manipulator", ":audio_mixer_impl", ":audio_mixer_test_utils", "../../api:array_view", + "../../api:rtp_packet_info", "../../api/audio:audio_mixer_api", + "../../api/units:timestamp", "../../audio/utility:audio_frame_operations", "../../rtc_base:checks", "../../rtc_base:rtc_base_approved", @@ -119,17 +123,19 @@ if (rtc_include_tests) { ] } - rtc_executable("audio_mixer_test") { - testonly = true - sources = [ "audio_mixer_test.cc" ] - - deps = [ - ":audio_mixer_impl", - "../../api/audio:audio_mixer_api", - "../../common_audio", - "../../rtc_base:stringutils", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - ] + if (!build_with_chromium) { + rtc_executable("audio_mixer_test") { + testonly = true + sources = [ "audio_mixer_test.cc" ] + + deps = [ + ":audio_mixer_impl", + "../../api/audio:audio_mixer_api", + "../../common_audio", + "../../rtc_base:stringutils", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] + } } } diff --git a/modules/audio_mixer/OWNERS b/modules/audio_mixer/OWNERS index b33d599697..5edc304ab3 100644 --- a/modules/audio_mixer/OWNERS +++ b/modules/audio_mixer/OWNERS @@ -1,2 +1,2 @@ -aleloi@webrtc.org +alessiob@webrtc.org henrik.lundin@webrtc.org diff --git a/modules/audio_mixer/audio_mixer_impl.cc b/modules/audio_mixer/audio_mixer_impl.cc index 04a8bcf723..8cebc38779 100644 --- a/modules/audio_mixer/audio_mixer_impl.cc +++ b/modules/audio_mixer/audio_mixer_impl.cc @@ -126,30 +126,33 @@ struct AudioMixerImpl::HelperContainers { AudioMixerImpl::AudioMixerImpl( std::unique_ptr output_rate_calculator, - bool use_limiter) - : output_rate_calculator_(std::move(output_rate_calculator)), + bool use_limiter, + int max_sources_to_mix) + : max_sources_to_mix_(max_sources_to_mix), + output_rate_calculator_(std::move(output_rate_calculator)), audio_source_list_(), helper_containers_(std::make_unique()), frame_combiner_(use_limiter) { - const int kTypicalMaxNumberOfMixedStreams = 3; - audio_source_list_.reserve(kTypicalMaxNumberOfMixedStreams); - helper_containers_->resize(kTypicalMaxNumberOfMixedStreams); + RTC_CHECK_GE(max_sources_to_mix, 1) << "At least one source must be mixed"; + audio_source_list_.reserve(max_sources_to_mix); + helper_containers_->resize(max_sources_to_mix); } AudioMixerImpl::~AudioMixerImpl() {} -rtc::scoped_refptr AudioMixerImpl::Create() { +rtc::scoped_refptr AudioMixerImpl::Create( + int max_sources_to_mix) { return Create(std::unique_ptr( new DefaultOutputRateCalculator()), - true); + /*use_limiter=*/true, max_sources_to_mix); } rtc::scoped_refptr AudioMixerImpl::Create( std::unique_ptr output_rate_calculator, - bool use_limiter) { - return rtc::scoped_refptr( - new rtc::RefCountedObject( - std::move(output_rate_calculator), use_limiter)); + bool use_limiter, + int max_sources_to_mix) { + return rtc::make_ref_counted( + std::move(output_rate_calculator), use_limiter, max_sources_to_mix); } void AudioMixerImpl::Mix(size_t number_of_channels, @@ -219,7 +222,7 @@ rtc::ArrayView AudioMixerImpl::GetAudioFromSources( std::sort(audio_source_mixing_data_view.begin(), audio_source_mixing_data_view.end(), ShouldMixBefore); - int max_audio_frame_counter = kMaximumAmountOfMixedAudioSources; + int max_audio_frame_counter = max_sources_to_mix_; int ramp_list_lengh = 0; int audio_to_mix_count = 0; // Go through list in order and put unmuted frames in result list. diff --git a/modules/audio_mixer/audio_mixer_impl.h b/modules/audio_mixer/audio_mixer_impl.h index 0a13082725..737fcbdc43 100644 --- a/modules/audio_mixer/audio_mixer_impl.h +++ b/modules/audio_mixer/audio_mixer_impl.h @@ -35,13 +35,16 @@ class AudioMixerImpl : public AudioMixer { // AudioProcessing only accepts 10 ms frames. static const int kFrameDurationInMs = 10; - enum : int { kMaximumAmountOfMixedAudioSources = 3 }; - static rtc::scoped_refptr Create(); + static const int kDefaultNumberOfMixedAudioSources = 3; + + static rtc::scoped_refptr Create( + int max_sources_to_mix = kDefaultNumberOfMixedAudioSources); static rtc::scoped_refptr Create( std::unique_ptr output_rate_calculator, - bool use_limiter); + bool use_limiter, + int max_sources_to_mix = kDefaultNumberOfMixedAudioSources); ~AudioMixerImpl() override; @@ -60,7 +63,8 @@ class AudioMixerImpl : public AudioMixer { protected: AudioMixerImpl(std::unique_ptr output_rate_calculator, - bool use_limiter); + bool use_limiter, + int max_sources_to_mix); private: struct HelperContainers; @@ -76,6 +80,8 @@ class AudioMixerImpl : public AudioMixer { // checks that mixing is done sequentially. mutable Mutex mutex_; + const int max_sources_to_mix_; + std::unique_ptr output_rate_calculator_; // List of all audio sources. diff --git a/modules/audio_mixer/audio_mixer_impl_unittest.cc b/modules/audio_mixer/audio_mixer_impl_unittest.cc index 18a4384b66..61aa74e0a1 100644 --- a/modules/audio_mixer/audio_mixer_impl_unittest.cc +++ b/modules/audio_mixer/audio_mixer_impl_unittest.cc @@ -12,14 +12,19 @@ #include +#include #include #include #include #include +#include +#include "absl/types/optional.h" #include "api/audio/audio_mixer.h" +#include "api/rtp_packet_info.h" +#include "api/rtp_packet_infos.h" +#include "api/units/timestamp.h" #include "modules/audio_mixer/default_output_rate_calculator.h" -#include "rtc_base/bind.h" #include "rtc_base/checks.h" #include "rtc_base/strings/string_builder.h" #include "rtc_base/task_queue_for_test.h" @@ -30,6 +35,7 @@ using ::testing::_; using ::testing::Exactly; using ::testing::Invoke; using ::testing::Return; +using ::testing::UnorderedElementsAre; namespace webrtc { @@ -88,6 +94,10 @@ class MockMixerAudioSource : public ::testing::NiceMock { fake_audio_frame_info_ = audio_frame_info; } + void set_packet_infos(const RtpPacketInfos& packet_infos) { + packet_infos_ = packet_infos; + } + private: AudioFrameInfo FakeAudioFrameWithInfo(int sample_rate_hz, AudioFrame* audio_frame) { @@ -95,11 +105,13 @@ class MockMixerAudioSource : public ::testing::NiceMock { audio_frame->sample_rate_hz_ = sample_rate_hz; audio_frame->samples_per_channel_ = rtc::CheckedDivExact(sample_rate_hz, 100); + audio_frame->packet_infos_ = packet_infos_; return fake_info(); } AudioFrame fake_frame_; AudioFrameInfo fake_audio_frame_info_; + RtpPacketInfos packet_infos_; }; class CustomRateCalculator : public OutputRateCalculator { @@ -161,7 +173,7 @@ void MixMonoAtGivenNativeRate(int native_sample_rate, TEST(AudioMixer, LargestEnergyVadActiveMixed) { constexpr int kAudioSources = - AudioMixerImpl::kMaximumAmountOfMixedAudioSources + 3; + AudioMixerImpl::kDefaultNumberOfMixedAudioSources + 3; const auto mixer = AudioMixerImpl::Create(); @@ -192,7 +204,7 @@ TEST(AudioMixer, LargestEnergyVadActiveMixed) { mixer->GetAudioSourceMixabilityStatusForTest(&participants[i]); if (i == kAudioSources - 1 || i < kAudioSources - 1 - - AudioMixerImpl::kMaximumAmountOfMixedAudioSources) { + AudioMixerImpl::kDefaultNumberOfMixedAudioSources) { EXPECT_FALSE(is_mixed) << "Mixing status of AudioSource #" << i << " wrong."; } else { @@ -323,7 +335,7 @@ TEST(AudioMixer, ParticipantNumberOfChannels) { // another participant with higher energy is added. TEST(AudioMixer, RampedOutSourcesShouldNotBeMarkedMixed) { constexpr int kAudioSources = - AudioMixerImpl::kMaximumAmountOfMixedAudioSources + 1; + AudioMixerImpl::kDefaultNumberOfMixedAudioSources + 1; const auto mixer = AudioMixerImpl::Create(); MockMixerAudioSource participants[kAudioSources]; @@ -400,7 +412,7 @@ TEST(AudioMixer, ConstructFromOtherThread) { TEST(AudioMixer, MutedShouldMixAfterUnmuted) { constexpr int kAudioSources = - AudioMixerImpl::kMaximumAmountOfMixedAudioSources + 1; + AudioMixerImpl::kDefaultNumberOfMixedAudioSources + 1; std::vector frames(kAudioSources); for (auto& frame : frames) { @@ -418,7 +430,7 @@ TEST(AudioMixer, MutedShouldMixAfterUnmuted) { TEST(AudioMixer, PassiveShouldMixAfterNormal) { constexpr int kAudioSources = - AudioMixerImpl::kMaximumAmountOfMixedAudioSources + 1; + AudioMixerImpl::kDefaultNumberOfMixedAudioSources + 1; std::vector frames(kAudioSources); for (auto& frame : frames) { @@ -436,7 +448,7 @@ TEST(AudioMixer, PassiveShouldMixAfterNormal) { TEST(AudioMixer, ActiveShouldMixBeforeLoud) { constexpr int kAudioSources = - AudioMixerImpl::kMaximumAmountOfMixedAudioSources + 1; + AudioMixerImpl::kDefaultNumberOfMixedAudioSources + 1; std::vector frames(kAudioSources); for (auto& frame : frames) { @@ -455,9 +467,52 @@ TEST(AudioMixer, ActiveShouldMixBeforeLoud) { MixAndCompare(frames, frame_info, expected_status); } +TEST(AudioMixer, ShouldMixUpToSpecifiedNumberOfSourcesToMix) { + constexpr int kAudioSources = 5; + constexpr int kSourcesToMix = 2; + + std::vector frames(kAudioSources); + for (auto& frame : frames) { + ResetFrame(&frame); + } + + std::vector frame_info( + kAudioSources, AudioMixer::Source::AudioFrameInfo::kNormal); + // Set up to kSourceToMix sources with kVadActive so that they're mixed. + const std::vector kVadActivities = { + AudioFrame::kVadUnknown, AudioFrame::kVadPassive, AudioFrame::kVadPassive, + AudioFrame::kVadActive, AudioFrame::kVadActive}; + // Populate VAD and frame for all sources. + for (int i = 0; i < kAudioSources; i++) { + frames[i].vad_activity_ = kVadActivities[i]; + } + + std::vector participants(kAudioSources); + for (int i = 0; i < kAudioSources; ++i) { + participants[i].fake_frame()->CopyFrom(frames[i]); + participants[i].set_fake_info(frame_info[i]); + } + + const auto mixer = AudioMixerImpl::Create(kSourcesToMix); + for (int i = 0; i < kAudioSources; ++i) { + EXPECT_TRUE(mixer->AddSource(&participants[i])); + EXPECT_CALL(participants[i], GetAudioFrameWithInfo(kDefaultSampleRateHz, _)) + .Times(Exactly(1)); + } + + mixer->Mix(1, &frame_for_mixing); + + std::vector expected_status = {false, false, false, true, true}; + for (int i = 0; i < kAudioSources; ++i) { + EXPECT_EQ(expected_status[i], + mixer->GetAudioSourceMixabilityStatusForTest(&participants[i])) + << "Wrong mix status for source #" << i << " is wrong"; + } +} + TEST(AudioMixer, UnmutedShouldMixBeforeLoud) { constexpr int kAudioSources = - AudioMixerImpl::kMaximumAmountOfMixedAudioSources + 1; + AudioMixerImpl::kDefaultNumberOfMixedAudioSources + 1; std::vector frames(kAudioSources); for (auto& frame : frames) { @@ -596,6 +651,100 @@ TEST(AudioMixer, MultipleChannelsManyParticipants) { } } +TEST(AudioMixer, ShouldIncludeRtpPacketInfoFromAllMixedSources) { + const uint32_t kSsrc0 = 10; + const uint32_t kSsrc1 = 11; + const uint32_t kSsrc2 = 12; + const uint32_t kCsrc0 = 20; + const uint32_t kCsrc1 = 21; + const uint32_t kCsrc2 = 22; + const uint32_t kCsrc3 = 23; + const int kAudioLevel0 = 10; + const int kAudioLevel1 = 40; + const absl::optional kAudioLevel2 = absl::nullopt; + const uint32_t kRtpTimestamp0 = 300; + const uint32_t kRtpTimestamp1 = 400; + const Timestamp kReceiveTime0 = Timestamp::Millis(10); + const Timestamp kReceiveTime1 = Timestamp::Millis(20); + + const RtpPacketInfo kPacketInfo0(kSsrc0, {kCsrc0, kCsrc1}, kRtpTimestamp0, + kAudioLevel0, absl::nullopt, kReceiveTime0); + const RtpPacketInfo kPacketInfo1(kSsrc1, {kCsrc2}, kRtpTimestamp1, + kAudioLevel1, absl::nullopt, kReceiveTime1); + const RtpPacketInfo kPacketInfo2(kSsrc2, {kCsrc3}, kRtpTimestamp1, + kAudioLevel2, absl::nullopt, kReceiveTime1); + + const auto mixer = AudioMixerImpl::Create(); + + MockMixerAudioSource source; + source.set_packet_infos(RtpPacketInfos({kPacketInfo0})); + mixer->AddSource(&source); + ResetFrame(source.fake_frame()); + mixer->Mix(1, &frame_for_mixing); + + MockMixerAudioSource other_source; + other_source.set_packet_infos(RtpPacketInfos({kPacketInfo1, kPacketInfo2})); + ResetFrame(other_source.fake_frame()); + mixer->AddSource(&other_source); + + mixer->Mix(/*number_of_channels=*/1, &frame_for_mixing); + + EXPECT_THAT(frame_for_mixing.packet_infos_, + UnorderedElementsAre(kPacketInfo0, kPacketInfo1, kPacketInfo2)); +} + +TEST(AudioMixer, MixerShouldIncludeRtpPacketInfoFromMixedSourcesOnly) { + const uint32_t kSsrc0 = 10; + const uint32_t kSsrc1 = 11; + const uint32_t kSsrc2 = 21; + const uint32_t kCsrc0 = 30; + const uint32_t kCsrc1 = 31; + const uint32_t kCsrc2 = 32; + const uint32_t kCsrc3 = 33; + const int kAudioLevel0 = 10; + const absl::optional kAudioLevelMissing = absl::nullopt; + const uint32_t kRtpTimestamp0 = 300; + const uint32_t kRtpTimestamp1 = 400; + const Timestamp kReceiveTime0 = Timestamp::Millis(10); + const Timestamp kReceiveTime1 = Timestamp::Millis(20); + + const RtpPacketInfo kPacketInfo0(kSsrc0, {kCsrc0, kCsrc1}, kRtpTimestamp0, + kAudioLevel0, absl::nullopt, kReceiveTime0); + const RtpPacketInfo kPacketInfo1(kSsrc1, {kCsrc2}, kRtpTimestamp1, + kAudioLevelMissing, absl::nullopt, + kReceiveTime1); + const RtpPacketInfo kPacketInfo2(kSsrc2, {kCsrc3}, kRtpTimestamp1, + kAudioLevelMissing, absl::nullopt, + kReceiveTime1); + + const auto mixer = AudioMixerImpl::Create(/*max_sources_to_mix=*/2); + + MockMixerAudioSource source1; + source1.set_packet_infos(RtpPacketInfos({kPacketInfo0})); + mixer->AddSource(&source1); + ResetFrame(source1.fake_frame()); + mixer->Mix(1, &frame_for_mixing); + + MockMixerAudioSource source2; + source2.set_packet_infos(RtpPacketInfos({kPacketInfo1})); + ResetFrame(source2.fake_frame()); + mixer->AddSource(&source2); + + // The mixer prioritizes kVadActive over kVadPassive. + // We limit the number of sources to mix to 2 and set the third source's VAD + // activity to kVadPassive so that it will not be added to the mix. + MockMixerAudioSource source3; + source3.set_packet_infos(RtpPacketInfos({kPacketInfo2})); + ResetFrame(source3.fake_frame()); + source3.fake_frame()->vad_activity_ = AudioFrame::kVadPassive; + mixer->AddSource(&source3); + + mixer->Mix(/*number_of_channels=*/1, &frame_for_mixing); + + EXPECT_THAT(frame_for_mixing.packet_infos_, + UnorderedElementsAre(kPacketInfo0, kPacketInfo1)); +} + class HighOutputRateCalculator : public OutputRateCalculator { public: static const int kDefaultFrequency = 76000; diff --git a/modules/audio_mixer/frame_combiner.cc b/modules/audio_mixer/frame_combiner.cc index e184506b4c..e31eea595f 100644 --- a/modules/audio_mixer/frame_combiner.cc +++ b/modules/audio_mixer/frame_combiner.cc @@ -16,8 +16,12 @@ #include #include #include +#include +#include #include "api/array_view.h" +#include "api/rtp_packet_info.h" +#include "api/rtp_packet_infos.h" #include "common_audio/include/audio_util.h" #include "modules/audio_mixer/audio_frame_manipulator.h" #include "modules/audio_mixer/audio_mixer_impl.h" @@ -26,6 +30,7 @@ #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/arraysize.h" #include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_conversions.h" #include "system_wrappers/include/metrics.h" namespace webrtc { @@ -53,11 +58,23 @@ void SetAudioFrameFields(rtc::ArrayView mix_list, if (mix_list.empty()) { audio_frame_for_mixing->elapsed_time_ms_ = -1; - } else if (mix_list.size() == 1) { + } else { audio_frame_for_mixing->timestamp_ = mix_list[0]->timestamp_; audio_frame_for_mixing->elapsed_time_ms_ = mix_list[0]->elapsed_time_ms_; audio_frame_for_mixing->ntp_time_ms_ = mix_list[0]->ntp_time_ms_; - audio_frame_for_mixing->packet_infos_ = mix_list[0]->packet_infos_; + std::vector packet_infos; + for (const auto& frame : mix_list) { + audio_frame_for_mixing->timestamp_ = + std::min(audio_frame_for_mixing->timestamp_, frame->timestamp_); + audio_frame_for_mixing->ntp_time_ms_ = + std::min(audio_frame_for_mixing->ntp_time_ms_, frame->ntp_time_ms_); + audio_frame_for_mixing->elapsed_time_ms_ = std::max( + audio_frame_for_mixing->elapsed_time_ms_, frame->elapsed_time_ms_); + packet_infos.insert(packet_infos.end(), frame->packet_infos_.begin(), + frame->packet_infos_.end()); + } + audio_frame_for_mixing->packet_infos_ = + RtpPacketInfos(std::move(packet_infos)); } } @@ -88,13 +105,14 @@ void MixToFloatFrame(rtc::ArrayView mix_list, // Convert to FloatS16 and mix. for (size_t i = 0; i < mix_list.size(); ++i) { const AudioFrame* const frame = mix_list[i]; + const int16_t* const frame_data = frame->data(); for (size_t j = 0; j < std::min(number_of_channels, FrameCombiner::kMaximumNumberOfChannels); ++j) { for (size_t k = 0; k < std::min(samples_per_channel, FrameCombiner::kMaximumChannelSize); ++k) { - (*mixing_buffer)[j][k] += frame->data()[number_of_channels * k + j]; + (*mixing_buffer)[j][k] += frame_data[number_of_channels * k + j]; } } } @@ -113,10 +131,11 @@ void InterleaveToAudioFrame(AudioFrameView mixing_buffer_view, AudioFrame* audio_frame_for_mixing) { const size_t number_of_channels = mixing_buffer_view.num_channels(); const size_t samples_per_channel = mixing_buffer_view.samples_per_channel(); + int16_t* const mixing_data = audio_frame_for_mixing->mutable_data(); // Put data in the result frame. for (size_t i = 0; i < number_of_channels; ++i) { for (size_t j = 0; j < samples_per_channel; ++j) { - audio_frame_for_mixing->mutable_data()[number_of_channels * j + i] = + mixing_data[number_of_channels * j + i] = FloatS16ToS16(mixing_buffer_view.channel(i)[j]); } } @@ -205,10 +224,10 @@ void FrameCombiner::LogMixingStats( uma_logging_counter_ = 0; RTC_HISTOGRAM_COUNTS_100("WebRTC.Audio.AudioMixer.NumIncomingStreams", static_cast(number_of_streams)); - RTC_HISTOGRAM_ENUMERATION( - "WebRTC.Audio.AudioMixer.NumIncomingActiveStreams", - static_cast(mix_list.size()), - AudioMixerImpl::kMaximumAmountOfMixedAudioSources); + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.Audio.AudioMixer.NumIncomingActiveStreams2", + rtc::dchecked_cast(mix_list.size()), /*min=*/1, /*max=*/16, + /*bucket_count=*/16); using NativeRate = AudioProcessing::NativeRate; static constexpr NativeRate native_rates[] = { diff --git a/modules/audio_mixer/frame_combiner_unittest.cc b/modules/audio_mixer/frame_combiner_unittest.cc index 4b189a052e..fa1fef325c 100644 --- a/modules/audio_mixer/frame_combiner_unittest.cc +++ b/modules/audio_mixer/frame_combiner_unittest.cc @@ -15,8 +15,12 @@ #include #include #include +#include +#include "absl/types/optional.h" #include "api/array_view.h" +#include "api/rtp_packet_info.h" +#include "api/rtp_packet_infos.h" #include "audio/utility/audio_frame_operations.h" #include "modules/audio_mixer/gain_change_calculator.h" #include "modules/audio_mixer/sine_wave_generator.h" @@ -28,7 +32,13 @@ namespace webrtc { namespace { + +using ::testing::ElementsAreArray; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAreArray; + using LimiterType = FrameCombiner::LimiterType; + struct FrameCombinerConfig { bool use_limiter; int sample_rate_hz; @@ -57,9 +67,24 @@ std::string ProduceDebugText(const FrameCombinerConfig& config) { AudioFrame frame1; AudioFrame frame2; -AudioFrame audio_frame_for_mixing; void SetUpFrames(int sample_rate_hz, int number_of_channels) { + RtpPacketInfo packet_info1( + /*ssrc=*/1001, /*csrcs=*/{}, /*rtp_timestamp=*/1000, + /*audio_level=*/absl::nullopt, /*absolute_capture_time=*/absl::nullopt, + /*receive_time_ms=*/1); + RtpPacketInfo packet_info2( + /*ssrc=*/4004, /*csrcs=*/{}, /*rtp_timestamp=*/1234, + /*audio_level=*/absl::nullopt, /*absolute_capture_time=*/absl::nullopt, + /*receive_time_ms=*/2); + RtpPacketInfo packet_info3( + /*ssrc=*/7007, /*csrcs=*/{}, /*rtp_timestamp=*/1333, + /*audio_level=*/absl::nullopt, /*absolute_capture_time=*/absl::nullopt, + /*receive_time_ms=*/2); + + frame1.packet_infos_ = RtpPacketInfos({packet_info1}); + frame2.packet_infos_ = RtpPacketInfos({packet_info2, packet_info3}); + for (auto* frame : {&frame1, &frame2}) { frame->UpdateFrame(0, nullptr, rtc::CheckedDivExact(sample_rate_hz, 100), sample_rate_hz, AudioFrame::kNormalSpeech, @@ -81,6 +106,7 @@ TEST(FrameCombiner, BasicApiCallsLimiter) { ProduceDebugText(rate, number_of_channels, number_of_frames)); const std::vector frames_to_combine( all_frames.begin(), all_frames.begin() + number_of_frames); + AudioFrame audio_frame_for_mixing; combiner.Combine(frames_to_combine, number_of_channels, rate, frames_to_combine.size(), &audio_frame_for_mixing); } @@ -88,6 +114,35 @@ TEST(FrameCombiner, BasicApiCallsLimiter) { } } +// The RtpPacketInfos field of the mixed packet should contain the union of the +// RtpPacketInfos from the frames that were actually mixed. +TEST(FrameCombiner, ContainsAllRtpPacketInfos) { + static constexpr int kSampleRateHz = 48000; + static constexpr int kNumChannels = 1; + FrameCombiner combiner(true); + const std::vector all_frames = {&frame1, &frame2}; + SetUpFrames(kSampleRateHz, kNumChannels); + + for (const int number_of_frames : {0, 1, 2}) { + SCOPED_TRACE( + ProduceDebugText(kSampleRateHz, kNumChannels, number_of_frames)); + const std::vector frames_to_combine( + all_frames.begin(), all_frames.begin() + number_of_frames); + + std::vector packet_infos; + for (const auto& frame : frames_to_combine) { + packet_infos.insert(packet_infos.end(), frame->packet_infos_.begin(), + frame->packet_infos_.end()); + } + + AudioFrame audio_frame_for_mixing; + combiner.Combine(frames_to_combine, kNumChannels, kSampleRateHz, + frames_to_combine.size(), &audio_frame_for_mixing); + EXPECT_THAT(audio_frame_for_mixing.packet_infos_, + UnorderedElementsAreArray(packet_infos)); + } +} + // There are DCHECKs in place to check for invalid parameters. TEST(FrameCombinerDeathTest, DebugBuildCrashesWithManyChannels) { FrameCombiner combiner(true); @@ -105,6 +160,7 @@ TEST(FrameCombinerDeathTest, DebugBuildCrashesWithManyChannels) { ProduceDebugText(rate, number_of_channels, number_of_frames)); const std::vector frames_to_combine( all_frames.begin(), all_frames.begin() + number_of_frames); + AudioFrame audio_frame_for_mixing; #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) EXPECT_DEATH( combiner.Combine(frames_to_combine, number_of_channels, rate, @@ -134,6 +190,7 @@ TEST(FrameCombinerDeathTest, DebugBuildCrashesWithHighRate) { ProduceDebugText(rate, number_of_channels, number_of_frames)); const std::vector frames_to_combine( all_frames.begin(), all_frames.begin() + number_of_frames); + AudioFrame audio_frame_for_mixing; #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) EXPECT_DEATH( combiner.Combine(frames_to_combine, number_of_channels, rate, @@ -161,6 +218,7 @@ TEST(FrameCombiner, BasicApiCallsNoLimiter) { ProduceDebugText(rate, number_of_channels, number_of_frames)); const std::vector frames_to_combine( all_frames.begin(), all_frames.begin() + number_of_frames); + AudioFrame audio_frame_for_mixing; combiner.Combine(frames_to_combine, number_of_channels, rate, frames_to_combine.size(), &audio_frame_for_mixing); } @@ -174,10 +232,11 @@ TEST(FrameCombiner, CombiningZeroFramesShouldProduceSilence) { for (const int number_of_channels : {1, 2}) { SCOPED_TRACE(ProduceDebugText(rate, number_of_channels, 0)); + AudioFrame audio_frame_for_mixing; + const std::vector frames_to_combine; combiner.Combine(frames_to_combine, number_of_channels, rate, frames_to_combine.size(), &audio_frame_for_mixing); - const int16_t* audio_frame_for_mixing_data = audio_frame_for_mixing.data(); const std::vector mixed_data( @@ -186,6 +245,7 @@ TEST(FrameCombiner, CombiningZeroFramesShouldProduceSilence) { const std::vector expected(number_of_channels * rate / 100, 0); EXPECT_EQ(mixed_data, expected); + EXPECT_THAT(audio_frame_for_mixing.packet_infos_, IsEmpty()); } } } @@ -196,6 +256,8 @@ TEST(FrameCombiner, CombiningOneFrameShouldNotChangeFrame) { for (const int number_of_channels : {1, 2, 4, 8, 10}) { SCOPED_TRACE(ProduceDebugText(rate, number_of_channels, 1)); + AudioFrame audio_frame_for_mixing; + SetUpFrames(rate, number_of_channels); int16_t* frame1_data = frame1.mutable_data(); std::iota(frame1_data, frame1_data + number_of_channels * rate / 100, 0); @@ -212,6 +274,8 @@ TEST(FrameCombiner, CombiningOneFrameShouldNotChangeFrame) { std::vector expected(number_of_channels * rate / 100); std::iota(expected.begin(), expected.end(), 0); EXPECT_EQ(mixed_data, expected); + EXPECT_THAT(audio_frame_for_mixing.packet_infos_, + ElementsAreArray(frame1.packet_infos_)); } } } @@ -255,6 +319,7 @@ TEST(FrameCombiner, GainCurveIsSmoothForAlternatingNumberOfStreams) { // Ensures limiter is on if 'use_limiter'. constexpr size_t number_of_streams = 2; + AudioFrame audio_frame_for_mixing; combiner.Combine(frames_to_combine, config.number_of_channels, config.sample_rate_hz, number_of_streams, &audio_frame_for_mixing); diff --git a/modules/audio_mixer/g3doc/index.md b/modules/audio_mixer/g3doc/index.md new file mode 100644 index 0000000000..285530e95a --- /dev/null +++ b/modules/audio_mixer/g3doc/index.md @@ -0,0 +1,54 @@ + + + +# The WebRTC Audio Mixer Module + +The WebRTC audio mixer module is responsible for mixing multiple incoming audio +streams (sources) into a single audio stream (mix). It works with 10 ms frames, +it supports sample rates up to 48 kHz and up to 8 audio channels. The API is +defined in +[`api/audio/audio_mixer.h`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/audio/audio_mixer.h) +and it includes the definition of +[`AudioMixer::Source`](https://source.chromium.org/search?q=symbol:AudioMixer::Source%20file:third_party%2Fwebrtc%2Fapi%2Faudio%2Faudio_mixer.h), +which describes an incoming audio stream, and the definition of +[`AudioMixer`](https://source.chromium.org/search?q=symbol:AudioMixer%20file:third_party%2Fwebrtc%2Fapi%2Faudio%2Faudio_mixer.h), +which operates on a collection of +[`AudioMixer::Source`](https://source.chromium.org/search?q=symbol:AudioMixer::Source%20file:third_party%2Fwebrtc%2Fapi%2Faudio%2Faudio_mixer.h) +objects to produce a mix. + +## AudioMixer::Source + +A source has different characteristic (e.g., sample rate, number of channels, +muted state) and it is identified by an SSRC[^1]. +[`AudioMixer::Source::GetAudioFrameWithInfo()`](https://source.chromium.org/search?q=symbol:AudioMixer::Source::GetAudioFrameWithInfo%20file:third_party%2Fwebrtc%2Fapi%2Faudio%2Faudio_mixer.h) +is used to retrieve the next 10 ms chunk of audio to be mixed. + +[^1]: A synchronization source (SSRC) is the source of a stream of RTP packets, + identified by a 32-bit numeric SSRC identifier carried in the RTP header + so as not to be dependent upon the network address (see + [RFC 3550](https://tools.ietf.org/html/rfc3550#section-3)). + +## AudioMixer + +The interface allows to add and remove sources and the +[`AudioMixer::Mix()`](https://source.chromium.org/search?q=symbol:AudioMixer::Mix%20file:third_party%2Fwebrtc%2Fapi%2Faudio%2Faudio_mixer.h) +method allows to generates a mix with the desired number of channels. + +## WebRTC implementation + +The interface is implemented in different parts of WebRTC: + +* [`AudioMixer::Source`](https://source.chromium.org/search?q=symbol:AudioMixer::Source%20file:third_party%2Fwebrtc%2Fapi%2Faudio%2Faudio_mixer.h): + [`audio/audio_receive_stream.h`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/audio/audio_receive_stream.h) +* [`AudioMixer`](https://source.chromium.org/search?q=symbol:AudioMixer%20file:third_party%2Fwebrtc%2Fapi%2Faudio%2Faudio_mixer.h): + [`modules/audio_mixer/audio_mixer_impl.h`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_mixer/audio_mixer_impl.h) + +[`AudioMixer`](https://source.chromium.org/search?q=symbol:AudioMixer%20file:third_party%2Fwebrtc%2Fapi%2Faudio%2Faudio_mixer.h) +is thread-safe. The output sample rate of the generated mix is automatically +assigned depending on the sample rate of the sources; whereas the number of +output channels is defined by the caller[^2]. Samples from the non-muted sources +are summed up and then a limiter is used to apply soft-clipping when needed. + +[^2]: [`audio/utility/channel_mixer.h`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/audio/utility/channel_mixer.h) + is used to mix channels in the non-trivial cases - i.e., if the number of + channels for a source or the mix is greater than 3. diff --git a/modules/audio_processing/BUILD.gn b/modules/audio_processing/BUILD.gn index dbb1882de2..a733612ccc 100644 --- a/modules/audio_processing/BUILD.gn +++ b/modules/audio_processing/BUILD.gn @@ -43,7 +43,6 @@ rtc_library("api") { "../../api/audio:aec3_config", "../../api/audio:audio_frame_api", "../../api/audio:echo_control", - "../../rtc_base:deprecation", "../../rtc_base:rtc_base_approved", "../../rtc_base/system:arch", "../../rtc_base/system:file_wrapper", @@ -119,8 +118,8 @@ rtc_source_set("aec_dump_interface") { deps = [ ":api", ":audio_frame_view", - "../../rtc_base:deprecation", ] + absl_deps = [ "//third_party/abseil-cpp/absl/base:core_headers" ] } rtc_library("audio_processing") { @@ -177,7 +176,6 @@ rtc_library("audio_processing") { "../../common_audio:common_audio_c", "../../common_audio/third_party/ooura:fft_size_256", "../../rtc_base:checks", - "../../rtc_base:deprecation", "../../rtc_base:gtest_prod", "../../rtc_base:ignore_wundef", "../../rtc_base:refcount", @@ -197,6 +195,7 @@ rtc_library("audio_processing") { "agc2:adaptive_digital", "agc2:fixed_digital", "agc2:gain_applier", + "capture_levels_adjuster", "ns", "transient:transient_suppressor_api", "vad", @@ -291,6 +290,7 @@ rtc_library("apm_logging") { "../../rtc_base:checks", "../../rtc_base:rtc_base_approved", ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] defines = [] } @@ -308,141 +308,146 @@ if (rtc_include_tests) { ] } - group("audio_processing_tests") { - testonly = true - deps = [ - ":audioproc_test_utils", - "transient:click_annotate", - "transient:transient_suppression_test", - ] - - if (rtc_enable_protobuf) { - deps += [ - ":audioproc_unittest_proto", - "aec_dump:aec_dump_unittests", - "test/conversational_speech", - "test/py_quality_assessment", + if (!build_with_chromium) { + group("audio_processing_tests") { + testonly = true + deps = [ + ":audioproc_test_utils", + "transient:click_annotate", + "transient:transient_suppression_test", ] - } - } - - rtc_library("audio_processing_unittests") { - testonly = true - configs += [ ":apm_debug_dump" ] - sources = [ - "audio_buffer_unittest.cc", - "audio_frame_view_unittest.cc", - "config_unittest.cc", - "echo_control_mobile_unittest.cc", - "gain_controller2_unittest.cc", - "splitting_filter_unittest.cc", - "test/fake_recording_device_unittest.cc", - ] - - deps = [ - ":analog_mic_simulation", - ":api", - ":apm_logging", - ":audio_buffer", - ":audio_frame_view", - ":audio_processing", - ":audioproc_test_utils", - ":config", - ":high_pass_filter", - ":mocks", - ":voice_detection", - "../../api:array_view", - "../../api:scoped_refptr", - "../../api/audio:aec3_config", - "../../api/audio:aec3_factory", - "../../common_audio", - "../../common_audio:common_audio_c", - "../../rtc_base", - "../../rtc_base:checks", - "../../rtc_base:gtest_prod", - "../../rtc_base:ignore_wundef", - "../../rtc_base:protobuf_utils", - "../../rtc_base:rtc_base_approved", - "../../rtc_base:rtc_base_tests_utils", - "../../rtc_base:safe_minmax", - "../../rtc_base:task_queue_for_test", - "../../rtc_base/synchronization:mutex", - "../../rtc_base/system:arch", - "../../rtc_base/system:file_wrapper", - "../../system_wrappers", - "../../test:fileutils", - "../../test:rtc_expect_death", - "../../test:test_support", - "../audio_coding:neteq_input_audio_tools", - "aec_dump:mock_aec_dump_unittests", - "agc:agc_unittests", - "agc2:adaptive_digital_unittests", - "agc2:biquad_filter_unittests", - "agc2:fixed_digital_unittests", - "agc2:noise_estimator_unittests", - "agc2:rnn_vad_with_level_unittests", - "agc2:test_utils", - "agc2/rnn_vad:unittests", - "test/conversational_speech:unittest", - "transient:transient_suppression_unittests", - "utility:legacy_delay_estimator_unittest", - "utility:pffft_wrapper_unittest", - "vad:vad_unittests", - "//testing/gtest", - ] - absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + if (rtc_enable_protobuf) { + deps += [ + ":audioproc_unittest_proto", + "aec_dump:aec_dump_unittests", + "test/conversational_speech", + "test/py_quality_assessment", + ] + } + } - defines = [] + rtc_library("audio_processing_unittests") { + testonly = true - if (rtc_prefer_fixed_point) { - defines += [ "WEBRTC_AUDIOPROC_FIXED_PROFILE" ] - } else { - defines += [ "WEBRTC_AUDIOPROC_FLOAT_PROFILE" ] - } + configs += [ ":apm_debug_dump" ] + sources = [ + "audio_buffer_unittest.cc", + "audio_frame_view_unittest.cc", + "config_unittest.cc", + "echo_control_mobile_unittest.cc", + "gain_controller2_unittest.cc", + "splitting_filter_unittest.cc", + "test/fake_recording_device_unittest.cc", + ] - if (rtc_enable_protobuf) { - defines += [ "WEBRTC_AUDIOPROC_DEBUG_DUMP" ] - deps += [ - ":audioproc_debug_proto", - ":audioproc_protobuf_utils", + deps = [ + ":analog_mic_simulation", + ":api", + ":apm_logging", + ":audio_buffer", + ":audio_frame_view", + ":audio_processing", ":audioproc_test_utils", - ":audioproc_unittest_proto", - ":optionally_built_submodule_creators", - ":rms_level", - ":runtime_settings_protobuf_utils", - "../../api/audio:audio_frame_api", - "../../api/audio:echo_control", + ":config", + ":high_pass_filter", + ":mocks", + ":voice_detection", + "../../api:array_view", + "../../api:scoped_refptr", + "../../api/audio:aec3_config", + "../../api/audio:aec3_factory", + "../../common_audio", + "../../common_audio:common_audio_c", + "../../rtc_base", + "../../rtc_base:checks", + "../../rtc_base:gtest_prod", + "../../rtc_base:ignore_wundef", + "../../rtc_base:protobuf_utils", + "../../rtc_base:rtc_base_approved", "../../rtc_base:rtc_base_tests_utils", - "../../rtc_base:rtc_task_queue", - "aec_dump", - "aec_dump:aec_dump_unittests", - ] - absl_deps += [ "//third_party/abseil-cpp/absl/flags:flag" ] - sources += [ - "audio_processing_impl_locking_unittest.cc", - "audio_processing_impl_unittest.cc", - "audio_processing_unittest.cc", - "echo_control_mobile_bit_exact_unittest.cc", - "echo_detector/circular_buffer_unittest.cc", - "echo_detector/mean_variance_estimator_unittest.cc", - "echo_detector/moving_max_unittest.cc", - "echo_detector/normalized_covariance_estimator_unittest.cc", - "gain_control_unittest.cc", - "high_pass_filter_unittest.cc", - "level_estimator_unittest.cc", - "residual_echo_detector_unittest.cc", - "rms_level_unittest.cc", - "test/debug_dump_replayer.cc", - "test/debug_dump_replayer.h", - "test/debug_dump_test.cc", - "test/echo_canceller_test_tools.cc", - "test/echo_canceller_test_tools.h", - "test/echo_canceller_test_tools_unittest.cc", - "test/echo_control_mock.h", - "test/test_utils.h", - "voice_detection_unittest.cc", + "../../rtc_base:safe_minmax", + "../../rtc_base:task_queue_for_test", + "../../rtc_base:threading", + "../../rtc_base/synchronization:mutex", + "../../rtc_base/system:arch", + "../../rtc_base/system:file_wrapper", + "../../system_wrappers", + "../../test:fileutils", + "../../test:rtc_expect_death", + "../../test:test_support", + "../audio_coding:neteq_input_audio_tools", + "aec_dump:mock_aec_dump_unittests", + "agc:agc_unittests", + "agc2:adaptive_digital_unittests", + "agc2:biquad_filter_unittests", + "agc2:fixed_digital_unittests", + "agc2:noise_estimator_unittests", + "agc2:rnn_vad_with_level_unittests", + "agc2:test_utils", + "agc2/rnn_vad:unittests", + "capture_levels_adjuster", + "capture_levels_adjuster:capture_levels_adjuster_unittests", + "test/conversational_speech:unittest", + "transient:transient_suppression_unittests", + "utility:legacy_delay_estimator_unittest", + "utility:pffft_wrapper_unittest", + "vad:vad_unittests", + "//testing/gtest", ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + + defines = [] + + if (rtc_prefer_fixed_point) { + defines += [ "WEBRTC_AUDIOPROC_FIXED_PROFILE" ] + } else { + defines += [ "WEBRTC_AUDIOPROC_FLOAT_PROFILE" ] + } + + if (rtc_enable_protobuf) { + defines += [ "WEBRTC_AUDIOPROC_DEBUG_DUMP" ] + deps += [ + ":audioproc_debug_proto", + ":audioproc_protobuf_utils", + ":audioproc_test_utils", + ":audioproc_unittest_proto", + ":optionally_built_submodule_creators", + ":rms_level", + ":runtime_settings_protobuf_utils", + "../../api/audio:audio_frame_api", + "../../api/audio:echo_control", + "../../rtc_base:rtc_base_tests_utils", + "../../rtc_base:rtc_task_queue", + "aec_dump", + "aec_dump:aec_dump_unittests", + ] + absl_deps += [ "//third_party/abseil-cpp/absl/flags:flag" ] + sources += [ + "audio_processing_impl_locking_unittest.cc", + "audio_processing_impl_unittest.cc", + "audio_processing_unittest.cc", + "echo_control_mobile_bit_exact_unittest.cc", + "echo_detector/circular_buffer_unittest.cc", + "echo_detector/mean_variance_estimator_unittest.cc", + "echo_detector/moving_max_unittest.cc", + "echo_detector/normalized_covariance_estimator_unittest.cc", + "gain_control_unittest.cc", + "high_pass_filter_unittest.cc", + "level_estimator_unittest.cc", + "residual_echo_detector_unittest.cc", + "rms_level_unittest.cc", + "test/debug_dump_replayer.cc", + "test/debug_dump_replayer.h", + "test/debug_dump_test.cc", + "test/echo_canceller_test_tools.cc", + "test/echo_canceller_test_tools.h", + "test/echo_canceller_test_tools_unittest.cc", + "test/echo_control_mock.h", + "test/test_utils.h", + "voice_detection_unittest.cc", + ] + } } } @@ -480,7 +485,7 @@ if (rtc_include_tests) { absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } - if (rtc_enable_protobuf) { + if (rtc_enable_protobuf && !build_with_chromium) { rtc_library("audioproc_f_impl") { testonly = true configs += [ ":apm_debug_dump" ] diff --git a/modules/audio_processing/aec3/BUILD.gn b/modules/audio_processing/aec3/BUILD.gn index c98fa4c122..3ce494346f 100644 --- a/modules/audio_processing/aec3/BUILD.gn +++ b/modules/audio_processing/aec3/BUILD.gn @@ -302,7 +302,6 @@ if (rtc_include_tests) { "..:apm_logging", "..:audio_buffer", "..:audio_processing", - "..:audio_processing_unittests", "..:high_pass_filter", "../../../api:array_view", "../../../api/audio:aec3_config", @@ -363,5 +362,9 @@ if (rtc_include_tests) { "vector_math_unittest.cc", ] } + + if (!build_with_chromium) { + deps += [ "..:audio_processing_unittests" ] + } } } diff --git a/modules/audio_processing/aec3/aec_state.cc b/modules/audio_processing/aec3/aec_state.cc index 5b31e3cb9f..21cad2186f 100644 --- a/modules/audio_processing/aec3/aec_state.cc +++ b/modules/audio_processing/aec3/aec_state.cc @@ -113,14 +113,6 @@ void AecState::GetResidualEchoScaling( residual_scaling); } -absl::optional AecState::ErleUncertainty() const { - if (SaturatedEcho()) { - return 1.f; - } - - return absl::nullopt; -} - AecState::AecState(const EchoCanceller3Config& config, size_t num_capture_channels) : data_dumper_( @@ -302,7 +294,9 @@ void AecState::Update( data_dumper_->DumpRaw("aec3_active_render", active_render); data_dumper_->DumpRaw("aec3_erl", Erl()); data_dumper_->DumpRaw("aec3_erl_time_domain", ErlTimeDomain()); - data_dumper_->DumpRaw("aec3_erle", Erle()[0]); + data_dumper_->DumpRaw("aec3_erle", Erle(/*onset_compensated=*/false)[0]); + data_dumper_->DumpRaw("aec3_erle_onset_compensated", + Erle(/*onset_compensated=*/true)[0]); data_dumper_->DumpRaw("aec3_usable_linear_estimate", UsableLinearEstimate()); data_dumper_->DumpRaw("aec3_transparent_mode", TransparentModeActive()); data_dumper_->DumpRaw("aec3_filter_delay", diff --git a/modules/audio_processing/aec3/aec_state.h b/modules/audio_processing/aec3/aec_state.h index 5b40e9513a..e2f70a4c68 100644 --- a/modules/audio_processing/aec3/aec_state.h +++ b/modules/audio_processing/aec3/aec_state.h @@ -70,15 +70,16 @@ class AecState { } // Returns the ERLE. - rtc::ArrayView> Erle() const { - return erle_estimator_.Erle(); + rtc::ArrayView> Erle( + bool onset_compensated) const { + return erle_estimator_.Erle(onset_compensated); } - // Returns an offset to apply to the estimation of the residual echo - // computation. Returning nullopt means that no offset should be used, while - // any other value will be applied as a multiplier to the estimated residual - // echo. - absl::optional ErleUncertainty() const; + // Returns the non-capped ERLE. + rtc::ArrayView> ErleUnbounded() + const { + return erle_estimator_.ErleUnbounded(); + } // Returns the fullband ERLE estimate in log2 units. float FullBandErleLog2() const { return erle_estimator_.FullbandErleLog2(); } diff --git a/modules/audio_processing/aec3/aec_state_unittest.cc b/modules/audio_processing/aec3/aec_state_unittest.cc index c9db8bdb36..6e62a586ed 100644 --- a/modules/audio_processing/aec3/aec_state_unittest.cc +++ b/modules/audio_processing/aec3/aec_state_unittest.cc @@ -182,7 +182,7 @@ void RunNormalUsageTest(size_t num_render_channels, { // Note that the render spectrum is built so it does not have energy in // the odd bands but just in the even bands. - const auto& erle = state.Erle()[0]; + const auto& erle = state.Erle(/*onset_compensated=*/true)[0]; EXPECT_EQ(erle[0], erle[1]); constexpr size_t kLowFrequencyLimit = 32; for (size_t k = 2; k < kLowFrequencyLimit; k = k + 2) { @@ -210,7 +210,7 @@ void RunNormalUsageTest(size_t num_render_channels, ASSERT_TRUE(state.UsableLinearEstimate()); { - const auto& erle = state.Erle()[0]; + const auto& erle = state.Erle(/*onset_compensated=*/true)[0]; EXPECT_EQ(erle[0], erle[1]); constexpr size_t kLowFrequencyLimit = 32; for (size_t k = 1; k < kLowFrequencyLimit; ++k) { diff --git a/modules/audio_processing/aec3/block_processor.cc b/modules/audio_processing/aec3/block_processor.cc index f2f3261489..2ee32b82dc 100644 --- a/modules/audio_processing/aec3/block_processor.cc +++ b/modules/audio_processing/aec3/block_processor.cc @@ -63,6 +63,7 @@ class BlockProcessorImpl final : public BlockProcessor { void GetMetrics(EchoControl::Metrics* metrics) const override; void SetAudioBufferDelay(int delay_ms) override; + void SetCaptureOutputUsage(bool capture_output_used) override; private: static int instance_count_; @@ -237,6 +238,10 @@ void BlockProcessorImpl::SetAudioBufferDelay(int delay_ms) { render_buffer_->SetAudioBufferDelay(delay_ms); } +void BlockProcessorImpl::SetCaptureOutputUsage(bool capture_output_used) { + echo_remover_->SetCaptureOutputUsage(capture_output_used); +} + } // namespace BlockProcessor* BlockProcessor::Create(const EchoCanceller3Config& config, diff --git a/modules/audio_processing/aec3/block_processor.h b/modules/audio_processing/aec3/block_processor.h index 9bb0cf19f3..41ce016dc0 100644 --- a/modules/audio_processing/aec3/block_processor.h +++ b/modules/audio_processing/aec3/block_processor.h @@ -69,6 +69,12 @@ class BlockProcessor { // Reports whether echo leakage has been detected in the echo canceller // output. virtual void UpdateEchoLeakageStatus(bool leakage_detected) = 0; + + // Specifies whether the capture output will be used. The purpose of this is + // to allow the block processor to deactivate some of the processing when the + // resulting output is anyway not used, for instance when the endpoint is + // muted. + virtual void SetCaptureOutputUsage(bool capture_output_used) = 0; }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/echo_canceller3.cc b/modules/audio_processing/aec3/echo_canceller3.cc index 98da232bba..181b649f6d 100644 --- a/modules/audio_processing/aec3/echo_canceller3.cc +++ b/modules/audio_processing/aec3/echo_canceller3.cc @@ -49,7 +49,11 @@ void RetrieveFieldTrialValue(const char* trial_name, ParseFieldTrial({&field_trial_param}, field_trial_str); float field_trial_value = static_cast(field_trial_param.Get()); - if (field_trial_value >= min && field_trial_value <= max) { + if (field_trial_value >= min && field_trial_value <= max && + field_trial_value != *value_to_update) { + RTC_LOG(LS_INFO) << "Key " << trial_name + << " changing AEC3 parameter value from " + << *value_to_update << " to " << field_trial_value; *value_to_update = field_trial_value; } } @@ -65,7 +69,11 @@ void RetrieveFieldTrialValue(const char* trial_name, ParseFieldTrial({&field_trial_param}, field_trial_str); float field_trial_value = field_trial_param.Get(); - if (field_trial_value >= min && field_trial_value <= max) { + if (field_trial_value >= min && field_trial_value <= max && + field_trial_value != *value_to_update) { + RTC_LOG(LS_INFO) << "Key " << trial_name + << " changing AEC3 parameter value from " + << *value_to_update << " to " << field_trial_value; *value_to_update = field_trial_value; } } @@ -251,6 +259,10 @@ EchoCanceller3Config AdjustConfig(const EchoCanceller3Config& config) { adjusted_cfg.filter.initial_state_seconds = 2.0f; } + if (field_trial::IsEnabled("WebRTC-Aec3HighPassFilterEchoReference")) { + adjusted_cfg.filter.high_pass_filter_echo_reference = true; + } + if (field_trial::IsEnabled("WebRTC-Aec3EchoSaturationDetectionKillSwitch")) { adjusted_cfg.ep_strength.echo_can_saturate = false; } @@ -568,12 +580,19 @@ EchoCanceller3Config AdjustConfig(const EchoCanceller3Config& config) { RetrieveFieldTrialValue("WebRTC-Aec3SuppressorEpStrengthDefaultLenOverride", -1.f, 1.f, &adjusted_cfg.ep_strength.default_len); + // Field trial-based overrides of individual delay estimator parameters. + RetrieveFieldTrialValue("WebRTC-Aec3DelayEstimateSmoothingOverride", 0.f, 1.f, + &adjusted_cfg.delay.delay_estimate_smoothing); + RetrieveFieldTrialValue( + "WebRTC-Aec3DelayEstimateSmoothingDelayFoundOverride", 0.f, 1.f, + &adjusted_cfg.delay.delay_estimate_smoothing_delay_found); return adjusted_cfg; } class EchoCanceller3::RenderWriter { public: RenderWriter(ApmDataDumper* data_dumper, + const EchoCanceller3Config& config, SwapQueue>>, Aec3RenderQueueItemVerifier>* render_transfer_queue, size_t num_bands, @@ -590,7 +609,7 @@ class EchoCanceller3::RenderWriter { ApmDataDumper* data_dumper_; const size_t num_bands_; const size_t num_channels_; - HighPassFilter high_pass_filter_; + std::unique_ptr high_pass_filter_; std::vector>> render_queue_input_frame_; SwapQueue>>, Aec3RenderQueueItemVerifier>* render_transfer_queue_; @@ -598,6 +617,7 @@ class EchoCanceller3::RenderWriter { EchoCanceller3::RenderWriter::RenderWriter( ApmDataDumper* data_dumper, + const EchoCanceller3Config& config, SwapQueue>>, Aec3RenderQueueItemVerifier>* render_transfer_queue, size_t num_bands, @@ -605,7 +625,6 @@ EchoCanceller3::RenderWriter::RenderWriter( : data_dumper_(data_dumper), num_bands_(num_bands), num_channels_(num_channels), - high_pass_filter_(16000, num_channels), render_queue_input_frame_( num_bands_, std::vector>( @@ -613,6 +632,9 @@ EchoCanceller3::RenderWriter::RenderWriter( std::vector(AudioBuffer::kSplitBandSize, 0.f))), render_transfer_queue_(render_transfer_queue) { RTC_DCHECK(data_dumper); + if (config.filter.high_pass_filter_echo_reference) { + high_pass_filter_ = std::make_unique(16000, num_channels); + } } EchoCanceller3::RenderWriter::~RenderWriter() = default; @@ -631,7 +653,9 @@ void EchoCanceller3::RenderWriter::Insert(const AudioBuffer& input) { CopyBufferIntoFrame(input, num_bands_, num_channels_, &render_queue_input_frame_); - high_pass_filter_.Process(&render_queue_input_frame_[0]); + if (high_pass_filter_) { + high_pass_filter_->Process(&render_queue_input_frame_[0]); + } static_cast(render_transfer_queue_->Insert(&render_queue_input_frame_)); } @@ -704,7 +728,7 @@ EchoCanceller3::EchoCanceller3(const EchoCanceller3Config& config, config_.delay.fixed_capture_delay_samples)); } - render_writer_.reset(new RenderWriter(data_dumper_.get(), + render_writer_.reset(new RenderWriter(data_dumper_.get(), config_, &render_transfer_queue_, num_bands_, num_render_channels_)); @@ -721,6 +745,10 @@ EchoCanceller3::EchoCanceller3(const EchoCanceller3Config& config, std::vector>>( 1, std::vector>(num_capture_channels_)); } + + RTC_LOG(LS_INFO) << "AEC3 created with sample rate: " << sample_rate_hz_ + << " Hz, num render channels: " << num_render_channels_ + << ", num capture channels: " << num_capture_channels_; } EchoCanceller3::~EchoCanceller3() = default; @@ -823,6 +851,11 @@ void EchoCanceller3::SetAudioBufferDelay(int delay_ms) { block_processor_->SetAudioBufferDelay(delay_ms); } +void EchoCanceller3::SetCaptureOutputUsage(bool capture_output_used) { + RTC_DCHECK_RUNS_SERIALIZED(&capture_race_checker_); + block_processor_->SetCaptureOutputUsage(capture_output_used); +} + bool EchoCanceller3::ActiveProcessing() const { return true; } diff --git a/modules/audio_processing/aec3/echo_canceller3.h b/modules/audio_processing/aec3/echo_canceller3.h index bacd5dfc48..a4aab4987f 100644 --- a/modules/audio_processing/aec3/echo_canceller3.h +++ b/modules/audio_processing/aec3/echo_canceller3.h @@ -118,6 +118,12 @@ class EchoCanceller3 : public EchoControl { // Provides an optional external estimate of the audio buffer delay. void SetAudioBufferDelay(int delay_ms) override; + // Specifies whether the capture output will be used. The purpose of this is + // to allow the echo controller to deactivate some of the processing when the + // resulting output is anyway not used, for instance when the endpoint is + // muted. + void SetCaptureOutputUsage(bool capture_output_used) override; + bool ActiveProcessing() const override; // Signals whether an external detector has detected echo leakage from the diff --git a/modules/audio_processing/aec3/echo_canceller3_unittest.cc b/modules/audio_processing/aec3/echo_canceller3_unittest.cc index a02cfa3904..4a3c466712 100644 --- a/modules/audio_processing/aec3/echo_canceller3_unittest.cc +++ b/modules/audio_processing/aec3/echo_canceller3_unittest.cc @@ -131,6 +131,8 @@ class CaptureTransportVerificationProcessor : public BlockProcessor { void GetMetrics(EchoControl::Metrics* metrics) const override {} void SetAudioBufferDelay(int delay_ms) override {} + + void SetCaptureOutputUsage(bool capture_output_used) {} }; // Class for testing that the render data is properly received by the block @@ -169,6 +171,8 @@ class RenderTransportVerificationProcessor : public BlockProcessor { void SetAudioBufferDelay(int delay_ms) override {} + void SetCaptureOutputUsage(bool capture_output_used) {} + private: std::deque>>> received_render_blocks_; @@ -252,8 +256,6 @@ class EchoCanceller3Tester { capture_output.push_back(capture_buffer_.split_bands(0)[0][k]); } } - HighPassFilter hp_filter(16000, 1); - hp_filter.Process(&render_input); EXPECT_TRUE( VerifyOutputFrameBitexactness(render_input[0], capture_output, -64)); @@ -545,8 +547,6 @@ class EchoCanceller3Tester { capture_output.push_back(capture_buffer_.split_bands(0)[0][k]); } } - HighPassFilter hp_filter(16000, 1); - hp_filter.Process(&render_input); EXPECT_TRUE( VerifyOutputFrameBitexactness(render_input[0], capture_output, -64)); diff --git a/modules/audio_processing/aec3/echo_path_delay_estimator.cc b/modules/audio_processing/aec3/echo_path_delay_estimator.cc index 2c987f9341..8a78834143 100644 --- a/modules/audio_processing/aec3/echo_path_delay_estimator.cc +++ b/modules/audio_processing/aec3/echo_path_delay_estimator.cc @@ -42,6 +42,7 @@ EchoPathDelayEstimator::EchoPathDelayEstimator( ? config.render_levels.poor_excitation_render_limit_ds8 : config.render_levels.poor_excitation_render_limit, config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, config.delay.delay_candidate_detection_threshold), matched_filter_lag_aggregator_(data_dumper_, matched_filter_.GetMaxFilterLag(), @@ -71,7 +72,8 @@ absl::optional EchoPathDelayEstimator::EstimateDelay( data_dumper_->DumpWav("aec3_capture_decimator_output", downsampled_capture.size(), downsampled_capture.data(), 16000 / down_sampling_factor_, 1); - matched_filter_.Update(render_buffer, downsampled_capture); + matched_filter_.Update(render_buffer, downsampled_capture, + matched_filter_lag_aggregator_.ReliableDelayFound()); absl::optional aggregated_matched_filter_lag = matched_filter_lag_aggregator_.Aggregate( diff --git a/modules/audio_processing/aec3/echo_remover.cc b/modules/audio_processing/aec3/echo_remover.cc index df539bfad0..2bfaa951d8 100644 --- a/modules/audio_processing/aec3/echo_remover.cc +++ b/modules/audio_processing/aec3/echo_remover.cc @@ -132,6 +132,10 @@ class EchoRemoverImpl final : public EchoRemover { echo_leakage_detected_ = leakage_detected; } + void SetCaptureOutputUsage(bool capture_output_used) override { + capture_output_used_ = capture_output_used; + } + private: // Selects which of the coarse and refined linear filter outputs that is most // appropriate to pass to the suppressor and forms the linear filter output by @@ -155,6 +159,7 @@ class EchoRemoverImpl final : public EchoRemover { RenderSignalAnalyzer render_signal_analyzer_; ResidualEchoEstimator residual_echo_estimator_; bool echo_leakage_detected_ = false; + bool capture_output_used_ = true; AecState aec_state_; EchoRemoverMetrics metrics_; std::vector> e_old_; @@ -167,6 +172,7 @@ class EchoRemoverImpl final : public EchoRemover { std::vector> Y2_heap_; std::vector> E2_heap_; std::vector> R2_heap_; + std::vector> R2_unbounded_heap_; std::vector> S2_linear_heap_; std::vector Y_heap_; std::vector E_heap_; @@ -213,6 +219,7 @@ EchoRemoverImpl::EchoRemoverImpl(const EchoCanceller3Config& config, Y2_heap_(NumChannelsOnHeap(num_capture_channels_)), E2_heap_(NumChannelsOnHeap(num_capture_channels_)), R2_heap_(NumChannelsOnHeap(num_capture_channels_)), + R2_unbounded_heap_(NumChannelsOnHeap(num_capture_channels_)), S2_linear_heap_(NumChannelsOnHeap(num_capture_channels_)), Y_heap_(NumChannelsOnHeap(num_capture_channels_)), E_heap_(NumChannelsOnHeap(num_capture_channels_)), @@ -259,6 +266,8 @@ void EchoRemoverImpl::ProcessCapture( E2_stack; std::array, kMaxNumChannelsOnStack> R2_stack; + std::array, kMaxNumChannelsOnStack> + R2_unbounded_stack; std::array, kMaxNumChannelsOnStack> S2_linear_stack; std::array Y_stack; @@ -275,6 +284,8 @@ void EchoRemoverImpl::ProcessCapture( E2_stack.data(), num_capture_channels_); rtc::ArrayView> R2( R2_stack.data(), num_capture_channels_); + rtc::ArrayView> R2_unbounded( + R2_unbounded_stack.data(), num_capture_channels_); rtc::ArrayView> S2_linear( S2_linear_stack.data(), num_capture_channels_); rtc::ArrayView Y(Y_stack.data(), num_capture_channels_); @@ -296,6 +307,8 @@ void EchoRemoverImpl::ProcessCapture( E2_heap_.data(), num_capture_channels_); R2 = rtc::ArrayView>( R2_heap_.data(), num_capture_channels_); + R2_unbounded = rtc::ArrayView>( + R2_unbounded_heap_.data(), num_capture_channels_); S2_linear = rtc::ArrayView>( S2_linear_heap_.data(), num_capture_channels_); Y = rtc::ArrayView(Y_heap_.data(), num_capture_channels_); @@ -391,42 +404,50 @@ void EchoRemoverImpl::ProcessCapture( 1); data_dumper_->DumpWav("aec3_output_linear2", kBlockSize, &e[0][0], 16000, 1); - // Estimate the residual echo power. - residual_echo_estimator_.Estimate(aec_state_, *render_buffer, S2_linear, Y2, - R2); - // Estimate the comfort noise. cng_.Compute(aec_state_.SaturatedCapture(), Y2, comfort_noise, high_band_comfort_noise); - // Suppressor nearend estimate. - if (aec_state_.UsableLinearEstimate()) { - // E2 is bound by Y2. - for (size_t ch = 0; ch < num_capture_channels_; ++ch) { - std::transform(E2[ch].begin(), E2[ch].end(), Y2[ch].begin(), - E2[ch].begin(), - [](float a, float b) { return std::min(a, b); }); + // Only do the below processing if the output of the audio processing module + // is used. + std::array G; + if (capture_output_used_) { + // Estimate the residual echo power. + residual_echo_estimator_.Estimate(aec_state_, *render_buffer, S2_linear, Y2, + suppression_gain_.IsDominantNearend(), R2, + R2_unbounded); + + // Suppressor nearend estimate. + if (aec_state_.UsableLinearEstimate()) { + // E2 is bound by Y2. + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + std::transform(E2[ch].begin(), E2[ch].end(), Y2[ch].begin(), + E2[ch].begin(), + [](float a, float b) { return std::min(a, b); }); + } } - } - const auto& nearend_spectrum = aec_state_.UsableLinearEstimate() ? E2 : Y2; + const auto& nearend_spectrum = aec_state_.UsableLinearEstimate() ? E2 : Y2; - // Suppressor echo estimate. - const auto& echo_spectrum = - aec_state_.UsableLinearEstimate() ? S2_linear : R2; + // Suppressor echo estimate. + const auto& echo_spectrum = + aec_state_.UsableLinearEstimate() ? S2_linear : R2; - // Determine if the suppressor should assume clock drift. - const bool clock_drift = config_.echo_removal_control.has_clock_drift || - echo_path_variability.clock_drift; + // Determine if the suppressor should assume clock drift. + const bool clock_drift = config_.echo_removal_control.has_clock_drift || + echo_path_variability.clock_drift; - // Compute preferred gains. - float high_bands_gain; - std::array G; - suppression_gain_.GetGain(nearend_spectrum, echo_spectrum, R2, - cng_.NoiseSpectrum(), render_signal_analyzer_, - aec_state_, x, clock_drift, &high_bands_gain, &G); + // Compute preferred gains. + float high_bands_gain; + suppression_gain_.GetGain(nearend_spectrum, echo_spectrum, R2, R2_unbounded, + cng_.NoiseSpectrum(), render_signal_analyzer_, + aec_state_, x, clock_drift, &high_bands_gain, &G); - suppression_filter_.ApplyGain(comfort_noise, high_band_comfort_noise, G, - high_bands_gain, Y_fft, y); + suppression_filter_.ApplyGain(comfort_noise, high_band_comfort_noise, G, + high_bands_gain, Y_fft, y); + + } else { + G.fill(0.f); + } // Update the metrics. metrics_.Update(aec_state_, cng_.NoiseSpectrum()[0], G); diff --git a/modules/audio_processing/aec3/echo_remover.h b/modules/audio_processing/aec3/echo_remover.h index ef4164688b..486a9a72f4 100644 --- a/modules/audio_processing/aec3/echo_remover.h +++ b/modules/audio_processing/aec3/echo_remover.h @@ -48,6 +48,12 @@ class EchoRemover { // Updates the status on whether echo leakage is detected in the output of the // echo remover. virtual void UpdateEchoLeakageStatus(bool leakage_detected) = 0; + + // Specifies whether the capture output will be used. The purpose of this is + // to allow the echo remover to deactivate some of the processing when the + // resulting output is anyway not used, for instance when the endpoint is + // muted. + virtual void SetCaptureOutputUsage(bool capture_output_used) = 0; }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/erle_estimator.cc b/modules/audio_processing/aec3/erle_estimator.cc index 4d843457d3..0e3d715c59 100644 --- a/modules/audio_processing/aec3/erle_estimator.cc +++ b/modules/audio_processing/aec3/erle_estimator.cc @@ -52,8 +52,9 @@ void ErleEstimator::Update( rtc::ArrayView> subtractor_spectra, const std::vector& converged_filters) { - RTC_DCHECK_EQ(subband_erle_estimator_.Erle().size(), capture_spectra.size()); - RTC_DCHECK_EQ(subband_erle_estimator_.Erle().size(), + RTC_DCHECK_EQ(subband_erle_estimator_.Erle(/*onset_compensated=*/true).size(), + capture_spectra.size()); + RTC_DCHECK_EQ(subband_erle_estimator_.Erle(/*onset_compensated=*/true).size(), subtractor_spectra.size()); const auto& X2_reverb = avg_render_spectrum_with_reverb; const auto& Y2 = capture_spectra; @@ -68,7 +69,9 @@ void ErleEstimator::Update( if (signal_dependent_erle_estimator_) { signal_dependent_erle_estimator_->Update( render_buffer, filter_frequency_responses, X2_reverb, Y2, E2, - subband_erle_estimator_.Erle(), converged_filters); + subband_erle_estimator_.Erle(/*onset_compensated=*/false), + subband_erle_estimator_.Erle(/*onset_compensated=*/true), + converged_filters); } fullband_erle_estimator_.Update(X2_reverb, Y2, E2, converged_filters); diff --git a/modules/audio_processing/aec3/erle_estimator.h b/modules/audio_processing/aec3/erle_estimator.h index d741cff3da..55797592a9 100644 --- a/modules/audio_processing/aec3/erle_estimator.h +++ b/modules/audio_processing/aec3/erle_estimator.h @@ -55,17 +55,30 @@ class ErleEstimator { const std::vector& converged_filters); // Returns the most recent subband ERLE estimates. - rtc::ArrayView> Erle() const { + rtc::ArrayView> Erle( + bool onset_compensated) const { return signal_dependent_erle_estimator_ - ? signal_dependent_erle_estimator_->Erle() - : subband_erle_estimator_.Erle(); + ? signal_dependent_erle_estimator_->Erle(onset_compensated) + : subband_erle_estimator_.Erle(onset_compensated); + } + + // Returns the non-capped subband ERLE. + rtc::ArrayView> ErleUnbounded() + const { + // Unbounded ERLE is only used with the subband erle estimator where the + // ERLE is often capped at low values. When the signal dependent ERLE + // estimator is used the capped ERLE is returned. + return !signal_dependent_erle_estimator_ + ? subband_erle_estimator_.ErleUnbounded() + : signal_dependent_erle_estimator_->Erle( + /*onset_compensated=*/false); } // Returns the subband ERLE that are estimated during onsets (only used for // testing). - rtc::ArrayView> ErleOnsets() + rtc::ArrayView> ErleDuringOnsets() const { - return subband_erle_estimator_.ErleOnsets(); + return subband_erle_estimator_.ErleDuringOnsets(); } // Returns the fullband ERLE estimate. diff --git a/modules/audio_processing/aec3/erle_estimator_unittest.cc b/modules/audio_processing/aec3/erle_estimator_unittest.cc index 2a5a98d29f..e38f2386f7 100644 --- a/modules/audio_processing/aec3/erle_estimator_unittest.cc +++ b/modules/audio_processing/aec3/erle_estimator_unittest.cc @@ -50,6 +50,16 @@ void VerifyErle( EXPECT_NEAR(kTrueErle, erle_time_domain, 0.5); } +void VerifyErleGreaterOrEqual( + rtc::ArrayView> erle1, + rtc::ArrayView> erle2) { + for (size_t ch = 0; ch < erle1.size(); ++ch) { + for (size_t i = 0; i < kFftLengthBy2Plus1; ++i) { + EXPECT_GE(erle1[ch][i], erle2[ch][i]); + } + } +} + void FormFarendTimeFrame(std::vector>>* x) { const std::array frame = { 7459.88, 17209.6, 17383, 20768.9, 16816.7, 18386.3, 4492.83, 9675.85, @@ -156,9 +166,10 @@ TEST_P(ErleEstimatorMultiChannel, VerifyErleIncreaseAndHold) { kNumBands, std::vector>( num_render_channels, std::vector(kBlockSize, 0.f))); std::vector>> - filter_frequency_response( - config.filter.refined.length_blocks, - std::vector>(num_capture_channels)); + filter_frequency_response( + config.filter.refined.length_blocks, + std::vector>( + num_capture_channels)); std::unique_ptr render_delay_buffer( RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels)); @@ -178,8 +189,13 @@ TEST_P(ErleEstimatorMultiChannel, VerifyErleIncreaseAndHold) { estimator.Update(*render_delay_buffer->GetRenderBuffer(), filter_frequency_response, X2, Y2, E2, converged_filters); } - VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()), - config.erle.max_l, config.erle.max_h); + VerifyErle(estimator.Erle(/*onset_compensated=*/true), + std::pow(2.f, estimator.FullbandErleLog2()), config.erle.max_l, + config.erle.max_h); + VerifyErleGreaterOrEqual(estimator.Erle(/*onset_compensated=*/false), + estimator.Erle(/*onset_compensated=*/true)); + VerifyErleGreaterOrEqual(estimator.ErleUnbounded(), + estimator.Erle(/*onset_compensated=*/false)); FormNearendFrame(&x, &X2, E2, Y2); // Verifies that the ERLE is not immediately decreased during nearend @@ -190,8 +206,13 @@ TEST_P(ErleEstimatorMultiChannel, VerifyErleIncreaseAndHold) { estimator.Update(*render_delay_buffer->GetRenderBuffer(), filter_frequency_response, X2, Y2, E2, converged_filters); } - VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()), - config.erle.max_l, config.erle.max_h); + VerifyErle(estimator.Erle(/*onset_compensated=*/true), + std::pow(2.f, estimator.FullbandErleLog2()), config.erle.max_l, + config.erle.max_h); + VerifyErleGreaterOrEqual(estimator.Erle(/*onset_compensated=*/false), + estimator.Erle(/*onset_compensated=*/true)); + VerifyErleGreaterOrEqual(estimator.ErleUnbounded(), + estimator.Erle(/*onset_compensated=*/false)); } TEST_P(ErleEstimatorMultiChannel, VerifyErleTrackingOnOnsets) { @@ -210,9 +231,10 @@ TEST_P(ErleEstimatorMultiChannel, VerifyErleTrackingOnOnsets) { kNumBands, std::vector>( num_render_channels, std::vector(kBlockSize, 0.f))); std::vector>> - filter_frequency_response( - config.filter.refined.length_blocks, - std::vector>(num_capture_channels)); + filter_frequency_response( + config.filter.refined.length_blocks, + std::vector>( + num_capture_channels)); std::unique_ptr render_delay_buffer( RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels)); @@ -253,7 +275,8 @@ TEST_P(ErleEstimatorMultiChannel, VerifyErleTrackingOnOnsets) { converged_filters); } } - VerifyErleBands(estimator.ErleOnsets(), config.erle.min, config.erle.min); + VerifyErleBands(estimator.ErleDuringOnsets(), config.erle.min, + config.erle.min); FormNearendFrame(&x, &X2, E2, Y2); for (size_t k = 0; k < 1000; k++) { estimator.Update(*render_delay_buffer->GetRenderBuffer(), @@ -261,8 +284,9 @@ TEST_P(ErleEstimatorMultiChannel, VerifyErleTrackingOnOnsets) { } // Verifies that during ne activity, Erle converges to the Erle for // onsets. - VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()), - config.erle.min, config.erle.min); + VerifyErle(estimator.Erle(/*onset_compensated=*/true), + std::pow(2.f, estimator.FullbandErleLog2()), config.erle.min, + config.erle.min); } } // namespace webrtc diff --git a/modules/audio_processing/aec3/matched_filter.cc b/modules/audio_processing/aec3/matched_filter.cc index 64b2d4e697..1721e9c983 100644 --- a/modules/audio_processing/aec3/matched_filter.cc +++ b/modules/audio_processing/aec3/matched_filter.cc @@ -307,7 +307,8 @@ MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, int num_matched_filters, size_t alignment_shift_sub_blocks, float excitation_limit, - float smoothing, + float smoothing_fast, + float smoothing_slow, float matching_filter_threshold) : data_dumper_(data_dumper), optimization_(optimization), @@ -319,7 +320,8 @@ MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, lag_estimates_(num_matched_filters), filters_offsets_(num_matched_filters, 0), excitation_limit_(excitation_limit), - smoothing_(smoothing), + smoothing_fast_(smoothing_fast), + smoothing_slow_(smoothing_slow), matching_filter_threshold_(matching_filter_threshold) { RTC_DCHECK(data_dumper); RTC_DCHECK_LT(0, window_size_sub_blocks); @@ -340,10 +342,14 @@ void MatchedFilter::Reset() { } void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer, - rtc::ArrayView capture) { + rtc::ArrayView capture, + bool use_slow_smoothing) { RTC_DCHECK_EQ(sub_block_size_, capture.size()); auto& y = capture; + const float smoothing = + use_slow_smoothing ? smoothing_slow_ : smoothing_fast_; + const float x2_sum_threshold = filters_[0].size() * excitation_limit_ * excitation_limit_; @@ -360,25 +366,25 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer, switch (optimization_) { #if defined(WEBRTC_ARCH_X86_FAMILY) case Aec3Optimization::kSse2: - aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold, - smoothing_, render_buffer.buffer, y, - filters_[n], &filters_updated, &error_sum); + aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold, smoothing, + render_buffer.buffer, y, filters_[n], + &filters_updated, &error_sum); break; case Aec3Optimization::kAvx2: - aec3::MatchedFilterCore_AVX2(x_start_index, x2_sum_threshold, - smoothing_, render_buffer.buffer, y, - filters_[n], &filters_updated, &error_sum); + aec3::MatchedFilterCore_AVX2(x_start_index, x2_sum_threshold, smoothing, + render_buffer.buffer, y, filters_[n], + &filters_updated, &error_sum); break; #endif #if defined(WEBRTC_HAS_NEON) case Aec3Optimization::kNeon: - aec3::MatchedFilterCore_NEON(x_start_index, x2_sum_threshold, - smoothing_, render_buffer.buffer, y, - filters_[n], &filters_updated, &error_sum); + aec3::MatchedFilterCore_NEON(x_start_index, x2_sum_threshold, smoothing, + render_buffer.buffer, y, filters_[n], + &filters_updated, &error_sum); break; #endif default: - aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing_, + aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, filters_[n], &filters_updated, &error_sum); } diff --git a/modules/audio_processing/aec3/matched_filter.h b/modules/audio_processing/aec3/matched_filter.h index fa44eb27fd..c6410ab4ee 100644 --- a/modules/audio_processing/aec3/matched_filter.h +++ b/modules/audio_processing/aec3/matched_filter.h @@ -100,7 +100,8 @@ class MatchedFilter { int num_matched_filters, size_t alignment_shift_sub_blocks, float excitation_limit, - float smoothing, + float smoothing_fast, + float smoothing_slow, float matching_filter_threshold); MatchedFilter() = delete; @@ -111,7 +112,8 @@ class MatchedFilter { // Updates the correlation with the values in the capture buffer. void Update(const DownsampledRenderBuffer& render_buffer, - rtc::ArrayView capture); + rtc::ArrayView capture, + bool use_slow_smoothing); // Resets the matched filter. void Reset(); @@ -140,7 +142,8 @@ class MatchedFilter { std::vector lag_estimates_; std::vector filters_offsets_; const float excitation_limit_; - const float smoothing_; + const float smoothing_fast_; + const float smoothing_slow_; const float matching_filter_threshold_; }; diff --git a/modules/audio_processing/aec3/matched_filter_lag_aggregator.h b/modules/audio_processing/aec3/matched_filter_lag_aggregator.h index d48011e477..612bd5d942 100644 --- a/modules/audio_processing/aec3/matched_filter_lag_aggregator.h +++ b/modules/audio_processing/aec3/matched_filter_lag_aggregator.h @@ -45,6 +45,9 @@ class MatchedFilterLagAggregator { absl::optional Aggregate( rtc::ArrayView lag_estimates); + // Returns whether a reliable delay estimate has been found. + bool ReliableDelayFound() const { return significant_candidate_found_; } + private: ApmDataDumper* const data_dumper_; std::vector histogram_; diff --git a/modules/audio_processing/aec3/matched_filter_unittest.cc b/modules/audio_processing/aec3/matched_filter_unittest.cc index 137275fd74..37b51fa624 100644 --- a/modules/audio_processing/aec3/matched_filter_unittest.cc +++ b/modules/audio_processing/aec3/matched_filter_unittest.cc @@ -206,6 +206,7 @@ TEST(MatchedFilter, LagEstimation) { kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150, config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, config.delay.delay_candidate_detection_threshold); std::unique_ptr render_delay_buffer( @@ -231,7 +232,7 @@ TEST(MatchedFilter, LagEstimation) { downsampled_capture_data.data(), sub_block_size); capture_decimator.Decimate(capture[0], downsampled_capture); filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), - downsampled_capture); + downsampled_capture, false); } // Obtain the lag estimates. @@ -318,6 +319,7 @@ TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) { kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150, config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, config.delay.delay_candidate_detection_threshold); // Analyze the correlation between render and capture. @@ -325,7 +327,8 @@ TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) { RandomizeSampleVector(&random_generator, render[0][0]); RandomizeSampleVector(&random_generator, capture); render_delay_buffer->Insert(render); - filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), capture); + filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), capture, + false); } // Obtain the lag estimates. @@ -361,6 +364,7 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) { kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150, config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, config.delay.delay_candidate_detection_threshold); std::unique_ptr render_delay_buffer( RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz, @@ -379,7 +383,7 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) { sub_block_size); capture_decimator.Decimate(capture[0], downsampled_capture); filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), - downsampled_capture); + downsampled_capture, false); } // Obtain the lag estimates. @@ -407,6 +411,7 @@ TEST(MatchedFilter, NumberOfLagEstimates) { MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size, 32, num_matched_filters, 1, 150, config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, config.delay.delay_candidate_detection_threshold); EXPECT_EQ(num_matched_filters, filter.GetLagEstimates().size()); } @@ -421,6 +426,7 @@ TEST(MatchedFilterDeathTest, ZeroWindowSize) { EchoCanceller3Config config; EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 16, 0, 1, 1, 150, config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, config.delay.delay_candidate_detection_threshold), ""); } @@ -430,6 +436,7 @@ TEST(MatchedFilterDeathTest, NullDataDumper) { EchoCanceller3Config config; EXPECT_DEATH(MatchedFilter(nullptr, DetectOptimization(), 16, 1, 1, 1, 150, config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, config.delay.delay_candidate_detection_threshold), ""); } @@ -441,6 +448,7 @@ TEST(MatchedFilterDeathTest, DISABLED_BlockSizeMultipleOf4) { EchoCanceller3Config config; EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 15, 1, 1, 1, 150, config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, config.delay.delay_candidate_detection_threshold), ""); } @@ -453,6 +461,7 @@ TEST(MatchedFilterDeathTest, DISABLED_SubBlockSizeAddsUpToBlockSize) { EchoCanceller3Config config; EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 12, 1, 1, 1, 150, config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, config.delay.delay_candidate_detection_threshold), ""); } diff --git a/modules/audio_processing/aec3/mock/mock_block_processor.h b/modules/audio_processing/aec3/mock/mock_block_processor.h index e1eb26702f..aa612257ea 100644 --- a/modules/audio_processing/aec3/mock/mock_block_processor.h +++ b/modules/audio_processing/aec3/mock/mock_block_processor.h @@ -44,6 +44,10 @@ class MockBlockProcessor : public BlockProcessor { (EchoControl::Metrics * metrics), (const, override)); MOCK_METHOD(void, SetAudioBufferDelay, (int delay_ms), (override)); + MOCK_METHOD(void, + SetCaptureOutputUsage, + (bool capture_output_used), + (override)); }; } // namespace test diff --git a/modules/audio_processing/aec3/mock/mock_echo_remover.h b/modules/audio_processing/aec3/mock/mock_echo_remover.h index 8a3044bcf1..60c5bf433e 100644 --- a/modules/audio_processing/aec3/mock/mock_echo_remover.h +++ b/modules/audio_processing/aec3/mock/mock_echo_remover.h @@ -44,6 +44,10 @@ class MockEchoRemover : public EchoRemover { GetMetrics, (EchoControl::Metrics * metrics), (const, override)); + MOCK_METHOD(void, + SetCaptureOutputUsage, + (bool capture_output_used), + (override)); }; } // namespace test diff --git a/modules/audio_processing/aec3/residual_echo_estimator.cc b/modules/audio_processing/aec3/residual_echo_estimator.cc index e352cf5552..15bebecb5f 100644 --- a/modules/audio_processing/aec3/residual_echo_estimator.cc +++ b/modules/audio_processing/aec3/residual_echo_estimator.cc @@ -45,6 +45,13 @@ float GetLateReflectionsDefaultModeGain( return config.default_gain; } +bool UseErleOnsetCompensationInDominantNearend( + const EchoCanceller3Config::EpStrength& config) { + return config.erle_onset_compensation_in_dominant_nearend || + field_trial::IsEnabled( + "WebRTC-Aec3UseErleOnsetCompensationInDominantNearend"); +} + // Computes the indexes that will be used for computing spectral power over // the blocks surrounding the delay. void GetRenderIndexesToAnalyze( @@ -84,22 +91,6 @@ void LinearEstimate( } } -// Estimates the residual echo power based on an uncertainty estimate of the -// echo return loss enhancement (ERLE) and the linear power estimate. -void LinearEstimate( - rtc::ArrayView> S2_linear, - float erle_uncertainty, - rtc::ArrayView> R2) { - RTC_DCHECK_EQ(S2_linear.size(), R2.size()); - - const size_t num_capture_channels = R2.size(); - for (size_t ch = 0; ch < num_capture_channels; ++ch) { - for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { - R2[ch][k] = S2_linear[ch][k] * erle_uncertainty; - } - } -} - // Estimates the residual echo power based on the estimate of the echo path // gain. void NonLinearEstimate( @@ -172,7 +163,9 @@ ResidualEchoEstimator::ResidualEchoEstimator(const EchoCanceller3Config& config, early_reflections_general_gain_( GetEarlyReflectionsDefaultModeGain(config_.ep_strength)), late_reflections_general_gain_( - GetLateReflectionsDefaultModeGain(config_.ep_strength)) { + GetLateReflectionsDefaultModeGain(config_.ep_strength)), + erle_onset_compensation_in_dominant_nearend_( + UseErleOnsetCompensationInDominantNearend(config_.ep_strength)) { Reset(); } @@ -183,7 +176,9 @@ void ResidualEchoEstimator::Estimate( const RenderBuffer& render_buffer, rtc::ArrayView> S2_linear, rtc::ArrayView> Y2, - rtc::ArrayView> R2) { + bool dominant_nearend, + rtc::ArrayView> R2, + rtc::ArrayView> R2_unbounded) { RTC_DCHECK_EQ(R2.size(), Y2.size()); RTC_DCHECK_EQ(R2.size(), S2_linear.size()); @@ -199,17 +194,18 @@ void ResidualEchoEstimator::Estimate( if (aec_state.SaturatedEcho()) { for (size_t ch = 0; ch < num_capture_channels; ++ch) { std::copy(Y2[ch].begin(), Y2[ch].end(), R2[ch].begin()); + std::copy(Y2[ch].begin(), Y2[ch].end(), R2_unbounded[ch].begin()); } } else { - absl::optional erle_uncertainty = aec_state.ErleUncertainty(); - if (erle_uncertainty) { - LinearEstimate(S2_linear, *erle_uncertainty, R2); - } else { - LinearEstimate(S2_linear, aec_state.Erle(), R2); - } + const bool onset_compensated = + erle_onset_compensation_in_dominant_nearend_ || !dominant_nearend; + LinearEstimate(S2_linear, aec_state.Erle(onset_compensated), R2); + LinearEstimate(S2_linear, aec_state.ErleUnbounded(), R2_unbounded); } - AddReverb(ReverbType::kLinear, aec_state, render_buffer, R2); + UpdateReverb(ReverbType::kLinear, aec_state, render_buffer); + AddReverb(R2); + AddReverb(R2_unbounded); } else { const float echo_path_gain = GetEchoPathGain(aec_state, /*gain_for_early_reflections=*/true); @@ -219,6 +215,7 @@ void ResidualEchoEstimator::Estimate( if (aec_state.SaturatedEcho()) { for (size_t ch = 0; ch < num_capture_channels; ++ch) { std::copy(Y2[ch].begin(), Y2[ch].end(), R2[ch].begin()); + std::copy(Y2[ch].begin(), Y2[ch].end(), R2_unbounded[ch].begin()); } } else { // Estimate the echo generating signal power. @@ -238,11 +235,14 @@ void ResidualEchoEstimator::Estimate( } NonLinearEstimate(echo_path_gain, X2, R2); + NonLinearEstimate(echo_path_gain, X2, R2_unbounded); } if (config_.echo_model.model_reverb_in_nonlinear_mode && !aec_state.TransparentModeActive()) { - AddReverb(ReverbType::kNonLinear, aec_state, render_buffer, R2); + UpdateReverb(ReverbType::kNonLinear, aec_state, render_buffer); + AddReverb(R2); + AddReverb(R2_unbounded); } } @@ -253,6 +253,7 @@ void ResidualEchoEstimator::Estimate( for (size_t ch = 0; ch < num_capture_channels; ++ch) { for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { R2[ch][k] *= residual_scaling[k]; + R2_unbounded[ch][k] *= residual_scaling[k]; } } } @@ -301,14 +302,10 @@ void ResidualEchoEstimator::UpdateRenderNoisePower( } } -// Adds the estimated power of the reverb to the residual echo power. -void ResidualEchoEstimator::AddReverb( - ReverbType reverb_type, - const AecState& aec_state, - const RenderBuffer& render_buffer, - rtc::ArrayView> R2) { - const size_t num_capture_channels = R2.size(); - +// Updates the reverb estimation. +void ResidualEchoEstimator::UpdateReverb(ReverbType reverb_type, + const AecState& aec_state, + const RenderBuffer& render_buffer) { // Choose reverb partition based on what type of echo power model is used. const size_t first_reverb_partition = reverb_type == ReverbType::kLinear @@ -343,6 +340,11 @@ void ResidualEchoEstimator::AddReverb( echo_reverb_.UpdateReverbNoFreqShaping(render_power, echo_path_gain, aec_state.ReverbDecay()); } +} +// Adds the estimated power of the reverb to the residual echo power. +void ResidualEchoEstimator::AddReverb( + rtc::ArrayView> R2) const { + const size_t num_capture_channels = R2.size(); // Add the reverb power. rtc::ArrayView reverb_power = diff --git a/modules/audio_processing/aec3/residual_echo_estimator.h b/modules/audio_processing/aec3/residual_echo_estimator.h index 8fe7a84f04..c071854c4a 100644 --- a/modules/audio_processing/aec3/residual_echo_estimator.h +++ b/modules/audio_processing/aec3/residual_echo_estimator.h @@ -39,7 +39,9 @@ class ResidualEchoEstimator { const RenderBuffer& render_buffer, rtc::ArrayView> S2_linear, rtc::ArrayView> Y2, - rtc::ArrayView> R2); + bool dominant_nearend, + rtc::ArrayView> R2, + rtc::ArrayView> R2_unbounded); private: enum class ReverbType { kLinear, kNonLinear }; @@ -51,12 +53,15 @@ class ResidualEchoEstimator { // render signal. void UpdateRenderNoisePower(const RenderBuffer& render_buffer); + // Updates the reverb estimation. + void UpdateReverb(ReverbType reverb_type, + const AecState& aec_state, + const RenderBuffer& render_buffer); + // Adds the estimated unmodelled echo power to the residual echo power // estimate. - void AddReverb(ReverbType reverb_type, - const AecState& aec_state, - const RenderBuffer& render_buffer, - rtc::ArrayView> R2); + void AddReverb( + rtc::ArrayView> R2) const; // Gets the echo path gain to apply. float GetEchoPathGain(const AecState& aec_state, @@ -68,6 +73,7 @@ class ResidualEchoEstimator { const float late_reflections_transparent_mode_gain_; const float early_reflections_general_gain_; const float late_reflections_general_gain_; + const bool erle_onset_compensation_in_dominant_nearend_; std::array X2_noise_floor_; std::array X2_noise_floor_counter_; ReverbModel echo_reverb_; diff --git a/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc b/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc index f184eb8e6d..3d760b7dda 100644 --- a/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc +++ b/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc @@ -48,6 +48,8 @@ TEST_P(ResidualEchoEstimatorMultiChannel, BasicTest) { num_capture_channels); std::vector> Y2(num_capture_channels); std::vector> R2(num_capture_channels); + std::vector> R2_unbounded( + num_capture_channels); std::vector>> x( kNumBands, std::vector>( num_render_channels, std::vector(kBlockSize, 0.f))); @@ -100,7 +102,8 @@ TEST_P(ResidualEchoEstimatorMultiChannel, BasicTest) { output); estimator.Estimate(aec_state, *render_delay_buffer->GetRenderBuffer(), - S2_linear, Y2, R2); + S2_linear, Y2, /*dominant_nearend=*/false, R2, + R2_unbounded); } } diff --git a/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc b/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc index 5a3ba6c842..a5e77092a6 100644 --- a/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc +++ b/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc @@ -131,7 +131,9 @@ SignalDependentErleEstimator::SignalDependentErleEstimator( section_boundaries_blocks_(SetSectionsBoundaries(delay_headroom_blocks_, num_blocks_, num_sections_)), + use_onset_detection_(config.erle.onset_detection), erle_(num_capture_channels), + erle_onset_compensated_(num_capture_channels), S2_section_accum_( num_capture_channels, std::vector>(num_sections_)), @@ -154,6 +156,7 @@ SignalDependentErleEstimator::~SignalDependentErleEstimator() = default; void SignalDependentErleEstimator::Reset() { for (size_t ch = 0; ch < erle_.size(); ++ch) { erle_[ch].fill(min_erle_); + erle_onset_compensated_[ch].fill(min_erle_); for (auto& erle_estimator : erle_estimators_[ch]) { erle_estimator.fill(min_erle_); } @@ -180,6 +183,8 @@ void SignalDependentErleEstimator::Update( rtc::ArrayView> Y2, rtc::ArrayView> E2, rtc::ArrayView> average_erle, + rtc::ArrayView> + average_erle_onset_compensated, const std::vector& converged_filters) { RTC_DCHECK_GT(num_sections_, 1); @@ -202,6 +207,11 @@ void SignalDependentErleEstimator::Update( [band_to_subband_[k]]; erle_[ch][k] = rtc::SafeClamp(average_erle[ch][k] * correction_factor, min_erle_, max_erle_[band_to_subband_[k]]); + if (use_onset_detection_) { + erle_onset_compensated_[ch][k] = rtc::SafeClamp( + average_erle_onset_compensated[ch][k] * correction_factor, + min_erle_, max_erle_[band_to_subband_[k]]); + } } } } diff --git a/modules/audio_processing/aec3/signal_dependent_erle_estimator.h b/modules/audio_processing/aec3/signal_dependent_erle_estimator.h index 498e922f13..6847c1ab13 100644 --- a/modules/audio_processing/aec3/signal_dependent_erle_estimator.h +++ b/modules/audio_processing/aec3/signal_dependent_erle_estimator.h @@ -37,8 +37,10 @@ class SignalDependentErleEstimator { void Reset(); // Returns the Erle per frequency subband. - rtc::ArrayView> Erle() const { - return erle_; + rtc::ArrayView> Erle( + bool onset_compensated) const { + return onset_compensated && use_onset_detection_ ? erle_onset_compensated_ + : erle_; } // Updates the Erle estimate. The Erle that is passed as an input is required @@ -51,6 +53,8 @@ class SignalDependentErleEstimator { rtc::ArrayView> Y2, rtc::ArrayView> E2, rtc::ArrayView> average_erle, + rtc::ArrayView> + average_erle_onset_compensated, const std::vector& converged_filters); void Dump(const std::unique_ptr& data_dumper) const; @@ -83,7 +87,9 @@ class SignalDependentErleEstimator { const std::array band_to_subband_; const std::array max_erle_; const std::vector section_boundaries_blocks_; + const bool use_onset_detection_; std::vector> erle_; + std::vector> erle_onset_compensated_; std::vector>> S2_section_accum_; std::vector>> erle_estimators_; diff --git a/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc b/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc index f8a4aece89..58f56d8d53 100644 --- a/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc +++ b/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc @@ -172,7 +172,7 @@ TEST_P(SignalDependentErleEstimatorMultiChannel, SweepSettings) { for (size_t n = 0; n < 10; ++n) { inputs.Update(); s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(), - inputs.GetY2(), inputs.GetE2(), average_erle, + inputs.GetY2(), inputs.GetE2(), average_erle, average_erle, inputs.GetConvergedFilters()); } } @@ -201,7 +201,7 @@ TEST_P(SignalDependentErleEstimatorMultiChannel, LongerRun) { for (size_t n = 0; n < 200; ++n) { inputs.Update(); s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(), - inputs.GetY2(), inputs.GetE2(), average_erle, + inputs.GetY2(), inputs.GetE2(), average_erle, average_erle, inputs.GetConvergedFilters()); } } diff --git a/modules/audio_processing/aec3/subband_erle_estimator.cc b/modules/audio_processing/aec3/subband_erle_estimator.cc index 6c00091266..dc7f92fd99 100644 --- a/modules/audio_processing/aec3/subband_erle_estimator.cc +++ b/modules/audio_processing/aec3/subband_erle_estimator.cc @@ -48,7 +48,9 @@ SubbandErleEstimator::SubbandErleEstimator(const EchoCanceller3Config& config, use_min_erle_during_onsets_(EnableMinErleDuringOnsets()), accum_spectra_(num_capture_channels), erle_(num_capture_channels), - erle_onsets_(num_capture_channels), + erle_onset_compensated_(num_capture_channels), + erle_unbounded_(num_capture_channels), + erle_during_onsets_(num_capture_channels), coming_onset_(num_capture_channels), hold_counters_(num_capture_channels) { Reset(); @@ -57,11 +59,12 @@ SubbandErleEstimator::SubbandErleEstimator(const EchoCanceller3Config& config, SubbandErleEstimator::~SubbandErleEstimator() = default; void SubbandErleEstimator::Reset() { - for (auto& erle : erle_) { - erle.fill(min_erle_); - } - for (size_t ch = 0; ch < erle_onsets_.size(); ++ch) { - erle_onsets_[ch].fill(min_erle_); + const size_t num_capture_channels = erle_.size(); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + erle_[ch].fill(min_erle_); + erle_onset_compensated_[ch].fill(min_erle_); + erle_unbounded_[ch].fill(min_erle_); + erle_during_onsets_[ch].fill(min_erle_); coming_onset_[ch].fill(true); hold_counters_[ch].fill(0); } @@ -80,15 +83,25 @@ void SubbandErleEstimator::Update( DecreaseErlePerBandForLowRenderSignals(); } - for (auto& erle : erle_) { + const size_t num_capture_channels = erle_.size(); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + auto& erle = erle_[ch]; erle[0] = erle[1]; erle[kFftLengthBy2] = erle[kFftLengthBy2 - 1]; + + auto& erle_oc = erle_onset_compensated_[ch]; + erle_oc[0] = erle_oc[1]; + erle_oc[kFftLengthBy2] = erle_oc[kFftLengthBy2 - 1]; + + auto& erle_u = erle_unbounded_[ch]; + erle_u[0] = erle_u[1]; + erle_u[kFftLengthBy2] = erle_u[kFftLengthBy2 - 1]; } } void SubbandErleEstimator::Dump( const std::unique_ptr& data_dumper) const { - data_dumper->DumpRaw("aec3_erle_onset", ErleOnsets()[0]); + data_dumper->DumpRaw("aec3_erle_onset", ErleDuringOnsets()[0]); } void SubbandErleEstimator::UpdateBands( @@ -102,13 +115,16 @@ void SubbandErleEstimator::UpdateBands( continue; } + if (accum_spectra_.num_points[ch] != kPointsToAccumulate) { + continue; + } + std::array new_erle; std::array is_erle_updated; is_erle_updated.fill(false); for (size_t k = 1; k < kFftLengthBy2; ++k) { - if (accum_spectra_.num_points[ch] == kPointsToAccumulate && - accum_spectra_.E2[ch][k] > 0.f) { + if (accum_spectra_.E2[ch][k] > 0.f) { new_erle[k] = accum_spectra_.Y2[ch][k] / accum_spectra_.E2[ch][k]; is_erle_updated[k] = true; } @@ -120,10 +136,11 @@ void SubbandErleEstimator::UpdateBands( if (coming_onset_[ch][k]) { coming_onset_[ch][k] = false; if (!use_min_erle_during_onsets_) { - float alpha = new_erle[k] < erle_onsets_[ch][k] ? 0.3f : 0.15f; - erle_onsets_[ch][k] = rtc::SafeClamp( - erle_onsets_[ch][k] + - alpha * (new_erle[k] - erle_onsets_[ch][k]), + float alpha = + new_erle[k] < erle_during_onsets_[ch][k] ? 0.3f : 0.15f; + erle_during_onsets_[ch][k] = rtc::SafeClamp( + erle_during_onsets_[ch][k] + + alpha * (new_erle[k] - erle_during_onsets_[ch][k]), min_erle_, max_erle_[k]); } } @@ -132,15 +149,31 @@ void SubbandErleEstimator::UpdateBands( } } + auto update_erle_band = [](float& erle, float new_erle, + bool low_render_energy, float min_erle, + float max_erle) { + float alpha = 0.05f; + if (new_erle < erle) { + alpha = low_render_energy ? 0.f : 0.1f; + } + erle = + rtc::SafeClamp(erle + alpha * (new_erle - erle), min_erle, max_erle); + }; + for (size_t k = 1; k < kFftLengthBy2; ++k) { if (is_erle_updated[k]) { - float alpha = 0.05f; - if (new_erle[k] < erle_[ch][k]) { - alpha = accum_spectra_.low_render_energy[ch][k] ? 0.f : 0.1f; + const bool low_render_energy = accum_spectra_.low_render_energy[ch][k]; + update_erle_band(erle_[ch][k], new_erle[k], low_render_energy, + min_erle_, max_erle_[k]); + if (use_onset_detection_) { + update_erle_band(erle_onset_compensated_[ch][k], new_erle[k], + low_render_energy, min_erle_, max_erle_[k]); } - erle_[ch][k] = - rtc::SafeClamp(erle_[ch][k] + alpha * (new_erle[k] - erle_[ch][k]), - min_erle_, max_erle_[k]); + + // Virtually unbounded ERLE. + constexpr float kUnboundedErleMax = 100000.0f; + update_erle_band(erle_unbounded_[ch][k], new_erle[k], low_render_energy, + min_erle_, kUnboundedErleMax); } } } @@ -153,9 +186,11 @@ void SubbandErleEstimator::DecreaseErlePerBandForLowRenderSignals() { --hold_counters_[ch][k]; if (hold_counters_[ch][k] <= (kBlocksForOnsetDetection - kBlocksToHoldErle)) { - if (erle_[ch][k] > erle_onsets_[ch][k]) { - erle_[ch][k] = std::max(erle_onsets_[ch][k], 0.97f * erle_[ch][k]); - RTC_DCHECK_LE(min_erle_, erle_[ch][k]); + if (erle_onset_compensated_[ch][k] > erle_during_onsets_[ch][k]) { + erle_onset_compensated_[ch][k] = + std::max(erle_during_onsets_[ch][k], + 0.97f * erle_onset_compensated_[ch][k]); + RTC_DCHECK_LE(min_erle_, erle_onset_compensated_[ch][k]); } if (hold_counters_[ch][k] <= 0) { coming_onset_[ch][k] = true; @@ -167,7 +202,7 @@ void SubbandErleEstimator::DecreaseErlePerBandForLowRenderSignals() { } void SubbandErleEstimator::ResetAccumulatedSpectra() { - for (size_t ch = 0; ch < erle_onsets_.size(); ++ch) { + for (size_t ch = 0; ch < erle_during_onsets_.size(); ++ch) { accum_spectra_.Y2[ch].fill(0.f); accum_spectra_.E2[ch].fill(0.f); accum_spectra_.num_points[ch] = 0; diff --git a/modules/audio_processing/aec3/subband_erle_estimator.h b/modules/audio_processing/aec3/subband_erle_estimator.h index 90363e081d..8bf9c4d645 100644 --- a/modules/audio_processing/aec3/subband_erle_estimator.h +++ b/modules/audio_processing/aec3/subband_erle_estimator.h @@ -41,14 +41,22 @@ class SubbandErleEstimator { const std::vector& converged_filters); // Returns the ERLE estimate. - rtc::ArrayView> Erle() const { - return erle_; + rtc::ArrayView> Erle( + bool onset_compensated) const { + return onset_compensated && use_onset_detection_ ? erle_onset_compensated_ + : erle_; + } + + // Returns the non-capped ERLE estimate. + rtc::ArrayView> ErleUnbounded() + const { + return erle_unbounded_; } // Returns the ERLE estimate at onsets (only used for testing). - rtc::ArrayView> ErleOnsets() + rtc::ArrayView> ErleDuringOnsets() const { - return erle_onsets_; + return erle_during_onsets_; } void Dump(const std::unique_ptr& data_dumper) const; @@ -82,8 +90,13 @@ class SubbandErleEstimator { const std::array max_erle_; const bool use_min_erle_during_onsets_; AccumulatedSpectra accum_spectra_; + // ERLE without special handling of render onsets. std::vector> erle_; - std::vector> erle_onsets_; + // ERLE lowered during render onsets. + std::vector> erle_onset_compensated_; + std::vector> erle_unbounded_; + // Estimation of ERLE during render onsets. + std::vector> erle_during_onsets_; std::vector> coming_onset_; std::vector> hold_counters_; }; diff --git a/modules/audio_processing/aec3/subtractor.cc b/modules/audio_processing/aec3/subtractor.cc index d10e4ffc52..2eae686752 100644 --- a/modules/audio_processing/aec3/subtractor.cc +++ b/modules/audio_processing/aec3/subtractor.cc @@ -91,7 +91,20 @@ Subtractor::Subtractor(const EchoCanceller3Config& config, std::vector(GetTimeDomainLength(std::max( config_.filter.refined_initial.length_blocks, config_.filter.refined.length_blocks)), - 0.f)) { + 0.f)), + coarse_impulse_responses_(0) { + // Set up the storing of coarse impulse responses if data dumping is + // available. + if (ApmDataDumper::IsAvailable()) { + coarse_impulse_responses_.resize(num_capture_channels_); + const size_t filter_size = GetTimeDomainLength( + std::max(config_.filter.coarse_initial.length_blocks, + config_.filter.coarse.length_blocks)); + for (std::vector& impulse_response : coarse_impulse_responses_) { + impulse_response.resize(filter_size, 0.f); + } + } + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { refined_filters_[ch] = std::make_unique( config_.filter.refined.length_blocks, @@ -285,7 +298,14 @@ void Subtractor::Process(const RenderBuffer& render_buffer, config_.filter.coarse_reset_hangover_blocks; } - coarse_filter_[ch]->Adapt(render_buffer, G); + if (ApmDataDumper::IsAvailable()) { + RTC_DCHECK_LT(ch, coarse_impulse_responses_.size()); + coarse_filter_[ch]->Adapt(render_buffer, G, + &coarse_impulse_responses_[ch]); + } else { + coarse_filter_[ch]->Adapt(render_buffer, G); + } + if (ch == 0) { data_dumper_->DumpRaw("aec3_subtractor_G_coarse", G.re); data_dumper_->DumpRaw("aec3_subtractor_G_coarse", G.im); diff --git a/modules/audio_processing/aec3/subtractor.h b/modules/audio_processing/aec3/subtractor.h index 560f6568eb..767e4aad46 100644 --- a/modules/audio_processing/aec3/subtractor.h +++ b/modules/audio_processing/aec3/subtractor.h @@ -78,6 +78,15 @@ class Subtractor { refined_impulse_responses_[0].data(), GetTimeDomainLength( refined_filters_[0]->max_filter_size_partitions()))); + if (ApmDataDumper::IsAvailable()) { + RTC_DCHECK_GT(coarse_impulse_responses_.size(), 0); + data_dumper_->DumpRaw( + "aec3_subtractor_h_coarse", + rtc::ArrayView( + coarse_impulse_responses_[0].data(), + GetTimeDomainLength( + coarse_filter_[0]->max_filter_size_partitions()))); + } refined_filters_[0]->DumpFilter("aec3_subtractor_H_refined"); coarse_filter_[0]->DumpFilter("aec3_subtractor_H_coarse"); @@ -132,6 +141,7 @@ class Subtractor { std::vector>> refined_frequency_responses_; std::vector> refined_impulse_responses_; + std::vector> coarse_impulse_responses_; }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/suppression_gain.cc b/modules/audio_processing/aec3/suppression_gain.cc index 5b01c52908..6405d71c2d 100644 --- a/modules/audio_processing/aec3/suppression_gain.cc +++ b/modules/audio_processing/aec3/suppression_gain.cc @@ -23,10 +23,15 @@ #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/atomic_ops.h" #include "rtc_base/checks.h" +#include "system_wrappers/include/field_trial.h" namespace webrtc { namespace { +bool UseUnboundedEchoSpectrum() { + return field_trial::IsEnabled("WebRTC-Aec3UseUnboundedEchoSpectrum"); +} + void LimitLowFrequencyGains(std::array* gain) { // Limit the low frequency gains to avoid the impact of the high-pass filter // on the lower-frequency gain influencing the overall achieved gain. @@ -230,16 +235,20 @@ void SuppressionGain::GetMinGain( min_gain[k] = std::min(min_gain[k], 1.f); } - const bool is_nearend_state = dominant_nearend_detector_->IsNearendState(); - for (size_t k = 0; k < 6; ++k) { - const auto& dec = is_nearend_state ? nearend_params_.max_dec_factor_lf - : normal_params_.max_dec_factor_lf; - - // Make sure the gains of the low frequencies do not decrease too - // quickly after strong nearend. - if (last_nearend[k] > last_echo[k]) { - min_gain[k] = std::max(min_gain[k], last_gain_[k] * dec); - min_gain[k] = std::min(min_gain[k], 1.f); + if (!initial_state_ || + config_.suppressor.lf_smoothing_during_initial_phase) { + const float& dec = dominant_nearend_detector_->IsNearendState() + ? nearend_params_.max_dec_factor_lf + : normal_params_.max_dec_factor_lf; + + for (int k = 0; k <= config_.suppressor.last_lf_smoothing_band; ++k) { + // Make sure the gains of the low frequencies do not decrease too + // quickly after strong nearend. + if (last_nearend[k] > last_echo[k] || + k <= config_.suppressor.last_permanent_lf_smoothing_band) { + min_gain[k] = std::max(min_gain[k], last_gain_[k] * dec); + min_gain[k] = std::min(min_gain[k], 1.f); + } } } } else { @@ -333,8 +342,13 @@ SuppressionGain::SuppressionGain(const EchoCanceller3Config& config, num_capture_channels_, aec3::MovingAverage(kFftLengthBy2Plus1, config.suppressor.nearend_average_blocks)), - nearend_params_(config_.suppressor.nearend_tuning), - normal_params_(config_.suppressor.normal_tuning) { + nearend_params_(config_.suppressor.last_lf_band, + config_.suppressor.first_hf_band, + config_.suppressor.nearend_tuning), + normal_params_(config_.suppressor.last_lf_band, + config_.suppressor.first_hf_band, + config_.suppressor.normal_tuning), + use_unbounded_echo_spectrum_(UseUnboundedEchoSpectrum()) { RTC_DCHECK_LT(0, state_change_duration_blocks_); last_gain_.fill(1.f); if (config_.suppressor.use_subband_nearend_detection) { @@ -355,6 +369,8 @@ void SuppressionGain::GetGain( rtc::ArrayView> echo_spectrum, rtc::ArrayView> residual_echo_spectrum, + rtc::ArrayView> + residual_echo_spectrum_unbounded, rtc::ArrayView> comfort_noise_spectrum, const RenderSignalAnalyzer& render_signal_analyzer, @@ -366,8 +382,13 @@ void SuppressionGain::GetGain( RTC_DCHECK(high_bands_gain); RTC_DCHECK(low_band_gain); + // Choose residual echo spectrum for the dominant nearend detector. + const auto echo = use_unbounded_echo_spectrum_ + ? residual_echo_spectrum_unbounded + : residual_echo_spectrum; + // Update the nearend state selection. - dominant_nearend_detector_->Update(nearend_spectrum, residual_echo_spectrum, + dominant_nearend_detector_->Update(nearend_spectrum, echo, comfort_noise_spectrum, initial_state_); // Compute gain for the lower band. @@ -383,6 +404,9 @@ void SuppressionGain::GetGain( *high_bands_gain = UpperBandsGain(echo_spectrum, comfort_noise_spectrum, narrow_peak_band, aec_state.SaturatedEcho(), render, *low_band_gain); + + data_dumper_->DumpRaw("aec3_dominant_nearend", + dominant_nearend_detector_->IsNearendState()); } void SuppressionGain::SetInitialState(bool state) { @@ -419,23 +443,23 @@ bool SuppressionGain::LowNoiseRenderDetector::Detect( } SuppressionGain::GainParameters::GainParameters( + int last_lf_band, + int first_hf_band, const EchoCanceller3Config::Suppressor::Tuning& tuning) : max_inc_factor(tuning.max_inc_factor), max_dec_factor_lf(tuning.max_dec_factor_lf) { // Compute per-band masking thresholds. - constexpr size_t kLastLfBand = 5; - constexpr size_t kFirstHfBand = 8; - RTC_DCHECK_LT(kLastLfBand, kFirstHfBand); + RTC_DCHECK_LT(last_lf_band, first_hf_band); auto& lf = tuning.mask_lf; auto& hf = tuning.mask_hf; RTC_DCHECK_LT(lf.enr_transparent, lf.enr_suppress); RTC_DCHECK_LT(hf.enr_transparent, hf.enr_suppress); - for (size_t k = 0; k < kFftLengthBy2Plus1; k++) { + for (int k = 0; k < static_cast(kFftLengthBy2Plus1); k++) { float a; - if (k <= kLastLfBand) { + if (k <= last_lf_band) { a = 0.f; - } else if (k < kFirstHfBand) { - a = (k - kLastLfBand) / static_cast(kFirstHfBand - kLastLfBand); + } else if (k < first_hf_band) { + a = (k - last_lf_band) / static_cast(first_hf_band - last_lf_band); } else { a = 1.f; } diff --git a/modules/audio_processing/aec3/suppression_gain.h b/modules/audio_processing/aec3/suppression_gain.h index e7175c36da..7c4a1c9f7d 100644 --- a/modules/audio_processing/aec3/suppression_gain.h +++ b/modules/audio_processing/aec3/suppression_gain.h @@ -42,6 +42,8 @@ class SuppressionGain { rtc::ArrayView> echo_spectrum, rtc::ArrayView> residual_echo_spectrum, + rtc::ArrayView> + residual_echo_spectrum_unbounded, rtc::ArrayView> comfort_noise_spectrum, const RenderSignalAnalyzer& render_signal_analyzer, @@ -51,6 +53,10 @@ class SuppressionGain { float* high_bands_gain, std::array* low_band_gain); + bool IsDominantNearend() { + return dominant_nearend_detector_->IsNearendState(); + } + // Toggles the usage of the initial state. void SetInitialState(bool state); @@ -99,6 +105,8 @@ class SuppressionGain { struct GainParameters { explicit GainParameters( + int last_lf_band, + int first_hf_band, const EchoCanceller3Config::Suppressor::Tuning& tuning); const float max_inc_factor; const float max_dec_factor_lf; @@ -122,6 +130,9 @@ class SuppressionGain { std::vector nearend_smoothers_; const GainParameters nearend_params_; const GainParameters normal_params_; + // Determines if the dominant nearend detector uses the unbounded residual + // echo spectrum. + const bool use_unbounded_echo_spectrum_; std::unique_ptr dominant_nearend_detector_; RTC_DISALLOW_COPY_AND_ASSIGN(SuppressionGain); diff --git a/modules/audio_processing/aec3/suppression_gain_unittest.cc b/modules/audio_processing/aec3/suppression_gain_unittest.cc index 26bfc24ebb..999b0f27ab 100644 --- a/modules/audio_processing/aec3/suppression_gain_unittest.cc +++ b/modules/audio_processing/aec3/suppression_gain_unittest.cc @@ -26,29 +26,30 @@ namespace aec3 { // Verifies that the check for non-null output gains works. TEST(SuppressionGainDeathTest, NullOutputGains) { - std::vector> E2(1, {0.f}); - std::vector> R2(1, {0.f}); + std::vector> E2(1, {0.0f}); + std::vector> R2(1, {0.0f}); + std::vector> R2_unbounded(1, {0.0f}); std::vector> S2(1); - std::vector> N2(1, {0.f}); + std::vector> N2(1, {0.0f}); for (auto& S2_k : S2) { - S2_k.fill(.1f); + S2_k.fill(0.1f); } FftData E; FftData Y; - E.re.fill(0.f); - E.im.fill(0.f); - Y.re.fill(0.f); - Y.im.fill(0.f); + E.re.fill(0.0f); + E.im.fill(0.0f); + Y.re.fill(0.0f); + Y.im.fill(0.0f); float high_bands_gain; AecState aec_state(EchoCanceller3Config{}, 1); EXPECT_DEATH( SuppressionGain(EchoCanceller3Config{}, DetectOptimization(), 16000, 1) - .GetGain(E2, S2, R2, N2, + .GetGain(E2, S2, R2, R2_unbounded, N2, RenderSignalAnalyzer((EchoCanceller3Config{})), aec_state, std::vector>>( 3, std::vector>( - 1, std::vector(kBlockSize, 0.f))), + 1, std::vector(kBlockSize, 0.0f))), false, &high_bands_gain, nullptr), ""); } @@ -67,15 +68,17 @@ TEST(SuppressionGain, BasicGainComputation) { float high_bands_gain; std::vector> E2(kNumCaptureChannels); std::vector> S2(kNumCaptureChannels, - {0.f}); + {0.0f}); std::vector> Y2(kNumCaptureChannels); std::vector> R2(kNumCaptureChannels); + std::vector> R2_unbounded( + kNumCaptureChannels); std::vector> N2(kNumCaptureChannels); std::array g; std::vector output(kNumCaptureChannels); std::vector>> x( kNumBands, std::vector>( - kNumRenderChannels, std::vector(kBlockSize, 0.f))); + kNumRenderChannels, std::vector(kBlockSize, 0.0f))); EchoCanceller3Config config; AecState aec_state(config, kNumCaptureChannels); ApmDataDumper data_dumper(42); @@ -89,8 +92,9 @@ TEST(SuppressionGain, BasicGainComputation) { for (size_t ch = 0; ch < kNumCaptureChannels; ++ch) { E2[ch].fill(10.f); Y2[ch].fill(10.f); - R2[ch].fill(.1f); - N2[ch].fill(100.f); + R2[ch].fill(0.1f); + R2_unbounded[ch].fill(0.1f); + N2[ch].fill(100.0f); } for (auto& subtractor_output : output) { subtractor_output.Reset(); @@ -107,17 +111,18 @@ TEST(SuppressionGain, BasicGainComputation) { aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(), subtractor.FilterImpulseResponses(), *render_delay_buffer->GetRenderBuffer(), E2, Y2, output); - suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x, false, - &high_bands_gain, &g); + suppression_gain.GetGain(E2, S2, R2, R2_unbounded, N2, analyzer, aec_state, + x, false, &high_bands_gain, &g); } std::for_each(g.begin(), g.end(), - [](float a) { EXPECT_NEAR(1.f, a, 0.001); }); + [](float a) { EXPECT_NEAR(1.0f, a, 0.001f); }); // Ensure that a strong nearend is detected to mask any echoes. for (size_t ch = 0; ch < kNumCaptureChannels; ++ch) { E2[ch].fill(100.f); Y2[ch].fill(100.f); R2[ch].fill(0.1f); + R2_unbounded[ch].fill(0.1f); S2[ch].fill(0.1f); N2[ch].fill(0.f); } @@ -126,22 +131,23 @@ TEST(SuppressionGain, BasicGainComputation) { aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(), subtractor.FilterImpulseResponses(), *render_delay_buffer->GetRenderBuffer(), E2, Y2, output); - suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x, false, - &high_bands_gain, &g); + suppression_gain.GetGain(E2, S2, R2, R2_unbounded, N2, analyzer, aec_state, + x, false, &high_bands_gain, &g); } std::for_each(g.begin(), g.end(), - [](float a) { EXPECT_NEAR(1.f, a, 0.001); }); + [](float a) { EXPECT_NEAR(1.0f, a, 0.001f); }); // Add a strong echo to one of the channels and ensure that it is suppressed. - E2[1].fill(1000000000.f); - R2[1].fill(10000000000000.f); + E2[1].fill(1000000000.0f); + R2[1].fill(10000000000000.0f); + R2_unbounded[1].fill(10000000000000.0f); for (int k = 0; k < 10; ++k) { - suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x, false, - &high_bands_gain, &g); + suppression_gain.GetGain(E2, S2, R2, R2_unbounded, N2, analyzer, aec_state, + x, false, &high_bands_gain, &g); } std::for_each(g.begin(), g.end(), - [](float a) { EXPECT_NEAR(0.f, a, 0.001); }); + [](float a) { EXPECT_NEAR(0.0f, a, 0.001f); }); } } // namespace aec3 diff --git a/modules/audio_processing/aec3/transparent_mode.cc b/modules/audio_processing/aec3/transparent_mode.cc index 7cfa3e8eae..489f53f4f1 100644 --- a/modules/audio_processing/aec3/transparent_mode.cc +++ b/modules/audio_processing/aec3/transparent_mode.cc @@ -11,6 +11,7 @@ #include "modules/audio_processing/aec3/transparent_mode.h" #include "rtc_base/checks.h" +#include "rtc_base/logging.h" #include "system_wrappers/include/field_trial.h" namespace webrtc { @@ -228,11 +229,14 @@ class LegacyTransparentModeImpl : public TransparentMode { std::unique_ptr TransparentMode::Create( const EchoCanceller3Config& config) { if (config.ep_strength.bounded_erl || DeactivateTransparentMode()) { + RTC_LOG(LS_INFO) << "AEC3 Transparent Mode: Disabled"; return nullptr; } if (ActivateTransparentModeHmm()) { + RTC_LOG(LS_INFO) << "AEC3 Transparent Mode: HMM"; return std::make_unique(); } + RTC_LOG(LS_INFO) << "AEC3 Transparent Mode: Legacy"; return std::make_unique(config); } diff --git a/modules/audio_processing/aec_dump/aec_dump_impl.cc b/modules/audio_processing/aec_dump/aec_dump_impl.cc index 18f85721b1..db61b36c29 100644 --- a/modules/audio_processing/aec_dump/aec_dump_impl.cc +++ b/modules/audio_processing/aec_dump/aec_dump_impl.cc @@ -186,6 +186,12 @@ void AecDumpImpl::WriteRuntimeSetting( setting->set_capture_pre_gain(x); break; } + case AudioProcessing::RuntimeSetting::Type::kCapturePostGain: { + float x; + runtime_setting.GetFloat(&x); + setting->set_capture_post_gain(x); + break; + } case AudioProcessing::RuntimeSetting::Type:: kCustomRenderProcessingRuntimeSetting: { float x; diff --git a/modules/audio_processing/agc/BUILD.gn b/modules/audio_processing/agc/BUILD.gn index 8235456dd9..4bb8c5494b 100644 --- a/modules/audio_processing/agc/BUILD.gn +++ b/modules/audio_processing/agc/BUILD.gn @@ -19,11 +19,14 @@ rtc_library("agc") { ] configs += [ "..:apm_debug_dump" ] deps = [ + ":clipping_predictor", + ":clipping_predictor_evaluator", ":gain_control_interface", ":gain_map", ":level_estimation", "..:apm_logging", "..:audio_buffer", + "..:audio_frame_view", "../../../common_audio", "../../../common_audio:common_audio_c", "../../../rtc_base:checks", @@ -33,12 +36,54 @@ rtc_library("agc") { "../../../rtc_base:safe_minmax", "../../../system_wrappers:field_trial", "../../../system_wrappers:metrics", - "../agc2:level_estimation_agc", "../vad", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } +rtc_library("clipping_predictor") { + sources = [ + "clipping_predictor.cc", + "clipping_predictor.h", + ] + deps = [ + ":clipping_predictor_level_buffer", + ":gain_map", + "..:api", + "..:audio_frame_view", + "../../../common_audio", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:safe_minmax", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] +} + +rtc_library("clipping_predictor_evaluator") { + sources = [ + "clipping_predictor_evaluator.cc", + "clipping_predictor_evaluator.h", + ] + deps = [ + "../../../rtc_base:checks", + "../../../rtc_base:logging", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] +} + +rtc_library("clipping_predictor_level_buffer") { + sources = [ + "clipping_predictor_level_buffer.cc", + "clipping_predictor_level_buffer.h", + ] + deps = [ + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:rtc_base_approved", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] +} + rtc_library("level_estimation") { sources = [ "agc.cc", @@ -97,6 +142,9 @@ if (rtc_include_tests) { testonly = true sources = [ "agc_manager_direct_unittest.cc", + "clipping_predictor_evaluator_unittest.cc", + "clipping_predictor_level_buffer_unittest.cc", + "clipping_predictor_unittest.cc", "loudness_histogram_unittest.cc", "mock_agc.h", ] @@ -104,13 +152,20 @@ if (rtc_include_tests) { deps = [ ":agc", + ":clipping_predictor", + ":clipping_predictor_evaluator", + ":clipping_predictor_level_buffer", ":gain_control_interface", ":level_estimation", "..:mocks", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../../../rtc_base:safe_conversions", "../../../test:field_trial", "../../../test:fileutils", "../../../test:test_support", "//testing/gtest", ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } } diff --git a/modules/audio_processing/agc/agc_manager_direct.cc b/modules/audio_processing/agc/agc_manager_direct.cc index 1428d2a0e7..e2a5b998a4 100644 --- a/modules/audio_processing/agc/agc_manager_direct.cc +++ b/modules/audio_processing/agc/agc_manager_direct.cc @@ -16,7 +16,7 @@ #include "common_audio/include/audio_util.h" #include "modules/audio_processing/agc/gain_control.h" #include "modules/audio_processing/agc/gain_map_internal.h" -#include "modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.h" +#include "modules/audio_processing/include/audio_frame_view.h" #include "rtc_base/atomic_ops.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" @@ -28,33 +28,33 @@ namespace webrtc { namespace { -// Amount the microphone level is lowered with every clipping event. -const int kClippedLevelStep = 15; -// Proportion of clipped samples required to declare a clipping event. -const float kClippedRatioThreshold = 0.1f; -// Time in frames to wait after a clipping event before checking again. -const int kClippedWaitFrames = 300; - // Amount of error we tolerate in the microphone level (presumably due to OS // quantization) before we assume the user has manually adjusted the microphone. -const int kLevelQuantizationSlack = 25; +constexpr int kLevelQuantizationSlack = 25; -const int kDefaultCompressionGain = 7; -const int kMaxCompressionGain = 12; -const int kMinCompressionGain = 2; +constexpr int kDefaultCompressionGain = 7; +constexpr int kMaxCompressionGain = 12; +constexpr int kMinCompressionGain = 2; // Controls the rate of compression changes towards the target. -const float kCompressionGainStep = 0.05f; +constexpr float kCompressionGainStep = 0.05f; -const int kMaxMicLevel = 255; +constexpr int kMaxMicLevel = 255; static_assert(kGainMapSize > kMaxMicLevel, "gain map too small"); -const int kMinMicLevel = 12; +constexpr int kMinMicLevel = 12; // Prevent very large microphone level changes. -const int kMaxResidualGainChange = 15; +constexpr int kMaxResidualGainChange = 15; // Maximum additional gain allowed to compensate for microphone level // restrictions from clipping events. -const int kSurplusCompressionGain = 6; +constexpr int kSurplusCompressionGain = 6; + +// History size for the clipping predictor evaluator (unit: number of 10 ms +// frames). +constexpr int kClippingPredictorEvaluatorHistorySize = 32; + +using ClippingPredictorConfig = AudioProcessing::Config::GainController1:: + AnalogGainController::ClippingPredictor; // Returns whether a fall-back solution to choose the maximum level should be // chosen. @@ -133,29 +133,57 @@ float ComputeClippedRatio(const float* const* audio, return static_cast(num_clipped) / (samples_per_channel); } +void LogClippingPredictorMetrics(const ClippingPredictorEvaluator& evaluator) { + RTC_LOG(LS_INFO) << "Clipping predictor metrics: TP " + << evaluator.true_positives() << " TN " + << evaluator.true_negatives() << " FP " + << evaluator.false_positives() << " FN " + << evaluator.false_negatives(); + const float precision_denominator = + evaluator.true_positives() + evaluator.false_positives(); + const float recall_denominator = + evaluator.true_positives() + evaluator.false_negatives(); + if (precision_denominator > 0 && recall_denominator > 0) { + const float precision = evaluator.true_positives() / precision_denominator; + const float recall = evaluator.true_positives() / recall_denominator; + RTC_LOG(LS_INFO) << "Clipping predictor metrics: P " << precision << " R " + << recall; + const float f1_score_denominator = precision + recall; + if (f1_score_denominator > 0.0f) { + const float f1_score = 2 * precision * recall / f1_score_denominator; + RTC_LOG(LS_INFO) << "Clipping predictor metrics: F1 " << f1_score; + RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc.ClippingPredictor.F1Score", + std::round(f1_score * 100.0f), /*min=*/0, + /*max=*/100, + /*bucket_count=*/50); + } + } +} + +void LogClippingMetrics(int clipping_rate) { + RTC_LOG(LS_INFO) << "Input clipping rate: " << clipping_rate << "%"; + RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc.InputClippingRate", + clipping_rate, /*min=*/0, /*max=*/100, + /*bucket_count=*/50); +} + } // namespace MonoAgc::MonoAgc(ApmDataDumper* data_dumper, int startup_min_level, int clipped_level_min, - bool use_agc2_level_estimation, bool disable_digital_adaptive, int min_mic_level) : min_mic_level_(min_mic_level), disable_digital_adaptive_(disable_digital_adaptive), + agc_(std::make_unique()), max_level_(kMaxMicLevel), max_compression_gain_(kMaxCompressionGain), target_compression_(kDefaultCompressionGain), compression_(target_compression_), compression_accumulator_(compression_), startup_min_level_(ClampLevel(startup_min_level, min_mic_level_)), - clipped_level_min_(clipped_level_min) { - if (use_agc2_level_estimation) { - agc_ = std::make_unique(data_dumper); - } else { - agc_ = std::make_unique(); - } -} + clipped_level_min_(clipped_level_min) {} MonoAgc::~MonoAgc() = default; @@ -165,7 +193,7 @@ void MonoAgc::Initialize() { target_compression_ = disable_digital_adaptive_ ? 0 : kDefaultCompressionGain; compression_ = disable_digital_adaptive_ ? 0 : target_compression_; compression_accumulator_ = compression_; - capture_muted_ = false; + capture_output_used_ = true; check_volume_on_next_process_ = true; } @@ -189,19 +217,19 @@ void MonoAgc::Process(const int16_t* audio, } } -void MonoAgc::HandleClipping() { +void MonoAgc::HandleClipping(int clipped_level_step) { // Always decrease the maximum level, even if the current level is below // threshold. - SetMaxLevel(std::max(clipped_level_min_, max_level_ - kClippedLevelStep)); + SetMaxLevel(std::max(clipped_level_min_, max_level_ - clipped_level_step)); if (log_to_histograms_) { RTC_HISTOGRAM_BOOLEAN("WebRTC.Audio.AgcClippingAdjustmentAllowed", - level_ - kClippedLevelStep >= clipped_level_min_); + level_ - clipped_level_step >= clipped_level_min_); } if (level_ > clipped_level_min_) { // Don't try to adjust the level if we're already below the limit. As // a consequence, if the user has brought the level above the limit, we // will still not react until the postproc updates the level. - SetLevel(std::max(clipped_level_min_, level_ - kClippedLevelStep)); + SetLevel(std::max(clipped_level_min_, level_ - clipped_level_step)); // Reset the AGCs for all channels since the level has changed. agc_->Reset(); } @@ -263,14 +291,14 @@ void MonoAgc::SetMaxLevel(int level) { << ", max_compression_gain_=" << max_compression_gain_; } -void MonoAgc::SetCaptureMuted(bool muted) { - if (capture_muted_ == muted) { +void MonoAgc::HandleCaptureOutputUsedChange(bool capture_output_used) { + if (capture_output_used_ == capture_output_used) { return; } - capture_muted_ = muted; + capture_output_used_ = capture_output_used; - if (!muted) { - // When we unmute, we should reset things to be safe. + if (capture_output_used) { + // When we start using the output, we should reset things to be safe. check_volume_on_next_process_ = true; } } @@ -408,46 +436,74 @@ void MonoAgc::UpdateCompressor() { int AgcManagerDirect::instance_counter_ = 0; -AgcManagerDirect::AgcManagerDirect(Agc* agc, - int startup_min_level, - int clipped_level_min, - int sample_rate_hz) +AgcManagerDirect::AgcManagerDirect( + Agc* agc, + int startup_min_level, + int clipped_level_min, + int sample_rate_hz, + int clipped_level_step, + float clipped_ratio_threshold, + int clipped_wait_frames, + const ClippingPredictorConfig& clipping_config) : AgcManagerDirect(/*num_capture_channels*/ 1, startup_min_level, clipped_level_min, - /*use_agc2_level_estimation*/ false, /*disable_digital_adaptive*/ false, - sample_rate_hz) { + sample_rate_hz, + clipped_level_step, + clipped_ratio_threshold, + clipped_wait_frames, + clipping_config) { RTC_DCHECK(channel_agcs_[0]); RTC_DCHECK(agc); channel_agcs_[0]->set_agc(agc); } -AgcManagerDirect::AgcManagerDirect(int num_capture_channels, - int startup_min_level, - int clipped_level_min, - bool use_agc2_level_estimation, - bool disable_digital_adaptive, - int sample_rate_hz) +AgcManagerDirect::AgcManagerDirect( + int num_capture_channels, + int startup_min_level, + int clipped_level_min, + bool disable_digital_adaptive, + int sample_rate_hz, + int clipped_level_step, + float clipped_ratio_threshold, + int clipped_wait_frames, + const ClippingPredictorConfig& clipping_config) : data_dumper_( new ApmDataDumper(rtc::AtomicOps::Increment(&instance_counter_))), use_min_channel_level_(!UseMaxAnalogChannelLevel()), sample_rate_hz_(sample_rate_hz), num_capture_channels_(num_capture_channels), disable_digital_adaptive_(disable_digital_adaptive), - frames_since_clipped_(kClippedWaitFrames), - capture_muted_(false), + frames_since_clipped_(clipped_wait_frames), + capture_output_used_(true), + clipped_level_step_(clipped_level_step), + clipped_ratio_threshold_(clipped_ratio_threshold), + clipped_wait_frames_(clipped_wait_frames), channel_agcs_(num_capture_channels), - new_compressions_to_set_(num_capture_channels) { + new_compressions_to_set_(num_capture_channels), + clipping_predictor_( + CreateClippingPredictor(num_capture_channels, clipping_config)), + use_clipping_predictor_step_(!!clipping_predictor_ && + clipping_config.use_predicted_step), + clipping_predictor_evaluator_(kClippingPredictorEvaluatorHistorySize), + clipping_predictor_log_counter_(0), + clipping_rate_log_(0.0f), + clipping_rate_log_counter_(0) { const int min_mic_level = GetMinMicLevel(); for (size_t ch = 0; ch < channel_agcs_.size(); ++ch) { ApmDataDumper* data_dumper_ch = ch == 0 ? data_dumper_.get() : nullptr; channel_agcs_[ch] = std::make_unique( data_dumper_ch, startup_min_level, clipped_level_min, - use_agc2_level_estimation, disable_digital_adaptive_, min_mic_level); - } - RTC_DCHECK_LT(0, channel_agcs_.size()); + disable_digital_adaptive_, min_mic_level); + } + RTC_DCHECK(!channel_agcs_.empty()); + RTC_DCHECK_GT(clipped_level_step, 0); + RTC_DCHECK_LE(clipped_level_step, 255); + RTC_DCHECK_GT(clipped_ratio_threshold, 0.f); + RTC_DCHECK_LT(clipped_ratio_threshold, 1.f); + RTC_DCHECK_GT(clipped_wait_frames, 0); channel_agcs_[0]->ActivateLogging(); } @@ -459,9 +515,13 @@ void AgcManagerDirect::Initialize() { for (size_t ch = 0; ch < channel_agcs_.size(); ++ch) { channel_agcs_[ch]->Initialize(); } - capture_muted_ = false; + capture_output_used_ = true; AggregateChannelLevels(); + clipping_predictor_evaluator_.Reset(); + clipping_predictor_log_counter_ = 0; + clipping_rate_log_ = 0.0f; + clipping_rate_log_counter_ = 0; } void AgcManagerDirect::SetupDigitalGainControl( @@ -494,13 +554,14 @@ void AgcManagerDirect::AnalyzePreProcess(const float* const* audio, size_t samples_per_channel) { RTC_DCHECK(audio); AggregateChannelLevels(); - if (capture_muted_) { + if (!capture_output_used_) { return; } - if (frames_since_clipped_ < kClippedWaitFrames) { - ++frames_since_clipped_; - return; + if (!!clipping_predictor_) { + AudioFrameView frame = AudioFrameView( + audio, num_capture_channels_, static_cast(samples_per_channel)); + clipping_predictor_->Analyze(frame); } // Check for clipped samples, as the AGC has difficulty detecting pitch @@ -514,14 +575,67 @@ void AgcManagerDirect::AnalyzePreProcess(const float* const* audio, // gain is increased, through SetMaxLevel(). float clipped_ratio = ComputeClippedRatio(audio, num_capture_channels_, samples_per_channel); + clipping_rate_log_ = std::max(clipped_ratio, clipping_rate_log_); + clipping_rate_log_counter_++; + constexpr int kNumFramesIn30Seconds = 3000; + if (clipping_rate_log_counter_ == kNumFramesIn30Seconds) { + LogClippingMetrics(std::round(100.0f * clipping_rate_log_)); + clipping_rate_log_ = 0.0f; + clipping_rate_log_counter_ = 0; + } - if (clipped_ratio > kClippedRatioThreshold) { - RTC_DLOG(LS_INFO) << "[agc] Clipping detected. clipped_ratio=" - << clipped_ratio; + if (frames_since_clipped_ < clipped_wait_frames_) { + ++frames_since_clipped_; + return; + } + + const bool clipping_detected = clipped_ratio > clipped_ratio_threshold_; + bool clipping_predicted = false; + int predicted_step = 0; + if (!!clipping_predictor_) { + for (int channel = 0; channel < num_capture_channels_; ++channel) { + const auto step = clipping_predictor_->EstimateClippedLevelStep( + channel, stream_analog_level_, clipped_level_step_, + channel_agcs_[channel]->min_mic_level(), kMaxMicLevel); + if (use_clipping_predictor_step_ && step.has_value()) { + predicted_step = std::max(predicted_step, step.value()); + clipping_predicted = true; + } + } + // Clipping prediction evaluation. + absl::optional prediction_interval = + clipping_predictor_evaluator_.Observe(clipping_detected, + clipping_predicted); + if (prediction_interval.has_value()) { + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.Audio.Agc.ClippingPredictor.PredictionInterval", + prediction_interval.value(), /*min=*/0, + /*max=*/49, /*bucket_count=*/50); + } + clipping_predictor_log_counter_++; + if (clipping_predictor_log_counter_ == kNumFramesIn30Seconds) { + LogClippingPredictorMetrics(clipping_predictor_evaluator_); + clipping_predictor_log_counter_ = 0; + } + } + if (clipping_detected || clipping_predicted) { + int step = clipped_level_step_; + if (clipping_detected) { + RTC_DLOG(LS_INFO) << "[agc] Clipping detected. clipped_ratio=" + << clipped_ratio; + } + if (clipping_predicted) { + step = std::max(predicted_step, clipped_level_step_); + RTC_DLOG(LS_INFO) << "[agc] Clipping predicted. step=" << step; + } for (auto& state_ch : channel_agcs_) { - state_ch->HandleClipping(); + state_ch->HandleClipping(step); } frames_since_clipped_ = 0; + if (!!clipping_predictor_) { + clipping_predictor_->Reset(); + clipping_predictor_evaluator_.Reset(); + } } AggregateChannelLevels(); } @@ -529,7 +643,7 @@ void AgcManagerDirect::AnalyzePreProcess(const float* const* audio, void AgcManagerDirect::Process(const AudioBuffer* audio) { AggregateChannelLevels(); - if (capture_muted_) { + if (!capture_output_used_) { return; } @@ -558,11 +672,11 @@ absl::optional AgcManagerDirect::GetDigitalComressionGain() { return new_compressions_to_set_[channel_controlling_gain_]; } -void AgcManagerDirect::SetCaptureMuted(bool muted) { +void AgcManagerDirect::HandleCaptureOutputUsedChange(bool capture_output_used) { for (size_t ch = 0; ch < channel_agcs_.size(); ++ch) { - channel_agcs_[ch]->SetCaptureMuted(muted); + channel_agcs_[ch]->HandleCaptureOutputUsedChange(capture_output_used); } - capture_muted_ = muted; + capture_output_used_ = capture_output_used; } float AgcManagerDirect::voice_probability() const { diff --git a/modules/audio_processing/agc/agc_manager_direct.h b/modules/audio_processing/agc/agc_manager_direct.h index d3663be69e..d80a255ced 100644 --- a/modules/audio_processing/agc/agc_manager_direct.h +++ b/modules/audio_processing/agc/agc_manager_direct.h @@ -15,6 +15,8 @@ #include "absl/types/optional.h" #include "modules/audio_processing/agc/agc.h" +#include "modules/audio_processing/agc/clipping_predictor.h" +#include "modules/audio_processing/agc/clipping_predictor_evaluator.h" #include "modules/audio_processing/audio_buffer.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/gtest_prod_util.h" @@ -34,13 +36,23 @@ class AgcManagerDirect final { // AgcManagerDirect will configure GainControl internally. The user is // responsible for processing the audio using it after the call to Process. // The operating range of startup_min_level is [12, 255] and any input value - // outside that range will be clamped. - AgcManagerDirect(int num_capture_channels, - int startup_min_level, - int clipped_level_min, - bool use_agc2_level_estimation, - bool disable_digital_adaptive, - int sample_rate_hz); + // outside that range will be clamped. `clipped_level_step` is the amount + // the microphone level is lowered with every clipping event, limited to + // (0, 255]. `clipped_ratio_threshold` is the proportion of clipped + // samples required to declare a clipping event, limited to (0.f, 1.f). + // `clipped_wait_frames` is the time in frames to wait after a clipping event + // before checking again, limited to values higher than 0. + AgcManagerDirect( + int num_capture_channels, + int startup_min_level, + int clipped_level_min, + bool disable_digital_adaptive, + int sample_rate_hz, + int clipped_level_step, + float clipped_ratio_threshold, + int clipped_wait_frames, + const AudioProcessing::Config::GainController1::AnalogGainController:: + ClippingPredictor& clipping_config); ~AgcManagerDirect(); AgcManagerDirect(const AgcManagerDirect&) = delete; @@ -52,10 +64,9 @@ class AgcManagerDirect final { void AnalyzePreProcess(const AudioBuffer* audio); void Process(const AudioBuffer* audio); - // Call when the capture stream has been muted/unmuted. This causes the - // manager to disregard all incoming audio; chances are good it's background - // noise to which we'd like to avoid adapting. - void SetCaptureMuted(bool muted); + // Call when the capture stream output has been flagged to be used/not-used. + // If unused, the manager disregards all incoming audio. + void HandleCaptureOutputUsedChange(bool capture_output_used); float voice_probability() const; int stream_analog_level() const { return stream_analog_level_; } @@ -66,6 +77,14 @@ class AgcManagerDirect final { // If available, returns a new compression gain for the digital gain control. absl::optional GetDigitalComressionGain(); + // Returns true if clipping prediction is enabled. + bool clipping_predictor_enabled() const { return !!clipping_predictor_; } + + // Returns true if clipping prediction is used to adjust the analog gain. + bool use_clipping_predictor_step() const { + return use_clipping_predictor_step_; + } + private: friend class AgcManagerDirectTest; @@ -73,13 +92,38 @@ class AgcManagerDirect final { DisableDigitalDisablesDigital); FRIEND_TEST_ALL_PREFIXES(AgcManagerDirectStandaloneTest, AgcMinMicLevelExperiment); + FRIEND_TEST_ALL_PREFIXES(AgcManagerDirectStandaloneTest, + AgcMinMicLevelExperimentDisabled); + FRIEND_TEST_ALL_PREFIXES(AgcManagerDirectStandaloneTest, + AgcMinMicLevelExperimentOutOfRangeAbove); + FRIEND_TEST_ALL_PREFIXES(AgcManagerDirectStandaloneTest, + AgcMinMicLevelExperimentOutOfRangeBelow); + FRIEND_TEST_ALL_PREFIXES(AgcManagerDirectStandaloneTest, + AgcMinMicLevelExperimentEnabled50); + FRIEND_TEST_ALL_PREFIXES(AgcManagerDirectStandaloneTest, + AgcMinMicLevelExperimentEnabledAboveStartupLevel); + FRIEND_TEST_ALL_PREFIXES(AgcManagerDirectStandaloneTest, + ClippingParametersVerified); + FRIEND_TEST_ALL_PREFIXES(AgcManagerDirectStandaloneTest, + DisableClippingPredictorDoesNotLowerVolume); + FRIEND_TEST_ALL_PREFIXES( + AgcManagerDirectStandaloneTest, + EnableClippingPredictorWithUnusedPredictedStepDoesNotLowerVolume); + FRIEND_TEST_ALL_PREFIXES(AgcManagerDirectStandaloneTest, + EnableClippingPredictorLowersVolume); // Dependency injection for testing. Don't delete |agc| as the memory is owned // by the manager. - AgcManagerDirect(Agc* agc, - int startup_min_level, - int clipped_level_min, - int sample_rate_hz); + AgcManagerDirect( + Agc* agc, + int startup_min_level, + int clipped_level_min, + int sample_rate_hz, + int clipped_level_step, + float clipped_ratio_threshold, + int clipped_wait_frames, + const AudioProcessing::Config::GainController1::AnalogGainController:: + ClippingPredictor& clipping_config); void AnalyzePreProcess(const float* const* audio, size_t samples_per_channel); @@ -94,11 +138,22 @@ class AgcManagerDirect final { int frames_since_clipped_; int stream_analog_level_ = 0; - bool capture_muted_; + bool capture_output_used_; int channel_controlling_gain_ = 0; + const int clipped_level_step_; + const float clipped_ratio_threshold_; + const int clipped_wait_frames_; + std::vector> channel_agcs_; std::vector> new_compressions_to_set_; + + const std::unique_ptr clipping_predictor_; + const bool use_clipping_predictor_step_; + ClippingPredictorEvaluator clipping_predictor_evaluator_; + int clipping_predictor_log_counter_; + float clipping_rate_log_; + int clipping_rate_log_counter_; }; class MonoAgc { @@ -106,7 +161,6 @@ class MonoAgc { MonoAgc(ApmDataDumper* data_dumper, int startup_min_level, int clipped_level_min, - bool use_agc2_level_estimation, bool disable_digital_adaptive, int min_mic_level); ~MonoAgc(); @@ -114,9 +168,9 @@ class MonoAgc { MonoAgc& operator=(const MonoAgc&) = delete; void Initialize(); - void SetCaptureMuted(bool muted); + void HandleCaptureOutputUsedChange(bool capture_output_used); - void HandleClipping(); + void HandleClipping(int clipped_level_step); void Process(const int16_t* audio, size_t samples_per_channel, @@ -158,7 +212,7 @@ class MonoAgc { int target_compression_; int compression_; float compression_accumulator_; - bool capture_muted_ = false; + bool capture_output_used_ = true; bool check_volume_on_next_process_ = true; bool startup_ = true; int startup_min_level_; diff --git a/modules/audio_processing/agc/agc_manager_direct_unittest.cc b/modules/audio_processing/agc/agc_manager_direct_unittest.cc index 995801a8cb..bb284f9abc 100644 --- a/modules/audio_processing/agc/agc_manager_direct_unittest.cc +++ b/modules/audio_processing/agc/agc_manager_direct_unittest.cc @@ -26,13 +26,19 @@ using ::testing::SetArgPointee; namespace webrtc { namespace { -const int kSampleRateHz = 32000; -const int kNumChannels = 1; -const int kSamplesPerChannel = kSampleRateHz / 100; -const int kInitialVolume = 128; +constexpr int kSampleRateHz = 32000; +constexpr int kNumChannels = 1; +constexpr int kSamplesPerChannel = kSampleRateHz / 100; +constexpr int kInitialVolume = 128; constexpr int kClippedMin = 165; // Arbitrary, but different from the default. -const float kAboveClippedThreshold = 0.2f; -const int kMinMicLevel = 12; +constexpr float kAboveClippedThreshold = 0.2f; +constexpr int kMinMicLevel = 12; +constexpr int kClippedLevelStep = 15; +constexpr float kClippedRatioThreshold = 0.1f; +constexpr int kClippedWaitFrames = 300; + +using ClippingPredictorConfig = AudioProcessing::Config::GainController1:: + AnalogGainController::ClippingPredictor; class MockGainControl : public GainControl { public: @@ -56,13 +62,70 @@ class MockGainControl : public GainControl { MOCK_METHOD(bool, stream_is_saturated, (), (const, override)); }; +std::unique_ptr CreateAgcManagerDirect( + int startup_min_level, + int clipped_level_step, + float clipped_ratio_threshold, + int clipped_wait_frames) { + return std::make_unique( + /*num_capture_channels=*/1, startup_min_level, kClippedMin, + /*disable_digital_adaptive=*/true, kSampleRateHz, clipped_level_step, + clipped_ratio_threshold, clipped_wait_frames, ClippingPredictorConfig()); +} + +std::unique_ptr CreateAgcManagerDirect( + int startup_min_level, + int clipped_level_step, + float clipped_ratio_threshold, + int clipped_wait_frames, + const ClippingPredictorConfig& clipping_cfg) { + return std::make_unique( + /*num_capture_channels=*/1, startup_min_level, kClippedMin, + /*disable_digital_adaptive=*/true, kSampleRateHz, clipped_level_step, + clipped_ratio_threshold, clipped_wait_frames, clipping_cfg); +} + +void CallPreProcessAudioBuffer(int num_calls, + float peak_ratio, + AgcManagerDirect& manager) { + RTC_DCHECK_GE(1.f, peak_ratio); + AudioBuffer audio_buffer(kSampleRateHz, 1, kSampleRateHz, 1, kSampleRateHz, + 1); + const int num_channels = audio_buffer.num_channels(); + const int num_frames = audio_buffer.num_frames(); + for (int ch = 0; ch < num_channels; ++ch) { + for (int i = 0; i < num_frames; i += 2) { + audio_buffer.channels()[ch][i] = peak_ratio * 32767.f; + audio_buffer.channels()[ch][i + 1] = 0.0f; + } + } + for (int n = 0; n < num_calls / 2; ++n) { + manager.AnalyzePreProcess(&audio_buffer); + } + for (int ch = 0; ch < num_channels; ++ch) { + for (int i = 0; i < num_frames; ++i) { + audio_buffer.channels()[ch][i] = peak_ratio * 32767.f; + } + } + for (int n = 0; n < num_calls - num_calls / 2; ++n) { + manager.AnalyzePreProcess(&audio_buffer); + } +} + } // namespace class AgcManagerDirectTest : public ::testing::Test { protected: AgcManagerDirectTest() : agc_(new MockAgc), - manager_(agc_, kInitialVolume, kClippedMin, kSampleRateHz), + manager_(agc_, + kInitialVolume, + kClippedMin, + kSampleRateHz, + kClippedLevelStep, + kClippedRatioThreshold, + kClippedWaitFrames, + ClippingPredictorConfig()), audio(kNumChannels), audio_data(kNumChannels * kSamplesPerChannel, 0.f) { ExpectInitialize(); @@ -117,12 +180,32 @@ class AgcManagerDirectTest : public ::testing::Test { audio[ch][k] = 32767.f; } } - for (int i = 0; i < num_calls; ++i) { manager_.AnalyzePreProcess(audio.data(), kSamplesPerChannel); } } + void CallPreProcForChangingAudio(int num_calls, float peak_ratio) { + RTC_DCHECK_GE(1.f, peak_ratio); + std::fill(audio_data.begin(), audio_data.end(), 0.f); + for (size_t ch = 0; ch < kNumChannels; ++ch) { + for (size_t k = 0; k < kSamplesPerChannel; k += 2) { + audio[ch][k] = peak_ratio * 32767.f; + } + } + for (int i = 0; i < num_calls / 2; ++i) { + manager_.AnalyzePreProcess(audio.data(), kSamplesPerChannel); + } + for (size_t ch = 0; ch < kNumChannels; ++ch) { + for (size_t k = 0; k < kSamplesPerChannel; ++k) { + audio[ch][k] = peak_ratio * 32767.f; + } + } + for (int i = 0; i < num_calls - num_calls / 2; ++i) { + manager_.AnalyzePreProcess(audio.data(), kSamplesPerChannel); + } + } + MockAgc* agc_; MockGainControl gctrl_; AgcManagerDirect manager_; @@ -368,7 +451,7 @@ TEST_F(AgcManagerDirectTest, CompressorReachesMinimum) { } TEST_F(AgcManagerDirectTest, NoActionWhileMuted) { - manager_.SetCaptureMuted(true); + manager_.HandleCaptureOutputUsedChange(false); manager_.Process(nullptr); absl::optional new_digital_gain = manager_.GetDigitalComressionGain(); if (new_digital_gain) { @@ -379,8 +462,8 @@ TEST_F(AgcManagerDirectTest, NoActionWhileMuted) { TEST_F(AgcManagerDirectTest, UnmutingChecksVolumeWithoutRaising) { FirstProcess(); - manager_.SetCaptureMuted(true); - manager_.SetCaptureMuted(false); + manager_.HandleCaptureOutputUsedChange(false); + manager_.HandleCaptureOutputUsedChange(true); ExpectCheckVolumeAndReset(127); // SetMicVolume should not be called. EXPECT_CALL(*agc_, GetRmsErrorDb(_)).WillOnce(Return(false)); @@ -391,8 +474,8 @@ TEST_F(AgcManagerDirectTest, UnmutingChecksVolumeWithoutRaising) { TEST_F(AgcManagerDirectTest, UnmutingRaisesTooLowVolume) { FirstProcess(); - manager_.SetCaptureMuted(true); - manager_.SetCaptureMuted(false); + manager_.HandleCaptureOutputUsedChange(false); + manager_.HandleCaptureOutputUsedChange(true); ExpectCheckVolumeAndReset(11); EXPECT_CALL(*agc_, GetRmsErrorDb(_)).WillOnce(Return(false)); CallProcess(1); @@ -689,80 +772,227 @@ TEST_F(AgcManagerDirectTest, TakesNoActionOnZeroMicVolume) { EXPECT_EQ(0, manager_.stream_analog_level()); } +TEST_F(AgcManagerDirectTest, ClippingDetectionLowersVolume) { + SetVolumeAndProcess(255); + EXPECT_EQ(255, manager_.stream_analog_level()); + CallPreProcForChangingAudio(/*num_calls=*/100, /*peak_ratio=*/0.99f); + EXPECT_EQ(255, manager_.stream_analog_level()); + CallPreProcForChangingAudio(/*num_calls=*/100, /*peak_ratio=*/1.0f); + EXPECT_EQ(240, manager_.stream_analog_level()); +} + +TEST_F(AgcManagerDirectTest, DisabledClippingPredictorDoesNotLowerVolume) { + SetVolumeAndProcess(255); + EXPECT_FALSE(manager_.clipping_predictor_enabled()); + EXPECT_EQ(255, manager_.stream_analog_level()); + CallPreProcForChangingAudio(/*num_calls=*/100, /*peak_ratio=*/0.99f); + EXPECT_EQ(255, manager_.stream_analog_level()); + CallPreProcForChangingAudio(/*num_calls=*/100, /*peak_ratio=*/0.99f); + EXPECT_EQ(255, manager_.stream_analog_level()); +} + TEST(AgcManagerDirectStandaloneTest, DisableDigitalDisablesDigital) { auto agc = std::unique_ptr(new ::testing::NiceMock()); MockGainControl gctrl; - AgcManagerDirect manager(/* num_capture_channels */ 1, kInitialVolume, - kClippedMin, - /* use agc2 level estimation */ false, - /* disable digital adaptive */ true, kSampleRateHz); - EXPECT_CALL(gctrl, set_mode(GainControl::kFixedDigital)); EXPECT_CALL(gctrl, set_target_level_dbfs(0)); EXPECT_CALL(gctrl, set_compression_gain_db(0)); EXPECT_CALL(gctrl, enable_limiter(false)); - manager.Initialize(); - manager.SetupDigitalGainControl(&gctrl); + std::unique_ptr manager = + CreateAgcManagerDirect(kInitialVolume, kClippedLevelStep, + kClippedRatioThreshold, kClippedWaitFrames); + manager->Initialize(); + manager->SetupDigitalGainControl(&gctrl); } TEST(AgcManagerDirectStandaloneTest, AgcMinMicLevelExperiment) { - auto agc_man = std::unique_ptr(new AgcManagerDirect( - /* num_capture_channels */ 1, kInitialVolume, kClippedMin, true, true, - kSampleRateHz)); - EXPECT_EQ(agc_man->channel_agcs_[0]->min_mic_level(), kMinMicLevel); - EXPECT_EQ(agc_man->channel_agcs_[0]->startup_min_level(), kInitialVolume); - { - test::ScopedFieldTrials field_trial( - "WebRTC-Audio-AgcMinMicLevelExperiment/Disabled/"); - agc_man.reset(new AgcManagerDirect( - /* num_capture_channels */ 1, kInitialVolume, kClippedMin, true, true, - kSampleRateHz)); - EXPECT_EQ(agc_man->channel_agcs_[0]->min_mic_level(), kMinMicLevel); - EXPECT_EQ(agc_man->channel_agcs_[0]->startup_min_level(), kInitialVolume); - } - { - // Valid range of field-trial parameter is [0,255]. - test::ScopedFieldTrials field_trial( - "WebRTC-Audio-AgcMinMicLevelExperiment/Enabled-256/"); - agc_man.reset(new AgcManagerDirect( - /* num_capture_channels */ 1, kInitialVolume, kClippedMin, true, true, - kSampleRateHz)); - EXPECT_EQ(agc_man->channel_agcs_[0]->min_mic_level(), kMinMicLevel); - EXPECT_EQ(agc_man->channel_agcs_[0]->startup_min_level(), kInitialVolume); - } - { - test::ScopedFieldTrials field_trial( - "WebRTC-Audio-AgcMinMicLevelExperiment/Enabled--1/"); - agc_man.reset(new AgcManagerDirect( - /* num_capture_channels */ 1, kInitialVolume, kClippedMin, true, true, - kSampleRateHz)); - EXPECT_EQ(agc_man->channel_agcs_[0]->min_mic_level(), kMinMicLevel); - EXPECT_EQ(agc_man->channel_agcs_[0]->startup_min_level(), kInitialVolume); - } - { - // Verify that a valid experiment changes the minimum microphone level. - // The start volume is larger than the min level and should therefore not - // be changed. - test::ScopedFieldTrials field_trial( - "WebRTC-Audio-AgcMinMicLevelExperiment/Enabled-50/"); - agc_man.reset(new AgcManagerDirect( - /* num_capture_channels */ 1, kInitialVolume, kClippedMin, true, true, - kSampleRateHz)); - EXPECT_EQ(agc_man->channel_agcs_[0]->min_mic_level(), 50); - EXPECT_EQ(agc_man->channel_agcs_[0]->startup_min_level(), kInitialVolume); - } - { - // Use experiment to reduce the default minimum microphone level, start at - // a lower level and ensure that the startup level is increased to the min - // level set by the experiment. - test::ScopedFieldTrials field_trial( - "WebRTC-Audio-AgcMinMicLevelExperiment/Enabled-50/"); - agc_man.reset(new AgcManagerDirect(/* num_capture_channels */ 1, 30, - kClippedMin, true, true, kSampleRateHz)); - EXPECT_EQ(agc_man->channel_agcs_[0]->min_mic_level(), 50); - EXPECT_EQ(agc_man->channel_agcs_[0]->startup_min_level(), 50); - } + std::unique_ptr manager = + CreateAgcManagerDirect(kInitialVolume, kClippedLevelStep, + kClippedRatioThreshold, kClippedWaitFrames); + EXPECT_EQ(manager->channel_agcs_[0]->min_mic_level(), kMinMicLevel); + EXPECT_EQ(manager->channel_agcs_[0]->startup_min_level(), kInitialVolume); +} + +TEST(AgcManagerDirectStandaloneTest, AgcMinMicLevelExperimentDisabled) { + test::ScopedFieldTrials field_trial( + "WebRTC-Audio-AgcMinMicLevelExperiment/Disabled/"); + std::unique_ptr manager = + CreateAgcManagerDirect(kInitialVolume, kClippedLevelStep, + kClippedRatioThreshold, kClippedWaitFrames); + EXPECT_EQ(manager->channel_agcs_[0]->min_mic_level(), kMinMicLevel); + EXPECT_EQ(manager->channel_agcs_[0]->startup_min_level(), kInitialVolume); +} + +// Checks that a field-trial parameter outside of the valid range [0,255] is +// ignored. +TEST(AgcManagerDirectStandaloneTest, AgcMinMicLevelExperimentOutOfRangeAbove) { + test::ScopedFieldTrials field_trial( + "WebRTC-Audio-AgcMinMicLevelExperiment/Enabled-256/"); + std::unique_ptr manager = + CreateAgcManagerDirect(kInitialVolume, kClippedLevelStep, + kClippedRatioThreshold, kClippedWaitFrames); + EXPECT_EQ(manager->channel_agcs_[0]->min_mic_level(), kMinMicLevel); + EXPECT_EQ(manager->channel_agcs_[0]->startup_min_level(), kInitialVolume); +} + +// Checks that a field-trial parameter outside of the valid range [0,255] is +// ignored. +TEST(AgcManagerDirectStandaloneTest, AgcMinMicLevelExperimentOutOfRangeBelow) { + test::ScopedFieldTrials field_trial( + "WebRTC-Audio-AgcMinMicLevelExperiment/Enabled--1/"); + std::unique_ptr manager = + CreateAgcManagerDirect(kInitialVolume, kClippedLevelStep, + kClippedRatioThreshold, kClippedWaitFrames); + EXPECT_EQ(manager->channel_agcs_[0]->min_mic_level(), kMinMicLevel); + EXPECT_EQ(manager->channel_agcs_[0]->startup_min_level(), kInitialVolume); +} + +// Verifies that a valid experiment changes the minimum microphone level. The +// start volume is larger than the min level and should therefore not be +// changed. +TEST(AgcManagerDirectStandaloneTest, AgcMinMicLevelExperimentEnabled50) { + test::ScopedFieldTrials field_trial( + "WebRTC-Audio-AgcMinMicLevelExperiment/Enabled-50/"); + std::unique_ptr manager = + CreateAgcManagerDirect(kInitialVolume, kClippedLevelStep, + kClippedRatioThreshold, kClippedWaitFrames); + EXPECT_EQ(manager->channel_agcs_[0]->min_mic_level(), 50); + EXPECT_EQ(manager->channel_agcs_[0]->startup_min_level(), kInitialVolume); +} + +// Uses experiment to reduce the default minimum microphone level, start at a +// lower level and ensure that the startup level is increased to the min level +// set by the experiment. +TEST(AgcManagerDirectStandaloneTest, + AgcMinMicLevelExperimentEnabledAboveStartupLevel) { + test::ScopedFieldTrials field_trial( + "WebRTC-Audio-AgcMinMicLevelExperiment/Enabled-50/"); + std::unique_ptr manager = + CreateAgcManagerDirect(/*startup_min_level=*/30, kClippedLevelStep, + kClippedRatioThreshold, kClippedWaitFrames); + EXPECT_EQ(manager->channel_agcs_[0]->min_mic_level(), 50); + EXPECT_EQ(manager->channel_agcs_[0]->startup_min_level(), 50); +} + +// TODO(bugs.webrtc.org/12774): Test the bahavior of `clipped_level_step`. +// TODO(bugs.webrtc.org/12774): Test the bahavior of `clipped_ratio_threshold`. +// TODO(bugs.webrtc.org/12774): Test the bahavior of `clipped_wait_frames`. +// Verifies that configurable clipping parameters are initialized as intended. +TEST(AgcManagerDirectStandaloneTest, ClippingParametersVerified) { + std::unique_ptr manager = + CreateAgcManagerDirect(kInitialVolume, kClippedLevelStep, + kClippedRatioThreshold, kClippedWaitFrames); + manager->Initialize(); + EXPECT_EQ(manager->clipped_level_step_, kClippedLevelStep); + EXPECT_EQ(manager->clipped_ratio_threshold_, kClippedRatioThreshold); + EXPECT_EQ(manager->clipped_wait_frames_, kClippedWaitFrames); + std::unique_ptr manager_custom = + CreateAgcManagerDirect(kInitialVolume, + /*clipped_level_step=*/10, + /*clipped_ratio_threshold=*/0.2f, + /*clipped_wait_frames=*/50); + manager_custom->Initialize(); + EXPECT_EQ(manager_custom->clipped_level_step_, 10); + EXPECT_EQ(manager_custom->clipped_ratio_threshold_, 0.2f); + EXPECT_EQ(manager_custom->clipped_wait_frames_, 50); +} + +TEST(AgcManagerDirectStandaloneTest, + DisableClippingPredictorDisablesClippingPredictor) { + ClippingPredictorConfig default_config; + EXPECT_FALSE(default_config.enabled); + std::unique_ptr manager = CreateAgcManagerDirect( + kInitialVolume, kClippedLevelStep, kClippedRatioThreshold, + kClippedWaitFrames, default_config); + manager->Initialize(); + EXPECT_FALSE(manager->clipping_predictor_enabled()); + EXPECT_FALSE(manager->use_clipping_predictor_step()); +} + +TEST(AgcManagerDirectStandaloneTest, ClippingPredictorDisabledByDefault) { + constexpr ClippingPredictorConfig kDefaultConfig; + EXPECT_FALSE(kDefaultConfig.enabled); +} + +TEST(AgcManagerDirectStandaloneTest, + EnableClippingPredictorEnablesClippingPredictor) { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + ClippingPredictorConfig config; + config.enabled = true; + config.use_predicted_step = true; + std::unique_ptr manager = CreateAgcManagerDirect( + kInitialVolume, kClippedLevelStep, kClippedRatioThreshold, + kClippedWaitFrames, config); + manager->Initialize(); + EXPECT_TRUE(manager->clipping_predictor_enabled()); + EXPECT_TRUE(manager->use_clipping_predictor_step()); +} + +TEST(AgcManagerDirectStandaloneTest, + DisableClippingPredictorDoesNotLowerVolume) { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + constexpr ClippingPredictorConfig kConfig{/*enabled=*/false}; + AgcManagerDirect manager(new ::testing::NiceMock(), kInitialVolume, + kClippedMin, kSampleRateHz, kClippedLevelStep, + kClippedRatioThreshold, kClippedWaitFrames, kConfig); + manager.Initialize(); + manager.set_stream_analog_level(/*level=*/255); + EXPECT_FALSE(manager.clipping_predictor_enabled()); + EXPECT_FALSE(manager.use_clipping_predictor_step()); + EXPECT_EQ(manager.stream_analog_level(), 255); + manager.Process(nullptr); + CallPreProcessAudioBuffer(/*num_calls=*/10, /*peak_ratio=*/0.99f, manager); + EXPECT_EQ(manager.stream_analog_level(), 255); + CallPreProcessAudioBuffer(/*num_calls=*/300, /*peak_ratio=*/0.99f, manager); + EXPECT_EQ(manager.stream_analog_level(), 255); + CallPreProcessAudioBuffer(/*num_calls=*/10, /*peak_ratio=*/0.99f, manager); + EXPECT_EQ(manager.stream_analog_level(), 255); +} + +TEST(AgcManagerDirectStandaloneTest, + EnableClippingPredictorWithUnusedPredictedStepDoesNotLowerVolume) { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + ClippingPredictorConfig config; + config.enabled = true; + config.use_predicted_step = false; + AgcManagerDirect manager(new ::testing::NiceMock(), kInitialVolume, + kClippedMin, kSampleRateHz, kClippedLevelStep, + kClippedRatioThreshold, kClippedWaitFrames, config); + manager.Initialize(); + manager.set_stream_analog_level(/*level=*/255); + EXPECT_TRUE(manager.clipping_predictor_enabled()); + EXPECT_FALSE(manager.use_clipping_predictor_step()); + EXPECT_EQ(manager.stream_analog_level(), 255); + manager.Process(nullptr); + CallPreProcessAudioBuffer(/*num_calls=*/10, /*peak_ratio=*/0.99f, manager); + EXPECT_EQ(manager.stream_analog_level(), 255); + CallPreProcessAudioBuffer(/*num_calls=*/300, /*peak_ratio=*/0.99f, manager); + EXPECT_EQ(manager.stream_analog_level(), 255); + CallPreProcessAudioBuffer(/*num_calls=*/10, /*peak_ratio=*/0.99f, manager); + EXPECT_EQ(manager.stream_analog_level(), 255); +} + +TEST(AgcManagerDirectStandaloneTest, EnableClippingPredictorLowersVolume) { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + ClippingPredictorConfig config; + config.enabled = true; + config.use_predicted_step = true; + AgcManagerDirect manager(new ::testing::NiceMock(), kInitialVolume, + kClippedMin, kSampleRateHz, kClippedLevelStep, + kClippedRatioThreshold, kClippedWaitFrames, config); + manager.Initialize(); + manager.set_stream_analog_level(/*level=*/255); + EXPECT_TRUE(manager.clipping_predictor_enabled()); + EXPECT_TRUE(manager.use_clipping_predictor_step()); + EXPECT_EQ(manager.stream_analog_level(), 255); + manager.Process(nullptr); + CallPreProcessAudioBuffer(/*num_calls=*/10, /*peak_ratio=*/0.99f, manager); + EXPECT_EQ(manager.stream_analog_level(), 240); + CallPreProcessAudioBuffer(/*num_calls=*/300, /*peak_ratio=*/0.99f, manager); + EXPECT_EQ(manager.stream_analog_level(), 240); + CallPreProcessAudioBuffer(/*num_calls=*/10, /*peak_ratio=*/0.99f, manager); + EXPECT_EQ(manager.stream_analog_level(), 225); } } // namespace webrtc diff --git a/modules/audio_processing/agc/clipping_predictor.cc b/modules/audio_processing/agc/clipping_predictor.cc new file mode 100644 index 0000000000..982bbca2ee --- /dev/null +++ b/modules/audio_processing/agc/clipping_predictor.cc @@ -0,0 +1,383 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc/clipping_predictor.h" + +#include +#include + +#include "common_audio/include/audio_util.h" +#include "modules/audio_processing/agc/clipping_predictor_level_buffer.h" +#include "modules/audio_processing/agc/gain_map_internal.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/numerics/safe_minmax.h" + +namespace webrtc { +namespace { + +constexpr int kClippingPredictorMaxGainChange = 15; + +// Estimates the new level from the gain error; a copy of the function +// `LevelFromGainError` in agc_manager_direct.cc. +int LevelFromGainError(int gain_error, + int level, + int min_mic_level, + int max_mic_level) { + RTC_DCHECK_GE(level, 0); + RTC_DCHECK_LE(level, max_mic_level); + if (gain_error == 0) { + return level; + } + int new_level = level; + if (gain_error > 0) { + while (kGainMap[new_level] - kGainMap[level] < gain_error && + new_level < max_mic_level) { + ++new_level; + } + } else { + while (kGainMap[new_level] - kGainMap[level] > gain_error && + new_level > min_mic_level) { + --new_level; + } + } + return new_level; +} + +float ComputeCrestFactor(const ClippingPredictorLevelBuffer::Level& level) { + const float crest_factor = + FloatS16ToDbfs(level.max) - FloatS16ToDbfs(std::sqrt(level.average)); + return crest_factor; +} + +// Crest factor-based clipping prediction and clipped level step estimation. +class ClippingEventPredictor : public ClippingPredictor { + public: + // ClippingEventPredictor with `num_channels` channels (limited to values + // higher than zero); window size `window_length` and reference window size + // `reference_window_length` (both referring to the number of frames in the + // respective sliding windows and limited to values higher than zero); + // reference window delay `reference_window_delay` (delay in frames, limited + // to values zero and higher with an additional requirement of + // `window_length` < `reference_window_length` + reference_window_delay`); + // and an estimation peak threshold `clipping_threshold` and a crest factor + // drop threshold `crest_factor_margin` (both in dB). + ClippingEventPredictor(int num_channels, + int window_length, + int reference_window_length, + int reference_window_delay, + float clipping_threshold, + float crest_factor_margin) + : window_length_(window_length), + reference_window_length_(reference_window_length), + reference_window_delay_(reference_window_delay), + clipping_threshold_(clipping_threshold), + crest_factor_margin_(crest_factor_margin) { + RTC_DCHECK_GT(num_channels, 0); + RTC_DCHECK_GT(window_length, 0); + RTC_DCHECK_GT(reference_window_length, 0); + RTC_DCHECK_GE(reference_window_delay, 0); + RTC_DCHECK_GT(reference_window_length + reference_window_delay, + window_length); + const int buffer_length = GetMinFramesProcessed(); + RTC_DCHECK_GT(buffer_length, 0); + for (int i = 0; i < num_channels; ++i) { + ch_buffers_.push_back( + std::make_unique(buffer_length)); + } + } + + ClippingEventPredictor(const ClippingEventPredictor&) = delete; + ClippingEventPredictor& operator=(const ClippingEventPredictor&) = delete; + ~ClippingEventPredictor() {} + + void Reset() { + const int num_channels = ch_buffers_.size(); + for (int i = 0; i < num_channels; ++i) { + ch_buffers_[i]->Reset(); + } + } + + // Analyzes a frame of audio and stores the framewise metrics in + // `ch_buffers_`. + void Analyze(const AudioFrameView& frame) { + const int num_channels = frame.num_channels(); + RTC_DCHECK_EQ(num_channels, ch_buffers_.size()); + const int samples_per_channel = frame.samples_per_channel(); + RTC_DCHECK_GT(samples_per_channel, 0); + for (int channel = 0; channel < num_channels; ++channel) { + float sum_squares = 0.0f; + float peak = 0.0f; + for (const auto& sample : frame.channel(channel)) { + sum_squares += sample * sample; + peak = std::max(std::fabs(sample), peak); + } + ch_buffers_[channel]->Push( + {sum_squares / static_cast(samples_per_channel), peak}); + } + } + + // Estimates the analog gain adjustment for channel `channel` using a + // sliding window over the frame-wise metrics in `ch_buffers_`. Returns an + // estimate for the clipped level step equal to `default_clipped_level_step_` + // if at least `GetMinFramesProcessed()` frames have been processed since the + // last reset and a clipping event is predicted. `level`, `min_mic_level`, and + // `max_mic_level` are limited to [0, 255] and `default_step` to [1, 255]. + absl::optional EstimateClippedLevelStep(int channel, + int level, + int default_step, + int min_mic_level, + int max_mic_level) const { + RTC_CHECK_GE(channel, 0); + RTC_CHECK_LT(channel, ch_buffers_.size()); + RTC_DCHECK_GE(level, 0); + RTC_DCHECK_LE(level, 255); + RTC_DCHECK_GT(default_step, 0); + RTC_DCHECK_LE(default_step, 255); + RTC_DCHECK_GE(min_mic_level, 0); + RTC_DCHECK_LE(min_mic_level, 255); + RTC_DCHECK_GE(max_mic_level, 0); + RTC_DCHECK_LE(max_mic_level, 255); + if (level <= min_mic_level) { + return absl::nullopt; + } + if (PredictClippingEvent(channel)) { + const int new_level = + rtc::SafeClamp(level - default_step, min_mic_level, max_mic_level); + const int step = level - new_level; + if (step > 0) { + return step; + } + } + return absl::nullopt; + } + + private: + int GetMinFramesProcessed() const { + return reference_window_delay_ + reference_window_length_; + } + + // Predicts clipping events based on the processed audio frames. Returns + // true if a clipping event is likely. + bool PredictClippingEvent(int channel) const { + const auto metrics = + ch_buffers_[channel]->ComputePartialMetrics(0, window_length_); + if (!metrics.has_value() || + !(FloatS16ToDbfs(metrics.value().max) > clipping_threshold_)) { + return false; + } + const auto reference_metrics = ch_buffers_[channel]->ComputePartialMetrics( + reference_window_delay_, reference_window_length_); + if (!reference_metrics.has_value()) { + return false; + } + const float crest_factor = ComputeCrestFactor(metrics.value()); + const float reference_crest_factor = + ComputeCrestFactor(reference_metrics.value()); + if (crest_factor < reference_crest_factor - crest_factor_margin_) { + return true; + } + return false; + } + + std::vector> ch_buffers_; + const int window_length_; + const int reference_window_length_; + const int reference_window_delay_; + const float clipping_threshold_; + const float crest_factor_margin_; +}; + +// Performs crest factor-based clipping peak prediction. +class ClippingPeakPredictor : public ClippingPredictor { + public: + // Ctor. ClippingPeakPredictor with `num_channels` channels (limited to values + // higher than zero); window size `window_length` and reference window size + // `reference_window_length` (both referring to the number of frames in the + // respective sliding windows and limited to values higher than zero); + // reference window delay `reference_window_delay` (delay in frames, limited + // to values zero and higher with an additional requirement of + // `window_length` < `reference_window_length` + reference_window_delay`); + // and a clipping prediction threshold `clipping_threshold` (in dB). Adaptive + // clipped level step estimation is used if `adaptive_step_estimation` is + // true. + explicit ClippingPeakPredictor(int num_channels, + int window_length, + int reference_window_length, + int reference_window_delay, + int clipping_threshold, + bool adaptive_step_estimation) + : window_length_(window_length), + reference_window_length_(reference_window_length), + reference_window_delay_(reference_window_delay), + clipping_threshold_(clipping_threshold), + adaptive_step_estimation_(adaptive_step_estimation) { + RTC_DCHECK_GT(num_channels, 0); + RTC_DCHECK_GT(window_length, 0); + RTC_DCHECK_GT(reference_window_length, 0); + RTC_DCHECK_GE(reference_window_delay, 0); + RTC_DCHECK_GT(reference_window_length + reference_window_delay, + window_length); + const int buffer_length = GetMinFramesProcessed(); + RTC_DCHECK_GT(buffer_length, 0); + for (int i = 0; i < num_channels; ++i) { + ch_buffers_.push_back( + std::make_unique(buffer_length)); + } + } + + ClippingPeakPredictor(const ClippingPeakPredictor&) = delete; + ClippingPeakPredictor& operator=(const ClippingPeakPredictor&) = delete; + ~ClippingPeakPredictor() {} + + void Reset() { + const int num_channels = ch_buffers_.size(); + for (int i = 0; i < num_channels; ++i) { + ch_buffers_[i]->Reset(); + } + } + + // Analyzes a frame of audio and stores the framewise metrics in + // `ch_buffers_`. + void Analyze(const AudioFrameView& frame) { + const int num_channels = frame.num_channels(); + RTC_DCHECK_EQ(num_channels, ch_buffers_.size()); + const int samples_per_channel = frame.samples_per_channel(); + RTC_DCHECK_GT(samples_per_channel, 0); + for (int channel = 0; channel < num_channels; ++channel) { + float sum_squares = 0.0f; + float peak = 0.0f; + for (const auto& sample : frame.channel(channel)) { + sum_squares += sample * sample; + peak = std::max(std::fabs(sample), peak); + } + ch_buffers_[channel]->Push( + {sum_squares / static_cast(samples_per_channel), peak}); + } + } + + // Estimates the analog gain adjustment for channel `channel` using a + // sliding window over the frame-wise metrics in `ch_buffers_`. Returns an + // estimate for the clipped level step (equal to + // `default_clipped_level_step_` if `adaptive_estimation_` is false) if at + // least `GetMinFramesProcessed()` frames have been processed since the last + // reset and a clipping event is predicted. `level`, `min_mic_level`, and + // `max_mic_level` are limited to [0, 255] and `default_step` to [1, 255]. + absl::optional EstimateClippedLevelStep(int channel, + int level, + int default_step, + int min_mic_level, + int max_mic_level) const { + RTC_DCHECK_GE(channel, 0); + RTC_DCHECK_LT(channel, ch_buffers_.size()); + RTC_DCHECK_GE(level, 0); + RTC_DCHECK_LE(level, 255); + RTC_DCHECK_GT(default_step, 0); + RTC_DCHECK_LE(default_step, 255); + RTC_DCHECK_GE(min_mic_level, 0); + RTC_DCHECK_LE(min_mic_level, 255); + RTC_DCHECK_GE(max_mic_level, 0); + RTC_DCHECK_LE(max_mic_level, 255); + if (level <= min_mic_level) { + return absl::nullopt; + } + absl::optional estimate_db = EstimatePeakValue(channel); + if (estimate_db.has_value() && estimate_db.value() > clipping_threshold_) { + int step = 0; + if (!adaptive_step_estimation_) { + step = default_step; + } else { + const int estimated_gain_change = + rtc::SafeClamp(-static_cast(std::ceil(estimate_db.value())), + -kClippingPredictorMaxGainChange, 0); + step = + std::max(level - LevelFromGainError(estimated_gain_change, level, + min_mic_level, max_mic_level), + default_step); + } + const int new_level = + rtc::SafeClamp(level - step, min_mic_level, max_mic_level); + if (level > new_level) { + return level - new_level; + } + } + return absl::nullopt; + } + + private: + int GetMinFramesProcessed() { + return reference_window_delay_ + reference_window_length_; + } + + // Predicts clipping sample peaks based on the processed audio frames. + // Returns the estimated peak value if clipping is predicted. Otherwise + // returns absl::nullopt. + absl::optional EstimatePeakValue(int channel) const { + const auto reference_metrics = ch_buffers_[channel]->ComputePartialMetrics( + reference_window_delay_, reference_window_length_); + if (!reference_metrics.has_value()) { + return absl::nullopt; + } + const auto metrics = + ch_buffers_[channel]->ComputePartialMetrics(0, window_length_); + if (!metrics.has_value() || + !(FloatS16ToDbfs(metrics.value().max) > clipping_threshold_)) { + return absl::nullopt; + } + const float reference_crest_factor = + ComputeCrestFactor(reference_metrics.value()); + const float& mean_squares = metrics.value().average; + const float projected_peak = + reference_crest_factor + FloatS16ToDbfs(std::sqrt(mean_squares)); + return projected_peak; + } + + std::vector> ch_buffers_; + const int window_length_; + const int reference_window_length_; + const int reference_window_delay_; + const int clipping_threshold_; + const bool adaptive_step_estimation_; +}; + +} // namespace + +std::unique_ptr CreateClippingPredictor( + int num_channels, + const AudioProcessing::Config::GainController1::AnalogGainController:: + ClippingPredictor& config) { + if (!config.enabled) { + RTC_LOG(LS_INFO) << "[agc] Clipping prediction disabled."; + return nullptr; + } + RTC_LOG(LS_INFO) << "[agc] Clipping prediction enabled."; + using ClippingPredictorMode = AudioProcessing::Config::GainController1:: + AnalogGainController::ClippingPredictor::Mode; + switch (config.mode) { + case ClippingPredictorMode::kClippingEventPrediction: + return std::make_unique( + num_channels, config.window_length, config.reference_window_length, + config.reference_window_delay, config.clipping_threshold, + config.crest_factor_margin); + case ClippingPredictorMode::kAdaptiveStepClippingPeakPrediction: + return std::make_unique( + num_channels, config.window_length, config.reference_window_length, + config.reference_window_delay, config.clipping_threshold, + /*adaptive_step_estimation=*/true); + case ClippingPredictorMode::kFixedStepClippingPeakPrediction: + return std::make_unique( + num_channels, config.window_length, config.reference_window_length, + config.reference_window_delay, config.clipping_threshold, + /*adaptive_step_estimation=*/false); + } + RTC_NOTREACHED(); +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc/clipping_predictor.h b/modules/audio_processing/agc/clipping_predictor.h new file mode 100644 index 0000000000..ee2b6ef1e7 --- /dev/null +++ b/modules/audio_processing/agc/clipping_predictor.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC_CLIPPING_PREDICTOR_H_ +#define MODULES_AUDIO_PROCESSING_AGC_CLIPPING_PREDICTOR_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "modules/audio_processing/include/audio_frame_view.h" +#include "modules/audio_processing/include/audio_processing.h" + +namespace webrtc { + +// Frame-wise clipping prediction and clipped level step estimation. Analyzes +// 10 ms multi-channel frames and estimates an analog mic level decrease step +// to possibly avoid clipping when predicted. `Analyze()` and +// `EstimateClippedLevelStep()` can be called in any order. +class ClippingPredictor { + public: + virtual ~ClippingPredictor() = default; + + virtual void Reset() = 0; + + // Analyzes a 10 ms multi-channel audio frame. + virtual void Analyze(const AudioFrameView& frame) = 0; + + // Predicts if clipping is going to occur for the specified `channel` in the + // near-future and, if so, it returns a recommended analog mic level decrease + // step. Returns absl::nullopt if clipping is not predicted. + // `level` is the current analog mic level, `default_step` is the amount the + // mic level is lowered by the analog controller with every clipping event and + // `min_mic_level` and `max_mic_level` is the range of allowed analog mic + // levels. + virtual absl::optional EstimateClippedLevelStep( + int channel, + int level, + int default_step, + int min_mic_level, + int max_mic_level) const = 0; + +}; + +// Creates a ClippingPredictor based on the provided `config`. When enabled, +// the following must hold for `config`: +// `window_length < reference_window_length + reference_window_delay`. +// Returns `nullptr` if `config.enabled` is false. +std::unique_ptr CreateClippingPredictor( + int num_channels, + const AudioProcessing::Config::GainController1::AnalogGainController:: + ClippingPredictor& config); + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC_CLIPPING_PREDICTOR_H_ diff --git a/modules/audio_processing/agc/clipping_predictor_evaluator.cc b/modules/audio_processing/agc/clipping_predictor_evaluator.cc new file mode 100644 index 0000000000..2a4ea922cf --- /dev/null +++ b/modules/audio_processing/agc/clipping_predictor_evaluator.cc @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc/clipping_predictor_evaluator.h" + +#include + +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" + +namespace webrtc { +namespace { + +// Returns the index of the oldest item in the ring buffer for a non-empty +// ring buffer with give `size`, `tail` index and `capacity`. +int OldestExpectedDetectionIndex(int size, int tail, int capacity) { + RTC_DCHECK_GT(size, 0); + return tail - size + (tail < size ? capacity : 0); +} + +} // namespace + +ClippingPredictorEvaluator::ClippingPredictorEvaluator(int history_size) + : history_size_(history_size), + ring_buffer_capacity_(history_size + 1), + ring_buffer_(ring_buffer_capacity_), + true_positives_(0), + true_negatives_(0), + false_positives_(0), + false_negatives_(0) { + RTC_DCHECK_GT(history_size_, 0); + Reset(); +} + +ClippingPredictorEvaluator::~ClippingPredictorEvaluator() = default; + +absl::optional ClippingPredictorEvaluator::Observe( + bool clipping_detected, + bool clipping_predicted) { + RTC_DCHECK_GE(ring_buffer_size_, 0); + RTC_DCHECK_LE(ring_buffer_size_, ring_buffer_capacity_); + RTC_DCHECK_GE(ring_buffer_tail_, 0); + RTC_DCHECK_LT(ring_buffer_tail_, ring_buffer_capacity_); + + DecreaseTimesToLive(); + if (clipping_predicted) { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + Push(/*expected_detection=*/{/*ttl=*/history_size_, /*detected=*/false}); + } + // Clipping is expected if there are expected detections regardless of + // whether all the expected detections have been previously matched - i.e., + // `ExpectedDetection::detected` is true. + const bool clipping_expected = ring_buffer_size_ > 0; + + absl::optional prediction_interval; + if (clipping_expected && clipping_detected) { + prediction_interval = FindEarliestPredictionInterval(); + // Add a true positive for each unexpired expected detection. + const int num_modified_items = MarkExpectedDetectionAsDetected(); + true_positives_ += num_modified_items; + RTC_DCHECK(prediction_interval.has_value() || num_modified_items == 0); + RTC_DCHECK(!prediction_interval.has_value() || num_modified_items > 0); + } else if (clipping_expected && !clipping_detected) { + // Add a false positive if there is one expected detection that has expired + // and that has never been matched before. Note that there is at most one + // unmatched expired detection. + if (HasExpiredUnmatchedExpectedDetection()) { + false_positives_++; + } + } else if (!clipping_expected && clipping_detected) { + false_negatives_++; + } else { + RTC_DCHECK(!clipping_expected && !clipping_detected); + true_negatives_++; + } + return prediction_interval; +} + +void ClippingPredictorEvaluator::Reset() { + // Empty the ring buffer of expected detections. + ring_buffer_tail_ = 0; + ring_buffer_size_ = 0; +} + +// Cost: O(1). +void ClippingPredictorEvaluator::Push(ExpectedDetection value) { + ring_buffer_[ring_buffer_tail_] = value; + ring_buffer_tail_++; + if (ring_buffer_tail_ == ring_buffer_capacity_) { + ring_buffer_tail_ = 0; + } + ring_buffer_size_ = std::min(ring_buffer_capacity_, ring_buffer_size_ + 1); +} + +// Cost: O(N). +void ClippingPredictorEvaluator::DecreaseTimesToLive() { + bool expired_found = false; + for (int i = ring_buffer_tail_ - ring_buffer_size_; i < ring_buffer_tail_; + ++i) { + int index = i >= 0 ? i : ring_buffer_capacity_ + i; + RTC_DCHECK_GE(index, 0); + RTC_DCHECK_LT(index, ring_buffer_.size()); + RTC_DCHECK_GE(ring_buffer_[index].ttl, 0); + if (ring_buffer_[index].ttl == 0) { + RTC_DCHECK(!expired_found) + << "There must be at most one expired item in the ring buffer."; + expired_found = true; + RTC_DCHECK_EQ(index, OldestExpectedDetectionIndex(ring_buffer_size_, + ring_buffer_tail_, + ring_buffer_capacity_)) + << "The expired item must be the oldest in the ring buffer."; + } + ring_buffer_[index].ttl--; + } + if (expired_found) { + ring_buffer_size_--; + } +} + +// Cost: O(N). +absl::optional ClippingPredictorEvaluator::FindEarliestPredictionInterval() + const { + absl::optional prediction_interval; + for (int i = ring_buffer_tail_ - ring_buffer_size_; i < ring_buffer_tail_; + ++i) { + int index = i >= 0 ? i : ring_buffer_capacity_ + i; + RTC_DCHECK_GE(index, 0); + RTC_DCHECK_LT(index, ring_buffer_.size()); + if (!ring_buffer_[index].detected) { + prediction_interval = std::max(prediction_interval.value_or(0), + history_size_ - ring_buffer_[index].ttl); + } + } + return prediction_interval; +} + +// Cost: O(N). +int ClippingPredictorEvaluator::MarkExpectedDetectionAsDetected() { + int num_modified_items = 0; + for (int i = ring_buffer_tail_ - ring_buffer_size_; i < ring_buffer_tail_; + ++i) { + int index = i >= 0 ? i : ring_buffer_capacity_ + i; + RTC_DCHECK_GE(index, 0); + RTC_DCHECK_LT(index, ring_buffer_.size()); + if (!ring_buffer_[index].detected) { + num_modified_items++; + } + ring_buffer_[index].detected = true; + } + return num_modified_items; +} + +// Cost: O(1). +bool ClippingPredictorEvaluator::HasExpiredUnmatchedExpectedDetection() const { + if (ring_buffer_size_ == 0) { + return false; + } + // If an expired item, that is `ttl` equal to 0, exists, it must be the + // oldest. + const int oldest_index = OldestExpectedDetectionIndex( + ring_buffer_size_, ring_buffer_tail_, ring_buffer_capacity_); + RTC_DCHECK_GE(oldest_index, 0); + RTC_DCHECK_LT(oldest_index, ring_buffer_.size()); + return ring_buffer_[oldest_index].ttl == 0 && + !ring_buffer_[oldest_index].detected; +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc/clipping_predictor_evaluator.h b/modules/audio_processing/agc/clipping_predictor_evaluator.h new file mode 100644 index 0000000000..e76f25d5e1 --- /dev/null +++ b/modules/audio_processing/agc/clipping_predictor_evaluator.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC_CLIPPING_PREDICTOR_EVALUATOR_H_ +#define MODULES_AUDIO_PROCESSING_AGC_CLIPPING_PREDICTOR_EVALUATOR_H_ + +#include + +#include "absl/types/optional.h" + +namespace webrtc { + +// Counts true/false positives/negatives while observing sequences of flag pairs +// that indicate whether clipping has been detected and/or if clipping is +// predicted. When a true positive is found measures the time interval between +// prediction and detection events. +// From the time a prediction is observed and for a period equal to +// `history_size` calls to `Observe()`, one or more detections are expected. If +// the expectation is met, a true positives is added and the time interval +// between the earliest prediction and the detection is recorded; otherwise, +// when the deadline is reached, a false positive is added. Note that one +// detection matches all the expected detections that have not expired - i.e., +// one detection counts as multiple true positives. +// If a detection is observed, but no prediction has been observed over the past +// `history_size` calls to `Observe()`, then a false negative is added; +// otherwise, a true negative is added. +class ClippingPredictorEvaluator { + public: + // Ctor. `history_size` indicates how long to wait for a call to `Observe()` + // having `clipping_detected` set to true from the time clipping is predicted. + explicit ClippingPredictorEvaluator(int history_size); + ClippingPredictorEvaluator(const ClippingPredictorEvaluator&) = delete; + ClippingPredictorEvaluator& operator=(const ClippingPredictorEvaluator&) = + delete; + ~ClippingPredictorEvaluator(); + + // Observes whether clipping has been detected and/or if clipping is + // predicted. When predicted one or more detections are expected in the next + // `history_size_` calls of `Observe()`. When true positives are found returns + // the prediction interval between the earliest prediction and the detection. + absl::optional Observe(bool clipping_detected, bool clipping_predicted); + + // Removes any expectation recently set after a call to `Observe()` having + // `clipping_predicted` set to true. + void Reset(); + + // Metrics getters. + int true_positives() const { return true_positives_; } + int true_negatives() const { return true_negatives_; } + int false_positives() const { return false_positives_; } + int false_negatives() const { return false_negatives_; } + + private: + const int history_size_; + + // State of a detection expected to be observed after a prediction. + struct ExpectedDetection { + // Time to live (TTL); remaining number of `Observe()` calls to match a call + // having `clipping_detected` set to true. + int ttl; + // True if an `Observe()` call having `clipping_detected` set to true has + // been observed. + bool detected; + }; + // Ring buffer of expected detections. + const int ring_buffer_capacity_; + std::vector ring_buffer_; + int ring_buffer_tail_; + int ring_buffer_size_; + + // Pushes `expected_detection` into `expected_matches_ring_buffer_`. + void Push(ExpectedDetection expected_detection); + // Decreased the TTLs in `expected_matches_ring_buffer_` and removes expired + // items. + void DecreaseTimesToLive(); + // Returns the prediction interval for the earliest unexpired expected + // detection if any. + absl::optional FindEarliestPredictionInterval() const; + // Marks all the items in `expected_matches_ring_buffer_` as `detected` and + // returns the number of updated items. + int MarkExpectedDetectionAsDetected(); + // Returns true if `expected_matches_ring_buffer_` has an item having `ttl` + // equal to 0 (expired) and `detected` equal to false (unmatched). + bool HasExpiredUnmatchedExpectedDetection() const; + + // Metrics. + int true_positives_; + int true_negatives_; + int false_positives_; + int false_negatives_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC_CLIPPING_PREDICTOR_EVALUATOR_H_ diff --git a/modules/audio_processing/agc/clipping_predictor_evaluator_unittest.cc b/modules/audio_processing/agc/clipping_predictor_evaluator_unittest.cc new file mode 100644 index 0000000000..1eb83eae61 --- /dev/null +++ b/modules/audio_processing/agc/clipping_predictor_evaluator_unittest.cc @@ -0,0 +1,568 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc/clipping_predictor_evaluator.h" + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "rtc_base/numerics/safe_conversions.h" +#include "rtc_base/random.h" +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +using testing::Eq; +using testing::Optional; + +constexpr bool kDetected = true; +constexpr bool kNotDetected = false; + +constexpr bool kPredicted = true; +constexpr bool kNotPredicted = false; + +int SumTrueFalsePositivesNegatives( + const ClippingPredictorEvaluator& evaluator) { + return evaluator.true_positives() + evaluator.true_negatives() + + evaluator.false_positives() + evaluator.false_negatives(); +} + +// Checks the metrics after init - i.e., no call to `Observe()`. +TEST(ClippingPredictorEvaluatorTest, Init) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + EXPECT_EQ(evaluator.true_positives(), 0); + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_positives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 0); +} + +class ClippingPredictorEvaluatorParameterization + : public ::testing::TestWithParam> { + protected: + uint64_t seed() const { + return rtc::checked_cast(std::get<0>(GetParam())); + } + int history_size() const { return std::get<1>(GetParam()); } +}; + +// Checks that after each call to `Observe()` at most one metric changes. +TEST_P(ClippingPredictorEvaluatorParameterization, AtMostOneMetricChanges) { + constexpr int kNumCalls = 123; + Random random_generator(seed()); + ClippingPredictorEvaluator evaluator(history_size()); + + for (int i = 0; i < kNumCalls; ++i) { + SCOPED_TRACE(i); + // Read metrics before `Observe()` is called. + const int last_tp = evaluator.true_positives(); + const int last_tn = evaluator.true_negatives(); + const int last_fp = evaluator.false_positives(); + const int last_fn = evaluator.false_negatives(); + // `Observe()` a random observation. + bool clipping_detected = random_generator.Rand(); + bool clipping_predicted = random_generator.Rand(); + evaluator.Observe(clipping_detected, clipping_predicted); + + // Check that at most one metric has changed. + int num_changes = 0; + num_changes += last_tp == evaluator.true_positives() ? 0 : 1; + num_changes += last_tn == evaluator.true_negatives() ? 0 : 1; + num_changes += last_fp == evaluator.false_positives() ? 0 : 1; + num_changes += last_fn == evaluator.false_negatives() ? 0 : 1; + EXPECT_GE(num_changes, 0); + EXPECT_LE(num_changes, 1); + } +} + +// Checks that after each call to `Observe()` each metric either remains +// unchanged or grows. +TEST_P(ClippingPredictorEvaluatorParameterization, MetricsAreWeaklyMonotonic) { + constexpr int kNumCalls = 123; + Random random_generator(seed()); + ClippingPredictorEvaluator evaluator(history_size()); + + for (int i = 0; i < kNumCalls; ++i) { + SCOPED_TRACE(i); + // Read metrics before `Observe()` is called. + const int last_tp = evaluator.true_positives(); + const int last_tn = evaluator.true_negatives(); + const int last_fp = evaluator.false_positives(); + const int last_fn = evaluator.false_negatives(); + // `Observe()` a random observation. + bool clipping_detected = random_generator.Rand(); + bool clipping_predicted = random_generator.Rand(); + evaluator.Observe(clipping_detected, clipping_predicted); + + // Check that metrics are weakly monotonic. + EXPECT_GE(evaluator.true_positives(), last_tp); + EXPECT_GE(evaluator.true_negatives(), last_tn); + EXPECT_GE(evaluator.false_positives(), last_fp); + EXPECT_GE(evaluator.false_negatives(), last_fn); + } +} + +// Checks that after each call to `Observe()` the growth speed of each metrics +// is bounded. +TEST_P(ClippingPredictorEvaluatorParameterization, BoundedMetricsGrowth) { + constexpr int kNumCalls = 123; + Random random_generator(seed()); + ClippingPredictorEvaluator evaluator(history_size()); + + for (int i = 0; i < kNumCalls; ++i) { + SCOPED_TRACE(i); + // Read metrics before `Observe()` is called. + const int last_tp = evaluator.true_positives(); + const int last_tn = evaluator.true_negatives(); + const int last_fp = evaluator.false_positives(); + const int last_fn = evaluator.false_negatives(); + // `Observe()` a random observation. + bool clipping_detected = random_generator.Rand(); + bool clipping_predicted = random_generator.Rand(); + evaluator.Observe(clipping_detected, clipping_predicted); + + // Check that TPs grow by at most `history_size() + 1`. Such an upper bound + // is reached when multiple predictions are matched by a single detection. + EXPECT_LE(evaluator.true_positives() - last_tp, history_size() + 1); + // Check that TNs, FPs and FNs grow by at most one. `max_growth`. + EXPECT_LE(evaluator.true_negatives() - last_tn, 1); + EXPECT_LE(evaluator.false_positives() - last_fp, 1); + EXPECT_LE(evaluator.false_negatives() - last_fn, 1); + } +} + +// Checks that `Observe()` returns a prediction interval if and only if one or +// more true positives are found. +TEST_P(ClippingPredictorEvaluatorParameterization, + PredictionIntervalIfAndOnlyIfTruePositives) { + constexpr int kNumCalls = 123; + Random random_generator(seed()); + ClippingPredictorEvaluator evaluator(history_size()); + + for (int i = 0; i < kNumCalls; ++i) { + SCOPED_TRACE(i); + // Read true positives before `Observe()` is called. + const int last_tp = evaluator.true_positives(); + // `Observe()` a random observation. + bool clipping_detected = random_generator.Rand(); + bool clipping_predicted = random_generator.Rand(); + absl::optional prediction_interval = + evaluator.Observe(clipping_detected, clipping_predicted); + + // Check that the prediction interval is returned when a true positive is + // found. + if (evaluator.true_positives() == last_tp) { + EXPECT_FALSE(prediction_interval.has_value()); + } else { + EXPECT_TRUE(prediction_interval.has_value()); + } + } +} + +INSTANTIATE_TEST_SUITE_P( + ClippingPredictorEvaluatorTest, + ClippingPredictorEvaluatorParameterization, + ::testing::Combine(::testing::Values(4, 8, 15, 16, 23, 42), + ::testing::Values(1, 10, 21))); + +// Checks that, observing a detection and a prediction after init, produces a +// true positive. +TEST(ClippingPredictorEvaluatorTest, OneTruePositiveAfterInit) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kDetected, kPredicted); + EXPECT_EQ(evaluator.true_positives(), 1); + + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_positives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 0); +} + +// Checks that, observing a detection but no prediction after init, produces a +// false negative. +TEST(ClippingPredictorEvaluatorTest, OneFalseNegativeAfterInit) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kDetected, kNotPredicted); + EXPECT_EQ(evaluator.false_negatives(), 1); + + EXPECT_EQ(evaluator.true_positives(), 0); + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_positives(), 0); +} + +// Checks that, observing no detection but a prediction after init, produces a +// false positive after expiration. +TEST(ClippingPredictorEvaluatorTest, OneFalsePositiveAfterInit) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kPredicted); + EXPECT_EQ(evaluator.false_positives(), 0); + evaluator.Observe(kNotDetected, kNotPredicted); + evaluator.Observe(kNotDetected, kNotPredicted); + evaluator.Observe(kNotDetected, kNotPredicted); + EXPECT_EQ(evaluator.false_positives(), 1); + + EXPECT_EQ(evaluator.true_positives(), 0); + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 0); +} + +// Checks that, observing no detection and no prediction after init, produces a +// true negative. +TEST(ClippingPredictorEvaluatorTest, OneTrueNegativeAfterInit) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kNotPredicted); + EXPECT_EQ(evaluator.true_negatives(), 1); + + EXPECT_EQ(evaluator.true_positives(), 0); + EXPECT_EQ(evaluator.false_positives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 0); +} + +// Checks that the evaluator detects true negatives when clipping is neither +// predicted nor detected. +TEST(ClippingPredictorEvaluatorTest, NeverDetectedAndNotPredicted) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kNotPredicted); + evaluator.Observe(kNotDetected, kNotPredicted); + evaluator.Observe(kNotDetected, kNotPredicted); + evaluator.Observe(kNotDetected, kNotPredicted); + EXPECT_EQ(evaluator.true_negatives(), 4); + + EXPECT_EQ(evaluator.true_positives(), 0); + EXPECT_EQ(evaluator.false_positives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 0); +} + +// Checks that the evaluator detects a false negative when clipping is detected +// but not predicted. +TEST(ClippingPredictorEvaluatorTest, DetectedButNotPredicted) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kNotPredicted); + evaluator.Observe(kNotDetected, kNotPredicted); + evaluator.Observe(kNotDetected, kNotPredicted); + evaluator.Observe(kDetected, kNotPredicted); + EXPECT_EQ(evaluator.false_negatives(), 1); + + EXPECT_EQ(evaluator.true_positives(), 0); + EXPECT_EQ(evaluator.true_negatives(), 3); + EXPECT_EQ(evaluator.false_positives(), 0); +} + +// Checks that the evaluator does not detect a false positive when clipping is +// predicted but not detected until the observation period expires. +TEST(ClippingPredictorEvaluatorTest, + PredictedOnceAndNeverDetectedBeforeDeadline) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kPredicted); + evaluator.Observe(kNotDetected, kNotPredicted); + EXPECT_EQ(evaluator.false_positives(), 0); + evaluator.Observe(kNotDetected, kNotPredicted); + EXPECT_EQ(evaluator.false_positives(), 0); + evaluator.Observe(kNotDetected, kNotPredicted); + EXPECT_EQ(evaluator.false_positives(), 1); + + EXPECT_EQ(evaluator.true_positives(), 0); + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 0); +} + +// Checks that the evaluator detects a false positive when clipping is predicted +// but detected after the observation period expires. +TEST(ClippingPredictorEvaluatorTest, PredictedOnceButDetectedAfterDeadline) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kPredicted); + evaluator.Observe(kNotDetected, kNotPredicted); + evaluator.Observe(kNotDetected, kNotPredicted); + evaluator.Observe(kNotDetected, kNotPredicted); + evaluator.Observe(kDetected, kNotPredicted); + EXPECT_EQ(evaluator.false_positives(), 1); + + EXPECT_EQ(evaluator.true_positives(), 0); + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 1); +} + +// Checks that a prediction followed by a detection counts as true positive. +TEST(ClippingPredictorEvaluatorTest, PredictedOnceAndThenImmediatelyDetected) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kPredicted); + EXPECT_EQ(evaluator.false_positives(), 0); + evaluator.Observe(kDetected, kNotPredicted); + EXPECT_EQ(evaluator.true_positives(), 1); + + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_positives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 0); +} + +// Checks that a prediction followed by a delayed detection counts as true +// positive if the delay is within the observation period. +TEST(ClippingPredictorEvaluatorTest, PredictedOnceAndDetectedBeforeDeadline) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kPredicted); + EXPECT_EQ(evaluator.false_positives(), 0); + evaluator.Observe(kNotDetected, kNotPredicted); + EXPECT_EQ(evaluator.false_positives(), 0); + evaluator.Observe(kDetected, kNotPredicted); + EXPECT_EQ(evaluator.true_positives(), 1); + + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_positives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 0); +} + +// Checks that a prediction followed by a delayed detection counts as true +// positive if the delay equals the observation period. +TEST(ClippingPredictorEvaluatorTest, PredictedOnceAndDetectedAtDeadline) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kPredicted); + EXPECT_EQ(evaluator.false_positives(), 0); + evaluator.Observe(kNotDetected, kNotPredicted); + EXPECT_EQ(evaluator.false_positives(), 0); + evaluator.Observe(kNotDetected, kNotPredicted); + EXPECT_EQ(evaluator.false_positives(), 0); + evaluator.Observe(kDetected, kNotPredicted); + EXPECT_EQ(evaluator.true_positives(), 1); + + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_positives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 0); +} + +// Checks that a prediction followed by a multiple adjacent detections within +// the deadline counts as a single true positive and that, after the deadline, +// a detection counts as a false negative. +TEST(ClippingPredictorEvaluatorTest, PredictedOnceAndDetectedMultipleTimes) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kPredicted); + evaluator.Observe(kNotDetected, kNotPredicted); + // Multiple detections. + evaluator.Observe(kDetected, kNotPredicted); + EXPECT_EQ(evaluator.true_positives(), 1); + evaluator.Observe(kDetected, kNotPredicted); + EXPECT_EQ(evaluator.true_positives(), 1); + // A detection outside of the observation period counts as false negative. + evaluator.Observe(kDetected, kNotPredicted); + EXPECT_EQ(evaluator.false_negatives(), 1); + EXPECT_EQ(SumTrueFalsePositivesNegatives(evaluator), 2); + + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_positives(), 0); +} + +// Checks that a false positive is added when clipping is detected after a too +// early prediction. +TEST(ClippingPredictorEvaluatorTest, + PredictedMultipleTimesAndDetectedOnceAfterDeadline) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kPredicted); // ---+ + evaluator.Observe(kNotDetected, kPredicted); // | + evaluator.Observe(kNotDetected, kPredicted); // | + evaluator.Observe(kNotDetected, kPredicted); // <--+ Not matched. + // The time to match a detection after the first prediction expired. + EXPECT_EQ(evaluator.false_positives(), 1); + evaluator.Observe(kDetected, kNotPredicted); + // The detection above does not match the first prediction because it happened + // after the deadline of the 1st prediction. + EXPECT_EQ(evaluator.false_positives(), 1); + + EXPECT_EQ(evaluator.true_positives(), 3); + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 0); +} + +// Checks that multiple consecutive predictions match the first detection +// observed before the expected detection deadline expires. +TEST(ClippingPredictorEvaluatorTest, PredictedMultipleTimesAndDetectedOnce) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kPredicted); // --+ + evaluator.Observe(kNotDetected, kPredicted); // | --+ + evaluator.Observe(kNotDetected, kPredicted); // | | --+ + evaluator.Observe(kDetected, kNotPredicted); // <-+ <-+ <-+ + EXPECT_EQ(evaluator.true_positives(), 3); + // The following observations do not generate any true negatives as they + // belong to the observation period of the last prediction - for which a + // detection has already been matched. + const int true_negatives = evaluator.true_negatives(); + evaluator.Observe(kNotDetected, kNotPredicted); + evaluator.Observe(kNotDetected, kNotPredicted); + EXPECT_EQ(evaluator.true_negatives(), true_negatives); + + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_positives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 0); +} + +// Checks that multiple consecutive predictions match the multiple detections +// observed before the expected detection deadline expires. +TEST(ClippingPredictorEvaluatorTest, + PredictedMultipleTimesAndDetectedMultipleTimes) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kPredicted); // --+ + evaluator.Observe(kNotDetected, kPredicted); // | --+ + evaluator.Observe(kNotDetected, kPredicted); // | | --+ + evaluator.Observe(kDetected, kNotPredicted); // <-+ <-+ <-+ + evaluator.Observe(kDetected, kNotPredicted); // <-+ <-+ + EXPECT_EQ(evaluator.true_positives(), 3); + // The following observation does not generate a true negative as it belongs + // to the observation period of the last prediction - for which two detections + // have already been matched. + const int true_negatives = evaluator.true_negatives(); + evaluator.Observe(kNotDetected, kNotPredicted); + EXPECT_EQ(evaluator.true_negatives(), true_negatives); + + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_positives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 0); +} + +// Checks that multiple consecutive predictions match all the detections +// observed before the expected detection deadline expires. +TEST(ClippingPredictorEvaluatorTest, PredictedMultipleTimesAndAllDetected) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kPredicted); // --+ + evaluator.Observe(kNotDetected, kPredicted); // | --+ + evaluator.Observe(kNotDetected, kPredicted); // | | --+ + evaluator.Observe(kDetected, kNotPredicted); // <-+ <-+ <-+ + evaluator.Observe(kDetected, kNotPredicted); // <-+ <-+ + evaluator.Observe(kDetected, kNotPredicted); // <-+ + EXPECT_EQ(evaluator.true_positives(), 3); + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_positives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 0); +} + +// Checks that multiple non-consecutive predictions match all the detections +// observed before the expected detection deadline expires. +TEST(ClippingPredictorEvaluatorTest, + PredictedMultipleTimesWithGapAndAllDetected) { + ClippingPredictorEvaluator evaluator(/*history_size=*/3); + evaluator.Observe(kNotDetected, kPredicted); // --+ + evaluator.Observe(kNotDetected, kNotPredicted); // | + evaluator.Observe(kNotDetected, kPredicted); // | --+ + evaluator.Observe(kDetected, kNotPredicted); // <-+ <-+ + evaluator.Observe(kDetected, kNotPredicted); // <-+ + evaluator.Observe(kDetected, kNotPredicted); // <-+ + EXPECT_EQ(evaluator.true_positives(), 2); + EXPECT_EQ(evaluator.true_negatives(), 0); + EXPECT_EQ(evaluator.false_positives(), 0); + EXPECT_EQ(evaluator.false_negatives(), 0); +} + +class ClippingPredictorEvaluatorPredictionIntervalParameterization + : public ::testing::TestWithParam> { + protected: + int num_extra_observe_calls() const { return std::get<0>(GetParam()); } + int history_size() const { return std::get<1>(GetParam()); } +}; + +// Checks that the minimum prediction interval is returned if clipping is +// correctly predicted as soon as detected - i.e., no anticipation. +TEST_P(ClippingPredictorEvaluatorPredictionIntervalParameterization, + MinimumPredictionInterval) { + ClippingPredictorEvaluator evaluator(history_size()); + for (int i = 0; i < num_extra_observe_calls(); ++i) { + EXPECT_EQ(evaluator.Observe(kNotDetected, kNotPredicted), absl::nullopt); + } + absl::optional prediction_interval = + evaluator.Observe(kDetected, kPredicted); + EXPECT_THAT(prediction_interval, Optional(Eq(0))); +} + +// Checks that a prediction interval between the minimum and the maximum is +// returned if clipping is correctly predicted before it is detected but not as +// early as possible. +TEST_P(ClippingPredictorEvaluatorPredictionIntervalParameterization, + IntermediatePredictionInterval) { + ClippingPredictorEvaluator evaluator(history_size()); + for (int i = 0; i < num_extra_observe_calls(); ++i) { + EXPECT_EQ(evaluator.Observe(kNotDetected, kNotPredicted), absl::nullopt); + } + EXPECT_EQ(evaluator.Observe(kNotDetected, kPredicted), absl::nullopt); + EXPECT_EQ(evaluator.Observe(kNotDetected, kPredicted), absl::nullopt); + EXPECT_EQ(evaluator.Observe(kNotDetected, kPredicted), absl::nullopt); + absl::optional prediction_interval = + evaluator.Observe(kDetected, kPredicted); + EXPECT_THAT(prediction_interval, Optional(Eq(3))); +} + +// Checks that the maximum prediction interval is returned if clipping is +// correctly predicted as early as possible. +TEST_P(ClippingPredictorEvaluatorPredictionIntervalParameterization, + MaximumPredictionInterval) { + ClippingPredictorEvaluator evaluator(history_size()); + for (int i = 0; i < num_extra_observe_calls(); ++i) { + EXPECT_EQ(evaluator.Observe(kNotDetected, kNotPredicted), absl::nullopt); + } + for (int i = 0; i < history_size(); ++i) { + EXPECT_EQ(evaluator.Observe(kNotDetected, kPredicted), absl::nullopt); + } + absl::optional prediction_interval = + evaluator.Observe(kDetected, kPredicted); + EXPECT_THAT(prediction_interval, Optional(Eq(history_size()))); +} + +// Checks that `Observe()` returns the prediction interval as soon as a true +// positive is found and never again while ongoing detections are matched to a +// previously observed prediction. +TEST_P(ClippingPredictorEvaluatorPredictionIntervalParameterization, + PredictionIntervalReturnedOnce) { + ASSERT_LT(num_extra_observe_calls(), history_size()); + ClippingPredictorEvaluator evaluator(history_size()); + // Observe predictions before detection. + for (int i = 0; i < num_extra_observe_calls(); ++i) { + EXPECT_EQ(evaluator.Observe(kNotDetected, kPredicted), absl::nullopt); + } + // Observe a detection. + absl::optional prediction_interval = + evaluator.Observe(kDetected, kPredicted); + EXPECT_TRUE(prediction_interval.has_value()); + // `Observe()` does not return a prediction interval anymore during ongoing + // detections observed while a detection is still expected. + for (int i = 0; i < history_size(); ++i) { + EXPECT_EQ(evaluator.Observe(kDetected, kNotPredicted), absl::nullopt); + } +} + +INSTANTIATE_TEST_SUITE_P( + ClippingPredictorEvaluatorTest, + ClippingPredictorEvaluatorPredictionIntervalParameterization, + ::testing::Combine(::testing::Values(0, 3, 5), ::testing::Values(7, 11))); + +// Checks that, when a detection is expected, the expectation is removed if and +// only if `Reset()` is called after a prediction is observed. +TEST(ClippingPredictorEvaluatorTest, NoFalsePositivesAfterReset) { + constexpr int kHistorySize = 2; + + ClippingPredictorEvaluator with_reset(kHistorySize); + with_reset.Observe(kNotDetected, kPredicted); + with_reset.Reset(); + with_reset.Observe(kNotDetected, kNotPredicted); + with_reset.Observe(kNotDetected, kNotPredicted); + EXPECT_EQ(with_reset.true_positives(), 0); + EXPECT_EQ(with_reset.true_negatives(), 2); + EXPECT_EQ(with_reset.false_positives(), 0); + EXPECT_EQ(with_reset.false_negatives(), 0); + + ClippingPredictorEvaluator no_reset(kHistorySize); + no_reset.Observe(kNotDetected, kPredicted); + no_reset.Observe(kNotDetected, kNotPredicted); + no_reset.Observe(kNotDetected, kNotPredicted); + EXPECT_EQ(no_reset.true_positives(), 0); + EXPECT_EQ(no_reset.true_negatives(), 0); + EXPECT_EQ(no_reset.false_positives(), 1); + EXPECT_EQ(no_reset.false_negatives(), 0); +} + +} // namespace +} // namespace webrtc diff --git a/modules/audio_processing/agc/clipping_predictor_level_buffer.cc b/modules/audio_processing/agc/clipping_predictor_level_buffer.cc new file mode 100644 index 0000000000..bc33cda040 --- /dev/null +++ b/modules/audio_processing/agc/clipping_predictor_level_buffer.cc @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc/clipping_predictor_level_buffer.h" + +#include +#include + +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" + +namespace webrtc { + +bool ClippingPredictorLevelBuffer::Level::operator==(const Level& level) const { + constexpr float kEpsilon = 1e-6f; + return std::fabs(average - level.average) < kEpsilon && + std::fabs(max - level.max) < kEpsilon; +} + +ClippingPredictorLevelBuffer::ClippingPredictorLevelBuffer(int capacity) + : tail_(-1), size_(0), data_(std::max(1, capacity)) { + if (capacity > kMaxCapacity) { + RTC_LOG(LS_WARNING) << "[agc]: ClippingPredictorLevelBuffer exceeds the " + << "maximum allowed capacity. Capacity: " << capacity; + } + RTC_DCHECK(!data_.empty()); +} + +void ClippingPredictorLevelBuffer::Reset() { + tail_ = -1; + size_ = 0; +} + +void ClippingPredictorLevelBuffer::Push(Level level) { + ++tail_; + if (tail_ == Capacity()) { + tail_ = 0; + } + if (size_ < Capacity()) { + size_++; + } + data_[tail_] = level; +} + +// TODO(bugs.webrtc.org/12774): Optimize partial computation for long buffers. +absl::optional +ClippingPredictorLevelBuffer::ComputePartialMetrics(int delay, + int num_items) const { + RTC_DCHECK_GE(delay, 0); + RTC_DCHECK_LT(delay, Capacity()); + RTC_DCHECK_GT(num_items, 0); + RTC_DCHECK_LE(num_items, Capacity()); + RTC_DCHECK_LE(delay + num_items, Capacity()); + if (delay + num_items > Size()) { + return absl::nullopt; + } + float sum = 0.0f; + float max = 0.0f; + for (int i = 0; i < num_items && i < Size(); ++i) { + int idx = tail_ - delay - i; + if (idx < 0) { + idx += Capacity(); + } + sum += data_[idx].average; + max = std::fmax(data_[idx].max, max); + } + return absl::optional({sum / static_cast(num_items), max}); +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc/clipping_predictor_level_buffer.h b/modules/audio_processing/agc/clipping_predictor_level_buffer.h new file mode 100644 index 0000000000..f3e8368194 --- /dev/null +++ b/modules/audio_processing/agc/clipping_predictor_level_buffer.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC_CLIPPING_PREDICTOR_LEVEL_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AGC_CLIPPING_PREDICTOR_LEVEL_BUFFER_H_ + +#include +#include + +#include "absl/types/optional.h" + +namespace webrtc { + +// A circular buffer to store frame-wise `Level` items for clipping prediction. +// The current implementation is not optimized for large buffer lengths. +class ClippingPredictorLevelBuffer { + public: + struct Level { + float average; + float max; + bool operator==(const Level& level) const; + }; + + // Recommended maximum capacity. It is possible to create a buffer with a + // larger capacity, but the implementation is not optimized for large values. + static constexpr int kMaxCapacity = 100; + + // Ctor. Sets the buffer capacity to max(1, `capacity`) and logs a warning + // message if the capacity is greater than `kMaxCapacity`. + explicit ClippingPredictorLevelBuffer(int capacity); + ~ClippingPredictorLevelBuffer() {} + ClippingPredictorLevelBuffer(const ClippingPredictorLevelBuffer&) = delete; + ClippingPredictorLevelBuffer& operator=(const ClippingPredictorLevelBuffer&) = + delete; + + void Reset(); + + // Returns the current number of items stored in the buffer. + int Size() const { return size_; } + + // Returns the capacity of the buffer. + int Capacity() const { return data_.size(); } + + // Adds a `level` item into the circular buffer `data_`. Stores at most + // `Capacity()` items. If more items are pushed, the new item replaces the + // least recently pushed item. + void Push(Level level); + + // If at least `num_items` + `delay` items have been pushed, returns the + // average and maximum value for the `num_items` most recently pushed items + // from `delay` to `delay` - `num_items` (a delay equal to zero corresponds + // to the most recently pushed item). The value of `delay` is limited to + // [0, N] and `num_items` to [1, M] where N + M is the capacity of the buffer. + absl::optional ComputePartialMetrics(int delay, int num_items) const; + + private: + int tail_; + int size_; + std::vector data_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC_CLIPPING_PREDICTOR_LEVEL_BUFFER_H_ diff --git a/modules/audio_processing/agc/clipping_predictor_level_buffer_unittest.cc b/modules/audio_processing/agc/clipping_predictor_level_buffer_unittest.cc new file mode 100644 index 0000000000..7e594a1eca --- /dev/null +++ b/modules/audio_processing/agc/clipping_predictor_level_buffer_unittest.cc @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc/clipping_predictor_level_buffer.h" + +#include + +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +using ::testing::Eq; +using ::testing::Optional; + +class ClippingPredictorLevelBufferParametrization + : public ::testing::TestWithParam { + protected: + int capacity() const { return GetParam(); } +}; + +TEST_P(ClippingPredictorLevelBufferParametrization, CheckEmptyBufferSize) { + ClippingPredictorLevelBuffer buffer(capacity()); + EXPECT_EQ(buffer.Capacity(), std::max(capacity(), 1)); + EXPECT_EQ(buffer.Size(), 0); +} + +TEST_P(ClippingPredictorLevelBufferParametrization, CheckHalfEmptyBufferSize) { + ClippingPredictorLevelBuffer buffer(capacity()); + for (int i = 0; i < buffer.Capacity() / 2; ++i) { + buffer.Push({2, 4}); + } + EXPECT_EQ(buffer.Capacity(), std::max(capacity(), 1)); + EXPECT_EQ(buffer.Size(), std::max(capacity(), 1) / 2); +} + +TEST_P(ClippingPredictorLevelBufferParametrization, CheckFullBufferSize) { + ClippingPredictorLevelBuffer buffer(capacity()); + for (int i = 0; i < buffer.Capacity(); ++i) { + buffer.Push({2, 4}); + } + EXPECT_EQ(buffer.Capacity(), std::max(capacity(), 1)); + EXPECT_EQ(buffer.Size(), std::max(capacity(), 1)); +} + +TEST_P(ClippingPredictorLevelBufferParametrization, CheckLargeBufferSize) { + ClippingPredictorLevelBuffer buffer(capacity()); + for (int i = 0; i < 2 * buffer.Capacity(); ++i) { + buffer.Push({2, 4}); + } + EXPECT_EQ(buffer.Capacity(), std::max(capacity(), 1)); + EXPECT_EQ(buffer.Size(), std::max(capacity(), 1)); +} + +TEST_P(ClippingPredictorLevelBufferParametrization, CheckSizeAfterReset) { + ClippingPredictorLevelBuffer buffer(capacity()); + buffer.Push({1, 1}); + buffer.Push({1, 1}); + buffer.Reset(); + EXPECT_EQ(buffer.Capacity(), std::max(capacity(), 1)); + EXPECT_EQ(buffer.Size(), 0); + buffer.Push({1, 1}); + EXPECT_EQ(buffer.Capacity(), std::max(capacity(), 1)); + EXPECT_EQ(buffer.Size(), 1); +} + +INSTANTIATE_TEST_SUITE_P(ClippingPredictorLevelBufferTest, + ClippingPredictorLevelBufferParametrization, + ::testing::Values(-1, 0, 1, 123)); + +TEST(ClippingPredictorLevelBufferTest, CheckMetricsAfterFullBuffer) { + ClippingPredictorLevelBuffer buffer(/*capacity=*/2); + buffer.Push({1, 2}); + buffer.Push({3, 6}); + EXPECT_THAT(buffer.ComputePartialMetrics(/*delay=*/0, /*num_items=*/1), + Optional(Eq(ClippingPredictorLevelBuffer::Level{3, 6}))); + EXPECT_THAT(buffer.ComputePartialMetrics(/*delay=*/1, /*num_items=*/1), + Optional(Eq(ClippingPredictorLevelBuffer::Level{1, 2}))); + EXPECT_THAT(buffer.ComputePartialMetrics(/*delay=*/0, /*num_items=*/2), + Optional(Eq(ClippingPredictorLevelBuffer::Level{2, 6}))); +} + +TEST(ClippingPredictorLevelBufferTest, CheckMetricsAfterPushBeyondCapacity) { + ClippingPredictorLevelBuffer buffer(/*capacity=*/2); + buffer.Push({1, 1}); + buffer.Push({3, 6}); + buffer.Push({5, 10}); + buffer.Push({7, 14}); + buffer.Push({6, 12}); + EXPECT_THAT(buffer.ComputePartialMetrics(/*delay=*/0, /*num_items=*/1), + Optional(Eq(ClippingPredictorLevelBuffer::Level{6, 12}))); + EXPECT_THAT(buffer.ComputePartialMetrics(/*delay=*/1, /*num_items=*/1), + Optional(Eq(ClippingPredictorLevelBuffer::Level{7, 14}))); + EXPECT_THAT(buffer.ComputePartialMetrics(/*delay=*/0, /*num_items=*/2), + Optional(Eq(ClippingPredictorLevelBuffer::Level{6.5f, 14}))); +} + +TEST(ClippingPredictorLevelBufferTest, CheckMetricsAfterTooFewItems) { + ClippingPredictorLevelBuffer buffer(/*capacity=*/4); + buffer.Push({1, 2}); + buffer.Push({3, 6}); + EXPECT_EQ(buffer.ComputePartialMetrics(/*delay=*/0, /*num_items=*/3), + absl::nullopt); + EXPECT_EQ(buffer.ComputePartialMetrics(/*delay=*/2, /*num_items=*/1), + absl::nullopt); +} + +TEST(ClippingPredictorLevelBufferTest, CheckMetricsAfterReset) { + ClippingPredictorLevelBuffer buffer(/*capacity=*/2); + buffer.Push({1, 2}); + buffer.Reset(); + buffer.Push({5, 10}); + buffer.Push({7, 14}); + EXPECT_THAT(buffer.ComputePartialMetrics(/*delay=*/0, /*num_items=*/1), + Optional(Eq(ClippingPredictorLevelBuffer::Level{7, 14}))); + EXPECT_THAT(buffer.ComputePartialMetrics(/*delay=*/0, /*num_items=*/2), + Optional(Eq(ClippingPredictorLevelBuffer::Level{6, 14}))); + EXPECT_THAT(buffer.ComputePartialMetrics(/*delay=*/1, /*num_items=*/1), + Optional(Eq(ClippingPredictorLevelBuffer::Level{5, 10}))); +} + +} // namespace +} // namespace webrtc diff --git a/modules/audio_processing/agc/clipping_predictor_unittest.cc b/modules/audio_processing/agc/clipping_predictor_unittest.cc new file mode 100644 index 0000000000..e848e1a724 --- /dev/null +++ b/modules/audio_processing/agc/clipping_predictor_unittest.cc @@ -0,0 +1,491 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc/clipping_predictor.h" + +#include +#include +#include + +#include "rtc_base/checks.h" +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +using ::testing::Eq; +using ::testing::Optional; +using ClippingPredictorConfig = AudioProcessing::Config::GainController1:: + AnalogGainController::ClippingPredictor; +using ClippingPredictorMode = AudioProcessing::Config::GainController1:: + AnalogGainController::ClippingPredictor::Mode; + +constexpr int kSampleRateHz = 32000; +constexpr int kNumChannels = 1; +constexpr int kSamplesPerChannel = kSampleRateHz / 100; +constexpr int kMaxMicLevel = 255; +constexpr int kMinMicLevel = 12; +constexpr int kDefaultClippedLevelStep = 15; +constexpr float kMaxSampleS16 = + static_cast(std::numeric_limits::max()); + +// Threshold in dB corresponding to a signal with an amplitude equal to 99% of +// the dynamic range - i.e., computed as `20*log10(0.99)`. +constexpr float kClippingThresholdDb = -0.08729610804900176f; + +void CallAnalyze(int num_calls, + const AudioFrameView& frame, + ClippingPredictor& predictor) { + for (int i = 0; i < num_calls; ++i) { + predictor.Analyze(frame); + } +} + +// Creates and analyzes an audio frame with a non-zero (approx. 4.15dB) crest +// factor. +void AnalyzeNonZeroCrestFactorAudio(int num_calls, + int num_channels, + float peak_ratio, + ClippingPredictor& predictor) { + RTC_DCHECK_GT(num_calls, 0); + RTC_DCHECK_GT(num_channels, 0); + RTC_DCHECK_LE(peak_ratio, 1.0f); + std::vector audio(num_channels); + std::vector audio_data(num_channels * kSamplesPerChannel, 0.0f); + for (int channel = 0; channel < num_channels; ++channel) { + audio[channel] = &audio_data[channel * kSamplesPerChannel]; + for (int sample = 0; sample < kSamplesPerChannel; sample += 10) { + audio[channel][sample] = 0.1f * peak_ratio * kMaxSampleS16; + audio[channel][sample + 1] = 0.2f * peak_ratio * kMaxSampleS16; + audio[channel][sample + 2] = 0.3f * peak_ratio * kMaxSampleS16; + audio[channel][sample + 3] = 0.4f * peak_ratio * kMaxSampleS16; + audio[channel][sample + 4] = 0.5f * peak_ratio * kMaxSampleS16; + audio[channel][sample + 5] = 0.6f * peak_ratio * kMaxSampleS16; + audio[channel][sample + 6] = 0.7f * peak_ratio * kMaxSampleS16; + audio[channel][sample + 7] = 0.8f * peak_ratio * kMaxSampleS16; + audio[channel][sample + 8] = 0.9f * peak_ratio * kMaxSampleS16; + audio[channel][sample + 9] = 1.0f * peak_ratio * kMaxSampleS16; + } + } + AudioFrameView frame(audio.data(), num_channels, + kSamplesPerChannel); + CallAnalyze(num_calls, frame, predictor); +} + +void CheckChannelEstimatesWithValue(int num_channels, + int level, + int default_step, + int min_mic_level, + int max_mic_level, + const ClippingPredictor& predictor, + int expected) { + for (int i = 0; i < num_channels; ++i) { + SCOPED_TRACE(i); + EXPECT_THAT(predictor.EstimateClippedLevelStep( + i, level, default_step, min_mic_level, max_mic_level), + Optional(Eq(expected))); + } +} + +void CheckChannelEstimatesWithoutValue(int num_channels, + int level, + int default_step, + int min_mic_level, + int max_mic_level, + const ClippingPredictor& predictor) { + for (int i = 0; i < num_channels; ++i) { + SCOPED_TRACE(i); + EXPECT_EQ(predictor.EstimateClippedLevelStep(i, level, default_step, + min_mic_level, max_mic_level), + absl::nullopt); + } +} + +// Creates and analyzes an audio frame with a zero crest factor. +void AnalyzeZeroCrestFactorAudio(int num_calls, + int num_channels, + float peak_ratio, + ClippingPredictor& predictor) { + RTC_DCHECK_GT(num_calls, 0); + RTC_DCHECK_GT(num_channels, 0); + RTC_DCHECK_LE(peak_ratio, 1.f); + std::vector audio(num_channels); + std::vector audio_data(num_channels * kSamplesPerChannel, 0.f); + for (int channel = 0; channel < num_channels; ++channel) { + audio[channel] = &audio_data[channel * kSamplesPerChannel]; + for (int sample = 0; sample < kSamplesPerChannel; ++sample) { + audio[channel][sample] = peak_ratio * kMaxSampleS16; + } + } + auto frame = AudioFrameView(audio.data(), num_channels, + kSamplesPerChannel); + CallAnalyze(num_calls, frame, predictor); +} + +TEST(ClippingPeakPredictorTest, NoPredictorCreated) { + auto predictor = + CreateClippingPredictor(kNumChannels, /*config=*/{/*enabled=*/false}); + EXPECT_FALSE(predictor); +} + +TEST(ClippingPeakPredictorTest, ClippingEventPredictionCreated) { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + auto predictor = CreateClippingPredictor( + kNumChannels, + /*config=*/{/*enabled=*/true, + /*mode=*/ClippingPredictorMode::kClippingEventPrediction}); + EXPECT_TRUE(predictor); +} + +TEST(ClippingPeakPredictorTest, AdaptiveStepClippingPeakPredictionCreated) { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + auto predictor = CreateClippingPredictor( + kNumChannels, /*config=*/{ + /*enabled=*/true, + /*mode=*/ClippingPredictorMode::kAdaptiveStepClippingPeakPrediction}); + EXPECT_TRUE(predictor); +} + +TEST(ClippingPeakPredictorTest, FixedStepClippingPeakPredictionCreated) { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + auto predictor = CreateClippingPredictor( + kNumChannels, /*config=*/{ + /*enabled=*/true, + /*mode=*/ClippingPredictorMode::kFixedStepClippingPeakPrediction}); + EXPECT_TRUE(predictor); +} + +class ClippingPredictorParameterization + : public ::testing::TestWithParam> { + protected: + int num_channels() const { return std::get<0>(GetParam()); } + ClippingPredictorConfig GetConfig(ClippingPredictorMode mode) const { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + return {/*enabled=*/true, + /*mode=*/mode, + /*window_length=*/std::get<1>(GetParam()), + /*reference_window_length=*/std::get<2>(GetParam()), + /*reference_window_delay=*/std::get<3>(GetParam()), + /*clipping_threshold=*/-1.0f, + /*crest_factor_margin=*/0.5f}; + } +}; + +TEST_P(ClippingPredictorParameterization, + CheckClippingEventPredictorEstimateAfterCrestFactorDrop) { + const ClippingPredictorConfig config = + GetConfig(ClippingPredictorMode::kClippingEventPrediction); + if (config.reference_window_length + config.reference_window_delay <= + config.window_length) { + return; + } + auto predictor = CreateClippingPredictor(num_channels(), config); + AnalyzeNonZeroCrestFactorAudio( + /*num_calls=*/config.reference_window_length + + config.reference_window_delay - config.window_length, + num_channels(), /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithoutValue(num_channels(), /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); + AnalyzeZeroCrestFactorAudio(config.window_length, num_channels(), + /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithValue( + num_channels(), /*level=*/255, kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor, kDefaultClippedLevelStep); +} + +TEST_P(ClippingPredictorParameterization, + CheckClippingEventPredictorNoEstimateAfterConstantCrestFactor) { + const ClippingPredictorConfig config = + GetConfig(ClippingPredictorMode::kClippingEventPrediction); + if (config.reference_window_length + config.reference_window_delay <= + config.window_length) { + return; + } + auto predictor = CreateClippingPredictor(num_channels(), config); + AnalyzeNonZeroCrestFactorAudio( + /*num_calls=*/config.reference_window_length + + config.reference_window_delay - config.window_length, + num_channels(), /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithoutValue(num_channels(), /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); + AnalyzeNonZeroCrestFactorAudio(/*num_calls=*/config.window_length, + num_channels(), + /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithoutValue(num_channels(), /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); +} + +TEST_P(ClippingPredictorParameterization, + CheckClippingPeakPredictorEstimateAfterHighCrestFactor) { + const ClippingPredictorConfig config = + GetConfig(ClippingPredictorMode::kAdaptiveStepClippingPeakPrediction); + if (config.reference_window_length + config.reference_window_delay <= + config.window_length) { + return; + } + auto predictor = CreateClippingPredictor(num_channels(), config); + AnalyzeNonZeroCrestFactorAudio( + /*num_calls=*/config.reference_window_length + + config.reference_window_delay - config.window_length, + num_channels(), /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithoutValue(num_channels(), /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); + AnalyzeNonZeroCrestFactorAudio(/*num_calls=*/config.window_length, + num_channels(), + /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithValue( + num_channels(), /*level=*/255, kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor, kDefaultClippedLevelStep); +} + +TEST_P(ClippingPredictorParameterization, + CheckClippingPeakPredictorNoEstimateAfterLowCrestFactor) { + const ClippingPredictorConfig config = + GetConfig(ClippingPredictorMode::kAdaptiveStepClippingPeakPrediction); + if (config.reference_window_length + config.reference_window_delay <= + config.window_length) { + return; + } + auto predictor = CreateClippingPredictor(num_channels(), config); + AnalyzeZeroCrestFactorAudio( + /*num_calls=*/config.reference_window_length + + config.reference_window_delay - config.window_length, + num_channels(), /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithoutValue(num_channels(), /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); + AnalyzeNonZeroCrestFactorAudio(/*num_calls=*/config.window_length, + num_channels(), + /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithoutValue(num_channels(), /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); +} + +INSTANTIATE_TEST_SUITE_P(GainController1ClippingPredictor, + ClippingPredictorParameterization, + ::testing::Combine(::testing::Values(1, 5), + ::testing::Values(1, 5, 10), + ::testing::Values(1, 5), + ::testing::Values(0, 1, 5))); + +class ClippingEventPredictorParameterization + : public ::testing::TestWithParam> { + protected: + ClippingPredictorConfig GetConfig() const { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + return {/*enabled=*/true, + /*mode=*/ClippingPredictorMode::kClippingEventPrediction, + /*window_length=*/5, + /*reference_window_length=*/5, + /*reference_window_delay=*/5, + /*clipping_threshold=*/std::get<0>(GetParam()), + /*crest_factor_margin=*/std::get<1>(GetParam())}; + } +}; + +TEST_P(ClippingEventPredictorParameterization, + CheckEstimateAfterCrestFactorDrop) { + const ClippingPredictorConfig config = GetConfig(); + auto predictor = CreateClippingPredictor(kNumChannels, config); + AnalyzeNonZeroCrestFactorAudio(/*num_calls=*/config.reference_window_length, + kNumChannels, /*peak_ratio=*/0.99f, + *predictor); + CheckChannelEstimatesWithoutValue(kNumChannels, /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); + AnalyzeZeroCrestFactorAudio(config.window_length, kNumChannels, + /*peak_ratio=*/0.99f, *predictor); + // TODO(bugs.webrtc.org/12774): Add clarifying comment. + // TODO(bugs.webrtc.org/12774): Remove 4.15f threshold and split tests. + if (config.clipping_threshold < kClippingThresholdDb && + config.crest_factor_margin < 4.15f) { + CheckChannelEstimatesWithValue( + kNumChannels, /*level=*/255, kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor, kDefaultClippedLevelStep); + } else { + CheckChannelEstimatesWithoutValue(kNumChannels, /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); + } +} + +INSTANTIATE_TEST_SUITE_P(GainController1ClippingPredictor, + ClippingEventPredictorParameterization, + ::testing::Combine(::testing::Values(-1.0f, 0.0f), + ::testing::Values(3.0f, 4.16f))); + +class ClippingPredictorModeParameterization + : public ::testing::TestWithParam { + protected: + ClippingPredictorConfig GetConfig(float clipping_threshold_dbfs) const { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + return {/*enabled=*/true, + /*mode=*/GetParam(), + /*window_length=*/5, + /*reference_window_length=*/5, + /*reference_window_delay=*/5, + /*clipping_threshold=*/clipping_threshold_dbfs, + /*crest_factor_margin=*/3.0f}; + } +}; + +TEST_P(ClippingPredictorModeParameterization, + CheckEstimateAfterHighCrestFactorWithNoClippingMargin) { + const ClippingPredictorConfig config = GetConfig( + /*clipping_threshold_dbfs=*/0.0f); + auto predictor = CreateClippingPredictor(kNumChannels, config); + AnalyzeNonZeroCrestFactorAudio(/*num_calls=*/config.reference_window_length, + kNumChannels, /*peak_ratio=*/0.99f, + *predictor); + CheckChannelEstimatesWithoutValue(kNumChannels, /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); + AnalyzeZeroCrestFactorAudio(config.window_length, kNumChannels, + /*peak_ratio=*/0.99f, *predictor); + // Since the clipping threshold is set to 0 dBFS, `EstimateClippedLevelStep()` + // is expected to return an unavailable value. + CheckChannelEstimatesWithoutValue(kNumChannels, /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); +} + +TEST_P(ClippingPredictorModeParameterization, + CheckEstimateAfterHighCrestFactorWithClippingMargin) { + const ClippingPredictorConfig config = + GetConfig(/*clipping_threshold_dbfs=*/-1.0f); + auto predictor = CreateClippingPredictor(kNumChannels, config); + AnalyzeNonZeroCrestFactorAudio(/*num_calls=*/config.reference_window_length, + kNumChannels, + /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithoutValue(kNumChannels, /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); + AnalyzeZeroCrestFactorAudio(config.window_length, kNumChannels, + /*peak_ratio=*/0.99f, *predictor); + // TODO(bugs.webrtc.org/12774): Add clarifying comment. + const float expected_step = + config.mode == ClippingPredictorMode::kAdaptiveStepClippingPeakPrediction + ? 17 + : kDefaultClippedLevelStep; + CheckChannelEstimatesWithValue(kNumChannels, /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor, expected_step); +} + +INSTANTIATE_TEST_SUITE_P( + GainController1ClippingPredictor, + ClippingPredictorModeParameterization, + ::testing::Values( + ClippingPredictorMode::kAdaptiveStepClippingPeakPrediction, + ClippingPredictorMode::kFixedStepClippingPeakPrediction)); + +TEST(ClippingEventPredictorTest, CheckEstimateAfterReset) { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + constexpr ClippingPredictorConfig kConfig{ + /*enabled=*/true, + /*mode=*/ClippingPredictorMode::kClippingEventPrediction, + /*window_length=*/5, + /*reference_window_length=*/5, + /*reference_window_delay=*/5, + /*clipping_threshold=*/-1.0f, + /*crest_factor_margin=*/3.0f}; + auto predictor = CreateClippingPredictor(kNumChannels, kConfig); + AnalyzeNonZeroCrestFactorAudio(/*num_calls=*/kConfig.reference_window_length, + kNumChannels, + /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithoutValue(kNumChannels, /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); + predictor->Reset(); + AnalyzeZeroCrestFactorAudio(kConfig.window_length, kNumChannels, + /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithoutValue(kNumChannels, /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); +} + +TEST(ClippingPeakPredictorTest, CheckNoEstimateAfterReset) { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + constexpr ClippingPredictorConfig kConfig{ + /*enabled=*/true, + /*mode=*/ClippingPredictorMode::kAdaptiveStepClippingPeakPrediction, + /*window_length=*/5, + /*reference_window_length=*/5, + /*reference_window_delay=*/5, + /*clipping_threshold=*/-1.0f}; + auto predictor = CreateClippingPredictor(kNumChannels, kConfig); + AnalyzeNonZeroCrestFactorAudio(/*num_calls=*/kConfig.reference_window_length, + kNumChannels, + /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithoutValue(kNumChannels, /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); + predictor->Reset(); + AnalyzeZeroCrestFactorAudio(kConfig.window_length, kNumChannels, + /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithoutValue(kNumChannels, /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); +} + +TEST(ClippingPeakPredictorTest, CheckAdaptiveStepEstimate) { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + constexpr ClippingPredictorConfig kConfig{ + /*enabled=*/true, + /*mode=*/ClippingPredictorMode::kAdaptiveStepClippingPeakPrediction, + /*window_length=*/5, + /*reference_window_length=*/5, + /*reference_window_delay=*/5, + /*clipping_threshold=*/-1.0f}; + auto predictor = CreateClippingPredictor(kNumChannels, kConfig); + AnalyzeNonZeroCrestFactorAudio(/*num_calls=*/kConfig.reference_window_length, + kNumChannels, /*peak_ratio=*/0.99f, + *predictor); + CheckChannelEstimatesWithoutValue(kNumChannels, /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); + AnalyzeZeroCrestFactorAudio(kConfig.window_length, kNumChannels, + /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithValue(kNumChannels, /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor, /*expected=*/17); +} + +TEST(ClippingPeakPredictorTest, CheckFixedStepEstimate) { + // TODO(bugs.webrtc.org/12874): Use designated initializers one fixed. + constexpr ClippingPredictorConfig kConfig{ + /*enabled=*/true, + /*mode=*/ClippingPredictorMode::kFixedStepClippingPeakPrediction, + /*window_length=*/5, + /*reference_window_length=*/5, + /*reference_window_delay=*/5, + /*clipping_threshold=*/-1.0f}; + auto predictor = CreateClippingPredictor(kNumChannels, kConfig); + AnalyzeNonZeroCrestFactorAudio(/*num_calls=*/kConfig.reference_window_length, + kNumChannels, /*peak_ratio=*/0.99f, + *predictor); + CheckChannelEstimatesWithoutValue(kNumChannels, /*level=*/255, + kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor); + AnalyzeZeroCrestFactorAudio(kConfig.window_length, kNumChannels, + /*peak_ratio=*/0.99f, *predictor); + CheckChannelEstimatesWithValue( + kNumChannels, /*level=*/255, kDefaultClippedLevelStep, kMinMicLevel, + kMaxMicLevel, *predictor, kDefaultClippedLevelStep); +} + +} // namespace +} // namespace webrtc diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn index 7b71f6a8e7..4c6cfab576 100644 --- a/modules/audio_processing/agc2/BUILD.gn +++ b/modules/audio_processing/agc2/BUILD.gn @@ -15,31 +15,6 @@ group("agc2") { ] } -rtc_library("level_estimation_agc") { - sources = [ - "adaptive_mode_level_estimator_agc.cc", - "adaptive_mode_level_estimator_agc.h", - ] - configs += [ "..:apm_debug_dump" ] - deps = [ - ":adaptive_digital", - ":common", - ":gain_applier", - ":noise_level_estimator", - ":rnn_vad_with_level", - "..:api", - "..:apm_logging", - "..:audio_frame_view", - "../../../api:array_view", - "../../../common_audio", - "../../../rtc_base:checks", - "../../../rtc_base:rtc_base_approved", - "../../../rtc_base:safe_minmax", - "../agc:level_estimation", - "../vad", - ] -} - rtc_library("adaptive_digital") { sources = [ "adaptive_agc.cc", @@ -50,6 +25,8 @@ rtc_library("adaptive_digital") { "adaptive_mode_level_estimator.h", "saturation_protector.cc", "saturation_protector.h", + "saturation_protector_buffer.cc", + "saturation_protector_buffer.h", ] configs += [ "..:apm_debug_dump" ] @@ -202,6 +179,7 @@ rtc_library("adaptive_digital_unittests") { "adaptive_digital_gain_applier_unittest.cc", "adaptive_mode_level_estimator_unittest.cc", "gain_applier_unittest.cc", + "saturation_protector_buffer_unittest.cc", "saturation_protector_unittest.cc", ] deps = [ @@ -273,6 +251,7 @@ rtc_library("noise_estimator_unittests") { "..:apm_logging", "..:audio_frame_view", "../../../api:array_view", + "../../../api:function_view", "../../../rtc_base:checks", "../../../rtc_base:gunit_helpers", "../../../rtc_base:rtc_base_approved", diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc index e72942a646..3fc9008db1 100644 --- a/modules/audio_processing/agc2/adaptive_agc.cc +++ b/modules/audio_processing/agc2/adaptive_agc.cc @@ -20,22 +20,14 @@ namespace webrtc { namespace { -void DumpDebugData(const AdaptiveDigitalGainApplier::FrameInfo& info, - ApmDataDumper& dumper) { - dumper.DumpRaw("agc2_vad_probability", info.vad_result.speech_probability); - dumper.DumpRaw("agc2_vad_rms_dbfs", info.vad_result.rms_dbfs); - dumper.DumpRaw("agc2_vad_peak_dbfs", info.vad_result.peak_dbfs); - dumper.DumpRaw("agc2_noise_estimate_dbfs", info.input_noise_level_dbfs); - dumper.DumpRaw("agc2_last_limiter_audio_level", info.limiter_envelope_dbfs); -} - -constexpr int kGainApplierAdjacentSpeechFramesThreshold = 1; -constexpr float kMaxGainChangePerSecondDb = 3.f; -constexpr float kMaxOutputNoiseLevelDbfs = -50.f; +using AdaptiveDigitalConfig = + AudioProcessing::Config::GainController2::AdaptiveDigital; +using NoiseEstimatorType = + AudioProcessing::Config::GainController2::NoiseEstimator; // Detects the available CPU features and applies any kill-switches. AvailableCpuFeatures GetAllowedCpuFeatures( - const AudioProcessing::Config::GainController2::AdaptiveDigital& config) { + const AdaptiveDigitalConfig& config) { AvailableCpuFeatures features = GetAvailableCpuFeatures(); if (!config.sse2_allowed) { features.sse2 = false; @@ -49,60 +41,86 @@ AvailableCpuFeatures GetAllowedCpuFeatures( return features; } -} // namespace - -AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper) - : speech_level_estimator_(apm_data_dumper), - gain_applier_(apm_data_dumper, - kGainApplierAdjacentSpeechFramesThreshold, - kMaxGainChangePerSecondDb, - kMaxOutputNoiseLevelDbfs), - apm_data_dumper_(apm_data_dumper), - noise_level_estimator_(apm_data_dumper) { - RTC_DCHECK(apm_data_dumper); +std::unique_ptr CreateNoiseLevelEstimator( + NoiseEstimatorType estimator_type, + ApmDataDumper* apm_data_dumper) { + switch (estimator_type) { + case NoiseEstimatorType::kStationaryNoise: + return CreateStationaryNoiseEstimator(apm_data_dumper); + case NoiseEstimatorType::kNoiseFloor: + return CreateNoiseFloorEstimator(apm_data_dumper); + } } +} // namespace + AdaptiveAgc::AdaptiveAgc(ApmDataDumper* apm_data_dumper, - const AudioProcessing::Config::GainController2& config) - : speech_level_estimator_( - apm_data_dumper, - config.adaptive_digital.level_estimator, - config.adaptive_digital - .level_estimator_adjacent_speech_frames_threshold, - config.adaptive_digital.initial_saturation_margin_db, - config.adaptive_digital.extra_saturation_margin_db), - vad_(config.adaptive_digital.vad_probability_attack, - GetAllowedCpuFeatures(config.adaptive_digital)), - gain_applier_( - apm_data_dumper, - config.adaptive_digital.gain_applier_adjacent_speech_frames_threshold, - config.adaptive_digital.max_gain_change_db_per_second, - config.adaptive_digital.max_output_noise_level_dbfs), + const AdaptiveDigitalConfig& config) + : speech_level_estimator_(apm_data_dumper, + config.adjacent_speech_frames_threshold), + vad_(config.vad_reset_period_ms, GetAllowedCpuFeatures(config)), + gain_controller_(apm_data_dumper, + config.adjacent_speech_frames_threshold, + config.max_gain_change_db_per_second, + config.max_output_noise_level_dbfs, + config.dry_run), apm_data_dumper_(apm_data_dumper), - noise_level_estimator_(apm_data_dumper) { + noise_level_estimator_( + CreateNoiseLevelEstimator(config.noise_estimator, apm_data_dumper)), + saturation_protector_( + CreateSaturationProtector(kSaturationProtectorInitialHeadroomDb, + kSaturationProtectorExtraHeadroomDb, + config.adjacent_speech_frames_threshold, + apm_data_dumper)) { RTC_DCHECK(apm_data_dumper); - if (!config.adaptive_digital.use_saturation_protector) { + RTC_DCHECK(noise_level_estimator_); + RTC_DCHECK(saturation_protector_); + if (!config.use_saturation_protector) { RTC_LOG(LS_WARNING) << "The saturation protector cannot be disabled."; } } AdaptiveAgc::~AdaptiveAgc() = default; +void AdaptiveAgc::Initialize(int sample_rate_hz, int num_channels) { + gain_controller_.Initialize(sample_rate_hz, num_channels); +} + void AdaptiveAgc::Process(AudioFrameView frame, float limiter_envelope) { AdaptiveDigitalGainApplier::FrameInfo info; - info.vad_result = vad_.AnalyzeFrame(frame); - speech_level_estimator_.Update(info.vad_result); - info.input_level_dbfs = speech_level_estimator_.level_dbfs(); - info.input_noise_level_dbfs = noise_level_estimator_.Analyze(frame); - info.limiter_envelope_dbfs = - limiter_envelope > 0 ? FloatS16ToDbfs(limiter_envelope) : -90.f; - info.estimate_is_confident = speech_level_estimator_.IsConfident(); - DumpDebugData(info, *apm_data_dumper_); - gain_applier_.Process(info, frame); + + VadLevelAnalyzer::Result vad_result = vad_.AnalyzeFrame(frame); + info.speech_probability = vad_result.speech_probability; + apm_data_dumper_->DumpRaw("agc2_speech_probability", + vad_result.speech_probability); + apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", vad_result.rms_dbfs); + apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", vad_result.peak_dbfs); + + speech_level_estimator_.Update(vad_result); + info.speech_level_dbfs = speech_level_estimator_.level_dbfs(); + info.speech_level_reliable = speech_level_estimator_.IsConfident(); + apm_data_dumper_->DumpRaw("agc2_speech_level_dbfs", info.speech_level_dbfs); + apm_data_dumper_->DumpRaw("agc2_speech_level_reliable", + info.speech_level_reliable); + + info.noise_rms_dbfs = noise_level_estimator_->Analyze(frame); + apm_data_dumper_->DumpRaw("agc2_noise_rms_dbfs", info.noise_rms_dbfs); + + saturation_protector_->Analyze(info.speech_probability, vad_result.peak_dbfs, + info.speech_level_dbfs); + info.headroom_db = saturation_protector_->HeadroomDb(); + apm_data_dumper_->DumpRaw("agc2_headroom_db", info.headroom_db); + + info.limiter_envelope_dbfs = FloatS16ToDbfs(limiter_envelope); + apm_data_dumper_->DumpRaw("agc2_limiter_envelope_dbfs", + info.limiter_envelope_dbfs); + + gain_controller_.Process(info, frame); } -void AdaptiveAgc::Reset() { +void AdaptiveAgc::HandleInputGainChange() { speech_level_estimator_.Reset(); + saturation_protector_->Reset(); } } // namespace webrtc diff --git a/modules/audio_processing/agc2/adaptive_agc.h b/modules/audio_processing/agc2/adaptive_agc.h index f3c7854e16..43c7787e36 100644 --- a/modules/audio_processing/agc2/adaptive_agc.h +++ b/modules/audio_processing/agc2/adaptive_agc.h @@ -11,9 +11,12 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_ #define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_ +#include + #include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h" #include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h" #include "modules/audio_processing/agc2/noise_level_estimator.h" +#include "modules/audio_processing/agc2/saturation_protector.h" #include "modules/audio_processing/agc2/vad_with_level.h" #include "modules/audio_processing/include/audio_frame_view.h" #include "modules/audio_processing/include/audio_processing.h" @@ -22,27 +25,33 @@ namespace webrtc { class ApmDataDumper; // Adaptive digital gain controller. -// TODO(crbug.com/webrtc/7494): Unify with `AdaptiveDigitalGainApplier`. +// TODO(crbug.com/webrtc/7494): Rename to `AdaptiveDigitalGainController`. class AdaptiveAgc { public: - explicit AdaptiveAgc(ApmDataDumper* apm_data_dumper); - // TODO(crbug.com/webrtc/7494): Remove ctor above. - AdaptiveAgc(ApmDataDumper* apm_data_dumper, - const AudioProcessing::Config::GainController2& config); + AdaptiveAgc( + ApmDataDumper* apm_data_dumper, + const AudioProcessing::Config::GainController2::AdaptiveDigital& config); ~AdaptiveAgc(); + void Initialize(int sample_rate_hz, int num_channels); + + // TODO(crbug.com/webrtc/7494): Add `SetLimiterEnvelope()`. + // Analyzes `frame` and applies a digital adaptive gain to it. Takes into // account the envelope measured by the limiter. - // TODO(crbug.com/webrtc/7494): Make the class depend on the limiter. + // TODO(crbug.com/webrtc/7494): Remove `limiter_envelope`. void Process(AudioFrameView frame, float limiter_envelope); - void Reset(); + + // Handles a gain change applied to the input signal (e.g., analog gain). + void HandleInputGainChange(); private: AdaptiveModeLevelEstimator speech_level_estimator_; VadLevelAnalyzer vad_; - AdaptiveDigitalGainApplier gain_applier_; + AdaptiveDigitalGainApplier gain_controller_; ApmDataDumper* const apm_data_dumper_; - NoiseLevelEstimator noise_level_estimator_; + std::unique_ptr noise_level_estimator_; + std::unique_ptr saturation_protector_; }; } // namespace webrtc diff --git a/modules/audio_processing/agc2/adaptive_digital_gain_applier.cc b/modules/audio_processing/agc2/adaptive_digital_gain_applier.cc index 36ef9be561..e59b110efe 100644 --- a/modules/audio_processing/agc2/adaptive_digital_gain_applier.cc +++ b/modules/audio_processing/agc2/adaptive_digital_gain_applier.cc @@ -23,6 +23,9 @@ namespace webrtc { namespace { +constexpr int kHeadroomHistogramMin = 0; +constexpr int kHeadroomHistogramMax = 50; + // This function maps input level to desired applied gain. We want to // boost the signal so that peaks are at -kHeadroomDbfs. We can't // apply more than kMaxGainDb gain. @@ -31,17 +34,13 @@ float ComputeGainDb(float input_level_dbfs) { if (input_level_dbfs < -(kHeadroomDbfs + kMaxGainDb)) { return kMaxGainDb; } - // We expect to end up here most of the time: the level is below // -headroom, but we can boost it to -headroom. if (input_level_dbfs < -kHeadroomDbfs) { return -kHeadroomDbfs - input_level_dbfs; } - - // Otherwise, the level is too high and we can't boost. The - // LevelEstimator is responsible for not reporting bogus gain - // values. - RTC_DCHECK_LE(input_level_dbfs, 0.f); + // Otherwise, the level is too high and we can't boost. + RTC_DCHECK_GE(input_level_dbfs, -kHeadroomDbfs); return 0.f; } @@ -52,10 +51,11 @@ float LimitGainByNoise(float target_gain, float input_noise_level_dbfs, float max_output_noise_level_dbfs, ApmDataDumper& apm_data_dumper) { - const float noise_headroom_db = + const float max_allowed_gain_db = max_output_noise_level_dbfs - input_noise_level_dbfs; - apm_data_dumper.DumpRaw("agc2_noise_headroom_db", noise_headroom_db); - return std::min(target_gain, std::max(noise_headroom_db, 0.f)); + apm_data_dumper.DumpRaw("agc2_adaptive_gain_applier_max_allowed_gain_db", + max_allowed_gain_db); + return std::min(target_gain, std::max(max_allowed_gain_db, 0.f)); } float LimitGainByLowConfidence(float target_gain, @@ -68,8 +68,8 @@ float LimitGainByLowConfidence(float target_gain, } const float limiter_level_before_gain = limiter_audio_level_dbfs - last_gain; - // Compute a new gain so that limiter_level_before_gain + new_gain <= - // kLimiterThreshold. + // Compute a new gain so that `limiter_level_before_gain` + `new_target_gain` + // is not great than `kLimiterThresholdForAgcGainDbfs`. const float new_target_gain = std::max( kLimiterThresholdForAgcGainDbfs - limiter_level_before_gain, 0.f); return std::min(new_target_gain, target_gain); @@ -80,13 +80,30 @@ float LimitGainByLowConfidence(float target_gain, float ComputeGainChangeThisFrameDb(float target_gain_db, float last_gain_db, bool gain_increase_allowed, - float max_gain_change_db) { + float max_gain_decrease_db, + float max_gain_increase_db) { + RTC_DCHECK_GT(max_gain_decrease_db, 0); + RTC_DCHECK_GT(max_gain_increase_db, 0); float target_gain_difference_db = target_gain_db - last_gain_db; if (!gain_increase_allowed) { target_gain_difference_db = std::min(target_gain_difference_db, 0.f); } - return rtc::SafeClamp(target_gain_difference_db, -max_gain_change_db, - max_gain_change_db); + return rtc::SafeClamp(target_gain_difference_db, -max_gain_decrease_db, + max_gain_increase_db); +} + +// Copies the (multichannel) audio samples from `src` into `dst`. +void CopyAudio(AudioFrameView src, + std::vector>& dst) { + RTC_DCHECK_GT(src.num_channels(), 0); + RTC_DCHECK_GT(src.samples_per_channel(), 0); + RTC_DCHECK_EQ(dst.size(), src.num_channels()); + for (size_t c = 0; c < src.num_channels(); ++c) { + rtc::ArrayView channel_view = src.channel(c); + RTC_DCHECK_EQ(channel_view.size(), src.samples_per_channel()); + RTC_DCHECK_EQ(dst[c].size(), src.samples_per_channel()); + std::copy(channel_view.begin(), channel_view.end(), dst[c].begin()); + } } } // namespace @@ -95,7 +112,8 @@ AdaptiveDigitalGainApplier::AdaptiveDigitalGainApplier( ApmDataDumper* apm_data_dumper, int adjacent_speech_frames_threshold, float max_gain_change_db_per_second, - float max_output_noise_level_dbfs) + float max_output_noise_level_dbfs, + bool dry_run) : apm_data_dumper_(apm_data_dumper), gain_applier_( /*hard_clip_samples=*/false, @@ -104,18 +122,44 @@ AdaptiveDigitalGainApplier::AdaptiveDigitalGainApplier( max_gain_change_db_per_10ms_(max_gain_change_db_per_second * kFrameDurationMs / 1000.f), max_output_noise_level_dbfs_(max_output_noise_level_dbfs), + dry_run_(dry_run), calls_since_last_gain_log_(0), frames_to_gain_increase_allowed_(adjacent_speech_frames_threshold_), last_gain_db_(kInitialAdaptiveDigitalGainDb) { - RTC_DCHECK_GT(max_gain_change_db_per_second, 0.f); + RTC_DCHECK_GT(max_gain_change_db_per_second, 0.0f); RTC_DCHECK_GE(frames_to_gain_increase_allowed_, 1); - RTC_DCHECK_GE(max_output_noise_level_dbfs_, -90.f); - RTC_DCHECK_LE(max_output_noise_level_dbfs_, 0.f); + RTC_DCHECK_GE(max_output_noise_level_dbfs_, -90.0f); + RTC_DCHECK_LE(max_output_noise_level_dbfs_, 0.0f); + Initialize(/*sample_rate_hz=*/48000, /*num_channels=*/1); +} + +void AdaptiveDigitalGainApplier::Initialize(int sample_rate_hz, + int num_channels) { + if (!dry_run_) { + return; + } + RTC_DCHECK_GT(sample_rate_hz, 0); + RTC_DCHECK_GT(num_channels, 0); + int frame_size = rtc::CheckedDivExact(sample_rate_hz, 100); + bool sample_rate_changed = + dry_run_frame_.empty() || // Handle initialization. + dry_run_frame_[0].size() != static_cast(frame_size); + bool num_channels_changed = + dry_run_channels_.size() != static_cast(num_channels); + if (sample_rate_changed || num_channels_changed) { + // Resize the multichannel audio vector and update the channel pointers. + dry_run_frame_.resize(num_channels); + dry_run_channels_.resize(num_channels); + for (int c = 0; c < num_channels; ++c) { + dry_run_frame_[c].resize(frame_size); + dry_run_channels_[c] = dry_run_frame_[c].data(); + } + } } void AdaptiveDigitalGainApplier::Process(const FrameInfo& info, AudioFrameView frame) { - RTC_DCHECK_GE(info.input_level_dbfs, -150.f); + RTC_DCHECK_GE(info.speech_level_dbfs, -150.f); RTC_DCHECK_GE(frame.num_channels(), 1); RTC_DCHECK( frame.samples_per_channel() == 80 || frame.samples_per_channel() == 160 || @@ -123,28 +167,46 @@ void AdaptiveDigitalGainApplier::Process(const FrameInfo& info, << "`frame` does not look like a 10 ms frame for an APM supported sample " "rate"; + // Compute the input level used to select the desired gain. + RTC_DCHECK_GT(info.headroom_db, 0.0f); + const float input_level_dbfs = info.speech_level_dbfs + info.headroom_db; + const float target_gain_db = LimitGainByLowConfidence( - LimitGainByNoise(ComputeGainDb(std::min(info.input_level_dbfs, 0.f)), - info.input_noise_level_dbfs, + LimitGainByNoise(ComputeGainDb(input_level_dbfs), info.noise_rms_dbfs, max_output_noise_level_dbfs_, *apm_data_dumper_), - last_gain_db_, info.limiter_envelope_dbfs, info.estimate_is_confident); + last_gain_db_, info.limiter_envelope_dbfs, info.speech_level_reliable); // Forbid increasing the gain until enough adjacent speech frames are // observed. - if (info.vad_result.speech_probability < kVadConfidenceThreshold) { + bool first_confident_speech_frame = false; + if (info.speech_probability < kVadConfidenceThreshold) { frames_to_gain_increase_allowed_ = adjacent_speech_frames_threshold_; } else if (frames_to_gain_increase_allowed_ > 0) { frames_to_gain_increase_allowed_--; + first_confident_speech_frame = frames_to_gain_increase_allowed_ == 0; + } + apm_data_dumper_->DumpRaw( + "agc2_adaptive_gain_applier_frames_to_gain_increase_allowed", + frames_to_gain_increase_allowed_); + + const bool gain_increase_allowed = frames_to_gain_increase_allowed_ == 0; + + float max_gain_increase_db = max_gain_change_db_per_10ms_; + if (first_confident_speech_frame) { + // No gain increase happened while waiting for a long enough speech + // sequence. Therefore, temporarily allow a faster gain increase. + RTC_DCHECK(gain_increase_allowed); + max_gain_increase_db *= adjacent_speech_frames_threshold_; } const float gain_change_this_frame_db = ComputeGainChangeThisFrameDb( - target_gain_db, last_gain_db_, - /*gain_increase_allowed=*/frames_to_gain_increase_allowed_ == 0, - max_gain_change_db_per_10ms_); + target_gain_db, last_gain_db_, gain_increase_allowed, + /*max_gain_decrease_db=*/max_gain_change_db_per_10ms_, + max_gain_increase_db); - apm_data_dumper_->DumpRaw("agc2_want_to_change_by_db", + apm_data_dumper_->DumpRaw("agc2_adaptive_gain_applier_want_to_change_by_db", target_gain_db - last_gain_db_); - apm_data_dumper_->DumpRaw("agc2_will_change_by_db", + apm_data_dumper_->DumpRaw("agc2_adaptive_gain_applier_will_change_by_db", gain_change_this_frame_db); // Optimization: avoid calling math functions if gain does not @@ -153,27 +215,45 @@ void AdaptiveDigitalGainApplier::Process(const FrameInfo& info, gain_applier_.SetGainFactor( DbToRatio(last_gain_db_ + gain_change_this_frame_db)); } - gain_applier_.ApplyGain(frame); + + // Modify `frame` only if not running in "dry run" mode. + if (!dry_run_) { + gain_applier_.ApplyGain(frame); + } else { + // Copy `frame` so that `ApplyGain()` is called (on a copy). + CopyAudio(frame, dry_run_frame_); + RTC_DCHECK(!dry_run_channels_.empty()); + AudioFrameView frame_copy(&dry_run_channels_[0], + frame.num_channels(), + frame.samples_per_channel()); + gain_applier_.ApplyGain(frame_copy); + } // Remember that the gain has changed for the next iteration. last_gain_db_ = last_gain_db_ + gain_change_this_frame_db; - apm_data_dumper_->DumpRaw("agc2_applied_gain_db", last_gain_db_); + apm_data_dumper_->DumpRaw("agc2_adaptive_gain_applier_applied_gain_db", + last_gain_db_); // Log every 10 seconds. calls_since_last_gain_log_++; if (calls_since_last_gain_log_ == 1000) { calls_since_last_gain_log_ = 0; + RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.EstimatedSpeechLevel", + -info.speech_level_dbfs, 0, 100, 101); + RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.EstimatedNoiseLevel", + -info.noise_rms_dbfs, 0, 100, 101); + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.Audio.Agc2.Headroom", info.headroom_db, kHeadroomHistogramMin, + kHeadroomHistogramMax, + kHeadroomHistogramMax - kHeadroomHistogramMin + 1); RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.DigitalGainApplied", last_gain_db_, 0, kMaxGainDb, kMaxGainDb + 1); - RTC_HISTOGRAM_COUNTS_LINEAR( - "WebRTC.Audio.Agc2.EstimatedSpeechPlusNoiseLevel", - -info.input_level_dbfs, 0, 100, 101); - RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.EstimatedNoiseLevel", - -info.input_noise_level_dbfs, 0, 100, 101); RTC_LOG(LS_INFO) << "AGC2 adaptive digital" - << " | speech_plus_noise_dbfs: " << info.input_level_dbfs - << " | noise_dbfs: " << info.input_noise_level_dbfs + << " | speech_dbfs: " << info.speech_level_dbfs + << " | noise_dbfs: " << info.noise_rms_dbfs + << " | headroom_db: " << info.headroom_db << " | gain_db: " << last_gain_db_; } } + } // namespace webrtc diff --git a/modules/audio_processing/agc2/adaptive_digital_gain_applier.h b/modules/audio_processing/agc2/adaptive_digital_gain_applier.h index a65379f5be..8b58ea00b2 100644 --- a/modules/audio_processing/agc2/adaptive_digital_gain_applier.h +++ b/modules/audio_processing/agc2/adaptive_digital_gain_applier.h @@ -11,42 +11,46 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_APPLIER_H_ #define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_APPLIER_H_ +#include + #include "modules/audio_processing/agc2/gain_applier.h" -#include "modules/audio_processing/agc2/vad_with_level.h" #include "modules/audio_processing/include/audio_frame_view.h" namespace webrtc { class ApmDataDumper; -// Part of the adaptive digital controller that applies a digital adaptive gain. -// The gain is updated towards a target. The logic decides when gain updates are -// allowed, it controls the adaptation speed and caps the target based on the -// estimated noise level and the speech level estimate confidence. +// TODO(bugs.webrtc.org): Split into `GainAdaptor` and `GainApplier`. +// Selects the target digital gain, decides when and how quickly to adapt to the +// target and applies the current gain to 10 ms frames. class AdaptiveDigitalGainApplier { public: // Information about a frame to process. struct FrameInfo { - float input_level_dbfs; // Estimated speech plus noise level. - float input_noise_level_dbfs; // Estimated noise level. - VadLevelAnalyzer::Result vad_result; - float limiter_envelope_dbfs; // Envelope level from the limiter. - bool estimate_is_confident; + float speech_probability; // Probability of speech in the [0, 1] range. + float speech_level_dbfs; // Estimated speech level (dBFS). + bool speech_level_reliable; // True with reliable speech level estimation. + float noise_rms_dbfs; // Estimated noise RMS level (dBFS). + float headroom_db; // Headroom (dB). + float limiter_envelope_dbfs; // Envelope level from the limiter (dBFS). }; - // Ctor. - // `adjacent_speech_frames_threshold` indicates how many speech frames are - // required before a gain increase is allowed. `max_gain_change_db_per_second` - // limits the adaptation speed (uniformly operated across frames). - // `max_output_noise_level_dbfs` limits the output noise level. + // Ctor. `adjacent_speech_frames_threshold` indicates how many adjacent speech + // frames must be observed in order to consider the sequence as speech. + // `max_gain_change_db_per_second` limits the adaptation speed (uniformly + // operated across frames). `max_output_noise_level_dbfs` limits the output + // noise level. If `dry_run` is true, `Process()` will not modify the audio. AdaptiveDigitalGainApplier(ApmDataDumper* apm_data_dumper, int adjacent_speech_frames_threshold, float max_gain_change_db_per_second, - float max_output_noise_level_dbfs); + float max_output_noise_level_dbfs, + bool dry_run); AdaptiveDigitalGainApplier(const AdaptiveDigitalGainApplier&) = delete; AdaptiveDigitalGainApplier& operator=(const AdaptiveDigitalGainApplier&) = delete; + void Initialize(int sample_rate_hz, int num_channels); + // Analyzes `info`, updates the digital gain and applies it to a 10 ms // `frame`. Supports any sample rate supported by APM. void Process(const FrameInfo& info, AudioFrameView frame); @@ -58,10 +62,14 @@ class AdaptiveDigitalGainApplier { const int adjacent_speech_frames_threshold_; const float max_gain_change_db_per_10ms_; const float max_output_noise_level_dbfs_; + const bool dry_run_; int calls_since_last_gain_log_; int frames_to_gain_increase_allowed_; float last_gain_db_; + + std::vector> dry_run_frame_; + std::vector dry_run_channels_; }; } // namespace webrtc diff --git a/modules/audio_processing/agc2/adaptive_digital_gain_applier_unittest.cc b/modules/audio_processing/agc2/adaptive_digital_gain_applier_unittest.cc index e2df700422..f4a23a92b9 100644 --- a/modules/audio_processing/agc2/adaptive_digital_gain_applier_unittest.cc +++ b/modules/audio_processing/agc2/adaptive_digital_gain_applier_unittest.cc @@ -11,6 +11,7 @@ #include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h" #include +#include #include "common_audio/include/audio_util.h" #include "modules/audio_processing/agc2/agc2_common.h" @@ -26,104 +27,79 @@ constexpr int kStereo = 2; constexpr int kFrameLen10ms8kHz = 80; constexpr int kFrameLen10ms48kHz = 480; +constexpr float kMaxSpeechProbability = 1.0f; + // Constants used in place of estimated noise levels. -constexpr float kNoNoiseDbfs = -90.f; +constexpr float kNoNoiseDbfs = kMinLevelDbfs; constexpr float kWithNoiseDbfs = -20.f; -static_assert(std::is_trivially_destructible::value, - ""); -constexpr VadLevelAnalyzer::Result kVadSpeech{1.f, -20.f, 0.f}; -constexpr float kMaxGainChangePerSecondDb = 3.f; +constexpr float kMaxGainChangePerSecondDb = 3.0f; constexpr float kMaxGainChangePerFrameDb = - kMaxGainChangePerSecondDb * kFrameDurationMs / 1000.f; -constexpr float kMaxOutputNoiseLevelDbfs = -50.f; + kMaxGainChangePerSecondDb * kFrameDurationMs / 1000.0f; +constexpr float kMaxOutputNoiseLevelDbfs = -50.0f; -// Helper to instance `AdaptiveDigitalGainApplier`. +// Helper to create initialized `AdaptiveDigitalGainApplier` objects. struct GainApplierHelper { GainApplierHelper() : GainApplierHelper(/*adjacent_speech_frames_threshold=*/1) {} explicit GainApplierHelper(int adjacent_speech_frames_threshold) : apm_data_dumper(0), - gain_applier(&apm_data_dumper, - adjacent_speech_frames_threshold, - kMaxGainChangePerSecondDb, - kMaxOutputNoiseLevelDbfs) {} + gain_applier(std::make_unique( + &apm_data_dumper, + adjacent_speech_frames_threshold, + kMaxGainChangePerSecondDb, + kMaxOutputNoiseLevelDbfs, + /*dry_run=*/false)) {} ApmDataDumper apm_data_dumper; - AdaptiveDigitalGainApplier gain_applier; + std::unique_ptr gain_applier; }; -// Runs gain applier and returns the applied gain in linear scale. -float RunOnConstantLevel(int num_iterations, - VadLevelAnalyzer::Result vad_level, - float input_level_dbfs, - AdaptiveDigitalGainApplier* gain_applier) { - float gain_linear = 0.f; - - for (int i = 0; i < num_iterations; ++i) { - VectorFloatFrame fake_audio(kMono, kFrameLen10ms8kHz, 1.f); - AdaptiveDigitalGainApplier::FrameInfo info; - info.input_level_dbfs = input_level_dbfs; - info.input_noise_level_dbfs = kNoNoiseDbfs; - info.vad_result = vad_level; - info.limiter_envelope_dbfs = -2.f; - info.estimate_is_confident = true; - gain_applier->Process(info, fake_audio.float_frame_view()); - gain_linear = fake_audio.float_frame_view().channel(0)[0]; - } - return gain_linear; -} - // Voice on, no noise, low limiter, confident level. +static_assert(std::is_trivially_destructible< + AdaptiveDigitalGainApplier::FrameInfo>::value, + ""); constexpr AdaptiveDigitalGainApplier::FrameInfo kFrameInfo{ - /*input_level_dbfs=*/-1.f, - /*input_noise_level_dbfs=*/kNoNoiseDbfs, - /*vad_result=*/kVadSpeech, - /*limiter_envelope_dbfs=*/-2.f, - /*estimate_is_confident=*/true}; - -TEST(AutomaticGainController2AdaptiveGainApplier, GainApplierShouldNotCrash) { + /*speech_probability=*/kMaxSpeechProbability, + /*speech_level_dbfs=*/kInitialSpeechLevelEstimateDbfs, + /*speech_level_reliable=*/true, + /*noise_rms_dbfs=*/kNoNoiseDbfs, + /*headroom_db=*/kSaturationProtectorInitialHeadroomDb, + /*limiter_envelope_dbfs=*/-2.0f}; + +TEST(GainController2AdaptiveGainApplier, GainApplierShouldNotCrash) { GainApplierHelper helper; + helper.gain_applier->Initialize(/*sample_rate_hz=*/48000, kStereo); // Make one call with reasonable audio level values and settings. - VectorFloatFrame fake_audio(kStereo, kFrameLen10ms48kHz, 10000.f); + VectorFloatFrame fake_audio(kStereo, kFrameLen10ms48kHz, 10000.0f); AdaptiveDigitalGainApplier::FrameInfo info = kFrameInfo; - info.input_level_dbfs = -5.0; - helper.gain_applier.Process(kFrameInfo, fake_audio.float_frame_view()); -} - -// Check that the output is -kHeadroom dBFS. -TEST(AutomaticGainController2AdaptiveGainApplier, TargetLevelIsReached) { - GainApplierHelper helper; - - constexpr float initial_level_dbfs = -5.f; - - const float applied_gain = RunOnConstantLevel( - 200, kVadSpeech, initial_level_dbfs, &helper.gain_applier); - - EXPECT_NEAR(applied_gain, DbToRatio(-kHeadroomDbfs - initial_level_dbfs), - 0.1f); + info.speech_level_dbfs = -5.0f; + helper.gain_applier->Process(kFrameInfo, fake_audio.float_frame_view()); } -// Check that the output is -kHeadroom dBFS -TEST(AutomaticGainController2AdaptiveGainApplier, GainApproachesMaxGain) { - GainApplierHelper helper; - - constexpr float initial_level_dbfs = -kHeadroomDbfs - kMaxGainDb - 10.f; - // A few extra frames for safety. +// Checks that the maximum allowed gain is applied. +TEST(GainController2AdaptiveGainApplier, MaxGainApplied) { constexpr int kNumFramesToAdapt = static_cast(kMaxGainDb / kMaxGainChangePerFrameDb) + 10; - const float applied_gain = RunOnConstantLevel( - kNumFramesToAdapt, kVadSpeech, initial_level_dbfs, &helper.gain_applier); - EXPECT_NEAR(applied_gain, DbToRatio(kMaxGainDb), 0.1f); - - const float applied_gain_db = 20.f * std::log10(applied_gain); + GainApplierHelper helper; + helper.gain_applier->Initialize(/*sample_rate_hz=*/8000, kMono); + AdaptiveDigitalGainApplier::FrameInfo info = kFrameInfo; + info.speech_level_dbfs = -60.0f; + float applied_gain; + for (int i = 0; i < kNumFramesToAdapt; ++i) { + VectorFloatFrame fake_audio(kMono, kFrameLen10ms8kHz, 1.0f); + helper.gain_applier->Process(info, fake_audio.float_frame_view()); + applied_gain = fake_audio.float_frame_view().channel(0)[0]; + } + const float applied_gain_db = 20.0f * std::log10f(applied_gain); EXPECT_NEAR(applied_gain_db, kMaxGainDb, 0.1f); } -TEST(AutomaticGainController2AdaptiveGainApplier, GainDoesNotChangeFast) { +TEST(GainController2AdaptiveGainApplier, GainDoesNotChangeFast) { GainApplierHelper helper; + helper.gain_applier->Initialize(/*sample_rate_hz=*/8000, kMono); - constexpr float initial_level_dbfs = -25.f; + constexpr float initial_level_dbfs = -25.0f; // A few extra frames for safety. constexpr int kNumFramesToAdapt = static_cast(initial_level_dbfs / kMaxGainChangePerFrameDb) + 10; @@ -133,10 +109,10 @@ TEST(AutomaticGainController2AdaptiveGainApplier, GainDoesNotChangeFast) { float last_gain_linear = 1.f; for (int i = 0; i < kNumFramesToAdapt; ++i) { SCOPED_TRACE(i); - VectorFloatFrame fake_audio(kMono, kFrameLen10ms8kHz, 1.f); + VectorFloatFrame fake_audio(kMono, kFrameLen10ms8kHz, 1.0f); AdaptiveDigitalGainApplier::FrameInfo info = kFrameInfo; - info.input_level_dbfs = initial_level_dbfs; - helper.gain_applier.Process(info, fake_audio.float_frame_view()); + info.speech_level_dbfs = initial_level_dbfs; + helper.gain_applier->Process(info, fake_audio.float_frame_view()); float current_gain_linear = fake_audio.float_frame_view().channel(0)[0]; EXPECT_LE(std::abs(current_gain_linear - last_gain_linear), kMaxChangePerFrameLinear); @@ -146,10 +122,10 @@ TEST(AutomaticGainController2AdaptiveGainApplier, GainDoesNotChangeFast) { // Check that the same is true when gain decreases as well. for (int i = 0; i < kNumFramesToAdapt; ++i) { SCOPED_TRACE(i); - VectorFloatFrame fake_audio(kMono, kFrameLen10ms8kHz, 1.f); + VectorFloatFrame fake_audio(kMono, kFrameLen10ms8kHz, 1.0f); AdaptiveDigitalGainApplier::FrameInfo info = kFrameInfo; - info.input_level_dbfs = 0.f; - helper.gain_applier.Process(info, fake_audio.float_frame_view()); + info.speech_level_dbfs = 0.f; + helper.gain_applier->Process(info, fake_audio.float_frame_view()); float current_gain_linear = fake_audio.float_frame_view().channel(0)[0]; EXPECT_LE(std::abs(current_gain_linear - last_gain_linear), kMaxChangePerFrameLinear); @@ -157,17 +133,18 @@ TEST(AutomaticGainController2AdaptiveGainApplier, GainDoesNotChangeFast) { } } -TEST(AutomaticGainController2AdaptiveGainApplier, GainIsRampedInAFrame) { +TEST(GainController2AdaptiveGainApplier, GainIsRampedInAFrame) { GainApplierHelper helper; + helper.gain_applier->Initialize(/*sample_rate_hz=*/48000, kMono); - constexpr float initial_level_dbfs = -25.f; + constexpr float initial_level_dbfs = -25.0f; - VectorFloatFrame fake_audio(kMono, kFrameLen10ms48kHz, 1.f); + VectorFloatFrame fake_audio(kMono, kFrameLen10ms48kHz, 1.0f); AdaptiveDigitalGainApplier::FrameInfo info = kFrameInfo; - info.input_level_dbfs = initial_level_dbfs; - helper.gain_applier.Process(info, fake_audio.float_frame_view()); - float maximal_difference = 0.f; - float current_value = 1.f * DbToRatio(kInitialAdaptiveDigitalGainDb); + info.speech_level_dbfs = initial_level_dbfs; + helper.gain_applier->Process(info, fake_audio.float_frame_view()); + float maximal_difference = 0.0f; + float current_value = 1.0f * DbToRatio(kInitialAdaptiveDigitalGainDb); for (const auto& x : fake_audio.float_frame_view().channel(0)) { const float difference = std::abs(x - current_value); maximal_difference = std::max(maximal_difference, difference); @@ -181,10 +158,11 @@ TEST(AutomaticGainController2AdaptiveGainApplier, GainIsRampedInAFrame) { EXPECT_LE(maximal_difference, kMaxChangePerSample); } -TEST(AutomaticGainController2AdaptiveGainApplier, NoiseLimitsGain) { +TEST(GainController2AdaptiveGainApplier, NoiseLimitsGain) { GainApplierHelper helper; + helper.gain_applier->Initialize(/*sample_rate_hz=*/48000, kMono); - constexpr float initial_level_dbfs = -25.f; + constexpr float initial_level_dbfs = -25.0f; constexpr int num_initial_frames = kInitialAdaptiveDigitalGainDb / kMaxGainChangePerFrameDb; constexpr int num_frames = 50; @@ -193,11 +171,11 @@ TEST(AutomaticGainController2AdaptiveGainApplier, NoiseLimitsGain) { << "kWithNoiseDbfs is too low"; for (int i = 0; i < num_initial_frames + num_frames; ++i) { - VectorFloatFrame fake_audio(kMono, kFrameLen10ms48kHz, 1.f); + VectorFloatFrame fake_audio(kMono, kFrameLen10ms48kHz, 1.0f); AdaptiveDigitalGainApplier::FrameInfo info = kFrameInfo; - info.input_level_dbfs = initial_level_dbfs; - info.input_noise_level_dbfs = kWithNoiseDbfs; - helper.gain_applier.Process(info, fake_audio.float_frame_view()); + info.speech_level_dbfs = initial_level_dbfs; + info.noise_rms_dbfs = kWithNoiseDbfs; + helper.gain_applier->Process(info, fake_audio.float_frame_view()); // Wait so that the adaptive gain applier has time to lower the gain. if (i > num_initial_frames) { @@ -205,25 +183,27 @@ TEST(AutomaticGainController2AdaptiveGainApplier, NoiseLimitsGain) { *std::max_element(fake_audio.float_frame_view().channel(0).begin(), fake_audio.float_frame_view().channel(0).end()); - EXPECT_NEAR(maximal_ratio, 1.f, 0.001f); + EXPECT_NEAR(maximal_ratio, 1.0f, 0.001f); } } } -TEST(AutomaticGainController2GainApplier, CanHandlePositiveSpeechLevels) { +TEST(GainController2GainApplier, CanHandlePositiveSpeechLevels) { GainApplierHelper helper; + helper.gain_applier->Initialize(/*sample_rate_hz=*/48000, kStereo); // Make one call with positive audio level values and settings. - VectorFloatFrame fake_audio(kStereo, kFrameLen10ms48kHz, 10000.f); + VectorFloatFrame fake_audio(kStereo, kFrameLen10ms48kHz, 10000.0f); AdaptiveDigitalGainApplier::FrameInfo info = kFrameInfo; - info.input_level_dbfs = 5.f; - helper.gain_applier.Process(info, fake_audio.float_frame_view()); + info.speech_level_dbfs = 5.0f; + helper.gain_applier->Process(info, fake_audio.float_frame_view()); } -TEST(AutomaticGainController2GainApplier, AudioLevelLimitsGain) { +TEST(GainController2GainApplier, AudioLevelLimitsGain) { GainApplierHelper helper; + helper.gain_applier->Initialize(/*sample_rate_hz=*/48000, kMono); - constexpr float initial_level_dbfs = -25.f; + constexpr float initial_level_dbfs = -25.0f; constexpr int num_initial_frames = kInitialAdaptiveDigitalGainDb / kMaxGainChangePerFrameDb; constexpr int num_frames = 50; @@ -232,12 +212,12 @@ TEST(AutomaticGainController2GainApplier, AudioLevelLimitsGain) { << "kWithNoiseDbfs is too low"; for (int i = 0; i < num_initial_frames + num_frames; ++i) { - VectorFloatFrame fake_audio(kMono, kFrameLen10ms48kHz, 1.f); + VectorFloatFrame fake_audio(kMono, kFrameLen10ms48kHz, 1.0f); AdaptiveDigitalGainApplier::FrameInfo info = kFrameInfo; - info.input_level_dbfs = initial_level_dbfs; - info.limiter_envelope_dbfs = 1.f; - info.estimate_is_confident = false; - helper.gain_applier.Process(info, fake_audio.float_frame_view()); + info.speech_level_dbfs = initial_level_dbfs; + info.limiter_envelope_dbfs = 1.0f; + info.speech_level_reliable = false; + helper.gain_applier->Process(info, fake_audio.float_frame_view()); // Wait so that the adaptive gain applier has time to lower the gain. if (i > num_initial_frames) { @@ -245,7 +225,7 @@ TEST(AutomaticGainController2GainApplier, AudioLevelLimitsGain) { *std::max_element(fake_audio.float_frame_view().channel(0).begin(), fake_audio.float_frame_view().channel(0).end()); - EXPECT_NEAR(maximal_ratio, 1.f, 0.001f); + EXPECT_NEAR(maximal_ratio, 1.0f, 0.001f); } } } @@ -259,15 +239,13 @@ TEST_P(AdaptiveDigitalGainApplierTest, DoNotIncreaseGainWithTooFewSpeechFrames) { const int adjacent_speech_frames_threshold = AdjacentSpeechFramesThreshold(); GainApplierHelper helper(adjacent_speech_frames_threshold); + helper.gain_applier->Initialize(/*sample_rate_hz=*/48000, kMono); - AdaptiveDigitalGainApplier::FrameInfo info = kFrameInfo; - info.input_level_dbfs = -25.0; - - float prev_gain = 0.f; + float prev_gain = 0.0f; for (int i = 0; i < adjacent_speech_frames_threshold; ++i) { SCOPED_TRACE(i); - VectorFloatFrame audio(kMono, kFrameLen10ms48kHz, 1.f); - helper.gain_applier.Process(info, audio.float_frame_view()); + VectorFloatFrame audio(kMono, kFrameLen10ms48kHz, 1.0f); + helper.gain_applier->Process(kFrameInfo, audio.float_frame_view()); const float gain = audio.float_frame_view().channel(0)[0]; if (i > 0) { EXPECT_EQ(prev_gain, gain); // No gain increase. @@ -279,28 +257,90 @@ TEST_P(AdaptiveDigitalGainApplierTest, TEST_P(AdaptiveDigitalGainApplierTest, IncreaseGainWithEnoughSpeechFrames) { const int adjacent_speech_frames_threshold = AdjacentSpeechFramesThreshold(); GainApplierHelper helper(adjacent_speech_frames_threshold); + helper.gain_applier->Initialize(/*sample_rate_hz=*/48000, kMono); - AdaptiveDigitalGainApplier::FrameInfo info = kFrameInfo; - info.input_level_dbfs = -25.0; - - float prev_gain = 0.f; + float prev_gain = 0.0f; for (int i = 0; i < adjacent_speech_frames_threshold; ++i) { - VectorFloatFrame audio(kMono, kFrameLen10ms48kHz, 1.f); - helper.gain_applier.Process(info, audio.float_frame_view()); + SCOPED_TRACE(i); + VectorFloatFrame audio(kMono, kFrameLen10ms48kHz, 1.0f); + helper.gain_applier->Process(kFrameInfo, audio.float_frame_view()); prev_gain = audio.float_frame_view().channel(0)[0]; } // Process one more speech frame. - VectorFloatFrame audio(kMono, kFrameLen10ms48kHz, 1.f); - helper.gain_applier.Process(info, audio.float_frame_view()); + VectorFloatFrame audio(kMono, kFrameLen10ms48kHz, 1.0f); + helper.gain_applier->Process(kFrameInfo, audio.float_frame_view()); // The gain has increased. EXPECT_GT(audio.float_frame_view().channel(0)[0], prev_gain); } -INSTANTIATE_TEST_SUITE_P(AutomaticGainController2, +INSTANTIATE_TEST_SUITE_P(GainController2, AdaptiveDigitalGainApplierTest, ::testing::Values(1, 7, 31)); +// Checks that the input is never modified when running in dry run mode. +TEST(GainController2GainApplier, DryRunDoesNotChangeInput) { + ApmDataDumper apm_data_dumper(0); + AdaptiveDigitalGainApplier gain_applier( + &apm_data_dumper, /*adjacent_speech_frames_threshold=*/1, + kMaxGainChangePerSecondDb, kMaxOutputNoiseLevelDbfs, /*dry_run=*/true); + // Simulate an input signal with log speech level. + AdaptiveDigitalGainApplier::FrameInfo info = kFrameInfo; + info.speech_level_dbfs = -60.0f; + // Allow enough time to reach the maximum gain. + constexpr int kNumFramesToAdapt = + static_cast(kMaxGainDb / kMaxGainChangePerFrameDb) + 10; + constexpr float kPcmSamples = 123.456f; + // Run the gain applier and check that the PCM samples are not modified. + gain_applier.Initialize(/*sample_rate_hz=*/8000, kMono); + for (int i = 0; i < kNumFramesToAdapt; ++i) { + SCOPED_TRACE(i); + VectorFloatFrame fake_audio(kMono, kFrameLen10ms8kHz, kPcmSamples); + gain_applier.Process(info, fake_audio.float_frame_view()); + EXPECT_FLOAT_EQ(fake_audio.float_frame_view().channel(0)[0], kPcmSamples); + } +} + +// Checks that no sample is modified before and after the sample rate changes. +TEST(GainController2GainApplier, DryRunHandlesSampleRateChange) { + ApmDataDumper apm_data_dumper(0); + AdaptiveDigitalGainApplier gain_applier( + &apm_data_dumper, /*adjacent_speech_frames_threshold=*/1, + kMaxGainChangePerSecondDb, kMaxOutputNoiseLevelDbfs, /*dry_run=*/true); + AdaptiveDigitalGainApplier::FrameInfo info = kFrameInfo; + info.speech_level_dbfs = -60.0f; + constexpr float kPcmSamples = 123.456f; + VectorFloatFrame fake_audio_8k(kMono, kFrameLen10ms8kHz, kPcmSamples); + gain_applier.Initialize(/*sample_rate_hz=*/8000, kMono); + gain_applier.Process(info, fake_audio_8k.float_frame_view()); + EXPECT_FLOAT_EQ(fake_audio_8k.float_frame_view().channel(0)[0], kPcmSamples); + gain_applier.Initialize(/*sample_rate_hz=*/48000, kMono); + VectorFloatFrame fake_audio_48k(kMono, kFrameLen10ms48kHz, kPcmSamples); + gain_applier.Process(info, fake_audio_48k.float_frame_view()); + EXPECT_FLOAT_EQ(fake_audio_48k.float_frame_view().channel(0)[0], kPcmSamples); +} + +// Checks that no sample is modified before and after the number of channels +// changes. +TEST(GainController2GainApplier, DryRunHandlesNumChannelsChange) { + ApmDataDumper apm_data_dumper(0); + AdaptiveDigitalGainApplier gain_applier( + &apm_data_dumper, /*adjacent_speech_frames_threshold=*/1, + kMaxGainChangePerSecondDb, kMaxOutputNoiseLevelDbfs, /*dry_run=*/true); + AdaptiveDigitalGainApplier::FrameInfo info = kFrameInfo; + info.speech_level_dbfs = -60.0f; + constexpr float kPcmSamples = 123.456f; + VectorFloatFrame fake_audio_8k(kMono, kFrameLen10ms8kHz, kPcmSamples); + gain_applier.Initialize(/*sample_rate_hz=*/8000, kMono); + gain_applier.Process(info, fake_audio_8k.float_frame_view()); + EXPECT_FLOAT_EQ(fake_audio_8k.float_frame_view().channel(0)[0], kPcmSamples); + VectorFloatFrame fake_audio_48k(kStereo, kFrameLen10ms8kHz, kPcmSamples); + gain_applier.Initialize(/*sample_rate_hz=*/8000, kStereo); + gain_applier.Process(info, fake_audio_48k.float_frame_view()); + EXPECT_FLOAT_EQ(fake_audio_48k.float_frame_view().channel(0)[0], kPcmSamples); + EXPECT_FLOAT_EQ(fake_audio_48k.float_frame_view().channel(1)[0], kPcmSamples); +} + } // namespace } // namespace webrtc diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc b/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc index 739997f5e3..507aa12cb4 100644 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc +++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc @@ -22,37 +22,17 @@ namespace { using LevelEstimatorType = AudioProcessing::Config::GainController2::LevelEstimator; -// Combines a level estimation with the saturation protector margins. -float ComputeLevelEstimateDbfs(float level_estimate_dbfs, - float saturation_margin_db, - float extra_saturation_margin_db) { - return rtc::SafeClamp( - level_estimate_dbfs + saturation_margin_db + extra_saturation_margin_db, - -90.f, 30.f); -} - -// Returns the level of given type from `vad_level`. -float GetLevel(const VadLevelAnalyzer::Result& vad_level, - LevelEstimatorType type) { - switch (type) { - case LevelEstimatorType::kRms: - return vad_level.rms_dbfs; - break; - case LevelEstimatorType::kPeak: - return vad_level.peak_dbfs; - break; - } - RTC_CHECK_NOTREACHED(); +float ClampLevelEstimateDbfs(float level_estimate_dbfs) { + return rtc::SafeClamp(level_estimate_dbfs, -90.f, 30.f); } } // namespace bool AdaptiveModeLevelEstimator::LevelEstimatorState::operator==( const AdaptiveModeLevelEstimator::LevelEstimatorState& b) const { - return time_to_full_buffer_ms == b.time_to_full_buffer_ms && + return time_to_confidence_ms == b.time_to_confidence_ms && level_dbfs.numerator == b.level_dbfs.numerator && - level_dbfs.denominator == b.level_dbfs.denominator && - saturation_protector == b.saturation_protector; + level_dbfs.denominator == b.level_dbfs.denominator; } float AdaptiveModeLevelEstimator::LevelEstimatorState::Ratio::GetRatio() const { @@ -64,25 +44,14 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator( ApmDataDumper* apm_data_dumper) : AdaptiveModeLevelEstimator( apm_data_dumper, - AudioProcessing::Config::GainController2::LevelEstimator::kRms, - kDefaultLevelEstimatorAdjacentSpeechFramesThreshold, - kDefaultInitialSaturationMarginDb, - kDefaultExtraSaturationMarginDb) {} + kDefaultLevelEstimatorAdjacentSpeechFramesThreshold) {} AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator( ApmDataDumper* apm_data_dumper, - AudioProcessing::Config::GainController2::LevelEstimator level_estimator, - int adjacent_speech_frames_threshold, - float initial_saturation_margin_db, - float extra_saturation_margin_db) + int adjacent_speech_frames_threshold) : apm_data_dumper_(apm_data_dumper), - level_estimator_type_(level_estimator), adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold), - initial_saturation_margin_db_(initial_saturation_margin_db), - extra_saturation_margin_db_(extra_saturation_margin_db), - level_dbfs_(ComputeLevelEstimateDbfs(kInitialSpeechLevelEstimateDbfs, - initial_saturation_margin_db_, - extra_saturation_margin_db_)) { + level_dbfs_(ClampLevelEstimateDbfs(kInitialSpeechLevelEstimateDbfs)) { RTC_DCHECK(apm_data_dumper_); RTC_DCHECK_GE(adjacent_speech_frames_threshold_, 1); Reset(); @@ -96,8 +65,6 @@ void AdaptiveModeLevelEstimator::Update( RTC_DCHECK_LT(vad_level.peak_dbfs, 50.f); RTC_DCHECK_GE(vad_level.speech_probability, 0.f); RTC_DCHECK_LE(vad_level.speech_probability, 1.f); - DumpDebugData(); - if (vad_level.speech_probability < kVadConfidenceThreshold) { // Not a speech frame. if (adjacent_speech_frames_threshold_ > 1) { @@ -115,85 +82,82 @@ void AdaptiveModeLevelEstimator::Update( } } num_adjacent_speech_frames_ = 0; - return; - } - - // Speech frame observed. - num_adjacent_speech_frames_++; - - // Update preliminary level estimate. - RTC_DCHECK_GE(preliminary_state_.time_to_full_buffer_ms, 0); - const bool buffer_is_full = preliminary_state_.time_to_full_buffer_ms == 0; - if (!buffer_is_full) { - preliminary_state_.time_to_full_buffer_ms -= kFrameDurationMs; - } - // Weighted average of levels with speech probability as weight. - RTC_DCHECK_GT(vad_level.speech_probability, 0.f); - const float leak_factor = buffer_is_full ? kFullBufferLeakFactor : 1.f; - preliminary_state_.level_dbfs.numerator = - preliminary_state_.level_dbfs.numerator * leak_factor + - GetLevel(vad_level, level_estimator_type_) * vad_level.speech_probability; - preliminary_state_.level_dbfs.denominator = - preliminary_state_.level_dbfs.denominator * leak_factor + - vad_level.speech_probability; - - const float level_dbfs = preliminary_state_.level_dbfs.GetRatio(); - - UpdateSaturationProtectorState(vad_level.peak_dbfs, level_dbfs, - preliminary_state_.saturation_protector); - - if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) { - // `preliminary_state_` is now reliable. Update the last level estimation. - level_dbfs_ = ComputeLevelEstimateDbfs( - level_dbfs, preliminary_state_.saturation_protector.margin_db, - extra_saturation_margin_db_); + } else { + // Speech frame observed. + num_adjacent_speech_frames_++; + + // Update preliminary level estimate. + RTC_DCHECK_GE(preliminary_state_.time_to_confidence_ms, 0); + const bool buffer_is_full = preliminary_state_.time_to_confidence_ms == 0; + if (!buffer_is_full) { + preliminary_state_.time_to_confidence_ms -= kFrameDurationMs; + } + // Weighted average of levels with speech probability as weight. + RTC_DCHECK_GT(vad_level.speech_probability, 0.f); + const float leak_factor = buffer_is_full ? kLevelEstimatorLeakFactor : 1.f; + preliminary_state_.level_dbfs.numerator = + preliminary_state_.level_dbfs.numerator * leak_factor + + vad_level.rms_dbfs * vad_level.speech_probability; + preliminary_state_.level_dbfs.denominator = + preliminary_state_.level_dbfs.denominator * leak_factor + + vad_level.speech_probability; + + const float level_dbfs = preliminary_state_.level_dbfs.GetRatio(); + + if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) { + // `preliminary_state_` is now reliable. Update the last level estimation. + level_dbfs_ = ClampLevelEstimateDbfs(level_dbfs); + } } + DumpDebugData(); } bool AdaptiveModeLevelEstimator::IsConfident() const { if (adjacent_speech_frames_threshold_ == 1) { // Ignore `reliable_state_` when a single frame is enough to update the // level estimate (because it is not used). - return preliminary_state_.time_to_full_buffer_ms == 0; + return preliminary_state_.time_to_confidence_ms == 0; } // Once confident, it remains confident. - RTC_DCHECK(reliable_state_.time_to_full_buffer_ms != 0 || - preliminary_state_.time_to_full_buffer_ms == 0); + RTC_DCHECK(reliable_state_.time_to_confidence_ms != 0 || + preliminary_state_.time_to_confidence_ms == 0); // During the first long enough speech sequence, `reliable_state_` must be // ignored since `preliminary_state_` is used. - return reliable_state_.time_to_full_buffer_ms == 0 || + return reliable_state_.time_to_confidence_ms == 0 || (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_ && - preliminary_state_.time_to_full_buffer_ms == 0); + preliminary_state_.time_to_confidence_ms == 0); } void AdaptiveModeLevelEstimator::Reset() { ResetLevelEstimatorState(preliminary_state_); ResetLevelEstimatorState(reliable_state_); - level_dbfs_ = ComputeLevelEstimateDbfs(kInitialSpeechLevelEstimateDbfs, - initial_saturation_margin_db_, - extra_saturation_margin_db_); + level_dbfs_ = ClampLevelEstimateDbfs(kInitialSpeechLevelEstimateDbfs); num_adjacent_speech_frames_ = 0; } void AdaptiveModeLevelEstimator::ResetLevelEstimatorState( LevelEstimatorState& state) const { - state.time_to_full_buffer_ms = kFullBufferSizeMs; - state.level_dbfs.numerator = 0.f; - state.level_dbfs.denominator = 0.f; - ResetSaturationProtectorState(initial_saturation_margin_db_, - state.saturation_protector); + state.time_to_confidence_ms = kLevelEstimatorTimeToConfidenceMs; + state.level_dbfs.numerator = kInitialSpeechLevelEstimateDbfs; + state.level_dbfs.denominator = 1.0f; } void AdaptiveModeLevelEstimator::DumpDebugData() const { - apm_data_dumper_->DumpRaw("agc2_adaptive_level_estimate_dbfs", level_dbfs_); - apm_data_dumper_->DumpRaw("agc2_adaptive_num_adjacent_speech_frames_", - num_adjacent_speech_frames_); - apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_level_estimate_num", - preliminary_state_.level_dbfs.numerator); - apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_level_estimate_den", - preliminary_state_.level_dbfs.denominator); - apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_saturation_margin_db", - preliminary_state_.saturation_protector.margin_db); + apm_data_dumper_->DumpRaw( + "agc2_adaptive_level_estimator_num_adjacent_speech_frames", + num_adjacent_speech_frames_); + apm_data_dumper_->DumpRaw( + "agc2_adaptive_level_estimator_preliminary_level_estimate_num", + preliminary_state_.level_dbfs.numerator); + apm_data_dumper_->DumpRaw( + "agc2_adaptive_level_estimator_preliminary_level_estimate_den", + preliminary_state_.level_dbfs.denominator); + apm_data_dumper_->DumpRaw( + "agc2_adaptive_level_estimator_preliminary_time_to_confidence_ms", + preliminary_state_.time_to_confidence_ms); + apm_data_dumper_->DumpRaw( + "agc2_adaptive_level_estimator_reliable_time_to_confidence_ms", + reliable_state_.time_to_confidence_ms); } } // namespace webrtc diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h index 213fc0f0c8..6d44938587 100644 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h +++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h @@ -15,7 +15,6 @@ #include #include "modules/audio_processing/agc2/agc2_common.h" -#include "modules/audio_processing/agc2/saturation_protector.h" #include "modules/audio_processing/agc2/vad_with_level.h" #include "modules/audio_processing/include/audio_processing.h" @@ -29,12 +28,8 @@ class AdaptiveModeLevelEstimator { AdaptiveModeLevelEstimator(const AdaptiveModeLevelEstimator&) = delete; AdaptiveModeLevelEstimator& operator=(const AdaptiveModeLevelEstimator&) = delete; - AdaptiveModeLevelEstimator( - ApmDataDumper* apm_data_dumper, - AudioProcessing::Config::GainController2::LevelEstimator level_estimator, - int adjacent_speech_frames_threshold, - float initial_saturation_margin_db, - float extra_saturation_margin_db); + AdaptiveModeLevelEstimator(ApmDataDumper* apm_data_dumper, + int adjacent_speech_frames_threshold); // Updates the level estimation. void Update(const VadLevelAnalyzer::Result& vad_data); @@ -57,10 +52,9 @@ class AdaptiveModeLevelEstimator { float denominator; float GetRatio() const; }; - // TODO(crbug.com/webrtc/7494): Remove time_to_full_buffer_ms if redundant. - int time_to_full_buffer_ms; + // TODO(crbug.com/webrtc/7494): Remove time_to_confidence_ms if redundant. + int time_to_confidence_ms; Ratio level_dbfs; - SaturationProtectorState saturation_protector; }; static_assert(std::is_trivially_copyable::value, ""); @@ -70,11 +64,7 @@ class AdaptiveModeLevelEstimator { ApmDataDumper* const apm_data_dumper_; - const AudioProcessing::Config::GainController2::LevelEstimator - level_estimator_type_; const int adjacent_speech_frames_threshold_; - const float initial_saturation_margin_db_; - const float extra_saturation_margin_db_; LevelEstimatorState preliminary_state_; LevelEstimatorState reliable_state_; float level_dbfs_; diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.cc b/modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.cc deleted file mode 100644 index 5ceeb7df77..0000000000 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.h" - -#include -#include - -#include "modules/audio_processing/agc2/agc2_common.h" -#include "modules/audio_processing/include/audio_frame_view.h" - -namespace webrtc { - -AdaptiveModeLevelEstimatorAgc::AdaptiveModeLevelEstimatorAgc( - ApmDataDumper* apm_data_dumper) - : level_estimator_(apm_data_dumper) { - set_target_level_dbfs(kDefaultAgc2LevelHeadroomDbfs); -} - -// |audio| must be mono; in a multi-channel stream, provide the first (usually -// left) channel. -void AdaptiveModeLevelEstimatorAgc::Process(const int16_t* audio, - size_t length, - int sample_rate_hz) { - std::vector float_audio_frame(audio, audio + length); - const float* const first_channel = &float_audio_frame[0]; - AudioFrameView frame_view(&first_channel, 1 /* num channels */, - length); - const auto vad_prob = agc2_vad_.AnalyzeFrame(frame_view); - latest_voice_probability_ = vad_prob.speech_probability; - if (latest_voice_probability_ > kVadConfidenceThreshold) { - time_in_ms_since_last_estimate_ += kFrameDurationMs; - } - level_estimator_.Update(vad_prob); -} - -// Retrieves the difference between the target RMS level and the current -// signal RMS level in dB. Returns true if an update is available and false -// otherwise, in which case |error| should be ignored and no action taken. -bool AdaptiveModeLevelEstimatorAgc::GetRmsErrorDb(int* error) { - if (time_in_ms_since_last_estimate_ <= kTimeUntilConfidentMs) { - return false; - } - *error = - std::floor(target_level_dbfs() - level_estimator_.level_dbfs() + 0.5f); - time_in_ms_since_last_estimate_ = 0; - return true; -} - -void AdaptiveModeLevelEstimatorAgc::Reset() { - level_estimator_.Reset(); -} - -float AdaptiveModeLevelEstimatorAgc::voice_probability() const { - return latest_voice_probability_; -} - -} // namespace webrtc diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.h b/modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.h deleted file mode 100644 index bc6fa843b5..0000000000 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator_agc.h +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_AGC_H_ -#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_AGC_H_ - -#include -#include - -#include "modules/audio_processing/agc/agc.h" -#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h" -#include "modules/audio_processing/agc2/saturation_protector.h" -#include "modules/audio_processing/agc2/vad_with_level.h" - -namespace webrtc { -class AdaptiveModeLevelEstimatorAgc : public Agc { - public: - explicit AdaptiveModeLevelEstimatorAgc(ApmDataDumper* apm_data_dumper); - - // |audio| must be mono; in a multi-channel stream, provide the first (usually - // left) channel. - void Process(const int16_t* audio, - size_t length, - int sample_rate_hz) override; - - // Retrieves the difference between the target RMS level and the current - // signal RMS level in dB. Returns true if an update is available and false - // otherwise, in which case |error| should be ignored and no action taken. - bool GetRmsErrorDb(int* error) override; - void Reset() override; - - float voice_probability() const override; - - private: - static constexpr int kTimeUntilConfidentMs = 700; - static constexpr int kDefaultAgc2LevelHeadroomDbfs = -1; - int32_t time_in_ms_since_last_estimate_ = 0; - AdaptiveModeLevelEstimator level_estimator_; - VadLevelAnalyzer agc2_vad_; - float latest_voice_probability_ = 0.f; -}; -} // namespace webrtc - -#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_AGC_H_ diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc b/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc index ea35797f5e..c55950ac29 100644 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc +++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc @@ -19,22 +19,34 @@ namespace webrtc { namespace { -constexpr float kInitialSaturationMarginDb = 20.f; -constexpr float kExtraSaturationMarginDb = 2.f; - -static_assert(kInitialSpeechLevelEstimateDbfs < 0.f, ""); -constexpr float kVadLevelRms = kInitialSpeechLevelEstimateDbfs / 2.f; -constexpr float kVadLevelPeak = kInitialSpeechLevelEstimateDbfs / 3.f; - -constexpr VadLevelAnalyzer::Result kVadDataSpeech{/*speech_probability=*/1.f, +// Number of speech frames that the level estimator must observe in order to +// become confident about the estimated level. +constexpr int kNumFramesToConfidence = + kLevelEstimatorTimeToConfidenceMs / kFrameDurationMs; +static_assert(kNumFramesToConfidence > 0, ""); + +// Fake levels and speech probabilities used in the tests. +static_assert(kInitialSpeechLevelEstimateDbfs < 0.0f, ""); +constexpr float kVadLevelRms = kInitialSpeechLevelEstimateDbfs / 2.0f; +constexpr float kVadLevelPeak = kInitialSpeechLevelEstimateDbfs / 3.0f; +static_assert(kVadLevelRms < kVadLevelPeak, ""); +static_assert(kVadLevelRms > kInitialSpeechLevelEstimateDbfs, ""); +static_assert(kVadLevelRms - kInitialSpeechLevelEstimateDbfs > 5.0f, + "Adjust `kVadLevelRms` so that the difference from the initial " + "level is wide enough for the tests."); + +constexpr VadLevelAnalyzer::Result kVadDataSpeech{/*speech_probability=*/1.0f, kVadLevelRms, kVadLevelPeak}; constexpr VadLevelAnalyzer::Result kVadDataNonSpeech{ - /*speech_probability=*/kVadConfidenceThreshold / 2.f, kVadLevelRms, + /*speech_probability=*/kVadConfidenceThreshold / 2.0f, kVadLevelRms, kVadLevelPeak}; -constexpr float kMinSpeechProbability = 0.f; -constexpr float kMaxSpeechProbability = 1.f; +constexpr float kMinSpeechProbability = 0.0f; +constexpr float kMaxSpeechProbability = 1.0f; + +constexpr float kConvergenceSpeedTestsLevelTolerance = 0.5f; +// Provides the `vad_level` value `num_iterations` times to `level_estimator`. void RunOnConstantLevel(int num_iterations, const VadLevelAnalyzer::Result& vad_level, AdaptiveModeLevelEstimator& level_estimator) { @@ -43,172 +55,125 @@ void RunOnConstantLevel(int num_iterations, } } +// Level estimator with data dumper. struct TestLevelEstimator { TestLevelEstimator() : data_dumper(0), estimator(std::make_unique( &data_dumper, - AudioProcessing::Config::GainController2::LevelEstimator::kRms, - /*adjacent_speech_frames_threshold=*/1, - kInitialSaturationMarginDb, - kExtraSaturationMarginDb)) {} + /*adjacent_speech_frames_threshold=*/1)) {} ApmDataDumper data_dumper; std::unique_ptr estimator; }; -TEST(AutomaticGainController2AdaptiveModeLevelEstimator, - EstimatorShouldNotCrash) { +// Checks the initially estimated level. +TEST(GainController2AdaptiveModeLevelEstimator, CheckInitialEstimate) { TestLevelEstimator level_estimator; - - VadLevelAnalyzer::Result vad_level{kMaxSpeechProbability, /*rms_dbfs=*/-20.f, - /*peak_dbfs=*/-10.f}; - level_estimator.estimator->Update(vad_level); - static_cast(level_estimator.estimator->level_dbfs()); + EXPECT_FLOAT_EQ(level_estimator.estimator->level_dbfs(), + kInitialSpeechLevelEstimateDbfs); } -TEST(AutomaticGainController2AdaptiveModeLevelEstimator, LevelShouldStabilize) { +// Checks that the level estimator converges to a constant input speech level. +TEST(GainController2AdaptiveModeLevelEstimator, LevelStabilizes) { TestLevelEstimator level_estimator; - - constexpr float kSpeechPeakDbfs = -15.f; - RunOnConstantLevel(100, - VadLevelAnalyzer::Result{kMaxSpeechProbability, - /*rms_dbfs=*/kSpeechPeakDbfs - - kInitialSaturationMarginDb, - kSpeechPeakDbfs}, + RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence, kVadDataSpeech, *level_estimator.estimator); - - EXPECT_NEAR( - level_estimator.estimator->level_dbfs() - kExtraSaturationMarginDb, - kSpeechPeakDbfs, 0.1f); + const float estimated_level_dbfs = level_estimator.estimator->level_dbfs(); + RunOnConstantLevel(/*num_iterations=*/1, kVadDataSpeech, + *level_estimator.estimator); + EXPECT_NEAR(level_estimator.estimator->level_dbfs(), estimated_level_dbfs, + 0.1f); } -TEST(AutomaticGainController2AdaptiveModeLevelEstimator, - EstimatorIgnoresZeroProbabilityFrames) { +// Checks that the level controller does not become confident when too few +// speech frames are observed. +TEST(GainController2AdaptiveModeLevelEstimator, IsNotConfident) { TestLevelEstimator level_estimator; + RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence / 2, + kVadDataSpeech, *level_estimator.estimator); + EXPECT_FALSE(level_estimator.estimator->IsConfident()); +} - // Run for one second of fake audio. - constexpr float kSpeechRmsDbfs = -25.f; - RunOnConstantLevel(100, - VadLevelAnalyzer::Result{kMaxSpeechProbability, - /*rms_dbfs=*/kSpeechRmsDbfs - - kInitialSaturationMarginDb, - /*peak_dbfs=*/kSpeechRmsDbfs}, +// Checks that the level controller becomes confident when enough speech frames +// are observed. +TEST(GainController2AdaptiveModeLevelEstimator, IsConfident) { + TestLevelEstimator level_estimator; + RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence, kVadDataSpeech, *level_estimator.estimator); + EXPECT_TRUE(level_estimator.estimator->IsConfident()); +} - // Run for one more second, but mark as not speech. - constexpr float kNoiseRmsDbfs = 0.f; - RunOnConstantLevel(100, +// Checks that the estimated level is not affected by the level of non-speech +// frames. +TEST(GainController2AdaptiveModeLevelEstimator, + EstimatorIgnoresNonSpeechFrames) { + TestLevelEstimator level_estimator; + // Simulate speech. + RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence, kVadDataSpeech, + *level_estimator.estimator); + const float estimated_level_dbfs = level_estimator.estimator->level_dbfs(); + // Simulate full-scale non-speech. + RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence, VadLevelAnalyzer::Result{kMinSpeechProbability, - /*rms_dbfs=*/kNoiseRmsDbfs, - /*peak_dbfs=*/kNoiseRmsDbfs}, + /*rms_dbfs=*/0.0f, + /*peak_dbfs=*/0.0f}, *level_estimator.estimator); - - // Level should not have changed. - EXPECT_NEAR( - level_estimator.estimator->level_dbfs() - kExtraSaturationMarginDb, - kSpeechRmsDbfs, 0.1f); + // No estimated level change is expected. + EXPECT_FLOAT_EQ(level_estimator.estimator->level_dbfs(), + estimated_level_dbfs); } -TEST(AutomaticGainController2AdaptiveModeLevelEstimator, TimeToAdapt) { +// Checks the convergence speed of the estimator before it becomes confident. +TEST(GainController2AdaptiveModeLevelEstimator, + ConvergenceSpeedBeforeConfidence) { TestLevelEstimator level_estimator; - - // Run for one 'window size' interval. - constexpr float kInitialSpeechRmsDbfs = -30.f; - RunOnConstantLevel( - kFullBufferSizeMs / kFrameDurationMs, - VadLevelAnalyzer::Result{ - kMaxSpeechProbability, - /*rms_dbfs=*/kInitialSpeechRmsDbfs - kInitialSaturationMarginDb, - /*peak_dbfs=*/kInitialSpeechRmsDbfs}, - *level_estimator.estimator); - - // Run for one half 'window size' interval. This should not be enough to - // adapt. - constexpr float kDifferentSpeechRmsDbfs = -10.f; - // It should at most differ by 25% after one half 'window size' interval. - // TODO(crbug.com/webrtc/7494): Add constexpr for repeated expressions. - const float kMaxDifferenceDb = - 0.25f * std::abs(kDifferentSpeechRmsDbfs - kInitialSpeechRmsDbfs); - RunOnConstantLevel( - static_cast(kFullBufferSizeMs / kFrameDurationMs / 2), - VadLevelAnalyzer::Result{ - kMaxSpeechProbability, - /*rms_dbfs=*/kDifferentSpeechRmsDbfs - kInitialSaturationMarginDb, - /*peak_dbfs=*/kDifferentSpeechRmsDbfs}, - *level_estimator.estimator); - EXPECT_GT(std::abs(kDifferentSpeechRmsDbfs - - level_estimator.estimator->level_dbfs()), - kMaxDifferenceDb); - - // Run for some more time. Afterwards, we should have adapted. - RunOnConstantLevel( - static_cast(3 * kFullBufferSizeMs / kFrameDurationMs), - VadLevelAnalyzer::Result{ - kMaxSpeechProbability, - /*rms_dbfs=*/kDifferentSpeechRmsDbfs - kInitialSaturationMarginDb, - /*peak_dbfs=*/kDifferentSpeechRmsDbfs}, - *level_estimator.estimator); - EXPECT_NEAR( - level_estimator.estimator->level_dbfs() - kExtraSaturationMarginDb, - kDifferentSpeechRmsDbfs, kMaxDifferenceDb * 0.5f); + RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence, kVadDataSpeech, + *level_estimator.estimator); + EXPECT_NEAR(level_estimator.estimator->level_dbfs(), kVadDataSpeech.rms_dbfs, + kConvergenceSpeedTestsLevelTolerance); } -TEST(AutomaticGainController2AdaptiveModeLevelEstimator, - ResetGivesFastAdaptation) { +// Checks the convergence speed of the estimator after it becomes confident. +TEST(GainController2AdaptiveModeLevelEstimator, + ConvergenceSpeedAfterConfidence) { TestLevelEstimator level_estimator; - - // Run the level estimator for one window size interval. This gives time to - // adapt. - constexpr float kInitialSpeechRmsDbfs = -30.f; + // Reach confidence using the initial level estimate. RunOnConstantLevel( - kFullBufferSizeMs / kFrameDurationMs, + /*num_iterations=*/kNumFramesToConfidence, VadLevelAnalyzer::Result{ kMaxSpeechProbability, - /*rms_dbfs=*/kInitialSpeechRmsDbfs - kInitialSaturationMarginDb, - /*peak_dbfs=*/kInitialSpeechRmsDbfs}, + /*rms_dbfs=*/kInitialSpeechLevelEstimateDbfs, + /*peak_dbfs=*/kInitialSpeechLevelEstimateDbfs + 6.0f}, *level_estimator.estimator); - - constexpr float kDifferentSpeechRmsDbfs = -10.f; - // Reset and run one half window size interval. - level_estimator.estimator->Reset(); - + // No estimate change should occur, but confidence is achieved. + ASSERT_FLOAT_EQ(level_estimator.estimator->level_dbfs(), + kInitialSpeechLevelEstimateDbfs); + ASSERT_TRUE(level_estimator.estimator->IsConfident()); + // After confidence. + constexpr float kConvergenceTimeAfterConfidenceNumFrames = 600; // 6 seconds. + static_assert( + kConvergenceTimeAfterConfidenceNumFrames > kNumFramesToConfidence, ""); RunOnConstantLevel( - kFullBufferSizeMs / kFrameDurationMs / 2, - VadLevelAnalyzer::Result{ - kMaxSpeechProbability, - /*rms_dbfs=*/kDifferentSpeechRmsDbfs - kInitialSaturationMarginDb, - /*peak_dbfs=*/kDifferentSpeechRmsDbfs}, - *level_estimator.estimator); - - // The level should be close to 'kDifferentSpeechRmsDbfs'. - const float kMaxDifferenceDb = - 0.1f * std::abs(kDifferentSpeechRmsDbfs - kInitialSpeechRmsDbfs); - EXPECT_LT(std::abs(kDifferentSpeechRmsDbfs - - (level_estimator.estimator->level_dbfs() - - kExtraSaturationMarginDb)), - kMaxDifferenceDb); + /*num_iterations=*/kConvergenceTimeAfterConfidenceNumFrames, + kVadDataSpeech, *level_estimator.estimator); + EXPECT_NEAR(level_estimator.estimator->level_dbfs(), kVadDataSpeech.rms_dbfs, + kConvergenceSpeedTestsLevelTolerance); } -struct TestConfig { - int min_consecutive_speech_frames; - float initial_saturation_margin_db; - float extra_saturation_margin_db; +class AdaptiveModeLevelEstimatorParametrization + : public ::testing::TestWithParam { + protected: + int adjacent_speech_frames_threshold() const { return GetParam(); } }; -class AdaptiveModeLevelEstimatorTest - : public ::testing::TestWithParam {}; - -TEST_P(AdaptiveModeLevelEstimatorTest, DoNotAdaptToShortSpeechSegments) { - const auto params = GetParam(); +TEST_P(AdaptiveModeLevelEstimatorParametrization, + DoNotAdaptToShortSpeechSegments) { ApmDataDumper apm_data_dumper(0); AdaptiveModeLevelEstimator level_estimator( - &apm_data_dumper, - AudioProcessing::Config::GainController2::LevelEstimator::kRms, - params.min_consecutive_speech_frames, params.initial_saturation_margin_db, - params.extra_saturation_margin_db); + &apm_data_dumper, adjacent_speech_frames_threshold()); const float initial_level = level_estimator.level_dbfs(); - ASSERT_LT(initial_level, kVadDataSpeech.rms_dbfs); - for (int i = 0; i < params.min_consecutive_speech_frames - 1; ++i) { + ASSERT_LT(initial_level, kVadDataSpeech.peak_dbfs); + for (int i = 0; i < adjacent_speech_frames_threshold() - 1; ++i) { SCOPED_TRACE(i); level_estimator.Update(kVadDataSpeech); EXPECT_EQ(initial_level, level_estimator.level_dbfs()); @@ -217,26 +182,21 @@ TEST_P(AdaptiveModeLevelEstimatorTest, DoNotAdaptToShortSpeechSegments) { EXPECT_EQ(initial_level, level_estimator.level_dbfs()); } -TEST_P(AdaptiveModeLevelEstimatorTest, AdaptToEnoughSpeechSegments) { - const auto params = GetParam(); +TEST_P(AdaptiveModeLevelEstimatorParametrization, AdaptToEnoughSpeechSegments) { ApmDataDumper apm_data_dumper(0); AdaptiveModeLevelEstimator level_estimator( - &apm_data_dumper, - AudioProcessing::Config::GainController2::LevelEstimator::kRms, - params.min_consecutive_speech_frames, params.initial_saturation_margin_db, - params.extra_saturation_margin_db); + &apm_data_dumper, adjacent_speech_frames_threshold()); const float initial_level = level_estimator.level_dbfs(); - ASSERT_LT(initial_level, kVadDataSpeech.rms_dbfs); - for (int i = 0; i < params.min_consecutive_speech_frames; ++i) { + ASSERT_LT(initial_level, kVadDataSpeech.peak_dbfs); + for (int i = 0; i < adjacent_speech_frames_threshold(); ++i) { level_estimator.Update(kVadDataSpeech); } EXPECT_LT(initial_level, level_estimator.level_dbfs()); } -INSTANTIATE_TEST_SUITE_P(AutomaticGainController2, - AdaptiveModeLevelEstimatorTest, - ::testing::Values(TestConfig{1, 0.f, 0.f}, - TestConfig{9, 0.f, 0.f})); +INSTANTIATE_TEST_SUITE_P(GainController2, + AdaptiveModeLevelEstimatorParametrization, + ::testing::Values(1, 9, 17)); } // namespace } // namespace webrtc diff --git a/modules/audio_processing/agc2/agc2_common.h b/modules/audio_processing/agc2/agc2_common.h index 5d01100eb7..adb1614926 100644 --- a/modules/audio_processing/agc2/agc2_common.h +++ b/modules/audio_processing/agc2/agc2_common.h @@ -11,74 +11,59 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_ #define MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_ -#include - namespace webrtc { -constexpr float kMinFloatS16Value = -32768.f; -constexpr float kMaxFloatS16Value = 32767.f; +constexpr float kMinFloatS16Value = -32768.0f; +constexpr float kMaxFloatS16Value = 32767.0f; constexpr float kMaxAbsFloatS16Value = 32768.0f; -constexpr size_t kFrameDurationMs = 10; -constexpr size_t kSubFramesInFrame = 20; -constexpr size_t kMaximalNumberOfSamplesPerChannel = 480; +// Minimum audio level in dBFS scale for S16 samples. +constexpr float kMinLevelDbfs = -90.31f; -constexpr float kAttackFilterConstant = 0.f; +constexpr int kFrameDurationMs = 10; +constexpr int kSubFramesInFrame = 20; +constexpr int kMaximalNumberOfSamplesPerChannel = 480; // Adaptive digital gain applier settings below. -constexpr float kHeadroomDbfs = 1.f; -constexpr float kMaxGainDb = 30.f; -constexpr float kInitialAdaptiveDigitalGainDb = 8.f; +constexpr float kHeadroomDbfs = 1.0f; +constexpr float kMaxGainDb = 30.0f; +constexpr float kInitialAdaptiveDigitalGainDb = 8.0f; // At what limiter levels should we start decreasing the adaptive digital gain. constexpr float kLimiterThresholdForAgcGainDbfs = -kHeadroomDbfs; // This is the threshold for speech. Speech frames are used for updating the // speech level, measuring the amount of speech, and decide when to allow target // gain reduction. -constexpr float kVadConfidenceThreshold = 0.9f; - -// The amount of 'memory' of the Level Estimator. Decides leak factors. -constexpr size_t kFullBufferSizeMs = 1200; -constexpr float kFullBufferLeakFactor = 1.f - 1.f / kFullBufferSizeMs; +constexpr float kVadConfidenceThreshold = 0.95f; -constexpr float kInitialSpeechLevelEstimateDbfs = -30.f; +// Adaptive digital level estimator parameters. +// Number of milliseconds of speech frames to observe to make the estimator +// confident. +constexpr float kLevelEstimatorTimeToConfidenceMs = 400; +constexpr float kLevelEstimatorLeakFactor = + 1.0f - 1.0f / kLevelEstimatorTimeToConfidenceMs; // Robust VAD probability and speech decisions. -constexpr float kDefaultSmoothedVadProbabilityAttack = 1.f; -constexpr int kDefaultLevelEstimatorAdjacentSpeechFramesThreshold = 1; +constexpr int kDefaultLevelEstimatorAdjacentSpeechFramesThreshold = 12; // Saturation Protector settings. -constexpr float kDefaultInitialSaturationMarginDb = 20.f; -constexpr float kDefaultExtraSaturationMarginDb = 2.f; - -constexpr size_t kPeakEnveloperSuperFrameLengthMs = 400; -static_assert(kFullBufferSizeMs % kPeakEnveloperSuperFrameLengthMs == 0, - "Full buffer size should be a multiple of super frame length for " - "optimal Saturation Protector performance."); - -constexpr size_t kPeakEnveloperBufferSize = - kFullBufferSizeMs / kPeakEnveloperSuperFrameLengthMs + 1; - -// This value is 10 ** (-1/20 * frame_size_ms / satproc_attack_ms), -// where satproc_attack_ms is 5000. -constexpr float kSaturationProtectorAttackConstant = 0.9988493699365052f; - -// This value is 10 ** (-1/20 * frame_size_ms / satproc_decay_ms), -// where satproc_decay_ms is 1000. -constexpr float kSaturationProtectorDecayConstant = 0.9997697679981565f; +constexpr float kSaturationProtectorInitialHeadroomDb = 20.0f; +constexpr float kSaturationProtectorExtraHeadroomDb = 5.0f; +constexpr int kSaturationProtectorBufferSize = 4; -// This is computed from kDecayMs by -// 10 ** (-1/20 * subframe_duration / kDecayMs). -// |subframe_duration| is |kFrameDurationMs / kSubFramesInFrame|. -// kDecayMs is defined in agc2_testing_common.h -constexpr float kDecayFilterConstant = 0.9998848773724686f; +// Set the initial speech level estimate so that `kInitialAdaptiveDigitalGainDb` +// is applied at the beginning of the call. +constexpr float kInitialSpeechLevelEstimateDbfs = + -kSaturationProtectorExtraHeadroomDb - + kSaturationProtectorInitialHeadroomDb - kInitialAdaptiveDigitalGainDb - + kHeadroomDbfs; // Number of interpolation points for each region of the limiter. // These values have been tuned to limit the interpolated gain curve error given // the limiter parameters and allowing a maximum error of +/- 32768^-1. -constexpr size_t kInterpolatedGainCurveKneePoints = 22; -constexpr size_t kInterpolatedGainCurveBeyondKneePoints = 10; -constexpr size_t kInterpolatedGainCurveTotalPoints = +constexpr int kInterpolatedGainCurveKneePoints = 22; +constexpr int kInterpolatedGainCurveBeyondKneePoints = 10; +constexpr int kInterpolatedGainCurveTotalPoints = kInterpolatedGainCurveKneePoints + kInterpolatedGainCurveBeyondKneePoints; } // namespace webrtc diff --git a/modules/audio_processing/agc2/agc2_testing_common.cc b/modules/audio_processing/agc2/agc2_testing_common.cc index 6c22492e88..125e551b72 100644 --- a/modules/audio_processing/agc2/agc2_testing_common.cc +++ b/modules/audio_processing/agc2/agc2_testing_common.cc @@ -10,24 +10,84 @@ #include "modules/audio_processing/agc2/agc2_testing_common.h" +#include + #include "rtc_base/checks.h" namespace webrtc { - namespace test { -std::vector LinSpace(const double l, - const double r, - size_t num_points) { - RTC_CHECK(num_points >= 2); +std::vector LinSpace(double l, double r, int num_points) { + RTC_CHECK_GE(num_points, 2); std::vector points(num_points); const double step = (r - l) / (num_points - 1.0); points[0] = l; - for (size_t i = 1; i < num_points - 1; i++) { + for (int i = 1; i < num_points - 1; i++) { points[i] = static_cast(l) + i * step; } points[num_points - 1] = r; return points; } + +WhiteNoiseGenerator::WhiteNoiseGenerator(int min_amplitude, int max_amplitude) + : rand_gen_(42), + min_amplitude_(min_amplitude), + max_amplitude_(max_amplitude) { + RTC_DCHECK_LT(min_amplitude_, max_amplitude_); + RTC_DCHECK_LE(kMinS16, min_amplitude_); + RTC_DCHECK_LE(min_amplitude_, kMaxS16); + RTC_DCHECK_LE(kMinS16, max_amplitude_); + RTC_DCHECK_LE(max_amplitude_, kMaxS16); +} + +float WhiteNoiseGenerator::operator()() { + return static_cast(rand_gen_.Rand(min_amplitude_, max_amplitude_)); +} + +SineGenerator::SineGenerator(float amplitude, + float frequency_hz, + int sample_rate_hz) + : amplitude_(amplitude), + frequency_hz_(frequency_hz), + sample_rate_hz_(sample_rate_hz), + x_radians_(0.0f) { + RTC_DCHECK_GT(amplitude_, 0); + RTC_DCHECK_LE(amplitude_, kMaxS16); +} + +float SineGenerator::operator()() { + constexpr float kPi = 3.1415926536f; + x_radians_ += frequency_hz_ / sample_rate_hz_ * 2 * kPi; + if (x_radians_ >= 2 * kPi) { + x_radians_ -= 2 * kPi; + } + return amplitude_ * std::sinf(x_radians_); +} + +PulseGenerator::PulseGenerator(float pulse_amplitude, + float no_pulse_amplitude, + float frequency_hz, + int sample_rate_hz) + : pulse_amplitude_(pulse_amplitude), + no_pulse_amplitude_(no_pulse_amplitude), + samples_period_( + static_cast(static_cast(sample_rate_hz) / frequency_hz)), + sample_counter_(0) { + RTC_DCHECK_GE(pulse_amplitude_, kMinS16); + RTC_DCHECK_LE(pulse_amplitude_, kMaxS16); + RTC_DCHECK_GT(no_pulse_amplitude_, kMinS16); + RTC_DCHECK_LE(no_pulse_amplitude_, kMaxS16); + RTC_DCHECK_GT(sample_rate_hz, frequency_hz); +} + +float PulseGenerator::operator()() { + sample_counter_++; + if (sample_counter_ >= samples_period_) { + sample_counter_ -= samples_period_; + } + return static_cast(sample_counter_ == 0 ? pulse_amplitude_ + : no_pulse_amplitude_); +} + } // namespace test } // namespace webrtc diff --git a/modules/audio_processing/agc2/agc2_testing_common.h b/modules/audio_processing/agc2/agc2_testing_common.h index 7bfadbb3fc..4572d9cffd 100644 --- a/modules/audio_processing/agc2/agc2_testing_common.h +++ b/modules/audio_processing/agc2/agc2_testing_common.h @@ -11,17 +11,19 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_ #define MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_ -#include - #include #include -#include "rtc_base/checks.h" +#include "rtc_base/random.h" namespace webrtc { - namespace test { +constexpr float kMinS16 = + static_cast(std::numeric_limits::min()); +constexpr float kMaxS16 = + static_cast(std::numeric_limits::max()); + // Level Estimator test parameters. constexpr float kDecayMs = 500.f; @@ -29,47 +31,49 @@ constexpr float kDecayMs = 500.f; constexpr float kLimiterMaxInputLevelDbFs = 1.f; constexpr float kLimiterKneeSmoothnessDb = 1.f; constexpr float kLimiterCompressionRatio = 5.f; -constexpr float kPi = 3.1415926536f; -std::vector LinSpace(const double l, const double r, size_t num_points); +// Returns evenly spaced `num_points` numbers over a specified interval [l, r]. +std::vector LinSpace(double l, double r, int num_points); + +// Generates white noise. +class WhiteNoiseGenerator { + public: + WhiteNoiseGenerator(int min_amplitude, int max_amplitude); + float operator()(); + + private: + Random rand_gen_; + const int min_amplitude_; + const int max_amplitude_; +}; +// Generates a sine function. class SineGenerator { public: - SineGenerator(float frequency, int rate) - : frequency_(frequency), rate_(rate) {} - float operator()() { - x_radians_ += frequency_ / rate_ * 2 * kPi; - if (x_radians_ > 2 * kPi) { - x_radians_ -= 2 * kPi; - } - return 1000.f * sinf(x_radians_); - } + SineGenerator(float amplitude, float frequency_hz, int sample_rate_hz); + float operator()(); private: - float frequency_; - int rate_; - float x_radians_ = 0.f; + const float amplitude_; + const float frequency_hz_; + const int sample_rate_hz_; + float x_radians_; }; +// Generates periodic pulses. class PulseGenerator { public: - PulseGenerator(float frequency, int rate) - : samples_period_( - static_cast(static_cast(rate) / frequency)) { - RTC_DCHECK_GT(rate, frequency); - } - float operator()() { - sample_counter_++; - if (sample_counter_ >= samples_period_) { - sample_counter_ -= samples_period_; - } - return static_cast( - sample_counter_ == 0 ? std::numeric_limits::max() : 10.f); - } + PulseGenerator(float pulse_amplitude, + float no_pulse_amplitude, + float frequency_hz, + int sample_rate_hz); + float operator()(); private: - int samples_period_; - int sample_counter_ = 0; + const float pulse_amplitude_; + const float no_pulse_amplitude_; + const int samples_period_; + int sample_counter_; }; } // namespace test diff --git a/modules/audio_processing/agc2/agc2_testing_common_unittest.cc b/modules/audio_processing/agc2/agc2_testing_common_unittest.cc index f52ea3caf5..79c3cc95d9 100644 --- a/modules/audio_processing/agc2/agc2_testing_common_unittest.cc +++ b/modules/audio_processing/agc2/agc2_testing_common_unittest.cc @@ -14,7 +14,7 @@ namespace webrtc { -TEST(AutomaticGainController2Common, TestLinSpace) { +TEST(GainController2TestingCommon, LinSpace) { std::vector points1 = test::LinSpace(-1.0, 2.0, 4); const std::vector expected_points1{{-1.0, 0.0, 1.0, 2.0}}; EXPECT_EQ(expected_points1, points1); diff --git a/modules/audio_processing/agc2/down_sampler.cc b/modules/audio_processing/agc2/down_sampler.cc index 654ed4be37..fd1a2c3a46 100644 --- a/modules/audio_processing/agc2/down_sampler.cc +++ b/modules/audio_processing/agc2/down_sampler.cc @@ -72,7 +72,7 @@ void DownSampler::Initialize(int sample_rate_hz) { void DownSampler::DownSample(rtc::ArrayView in, rtc::ArrayView out) { - data_dumper_->DumpWav("lc_down_sampler_input", in, sample_rate_hz_, 1); + data_dumper_->DumpWav("agc2_down_sampler_input", in, sample_rate_hz_, 1); RTC_DCHECK_EQ(sample_rate_hz_ * kChunkSizeMs / 1000, in.size()); RTC_DCHECK_EQ(kSampleRate8kHz * kChunkSizeMs / 1000, out.size()); const size_t kMaxNumFrames = kSampleRate48kHz * kChunkSizeMs / 1000; @@ -93,7 +93,7 @@ void DownSampler::DownSample(rtc::ArrayView in, std::copy(in.data(), in.data() + in.size(), out.data()); } - data_dumper_->DumpWav("lc_down_sampler_output", out, kSampleRate8kHz, 1); + data_dumper_->DumpWav("agc2_down_sampler_output", out, kSampleRate8kHz, 1); } } // namespace webrtc diff --git a/modules/audio_processing/agc2/down_sampler.h b/modules/audio_processing/agc2/down_sampler.h index be7cbb3da7..a44f96fa2d 100644 --- a/modules/audio_processing/agc2/down_sampler.h +++ b/modules/audio_processing/agc2/down_sampler.h @@ -31,7 +31,7 @@ class DownSampler { void DownSample(rtc::ArrayView in, rtc::ArrayView out); private: - ApmDataDumper* data_dumper_; + ApmDataDumper* const data_dumper_; int sample_rate_hz_; int down_sampling_factor_; BiQuadFilter low_pass_filter_; diff --git a/modules/audio_processing/agc2/fixed_digital_level_estimator.cc b/modules/audio_processing/agc2/fixed_digital_level_estimator.cc index 971f4f62b7..3e9bb2efbd 100644 --- a/modules/audio_processing/agc2/fixed_digital_level_estimator.cc +++ b/modules/audio_processing/agc2/fixed_digital_level_estimator.cc @@ -22,10 +22,18 @@ namespace { constexpr float kInitialFilterStateLevel = 0.f; +// Instant attack. +constexpr float kAttackFilterConstant = 0.f; +// This is computed from kDecayMs by +// 10 ** (-1/20 * subframe_duration / kDecayMs). +// |subframe_duration| is |kFrameDurationMs / kSubFramesInFrame|. +// kDecayMs is defined in agc2_testing_common.h +constexpr float kDecayFilterConstant = 0.9998848773724686f; + } // namespace FixedDigitalLevelEstimator::FixedDigitalLevelEstimator( - size_t sample_rate_hz, + int sample_rate_hz, ApmDataDumper* apm_data_dumper) : apm_data_dumper_(apm_data_dumper), filter_state_level_(kInitialFilterStateLevel) { @@ -52,8 +60,8 @@ std::array FixedDigitalLevelEstimator::ComputeLevel( for (size_t channel_idx = 0; channel_idx < float_frame.num_channels(); ++channel_idx) { const auto channel = float_frame.channel(channel_idx); - for (size_t sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) { - for (size_t sample_in_sub_frame = 0; + for (int sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) { + for (int sample_in_sub_frame = 0; sample_in_sub_frame < samples_in_sub_frame_; ++sample_in_sub_frame) { envelope[sub_frame] = std::max(envelope[sub_frame], @@ -66,14 +74,14 @@ std::array FixedDigitalLevelEstimator::ComputeLevel( // Make sure envelope increases happen one step earlier so that the // corresponding *gain decrease* doesn't miss a sudden signal // increase due to interpolation. - for (size_t sub_frame = 0; sub_frame < kSubFramesInFrame - 1; ++sub_frame) { + for (int sub_frame = 0; sub_frame < kSubFramesInFrame - 1; ++sub_frame) { if (envelope[sub_frame] < envelope[sub_frame + 1]) { envelope[sub_frame] = envelope[sub_frame + 1]; } } // Add attack / decay smoothing. - for (size_t sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) { + for (int sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) { const float envelope_value = envelope[sub_frame]; if (envelope_value > filter_state_level_) { envelope[sub_frame] = envelope_value * (1 - kAttackFilterConstant) + @@ -97,9 +105,9 @@ std::array FixedDigitalLevelEstimator::ComputeLevel( return envelope; } -void FixedDigitalLevelEstimator::SetSampleRate(size_t sample_rate_hz) { - samples_in_frame_ = rtc::CheckedDivExact(sample_rate_hz * kFrameDurationMs, - static_cast(1000)); +void FixedDigitalLevelEstimator::SetSampleRate(int sample_rate_hz) { + samples_in_frame_ = + rtc::CheckedDivExact(sample_rate_hz * kFrameDurationMs, 1000); samples_in_sub_frame_ = rtc::CheckedDivExact(samples_in_frame_, kSubFramesInFrame); CheckParameterCombination(); diff --git a/modules/audio_processing/agc2/fixed_digital_level_estimator.h b/modules/audio_processing/agc2/fixed_digital_level_estimator.h index aa84a2e0f1..d96aedaf9e 100644 --- a/modules/audio_processing/agc2/fixed_digital_level_estimator.h +++ b/modules/audio_processing/agc2/fixed_digital_level_estimator.h @@ -31,7 +31,7 @@ class FixedDigitalLevelEstimator { // kSubFramesInSample. For kFrameDurationMs=10 and // kSubFramesInSample=20, this means that sample_rate_hz has to be // divisible by 2000. - FixedDigitalLevelEstimator(size_t sample_rate_hz, + FixedDigitalLevelEstimator(int sample_rate_hz, ApmDataDumper* apm_data_dumper); // The input is assumed to be in FloatS16 format. Scaled input will @@ -43,7 +43,7 @@ class FixedDigitalLevelEstimator { // Rate may be changed at any time (but not concurrently) from the // value passed to the constructor. The class is not thread safe. - void SetSampleRate(size_t sample_rate_hz); + void SetSampleRate(int sample_rate_hz); // Resets the level estimator internal state. void Reset(); @@ -55,8 +55,8 @@ class FixedDigitalLevelEstimator { ApmDataDumper* const apm_data_dumper_ = nullptr; float filter_state_level_; - size_t samples_in_frame_; - size_t samples_in_sub_frame_; + int samples_in_frame_; + int samples_in_sub_frame_; RTC_DISALLOW_COPY_AND_ASSIGN(FixedDigitalLevelEstimator); }; diff --git a/modules/audio_processing/agc2/fixed_digital_level_estimator_unittest.cc b/modules/audio_processing/agc2/fixed_digital_level_estimator_unittest.cc index 7547f8e2ed..97b421d04c 100644 --- a/modules/audio_processing/agc2/fixed_digital_level_estimator_unittest.cc +++ b/modules/audio_processing/agc2/fixed_digital_level_estimator_unittest.cc @@ -101,25 +101,25 @@ float TimeMsToDecreaseLevel(int sample_rate_hz, } } // namespace -TEST(AutomaticGainController2LevelEstimator, EstimatorShouldNotCrash) { +TEST(GainController2FixedDigitalLevelEstimator, EstimatorShouldNotCrash) { TestLevelEstimator(8000, 1, 0, std::numeric_limits::lowest(), std::numeric_limits::max()); } -TEST(AutomaticGainController2LevelEstimator, +TEST(GainController2FixedDigitalLevelEstimator, EstimatorShouldEstimateConstantLevel) { TestLevelEstimator(10000, 1, kInputLevel, kInputLevel * 0.99, kInputLevel * 1.01); } -TEST(AutomaticGainController2LevelEstimator, +TEST(GainController2FixedDigitalLevelEstimator, EstimatorShouldEstimateConstantLevelForManyChannels) { constexpr size_t num_channels = 10; TestLevelEstimator(20000, num_channels, kInputLevel, kInputLevel * 0.99, kInputLevel * 1.01); } -TEST(AutomaticGainController2LevelEstimator, TimeToDecreaseForLowLevel) { +TEST(GainController2FixedDigitalLevelEstimator, TimeToDecreaseForLowLevel) { constexpr float kLevelReductionDb = 25; constexpr float kInitialLowLevel = -40; constexpr float kExpectedTime = kLevelReductionDb * test::kDecayMs; @@ -131,7 +131,8 @@ TEST(AutomaticGainController2LevelEstimator, TimeToDecreaseForLowLevel) { EXPECT_LE(time_to_decrease, kExpectedTime * 1.1); } -TEST(AutomaticGainController2LevelEstimator, TimeToDecreaseForFullScaleLevel) { +TEST(GainController2FixedDigitalLevelEstimator, + TimeToDecreaseForFullScaleLevel) { constexpr float kLevelReductionDb = 25; constexpr float kExpectedTime = kLevelReductionDb * test::kDecayMs; @@ -142,7 +143,7 @@ TEST(AutomaticGainController2LevelEstimator, TimeToDecreaseForFullScaleLevel) { EXPECT_LE(time_to_decrease, kExpectedTime * 1.1); } -TEST(AutomaticGainController2LevelEstimator, +TEST(GainController2FixedDigitalLevelEstimator, TimeToDecreaseForMultipleChannels) { constexpr float kLevelReductionDb = 25; constexpr float kExpectedTime = kLevelReductionDb * test::kDecayMs; diff --git a/modules/audio_processing/agc2/interpolated_gain_curve.h b/modules/audio_processing/agc2/interpolated_gain_curve.h index 69652c5a72..af993204ce 100644 --- a/modules/audio_processing/agc2/interpolated_gain_curve.h +++ b/modules/audio_processing/agc2/interpolated_gain_curve.h @@ -75,7 +75,7 @@ class InterpolatedGainCurve { private: // For comparing 'approximation_params_*_' with ones computed by // ComputeInterpolatedGainCurve. - FRIEND_TEST_ALL_PREFIXES(AutomaticGainController2InterpolatedGainCurve, + FRIEND_TEST_ALL_PREFIXES(GainController2InterpolatedGainCurve, CheckApproximationParams); struct RegionLogger { diff --git a/modules/audio_processing/agc2/interpolated_gain_curve_unittest.cc b/modules/audio_processing/agc2/interpolated_gain_curve_unittest.cc index 67d34e517b..7861ae997d 100644 --- a/modules/audio_processing/agc2/interpolated_gain_curve_unittest.cc +++ b/modules/audio_processing/agc2/interpolated_gain_curve_unittest.cc @@ -34,7 +34,7 @@ const LimiterDbGainCurve limiter; } // namespace -TEST(AutomaticGainController2InterpolatedGainCurve, CreateUse) { +TEST(GainController2InterpolatedGainCurve, CreateUse) { InterpolatedGainCurve igc(&apm_data_dumper, ""); const auto levels = test::LinSpace( @@ -44,7 +44,7 @@ TEST(AutomaticGainController2InterpolatedGainCurve, CreateUse) { } } -TEST(AutomaticGainController2InterpolatedGainCurve, CheckValidOutput) { +TEST(GainController2InterpolatedGainCurve, CheckValidOutput) { InterpolatedGainCurve igc(&apm_data_dumper, ""); const auto levels = test::LinSpace( @@ -57,7 +57,7 @@ TEST(AutomaticGainController2InterpolatedGainCurve, CheckValidOutput) { } } -TEST(AutomaticGainController2InterpolatedGainCurve, CheckMonotonicity) { +TEST(GainController2InterpolatedGainCurve, CheckMonotonicity) { InterpolatedGainCurve igc(&apm_data_dumper, ""); const auto levels = test::LinSpace( @@ -71,7 +71,7 @@ TEST(AutomaticGainController2InterpolatedGainCurve, CheckMonotonicity) { } } -TEST(AutomaticGainController2InterpolatedGainCurve, CheckApproximation) { +TEST(GainController2InterpolatedGainCurve, CheckApproximation) { InterpolatedGainCurve igc(&apm_data_dumper, ""); const auto levels = test::LinSpace( @@ -84,7 +84,7 @@ TEST(AutomaticGainController2InterpolatedGainCurve, CheckApproximation) { } } -TEST(AutomaticGainController2InterpolatedGainCurve, CheckRegionBoundaries) { +TEST(GainController2InterpolatedGainCurve, CheckRegionBoundaries) { InterpolatedGainCurve igc(&apm_data_dumper, ""); const std::vector levels{ @@ -102,7 +102,7 @@ TEST(AutomaticGainController2InterpolatedGainCurve, CheckRegionBoundaries) { EXPECT_EQ(1ul, stats.look_ups_saturation_region); } -TEST(AutomaticGainController2InterpolatedGainCurve, CheckIdentityRegion) { +TEST(GainController2InterpolatedGainCurve, CheckIdentityRegion) { constexpr size_t kNumSteps = 10; InterpolatedGainCurve igc(&apm_data_dumper, ""); @@ -120,8 +120,7 @@ TEST(AutomaticGainController2InterpolatedGainCurve, CheckIdentityRegion) { EXPECT_EQ(0ul, stats.look_ups_saturation_region); } -TEST(AutomaticGainController2InterpolatedGainCurve, - CheckNoOverApproximationKnee) { +TEST(GainController2InterpolatedGainCurve, CheckNoOverApproximationKnee) { constexpr size_t kNumSteps = 10; InterpolatedGainCurve igc(&apm_data_dumper, ""); @@ -142,8 +141,7 @@ TEST(AutomaticGainController2InterpolatedGainCurve, EXPECT_EQ(0ul, stats.look_ups_saturation_region); } -TEST(AutomaticGainController2InterpolatedGainCurve, - CheckNoOverApproximationBeyondKnee) { +TEST(GainController2InterpolatedGainCurve, CheckNoOverApproximationBeyondKnee) { constexpr size_t kNumSteps = 10; InterpolatedGainCurve igc(&apm_data_dumper, ""); @@ -164,7 +162,7 @@ TEST(AutomaticGainController2InterpolatedGainCurve, EXPECT_EQ(0ul, stats.look_ups_saturation_region); } -TEST(AutomaticGainController2InterpolatedGainCurve, +TEST(GainController2InterpolatedGainCurve, CheckNoOverApproximationWithSaturation) { constexpr size_t kNumSteps = 3; InterpolatedGainCurve igc(&apm_data_dumper, ""); @@ -184,7 +182,7 @@ TEST(AutomaticGainController2InterpolatedGainCurve, EXPECT_EQ(kNumSteps, stats.look_ups_saturation_region); } -TEST(AutomaticGainController2InterpolatedGainCurve, CheckApproximationParams) { +TEST(GainController2InterpolatedGainCurve, CheckApproximationParams) { test::InterpolatedParameters parameters = test::ComputeInterpolatedGainCurveApproximationParams(); diff --git a/modules/audio_processing/agc2/limiter.cc b/modules/audio_processing/agc2/limiter.cc index 11473326e1..ed7d3ee5f2 100644 --- a/modules/audio_processing/agc2/limiter.cc +++ b/modules/audio_processing/agc2/limiter.cc @@ -125,9 +125,11 @@ void Limiter::Process(AudioFrameView signal) { last_scaling_factor_ = scaling_factors_.back(); // Dump data for debug. - apm_data_dumper_->DumpRaw("agc2_gain_curve_applier_scaling_factors", - samples_per_channel, - per_sample_scaling_factors_.data()); + apm_data_dumper_->DumpRaw("agc2_limiter_last_scaling_factor", + last_scaling_factor_); + apm_data_dumper_->DumpRaw( + "agc2_limiter_region", + static_cast(interp_gain_curve_.get_stats().region)); } InterpolatedGainCurve::Stats Limiter::GetGainCurveStats() const { diff --git a/modules/audio_processing/agc2/noise_level_estimator.cc b/modules/audio_processing/agc2/noise_level_estimator.cc index 2ca5034334..10e8437d3f 100644 --- a/modules/audio_processing/agc2/noise_level_estimator.cc +++ b/modules/audio_processing/agc2/noise_level_estimator.cc @@ -18,19 +18,19 @@ #include "api/array_view.h" #include "common_audio/include/audio_util.h" +#include "modules/audio_processing/agc2/signal_classifier.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/checks.h" namespace webrtc { - namespace { constexpr int kFramesPerSecond = 100; float FrameEnergy(const AudioFrameView& audio) { - float energy = 0.f; + float energy = 0.0f; for (size_t k = 0; k < audio.num_channels(); ++k) { float channel_energy = - std::accumulate(audio.channel(k).begin(), audio.channel(k).end(), 0.f, + std::accumulate(audio.channel(k).begin(), audio.channel(k).end(), 0.0f, [](float a, float b) -> float { return a + b * b; }); energy = std::max(channel_energy, energy); } @@ -41,74 +41,220 @@ float EnergyToDbfs(float signal_energy, size_t num_samples) { const float rms = std::sqrt(signal_energy / num_samples); return FloatS16ToDbfs(rms); } -} // namespace -NoiseLevelEstimator::NoiseLevelEstimator(ApmDataDumper* data_dumper) - : signal_classifier_(data_dumper) { - Initialize(48000); -} +class NoiseLevelEstimatorImpl : public NoiseLevelEstimator { + public: + NoiseLevelEstimatorImpl(ApmDataDumper* data_dumper) + : data_dumper_(data_dumper), signal_classifier_(data_dumper) { + // Initially assume that 48 kHz will be used. `Analyze()` will detect the + // used sample rate and call `Initialize()` again if needed. + Initialize(/*sample_rate_hz=*/48000); + } + NoiseLevelEstimatorImpl(const NoiseLevelEstimatorImpl&) = delete; + NoiseLevelEstimatorImpl& operator=(const NoiseLevelEstimatorImpl&) = delete; + ~NoiseLevelEstimatorImpl() = default; -NoiseLevelEstimator::~NoiseLevelEstimator() {} + float Analyze(const AudioFrameView& frame) override { + data_dumper_->DumpRaw("agc2_noise_level_estimator_hold_counter", + noise_energy_hold_counter_); + const int sample_rate_hz = + static_cast(frame.samples_per_channel() * kFramesPerSecond); + if (sample_rate_hz != sample_rate_hz_) { + Initialize(sample_rate_hz); + } + const float frame_energy = FrameEnergy(frame); + if (frame_energy <= 0.f) { + RTC_DCHECK_GE(frame_energy, 0.f); + data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1); + return EnergyToDbfs(noise_energy_, frame.samples_per_channel()); + } -void NoiseLevelEstimator::Initialize(int sample_rate_hz) { - sample_rate_hz_ = sample_rate_hz; - noise_energy_ = 1.f; - first_update_ = true; - min_noise_energy_ = sample_rate_hz * 2.f * 2.f / kFramesPerSecond; - noise_energy_hold_counter_ = 0; - signal_classifier_.Initialize(sample_rate_hz); -} + if (first_update_) { + // Initialize the noise energy to the frame energy. + first_update_ = false; + data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1); + noise_energy_ = std::max(frame_energy, min_noise_energy_); + return EnergyToDbfs(noise_energy_, frame.samples_per_channel()); + } -float NoiseLevelEstimator::Analyze(const AudioFrameView& frame) { - const int rate = - static_cast(frame.samples_per_channel() * kFramesPerSecond); - if (rate != sample_rate_hz_) { - Initialize(rate); - } - const float frame_energy = FrameEnergy(frame); - if (frame_energy <= 0.f) { - RTC_DCHECK_GE(frame_energy, 0.f); + const SignalClassifier::SignalType signal_type = + signal_classifier_.Analyze(frame.channel(0)); + data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", + static_cast(signal_type)); + + // Update the noise estimate in a minimum statistics-type manner. + if (signal_type == SignalClassifier::SignalType::kStationary) { + if (frame_energy > noise_energy_) { + // Leak the estimate upwards towards the frame energy if no recent + // downward update. + noise_energy_hold_counter_ = + std::max(noise_energy_hold_counter_ - 1, 0); + + if (noise_energy_hold_counter_ == 0) { + constexpr float kMaxNoiseEnergyFactor = 1.01f; + noise_energy_ = + std::min(noise_energy_ * kMaxNoiseEnergyFactor, frame_energy); + } + } else { + // Update smoothly downwards with a limited maximum update magnitude. + constexpr float kMinNoiseEnergyFactor = 0.9f; + constexpr float kNoiseEnergyDeltaFactor = 0.05f; + noise_energy_ = + std::max(noise_energy_ * kMinNoiseEnergyFactor, + noise_energy_ - kNoiseEnergyDeltaFactor * + (noise_energy_ - frame_energy)); + // Prevent an energy increase for the next 10 seconds. + constexpr int kNumFramesToEnergyIncreaseAllowed = 1000; + noise_energy_hold_counter_ = kNumFramesToEnergyIncreaseAllowed; + } + } else { + // TODO(bugs.webrtc.org/7494): Remove to not forget the estimated level. + // For a non-stationary signal, leak the estimate downwards in order to + // avoid estimate locking due to incorrect signal classification. + noise_energy_ = noise_energy_ * 0.99f; + } + + // Ensure a minimum of the estimate. + noise_energy_ = std::max(noise_energy_, min_noise_energy_); return EnergyToDbfs(noise_energy_, frame.samples_per_channel()); } - if (first_update_) { - // Initialize the noise energy to the frame energy. - first_update_ = false; - return EnergyToDbfs( - noise_energy_ = std::max(frame_energy, min_noise_energy_), - frame.samples_per_channel()); + private: + void Initialize(int sample_rate_hz) { + sample_rate_hz_ = sample_rate_hz; + noise_energy_ = 1.0f; + first_update_ = true; + // Initialize the minimum noise energy to -84 dBFS. + min_noise_energy_ = sample_rate_hz * 2.0f * 2.0f / kFramesPerSecond; + noise_energy_hold_counter_ = 0; + signal_classifier_.Initialize(sample_rate_hz); } - const SignalClassifier::SignalType signal_type = - signal_classifier_.Analyze(frame.channel(0)); + ApmDataDumper* const data_dumper_; + int sample_rate_hz_; + float min_noise_energy_; + bool first_update_; + float noise_energy_; + int noise_energy_hold_counter_; + SignalClassifier signal_classifier_; +}; - // Update the noise estimate in a minimum statistics-type manner. - if (signal_type == SignalClassifier::SignalType::kStationary) { - if (frame_energy > noise_energy_) { - // Leak the estimate upwards towards the frame energy if no recent - // downward update. - noise_energy_hold_counter_ = std::max(noise_energy_hold_counter_ - 1, 0); +// Updates the noise floor with instant decay and slow attack. This tuning is +// specific for AGC2, so that (i) it can promptly increase the gain if the noise +// floor drops (instant decay) and (ii) in case of music or fast speech, due to +// which the noise floor can be overestimated, the gain reduction is slowed +// down. +float SmoothNoiseFloorEstimate(float current_estimate, float new_estimate) { + constexpr float kAttack = 0.5f; + if (current_estimate < new_estimate) { + // Attack phase. + return kAttack * new_estimate + (1.0f - kAttack) * current_estimate; + } + // Instant attack. + return new_estimate; +} - if (noise_energy_hold_counter_ == 0) { - noise_energy_ = std::min(noise_energy_ * 1.01f, frame_energy); - } +class NoiseFloorEstimator : public NoiseLevelEstimator { + public: + // Update the noise floor every 5 seconds. + static constexpr int kUpdatePeriodNumFrames = 500; + static_assert(kUpdatePeriodNumFrames >= 200, + "A too small value may cause noise level overestimation."); + static_assert(kUpdatePeriodNumFrames <= 1500, + "A too large value may make AGC2 slow at reacting to increased " + "noise levels."); + + NoiseFloorEstimator(ApmDataDumper* data_dumper) : data_dumper_(data_dumper) { + // Initially assume that 48 kHz will be used. `Analyze()` will detect the + // used sample rate and call `Initialize()` again if needed. + Initialize(/*sample_rate_hz=*/48000); + } + NoiseFloorEstimator(const NoiseFloorEstimator&) = delete; + NoiseFloorEstimator& operator=(const NoiseFloorEstimator&) = delete; + ~NoiseFloorEstimator() = default; + + float Analyze(const AudioFrameView& frame) override { + // Detect sample rate changes. + const int sample_rate_hz = + static_cast(frame.samples_per_channel() * kFramesPerSecond); + if (sample_rate_hz != sample_rate_hz_) { + Initialize(sample_rate_hz); + } + + const float frame_energy = FrameEnergy(frame); + if (frame_energy <= min_noise_energy_) { + // Ignore frames when muted or below the minimum measurable energy. + data_dumper_->DumpRaw("agc2_noise_floor_estimator_preliminary_level", + noise_energy_); + return EnergyToDbfs(noise_energy_, frame.samples_per_channel()); + } + + if (preliminary_noise_energy_set_) { + preliminary_noise_energy_ = + std::min(preliminary_noise_energy_, frame_energy); + } else { + preliminary_noise_energy_ = frame_energy; + preliminary_noise_energy_set_ = true; + } + data_dumper_->DumpRaw("agc2_noise_floor_estimator_preliminary_level", + preliminary_noise_energy_); + + if (counter_ == 0) { + // Full period observed. + first_period_ = false; + // Update the estimated noise floor energy with the preliminary + // estimation. + noise_energy_ = SmoothNoiseFloorEstimate( + /*current_estimate=*/noise_energy_, + /*new_estimate=*/preliminary_noise_energy_); + // Reset for a new observation period. + counter_ = kUpdatePeriodNumFrames; + preliminary_noise_energy_set_ = false; + } else if (first_period_) { + // While analyzing the signal during the initial period, continuously + // update the estimated noise energy, which is monotonic. + noise_energy_ = preliminary_noise_energy_; + counter_--; } else { - // Update smoothly downwards with a limited maximum update magnitude. - noise_energy_ = - std::max(noise_energy_ * 0.9f, - noise_energy_ + 0.05f * (frame_energy - noise_energy_)); - noise_energy_hold_counter_ = 1000; + // During the observation period it's only allowed to lower the energy. + noise_energy_ = std::min(noise_energy_, preliminary_noise_energy_); + counter_--; } - } else { - // For a non-stationary signal, leak the estimate downwards in order to - // avoid estimate locking due to incorrect signal classification. - noise_energy_ = noise_energy_ * 0.99f; + return EnergyToDbfs(noise_energy_, frame.samples_per_channel()); } - // Ensure a minimum of the estimate. - return EnergyToDbfs( - noise_energy_ = std::max(noise_energy_, min_noise_energy_), - frame.samples_per_channel()); + private: + void Initialize(int sample_rate_hz) { + sample_rate_hz_ = sample_rate_hz; + first_period_ = true; + preliminary_noise_energy_set_ = false; + // Initialize the minimum noise energy to -84 dBFS. + min_noise_energy_ = sample_rate_hz * 2.0f * 2.0f / kFramesPerSecond; + preliminary_noise_energy_ = min_noise_energy_; + noise_energy_ = min_noise_energy_; + counter_ = kUpdatePeriodNumFrames; + } + + ApmDataDumper* const data_dumper_; + int sample_rate_hz_; + float min_noise_energy_; + bool first_period_; + bool preliminary_noise_energy_set_; + float preliminary_noise_energy_; + float noise_energy_; + int counter_; +}; + +} // namespace + +std::unique_ptr CreateStationaryNoiseEstimator( + ApmDataDumper* data_dumper) { + return std::make_unique(data_dumper); +} + +std::unique_ptr CreateNoiseFloorEstimator( + ApmDataDumper* data_dumper) { + return std::make_unique(data_dumper); } } // namespace webrtc diff --git a/modules/audio_processing/agc2/noise_level_estimator.h b/modules/audio_processing/agc2/noise_level_estimator.h index ca2f9f2e2f..94aecda7fc 100644 --- a/modules/audio_processing/agc2/noise_level_estimator.h +++ b/modules/audio_processing/agc2/noise_level_estimator.h @@ -11,33 +11,30 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_ #define MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_ -#include "modules/audio_processing/agc2/signal_classifier.h" +#include + #include "modules/audio_processing/include/audio_frame_view.h" -#include "rtc_base/constructor_magic.h" namespace webrtc { class ApmDataDumper; +// Noise level estimator interface. class NoiseLevelEstimator { public: - NoiseLevelEstimator(ApmDataDumper* data_dumper); - ~NoiseLevelEstimator(); - // Returns the estimated noise level in dBFS. - float Analyze(const AudioFrameView& frame); - - private: - void Initialize(int sample_rate_hz); - - int sample_rate_hz_; - float min_noise_energy_; - bool first_update_; - float noise_energy_; - int noise_energy_hold_counter_; - SignalClassifier signal_classifier_; - - RTC_DISALLOW_COPY_AND_ASSIGN(NoiseLevelEstimator); + virtual ~NoiseLevelEstimator() = default; + // Analyzes a 10 ms `frame`, updates the noise level estimation and returns + // the value for the latter in dBFS. + virtual float Analyze(const AudioFrameView& frame) = 0; }; +// Creates a noise level estimator based on stationarity detection. +std::unique_ptr CreateStationaryNoiseEstimator( + ApmDataDumper* data_dumper); + +// Creates a noise level estimator based on noise floor detection. +std::unique_ptr CreateNoiseFloorEstimator( + ApmDataDumper* data_dumper); + } // namespace webrtc #endif // MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_ diff --git a/modules/audio_processing/agc2/noise_level_estimator_unittest.cc b/modules/audio_processing/agc2/noise_level_estimator_unittest.cc index c4fd33b0a0..51ad1ba00a 100644 --- a/modules/audio_processing/agc2/noise_level_estimator_unittest.cc +++ b/modules/audio_processing/agc2/noise_level_estimator_unittest.cc @@ -11,33 +11,33 @@ #include "modules/audio_processing/agc2/noise_level_estimator.h" #include +#include #include #include +#include "api/function_view.h" #include "modules/audio_processing/agc2/agc2_testing_common.h" #include "modules/audio_processing/agc2/vector_float_frame.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/gunit.h" -#include "rtc_base/random.h" namespace webrtc { namespace { -Random rand_gen(42); -ApmDataDumper data_dumper(0); + constexpr int kNumIterations = 200; constexpr int kFramesPerSecond = 100; // Runs the noise estimator on audio generated by 'sample_generator' // for kNumIterations. Returns the last noise level estimate. -float RunEstimator(std::function sample_generator, int rate) { - NoiseLevelEstimator estimator(&data_dumper); - const size_t samples_per_channel = - rtc::CheckedDivExact(rate, kFramesPerSecond); - VectorFloatFrame signal(1, static_cast(samples_per_channel), 0.f); - +float RunEstimator(rtc::FunctionView sample_generator, + NoiseLevelEstimator& estimator, + int sample_rate_hz) { + const int samples_per_channel = + rtc::CheckedDivExact(sample_rate_hz, kFramesPerSecond); + VectorFloatFrame signal(1, samples_per_channel, 0.0f); for (int i = 0; i < kNumIterations; ++i) { AudioFrameView frame_view = signal.float_frame_view(); - for (size_t j = 0; j < samples_per_channel; ++j) { + for (int j = 0; j < samples_per_channel; ++j) { frame_view.channel(0)[j] = sample_generator(); } estimator.Analyze(frame_view); @@ -45,39 +45,92 @@ float RunEstimator(std::function sample_generator, int rate) { return estimator.Analyze(signal.float_frame_view()); } -float WhiteNoiseGenerator() { - return static_cast(rand_gen.Rand(std::numeric_limits::min(), - std::numeric_limits::max())); -} -} // namespace +class NoiseEstimatorParametrization : public ::testing::TestWithParam { + protected: + int sample_rate_hz() const { return GetParam(); } +}; // White random noise is stationary, but does not trigger the detector // every frame due to the randomness. -TEST(AutomaticGainController2NoiseEstimator, RandomNoise) { - for (const auto rate : {8000, 16000, 32000, 48000}) { - const float noise_level = RunEstimator(WhiteNoiseGenerator, rate); - EXPECT_NEAR(noise_level, -5.f, 1.f); - } +TEST_P(NoiseEstimatorParametrization, StationaryNoiseEstimatorWithRandomNoise) { + ApmDataDumper data_dumper(0); + auto estimator = CreateStationaryNoiseEstimator(&data_dumper); + + test::WhiteNoiseGenerator gen(/*min_amplitude=*/test::kMinS16, + /*max_amplitude=*/test::kMaxS16); + const float noise_level_dbfs = + RunEstimator(gen, *estimator, sample_rate_hz()); + EXPECT_NEAR(noise_level_dbfs, -5.5f, 1.0f); } // Sine curves are (very) stationary. They trigger the detector all // the time. Except for a few initial frames. -TEST(AutomaticGainController2NoiseEstimator, SineTone) { - for (const auto rate : {8000, 16000, 32000, 48000}) { - test::SineGenerator gen(600.f, rate); - const float noise_level = RunEstimator(gen, rate); - EXPECT_NEAR(noise_level, -33.f, 1.f); - } +TEST_P(NoiseEstimatorParametrization, StationaryNoiseEstimatorWithSineTone) { + ApmDataDumper data_dumper(0); + auto estimator = CreateStationaryNoiseEstimator(&data_dumper); + + test::SineGenerator gen(/*amplitude=*/test::kMaxS16, /*frequency_hz=*/600.0f, + sample_rate_hz()); + const float noise_level_dbfs = + RunEstimator(gen, *estimator, sample_rate_hz()); + EXPECT_NEAR(noise_level_dbfs, -3.0f, 1.0f); } // Pulses are transient if they are far enough apart. They shouldn't // trigger the noise detector. -TEST(AutomaticGainController2NoiseEstimator, PulseTone) { - for (const auto rate : {8000, 16000, 32000, 48000}) { - test::PulseGenerator gen(20.f, rate); - const int noise_level = RunEstimator(gen, rate); - EXPECT_NEAR(noise_level, -79.f, 1.f); - } +TEST_P(NoiseEstimatorParametrization, StationaryNoiseEstimatorWithPulseTone) { + ApmDataDumper data_dumper(0); + auto estimator = CreateStationaryNoiseEstimator(&data_dumper); + + test::PulseGenerator gen(/*pulse_amplitude=*/test::kMaxS16, + /*no_pulse_amplitude=*/10.0f, /*frequency_hz=*/20.0f, + sample_rate_hz()); + const int noise_level_dbfs = RunEstimator(gen, *estimator, sample_rate_hz()); + EXPECT_NEAR(noise_level_dbfs, -79.0f, 1.0f); } +// Checks that full scale white noise maps to about -5.5 dBFS. +TEST_P(NoiseEstimatorParametrization, NoiseFloorEstimatorWithRandomNoise) { + ApmDataDumper data_dumper(0); + auto estimator = CreateNoiseFloorEstimator(&data_dumper); + + test::WhiteNoiseGenerator gen(/*min_amplitude=*/test::kMinS16, + /*max_amplitude=*/test::kMaxS16); + const float noise_level_dbfs = + RunEstimator(gen, *estimator, sample_rate_hz()); + EXPECT_NEAR(noise_level_dbfs, -5.5f, 0.5f); +} + +// Checks that a full scale sine wave maps to about -3 dBFS. +TEST_P(NoiseEstimatorParametrization, NoiseFloorEstimatorWithSineTone) { + ApmDataDumper data_dumper(0); + auto estimator = CreateNoiseFloorEstimator(&data_dumper); + + test::SineGenerator gen(/*amplitude=*/test::kMaxS16, /*frequency_hz=*/600.0f, + sample_rate_hz()); + const float noise_level_dbfs = + RunEstimator(gen, *estimator, sample_rate_hz()); + EXPECT_NEAR(noise_level_dbfs, -3.0f, 0.1f); +} + +// Check that sufficiently spaced periodic pulses do not raise the estimated +// noise floor, which is determined by the amplitude of the non-pulse samples. +TEST_P(NoiseEstimatorParametrization, NoiseFloorEstimatorWithPulseTone) { + ApmDataDumper data_dumper(0); + auto estimator = CreateNoiseFloorEstimator(&data_dumper); + + constexpr float kNoPulseAmplitude = 10.0f; + test::PulseGenerator gen(/*pulse_amplitude=*/test::kMaxS16, kNoPulseAmplitude, + /*frequency_hz=*/20.0f, sample_rate_hz()); + const int noise_level_dbfs = RunEstimator(gen, *estimator, sample_rate_hz()); + const float expected_noise_floor_dbfs = + 20.0f * std::log10f(kNoPulseAmplitude / test::kMaxS16); + EXPECT_NEAR(noise_level_dbfs, expected_noise_floor_dbfs, 0.5f); +} + +INSTANTIATE_TEST_SUITE_P(GainController2NoiseEstimator, + NoiseEstimatorParametrization, + ::testing::Values(8000, 16000, 32000, 48000)); + +} // namespace } // namespace webrtc diff --git a/modules/audio_processing/agc2/noise_spectrum_estimator.cc b/modules/audio_processing/agc2/noise_spectrum_estimator.cc index 31438b1f49..f283f4e27f 100644 --- a/modules/audio_processing/agc2/noise_spectrum_estimator.cc +++ b/modules/audio_processing/agc2/noise_spectrum_estimator.cc @@ -63,8 +63,8 @@ void NoiseSpectrumEstimator::Update(rtc::ArrayView spectrum, v = std::max(v, kMinNoisePower); } - data_dumper_->DumpRaw("lc_noise_spectrum", 65, noise_spectrum_); - data_dumper_->DumpRaw("lc_signal_spectrum", spectrum); + data_dumper_->DumpRaw("agc2_noise_spectrum", 65, noise_spectrum_); + data_dumper_->DumpRaw("agc2_signal_spectrum", spectrum); } } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 4732efd082..bc848b3e13 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -312,20 +312,22 @@ if (rtc_include_tests) { } } - rtc_executable("rnn_vad_tool") { - testonly = true - sources = [ "rnn_vad_tool.cc" ] - deps = [ - ":rnn_vad", - ":rnn_vad_common", - "..:cpu_features", - "../../../../api:array_view", - "../../../../common_audio", - "../../../../rtc_base:rtc_base_approved", - "../../../../rtc_base:safe_compare", - "../../../../test:test_support", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - ] + if (!build_with_chromium) { + rtc_executable("rnn_vad_tool") { + testonly = true + sources = [ "rnn_vad_tool.cc" ] + deps = [ + ":rnn_vad", + ":rnn_vad_common", + "..:cpu_features", + "../../../../api:array_view", + "../../../../common_audio", + "../../../../rtc_base:rtc_base_approved", + "../../../../rtc_base:safe_compare", + "../../../../test:test_support", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] + } } } diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_fc.cc b/modules/audio_processing/agc2/rnn_vad/rnn_fc.cc index b04807f19f..ecbb198c96 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_fc.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_fc.cc @@ -56,10 +56,8 @@ rtc::FunctionView GetActivationFunction( switch (activation_function) { case ActivationFunction::kTansigApproximated: return ::rnnoise::TansigApproximated; - break; case ActivationFunction::kSigmoidApproximated: return ::rnnoise::SigmoidApproximated; - break; } } diff --git a/modules/audio_processing/agc2/saturation_protector.cc b/modules/audio_processing/agc2/saturation_protector.cc index b64fcdb71f..d6f21ef891 100644 --- a/modules/audio_processing/agc2/saturation_protector.cc +++ b/modules/audio_processing/agc2/saturation_protector.cc @@ -10,84 +10,59 @@ #include "modules/audio_processing/agc2/saturation_protector.h" +#include + +#include "modules/audio_processing/agc2/agc2_common.h" +#include "modules/audio_processing/agc2/saturation_protector_buffer.h" #include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" #include "rtc_base/numerics/safe_minmax.h" namespace webrtc { namespace { -constexpr float kMinLevelDbfs = -90.f; - -// Min/max margins are based on speech crest-factor. -constexpr float kMinMarginDb = 12.f; -constexpr float kMaxMarginDb = 25.f; - -using saturation_protector_impl::RingBuffer; - -} // namespace - -bool RingBuffer::operator==(const RingBuffer& b) const { - RTC_DCHECK_LE(size_, buffer_.size()); - RTC_DCHECK_LE(b.size_, b.buffer_.size()); - if (size_ != b.size_) { - return false; - } - for (int i = 0, i0 = FrontIndex(), i1 = b.FrontIndex(); i < size_; - ++i, ++i0, ++i1) { - if (buffer_[i0 % buffer_.size()] != b.buffer_[i1 % b.buffer_.size()]) { - return false; - } - } - return true; -} - -void RingBuffer::Reset() { - next_ = 0; - size_ = 0; -} - -void RingBuffer::PushBack(float v) { - RTC_DCHECK_GE(next_, 0); - RTC_DCHECK_GE(size_, 0); - RTC_DCHECK_LT(next_, buffer_.size()); - RTC_DCHECK_LE(size_, buffer_.size()); - buffer_[next_++] = v; - if (rtc::SafeEq(next_, buffer_.size())) { - next_ = 0; +constexpr int kPeakEnveloperSuperFrameLengthMs = 400; +constexpr float kMinMarginDb = 12.0f; +constexpr float kMaxMarginDb = 25.0f; +constexpr float kAttack = 0.9988493699365052f; +constexpr float kDecay = 0.9997697679981565f; + +// Saturation protector state. Defined outside of `SaturationProtectorImpl` to +// implement check-point and restore ops. +struct SaturationProtectorState { + bool operator==(const SaturationProtectorState& s) const { + return headroom_db == s.headroom_db && + peak_delay_buffer == s.peak_delay_buffer && + max_peaks_dbfs == s.max_peaks_dbfs && + time_since_push_ms == s.time_since_push_ms; } - if (rtc::SafeLt(size_, buffer_.size())) { - size_++; + inline bool operator!=(const SaturationProtectorState& s) const { + return !(*this == s); } -} -absl::optional RingBuffer::Front() const { - if (size_ == 0) { - return absl::nullopt; - } - RTC_DCHECK_LT(FrontIndex(), buffer_.size()); - return buffer_[FrontIndex()]; -} + float headroom_db; + SaturationProtectorBuffer peak_delay_buffer; + float max_peaks_dbfs; + int time_since_push_ms; // Time since the last ring buffer push operation. +}; -bool SaturationProtectorState::operator==( - const SaturationProtectorState& b) const { - return margin_db == b.margin_db && peak_delay_buffer == b.peak_delay_buffer && - max_peaks_dbfs == b.max_peaks_dbfs && - time_since_push_ms == b.time_since_push_ms; -} - -void ResetSaturationProtectorState(float initial_margin_db, +// Resets the saturation protector state. +void ResetSaturationProtectorState(float initial_headroom_db, SaturationProtectorState& state) { - state.margin_db = initial_margin_db; + state.headroom_db = initial_headroom_db; state.peak_delay_buffer.Reset(); state.max_peaks_dbfs = kMinLevelDbfs; state.time_since_push_ms = 0; } -void UpdateSaturationProtectorState(float speech_peak_dbfs, +// Updates `state` by analyzing the estimated speech level `speech_level_dbfs` +// and the peak level `peak_dbfs` for an observed frame. `state` must not be +// modified without calling this function. +void UpdateSaturationProtectorState(float peak_dbfs, float speech_level_dbfs, SaturationProtectorState& state) { // Get the max peak over `kPeakEnveloperSuperFrameLengthMs` ms. - state.max_peaks_dbfs = std::max(state.max_peaks_dbfs, speech_peak_dbfs); + state.max_peaks_dbfs = std::max(state.max_peaks_dbfs, peak_dbfs); state.time_since_push_ms += kFrameDurationMs; if (rtc::SafeGt(state.time_since_push_ms, kPeakEnveloperSuperFrameLengthMs)) { // Push `max_peaks_dbfs` back into the ring buffer. @@ -97,25 +72,117 @@ void UpdateSaturationProtectorState(float speech_peak_dbfs, state.time_since_push_ms = 0; } - // Update margin by comparing the estimated speech level and the delayed max - // speech peak power. - // TODO(alessiob): Check with aleloi@ why we use a delay and how to tune it. + // Update the headroom by comparing the estimated speech level and the delayed + // max speech peak. const float delayed_peak_dbfs = state.peak_delay_buffer.Front().value_or(state.max_peaks_dbfs); const float difference_db = delayed_peak_dbfs - speech_level_dbfs; - if (difference_db > state.margin_db) { + if (difference_db > state.headroom_db) { // Attack. - state.margin_db = - state.margin_db * kSaturationProtectorAttackConstant + - difference_db * (1.f - kSaturationProtectorAttackConstant); + state.headroom_db = + state.headroom_db * kAttack + difference_db * (1.0f - kAttack); } else { // Decay. - state.margin_db = state.margin_db * kSaturationProtectorDecayConstant + - difference_db * (1.f - kSaturationProtectorDecayConstant); + state.headroom_db = + state.headroom_db * kDecay + difference_db * (1.0f - kDecay); + } + + state.headroom_db = + rtc::SafeClamp(state.headroom_db, kMinMarginDb, kMaxMarginDb); +} + +// Saturation protector which recommends a headroom based on the recent peaks. +class SaturationProtectorImpl : public SaturationProtector { + public: + explicit SaturationProtectorImpl(float initial_headroom_db, + float extra_headroom_db, + int adjacent_speech_frames_threshold, + ApmDataDumper* apm_data_dumper) + : apm_data_dumper_(apm_data_dumper), + initial_headroom_db_(initial_headroom_db), + extra_headroom_db_(extra_headroom_db), + adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold) { + Reset(); + } + SaturationProtectorImpl(const SaturationProtectorImpl&) = delete; + SaturationProtectorImpl& operator=(const SaturationProtectorImpl&) = delete; + ~SaturationProtectorImpl() = default; + + float HeadroomDb() override { return headroom_db_; } + + void Analyze(float speech_probability, + float peak_dbfs, + float speech_level_dbfs) override { + if (speech_probability < kVadConfidenceThreshold) { + // Not a speech frame. + if (adjacent_speech_frames_threshold_ > 1) { + // When two or more adjacent speech frames are required in order to + // update the state, we need to decide whether to discard or confirm the + // updates based on the speech sequence length. + if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) { + // First non-speech frame after a long enough sequence of speech + // frames. Update the reliable state. + reliable_state_ = preliminary_state_; + } else if (num_adjacent_speech_frames_ > 0) { + // First non-speech frame after a too short sequence of speech frames. + // Reset to the last reliable state. + preliminary_state_ = reliable_state_; + } + } + num_adjacent_speech_frames_ = 0; + } else { + // Speech frame observed. + num_adjacent_speech_frames_++; + + // Update preliminary level estimate. + UpdateSaturationProtectorState(peak_dbfs, speech_level_dbfs, + preliminary_state_); + + if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) { + // `preliminary_state_` is now reliable. Update the headroom. + headroom_db_ = preliminary_state_.headroom_db + extra_headroom_db_; + } + } + DumpDebugData(); } - state.margin_db = - rtc::SafeClamp(state.margin_db, kMinMarginDb, kMaxMarginDb); + void Reset() override { + num_adjacent_speech_frames_ = 0; + headroom_db_ = initial_headroom_db_ + extra_headroom_db_; + ResetSaturationProtectorState(initial_headroom_db_, preliminary_state_); + ResetSaturationProtectorState(initial_headroom_db_, reliable_state_); + } + + private: + void DumpDebugData() { + apm_data_dumper_->DumpRaw( + "agc2_saturation_protector_preliminary_max_peak_dbfs", + preliminary_state_.max_peaks_dbfs); + apm_data_dumper_->DumpRaw( + "agc2_saturation_protector_reliable_max_peak_dbfs", + reliable_state_.max_peaks_dbfs); + } + + ApmDataDumper* const apm_data_dumper_; + const float initial_headroom_db_; + const float extra_headroom_db_; + const int adjacent_speech_frames_threshold_; + int num_adjacent_speech_frames_; + float headroom_db_; + SaturationProtectorState preliminary_state_; + SaturationProtectorState reliable_state_; +}; + +} // namespace + +std::unique_ptr CreateSaturationProtector( + float initial_headroom_db, + float extra_headroom_db, + int adjacent_speech_frames_threshold, + ApmDataDumper* apm_data_dumper) { + return std::make_unique( + initial_headroom_db, extra_headroom_db, adjacent_speech_frames_threshold, + apm_data_dumper); } } // namespace webrtc diff --git a/modules/audio_processing/agc2/saturation_protector.h b/modules/audio_processing/agc2/saturation_protector.h index 88be91a79b..0c384f1fa0 100644 --- a/modules/audio_processing/agc2/saturation_protector.h +++ b/modules/audio_processing/agc2/saturation_protector.h @@ -11,71 +11,36 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_ #define MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_ -#include - -#include "absl/types/optional.h" -#include "modules/audio_processing/agc2/agc2_common.h" -#include "rtc_base/numerics/safe_compare.h" +#include namespace webrtc { -namespace saturation_protector_impl { +class ApmDataDumper; -// Ring buffer which only supports (i) push back and (ii) read oldest item. -class RingBuffer { +// Saturation protector. Analyzes peak levels and recommends a headroom to +// reduce the chances of clipping. +class SaturationProtector { public: - bool operator==(const RingBuffer& b) const; - inline bool operator!=(const RingBuffer& b) const { return !(*this == b); } - - // Maximum number of values that the buffer can contain. - int Capacity() const { return buffer_.size(); } - // Number of values in the buffer. - int Size() const { return size_; } - - void Reset(); - // Pushes back `v`. If the buffer is full, the oldest value is replaced. - void PushBack(float v); - // Returns the oldest item in the buffer. Returns an empty value if the - // buffer is empty. - absl::optional Front() const; + virtual ~SaturationProtector() = default; - private: - inline int FrontIndex() const { - return rtc::SafeEq(size_, buffer_.size()) ? next_ : 0; - } - // `buffer_` has `size_` elements (up to the size of `buffer_`) and `next_` is - // the position where the next new value is written in `buffer_`. - std::array buffer_; - int next_ = 0; - int size_ = 0; -}; - -} // namespace saturation_protector_impl + // Returns the recommended headroom in dB. + virtual float HeadroomDb() = 0; -// Saturation protector state. Exposed publicly for check-pointing and restore -// ops. -struct SaturationProtectorState { - bool operator==(const SaturationProtectorState& s) const; - inline bool operator!=(const SaturationProtectorState& s) const { - return !(*this == s); - } + // Analyzes the peak level of a 10 ms frame along with its speech probability + // and the current speech level estimate to update the recommended headroom. + virtual void Analyze(float speech_probability, + float peak_dbfs, + float speech_level_dbfs) = 0; - float margin_db; // Recommended margin. - saturation_protector_impl::RingBuffer peak_delay_buffer; - float max_peaks_dbfs; - int time_since_push_ms; // Time since the last ring buffer push operation. + // Resets the internal state. + virtual void Reset() = 0; }; -// Resets the saturation protector state. -void ResetSaturationProtectorState(float initial_margin_db, - SaturationProtectorState& state); - -// Updates `state` by analyzing the estimated speech level `speech_level_dbfs` -// and the peak power `speech_peak_dbfs` for an observed frame which is -// reliably classified as "speech". `state` must not be modified without calling -// this function. -void UpdateSaturationProtectorState(float speech_peak_dbfs, - float speech_level_dbfs, - SaturationProtectorState& state); +// Creates a saturation protector that starts at `initial_headroom_db`. +std::unique_ptr CreateSaturationProtector( + float initial_headroom_db, + float extra_headroom_db, + int adjacent_speech_frames_threshold, + ApmDataDumper* apm_data_dumper); } // namespace webrtc diff --git a/modules/audio_processing/agc2/saturation_protector_buffer.cc b/modules/audio_processing/agc2/saturation_protector_buffer.cc new file mode 100644 index 0000000000..41efdad2c8 --- /dev/null +++ b/modules/audio_processing/agc2/saturation_protector_buffer.cc @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/saturation_protector_buffer.h" + +#include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_compare.h" + +namespace webrtc { + +SaturationProtectorBuffer::SaturationProtectorBuffer() = default; + +SaturationProtectorBuffer::~SaturationProtectorBuffer() = default; + +bool SaturationProtectorBuffer::operator==( + const SaturationProtectorBuffer& b) const { + RTC_DCHECK_LE(size_, buffer_.size()); + RTC_DCHECK_LE(b.size_, b.buffer_.size()); + if (size_ != b.size_) { + return false; + } + for (int i = 0, i0 = FrontIndex(), i1 = b.FrontIndex(); i < size_; + ++i, ++i0, ++i1) { + if (buffer_[i0 % buffer_.size()] != b.buffer_[i1 % b.buffer_.size()]) { + return false; + } + } + return true; +} + +int SaturationProtectorBuffer::Capacity() const { + return buffer_.size(); +} + +int SaturationProtectorBuffer::Size() const { + return size_; +} + +void SaturationProtectorBuffer::Reset() { + next_ = 0; + size_ = 0; +} + +void SaturationProtectorBuffer::PushBack(float v) { + RTC_DCHECK_GE(next_, 0); + RTC_DCHECK_GE(size_, 0); + RTC_DCHECK_LT(next_, buffer_.size()); + RTC_DCHECK_LE(size_, buffer_.size()); + buffer_[next_++] = v; + if (rtc::SafeEq(next_, buffer_.size())) { + next_ = 0; + } + if (rtc::SafeLt(size_, buffer_.size())) { + size_++; + } +} + +absl::optional SaturationProtectorBuffer::Front() const { + if (size_ == 0) { + return absl::nullopt; + } + RTC_DCHECK_LT(FrontIndex(), buffer_.size()); + return buffer_[FrontIndex()]; +} + +int SaturationProtectorBuffer::FrontIndex() const { + return rtc::SafeEq(size_, buffer_.size()) ? next_ : 0; +} + +} // namespace webrtc diff --git a/modules/audio_processing/agc2/saturation_protector_buffer.h b/modules/audio_processing/agc2/saturation_protector_buffer.h new file mode 100644 index 0000000000..e17d0998c4 --- /dev/null +++ b/modules/audio_processing/agc2/saturation_protector_buffer.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_BUFFER_H_ + +#include + +#include "absl/types/optional.h" +#include "modules/audio_processing/agc2/agc2_common.h" + +namespace webrtc { + +// Ring buffer for the saturation protector which only supports (i) push back +// and (ii) read oldest item. +class SaturationProtectorBuffer { + public: + SaturationProtectorBuffer(); + ~SaturationProtectorBuffer(); + + bool operator==(const SaturationProtectorBuffer& b) const; + inline bool operator!=(const SaturationProtectorBuffer& b) const { + return !(*this == b); + } + + // Maximum number of values that the buffer can contain. + int Capacity() const; + + // Number of values in the buffer. + int Size() const; + + void Reset(); + + // Pushes back `v`. If the buffer is full, the oldest value is replaced. + void PushBack(float v); + + // Returns the oldest item in the buffer. Returns an empty value if the + // buffer is empty. + absl::optional Front() const; + + private: + int FrontIndex() const; + // `buffer_` has `size_` elements (up to the size of `buffer_`) and `next_` is + // the position where the next new value is written in `buffer_`. + std::array buffer_; + int next_ = 0; + int size_ = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_BUFFER_H_ diff --git a/modules/audio_processing/agc2/saturation_protector_buffer_unittest.cc b/modules/audio_processing/agc2/saturation_protector_buffer_unittest.cc new file mode 100644 index 0000000000..22187bf027 --- /dev/null +++ b/modules/audio_processing/agc2/saturation_protector_buffer_unittest.cc @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/saturation_protector_buffer.h" + +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +using ::testing::Eq; +using ::testing::Optional; + +TEST(GainController2SaturationProtectorBuffer, Init) { + SaturationProtectorBuffer b; + EXPECT_EQ(b.Size(), 0); + EXPECT_FALSE(b.Front().has_value()); +} + +TEST(GainController2SaturationProtectorBuffer, PushBack) { + SaturationProtectorBuffer b; + constexpr float kValue = 123.0f; + b.PushBack(kValue); + EXPECT_EQ(b.Size(), 1); + EXPECT_THAT(b.Front(), Optional(Eq(kValue))); +} + +TEST(GainController2SaturationProtectorBuffer, Reset) { + SaturationProtectorBuffer b; + b.PushBack(123.0f); + b.Reset(); + EXPECT_EQ(b.Size(), 0); + EXPECT_FALSE(b.Front().has_value()); +} + +// Checks that the front value does not change until the ring buffer gets full. +TEST(GainController2SaturationProtectorBuffer, FrontUntilBufferIsFull) { + SaturationProtectorBuffer b; + constexpr float kValue = 123.0f; + b.PushBack(kValue); + for (int i = 1; i < b.Capacity(); ++i) { + SCOPED_TRACE(i); + EXPECT_THAT(b.Front(), Optional(Eq(kValue))); + b.PushBack(kValue + i); + } +} + +// Checks that when the buffer is full it behaves as a shift register. +TEST(GainController2SaturationProtectorBuffer, FrontIsDelayed) { + SaturationProtectorBuffer b; + // Fill the buffer. + for (int i = 0; i < b.Capacity(); ++i) { + b.PushBack(i); + } + // The ring buffer should now behave as a shift register with a delay equal to + // its capacity. + for (int i = b.Capacity(); i < 2 * b.Capacity() + 1; ++i) { + SCOPED_TRACE(i); + EXPECT_THAT(b.Front(), Optional(Eq(i - b.Capacity()))); + b.PushBack(i); + } +} + +} // namespace +} // namespace webrtc diff --git a/modules/audio_processing/agc2/saturation_protector_unittest.cc b/modules/audio_processing/agc2/saturation_protector_unittest.cc index 2c5ee5b036..dc16dc254c 100644 --- a/modules/audio_processing/agc2/saturation_protector_unittest.cc +++ b/modules/audio_processing/agc2/saturation_protector_unittest.cc @@ -10,181 +10,166 @@ #include "modules/audio_processing/agc2/saturation_protector.h" -#include - #include "modules/audio_processing/agc2/agc2_common.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/gunit.h" -#include "test/gmock.h" namespace webrtc { namespace { -constexpr float kInitialMarginDb = 20.f; - -using saturation_protector_impl::RingBuffer; - -SaturationProtectorState CreateSaturationProtectorState() { - SaturationProtectorState state; - ResetSaturationProtectorState(kInitialMarginDb, state); - return state; -} +constexpr float kInitialHeadroomDb = 20.0f; +constexpr float kNoExtraHeadroomDb = 0.0f; +constexpr int kNoAdjacentSpeechFramesRequired = 1; +constexpr float kMaxSpeechProbability = 1.0f; -// Updates `state` for `num_iterations` times with constant speech level and -// peak powers and returns the maximum margin. +// Calls `Analyze(speech_probability, peak_dbfs, speech_level_dbfs)` +// `num_iterations` times on `saturation_protector` and return the largest +// headroom difference between two consecutive calls. float RunOnConstantLevel(int num_iterations, - float speech_peak_dbfs, + float speech_probability, + float peak_dbfs, float speech_level_dbfs, - SaturationProtectorState& state) { - float last_margin = state.margin_db; - float max_difference = 0.f; + SaturationProtector& saturation_protector) { + float last_headroom = saturation_protector.HeadroomDb(); + float max_difference = 0.0f; for (int i = 0; i < num_iterations; ++i) { - UpdateSaturationProtectorState(speech_peak_dbfs, speech_level_dbfs, state); - const float new_margin = state.margin_db; + saturation_protector.Analyze(speech_probability, peak_dbfs, + speech_level_dbfs); + const float new_headroom = saturation_protector.HeadroomDb(); max_difference = - std::max(max_difference, std::abs(new_margin - last_margin)); - last_margin = new_margin; + std::max(max_difference, std::fabs(new_headroom - last_headroom)); + last_headroom = new_headroom; } return max_difference; } -} // namespace - -TEST(AutomaticGainController2SaturationProtector, RingBufferInit) { - RingBuffer b; - EXPECT_EQ(b.Size(), 0); - EXPECT_FALSE(b.Front().has_value()); -} - -TEST(AutomaticGainController2SaturationProtector, RingBufferPushBack) { - RingBuffer b; - constexpr float kValue = 123.f; - b.PushBack(kValue); - EXPECT_EQ(b.Size(), 1); - ASSERT_TRUE(b.Front().has_value()); - EXPECT_EQ(b.Front().value(), kValue); +// Checks that the returned headroom value is correctly reset. +TEST(GainController2SaturationProtector, Reset) { + ApmDataDumper apm_data_dumper(0); + auto saturation_protector = CreateSaturationProtector( + kInitialHeadroomDb, kNoExtraHeadroomDb, kNoAdjacentSpeechFramesRequired, + &apm_data_dumper); + const float initial_headroom_db = saturation_protector->HeadroomDb(); + RunOnConstantLevel(/*num_iterations=*/10, kMaxSpeechProbability, + /*peak_dbfs=*/0.0f, + /*speech_level_dbfs=*/-10.0f, *saturation_protector); + // Make sure that there are side-effects. + ASSERT_NE(initial_headroom_db, saturation_protector->HeadroomDb()); + saturation_protector->Reset(); + EXPECT_EQ(initial_headroom_db, saturation_protector->HeadroomDb()); } -TEST(AutomaticGainController2SaturationProtector, RingBufferReset) { - RingBuffer b; - b.PushBack(123.f); - b.Reset(); - EXPECT_EQ(b.Size(), 0); - EXPECT_FALSE(b.Front().has_value()); +// Checks that the estimate converges to the ratio between peaks and level +// estimator values after a while. +TEST(GainController2SaturationProtector, EstimatesCrestRatio) { + constexpr int kNumIterations = 2000; + constexpr float kPeakLevelDbfs = -20.0f; + constexpr float kCrestFactorDb = kInitialHeadroomDb + 1.0f; + constexpr float kSpeechLevelDbfs = kPeakLevelDbfs - kCrestFactorDb; + const float kMaxDifferenceDb = + 0.5f * std::fabs(kInitialHeadroomDb - kCrestFactorDb); + + ApmDataDumper apm_data_dumper(0); + auto saturation_protector = CreateSaturationProtector( + kInitialHeadroomDb, kNoExtraHeadroomDb, kNoAdjacentSpeechFramesRequired, + &apm_data_dumper); + RunOnConstantLevel(kNumIterations, kMaxSpeechProbability, kPeakLevelDbfs, + kSpeechLevelDbfs, *saturation_protector); + EXPECT_NEAR(saturation_protector->HeadroomDb(), kCrestFactorDb, + kMaxDifferenceDb); } -// Checks that the front value does not change until the ring buffer gets full. -TEST(AutomaticGainController2SaturationProtector, - RingBufferFrontUntilBufferIsFull) { - RingBuffer b; - constexpr float kValue = 123.f; - b.PushBack(kValue); - for (int i = 1; i < b.Capacity(); ++i) { - EXPECT_EQ(b.Front().value(), kValue); - b.PushBack(kValue + i); +// Checks that the extra headroom is applied. +TEST(GainController2SaturationProtector, ExtraHeadroomApplied) { + constexpr float kExtraHeadroomDb = 5.1234f; + constexpr int kNumIterations = 10; + constexpr float kPeakLevelDbfs = -20.0f; + constexpr float kSpeechLevelDbfs = kPeakLevelDbfs - 15.0f; + + ApmDataDumper apm_data_dumper(0); + + auto saturation_protector_no_extra = CreateSaturationProtector( + kInitialHeadroomDb, kNoExtraHeadroomDb, kNoAdjacentSpeechFramesRequired, + &apm_data_dumper); + for (int i = 0; i < kNumIterations; ++i) { + saturation_protector_no_extra->Analyze(kMaxSpeechProbability, + kPeakLevelDbfs, kSpeechLevelDbfs); } -} -// Checks that when the buffer is full it behaves as a shift register. -TEST(AutomaticGainController2SaturationProtector, - FullRingBufferFrontIsDelayed) { - RingBuffer b; - // Fill the buffer. - for (int i = 0; i < b.Capacity(); ++i) { - b.PushBack(i); - } - // The ring buffer should now behave as a shift register with a delay equal to - // its capacity. - for (int i = b.Capacity(); i < 2 * b.Capacity() + 1; ++i) { - EXPECT_EQ(b.Front().value(), i - b.Capacity()); - b.PushBack(i); + auto saturation_protector_extra = CreateSaturationProtector( + kInitialHeadroomDb, kExtraHeadroomDb, kNoAdjacentSpeechFramesRequired, + &apm_data_dumper); + for (int i = 0; i < kNumIterations; ++i) { + saturation_protector_extra->Analyze(kMaxSpeechProbability, kPeakLevelDbfs, + kSpeechLevelDbfs); } -} -// Checks that a state after reset equals a state after construction. -TEST(AutomaticGainController2SaturationProtector, ResetState) { - SaturationProtectorState init_state; - ResetSaturationProtectorState(kInitialMarginDb, init_state); - - SaturationProtectorState state; - ResetSaturationProtectorState(kInitialMarginDb, state); - RunOnConstantLevel(/*num_iterations=*/10, /*speech_level_dbfs=*/-20.f, - /*speech_peak_dbfs=*/-10.f, state); - ASSERT_NE(init_state, state); // Make sure that there are side-effects. - ResetSaturationProtectorState(kInitialMarginDb, state); - - EXPECT_EQ(init_state, state); -} - -// Checks that the estimate converges to the ratio between peaks and level -// estimator values after a while. -TEST(AutomaticGainController2SaturationProtector, - ProtectorEstimatesCrestRatio) { - constexpr int kNumIterations = 2000; - constexpr float kPeakLevel = -20.f; - constexpr float kCrestFactor = kInitialMarginDb + 1.f; - constexpr float kSpeechLevel = kPeakLevel - kCrestFactor; - const float kMaxDifference = 0.5f * std::abs(kInitialMarginDb - kCrestFactor); - - auto state = CreateSaturationProtectorState(); - RunOnConstantLevel(kNumIterations, kPeakLevel, kSpeechLevel, state); - - EXPECT_NEAR(state.margin_db, kCrestFactor, kMaxDifference); + EXPECT_EQ(saturation_protector_no_extra->HeadroomDb() + kExtraHeadroomDb, + saturation_protector_extra->HeadroomDb()); } -// Checks that the margin does not change too quickly. -TEST(AutomaticGainController2SaturationProtector, ChangeSlowly) { +// Checks that the headroom does not change too quickly. +TEST(GainController2SaturationProtector, ChangeSlowly) { constexpr int kNumIterations = 1000; - constexpr float kPeakLevel = -20.f; - constexpr float kCrestFactor = kInitialMarginDb - 5.f; - constexpr float kOtherCrestFactor = kInitialMarginDb; - constexpr float kSpeechLevel = kPeakLevel - kCrestFactor; - constexpr float kOtherSpeechLevel = kPeakLevel - kOtherCrestFactor; - - auto state = CreateSaturationProtectorState(); - float max_difference = - RunOnConstantLevel(kNumIterations, kPeakLevel, kSpeechLevel, state); - max_difference = std::max( - RunOnConstantLevel(kNumIterations, kPeakLevel, kOtherSpeechLevel, state), - max_difference); - + constexpr float kPeakLevelDbfs = -20.f; + constexpr float kCrestFactorDb = kInitialHeadroomDb - 5.f; + constexpr float kOtherCrestFactorDb = kInitialHeadroomDb; + constexpr float kSpeechLevelDbfs = kPeakLevelDbfs - kCrestFactorDb; + constexpr float kOtherSpeechLevelDbfs = kPeakLevelDbfs - kOtherCrestFactorDb; + + ApmDataDumper apm_data_dumper(0); + auto saturation_protector = CreateSaturationProtector( + kInitialHeadroomDb, kNoExtraHeadroomDb, kNoAdjacentSpeechFramesRequired, + &apm_data_dumper); + float max_difference_db = + RunOnConstantLevel(kNumIterations, kMaxSpeechProbability, kPeakLevelDbfs, + kSpeechLevelDbfs, *saturation_protector); + max_difference_db = std::max( + RunOnConstantLevel(kNumIterations, kMaxSpeechProbability, kPeakLevelDbfs, + kOtherSpeechLevelDbfs, *saturation_protector), + max_difference_db); constexpr float kMaxChangeSpeedDbPerSecond = 0.5f; // 1 db / 2 seconds. - EXPECT_LE(max_difference, + EXPECT_LE(max_difference_db, kMaxChangeSpeedDbPerSecond / 1000 * kFrameDurationMs); } -// Checks that there is a delay between input change and margin adaptations. -TEST(AutomaticGainController2SaturationProtector, AdaptToDelayedChanges) { - constexpr int kDelayIterations = kFullBufferSizeMs / kFrameDurationMs; - constexpr float kInitialSpeechLevelDbfs = -30.f; - constexpr float kLaterSpeechLevelDbfs = -15.f; - - auto state = CreateSaturationProtectorState(); - // First run on initial level. - float max_difference = RunOnConstantLevel( - kDelayIterations, kInitialSpeechLevelDbfs + kInitialMarginDb, - kInitialSpeechLevelDbfs, state); - // Then peak changes, but not RMS. - max_difference = - std::max(RunOnConstantLevel(kDelayIterations, - kLaterSpeechLevelDbfs + kInitialMarginDb, - kInitialSpeechLevelDbfs, state), - max_difference); - // Then both change. - max_difference = - std::max(RunOnConstantLevel(kDelayIterations, - kLaterSpeechLevelDbfs + kInitialMarginDb, - kLaterSpeechLevelDbfs, state), - max_difference); - - // The saturation protector expects that the RMS changes roughly - // 'kFullBufferSizeMs' after peaks change. This is to account for delay - // introduced by the level estimator. Therefore, the input above is 'normal' - // and 'expected', and shouldn't influence the margin by much. - const float total_difference = std::abs(state.margin_db - kInitialMarginDb); - - EXPECT_LE(total_difference, 0.05f); - EXPECT_LE(max_difference, 0.01f); +class SaturationProtectorParametrization + : public ::testing::TestWithParam { + protected: + int adjacent_speech_frames_threshold() const { return GetParam(); } +}; + +TEST_P(SaturationProtectorParametrization, DoNotAdaptToShortSpeechSegments) { + ApmDataDumper apm_data_dumper(0); + auto saturation_protector = CreateSaturationProtector( + kInitialHeadroomDb, kNoExtraHeadroomDb, + adjacent_speech_frames_threshold(), &apm_data_dumper); + const float initial_headroom_db = saturation_protector->HeadroomDb(); + RunOnConstantLevel(/*num_iterations=*/adjacent_speech_frames_threshold() - 1, + kMaxSpeechProbability, + /*peak_dbfs=*/0.0f, + /*speech_level_dbfs=*/-10.0f, *saturation_protector); + // No adaptation expected. + EXPECT_EQ(initial_headroom_db, saturation_protector->HeadroomDb()); } +TEST_P(SaturationProtectorParametrization, AdaptToEnoughSpeechSegments) { + ApmDataDumper apm_data_dumper(0); + auto saturation_protector = CreateSaturationProtector( + kInitialHeadroomDb, kNoExtraHeadroomDb, + adjacent_speech_frames_threshold(), &apm_data_dumper); + const float initial_headroom_db = saturation_protector->HeadroomDb(); + RunOnConstantLevel(/*num_iterations=*/adjacent_speech_frames_threshold() + 1, + kMaxSpeechProbability, + /*peak_dbfs=*/0.0f, + /*speech_level_dbfs=*/-10.0f, *saturation_protector); + // Adaptation expected. + EXPECT_NE(initial_headroom_db, saturation_protector->HeadroomDb()); +} + +INSTANTIATE_TEST_SUITE_P(GainController2, + SaturationProtectorParametrization, + ::testing::Values(2, 9, 17)); + +} // namespace } // namespace webrtc diff --git a/modules/audio_processing/agc2/signal_classifier.cc b/modules/audio_processing/agc2/signal_classifier.cc index a06413d166..3ef8dd775b 100644 --- a/modules/audio_processing/agc2/signal_classifier.cc +++ b/modules/audio_processing/agc2/signal_classifier.cc @@ -84,8 +84,8 @@ webrtc::SignalClassifier::SignalType ClassifySignal( } } - data_dumper->DumpRaw("lc_num_stationary_bands", 1, &num_stationary_bands); - data_dumper->DumpRaw("lc_num_highly_nonstationary_bands", 1, + data_dumper->DumpRaw("agc2_num_stationary_bands", 1, &num_stationary_bands); + data_dumper->DumpRaw("agc2_num_highly_nonstationary_bands", 1, &num_highly_nonstationary_bands); // Use the detected number of bands to classify the overall signal diff --git a/modules/audio_processing/agc2/signal_classifier_unittest.cc b/modules/audio_processing/agc2/signal_classifier_unittest.cc index 62171b32e6..f1a3a664f0 100644 --- a/modules/audio_processing/agc2/signal_classifier_unittest.cc +++ b/modules/audio_processing/agc2/signal_classifier_unittest.cc @@ -14,25 +14,25 @@ #include #include +#include "api/function_view.h" #include "modules/audio_processing/agc2/agc2_testing_common.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/gunit.h" #include "rtc_base/random.h" namespace webrtc { - namespace { -Random rand_gen(42); -ApmDataDumper data_dumper(0); constexpr int kNumIterations = 100; // Runs the signal classifier on audio generated by 'sample_generator' // for kNumIterations. Returns the number of frames classified as noise. -int RunClassifier(std::function sample_generator, int rate) { +float RunClassifier(rtc::FunctionView sample_generator, + int sample_rate_hz) { + ApmDataDumper data_dumper(0); SignalClassifier classifier(&data_dumper); std::array signal; - classifier.Initialize(rate); - const size_t samples_per_channel = rtc::CheckedDivExact(rate, 100); + classifier.Initialize(sample_rate_hz); + const size_t samples_per_channel = rtc::CheckedDivExact(sample_rate_hz, 100); int number_of_noise_frames = 0; for (int i = 0; i < kNumIterations; ++i) { for (size_t j = 0; j < samples_per_channel; ++j) { @@ -45,38 +45,42 @@ int RunClassifier(std::function sample_generator, int rate) { return number_of_noise_frames; } -float WhiteNoiseGenerator() { - return static_cast(rand_gen.Rand(std::numeric_limits::min(), - std::numeric_limits::max())); -} -} // namespace +class SignalClassifierParametrization : public ::testing::TestWithParam { + protected: + int sample_rate_hz() const { return GetParam(); } +}; // White random noise is stationary, but does not trigger the detector // every frame due to the randomness. -TEST(AutomaticGainController2SignalClassifier, WhiteNoise) { - for (const auto rate : {8000, 16000, 32000, 48000}) { - const int number_of_noise_frames = RunClassifier(WhiteNoiseGenerator, rate); - EXPECT_GT(number_of_noise_frames, kNumIterations / 2); - } +TEST_P(SignalClassifierParametrization, WhiteNoise) { + test::WhiteNoiseGenerator gen(/*min_amplitude=*/test::kMinS16, + /*max_amplitude=*/test::kMaxS16); + const int number_of_noise_frames = RunClassifier(gen, sample_rate_hz()); + EXPECT_GT(number_of_noise_frames, kNumIterations / 2); } // Sine curves are (very) stationary. They trigger the detector all // the time. Except for a few initial frames. -TEST(AutomaticGainController2SignalClassifier, SineTone) { - for (const auto rate : {8000, 16000, 32000, 48000}) { - test::SineGenerator gen(600.f, rate); - const int number_of_noise_frames = RunClassifier(gen, rate); - EXPECT_GE(number_of_noise_frames, kNumIterations - 5); - } +TEST_P(SignalClassifierParametrization, SineTone) { + test::SineGenerator gen(/*amplitude=*/test::kMaxS16, /*frequency_hz=*/600.0f, + sample_rate_hz()); + const int number_of_noise_frames = RunClassifier(gen, sample_rate_hz()); + EXPECT_GE(number_of_noise_frames, kNumIterations - 5); } // Pulses are transient if they are far enough apart. They shouldn't // trigger the noise detector. -TEST(AutomaticGainController2SignalClassifier, PulseTone) { - for (const auto rate : {8000, 16000, 32000, 48000}) { - test::PulseGenerator gen(30.f, rate); - const int number_of_noise_frames = RunClassifier(gen, rate); - EXPECT_EQ(number_of_noise_frames, 0); - } +TEST_P(SignalClassifierParametrization, PulseTone) { + test::PulseGenerator gen(/*pulse_amplitude=*/test::kMaxS16, + /*no_pulse_amplitude=*/10.0f, /*frequency_hz=*/20.0f, + sample_rate_hz()); + const int number_of_noise_frames = RunClassifier(gen, sample_rate_hz()); + EXPECT_EQ(number_of_noise_frames, 0); } + +INSTANTIATE_TEST_SUITE_P(GainController2SignalClassifier, + SignalClassifierParametrization, + ::testing::Values(8000, 16000, 32000, 48000)); + +} // namespace } // namespace webrtc diff --git a/modules/audio_processing/agc2/vad_with_level.cc b/modules/audio_processing/agc2/vad_with_level.cc index b54ae564da..9747ca2370 100644 --- a/modules/audio_processing/agc2/vad_with_level.cc +++ b/modules/audio_processing/agc2/vad_with_level.cc @@ -38,6 +38,8 @@ class Vad : public VoiceActivityDetector { Vad& operator=(const Vad&) = delete; ~Vad() = default; + void Reset() override { rnn_vad_.Reset(); } + float ComputeProbability(AudioFrameView frame) override { // The source number of channels is 1, because we always use the 1st // channel. @@ -63,53 +65,41 @@ class Vad : public VoiceActivityDetector { rnn_vad::RnnVad rnn_vad_; }; -// Returns an updated version of `p_old` by using instant decay and the given -// `attack` on a new VAD probability value `p_new`. -float SmoothedVadProbability(float p_old, float p_new, float attack) { - RTC_DCHECK_GT(attack, 0.f); - RTC_DCHECK_LE(attack, 1.f); - if (p_new < p_old || attack == 1.f) { - // Instant decay (or no smoothing). - return p_new; - } else { - // Attack phase. - return attack * p_new + (1.f - attack) * p_old; - } -} - } // namespace -VadLevelAnalyzer::VadLevelAnalyzer() - : VadLevelAnalyzer(kDefaultSmoothedVadProbabilityAttack, - GetAvailableCpuFeatures()) {} - -VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack, +VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms, const AvailableCpuFeatures& cpu_features) - : VadLevelAnalyzer(vad_probability_attack, + : VadLevelAnalyzer(vad_reset_period_ms, std::make_unique(cpu_features)) {} -VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack, +VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms, std::unique_ptr vad) - : vad_(std::move(vad)), vad_probability_attack_(vad_probability_attack) { + : vad_(std::move(vad)), + vad_reset_period_frames_( + rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)), + time_to_vad_reset_(vad_reset_period_frames_) { RTC_DCHECK(vad_); + RTC_DCHECK_GT(vad_reset_period_frames_, 1); } VadLevelAnalyzer::~VadLevelAnalyzer() = default; VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame( AudioFrameView frame) { + // Periodically reset the VAD. + time_to_vad_reset_--; + if (time_to_vad_reset_ <= 0) { + vad_->Reset(); + time_to_vad_reset_ = vad_reset_period_frames_; + } // Compute levels. - float peak = 0.f; - float rms = 0.f; + float peak = 0.0f; + float rms = 0.0f; for (const auto& x : frame.channel(0)) { peak = std::max(std::fabs(x), peak); rms += x * x; } - // Compute smoothed speech probability. - vad_probability_ = SmoothedVadProbability( - /*p_old=*/vad_probability_, /*p_new=*/vad_->ComputeProbability(frame), - vad_probability_attack_); - return {vad_probability_, + return {vad_->ComputeProbability(frame), FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel())), FloatS16ToDbfs(peak)}; } diff --git a/modules/audio_processing/agc2/vad_with_level.h b/modules/audio_processing/agc2/vad_with_level.h index 2a6788278e..8d2ae45762 100644 --- a/modules/audio_processing/agc2/vad_with_level.h +++ b/modules/audio_processing/agc2/vad_with_level.h @@ -31,17 +31,21 @@ class VadLevelAnalyzer { class VoiceActivityDetector { public: virtual ~VoiceActivityDetector() = default; + // Resets the internal state. + virtual void Reset() = 0; // Analyzes an audio frame and returns the speech probability. virtual float ComputeProbability(AudioFrameView frame) = 0; }; - // Ctor. Uses the default VAD. - VadLevelAnalyzer(); - VadLevelAnalyzer(float vad_probability_attack, + // Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call + // `VadLevelAnalyzer::Reset()`; it must be equal to or greater than the + // duration of two frames. Uses `cpu_features` to instantiate the default VAD. + VadLevelAnalyzer(int vad_reset_period_ms, const AvailableCpuFeatures& cpu_features); // Ctor. Uses a custom `vad`. - VadLevelAnalyzer(float vad_probability_attack, + VadLevelAnalyzer(int vad_reset_period_ms, std::unique_ptr vad); + VadLevelAnalyzer(const VadLevelAnalyzer&) = delete; VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete; ~VadLevelAnalyzer(); @@ -51,8 +55,8 @@ class VadLevelAnalyzer { private: std::unique_ptr vad_; - const float vad_probability_attack_; - float vad_probability_ = 0.f; + const int vad_reset_period_frames_; + int time_to_vad_reset_; }; } // namespace webrtc diff --git a/modules/audio_processing/agc2/vad_with_level_unittest.cc b/modules/audio_processing/agc2/vad_with_level_unittest.cc index fb93c86417..ec8e476965 100644 --- a/modules/audio_processing/agc2/vad_with_level_unittest.cc +++ b/modules/audio_processing/agc2/vad_with_level_unittest.cc @@ -10,6 +10,7 @@ #include "modules/audio_processing/agc2/vad_with_level.h" +#include #include #include @@ -25,13 +26,14 @@ namespace { using ::testing::AnyNumber; using ::testing::ReturnRoundRobin; -constexpr float kInstantAttack = 1.f; -constexpr float kSlowAttack = 0.1f; +constexpr int kNoVadPeriodicReset = + kFrameDurationMs * (std::numeric_limits::max() / kFrameDurationMs); constexpr int kSampleRateHz = 8000; class MockVad : public VadLevelAnalyzer::VoiceActivityDetector { public: + MOCK_METHOD(void, Reset, (), (override)); MOCK_METHOD(float, ComputeProbability, (AudioFrameView frame), @@ -42,20 +44,24 @@ class MockVad : public VadLevelAnalyzer::VoiceActivityDetector { // the next value from `speech_probabilities` until it reaches the end and will // restart from the beginning. std::unique_ptr CreateVadLevelAnalyzerWithMockVad( - float vad_probability_attack, - const std::vector& speech_probabilities) { + int vad_reset_period_ms, + const std::vector& speech_probabilities, + int expected_vad_reset_calls = 0) { auto vad = std::make_unique(); EXPECT_CALL(*vad, ComputeProbability) .Times(AnyNumber()) .WillRepeatedly(ReturnRoundRobin(speech_probabilities)); - return std::make_unique(vad_probability_attack, + if (expected_vad_reset_calls >= 0) { + EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls); + } + return std::make_unique(vad_reset_period_ms, std::move(vad)); } // 10 ms mono frame. struct FrameWithView { // Ctor. Initializes the frame samples with `value`. - FrameWithView(float value = 0.f) + FrameWithView(float value = 0.0f) : channel0(samples.data()), view(&channel0, /*num_channels=*/1, samples.size()) { samples.fill(value); @@ -65,27 +71,26 @@ struct FrameWithView { const AudioFrameView view; }; -TEST(AutomaticGainController2VadLevelAnalyzer, PeakLevelGreaterThanRmsLevel) { +TEST(GainController2VadLevelAnalyzer, RmsLessThanPeakLevel) { + auto analyzer = CreateVadLevelAnalyzerWithMockVad( + /*vad_reset_period_ms=*/1500, + /*speech_probabilities=*/{1.0f}, + /*expected_vad_reset_calls=*/0); // Handcrafted frame so that the average is lower than the peak value. - FrameWithView frame(1000.f); // Constant frame. - frame.samples[10] = 2000.f; // Except for one peak value. - - // Compute audio frame levels (the VAD result is ignored). - VadLevelAnalyzer analyzer; - auto levels_and_vad_prob = analyzer.AnalyzeFrame(frame.view); - - // Compare peak and RMS levels. + FrameWithView frame(1000.0f); // Constant frame. + frame.samples[10] = 2000.0f; // Except for one peak value. + // Compute audio frame levels. + auto levels_and_vad_prob = analyzer->AnalyzeFrame(frame.view); EXPECT_LT(levels_and_vad_prob.rms_dbfs, levels_and_vad_prob.peak_dbfs); } -// Checks that the unprocessed and the smoothed speech probabilities match when -// instant attack is used. -TEST(AutomaticGainController2VadLevelAnalyzer, NoSpeechProbabilitySmoothing) { +// Checks that the expect VAD probabilities are returned. +TEST(GainController2VadLevelAnalyzer, NoSpeechProbabilitySmoothing) { const std::vector speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f, 0.44f, 0.525f, 0.858f, 0.314f, - 0.653f, 0.965f, 0.413f, 0.f}; - auto analyzer = - CreateVadLevelAnalyzerWithMockVad(kInstantAttack, speech_probabilities); + 0.653f, 0.965f, 0.413f, 0.0f}; + auto analyzer = CreateVadLevelAnalyzerWithMockVad(kNoVadPeriodicReset, + speech_probabilities); FrameWithView frame; for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) { SCOPED_TRACE(i); @@ -94,37 +99,41 @@ TEST(AutomaticGainController2VadLevelAnalyzer, NoSpeechProbabilitySmoothing) { } } -// Checks that the smoothed speech probability does not instantly converge to -// the unprocessed one when slow attack is used. -TEST(AutomaticGainController2VadLevelAnalyzer, - SlowAttackSpeechProbabilitySmoothing) { - const std::vector speech_probabilities{0.f, 0.f, 1.f, 1.f, 1.f, 1.f}; - auto analyzer = - CreateVadLevelAnalyzerWithMockVad(kSlowAttack, speech_probabilities); +// Checks that the VAD is not periodically reset. +TEST(GainController2VadLevelAnalyzer, VadNoPeriodicReset) { + constexpr int kNumFrames = 19; + auto analyzer = CreateVadLevelAnalyzerWithMockVad( + kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f}, + /*expected_vad_reset_calls=*/0); FrameWithView frame; - float prev_probability = 0.f; - for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) { - SCOPED_TRACE(i); - const float smoothed_probability = - analyzer->AnalyzeFrame(frame.view).speech_probability; - EXPECT_LT(smoothed_probability, 1.f); // Not enough time to reach 1. - EXPECT_LE(prev_probability, smoothed_probability); // Converge towards 1. - prev_probability = smoothed_probability; + for (int i = 0; i < kNumFrames; ++i) { + analyzer->AnalyzeFrame(frame.view); } } -// Checks that the smoothed speech probability instantly decays to the -// unprocessed one when slow attack is used. -TEST(AutomaticGainController2VadLevelAnalyzer, SpeechProbabilityInstantDecay) { - const std::vector speech_probabilities{1.f, 1.f, 1.f, 1.f, 1.f, 0.f}; - auto analyzer = - CreateVadLevelAnalyzerWithMockVad(kSlowAttack, speech_probabilities); +class VadPeriodResetParametrization + : public ::testing::TestWithParam> { + protected: + int num_frames() const { return std::get<0>(GetParam()); } + int vad_reset_period_frames() const { return std::get<1>(GetParam()); } +}; + +// Checks that the VAD is periodically reset with the expected period. +TEST_P(VadPeriodResetParametrization, VadPeriodicReset) { + auto analyzer = CreateVadLevelAnalyzerWithMockVad( + /*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs, + /*speech_probabilities=*/{1.0f}, + /*expected_vad_reset_calls=*/num_frames() / vad_reset_period_frames()); FrameWithView frame; - for (int i = 0; rtc::SafeLt(i, speech_probabilities.size() - 1); ++i) { + for (int i = 0; i < num_frames(); ++i) { analyzer->AnalyzeFrame(frame.view); } - EXPECT_EQ(0.f, analyzer->AnalyzeFrame(frame.view).speech_probability); } +INSTANTIATE_TEST_SUITE_P(GainController2VadLevelAnalyzer, + VadPeriodResetParametrization, + ::testing::Combine(::testing::Values(1, 19, 123), + ::testing::Values(2, 5, 20, 53))); + } // namespace } // namespace webrtc diff --git a/modules/audio_processing/audio_processing_impl.cc b/modules/audio_processing/audio_processing_impl.cc index 37112f0888..4a1985545f 100644 --- a/modules/audio_processing/audio_processing_impl.cc +++ b/modules/audio_processing/audio_processing_impl.cc @@ -23,7 +23,6 @@ #include "common_audio/audio_converter.h" #include "common_audio/include/audio_util.h" #include "modules/audio_processing/aec_dump/aec_dump_factory.h" -#include "modules/audio_processing/agc2/gain_applier.h" #include "modules/audio_processing/audio_buffer.h" #include "modules/audio_processing/common.h" #include "modules/audio_processing/include/audio_frame_view.h" @@ -49,8 +48,6 @@ namespace webrtc { -constexpr int kRuntimeSettingQueueSize = 100; - namespace { static bool LayoutHasKeyboard(AudioProcessing::ChannelLayout layout) { @@ -117,6 +114,10 @@ GainControl::Mode Agc1ConfigModeToInterfaceMode( RTC_CHECK_NOTREACHED(); } +bool MinimizeProcessingForUnusedOutput() { + return !field_trial::IsEnabled("WebRTC-MutedStateKillSwitch"); +} + // Maximum lengths that frame of samples being passed from the render side to // the capture side can have (does not apply to AEC3). static const size_t kMaxAllowedValuesOfSamplesPerBand = 160; @@ -147,7 +148,7 @@ bool AudioProcessingImpl::SubmoduleStates::Update( bool noise_suppressor_enabled, bool adaptive_gain_controller_enabled, bool gain_controller2_enabled, - bool pre_amplifier_enabled, + bool gain_adjustment_enabled, bool echo_controller_enabled, bool voice_detector_enabled, bool transient_suppressor_enabled) { @@ -161,7 +162,7 @@ bool AudioProcessingImpl::SubmoduleStates::Update( changed |= (adaptive_gain_controller_enabled != adaptive_gain_controller_enabled_); changed |= (gain_controller2_enabled != gain_controller2_enabled_); - changed |= (pre_amplifier_enabled_ != pre_amplifier_enabled); + changed |= (gain_adjustment_enabled != gain_adjustment_enabled_); changed |= (echo_controller_enabled != echo_controller_enabled_); changed |= (voice_detector_enabled != voice_detector_enabled_); changed |= (transient_suppressor_enabled != transient_suppressor_enabled_); @@ -172,7 +173,7 @@ bool AudioProcessingImpl::SubmoduleStates::Update( noise_suppressor_enabled_ = noise_suppressor_enabled; adaptive_gain_controller_enabled_ = adaptive_gain_controller_enabled; gain_controller2_enabled_ = gain_controller2_enabled; - pre_amplifier_enabled_ = pre_amplifier_enabled; + gain_adjustment_enabled_ = gain_adjustment_enabled; echo_controller_enabled_ = echo_controller_enabled; voice_detector_enabled_ = voice_detector_enabled; transient_suppressor_enabled_ = transient_suppressor_enabled; @@ -204,7 +205,7 @@ bool AudioProcessingImpl::SubmoduleStates::CaptureMultiBandProcessingActive( bool AudioProcessingImpl::SubmoduleStates::CaptureFullBandProcessingActive() const { return gain_controller2_enabled_ || capture_post_processor_enabled_ || - pre_amplifier_enabled_; + gain_adjustment_enabled_; } bool AudioProcessingImpl::SubmoduleStates::CaptureAnalyzerActive() const { @@ -253,8 +254,8 @@ AudioProcessingImpl::AudioProcessingImpl( new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))), use_setup_specific_default_aec3_config_( UseSetupSpecificDefaultAec3Congfig()), - capture_runtime_settings_(kRuntimeSettingQueueSize), - render_runtime_settings_(kRuntimeSettingQueueSize), + capture_runtime_settings_(RuntimeSettingQueueSize()), + render_runtime_settings_(RuntimeSettingQueueSize()), capture_runtime_settings_enqueuer_(&capture_runtime_settings_), render_runtime_settings_enqueuer_(&render_runtime_settings_), echo_control_factory_(std::move(echo_control_factory)), @@ -269,7 +270,10 @@ AudioProcessingImpl::AudioProcessingImpl( "WebRTC-ApmExperimentalMultiChannelRenderKillSwitch"), !field_trial::IsEnabled( "WebRTC-ApmExperimentalMultiChannelCaptureKillSwitch"), - EnforceSplitBandHpf()), + EnforceSplitBandHpf(), + MinimizeProcessingForUnusedOutput(), + field_trial::IsEnabled("WebRTC-TransientSuppressorForcedOff")), + capture_(), capture_nonlocked_() { RTC_LOG(LS_INFO) << "Injected APM submodules:" "\nEcho control factory: " @@ -287,8 +291,7 @@ AudioProcessingImpl::AudioProcessingImpl( // If no echo detector is injected, use the ResidualEchoDetector. if (!submodules_.echo_detector) { - submodules_.echo_detector = - new rtc::RefCountedObject(); + submodules_.echo_detector = rtc::make_ref_counted(); } #if !(defined(WEBRTC_ANDROID) || defined(WEBRTC_IOS)) @@ -304,8 +307,6 @@ AudioProcessingImpl::AudioProcessingImpl( config.Get().startup_min_volume; config_.gain_controller1.analog_gain_controller.clipped_level_min = config.Get().clipped_level_min; - config_.gain_controller1.analog_gain_controller.enable_agc2_level_estimator = - config.Get().enabled_agc2_level_estimator; config_.gain_controller1.analog_gain_controller.enable_digital_adaptive = !config.Get().digital_adaptive_disabled; #endif @@ -426,6 +427,7 @@ void AudioProcessingImpl::InitializeLocked() { InitializeAnalyzer(); InitializePostProcessor(); InitializePreProcessor(); + InitializeCaptureLevelsAdjuster(); if (aec_dump_) { aec_dump_->WriteInitMessage(formats_.api_format, rtc::TimeUTCMillis()); @@ -567,6 +569,9 @@ void AudioProcessingImpl::ApplyConfig(const AudioProcessing::Config& config) { config_.pre_amplifier.fixed_gain_factor != config.pre_amplifier.fixed_gain_factor; + const bool gain_adjustment_config_changed = + config_.capture_level_adjustment != config.capture_level_adjustment; + config_ = config; if (aec_config_changed) { @@ -598,8 +603,8 @@ void AudioProcessingImpl::ApplyConfig(const AudioProcessing::Config& config) { InitializeGainController2(); } - if (pre_amplifier_config_changed) { - InitializePreAmplifier(); + if (pre_amplifier_config_changed || gain_adjustment_config_changed) { + InitializeCaptureLevelsAdjuster(); } if (config_.level_estimation.enabled && !submodules_.output_level_estimator) { @@ -666,35 +671,60 @@ size_t AudioProcessingImpl::num_output_channels() const { void AudioProcessingImpl::set_output_will_be_muted(bool muted) { MutexLock lock(&mutex_capture_); - capture_.output_will_be_muted = muted; + HandleCaptureOutputUsedSetting(!muted); +} + +void AudioProcessingImpl::HandleCaptureOutputUsedSetting( + bool capture_output_used) { + capture_.capture_output_used = + capture_output_used || !constants_.minimize_processing_for_unused_output; + if (submodules_.agc_manager.get()) { - submodules_.agc_manager->SetCaptureMuted(capture_.output_will_be_muted); + submodules_.agc_manager->HandleCaptureOutputUsedChange( + capture_.capture_output_used); + } + if (submodules_.echo_controller) { + submodules_.echo_controller->SetCaptureOutputUsage( + capture_.capture_output_used); + } + if (submodules_.noise_suppressor) { + submodules_.noise_suppressor->SetCaptureOutputUsage( + capture_.capture_output_used); } } void AudioProcessingImpl::SetRuntimeSetting(RuntimeSetting setting) { + PostRuntimeSetting(setting); +} + +bool AudioProcessingImpl::PostRuntimeSetting(RuntimeSetting setting) { switch (setting.type()) { case RuntimeSetting::Type::kCustomRenderProcessingRuntimeSetting: case RuntimeSetting::Type::kPlayoutAudioDeviceChange: - render_runtime_settings_enqueuer_.Enqueue(setting); - return; + return render_runtime_settings_enqueuer_.Enqueue(setting); case RuntimeSetting::Type::kCapturePreGain: + case RuntimeSetting::Type::kCapturePostGain: case RuntimeSetting::Type::kCaptureCompressionGain: case RuntimeSetting::Type::kCaptureFixedPostGain: case RuntimeSetting::Type::kCaptureOutputUsed: - capture_runtime_settings_enqueuer_.Enqueue(setting); - return; - case RuntimeSetting::Type::kPlayoutVolumeChange: - capture_runtime_settings_enqueuer_.Enqueue(setting); - render_runtime_settings_enqueuer_.Enqueue(setting); - return; + return capture_runtime_settings_enqueuer_.Enqueue(setting); + case RuntimeSetting::Type::kPlayoutVolumeChange: { + bool enqueueing_successful; + enqueueing_successful = + capture_runtime_settings_enqueuer_.Enqueue(setting); + enqueueing_successful = + render_runtime_settings_enqueuer_.Enqueue(setting) && + enqueueing_successful; + return enqueueing_successful; + } case RuntimeSetting::Type::kNotSpecified: RTC_NOTREACHED(); - return; + return true; } // The language allows the enum to have a non-enumerator // value. Check that this doesn't happen. RTC_NOTREACHED(); + return true; } AudioProcessingImpl::RuntimeSettingEnqueuer::RuntimeSettingEnqueuer( @@ -706,20 +736,15 @@ AudioProcessingImpl::RuntimeSettingEnqueuer::RuntimeSettingEnqueuer( AudioProcessingImpl::RuntimeSettingEnqueuer::~RuntimeSettingEnqueuer() = default; -void AudioProcessingImpl::RuntimeSettingEnqueuer::Enqueue( +bool AudioProcessingImpl::RuntimeSettingEnqueuer::Enqueue( RuntimeSetting setting) { - int remaining_attempts = 10; - while (!runtime_settings_.Insert(&setting) && remaining_attempts-- > 0) { - RuntimeSetting setting_to_discard; - if (runtime_settings_.Remove(&setting_to_discard)) { - RTC_LOG(LS_ERROR) - << "The runtime settings queue is full. Oldest setting discarded."; - } - } - if (remaining_attempts == 0) { + const bool successful_insert = runtime_settings_.Insert(&setting); + + if (!successful_insert) { RTC_HISTOGRAM_BOOLEAN("WebRTC.Audio.ApmRuntimeSettingCannotEnqueue", 1); RTC_LOG(LS_ERROR) << "Cannot enqueue a new runtime setting."; } + return successful_insert; } int AudioProcessingImpl::MaybeInitializeCapture( @@ -793,17 +818,48 @@ int AudioProcessingImpl::ProcessStream(const float* const* src, void AudioProcessingImpl::HandleCaptureRuntimeSettings() { RuntimeSetting setting; + int num_settings_processed = 0; while (capture_runtime_settings_.Remove(&setting)) { if (aec_dump_) { aec_dump_->WriteRuntimeSetting(setting); } switch (setting.type()) { case RuntimeSetting::Type::kCapturePreGain: - if (config_.pre_amplifier.enabled) { + if (config_.pre_amplifier.enabled || + config_.capture_level_adjustment.enabled) { + float value; + setting.GetFloat(&value); + // If the pre-amplifier is used, apply the new gain to the + // pre-amplifier regardless if the capture level adjustment is + // activated. This approach allows both functionalities to coexist + // until they have been properly merged. + if (config_.pre_amplifier.enabled) { + config_.pre_amplifier.fixed_gain_factor = value; + } else { + config_.capture_level_adjustment.pre_gain_factor = value; + } + + // Use both the pre-amplifier and the capture level adjustment gains + // as pre-gains. + float gain = 1.f; + if (config_.pre_amplifier.enabled) { + gain *= config_.pre_amplifier.fixed_gain_factor; + } + if (config_.capture_level_adjustment.enabled) { + gain *= config_.capture_level_adjustment.pre_gain_factor; + } + + submodules_.capture_levels_adjuster->SetPreGain(gain); + } + // TODO(bugs.chromium.org/9138): Log setting handling by Aec Dump. + break; + case RuntimeSetting::Type::kCapturePostGain: + if (config_.capture_level_adjustment.enabled) { float value; setting.GetFloat(&value); - config_.pre_amplifier.fixed_gain_factor = value; - submodules_.pre_amplifier->SetGainFactor(value); + config_.capture_level_adjustment.post_gain_factor = value; + submodules_.capture_levels_adjuster->SetPostGain( + config_.capture_level_adjustment.post_gain_factor); } // TODO(bugs.chromium.org/9138): Log setting handling by Aec Dump. break; @@ -846,13 +902,27 @@ void AudioProcessingImpl::HandleCaptureRuntimeSettings() { RTC_NOTREACHED(); break; case RuntimeSetting::Type::kCaptureOutputUsed: - // TODO(b/154437967): Add support for reducing complexity when it is - // known that the capture output will not be used. + bool value; + setting.GetBool(&value); + HandleCaptureOutputUsedSetting(value); break; } + ++num_settings_processed; + } + + if (num_settings_processed >= RuntimeSettingQueueSize()) { + // Handle overrun of the runtime settings queue, which likely will has + // caused settings to be discarded. + HandleOverrunInCaptureRuntimeSettingsQueue(); } } +void AudioProcessingImpl::HandleOverrunInCaptureRuntimeSettingsQueue() { + // Fall back to a safe state for the case when a setting for capture output + // usage setting has been missed. + HandleCaptureOutputUsedSetting(/*capture_output_used=*/true); +} + void AudioProcessingImpl::HandleRenderRuntimeSettings() { RuntimeSetting setting; while (render_runtime_settings_.Remove(&setting)) { @@ -868,6 +938,7 @@ void AudioProcessingImpl::HandleRenderRuntimeSettings() { } break; case RuntimeSetting::Type::kCapturePreGain: // fall-through + case RuntimeSetting::Type::kCapturePostGain: // fall-through case RuntimeSetting::Type::kCaptureCompressionGain: // fall-through case RuntimeSetting::Type::kCaptureFixedPostGain: // fall-through case RuntimeSetting::Type::kCaptureOutputUsed: // fall-through @@ -1055,10 +1126,21 @@ int AudioProcessingImpl::ProcessCaptureStreamLocked() { /*use_split_band_data=*/false); } - if (submodules_.pre_amplifier) { - submodules_.pre_amplifier->ApplyGain(AudioFrameView( - capture_buffer->channels(), capture_buffer->num_channels(), - capture_buffer->num_frames())); + if (submodules_.capture_levels_adjuster) { + // If the analog mic gain emulation is active, get the emulated analog mic + // gain and pass it to the analog gain control functionality. + if (config_.capture_level_adjustment.analog_mic_gain_emulation.enabled) { + int level = submodules_.capture_levels_adjuster->GetAnalogMicGainLevel(); + if (submodules_.agc_manager) { + submodules_.agc_manager->set_stream_analog_level(level); + } else if (submodules_.gain_control) { + int error = submodules_.gain_control->set_stream_analog_level(level); + RTC_DCHECK_EQ(kNoError, error); + } + } + + submodules_.capture_levels_adjuster->ApplyPreLevelAdjustment( + *capture_buffer); } capture_input_rms_.Analyze(rtc::ArrayView( @@ -1082,14 +1164,15 @@ int AudioProcessingImpl::ProcessCaptureStreamLocked() { capture_.prev_analog_mic_level != -1; capture_.prev_analog_mic_level = analog_mic_level; - // Detect and flag any change in the pre-amplifier gain. - if (submodules_.pre_amplifier) { - float pre_amp_gain = submodules_.pre_amplifier->GetGainFactor(); + // Detect and flag any change in the capture level adjustment pre-gain. + if (submodules_.capture_levels_adjuster) { + float pre_adjustment_gain = + submodules_.capture_levels_adjuster->GetPreAdjustmentGain(); capture_.echo_path_gain_change = capture_.echo_path_gain_change || - (capture_.prev_pre_amp_gain != pre_amp_gain && - capture_.prev_pre_amp_gain >= 0.f); - capture_.prev_pre_amp_gain = pre_amp_gain; + (capture_.prev_pre_adjustment_gain != pre_adjustment_gain && + capture_.prev_pre_adjustment_gain >= 0.f); + capture_.prev_pre_adjustment_gain = pre_adjustment_gain; } // Detect volume change. @@ -1204,81 +1287,95 @@ int AudioProcessingImpl::ProcessCaptureStreamLocked() { capture_buffer->MergeFrequencyBands(); } - if (capture_.capture_fullband_audio) { - const auto& ec = submodules_.echo_controller; - bool ec_active = ec ? ec->ActiveProcessing() : false; - // Only update the fullband buffer if the multiband processing has changed - // the signal. Keep the original signal otherwise. - if (submodule_states_.CaptureMultiBandProcessingActive(ec_active)) { - capture_buffer->CopyTo(capture_.capture_fullband_audio.get()); + capture_.stats.output_rms_dbfs = absl::nullopt; + if (capture_.capture_output_used) { + if (capture_.capture_fullband_audio) { + const auto& ec = submodules_.echo_controller; + bool ec_active = ec ? ec->ActiveProcessing() : false; + // Only update the fullband buffer if the multiband processing has changed + // the signal. Keep the original signal otherwise. + if (submodule_states_.CaptureMultiBandProcessingActive(ec_active)) { + capture_buffer->CopyTo(capture_.capture_fullband_audio.get()); + } + capture_buffer = capture_.capture_fullband_audio.get(); } - capture_buffer = capture_.capture_fullband_audio.get(); - } - if (config_.residual_echo_detector.enabled) { - RTC_DCHECK(submodules_.echo_detector); - submodules_.echo_detector->AnalyzeCaptureAudio(rtc::ArrayView( - capture_buffer->channels()[0], capture_buffer->num_frames())); - } + if (config_.residual_echo_detector.enabled) { + RTC_DCHECK(submodules_.echo_detector); + submodules_.echo_detector->AnalyzeCaptureAudio( + rtc::ArrayView(capture_buffer->channels()[0], + capture_buffer->num_frames())); + } - // TODO(aluebs): Investigate if the transient suppression placement should be - // before or after the AGC. - if (submodules_.transient_suppressor) { - float voice_probability = submodules_.agc_manager.get() - ? submodules_.agc_manager->voice_probability() - : 1.f; + // TODO(aluebs): Investigate if the transient suppression placement should + // be before or after the AGC. + if (submodules_.transient_suppressor) { + float voice_probability = + submodules_.agc_manager.get() + ? submodules_.agc_manager->voice_probability() + : 1.f; + + submodules_.transient_suppressor->Suppress( + capture_buffer->channels()[0], capture_buffer->num_frames(), + capture_buffer->num_channels(), + capture_buffer->split_bands_const(0)[kBand0To8kHz], + capture_buffer->num_frames_per_band(), + capture_.keyboard_info.keyboard_data, + capture_.keyboard_info.num_keyboard_frames, voice_probability, + capture_.key_pressed); + } - submodules_.transient_suppressor->Suppress( - capture_buffer->channels()[0], capture_buffer->num_frames(), - capture_buffer->num_channels(), - capture_buffer->split_bands_const(0)[kBand0To8kHz], - capture_buffer->num_frames_per_band(), - capture_.keyboard_info.keyboard_data, - capture_.keyboard_info.num_keyboard_frames, voice_probability, - capture_.key_pressed); - } + // Experimental APM sub-module that analyzes |capture_buffer|. + if (submodules_.capture_analyzer) { + submodules_.capture_analyzer->Analyze(capture_buffer); + } - // Experimental APM sub-module that analyzes |capture_buffer|. - if (submodules_.capture_analyzer) { - submodules_.capture_analyzer->Analyze(capture_buffer); - } + if (submodules_.gain_controller2) { + submodules_.gain_controller2->NotifyAnalogLevel( + recommended_stream_analog_level_locked()); + submodules_.gain_controller2->Process(capture_buffer); + } - if (submodules_.gain_controller2) { - submodules_.gain_controller2->NotifyAnalogLevel( - recommended_stream_analog_level_locked()); - submodules_.gain_controller2->Process(capture_buffer); - } + if (submodules_.capture_post_processor) { + submodules_.capture_post_processor->Process(capture_buffer); + } - if (submodules_.capture_post_processor) { - submodules_.capture_post_processor->Process(capture_buffer); - } + // The level estimator operates on the recombined data. + if (config_.level_estimation.enabled) { + submodules_.output_level_estimator->ProcessStream(*capture_buffer); + capture_.stats.output_rms_dbfs = + submodules_.output_level_estimator->RMS(); + } - // The level estimator operates on the recombined data. - if (config_.level_estimation.enabled) { - submodules_.output_level_estimator->ProcessStream(*capture_buffer); - capture_.stats.output_rms_dbfs = submodules_.output_level_estimator->RMS(); - } else { - capture_.stats.output_rms_dbfs = absl::nullopt; - } + capture_output_rms_.Analyze(rtc::ArrayView( + capture_buffer->channels_const()[0], + capture_nonlocked_.capture_processing_format.num_frames())); + if (log_rms) { + RmsLevel::Levels levels = capture_output_rms_.AverageAndPeak(); + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.Audio.ApmCaptureOutputLevelAverageRms", levels.average, 1, + RmsLevel::kMinLevelDb, 64); + RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.ApmCaptureOutputLevelPeakRms", + levels.peak, 1, RmsLevel::kMinLevelDb, 64); + } - capture_output_rms_.Analyze(rtc::ArrayView( - capture_buffer->channels_const()[0], - capture_nonlocked_.capture_processing_format.num_frames())); - if (log_rms) { - RmsLevel::Levels levels = capture_output_rms_.AverageAndPeak(); - RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.ApmCaptureOutputLevelAverageRms", - levels.average, 1, RmsLevel::kMinLevelDb, 64); - RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.ApmCaptureOutputLevelPeakRms", - levels.peak, 1, RmsLevel::kMinLevelDb, 64); - } + if (submodules_.agc_manager) { + int level = recommended_stream_analog_level_locked(); + data_dumper_->DumpRaw("experimental_gain_control_stream_analog_level", 1, + &level); + } - if (submodules_.agc_manager) { - int level = recommended_stream_analog_level_locked(); - data_dumper_->DumpRaw("experimental_gain_control_stream_analog_level", 1, - &level); + // Compute echo-detector stats. + if (config_.residual_echo_detector.enabled) { + RTC_DCHECK(submodules_.echo_detector); + auto ed_metrics = submodules_.echo_detector->GetMetrics(); + capture_.stats.residual_echo_likelihood = ed_metrics.echo_likelihood; + capture_.stats.residual_echo_likelihood_recent_max = + ed_metrics.echo_likelihood_recent_max; + } } - // Compute echo-related stats. + // Compute echo-controller stats. if (submodules_.echo_controller) { auto ec_metrics = submodules_.echo_controller->GetMetrics(); capture_.stats.echo_return_loss = ec_metrics.echo_return_loss; @@ -1286,17 +1383,41 @@ int AudioProcessingImpl::ProcessCaptureStreamLocked() { ec_metrics.echo_return_loss_enhancement; capture_.stats.delay_ms = ec_metrics.delay_ms; } - if (config_.residual_echo_detector.enabled) { - RTC_DCHECK(submodules_.echo_detector); - auto ed_metrics = submodules_.echo_detector->GetMetrics(); - capture_.stats.residual_echo_likelihood = ed_metrics.echo_likelihood; - capture_.stats.residual_echo_likelihood_recent_max = - ed_metrics.echo_likelihood_recent_max; - } // Pass stats for reporting. stats_reporter_.UpdateStatistics(capture_.stats); + if (submodules_.capture_levels_adjuster) { + submodules_.capture_levels_adjuster->ApplyPostLevelAdjustment( + *capture_buffer); + + // If the analog mic gain emulation is active, retrieve the level from the + // analog gain control and set it to mic gain emulator. + if (config_.capture_level_adjustment.analog_mic_gain_emulation.enabled) { + if (submodules_.agc_manager) { + submodules_.capture_levels_adjuster->SetAnalogMicGainLevel( + submodules_.agc_manager->stream_analog_level()); + } else if (submodules_.gain_control) { + submodules_.capture_levels_adjuster->SetAnalogMicGainLevel( + submodules_.gain_control->stream_analog_level()); + } + } + } + + // Temporarily set the output to zero after the stream has been unmuted + // (capture output is again used). The purpose of this is to avoid clicks and + // artefacts in the audio that results when the processing again is + // reactivated after unmuting. + if (!capture_.capture_output_used_last_frame && + capture_.capture_output_used) { + for (size_t ch = 0; ch < capture_buffer->num_channels(); ++ch) { + rtc::ArrayView channel_view(capture_buffer->channels()[ch], + capture_buffer->num_frames()); + std::fill(channel_view.begin(), channel_view.end(), 0.f); + } + } + capture_.capture_output_used_last_frame = capture_.capture_output_used; + capture_.was_stream_delay_set = false; return kNoError; } @@ -1499,16 +1620,29 @@ void AudioProcessingImpl::set_stream_key_pressed(bool key_pressed) { void AudioProcessingImpl::set_stream_analog_level(int level) { MutexLock lock_capture(&mutex_capture_); + if (config_.capture_level_adjustment.analog_mic_gain_emulation.enabled) { + // If the analog mic gain is emulated internally, simply cache the level for + // later reporting back as the recommended stream analog level to use. + capture_.cached_stream_analog_level_ = level; + return; + } + if (submodules_.agc_manager) { submodules_.agc_manager->set_stream_analog_level(level); data_dumper_->DumpRaw("experimental_gain_control_set_stream_analog_level", 1, &level); - } else if (submodules_.gain_control) { + return; + } + + if (submodules_.gain_control) { int error = submodules_.gain_control->set_stream_analog_level(level); RTC_DCHECK_EQ(kNoError, error); - } else { - capture_.cached_stream_analog_level_ = level; + return; } + + // If no analog mic gain control functionality is in place, cache the level + // for later reporting back as the recommended stream analog level to use. + capture_.cached_stream_analog_level_ = level; } int AudioProcessingImpl::recommended_stream_analog_level() const { @@ -1517,13 +1651,19 @@ int AudioProcessingImpl::recommended_stream_analog_level() const { } int AudioProcessingImpl::recommended_stream_analog_level_locked() const { + if (config_.capture_level_adjustment.analog_mic_gain_emulation.enabled) { + return capture_.cached_stream_analog_level_; + } + if (submodules_.agc_manager) { return submodules_.agc_manager->stream_analog_level(); - } else if (submodules_.gain_control) { + } + + if (submodules_.gain_control) { return submodules_.gain_control->stream_analog_level(); - } else { - return capture_.cached_stream_analog_level_; } + + return capture_.cached_stream_analog_level_; } bool AudioProcessingImpl::CreateAndAttachAecDump(const std::string& file_name, @@ -1576,14 +1716,6 @@ void AudioProcessingImpl::DetachAecDump() { } } -void AudioProcessingImpl::MutateConfig( - rtc::FunctionView mutator) { - MutexLock lock_render(&mutex_render_); - MutexLock lock_capture(&mutex_capture_); - mutator(&config_); - ApplyConfig(config_); -} - AudioProcessing::Config AudioProcessingImpl::GetConfig() const { MutexLock lock_render(&mutex_render_); MutexLock lock_capture(&mutex_capture_); @@ -1595,12 +1727,14 @@ bool AudioProcessingImpl::UpdateActiveSubmoduleStates() { config_.high_pass_filter.enabled, !!submodules_.echo_control_mobile, config_.residual_echo_detector.enabled, !!submodules_.noise_suppressor, !!submodules_.gain_control, !!submodules_.gain_controller2, - config_.pre_amplifier.enabled, capture_nonlocked_.echo_controller_enabled, + config_.pre_amplifier.enabled || config_.capture_level_adjustment.enabled, + capture_nonlocked_.echo_controller_enabled, config_.voice_detection.enabled, !!submodules_.transient_suppressor); } void AudioProcessingImpl::InitializeTransientSuppressor() { - if (config_.transient_suppression.enabled) { + if (config_.transient_suppression.enabled && + !constants_.transient_suppressor_forced_off) { // Attempt to create a transient suppressor, if one is not already created. if (!submodules_.transient_suppressor) { submodules_.transient_suppressor = @@ -1782,11 +1916,13 @@ void AudioProcessingImpl::InitializeGainController1() { num_proc_channels(), config_.gain_controller1.analog_gain_controller.startup_min_volume, config_.gain_controller1.analog_gain_controller.clipped_level_min, - config_.gain_controller1.analog_gain_controller - .enable_agc2_level_estimator, !config_.gain_controller1.analog_gain_controller .enable_digital_adaptive, - capture_nonlocked_.split_rate)); + capture_nonlocked_.split_rate, + config_.gain_controller1.analog_gain_controller.clipped_level_step, + config_.gain_controller1.analog_gain_controller.clipped_ratio_threshold, + config_.gain_controller1.analog_gain_controller.clipped_wait_frames, + config_.gain_controller1.analog_gain_controller.clipping_predictor)); if (re_creation) { submodules_.agc_manager->set_stream_analog_level(stream_analog_level); } @@ -1794,7 +1930,8 @@ void AudioProcessingImpl::InitializeGainController1() { submodules_.agc_manager->Initialize(); submodules_.agc_manager->SetupDigitalGainControl( submodules_.gain_control.get()); - submodules_.agc_manager->SetCaptureMuted(capture_.output_will_be_muted); + submodules_.agc_manager->HandleCaptureOutputUsedChange( + capture_.capture_output_used); } void AudioProcessingImpl::InitializeGainController2() { @@ -1805,7 +1942,8 @@ void AudioProcessingImpl::InitializeGainController2() { submodules_.gain_controller2.reset(new GainController2()); } - submodules_.gain_controller2->Initialize(proc_fullband_sample_rate_hz()); + submodules_.gain_controller2->Initialize(proc_fullband_sample_rate_hz(), + num_input_channels()); submodules_.gain_controller2->ApplyConfig(config_.gain_controller2); } else { submodules_.gain_controller2.reset(); @@ -1840,12 +1978,27 @@ void AudioProcessingImpl::InitializeNoiseSuppressor() { } } -void AudioProcessingImpl::InitializePreAmplifier() { - if (config_.pre_amplifier.enabled) { - submodules_.pre_amplifier.reset( - new GainApplier(true, config_.pre_amplifier.fixed_gain_factor)); +void AudioProcessingImpl::InitializeCaptureLevelsAdjuster() { + if (config_.pre_amplifier.enabled || + config_.capture_level_adjustment.enabled) { + // Use both the pre-amplifier and the capture level adjustment gains as + // pre-gains. + float pre_gain = 1.f; + if (config_.pre_amplifier.enabled) { + pre_gain *= config_.pre_amplifier.fixed_gain_factor; + } + if (config_.capture_level_adjustment.enabled) { + pre_gain *= config_.capture_level_adjustment.pre_gain_factor; + } + + submodules_.capture_levels_adjuster = + std::make_unique( + config_.capture_level_adjustment.analog_mic_gain_emulation.enabled, + config_.capture_level_adjustment.analog_mic_gain_emulation + .initial_level, + pre_gain, config_.capture_level_adjustment.post_gain_factor); } else { - submodules_.pre_amplifier.reset(); + submodules_.capture_levels_adjuster.reset(); } } @@ -2005,13 +2158,14 @@ void AudioProcessingImpl::RecordAudioProcessingState() { AudioProcessingImpl::ApmCaptureState::ApmCaptureState() : was_stream_delay_set(false), - output_will_be_muted(false), + capture_output_used(true), + capture_output_used_last_frame(true), key_pressed(false), capture_processing_format(kSampleRate16kHz), split_rate(kSampleRate16kHz), echo_path_gain_change(false), prev_analog_mic_level(-1), - prev_pre_amp_gain(-1.f), + prev_pre_adjustment_gain(-1.f), playout_volume(-1), prev_playout_volume(-1) {} diff --git a/modules/audio_processing/audio_processing_impl.h b/modules/audio_processing/audio_processing_impl.h index d0eec0eec3..c88cfcde92 100644 --- a/modules/audio_processing/audio_processing_impl.h +++ b/modules/audio_processing/audio_processing_impl.h @@ -23,6 +23,7 @@ #include "modules/audio_processing/agc/agc_manager_direct.h" #include "modules/audio_processing/agc/gain_control.h" #include "modules/audio_processing/audio_buffer.h" +#include "modules/audio_processing/capture_levels_adjuster/capture_levels_adjuster.h" #include "modules/audio_processing/echo_control_mobile_impl.h" #include "modules/audio_processing/gain_control_impl.h" #include "modules/audio_processing/gain_controller2.h" @@ -82,6 +83,7 @@ class AudioProcessingImpl : public AudioProcessing { void AttachAecDump(std::unique_ptr aec_dump) override; void DetachAecDump() override; void SetRuntimeSetting(RuntimeSetting setting) override; + bool PostRuntimeSetting(RuntimeSetting setting) override; // Capture-side exclusive methods possibly running APM in a // multi-threaded manner. Acquire the capture lock. @@ -96,6 +98,8 @@ class AudioProcessingImpl : public AudioProcessing { bool GetLinearAecOutput( rtc::ArrayView> linear_output) const override; void set_output_will_be_muted(bool muted) override; + void HandleCaptureOutputUsedSetting(bool capture_output_used) + RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_); int set_stream_delay_ms(int delay) override; void set_stream_key_pressed(bool key_pressed) override; void set_stream_analog_level(int level) override; @@ -133,8 +137,6 @@ class AudioProcessingImpl : public AudioProcessing { return stats_reporter_.GetStatistics(); } - // TODO(peah): Remove MutateConfig once the new API allows that. - void MutateConfig(rtc::FunctionView mutator); AudioProcessing::Config GetConfig() const override; protected: @@ -168,7 +170,9 @@ class AudioProcessingImpl : public AudioProcessing { explicit RuntimeSettingEnqueuer( SwapQueue* runtime_settings); ~RuntimeSettingEnqueuer(); - void Enqueue(RuntimeSetting setting); + + // Enqueue setting and return whether the setting was successfully enqueued. + bool Enqueue(RuntimeSetting setting); private: SwapQueue& runtime_settings_; @@ -199,7 +203,7 @@ class AudioProcessingImpl : public AudioProcessing { bool noise_suppressor_enabled, bool adaptive_gain_controller_enabled, bool gain_controller2_enabled, - bool pre_amplifier_enabled, + bool gain_adjustment_enabled, bool echo_controller_enabled, bool voice_detector_enabled, bool transient_suppressor_enabled); @@ -223,7 +227,7 @@ class AudioProcessingImpl : public AudioProcessing { bool noise_suppressor_enabled_ = false; bool adaptive_gain_controller_enabled_ = false; bool gain_controller2_enabled_ = false; - bool pre_amplifier_enabled_ = false; + bool gain_adjustment_enabled_ = false; bool echo_controller_enabled_ = false; bool voice_detector_enabled_ = false; bool transient_suppressor_enabled_ = false; @@ -267,7 +271,8 @@ class AudioProcessingImpl : public AudioProcessing { RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_); void InitializeGainController2() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_); void InitializeNoiseSuppressor() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_); - void InitializePreAmplifier() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_); + void InitializeCaptureLevelsAdjuster() + RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_); void InitializePostProcessor() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_); void InitializeAnalyzer() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_); @@ -339,6 +344,12 @@ class AudioProcessingImpl : public AudioProcessing { void RecordAudioProcessingState() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_); + // Ensures that overruns in the capture runtime settings queue is properly + // handled by the code, providing safe-fallbacks to mitigate the implications + // of any settings being missed. + void HandleOverrunInCaptureRuntimeSettingsQueue() + RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_); + // AecDump instance used for optionally logging APM config, input // and output to file in the AEC-dump format defined in debug.proto. std::unique_ptr aec_dump_; @@ -383,10 +394,10 @@ class AudioProcessingImpl : public AudioProcessing { std::unique_ptr transient_suppressor; std::unique_ptr capture_post_processor; std::unique_ptr render_pre_processor; - std::unique_ptr pre_amplifier; std::unique_ptr capture_analyzer; std::unique_ptr output_level_estimator; std::unique_ptr voice_detector; + std::unique_ptr capture_levels_adjuster; } submodules_; // State that is written to while holding both the render and capture locks @@ -410,20 +421,28 @@ class AudioProcessingImpl : public AudioProcessing { const struct ApmConstants { ApmConstants(bool multi_channel_render_support, bool multi_channel_capture_support, - bool enforce_split_band_hpf) + bool enforce_split_band_hpf, + bool minimize_processing_for_unused_output, + bool transient_suppressor_forced_off) : multi_channel_render_support(multi_channel_render_support), multi_channel_capture_support(multi_channel_capture_support), - enforce_split_band_hpf(enforce_split_band_hpf) {} + enforce_split_band_hpf(enforce_split_band_hpf), + minimize_processing_for_unused_output( + minimize_processing_for_unused_output), + transient_suppressor_forced_off(transient_suppressor_forced_off) {} bool multi_channel_render_support; bool multi_channel_capture_support; bool enforce_split_band_hpf; + bool minimize_processing_for_unused_output; + bool transient_suppressor_forced_off; } constants_; struct ApmCaptureState { ApmCaptureState(); ~ApmCaptureState(); bool was_stream_delay_set; - bool output_will_be_muted; + bool capture_output_used; + bool capture_output_used_last_frame; bool key_pressed; std::unique_ptr capture_audio; std::unique_ptr capture_fullband_audio; @@ -435,7 +454,7 @@ class AudioProcessingImpl : public AudioProcessing { int split_rate; bool echo_path_gain_change; int prev_analog_mic_level; - float prev_pre_amp_gain; + float prev_pre_adjustment_gain; int playout_volume; int prev_playout_volume; AudioProcessingStats stats; diff --git a/modules/audio_processing/audio_processing_impl_locking_unittest.cc b/modules/audio_processing/audio_processing_impl_locking_unittest.cc index ec165aa146..66c1251d4c 100644 --- a/modules/audio_processing/audio_processing_impl_locking_unittest.cc +++ b/modules/audio_processing/audio_processing_impl_locking_unittest.cc @@ -387,33 +387,6 @@ class AudioProcessingImplLockTest void SetUp() override; void TearDown() override; - // Thread callback for the render thread - static void RenderProcessorThreadFunc(void* context) { - AudioProcessingImplLockTest* impl = - reinterpret_cast(context); - while (!impl->MaybeEndTest()) { - impl->render_thread_state_.Process(); - } - } - - // Thread callback for the capture thread - static void CaptureProcessorThreadFunc(void* context) { - AudioProcessingImplLockTest* impl = - reinterpret_cast(context); - while (!impl->MaybeEndTest()) { - impl->capture_thread_state_.Process(); - } - } - - // Thread callback for the stats thread - static void StatsProcessorThreadFunc(void* context) { - AudioProcessingImplLockTest* impl = - reinterpret_cast(context); - while (!impl->MaybeEndTest()) { - impl->stats_thread_state_.Process(); - } - } - // Tests whether all the required render and capture side calls have been // done. bool TestDone() { @@ -423,9 +396,28 @@ class AudioProcessingImplLockTest // Start the threads used in the test. void StartThreads() { - render_thread_.Start(); - capture_thread_.Start(); - stats_thread_.Start(); + const auto attributes = + rtc::ThreadAttributes().SetPriority(rtc::ThreadPriority::kRealtime); + render_thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { + while (!MaybeEndTest()) + render_thread_state_.Process(); + }, + "render", attributes); + capture_thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { + while (!MaybeEndTest()) { + capture_thread_state_.Process(); + } + }, + "capture", attributes); + + stats_thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { + while (!MaybeEndTest()) + stats_thread_state_.Process(); + }, + "stats", attributes); } // Event handlers for the test. @@ -434,9 +426,6 @@ class AudioProcessingImplLockTest rtc::Event capture_call_event_; // Thread related variables. - rtc::PlatformThread render_thread_; - rtc::PlatformThread capture_thread_; - rtc::PlatformThread stats_thread_; mutable RandomGenerator rand_gen_; std::unique_ptr apm_; @@ -445,6 +434,9 @@ class AudioProcessingImplLockTest RenderProcessor render_thread_state_; CaptureProcessor capture_thread_state_; StatsProcessor stats_thread_state_; + rtc::PlatformThread render_thread_; + rtc::PlatformThread capture_thread_; + rtc::PlatformThread stats_thread_; }; // Sleeps a random time between 0 and max_sleep milliseconds. @@ -485,19 +477,7 @@ void PopulateAudioFrame(float amplitude, } AudioProcessingImplLockTest::AudioProcessingImplLockTest() - : render_thread_(RenderProcessorThreadFunc, - this, - "render", - rtc::kRealtimePriority), - capture_thread_(CaptureProcessorThreadFunc, - this, - "capture", - rtc::kRealtimePriority), - stats_thread_(StatsProcessorThreadFunc, - this, - "stats", - rtc::kNormalPriority), - apm_(AudioProcessingBuilderForTesting().Create()), + : apm_(AudioProcessingBuilderForTesting().Create()), render_thread_state_(kMaxFrameSize, &rand_gen_, &render_call_event_, @@ -549,9 +529,6 @@ void AudioProcessingImplLockTest::SetUp() { void AudioProcessingImplLockTest::TearDown() { render_call_event_.Set(); capture_call_event_.Set(); - render_thread_.Stop(); - capture_thread_.Stop(); - stats_thread_.Stop(); } StatsProcessor::StatsProcessor(RandomGenerator* rand_gen, diff --git a/modules/audio_processing/audio_processing_impl_unittest.cc b/modules/audio_processing/audio_processing_impl_unittest.cc index e289c316bc..ca8b8b4c25 100644 --- a/modules/audio_processing/audio_processing_impl_unittest.cc +++ b/modules/audio_processing/audio_processing_impl_unittest.cc @@ -14,6 +14,7 @@ #include #include "api/scoped_refptr.h" +#include "modules/audio_processing/common.h" #include "modules/audio_processing/include/audio_processing.h" #include "modules/audio_processing/optionally_built_submodule_creators.h" #include "modules/audio_processing/test/audio_processing_builder_for_testing.h" @@ -202,6 +203,154 @@ TEST(AudioProcessingImplTest, UpdateCapturePreGainRuntimeSetting) { << "Frame should be amplified."; } +TEST(AudioProcessingImplTest, + LevelAdjustmentUpdateCapturePreGainRuntimeSetting) { + std::unique_ptr apm( + AudioProcessingBuilderForTesting().Create()); + webrtc::AudioProcessing::Config apm_config; + apm_config.capture_level_adjustment.enabled = true; + apm_config.capture_level_adjustment.pre_gain_factor = 1.f; + apm->ApplyConfig(apm_config); + + constexpr int kSampleRateHz = 48000; + constexpr int16_t kAudioLevel = 10000; + constexpr size_t kNumChannels = 2; + + std::array frame; + StreamConfig config(kSampleRateHz, kNumChannels, /*has_keyboard=*/false); + frame.fill(kAudioLevel); + apm->ProcessStream(frame.data(), config, config, frame.data()); + EXPECT_EQ(frame[100], kAudioLevel) + << "With factor 1, frame shouldn't be modified."; + + constexpr float kGainFactor = 2.f; + apm->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCapturePreGain(kGainFactor)); + + // Process for two frames to have time to ramp up gain. + for (int i = 0; i < 2; ++i) { + frame.fill(kAudioLevel); + apm->ProcessStream(frame.data(), config, config, frame.data()); + } + EXPECT_EQ(frame[100], kGainFactor * kAudioLevel) + << "Frame should be amplified."; +} + +TEST(AudioProcessingImplTest, + LevelAdjustmentUpdateCapturePostGainRuntimeSetting) { + std::unique_ptr apm( + AudioProcessingBuilderForTesting().Create()); + webrtc::AudioProcessing::Config apm_config; + apm_config.capture_level_adjustment.enabled = true; + apm_config.capture_level_adjustment.post_gain_factor = 1.f; + apm->ApplyConfig(apm_config); + + constexpr int kSampleRateHz = 48000; + constexpr int16_t kAudioLevel = 10000; + constexpr size_t kNumChannels = 2; + + std::array frame; + StreamConfig config(kSampleRateHz, kNumChannels, /*has_keyboard=*/false); + frame.fill(kAudioLevel); + apm->ProcessStream(frame.data(), config, config, frame.data()); + EXPECT_EQ(frame[100], kAudioLevel) + << "With factor 1, frame shouldn't be modified."; + + constexpr float kGainFactor = 2.f; + apm->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCapturePostGain(kGainFactor)); + + // Process for two frames to have time to ramp up gain. + for (int i = 0; i < 2; ++i) { + frame.fill(kAudioLevel); + apm->ProcessStream(frame.data(), config, config, frame.data()); + } + EXPECT_EQ(frame[100], kGainFactor * kAudioLevel) + << "Frame should be amplified."; +} + +TEST(AudioProcessingImplTest, EchoControllerObservesSetCaptureUsageChange) { + // Tests that the echo controller observes that the capture usage has been + // updated. + auto echo_control_factory = std::make_unique(); + const MockEchoControlFactory* echo_control_factory_ptr = + echo_control_factory.get(); + + std::unique_ptr apm( + AudioProcessingBuilderForTesting() + .SetEchoControlFactory(std::move(echo_control_factory)) + .Create()); + + constexpr int16_t kAudioLevel = 10000; + constexpr int kSampleRateHz = 48000; + constexpr int kNumChannels = 2; + std::array frame; + StreamConfig config(kSampleRateHz, kNumChannels, /*has_keyboard=*/false); + frame.fill(kAudioLevel); + + MockEchoControl* echo_control_mock = echo_control_factory_ptr->GetNext(); + + // Ensure that SetCaptureOutputUsage is not called when no runtime settings + // are passed. + EXPECT_CALL(*echo_control_mock, SetCaptureOutputUsage(testing::_)).Times(0); + apm->ProcessStream(frame.data(), config, config, frame.data()); + + // Ensure that SetCaptureOutputUsage is called with the right information when + // a runtime setting is passed. + EXPECT_CALL(*echo_control_mock, + SetCaptureOutputUsage(/*capture_output_used=*/false)) + .Times(1); + EXPECT_TRUE(apm->PostRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCaptureOutputUsedSetting( + /*capture_output_used=*/false))); + apm->ProcessStream(frame.data(), config, config, frame.data()); + + EXPECT_CALL(*echo_control_mock, + SetCaptureOutputUsage(/*capture_output_used=*/true)) + .Times(1); + EXPECT_TRUE(apm->PostRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCaptureOutputUsedSetting( + /*capture_output_used=*/true))); + apm->ProcessStream(frame.data(), config, config, frame.data()); + + // The number of positions to place items in the queue is equal to the queue + // size minus 1. + constexpr int kNumSlotsInQueue = RuntimeSettingQueueSize(); + + // Ensure that SetCaptureOutputUsage is called with the right information when + // many runtime settings are passed. + for (int k = 0; k < kNumSlotsInQueue - 1; ++k) { + EXPECT_TRUE(apm->PostRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCaptureOutputUsedSetting( + /*capture_output_used=*/false))); + } + EXPECT_CALL(*echo_control_mock, + SetCaptureOutputUsage(/*capture_output_used=*/false)) + .Times(kNumSlotsInQueue - 1); + apm->ProcessStream(frame.data(), config, config, frame.data()); + + // Ensure that SetCaptureOutputUsage is properly called with the fallback + // value when the runtime settings queue becomes full. + for (int k = 0; k < kNumSlotsInQueue; ++k) { + EXPECT_TRUE(apm->PostRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCaptureOutputUsedSetting( + /*capture_output_used=*/false))); + } + EXPECT_FALSE(apm->PostRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCaptureOutputUsedSetting( + /*capture_output_used=*/false))); + EXPECT_FALSE(apm->PostRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCaptureOutputUsedSetting( + /*capture_output_used=*/false))); + EXPECT_CALL(*echo_control_mock, + SetCaptureOutputUsage(/*capture_output_used=*/false)) + .Times(kNumSlotsInQueue); + EXPECT_CALL(*echo_control_mock, + SetCaptureOutputUsage(/*capture_output_used=*/true)) + .Times(1); + apm->ProcessStream(frame.data(), config, config, frame.data()); +} + TEST(AudioProcessingImplTest, EchoControllerObservesPreAmplifierEchoPathGainChange) { // Tests that the echo controller observes an echo path gain change when the @@ -245,6 +394,49 @@ TEST(AudioProcessingImplTest, apm->ProcessStream(frame.data(), config, config, frame.data()); } +TEST(AudioProcessingImplTest, + EchoControllerObservesLevelAdjustmentPreGainEchoPathGainChange) { + // Tests that the echo controller observes an echo path gain change when the + // pre-amplifier submodule changes the gain. + auto echo_control_factory = std::make_unique(); + const auto* echo_control_factory_ptr = echo_control_factory.get(); + + std::unique_ptr apm( + AudioProcessingBuilderForTesting() + .SetEchoControlFactory(std::move(echo_control_factory)) + .Create()); + // Disable AGC. + webrtc::AudioProcessing::Config apm_config; + apm_config.gain_controller1.enabled = false; + apm_config.gain_controller2.enabled = false; + apm_config.capture_level_adjustment.enabled = true; + apm_config.capture_level_adjustment.pre_gain_factor = 1.f; + apm->ApplyConfig(apm_config); + + constexpr int16_t kAudioLevel = 10000; + constexpr size_t kSampleRateHz = 48000; + constexpr size_t kNumChannels = 2; + std::array frame; + StreamConfig config(kSampleRateHz, kNumChannels, /*has_keyboard=*/false); + frame.fill(kAudioLevel); + + MockEchoControl* echo_control_mock = echo_control_factory_ptr->GetNext(); + + EXPECT_CALL(*echo_control_mock, AnalyzeCapture(testing::_)).Times(1); + EXPECT_CALL(*echo_control_mock, + ProcessCapture(NotNull(), testing::_, /*echo_path_change=*/false)) + .Times(1); + apm->ProcessStream(frame.data(), config, config, frame.data()); + + EXPECT_CALL(*echo_control_mock, AnalyzeCapture(testing::_)).Times(1); + EXPECT_CALL(*echo_control_mock, + ProcessCapture(NotNull(), testing::_, /*echo_path_change=*/true)) + .Times(1); + apm->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCapturePreGain(2.f)); + apm->ProcessStream(frame.data(), config, config, frame.data()); +} + TEST(AudioProcessingImplTest, EchoControllerObservesAnalogAgc1EchoPathGainChange) { // Tests that the echo controller observes an echo path gain change when the @@ -352,8 +544,7 @@ TEST(AudioProcessingImplTest, EchoControllerObservesPlayoutVolumeChange) { TEST(AudioProcessingImplTest, RenderPreProcessorBeforeEchoDetector) { // Make sure that signal changes caused by a render pre-processing sub-module // take place before any echo detector analysis. - rtc::scoped_refptr test_echo_detector( - new rtc::RefCountedObject()); + auto test_echo_detector = rtc::make_ref_counted(); std::unique_ptr test_render_pre_processor( new TestRenderPreProcessor()); // Create APM injecting the test echo detector and render pre-processor. @@ -413,8 +604,7 @@ TEST(AudioProcessingImplTest, RenderPreProcessorBeforeEchoDetector) { // config should be bit-exact with running APM with said submodules disabled. // This mainly tests that SetCreateOptionalSubmodulesForTesting has an effect. TEST(ApmWithSubmodulesExcludedTest, BitexactWithDisabledModules) { - rtc::scoped_refptr apm = - new rtc::RefCountedObject(webrtc::Config()); + auto apm = rtc::make_ref_counted(webrtc::Config()); ASSERT_EQ(apm->Initialize(), AudioProcessing::kNoError); ApmSubmoduleCreationOverrides overrides; @@ -462,8 +652,7 @@ TEST(ApmWithSubmodulesExcludedTest, BitexactWithDisabledModules) { // Disable transient suppressor creation and run APM in ways that should trigger // calls to the transient suppressor API. TEST(ApmWithSubmodulesExcludedTest, ReinitializeTransientSuppressor) { - rtc::scoped_refptr apm = - new rtc::RefCountedObject(webrtc::Config()); + auto apm = rtc::make_ref_counted(webrtc::Config()); ASSERT_EQ(apm->Initialize(), kNoErr); ApmSubmoduleCreationOverrides overrides; @@ -524,8 +713,7 @@ TEST(ApmWithSubmodulesExcludedTest, ReinitializeTransientSuppressor) { // Disable transient suppressor creation and run APM in ways that should trigger // calls to the transient suppressor API. TEST(ApmWithSubmodulesExcludedTest, ToggleTransientSuppressor) { - rtc::scoped_refptr apm = - new rtc::RefCountedObject(webrtc::Config()); + auto apm = rtc::make_ref_counted(webrtc::Config()); ASSERT_EQ(apm->Initialize(), AudioProcessing::kNoError); ApmSubmoduleCreationOverrides overrides; diff --git a/modules/audio_processing/audio_processing_performance_unittest.cc b/modules/audio_processing/audio_processing_performance_unittest.cc index 86ff0e8bfe..9585850296 100644 --- a/modules/audio_processing/audio_processing_performance_unittest.cc +++ b/modules/audio_processing/audio_processing_performance_unittest.cc @@ -391,15 +391,7 @@ class TimedThreadApiProcessor { class CallSimulator : public ::testing::TestWithParam { public: CallSimulator() - : render_thread_(new rtc::PlatformThread(RenderProcessorThreadFunc, - this, - "render", - rtc::kRealtimePriority)), - capture_thread_(new rtc::PlatformThread(CaptureProcessorThreadFunc, - this, - "capture", - rtc::kRealtimePriority)), - rand_gen_(42U), + : rand_gen_(42U), simulation_config_(static_cast(GetParam())) {} // Run the call simulation with a timeout. @@ -434,13 +426,10 @@ class CallSimulator : public ::testing::TestWithParam { static const int kMinNumFramesToProcess = 150; static const int32_t kTestTimeout = 3 * 10 * kMinNumFramesToProcess; - // ::testing::TestWithParam<> implementation. - void TearDown() override { StopThreads(); } - // Stop all running threads. void StopThreads() { - render_thread_->Stop(); - capture_thread_->Stop(); + render_thread_.Finalize(); + capture_thread_.Finalize(); } // Simulator and APM setup. @@ -531,32 +520,28 @@ class CallSimulator : public ::testing::TestWithParam { kMinNumFramesToProcess, kCaptureInputFloatLevel, num_capture_channels)); } - // Thread callback for the render thread. - static void RenderProcessorThreadFunc(void* context) { - CallSimulator* call_simulator = reinterpret_cast(context); - while (call_simulator->render_thread_state_->Process()) { - } - } - - // Thread callback for the capture thread. - static void CaptureProcessorThreadFunc(void* context) { - CallSimulator* call_simulator = reinterpret_cast(context); - while (call_simulator->capture_thread_state_->Process()) { - } - } - // Start the threads used in the test. void StartThreads() { - ASSERT_NO_FATAL_FAILURE(render_thread_->Start()); - ASSERT_NO_FATAL_FAILURE(capture_thread_->Start()); + const auto attributes = + rtc::ThreadAttributes().SetPriority(rtc::ThreadPriority::kRealtime); + render_thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { + while (render_thread_state_->Process()) { + } + }, + "render", attributes); + capture_thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { + while (capture_thread_state_->Process()) { + } + }, + "capture", attributes); } // Event handler for the test. rtc::Event test_complete_; // Thread related variables. - std::unique_ptr render_thread_; - std::unique_ptr capture_thread_; Random rand_gen_; std::unique_ptr apm_; @@ -565,6 +550,8 @@ class CallSimulator : public ::testing::TestWithParam { LockedFlag capture_call_checker_; std::unique_ptr render_thread_state_; std::unique_ptr capture_thread_state_; + rtc::PlatformThread render_thread_; + rtc::PlatformThread capture_thread_; }; // Implements the callback functionality for the threads. diff --git a/modules/audio_processing/audio_processing_unittest.cc b/modules/audio_processing/audio_processing_unittest.cc index 545c7809da..4d30a348f6 100644 --- a/modules/audio_processing/audio_processing_unittest.cc +++ b/modules/audio_processing/audio_processing_unittest.cc @@ -913,6 +913,131 @@ TEST_F(ApmTest, PreAmplifier) { EXPECT_EQ(config.pre_amplifier.fixed_gain_factor, 1.5f); } +// This test a simple test that ensures that the emulated analog mic gain +// functionality runs without crashing. +TEST_F(ApmTest, AnalogMicGainEmulation) { + // Fill the audio frame with a sawtooth pattern. + rtc::ArrayView frame_data = GetMutableFrameData(&frame_); + const size_t samples_per_channel = frame_.samples_per_channel; + for (size_t i = 0; i < samples_per_channel; i++) { + for (size_t ch = 0; ch < frame_.num_channels; ++ch) { + frame_data[i + ch * samples_per_channel] = 100 * ((i % 3) - 1); + } + } + // Cache the frame in tmp_frame. + Int16FrameData tmp_frame; + tmp_frame.CopyFrom(frame_); + + // Enable the analog gain emulation. + AudioProcessing::Config config = apm_->GetConfig(); + config.capture_level_adjustment.enabled = true; + config.capture_level_adjustment.analog_mic_gain_emulation.enabled = true; + config.capture_level_adjustment.analog_mic_gain_emulation.initial_level = 21; + config.gain_controller1.enabled = true; + config.gain_controller1.mode = + AudioProcessing::Config::GainController1::Mode::kAdaptiveAnalog; + config.gain_controller1.analog_gain_controller.enabled = true; + apm_->ApplyConfig(config); + + // Process a number of frames to ensure that the code runs without crashes. + for (int i = 0; i < 20; ++i) { + frame_.CopyFrom(tmp_frame); + EXPECT_EQ(apm_->kNoError, ProcessStreamChooser(kIntFormat)); + } +} + +// This test repeatedly reconfigures the capture level adjustment functionality +// in APM, processes a number of frames, and checks that output signal has the +// right level. +TEST_F(ApmTest, CaptureLevelAdjustment) { + // Fill the audio frame with a sawtooth pattern. + rtc::ArrayView frame_data = GetMutableFrameData(&frame_); + const size_t samples_per_channel = frame_.samples_per_channel; + for (size_t i = 0; i < samples_per_channel; i++) { + for (size_t ch = 0; ch < frame_.num_channels; ++ch) { + frame_data[i + ch * samples_per_channel] = 100 * ((i % 3) - 1); + } + } + // Cache the frame in tmp_frame. + Int16FrameData tmp_frame; + tmp_frame.CopyFrom(frame_); + + auto compute_power = [](const Int16FrameData& frame) { + rtc::ArrayView data = GetFrameData(frame); + return std::accumulate(data.begin(), data.end(), 0.0f, + [](float a, float b) { return a + b * b; }) / + data.size() / 32768 / 32768; + }; + + const float input_power = compute_power(tmp_frame); + // Double-check that the input data is large compared to the error kEpsilon. + constexpr float kEpsilon = 1e-20f; + RTC_DCHECK_GE(input_power, 10 * kEpsilon); + + // 1. Enable pre-amp with 0 dB gain. + AudioProcessing::Config config = apm_->GetConfig(); + config.capture_level_adjustment.enabled = true; + config.capture_level_adjustment.pre_gain_factor = 0.5f; + config.capture_level_adjustment.post_gain_factor = 4.f; + const float expected_output_power1 = + config.capture_level_adjustment.pre_gain_factor * + config.capture_level_adjustment.pre_gain_factor * + config.capture_level_adjustment.post_gain_factor * + config.capture_level_adjustment.post_gain_factor * input_power; + apm_->ApplyConfig(config); + + for (int i = 0; i < 20; ++i) { + frame_.CopyFrom(tmp_frame); + EXPECT_EQ(apm_->kNoError, ProcessStreamChooser(kIntFormat)); + } + float output_power = compute_power(frame_); + EXPECT_NEAR(output_power, expected_output_power1, kEpsilon); + config = apm_->GetConfig(); + EXPECT_EQ(config.capture_level_adjustment.pre_gain_factor, 0.5f); + EXPECT_EQ(config.capture_level_adjustment.post_gain_factor, 4.f); + + // 2. Change pre-amp gain via ApplyConfig. + config.capture_level_adjustment.pre_gain_factor = 1.0f; + config.capture_level_adjustment.post_gain_factor = 2.f; + const float expected_output_power2 = + config.capture_level_adjustment.pre_gain_factor * + config.capture_level_adjustment.pre_gain_factor * + config.capture_level_adjustment.post_gain_factor * + config.capture_level_adjustment.post_gain_factor * input_power; + apm_->ApplyConfig(config); + + for (int i = 0; i < 20; ++i) { + frame_.CopyFrom(tmp_frame); + EXPECT_EQ(apm_->kNoError, ProcessStreamChooser(kIntFormat)); + } + output_power = compute_power(frame_); + EXPECT_NEAR(output_power, expected_output_power2, kEpsilon); + config = apm_->GetConfig(); + EXPECT_EQ(config.capture_level_adjustment.pre_gain_factor, 1.0f); + EXPECT_EQ(config.capture_level_adjustment.post_gain_factor, 2.f); + + // 3. Change pre-amp gain via a RuntimeSetting. + constexpr float kPreGain3 = 0.5f; + constexpr float kPostGain3 = 3.f; + const float expected_output_power3 = + kPreGain3 * kPreGain3 * kPostGain3 * kPostGain3 * input_power; + + apm_->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCapturePreGain(kPreGain3)); + apm_->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCapturePostGain(kPostGain3)); + + for (int i = 0; i < 20; ++i) { + frame_.CopyFrom(tmp_frame); + EXPECT_EQ(apm_->kNoError, ProcessStreamChooser(kIntFormat)); + } + output_power = compute_power(frame_); + EXPECT_NEAR(output_power, expected_output_power3, kEpsilon); + config = apm_->GetConfig(); + EXPECT_EQ(config.capture_level_adjustment.pre_gain_factor, 0.5f); + EXPECT_EQ(config.capture_level_adjustment.post_gain_factor, 3.f); +} + TEST_F(ApmTest, GainControl) { AudioProcessing::Config config = apm_->GetConfig(); config.gain_controller1.enabled = false; @@ -2216,7 +2341,7 @@ INSTANTIATE_TEST_SUITE_P( std::make_tuple(32000, 44100, 16000, 44100, 19, 15), std::make_tuple(32000, 32000, 48000, 32000, 40, 35), std::make_tuple(32000, 32000, 32000, 32000, 0, 0), - std::make_tuple(32000, 32000, 16000, 32000, 40, 20), + std::make_tuple(32000, 32000, 16000, 32000, 39, 20), std::make_tuple(32000, 16000, 48000, 16000, 25, 20), std::make_tuple(32000, 16000, 32000, 16000, 25, 20), std::make_tuple(32000, 16000, 16000, 16000, 25, 0), @@ -2231,7 +2356,7 @@ INSTANTIATE_TEST_SUITE_P( std::make_tuple(16000, 32000, 32000, 32000, 25, 0), std::make_tuple(16000, 32000, 16000, 32000, 25, 20), std::make_tuple(16000, 16000, 48000, 16000, 39, 20), - std::make_tuple(16000, 16000, 32000, 16000, 40, 20), + std::make_tuple(16000, 16000, 32000, 16000, 39, 20), std::make_tuple(16000, 16000, 16000, 16000, 0, 0))); #elif defined(WEBRTC_AUDIOPROC_FIXED_PROFILE) @@ -2428,36 +2553,6 @@ TEST(RuntimeSettingTest, TestDefaultCtor) { EXPECT_EQ(AudioProcessing::RuntimeSetting::Type::kNotSpecified, s.type()); } -TEST(RuntimeSettingDeathTest, TestCapturePreGain) { - using Type = AudioProcessing::RuntimeSetting::Type; - { - auto s = AudioProcessing::RuntimeSetting::CreateCapturePreGain(1.25f); - EXPECT_EQ(Type::kCapturePreGain, s.type()); - float v; - s.GetFloat(&v); - EXPECT_EQ(1.25f, v); - } - -#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) - EXPECT_DEATH(AudioProcessing::RuntimeSetting::CreateCapturePreGain(0.1f), ""); -#endif -} - -TEST(RuntimeSettingDeathTest, TestCaptureFixedPostGain) { - using Type = AudioProcessing::RuntimeSetting::Type; - { - auto s = AudioProcessing::RuntimeSetting::CreateCaptureFixedPostGain(1.25f); - EXPECT_EQ(Type::kCaptureFixedPostGain, s.type()); - float v; - s.GetFloat(&v); - EXPECT_EQ(1.25f, v); - } - -#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) - EXPECT_DEATH(AudioProcessing::RuntimeSetting::CreateCapturePreGain(0.1f), ""); -#endif -} - TEST(RuntimeSettingTest, TestUsageWithSwapQueue) { SwapQueue q(1); auto s = AudioProcessing::RuntimeSetting(); @@ -2931,10 +3026,6 @@ TEST(AudioProcessing, GainController1ConfigEqual) { b_analog.clipped_level_min = a_analog.clipped_level_min; EXPECT_EQ(a, b); - Toggle(a_analog.enable_agc2_level_estimator); - b_analog.enable_agc2_level_estimator = a_analog.enable_agc2_level_estimator; - EXPECT_EQ(a, b); - Toggle(a_analog.enable_digital_adaptive); b_analog.enable_digital_adaptive = a_analog.enable_digital_adaptive; EXPECT_EQ(a, b); @@ -2948,54 +3039,50 @@ TEST(AudioProcessing, GainController1ConfigNotEqual) { Toggle(a.enabled); EXPECT_NE(a, b); - a.enabled = b.enabled; + a = b; a.mode = AudioProcessing::Config::GainController1::Mode::kAdaptiveDigital; EXPECT_NE(a, b); - a.mode = b.mode; + a = b; a.target_level_dbfs++; EXPECT_NE(a, b); - a.target_level_dbfs = b.target_level_dbfs; + a = b; a.compression_gain_db++; EXPECT_NE(a, b); - a.compression_gain_db = b.compression_gain_db; + a = b; Toggle(a.enable_limiter); EXPECT_NE(a, b); - a.enable_limiter = b.enable_limiter; + a = b; a.analog_level_minimum++; EXPECT_NE(a, b); - a.analog_level_minimum = b.analog_level_minimum; + a = b; a.analog_level_maximum--; EXPECT_NE(a, b); - a.analog_level_maximum = b.analog_level_maximum; + a = b; auto& a_analog = a.analog_gain_controller; const auto& b_analog = b.analog_gain_controller; Toggle(a_analog.enabled); EXPECT_NE(a, b); - a_analog.enabled = b_analog.enabled; + a_analog = b_analog; a_analog.startup_min_volume++; EXPECT_NE(a, b); - a_analog.startup_min_volume = b_analog.startup_min_volume; + a_analog = b_analog; a_analog.clipped_level_min++; EXPECT_NE(a, b); - a_analog.clipped_level_min = b_analog.clipped_level_min; - - Toggle(a_analog.enable_agc2_level_estimator); - EXPECT_NE(a, b); - a_analog.enable_agc2_level_estimator = b_analog.enable_agc2_level_estimator; + a_analog = b_analog; Toggle(a_analog.enable_digital_adaptive); EXPECT_NE(a, b); - a_analog.enable_digital_adaptive = b_analog.enable_digital_adaptive; + a_analog = b_analog; } TEST(AudioProcessing, GainController2ConfigEqual) { @@ -3007,7 +3094,7 @@ TEST(AudioProcessing, GainController2ConfigEqual) { b.enabled = a.enabled; EXPECT_EQ(a, b); - a.fixed_digital.gain_db += 1.f; + a.fixed_digital.gain_db += 1.0f; b.fixed_digital.gain_db = a.fixed_digital.gain_db; EXPECT_EQ(a, b); @@ -3018,46 +3105,44 @@ TEST(AudioProcessing, GainController2ConfigEqual) { b_adaptive.enabled = a_adaptive.enabled; EXPECT_EQ(a, b); - a_adaptive.vad_probability_attack += 1.f; - b_adaptive.vad_probability_attack = a_adaptive.vad_probability_attack; + Toggle(a_adaptive.dry_run); + b_adaptive.dry_run = a_adaptive.dry_run; EXPECT_EQ(a, b); - a_adaptive.level_estimator = - AudioProcessing::Config::GainController2::LevelEstimator::kPeak; - b_adaptive.level_estimator = a_adaptive.level_estimator; + a_adaptive.noise_estimator = AudioProcessing::Config::GainController2:: + NoiseEstimator::kStationaryNoise; + b_adaptive.noise_estimator = a_adaptive.noise_estimator; EXPECT_EQ(a, b); - a_adaptive.level_estimator_adjacent_speech_frames_threshold++; - b_adaptive.level_estimator_adjacent_speech_frames_threshold = - a_adaptive.level_estimator_adjacent_speech_frames_threshold; + a_adaptive.vad_reset_period_ms++; + b_adaptive.vad_reset_period_ms = a_adaptive.vad_reset_period_ms; EXPECT_EQ(a, b); - Toggle(a_adaptive.use_saturation_protector); - b_adaptive.use_saturation_protector = a_adaptive.use_saturation_protector; + a_adaptive.adjacent_speech_frames_threshold++; + b_adaptive.adjacent_speech_frames_threshold = + a_adaptive.adjacent_speech_frames_threshold; EXPECT_EQ(a, b); - a_adaptive.initial_saturation_margin_db += 1.f; - b_adaptive.initial_saturation_margin_db = - a_adaptive.initial_saturation_margin_db; + a_adaptive.max_gain_change_db_per_second += 1.0f; + b_adaptive.max_gain_change_db_per_second = + a_adaptive.max_gain_change_db_per_second; EXPECT_EQ(a, b); - a_adaptive.extra_saturation_margin_db += 1.f; - b_adaptive.extra_saturation_margin_db = a_adaptive.extra_saturation_margin_db; + a_adaptive.max_output_noise_level_dbfs += 1.0f; + b_adaptive.max_output_noise_level_dbfs = + a_adaptive.max_output_noise_level_dbfs; EXPECT_EQ(a, b); - a_adaptive.gain_applier_adjacent_speech_frames_threshold++; - b_adaptive.gain_applier_adjacent_speech_frames_threshold = - a_adaptive.gain_applier_adjacent_speech_frames_threshold; + Toggle(a_adaptive.sse2_allowed); + b_adaptive.sse2_allowed = a_adaptive.sse2_allowed; EXPECT_EQ(a, b); - a_adaptive.max_gain_change_db_per_second += 1.f; - b_adaptive.max_gain_change_db_per_second = - a_adaptive.max_gain_change_db_per_second; + Toggle(a_adaptive.avx2_allowed); + b_adaptive.avx2_allowed = a_adaptive.avx2_allowed; EXPECT_EQ(a, b); - a_adaptive.max_output_noise_level_dbfs -= 1.f; - b_adaptive.max_output_noise_level_dbfs = - a_adaptive.max_output_noise_level_dbfs; + Toggle(a_adaptive.neon_allowed); + b_adaptive.neon_allowed = a_adaptive.neon_allowed; EXPECT_EQ(a, b); } @@ -3069,60 +3154,55 @@ TEST(AudioProcessing, GainController2ConfigNotEqual) { Toggle(a.enabled); EXPECT_NE(a, b); - a.enabled = b.enabled; + a = b; - a.fixed_digital.gain_db += 1.f; + a.fixed_digital.gain_db += 1.0f; EXPECT_NE(a, b); - a.fixed_digital.gain_db = b.fixed_digital.gain_db; + a.fixed_digital = b.fixed_digital; auto& a_adaptive = a.adaptive_digital; const auto& b_adaptive = b.adaptive_digital; Toggle(a_adaptive.enabled); EXPECT_NE(a, b); - a_adaptive.enabled = b_adaptive.enabled; + a_adaptive = b_adaptive; - a_adaptive.vad_probability_attack += 1.f; + Toggle(a_adaptive.dry_run); EXPECT_NE(a, b); - a_adaptive.vad_probability_attack = b_adaptive.vad_probability_attack; + a_adaptive = b_adaptive; - a_adaptive.level_estimator = - AudioProcessing::Config::GainController2::LevelEstimator::kPeak; + a_adaptive.noise_estimator = AudioProcessing::Config::GainController2:: + NoiseEstimator::kStationaryNoise; EXPECT_NE(a, b); - a_adaptive.level_estimator = b_adaptive.level_estimator; + a_adaptive = b_adaptive; - a_adaptive.level_estimator_adjacent_speech_frames_threshold++; + a_adaptive.vad_reset_period_ms++; EXPECT_NE(a, b); - a_adaptive.level_estimator_adjacent_speech_frames_threshold = - b_adaptive.level_estimator_adjacent_speech_frames_threshold; + a_adaptive = b_adaptive; - Toggle(a_adaptive.use_saturation_protector); + a_adaptive.adjacent_speech_frames_threshold++; EXPECT_NE(a, b); - a_adaptive.use_saturation_protector = b_adaptive.use_saturation_protector; + a_adaptive = b_adaptive; - a_adaptive.initial_saturation_margin_db += 1.f; + a_adaptive.max_gain_change_db_per_second += 1.0f; EXPECT_NE(a, b); - a_adaptive.initial_saturation_margin_db = - b_adaptive.initial_saturation_margin_db; + a_adaptive = b_adaptive; - a_adaptive.extra_saturation_margin_db += 1.f; + a_adaptive.max_output_noise_level_dbfs += 1.0f; EXPECT_NE(a, b); - a_adaptive.extra_saturation_margin_db = b_adaptive.extra_saturation_margin_db; + a_adaptive = b_adaptive; - a_adaptive.gain_applier_adjacent_speech_frames_threshold++; + Toggle(a_adaptive.sse2_allowed); EXPECT_NE(a, b); - a_adaptive.gain_applier_adjacent_speech_frames_threshold = - b_adaptive.gain_applier_adjacent_speech_frames_threshold; + a_adaptive = b_adaptive; - a_adaptive.max_gain_change_db_per_second += 1.f; + Toggle(a_adaptive.avx2_allowed); EXPECT_NE(a, b); - a_adaptive.max_gain_change_db_per_second = - b_adaptive.max_gain_change_db_per_second; + a_adaptive = b_adaptive; - a_adaptive.max_output_noise_level_dbfs -= 1.f; + Toggle(a_adaptive.neon_allowed); EXPECT_NE(a, b); - a_adaptive.max_output_noise_level_dbfs = - b_adaptive.max_output_noise_level_dbfs; + a_adaptive = b_adaptive; } } // namespace webrtc diff --git a/modules/audio_processing/capture_levels_adjuster/BUILD.gn b/modules/audio_processing/capture_levels_adjuster/BUILD.gn new file mode 100644 index 0000000000..e7ff8482f6 --- /dev/null +++ b/modules/audio_processing/capture_levels_adjuster/BUILD.gn @@ -0,0 +1,45 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_library("capture_levels_adjuster") { + visibility = [ "*" ] + + sources = [ + "audio_samples_scaler.cc", + "audio_samples_scaler.h", + "capture_levels_adjuster.cc", + "capture_levels_adjuster.h", + ] + + defines = [] + + deps = [ + "..:audio_buffer", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:safe_minmax", + ] +} + +rtc_library("capture_levels_adjuster_unittests") { + testonly = true + + sources = [ + "audio_samples_scaler_unittest.cc", + "capture_levels_adjuster_unittest.cc", + ] + deps = [ + ":capture_levels_adjuster", + "..:audioproc_test_utils", + "../../../rtc_base:gunit_helpers", + "../../../rtc_base:stringutils", + "../../../test:test_support", + ] +} diff --git a/modules/audio_processing/capture_levels_adjuster/audio_samples_scaler.cc b/modules/audio_processing/capture_levels_adjuster/audio_samples_scaler.cc new file mode 100644 index 0000000000..cb2336b87d --- /dev/null +++ b/modules/audio_processing/capture_levels_adjuster/audio_samples_scaler.cc @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/capture_levels_adjuster/audio_samples_scaler.h" + +#include + +#include "api/array_view.h" +#include "modules/audio_processing/audio_buffer.h" +#include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_minmax.h" + +namespace webrtc { + +AudioSamplesScaler::AudioSamplesScaler(float initial_gain) + : previous_gain_(initial_gain), target_gain_(initial_gain) {} + +void AudioSamplesScaler::Process(AudioBuffer& audio_buffer) { + if (static_cast(audio_buffer.num_frames()) != samples_per_channel_) { + // Update the members depending on audio-buffer length if needed. + RTC_DCHECK_GT(audio_buffer.num_frames(), 0); + samples_per_channel_ = static_cast(audio_buffer.num_frames()); + one_by_samples_per_channel_ = 1.f / samples_per_channel_; + } + + if (target_gain_ == 1.f && previous_gain_ == target_gain_) { + // If only a gain of 1 is to be applied, do an early return without applying + // any gain. + return; + } + + float gain = previous_gain_; + if (previous_gain_ == target_gain_) { + // Apply a non-changing gain. + for (size_t channel = 0; channel < audio_buffer.num_channels(); ++channel) { + rtc::ArrayView channel_view(audio_buffer.channels()[channel], + samples_per_channel_); + for (float& sample : channel_view) { + sample *= gain; + } + } + } else { + const float increment = + (target_gain_ - previous_gain_) * one_by_samples_per_channel_; + + if (increment > 0.f) { + // Apply an increasing gain. + for (size_t channel = 0; channel < audio_buffer.num_channels(); + ++channel) { + gain = previous_gain_; + rtc::ArrayView channel_view(audio_buffer.channels()[channel], + samples_per_channel_); + for (float& sample : channel_view) { + gain = std::min(gain + increment, target_gain_); + sample *= gain; + } + } + } else { + // Apply a decreasing gain. + for (size_t channel = 0; channel < audio_buffer.num_channels(); + ++channel) { + gain = previous_gain_; + rtc::ArrayView channel_view(audio_buffer.channels()[channel], + samples_per_channel_); + for (float& sample : channel_view) { + gain = std::max(gain + increment, target_gain_); + sample *= gain; + } + } + } + } + previous_gain_ = target_gain_; + + // Saturate the samples to be in the S16 range. + for (size_t channel = 0; channel < audio_buffer.num_channels(); ++channel) { + rtc::ArrayView channel_view(audio_buffer.channels()[channel], + samples_per_channel_); + for (float& sample : channel_view) { + constexpr float kMinFloatS16Value = -32768.f; + constexpr float kMaxFloatS16Value = 32767.f; + sample = rtc::SafeClamp(sample, kMinFloatS16Value, kMaxFloatS16Value); + } + } +} + +} // namespace webrtc diff --git a/modules/audio_processing/capture_levels_adjuster/audio_samples_scaler.h b/modules/audio_processing/capture_levels_adjuster/audio_samples_scaler.h new file mode 100644 index 0000000000..2ae8533940 --- /dev/null +++ b/modules/audio_processing/capture_levels_adjuster/audio_samples_scaler.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_CAPTURE_LEVELS_ADJUSTER_AUDIO_SAMPLES_SCALER_H_ +#define MODULES_AUDIO_PROCESSING_CAPTURE_LEVELS_ADJUSTER_AUDIO_SAMPLES_SCALER_H_ + +#include + +#include "modules/audio_processing/audio_buffer.h" + +namespace webrtc { + +// Handles and applies a gain to the samples in an audio buffer. +// The gain is applied for each sample and any changes in the gain take effect +// gradually (in a linear manner) over one frame. +class AudioSamplesScaler { + public: + // C-tor. The supplied `initial_gain` is used immediately at the first call to + // Process(), i.e., in contrast to the gain supplied by SetGain(...) there is + // no gradual change to the `initial_gain`. + explicit AudioSamplesScaler(float initial_gain); + AudioSamplesScaler(const AudioSamplesScaler&) = delete; + AudioSamplesScaler& operator=(const AudioSamplesScaler&) = delete; + + // Applies the specified gain to the audio in `audio_buffer`. + void Process(AudioBuffer& audio_buffer); + + // Sets the gain to apply to each sample. + void SetGain(float gain) { target_gain_ = gain; } + + private: + float previous_gain_ = 1.f; + float target_gain_ = 1.f; + int samples_per_channel_ = -1; + float one_by_samples_per_channel_ = -1.f; +}; +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_CAPTURE_LEVELS_ADJUSTER_AUDIO_SAMPLES_SCALER_H_ diff --git a/modules/audio_processing/capture_levels_adjuster/audio_samples_scaler_unittest.cc b/modules/audio_processing/capture_levels_adjuster/audio_samples_scaler_unittest.cc new file mode 100644 index 0000000000..6e5fc2cbe3 --- /dev/null +++ b/modules/audio_processing/capture_levels_adjuster/audio_samples_scaler_unittest.cc @@ -0,0 +1,204 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/capture_levels_adjuster/audio_samples_scaler.h" + +#include + +#include "modules/audio_processing/test/audio_buffer_tools.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +float SampleValueForChannel(int channel) { + constexpr float kSampleBaseValue = 100.f; + constexpr float kSampleChannelOffset = 1.f; + return kSampleBaseValue + channel * kSampleChannelOffset; +} + +void PopulateBuffer(AudioBuffer& audio_buffer) { + for (size_t ch = 0; ch < audio_buffer.num_channels(); ++ch) { + test::FillBufferChannel(SampleValueForChannel(ch), ch, audio_buffer); + } +} + +constexpr int kNumFramesToProcess = 10; + +class AudioSamplesScalerTest + : public ::testing::Test, + public ::testing::WithParamInterface> { + protected: + int sample_rate_hz() const { return std::get<0>(GetParam()); } + int num_channels() const { return std::get<1>(GetParam()); } + float initial_gain() const { return std::get<2>(GetParam()); } +}; + +INSTANTIATE_TEST_SUITE_P( + AudioSamplesScalerTestSuite, + AudioSamplesScalerTest, + ::testing::Combine(::testing::Values(16000, 32000, 48000), + ::testing::Values(1, 2, 4), + ::testing::Values(0.1f, 1.f, 2.f, 4.f))); + +TEST_P(AudioSamplesScalerTest, InitialGainIsRespected) { + AudioSamplesScaler scaler(initial_gain()); + + AudioBuffer audio_buffer(sample_rate_hz(), num_channels(), sample_rate_hz(), + num_channels(), sample_rate_hz(), num_channels()); + + for (int frame = 0; frame < kNumFramesToProcess; ++frame) { + PopulateBuffer(audio_buffer); + scaler.Process(audio_buffer); + for (int ch = 0; ch < num_channels(); ++ch) { + for (size_t i = 0; i < audio_buffer.num_frames(); ++i) { + EXPECT_FLOAT_EQ(audio_buffer.channels_const()[ch][i], + initial_gain() * SampleValueForChannel(ch)); + } + } + } +} + +TEST_P(AudioSamplesScalerTest, VerifyGainAdjustment) { + const float higher_gain = initial_gain(); + const float lower_gain = higher_gain / 2.f; + + AudioSamplesScaler scaler(lower_gain); + + AudioBuffer audio_buffer(sample_rate_hz(), num_channels(), sample_rate_hz(), + num_channels(), sample_rate_hz(), num_channels()); + + // Allow the intial, lower, gain to take effect. + PopulateBuffer(audio_buffer); + + scaler.Process(audio_buffer); + + // Set the new, higher, gain. + scaler.SetGain(higher_gain); + + // Ensure that the new, higher, gain is achieved gradually over one frame. + PopulateBuffer(audio_buffer); + + scaler.Process(audio_buffer); + for (int ch = 0; ch < num_channels(); ++ch) { + for (size_t i = 0; i < audio_buffer.num_frames() - 1; ++i) { + EXPECT_LT(audio_buffer.channels_const()[ch][i], + higher_gain * SampleValueForChannel(ch)); + EXPECT_LE(audio_buffer.channels_const()[ch][i], + audio_buffer.channels_const()[ch][i + 1]); + } + EXPECT_LE(audio_buffer.channels_const()[ch][audio_buffer.num_frames() - 1], + higher_gain * SampleValueForChannel(ch)); + } + + // Ensure that the new, higher, gain is achieved and stay unchanged. + for (int frame = 0; frame < kNumFramesToProcess; ++frame) { + PopulateBuffer(audio_buffer); + scaler.Process(audio_buffer); + + for (int ch = 0; ch < num_channels(); ++ch) { + for (size_t i = 0; i < audio_buffer.num_frames(); ++i) { + EXPECT_FLOAT_EQ(audio_buffer.channels_const()[ch][i], + higher_gain * SampleValueForChannel(ch)); + } + } + } + + // Set the new, lower, gain. + scaler.SetGain(lower_gain); + + // Ensure that the new, lower, gain is achieved gradually over one frame. + PopulateBuffer(audio_buffer); + scaler.Process(audio_buffer); + for (int ch = 0; ch < num_channels(); ++ch) { + for (size_t i = 0; i < audio_buffer.num_frames() - 1; ++i) { + EXPECT_GT(audio_buffer.channels_const()[ch][i], + lower_gain * SampleValueForChannel(ch)); + EXPECT_GE(audio_buffer.channels_const()[ch][i], + audio_buffer.channels_const()[ch][i + 1]); + } + EXPECT_GE(audio_buffer.channels_const()[ch][audio_buffer.num_frames() - 1], + lower_gain * SampleValueForChannel(ch)); + } + + // Ensure that the new, lower, gain is achieved and stay unchanged. + for (int frame = 0; frame < kNumFramesToProcess; ++frame) { + PopulateBuffer(audio_buffer); + scaler.Process(audio_buffer); + + for (int ch = 0; ch < num_channels(); ++ch) { + for (size_t i = 0; i < audio_buffer.num_frames(); ++i) { + EXPECT_FLOAT_EQ(audio_buffer.channels_const()[ch][i], + lower_gain * SampleValueForChannel(ch)); + } + } + } +} + +TEST(AudioSamplesScaler, UpwardsClamping) { + constexpr int kSampleRateHz = 48000; + constexpr int kNumChannels = 1; + constexpr float kGain = 10.f; + constexpr float kMaxClampedSampleValue = 32767.f; + static_assert(kGain > 1.f, ""); + + AudioSamplesScaler scaler(kGain); + + AudioBuffer audio_buffer(kSampleRateHz, kNumChannels, kSampleRateHz, + kNumChannels, kSampleRateHz, kNumChannels); + + for (int frame = 0; frame < kNumFramesToProcess; ++frame) { + for (size_t ch = 0; ch < audio_buffer.num_channels(); ++ch) { + test::FillBufferChannel( + kMaxClampedSampleValue - audio_buffer.num_channels() + 1.f + ch, ch, + audio_buffer); + } + + scaler.Process(audio_buffer); + for (int ch = 0; ch < kNumChannels; ++ch) { + for (size_t i = 0; i < audio_buffer.num_frames(); ++i) { + EXPECT_FLOAT_EQ(audio_buffer.channels_const()[ch][i], + kMaxClampedSampleValue); + } + } + } +} + +TEST(AudioSamplesScaler, DownwardsClamping) { + constexpr int kSampleRateHz = 48000; + constexpr int kNumChannels = 1; + constexpr float kGain = 10.f; + constexpr float kMinClampedSampleValue = -32768.f; + static_assert(kGain > 1.f, ""); + + AudioSamplesScaler scaler(kGain); + + AudioBuffer audio_buffer(kSampleRateHz, kNumChannels, kSampleRateHz, + kNumChannels, kSampleRateHz, kNumChannels); + + for (int frame = 0; frame < kNumFramesToProcess; ++frame) { + for (size_t ch = 0; ch < audio_buffer.num_channels(); ++ch) { + test::FillBufferChannel( + kMinClampedSampleValue + audio_buffer.num_channels() - 1.f + ch, ch, + audio_buffer); + } + + scaler.Process(audio_buffer); + for (int ch = 0; ch < kNumChannels; ++ch) { + for (size_t i = 0; i < audio_buffer.num_frames(); ++i) { + EXPECT_FLOAT_EQ(audio_buffer.channels_const()[ch][i], + kMinClampedSampleValue); + } + } + } +} + +} // namespace +} // namespace webrtc diff --git a/modules/audio_processing/capture_levels_adjuster/capture_levels_adjuster.cc b/modules/audio_processing/capture_levels_adjuster/capture_levels_adjuster.cc new file mode 100644 index 0000000000..dfda582915 --- /dev/null +++ b/modules/audio_processing/capture_levels_adjuster/capture_levels_adjuster.cc @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/capture_levels_adjuster/capture_levels_adjuster.h" + +#include "modules/audio_processing/audio_buffer.h" +#include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_minmax.h" + +namespace webrtc { + +namespace { + +constexpr int kMinAnalogMicGainLevel = 0; +constexpr int kMaxAnalogMicGainLevel = 255; + +float ComputeLevelBasedGain(int emulated_analog_mic_gain_level) { + static_assert( + kMinAnalogMicGainLevel == 0, + "The minimum gain level must be 0 for the maths below to work."); + static_assert(kMaxAnalogMicGainLevel > 0, + "The minimum gain level must be larger than 0 for the maths " + "below to work."); + constexpr float kGainToLevelMultiplier = 1.f / kMaxAnalogMicGainLevel; + + RTC_DCHECK_GE(emulated_analog_mic_gain_level, kMinAnalogMicGainLevel); + RTC_DCHECK_LE(emulated_analog_mic_gain_level, kMaxAnalogMicGainLevel); + return kGainToLevelMultiplier * emulated_analog_mic_gain_level; +} + +float ComputePreGain(float pre_gain, + int emulated_analog_mic_gain_level, + bool emulated_analog_mic_gain_enabled) { + return emulated_analog_mic_gain_enabled + ? pre_gain * ComputeLevelBasedGain(emulated_analog_mic_gain_level) + : pre_gain; +} + +} // namespace + +CaptureLevelsAdjuster::CaptureLevelsAdjuster( + bool emulated_analog_mic_gain_enabled, + int emulated_analog_mic_gain_level, + float pre_gain, + float post_gain) + : emulated_analog_mic_gain_enabled_(emulated_analog_mic_gain_enabled), + emulated_analog_mic_gain_level_(emulated_analog_mic_gain_level), + pre_gain_(pre_gain), + pre_adjustment_gain_(ComputePreGain(pre_gain_, + emulated_analog_mic_gain_level_, + emulated_analog_mic_gain_enabled_)), + pre_scaler_(pre_adjustment_gain_), + post_scaler_(post_gain) {} + +void CaptureLevelsAdjuster::ApplyPreLevelAdjustment(AudioBuffer& audio_buffer) { + pre_scaler_.Process(audio_buffer); +} + +void CaptureLevelsAdjuster::ApplyPostLevelAdjustment( + AudioBuffer& audio_buffer) { + post_scaler_.Process(audio_buffer); +} + +void CaptureLevelsAdjuster::SetPreGain(float pre_gain) { + pre_gain_ = pre_gain; + UpdatePreAdjustmentGain(); +} + +void CaptureLevelsAdjuster::SetPostGain(float post_gain) { + post_scaler_.SetGain(post_gain); +} + +void CaptureLevelsAdjuster::SetAnalogMicGainLevel(int level) { + RTC_DCHECK_GE(level, kMinAnalogMicGainLevel); + RTC_DCHECK_LE(level, kMaxAnalogMicGainLevel); + int clamped_level = + rtc::SafeClamp(level, kMinAnalogMicGainLevel, kMaxAnalogMicGainLevel); + + emulated_analog_mic_gain_level_ = clamped_level; + UpdatePreAdjustmentGain(); +} + +void CaptureLevelsAdjuster::UpdatePreAdjustmentGain() { + pre_adjustment_gain_ = + ComputePreGain(pre_gain_, emulated_analog_mic_gain_level_, + emulated_analog_mic_gain_enabled_); + pre_scaler_.SetGain(pre_adjustment_gain_); +} + +} // namespace webrtc diff --git a/modules/audio_processing/capture_levels_adjuster/capture_levels_adjuster.h b/modules/audio_processing/capture_levels_adjuster/capture_levels_adjuster.h new file mode 100644 index 0000000000..38b68ad06c --- /dev/null +++ b/modules/audio_processing/capture_levels_adjuster/capture_levels_adjuster.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef MODULES_AUDIO_PROCESSING_CAPTURE_LEVELS_ADJUSTER_CAPTURE_LEVELS_ADJUSTER_H_ +#define MODULES_AUDIO_PROCESSING_CAPTURE_LEVELS_ADJUSTER_CAPTURE_LEVELS_ADJUSTER_H_ + +#include + +#include "modules/audio_processing/audio_buffer.h" +#include "modules/audio_processing/capture_levels_adjuster/audio_samples_scaler.h" + +namespace webrtc { + +// Adjusts the level of the capture signal before and after all capture-side +// processing is done using a combination of explicitly specified gains +// and an emulated analog gain functionality where a specified analog level +// results in an additional gain. The pre-adjustment is achieved by combining +// the gain value `pre_gain` and the level `emulated_analog_mic_gain_level` to +// form a combined gain of `pre_gain`*`emulated_analog_mic_gain_level`/255 which +// is multiplied to each sample. The intention of the +// `emulated_analog_mic_gain_level` is to be controlled by the analog AGC +// functionality and to produce an emulated analog mic gain equal to +// `emulated_analog_mic_gain_level`/255. The post level adjustment is achieved +// by multiplying each sample with the value of `post_gain`. Any changes in the +// gains take are done smoothly over one frame and the scaled samples are +// clamped to fit into the allowed S16 sample range. +class CaptureLevelsAdjuster { + public: + // C-tor. The values for the level and the gains must fulfill + // 0 <= emulated_analog_mic_gain_level <= 255. + // 0.f <= pre_gain. + // 0.f <= post_gain. + CaptureLevelsAdjuster(bool emulated_analog_mic_gain_enabled, + int emulated_analog_mic_gain_level, + float pre_gain, + float post_gain); + CaptureLevelsAdjuster(const CaptureLevelsAdjuster&) = delete; + CaptureLevelsAdjuster& operator=(const CaptureLevelsAdjuster&) = delete; + + // Adjusts the level of the signal. This should be called before any of the + // other processing is performed. + void ApplyPreLevelAdjustment(AudioBuffer& audio_buffer); + + // Adjusts the level of the signal. This should be called after all of the + // other processing have been performed. + void ApplyPostLevelAdjustment(AudioBuffer& audio_buffer); + + // Sets the gain to apply to each sample before any of the other processing is + // performed. + void SetPreGain(float pre_gain); + + // Returns the total pre-adjustment gain applied, comprising both the pre_gain + // as well as the gain from the emulated analog mic, to each sample before any + // of the other processing is performed. + float GetPreAdjustmentGain() const { return pre_adjustment_gain_; } + + // Sets the gain to apply to each sample after all of the other processing + // have been performed. + void SetPostGain(float post_gain); + + // Sets the analog gain level to use for the emulated analog gain. + // `level` must be in the range [0...255]. + void SetAnalogMicGainLevel(int level); + + // Returns the current analog gain level used for the emulated analog gain. + int GetAnalogMicGainLevel() const { return emulated_analog_mic_gain_level_; } + + private: + // Updates the value of `pre_adjustment_gain_` based on the supplied values + // for `pre_gain` and `emulated_analog_mic_gain_level_`. + void UpdatePreAdjustmentGain(); + + const bool emulated_analog_mic_gain_enabled_; + int emulated_analog_mic_gain_level_; + float pre_gain_; + float pre_adjustment_gain_; + AudioSamplesScaler pre_scaler_; + AudioSamplesScaler post_scaler_; +}; +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_CAPTURE_LEVELS_ADJUSTER_CAPTURE_LEVELS_ADJUSTER_H_ diff --git a/modules/audio_processing/capture_levels_adjuster/capture_levels_adjuster_unittest.cc b/modules/audio_processing/capture_levels_adjuster/capture_levels_adjuster_unittest.cc new file mode 100644 index 0000000000..1183441a14 --- /dev/null +++ b/modules/audio_processing/capture_levels_adjuster/capture_levels_adjuster_unittest.cc @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/capture_levels_adjuster/capture_levels_adjuster.h" + +#include +#include + +#include "modules/audio_processing/test/audio_buffer_tools.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +float SampleValueForChannel(int channel) { + constexpr float kSampleBaseValue = 100.f; + constexpr float kSampleChannelOffset = 1.f; + return kSampleBaseValue + channel * kSampleChannelOffset; +} + +void PopulateBuffer(AudioBuffer& audio_buffer) { + for (size_t ch = 0; ch < audio_buffer.num_channels(); ++ch) { + test::FillBufferChannel(SampleValueForChannel(ch), ch, audio_buffer); + } +} + +float ComputeExpectedSignalGainAfterApplyPreLevelAdjustment( + bool emulated_analog_mic_gain_enabled, + int emulated_analog_mic_gain_level, + float pre_gain) { + if (!emulated_analog_mic_gain_enabled) { + return pre_gain; + } + return pre_gain * std::min(emulated_analog_mic_gain_level, 255) / 255.f; +} + +float ComputeExpectedSignalGainAfterApplyPostLevelAdjustment( + bool emulated_analog_mic_gain_enabled, + int emulated_analog_mic_gain_level, + float pre_gain, + float post_gain) { + return post_gain * ComputeExpectedSignalGainAfterApplyPreLevelAdjustment( + emulated_analog_mic_gain_enabled, + emulated_analog_mic_gain_level, pre_gain); +} + +constexpr int kNumFramesToProcess = 10; + +class CaptureLevelsAdjusterTest + : public ::testing::Test, + public ::testing::WithParamInterface< + std::tuple> { + protected: + int sample_rate_hz() const { return std::get<0>(GetParam()); } + int num_channels() const { return std::get<1>(GetParam()); } + bool emulated_analog_mic_gain_enabled() const { + return std::get<2>(GetParam()); + } + int emulated_analog_mic_gain_level() const { return std::get<3>(GetParam()); } + float pre_gain() const { return std::get<4>(GetParam()); } + float post_gain() const { return std::get<5>(GetParam()); } +}; + +INSTANTIATE_TEST_SUITE_P( + CaptureLevelsAdjusterTestSuite, + CaptureLevelsAdjusterTest, + ::testing::Combine(::testing::Values(16000, 32000, 48000), + ::testing::Values(1, 2, 4), + ::testing::Values(false, true), + ::testing::Values(21, 255), + ::testing::Values(0.1f, 1.f, 4.f), + ::testing::Values(0.1f, 1.f, 4.f))); + +TEST_P(CaptureLevelsAdjusterTest, InitialGainIsInstantlyAchieved) { + CaptureLevelsAdjuster adjuster(emulated_analog_mic_gain_enabled(), + emulated_analog_mic_gain_level(), pre_gain(), + post_gain()); + + AudioBuffer audio_buffer(sample_rate_hz(), num_channels(), sample_rate_hz(), + num_channels(), sample_rate_hz(), num_channels()); + + const float expected_signal_gain_after_pre_gain = + ComputeExpectedSignalGainAfterApplyPreLevelAdjustment( + emulated_analog_mic_gain_enabled(), emulated_analog_mic_gain_level(), + pre_gain()); + const float expected_signal_gain_after_post_level_adjustment = + ComputeExpectedSignalGainAfterApplyPostLevelAdjustment( + emulated_analog_mic_gain_enabled(), emulated_analog_mic_gain_level(), + pre_gain(), post_gain()); + + for (int frame = 0; frame < kNumFramesToProcess; ++frame) { + PopulateBuffer(audio_buffer); + adjuster.ApplyPreLevelAdjustment(audio_buffer); + EXPECT_FLOAT_EQ(adjuster.GetPreAdjustmentGain(), + expected_signal_gain_after_pre_gain); + + for (int ch = 0; ch < num_channels(); ++ch) { + for (size_t i = 0; i < audio_buffer.num_frames(); ++i) { + EXPECT_FLOAT_EQ( + audio_buffer.channels_const()[ch][i], + expected_signal_gain_after_pre_gain * SampleValueForChannel(ch)); + } + } + adjuster.ApplyPostLevelAdjustment(audio_buffer); + for (int ch = 0; ch < num_channels(); ++ch) { + for (size_t i = 0; i < audio_buffer.num_frames(); ++i) { + EXPECT_FLOAT_EQ(audio_buffer.channels_const()[ch][i], + expected_signal_gain_after_post_level_adjustment * + SampleValueForChannel(ch)); + } + } + } +} + +TEST_P(CaptureLevelsAdjusterTest, NewGainsAreAchieved) { + const int lower_emulated_analog_mic_gain_level = + emulated_analog_mic_gain_level(); + const float lower_pre_gain = pre_gain(); + const float lower_post_gain = post_gain(); + const int higher_emulated_analog_mic_gain_level = + std::min(lower_emulated_analog_mic_gain_level * 2, 255); + const float higher_pre_gain = lower_pre_gain * 2.f; + const float higher_post_gain = lower_post_gain * 2.f; + + CaptureLevelsAdjuster adjuster(emulated_analog_mic_gain_enabled(), + lower_emulated_analog_mic_gain_level, + lower_pre_gain, lower_post_gain); + + AudioBuffer audio_buffer(sample_rate_hz(), num_channels(), sample_rate_hz(), + num_channels(), sample_rate_hz(), num_channels()); + + const float expected_signal_gain_after_pre_gain = + ComputeExpectedSignalGainAfterApplyPreLevelAdjustment( + emulated_analog_mic_gain_enabled(), + higher_emulated_analog_mic_gain_level, higher_pre_gain); + const float expected_signal_gain_after_post_level_adjustment = + ComputeExpectedSignalGainAfterApplyPostLevelAdjustment( + emulated_analog_mic_gain_enabled(), + higher_emulated_analog_mic_gain_level, higher_pre_gain, + higher_post_gain); + + adjuster.SetPreGain(higher_pre_gain); + adjuster.SetPostGain(higher_post_gain); + adjuster.SetAnalogMicGainLevel(higher_emulated_analog_mic_gain_level); + + PopulateBuffer(audio_buffer); + adjuster.ApplyPreLevelAdjustment(audio_buffer); + adjuster.ApplyPostLevelAdjustment(audio_buffer); + EXPECT_EQ(adjuster.GetAnalogMicGainLevel(), + higher_emulated_analog_mic_gain_level); + + for (int frame = 1; frame < kNumFramesToProcess; ++frame) { + PopulateBuffer(audio_buffer); + adjuster.ApplyPreLevelAdjustment(audio_buffer); + EXPECT_FLOAT_EQ(adjuster.GetPreAdjustmentGain(), + expected_signal_gain_after_pre_gain); + for (int ch = 0; ch < num_channels(); ++ch) { + for (size_t i = 0; i < audio_buffer.num_frames(); ++i) { + EXPECT_FLOAT_EQ( + audio_buffer.channels_const()[ch][i], + expected_signal_gain_after_pre_gain * SampleValueForChannel(ch)); + } + } + + adjuster.ApplyPostLevelAdjustment(audio_buffer); + for (int ch = 0; ch < num_channels(); ++ch) { + for (size_t i = 0; i < audio_buffer.num_frames(); ++i) { + EXPECT_FLOAT_EQ(audio_buffer.channels_const()[ch][i], + expected_signal_gain_after_post_level_adjustment * + SampleValueForChannel(ch)); + } + } + + EXPECT_EQ(adjuster.GetAnalogMicGainLevel(), + higher_emulated_analog_mic_gain_level); + } +} + +} // namespace +} // namespace webrtc diff --git a/modules/audio_processing/common.h b/modules/audio_processing/common.h index d8532c5749..2c88c4e46c 100644 --- a/modules/audio_processing/common.h +++ b/modules/audio_processing/common.h @@ -16,6 +16,10 @@ namespace webrtc { +constexpr int RuntimeSettingQueueSize() { + return 100; +} + static inline size_t ChannelsFromLayout(AudioProcessing::ChannelLayout layout) { switch (layout) { case AudioProcessing::kMono: diff --git a/modules/audio_processing/debug.proto b/modules/audio_processing/debug.proto index 07cce23ba3..4bc1a52160 100644 --- a/modules/audio_processing/debug.proto +++ b/modules/audio_processing/debug.proto @@ -92,6 +92,7 @@ message RuntimeSetting { optional int32 playout_volume_change = 4; optional PlayoutAudioDeviceInfo playout_audio_device_change = 5; optional bool capture_output_used = 6; + optional float capture_post_gain = 7; } message Event { diff --git a/modules/audio_processing/g3doc/audio_processing_module.md b/modules/audio_processing/g3doc/audio_processing_module.md new file mode 100644 index 0000000000..bb80dc9882 --- /dev/null +++ b/modules/audio_processing/g3doc/audio_processing_module.md @@ -0,0 +1,26 @@ +# Audio Processing Module (APM) + + + + +## Overview + +The APM is responsible for applying speech enhancements effects to the +microphone signal. These effects are required for VoIP calling and some +examples include echo cancellation (AEC), noise suppression (NS) and +automatic gain control (AGC). + +The API for APM resides in [`/modules/audio_processing/include`][https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_processing/include]. +APM is created using the [`AudioProcessingBuilder`][https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/audio_processing/include/audio_processing.h] +builder that allows it to be customized and configured. + +Some specific aspects of APM include that: +* APM is fully thread-safe in that it can be accessed concurrently from + different threads. +* APM handles for any input sample rates < 384 kHz and achieves this by + automatic reconfiguration whenever a new sample format is observed. +* APM handles any number of microphone channels and loudspeaker channels, with + the same automatic reconfiguration as for the sample rates. + + +APM can either be used as part of the WebRTC native pipeline, or standalone. diff --git a/modules/audio_processing/gain_controller2.cc b/modules/audio_processing/gain_controller2.cc index 44770653e5..74b63c9432 100644 --- a/modules/audio_processing/gain_controller2.cc +++ b/modules/audio_processing/gain_controller2.cc @@ -24,31 +24,35 @@ namespace webrtc { int GainController2::instance_count_ = 0; GainController2::GainController2() - : data_dumper_( - new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))), + : data_dumper_(rtc::AtomicOps::Increment(&instance_count_)), gain_applier_(/*hard_clip_samples=*/false, - /*initial_gain_factor=*/0.f), - limiter_(static_cast(48000), data_dumper_.get(), "Agc2"), + /*initial_gain_factor=*/0.0f), + limiter_(static_cast(48000), &data_dumper_, "Agc2"), calls_since_last_limiter_log_(0) { if (config_.adaptive_digital.enabled) { - adaptive_agc_.reset(new AdaptiveAgc(data_dumper_.get())); + adaptive_agc_ = + std::make_unique(&data_dumper_, config_.adaptive_digital); } } GainController2::~GainController2() = default; -void GainController2::Initialize(int sample_rate_hz) { +void GainController2::Initialize(int sample_rate_hz, int num_channels) { RTC_DCHECK(sample_rate_hz == AudioProcessing::kSampleRate8kHz || sample_rate_hz == AudioProcessing::kSampleRate16kHz || sample_rate_hz == AudioProcessing::kSampleRate32kHz || sample_rate_hz == AudioProcessing::kSampleRate48kHz); limiter_.SetSampleRate(sample_rate_hz); - data_dumper_->InitiateNewSetOfRecordings(); - data_dumper_->DumpRaw("sample_rate_hz", sample_rate_hz); + if (adaptive_agc_) { + adaptive_agc_->Initialize(sample_rate_hz, num_channels); + } + data_dumper_.InitiateNewSetOfRecordings(); + data_dumper_.DumpRaw("sample_rate_hz", sample_rate_hz); calls_since_last_limiter_log_ = 0; } void GainController2::Process(AudioBuffer* audio) { + data_dumper_.DumpRaw("agc2_notified_analog_level", analog_level_); AudioFrameView float_frame(audio->channels(), audio->num_channels(), audio->num_frames()); // Apply fixed gain first, then the adaptive one. @@ -73,7 +77,7 @@ void GainController2::Process(AudioBuffer* audio) { void GainController2::NotifyAnalogLevel(int level) { if (analog_level_ != level && adaptive_agc_) { - adaptive_agc_->Reset(); + adaptive_agc_->HandleInputGainChange(); } analog_level_ = level; } @@ -90,7 +94,8 @@ void GainController2::ApplyConfig( } gain_applier_.SetGainFactor(DbToRatio(config_.fixed_digital.gain_db)); if (config_.adaptive_digital.enabled) { - adaptive_agc_.reset(new AdaptiveAgc(data_dumper_.get(), config_)); + adaptive_agc_ = + std::make_unique(&data_dumper_, config_.adaptive_digital); } else { adaptive_agc_.reset(); } diff --git a/modules/audio_processing/gain_controller2.h b/modules/audio_processing/gain_controller2.h index 31665bdeac..ce758c7834 100644 --- a/modules/audio_processing/gain_controller2.h +++ b/modules/audio_processing/gain_controller2.h @@ -18,11 +18,11 @@ #include "modules/audio_processing/agc2/gain_applier.h" #include "modules/audio_processing/agc2/limiter.h" #include "modules/audio_processing/include/audio_processing.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/constructor_magic.h" namespace webrtc { -class ApmDataDumper; class AudioBuffer; // Gain Controller 2 aims to automatically adjust levels by acting on the @@ -30,9 +30,11 @@ class AudioBuffer; class GainController2 { public: GainController2(); + GainController2(const GainController2&) = delete; + GainController2& operator=(const GainController2&) = delete; ~GainController2(); - void Initialize(int sample_rate_hz); + void Initialize(int sample_rate_hz, int num_channels); void Process(AudioBuffer* audio); void NotifyAnalogLevel(int level); @@ -41,15 +43,13 @@ class GainController2 { private: static int instance_count_; - std::unique_ptr data_dumper_; + ApmDataDumper data_dumper_; AudioProcessing::Config::GainController2 config_; GainApplier gain_applier_; std::unique_ptr adaptive_agc_; Limiter limiter_; int calls_since_last_limiter_log_; int analog_level_ = -1; - - RTC_DISALLOW_COPY_AND_ASSIGN(GainController2); }; } // namespace webrtc diff --git a/modules/audio_processing/gain_controller2_unittest.cc b/modules/audio_processing/gain_controller2_unittest.cc index 09bad5087d..85c08bb750 100644 --- a/modules/audio_processing/gain_controller2_unittest.cc +++ b/modules/audio_processing/gain_controller2_unittest.cc @@ -11,6 +11,7 @@ #include "modules/audio_processing/gain_controller2.h" #include +#include #include #include "api/array_view.h" @@ -64,11 +65,12 @@ std::unique_ptr CreateAgc2FixedDigitalMode( size_t sample_rate_hz) { auto agc2 = std::make_unique(); agc2->ApplyConfig(CreateAgc2FixedDigitalModeConfig(fixed_gain_db)); - agc2->Initialize(sample_rate_hz); + agc2->Initialize(sample_rate_hz, /*num_channels=*/1); return agc2; } -float GainAfterProcessingFile(GainController2* gain_controller) { +float GainDbAfterProcessingFile(GainController2& gain_controller, + int max_duration_ms) { // Set up an AudioBuffer to be filled from the speech file. constexpr size_t kStereo = 2u; const StreamConfig capture_config(AudioProcessing::kSampleRate48kHz, kStereo, @@ -82,24 +84,29 @@ float GainAfterProcessingFile(GainController2* gain_controller) { std::vector capture_input(capture_config.num_frames() * capture_config.num_channels()); - // The file should contain at least this many frames. Every iteration, we put - // a frame through the gain controller. - const int kNumFramesToProcess = 100; - for (int frame_no = 0; frame_no < kNumFramesToProcess; ++frame_no) { + // Process the input file which must be long enough to cover + // `max_duration_ms`. + RTC_DCHECK_GT(max_duration_ms, 0); + const int num_frames = rtc::CheckedDivExact(max_duration_ms, 10); + for (int i = 0; i < num_frames; ++i) { ReadFloatSamplesFromStereoFile(capture_config.num_frames(), capture_config.num_channels(), &capture_file, capture_input); - test::CopyVectorToAudioBuffer(capture_config, capture_input, &ab); - gain_controller->Process(&ab); + gain_controller.Process(&ab); } - // Send in a last frame with values constant 1 (It's low enough to detect high - // gain, and for ease of computation). The applied gain is the result. + // Send in a last frame with minimum dBFS level. constexpr float sample_value = 1.f; SetAudioBufferSamples(sample_value, &ab); - gain_controller->Process(&ab); - return ab.channels()[0][0]; + gain_controller.Process(&ab); + // Measure the RMS level after processing. + float rms = 0.0f; + for (size_t i = 0; i < capture_config.num_frames(); ++i) { + rms += ab.channels()[0][i] * ab.channels()[0][i]; + } + // Return the applied gain in dB. + return 20.0f * std::log10(std::sqrt(rms / capture_config.num_frames())); } } // namespace @@ -324,34 +331,21 @@ INSTANTIATE_TEST_SUITE_P( 48000, true))); -TEST(GainController2, UsageSaturationMargin) { - GainController2 gain_controller2; - gain_controller2.Initialize(AudioProcessing::kSampleRate48kHz); - - AudioProcessing::Config::GainController2 config; - // Check that samples are not amplified as much when extra margin is - // high. They should not be amplified at all, but only after convergence. GC2 - // starts with a gain, and it takes time until it's down to 0 dB. - config.fixed_digital.gain_db = 0.f; - config.adaptive_digital.enabled = true; - config.adaptive_digital.extra_saturation_margin_db = 50.f; - gain_controller2.ApplyConfig(config); - - EXPECT_LT(GainAfterProcessingFile(&gain_controller2), 2.f); -} - -TEST(GainController2, UsageNoSaturationMargin) { +// Checks that the gain applied at the end of a PCM samples file is close to the +// expected value. +TEST(GainController2, CheckGainAdaptiveDigital) { + constexpr float kExpectedGainDb = 4.3f; + constexpr float kToleranceDb = 0.5f; GainController2 gain_controller2; - gain_controller2.Initialize(AudioProcessing::kSampleRate48kHz); - + gain_controller2.Initialize(AudioProcessing::kSampleRate48kHz, + /*num_channels=*/1); AudioProcessing::Config::GainController2 config; - // Check that some gain is applied if there is no margin. - config.fixed_digital.gain_db = 0.f; + config.fixed_digital.gain_db = 0.0f; config.adaptive_digital.enabled = true; - config.adaptive_digital.extra_saturation_margin_db = 0.f; gain_controller2.ApplyConfig(config); - - EXPECT_GT(GainAfterProcessingFile(&gain_controller2), 2.f); + EXPECT_NEAR( + GainDbAfterProcessingFile(gain_controller2, /*max_duration_ms=*/2000), + kExpectedGainDb, kToleranceDb); } } // namespace test diff --git a/modules/audio_processing/include/aec_dump.h b/modules/audio_processing/include/aec_dump.h index ed5acb0943..a7769d9973 100644 --- a/modules/audio_processing/include/aec_dump.h +++ b/modules/audio_processing/include/aec_dump.h @@ -15,9 +15,9 @@ #include +#include "absl/base/attributes.h" #include "modules/audio_processing/include/audio_frame_view.h" #include "modules/audio_processing/include/audio_processing.h" -#include "rtc_base/deprecation.h" namespace webrtc { @@ -76,7 +76,8 @@ class AecDump { // Logs Event::Type INIT message. virtual void WriteInitMessage(const ProcessingConfig& api_format, int64_t time_now_ms) = 0; - RTC_DEPRECATED void WriteInitMessage(const ProcessingConfig& api_format) { + ABSL_DEPRECATED("") + void WriteInitMessage(const ProcessingConfig& api_format) { WriteInitMessage(api_format, 0); } diff --git a/modules/audio_processing/include/audio_processing.cc b/modules/audio_processing/include/audio_processing.cc index 3bc00751cc..44a90d6e76 100644 --- a/modules/audio_processing/include/audio_processing.cc +++ b/modules/audio_processing/include/audio_processing.cc @@ -46,25 +46,17 @@ std::string GainController1ModeToString(const Agc1Config::Mode& mode) { RTC_CHECK_NOTREACHED(); } -std::string GainController2LevelEstimatorToString( - const Agc2Config::LevelEstimator& level) { - switch (level) { - case Agc2Config::LevelEstimator::kRms: - return "Rms"; - case Agc2Config::LevelEstimator::kPeak: - return "Peak"; +std::string GainController2NoiseEstimatorToString( + const Agc2Config::NoiseEstimator& type) { + switch (type) { + case Agc2Config::NoiseEstimator::kStationaryNoise: + return "StationaryNoise"; + case Agc2Config::NoiseEstimator::kNoiseFloor: + return "NoiseFloor"; } RTC_CHECK_NOTREACHED(); } -int GetDefaultMaxInternalRate() { -#ifdef WEBRTC_ARCH_ARM_FAMILY - return 32000; -#else - return 48000; -#endif -} - } // namespace constexpr int AudioProcessing::kNativeSampleRatesHz[]; @@ -72,9 +64,6 @@ constexpr int AudioProcessing::kNativeSampleRatesHz[]; void CustomProcessing::SetRuntimeSetting( AudioProcessing::RuntimeSetting setting) {} -AudioProcessing::Config::Pipeline::Pipeline() - : maximum_internal_processing_rate(GetDefaultMaxInternalRate()) {} - bool Agc1Config::operator==(const Agc1Config& rhs) const { const auto& analog_lhs = analog_gain_controller; const auto& analog_rhs = rhs.analog_gain_controller; @@ -87,95 +76,152 @@ bool Agc1Config::operator==(const Agc1Config& rhs) const { analog_lhs.enabled == analog_rhs.enabled && analog_lhs.startup_min_volume == analog_rhs.startup_min_volume && analog_lhs.clipped_level_min == analog_rhs.clipped_level_min && - analog_lhs.enable_agc2_level_estimator == - analog_rhs.enable_agc2_level_estimator && analog_lhs.enable_digital_adaptive == - analog_rhs.enable_digital_adaptive; + analog_rhs.enable_digital_adaptive && + analog_lhs.clipped_level_step == analog_rhs.clipped_level_step && + analog_lhs.clipped_ratio_threshold == + analog_rhs.clipped_ratio_threshold && + analog_lhs.clipped_wait_frames == analog_rhs.clipped_wait_frames && + analog_lhs.clipping_predictor.mode == + analog_rhs.clipping_predictor.mode && + analog_lhs.clipping_predictor.window_length == + analog_rhs.clipping_predictor.window_length && + analog_lhs.clipping_predictor.reference_window_length == + analog_rhs.clipping_predictor.reference_window_length && + analog_lhs.clipping_predictor.reference_window_delay == + analog_rhs.clipping_predictor.reference_window_delay && + analog_lhs.clipping_predictor.clipping_threshold == + analog_rhs.clipping_predictor.clipping_threshold && + analog_lhs.clipping_predictor.crest_factor_margin == + analog_rhs.clipping_predictor.crest_factor_margin; } -bool Agc2Config::operator==(const Agc2Config& rhs) const { - const auto& adaptive_lhs = adaptive_digital; - const auto& adaptive_rhs = rhs.adaptive_digital; +bool Agc2Config::AdaptiveDigital::operator==( + const Agc2Config::AdaptiveDigital& rhs) const { + return enabled == rhs.enabled && dry_run == rhs.dry_run && + noise_estimator == rhs.noise_estimator && + vad_reset_period_ms == rhs.vad_reset_period_ms && + adjacent_speech_frames_threshold == + rhs.adjacent_speech_frames_threshold && + max_gain_change_db_per_second == rhs.max_gain_change_db_per_second && + max_output_noise_level_dbfs == rhs.max_output_noise_level_dbfs && + sse2_allowed == rhs.sse2_allowed && avx2_allowed == rhs.avx2_allowed && + neon_allowed == rhs.neon_allowed; +} +bool Agc2Config::operator==(const Agc2Config& rhs) const { return enabled == rhs.enabled && fixed_digital.gain_db == rhs.fixed_digital.gain_db && - adaptive_lhs.enabled == adaptive_rhs.enabled && - adaptive_lhs.vad_probability_attack == - adaptive_rhs.vad_probability_attack && - adaptive_lhs.level_estimator == adaptive_rhs.level_estimator && - adaptive_lhs.level_estimator_adjacent_speech_frames_threshold == - adaptive_rhs.level_estimator_adjacent_speech_frames_threshold && - adaptive_lhs.use_saturation_protector == - adaptive_rhs.use_saturation_protector && - adaptive_lhs.initial_saturation_margin_db == - adaptive_rhs.initial_saturation_margin_db && - adaptive_lhs.extra_saturation_margin_db == - adaptive_rhs.extra_saturation_margin_db && - adaptive_lhs.gain_applier_adjacent_speech_frames_threshold == - adaptive_rhs.gain_applier_adjacent_speech_frames_threshold && - adaptive_lhs.max_gain_change_db_per_second == - adaptive_rhs.max_gain_change_db_per_second && - adaptive_lhs.max_output_noise_level_dbfs == - adaptive_rhs.max_output_noise_level_dbfs; + adaptive_digital == rhs.adaptive_digital; +} + +bool AudioProcessing::Config::CaptureLevelAdjustment::operator==( + const AudioProcessing::Config::CaptureLevelAdjustment& rhs) const { + return enabled == rhs.enabled && pre_gain_factor == rhs.pre_gain_factor && + post_gain_factor && rhs.post_gain_factor && + analog_mic_gain_emulation == rhs.analog_mic_gain_emulation; +} + +bool AudioProcessing::Config::CaptureLevelAdjustment::AnalogMicGainEmulation:: +operator==(const AudioProcessing::Config::CaptureLevelAdjustment:: + AnalogMicGainEmulation& rhs) const { + return enabled == rhs.enabled && initial_level == rhs.initial_level; } std::string AudioProcessing::Config::ToString() const { char buf[2048]; rtc::SimpleStringBuilder builder(buf); - builder << "AudioProcessing::Config{ " - "pipeline: { " - "maximum_internal_processing_rate: " - << pipeline.maximum_internal_processing_rate - << ", multi_channel_render: " << pipeline.multi_channel_render - << ", multi_channel_capture: " << pipeline.multi_channel_capture - << " }, pre_amplifier: { enabled: " << pre_amplifier.enabled - << ", fixed_gain_factor: " << pre_amplifier.fixed_gain_factor - << " }, high_pass_filter: { enabled: " << high_pass_filter.enabled - << " }, echo_canceller: { enabled: " << echo_canceller.enabled - << ", mobile_mode: " << echo_canceller.mobile_mode - << ", enforce_high_pass_filtering: " - << echo_canceller.enforce_high_pass_filtering - << " }, noise_suppression: { enabled: " << noise_suppression.enabled - << ", level: " - << NoiseSuppressionLevelToString(noise_suppression.level) - << " }, transient_suppression: { enabled: " - << transient_suppression.enabled - << " }, voice_detection: { enabled: " << voice_detection.enabled - << " }, gain_controller1: { enabled: " << gain_controller1.enabled - << ", mode: " << GainController1ModeToString(gain_controller1.mode) - << ", target_level_dbfs: " << gain_controller1.target_level_dbfs - << ", compression_gain_db: " << gain_controller1.compression_gain_db - << ", enable_limiter: " << gain_controller1.enable_limiter - << ", analog_level_minimum: " << gain_controller1.analog_level_minimum - << ", analog_level_maximum: " << gain_controller1.analog_level_maximum - << " }, gain_controller2: { enabled: " << gain_controller2.enabled - << ", fixed_digital: { gain_db: " - << gain_controller2.fixed_digital.gain_db - << " }, adaptive_digital: { enabled: " - << gain_controller2.adaptive_digital.enabled - << ", level_estimator: { vad_probability_attack: " - << gain_controller2.adaptive_digital.vad_probability_attack - << ", type: " - << GainController2LevelEstimatorToString( - gain_controller2.adaptive_digital.level_estimator) - << ", adjacent_speech_frames_threshold: " - << gain_controller2.adaptive_digital - .level_estimator_adjacent_speech_frames_threshold - << ", initial_saturation_margin_db: " - << gain_controller2.adaptive_digital.initial_saturation_margin_db - << ", extra_saturation_margin_db: " - << gain_controller2.adaptive_digital.extra_saturation_margin_db - << " }, gain_applier: { adjacent_speech_frames_threshold: " - << gain_controller2.adaptive_digital - .gain_applier_adjacent_speech_frames_threshold - << ", max_gain_change_db_per_second: " - << gain_controller2.adaptive_digital.max_gain_change_db_per_second - << ", max_output_noise_level_dbfs: " - << gain_controller2.adaptive_digital.max_output_noise_level_dbfs - << " }}}, residual_echo_detector: { enabled: " - << residual_echo_detector.enabled - << " }, level_estimation: { enabled: " << level_estimation.enabled - << " }}"; + builder + << "AudioProcessing::Config{ " + "pipeline: { " + "maximum_internal_processing_rate: " + << pipeline.maximum_internal_processing_rate + << ", multi_channel_render: " << pipeline.multi_channel_render + << ", multi_channel_capture: " << pipeline.multi_channel_capture + << " }, pre_amplifier: { enabled: " << pre_amplifier.enabled + << ", fixed_gain_factor: " << pre_amplifier.fixed_gain_factor + << " },capture_level_adjustment: { enabled: " + << capture_level_adjustment.enabled + << ", pre_gain_factor: " << capture_level_adjustment.pre_gain_factor + << ", post_gain_factor: " << capture_level_adjustment.post_gain_factor + << ", analog_mic_gain_emulation: { enabled: " + << capture_level_adjustment.analog_mic_gain_emulation.enabled + << ", initial_level: " + << capture_level_adjustment.analog_mic_gain_emulation.initial_level + << " }}, high_pass_filter: { enabled: " << high_pass_filter.enabled + << " }, echo_canceller: { enabled: " << echo_canceller.enabled + << ", mobile_mode: " << echo_canceller.mobile_mode + << ", enforce_high_pass_filtering: " + << echo_canceller.enforce_high_pass_filtering + << " }, noise_suppression: { enabled: " << noise_suppression.enabled + << ", level: " << NoiseSuppressionLevelToString(noise_suppression.level) + << " }, transient_suppression: { enabled: " + << transient_suppression.enabled + << " }, voice_detection: { enabled: " << voice_detection.enabled + << " }, gain_controller1: { enabled: " << gain_controller1.enabled + << ", mode: " << GainController1ModeToString(gain_controller1.mode) + << ", target_level_dbfs: " << gain_controller1.target_level_dbfs + << ", compression_gain_db: " << gain_controller1.compression_gain_db + << ", enable_limiter: " << gain_controller1.enable_limiter + << ", analog_level_minimum: " << gain_controller1.analog_level_minimum + << ", analog_level_maximum: " << gain_controller1.analog_level_maximum + << ", analog_gain_controller { enabled: " + << gain_controller1.analog_gain_controller.enabled + << ", startup_min_volume: " + << gain_controller1.analog_gain_controller.startup_min_volume + << ", clipped_level_min: " + << gain_controller1.analog_gain_controller.clipped_level_min + << ", enable_digital_adaptive: " + << gain_controller1.analog_gain_controller.enable_digital_adaptive + << ", clipped_level_step: " + << gain_controller1.analog_gain_controller.clipped_level_step + << ", clipped_ratio_threshold: " + << gain_controller1.analog_gain_controller.clipped_ratio_threshold + << ", clipped_wait_frames: " + << gain_controller1.analog_gain_controller.clipped_wait_frames + << ", clipping_predictor: { enabled: " + << gain_controller1.analog_gain_controller.clipping_predictor.enabled + << ", mode: " + << gain_controller1.analog_gain_controller.clipping_predictor.mode + << ", window_length: " + << gain_controller1.analog_gain_controller.clipping_predictor + .window_length + << ", reference_window_length: " + << gain_controller1.analog_gain_controller.clipping_predictor + .reference_window_length + << ", reference_window_delay: " + << gain_controller1.analog_gain_controller.clipping_predictor + .reference_window_delay + << ", clipping_threshold: " + << gain_controller1.analog_gain_controller.clipping_predictor + .clipping_threshold + << ", crest_factor_margin: " + << gain_controller1.analog_gain_controller.clipping_predictor + .crest_factor_margin + << " }}}, gain_controller2: { enabled: " << gain_controller2.enabled + << ", fixed_digital: { gain_db: " + << gain_controller2.fixed_digital.gain_db + << " }, adaptive_digital: { enabled: " + << gain_controller2.adaptive_digital.enabled + << ", dry_run: " << gain_controller2.adaptive_digital.dry_run + << ", noise_estimator: " + << GainController2NoiseEstimatorToString( + gain_controller2.adaptive_digital.noise_estimator) + << ", vad_reset_period_ms: " + << gain_controller2.adaptive_digital.vad_reset_period_ms + << ", adjacent_speech_frames_threshold: " + << gain_controller2.adaptive_digital.adjacent_speech_frames_threshold + << ", max_gain_change_db_per_second: " + << gain_controller2.adaptive_digital.max_gain_change_db_per_second + << ", max_output_noise_level_dbfs: " + << gain_controller2.adaptive_digital.max_output_noise_level_dbfs + << ", sse2_allowed: " << gain_controller2.adaptive_digital.sse2_allowed + << ", avx2_allowed: " << gain_controller2.adaptive_digital.avx2_allowed + << ", neon_allowed: " << gain_controller2.adaptive_digital.neon_allowed + << "}}, residual_echo_detector: { enabled: " + << residual_echo_detector.enabled + << " }, level_estimation: { enabled: " << level_estimation.enabled + << " }}"; return builder.str(); } diff --git a/modules/audio_processing/include/audio_processing.h b/modules/audio_processing/include/audio_processing.h index 942e0c0ce2..64b1b5d107 100644 --- a/modules/audio_processing/include/audio_processing.h +++ b/modules/audio_processing/include/audio_processing.h @@ -32,7 +32,6 @@ #include "modules/audio_processing/include/config.h" #include "rtc_base/arraysize.h" #include "rtc_base/constructor_magic.h" -#include "rtc_base/deprecation.h" #include "rtc_base/ref_count.h" #include "rtc_base/system/file_wrapper.h" #include "rtc_base/system/rtc_export.h" @@ -60,9 +59,9 @@ class CustomProcessing; // // Must be provided through AudioProcessingBuilder().Create(config). #if defined(WEBRTC_CHROMIUM_BUILD) -static const int kAgcStartupMinVolume = 85; +static constexpr int kAgcStartupMinVolume = 85; #else -static const int kAgcStartupMinVolume = 0; +static constexpr int kAgcStartupMinVolume = 0; #endif // defined(WEBRTC_CHROMIUM_BUILD) static constexpr int kClippedLevelMin = 70; @@ -72,32 +71,13 @@ static constexpr int kClippedLevelMin = 70; struct ExperimentalAgc { ExperimentalAgc() = default; explicit ExperimentalAgc(bool enabled) : enabled(enabled) {} - ExperimentalAgc(bool enabled, - bool enabled_agc2_level_estimator, - bool digital_adaptive_disabled) - : enabled(enabled), - enabled_agc2_level_estimator(enabled_agc2_level_estimator), - digital_adaptive_disabled(digital_adaptive_disabled) {} - // Deprecated constructor: will be removed. - ExperimentalAgc(bool enabled, - bool enabled_agc2_level_estimator, - bool digital_adaptive_disabled, - bool analyze_before_aec) - : enabled(enabled), - enabled_agc2_level_estimator(enabled_agc2_level_estimator), - digital_adaptive_disabled(digital_adaptive_disabled) {} ExperimentalAgc(bool enabled, int startup_min_volume) : enabled(enabled), startup_min_volume(startup_min_volume) {} - ExperimentalAgc(bool enabled, int startup_min_volume, int clipped_level_min) - : enabled(enabled), - startup_min_volume(startup_min_volume), - clipped_level_min(clipped_level_min) {} static const ConfigOptionID identifier = ConfigOptionID::kExperimentalAgc; bool enabled = true; int startup_min_volume = kAgcStartupMinVolume; // Lowest microphone level that will be applied in response to clipping. int clipped_level_min = kClippedLevelMin; - bool enabled_agc2_level_estimator = false; bool digital_adaptive_disabled = false; }; @@ -214,13 +194,9 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface { // Sets the properties of the audio processing pipeline. struct RTC_EXPORT Pipeline { - Pipeline(); - // Maximum allowed processing rate used internally. May only be set to - // 32000 or 48000 and any differing values will be treated as 48000. The - // default rate is currently selected based on the CPU architecture, but - // that logic may change. - int maximum_internal_processing_rate; + // 32000 or 48000 and any differing values will be treated as 48000. + int maximum_internal_processing_rate = 48000; // Allow multi-channel processing of render audio. bool multi_channel_render = false; // Allow multi-channel processing of capture audio when AEC3 is active @@ -230,11 +206,37 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface { // Enabled the pre-amplifier. It amplifies the capture signal // before any other processing is done. + // TODO(webrtc:5298): Deprecate and use the pre-gain functionality in + // capture_level_adjustment instead. struct PreAmplifier { bool enabled = false; - float fixed_gain_factor = 1.f; + float fixed_gain_factor = 1.0f; } pre_amplifier; + // Functionality for general level adjustment in the capture pipeline. This + // should not be used together with the legacy PreAmplifier functionality. + struct CaptureLevelAdjustment { + bool operator==(const CaptureLevelAdjustment& rhs) const; + bool operator!=(const CaptureLevelAdjustment& rhs) const { + return !(*this == rhs); + } + bool enabled = false; + // The `pre_gain_factor` scales the signal before any processing is done. + float pre_gain_factor = 1.0f; + // The `post_gain_factor` scales the signal after all processing is done. + float post_gain_factor = 1.0f; + struct AnalogMicGainEmulation { + bool operator==(const AnalogMicGainEmulation& rhs) const; + bool operator!=(const AnalogMicGainEmulation& rhs) const { + return !(*this == rhs); + } + bool enabled = false; + // Initial analog gain level to use for the emulated analog gain. Must + // be in the range [0...255]. + int initial_level = 255; + } analog_mic_gain_emulation; + } capture_level_adjustment; + struct HighPassFilter { bool enabled = false; bool apply_in_full_band = true; @@ -273,7 +275,7 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface { // in the analog mode, prescribing an analog gain to be applied at the audio // HAL. // Recommended to be enabled on the client-side. - struct GainController1 { + struct RTC_EXPORT GainController1 { bool operator==(const GainController1& rhs) const; bool operator!=(const GainController1& rhs) const { return !(*this == rhs); @@ -331,8 +333,44 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface { // Lowest analog microphone level that will be applied in response to // clipping. int clipped_level_min = kClippedLevelMin; - bool enable_agc2_level_estimator = false; bool enable_digital_adaptive = true; + // Amount the microphone level is lowered with every clipping event. + // Limited to (0, 255]. + int clipped_level_step = 15; + // Proportion of clipped samples required to declare a clipping event. + // Limited to (0.f, 1.f). + float clipped_ratio_threshold = 0.1f; + // Time in frames to wait after a clipping event before checking again. + // Limited to values higher than 0. + int clipped_wait_frames = 300; + + // Enables clipping prediction functionality. + struct ClippingPredictor { + bool enabled = false; + enum Mode { + // Clipping event prediction mode with fixed step estimation. + kClippingEventPrediction, + // Clipped peak estimation mode with adaptive step estimation. + kAdaptiveStepClippingPeakPrediction, + // Clipped peak estimation mode with fixed step estimation. + kFixedStepClippingPeakPrediction, + }; + Mode mode = kClippingEventPrediction; + // Number of frames in the sliding analysis window. + int window_length = 5; + // Number of frames in the sliding reference window. + int reference_window_length = 5; + // Reference window delay (unit: number of frames). + int reference_window_delay = 5; + // Clipping prediction threshold (dBFS). + float clipping_threshold = -1.0f; + // Crest factor drop threshold (dB). + float crest_factor_margin = 3.0f; + // If true, the recommended clipped level step is used to modify the + // analog gain. Otherwise, the predictor runs without affecting the + // analog gain. + bool use_predicted_step = true; + } clipping_predictor; } analog_gain_controller; } gain_controller1; @@ -342,32 +380,44 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface { // setting |fixed_gain_db|, the limiter can be turned into a compressor that // first applies a fixed gain. The adaptive digital AGC can be turned off by // setting |adaptive_digital_mode=false|. - struct GainController2 { + struct RTC_EXPORT GainController2 { bool operator==(const GainController2& rhs) const; bool operator!=(const GainController2& rhs) const { return !(*this == rhs); } + // TODO(crbug.com/webrtc/7494): Remove `LevelEstimator`. enum LevelEstimator { kRms, kPeak }; + enum NoiseEstimator { kStationaryNoise, kNoiseFloor }; bool enabled = false; struct FixedDigital { - float gain_db = 0.f; + float gain_db = 0.0f; } fixed_digital; - struct AdaptiveDigital { + struct RTC_EXPORT AdaptiveDigital { + bool operator==(const AdaptiveDigital& rhs) const; + bool operator!=(const AdaptiveDigital& rhs) const { + return !(*this == rhs); + } + bool enabled = false; - float vad_probability_attack = 1.f; - LevelEstimator level_estimator = kRms; - int level_estimator_adjacent_speech_frames_threshold = 1; - // TODO(crbug.com/webrtc/7494): Remove `use_saturation_protector`. - bool use_saturation_protector = true; - float initial_saturation_margin_db = 20.f; - float extra_saturation_margin_db = 2.f; - int gain_applier_adjacent_speech_frames_threshold = 1; - float max_gain_change_db_per_second = 3.f; - float max_output_noise_level_dbfs = -50.f; + // Run the adaptive digital controller but the signal is not modified. + bool dry_run = false; + NoiseEstimator noise_estimator = kNoiseFloor; + int vad_reset_period_ms = 1500; + int adjacent_speech_frames_threshold = 12; + float max_gain_change_db_per_second = 3.0f; + float max_output_noise_level_dbfs = -50.0f; bool sse2_allowed = true; bool avx2_allowed = true; bool neon_allowed = true; + // TODO(crbug.com/webrtc/7494): Remove deprecated settings below. + float vad_probability_attack = 1.0f; + LevelEstimator level_estimator = kRms; + int level_estimator_adjacent_speech_frames_threshold = 12; + bool use_saturation_protector = true; + float initial_saturation_margin_db = 25.0f; + float extra_saturation_margin_db = 5.0f; + int gain_applier_adjacent_speech_frames_threshold = 12; } adaptive_digital; } gain_controller2; @@ -406,6 +456,7 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface { kPlayoutVolumeChange, kCustomRenderProcessingRuntimeSetting, kPlayoutAudioDeviceChange, + kCapturePostGain, kCaptureOutputUsed }; @@ -415,14 +466,17 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface { int max_volume; // Maximum play-out volume. }; - RuntimeSetting() : type_(Type::kNotSpecified), value_(0.f) {} + RuntimeSetting() : type_(Type::kNotSpecified), value_(0.0f) {} ~RuntimeSetting() = default; static RuntimeSetting CreateCapturePreGain(float gain) { - RTC_DCHECK_GE(gain, 1.f) << "Attenuation is not allowed."; return {Type::kCapturePreGain, gain}; } + static RuntimeSetting CreateCapturePostGain(float gain) { + return {Type::kCapturePostGain, gain}; + } + // Corresponds to Config::GainController1::compression_gain_db, but for // runtime configuration. static RuntimeSetting CreateCompressionGainDb(int gain_db) { @@ -434,8 +488,8 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface { // Corresponds to Config::GainController2::fixed_digital::gain_db, but for // runtime configuration. static RuntimeSetting CreateCaptureFixedPostGain(float gain_db) { - RTC_DCHECK_GE(gain_db, 0.f); - RTC_DCHECK_LE(gain_db, 90.f); + RTC_DCHECK_GE(gain_db, 0.0f); + RTC_DCHECK_LE(gain_db, 90.0f); return {Type::kCaptureFixedPostGain, gain_db}; } @@ -456,8 +510,9 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface { return {Type::kCustomRenderProcessingRuntimeSetting, payload}; } - static RuntimeSetting CreateCaptureOutputUsedSetting(bool payload) { - return {Type::kCaptureOutputUsed, payload}; + static RuntimeSetting CreateCaptureOutputUsedSetting( + bool capture_output_used) { + return {Type::kCaptureOutputUsed, capture_output_used}; } Type type() const { return type_; } @@ -549,12 +604,17 @@ class RTC_EXPORT AudioProcessing : public rtc::RefCountInterface { // Set to true when the output of AudioProcessing will be muted or in some // other way not used. Ideally, the captured audio would still be processed, // but some components may change behavior based on this information. - // Default false. + // Default false. This method takes a lock. To achieve this in a lock-less + // manner the PostRuntimeSetting can instead be used. virtual void set_output_will_be_muted(bool muted) = 0; - // Enqueue a runtime setting. + // Enqueues a runtime setting. virtual void SetRuntimeSetting(RuntimeSetting setting) = 0; + // Enqueues a runtime setting. Returns a bool indicating whether the + // enqueueing was successfull. + virtual bool PostRuntimeSetting(RuntimeSetting setting) = 0; + // Accepts and produces a 10 ms frame interleaved 16 bit integer audio as // specified in |input_config| and |output_config|. |src| and |dest| may use // the same memory, if desired. diff --git a/modules/audio_processing/include/mock_audio_processing.h b/modules/audio_processing/include/mock_audio_processing.h index db9ab975ff..46c5f0efbe 100644 --- a/modules/audio_processing/include/mock_audio_processing.h +++ b/modules/audio_processing/include/mock_audio_processing.h @@ -96,6 +96,7 @@ class MockAudioProcessing : public AudioProcessing { MOCK_METHOD(size_t, num_reverse_channels, (), (const, override)); MOCK_METHOD(void, set_output_will_be_muted, (bool muted), (override)); MOCK_METHOD(void, SetRuntimeSetting, (RuntimeSetting setting), (override)); + MOCK_METHOD(bool, PostRuntimeSetting, (RuntimeSetting setting), (override)); MOCK_METHOD(int, ProcessStream, (const int16_t* const src, diff --git a/modules/audio_processing/logging/apm_data_dumper.cc b/modules/audio_processing/logging/apm_data_dumper.cc index 917df60c9c..445248b0bf 100644 --- a/modules/audio_processing/logging/apm_data_dumper.cc +++ b/modules/audio_processing/logging/apm_data_dumper.cc @@ -61,6 +61,7 @@ ApmDataDumper::~ApmDataDumper() = default; #if WEBRTC_APM_DEBUG_DUMP == 1 bool ApmDataDumper::recording_activated_ = false; +absl::optional ApmDataDumper::dump_set_to_use_; char ApmDataDumper::output_dir_[] = ""; FILE* ApmDataDumper::GetRawFile(const char* name) { diff --git a/modules/audio_processing/logging/apm_data_dumper.h b/modules/audio_processing/logging/apm_data_dumper.h index 1824fdd2a9..9c2ac3be5d 100644 --- a/modules/audio_processing/logging/apm_data_dumper.h +++ b/modules/audio_processing/logging/apm_data_dumper.h @@ -21,6 +21,7 @@ #include #endif +#include "absl/types/optional.h" #include "api/array_view.h" #if WEBRTC_APM_DEBUG_DUMP == 1 #include "common_audio/wav_file.h" @@ -64,6 +65,27 @@ class ApmDataDumper { #endif } + // Returns whether dumping functionality is enabled/available. + static bool IsAvailable() { +#if WEBRTC_APM_DEBUG_DUMP == 1 + return true; +#else + return false; +#endif + } + + // Default dump set. + static constexpr size_t kDefaultDumpSet = 0; + + // Specifies what dump set to use. All dump commands with a different dump set + // than the one specified will be discarded. If not specificed, all dump sets + // will be used. + static void SetDumpSetToUse(int dump_set_to_use) { +#if WEBRTC_APM_DEBUG_DUMP == 1 + dump_set_to_use_ = dump_set_to_use; +#endif + } + // Set an optional output directory. static void SetOutputDirectory(const std::string& output_dir) { #if WEBRTC_APM_DEBUG_DUMP == 1 @@ -82,8 +104,11 @@ class ApmDataDumper { // Methods for performing dumping of data of various types into // various formats. - void DumpRaw(const char* name, double v) { + void DumpRaw(const char* name, double v, int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { FILE* file = GetRawFile(name); fwrite(&v, sizeof(v), 1, file); @@ -91,8 +116,14 @@ class ApmDataDumper { #endif } - void DumpRaw(const char* name, size_t v_length, const double* v) { + void DumpRaw(const char* name, + size_t v_length, + const double* v, + int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { FILE* file = GetRawFile(name); fwrite(v, sizeof(v[0]), v_length, file); @@ -100,16 +131,24 @@ class ApmDataDumper { #endif } - void DumpRaw(const char* name, rtc::ArrayView v) { + void DumpRaw(const char* name, + rtc::ArrayView v, + int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { DumpRaw(name, v.size(), v.data()); } #endif } - void DumpRaw(const char* name, float v) { + void DumpRaw(const char* name, float v, int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { FILE* file = GetRawFile(name); fwrite(&v, sizeof(v), 1, file); @@ -117,8 +156,14 @@ class ApmDataDumper { #endif } - void DumpRaw(const char* name, size_t v_length, const float* v) { + void DumpRaw(const char* name, + size_t v_length, + const float* v, + int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { FILE* file = GetRawFile(name); fwrite(v, sizeof(v[0]), v_length, file); @@ -126,24 +171,38 @@ class ApmDataDumper { #endif } - void DumpRaw(const char* name, rtc::ArrayView v) { + void DumpRaw(const char* name, + rtc::ArrayView v, + int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { DumpRaw(name, v.size(), v.data()); } #endif } - void DumpRaw(const char* name, bool v) { + void DumpRaw(const char* name, bool v, int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { DumpRaw(name, static_cast(v)); } #endif } - void DumpRaw(const char* name, size_t v_length, const bool* v) { + void DumpRaw(const char* name, + size_t v_length, + const bool* v, + int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { FILE* file = GetRawFile(name); for (size_t k = 0; k < v_length; ++k) { @@ -154,16 +213,24 @@ class ApmDataDumper { #endif } - void DumpRaw(const char* name, rtc::ArrayView v) { + void DumpRaw(const char* name, + rtc::ArrayView v, + int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { DumpRaw(name, v.size(), v.data()); } #endif } - void DumpRaw(const char* name, int16_t v) { + void DumpRaw(const char* name, int16_t v, int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { FILE* file = GetRawFile(name); fwrite(&v, sizeof(v), 1, file); @@ -171,8 +238,14 @@ class ApmDataDumper { #endif } - void DumpRaw(const char* name, size_t v_length, const int16_t* v) { + void DumpRaw(const char* name, + size_t v_length, + const int16_t* v, + int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { FILE* file = GetRawFile(name); fwrite(v, sizeof(v[0]), v_length, file); @@ -180,16 +253,24 @@ class ApmDataDumper { #endif } - void DumpRaw(const char* name, rtc::ArrayView v) { + void DumpRaw(const char* name, + rtc::ArrayView v, + int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { DumpRaw(name, v.size(), v.data()); } #endif } - void DumpRaw(const char* name, int32_t v) { + void DumpRaw(const char* name, int32_t v, int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { FILE* file = GetRawFile(name); fwrite(&v, sizeof(v), 1, file); @@ -197,8 +278,14 @@ class ApmDataDumper { #endif } - void DumpRaw(const char* name, size_t v_length, const int32_t* v) { + void DumpRaw(const char* name, + size_t v_length, + const int32_t* v, + int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { FILE* file = GetRawFile(name); fwrite(v, sizeof(v[0]), v_length, file); @@ -206,8 +293,11 @@ class ApmDataDumper { #endif } - void DumpRaw(const char* name, size_t v) { + void DumpRaw(const char* name, size_t v, int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { FILE* file = GetRawFile(name); fwrite(&v, sizeof(v), 1, file); @@ -215,8 +305,14 @@ class ApmDataDumper { #endif } - void DumpRaw(const char* name, size_t v_length, const size_t* v) { + void DumpRaw(const char* name, + size_t v_length, + const size_t* v, + int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { FILE* file = GetRawFile(name); fwrite(v, sizeof(v[0]), v_length, file); @@ -224,16 +320,26 @@ class ApmDataDumper { #endif } - void DumpRaw(const char* name, rtc::ArrayView v) { + void DumpRaw(const char* name, + rtc::ArrayView v, + int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { DumpRaw(name, v.size(), v.data()); } #endif } - void DumpRaw(const char* name, rtc::ArrayView v) { + void DumpRaw(const char* name, + rtc::ArrayView v, + int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + DumpRaw(name, v.size(), v.data()); #endif } @@ -242,8 +348,12 @@ class ApmDataDumper { size_t v_length, const float* v, int sample_rate_hz, - int num_channels) { + int num_channels, + int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { WavWriter* file = GetWavFile(name, sample_rate_hz, num_channels, WavFile::SampleFormat::kFloat); @@ -255,8 +365,12 @@ class ApmDataDumper { void DumpWav(const char* name, rtc::ArrayView v, int sample_rate_hz, - int num_channels) { + int num_channels, + int dump_set = kDefaultDumpSet) { #if WEBRTC_APM_DEBUG_DUMP == 1 + if (dump_set_to_use_ && *dump_set_to_use_ != dump_set) + return; + if (recording_activated_) { DumpWav(name, v.size(), v.data(), sample_rate_hz, num_channels); } @@ -266,6 +380,7 @@ class ApmDataDumper { private: #if WEBRTC_APM_DEBUG_DUMP == 1 static bool recording_activated_; + static absl::optional dump_set_to_use_; static constexpr size_t kOutputDirMaxLength = 1024; static char output_dir_[kOutputDirMaxLength]; const int instance_index_; diff --git a/modules/audio_processing/ns/BUILD.gn b/modules/audio_processing/ns/BUILD.gn index f0842c505b..eb99c775a9 100644 --- a/modules/audio_processing/ns/BUILD.gn +++ b/modules/audio_processing/ns/BUILD.gn @@ -80,7 +80,6 @@ if (rtc_include_tests) { "..:apm_logging", "..:audio_buffer", "..:audio_processing", - "..:audio_processing_unittests", "..:high_pass_filter", "../../../api:array_view", "../../../rtc_base:checks", @@ -98,5 +97,9 @@ if (rtc_include_tests) { if (rtc_enable_protobuf) { sources += [] } + + if (!build_with_chromium) { + deps += [ "..:audio_processing_unittests" ] + } } } diff --git a/modules/audio_processing/ns/noise_suppressor.cc b/modules/audio_processing/ns/noise_suppressor.cc index 89e1fe0d91..d66faa6ed4 100644 --- a/modules/audio_processing/ns/noise_suppressor.cc +++ b/modules/audio_processing/ns/noise_suppressor.cc @@ -448,6 +448,12 @@ void NoiseSuppressor::Process(AudioBuffer* audio) { } } + // Only do the below processing if the output of the audio processing module + // is used. + if (!capture_output_used_) { + return; + } + // Aggregate the Wiener filters for all channels. std::array filter_data; rtc::ArrayView filter = filter_data; diff --git a/modules/audio_processing/ns/noise_suppressor.h b/modules/audio_processing/ns/noise_suppressor.h index d9628869bb..1e321cf4a2 100644 --- a/modules/audio_processing/ns/noise_suppressor.h +++ b/modules/audio_processing/ns/noise_suppressor.h @@ -41,12 +41,21 @@ class NoiseSuppressor { // Applies noise suppression. void Process(AudioBuffer* audio); + // Specifies whether the capture output will be used. The purpose of this is + // to allow the noise suppressor to deactivate some of the processing when the + // resulting output is anyway not used, for instance when the endpoint is + // muted. + void SetCaptureOutputUsage(bool capture_output_used) { + capture_output_used_ = capture_output_used; + } + private: const size_t num_bands_; const size_t num_channels_; const SuppressionParams suppression_params_; int32_t num_analyzed_frames_ = -1; NrFft fft_; + bool capture_output_used_ = true; struct ChannelState { ChannelState(const SuppressionParams& suppression_params, size_t num_bands); diff --git a/modules/audio_processing/residual_echo_detector_unittest.cc b/modules/audio_processing/residual_echo_detector_unittest.cc index 6697cf009d..a5f1409516 100644 --- a/modules/audio_processing/residual_echo_detector_unittest.cc +++ b/modules/audio_processing/residual_echo_detector_unittest.cc @@ -18,8 +18,7 @@ namespace webrtc { TEST(ResidualEchoDetectorTests, Echo) { - rtc::scoped_refptr echo_detector = - new rtc::RefCountedObject(); + auto echo_detector = rtc::make_ref_counted(); echo_detector->SetReliabilityForTest(1.0f); std::vector ones(160, 1.f); std::vector zeros(160, 0.f); @@ -46,8 +45,7 @@ TEST(ResidualEchoDetectorTests, Echo) { } TEST(ResidualEchoDetectorTests, NoEcho) { - rtc::scoped_refptr echo_detector = - new rtc::RefCountedObject(); + auto echo_detector = rtc::make_ref_counted(); echo_detector->SetReliabilityForTest(1.0f); std::vector ones(160, 1.f); std::vector zeros(160, 0.f); @@ -69,8 +67,7 @@ TEST(ResidualEchoDetectorTests, NoEcho) { } TEST(ResidualEchoDetectorTests, EchoWithRenderClockDrift) { - rtc::scoped_refptr echo_detector = - new rtc::RefCountedObject(); + auto echo_detector = rtc::make_ref_counted(); echo_detector->SetReliabilityForTest(1.0f); std::vector ones(160, 1.f); std::vector zeros(160, 0.f); @@ -107,8 +104,7 @@ TEST(ResidualEchoDetectorTests, EchoWithRenderClockDrift) { } TEST(ResidualEchoDetectorTests, EchoWithCaptureClockDrift) { - rtc::scoped_refptr echo_detector = - new rtc::RefCountedObject(); + auto echo_detector = rtc::make_ref_counted(); echo_detector->SetReliabilityForTest(1.0f); std::vector ones(160, 1.f); std::vector zeros(160, 0.f); diff --git a/modules/audio_processing/test/aec_dump_based_simulator.cc b/modules/audio_processing/test/aec_dump_based_simulator.cc index c3014d8e0b..4703ee30c7 100644 --- a/modules/audio_processing/test/aec_dump_based_simulator.cc +++ b/modules/audio_processing/test/aec_dump_based_simulator.cc @@ -14,6 +14,8 @@ #include #include "modules/audio_processing/echo_control_mobile_impl.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "modules/audio_processing/test/aec_dump_based_simulator.h" #include "modules/audio_processing/test/protobuf_utils.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" @@ -62,6 +64,18 @@ bool VerifyFloatBitExactness(const webrtc::audioproc::Stream& msg, return true; } +// Selectively reads the next proto-buf message from dump-file or string input. +// Returns a bool indicating whether a new message was available. +bool ReadNextMessage(bool use_dump_file, + FILE* dump_input_file, + std::stringstream& input, + webrtc::audioproc::Event& event_msg) { + if (use_dump_file) { + return ReadMessageFromFile(dump_input_file, &event_msg); + } + return ReadMessageFromString(&input, &event_msg); +} + } // namespace AecDumpBasedSimulator::AecDumpBasedSimulator( @@ -150,12 +164,15 @@ void AecDumpBasedSimulator::PrepareProcessStreamCall( } } - if (!settings_.use_ts) { + if (!settings_.use_ts || *settings_.use_ts == 1) { + // Transient suppressor activated (1) or not specified. if (msg.has_keypress()) { ap_->set_stream_key_pressed(msg.keypress()); } } else { - ap_->set_stream_key_pressed(*settings_.use_ts); + // Transient suppressor deactivated (0) or activated with continuous key + // events (2). + ap_->set_stream_key_pressed(*settings_.use_ts == 2); } // Level is always logged in AEC dumps. @@ -226,36 +243,93 @@ void AecDumpBasedSimulator::Process() { rtc::CheckedDivExact(sample_rate_hz, kChunksPerSecond), 1)); } - webrtc::audioproc::Event event_msg; - int num_forward_chunks_processed = 0; - if (settings_.aec_dump_input_string.has_value()) { - std::stringstream input; - input << settings_.aec_dump_input_string.value(); - while (ReadMessageFromString(&input, &event_msg)) - HandleEvent(event_msg, &num_forward_chunks_processed); - } else { + const bool use_dump_file = !settings_.aec_dump_input_string.has_value(); + std::stringstream input; + if (use_dump_file) { dump_input_file_ = OpenFile(settings_.aec_dump_input_filename->c_str(), "rb"); - while (ReadMessageFromFile(dump_input_file_, &event_msg)) - HandleEvent(event_msg, &num_forward_chunks_processed); + } else { + input << settings_.aec_dump_input_string.value(); + } + + webrtc::audioproc::Event event_msg; + int capture_frames_since_init = 0; + int init_index = 0; + while (ReadNextMessage(use_dump_file, dump_input_file_, input, event_msg)) { + SelectivelyToggleDataDumping(init_index, capture_frames_since_init); + HandleEvent(event_msg, capture_frames_since_init, init_index); + + // Perfom an early exit if the init block to process has been fully + // processed + if (finished_processing_specified_init_block_) { + break; + } + RTC_CHECK(!settings_.init_to_process || + *settings_.init_to_process >= init_index); + } + + if (use_dump_file) { fclose(dump_input_file_); } DetachAecDump(); } +void AecDumpBasedSimulator::Analyze() { + const bool use_dump_file = !settings_.aec_dump_input_string.has_value(); + std::stringstream input; + if (use_dump_file) { + dump_input_file_ = + OpenFile(settings_.aec_dump_input_filename->c_str(), "rb"); + } else { + input << settings_.aec_dump_input_string.value(); + } + + webrtc::audioproc::Event event_msg; + int num_capture_frames = 0; + int num_render_frames = 0; + int init_index = 0; + while (ReadNextMessage(use_dump_file, dump_input_file_, input, event_msg)) { + if (event_msg.type() == webrtc::audioproc::Event::INIT) { + ++init_index; + constexpr float kNumFramesPerSecond = 100.f; + float capture_time_seconds = num_capture_frames / kNumFramesPerSecond; + float render_time_seconds = num_render_frames / kNumFramesPerSecond; + + std::cout << "Inits:" << std::endl; + std::cout << init_index << ": -->" << std::endl; + std::cout << " Time:" << std::endl; + std::cout << " Capture: " << capture_time_seconds << " s (" + << num_capture_frames << " frames) " << std::endl; + std::cout << " Render: " << render_time_seconds << " s (" + << num_render_frames << " frames) " << std::endl; + } else if (event_msg.type() == webrtc::audioproc::Event::STREAM) { + ++num_capture_frames; + } else if (event_msg.type() == webrtc::audioproc::Event::REVERSE_STREAM) { + ++num_render_frames; + } + } + + if (use_dump_file) { + fclose(dump_input_file_); + } +} + void AecDumpBasedSimulator::HandleEvent( const webrtc::audioproc::Event& event_msg, - int* num_forward_chunks_processed) { + int& capture_frames_since_init, + int& init_index) { switch (event_msg.type()) { case webrtc::audioproc::Event::INIT: RTC_CHECK(event_msg.has_init()); - HandleMessage(event_msg.init()); + ++init_index; + capture_frames_since_init = 0; + HandleMessage(event_msg.init(), init_index); break; case webrtc::audioproc::Event::STREAM: RTC_CHECK(event_msg.has_stream()); + ++capture_frames_since_init; HandleMessage(event_msg.stream()); - ++num_forward_chunks_processed; break; case webrtc::audioproc::Event::REVERSE_STREAM: RTC_CHECK(event_msg.has_reverse_stream()); @@ -439,11 +513,18 @@ void AecDumpBasedSimulator::HandleMessage( } } -void AecDumpBasedSimulator::HandleMessage(const webrtc::audioproc::Init& msg) { +void AecDumpBasedSimulator::HandleMessage(const webrtc::audioproc::Init& msg, + int init_index) { RTC_CHECK(msg.has_sample_rate()); RTC_CHECK(msg.has_num_input_channels()); RTC_CHECK(msg.has_num_reverse_channels()); RTC_CHECK(msg.has_reverse_sample_rate()); + + // Do not perform the init if the init block to process is fully processed + if (settings_.init_to_process && *settings_.init_to_process < init_index) { + finished_processing_specified_init_block_ = true; + } + MaybeOpenCallOrderFile(); if (settings_.use_verbose_logging) { @@ -518,8 +599,23 @@ void AecDumpBasedSimulator::HandleMessage( RTC_CHECK(ap_.get()); if (msg.has_capture_pre_gain()) { // Handle capture pre-gain runtime setting only if not overridden. - if ((!settings_.use_pre_amplifier || *settings_.use_pre_amplifier) && - !settings_.pre_amplifier_gain_factor) { + const bool pre_amplifier_overridden = + (!settings_.use_pre_amplifier || *settings_.use_pre_amplifier) && + !settings_.pre_amplifier_gain_factor; + const bool capture_level_adjustment_overridden = + (!settings_.use_capture_level_adjustment || + *settings_.use_capture_level_adjustment) && + !settings_.pre_gain_factor; + if (pre_amplifier_overridden || capture_level_adjustment_overridden) { + ap_->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCapturePreGain( + msg.capture_pre_gain())); + } + } else if (msg.has_capture_post_gain()) { + // Handle capture post-gain runtime setting only if not overridden. + if ((!settings_.use_capture_level_adjustment || + *settings_.use_capture_level_adjustment) && + !settings_.post_gain_factor) { ap_->SetRuntimeSetting( AudioProcessing::RuntimeSetting::CreateCapturePreGain( msg.capture_pre_gain())); diff --git a/modules/audio_processing/test/aec_dump_based_simulator.h b/modules/audio_processing/test/aec_dump_based_simulator.h index c8d82e6fc3..e2c1f3e4ba 100644 --- a/modules/audio_processing/test/aec_dump_based_simulator.h +++ b/modules/audio_processing/test/aec_dump_based_simulator.h @@ -44,10 +44,14 @@ class AecDumpBasedSimulator final : public AudioProcessingSimulator { // Processes the messages in the aecdump file. void Process() override; + // Analyzes the data in the aecdump file and reports the resulting statistics. + void Analyze() override; + private: void HandleEvent(const webrtc::audioproc::Event& event_msg, - int* num_forward_chunks_processed); - void HandleMessage(const webrtc::audioproc::Init& msg); + int& num_forward_chunks_processed, + int& init_index); + void HandleMessage(const webrtc::audioproc::Init& msg, int init_index); void HandleMessage(const webrtc::audioproc::Stream& msg); void HandleMessage(const webrtc::audioproc::ReverseStream& msg); void HandleMessage(const webrtc::audioproc::Config& msg); @@ -69,6 +73,7 @@ class AecDumpBasedSimulator final : public AudioProcessingSimulator { bool artificial_nearend_eof_reported_ = false; InterfaceType interface_used_ = InterfaceType::kNotSpecified; std::unique_ptr call_order_output_file_; + bool finished_processing_specified_init_block_ = false; }; } // namespace test diff --git a/modules/audio_processing/test/audio_buffer_tools.cc b/modules/audio_processing/test/audio_buffer_tools.cc index 0f0e5cd520..64fb9c7ab1 100644 --- a/modules/audio_processing/test/audio_buffer_tools.cc +++ b/modules/audio_processing/test/audio_buffer_tools.cc @@ -51,5 +51,18 @@ void ExtractVectorFromAudioBuffer(const StreamConfig& stream_config, source->CopyTo(stream_config, &output[0]); } +void FillBuffer(float value, AudioBuffer& audio_buffer) { + for (size_t ch = 0; ch < audio_buffer.num_channels(); ++ch) { + FillBufferChannel(value, ch, audio_buffer); + } +} + +void FillBufferChannel(float value, int channel, AudioBuffer& audio_buffer) { + RTC_CHECK_LT(channel, audio_buffer.num_channels()); + for (size_t i = 0; i < audio_buffer.num_frames(); ++i) { + audio_buffer.channels()[channel][i] = value; + } +} + } // namespace test } // namespace webrtc diff --git a/modules/audio_processing/test/audio_buffer_tools.h b/modules/audio_processing/test/audio_buffer_tools.h index 9ee34e783a..faac4bf9ff 100644 --- a/modules/audio_processing/test/audio_buffer_tools.h +++ b/modules/audio_processing/test/audio_buffer_tools.h @@ -30,6 +30,12 @@ void ExtractVectorFromAudioBuffer(const StreamConfig& stream_config, AudioBuffer* source, std::vector* destination); +// Sets all values in `audio_buffer` to `value`. +void FillBuffer(float value, AudioBuffer& audio_buffer); + +// Sets all values channel `channel` for `audio_buffer` to `value`. +void FillBufferChannel(float value, int channel, AudioBuffer& audio_buffer); + } // namespace test } // namespace webrtc diff --git a/modules/audio_processing/test/audio_processing_simulator.cc b/modules/audio_processing/test/audio_processing_simulator.cc index 40ca7d11b0..1f05f43120 100644 --- a/modules/audio_processing/test/audio_processing_simulator.cc +++ b/modules/audio_processing/test/audio_processing_simulator.cc @@ -122,7 +122,16 @@ AudioProcessingSimulator::AudioProcessingSimulator( settings_.simulate_mic_gain ? *settings.simulated_mic_kind : 0), worker_queue_("file_writer_task_queue") { RTC_CHECK(!settings_.dump_internal_data || WEBRTC_APM_DEBUG_DUMP == 1); - ApmDataDumper::SetActivated(settings_.dump_internal_data); + if (settings_.dump_start_frame || settings_.dump_end_frame) { + ApmDataDumper::SetActivated(!settings_.dump_start_frame); + } else { + ApmDataDumper::SetActivated(settings_.dump_internal_data); + } + + if (settings_.dump_set_to_use) { + ApmDataDumper::SetDumpSetToUse(*settings_.dump_set_to_use); + } + if (settings_.dump_internal_data_output_dir.has_value()) { ApmDataDumper::SetOutputDirectory( settings_.dump_internal_data_output_dir.value()); @@ -217,6 +226,20 @@ void AudioProcessingSimulator::ProcessStream(bool fixed_interface) { : analog_mic_level_); } + // Post any scheduled runtime settings. + if (settings_.frame_for_sending_capture_output_used_false && + *settings_.frame_for_sending_capture_output_used_false == + static_cast(num_process_stream_calls_)) { + ap_->PostRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCaptureOutputUsedSetting(false)); + } + if (settings_.frame_for_sending_capture_output_used_true && + *settings_.frame_for_sending_capture_output_used_true == + static_cast(num_process_stream_calls_)) { + ap_->PostRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCaptureOutputUsedSetting(true)); + } + // Process the current audio frame. if (fixed_interface) { { @@ -357,6 +380,28 @@ void AudioProcessingSimulator::SetupBuffersConfigsOutputs( SetupOutput(); } +void AudioProcessingSimulator::SelectivelyToggleDataDumping( + int init_index, + int capture_frames_since_init) const { + if (!(settings_.dump_start_frame || settings_.dump_end_frame)) { + return; + } + + if (settings_.init_to_process && *settings_.init_to_process != init_index) { + return; + } + + if (settings_.dump_start_frame && + *settings_.dump_start_frame == capture_frames_since_init) { + ApmDataDumper::SetActivated(true); + } + + if (settings_.dump_end_frame && + *settings_.dump_end_frame == capture_frames_since_init) { + ApmDataDumper::SetActivated(false); + } +} + void AudioProcessingSimulator::SetupOutput() { if (settings_.output_filename) { std::string filename; @@ -422,7 +467,7 @@ void AudioProcessingSimulator::DetachAecDump() { void AudioProcessingSimulator::ConfigureAudioProcessor() { AudioProcessing::Config apm_config; if (settings_.use_ts) { - apm_config.transient_suppression.enabled = *settings_.use_ts; + apm_config.transient_suppression.enabled = *settings_.use_ts != 0; } if (settings_.multi_channel_render) { apm_config.pipeline.multi_channel_render = *settings_.multi_channel_render; @@ -454,6 +499,34 @@ void AudioProcessingSimulator::ConfigureAudioProcessor() { } } + if (settings_.use_analog_mic_gain_emulation) { + if (*settings_.use_analog_mic_gain_emulation) { + apm_config.capture_level_adjustment.enabled = true; + apm_config.capture_level_adjustment.analog_mic_gain_emulation.enabled = + true; + } else { + apm_config.capture_level_adjustment.analog_mic_gain_emulation.enabled = + false; + } + } + if (settings_.analog_mic_gain_emulation_initial_level) { + apm_config.capture_level_adjustment.analog_mic_gain_emulation + .initial_level = *settings_.analog_mic_gain_emulation_initial_level; + } + + if (settings_.use_capture_level_adjustment) { + apm_config.capture_level_adjustment.enabled = + *settings_.use_capture_level_adjustment; + } + if (settings_.pre_gain_factor) { + apm_config.capture_level_adjustment.pre_gain_factor = + *settings_.pre_gain_factor; + } + if (settings_.post_gain_factor) { + apm_config.capture_level_adjustment.post_gain_factor = + *settings_.post_gain_factor; + } + const bool use_aec = settings_.use_aec && *settings_.use_aec; const bool use_aecm = settings_.use_aecm && *settings_.use_aecm; if (use_aec || use_aecm) { @@ -497,11 +570,6 @@ void AudioProcessingSimulator::ConfigureAudioProcessor() { apm_config.gain_controller1.analog_gain_controller.enabled = *settings_.use_analog_agc; } - if (settings_.use_analog_agc_agc2_level_estimator) { - apm_config.gain_controller1.analog_gain_controller - .enable_agc2_level_estimator = - *settings_.use_analog_agc_agc2_level_estimator; - } if (settings_.analog_agc_disable_digital_adaptive) { apm_config.gain_controller1.analog_gain_controller.enable_digital_adaptive = *settings_.analog_agc_disable_digital_adaptive; @@ -534,7 +602,9 @@ void AudioProcessingSimulator::ConfigureAudioProcessor() { ap_->ApplyConfig(apm_config); if (settings_.use_ts) { - ap_->set_stream_key_pressed(*settings_.use_ts); + // Default to key pressed if activating the transient suppressor with + // continuous key events. + ap_->set_stream_key_pressed(*settings_.use_ts == 2); } if (settings_.aec_dump_output_filename) { diff --git a/modules/audio_processing/test/audio_processing_simulator.h b/modules/audio_processing/test/audio_processing_simulator.h index 63e644a9fa..9539e58b1b 100644 --- a/modules/audio_processing/test/audio_processing_simulator.h +++ b/modules/audio_processing/test/audio_processing_simulator.h @@ -99,14 +99,15 @@ struct SimulationSettings { absl::optional use_agc; absl::optional use_agc2; absl::optional use_pre_amplifier; + absl::optional use_capture_level_adjustment; + absl::optional use_analog_mic_gain_emulation; absl::optional use_hpf; absl::optional use_ns; - absl::optional use_ts; + absl::optional use_ts; absl::optional use_analog_agc; absl::optional use_vad; absl::optional use_le; absl::optional use_all; - absl::optional use_analog_agc_agc2_level_estimator; absl::optional analog_agc_disable_digital_adaptive; absl::optional agc_mode; absl::optional agc_target_level; @@ -117,6 +118,9 @@ struct SimulationSettings { AudioProcessing::Config::GainController2::LevelEstimator agc2_adaptive_level_estimator; absl::optional pre_amplifier_gain_factor; + absl::optional pre_gain_factor; + absl::optional post_gain_factor; + absl::optional analog_mic_gain_emulation_initial_level; absl::optional ns_level; absl::optional ns_analysis_on_linear_aec_output; absl::optional maximum_internal_processing_rate; @@ -125,6 +129,8 @@ struct SimulationSettings { absl::optional multi_channel_render; absl::optional multi_channel_capture; absl::optional simulated_mic_kind; + absl::optional frame_for_sending_capture_output_used_false; + absl::optional frame_for_sending_capture_output_used_true; bool report_performance = false; absl::optional performance_report_output_filename; bool report_bitexactness = false; @@ -139,11 +145,16 @@ struct SimulationSettings { bool dump_internal_data = false; WavFile::SampleFormat wav_output_format = WavFile::SampleFormat::kInt16; absl::optional dump_internal_data_output_dir; + absl::optional dump_set_to_use; absl::optional call_order_input_filename; absl::optional call_order_output_filename; absl::optional aec_settings_filename; absl::optional aec_dump_input_string; std::vector* processed_capture_samples = nullptr; + bool analysis_only = false; + absl::optional dump_start_frame; + absl::optional dump_end_frame; + absl::optional init_to_process; }; // Provides common functionality for performing audioprocessing simulations. @@ -167,6 +178,9 @@ class AudioProcessingSimulator { return api_call_statistics_; } + // Analyzes the data in the input and reports the resulting statistics. + virtual void Analyze() = 0; + // Reports whether the processed recording was bitexact. bool OutputWasBitexact() { return bitexact_output_; } @@ -188,6 +202,8 @@ class AudioProcessingSimulator { int output_num_channels, int reverse_input_num_channels, int reverse_output_num_channels); + void SelectivelyToggleDataDumping(int init_index, + int capture_frames_since_init) const; const SimulationSettings settings_; rtc::scoped_refptr ap_; diff --git a/modules/audio_processing/test/audioproc_float_impl.cc b/modules/audio_processing/test/audioproc_float_impl.cc index ab395f1018..1fc39bb6b9 100644 --- a/modules/audio_processing/test/audioproc_float_impl.cc +++ b/modules/audio_processing/test/audioproc_float_impl.cc @@ -65,11 +65,11 @@ ABSL_FLAG(bool, ABSL_FLAG(int, aec, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) the echo canceller"); + "Activate (1) or deactivate (0) the echo canceller"); ABSL_FLAG(int, aecm, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) the mobile echo controller"); + "Activate (1) or deactivate (0) the mobile echo controller"); ABSL_FLAG(int, ed, kParameterNotSpecifiedValue, @@ -81,39 +81,50 @@ ABSL_FLAG(std::string, ABSL_FLAG(int, agc, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) the AGC"); + "Activate (1) or deactivate (0) the AGC"); ABSL_FLAG(int, agc2, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) the AGC2"); + "Activate (1) or deactivate (0) the AGC2"); ABSL_FLAG(int, pre_amplifier, kParameterNotSpecifiedValue, "Activate (1) or deactivate(0) the pre amplifier"); +ABSL_FLAG( + int, + capture_level_adjustment, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate(0) the capture level adjustment functionality"); +ABSL_FLAG(int, + analog_mic_gain_emulation, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate(0) the analog mic gain emulation in the " + "production (non-test) code."); ABSL_FLAG(int, hpf, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) the high-pass filter"); + "Activate (1) or deactivate (0) the high-pass filter"); ABSL_FLAG(int, ns, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) the noise suppressor"); + "Activate (1) or deactivate (0) the noise suppressor"); ABSL_FLAG(int, ts, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) the transient suppressor"); + "Activate (1), deactivate (0) or activate the transient suppressor " + "with continuous key events (2)"); ABSL_FLAG(int, analog_agc, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) the transient suppressor"); + "Activate (1) or deactivate (0) the analog AGC"); ABSL_FLAG(int, vad, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) the voice activity detector"); + "Activate (1) or deactivate (0) the voice activity detector"); ABSL_FLAG(int, le, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) the level estimator"); + "Activate (1) or deactivate (0) the level estimator"); ABSL_FLAG(bool, all_default, false, @@ -124,11 +135,6 @@ ABSL_FLAG(int, kParameterNotSpecifiedValue, "Force-deactivate (1) digital adaptation in " "experimental AGC. Digital adaptation is active by default (0)."); -ABSL_FLAG(int, - analog_agc_agc2_level_estimator, - kParameterNotSpecifiedValue, - "AGC2 level estimation" - " in the experimental AGC. AGC1 level estimation is the default (0)"); ABSL_FLAG(int, agc_mode, kParameterNotSpecifiedValue, @@ -140,7 +146,7 @@ ABSL_FLAG(int, ABSL_FLAG(int, agc_limiter, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) the level estimator"); + "Activate (1) or deactivate (0) the level estimator"); ABSL_FLAG(int, agc_compression_gain, kParameterNotSpecifiedValue, @@ -148,7 +154,7 @@ ABSL_FLAG(int, ABSL_FLAG(int, agc2_enable_adaptive_gain, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) the AGC2 adaptive gain"); + "Activate (1) or deactivate (0) the AGC2 adaptive gain"); ABSL_FLAG(float, agc2_fixed_gain_db, kParameterNotSpecifiedValue, @@ -161,6 +167,19 @@ ABSL_FLAG(float, pre_amplifier_gain_factor, kParameterNotSpecifiedValue, "Pre-amplifier gain factor (linear) to apply"); +ABSL_FLAG(float, + pre_gain_factor, + kParameterNotSpecifiedValue, + "Pre-gain factor (linear) to apply in the capture level adjustment"); +ABSL_FLAG(float, + post_gain_factor, + kParameterNotSpecifiedValue, + "Post-gain factor (linear) to apply in the capture level adjustment"); +ABSL_FLAG(float, + analog_mic_gain_emulation_initial_level, + kParameterNotSpecifiedValue, + "Emulated analog mic level to apply initially in the production " + "(non-test) code."); ABSL_FLAG(int, ns_level, kParameterNotSpecifiedValue, @@ -182,30 +201,45 @@ ABSL_FLAG(int, ABSL_FLAG(int, use_stream_delay, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) reporting the stream delay"); + "Activate (1) or deactivate (0) reporting the stream delay"); ABSL_FLAG(int, stream_drift_samples, kParameterNotSpecifiedValue, "Specify the number of stream drift samples to use"); -ABSL_FLAG(int, initial_mic_level, 100, "Initial mic level (0-255)"); +ABSL_FLAG(int, + initial_mic_level, + 100, + "Initial mic level (0-255) for the analog mic gain simulation in the " + "test code"); ABSL_FLAG(int, simulate_mic_gain, 0, - "Activate (1) or deactivate(0) the analog mic gain simulation"); + "Activate (1) or deactivate(0) the analog mic gain simulation in the " + "test code"); ABSL_FLAG(int, multi_channel_render, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) multi-channel render processing in " + "Activate (1) or deactivate (0) multi-channel render processing in " "APM pipeline"); ABSL_FLAG(int, multi_channel_capture, kParameterNotSpecifiedValue, - "Activate (1) or deactivate(0) multi-channel capture processing in " + "Activate (1) or deactivate (0) multi-channel capture processing in " "APM pipeline"); ABSL_FLAG(int, simulated_mic_kind, kParameterNotSpecifiedValue, "Specify which microphone kind to use for microphone simulation"); +ABSL_FLAG(int, + frame_for_sending_capture_output_used_false, + kParameterNotSpecifiedValue, + "Capture frame index for sending a runtime setting for that the " + "capture output is not used."); +ABSL_FLAG(int, + frame_for_sending_capture_output_used_true, + kParameterNotSpecifiedValue, + "Capture frame index for sending a runtime setting for that the " + "capture output is used."); ABSL_FLAG(bool, performance_report, false, "Report the APM performance "); ABSL_FLAG(std::string, performance_report_output_file, @@ -252,6 +286,36 @@ ABSL_FLAG(std::string, dump_data_output_dir, "", "Internal data dump output directory"); +ABSL_FLAG(int, + dump_set_to_use, + kParameterNotSpecifiedValue, + "Specifies the dump set to use (if not all the dump sets will " + "be used"); +ABSL_FLAG(bool, + analyze, + false, + "Only analyze the call setup behavior (no processing)"); +ABSL_FLAG(float, + dump_start_seconds, + kParameterNotSpecifiedValue, + "Start of when to dump data (seconds)."); +ABSL_FLAG(float, + dump_end_seconds, + kParameterNotSpecifiedValue, + "End of when to dump data (seconds)."); +ABSL_FLAG(int, + dump_start_frame, + kParameterNotSpecifiedValue, + "Start of when to dump data (frames)."); +ABSL_FLAG(int, + dump_end_frame, + kParameterNotSpecifiedValue, + "End of when to dump data (frames)."); +ABSL_FLAG(int, + init_to_process, + kParameterNotSpecifiedValue, + "Init index to process."); + ABSL_FLAG(bool, float_wav_output, false, @@ -378,17 +442,19 @@ SimulationSettings CreateSettings() { SetSettingIfFlagSet(absl::GetFlag(FLAGS_agc2), &settings.use_agc2); SetSettingIfFlagSet(absl::GetFlag(FLAGS_pre_amplifier), &settings.use_pre_amplifier); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_capture_level_adjustment), + &settings.use_capture_level_adjustment); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_analog_mic_gain_emulation), + &settings.use_analog_mic_gain_emulation); SetSettingIfFlagSet(absl::GetFlag(FLAGS_hpf), &settings.use_hpf); SetSettingIfFlagSet(absl::GetFlag(FLAGS_ns), &settings.use_ns); - SetSettingIfFlagSet(absl::GetFlag(FLAGS_ts), &settings.use_ts); + SetSettingIfSpecified(absl::GetFlag(FLAGS_ts), &settings.use_ts); SetSettingIfFlagSet(absl::GetFlag(FLAGS_analog_agc), &settings.use_analog_agc); SetSettingIfFlagSet(absl::GetFlag(FLAGS_vad), &settings.use_vad); SetSettingIfFlagSet(absl::GetFlag(FLAGS_le), &settings.use_le); SetSettingIfFlagSet(absl::GetFlag(FLAGS_analog_agc_disable_digital_adaptive), &settings.analog_agc_disable_digital_adaptive); - SetSettingIfFlagSet(absl::GetFlag(FLAGS_analog_agc_agc2_level_estimator), - &settings.use_analog_agc_agc2_level_estimator); SetSettingIfSpecified(absl::GetFlag(FLAGS_agc_mode), &settings.agc_mode); SetSettingIfSpecified(absl::GetFlag(FLAGS_agc_target_level), &settings.agc_target_level); @@ -398,12 +464,20 @@ SimulationSettings CreateSettings() { &settings.agc_compression_gain); SetSettingIfFlagSet(absl::GetFlag(FLAGS_agc2_enable_adaptive_gain), &settings.agc2_use_adaptive_gain); + SetSettingIfSpecified(absl::GetFlag(FLAGS_agc2_fixed_gain_db), &settings.agc2_fixed_gain_db); settings.agc2_adaptive_level_estimator = MapAgc2AdaptiveLevelEstimator( absl::GetFlag(FLAGS_agc2_adaptive_level_estimator)); SetSettingIfSpecified(absl::GetFlag(FLAGS_pre_amplifier_gain_factor), &settings.pre_amplifier_gain_factor); + SetSettingIfSpecified(absl::GetFlag(FLAGS_pre_gain_factor), + &settings.pre_gain_factor); + SetSettingIfSpecified(absl::GetFlag(FLAGS_post_gain_factor), + &settings.post_gain_factor); + SetSettingIfSpecified( + absl::GetFlag(FLAGS_analog_mic_gain_emulation_initial_level), + &settings.analog_mic_gain_emulation_initial_level); SetSettingIfSpecified(absl::GetFlag(FLAGS_ns_level), &settings.ns_level); SetSettingIfFlagSet(absl::GetFlag(FLAGS_ns_analysis_on_linear_aec_output), &settings.ns_analysis_on_linear_aec_output); @@ -427,6 +501,12 @@ SimulationSettings CreateSettings() { settings.simulate_mic_gain = absl::GetFlag(FLAGS_simulate_mic_gain); SetSettingIfSpecified(absl::GetFlag(FLAGS_simulated_mic_kind), &settings.simulated_mic_kind); + SetSettingIfSpecified( + absl::GetFlag(FLAGS_frame_for_sending_capture_output_used_false), + &settings.frame_for_sending_capture_output_used_false); + SetSettingIfSpecified( + absl::GetFlag(FLAGS_frame_for_sending_capture_output_used_true), + &settings.frame_for_sending_capture_output_used_true); settings.report_performance = absl::GetFlag(FLAGS_performance_report); SetSettingIfSpecified(absl::GetFlag(FLAGS_performance_report_output_file), &settings.performance_report_output_filename); @@ -443,10 +523,36 @@ SimulationSettings CreateSettings() { settings.dump_internal_data = absl::GetFlag(FLAGS_dump_data); SetSettingIfSpecified(absl::GetFlag(FLAGS_dump_data_output_dir), &settings.dump_internal_data_output_dir); + SetSettingIfSpecified(absl::GetFlag(FLAGS_dump_set_to_use), + &settings.dump_set_to_use); settings.wav_output_format = absl::GetFlag(FLAGS_float_wav_output) ? WavFile::SampleFormat::kFloat : WavFile::SampleFormat::kInt16; + settings.analysis_only = absl::GetFlag(FLAGS_analyze); + + SetSettingIfSpecified(absl::GetFlag(FLAGS_dump_start_frame), + &settings.dump_start_frame); + SetSettingIfSpecified(absl::GetFlag(FLAGS_dump_end_frame), + &settings.dump_end_frame); + + constexpr int kFramesPerSecond = 100; + absl::optional start_seconds; + SetSettingIfSpecified(absl::GetFlag(FLAGS_dump_start_seconds), + &start_seconds); + if (start_seconds) { + settings.dump_start_frame = *start_seconds * kFramesPerSecond; + } + + absl::optional end_seconds; + SetSettingIfSpecified(absl::GetFlag(FLAGS_dump_end_seconds), &end_seconds); + if (end_seconds) { + settings.dump_end_frame = *end_seconds * kFramesPerSecond; + } + + SetSettingIfSpecified(absl::GetFlag(FLAGS_init_to_process), + &settings.init_to_process); + return settings; } @@ -612,6 +718,18 @@ void PerformBasicParameterSanityChecks( WEBRTC_APM_DEBUG_DUMP == 0 && settings.dump_internal_data, "Error: --dump_data cannot be set without proper build support.\n"); + ReportConditionalErrorAndExit(settings.init_to_process && + *settings.init_to_process != 1 && + !settings.aec_dump_input_filename, + "Error: --init_to_process must be set to 1 for " + "wav-file based simulations.\n"); + + ReportConditionalErrorAndExit( + !settings.init_to_process && + (settings.dump_start_frame || settings.dump_end_frame), + "Error: --init_to_process must be set when specifying a start and/or end " + "frame for when to dump internal data.\n"); + ReportConditionalErrorAndExit( !settings.dump_internal_data && settings.dump_internal_data_output_dir.has_value(), @@ -684,7 +802,11 @@ int RunSimulation(rtc::scoped_refptr audio_processing, std::move(ap_builder))); } - processor->Process(); + if (settings.analysis_only) { + processor->Analyze(); + } else { + processor->Process(); + } if (settings.report_performance) { processor->GetApiCallStatistics().PrintReport(); diff --git a/modules/audio_processing/test/conversational_speech/BUILD.gn b/modules/audio_processing/test/conversational_speech/BUILD.gn index b311abdbd1..42707afda7 100644 --- a/modules/audio_processing/test/conversational_speech/BUILD.gn +++ b/modules/audio_processing/test/conversational_speech/BUILD.gn @@ -8,21 +8,23 @@ import("../../../../webrtc.gni") -group("conversational_speech") { - testonly = true - deps = [ ":conversational_speech_generator" ] -} +if (!build_with_chromium) { + group("conversational_speech") { + testonly = true + deps = [ ":conversational_speech_generator" ] + } -rtc_executable("conversational_speech_generator") { - testonly = true - sources = [ "generator.cc" ] - deps = [ - ":lib", - "../../../../test:fileutils", - "../../../../test:test_support", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - ] + rtc_executable("conversational_speech_generator") { + testonly = true + sources = [ "generator.cc" ] + deps = [ + ":lib", + "../../../../test:fileutils", + "../../../../test:test_support", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] + } } rtc_library("lib") { diff --git a/modules/audio_processing/test/echo_control_mock.h b/modules/audio_processing/test/echo_control_mock.h index 927de43ae0..763d6e4f0b 100644 --- a/modules/audio_processing/test/echo_control_mock.h +++ b/modules/audio_processing/test/echo_control_mock.h @@ -34,6 +34,10 @@ class MockEchoControl : public EchoControl { (override)); MOCK_METHOD(EchoControl::Metrics, GetMetrics, (), (const, override)); MOCK_METHOD(void, SetAudioBufferDelay, (int delay_ms), (override)); + MOCK_METHOD(void, + SetCaptureOutputUsage, + (bool capture_output_used), + (override)); MOCK_METHOD(bool, ActiveProcessing, (), (const, override)); }; diff --git a/modules/audio_processing/test/py_quality_assessment/BUILD.gn b/modules/audio_processing/test/py_quality_assessment/BUILD.gn index fe7c444a81..9ec86d17ec 100644 --- a/modules/audio_processing/test/py_quality_assessment/BUILD.gn +++ b/modules/audio_processing/test/py_quality_assessment/BUILD.gn @@ -8,161 +8,163 @@ import("../../../../webrtc.gni") -group("py_quality_assessment") { - testonly = true - deps = [ - ":scripts", - ":unit_tests", - ] -} +if (!build_with_chromium) { + group("py_quality_assessment") { + testonly = true + deps = [ + ":scripts", + ":unit_tests", + ] + } -copy("scripts") { - testonly = true - sources = [ - "README.md", - "apm_quality_assessment.py", - "apm_quality_assessment.sh", - "apm_quality_assessment_boxplot.py", - "apm_quality_assessment_export.py", - "apm_quality_assessment_gencfgs.py", - "apm_quality_assessment_optimize.py", - ] - outputs = [ "$root_build_dir/py_quality_assessment/{{source_file_part}}" ] - deps = [ - ":apm_configs", - ":lib", - ":output", - "../../../../resources/audio_processing/test/py_quality_assessment:probing_signals", - "../../../../rtc_tools:audioproc_f", - ] -} + copy("scripts") { + testonly = true + sources = [ + "README.md", + "apm_quality_assessment.py", + "apm_quality_assessment.sh", + "apm_quality_assessment_boxplot.py", + "apm_quality_assessment_export.py", + "apm_quality_assessment_gencfgs.py", + "apm_quality_assessment_optimize.py", + ] + outputs = [ "$root_build_dir/py_quality_assessment/{{source_file_part}}" ] + deps = [ + ":apm_configs", + ":lib", + ":output", + "../../../../resources/audio_processing/test/py_quality_assessment:probing_signals", + "../../../../rtc_tools:audioproc_f", + ] + } -copy("apm_configs") { - testonly = true - sources = [ "apm_configs/default.json" ] - visibility = [ ":*" ] # Only targets in this file can depend on this. - outputs = [ - "$root_build_dir/py_quality_assessment/apm_configs/{{source_file_part}}", - ] -} # apm_configs + copy("apm_configs") { + testonly = true + sources = [ "apm_configs/default.json" ] + visibility = [ ":*" ] # Only targets in this file can depend on this. + outputs = [ + "$root_build_dir/py_quality_assessment/apm_configs/{{source_file_part}}", + ] + } # apm_configs -copy("lib") { - testonly = true - sources = [ - "quality_assessment/__init__.py", - "quality_assessment/annotations.py", - "quality_assessment/audioproc_wrapper.py", - "quality_assessment/collect_data.py", - "quality_assessment/data_access.py", - "quality_assessment/echo_path_simulation.py", - "quality_assessment/echo_path_simulation_factory.py", - "quality_assessment/eval_scores.py", - "quality_assessment/eval_scores_factory.py", - "quality_assessment/evaluation.py", - "quality_assessment/exceptions.py", - "quality_assessment/export.py", - "quality_assessment/export_unittest.py", - "quality_assessment/external_vad.py", - "quality_assessment/input_mixer.py", - "quality_assessment/input_signal_creator.py", - "quality_assessment/results.css", - "quality_assessment/results.js", - "quality_assessment/signal_processing.py", - "quality_assessment/simulation.py", - "quality_assessment/test_data_generation.py", - "quality_assessment/test_data_generation_factory.py", - ] - visibility = [ ":*" ] # Only targets in this file can depend on this. - outputs = [ "$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}" ] - deps = [ "../../../../resources/audio_processing/test/py_quality_assessment:noise_tracks" ] -} + copy("lib") { + testonly = true + sources = [ + "quality_assessment/__init__.py", + "quality_assessment/annotations.py", + "quality_assessment/audioproc_wrapper.py", + "quality_assessment/collect_data.py", + "quality_assessment/data_access.py", + "quality_assessment/echo_path_simulation.py", + "quality_assessment/echo_path_simulation_factory.py", + "quality_assessment/eval_scores.py", + "quality_assessment/eval_scores_factory.py", + "quality_assessment/evaluation.py", + "quality_assessment/exceptions.py", + "quality_assessment/export.py", + "quality_assessment/export_unittest.py", + "quality_assessment/external_vad.py", + "quality_assessment/input_mixer.py", + "quality_assessment/input_signal_creator.py", + "quality_assessment/results.css", + "quality_assessment/results.js", + "quality_assessment/signal_processing.py", + "quality_assessment/simulation.py", + "quality_assessment/test_data_generation.py", + "quality_assessment/test_data_generation_factory.py", + ] + visibility = [ ":*" ] # Only targets in this file can depend on this. + outputs = [ "$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}" ] + deps = [ "../../../../resources/audio_processing/test/py_quality_assessment:noise_tracks" ] + } -copy("output") { - testonly = true - sources = [ "output/README.md" ] - visibility = [ ":*" ] # Only targets in this file can depend on this. - outputs = - [ "$root_build_dir/py_quality_assessment/output/{{source_file_part}}" ] -} + copy("output") { + testonly = true + sources = [ "output/README.md" ] + visibility = [ ":*" ] # Only targets in this file can depend on this. + outputs = + [ "$root_build_dir/py_quality_assessment/output/{{source_file_part}}" ] + } -group("unit_tests") { - testonly = true - visibility = [ ":*" ] # Only targets in this file can depend on this. - deps = [ - ":apm_vad", - ":fake_polqa", - ":lib_unit_tests", - ":scripts_unit_tests", - ":vad", - ] -} + group("unit_tests") { + testonly = true + visibility = [ ":*" ] # Only targets in this file can depend on this. + deps = [ + ":apm_vad", + ":fake_polqa", + ":lib_unit_tests", + ":scripts_unit_tests", + ":vad", + ] + } -rtc_executable("fake_polqa") { - testonly = true - sources = [ "quality_assessment/fake_polqa.cc" ] - visibility = [ ":*" ] # Only targets in this file can depend on this. - output_dir = "${root_out_dir}/py_quality_assessment/quality_assessment" - deps = [ - "../../../../rtc_base:checks", - "../../../../rtc_base:rtc_base_approved", - ] -} + rtc_executable("fake_polqa") { + testonly = true + sources = [ "quality_assessment/fake_polqa.cc" ] + visibility = [ ":*" ] # Only targets in this file can depend on this. + output_dir = "${root_out_dir}/py_quality_assessment/quality_assessment" + deps = [ + "../../../../rtc_base:checks", + "../../../../rtc_base:rtc_base_approved", + ] + } -rtc_executable("vad") { - testonly = true - sources = [ "quality_assessment/vad.cc" ] - deps = [ - "../../../../common_audio", - "../../../../rtc_base:rtc_base_approved", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - ] -} + rtc_executable("vad") { + testonly = true + sources = [ "quality_assessment/vad.cc" ] + deps = [ + "../../../../common_audio", + "../../../../rtc_base:rtc_base_approved", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] + } -rtc_executable("apm_vad") { - testonly = true - sources = [ "quality_assessment/apm_vad.cc" ] - deps = [ - "../..", - "../../../../common_audio", - "../../../../rtc_base:rtc_base_approved", - "../../vad", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - ] -} + rtc_executable("apm_vad") { + testonly = true + sources = [ "quality_assessment/apm_vad.cc" ] + deps = [ + "../..", + "../../../../common_audio", + "../../../../rtc_base:rtc_base_approved", + "../../vad", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] + } -rtc_executable("sound_level") { - testonly = true - sources = [ "quality_assessment/sound_level.cc" ] - deps = [ - "../..", - "../../../../common_audio", - "../../../../rtc_base:rtc_base_approved", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - ] -} + rtc_executable("sound_level") { + testonly = true + sources = [ "quality_assessment/sound_level.cc" ] + deps = [ + "../..", + "../../../../common_audio", + "../../../../rtc_base:rtc_base_approved", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] + } -copy("lib_unit_tests") { - testonly = true - sources = [ - "quality_assessment/annotations_unittest.py", - "quality_assessment/echo_path_simulation_unittest.py", - "quality_assessment/eval_scores_unittest.py", - "quality_assessment/fake_external_vad.py", - "quality_assessment/input_mixer_unittest.py", - "quality_assessment/signal_processing_unittest.py", - "quality_assessment/simulation_unittest.py", - "quality_assessment/test_data_generation_unittest.py", - ] - visibility = [ ":*" ] # Only targets in this file can depend on this. - outputs = [ "$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}" ] -} + copy("lib_unit_tests") { + testonly = true + sources = [ + "quality_assessment/annotations_unittest.py", + "quality_assessment/echo_path_simulation_unittest.py", + "quality_assessment/eval_scores_unittest.py", + "quality_assessment/fake_external_vad.py", + "quality_assessment/input_mixer_unittest.py", + "quality_assessment/signal_processing_unittest.py", + "quality_assessment/simulation_unittest.py", + "quality_assessment/test_data_generation_unittest.py", + ] + visibility = [ ":*" ] # Only targets in this file can depend on this. + outputs = [ "$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}" ] + } -copy("scripts_unit_tests") { - testonly = true - sources = [ "apm_quality_assessment_unittest.py" ] - visibility = [ ":*" ] # Only targets in this file can depend on this. - outputs = [ "$root_build_dir/py_quality_assessment/{{source_file_part}}" ] + copy("scripts_unit_tests") { + testonly = true + sources = [ "apm_quality_assessment_unittest.py" ] + visibility = [ ":*" ] # Only targets in this file can depend on this. + outputs = [ "$root_build_dir/py_quality_assessment/{{source_file_part}}" ] + } } diff --git a/modules/audio_processing/test/wav_based_simulator.cc b/modules/audio_processing/test/wav_based_simulator.cc index 6dab469e2b..e6a6fe92eb 100644 --- a/modules/audio_processing/test/wav_based_simulator.cc +++ b/modules/audio_processing/test/wav_based_simulator.cc @@ -14,6 +14,7 @@ #include +#include "modules/audio_processing/logging/apm_data_dumper.h" #include "modules/audio_processing/test/test_utils.h" #include "rtc_base/checks.h" #include "rtc_base/system/file_wrapper.h" @@ -106,12 +107,15 @@ void WavBasedSimulator::Process() { bool samples_left_to_process = true; int call_chain_index = 0; - int num_forward_chunks_processed = 0; + int capture_frames_since_init = 0; + constexpr int kInitIndex = 1; while (samples_left_to_process) { switch (call_chain_[call_chain_index]) { case SimulationEventType::kProcessStream: + SelectivelyToggleDataDumping(kInitIndex, capture_frames_since_init); + samples_left_to_process = HandleProcessStreamCall(); - ++num_forward_chunks_processed; + ++capture_frames_since_init; break; case SimulationEventType::kProcessReverseStream: if (settings_.reverse_input_filename) { @@ -128,6 +132,14 @@ void WavBasedSimulator::Process() { DetachAecDump(); } +void WavBasedSimulator::Analyze() { + std::cout << "Inits:" << std::endl; + std::cout << "1: -->" << std::endl; + std::cout << " Time:" << std::endl; + std::cout << " Capture: 0 s (0 frames) " << std::endl; + std::cout << " Render: 0 s (0 frames)" << std::endl; +} + bool WavBasedSimulator::HandleProcessStreamCall() { bool samples_left_to_process = buffer_reader_->Read(in_buf_.get()); if (samples_left_to_process) { diff --git a/modules/audio_processing/test/wav_based_simulator.h b/modules/audio_processing/test/wav_based_simulator.h index 286ce1f587..ff88fd5535 100644 --- a/modules/audio_processing/test/wav_based_simulator.h +++ b/modules/audio_processing/test/wav_based_simulator.h @@ -34,6 +34,10 @@ class WavBasedSimulator final : public AudioProcessingSimulator { // Processes the WAV input. void Process() override; + // Only analyzes the data for the simulation, instead of perform any + // processing. + void Analyze() override; + private: enum SimulationEventType { kProcessStream, diff --git a/modules/audio_processing/transient/BUILD.gn b/modules/audio_processing/transient/BUILD.gn index 13e319f88e..5f9a13969a 100644 --- a/modules/audio_processing/transient/BUILD.gn +++ b/modules/audio_processing/transient/BUILD.gn @@ -14,10 +14,10 @@ rtc_source_set("transient_suppressor_api") { rtc_library("transient_suppressor_impl") { visibility = [ - "..:optionally_built_submodule_creators", + ":click_annotate", ":transient_suppression_test", ":transient_suppression_unittests", - ":click_annotate", + "..:optionally_built_submodule_creators", ] sources = [ "common.h", @@ -49,42 +49,44 @@ rtc_library("transient_suppressor_impl") { } if (rtc_include_tests) { - rtc_executable("click_annotate") { - testonly = true - sources = [ - "click_annotate.cc", - "file_utils.cc", - "file_utils.h", - ] - deps = [ - ":transient_suppressor_impl", - "..:audio_processing", - "../../../rtc_base/system:file_wrapper", - "../../../system_wrappers", - ] - } + if (!build_with_chromium) { + rtc_executable("click_annotate") { + testonly = true + sources = [ + "click_annotate.cc", + "file_utils.cc", + "file_utils.h", + ] + deps = [ + ":transient_suppressor_impl", + "..:audio_processing", + "../../../rtc_base/system:file_wrapper", + "../../../system_wrappers", + ] + } - rtc_executable("transient_suppression_test") { - testonly = true - sources = [ - "file_utils.cc", - "file_utils.h", - "transient_suppression_test.cc", - ] - deps = [ - ":transient_suppressor_impl", - "..:audio_processing", - "../../../common_audio", - "../../../rtc_base:rtc_base_approved", - "../../../rtc_base/system:file_wrapper", - "../../../system_wrappers", - "../../../test:fileutils", - "../../../test:test_support", - "../agc:level_estimation", - "//testing/gtest", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - ] + rtc_executable("transient_suppression_test") { + testonly = true + sources = [ + "file_utils.cc", + "file_utils.h", + "transient_suppression_test.cc", + ] + deps = [ + ":transient_suppressor_impl", + "..:audio_processing", + "../../../common_audio", + "../../../rtc_base:rtc_base_approved", + "../../../rtc_base/system:file_wrapper", + "../../../system_wrappers", + "../../../test:fileutils", + "../../../test:test_support", + "../agc:level_estimation", + "//testing/gtest", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] + } } rtc_library("transient_suppression_unittests") { diff --git a/modules/congestion_controller/BUILD.gn b/modules/congestion_controller/BUILD.gn index 231ff5e0dd..c0b064d9ed 100644 --- a/modules/congestion_controller/BUILD.gn +++ b/modules/congestion_controller/BUILD.gn @@ -22,12 +22,17 @@ rtc_library("congestion_controller") { sources = [ "include/receive_side_congestion_controller.h", "receive_side_congestion_controller.cc", + "remb_throttler.cc", + "remb_throttler.h", ] deps = [ "..:module_api", "../../api/transport:field_trial_based_config", "../../api/transport:network_control", + "../../api/units:data_rate", + "../../api/units:time_delta", + "../../api/units:timestamp", "../../rtc_base/synchronization:mutex", "../pacing", "../remote_bitrate_estimator", @@ -39,13 +44,21 @@ rtc_library("congestion_controller") { } } -if (rtc_include_tests) { +if (rtc_include_tests && !build_with_chromium) { rtc_library("congestion_controller_unittests") { testonly = true - sources = [ "receive_side_congestion_controller_unittest.cc" ] + sources = [ + "receive_side_congestion_controller_unittest.cc", + "remb_throttler_unittest.cc", + ] deps = [ ":congestion_controller", + "../../api/test/network_emulation", + "../../api/test/network_emulation:create_cross_traffic", + "../../api/units:data_rate", + "../../api/units:time_delta", + "../../api/units:timestamp", "../../system_wrappers", "../../test:test_support", "../../test/scenario", diff --git a/modules/congestion_controller/goog_cc/BUILD.gn b/modules/congestion_controller/goog_cc/BUILD.gn index e3be246347..ea20da87a3 100644 --- a/modules/congestion_controller/goog_cc/BUILD.gn +++ b/modules/congestion_controller/goog_cc/BUILD.gn @@ -226,10 +226,10 @@ rtc_library("probe_controller") { "../../../rtc_base:macromagic", "../../../rtc_base:safe_conversions", "../../../rtc_base/experiments:field_trial_parser", - "../../../rtc_base/system:unused", "../../../system_wrappers:metrics", ] absl_deps = [ + "//third_party/abseil-cpp/absl/base:core_headers", "//third_party/abseil-cpp/absl/strings", "//third_party/abseil-cpp/absl/types:optional", ] @@ -257,51 +257,55 @@ if (rtc_include_tests) { ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } - rtc_library("goog_cc_unittests") { - testonly = true + if (!build_with_chromium) { + rtc_library("goog_cc_unittests") { + testonly = true - sources = [ - "acknowledged_bitrate_estimator_unittest.cc", - "alr_detector_unittest.cc", - "congestion_window_pushback_controller_unittest.cc", - "delay_based_bwe_unittest.cc", - "delay_based_bwe_unittest_helper.cc", - "delay_based_bwe_unittest_helper.h", - "goog_cc_network_control_unittest.cc", - "probe_bitrate_estimator_unittest.cc", - "probe_controller_unittest.cc", - "robust_throughput_estimator_unittest.cc", - "send_side_bandwidth_estimation_unittest.cc", - "trendline_estimator_unittest.cc", - ] - deps = [ - ":alr_detector", - ":delay_based_bwe", - ":estimators", - ":goog_cc", - ":loss_based_controller", - ":probe_controller", - ":pushback_controller", - "../../../api/rtc_event_log", - "../../../api/transport:field_trial_based_config", - "../../../api/transport:goog_cc", - "../../../api/transport:network_control", - "../../../api/transport:webrtc_key_value_config", - "../../../api/units:data_rate", - "../../../api/units:timestamp", - "../../../logging:mocks", - "../../../logging:rtc_event_bwe", - "../../../rtc_base:checks", - "../../../rtc_base:rtc_base_approved", - "../../../rtc_base:rtc_base_tests_utils", - "../../../rtc_base/experiments:alr_experiment", - "../../../system_wrappers", - "../../../test:explicit_key_value_config", - "../../../test:field_trial", - "../../../test:test_support", - "../../../test/scenario", - "../../pacing", - "//testing/gmock", - ] + sources = [ + "acknowledged_bitrate_estimator_unittest.cc", + "alr_detector_unittest.cc", + "congestion_window_pushback_controller_unittest.cc", + "delay_based_bwe_unittest.cc", + "delay_based_bwe_unittest_helper.cc", + "delay_based_bwe_unittest_helper.h", + "goog_cc_network_control_unittest.cc", + "probe_bitrate_estimator_unittest.cc", + "probe_controller_unittest.cc", + "robust_throughput_estimator_unittest.cc", + "send_side_bandwidth_estimation_unittest.cc", + "trendline_estimator_unittest.cc", + ] + deps = [ + ":alr_detector", + ":delay_based_bwe", + ":estimators", + ":goog_cc", + ":loss_based_controller", + ":probe_controller", + ":pushback_controller", + "../../../api/rtc_event_log", + "../../../api/test/network_emulation", + "../../../api/test/network_emulation:create_cross_traffic", + "../../../api/transport:field_trial_based_config", + "../../../api/transport:goog_cc", + "../../../api/transport:network_control", + "../../../api/transport:webrtc_key_value_config", + "../../../api/units:data_rate", + "../../../api/units:timestamp", + "../../../logging:mocks", + "../../../logging:rtc_event_bwe", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../../../rtc_base:rtc_base_tests_utils", + "../../../rtc_base/experiments:alr_experiment", + "../../../system_wrappers", + "../../../test:explicit_key_value_config", + "../../../test:field_trial", + "../../../test:test_support", + "../../../test/scenario", + "../../pacing", + "//testing/gmock", + ] + } } } diff --git a/modules/congestion_controller/goog_cc/goog_cc_network_control.cc b/modules/congestion_controller/goog_cc/goog_cc_network_control.cc index 0a0b1801f2..2344f45a65 100644 --- a/modules/congestion_controller/goog_cc/goog_cc_network_control.cc +++ b/modules/congestion_controller/goog_cc/goog_cc_network_control.cc @@ -465,7 +465,7 @@ NetworkControlUpdate GoogCcNetworkController::OnTransportPacketsFeedback( expected_packets_since_last_loss_update_ += report.PacketsWithFeedback().size(); for (const auto& packet_feedback : report.PacketsWithFeedback()) { - if (packet_feedback.receive_time.IsInfinite()) + if (!packet_feedback.IsReceived()) lost_packets_since_last_loss_update_ += 1; } if (report.feedback_time > next_loss_update_) { diff --git a/modules/congestion_controller/goog_cc/goog_cc_network_control_unittest.cc b/modules/congestion_controller/goog_cc/goog_cc_network_control_unittest.cc index 0510cb99b7..7e8d7b9ac6 100644 --- a/modules/congestion_controller/goog_cc/goog_cc_network_control_unittest.cc +++ b/modules/congestion_controller/goog_cc/goog_cc_network_control_unittest.cc @@ -10,6 +10,8 @@ #include +#include "api/test/network_emulation/create_cross_traffic.h" +#include "api/test/network_emulation/cross_traffic.h" #include "api/transport/goog_cc_factory.h" #include "api/units/data_rate.h" #include "logging/rtc_event_log/mock/mock_rtc_event_log.h" @@ -122,6 +124,35 @@ void UpdatesTargetRateBasedOnLinkCapacity(std::string test_name = "") { truth->PrintRow(); EXPECT_NEAR(client->target_rate().kbps(), 90, 25); } + +DataRate RunRembDipScenario(std::string test_name) { + Scenario s(test_name); + NetworkSimulationConfig net_conf; + net_conf.bandwidth = DataRate::KilobitsPerSec(2000); + net_conf.delay = TimeDelta::Millis(50); + auto* client = s.CreateClient("send", [&](CallClientConfig* c) { + c->transport.rates.start_rate = DataRate::KilobitsPerSec(1000); + }); + auto send_net = {s.CreateSimulationNode(net_conf)}; + auto ret_net = {s.CreateSimulationNode(net_conf)}; + auto* route = s.CreateRoutes( + client, send_net, s.CreateClient("return", CallClientConfig()), ret_net); + s.CreateVideoStream(route->forward(), VideoStreamConfig()); + + s.RunFor(TimeDelta::Seconds(10)); + EXPECT_GT(client->send_bandwidth().kbps(), 1500); + + DataRate RembLimit = DataRate::KilobitsPerSec(250); + client->SetRemoteBitrate(RembLimit); + s.RunFor(TimeDelta::Seconds(1)); + EXPECT_EQ(client->send_bandwidth(), RembLimit); + + DataRate RembLimitLifted = DataRate::KilobitsPerSec(10000); + client->SetRemoteBitrate(RembLimitLifted); + s.RunFor(TimeDelta::Seconds(10)); + + return client->send_bandwidth(); +} } // namespace class GoogCcNetworkControllerTest : public ::testing::Test { @@ -547,8 +578,9 @@ DataRate AverageBitrateAfterCrossInducedLoss(std::string name) { s.RunFor(TimeDelta::Seconds(10)); for (int i = 0; i < 4; ++i) { // Sends TCP cross traffic inducing loss. - auto* tcp_traffic = - s.net()->StartFakeTcpCrossTraffic(send_net, ret_net, FakeTcpConfig()); + auto* tcp_traffic = s.net()->StartCrossTraffic(CreateFakeTcpCrossTraffic( + s.net()->CreateRoute(send_net), s.net()->CreateRoute(ret_net), + FakeTcpConfig())); s.RunFor(TimeDelta::Seconds(2)); // Allow the ccongestion controller to recover. s.net()->StopCrossTraffic(tcp_traffic); @@ -836,7 +868,9 @@ TEST_F(GoogCcNetworkControllerTest, IsFairToTCP) { auto* route = s.CreateRoutes( client, send_net, s.CreateClient("return", CallClientConfig()), ret_net); s.CreateVideoStream(route->forward(), VideoStreamConfig()); - s.net()->StartFakeTcpCrossTraffic(send_net, ret_net, FakeTcpConfig()); + s.net()->StartCrossTraffic(CreateFakeTcpCrossTraffic( + s.net()->CreateRoute(send_net), s.net()->CreateRoute(ret_net), + FakeTcpConfig())); s.RunFor(TimeDelta::Seconds(10)); // Currently only testing for the upper limit as we in practice back out @@ -845,33 +879,17 @@ TEST_F(GoogCcNetworkControllerTest, IsFairToTCP) { EXPECT_LT(client->send_bandwidth().kbps(), 750); } -TEST(GoogCcScenario, RampupOnRembCapLifted) { +TEST(GoogCcScenario, FastRampupOnRembCapLiftedWithFieldTrial) { ScopedFieldTrials trial("WebRTC-Bwe-ReceiverLimitCapsOnly/Enabled/"); - Scenario s("googcc_unit/rampup_ramb_cap_lifted"); - NetworkSimulationConfig net_conf; - net_conf.bandwidth = DataRate::KilobitsPerSec(2000); - net_conf.delay = TimeDelta::Millis(50); - auto* client = s.CreateClient("send", [&](CallClientConfig* c) { - c->transport.rates.start_rate = DataRate::KilobitsPerSec(1000); - }); - auto send_net = {s.CreateSimulationNode(net_conf)}; - auto ret_net = {s.CreateSimulationNode(net_conf)}; - auto* route = s.CreateRoutes( - client, send_net, s.CreateClient("return", CallClientConfig()), ret_net); - s.CreateVideoStream(route->forward(), VideoStreamConfig()); - - s.RunFor(TimeDelta::Seconds(10)); - EXPECT_GT(client->send_bandwidth().kbps(), 1500); - - DataRate RembLimit = DataRate::KilobitsPerSec(250); - client->SetRemoteBitrate(RembLimit); - s.RunFor(TimeDelta::Seconds(1)); - EXPECT_EQ(client->send_bandwidth(), RembLimit); + DataRate final_estimate = + RunRembDipScenario("googcc_unit/fast_rampup_on_remb_cap_lifted"); + EXPECT_GT(final_estimate.kbps(), 1500); +} - DataRate RembLimitLifted = DataRate::KilobitsPerSec(10000); - client->SetRemoteBitrate(RembLimitLifted); - s.RunFor(TimeDelta::Seconds(10)); - EXPECT_GT(client->send_bandwidth().kbps(), 1500); +TEST(GoogCcScenario, SlowRampupOnRembCapLifted) { + DataRate final_estimate = + RunRembDipScenario("googcc_unit/default_slow_rampup_on_remb_cap_lifted"); + EXPECT_LT(final_estimate.kbps(), 1000); } } // namespace test diff --git a/modules/congestion_controller/goog_cc/loss_based_bandwidth_estimation.cc b/modules/congestion_controller/goog_cc/loss_based_bandwidth_estimation.cc index 1d2aab8521..c7f53c62f2 100644 --- a/modules/congestion_controller/goog_cc/loss_based_bandwidth_estimation.cc +++ b/modules/congestion_controller/goog_cc/loss_based_bandwidth_estimation.cc @@ -14,14 +14,18 @@ #include #include +#include "absl/strings/match.h" #include "api/units/data_rate.h" #include "api/units/time_delta.h" -#include "system_wrappers/include/field_trial.h" namespace webrtc { namespace { const char kBweLossBasedControl[] = "WebRTC-Bwe-LossBasedControl"; +// Expecting RTCP feedback to be sent with roughly 1s intervals, a 5s gap +// indicates a channel outage. +constexpr TimeDelta kMaxRtcpFeedbackInterval = TimeDelta::Millis(5000); + // Increase slower when RTT is high. double GetIncreaseFactor(const LossBasedControlConfig& config, TimeDelta rtt) { // Clamp the RTT @@ -32,7 +36,7 @@ double GetIncreaseFactor(const LossBasedControlConfig& config, TimeDelta rtt) { } auto rtt_range = config.increase_high_rtt.Get() - config.increase_low_rtt; if (rtt_range <= TimeDelta::Zero()) { - RTC_DCHECK(false); // Only on misconfiguration. + RTC_NOTREACHED(); // Only on misconfiguration. return config.min_increase_factor; } auto rtt_offset = rtt - config.increase_low_rtt; @@ -53,7 +57,7 @@ DataRate BitrateFromLoss(double loss, DataRate loss_bandwidth_balance, double exponent) { if (exponent <= 0) { - RTC_DCHECK(false); + RTC_NOTREACHED(); return DataRate::Infinity(); } if (loss < 1e-5) @@ -65,16 +69,22 @@ double ExponentialUpdate(TimeDelta window, TimeDelta interval) { // Use the convention that exponential window length (which is really // infinite) is the time it takes to dampen to 1/e. if (window <= TimeDelta::Zero()) { - RTC_DCHECK(false); + RTC_NOTREACHED(); return 1.0f; } return 1.0f - exp(interval / window * -1.0); } +bool IsEnabled(const webrtc::WebRtcKeyValueConfig& key_value_config, + absl::string_view name) { + return absl::StartsWith(key_value_config.Lookup(name), "Enabled"); +} + } // namespace -LossBasedControlConfig::LossBasedControlConfig() - : enabled(field_trial::IsEnabled(kBweLossBasedControl)), +LossBasedControlConfig::LossBasedControlConfig( + const WebRtcKeyValueConfig* key_value_config) + : enabled(IsEnabled(*key_value_config, kBweLossBasedControl)), min_increase_factor("min_incr", 1.02), max_increase_factor("max_incr", 1.08), increase_low_rtt("incr_low_rtt", TimeDelta::Millis(200)), @@ -88,26 +98,28 @@ LossBasedControlConfig::LossBasedControlConfig() DataRate::KilobitsPerSec(0.5)), loss_bandwidth_balance_decrease("balance_decr", DataRate::KilobitsPerSec(4)), + loss_bandwidth_balance_reset("balance_reset", + DataRate::KilobitsPerSec(0.1)), loss_bandwidth_balance_exponent("exponent", 0.5), allow_resets("resets", false), decrease_interval("decr_intvl", TimeDelta::Millis(300)), loss_report_timeout("timeout", TimeDelta::Millis(6000)) { - std::string trial_string = field_trial::FindFullName(kBweLossBasedControl); ParseFieldTrial( {&min_increase_factor, &max_increase_factor, &increase_low_rtt, &increase_high_rtt, &decrease_factor, &loss_window, &loss_max_window, &acknowledged_rate_max_window, &increase_offset, &loss_bandwidth_balance_increase, &loss_bandwidth_balance_decrease, - &loss_bandwidth_balance_exponent, &allow_resets, &decrease_interval, - &loss_report_timeout}, - trial_string); + &loss_bandwidth_balance_reset, &loss_bandwidth_balance_exponent, + &allow_resets, &decrease_interval, &loss_report_timeout}, + key_value_config->Lookup(kBweLossBasedControl)); } LossBasedControlConfig::LossBasedControlConfig(const LossBasedControlConfig&) = default; LossBasedControlConfig::~LossBasedControlConfig() = default; -LossBasedBandwidthEstimation::LossBasedBandwidthEstimation() - : config_(LossBasedControlConfig()), +LossBasedBandwidthEstimation::LossBasedBandwidthEstimation( + const WebRtcKeyValueConfig* key_value_config) + : config_(key_value_config), average_loss_(0), average_loss_max_(0), loss_based_bitrate_(DataRate::Zero()), @@ -122,12 +134,12 @@ void LossBasedBandwidthEstimation::UpdateLossStatistics( const std::vector& packet_results, Timestamp at_time) { if (packet_results.empty()) { - RTC_DCHECK(false); + RTC_NOTREACHED(); return; } int loss_count = 0; for (const auto& pkt : packet_results) { - loss_count += pkt.receive_time.IsInfinite() ? 1 : 0; + loss_count += !pkt.IsReceived() ? 1 : 0; } last_loss_ratio_ = static_cast(loss_count) / packet_results.size(); const TimeDelta time_passed = last_loss_packet_report_.IsFinite() @@ -164,9 +176,14 @@ void LossBasedBandwidthEstimation::UpdateAcknowledgedBitrate( } } -void LossBasedBandwidthEstimation::Update(Timestamp at_time, - DataRate min_bitrate, - TimeDelta last_round_trip_time) { +DataRate LossBasedBandwidthEstimation::Update(Timestamp at_time, + DataRate min_bitrate, + DataRate wanted_bitrate, + TimeDelta last_round_trip_time) { + if (loss_based_bitrate_.IsZero()) { + loss_based_bitrate_ = wanted_bitrate; + } + // Only increase if loss has been low for some time. const double loss_estimate_for_increase = average_loss_max_; // Avoid multiple decreases from averaging over one loss spike. @@ -176,8 +193,15 @@ void LossBasedBandwidthEstimation::Update(Timestamp at_time, !has_decreased_since_last_loss_report_ && (at_time - time_last_decrease_ >= last_round_trip_time + config_.decrease_interval); + // If packet lost reports are too old, dont increase bitrate. + const bool loss_report_valid = + at_time - last_loss_packet_report_ < 1.2 * kMaxRtcpFeedbackInterval; - if (loss_estimate_for_increase < loss_increase_threshold()) { + if (loss_report_valid && config_.allow_resets && + loss_estimate_for_increase < loss_reset_threshold()) { + loss_based_bitrate_ = wanted_bitrate; + } else if (loss_report_valid && + loss_estimate_for_increase < loss_increase_threshold()) { // Increase bitrate by RTT-adaptive ratio. DataRate new_increased_bitrate = min_bitrate * GetIncreaseFactor(config_, last_round_trip_time) + @@ -203,14 +227,21 @@ void LossBasedBandwidthEstimation::Update(Timestamp at_time, loss_based_bitrate_ = new_decreased_bitrate; } } + return loss_based_bitrate_; } -void LossBasedBandwidthEstimation::Reset(DataRate bitrate) { +void LossBasedBandwidthEstimation::Initialize(DataRate bitrate) { loss_based_bitrate_ = bitrate; average_loss_ = 0; average_loss_max_ = 0; } +double LossBasedBandwidthEstimation::loss_reset_threshold() const { + return LossFromBitrate(loss_based_bitrate_, + config_.loss_bandwidth_balance_reset, + config_.loss_bandwidth_balance_exponent); +} + double LossBasedBandwidthEstimation::loss_increase_threshold() const { return LossFromBitrate(loss_based_bitrate_, config_.loss_bandwidth_balance_increase, @@ -226,14 +257,4 @@ double LossBasedBandwidthEstimation::loss_decrease_threshold() const { DataRate LossBasedBandwidthEstimation::decreased_bitrate() const { return config_.decrease_factor * acknowledged_bitrate_max_; } - -void LossBasedBandwidthEstimation::MaybeReset(DataRate bitrate) { - if (config_.allow_resets) - Reset(bitrate); -} - -void LossBasedBandwidthEstimation::SetInitialBitrate(DataRate bitrate) { - Reset(bitrate); -} - } // namespace webrtc diff --git a/modules/congestion_controller/goog_cc/loss_based_bandwidth_estimation.h b/modules/congestion_controller/goog_cc/loss_based_bandwidth_estimation.h index b63363cadd..20ff092e6f 100644 --- a/modules/congestion_controller/goog_cc/loss_based_bandwidth_estimation.h +++ b/modules/congestion_controller/goog_cc/loss_based_bandwidth_estimation.h @@ -14,6 +14,7 @@ #include #include "api/transport/network_types.h" +#include "api/transport/webrtc_key_value_config.h" #include "api/units/data_rate.h" #include "api/units/time_delta.h" #include "api/units/timestamp.h" @@ -22,7 +23,7 @@ namespace webrtc { struct LossBasedControlConfig { - LossBasedControlConfig(); + explicit LossBasedControlConfig(const WebRtcKeyValueConfig* key_value_config); LossBasedControlConfig(const LossBasedControlConfig&); LossBasedControlConfig& operator=(const LossBasedControlConfig&) = default; ~LossBasedControlConfig(); @@ -38,23 +39,34 @@ struct LossBasedControlConfig { FieldTrialParameter increase_offset; FieldTrialParameter loss_bandwidth_balance_increase; FieldTrialParameter loss_bandwidth_balance_decrease; + FieldTrialParameter loss_bandwidth_balance_reset; FieldTrialParameter loss_bandwidth_balance_exponent; FieldTrialParameter allow_resets; FieldTrialParameter decrease_interval; FieldTrialParameter loss_report_timeout; }; +// Estimates an upper BWE limit based on loss. +// It requires knowledge about lost packets and acknowledged bitrate. +// Ie, this class require transport feedback. class LossBasedBandwidthEstimation { public: - LossBasedBandwidthEstimation(); - void Update(Timestamp at_time, - DataRate min_bitrate, - TimeDelta last_round_trip_time); + explicit LossBasedBandwidthEstimation( + const WebRtcKeyValueConfig* key_value_config); + // Returns the new estimate. + DataRate Update(Timestamp at_time, + DataRate min_bitrate, + DataRate wanted_bitrate, + TimeDelta last_round_trip_time); void UpdateAcknowledgedBitrate(DataRate acknowledged_bitrate, Timestamp at_time); - void MaybeReset(DataRate bitrate); - void SetInitialBitrate(DataRate bitrate); + void Initialize(DataRate bitrate); bool Enabled() const { return config_.enabled; } + // Returns true if LossBasedBandwidthEstimation is enabled and have + // received loss statistics. Ie, this class require transport feedback. + bool InUse() const { + return Enabled() && last_loss_packet_report_.IsFinite(); + } void UpdateLossStatistics(const std::vector& packet_results, Timestamp at_time); DataRate GetEstimate() const { return loss_based_bitrate_; } @@ -64,9 +76,11 @@ class LossBasedBandwidthEstimation { void Reset(DataRate bitrate); double loss_increase_threshold() const; double loss_decrease_threshold() const; + double loss_reset_threshold() const; + DataRate decreased_bitrate() const; - LossBasedControlConfig config_; + const LossBasedControlConfig config_; double average_loss_; double average_loss_max_; DataRate loss_based_bitrate_; diff --git a/modules/congestion_controller/goog_cc/probe_controller.h b/modules/congestion_controller/goog_cc/probe_controller.h index 11e92b97ae..bcaa293209 100644 --- a/modules/congestion_controller/goog_cc/probe_controller.h +++ b/modules/congestion_controller/goog_cc/probe_controller.h @@ -16,6 +16,7 @@ #include #include +#include "absl/base/attributes.h" #include "absl/types/optional.h" #include "api/rtc_event_log/rtc_event_log.h" #include "api/transport/network_control.h" @@ -23,7 +24,6 @@ #include "api/units/data_rate.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/experiments/field_trial_parser.h" -#include "rtc_base/system/unused.h" namespace webrtc { @@ -63,7 +63,7 @@ class ProbeController { RtcEventLog* event_log); ~ProbeController(); - RTC_WARN_UNUSED_RESULT std::vector SetBitrates( + ABSL_MUST_USE_RESULT std::vector SetBitrates( int64_t min_bitrate_bps, int64_t start_bitrate_bps, int64_t max_bitrate_bps, @@ -71,14 +71,14 @@ class ProbeController { // The total bitrate, as opposed to the max bitrate, is the sum of the // configured bitrates for all active streams. - RTC_WARN_UNUSED_RESULT std::vector + ABSL_MUST_USE_RESULT std::vector OnMaxTotalAllocatedBitrate(int64_t max_total_allocated_bitrate, int64_t at_time_ms); - RTC_WARN_UNUSED_RESULT std::vector OnNetworkAvailability( + ABSL_MUST_USE_RESULT std::vector OnNetworkAvailability( NetworkAvailability msg); - RTC_WARN_UNUSED_RESULT std::vector SetEstimatedBitrate( + ABSL_MUST_USE_RESULT std::vector SetEstimatedBitrate( int64_t bitrate_bps, int64_t at_time_ms); @@ -87,7 +87,7 @@ class ProbeController { void SetAlrStartTimeMs(absl::optional alr_start_time); void SetAlrEndedTimeMs(int64_t alr_end_time); - RTC_WARN_UNUSED_RESULT std::vector RequestProbe( + ABSL_MUST_USE_RESULT std::vector RequestProbe( int64_t at_time_ms); // Sets a new maximum probing bitrate, without generating a new probe cluster. @@ -97,7 +97,7 @@ class ProbeController { // created EXCEPT for |enable_periodic_alr_probing_|. void Reset(int64_t at_time_ms); - RTC_WARN_UNUSED_RESULT std::vector Process( + ABSL_MUST_USE_RESULT std::vector Process( int64_t at_time_ms); private: @@ -110,9 +110,9 @@ class ProbeController { kProbingComplete, }; - RTC_WARN_UNUSED_RESULT std::vector + ABSL_MUST_USE_RESULT std::vector InitiateExponentialProbing(int64_t at_time_ms); - RTC_WARN_UNUSED_RESULT std::vector InitiateProbing( + ABSL_MUST_USE_RESULT std::vector InitiateProbing( int64_t now_ms, std::vector bitrates_to_probe, bool probe_further); diff --git a/modules/congestion_controller/goog_cc/send_side_bandwidth_estimation.cc b/modules/congestion_controller/goog_cc/send_side_bandwidth_estimation.cc index 8de2a91114..c5f51df99b 100644 --- a/modules/congestion_controller/goog_cc/send_side_bandwidth_estimation.cc +++ b/modules/congestion_controller/goog_cc/send_side_bandwidth_estimation.cc @@ -226,6 +226,7 @@ SendSideBandwidthEstimation::SendSideBandwidthEstimation( low_loss_threshold_(kDefaultLowLossThreshold), high_loss_threshold_(kDefaultHighLossThreshold), bitrate_threshold_(kDefaultBitrateThreshold), + loss_based_bandwidth_estimation_(key_value_config), receiver_limit_caps_only_("Enabled") { RTC_DCHECK(event_log); if (BweLossExperimentIsEnabled()) { @@ -287,9 +288,6 @@ void SendSideBandwidthEstimation::SetSendBitrate(DataRate bitrate, RTC_DCHECK_GT(bitrate, DataRate::Zero()); // Reset to avoid being capped by the estimate. delay_based_limit_ = DataRate::PlusInfinity(); - if (loss_based_bandwidth_estimation_.Enabled()) { - loss_based_bandwidth_estimation_.MaybeReset(bitrate); - } UpdateTargetBitrate(bitrate, at_time); // Clear last sent bitrate history so the new value can be used directly // and not capped. @@ -462,7 +460,7 @@ void SendSideBandwidthEstimation::UpdateEstimate(Timestamp at_time) { if (delay_based_limit_.IsFinite()) new_bitrate = std::max(delay_based_limit_, new_bitrate); if (loss_based_bandwidth_estimation_.Enabled()) { - loss_based_bandwidth_estimation_.SetInitialBitrate(new_bitrate); + loss_based_bandwidth_estimation_.Initialize(new_bitrate); } if (new_bitrate != current_target_) { @@ -485,10 +483,10 @@ void SendSideBandwidthEstimation::UpdateEstimate(Timestamp at_time) { return; } - if (loss_based_bandwidth_estimation_.Enabled()) { - loss_based_bandwidth_estimation_.Update( - at_time, min_bitrate_history_.front().second, last_round_trip_time_); - DataRate new_bitrate = MaybeRampupOrBackoff(current_target_, at_time); + if (loss_based_bandwidth_estimation_.InUse()) { + DataRate new_bitrate = loss_based_bandwidth_estimation_.Update( + at_time, min_bitrate_history_.front().second, delay_based_limit_, + last_round_trip_time_); UpdateTargetBitrate(new_bitrate, at_time); return; } @@ -585,30 +583,11 @@ void SendSideBandwidthEstimation::UpdateMinHistory(Timestamp at_time) { min_bitrate_history_.push_back(std::make_pair(at_time, current_target_)); } -DataRate SendSideBandwidthEstimation::MaybeRampupOrBackoff(DataRate new_bitrate, - Timestamp at_time) { - // TODO(crodbro): reuse this code in UpdateEstimate instead of current - // inlining of very similar functionality. - const TimeDelta time_since_loss_packet_report = - at_time - last_loss_packet_report_; - if (time_since_loss_packet_report < 1.2 * kMaxRtcpFeedbackInterval) { - new_bitrate = min_bitrate_history_.front().second * 1.08; - new_bitrate += DataRate::BitsPerSec(1000); - } - return new_bitrate; -} - DataRate SendSideBandwidthEstimation::GetUpperLimit() const { DataRate upper_limit = delay_based_limit_; if (!receiver_limit_caps_only_) upper_limit = std::min(upper_limit, receiver_limit_); - upper_limit = std::min(upper_limit, max_bitrate_configured_); - if (loss_based_bandwidth_estimation_.Enabled() && - loss_based_bandwidth_estimation_.GetEstimate() > DataRate::Zero()) { - upper_limit = - std::min(upper_limit, loss_based_bandwidth_estimation_.GetEstimate()); - } - return upper_limit; + return std::min(upper_limit, max_bitrate_configured_); } void SendSideBandwidthEstimation::MaybeLogLowBitrateWarning(DataRate bitrate, diff --git a/modules/congestion_controller/goog_cc/send_side_bandwidth_estimation.h b/modules/congestion_controller/goog_cc/send_side_bandwidth_estimation.h index 3fa8c4b282..b97b940db0 100644 --- a/modules/congestion_controller/goog_cc/send_side_bandwidth_estimation.h +++ b/modules/congestion_controller/goog_cc/send_side_bandwidth_estimation.h @@ -131,8 +131,6 @@ class SendSideBandwidthEstimation { // min bitrate used during last kBweIncreaseIntervalMs. void UpdateMinHistory(Timestamp at_time); - DataRate MaybeRampupOrBackoff(DataRate new_bitrate, Timestamp at_time); - // Gets the upper limit for the target bitrate. This is the minimum of the // delay based limit, the receiver limit and the loss based controller limit. DataRate GetUpperLimit() const; diff --git a/modules/congestion_controller/include/receive_side_congestion_controller.h b/modules/congestion_controller/include/receive_side_congestion_controller.h index 034f2e9517..84661c05b7 100644 --- a/modules/congestion_controller/include/receive_side_congestion_controller.h +++ b/modules/congestion_controller/include/receive_side_congestion_controller.h @@ -16,7 +16,10 @@ #include "api/transport/field_trial_based_config.h" #include "api/transport/network_control.h" +#include "api/units/data_rate.h" +#include "modules/congestion_controller/remb_throttler.h" #include "modules/include/module.h" +#include "modules/pacing/packet_router.h" #include "modules/remote_bitrate_estimator/remote_estimator_proxy.h" #include "rtc_base/synchronization/mutex.h" @@ -32,10 +35,10 @@ class RemoteBitrateObserver; class ReceiveSideCongestionController : public CallStatsObserver, public Module { public: - ReceiveSideCongestionController(Clock* clock, PacketRouter* packet_router); ReceiveSideCongestionController( Clock* clock, - PacketRouter* packet_router, + RemoteEstimatorProxy::TransportFeedbackSender feedback_sender, + RembThrottler::RembSender remb_sender, NetworkStateEstimator* network_state_estimator); ~ReceiveSideCongestionController() override {} @@ -56,6 +59,10 @@ class ReceiveSideCongestionController : public CallStatsObserver, // This is send bitrate, used to control the rate of feedback messages. void OnBitrateChanged(int bitrate_bps); + // Ensures the remote party is notified of the receive bitrate no larger than + // |bitrate| using RTCP REMB. + void SetMaxDesiredReceiveBitrate(DataRate bitrate); + // Implements Module. int64_t TimeUntilNextProcess() override; void Process() override; @@ -103,6 +110,7 @@ class ReceiveSideCongestionController : public CallStatsObserver, }; const FieldTrialBasedConfig field_trial_config_; + RembThrottler remb_throttler_; WrappingBitrateEstimator remote_bitrate_estimator_; RemoteEstimatorProxy remote_estimator_proxy_; }; diff --git a/modules/congestion_controller/pcc/BUILD.gn b/modules/congestion_controller/pcc/BUILD.gn index 2f378769e7..38a3b8ad7c 100644 --- a/modules/congestion_controller/pcc/BUILD.gn +++ b/modules/congestion_controller/pcc/BUILD.gn @@ -98,7 +98,7 @@ rtc_library("bitrate_controller") { absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } -if (rtc_include_tests) { +if (rtc_include_tests && !build_with_chromium) { rtc_library("pcc_unittests") { testonly = true sources = [ diff --git a/modules/congestion_controller/pcc/monitor_interval.cc b/modules/congestion_controller/pcc/monitor_interval.cc index c8efd5b59a..6bc9f4a7ef 100644 --- a/modules/congestion_controller/pcc/monitor_interval.cc +++ b/modules/congestion_controller/pcc/monitor_interval.cc @@ -47,7 +47,7 @@ void PccMonitorInterval::OnPacketsFeedback( feedback_collection_done_ = true; return; } - if (packet_result.receive_time.IsInfinite()) { + if (!packet_result.IsReceived()) { lost_packets_sent_time_.push_back(packet_result.sent_packet.send_time); } else { received_packets_.push_back( diff --git a/modules/congestion_controller/pcc/rtt_tracker.cc b/modules/congestion_controller/pcc/rtt_tracker.cc index 0814912b49..af9dc8f11b 100644 --- a/modules/congestion_controller/pcc/rtt_tracker.cc +++ b/modules/congestion_controller/pcc/rtt_tracker.cc @@ -23,7 +23,7 @@ void RttTracker::OnPacketsFeedback( Timestamp feedback_received_time) { TimeDelta packet_rtt = TimeDelta::MinusInfinity(); for (const PacketResult& packet_result : packet_feedbacks) { - if (packet_result.receive_time.IsInfinite()) + if (!packet_result.IsReceived()) continue; packet_rtt = std::max( packet_rtt, diff --git a/modules/congestion_controller/receive_side_congestion_controller.cc b/modules/congestion_controller/receive_side_congestion_controller.cc index 638cb2d295..61a126fbe3 100644 --- a/modules/congestion_controller/receive_side_congestion_controller.cc +++ b/modules/congestion_controller/receive_side_congestion_controller.cc @@ -10,6 +10,7 @@ #include "modules/congestion_controller/include/receive_side_congestion_controller.h" +#include "api/units/data_rate.h" #include "modules/pacing/packet_router.h" #include "modules/remote_bitrate_estimator/include/bwe_defines.h" #include "modules/remote_bitrate_estimator/remote_bitrate_estimator_abs_send_time.h" @@ -120,16 +121,13 @@ void ReceiveSideCongestionController::WrappingBitrateEstimator:: ReceiveSideCongestionController::ReceiveSideCongestionController( Clock* clock, - PacketRouter* packet_router) - : ReceiveSideCongestionController(clock, packet_router, nullptr) {} - -ReceiveSideCongestionController::ReceiveSideCongestionController( - Clock* clock, - PacketRouter* packet_router, + RemoteEstimatorProxy::TransportFeedbackSender feedback_sender, + RembThrottler::RembSender remb_sender, NetworkStateEstimator* network_state_estimator) - : remote_bitrate_estimator_(packet_router, clock), + : remb_throttler_(std::move(remb_sender), clock), + remote_bitrate_estimator_(&remb_throttler_, clock), remote_estimator_proxy_(clock, - packet_router, + std::move(feedback_sender), &field_trial_config_, network_state_estimator) {} @@ -186,4 +184,9 @@ void ReceiveSideCongestionController::Process() { remote_bitrate_estimator_.Process(); } +void ReceiveSideCongestionController::SetMaxDesiredReceiveBitrate( + DataRate bitrate) { + remb_throttler_.SetMaxDesiredReceiveBitrate(bitrate); +} + } // namespace webrtc diff --git a/modules/congestion_controller/receive_side_congestion_controller_unittest.cc b/modules/congestion_controller/receive_side_congestion_controller_unittest.cc index b5846237ee..5e03179f42 100644 --- a/modules/congestion_controller/receive_side_congestion_controller_unittest.cc +++ b/modules/congestion_controller/receive_side_congestion_controller_unittest.cc @@ -10,6 +10,8 @@ #include "modules/congestion_controller/include/receive_side_congestion_controller.h" +#include "api/test/network_emulation/create_cross_traffic.h" +#include "api/test/network_emulation/cross_traffic.h" #include "modules/pacing/packet_router.h" #include "system_wrappers/include/clock.h" #include "test/gmock.h" @@ -18,10 +20,8 @@ using ::testing::_; using ::testing::AtLeast; -using ::testing::NiceMock; -using ::testing::Return; -using ::testing::SaveArg; -using ::testing::StrictMock; +using ::testing::ElementsAre; +using ::testing::MockFunction; namespace webrtc { @@ -35,34 +35,28 @@ uint32_t AbsSendTime(int64_t t, int64_t denom) { return (((t << 18) + (denom >> 1)) / denom) & 0x00fffffful; } -class MockPacketRouter : public PacketRouter { - public: - MOCK_METHOD(void, - OnReceiveBitrateChanged, - (const std::vector& ssrcs, uint32_t bitrate), - (override)); -}; - const uint32_t kInitialBitrateBps = 60000; } // namespace namespace test { -TEST(ReceiveSideCongestionControllerTest, OnReceivedPacketWithAbsSendTime) { - StrictMock packet_router; +TEST(ReceiveSideCongestionControllerTest, SendsRembWithAbsSendTime) { + MockFunction>)> + feedback_sender; + MockFunction)> remb_sender; SimulatedClock clock_(123456); - ReceiveSideCongestionController controller(&clock_, &packet_router); + ReceiveSideCongestionController controller( + &clock_, feedback_sender.AsStdFunction(), remb_sender.AsStdFunction(), + nullptr); size_t payload_size = 1000; RTPHeader header; header.ssrc = 0x11eb21c; header.extension.hasAbsoluteSendTime = true; - std::vector ssrcs; - EXPECT_CALL(packet_router, OnReceiveBitrateChanged(_, _)) - .WillRepeatedly(SaveArg<0>(&ssrcs)); + EXPECT_CALL(remb_sender, Call(_, ElementsAre(header.ssrc))).Times(AtLeast(1)); for (int i = 0; i < 10; ++i) { clock_.AdvanceTimeMilliseconds((1000 * payload_size) / kInitialBitrateBps); @@ -70,9 +64,20 @@ TEST(ReceiveSideCongestionControllerTest, OnReceivedPacketWithAbsSendTime) { header.extension.absoluteSendTime = AbsSendTime(now_ms, 1000); controller.OnReceivedPacket(now_ms, payload_size, header); } +} + +TEST(ReceiveSideCongestionControllerTest, + SendsRembAfterSetMaxDesiredReceiveBitrate) { + MockFunction>)> + feedback_sender; + MockFunction)> remb_sender; + SimulatedClock clock_(123456); - ASSERT_EQ(1u, ssrcs.size()); - EXPECT_EQ(header.ssrc, ssrcs[0]); + ReceiveSideCongestionController controller( + &clock_, feedback_sender.AsStdFunction(), remb_sender.AsStdFunction(), + nullptr); + EXPECT_CALL(remb_sender, Call(123, _)); + controller.SetMaxDesiredReceiveBitrate(DataRate::BitsPerSec(123)); } TEST(ReceiveSideCongestionControllerTest, ConvergesToCapacity) { @@ -109,7 +114,9 @@ TEST(ReceiveSideCongestionControllerTest, IsFairToTCP) { VideoStreamConfig video; video.stream.packet_feedback = false; s.CreateVideoStream(route->forward(), video); - s.net()->StartFakeTcpCrossTraffic(send_net, ret_net, FakeTcpConfig()); + s.net()->StartCrossTraffic(CreateFakeTcpCrossTraffic( + s.net()->CreateRoute(send_net), s.net()->CreateRoute(ret_net), + FakeTcpConfig())); s.RunFor(TimeDelta::Seconds(30)); // For some reason we get outcompeted by TCP here, this should probably be // fixed and a lower bound should be added to the test. diff --git a/modules/congestion_controller/remb_throttler.cc b/modules/congestion_controller/remb_throttler.cc new file mode 100644 index 0000000000..fcc30af9a8 --- /dev/null +++ b/modules/congestion_controller/remb_throttler.cc @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/congestion_controller/remb_throttler.h" + +#include +#include + +namespace webrtc { + +namespace { +constexpr TimeDelta kRembSendInterval = TimeDelta::Millis(200); +} // namespace + +RembThrottler::RembThrottler(RembSender remb_sender, Clock* clock) + : remb_sender_(std::move(remb_sender)), + clock_(clock), + last_remb_time_(Timestamp::MinusInfinity()), + last_send_remb_bitrate_(DataRate::PlusInfinity()), + max_remb_bitrate_(DataRate::PlusInfinity()) {} + +void RembThrottler::OnReceiveBitrateChanged(const std::vector& ssrcs, + uint32_t bitrate_bps) { + DataRate receive_bitrate = DataRate::BitsPerSec(bitrate_bps); + Timestamp now = clock_->CurrentTime(); + { + MutexLock lock(&mutex_); + // % threshold for if we should send a new REMB asap. + const int64_t kSendThresholdPercent = 103; + if (receive_bitrate * kSendThresholdPercent / 100 > + last_send_remb_bitrate_ && + now < last_remb_time_ + kRembSendInterval) { + return; + } + last_remb_time_ = now; + last_send_remb_bitrate_ = receive_bitrate; + receive_bitrate = std::min(last_send_remb_bitrate_, max_remb_bitrate_); + } + remb_sender_(receive_bitrate.bps(), ssrcs); +} + +void RembThrottler::SetMaxDesiredReceiveBitrate(DataRate bitrate) { + Timestamp now = clock_->CurrentTime(); + { + MutexLock lock(&mutex_); + max_remb_bitrate_ = bitrate; + if (now - last_remb_time_ < kRembSendInterval && + !last_send_remb_bitrate_.IsZero() && + last_send_remb_bitrate_ <= max_remb_bitrate_) { + return; + } + } + remb_sender_(bitrate.bps(), /*ssrcs=*/{}); +} + +} // namespace webrtc diff --git a/modules/congestion_controller/remb_throttler.h b/modules/congestion_controller/remb_throttler.h new file mode 100644 index 0000000000..67c0280749 --- /dev/null +++ b/modules/congestion_controller/remb_throttler.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef MODULES_CONGESTION_CONTROLLER_REMB_THROTTLER_H_ +#define MODULES_CONGESTION_CONTROLLER_REMB_THROTTLER_H_ + +#include +#include + +#include "api/units/data_rate.h" +#include "api/units/time_delta.h" +#include "api/units/timestamp.h" +#include "modules/remote_bitrate_estimator/remote_estimator_proxy.h" +#include "rtc_base/synchronization/mutex.h" + +namespace webrtc { + +// RembThrottler is a helper class used for throttling RTCP REMB messages. +// Throttles small changes to the received BWE within 200ms. +class RembThrottler : public RemoteBitrateObserver { + public: + using RembSender = + std::function ssrcs)>; + RembThrottler(RembSender remb_sender, Clock* clock); + + // Ensures the remote party is notified of the receive bitrate no larger than + // |bitrate| using RTCP REMB. + void SetMaxDesiredReceiveBitrate(DataRate bitrate); + + // Implements RemoteBitrateObserver; + // Called every time there is a new bitrate estimate for a receive channel + // group. This call will trigger a new RTCP REMB packet if the bitrate + // estimate has decreased or if no RTCP REMB packet has been sent for + // a certain time interval. + void OnReceiveBitrateChanged(const std::vector& ssrcs, + uint32_t bitrate_bps) override; + + private: + const RembSender remb_sender_; + Clock* const clock_; + mutable Mutex mutex_; + Timestamp last_remb_time_ RTC_GUARDED_BY(mutex_); + DataRate last_send_remb_bitrate_ RTC_GUARDED_BY(mutex_); + DataRate max_remb_bitrate_ RTC_GUARDED_BY(mutex_); +}; + +} // namespace webrtc +#endif // MODULES_CONGESTION_CONTROLLER_REMB_THROTTLER_H_ diff --git a/modules/congestion_controller/remb_throttler_unittest.cc b/modules/congestion_controller/remb_throttler_unittest.cc new file mode 100644 index 0000000000..3f8df8a7bb --- /dev/null +++ b/modules/congestion_controller/remb_throttler_unittest.cc @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/congestion_controller/remb_throttler.h" + +#include + +#include "api/units/data_rate.h" +#include "api/units/time_delta.h" +#include "system_wrappers/include/clock.h" +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { + +using ::testing::_; +using ::testing::MockFunction; + +TEST(RembThrottlerTest, CallRembSenderOnFirstReceiveBitrateChange) { + SimulatedClock clock(Timestamp::Zero()); + MockFunction)> remb_sender; + RembThrottler remb_throttler(remb_sender.AsStdFunction(), &clock); + + EXPECT_CALL(remb_sender, Call(12345, std::vector({1, 2, 3}))); + remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/12345); +} + +TEST(RembThrottlerTest, ThrottlesSmallReceiveBitrateDecrease) { + SimulatedClock clock(Timestamp::Zero()); + MockFunction)> remb_sender; + RembThrottler remb_throttler(remb_sender.AsStdFunction(), &clock); + + EXPECT_CALL(remb_sender, Call); + remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/12346); + clock.AdvanceTime(TimeDelta::Millis(100)); + remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/12345); + + EXPECT_CALL(remb_sender, Call(12345, _)); + clock.AdvanceTime(TimeDelta::Millis(101)); + remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/12345); +} + +TEST(RembThrottlerTest, DoNotThrottleLargeReceiveBitrateDecrease) { + SimulatedClock clock(Timestamp::Zero()); + MockFunction)> remb_sender; + RembThrottler remb_throttler(remb_sender.AsStdFunction(), &clock); + + EXPECT_CALL(remb_sender, Call(2345, _)); + EXPECT_CALL(remb_sender, Call(1234, _)); + remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/2345); + clock.AdvanceTime(TimeDelta::Millis(1)); + remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/1234); +} + +TEST(RembThrottlerTest, ThrottlesReceiveBitrateIncrease) { + SimulatedClock clock(Timestamp::Zero()); + MockFunction)> remb_sender; + RembThrottler remb_throttler(remb_sender.AsStdFunction(), &clock); + + EXPECT_CALL(remb_sender, Call); + remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/1234); + clock.AdvanceTime(TimeDelta::Millis(100)); + remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/2345); + + // Updates 200ms after previous callback is not throttled. + EXPECT_CALL(remb_sender, Call(2345, _)); + clock.AdvanceTime(TimeDelta::Millis(101)); + remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/2345); +} + +TEST(RembThrottlerTest, CallRembSenderOnSetMaxDesiredReceiveBitrate) { + SimulatedClock clock(Timestamp::Zero()); + MockFunction)> remb_sender; + RembThrottler remb_throttler(remb_sender.AsStdFunction(), &clock); + EXPECT_CALL(remb_sender, Call(1234, _)); + remb_throttler.SetMaxDesiredReceiveBitrate(DataRate::BitsPerSec(1234)); +} + +TEST(RembThrottlerTest, CallRembSenderWithMinOfMaxDesiredAndOnReceivedBitrate) { + SimulatedClock clock(Timestamp::Zero()); + MockFunction)> remb_sender; + RembThrottler remb_throttler(remb_sender.AsStdFunction(), &clock); + + EXPECT_CALL(remb_sender, Call(1234, _)); + remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/1234); + clock.AdvanceTime(TimeDelta::Millis(1)); + remb_throttler.SetMaxDesiredReceiveBitrate(DataRate::BitsPerSec(4567)); + + clock.AdvanceTime(TimeDelta::Millis(200)); + EXPECT_CALL(remb_sender, Call(4567, _)); + remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/5678); +} + +} // namespace webrtc diff --git a/modules/congestion_controller/rtp/BUILD.gn b/modules/congestion_controller/rtp/BUILD.gn index a030976a96..1a70447307 100644 --- a/modules/congestion_controller/rtp/BUILD.gn +++ b/modules/congestion_controller/rtp/BUILD.gn @@ -24,13 +24,13 @@ rtc_library("control_handler") { ] deps = [ + "../../../api:sequence_checker", "../../../api/transport:network_control", "../../../api/units:data_rate", "../../../api/units:data_size", "../../../api/units:time_delta", "../../../rtc_base:checks", "../../../rtc_base:safe_minmax", - "../../../rtc_base/synchronization:sequence_checker", "../../../rtc_base/system:no_unique_address", "../../../system_wrappers:field_trial", "../../pacing", @@ -52,6 +52,7 @@ rtc_library("transport_feedback") { deps = [ "../..:module_api_public", + "../../../api:sequence_checker", "../../../api/transport:network_control", "../../../api/units:data_size", "../../../api/units:timestamp", diff --git a/modules/congestion_controller/rtp/control_handler.h b/modules/congestion_controller/rtp/control_handler.h index e3450f3eb1..1da6463219 100644 --- a/modules/congestion_controller/rtp/control_handler.h +++ b/modules/congestion_controller/rtp/control_handler.h @@ -14,12 +14,12 @@ #include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "api/transport/network_types.h" #include "api/units/data_size.h" #include "api/units/time_delta.h" #include "modules/pacing/paced_sender.h" #include "rtc_base/constructor_magic.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" namespace webrtc { diff --git a/modules/congestion_controller/rtp/transport_feedback_adapter.h b/modules/congestion_controller/rtp/transport_feedback_adapter.h index c41a7c67f8..deb7925d77 100644 --- a/modules/congestion_controller/rtp/transport_feedback_adapter.h +++ b/modules/congestion_controller/rtp/transport_feedback_adapter.h @@ -16,13 +16,13 @@ #include #include +#include "api/sequence_checker.h" #include "api/transport/network_types.h" #include "modules/include/module_common_types_public.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" #include "rtc_base/network/sent_packet.h" #include "rtc_base/network_route.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" namespace webrtc { diff --git a/modules/congestion_controller/rtp/transport_feedback_adapter_unittest.cc b/modules/congestion_controller/rtp/transport_feedback_adapter_unittest.cc index 3849cb3707..933abd9bf0 100644 --- a/modules/congestion_controller/rtp/transport_feedback_adapter_unittest.cc +++ b/modules/congestion_controller/rtp/transport_feedback_adapter_unittest.cc @@ -27,9 +27,9 @@ using ::testing::_; using ::testing::Invoke; namespace webrtc { -namespace webrtc_cc { namespace { +constexpr uint32_t kSsrc = 8492; const PacedPacketInfo kPacingInfo0(0, 5, 2000); const PacedPacketInfo kPacingInfo1(1, 8, 4000); const PacedPacketInfo kPacingInfo2(2, 14, 7000); @@ -49,8 +49,8 @@ void ComparePacketFeedbackVectors(const std::vector& truth, // equal. However, the difference must be the same for all x. TimeDelta arrival_time_delta = truth[0].receive_time - input[0].receive_time; for (size_t i = 0; i < len; ++i) { - RTC_CHECK(truth[i].receive_time.IsFinite()); - if (input[i].receive_time.IsFinite()) { + RTC_CHECK(truth[i].IsReceived()); + if (input[i].IsReceived()) { EXPECT_EQ(truth[i].receive_time - input[i].receive_time, arrival_time_delta); } @@ -77,10 +77,6 @@ PacketResult CreatePacket(int64_t receive_time_ms, return res; } -} // namespace - -namespace test { - class MockStreamFeedbackObserver : public webrtc::StreamFeedbackObserver { public: MOCK_METHOD(void, @@ -89,6 +85,8 @@ class MockStreamFeedbackObserver : public webrtc::StreamFeedbackObserver { (override)); }; +} // namespace + class TransportFeedbackAdapterTest : public ::testing::Test { public: TransportFeedbackAdapterTest() : clock_(0) {} @@ -108,7 +106,7 @@ class TransportFeedbackAdapterTest : public ::testing::Test { void OnSentPacket(const PacketResult& packet_feedback) { RtpPacketSendInfo packet_info; - packet_info.ssrc = kSsrc; + packet_info.media_ssrc = kSsrc; packet_info.transport_sequence_number = packet_feedback.sent_packet.sequence_number; packet_info.rtp_sequence_number = 0; @@ -122,8 +120,6 @@ class TransportFeedbackAdapterTest : public ::testing::Test { packet_feedback.sent_packet.send_time.ms(), rtc::PacketInfo())); } - static constexpr uint32_t kSsrc = 8492; - SimulatedClock clock_; std::unique_ptr adapter_; }; @@ -393,7 +389,7 @@ TEST_F(TransportFeedbackAdapterTest, IgnoreDuplicatePacketSentCalls) { // Add a packet and then mark it as sent. RtpPacketSendInfo packet_info; - packet_info.ssrc = kSsrc; + packet_info.media_ssrc = kSsrc; packet_info.transport_sequence_number = packet.sent_packet.sequence_number; packet_info.length = packet.sent_packet.size.bytes(); packet_info.pacing_info = packet.sent_packet.pacing_info; @@ -412,6 +408,4 @@ TEST_F(TransportFeedbackAdapterTest, IgnoreDuplicatePacketSentCalls) { EXPECT_FALSE(duplicate_packet.has_value()); } -} // namespace test -} // namespace webrtc_cc } // namespace webrtc diff --git a/modules/congestion_controller/rtp/transport_feedback_demuxer.cc b/modules/congestion_controller/rtp/transport_feedback_demuxer.cc index c958a1c3cb..6ab3ad80fa 100644 --- a/modules/congestion_controller/rtp/transport_feedback_demuxer.cc +++ b/modules/congestion_controller/rtp/transport_feedback_demuxer.cc @@ -38,15 +38,16 @@ void TransportFeedbackDemuxer::DeRegisterStreamFeedbackObserver( void TransportFeedbackDemuxer::AddPacket(const RtpPacketSendInfo& packet_info) { MutexLock lock(&lock_); - if (packet_info.ssrc != 0) { - StreamFeedbackObserver::StreamPacketInfo info; - info.ssrc = packet_info.ssrc; - info.rtp_sequence_number = packet_info.rtp_sequence_number; - info.received = false; - history_.insert( - {seq_num_unwrapper_.Unwrap(packet_info.transport_sequence_number), - info}); - } + + StreamFeedbackObserver::StreamPacketInfo info; + info.ssrc = packet_info.media_ssrc; + info.rtp_sequence_number = packet_info.rtp_sequence_number; + info.received = false; + info.is_retransmission = + packet_info.packet_type == RtpPacketMediaType::kRetransmission; + history_.insert( + {seq_num_unwrapper_.Unwrap(packet_info.transport_sequence_number), info}); + while (history_.size() > kMaxPacketsInHistory) { history_.erase(history_.begin()); } diff --git a/modules/congestion_controller/rtp/transport_feedback_demuxer_unittest.cc b/modules/congestion_controller/rtp/transport_feedback_demuxer_unittest.cc index 6514a4eda7..482f58d1bb 100644 --- a/modules/congestion_controller/rtp/transport_feedback_demuxer_unittest.cc +++ b/modules/congestion_controller/rtp/transport_feedback_demuxer_unittest.cc @@ -16,7 +16,11 @@ namespace webrtc { namespace { -using ::testing::_; +using ::testing::AllOf; +using ::testing::ElementsAre; +using ::testing::Field; +using PacketInfo = StreamFeedbackObserver::StreamPacketInfo; + static constexpr uint32_t kSsrc = 8492; class MockStreamFeedbackObserver : public webrtc::StreamFeedbackObserver { @@ -28,41 +32,65 @@ class MockStreamFeedbackObserver : public webrtc::StreamFeedbackObserver { }; RtpPacketSendInfo CreatePacket(uint32_t ssrc, - int16_t rtp_sequence_number, - int64_t transport_sequence_number) { + uint16_t rtp_sequence_number, + int64_t transport_sequence_number, + bool is_retransmission) { RtpPacketSendInfo res; - res.ssrc = ssrc; + res.media_ssrc = ssrc; res.transport_sequence_number = transport_sequence_number; res.rtp_sequence_number = rtp_sequence_number; + res.packet_type = is_retransmission ? RtpPacketMediaType::kRetransmission + : RtpPacketMediaType::kVideo; return res; } } // namespace + TEST(TransportFeedbackDemuxerTest, ObserverSanity) { TransportFeedbackDemuxer demuxer; MockStreamFeedbackObserver mock; demuxer.RegisterStreamFeedbackObserver({kSsrc}, &mock); - demuxer.AddPacket(CreatePacket(kSsrc, 55, 1)); - demuxer.AddPacket(CreatePacket(kSsrc, 56, 2)); - demuxer.AddPacket(CreatePacket(kSsrc, 57, 3)); + const uint16_t kRtpStartSeq = 55; + const int64_t kTransportStartSeq = 1; + demuxer.AddPacket(CreatePacket(kSsrc, kRtpStartSeq, kTransportStartSeq, + /*is_retransmit=*/false)); + demuxer.AddPacket(CreatePacket(kSsrc, kRtpStartSeq + 1, + kTransportStartSeq + 1, + /*is_retransmit=*/false)); + demuxer.AddPacket(CreatePacket( + kSsrc, kRtpStartSeq + 2, kTransportStartSeq + 2, /*is_retransmit=*/true)); rtcp::TransportFeedback feedback; - feedback.SetBase(1, 1000); - ASSERT_TRUE(feedback.AddReceivedPacket(1, 1000)); - ASSERT_TRUE(feedback.AddReceivedPacket(2, 2000)); - ASSERT_TRUE(feedback.AddReceivedPacket(3, 3000)); + feedback.SetBase(kTransportStartSeq, 1000); + ASSERT_TRUE(feedback.AddReceivedPacket(kTransportStartSeq, 1000)); + // Drop middle packet. + ASSERT_TRUE(feedback.AddReceivedPacket(kTransportStartSeq + 2, 3000)); - EXPECT_CALL(mock, OnPacketFeedbackVector(_)).Times(1); + EXPECT_CALL( + mock, OnPacketFeedbackVector(ElementsAre( + AllOf(Field(&PacketInfo::received, true), + Field(&PacketInfo::ssrc, kSsrc), + Field(&PacketInfo::rtp_sequence_number, kRtpStartSeq), + Field(&PacketInfo::is_retransmission, false)), + AllOf(Field(&PacketInfo::received, false), + Field(&PacketInfo::ssrc, kSsrc), + Field(&PacketInfo::rtp_sequence_number, kRtpStartSeq + 1), + Field(&PacketInfo::is_retransmission, false)), + AllOf(Field(&PacketInfo::received, true), + Field(&PacketInfo::ssrc, kSsrc), + Field(&PacketInfo::rtp_sequence_number, kRtpStartSeq + 2), + Field(&PacketInfo::is_retransmission, true))))); demuxer.OnTransportFeedback(feedback); demuxer.DeRegisterStreamFeedbackObserver(&mock); - demuxer.AddPacket(CreatePacket(kSsrc, 58, 4)); + demuxer.AddPacket( + CreatePacket(kSsrc, kRtpStartSeq + 3, kTransportStartSeq + 3, false)); rtcp::TransportFeedback second_feedback; - second_feedback.SetBase(4, 4000); - ASSERT_TRUE(second_feedback.AddReceivedPacket(4, 4000)); + second_feedback.SetBase(kTransportStartSeq + 3, 4000); + ASSERT_TRUE(second_feedback.AddReceivedPacket(kTransportStartSeq + 3, 4000)); - EXPECT_CALL(mock, OnPacketFeedbackVector(_)).Times(0); + EXPECT_CALL(mock, OnPacketFeedbackVector).Times(0); demuxer.OnTransportFeedback(second_feedback); } } // namespace webrtc diff --git a/modules/desktop_capture/BUILD.gn b/modules/desktop_capture/BUILD.gn index 70344e5ba8..25b92bed45 100644 --- a/modules/desktop_capture/BUILD.gn +++ b/modules/desktop_capture/BUILD.gn @@ -44,6 +44,7 @@ rtc_library("primitives") { "../../api:scoped_refptr", "../../rtc_base:checks", "../../rtc_base/system:rtc_export", + "//third_party/libyuv", ] if (!build_with_mozilla) { @@ -143,7 +144,16 @@ if (rtc_include_tests) { if (is_mac) { sources += [ "screen_capturer_mac_unittest.cc" ] } - deps += [ ":desktop_capture_mock" ] + if (rtc_enable_win_wgc) { + sources += [ + "win/wgc_capture_source_unittest.cc", + "win/wgc_capturer_win_unittest.cc", + ] + } + deps += [ + ":desktop_capture_mock", + "../../system_wrappers:metrics", + ] public_configs = [ ":x11_config" ] } } @@ -322,6 +332,8 @@ rtc_library("desktop_capture_generic") { "cropping_window_capturer.h", "desktop_and_cursor_composer.cc", "desktop_and_cursor_composer.h", + "desktop_capture_metrics_helper.cc", + "desktop_capture_metrics_helper.h", "desktop_capture_options.cc", "desktop_capture_options.h", "desktop_capturer.cc", @@ -451,6 +463,7 @@ rtc_library("desktop_capture_generic") { "../../api:function_view", "../../api:refcountedbase", "../../api:scoped_refptr", + "../../api:sequence_checker", "../../rtc_base", # TODO(kjellander): Cleanup in bugs.webrtc.org/3806. "../../rtc_base:checks", "../../rtc_base/synchronization:mutex", @@ -564,11 +577,13 @@ rtc_library("desktop_capture_generic") { sources += [ "win/wgc_capture_session.cc", "win/wgc_capture_session.h", - "win/window_capturer_win_wgc.cc", - "win/window_capturer_win_wgc.h", + "win/wgc_capture_source.cc", + "win/wgc_capture_source.h", + "win/wgc_capturer_win.cc", + "win/wgc_capturer_win.h", + "win/wgc_desktop_frame.cc", + "win/wgc_desktop_frame.h", ] - - defines += [ "RTC_ENABLE_WIN_WGC" ] } } diff --git a/modules/desktop_capture/cropping_window_capturer_win.cc b/modules/desktop_capture/cropping_window_capturer_win.cc index de36adb01e..31ddbe1b33 100644 --- a/modules/desktop_capture/cropping_window_capturer_win.cc +++ b/modules/desktop_capture/cropping_window_capturer_win.cc @@ -118,7 +118,7 @@ struct TopWindowVerifierContext : public SelectedWindowContext { // firing an assert when enabled, report that the selected window isn't // topmost to avoid inadvertent capture of other windows. RTC_LOG(LS_ERROR) << "Failed to enumerate windows: " << lastError; - RTC_DCHECK(false); + RTC_NOTREACHED(); return false; } } @@ -130,6 +130,8 @@ class CroppingWindowCapturerWin : public CroppingWindowCapturer { public: explicit CroppingWindowCapturerWin(const DesktopCaptureOptions& options) : CroppingWindowCapturer(options), + enumerate_current_process_windows_( + options.enumerate_current_process_windows()), full_screen_window_detector_(options.full_screen_window_detector()) {} void CaptureFrame() override; @@ -148,6 +150,8 @@ class CroppingWindowCapturerWin : public CroppingWindowCapturer { WindowCaptureHelperWin window_capture_helper_; + bool enumerate_current_process_windows_; + rtc::scoped_refptr full_screen_window_detector_; }; @@ -164,7 +168,12 @@ void CroppingWindowCapturerWin::CaptureFrame() { // it uses responsiveness check which could lead to performance // issues. SourceList result; - if (!webrtc::GetWindowList(GetWindowListFlags::kNone, &result)) + int window_list_flags = + enumerate_current_process_windows_ + ? GetWindowListFlags::kNone + : GetWindowListFlags::kIgnoreCurrentProcessWindows; + + if (!webrtc::GetWindowList(window_list_flags, &result)) return false; // Filter out windows not visible on current desktop diff --git a/modules/desktop_capture/desktop_and_cursor_composer.cc b/modules/desktop_capture/desktop_and_cursor_composer.cc index f282c1d500..69b8b40c73 100644 --- a/modules/desktop_capture/desktop_and_cursor_composer.cc +++ b/modules/desktop_capture/desktop_and_cursor_composer.cc @@ -207,7 +207,8 @@ void DesktopAndCursorComposer::OnCaptureResult( DesktopCapturer::Result result, std::unique_ptr frame) { if (frame && cursor_) { - if (frame->rect().Contains(cursor_position_) && + if (!frame->may_contain_cursor() && + frame->rect().Contains(cursor_position_) && !desktop_capturer_->IsOccluded(cursor_position_)) { DesktopVector relative_position = cursor_position_.subtract(frame->top_left()); @@ -228,6 +229,7 @@ void DesktopAndCursorComposer::OnCaptureResult( previous_cursor_rect_ = frame_with_cursor->cursor_rect(); cursor_changed_ = false; frame = std::move(frame_with_cursor); + frame->set_may_contain_cursor(true); } } diff --git a/modules/desktop_capture/desktop_and_cursor_composer_unittest.cc b/modules/desktop_capture/desktop_and_cursor_composer_unittest.cc index c9cb56d8c2..00253d38e2 100644 --- a/modules/desktop_capture/desktop_and_cursor_composer_unittest.cc +++ b/modules/desktop_capture/desktop_and_cursor_composer_unittest.cc @@ -27,6 +27,8 @@ namespace webrtc { namespace { +const int kFrameXCoord = 100; +const int kFrameYCoord = 200; const int kScreenWidth = 100; const int kScreenHeight = 100; const int kCursorWidth = 10; @@ -249,11 +251,61 @@ TEST_F(DesktopAndCursorComposerTest, CursorShouldBeIgnoredIfNoFrameCaptured) { } } +TEST_F(DesktopAndCursorComposerTest, CursorShouldBeIgnoredIfFrameMayContainIt) { + // We can't use a shared frame because we need to detect modifications + // compared to a control. + std::unique_ptr control_frame(CreateTestFrame()); + control_frame->set_top_left(DesktopVector(kFrameXCoord, kFrameYCoord)); + + struct { + int x; + int y; + bool may_contain_cursor; + } tests[] = { + {100, 200, true}, + {100, 200, false}, + {150, 250, true}, + {150, 250, false}, + }; + + for (size_t i = 0; i < arraysize(tests); i++) { + SCOPED_TRACE(i); + + std::unique_ptr frame(CreateTestFrame()); + frame->set_top_left(DesktopVector(kFrameXCoord, kFrameYCoord)); + frame->set_may_contain_cursor(tests[i].may_contain_cursor); + fake_screen_->SetNextFrame(std::move(frame)); + + const DesktopVector abs_pos(tests[i].x, tests[i].y); + fake_cursor_->SetState(MouseCursorMonitor::INSIDE, abs_pos); + blender_.CaptureFrame(); + + // If the frame may already have contained the cursor, then |CaptureFrame()| + // should not have modified it, so it should be the same as the control. + EXPECT_TRUE(frame_); + const DesktopVector rel_pos(abs_pos.subtract(control_frame->top_left())); + if (tests[i].may_contain_cursor) { + EXPECT_EQ( + *reinterpret_cast(frame_->GetFrameDataAtPos(rel_pos)), + *reinterpret_cast( + control_frame->GetFrameDataAtPos(rel_pos))); + + } else { + // |CaptureFrame()| should have modified the frame to have the cursor. + EXPECT_NE( + *reinterpret_cast(frame_->GetFrameDataAtPos(rel_pos)), + *reinterpret_cast( + control_frame->GetFrameDataAtPos(rel_pos))); + EXPECT_TRUE(frame_->may_contain_cursor()); + } + } +} + TEST_F(DesktopAndCursorComposerTest, CursorShouldBeIgnoredIfItIsOutOfDesktopFrame) { std::unique_ptr frame( SharedDesktopFrame::Wrap(CreateTestFrame())); - frame->set_top_left(DesktopVector(100, 200)); + frame->set_top_left(DesktopVector(kFrameXCoord, kFrameYCoord)); // The frame covers (100, 200) - (200, 300). struct { @@ -279,7 +331,7 @@ TEST_F(DesktopAndCursorComposerTest, TEST_F(DesktopAndCursorComposerTest, IsOccludedShouldBeConsidered) { std::unique_ptr frame( SharedDesktopFrame::Wrap(CreateTestFrame())); - frame->set_top_left(DesktopVector(100, 200)); + frame->set_top_left(DesktopVector(kFrameXCoord, kFrameYCoord)); // The frame covers (100, 200) - (200, 300). struct { @@ -304,7 +356,7 @@ TEST_F(DesktopAndCursorComposerTest, IsOccludedShouldBeConsidered) { TEST_F(DesktopAndCursorComposerTest, CursorIncluded) { std::unique_ptr frame( SharedDesktopFrame::Wrap(CreateTestFrame())); - frame->set_top_left(DesktopVector(100, 200)); + frame->set_top_left(DesktopVector(kFrameXCoord, kFrameYCoord)); // The frame covers (100, 200) - (200, 300). struct { diff --git a/modules/desktop_capture/desktop_capture_metrics_helper.cc b/modules/desktop_capture/desktop_capture_metrics_helper.cc new file mode 100644 index 0000000000..6b741ef4bb --- /dev/null +++ b/modules/desktop_capture/desktop_capture_metrics_helper.cc @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/desktop_capture/desktop_capture_metrics_helper.h" + +#include "modules/desktop_capture/desktop_capture_types.h" +#include "system_wrappers/include/metrics.h" + +namespace webrtc { +namespace { +// This enum is logged via UMA so entries should not be reordered or have their +// values changed. This should also be kept in sync with the values in the +// DesktopCapturerId namespace. +enum class SequentialDesktopCapturerId { + kUnknown = 0, + kWgcCapturerWin = 1, + kScreenCapturerWinMagnifier = 2, + kWindowCapturerWinGdi = 3, + kScreenCapturerWinGdi = 4, + kScreenCapturerWinDirectx = 5, + kMaxValue = kScreenCapturerWinDirectx +}; +} // namespace + +void RecordCapturerImpl(uint32_t capturer_id) { + SequentialDesktopCapturerId sequential_id; + switch (capturer_id) { + case DesktopCapturerId::kWgcCapturerWin: + sequential_id = SequentialDesktopCapturerId::kWgcCapturerWin; + break; + case DesktopCapturerId::kScreenCapturerWinMagnifier: + sequential_id = SequentialDesktopCapturerId::kScreenCapturerWinMagnifier; + break; + case DesktopCapturerId::kWindowCapturerWinGdi: + sequential_id = SequentialDesktopCapturerId::kWindowCapturerWinGdi; + break; + case DesktopCapturerId::kScreenCapturerWinGdi: + sequential_id = SequentialDesktopCapturerId::kScreenCapturerWinGdi; + break; + case DesktopCapturerId::kScreenCapturerWinDirectx: + sequential_id = SequentialDesktopCapturerId::kScreenCapturerWinDirectx; + break; + case DesktopCapturerId::kUnknown: + default: + sequential_id = SequentialDesktopCapturerId::kUnknown; + } + RTC_HISTOGRAM_ENUMERATION( + "WebRTC.DesktopCapture.Win.DesktopCapturerImpl", + static_cast(sequential_id), + static_cast(SequentialDesktopCapturerId::kMaxValue)); +} + +} // namespace webrtc diff --git a/api/video_codecs/video_decoder_factory.cc b/modules/desktop_capture/desktop_capture_metrics_helper.h similarity index 53% rename from api/video_codecs/video_decoder_factory.cc rename to modules/desktop_capture/desktop_capture_metrics_helper.h index 511a3c7e92..37542b84bb 100644 --- a/api/video_codecs/video_decoder_factory.cc +++ b/modules/desktop_capture/desktop_capture_metrics_helper.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source @@ -8,16 +8,15 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include "api/video_codecs/video_decoder_factory.h" +#ifndef MODULES_DESKTOP_CAPTURE_DESKTOP_CAPTURE_METRICS_HELPER_H_ +#define MODULES_DESKTOP_CAPTURE_DESKTOP_CAPTURE_METRICS_HELPER_H_ -#include "api/video_codecs/video_decoder.h" +#include namespace webrtc { -std::unique_ptr VideoDecoderFactory::LegacyCreateVideoDecoder( - const SdpVideoFormat& format, - const std::string& receive_stream_id) { - return CreateVideoDecoder(format); -} +void RecordCapturerImpl(uint32_t capturer_id); } // namespace webrtc + +#endif // MODULES_DESKTOP_CAPTURE_DESKTOP_CAPTURE_METRICS_HELPER_H_ diff --git a/modules/desktop_capture/desktop_capture_options.h b/modules/desktop_capture/desktop_capture_options.h index 521c80b5c5..a693803aa0 100644 --- a/modules/desktop_capture/desktop_capture_options.h +++ b/modules/desktop_capture/desktop_capture_options.h @@ -98,6 +98,24 @@ class RTC_EXPORT DesktopCaptureOptions { } #if defined(WEBRTC_WIN) + // Enumerating windows owned by the current process on Windows has some + // complications due to |GetWindowText*()| APIs potentially causing a + // deadlock (see the comments in the |GetWindowListHandler()| function in + // window_capture_utils.cc for more details on the deadlock). + // To avoid this issue, consumers can either ensure that the thread that runs + // their message loop never waits on |GetSourceList()|, or they can set this + // flag to false which will prevent windows running in the current process + // from being enumerated and included in the results. Consumers can still + // provide the WindowId for their own windows to |SelectSource()| and capture + // them. + bool enumerate_current_process_windows() const { + return enumerate_current_process_windows_; + } + void set_enumerate_current_process_windows( + bool enumerate_current_process_windows) { + enumerate_current_process_windows_ = enumerate_current_process_windows; + } + bool allow_use_magnification_api() const { return allow_use_magnification_api_; } @@ -126,7 +144,19 @@ class RTC_EXPORT DesktopCaptureOptions { void set_allow_cropping_window_capturer(bool allow) { allow_cropping_window_capturer_ = allow; } -#endif + +#if defined(RTC_ENABLE_WIN_WGC) + // This flag enables the WGC capturer for both window and screen capture. + // This capturer should offer similar or better performance than the cropping + // capturer without the disadvantages listed above. However, the WGC capturer + // is only available on Windows 10 version 1809 (Redstone 5) and up. This flag + // will have no affect on older versions. + // If set, and running a supported version of Win10, this flag will take + // precedence over the cropping, directx, and magnification flags. + bool allow_wgc_capturer() const { return allow_wgc_capturer_; } + void set_allow_wgc_capturer(bool allow) { allow_wgc_capturer_ = allow; } +#endif // defined(RTC_ENABLE_WIN_WGC) +#endif // defined(WEBRTC_WIN) #if defined(WEBRTC_USE_PIPEWIRE) bool allow_pipewire() const { return allow_pipewire_; } @@ -146,9 +176,13 @@ class RTC_EXPORT DesktopCaptureOptions { rtc::scoped_refptr full_screen_window_detector_; #if defined(WEBRTC_WIN) + bool enumerate_current_process_windows_ = true; bool allow_use_magnification_api_ = false; bool allow_directx_capturer_ = false; bool allow_cropping_window_capturer_ = false; +#if defined(RTC_ENABLE_WIN_WGC) + bool allow_wgc_capturer_ = false; +#endif #endif #if defined(WEBRTC_USE_X11) bool use_update_notifications_ = false; diff --git a/modules/desktop_capture/desktop_capture_types.h b/modules/desktop_capture/desktop_capture_types.h index 5031cbf3ac..5f9966bb6d 100644 --- a/modules/desktop_capture/desktop_capture_types.h +++ b/modules/desktop_capture/desktop_capture_types.h @@ -36,8 +36,11 @@ const ScreenId kFullDesktopScreenId = -1; const ScreenId kInvalidScreenId = -2; -// An integer to attach to each DesktopFrame to differentiate the generator of -// the frame. +// Integers to attach to each DesktopFrame to differentiate the generator of +// the frame. The entries in this namespace should remain in sync with the +// SequentialDesktopCapturerId enum, which is logged via UMA. +// |kScreenCapturerWinGdi| and |kScreenCapturerWinDirectx| values are preserved +// to maintain compatibility namespace DesktopCapturerId { constexpr uint32_t CreateFourCC(char a, char b, char c, char d) { return ((static_cast(a)) | (static_cast(b) << 8) | @@ -45,6 +48,9 @@ constexpr uint32_t CreateFourCC(char a, char b, char c, char d) { } constexpr uint32_t kUnknown = 0; +constexpr uint32_t kWgcCapturerWin = 1; +constexpr uint32_t kScreenCapturerWinMagnifier = 2; +constexpr uint32_t kWindowCapturerWinGdi = 3; constexpr uint32_t kScreenCapturerWinGdi = CreateFourCC('G', 'D', 'I', ' '); constexpr uint32_t kScreenCapturerWinDirectx = CreateFourCC('D', 'X', 'G', 'I'); } // namespace DesktopCapturerId diff --git a/modules/desktop_capture/desktop_capturer.cc b/modules/desktop_capture/desktop_capturer.cc index e1fff4ea57..735aa4d530 100644 --- a/modules/desktop_capture/desktop_capturer.cc +++ b/modules/desktop_capture/desktop_capturer.cc @@ -21,10 +21,8 @@ #include "modules/desktop_capture/desktop_capturer_differ_wrapper.h" #if defined(RTC_ENABLE_WIN_WGC) -#include "modules/desktop_capture/win/window_capturer_win_wgc.h" +#include "modules/desktop_capture/win/wgc_capturer_win.h" #include "rtc_base/win/windows_version.h" - -const bool kUseWinWgcCapturer = false; #endif // defined(RTC_ENABLE_WIN_WGC) namespace webrtc { @@ -56,12 +54,9 @@ bool DesktopCapturer::IsOccluded(const DesktopVector& pos) { std::unique_ptr DesktopCapturer::CreateWindowCapturer( const DesktopCaptureOptions& options) { #if defined(RTC_ENABLE_WIN_WGC) - // TODO(bugs.webrtc.org/11760): Add a WebRTC field trial (or similar - // mechanism) check here that leads to use of the WGC capturer once it is - // fully implemented. - if (kUseWinWgcCapturer && + if (options.allow_wgc_capturer() && rtc::rtc_win::GetVersion() >= rtc::rtc_win::Version::VERSION_WIN10_RS5) { - return WindowCapturerWinWgc::CreateRawWindowCapturer(options); + return WgcCapturerWin::CreateRawWindowCapturer(options); } #endif // defined(RTC_ENABLE_WIN_WGC) @@ -82,6 +77,13 @@ std::unique_ptr DesktopCapturer::CreateWindowCapturer( // static std::unique_ptr DesktopCapturer::CreateScreenCapturer( const DesktopCaptureOptions& options) { +#if defined(RTC_ENABLE_WIN_WGC) + if (options.allow_wgc_capturer() && + rtc::rtc_win::GetVersion() >= rtc::rtc_win::Version::VERSION_WIN10_RS5) { + return WgcCapturerWin::CreateRawScreenCapturer(options); + } +#endif // defined(RTC_ENABLE_WIN_WGC) + std::unique_ptr capturer = CreateRawScreenCapturer(options); if (capturer && options.detect_updated_region()) { capturer.reset(new DesktopCapturerDifferWrapper(std::move(capturer))); diff --git a/modules/desktop_capture/desktop_frame.cc b/modules/desktop_capture/desktop_frame.cc index fd10dd5d23..9e4a899fd2 100644 --- a/modules/desktop_capture/desktop_frame.cc +++ b/modules/desktop_capture/desktop_frame.cc @@ -19,6 +19,7 @@ #include "modules/desktop_capture/desktop_capture_types.h" #include "modules/desktop_capture/desktop_geometry.h" #include "rtc_base/checks.h" +#include "third_party/libyuv/include/libyuv/planar_functions.h" namespace webrtc { @@ -44,11 +45,9 @@ void DesktopFrame::CopyPixelsFrom(const uint8_t* src_buffer, RTC_CHECK(DesktopRect::MakeSize(size()).ContainsRect(dest_rect)); uint8_t* dest = GetFrameDataAtPos(dest_rect.top_left()); - for (int y = 0; y < dest_rect.height(); ++y) { - memcpy(dest, src_buffer, DesktopFrame::kBytesPerPixel * dest_rect.width()); - src_buffer += src_stride; - dest += stride(); - } + libyuv::CopyPlane(src_buffer, src_stride, dest, stride(), + DesktopFrame::kBytesPerPixel * dest_rect.width(), + dest_rect.height()); } void DesktopFrame::CopyPixelsFrom(const DesktopFrame& src_frame, @@ -158,11 +157,9 @@ BasicDesktopFrame::~BasicDesktopFrame() { // static DesktopFrame* BasicDesktopFrame::CopyOf(const DesktopFrame& frame) { DesktopFrame* result = new BasicDesktopFrame(frame.size()); - for (int y = 0; y < frame.size().height(); ++y) { - memcpy(result->data() + y * result->stride(), - frame.data() + y * frame.stride(), - frame.size().width() * kBytesPerPixel); - } + libyuv::CopyPlane(frame.data(), frame.stride(), result->data(), + result->stride(), frame.size().width() * kBytesPerPixel, + frame.size().height()); result->CopyFrameInfoFrom(frame); return result; } diff --git a/modules/desktop_capture/desktop_frame.h b/modules/desktop_capture/desktop_frame.h index 4ee3680670..bc47cc50f2 100644 --- a/modules/desktop_capture/desktop_frame.h +++ b/modules/desktop_capture/desktop_frame.h @@ -72,6 +72,15 @@ class RTC_EXPORT DesktopFrame { const DesktopVector& dpi() const { return dpi_; } void set_dpi(const DesktopVector& dpi) { dpi_ = dpi; } + // Indicates if this frame may have the mouse cursor in it. Capturers that + // support cursor capture may set this to true. If the cursor was + // outside of the captured area, this may be true even though the cursor is + // not in the image. + bool may_contain_cursor() const { return may_contain_cursor_; } + void set_may_contain_cursor(bool may_contain_cursor) { + may_contain_cursor_ = may_contain_cursor; + } + // Time taken to capture the frame in milliseconds. int64_t capture_time_ms() const { return capture_time_ms_; } void set_capture_time_ms(int64_t time_ms) { capture_time_ms_ = time_ms; } @@ -150,6 +159,7 @@ class RTC_EXPORT DesktopFrame { DesktopRegion updated_region_; DesktopVector top_left_; DesktopVector dpi_; + bool may_contain_cursor_ = false; int64_t capture_time_ms_; uint32_t capturer_id_; std::vector icc_profile_; diff --git a/modules/desktop_capture/desktop_region.cc b/modules/desktop_capture/desktop_region.cc index befbcc6f41..96f142d3dd 100644 --- a/modules/desktop_capture/desktop_region.cc +++ b/modules/desktop_capture/desktop_region.cc @@ -10,11 +10,11 @@ #include "modules/desktop_capture/desktop_region.h" -#include - #include #include +#include "rtc_base/checks.h" + namespace webrtc { DesktopRegion::RowSpan::RowSpan(int32_t left, int32_t right) @@ -109,7 +109,7 @@ void DesktopRegion::AddRect(const DesktopRect& rect) { // If the |top| falls in the middle of the |row| then split |row| into // two, at |top|, and leave |row| referring to the lower of the two, // ready to insert a new span into. - assert(top <= row->second->bottom); + RTC_DCHECK_LE(top, row->second->bottom); Rows::iterator new_row = rows_.insert( row, Rows::value_type(top, new Row(row->second->top, top))); row->second->top = top; @@ -148,7 +148,7 @@ void DesktopRegion::AddRects(const DesktopRect* rects, int count) { } void DesktopRegion::MergeWithPrecedingRow(Rows::iterator row) { - assert(row != rows_.end()); + RTC_DCHECK(row != rows_.end()); if (row != rows_.begin()) { Rows::iterator previous_row = row; @@ -230,7 +230,7 @@ void DesktopRegion::IntersectRows(const RowSpanSet& set1, RowSpanSet::const_iterator end1 = set1.end(); RowSpanSet::const_iterator it2 = set2.begin(); RowSpanSet::const_iterator end2 = set2.end(); - assert(it1 != end1 && it2 != end2); + RTC_DCHECK(it1 != end1 && it2 != end2); do { // Arrange for |it1| to always be the left-most of the spans. @@ -247,7 +247,7 @@ void DesktopRegion::IntersectRows(const RowSpanSet& set1, int32_t left = it2->left; int32_t right = std::min(it1->right, it2->right); - assert(left < right); + RTC_DCHECK_LT(left, right); output->push_back(RowSpan(left, right)); @@ -302,7 +302,7 @@ void DesktopRegion::Subtract(const DesktopRegion& region) { // If |top| falls in the middle of |row_a| then split |row_a| into two, at // |top|, and leave |row_a| referring to the lower of the two, ready to // subtract spans from. - assert(top <= row_a->second->bottom); + RTC_DCHECK_LE(top, row_a->second->bottom); Rows::iterator new_row = rows_.insert( row_a, Rows::value_type(top, new Row(row_a->second->top, top))); row_a->second->top = top; @@ -420,7 +420,7 @@ void DesktopRegion::AddSpanToRow(Row* row, int left, int right) { // Find the first span that ends at or after |left|. RowSpanSet::iterator start = std::lower_bound( row->spans.begin(), row->spans.end(), left, CompareSpanRight); - assert(start < row->spans.end()); + RTC_DCHECK(start < row->spans.end()); // Find the first span that starts after |right|. RowSpanSet::iterator end = @@ -467,7 +467,7 @@ bool DesktopRegion::IsSpanInRow(const Row& row, const RowSpan& span) { void DesktopRegion::SubtractRows(const RowSpanSet& set_a, const RowSpanSet& set_b, RowSpanSet* output) { - assert(!set_a.empty() && !set_b.empty()); + RTC_DCHECK(!set_a.empty() && !set_b.empty()); RowSpanSet::const_iterator it_b = set_b.begin(); @@ -503,7 +503,7 @@ DesktopRegion::Iterator::Iterator(const DesktopRegion& region) row_(region.rows_.begin()), previous_row_(region.rows_.end()) { if (!IsAtEnd()) { - assert(row_->second->spans.size() > 0); + RTC_DCHECK_GT(row_->second->spans.size(), 0); row_span_ = row_->second->spans.begin(); UpdateCurrentRect(); } @@ -516,7 +516,7 @@ bool DesktopRegion::Iterator::IsAtEnd() const { } void DesktopRegion::Iterator::Advance() { - assert(!IsAtEnd()); + RTC_DCHECK(!IsAtEnd()); while (true) { ++row_span_; @@ -524,7 +524,7 @@ void DesktopRegion::Iterator::Advance() { previous_row_ = row_; ++row_; if (row_ != region_.rows_.end()) { - assert(row_->second->spans.size() > 0); + RTC_DCHECK_GT(row_->second->spans.size(), 0); row_span_ = row_->second->spans.begin(); } } @@ -544,7 +544,7 @@ void DesktopRegion::Iterator::Advance() { break; } - assert(!IsAtEnd()); + RTC_DCHECK(!IsAtEnd()); UpdateCurrentRect(); } diff --git a/modules/desktop_capture/fallback_desktop_capturer_wrapper.cc b/modules/desktop_capture/fallback_desktop_capturer_wrapper.cc index 206791ca78..0b1ab7ed37 100644 --- a/modules/desktop_capture/fallback_desktop_capturer_wrapper.cc +++ b/modules/desktop_capture/fallback_desktop_capturer_wrapper.cc @@ -14,8 +14,8 @@ #include +#include "api/sequence_checker.h" #include "rtc_base/checks.h" -#include "rtc_base/thread_checker.h" #include "system_wrappers/include/metrics.h" namespace webrtc { @@ -42,7 +42,7 @@ class SharedMemoryFactoryProxy : public SharedMemoryFactory { explicit SharedMemoryFactoryProxy(SharedMemoryFactory* factory); SharedMemoryFactory* factory_ = nullptr; - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; }; } // namespace diff --git a/modules/desktop_capture/full_screen_window_detector.h b/modules/desktop_capture/full_screen_window_detector.h index 46fb607b7d..ca30d95de4 100644 --- a/modules/desktop_capture/full_screen_window_detector.h +++ b/modules/desktop_capture/full_screen_window_detector.h @@ -32,7 +32,8 @@ namespace webrtc { // window using criteria provided by application specific // FullScreenApplicationHandler. -class FullScreenWindowDetector : public rtc::RefCountedBase { +class FullScreenWindowDetector + : public rtc::RefCountedNonVirtual { public: using ApplicationHandlerFactory = std::function( diff --git a/modules/desktop_capture/linux/base_capturer_pipewire.cc b/modules/desktop_capture/linux/base_capturer_pipewire.cc index c302a086ea..e5d001e476 100644 --- a/modules/desktop_capture/linux/base_capturer_pipewire.cc +++ b/modules/desktop_capture/linux/base_capturer_pipewire.cc @@ -772,37 +772,27 @@ void BaseCapturerPipeWire::HandleBuffer(pw_buffer* buffer) { // Use video metadata when video size from metadata is set and smaller than // video stream size, so we need to adjust it. - bool video_is_full_width = true; - bool video_is_full_height = true; + bool video_metadata_use = false; + #if PW_CHECK_VERSION(0, 3, 0) - if (video_metadata && video_metadata->region.size.width != 0 && - video_metadata->region.size.height != 0) { - if (video_metadata->region.size.width < - static_cast(desktop_size_.width())) { - video_is_full_width = false; - } else if (video_metadata->region.size.height < - static_cast(desktop_size_.height())) { - video_is_full_height = false; - } - } + const struct spa_rectangle* video_metadata_size = + video_metadata ? &video_metadata->region.size : nullptr; #else - if (video_metadata && video_metadata->width != 0 && - video_metadata->height != 0) { - if (video_metadata->width < desktop_size_.width()) { - } else if (video_metadata->height < desktop_size_.height()) { - video_is_full_height = false; - } - } + const struct spa_meta_video_crop* video_metadata_size = video_metadata; #endif + if (video_metadata_size && video_metadata_size->width != 0 && + video_metadata_size->height != 0 && + (static_cast(video_metadata_size->width) < desktop_size_.width() || + static_cast(video_metadata_size->height) < + desktop_size_.height())) { + video_metadata_use = true; + } + DesktopSize video_size_prev = video_size_; - if (!video_is_full_height || !video_is_full_width) { -#if PW_CHECK_VERSION(0, 3, 0) - video_size_ = DesktopSize(video_metadata->region.size.width, - video_metadata->region.size.height); -#else - video_size_ = DesktopSize(video_metadata->width, video_metadata->height); -#endif + if (video_metadata_use) { + video_size_ = + DesktopSize(video_metadata_size->width, video_metadata_size->height); } else { video_size_ = desktop_size_; } @@ -827,25 +817,25 @@ void BaseCapturerPipeWire::HandleBuffer(pw_buffer* buffer) { // Adjust source content based on metadata video position #if PW_CHECK_VERSION(0, 3, 0) - if (!video_is_full_height && + if (video_metadata_use && (video_metadata->region.position.y + video_size_.height() <= desktop_size_.height())) { src += src_stride * video_metadata->region.position.y; } const int x_offset = - !video_is_full_width && + video_metadata_use && (video_metadata->region.position.x + video_size_.width() <= desktop_size_.width()) ? video_metadata->region.position.x * kBytesPerPixel : 0; #else - if (!video_is_full_height && + if (video_metadata_use && (video_metadata->y + video_size_.height() <= desktop_size_.height())) { src += src_stride * video_metadata->y; } const int x_offset = - !video_is_full_width && + video_metadata_use && (video_metadata->x + video_size_.width() <= desktop_size_.width()) ? video_metadata->x * kBytesPerPixel : 0; @@ -1036,6 +1026,23 @@ void BaseCapturerPipeWire::SourcesRequest() { // We don't want to allow selection of multiple sources. g_variant_builder_add(&builder, "{sv}", "multiple", g_variant_new_boolean(false)); + + Scoped variant( + g_dbus_proxy_get_cached_property(proxy_, "AvailableCursorModes")); + if (variant.get()) { + uint32_t modes = 0; + g_variant_get(variant.get(), "u", &modes); + // Request mouse cursor to be embedded as part of the stream, otherwise it + // is hidden by default. Make request only if this mode is advertised by + // the portal implementation. + if (modes & + static_cast(BaseCapturerPipeWire::CursorMode::kEmbedded)) { + g_variant_builder_add(&builder, "{sv}", "cursor_mode", + g_variant_new_uint32(static_cast( + BaseCapturerPipeWire::CursorMode::kEmbedded))); + } + } + variant_string = g_strdup_printf("webrtc%d", g_random_int_range(0, G_MAXINT)); g_variant_builder_add(&builder, "{sv}", "handle_token", g_variant_new_string(variant_string.get())); diff --git a/modules/desktop_capture/linux/base_capturer_pipewire.h b/modules/desktop_capture/linux/base_capturer_pipewire.h index 75d20dbf1d..52264188a7 100644 --- a/modules/desktop_capture/linux/base_capturer_pipewire.h +++ b/modules/desktop_capture/linux/base_capturer_pipewire.h @@ -47,6 +47,12 @@ class BaseCapturerPipeWire : public DesktopCapturer { kAny = 0b11 }; + enum class CursorMode : uint32_t { + kHidden = 0b01, + kEmbedded = 0b10, + kMetadata = 0b100 + }; + explicit BaseCapturerPipeWire(CaptureSourceType source_type); ~BaseCapturerPipeWire() override; diff --git a/modules/desktop_capture/linux/shared_x_display.h b/modules/desktop_capture/linux/shared_x_display.h index 64c498c134..dd52e456ca 100644 --- a/modules/desktop_capture/linux/shared_x_display.h +++ b/modules/desktop_capture/linux/shared_x_display.h @@ -28,7 +28,8 @@ typedef union _XEvent XEvent; namespace webrtc { // A ref-counted object to store XDisplay connection. -class RTC_EXPORT SharedXDisplay : public rtc::RefCountedBase { +class RTC_EXPORT SharedXDisplay + : public rtc::RefCountedNonVirtual { public: class XEventHandler { public: @@ -38,9 +39,6 @@ class RTC_EXPORT SharedXDisplay : public rtc::RefCountedBase { virtual bool HandleXEvent(const XEvent& event) = 0; }; - // Takes ownership of |display|. - explicit SharedXDisplay(Display* display); - // Creates a new X11 Display for the |display_name|. NULL is returned if X11 // connection failed. Equivalent to CreateDefault() when |display_name| is // empty. @@ -65,8 +63,11 @@ class RTC_EXPORT SharedXDisplay : public rtc::RefCountedBase { void IgnoreXServerGrabs(); + ~SharedXDisplay(); + protected: - ~SharedXDisplay() override; + // Takes ownership of |display|. + explicit SharedXDisplay(Display* display); private: typedef std::map > EventHandlersMap; diff --git a/modules/desktop_capture/linux/x_error_trap.cc b/modules/desktop_capture/linux/x_error_trap.cc index 53c907fc45..13233d8274 100644 --- a/modules/desktop_capture/linux/x_error_trap.cc +++ b/modules/desktop_capture/linux/x_error_trap.cc @@ -10,55 +10,40 @@ #include "modules/desktop_capture/linux/x_error_trap.h" -#include #include -#if defined(TOOLKIT_GTK) -#include -#endif // !defined(TOOLKIT_GTK) +#include "rtc_base/checks.h" namespace webrtc { namespace { -#if !defined(TOOLKIT_GTK) - // TODO(sergeyu): This code is not thread safe. Fix it. Bug 2202. static bool g_xserver_error_trap_enabled = false; static int g_last_xserver_error_code = 0; int XServerErrorHandler(Display* display, XErrorEvent* error_event) { - assert(g_xserver_error_trap_enabled); + RTC_DCHECK(g_xserver_error_trap_enabled); g_last_xserver_error_code = error_event->error_code; return 0; } -#endif // !defined(TOOLKIT_GTK) - } // namespace XErrorTrap::XErrorTrap(Display* display) : original_error_handler_(NULL), enabled_(true) { -#if defined(TOOLKIT_GTK) - gdk_error_trap_push(); -#else // !defined(TOOLKIT_GTK) - assert(!g_xserver_error_trap_enabled); + RTC_DCHECK(!g_xserver_error_trap_enabled); original_error_handler_ = XSetErrorHandler(&XServerErrorHandler); g_xserver_error_trap_enabled = true; g_last_xserver_error_code = 0; -#endif // !defined(TOOLKIT_GTK) } int XErrorTrap::GetLastErrorAndDisable() { enabled_ = false; -#if defined(TOOLKIT_GTK) - return gdk_error_trap_push(); -#else // !defined(TOOLKIT_GTK) - assert(g_xserver_error_trap_enabled); + RTC_DCHECK(g_xserver_error_trap_enabled); XSetErrorHandler(original_error_handler_); g_xserver_error_trap_enabled = false; return g_last_xserver_error_code; -#endif // !defined(TOOLKIT_GTK) } XErrorTrap::~XErrorTrap() { diff --git a/modules/desktop_capture/mac/desktop_configuration_monitor.h b/modules/desktop_capture/mac/desktop_configuration_monitor.h index 46a66d1d4c..aa0ebfbacc 100644 --- a/modules/desktop_capture/mac/desktop_configuration_monitor.h +++ b/modules/desktop_capture/mac/desktop_configuration_monitor.h @@ -25,15 +25,15 @@ namespace webrtc { // The class provides functions to synchronize capturing and display // reconfiguring across threads, and the up-to-date MacDesktopConfiguration. -class DesktopConfigurationMonitor : public rtc::RefCountedBase { +class DesktopConfigurationMonitor final + : public rtc::RefCountedNonVirtual { public: DesktopConfigurationMonitor(); + ~DesktopConfigurationMonitor(); + // Returns the current desktop configuration. MacDesktopConfiguration desktop_configuration(); - protected: - ~DesktopConfigurationMonitor() override; - private: static void DisplaysReconfiguredCallback(CGDirectDisplayID display, CGDisplayChangeSummaryFlags flags, diff --git a/modules/desktop_capture/mac/desktop_frame_provider.h b/modules/desktop_capture/mac/desktop_frame_provider.h index 4826f99e8e..f71959bda1 100644 --- a/modules/desktop_capture/mac/desktop_frame_provider.h +++ b/modules/desktop_capture/mac/desktop_frame_provider.h @@ -17,8 +17,8 @@ #include #include +#include "api/sequence_checker.h" #include "modules/desktop_capture/shared_desktop_frame.h" -#include "rtc_base/thread_checker.h" #include "sdk/objc/helpers/scoped_cftyperef.h" namespace webrtc { @@ -44,7 +44,7 @@ class DesktopFrameProvider { void Release(); private: - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; const bool allow_iosurface_; // Most recent IOSurface that contains a capture of matching display. diff --git a/modules/desktop_capture/mac/screen_capturer_mac.h b/modules/desktop_capture/mac/screen_capturer_mac.h index 8076e5b09a..68b8655b1c 100644 --- a/modules/desktop_capture/mac/screen_capturer_mac.h +++ b/modules/desktop_capture/mac/screen_capturer_mac.h @@ -16,6 +16,7 @@ #include #include +#include "api/sequence_checker.h" #include "modules/desktop_capture/desktop_capture_options.h" #include "modules/desktop_capture/desktop_capturer.h" #include "modules/desktop_capture/desktop_frame.h" @@ -27,7 +28,6 @@ #include "modules/desktop_capture/screen_capture_frame_queue.h" #include "modules/desktop_capture/screen_capturer_helper.h" #include "modules/desktop_capture/shared_desktop_frame.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -110,7 +110,7 @@ class ScreenCapturerMac final : public DesktopCapturer { DesktopFrameProvider desktop_frame_provider_; // Start, CaptureFrame and destructor have to called in the same thread. - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; RTC_DISALLOW_COPY_AND_ASSIGN(ScreenCapturerMac); }; diff --git a/modules/desktop_capture/mouse_cursor.cc b/modules/desktop_capture/mouse_cursor.cc index 3b61e10a8b..e826552b0f 100644 --- a/modules/desktop_capture/mouse_cursor.cc +++ b/modules/desktop_capture/mouse_cursor.cc @@ -10,9 +10,8 @@ #include "modules/desktop_capture/mouse_cursor.h" -#include - #include "modules/desktop_capture/desktop_frame.h" +#include "rtc_base/checks.h" namespace webrtc { @@ -20,8 +19,8 @@ MouseCursor::MouseCursor() {} MouseCursor::MouseCursor(DesktopFrame* image, const DesktopVector& hotspot) : image_(image), hotspot_(hotspot) { - assert(0 <= hotspot_.x() && hotspot_.x() <= image_->size().width()); - assert(0 <= hotspot_.y() && hotspot_.y() <= image_->size().height()); + RTC_DCHECK(0 <= hotspot_.x() && hotspot_.x() <= image_->size().width()); + RTC_DCHECK(0 <= hotspot_.y() && hotspot_.y() <= image_->size().height()); } MouseCursor::~MouseCursor() {} diff --git a/modules/desktop_capture/mouse_cursor_monitor_unittest.cc b/modules/desktop_capture/mouse_cursor_monitor_unittest.cc index ee2dff32af..268e5e3475 100644 --- a/modules/desktop_capture/mouse_cursor_monitor_unittest.cc +++ b/modules/desktop_capture/mouse_cursor_monitor_unittest.cc @@ -65,7 +65,7 @@ TEST_F(MouseCursorMonitorTest, MAYBE(FromScreen)) { MouseCursorMonitor::CreateForScreen( DesktopCaptureOptions::CreateDefault(), webrtc::kFullDesktopScreenId)); - assert(capturer.get()); + RTC_DCHECK(capturer.get()); capturer->Init(this, MouseCursorMonitor::SHAPE_AND_POSITION); capturer->Capture(); @@ -102,7 +102,7 @@ TEST_F(MouseCursorMonitorTest, MAYBE(FromWindow)) { std::unique_ptr capturer( MouseCursorMonitor::CreateForWindow( DesktopCaptureOptions::CreateDefault(), sources[i].id)); - assert(capturer.get()); + RTC_DCHECK(capturer.get()); capturer->Init(this, MouseCursorMonitor::SHAPE_AND_POSITION); capturer->Capture(); @@ -118,7 +118,7 @@ TEST_F(MouseCursorMonitorTest, MAYBE(ShapeOnly)) { MouseCursorMonitor::CreateForScreen( DesktopCaptureOptions::CreateDefault(), webrtc::kFullDesktopScreenId)); - assert(capturer.get()); + RTC_DCHECK(capturer.get()); capturer->Init(this, MouseCursorMonitor::SHAPE_ONLY); capturer->Capture(); diff --git a/modules/desktop_capture/mouse_cursor_monitor_win.cc b/modules/desktop_capture/mouse_cursor_monitor_win.cc index bf0d8534e3..5a10ee1251 100644 --- a/modules/desktop_capture/mouse_cursor_monitor_win.cc +++ b/modules/desktop_capture/mouse_cursor_monitor_win.cc @@ -77,7 +77,7 @@ MouseCursorMonitorWin::MouseCursorMonitorWin(ScreenId screen) callback_(NULL), mode_(SHAPE_AND_POSITION), desktop_dc_(NULL) { - assert(screen >= kFullDesktopScreenId); + RTC_DCHECK_GE(screen, kFullDesktopScreenId); memset(&last_cursor_, 0, sizeof(CURSORINFO)); } @@ -87,8 +87,8 @@ MouseCursorMonitorWin::~MouseCursorMonitorWin() { } void MouseCursorMonitorWin::Init(Callback* callback, Mode mode) { - assert(!callback_); - assert(callback); + RTC_DCHECK(!callback_); + RTC_DCHECK(callback); callback_ = callback; mode_ = mode; @@ -97,7 +97,7 @@ void MouseCursorMonitorWin::Init(Callback* callback, Mode mode) { } void MouseCursorMonitorWin::Capture() { - assert(callback_); + RTC_DCHECK(callback_); CURSORINFO cursor_info; cursor_info.cbSize = sizeof(CURSORINFO); @@ -158,7 +158,7 @@ void MouseCursorMonitorWin::Capture() { position = position.subtract(cropped_rect.top_left()); } } else { - assert(screen_ != kInvalidScreenId); + RTC_DCHECK_NE(screen_, kInvalidScreenId); DesktopRect rect = GetScreenRect(); if (inside) inside = rect.Contains(position); @@ -169,7 +169,7 @@ void MouseCursorMonitorWin::Capture() { } DesktopRect MouseCursorMonitorWin::GetScreenRect() { - assert(screen_ != kInvalidScreenId); + RTC_DCHECK_NE(screen_, kInvalidScreenId); if (screen_ == kFullDesktopScreenId) { return DesktopRect::MakeXYWH(GetSystemMetrics(SM_XVIRTUALSCREEN), GetSystemMetrics(SM_YVIRTUALSCREEN), diff --git a/modules/desktop_capture/screen_capturer_helper.cc b/modules/desktop_capture/screen_capturer_helper.cc index 535b653c08..e8bd3fc450 100644 --- a/modules/desktop_capture/screen_capturer_helper.cc +++ b/modules/desktop_capture/screen_capturer_helper.cc @@ -74,7 +74,7 @@ static int UpToMultiple(int x, int n, int nMask) { void ScreenCapturerHelper::ExpandToGrid(const DesktopRegion& region, int log_grid_size, DesktopRegion* result) { - assert(log_grid_size >= 1); + RTC_DCHECK_GE(log_grid_size, 1); int grid_size = 1 << log_grid_size; int grid_size_mask = ~(grid_size - 1); diff --git a/modules/desktop_capture/screen_capturer_unittest.cc b/modules/desktop_capture/screen_capturer_unittest.cc index ea77069278..ba6b8bfe3d 100644 --- a/modules/desktop_capture/screen_capturer_unittest.cc +++ b/modules/desktop_capture/screen_capturer_unittest.cc @@ -99,7 +99,13 @@ ACTION_P(SaveUniquePtrArg, dest) { *dest = std::move(*arg1); } -TEST_F(ScreenCapturerTest, GetScreenListAndSelectScreen) { +// TODO(bugs.webrtc.org/12950): Re-enable when libc++ issue is fixed. +#if defined(WEBRTC_LINUX) && defined(MEMORY_SANITIZER) +#define MAYBE_GetScreenListAndSelectScreen DISABLED_GetScreenListAndSelectScreen +#else +#define MAYBE_GetScreenListAndSelectScreen GetScreenListAndSelectScreen +#endif +TEST_F(ScreenCapturerTest, MAYBE_GetScreenListAndSelectScreen) { webrtc::DesktopCapturer::SourceList screens; EXPECT_TRUE(capturer_->GetSourceList(&screens)); for (const auto& screen : screens) { diff --git a/modules/desktop_capture/screen_drawer_unittest.cc b/modules/desktop_capture/screen_drawer_unittest.cc index c38eee6991..2394260105 100644 --- a/modules/desktop_capture/screen_drawer_unittest.cc +++ b/modules/desktop_capture/screen_drawer_unittest.cc @@ -48,13 +48,12 @@ void TestScreenDrawerLock( ~Task() = default; - static void RunTask(void* me) { - Task* task = static_cast(me); - std::unique_ptr lock = task->ctor_(); + void RunTask() { + std::unique_ptr lock = ctor_(); ASSERT_TRUE(!!lock); - task->created_->store(true); + created_->store(true); // Wait for the main thread to get the signal of created_. - while (!task->ready_.load()) { + while (!ready_.load()) { SleepMs(1); } // At this point, main thread should begin to create a second lock. Though @@ -77,8 +76,8 @@ void TestScreenDrawerLock( const rtc::FunctionView()> ctor_; } task(&created, ready, ctor); - rtc::PlatformThread lock_thread(&Task::RunTask, &task, "lock_thread"); - lock_thread.Start(); + auto lock_thread = rtc::PlatformThread::SpawnJoinable( + [&task] { task.RunTask(); }, "lock_thread"); // Wait for the first lock in Task::RunTask() to be created. // TODO(zijiehe): Find a better solution to wait for the creation of the first @@ -95,7 +94,6 @@ void TestScreenDrawerLock( ASSERT_GT(kLockDurationMs, rtc::TimeMillis() - start_ms); ctor(); ASSERT_LE(kLockDurationMs, rtc::TimeMillis() - start_ms); - lock_thread.Stop(); } } // namespace diff --git a/modules/desktop_capture/shared_desktop_frame.h b/modules/desktop_capture/shared_desktop_frame.h index fd862d7f21..1f451b65df 100644 --- a/modules/desktop_capture/shared_desktop_frame.h +++ b/modules/desktop_capture/shared_desktop_frame.h @@ -23,7 +23,7 @@ namespace webrtc { // SharedDesktopFrame is a DesktopFrame that may have multiple instances all // sharing the same buffer. -class RTC_EXPORT SharedDesktopFrame : public DesktopFrame { +class RTC_EXPORT SharedDesktopFrame final : public DesktopFrame { public: ~SharedDesktopFrame() override; @@ -51,7 +51,7 @@ class RTC_EXPORT SharedDesktopFrame : public DesktopFrame { bool IsShared(); private: - typedef rtc::RefCountedObject> Core; + typedef rtc::FinalRefCountedObject> Core; SharedDesktopFrame(rtc::scoped_refptr core); diff --git a/modules/desktop_capture/win/dxgi_duplicator_controller.cc b/modules/desktop_capture/win/dxgi_duplicator_controller.cc index bdf495837e..4460ad94f2 100644 --- a/modules/desktop_capture/win/dxgi_duplicator_controller.cc +++ b/modules/desktop_capture/win/dxgi_duplicator_controller.cc @@ -85,14 +85,14 @@ void DxgiDuplicatorController::Release() { } bool DxgiDuplicatorController::IsSupported() { - rtc::CritScope lock(&lock_); + MutexLock lock(&mutex_); return Initialize(); } bool DxgiDuplicatorController::RetrieveD3dInfo(D3dInfo* info) { bool result = false; { - rtc::CritScope lock(&lock_); + MutexLock lock(&mutex_); result = Initialize(); *info = d3d_info_; } @@ -116,7 +116,7 @@ DxgiDuplicatorController::Result DxgiDuplicatorController::DuplicateMonitor( } DesktopVector DxgiDuplicatorController::dpi() { - rtc::CritScope lock(&lock_); + MutexLock lock(&mutex_); if (Initialize()) { return dpi_; } @@ -124,7 +124,7 @@ DesktopVector DxgiDuplicatorController::dpi() { } int DxgiDuplicatorController::ScreenCount() { - rtc::CritScope lock(&lock_); + MutexLock lock(&mutex_); if (Initialize()) { return ScreenCountUnlocked(); } @@ -133,7 +133,7 @@ int DxgiDuplicatorController::ScreenCount() { bool DxgiDuplicatorController::GetDeviceNames( std::vector* output) { - rtc::CritScope lock(&lock_); + MutexLock lock(&mutex_); if (Initialize()) { GetDeviceNamesUnlocked(output); return true; @@ -145,7 +145,7 @@ DxgiDuplicatorController::Result DxgiDuplicatorController::DoDuplicate( DxgiFrame* frame, int monitor_id) { RTC_DCHECK(frame); - rtc::CritScope lock(&lock_); + MutexLock lock(&mutex_); // The dxgi components and APIs do not update the screen resolution without // a reinitialization. So we use the GetDC() function to retrieve the screen @@ -198,12 +198,12 @@ DxgiDuplicatorController::Result DxgiDuplicatorController::DoDuplicate( } void DxgiDuplicatorController::Unload() { - rtc::CritScope lock(&lock_); + MutexLock lock(&mutex_); Deinitialize(); } void DxgiDuplicatorController::Unregister(const Context* const context) { - rtc::CritScope lock(&lock_); + MutexLock lock(&mutex_); if (ContextExpired(context)) { // The Context has not been setup after a recent initialization, so it // should not been registered in duplicators. diff --git a/modules/desktop_capture/win/dxgi_duplicator_controller.h b/modules/desktop_capture/win/dxgi_duplicator_controller.h index b6f8e78649..5e714f35cf 100644 --- a/modules/desktop_capture/win/dxgi_duplicator_controller.h +++ b/modules/desktop_capture/win/dxgi_duplicator_controller.h @@ -25,7 +25,7 @@ #include "modules/desktop_capture/win/dxgi_adapter_duplicator.h" #include "modules/desktop_capture/win/dxgi_context.h" #include "modules/desktop_capture/win/dxgi_frame.h" -#include "rtc_base/deprecated/recursive_critical_section.h" +#include "rtc_base/synchronization/mutex.h" namespace webrtc { @@ -142,95 +142,103 @@ class DxgiDuplicatorController { Result DoDuplicate(DxgiFrame* frame, int monitor_id); // Unload all the DXGI components and releases the resources. This function - // wraps Deinitialize() with |lock_|. + // wraps Deinitialize() with |mutex_|. void Unload(); // Unregisters Context from this instance and all DxgiAdapterDuplicator(s) // it owns. void Unregister(const Context* const context); - // All functions below should be called in |lock_| locked scope and should be + // All functions below should be called in |mutex_| locked scope and should be // after a successful Initialize(). // If current instance has not been initialized, executes DoInitialize() // function, and returns initialize result. Otherwise directly returns true. // This function may calls Deinitialize() if initialization failed. - bool Initialize(); + bool Initialize() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Does the real initialization work, this function should only be called in // Initialize(). - bool DoInitialize(); + bool DoInitialize() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Clears all COM components referred by this instance. So next Duplicate() // call will eventually initialize this instance again. - void Deinitialize(); + void Deinitialize() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // A helper function to check whether a Context has been expired. - bool ContextExpired(const Context* const context) const; + bool ContextExpired(const Context* const context) const + RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Updates Context if needed. - void Setup(Context* context); + void Setup(Context* context) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); bool DoDuplicateUnlocked(Context* context, int monitor_id, - SharedDesktopFrame* target); + SharedDesktopFrame* target) + RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Captures all monitors. - bool DoDuplicateAll(Context* context, SharedDesktopFrame* target); + bool DoDuplicateAll(Context* context, SharedDesktopFrame* target) + RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Captures one monitor. bool DoDuplicateOne(Context* context, int monitor_id, - SharedDesktopFrame* target); + SharedDesktopFrame* target) + RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // The minimum GetNumFramesCaptured() returned by |duplicators_|. - int64_t GetNumFramesCaptured() const; + int64_t GetNumFramesCaptured() const RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Returns a DesktopSize to cover entire |desktop_rect_|. - DesktopSize desktop_size() const; + DesktopSize desktop_size() const RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Returns the size of one screen. |id| should be >= 0. If system does not // support DXGI based capturer, or |id| is greater than the total screen count // of all the Duplicators, this function returns an empty DesktopRect. - DesktopRect ScreenRect(int id) const; + DesktopRect ScreenRect(int id) const RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - int ScreenCountUnlocked() const; + int ScreenCountUnlocked() const RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - void GetDeviceNamesUnlocked(std::vector* output) const; + void GetDeviceNamesUnlocked(std::vector* output) const + RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Returns the desktop size of the selected screen |monitor_id|. Setting // |monitor_id| < 0 to return the entire screen size. - DesktopSize SelectedDesktopSize(int monitor_id) const; + DesktopSize SelectedDesktopSize(int monitor_id) const + RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Retries DoDuplicateAll() for several times until GetNumFramesCaptured() is // large enough. Returns false if DoDuplicateAll() returns false, or // GetNumFramesCaptured() has never reached the requirement. // According to http://crbug.com/682112, dxgi capturer returns a black frame // during first several capture attempts. - bool EnsureFrameCaptured(Context* context, SharedDesktopFrame* target); + bool EnsureFrameCaptured(Context* context, SharedDesktopFrame* target) + RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Moves |desktop_rect_| and all underlying |duplicators_|, putting top left // corner of the desktop at (0, 0). This is necessary because DXGI_OUTPUT_DESC // may return negative coordinates. Called from DoInitialize() after all // DxgiAdapterDuplicator and DxgiOutputDuplicator instances are initialized. - void TranslateRect(); + void TranslateRect() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // The count of references which are now "living". std::atomic_int refcount_; // This lock must be locked whenever accessing any of the following objects. - rtc::RecursiveCriticalSection lock_; + Mutex mutex_; // A self-incremented integer to compare with the one in Context. It ensures // a Context instance is always initialized after DxgiDuplicatorController. - int identity_ = 0; - DesktopRect desktop_rect_; - DesktopVector dpi_; - std::vector duplicators_; - D3dInfo d3d_info_; - DisplayConfigurationMonitor display_configuration_monitor_; + int identity_ RTC_GUARDED_BY(mutex_) = 0; + DesktopRect desktop_rect_ RTC_GUARDED_BY(mutex_); + DesktopVector dpi_ RTC_GUARDED_BY(mutex_); + std::vector duplicators_ RTC_GUARDED_BY(mutex_); + D3dInfo d3d_info_ RTC_GUARDED_BY(mutex_); + DisplayConfigurationMonitor display_configuration_monitor_ + RTC_GUARDED_BY(mutex_); // A number to indicate how many succeeded duplications have been performed. - uint32_t succeeded_duplications_ = 0; + uint32_t succeeded_duplications_ RTC_GUARDED_BY(mutex_) = 0; }; } // namespace webrtc diff --git a/modules/desktop_capture/win/screen_capture_utils.cc b/modules/desktop_capture/win/screen_capture_utils.cc index 95f6d92059..53b6dd399c 100644 --- a/modules/desktop_capture/win/screen_capture_utils.cc +++ b/modules/desktop_capture/win/screen_capture_utils.cc @@ -16,7 +16,9 @@ #include #include "modules/desktop_capture/desktop_capturer.h" +#include "modules/desktop_capture/desktop_geometry.h" #include "rtc_base/checks.h" +#include "rtc_base/logging.h" #include "rtc_base/string_utils.h" #include "rtc_base/win32.h" @@ -36,12 +38,14 @@ bool GetScreenList(DesktopCapturer::SourceList* screens, enum_result = EnumDisplayDevicesW(NULL, device_index, &device, 0); // |enum_result| is 0 if we have enumerated all devices. - if (!enum_result) + if (!enum_result) { break; + } // We only care about active displays. - if (!(device.StateFlags & DISPLAY_DEVICE_ACTIVE)) + if (!(device.StateFlags & DISPLAY_DEVICE_ACTIVE)) { continue; + } screens->push_back({device_index, std::string()}); if (device_names) { @@ -51,7 +55,64 @@ bool GetScreenList(DesktopCapturer::SourceList* screens, return true; } -bool IsScreenValid(DesktopCapturer::SourceId screen, std::wstring* device_key) { +bool GetHmonitorFromDeviceIndex(const DesktopCapturer::SourceId device_index, + HMONITOR* hmonitor) { + // A device index of |kFullDesktopScreenId| or -1 represents all screens, an + // HMONITOR of 0 indicates the same. + if (device_index == kFullDesktopScreenId) { + *hmonitor = 0; + return true; + } + + std::wstring device_key; + if (!IsScreenValid(device_index, &device_key)) { + return false; + } + + DesktopRect screen_rect = GetScreenRect(device_index, device_key); + if (screen_rect.is_empty()) { + return false; + } + + RECT rect = {screen_rect.left(), screen_rect.top(), screen_rect.right(), + screen_rect.bottom()}; + + HMONITOR monitor = MonitorFromRect(&rect, MONITOR_DEFAULTTONULL); + if (monitor == NULL) { + RTC_LOG(LS_WARNING) << "No HMONITOR found for supplied device index."; + return false; + } + + *hmonitor = monitor; + return true; +} + +bool IsMonitorValid(const HMONITOR monitor) { + // An HMONITOR of 0 refers to a virtual monitor that spans all physical + // monitors. + if (monitor == 0) { + return true; + } + + MONITORINFO monitor_info; + monitor_info.cbSize = sizeof(MONITORINFO); + return GetMonitorInfoA(monitor, &monitor_info); +} + +DesktopRect GetMonitorRect(const HMONITOR monitor) { + MONITORINFO monitor_info; + monitor_info.cbSize = sizeof(MONITORINFO); + if (!GetMonitorInfoA(monitor, &monitor_info)) { + return DesktopRect(); + } + + return DesktopRect::MakeLTRB( + monitor_info.rcMonitor.left, monitor_info.rcMonitor.top, + monitor_info.rcMonitor.right, monitor_info.rcMonitor.bottom); +} + +bool IsScreenValid(const DesktopCapturer::SourceId screen, + std::wstring* device_key) { if (screen == kFullDesktopScreenId) { *device_key = L""; return true; @@ -60,8 +121,9 @@ bool IsScreenValid(DesktopCapturer::SourceId screen, std::wstring* device_key) { DISPLAY_DEVICEW device; device.cb = sizeof(device); BOOL enum_result = EnumDisplayDevicesW(NULL, screen, &device, 0); - if (enum_result) + if (enum_result) { *device_key = device.DeviceKey; + } return !!enum_result; } @@ -73,7 +135,7 @@ DesktopRect GetFullscreenRect() { GetSystemMetrics(SM_CYVIRTUALSCREEN)); } -DesktopRect GetScreenRect(DesktopCapturer::SourceId screen, +DesktopRect GetScreenRect(const DesktopCapturer::SourceId screen, const std::wstring& device_key) { if (screen == kFullDesktopScreenId) { return GetFullscreenRect(); @@ -82,23 +144,26 @@ DesktopRect GetScreenRect(DesktopCapturer::SourceId screen, DISPLAY_DEVICEW device; device.cb = sizeof(device); BOOL result = EnumDisplayDevicesW(NULL, screen, &device, 0); - if (!result) + if (!result) { return DesktopRect(); + } // Verifies the device index still maps to the same display device, to make // sure we are capturing the same device when devices are added or removed. // DeviceKey is documented as reserved, but it actually contains the registry // key for the device and is unique for each monitor, while DeviceID is not. - if (device_key != device.DeviceKey) + if (device_key != device.DeviceKey) { return DesktopRect(); + } DEVMODEW device_mode; device_mode.dmSize = sizeof(device_mode); device_mode.dmDriverExtra = 0; result = EnumDisplaySettingsExW(device.DeviceName, ENUM_CURRENT_SETTINGS, &device_mode, 0); - if (!result) + if (!result) { return DesktopRect(); + } return DesktopRect::MakeXYWH( device_mode.dmPosition.x, device_mode.dmPosition.y, diff --git a/modules/desktop_capture/win/screen_capture_utils.h b/modules/desktop_capture/win/screen_capture_utils.h index 5c4c11d542..dc993dad25 100644 --- a/modules/desktop_capture/win/screen_capture_utils.h +++ b/modules/desktop_capture/win/screen_capture_utils.h @@ -27,11 +27,26 @@ namespace webrtc { bool GetScreenList(DesktopCapturer::SourceList* screens, std::vector* device_names = nullptr); +// Converts a device index (which are returned by |GetScreenList|) into an +// HMONITOR. +bool GetHmonitorFromDeviceIndex(const DesktopCapturer::SourceId device_index, + HMONITOR* hmonitor); + +// Returns true if |monitor| represents a valid display +// monitor. Consumers should recheck the validity of HMONITORs before use if a +// WM_DISPLAYCHANGE message has been received. +bool IsMonitorValid(const HMONITOR monitor); + +// Returns the rect of the monitor identified by |monitor|, relative to the +// primary display's top-left. On failure, returns an empty rect. +DesktopRect GetMonitorRect(const HMONITOR monitor); + // Returns true if |screen| is a valid screen. The screen device key is // returned through |device_key| if the screen is valid. The device key can be // used in GetScreenRect to verify the screen matches the previously obtained // id. -bool IsScreenValid(DesktopCapturer::SourceId screen, std::wstring* device_key); +bool IsScreenValid(const DesktopCapturer::SourceId screen, + std::wstring* device_key); // Get the rect of the entire system in system coordinate system. I.e. the // primary monitor always starts from (0, 0). @@ -40,7 +55,7 @@ DesktopRect GetFullscreenRect(); // Get the rect of the screen identified by |screen|, relative to the primary // display's top-left. If the screen device key does not match |device_key|, or // the screen does not exist, or any error happens, an empty rect is returned. -RTC_EXPORT DesktopRect GetScreenRect(DesktopCapturer::SourceId screen, +RTC_EXPORT DesktopRect GetScreenRect(const DesktopCapturer::SourceId screen, const std::wstring& device_key); } // namespace webrtc diff --git a/modules/desktop_capture/win/screen_capture_utils_unittest.cc b/modules/desktop_capture/win/screen_capture_utils_unittest.cc index a71c4f7610..80d1fb3242 100644 --- a/modules/desktop_capture/win/screen_capture_utils_unittest.cc +++ b/modules/desktop_capture/win/screen_capture_utils_unittest.cc @@ -13,7 +13,9 @@ #include #include +#include "modules/desktop_capture/desktop_capture_types.h" #include "modules/desktop_capture/desktop_capturer.h" +#include "rtc_base/logging.h" #include "test/gtest.h" namespace webrtc { @@ -29,4 +31,29 @@ TEST(ScreenCaptureUtilsTest, GetScreenList) { ASSERT_EQ(screens.size(), device_names.size()); } +TEST(ScreenCaptureUtilsTest, DeviceIndexToHmonitor) { + DesktopCapturer::SourceList screens; + ASSERT_TRUE(GetScreenList(&screens)); + if (screens.size() == 0) { + RTC_LOG(LS_INFO) << "Skip screen capture test on systems with no monitors."; + GTEST_SKIP(); + } + + HMONITOR hmonitor; + ASSERT_TRUE(GetHmonitorFromDeviceIndex(screens[0].id, &hmonitor)); + ASSERT_TRUE(IsMonitorValid(hmonitor)); +} + +TEST(ScreenCaptureUtilsTest, FullScreenDeviceIndexToHmonitor) { + HMONITOR hmonitor; + ASSERT_TRUE(GetHmonitorFromDeviceIndex(kFullDesktopScreenId, &hmonitor)); + ASSERT_EQ(hmonitor, static_cast(0)); + ASSERT_TRUE(IsMonitorValid(hmonitor)); +} + +TEST(ScreenCaptureUtilsTest, InvalidDeviceIndexToHmonitor) { + HMONITOR hmonitor; + ASSERT_FALSE(GetHmonitorFromDeviceIndex(kInvalidScreenId, &hmonitor)); +} + } // namespace webrtc diff --git a/modules/desktop_capture/win/screen_capturer_win_directx.cc b/modules/desktop_capture/win/screen_capturer_win_directx.cc index df3bee8f26..1556d7c787 100644 --- a/modules/desktop_capture/win/screen_capturer_win_directx.cc +++ b/modules/desktop_capture/win/screen_capturer_win_directx.cc @@ -16,12 +16,15 @@ #include #include +#include "modules/desktop_capture/desktop_capture_metrics_helper.h" +#include "modules/desktop_capture/desktop_capture_types.h" #include "modules/desktop_capture/desktop_frame.h" #include "modules/desktop_capture/win/screen_capture_utils.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/time_utils.h" #include "rtc_base/trace_event.h" +#include "system_wrappers/include/metrics.h" namespace webrtc { @@ -106,6 +109,7 @@ ScreenCapturerWinDirectx::~ScreenCapturerWinDirectx() = default; void ScreenCapturerWinDirectx::Start(Callback* callback) { RTC_DCHECK(!callback_); RTC_DCHECK(callback); + RecordCapturerImpl(DesktopCapturerId::kScreenCapturerWinDirectx); callback_ = callback; } @@ -169,8 +173,13 @@ void ScreenCapturerWinDirectx::CaptureFrame() { case DuplicateResult::SUCCEEDED: { std::unique_ptr frame = frames_.current_frame()->frame()->Share(); - frame->set_capture_time_ms((rtc::TimeNanos() - capture_start_time_nanos) / - rtc::kNumNanosecsPerMillisec); + + int capture_time_ms = (rtc::TimeNanos() - capture_start_time_nanos) / + rtc::kNumNanosecsPerMillisec; + RTC_HISTOGRAM_COUNTS_1000( + "WebRTC.DesktopCapture.Win.DirectXCapturerFrameTime", + capture_time_ms); + frame->set_capture_time_ms(capture_time_ms); frame->set_capturer_id(DesktopCapturerId::kScreenCapturerWinDirectx); // TODO(julien.isorce): http://crbug.com/945468. Set the icc profile on diff --git a/modules/desktop_capture/win/screen_capturer_win_gdi.cc b/modules/desktop_capture/win/screen_capturer_win_gdi.cc index bf6cb162a0..dc27344f82 100644 --- a/modules/desktop_capture/win/screen_capturer_win_gdi.cc +++ b/modules/desktop_capture/win/screen_capturer_win_gdi.cc @@ -12,7 +12,9 @@ #include +#include "modules/desktop_capture/desktop_capture_metrics_helper.h" #include "modules/desktop_capture/desktop_capture_options.h" +#include "modules/desktop_capture/desktop_capture_types.h" #include "modules/desktop_capture/desktop_frame.h" #include "modules/desktop_capture/desktop_frame_win.h" #include "modules/desktop_capture/desktop_region.h" @@ -24,6 +26,7 @@ #include "rtc_base/logging.h" #include "rtc_base/time_utils.h" #include "rtc_base/trace_event.h" +#include "system_wrappers/include/metrics.h" namespace webrtc { @@ -92,8 +95,12 @@ void ScreenCapturerWinGdi::CaptureFrame() { GetDeviceCaps(desktop_dc_, LOGPIXELSY))); frame->mutable_updated_region()->SetRect( DesktopRect::MakeSize(frame->size())); - frame->set_capture_time_ms((rtc::TimeNanos() - capture_start_time_nanos) / - rtc::kNumNanosecsPerMillisec); + + int capture_time_ms = (rtc::TimeNanos() - capture_start_time_nanos) / + rtc::kNumNanosecsPerMillisec; + RTC_HISTOGRAM_COUNTS_1000( + "WebRTC.DesktopCapture.Win.ScreenGdiCapturerFrameTime", capture_time_ms); + frame->set_capture_time_ms(capture_time_ms); frame->set_capturer_id(DesktopCapturerId::kScreenCapturerWinGdi); callback_->OnCaptureResult(Result::SUCCESS, std::move(frame)); } @@ -112,6 +119,7 @@ bool ScreenCapturerWinGdi::SelectSource(SourceId id) { void ScreenCapturerWinGdi::Start(Callback* callback) { RTC_DCHECK(!callback_); RTC_DCHECK(callback); + RecordCapturerImpl(DesktopCapturerId::kScreenCapturerWinGdi); callback_ = callback; diff --git a/modules/desktop_capture/win/screen_capturer_win_magnifier.cc b/modules/desktop_capture/win/screen_capturer_win_magnifier.cc index 1a7bbc18c8..214eb0e463 100644 --- a/modules/desktop_capture/win/screen_capturer_win_magnifier.cc +++ b/modules/desktop_capture/win/screen_capturer_win_magnifier.cc @@ -12,7 +12,9 @@ #include +#include "modules/desktop_capture/desktop_capture_metrics_helper.h" #include "modules/desktop_capture/desktop_capture_options.h" +#include "modules/desktop_capture/desktop_capture_types.h" #include "modules/desktop_capture/desktop_frame.h" #include "modules/desktop_capture/desktop_frame_win.h" #include "modules/desktop_capture/desktop_region.h" @@ -23,6 +25,7 @@ #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/time_utils.h" +#include "system_wrappers/include/metrics.h" namespace webrtc { @@ -62,6 +65,8 @@ ScreenCapturerWinMagnifier::~ScreenCapturerWinMagnifier() { void ScreenCapturerWinMagnifier::Start(Callback* callback) { RTC_DCHECK(!callback_); RTC_DCHECK(callback); + RecordCapturerImpl(DesktopCapturerId::kScreenCapturerWinMagnifier); + callback_ = callback; if (!InitializeMagnifier()) { @@ -115,8 +120,13 @@ void ScreenCapturerWinMagnifier::CaptureFrame() { GetDeviceCaps(desktop_dc_, LOGPIXELSY))); frame->mutable_updated_region()->SetRect( DesktopRect::MakeSize(frame->size())); - frame->set_capture_time_ms((rtc::TimeNanos() - capture_start_time_nanos) / - rtc::kNumNanosecsPerMillisec); + + int capture_time_ms = (rtc::TimeNanos() - capture_start_time_nanos) / + rtc::kNumNanosecsPerMillisec; + RTC_HISTOGRAM_COUNTS_1000( + "WebRTC.DesktopCapture.Win.MagnifierCapturerFrameTime", capture_time_ms); + frame->set_capture_time_ms(capture_time_ms); + frame->set_capturer_id(DesktopCapturerId::kScreenCapturerWinMagnifier); callback_->OnCaptureResult(Result::SUCCESS, std::move(frame)); } diff --git a/modules/desktop_capture/win/test_support/test_window.cc b/modules/desktop_capture/win/test_support/test_window.cc index dc94ee0d6e..c07ff74aa5 100644 --- a/modules/desktop_capture/win/test_support/test_window.cc +++ b/modules/desktop_capture/win/test_support/test_window.cc @@ -17,15 +17,36 @@ const WCHAR kWindowClass[] = L"DesktopCaptureTestWindowClass"; const int kWindowHeight = 200; const int kWindowWidth = 300; +LRESULT CALLBACK WindowProc(HWND hwnd, + UINT msg, + WPARAM w_param, + LPARAM l_param) { + switch (msg) { + case WM_PAINT: + PAINTSTRUCT paint_struct; + HDC hdc = BeginPaint(hwnd, &paint_struct); + + // Paint the window so the color is consistent and we can inspect the + // pixels in tests and know what to expect. + FillRect(hdc, &paint_struct.rcPaint, + CreateSolidBrush(RGB(kTestWindowRValue, kTestWindowGValue, + kTestWindowBValue))); + + EndPaint(hwnd, &paint_struct); + } + return DefWindowProc(hwnd, msg, w_param, l_param); +} + } // namespace WindowInfo CreateTestWindow(const WCHAR* window_title, const int height, - const int width) { + const int width, + const LONG extended_styles) { WindowInfo info; ::GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, - reinterpret_cast(&::DefWindowProc), + reinterpret_cast(&WindowProc), &info.window_instance); WNDCLASSEXW wcex; @@ -33,7 +54,7 @@ WindowInfo CreateTestWindow(const WCHAR* window_title, wcex.cbSize = sizeof(wcex); wcex.style = CS_HREDRAW | CS_VREDRAW; wcex.hInstance = info.window_instance; - wcex.lpfnWndProc = &::DefWindowProc; + wcex.lpfnWndProc = &WindowProc; wcex.lpszClassName = kWindowClass; info.window_class = ::RegisterClassExW(&wcex); @@ -41,11 +62,12 @@ WindowInfo CreateTestWindow(const WCHAR* window_title, // height and width parameters, or if they supplied invalid values. int window_height = height <= 0 ? kWindowHeight : height; int window_width = width <= 0 ? kWindowWidth : width; - info.hwnd = ::CreateWindowW(kWindowClass, window_title, WS_OVERLAPPEDWINDOW, - CW_USEDEFAULT, CW_USEDEFAULT, window_width, - window_height, /*parent_window=*/nullptr, - /*menu_bar=*/nullptr, info.window_instance, - /*additional_params=*/nullptr); + info.hwnd = + ::CreateWindowExW(extended_styles, kWindowClass, window_title, + WS_OVERLAPPEDWINDOW, CW_USEDEFAULT, CW_USEDEFAULT, + window_width, window_height, /*parent_window=*/nullptr, + /*menu_bar=*/nullptr, info.window_instance, + /*additional_params=*/nullptr); ::ShowWindow(info.hwnd, SW_SHOWNORMAL); ::UpdateWindow(info.hwnd); @@ -53,8 +75,16 @@ WindowInfo CreateTestWindow(const WCHAR* window_title, } void ResizeTestWindow(const HWND hwnd, const int width, const int height) { + // SWP_NOMOVE results in the x and y params being ignored. ::SetWindowPos(hwnd, HWND_TOP, /*x-coord=*/0, /*y-coord=*/0, width, height, - SWP_SHOWWINDOW); + SWP_SHOWWINDOW | SWP_NOMOVE); + ::UpdateWindow(hwnd); +} + +void MoveTestWindow(const HWND hwnd, const int x, const int y) { + // SWP_NOSIZE results in the width and height params being ignored. + ::SetWindowPos(hwnd, HWND_TOP, x, y, /*width=*/0, /*height=*/0, + SWP_SHOWWINDOW | SWP_NOSIZE); ::UpdateWindow(hwnd); } diff --git a/modules/desktop_capture/win/test_support/test_window.h b/modules/desktop_capture/win/test_support/test_window.h index a5962b5819..8701dc990b 100644 --- a/modules/desktop_capture/win/test_support/test_window.h +++ b/modules/desktop_capture/win/test_support/test_window.h @@ -17,6 +17,14 @@ namespace webrtc { +typedef unsigned char uint8_t; + +// Define an arbitrary color for the test window with unique R, G, and B values +// so consumers can verify captured content in tests. +const uint8_t kTestWindowRValue = 191; +const uint8_t kTestWindowGValue = 99; +const uint8_t kTestWindowBValue = 12; + struct WindowInfo { HWND hwnd; HINSTANCE window_instance; @@ -25,10 +33,13 @@ struct WindowInfo { WindowInfo CreateTestWindow(const WCHAR* window_title, const int height = 0, - const int width = 0); + const int width = 0, + const LONG extended_styles = 0); void ResizeTestWindow(const HWND hwnd, const int width, const int height); +void MoveTestWindow(const HWND hwnd, const int x, const int y); + void MinimizeTestWindow(const HWND hwnd); void UnminimizeTestWindow(const HWND hwnd); diff --git a/modules/desktop_capture/win/wgc_capture_session.cc b/modules/desktop_capture/win/wgc_capture_session.cc index ee55cf6164..48c56864b3 100644 --- a/modules/desktop_capture/win/wgc_capture_session.cc +++ b/modules/desktop_capture/win/wgc_capture_session.cc @@ -10,31 +10,363 @@ #include "modules/desktop_capture/win/wgc_capture_session.h" +#include +#include +#include + +#include #include +#include +#include "modules/desktop_capture/win/wgc_desktop_frame.h" #include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/time_utils.h" +#include "rtc_base/win/create_direct3d_device.h" +#include "rtc_base/win/get_activation_factory.h" +#include "system_wrappers/include/metrics.h" using Microsoft::WRL::ComPtr; +namespace WGC = ABI::Windows::Graphics::Capture; + namespace webrtc { +namespace { + +// We must use a BGRA pixel format that has 4 bytes per pixel, as required by +// the DesktopFrame interface. +const auto kPixelFormat = ABI::Windows::Graphics::DirectX::DirectXPixelFormat:: + DirectXPixelFormat_B8G8R8A8UIntNormalized; + +// We only want 1 buffer in our frame pool to reduce latency. If we had more, +// they would sit in the pool for longer and be stale by the time we are asked +// for a new frame. +const int kNumBuffers = 1; + +// These values are persisted to logs. Entries should not be renumbered and +// numeric values should never be reused. +enum class StartCaptureResult { + kSuccess = 0, + kSourceClosed = 1, + kAddClosedFailed = 2, + kDxgiDeviceCastFailed = 3, + kD3dDelayLoadFailed = 4, + kD3dDeviceCreationFailed = 5, + kFramePoolActivationFailed = 6, + kFramePoolCastFailed = 7, + kGetItemSizeFailed = 8, + kCreateFreeThreadedFailed = 9, + kCreateCaptureSessionFailed = 10, + kStartCaptureFailed = 11, + kMaxValue = kStartCaptureFailed +}; + +// These values are persisted to logs. Entries should not be renumbered and +// numeric values should never be reused. +enum class GetFrameResult { + kSuccess = 0, + kItemClosed = 1, + kTryGetNextFrameFailed = 2, + kFrameDropped = 3, + kGetSurfaceFailed = 4, + kDxgiInterfaceAccessFailed = 5, + kTexture2dCastFailed = 6, + kCreateMappedTextureFailed = 7, + kMapFrameFailed = 8, + kGetContentSizeFailed = 9, + kResizeMappedTextureFailed = 10, + kRecreateFramePoolFailed = 11, + kMaxValue = kRecreateFramePoolFailed +}; + +void RecordStartCaptureResult(StartCaptureResult error) { + RTC_HISTOGRAM_ENUMERATION( + "WebRTC.DesktopCapture.Win.WgcCaptureSessionStartResult", + static_cast(error), static_cast(StartCaptureResult::kMaxValue)); +} + +void RecordGetFrameResult(GetFrameResult error) { + RTC_HISTOGRAM_ENUMERATION( + "WebRTC.DesktopCapture.Win.WgcCaptureSessionGetFrameResult", + static_cast(error), static_cast(GetFrameResult::kMaxValue)); +} + +} // namespace WgcCaptureSession::WgcCaptureSession(ComPtr d3d11_device, - HWND window) - : d3d11_device_(std::move(d3d11_device)), window_(window) {} + ComPtr item) + : d3d11_device_(std::move(d3d11_device)), item_(std::move(item)) {} WgcCaptureSession::~WgcCaptureSession() = default; HRESULT WgcCaptureSession::StartCapture() { + RTC_DCHECK_RUN_ON(&sequence_checker_); RTC_DCHECK(!is_capture_started_); + + if (item_closed_) { + RTC_LOG(LS_ERROR) << "The target source has been closed."; + RecordStartCaptureResult(StartCaptureResult::kSourceClosed); + return E_ABORT; + } + RTC_DCHECK(d3d11_device_); - RTC_DCHECK(window_); + RTC_DCHECK(item_); - return E_NOTIMPL; + // Listen for the Closed event, to detect if the source we are capturing is + // closed (e.g. application window is closed or monitor is disconnected). If + // it is, we should abort the capture. + auto closed_handler = + Microsoft::WRL::Callback>( + this, &WgcCaptureSession::OnItemClosed); + EventRegistrationToken item_closed_token; + HRESULT hr = item_->add_Closed(closed_handler.Get(), &item_closed_token); + if (FAILED(hr)) { + RecordStartCaptureResult(StartCaptureResult::kAddClosedFailed); + return hr; + } + + ComPtr dxgi_device; + hr = d3d11_device_->QueryInterface(IID_PPV_ARGS(&dxgi_device)); + if (FAILED(hr)) { + RecordStartCaptureResult(StartCaptureResult::kDxgiDeviceCastFailed); + return hr; + } + + if (!ResolveCoreWinRTDirect3DDelayload()) { + RecordStartCaptureResult(StartCaptureResult::kD3dDelayLoadFailed); + return E_FAIL; + } + + hr = CreateDirect3DDeviceFromDXGIDevice(dxgi_device.Get(), &direct3d_device_); + if (FAILED(hr)) { + RecordStartCaptureResult(StartCaptureResult::kD3dDeviceCreationFailed); + return hr; + } + + ComPtr frame_pool_statics; + hr = GetActivationFactory< + ABI::Windows::Graphics::Capture::IDirect3D11CaptureFramePoolStatics, + RuntimeClass_Windows_Graphics_Capture_Direct3D11CaptureFramePool>( + &frame_pool_statics); + if (FAILED(hr)) { + RecordStartCaptureResult(StartCaptureResult::kFramePoolActivationFailed); + return hr; + } + + // Cast to FramePoolStatics2 so we can use CreateFreeThreaded and avoid the + // need to have a DispatcherQueue. We don't listen for the FrameArrived event, + // so there's no difference. + ComPtr frame_pool_statics2; + hr = frame_pool_statics->QueryInterface(IID_PPV_ARGS(&frame_pool_statics2)); + if (FAILED(hr)) { + RecordStartCaptureResult(StartCaptureResult::kFramePoolCastFailed); + return hr; + } + + ABI::Windows::Graphics::SizeInt32 item_size; + hr = item_.Get()->get_Size(&item_size); + if (FAILED(hr)) { + RecordStartCaptureResult(StartCaptureResult::kGetItemSizeFailed); + return hr; + } + + previous_size_ = item_size; + + hr = frame_pool_statics2->CreateFreeThreaded(direct3d_device_.Get(), + kPixelFormat, kNumBuffers, + item_size, &frame_pool_); + if (FAILED(hr)) { + RecordStartCaptureResult(StartCaptureResult::kCreateFreeThreadedFailed); + return hr; + } + + hr = frame_pool_->CreateCaptureSession(item_.Get(), &session_); + if (FAILED(hr)) { + RecordStartCaptureResult(StartCaptureResult::kCreateCaptureSessionFailed); + return hr; + } + + hr = session_->StartCapture(); + if (FAILED(hr)) { + RTC_LOG(LS_ERROR) << "Failed to start CaptureSession: " << hr; + RecordStartCaptureResult(StartCaptureResult::kStartCaptureFailed); + return hr; + } + + RecordStartCaptureResult(StartCaptureResult::kSuccess); + + is_capture_started_ = true; + return hr; } -HRESULT WgcCaptureSession::GetMostRecentFrame( +HRESULT WgcCaptureSession::GetFrame( std::unique_ptr* output_frame) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + + if (item_closed_) { + RTC_LOG(LS_ERROR) << "The target source has been closed."; + RecordGetFrameResult(GetFrameResult::kItemClosed); + return E_ABORT; + } + RTC_DCHECK(is_capture_started_); - return E_NOTIMPL; + ComPtr capture_frame; + HRESULT hr = frame_pool_->TryGetNextFrame(&capture_frame); + if (FAILED(hr)) { + RTC_LOG(LS_ERROR) << "TryGetNextFrame failed: " << hr; + RecordGetFrameResult(GetFrameResult::kTryGetNextFrameFailed); + return hr; + } + + if (!capture_frame) { + RecordGetFrameResult(GetFrameResult::kFrameDropped); + return hr; + } + + // We need to get this CaptureFrame as an ID3D11Texture2D so that we can get + // the raw image data in the format required by the DesktopFrame interface. + ComPtr + d3d_surface; + hr = capture_frame->get_Surface(&d3d_surface); + if (FAILED(hr)) { + RecordGetFrameResult(GetFrameResult::kGetSurfaceFailed); + return hr; + } + + ComPtr + direct3DDxgiInterfaceAccess; + hr = d3d_surface->QueryInterface(IID_PPV_ARGS(&direct3DDxgiInterfaceAccess)); + if (FAILED(hr)) { + RecordGetFrameResult(GetFrameResult::kDxgiInterfaceAccessFailed); + return hr; + } + + ComPtr texture_2D; + hr = direct3DDxgiInterfaceAccess->GetInterface(IID_PPV_ARGS(&texture_2D)); + if (FAILED(hr)) { + RecordGetFrameResult(GetFrameResult::kTexture2dCastFailed); + return hr; + } + + if (!mapped_texture_) { + hr = CreateMappedTexture(texture_2D); + if (FAILED(hr)) { + RecordGetFrameResult(GetFrameResult::kCreateMappedTextureFailed); + return hr; + } + } + + // We need to copy |texture_2D| into |mapped_texture_| as the latter has the + // D3D11_CPU_ACCESS_READ flag set, which lets us access the image data. + // Otherwise it would only be readable by the GPU. + ComPtr d3d_context; + d3d11_device_->GetImmediateContext(&d3d_context); + d3d_context->CopyResource(mapped_texture_.Get(), texture_2D.Get()); + + D3D11_MAPPED_SUBRESOURCE map_info; + hr = d3d_context->Map(mapped_texture_.Get(), /*subresource_index=*/0, + D3D11_MAP_READ, /*D3D11_MAP_FLAG_DO_NOT_WAIT=*/0, + &map_info); + if (FAILED(hr)) { + RecordGetFrameResult(GetFrameResult::kMapFrameFailed); + return hr; + } + + ABI::Windows::Graphics::SizeInt32 new_size; + hr = capture_frame->get_ContentSize(&new_size); + if (FAILED(hr)) { + RecordGetFrameResult(GetFrameResult::kGetContentSizeFailed); + return hr; + } + + // If the size has changed since the last capture, we must be sure to use + // the smaller dimensions. Otherwise we might overrun our buffer, or + // read stale data from the last frame. + int image_height = std::min(previous_size_.Height, new_size.Height); + int image_width = std::min(previous_size_.Width, new_size.Width); + int row_data_length = image_width * DesktopFrame::kBytesPerPixel; + + // Make a copy of the data pointed to by |map_info.pData| so we are free to + // unmap our texture. + uint8_t* src_data = static_cast(map_info.pData); + std::vector image_data; + image_data.reserve(image_height * row_data_length); + uint8_t* image_data_ptr = image_data.data(); + for (int i = 0; i < image_height; i++) { + memcpy(image_data_ptr, src_data, row_data_length); + image_data_ptr += row_data_length; + src_data += map_info.RowPitch; + } + + // Transfer ownership of |image_data| to the output_frame. + DesktopSize size(image_width, image_height); + *output_frame = std::make_unique(size, row_data_length, + std::move(image_data)); + + d3d_context->Unmap(mapped_texture_.Get(), 0); + + // If the size changed, we must resize the texture and frame pool to fit the + // new size. + if (previous_size_.Height != new_size.Height || + previous_size_.Width != new_size.Width) { + hr = CreateMappedTexture(texture_2D, new_size.Width, new_size.Height); + if (FAILED(hr)) { + RecordGetFrameResult(GetFrameResult::kResizeMappedTextureFailed); + return hr; + } + + hr = frame_pool_->Recreate(direct3d_device_.Get(), kPixelFormat, + kNumBuffers, new_size); + if (FAILED(hr)) { + RecordGetFrameResult(GetFrameResult::kRecreateFramePoolFailed); + return hr; + } + } + + RecordGetFrameResult(GetFrameResult::kSuccess); + + previous_size_ = new_size; + return hr; +} + +HRESULT WgcCaptureSession::CreateMappedTexture( + ComPtr src_texture, + UINT width, + UINT height) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + + D3D11_TEXTURE2D_DESC src_desc; + src_texture->GetDesc(&src_desc); + D3D11_TEXTURE2D_DESC map_desc; + map_desc.Width = width == 0 ? src_desc.Width : width; + map_desc.Height = height == 0 ? src_desc.Height : height; + map_desc.MipLevels = src_desc.MipLevels; + map_desc.ArraySize = src_desc.ArraySize; + map_desc.Format = src_desc.Format; + map_desc.SampleDesc = src_desc.SampleDesc; + map_desc.Usage = D3D11_USAGE_STAGING; + map_desc.BindFlags = 0; + map_desc.CPUAccessFlags = D3D11_CPU_ACCESS_READ; + map_desc.MiscFlags = 0; + return d3d11_device_->CreateTexture2D(&map_desc, nullptr, &mapped_texture_); +} + +HRESULT WgcCaptureSession::OnItemClosed(WGC::IGraphicsCaptureItem* sender, + IInspectable* event_args) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + + RTC_LOG(LS_INFO) << "Capture target has been closed."; + item_closed_ = true; + is_capture_started_ = false; + + mapped_texture_ = nullptr; + session_ = nullptr; + frame_pool_ = nullptr; + direct3d_device_ = nullptr; + item_ = nullptr; + d3d11_device_ = nullptr; + + return S_OK; } } // namespace webrtc diff --git a/modules/desktop_capture/win/wgc_capture_session.h b/modules/desktop_capture/win/wgc_capture_session.h index 9f41331c92..9f08b7cf2d 100644 --- a/modules/desktop_capture/win/wgc_capture_session.h +++ b/modules/desktop_capture/win/wgc_capture_session.h @@ -11,36 +11,98 @@ #ifndef MODULES_DESKTOP_CAPTURE_WIN_WGC_CAPTURE_SESSION_H_ #define MODULES_DESKTOP_CAPTURE_WIN_WGC_CAPTURE_SESSION_H_ -#include #include +#include #include + #include -#include "modules/desktop_capture/desktop_frame.h" +#include "api/sequence_checker.h" +#include "modules/desktop_capture/desktop_capture_options.h" +#include "modules/desktop_capture/win/wgc_capture_source.h" namespace webrtc { class WgcCaptureSession final { public: - WgcCaptureSession(Microsoft::WRL::ComPtr d3d11_device, - HWND window); + WgcCaptureSession( + Microsoft::WRL::ComPtr d3d11_device, + Microsoft::WRL::ComPtr< + ABI::Windows::Graphics::Capture::IGraphicsCaptureItem> item); - // Disallow copy and assign + // Disallow copy and assign. WgcCaptureSession(const WgcCaptureSession&) = delete; WgcCaptureSession& operator=(const WgcCaptureSession&) = delete; ~WgcCaptureSession(); HRESULT StartCapture(); - HRESULT GetMostRecentFrame(std::unique_ptr* output_frame); - bool IsCaptureStarted() const { return is_capture_started_; } + + // Returns a frame from the frame pool, if any are present. + HRESULT GetFrame(std::unique_ptr* output_frame); + + bool IsCaptureStarted() const { + RTC_DCHECK_RUN_ON(&sequence_checker_); + return is_capture_started_; + } private: + // Initializes |mapped_texture_| with the properties of the |src_texture|, + // overrides the values of some necessary properties like the + // D3D11_CPU_ACCESS_READ flag. Also has optional parameters for what size + // |mapped_texture_| should be, if they aren't provided we will use the size + // of |src_texture|. + HRESULT CreateMappedTexture( + Microsoft::WRL::ComPtr src_texture, + UINT width = 0, + UINT height = 0); + + // Event handler for |item_|'s Closed event. + HRESULT OnItemClosed( + ABI::Windows::Graphics::Capture::IGraphicsCaptureItem* sender, + IInspectable* event_args); + // A Direct3D11 Device provided by the caller. We use this to create an // IDirect3DDevice, and also to create textures that will hold the image data. Microsoft::WRL::ComPtr d3d11_device_; - HWND window_; + + // This item represents what we are capturing, we use it to create the + // capture session, and also to listen for the Closed event. + Microsoft::WRL::ComPtr + item_; + + // The IDirect3DDevice is necessary to instantiate the frame pool. + Microsoft::WRL::ComPtr< + ABI::Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice> + direct3d_device_; + + // The frame pool is where frames are deposited during capture, we retrieve + // them from here with TryGetNextFrame(). + Microsoft::WRL::ComPtr< + ABI::Windows::Graphics::Capture::IDirect3D11CaptureFramePool> + frame_pool_; + + // This texture holds the final image data. We made it a member so we can + // reuse it, instead of having to create a new texture every time we grab a + // frame. + Microsoft::WRL::ComPtr mapped_texture_; + + // This lets us know when the source has been resized, which is important + // because we must resize the framepool and our texture to be able to hold + // enough data for the frame. + ABI::Windows::Graphics::SizeInt32 previous_size_; + + // The capture session lets us set properties about the capture before it + // starts such as whether to capture the mouse cursor, and it lets us tell WGC + // to start capturing frames. + Microsoft::WRL::ComPtr< + ABI::Windows::Graphics::Capture::IGraphicsCaptureSession> + session_; + + bool item_closed_ = false; bool is_capture_started_ = false; + + SequenceChecker sequence_checker_; }; } // namespace webrtc diff --git a/modules/desktop_capture/win/wgc_capture_source.cc b/modules/desktop_capture/win/wgc_capture_source.cc new file mode 100644 index 0000000000..9786ca67b5 --- /dev/null +++ b/modules/desktop_capture/win/wgc_capture_source.cc @@ -0,0 +1,178 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/desktop_capture/win/wgc_capture_source.h" + +#include +#include + +#include + +#include "modules/desktop_capture/win/screen_capture_utils.h" +#include "modules/desktop_capture/win/window_capture_utils.h" +#include "rtc_base/win/get_activation_factory.h" + +using Microsoft::WRL::ComPtr; +namespace WGC = ABI::Windows::Graphics::Capture; + +namespace webrtc { + +WgcCaptureSource::WgcCaptureSource(DesktopCapturer::SourceId source_id) + : source_id_(source_id) {} +WgcCaptureSource::~WgcCaptureSource() = default; + +bool WgcCaptureSource::IsCapturable() { + // If we can create a capture item, then we can capture it. Unfortunately, + // we can't cache this item because it may be created in a different COM + // apartment than where capture will eventually start from. + ComPtr item; + return SUCCEEDED(CreateCaptureItem(&item)); +} + +bool WgcCaptureSource::FocusOnSource() { + return false; +} + +HRESULT WgcCaptureSource::GetCaptureItem( + ComPtr* result) { + HRESULT hr = S_OK; + if (!item_) + hr = CreateCaptureItem(&item_); + + *result = item_; + return hr; +} + +WgcCaptureSourceFactory::~WgcCaptureSourceFactory() = default; + +WgcWindowSourceFactory::WgcWindowSourceFactory() = default; +WgcWindowSourceFactory::~WgcWindowSourceFactory() = default; + +std::unique_ptr WgcWindowSourceFactory::CreateCaptureSource( + DesktopCapturer::SourceId source_id) { + return std::make_unique(source_id); +} + +WgcScreenSourceFactory::WgcScreenSourceFactory() = default; +WgcScreenSourceFactory::~WgcScreenSourceFactory() = default; + +std::unique_ptr WgcScreenSourceFactory::CreateCaptureSource( + DesktopCapturer::SourceId source_id) { + return std::make_unique(source_id); +} + +WgcWindowSource::WgcWindowSource(DesktopCapturer::SourceId source_id) + : WgcCaptureSource(source_id) {} +WgcWindowSource::~WgcWindowSource() = default; + +DesktopVector WgcWindowSource::GetTopLeft() { + DesktopRect window_rect; + if (!GetWindowRect(reinterpret_cast(GetSourceId()), &window_rect)) + return DesktopVector(); + + return window_rect.top_left(); +} + +bool WgcWindowSource::IsCapturable() { + if (!IsWindowValidAndVisible(reinterpret_cast(GetSourceId()))) + return false; + + return WgcCaptureSource::IsCapturable(); +} + +bool WgcWindowSource::FocusOnSource() { + if (!IsWindowValidAndVisible(reinterpret_cast(GetSourceId()))) + return false; + + return ::BringWindowToTop(reinterpret_cast(GetSourceId())) && + ::SetForegroundWindow(reinterpret_cast(GetSourceId())); +} + +HRESULT WgcWindowSource::CreateCaptureItem( + ComPtr* result) { + if (!ResolveCoreWinRTDelayload()) + return E_FAIL; + + ComPtr interop; + HRESULT hr = GetActivationFactory< + IGraphicsCaptureItemInterop, + RuntimeClass_Windows_Graphics_Capture_GraphicsCaptureItem>(&interop); + if (FAILED(hr)) + return hr; + + ComPtr item; + hr = interop->CreateForWindow(reinterpret_cast(GetSourceId()), + IID_PPV_ARGS(&item)); + if (FAILED(hr)) + return hr; + + if (!item) + return E_HANDLE; + + *result = std::move(item); + return hr; +} + +WgcScreenSource::WgcScreenSource(DesktopCapturer::SourceId source_id) + : WgcCaptureSource(source_id) { + // Getting the HMONITOR could fail if the source_id is invalid. In that case, + // we leave hmonitor_ uninitialized and |IsCapturable()| will fail. + HMONITOR hmon; + if (GetHmonitorFromDeviceIndex(GetSourceId(), &hmon)) + hmonitor_ = hmon; +} + +WgcScreenSource::~WgcScreenSource() = default; + +DesktopVector WgcScreenSource::GetTopLeft() { + if (!hmonitor_) + return DesktopVector(); + + return GetMonitorRect(*hmonitor_).top_left(); +} + +bool WgcScreenSource::IsCapturable() { + if (!hmonitor_) + return false; + + if (!IsMonitorValid(*hmonitor_)) + return false; + + return WgcCaptureSource::IsCapturable(); +} + +HRESULT WgcScreenSource::CreateCaptureItem( + ComPtr* result) { + if (!hmonitor_) + return E_ABORT; + + if (!ResolveCoreWinRTDelayload()) + return E_FAIL; + + ComPtr interop; + HRESULT hr = GetActivationFactory< + IGraphicsCaptureItemInterop, + RuntimeClass_Windows_Graphics_Capture_GraphicsCaptureItem>(&interop); + if (FAILED(hr)) + return hr; + + ComPtr item; + hr = interop->CreateForMonitor(*hmonitor_, IID_PPV_ARGS(&item)); + if (FAILED(hr)) + return hr; + + if (!item) + return E_HANDLE; + + *result = std::move(item); + return hr; +} + +} // namespace webrtc diff --git a/modules/desktop_capture/win/wgc_capture_source.h b/modules/desktop_capture/win/wgc_capture_source.h new file mode 100644 index 0000000000..135f92bb84 --- /dev/null +++ b/modules/desktop_capture/win/wgc_capture_source.h @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_DESKTOP_CAPTURE_WIN_WGC_CAPTURE_SOURCE_H_ +#define MODULES_DESKTOP_CAPTURE_WIN_WGC_CAPTURE_SOURCE_H_ + +#include +#include + +#include + +#include "absl/types/optional.h" +#include "modules/desktop_capture/desktop_capturer.h" +#include "modules/desktop_capture/desktop_geometry.h" + +namespace webrtc { + +// Abstract class to represent the source that WGC-based capturers capture +// from. Could represent an application window or a screen. Consumers should use +// the appropriate Wgc*SourceFactory class to create WgcCaptureSource objects +// of the appropriate type. +class WgcCaptureSource { + public: + explicit WgcCaptureSource(DesktopCapturer::SourceId source_id); + virtual ~WgcCaptureSource(); + + virtual DesktopVector GetTopLeft() = 0; + virtual bool IsCapturable(); + virtual bool FocusOnSource(); + HRESULT GetCaptureItem( + Microsoft::WRL::ComPtr< + ABI::Windows::Graphics::Capture::IGraphicsCaptureItem>* result); + DesktopCapturer::SourceId GetSourceId() { return source_id_; } + + protected: + virtual HRESULT CreateCaptureItem( + Microsoft::WRL::ComPtr< + ABI::Windows::Graphics::Capture::IGraphicsCaptureItem>* result) = 0; + + private: + Microsoft::WRL::ComPtr + item_; + const DesktopCapturer::SourceId source_id_; +}; + +class WgcCaptureSourceFactory { + public: + virtual ~WgcCaptureSourceFactory(); + + virtual std::unique_ptr CreateCaptureSource( + DesktopCapturer::SourceId) = 0; +}; + +class WgcWindowSourceFactory final : public WgcCaptureSourceFactory { + public: + WgcWindowSourceFactory(); + + // Disallow copy and assign. + WgcWindowSourceFactory(const WgcWindowSourceFactory&) = delete; + WgcWindowSourceFactory& operator=(const WgcWindowSourceFactory&) = delete; + + ~WgcWindowSourceFactory() override; + + std::unique_ptr CreateCaptureSource( + DesktopCapturer::SourceId) override; +}; + +class WgcScreenSourceFactory final : public WgcCaptureSourceFactory { + public: + WgcScreenSourceFactory(); + + WgcScreenSourceFactory(const WgcScreenSourceFactory&) = delete; + WgcScreenSourceFactory& operator=(const WgcScreenSourceFactory&) = delete; + + ~WgcScreenSourceFactory() override; + + std::unique_ptr CreateCaptureSource( + DesktopCapturer::SourceId) override; +}; + +// Class for capturing application windows. +class WgcWindowSource final : public WgcCaptureSource { + public: + explicit WgcWindowSource(DesktopCapturer::SourceId source_id); + + WgcWindowSource(const WgcWindowSource&) = delete; + WgcWindowSource& operator=(const WgcWindowSource&) = delete; + + ~WgcWindowSource() override; + + DesktopVector GetTopLeft() override; + bool IsCapturable() override; + bool FocusOnSource() override; + + private: + HRESULT CreateCaptureItem( + Microsoft::WRL::ComPtr< + ABI::Windows::Graphics::Capture::IGraphicsCaptureItem>* result) + override; +}; + +// Class for capturing screens/monitors/displays. +class WgcScreenSource final : public WgcCaptureSource { + public: + explicit WgcScreenSource(DesktopCapturer::SourceId source_id); + + WgcScreenSource(const WgcScreenSource&) = delete; + WgcScreenSource& operator=(const WgcScreenSource&) = delete; + + ~WgcScreenSource() override; + + DesktopVector GetTopLeft() override; + bool IsCapturable() override; + + private: + HRESULT CreateCaptureItem( + Microsoft::WRL::ComPtr< + ABI::Windows::Graphics::Capture::IGraphicsCaptureItem>* result) + override; + + // To maintain compatibility with other capturers, this class accepts a + // device index as it's SourceId. However, WGC requires we use an HMONITOR to + // describe which screen to capture. So, we internally convert the supplied + // device index into an HMONITOR when |IsCapturable()| is called. + absl::optional hmonitor_; +}; + +} // namespace webrtc + +#endif // MODULES_DESKTOP_CAPTURE_WIN_WGC_CAPTURE_SOURCE_H_ diff --git a/modules/desktop_capture/win/wgc_capture_source_unittest.cc b/modules/desktop_capture/win/wgc_capture_source_unittest.cc new file mode 100644 index 0000000000..a230e12578 --- /dev/null +++ b/modules/desktop_capture/win/wgc_capture_source_unittest.cc @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/desktop_capture/win/wgc_capture_source.h" + +#include +#include + +#include + +#include "modules/desktop_capture/desktop_capture_types.h" +#include "modules/desktop_capture/desktop_geometry.h" +#include "modules/desktop_capture/win/screen_capture_utils.h" +#include "modules/desktop_capture/win/test_support/test_window.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/win/scoped_com_initializer.h" +#include "rtc_base/win/windows_version.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +const WCHAR kWindowTitle[] = L"WGC Capture Source Test Window"; + +const int kFirstXCoord = 25; +const int kFirstYCoord = 50; +const int kSecondXCoord = 50; +const int kSecondYCoord = 75; + +enum SourceType { kWindowSource = 0, kScreenSource = 1 }; + +} // namespace + +class WgcCaptureSourceTest : public ::testing::TestWithParam { + public: + void SetUp() override { + if (rtc::rtc_win::GetVersion() < rtc::rtc_win::Version::VERSION_WIN10_RS5) { + RTC_LOG(LS_INFO) + << "Skipping WgcCaptureSourceTests on Windows versions < RS5."; + GTEST_SKIP(); + } + + com_initializer_ = + std::make_unique(ScopedCOMInitializer::kMTA); + ASSERT_TRUE(com_initializer_->Succeeded()); + } + + void TearDown() override { + if (window_open_) { + DestroyTestWindow(window_info_); + } + } + + void SetUpForWindowSource() { + window_info_ = CreateTestWindow(kWindowTitle); + window_open_ = true; + source_id_ = reinterpret_cast(window_info_.hwnd); + source_factory_ = std::make_unique(); + } + + void SetUpForScreenSource() { + source_id_ = kFullDesktopScreenId; + source_factory_ = std::make_unique(); + } + + protected: + std::unique_ptr com_initializer_; + std::unique_ptr source_factory_; + std::unique_ptr source_; + DesktopCapturer::SourceId source_id_; + WindowInfo window_info_; + bool window_open_ = false; +}; + +// Window specific test +TEST_F(WgcCaptureSourceTest, WindowPosition) { + SetUpForWindowSource(); + source_ = source_factory_->CreateCaptureSource(source_id_); + ASSERT_TRUE(source_); + EXPECT_EQ(source_->GetSourceId(), source_id_); + + MoveTestWindow(window_info_.hwnd, kFirstXCoord, kFirstYCoord); + DesktopVector source_vector = source_->GetTopLeft(); + EXPECT_EQ(source_vector.x(), kFirstXCoord); + EXPECT_EQ(source_vector.y(), kFirstYCoord); + + MoveTestWindow(window_info_.hwnd, kSecondXCoord, kSecondYCoord); + source_vector = source_->GetTopLeft(); + EXPECT_EQ(source_vector.x(), kSecondXCoord); + EXPECT_EQ(source_vector.y(), kSecondYCoord); +} + +// Screen specific test +TEST_F(WgcCaptureSourceTest, ScreenPosition) { + SetUpForScreenSource(); + source_ = source_factory_->CreateCaptureSource(source_id_); + ASSERT_TRUE(source_); + EXPECT_EQ(source_id_, source_->GetSourceId()); + + DesktopRect screen_rect = GetFullscreenRect(); + DesktopVector source_vector = source_->GetTopLeft(); + EXPECT_EQ(source_vector.x(), screen_rect.left()); + EXPECT_EQ(source_vector.y(), screen_rect.top()); +} + +// Source agnostic test +TEST_P(WgcCaptureSourceTest, CreateSource) { + if (GetParam() == SourceType::kWindowSource) { + SetUpForWindowSource(); + } else { + SetUpForScreenSource(); + } + + source_ = source_factory_->CreateCaptureSource(source_id_); + ASSERT_TRUE(source_); + EXPECT_EQ(source_id_, source_->GetSourceId()); + EXPECT_TRUE(source_->IsCapturable()); + + Microsoft::WRL::ComPtr + item; + EXPECT_TRUE(SUCCEEDED(source_->GetCaptureItem(&item))); + EXPECT_TRUE(item); +} + +INSTANTIATE_TEST_SUITE_P(SourceAgnostic, + WgcCaptureSourceTest, + ::testing::Values(SourceType::kWindowSource, + SourceType::kScreenSource)); + +} // namespace webrtc diff --git a/modules/desktop_capture/win/wgc_capturer_win.cc b/modules/desktop_capture/win/wgc_capturer_win.cc new file mode 100644 index 0000000000..442c827a67 --- /dev/null +++ b/modules/desktop_capture/win/wgc_capturer_win.cc @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/desktop_capture/win/wgc_capturer_win.h" + +#include + +#include "modules/desktop_capture/desktop_capture_metrics_helper.h" +#include "modules/desktop_capture/desktop_capture_types.h" +#include "modules/desktop_capture/win/wgc_desktop_frame.h" +#include "rtc_base/logging.h" +#include "rtc_base/time_utils.h" +#include "system_wrappers/include/metrics.h" + +namespace WGC = ABI::Windows::Graphics::Capture; +using Microsoft::WRL::ComPtr; + +namespace webrtc { + +namespace { + +enum class WgcCapturerResult { + kSuccess = 0, + kNoDirect3dDevice = 1, + kNoSourceSelected = 2, + kItemCreationFailure = 3, + kSessionStartFailure = 4, + kGetFrameFailure = 5, + kFrameDropped = 6, + kMaxValue = kFrameDropped +}; + +void RecordWgcCapturerResult(WgcCapturerResult error) { + RTC_HISTOGRAM_ENUMERATION("WebRTC.DesktopCapture.Win.WgcCapturerResult", + static_cast(error), + static_cast(WgcCapturerResult::kMaxValue)); +} + +} // namespace + +WgcCapturerWin::WgcCapturerWin( + std::unique_ptr source_factory, + std::unique_ptr source_enumerator) + : source_factory_(std::move(source_factory)), + source_enumerator_(std::move(source_enumerator)) {} +WgcCapturerWin::~WgcCapturerWin() = default; + +// static +std::unique_ptr WgcCapturerWin::CreateRawWindowCapturer( + const DesktopCaptureOptions& options) { + return std::make_unique( + std::make_unique(), + std::make_unique( + options.enumerate_current_process_windows())); +} + +// static +std::unique_ptr WgcCapturerWin::CreateRawScreenCapturer( + const DesktopCaptureOptions& options) { + return std::make_unique( + std::make_unique(), + std::make_unique()); +} + +bool WgcCapturerWin::GetSourceList(SourceList* sources) { + return source_enumerator_->FindAllSources(sources); +} + +bool WgcCapturerWin::SelectSource(DesktopCapturer::SourceId id) { + capture_source_ = source_factory_->CreateCaptureSource(id); + return capture_source_->IsCapturable(); +} + +bool WgcCapturerWin::FocusOnSelectedSource() { + if (!capture_source_) + return false; + + return capture_source_->FocusOnSource(); +} + +void WgcCapturerWin::Start(Callback* callback) { + RTC_DCHECK(!callback_); + RTC_DCHECK(callback); + RecordCapturerImpl(DesktopCapturerId::kWgcCapturerWin); + + callback_ = callback; + + // Create a Direct3D11 device to share amongst the WgcCaptureSessions. Many + // parameters are nullptr as the implemention uses defaults that work well for + // us. + HRESULT hr = D3D11CreateDevice( + /*adapter=*/nullptr, D3D_DRIVER_TYPE_HARDWARE, + /*software_rasterizer=*/nullptr, D3D11_CREATE_DEVICE_BGRA_SUPPORT, + /*feature_levels=*/nullptr, /*feature_levels_size=*/0, D3D11_SDK_VERSION, + &d3d11_device_, /*feature_level=*/nullptr, /*device_context=*/nullptr); + if (hr == DXGI_ERROR_UNSUPPORTED) { + // If a hardware device could not be created, use WARP which is a high speed + // software device. + hr = D3D11CreateDevice( + /*adapter=*/nullptr, D3D_DRIVER_TYPE_WARP, + /*software_rasterizer=*/nullptr, D3D11_CREATE_DEVICE_BGRA_SUPPORT, + /*feature_levels=*/nullptr, /*feature_levels_size=*/0, + D3D11_SDK_VERSION, &d3d11_device_, /*feature_level=*/nullptr, + /*device_context=*/nullptr); + } + + if (FAILED(hr)) { + RTC_LOG(LS_ERROR) << "Failed to create D3D11Device: " << hr; + } +} + +void WgcCapturerWin::CaptureFrame() { + RTC_DCHECK(callback_); + + if (!capture_source_) { + RTC_LOG(LS_ERROR) << "Source hasn't been selected"; + callback_->OnCaptureResult(DesktopCapturer::Result::ERROR_PERMANENT, + /*frame=*/nullptr); + RecordWgcCapturerResult(WgcCapturerResult::kNoSourceSelected); + return; + } + + if (!d3d11_device_) { + RTC_LOG(LS_ERROR) << "No D3D11D3evice, cannot capture."; + callback_->OnCaptureResult(DesktopCapturer::Result::ERROR_PERMANENT, + /*frame=*/nullptr); + RecordWgcCapturerResult(WgcCapturerResult::kNoDirect3dDevice); + return; + } + + int64_t capture_start_time_nanos = rtc::TimeNanos(); + + HRESULT hr; + WgcCaptureSession* capture_session = nullptr; + std::map::iterator session_iter = + ongoing_captures_.find(capture_source_->GetSourceId()); + if (session_iter == ongoing_captures_.end()) { + ComPtr item; + hr = capture_source_->GetCaptureItem(&item); + if (FAILED(hr)) { + RTC_LOG(LS_ERROR) << "Failed to create a GraphicsCaptureItem: " << hr; + callback_->OnCaptureResult(DesktopCapturer::Result::ERROR_PERMANENT, + /*frame=*/nullptr); + RecordWgcCapturerResult(WgcCapturerResult::kItemCreationFailure); + return; + } + + std::pair::iterator, bool> + iter_success_pair = ongoing_captures_.emplace( + std::piecewise_construct, + std::forward_as_tuple(capture_source_->GetSourceId()), + std::forward_as_tuple(d3d11_device_, item)); + RTC_DCHECK(iter_success_pair.second); + capture_session = &iter_success_pair.first->second; + } else { + capture_session = &session_iter->second; + } + + if (!capture_session->IsCaptureStarted()) { + hr = capture_session->StartCapture(); + if (FAILED(hr)) { + RTC_LOG(LS_ERROR) << "Failed to start capture: " << hr; + ongoing_captures_.erase(capture_source_->GetSourceId()); + callback_->OnCaptureResult(DesktopCapturer::Result::ERROR_PERMANENT, + /*frame=*/nullptr); + RecordWgcCapturerResult(WgcCapturerResult::kSessionStartFailure); + return; + } + } + + std::unique_ptr frame; + hr = capture_session->GetFrame(&frame); + if (FAILED(hr)) { + RTC_LOG(LS_ERROR) << "GetFrame failed: " << hr; + ongoing_captures_.erase(capture_source_->GetSourceId()); + callback_->OnCaptureResult(DesktopCapturer::Result::ERROR_PERMANENT, + /*frame=*/nullptr); + RecordWgcCapturerResult(WgcCapturerResult::kGetFrameFailure); + return; + } + + if (!frame) { + callback_->OnCaptureResult(DesktopCapturer::Result::ERROR_TEMPORARY, + /*frame=*/nullptr); + RecordWgcCapturerResult(WgcCapturerResult::kFrameDropped); + return; + } + + int capture_time_ms = (rtc::TimeNanos() - capture_start_time_nanos) / + rtc::kNumNanosecsPerMillisec; + RTC_HISTOGRAM_COUNTS_1000("WebRTC.DesktopCapture.Win.WgcCapturerFrameTime", + capture_time_ms); + frame->set_capture_time_ms(capture_time_ms); + frame->set_capturer_id(DesktopCapturerId::kWgcCapturerWin); + frame->set_may_contain_cursor(true); + frame->set_top_left(capture_source_->GetTopLeft()); + RecordWgcCapturerResult(WgcCapturerResult::kSuccess); + callback_->OnCaptureResult(DesktopCapturer::Result::SUCCESS, + std::move(frame)); +} + +bool WgcCapturerWin::IsSourceBeingCaptured(DesktopCapturer::SourceId id) { + std::map::iterator + session_iter = ongoing_captures_.find(id); + if (session_iter == ongoing_captures_.end()) + return false; + + return session_iter->second.IsCaptureStarted(); +} + +} // namespace webrtc diff --git a/modules/desktop_capture/win/wgc_capturer_win.h b/modules/desktop_capture/win/wgc_capturer_win.h new file mode 100644 index 0000000000..58f3fc318a --- /dev/null +++ b/modules/desktop_capture/win/wgc_capturer_win.h @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_DESKTOP_CAPTURE_WIN_WGC_CAPTURER_WIN_H_ +#define MODULES_DESKTOP_CAPTURE_WIN_WGC_CAPTURER_WIN_H_ + +#include +#include + +#include +#include + +#include "modules/desktop_capture/desktop_capture_options.h" +#include "modules/desktop_capture/desktop_capturer.h" +#include "modules/desktop_capture/win/screen_capture_utils.h" +#include "modules/desktop_capture/win/wgc_capture_session.h" +#include "modules/desktop_capture/win/wgc_capture_source.h" +#include "modules/desktop_capture/win/window_capture_utils.h" + +namespace webrtc { + +// WgcCapturerWin is initialized with an implementation of this base class, +// which it uses to find capturable sources of a particular type. This way, +// WgcCapturerWin can remain source-agnostic. +class SourceEnumerator { + public: + virtual ~SourceEnumerator() = default; + + virtual bool FindAllSources(DesktopCapturer::SourceList* sources) = 0; +}; + +class WindowEnumerator final : public SourceEnumerator { + public: + explicit WindowEnumerator(bool enumerate_current_process_windows) + : enumerate_current_process_windows_(enumerate_current_process_windows) {} + + WindowEnumerator(const WindowEnumerator&) = delete; + WindowEnumerator& operator=(const WindowEnumerator&) = delete; + + ~WindowEnumerator() override = default; + + bool FindAllSources(DesktopCapturer::SourceList* sources) override { + // WGC fails to capture windows with the WS_EX_TOOLWINDOW style, so we + // provide it as a filter to ensure windows with the style are not returned. + return window_capture_helper_.EnumerateCapturableWindows( + sources, enumerate_current_process_windows_, WS_EX_TOOLWINDOW); + } + + private: + WindowCaptureHelperWin window_capture_helper_; + bool enumerate_current_process_windows_; +}; + +class ScreenEnumerator final : public SourceEnumerator { + public: + ScreenEnumerator() = default; + + ScreenEnumerator(const ScreenEnumerator&) = delete; + ScreenEnumerator& operator=(const ScreenEnumerator&) = delete; + + ~ScreenEnumerator() override = default; + + bool FindAllSources(DesktopCapturer::SourceList* sources) override { + return webrtc::GetScreenList(sources); + } +}; + +// A capturer that uses the Window.Graphics.Capture APIs. It is suitable for +// both window and screen capture (but only one type per instance). Consumers +// should not instantiate this class directly, instead they should use +// |CreateRawWindowCapturer()| or |CreateRawScreenCapturer()| to receive a +// capturer appropriate for the type of source they want to capture. +class WgcCapturerWin : public DesktopCapturer { + public: + WgcCapturerWin(std::unique_ptr source_factory, + std::unique_ptr source_enumerator); + + WgcCapturerWin(const WgcCapturerWin&) = delete; + WgcCapturerWin& operator=(const WgcCapturerWin&) = delete; + + ~WgcCapturerWin() override; + + static std::unique_ptr CreateRawWindowCapturer( + const DesktopCaptureOptions& options); + + static std::unique_ptr CreateRawScreenCapturer( + const DesktopCaptureOptions& options); + + // DesktopCapturer interface. + bool GetSourceList(SourceList* sources) override; + bool SelectSource(SourceId id) override; + bool FocusOnSelectedSource() override; + void Start(Callback* callback) override; + void CaptureFrame() override; + + // Used in WgcCapturerTests. + bool IsSourceBeingCaptured(SourceId id); + + private: + // Factory to create a WgcCaptureSource for us whenever SelectSource is + // called. Initialized at construction with a source-specific implementation. + std::unique_ptr source_factory_; + + // The source enumerator helps us find capturable sources of the appropriate + // type. Initialized at construction with a source-specific implementation. + std::unique_ptr source_enumerator_; + + // The WgcCaptureSource represents the source we are capturing. It tells us + // if the source is capturable and it creates the GraphicsCaptureItem for us. + std::unique_ptr capture_source_; + + // A map of all the sources we are capturing and the associated + // WgcCaptureSession. Frames for the current source (indicated via + // SelectSource) will be retrieved from the appropriate session when + // requested via CaptureFrame. + // This helps us efficiently capture multiple sources (e.g. when consumers + // are trying to display a list of available capture targets with thumbnails). + std::map ongoing_captures_; + + // The callback that we deliver frames to, synchronously, before CaptureFrame + // returns. + Callback* callback_ = nullptr; + + // A Direct3D11 device that is shared amongst the WgcCaptureSessions, who + // require one to perform the capture. + Microsoft::WRL::ComPtr<::ID3D11Device> d3d11_device_; +}; + +} // namespace webrtc + +#endif // MODULES_DESKTOP_CAPTURE_WIN_WGC_CAPTURER_WIN_H_ diff --git a/modules/desktop_capture/win/wgc_capturer_win_unittest.cc b/modules/desktop_capture/win/wgc_capturer_win_unittest.cc new file mode 100644 index 0000000000..ebfb576e63 --- /dev/null +++ b/modules/desktop_capture/win/wgc_capturer_win_unittest.cc @@ -0,0 +1,508 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/desktop_capture/win/wgc_capturer_win.h" + +#include +#include +#include + +#include "modules/desktop_capture/desktop_capture_options.h" +#include "modules/desktop_capture/desktop_capture_types.h" +#include "modules/desktop_capture/desktop_capturer.h" +#include "modules/desktop_capture/win/test_support/test_window.h" +#include "modules/desktop_capture/win/window_capture_utils.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/thread.h" +#include "rtc_base/time_utils.h" +#include "rtc_base/win/scoped_com_initializer.h" +#include "rtc_base/win/windows_version.h" +#include "system_wrappers/include/metrics.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +const char kWindowThreadName[] = "wgc_capturer_test_window_thread"; +const WCHAR kWindowTitle[] = L"WGC Capturer Test Window"; + +const char kCapturerImplHistogram[] = + "WebRTC.DesktopCapture.Win.DesktopCapturerImpl"; + +const char kCapturerResultHistogram[] = + "WebRTC.DesktopCapture.Win.WgcCapturerResult"; +const int kSuccess = 0; +const int kSessionStartFailure = 4; + +const char kCaptureSessionResultHistogram[] = + "WebRTC.DesktopCapture.Win.WgcCaptureSessionStartResult"; +const int kSourceClosed = 1; + +const char kCaptureTimeHistogram[] = + "WebRTC.DesktopCapture.Win.WgcCapturerFrameTime"; + +const int kSmallWindowWidth = 200; +const int kSmallWindowHeight = 100; +const int kMediumWindowWidth = 300; +const int kMediumWindowHeight = 200; +const int kLargeWindowWidth = 400; +const int kLargeWindowHeight = 500; + +// The size of the image we capture is slightly smaller than the actual size of +// the window. +const int kWindowWidthSubtrahend = 14; +const int kWindowHeightSubtrahend = 7; + +// Custom message constants so we can direct our thread to close windows +// and quit running. +const UINT kNoOp = WM_APP; +const UINT kDestroyWindow = WM_APP + 1; +const UINT kQuitRunning = WM_APP + 2; + +enum CaptureType { kWindowCapture = 0, kScreenCapture = 1 }; + +} // namespace + +class WgcCapturerWinTest : public ::testing::TestWithParam, + public DesktopCapturer::Callback { + public: + void SetUp() override { + if (rtc::rtc_win::GetVersion() < rtc::rtc_win::Version::VERSION_WIN10_RS5) { + RTC_LOG(LS_INFO) + << "Skipping WgcCapturerWinTests on Windows versions < RS5."; + GTEST_SKIP(); + } + + com_initializer_ = + std::make_unique(ScopedCOMInitializer::kMTA); + EXPECT_TRUE(com_initializer_->Succeeded()); + } + + void SetUpForWindowCapture(int window_width = kMediumWindowWidth, + int window_height = kMediumWindowHeight) { + capturer_ = WgcCapturerWin::CreateRawWindowCapturer( + DesktopCaptureOptions::CreateDefault()); + CreateWindowOnSeparateThread(window_width, window_height); + StartWindowThreadMessageLoop(); + source_id_ = GetTestWindowIdFromSourceList(); + } + + void SetUpForScreenCapture() { + capturer_ = WgcCapturerWin::CreateRawScreenCapturer( + DesktopCaptureOptions::CreateDefault()); + source_id_ = GetScreenIdFromSourceList(); + } + + void TearDown() override { + if (window_open_) { + CloseTestWindow(); + } + } + + // The window must live on a separate thread so that we can run a message pump + // without blocking the test thread. This is necessary if we are interested in + // having GraphicsCaptureItem events (i.e. the Closed event) fire, and it more + // closely resembles how capture works in the wild. + void CreateWindowOnSeparateThread(int window_width, int window_height) { + window_thread_ = rtc::Thread::Create(); + window_thread_->SetName(kWindowThreadName, nullptr); + window_thread_->Start(); + window_thread_->Invoke(RTC_FROM_HERE, [this, window_width, + window_height]() { + window_thread_id_ = GetCurrentThreadId(); + window_info_ = + CreateTestWindow(kWindowTitle, window_height, window_width); + window_open_ = true; + + while (!IsWindowResponding(window_info_.hwnd)) { + RTC_LOG(LS_INFO) << "Waiting for test window to become responsive in " + "WgcWindowCaptureTest."; + } + + while (!IsWindowValidAndVisible(window_info_.hwnd)) { + RTC_LOG(LS_INFO) << "Waiting for test window to be visible in " + "WgcWindowCaptureTest."; + } + }); + + ASSERT_TRUE(window_thread_->RunningForTest()); + ASSERT_FALSE(window_thread_->IsCurrent()); + } + + void StartWindowThreadMessageLoop() { + window_thread_->PostTask(RTC_FROM_HERE, [this]() { + MSG msg; + BOOL gm; + while ((gm = ::GetMessage(&msg, NULL, 0, 0)) != 0 && gm != -1) { + ::DispatchMessage(&msg); + if (msg.message == kDestroyWindow) { + DestroyTestWindow(window_info_); + } + if (msg.message == kQuitRunning) { + PostQuitMessage(0); + } + } + }); + } + + void CloseTestWindow() { + ::PostThreadMessage(window_thread_id_, kDestroyWindow, 0, 0); + ::PostThreadMessage(window_thread_id_, kQuitRunning, 0, 0); + window_thread_->Stop(); + window_open_ = false; + } + + DesktopCapturer::SourceId GetTestWindowIdFromSourceList() { + // Frequently, the test window will not show up in GetSourceList because it + // was created too recently. Since we are confident the window will be found + // eventually we loop here until we find it. + intptr_t src_id; + do { + DesktopCapturer::SourceList sources; + EXPECT_TRUE(capturer_->GetSourceList(&sources)); + + auto it = std::find_if( + sources.begin(), sources.end(), + [&](const DesktopCapturer::Source& src) { + return src.id == reinterpret_cast(window_info_.hwnd); + }); + + src_id = it->id; + } while (src_id != reinterpret_cast(window_info_.hwnd)); + + return src_id; + } + + DesktopCapturer::SourceId GetScreenIdFromSourceList() { + DesktopCapturer::SourceList sources; + EXPECT_TRUE(capturer_->GetSourceList(&sources)); + EXPECT_GT(sources.size(), 0ULL); + return sources[0].id; + } + + void DoCapture() { + // Sometimes the first few frames are empty becaues the capture engine is + // still starting up. We also may drop a few frames when the window is + // resized or un-minimized. + do { + capturer_->CaptureFrame(); + } while (result_ == DesktopCapturer::Result::ERROR_TEMPORARY); + + EXPECT_EQ(result_, DesktopCapturer::Result::SUCCESS); + EXPECT_TRUE(frame_); + + EXPECT_GT(metrics::NumEvents(kCapturerResultHistogram, kSuccess), + successful_captures_); + ++successful_captures_; + } + + void ValidateFrame(int expected_width, int expected_height) { + EXPECT_EQ(frame_->size().width(), expected_width - kWindowWidthSubtrahend); + EXPECT_EQ(frame_->size().height(), + expected_height - kWindowHeightSubtrahend); + + // Verify the buffer contains as much data as it should, and that the right + // colors are found. + int data_length = frame_->stride() * frame_->size().height(); + + // The first and last pixel should have the same color because they will be + // from the border of the window. + // Pixels have 4 bytes of data so the whole pixel needs a uint32_t to fit. + uint32_t first_pixel = static_cast(*frame_->data()); + uint32_t last_pixel = static_cast( + *(frame_->data() + data_length - DesktopFrame::kBytesPerPixel)); + EXPECT_EQ(first_pixel, last_pixel); + + // Let's also check a pixel from the middle of the content area, which the + // TestWindow will paint a consistent color for us to verify. + uint8_t* middle_pixel = frame_->data() + (data_length / 2); + + int sub_pixel_offset = DesktopFrame::kBytesPerPixel / 4; + EXPECT_EQ(*middle_pixel, kTestWindowBValue); + middle_pixel += sub_pixel_offset; + EXPECT_EQ(*middle_pixel, kTestWindowGValue); + middle_pixel += sub_pixel_offset; + EXPECT_EQ(*middle_pixel, kTestWindowRValue); + middle_pixel += sub_pixel_offset; + + // The window is opaque so we expect 0xFF for the Alpha channel. + EXPECT_EQ(*middle_pixel, 0xFF); + } + + // DesktopCapturer::Callback interface + // The capturer synchronously invokes this method before |CaptureFrame()| + // returns. + void OnCaptureResult(DesktopCapturer::Result result, + std::unique_ptr frame) override { + result_ = result; + frame_ = std::move(frame); + } + + protected: + std::unique_ptr com_initializer_; + DWORD window_thread_id_; + std::unique_ptr window_thread_; + WindowInfo window_info_; + intptr_t source_id_; + bool window_open_ = false; + DesktopCapturer::Result result_; + int successful_captures_ = 0; + std::unique_ptr frame_; + std::unique_ptr capturer_; +}; + +TEST_P(WgcCapturerWinTest, SelectValidSource) { + if (GetParam() == CaptureType::kWindowCapture) { + SetUpForWindowCapture(); + } else { + SetUpForScreenCapture(); + } + + EXPECT_TRUE(capturer_->SelectSource(source_id_)); +} + +TEST_P(WgcCapturerWinTest, SelectInvalidSource) { + if (GetParam() == CaptureType::kWindowCapture) { + capturer_ = WgcCapturerWin::CreateRawWindowCapturer( + DesktopCaptureOptions::CreateDefault()); + source_id_ = kNullWindowId; + } else { + capturer_ = WgcCapturerWin::CreateRawScreenCapturer( + DesktopCaptureOptions::CreateDefault()); + source_id_ = kInvalidScreenId; + } + + EXPECT_FALSE(capturer_->SelectSource(source_id_)); +} + +TEST_P(WgcCapturerWinTest, Capture) { + if (GetParam() == CaptureType::kWindowCapture) { + SetUpForWindowCapture(); + } else { + SetUpForScreenCapture(); + } + + EXPECT_TRUE(capturer_->SelectSource(source_id_)); + + capturer_->Start(this); + EXPECT_GE(metrics::NumEvents(kCapturerImplHistogram, + DesktopCapturerId::kWgcCapturerWin), + 1); + + DoCapture(); + EXPECT_GT(frame_->size().width(), 0); + EXPECT_GT(frame_->size().height(), 0); +} + +TEST_P(WgcCapturerWinTest, CaptureTime) { + if (GetParam() == CaptureType::kWindowCapture) { + SetUpForWindowCapture(); + } else { + SetUpForScreenCapture(); + } + + EXPECT_TRUE(capturer_->SelectSource(source_id_)); + capturer_->Start(this); + + int64_t start_time; + do { + start_time = rtc::TimeNanos(); + capturer_->CaptureFrame(); + } while (result_ == DesktopCapturer::Result::ERROR_TEMPORARY); + + int capture_time_ms = + (rtc::TimeNanos() - start_time) / rtc::kNumNanosecsPerMillisec; + EXPECT_TRUE(frame_); + + // The test may measure the time slightly differently than the capturer. So we + // just check if it's within 5 ms. + EXPECT_NEAR(frame_->capture_time_ms(), capture_time_ms, 5); + EXPECT_GE( + metrics::NumEvents(kCaptureTimeHistogram, frame_->capture_time_ms()), 1); +} + +INSTANTIATE_TEST_SUITE_P(SourceAgnostic, + WgcCapturerWinTest, + ::testing::Values(CaptureType::kWindowCapture, + CaptureType::kScreenCapture)); + +// Monitor specific tests. +TEST_F(WgcCapturerWinTest, FocusOnMonitor) { + SetUpForScreenCapture(); + EXPECT_TRUE(capturer_->SelectSource(0)); + + // You can't set focus on a monitor. + EXPECT_FALSE(capturer_->FocusOnSelectedSource()); +} + +TEST_F(WgcCapturerWinTest, CaptureAllMonitors) { + SetUpForScreenCapture(); + EXPECT_TRUE(capturer_->SelectSource(kFullDesktopScreenId)); + + capturer_->Start(this); + DoCapture(); + EXPECT_GT(frame_->size().width(), 0); + EXPECT_GT(frame_->size().height(), 0); +} + +// Window specific tests. +TEST_F(WgcCapturerWinTest, FocusOnWindow) { + capturer_ = WgcCapturerWin::CreateRawWindowCapturer( + DesktopCaptureOptions::CreateDefault()); + window_info_ = CreateTestWindow(kWindowTitle); + source_id_ = GetScreenIdFromSourceList(); + + EXPECT_TRUE(capturer_->SelectSource(source_id_)); + EXPECT_TRUE(capturer_->FocusOnSelectedSource()); + + HWND hwnd = reinterpret_cast(source_id_); + EXPECT_EQ(hwnd, ::GetActiveWindow()); + EXPECT_EQ(hwnd, ::GetForegroundWindow()); + EXPECT_EQ(hwnd, ::GetFocus()); + DestroyTestWindow(window_info_); +} + +TEST_F(WgcCapturerWinTest, SelectMinimizedWindow) { + SetUpForWindowCapture(); + MinimizeTestWindow(reinterpret_cast(source_id_)); + EXPECT_FALSE(capturer_->SelectSource(source_id_)); + + UnminimizeTestWindow(reinterpret_cast(source_id_)); + EXPECT_TRUE(capturer_->SelectSource(source_id_)); +} + +TEST_F(WgcCapturerWinTest, SelectClosedWindow) { + SetUpForWindowCapture(); + EXPECT_TRUE(capturer_->SelectSource(source_id_)); + + CloseTestWindow(); + EXPECT_FALSE(capturer_->SelectSource(source_id_)); +} + +TEST_F(WgcCapturerWinTest, UnsupportedWindowStyle) { + // Create a window with the WS_EX_TOOLWINDOW style, which WGC does not + // support. + window_info_ = CreateTestWindow(kWindowTitle, kMediumWindowWidth, + kMediumWindowHeight, WS_EX_TOOLWINDOW); + capturer_ = WgcCapturerWin::CreateRawWindowCapturer( + DesktopCaptureOptions::CreateDefault()); + DesktopCapturer::SourceList sources; + EXPECT_TRUE(capturer_->GetSourceList(&sources)); + auto it = std::find_if( + sources.begin(), sources.end(), [&](const DesktopCapturer::Source& src) { + return src.id == reinterpret_cast(window_info_.hwnd); + }); + + // We should not find the window, since we filter for unsupported styles. + EXPECT_EQ(it, sources.end()); + DestroyTestWindow(window_info_); +} + +TEST_F(WgcCapturerWinTest, IncreaseWindowSizeMidCapture) { + SetUpForWindowCapture(kSmallWindowWidth, kSmallWindowHeight); + EXPECT_TRUE(capturer_->SelectSource(source_id_)); + + capturer_->Start(this); + DoCapture(); + ValidateFrame(kSmallWindowWidth, kSmallWindowHeight); + + ResizeTestWindow(window_info_.hwnd, kSmallWindowWidth, kMediumWindowHeight); + DoCapture(); + // We don't expect to see the new size until the next capture, as the frame + // pool hadn't had a chance to resize yet to fit the new, larger image. + DoCapture(); + ValidateFrame(kSmallWindowWidth, kMediumWindowHeight); + + ResizeTestWindow(window_info_.hwnd, kLargeWindowWidth, kMediumWindowHeight); + DoCapture(); + DoCapture(); + ValidateFrame(kLargeWindowWidth, kMediumWindowHeight); +} + +TEST_F(WgcCapturerWinTest, ReduceWindowSizeMidCapture) { + SetUpForWindowCapture(kLargeWindowWidth, kLargeWindowHeight); + EXPECT_TRUE(capturer_->SelectSource(source_id_)); + + capturer_->Start(this); + DoCapture(); + ValidateFrame(kLargeWindowWidth, kLargeWindowHeight); + + ResizeTestWindow(window_info_.hwnd, kLargeWindowWidth, kMediumWindowHeight); + // We expect to see the new size immediately because the image data has shrunk + // and will fit in the existing buffer. + DoCapture(); + ValidateFrame(kLargeWindowWidth, kMediumWindowHeight); + + ResizeTestWindow(window_info_.hwnd, kSmallWindowWidth, kMediumWindowHeight); + DoCapture(); + ValidateFrame(kSmallWindowWidth, kMediumWindowHeight); +} + +TEST_F(WgcCapturerWinTest, MinimizeWindowMidCapture) { + SetUpForWindowCapture(); + EXPECT_TRUE(capturer_->SelectSource(source_id_)); + + capturer_->Start(this); + + // Minmize the window and capture should continue but return temporary errors. + MinimizeTestWindow(window_info_.hwnd); + for (int i = 0; i < 10; ++i) { + capturer_->CaptureFrame(); + EXPECT_EQ(result_, DesktopCapturer::Result::ERROR_TEMPORARY); + } + + // Reopen the window and the capture should continue normally. + UnminimizeTestWindow(window_info_.hwnd); + DoCapture(); + // We can't verify the window size here because the test window does not + // repaint itself after it is unminimized, but capturing successfully is still + // a good test. +} + +TEST_F(WgcCapturerWinTest, CloseWindowMidCapture) { + SetUpForWindowCapture(); + EXPECT_TRUE(capturer_->SelectSource(source_id_)); + + capturer_->Start(this); + DoCapture(); + ValidateFrame(kMediumWindowWidth, kMediumWindowHeight); + + CloseTestWindow(); + + // We need to call GetMessage to trigger the Closed event and the capturer's + // event handler for it. If we are too early and the Closed event hasn't + // arrived yet we should keep trying until the capturer receives it and stops. + auto* wgc_capturer = static_cast(capturer_.get()); + while (wgc_capturer->IsSourceBeingCaptured(source_id_)) { + // Since the capturer handles the Closed message, there will be no message + // for us and GetMessage will hang, unless we send ourselves a message + // first. + ::PostThreadMessage(GetCurrentThreadId(), kNoOp, 0, 0); + MSG msg; + ::GetMessage(&msg, NULL, 0, 0); + ::DispatchMessage(&msg); + } + + // Occasionally, one last frame will have made it into the frame pool before + // the window closed. The first call will consume it, and in that case we need + // to make one more call to CaptureFrame. + capturer_->CaptureFrame(); + if (result_ == DesktopCapturer::Result::SUCCESS) + capturer_->CaptureFrame(); + + EXPECT_GE(metrics::NumEvents(kCapturerResultHistogram, kSessionStartFailure), + 1); + EXPECT_GE(metrics::NumEvents(kCaptureSessionResultHistogram, kSourceClosed), + 1); + EXPECT_EQ(result_, DesktopCapturer::Result::ERROR_PERMANENT); +} + +} // namespace webrtc diff --git a/modules/desktop_capture/win/wgc_desktop_frame.cc b/modules/desktop_capture/win/wgc_desktop_frame.cc new file mode 100644 index 0000000000..dd9009120b --- /dev/null +++ b/modules/desktop_capture/win/wgc_desktop_frame.cc @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/desktop_capture/win/wgc_desktop_frame.h" + +#include + +namespace webrtc { + +WgcDesktopFrame::WgcDesktopFrame(DesktopSize size, + int stride, + std::vector&& image_data) + : DesktopFrame(size, stride, image_data.data(), nullptr), + image_data_(std::move(image_data)) {} + +WgcDesktopFrame::~WgcDesktopFrame() = default; + +} // namespace webrtc diff --git a/modules/desktop_capture/win/wgc_desktop_frame.h b/modules/desktop_capture/win/wgc_desktop_frame.h new file mode 100644 index 0000000000..0eca763f9e --- /dev/null +++ b/modules/desktop_capture/win/wgc_desktop_frame.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_DESKTOP_CAPTURE_WIN_WGC_DESKTOP_FRAME_H_ +#define MODULES_DESKTOP_CAPTURE_WIN_WGC_DESKTOP_FRAME_H_ + +#include +#include + +#include +#include + +#include "modules/desktop_capture/desktop_frame.h" +#include "modules/desktop_capture/desktop_geometry.h" + +namespace webrtc { + +// DesktopFrame implementation used by capturers that use the +// Windows.Graphics.Capture API. +class WgcDesktopFrame final : public DesktopFrame { + public: + // WgcDesktopFrame receives an rvalue reference to the |image_data| vector + // so that it can take ownership of it (and avoid a copy). + WgcDesktopFrame(DesktopSize size, + int stride, + std::vector&& image_data); + + WgcDesktopFrame(const WgcDesktopFrame&) = delete; + WgcDesktopFrame& operator=(const WgcDesktopFrame&) = delete; + + ~WgcDesktopFrame() override; + + private: + std::vector image_data_; +}; + +} // namespace webrtc + +#endif // MODULES_DESKTOP_CAPTURE_WIN_WGC_DESKTOP_FRAME_H_ diff --git a/modules/desktop_capture/win/window_capture_utils.cc b/modules/desktop_capture/win/window_capture_utils.cc index 9e33e56c2d..aaaef0a80d 100644 --- a/modules/desktop_capture/win/window_capture_utils.cc +++ b/modules/desktop_capture/win/window_capture_utils.cc @@ -27,27 +27,26 @@ namespace webrtc { namespace { struct GetWindowListParams { - GetWindowListParams(int flags, DesktopCapturer::SourceList* result) - : ignoreUntitled(flags & GetWindowListFlags::kIgnoreUntitled), - ignoreUnresponsive(flags & GetWindowListFlags::kIgnoreUnresponsive), + GetWindowListParams(int flags, + LONG ex_style_filters, + DesktopCapturer::SourceList* result) + : ignore_untitled(flags & GetWindowListFlags::kIgnoreUntitled), + ignore_unresponsive(flags & GetWindowListFlags::kIgnoreUnresponsive), + ignore_current_process_windows( + flags & GetWindowListFlags::kIgnoreCurrentProcessWindows), + ex_style_filters(ex_style_filters), result(result) {} - const bool ignoreUntitled; - const bool ignoreUnresponsive; + const bool ignore_untitled; + const bool ignore_unresponsive; + const bool ignore_current_process_windows; + const LONG ex_style_filters; DesktopCapturer::SourceList* const result; }; -// If a window is owned by the current process and unresponsive, then making a -// blocking call such as GetWindowText may lead to a deadlock. -// -// https://docs.microsoft.com/en-us/windows/win32/api/winuser/nf-winuser-getwindowtexta#remarks -bool CanSafelyMakeBlockingCalls(HWND hwnd) { +bool IsWindowOwnedByCurrentProcess(HWND hwnd) { DWORD process_id; GetWindowThreadProcessId(hwnd, &process_id); - if (process_id != GetCurrentProcessId() || IsWindowResponding(hwnd)) { - return true; - } - - return false; + return process_id == GetCurrentProcessId(); } BOOL CALLBACK GetWindowListHandler(HWND hwnd, LPARAM param) { @@ -67,7 +66,13 @@ BOOL CALLBACK GetWindowListHandler(HWND hwnd, LPARAM param) { return TRUE; } - if (params->ignoreUnresponsive && !IsWindowResponding(hwnd)) { + // Filter out windows that match the extended styles the caller has specified, + // e.g. WS_EX_TOOLWINDOW for capturers that don't support overlay windows. + if (exstyle & params->ex_style_filters) { + return TRUE; + } + + if (params->ignore_unresponsive && !IsWindowResponding(hwnd)) { return TRUE; } @@ -75,11 +80,26 @@ BOOL CALLBACK GetWindowListHandler(HWND hwnd, LPARAM param) { window.id = reinterpret_cast(hwnd); // GetWindowText* are potentially blocking operations if |hwnd| is - // owned by the current process, and can lead to a deadlock if the message - // pump is waiting on this thread. If we've filtered out unresponsive - // windows, this is not a concern, but otherwise we need to check if we can - // safely make blocking calls. - if (params->ignoreUnresponsive || CanSafelyMakeBlockingCalls(hwnd)) { + // owned by the current process. The APIs will send messages to the window's + // message loop, and if the message loop is waiting on this operation we will + // enter a deadlock. + // https://docs.microsoft.com/en-us/windows/win32/api/winuser/nf-winuser-getwindowtexta#remarks + // + // To help consumers avoid this, there is a DesktopCaptureOption to ignore + // windows owned by the current process. Consumers should either ensure that + // the thread running their message loop never waits on this operation, or use + // the option to exclude these windows from the source list. + bool owned_by_current_process = IsWindowOwnedByCurrentProcess(hwnd); + if (owned_by_current_process && params->ignore_current_process_windows) { + return TRUE; + } + + // Even if consumers request to enumerate windows owned by the current + // process, we should not call GetWindowText* on unresponsive windows owned by + // the current process because we will hang. Unfortunately, we could still + // hang if the window becomes unresponsive after this check, hence the option + // to avoid these completely. + if (!owned_by_current_process || IsWindowResponding(hwnd)) { const size_t kTitleLength = 500; WCHAR window_title[kTitleLength] = L""; if (GetWindowTextLength(hwnd) != 0 && @@ -89,7 +109,7 @@ BOOL CALLBACK GetWindowListHandler(HWND hwnd, LPARAM param) { } // Skip windows when we failed to convert the title or it is empty. - if (params->ignoreUntitled && window.title.empty()) + if (params->ignore_untitled && window.title.empty()) return TRUE; // Capture the window class name, to allow specific window classes to be @@ -271,8 +291,10 @@ bool IsWindowResponding(HWND window) { nullptr); } -bool GetWindowList(int flags, DesktopCapturer::SourceList* windows) { - GetWindowListParams params(flags, windows); +bool GetWindowList(int flags, + DesktopCapturer::SourceList* windows, + LONG ex_style_filters) { + GetWindowListParams params(flags, ex_style_filters, windows); return ::EnumWindows(&GetWindowListHandler, reinterpret_cast(¶ms)) != 0; } @@ -432,10 +454,16 @@ bool WindowCaptureHelperWin::IsWindowCloaked(HWND hwnd) { } bool WindowCaptureHelperWin::EnumerateCapturableWindows( - DesktopCapturer::SourceList* results) { - if (!webrtc::GetWindowList((GetWindowListFlags::kIgnoreUntitled | - GetWindowListFlags::kIgnoreUnresponsive), - results)) { + DesktopCapturer::SourceList* results, + bool enumerate_current_process_windows, + LONG ex_style_filters) { + int flags = (GetWindowListFlags::kIgnoreUntitled | + GetWindowListFlags::kIgnoreUnresponsive); + if (!enumerate_current_process_windows) { + flags |= GetWindowListFlags::kIgnoreCurrentProcessWindows; + } + + if (!webrtc::GetWindowList(flags, results, ex_style_filters)) { return false; } diff --git a/modules/desktop_capture/win/window_capture_utils.h b/modules/desktop_capture/win/window_capture_utils.h index f636a312f5..a6a295d068 100644 --- a/modules/desktop_capture/win/window_capture_utils.h +++ b/modules/desktop_capture/win/window_capture_utils.h @@ -78,6 +78,7 @@ enum GetWindowListFlags { kNone = 0x00, kIgnoreUntitled = 1 << 0, kIgnoreUnresponsive = 1 << 1, + kIgnoreCurrentProcessWindows = 1 << 2, }; // Retrieves the list of top-level windows on the screen. @@ -85,9 +86,13 @@ enum GetWindowListFlags { // - Those that are invisible or minimized. // - Program Manager & Start menu. // - [with kIgnoreUntitled] windows with no title. -// - [with kIgnoreUnresponsive] windows that unresponsive. +// - [with kIgnoreUnresponsive] windows that are unresponsive. +// - [with kIgnoreCurrentProcessWindows] windows owned by the current process. +// - Any windows with extended styles that match |ex_style_filters|. // Returns false if native APIs failed. -bool GetWindowList(int flags, DesktopCapturer::SourceList* windows); +bool GetWindowList(int flags, + DesktopCapturer::SourceList* windows, + LONG ex_style_filters = 0); typedef HRESULT(WINAPI* DwmIsCompositionEnabledFunc)(BOOL* enabled); typedef HRESULT(WINAPI* DwmGetWindowAttributeFunc)(HWND hwnd, @@ -107,7 +112,13 @@ class WindowCaptureHelperWin { bool IsWindowOnCurrentDesktop(HWND hwnd); bool IsWindowVisibleOnCurrentDesktop(HWND hwnd); bool IsWindowCloaked(HWND hwnd); - bool EnumerateCapturableWindows(DesktopCapturer::SourceList* results); + + // The optional |ex_style_filters| parameter allows callers to provide + // extended window styles (e.g. WS_EX_TOOLWINDOW) and prevent windows that + // match from being included in |results|. + bool EnumerateCapturableWindows(DesktopCapturer::SourceList* results, + bool enumerate_current_process_windows, + LONG ex_style_filters = 0); private: HMODULE dwmapi_library_ = nullptr; diff --git a/modules/desktop_capture/win/window_capture_utils_unittest.cc b/modules/desktop_capture/win/window_capture_utils_unittest.cc index 52f6714383..4b426fc464 100644 --- a/modules/desktop_capture/win/window_capture_utils_unittest.cc +++ b/modules/desktop_capture/win/window_capture_utils_unittest.cc @@ -137,4 +137,18 @@ TEST(WindowCaptureUtilsTest, IgnoreUntitledWindows) { DestroyTestWindow(info); } +TEST(WindowCaptureUtilsTest, IgnoreCurrentProcessWindows) { + WindowInfo info = CreateTestWindow(kWindowTitle); + DesktopCapturer::SourceList window_list; + ASSERT_TRUE(GetWindowList(GetWindowListFlags::kIgnoreCurrentProcessWindows, + &window_list)); + EXPECT_EQ(std::find_if(window_list.begin(), window_list.end(), + [&info](DesktopCapturer::Source window) { + return reinterpret_cast(window.id) == + info.hwnd; + }), + window_list.end()); + DestroyTestWindow(info); +} + } // namespace webrtc diff --git a/modules/desktop_capture/win/window_capturer_win_gdi.cc b/modules/desktop_capture/win/window_capturer_win_gdi.cc index 04cd7f667d..25677e9868 100644 --- a/modules/desktop_capture/win/window_capturer_win_gdi.cc +++ b/modules/desktop_capture/win/window_capturer_win_gdi.cc @@ -17,6 +17,8 @@ #include #include "modules/desktop_capture/cropped_desktop_frame.h" +#include "modules/desktop_capture/desktop_capture_metrics_helper.h" +#include "modules/desktop_capture/desktop_capture_types.h" #include "modules/desktop_capture/desktop_capturer.h" #include "modules/desktop_capture/desktop_frame_win.h" #include "modules/desktop_capture/win/screen_capture_utils.h" @@ -25,8 +27,10 @@ #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/string_utils.h" +#include "rtc_base/time_utils.h" #include "rtc_base/trace_event.h" #include "rtc_base/win32.h" +#include "system_wrappers/include/metrics.h" namespace webrtc { @@ -91,11 +95,14 @@ BOOL CALLBACK OwnedWindowCollector(HWND hwnd, LPARAM param) { return TRUE; } -WindowCapturerWinGdi::WindowCapturerWinGdi() {} +WindowCapturerWinGdi::WindowCapturerWinGdi( + bool enumerate_current_process_windows) + : enumerate_current_process_windows_(enumerate_current_process_windows) {} WindowCapturerWinGdi::~WindowCapturerWinGdi() {} bool WindowCapturerWinGdi::GetSourceList(SourceList* sources) { - if (!window_capture_helper_.EnumerateCapturableWindows(sources)) + if (!window_capture_helper_.EnumerateCapturableWindows( + sources, enumerate_current_process_windows_)) return false; std::map new_map; @@ -143,14 +150,27 @@ bool WindowCapturerWinGdi::IsOccluded(const DesktopVector& pos) { void WindowCapturerWinGdi::Start(Callback* callback) { RTC_DCHECK(!callback_); RTC_DCHECK(callback); + RecordCapturerImpl(DesktopCapturerId::kWindowCapturerWinGdi); callback_ = callback; } void WindowCapturerWinGdi::CaptureFrame() { RTC_DCHECK(callback_); + int64_t capture_start_time_nanos = rtc::TimeNanos(); CaptureResults results = CaptureFrame(/*capture_owned_windows*/ true); + + if (results.frame) { + int capture_time_ms = (rtc::TimeNanos() - capture_start_time_nanos) / + rtc::kNumNanosecsPerMillisec; + RTC_HISTOGRAM_COUNTS_1000( + "WebRTC.DesktopCapture.Win.WindowGdiCapturerFrameTime", + capture_time_ms); + results.frame->set_capture_time_ms(capture_time_ms); + results.frame->set_capturer_id(DesktopCapturerId::kWindowCapturerWinGdi); + } + callback_->OnCaptureResult(results.result, std::move(results.frame)); } @@ -333,7 +353,8 @@ WindowCapturerWinGdi::CaptureResults WindowCapturerWinGdi::CaptureFrame( if (!owned_windows_.empty()) { if (!owned_window_capturer_) { - owned_window_capturer_ = std::make_unique(); + owned_window_capturer_ = std::make_unique( + enumerate_current_process_windows_); } // Owned windows are stored in top-down z-order, so this iterates in @@ -372,7 +393,8 @@ WindowCapturerWinGdi::CaptureResults WindowCapturerWinGdi::CaptureFrame( // static std::unique_ptr WindowCapturerWinGdi::CreateRawWindowCapturer( const DesktopCaptureOptions& options) { - return std::unique_ptr(new WindowCapturerWinGdi()); + return std::unique_ptr( + new WindowCapturerWinGdi(options.enumerate_current_process_windows())); } } // namespace webrtc diff --git a/modules/desktop_capture/win/window_capturer_win_gdi.h b/modules/desktop_capture/win/window_capturer_win_gdi.h index c954c230c9..5091458a12 100644 --- a/modules/desktop_capture/win/window_capturer_win_gdi.h +++ b/modules/desktop_capture/win/window_capturer_win_gdi.h @@ -24,7 +24,7 @@ namespace webrtc { class WindowCapturerWinGdi : public DesktopCapturer { public: - WindowCapturerWinGdi(); + explicit WindowCapturerWinGdi(bool enumerate_current_process_windows); // Disallow copy and assign WindowCapturerWinGdi(const WindowCapturerWinGdi&) = delete; @@ -61,6 +61,8 @@ class WindowCapturerWinGdi : public DesktopCapturer { WindowCaptureHelperWin window_capture_helper_; + bool enumerate_current_process_windows_; + // This map is used to avoid flickering for the case when SelectWindow() calls // are interleaved with Capture() calls. std::map window_size_map_; diff --git a/modules/desktop_capture/win/window_capturer_win_wgc.cc b/modules/desktop_capture/win/window_capturer_win_wgc.cc deleted file mode 100644 index 30a672d9ef..0000000000 --- a/modules/desktop_capture/win/window_capturer_win_wgc.cc +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "modules/desktop_capture/win/window_capturer_win_wgc.h" - -#include - -#include "rtc_base/logging.h" - -namespace webrtc { - -WindowCapturerWinWgc::WindowCapturerWinWgc() = default; -WindowCapturerWinWgc::~WindowCapturerWinWgc() = default; - -bool WindowCapturerWinWgc::GetSourceList(SourceList* sources) { - return window_capture_helper_.EnumerateCapturableWindows(sources); -} - -bool WindowCapturerWinWgc::SelectSource(SourceId id) { - HWND window = reinterpret_cast(id); - if (!IsWindowValidAndVisible(window)) - return false; - - window_ = window; - return true; -} - -void WindowCapturerWinWgc::Start(Callback* callback) { - RTC_DCHECK(!callback_); - RTC_DCHECK(callback); - - callback_ = callback; - - // Create a Direct3D11 device to share amongst the WgcCaptureSessions. Many - // parameters are nullptr as the implemention uses defaults that work well for - // us. - HRESULT hr = D3D11CreateDevice( - /*adapter=*/nullptr, D3D_DRIVER_TYPE_HARDWARE, - /*software_rasterizer=*/nullptr, D3D11_CREATE_DEVICE_BGRA_SUPPORT, - /*feature_levels=*/nullptr, /*feature_levels_size=*/0, D3D11_SDK_VERSION, - &d3d11_device_, /*feature_level=*/nullptr, /*device_context=*/nullptr); - if (hr == DXGI_ERROR_UNSUPPORTED) { - // If a hardware device could not be created, use WARP which is a high speed - // software device. - hr = D3D11CreateDevice( - /*adapter=*/nullptr, D3D_DRIVER_TYPE_WARP, - /*software_rasterizer=*/nullptr, D3D11_CREATE_DEVICE_BGRA_SUPPORT, - /*feature_levels=*/nullptr, /*feature_levels_size=*/0, - D3D11_SDK_VERSION, &d3d11_device_, /*feature_level=*/nullptr, - /*device_context=*/nullptr); - } - - if (FAILED(hr)) { - RTC_LOG(LS_ERROR) << "Failed to create D3D11Device: " << hr; - } -} - -void WindowCapturerWinWgc::CaptureFrame() { - RTC_DCHECK(callback_); - - if (!window_) { - RTC_LOG(LS_ERROR) << "Window hasn't been selected"; - callback_->OnCaptureResult(DesktopCapturer::Result::ERROR_PERMANENT, - /*frame=*/nullptr); - return; - } - - if (!d3d11_device_) { - RTC_LOG(LS_ERROR) << "No D3D11D3evice, cannot capture."; - callback_->OnCaptureResult(DesktopCapturer::Result::ERROR_PERMANENT, - /*frame=*/nullptr); - return; - } - - WgcCaptureSession* capture_session = nullptr; - auto iter = ongoing_captures_.find(window_); - if (iter == ongoing_captures_.end()) { - auto iter_success_pair = ongoing_captures_.emplace( - std::piecewise_construct, std::forward_as_tuple(window_), - std::forward_as_tuple(d3d11_device_, window_)); - if (iter_success_pair.second) { - capture_session = &iter_success_pair.first->second; - } else { - RTC_LOG(LS_ERROR) << "Failed to create new WgcCaptureSession."; - callback_->OnCaptureResult(DesktopCapturer::Result::ERROR_PERMANENT, - /*frame=*/nullptr); - return; - } - } else { - capture_session = &iter->second; - } - - HRESULT hr; - if (!capture_session->IsCaptureStarted()) { - hr = capture_session->StartCapture(); - if (FAILED(hr)) { - RTC_LOG(LS_ERROR) << "Failed to start capture: " << hr; - callback_->OnCaptureResult(DesktopCapturer::Result::ERROR_PERMANENT, - /*frame=*/nullptr); - return; - } - } - - std::unique_ptr frame; - hr = capture_session->GetMostRecentFrame(&frame); - if (FAILED(hr)) { - RTC_LOG(LS_ERROR) << "GetMostRecentFrame failed: " << hr; - callback_->OnCaptureResult(DesktopCapturer::Result::ERROR_PERMANENT, - /*frame=*/nullptr); - return; - } - - if (!frame) { - RTC_LOG(LS_WARNING) << "GetMostRecentFrame returned an empty frame."; - callback_->OnCaptureResult(DesktopCapturer::Result::ERROR_TEMPORARY, - /*frame=*/nullptr); - return; - } - - callback_->OnCaptureResult(DesktopCapturer::Result::SUCCESS, - std::move(frame)); -} - -// static -std::unique_ptr WindowCapturerWinWgc::CreateRawWindowCapturer( - const DesktopCaptureOptions& options) { - return std::unique_ptr(new WindowCapturerWinWgc()); -} - -} // namespace webrtc diff --git a/modules/desktop_capture/win/window_capturer_win_wgc.h b/modules/desktop_capture/win/window_capturer_win_wgc.h deleted file mode 100644 index 7e05b0e541..0000000000 --- a/modules/desktop_capture/win/window_capturer_win_wgc.h +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef MODULES_DESKTOP_CAPTURE_WIN_WINDOW_CAPTURER_WIN_WGC_H_ -#define MODULES_DESKTOP_CAPTURE_WIN_WINDOW_CAPTURER_WIN_WGC_H_ - -#include -#include -#include -#include - -#include "modules/desktop_capture/desktop_capture_options.h" -#include "modules/desktop_capture/desktop_capturer.h" -#include "modules/desktop_capture/win/wgc_capture_session.h" -#include "modules/desktop_capture/win/window_capture_utils.h" - -namespace webrtc { - -class WindowCapturerWinWgc final : public DesktopCapturer { - public: - WindowCapturerWinWgc(); - - WindowCapturerWinWgc(const WindowCapturerWinWgc&) = delete; - WindowCapturerWinWgc& operator=(const WindowCapturerWinWgc&) = delete; - - ~WindowCapturerWinWgc() override; - - static std::unique_ptr CreateRawWindowCapturer( - const DesktopCaptureOptions& options); - - // DesktopCapturer interface. - void Start(Callback* callback) override; - void CaptureFrame() override; - bool GetSourceList(SourceList* sources) override; - bool SelectSource(SourceId id) override; - - private: - // The callback that we deliver frames to, synchronously, before CaptureFrame - // returns. - Callback* callback_ = nullptr; - - // HWND for the currently selected window or nullptr if a window is not - // selected. We may be capturing many other windows, but this is the window - // that we will return a frame for when CaptureFrame is called. - HWND window_ = nullptr; - - // This helps us enumerate the list of windows that we can capture. - WindowCaptureHelperWin window_capture_helper_; - - // A Direct3D11 device that is shared amongst the WgcCaptureSessions, who - // require one to perform the capture. - Microsoft::WRL::ComPtr<::ID3D11Device> d3d11_device_; - - // A map of all the windows we are capturing and the associated - // WgcCaptureSession. This is where we will get the frames for the window - // from, when requested. - std::map ongoing_captures_; -}; - -} // namespace webrtc - -#endif // MODULES_DESKTOP_CAPTURE_WIN_WINDOW_CAPTURER_WIN_WGC_H_ diff --git a/modules/desktop_capture/window_capturer_null.cc b/modules/desktop_capture/window_capturer_null.cc index 66e76a50fb..e7c7b0a134 100644 --- a/modules/desktop_capture/window_capturer_null.cc +++ b/modules/desktop_capture/window_capturer_null.cc @@ -8,10 +8,9 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include - #include "modules/desktop_capture/desktop_capturer.h" #include "modules/desktop_capture/desktop_frame.h" +#include "rtc_base/checks.h" #include "rtc_base/constructor_magic.h" namespace webrtc { @@ -49,8 +48,8 @@ bool WindowCapturerNull::SelectSource(SourceId id) { } void WindowCapturerNull::Start(Callback* callback) { - assert(!callback_); - assert(callback); + RTC_DCHECK(!callback_); + RTC_DCHECK(callback); callback_ = callback; } diff --git a/modules/desktop_capture/window_capturer_unittest.cc b/modules/desktop_capture/window_capturer_unittest.cc index 8a611e760a..519c04601b 100644 --- a/modules/desktop_capture/window_capturer_unittest.cc +++ b/modules/desktop_capture/window_capturer_unittest.cc @@ -44,7 +44,13 @@ class WindowCapturerTest : public ::testing::Test, }; // Verify that we can enumerate windows. -TEST_F(WindowCapturerTest, Enumerate) { +// TODO(bugs.webrtc.org/12950): Re-enable when libc++ issue is fixed +#if defined(WEBRTC_LINUX) && defined(MEMORY_SANITIZER) +#define MAYBE_Enumerate DISABLED_Enumerate +#else +#define MAYBE_Enumerate Enumerate +#endif +TEST_F(WindowCapturerTest, MAYBE_Enumerate) { DesktopCapturer::SourceList sources; EXPECT_TRUE(capturer_->GetSourceList(&sources)); @@ -54,8 +60,9 @@ TEST_F(WindowCapturerTest, Enumerate) { } } -// Flaky on Linux. See: crbug.com/webrtc/7830 -#if defined(WEBRTC_LINUX) +// Flaky on Linux. See: crbug.com/webrtc/7830. +// Failing on macOS 11: See bugs.webrtc.org/12801 +#if defined(WEBRTC_LINUX) || defined(WEBRTC_MAC) #define MAYBE_Capture DISABLED_Capture #else #define MAYBE_Capture Capture diff --git a/modules/pacing/BUILD.gn b/modules/pacing/BUILD.gn index cabcd9300b..0787105f14 100644 --- a/modules/pacing/BUILD.gn +++ b/modules/pacing/BUILD.gn @@ -34,6 +34,7 @@ rtc_library("pacing") { ":interval_budget", "..:module_api", "../../api:function_view", + "../../api:sequence_checker", "../../api/rtc_event_log", "../../api/task_queue:task_queue", "../../api/transport:field_trial_based_config", @@ -50,7 +51,6 @@ rtc_library("pacing") { "../../rtc_base:rtc_task_queue", "../../rtc_base/experiments:field_trial_parser", "../../rtc_base/synchronization:mutex", - "../../rtc_base/synchronization:sequence_checker", "../../rtc_base/task_utils:to_queued_task", "../../system_wrappers", "../../system_wrappers:metrics", diff --git a/modules/pacing/g3doc/index.md b/modules/pacing/g3doc/index.md new file mode 100644 index 0000000000..4187a8bd9b --- /dev/null +++ b/modules/pacing/g3doc/index.md @@ -0,0 +1,169 @@ + + + +# Paced Sending + +The paced sender, often referred to as just the "pacer", is a part of the WebRTC +RTP stack used primarily to smooth the flow of packets sent onto the network. + +## Background + +Consider a video stream at 5Mbps and 30fps. This would in an ideal world result +in each frame being ~21kB large and packetized into 18 RTP packets. While the +average bitrate over say a one second sliding window would be a correct 5Mbps, +on a shorter time scale it can be seen as a burst of 167Mbps every 33ms, each +followed by a 32ms silent period. Further, it is quite common that video +encoders overshoot the target frame size in case of sudden movement especially +dealing with screensharing. Frames being 10x or even 100x larger than the ideal +size is an all too real scenario. These packet bursts can cause several issues, +such as congesting networks and causing buffer bloat or even packet loss. Most +sessions have more than one media stream, e.g. a video and an audio track. If +you put a frame on the wire in one go, and those packets take 100ms to reach the +other side - that means you have now blocked any audio packets from reaching the +remote end in time as well. + +The paced sender solves this by having a buffer in which media is queued, and +then using a _leaky bucket_ algorithm to pace them onto the network. The buffer +contains separate fifo streams for all media tracks so that e.g. audio can be +prioritized over video - and equal prio streams can be sent in a round-robin +fashion to avoid any one stream blocking others. + +Since the pacer is in control of the bitrate sent on the wire, it is also used +to generate padding in cases where a minimum send rate is required - and to +generate packet trains if bitrate probing is used. + +## Life of a Packet + +The typical path for media packets when using the paced sender looks something +like this: + +1. `RTPSenderVideo` or `RTPSenderAudio` packetizes media into RTP packets. +2. The packets are sent to the [RTPSender] class for transmission. +3. The pacer is called via [RtpPacketSender] interface to enqueue the packet + batch. +4. The packets are put into a queue within the pacer awaiting opportune moments + to send them. +5. At a calculated time, the pacer calls the `PacingController::PacketSender()` + callback method, normally implemented by the [PacketRouter] class. +6. The router forwards the packet to the correct RTP module based on the + packet's SSRC, and in which the `RTPSenderEgress` class makes final time + stamping, potentially records it for retransmissions etc. +7. The packet is sent to the low-level `Transport` interface, after which it is + now out of scope. + +Asynchronously to this, the estimated available send bandwidth is determined - +and the target send rate is set on the `RtpPacketPacker` via the `void +SetPacingRates(DataRate pacing_rate, DataRate padding_rate)` method. + +## Packet Prioritization + +The pacer prioritized packets based on two criteria: + +* Packet type, with most to least prioritized: + 1. Audio + 2. Retransmissions + 3. Video and FEC + 4. Padding +* Enqueue order + +The enqueue order is enforced on a per stream (SSRC) basis. Given equal +priority, the [RoundRobinPacketQueue] alternates between media streams to ensure +no stream needlessly blocks others. + +## Implementations + +There are currently two implementations of the paced sender (although they share +a large amount of logic via the `PacingController` class). The legacy +[PacedSender] uses a dedicated thread to poll the pacing controller at 5ms +intervals, and has a lock to protect internal state. The newer +[TaskQueuePacedSender] as the name implies uses a TaskQueue to both protect +state and schedule packet processing, the latter is dynamic based on actual send +rates and constraints. Avoid using the legacy PacedSender in new applications as +we are planning to remove it. + +## The Packet Router + +An adjacent component called [PacketRouter] is used to route packets coming out +of the pacer and into the correct RTP module. It has the following functions: + +* The `SendPacket` method looks up an RTP module with an SSRC corresponding to + the packet for further routing to the network. +* If send-side bandwidth estimation is used, it populates the transport-wide + sequence number extension. +* Generate padding. Modules supporting payload-based padding are prioritized, + with the last module to have sent media always being the first choice. +* Returns any generated FEC after having sent media. +* Forwards REMB and/or TransportFeedback messages to suitable RTP modules. + +At present the FEC is generated on a per SSRC basis, so is always returned from +an RTP module after sending media. Hopefully one day we will support covering +multiple streams with a single FlexFEC stream - and the packet router is the +likely place for that FEC generator to live. It may even be used for FEC padding +as an alternative to RTX. + +## The API + +The section outlines the classes and methods relevant to a few different use +cases of the pacer. + +### Packet sending + +For sending packets, use +`RtpPacketSender::EnqueuePackets(std::vector> +packets)` The pacer takes a `PacingController::PacketSender` as constructor +argument, this callback is used when it's time to actually send packets. + +### Send rates + +To control the send rate, use `void SetPacingRates(DataRate pacing_rate, +DataRate padding_rate)` If the packet queue becomes empty and the send rate +drops below `padding_rate`, the pacer will request padding packets from the +`PacketRouter`. + +In order to completely suspend/resume sending data (e.g. due to network +availability), use the `Pause()` and `Resume()` methods. + +The specified pacing rate may be overriden in some cases, e.g. due to extreme +encoder overshoot. Use `void SetQueueTimeLimit(TimeDelta limit)` to specify the +longest time you want packets to spend waiting in the pacer queue (pausing +excluded). The actual send rate may then be increased past the pacing_rate to +try to make the _average_ queue time less than that requested limit. The +rationale for this is that if the send queue is say longer than three seconds, +it's better to risk packet loss and then try to recover using a key-frame rather +than cause severe delays. + +### Bandwidth estimation + +If the bandwidth estimator supports bandwidth probing, it may request a cluster +of packets to be sent at a specified rate in order to gauge if this causes +increased delay/loss on the network. Use the `void CreateProbeCluster(DataRate +bitrate, int cluster_id)` method - packets sent via this `PacketRouter` will be +marked with the corresponding cluster_id in the attached `PacedPacketInfo` +struct. + +If congestion window pushback is used, the state can be updated using +`SetCongestionWindow()` and `UpdateOutstandingData()`. + +A few more methods control how we pace: * `SetAccountForAudioPackets()` +determines if audio packets count into bandwidth consumed. * +`SetIncludeOverhead()` determines if the entire RTP packet size counts into +bandwidth used (otherwise just media payload). * `SetTransportOverhead()` sets +an additional data size consumed per packet, representing e.g. UDP/IP headers. + +### Stats + +Several methods are used to gather statistics in pacer state: + +* `OldestPacketWaitTime()` time since the oldest packet in the queue was + added. +* `QueueSizeData()` total bytes currently in the queue. +* `FirstSentPacketTime()` absolute time the first packet was sent. +* `ExpectedQueueTime()` total bytes in the queue divided by the send rate. + +[RTPSender]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/rtp_rtcp/source/rtp_sender.h;drc=77ee8542dd35d5143b5788ddf47fb7cdb96eb08e +[RtpPacketSender]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/rtp_rtcp/include/rtp_packet_sender.h;drc=ea55b0872f14faab23a4e5dbcb6956369c8ed5dc +[RtpPacketPacer]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/pacing/rtp_packet_pacer.h;drc=e7bc3a347760023dd4840cf6ebdd1e6c8592f4d7 +[PacketRouter]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/pacing/packet_router.h;drc=3d2210876e31d0bb5c7de88b27fd02ceb1f4e03e +[PacedSender]: https://source.chromium.org/chromium/chromium/src/+/master:media/cast/net/pacing/paced_sender.h;drc=df00acf8f3cea9a947e11dc687aa1147971a1883 +[TaskQueuePacedSender]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/pacing/task_queue_paced_sender.h;drc=5051693ada61bc7b78855c6fb3fa87a0394fa813 +[RoundRobinPacketQueue]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/pacing/round_robin_packet_queue.h;drc=b571ff48f8fe07678da5a854cd6c3f5dde02855f diff --git a/modules/pacing/paced_sender.cc b/modules/pacing/paced_sender.cc index a0e76761e7..51d3edc301 100644 --- a/modules/pacing/paced_sender.cc +++ b/modules/pacing/paced_sender.cc @@ -58,13 +58,13 @@ PacedSender::~PacedSender() { } void PacedSender::CreateProbeCluster(DataRate bitrate, int cluster_id) { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); return pacing_controller_.CreateProbeCluster(bitrate, cluster_id); } void PacedSender::Pause() { { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); pacing_controller_.Pause(); } @@ -77,7 +77,7 @@ void PacedSender::Pause() { void PacedSender::Resume() { { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); pacing_controller_.Resume(); } @@ -90,7 +90,7 @@ void PacedSender::Resume() { void PacedSender::SetCongestionWindow(DataSize congestion_window_size) { { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); pacing_controller_.SetCongestionWindow(congestion_window_size); } MaybeWakupProcessThread(); @@ -98,7 +98,7 @@ void PacedSender::SetCongestionWindow(DataSize congestion_window_size) { void PacedSender::UpdateOutstandingData(DataSize outstanding_data) { { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); pacing_controller_.UpdateOutstandingData(outstanding_data); } MaybeWakupProcessThread(); @@ -106,7 +106,7 @@ void PacedSender::UpdateOutstandingData(DataSize outstanding_data) { void PacedSender::SetPacingRates(DataRate pacing_rate, DataRate padding_rate) { { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); pacing_controller_.SetPacingRates(pacing_rate, padding_rate); } MaybeWakupProcessThread(); @@ -117,13 +117,14 @@ void PacedSender::EnqueuePackets( { TRACE_EVENT0(TRACE_DISABLED_BY_DEFAULT("webrtc"), "PacedSender::EnqueuePackets"); - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); for (auto& packet : packets) { TRACE_EVENT2(TRACE_DISABLED_BY_DEFAULT("webrtc"), "PacedSender::EnqueuePackets::Loop", "sequence_number", packet->SequenceNumber(), "rtp_timestamp", packet->Timestamp()); + RTC_DCHECK_GE(packet->capture_time_ms(), 0); pacing_controller_.EnqueuePacket(std::move(packet)); } } @@ -131,42 +132,42 @@ void PacedSender::EnqueuePackets( } void PacedSender::SetAccountForAudioPackets(bool account_for_audio) { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); pacing_controller_.SetAccountForAudioPackets(account_for_audio); } void PacedSender::SetIncludeOverhead() { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); pacing_controller_.SetIncludeOverhead(); } void PacedSender::SetTransportOverhead(DataSize overhead_per_packet) { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); pacing_controller_.SetTransportOverhead(overhead_per_packet); } TimeDelta PacedSender::ExpectedQueueTime() const { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); return pacing_controller_.ExpectedQueueTime(); } DataSize PacedSender::QueueSizeData() const { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); return pacing_controller_.QueueSizeData(); } absl::optional PacedSender::FirstSentPacketTime() const { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); return pacing_controller_.FirstSentPacketTime(); } TimeDelta PacedSender::OldestPacketWaitTime() const { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); return pacing_controller_.OldestPacketWaitTime(); } int64_t PacedSender::TimeUntilNextProcess() { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); Timestamp next_send_time = pacing_controller_.NextSendTime(); TimeDelta sleep_time = @@ -178,7 +179,7 @@ int64_t PacedSender::TimeUntilNextProcess() { } void PacedSender::Process() { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); pacing_controller_.ProcessPackets(); } @@ -198,7 +199,7 @@ void PacedSender::MaybeWakupProcessThread() { void PacedSender::SetQueueTimeLimit(TimeDelta limit) { { - rtc::CritScope cs(&critsect_); + MutexLock lock(&mutex_); pacing_controller_.SetQueueTimeLimit(limit); } MaybeWakupProcessThread(); diff --git a/modules/pacing/paced_sender.h b/modules/pacing/paced_sender.h index d255efdc3b..c819f3fb79 100644 --- a/modules/pacing/paced_sender.h +++ b/modules/pacing/paced_sender.h @@ -32,7 +32,7 @@ #include "modules/rtp_rtcp/include/rtp_packet_sender.h" #include "modules/rtp_rtcp/source/rtp_packet_to_send.h" #include "modules/utility/include/process_thread.h" -#include "rtc_base/deprecated/recursive_critical_section.h" +#include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread_annotations.h" namespace webrtc { @@ -157,9 +157,9 @@ class PacedSender : public Module, PacedSender* const delegate_; } module_proxy_{this}; - rtc::RecursiveCriticalSection critsect_; + mutable Mutex mutex_; const PacingController::ProcessMode process_mode_; - PacingController pacing_controller_ RTC_GUARDED_BY(critsect_); + PacingController pacing_controller_ RTC_GUARDED_BY(mutex_); Clock* const clock_; ProcessThread* const process_thread_; diff --git a/modules/pacing/pacing_controller.cc b/modules/pacing/pacing_controller.cc index 3ac7afa7ea..e0ace4e65e 100644 --- a/modules/pacing/pacing_controller.cc +++ b/modules/pacing/pacing_controller.cc @@ -295,11 +295,7 @@ void PacingController::EnqueuePacketInternal( int priority) { prober_.OnIncomingPacket(DataSize::Bytes(packet->payload_size())); - // TODO(sprang): Make sure tests respect this, replace with DCHECK. Timestamp now = CurrentTime(); - if (packet->capture_time_ms() < 0) { - packet->set_capture_time_ms(now.ms()); - } if (mode_ == ProcessMode::kDynamic && packet_queue_.Empty() && NextSendTime() <= now) { diff --git a/modules/pacing/packet_router.cc b/modules/pacing/packet_router.cc index 5317f510c9..3b1278e504 100644 --- a/modules/pacing/packet_router.cc +++ b/modules/pacing/packet_router.cc @@ -27,20 +27,11 @@ #include "rtc_base/trace_event.h" namespace webrtc { -namespace { - -constexpr int kRembSendIntervalMs = 200; - -} // namespace PacketRouter::PacketRouter() : PacketRouter(0) {} PacketRouter::PacketRouter(uint16_t start_transport_seq) : last_send_module_(nullptr), - last_remb_time_ms_(rtc::TimeMillis()), - last_send_bitrate_bps_(0), - bitrate_bps_(0), - max_bitrate_bps_(std::numeric_limits::max()), active_remb_module_(nullptr), transport_seq_(start_transport_seq) {} @@ -235,77 +226,19 @@ uint16_t PacketRouter::CurrentTransportSequenceNumber() const { return transport_seq_ & 0xFFFF; } -void PacketRouter::OnReceiveBitrateChanged(const std::vector& ssrcs, - uint32_t bitrate_bps) { - // % threshold for if we should send a new REMB asap. - const int64_t kSendThresholdPercent = 97; - // TODO(danilchap): Remove receive_bitrate_bps variable and the cast - // when OnReceiveBitrateChanged takes bitrate as int64_t. - int64_t receive_bitrate_bps = static_cast(bitrate_bps); - - int64_t now_ms = rtc::TimeMillis(); - { - MutexLock lock(&remb_mutex_); - - // If we already have an estimate, check if the new total estimate is below - // kSendThresholdPercent of the previous estimate. - if (last_send_bitrate_bps_ > 0) { - int64_t new_remb_bitrate_bps = - last_send_bitrate_bps_ - bitrate_bps_ + receive_bitrate_bps; - - if (new_remb_bitrate_bps < - kSendThresholdPercent * last_send_bitrate_bps_ / 100) { - // The new bitrate estimate is less than kSendThresholdPercent % of the - // last report. Send a REMB asap. - last_remb_time_ms_ = now_ms - kRembSendIntervalMs; - } - } - bitrate_bps_ = receive_bitrate_bps; - - if (now_ms - last_remb_time_ms_ < kRembSendIntervalMs) { - return; - } - // NOTE: Updated if we intend to send the data; we might not have - // a module to actually send it. - last_remb_time_ms_ = now_ms; - last_send_bitrate_bps_ = receive_bitrate_bps; - // Cap the value to send in remb with configured value. - receive_bitrate_bps = std::min(receive_bitrate_bps, max_bitrate_bps_); - } - SendRemb(receive_bitrate_bps, ssrcs); -} - -void PacketRouter::SetMaxDesiredReceiveBitrate(int64_t bitrate_bps) { - RTC_DCHECK_GE(bitrate_bps, 0); - { - MutexLock lock(&remb_mutex_); - max_bitrate_bps_ = bitrate_bps; - if (rtc::TimeMillis() - last_remb_time_ms_ < kRembSendIntervalMs && - last_send_bitrate_bps_ > 0 && - last_send_bitrate_bps_ <= max_bitrate_bps_) { - // Recent measured bitrate is already below the cap. - return; - } - } - SendRemb(bitrate_bps, /*ssrcs=*/{}); -} - -bool PacketRouter::SendRemb(int64_t bitrate_bps, - const std::vector& ssrcs) { +void PacketRouter::SendRemb(int64_t bitrate_bps, std::vector ssrcs) { MutexLock lock(&modules_mutex_); if (!active_remb_module_) { - return false; + return; } // The Add* and Remove* methods above ensure that REMB is disabled on all // other modules, because otherwise, they will send REMB with stale info. - active_remb_module_->SetRemb(bitrate_bps, ssrcs); - - return true; + active_remb_module_->SetRemb(bitrate_bps, std::move(ssrcs)); } -bool PacketRouter::SendCombinedRtcpPacket( +void PacketRouter::SendCombinedRtcpPacket( std::vector> packets) { MutexLock lock(&modules_mutex_); @@ -315,15 +248,14 @@ bool PacketRouter::SendCombinedRtcpPacket( continue; } rtp_module->SendCombinedRtcpPacket(std::move(packets)); - return true; + return; } if (rtcp_feedback_senders_.empty()) { - return false; + return; } auto* rtcp_sender = rtcp_feedback_senders_[0]; rtcp_sender->SendCombinedRtcpPacket(std::move(packets)); - return true; } void PacketRouter::AddRembModuleCandidate( diff --git a/modules/pacing/packet_router.h b/modules/pacing/packet_router.h index 2fa104b4cd..7a6e24d7ea 100644 --- a/modules/pacing/packet_router.h +++ b/modules/pacing/packet_router.h @@ -39,9 +39,7 @@ class RtpRtcpInterface; // module if possible (sender report), otherwise on receive module // (receiver report). For the latter case, we also keep track of the // receive modules. -class PacketRouter : public RemoteBitrateObserver, - public TransportFeedbackSenderInterface, - public PacingController::PacketSender { +class PacketRouter : public PacingController::PacketSender { public: PacketRouter(); explicit PacketRouter(uint16_t start_transport_seq); @@ -62,24 +60,12 @@ class PacketRouter : public RemoteBitrateObserver, uint16_t CurrentTransportSequenceNumber() const; - // Called every time there is a new bitrate estimate for a receive channel - // group. This call will trigger a new RTCP REMB packet if the bitrate - // estimate has decreased or if no RTCP REMB packet has been sent for - // a certain time interval. - // Implements RtpReceiveBitrateUpdate. - void OnReceiveBitrateChanged(const std::vector& ssrcs, - uint32_t bitrate_bps) override; - - // Ensures remote party notified of the receive bitrate limit no larger than - // |bitrate_bps|. - void SetMaxDesiredReceiveBitrate(int64_t bitrate_bps); - // Send REMB feedback. - bool SendRemb(int64_t bitrate_bps, const std::vector& ssrcs); + void SendRemb(int64_t bitrate_bps, std::vector ssrcs); // Sends |packets| in one or more IP packets. - bool SendCombinedRtcpPacket( - std::vector> packets) override; + void SendCombinedRtcpPacket( + std::vector> packets); private: void AddRembModuleCandidate(RtcpFeedbackSenderInterface* candidate_module, @@ -107,16 +93,6 @@ class PacketRouter : public RemoteBitrateObserver, std::vector rtcp_feedback_senders_ RTC_GUARDED_BY(modules_mutex_); - // TODO(eladalon): remb_mutex_ only ever held from one function, and it's not - // clear if that function can actually be called from more than one thread. - Mutex remb_mutex_; - // The last time a REMB was sent. - int64_t last_remb_time_ms_ RTC_GUARDED_BY(remb_mutex_); - int64_t last_send_bitrate_bps_ RTC_GUARDED_BY(remb_mutex_); - // The last bitrate update. - int64_t bitrate_bps_ RTC_GUARDED_BY(remb_mutex_); - int64_t max_bitrate_bps_ RTC_GUARDED_BY(remb_mutex_); - // Candidates for the REMB module can be RTP sender/receiver modules, with // the sender modules taking precedence. std::vector sender_remb_candidates_ diff --git a/modules/pacing/packet_router_unittest.cc b/modules/pacing/packet_router_unittest.cc index 10cf98b3dd..77fe5f9f8d 100644 --- a/modules/pacing/packet_router_unittest.cc +++ b/modules/pacing/packet_router_unittest.cc @@ -74,25 +74,19 @@ TEST_F(PacketRouterTest, Sanity_NoModuleRegistered_GeneratePadding) { EXPECT_TRUE(packet_router_.GeneratePadding(bytes).empty()); } -TEST_F(PacketRouterTest, Sanity_NoModuleRegistered_OnReceiveBitrateChanged) { - const std::vector ssrcs = {1, 2, 3}; - constexpr uint32_t bitrate_bps = 10000; - - packet_router_.OnReceiveBitrateChanged(ssrcs, bitrate_bps); -} TEST_F(PacketRouterTest, Sanity_NoModuleRegistered_SendRemb) { const std::vector ssrcs = {1, 2, 3}; constexpr uint32_t bitrate_bps = 10000; - - EXPECT_FALSE(packet_router_.SendRemb(bitrate_bps, ssrcs)); + // Expect not to crash + packet_router_.SendRemb(bitrate_bps, ssrcs); } TEST_F(PacketRouterTest, Sanity_NoModuleRegistered_SendTransportFeedback) { std::vector> feedback; feedback.push_back(std::make_unique()); - - EXPECT_FALSE(packet_router_.SendCombinedRtcpPacket(std::move(feedback))); + // Expect not to crash + packet_router_.SendCombinedRtcpPacket(std::move(feedback)); } TEST_F(PacketRouterTest, GeneratePaddingPrioritizesRtx) { @@ -327,10 +321,10 @@ TEST_F(PacketRouterTest, SendTransportFeedback) { std::vector> feedback; feedback.push_back(std::make_unique()); - EXPECT_CALL(rtp_1, SendCombinedRtcpPacket).Times(1); + EXPECT_CALL(rtp_1, SendCombinedRtcpPacket); packet_router_.SendCombinedRtcpPacket(std::move(feedback)); packet_router_.RemoveSendRtpModule(&rtp_1); - EXPECT_CALL(rtp_2, SendCombinedRtcpPacket).Times(1); + EXPECT_CALL(rtp_2, SendCombinedRtcpPacket); std::vector> new_feedback; new_feedback.push_back(std::make_unique()); packet_router_.SendCombinedRtcpPacket(std::move(new_feedback)); @@ -442,86 +436,7 @@ TEST_F(PacketRouterDeathTest, RemovalOfNeverAddedReceiveModuleDisallowed) { } #endif // RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) -TEST(PacketRouterRembTest, LowerEstimateToSendRemb) { - rtc::ScopedFakeClock clock; - NiceMock rtp; - PacketRouter packet_router; - - packet_router.AddSendRtpModule(&rtp, true); - - uint32_t bitrate_estimate = 456; - const std::vector ssrcs = {1234}; - - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); - - // Call OnReceiveBitrateChanged twice to get a first estimate. - clock.AdvanceTime(TimeDelta::Millis(1000)); - EXPECT_CALL(rtp, SetRemb(bitrate_estimate, ssrcs)).Times(1); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); - - // Lower the estimate with more than 3% to trigger a call to SetRemb right - // away. - bitrate_estimate = bitrate_estimate - 100; - EXPECT_CALL(rtp, SetRemb(bitrate_estimate, ssrcs)).Times(1); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); - - packet_router.RemoveSendRtpModule(&rtp); -} - -TEST(PacketRouterRembTest, VerifyIncreasingAndDecreasing) { - rtc::ScopedFakeClock clock; - NiceMock rtp; - PacketRouter packet_router; - packet_router.AddSendRtpModule(&rtp, true); - - uint32_t bitrate_estimate[] = {456, 789}; - std::vector ssrcs = {1234, 5678}; - - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate[0]); - - // Call OnReceiveBitrateChanged twice to get a first estimate. - EXPECT_CALL(rtp, SetRemb(bitrate_estimate[0], ssrcs)).Times(1); - clock.AdvanceTime(TimeDelta::Millis(1000)); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate[0]); - - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate[1] + 100); - - // Lower the estimate to trigger a callback. - EXPECT_CALL(rtp, SetRemb(bitrate_estimate[1], ssrcs)).Times(1); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate[1]); - - packet_router.RemoveSendRtpModule(&rtp); -} - -TEST(PacketRouterRembTest, NoRembForIncreasedBitrate) { - rtc::ScopedFakeClock clock; - NiceMock rtp; - PacketRouter packet_router; - packet_router.AddSendRtpModule(&rtp, true); - - uint32_t bitrate_estimate = 456; - std::vector ssrcs = {1234, 5678}; - - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); - - // Call OnReceiveBitrateChanged twice to get a first estimate. - EXPECT_CALL(rtp, SetRemb(bitrate_estimate, ssrcs)).Times(1); - clock.AdvanceTime(TimeDelta::Millis(1000)); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); - - // Increased estimate shouldn't trigger a callback right away. - EXPECT_CALL(rtp, SetRemb(_, _)).Times(0); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate + 1); - - // Decreasing the estimate less than 3% shouldn't trigger a new callback. - EXPECT_CALL(rtp, SetRemb(_, _)).Times(0); - int lower_estimate = bitrate_estimate * 98 / 100; - packet_router.OnReceiveBitrateChanged(ssrcs, lower_estimate); - - packet_router.RemoveSendRtpModule(&rtp); -} - -TEST(PacketRouterRembTest, ChangeSendRtpModule) { +TEST(PacketRouterRembTest, ChangeSendRtpModuleChangeRembSender) { rtc::ScopedFakeClock clock; NiceMock rtp_send; NiceMock rtp_recv; @@ -532,191 +447,18 @@ TEST(PacketRouterRembTest, ChangeSendRtpModule) { uint32_t bitrate_estimate = 456; std::vector ssrcs = {1234, 5678}; - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); - - // Call OnReceiveBitrateChanged twice to get a first estimate. - clock.AdvanceTime(TimeDelta::Millis(1000)); - EXPECT_CALL(rtp_send, SetRemb(bitrate_estimate, ssrcs)).Times(1); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); - - // Decrease estimate to trigger a REMB. - bitrate_estimate = bitrate_estimate - 100; - EXPECT_CALL(rtp_send, SetRemb(bitrate_estimate, ssrcs)).Times(1); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); + EXPECT_CALL(rtp_send, SetRemb(bitrate_estimate, ssrcs)); + packet_router.SendRemb(bitrate_estimate, ssrcs); // Remove the sending module -> should get remb on the second module. packet_router.RemoveSendRtpModule(&rtp_send); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); - - bitrate_estimate = bitrate_estimate - 100; - EXPECT_CALL(rtp_recv, SetRemb(bitrate_estimate, ssrcs)).Times(1); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); + EXPECT_CALL(rtp_recv, SetRemb(bitrate_estimate, ssrcs)); + packet_router.SendRemb(bitrate_estimate, ssrcs); packet_router.RemoveReceiveRtpModule(&rtp_recv); } -TEST(PacketRouterRembTest, OnlyOneRembForRepeatedOnReceiveBitrateChanged) { - rtc::ScopedFakeClock clock; - NiceMock rtp; - PacketRouter packet_router; - packet_router.AddSendRtpModule(&rtp, true); - - uint32_t bitrate_estimate = 456; - const std::vector ssrcs = {1234}; - - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); - - // Call OnReceiveBitrateChanged twice to get a first estimate. - clock.AdvanceTime(TimeDelta::Millis(1000)); - EXPECT_CALL(rtp, SetRemb(_, _)).Times(1); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); - - // Lower the estimate, should trigger a call to SetRemb right away. - bitrate_estimate = bitrate_estimate - 100; - EXPECT_CALL(rtp, SetRemb(bitrate_estimate, ssrcs)).Times(1); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); - - // Call OnReceiveBitrateChanged again, this should not trigger a new callback. - EXPECT_CALL(rtp, SetRemb(_, _)).Times(0); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); - packet_router.RemoveSendRtpModule(&rtp); -} - -TEST(PacketRouterRembTest, SetMaxDesiredReceiveBitrateLimitsSetRemb) { - rtc::ScopedFakeClock clock; - PacketRouter packet_router; - clock.AdvanceTime(TimeDelta::Millis(1000)); - NiceMock remb_sender; - constexpr bool remb_candidate = true; - packet_router.AddSendRtpModule(&remb_sender, remb_candidate); - - const int64_t cap_bitrate = 100000; - EXPECT_CALL(remb_sender, SetRemb(Le(cap_bitrate), _)).Times(AtLeast(1)); - EXPECT_CALL(remb_sender, SetRemb(Gt(cap_bitrate), _)).Times(0); - - const std::vector ssrcs = {1234}; - packet_router.SetMaxDesiredReceiveBitrate(cap_bitrate); - packet_router.OnReceiveBitrateChanged(ssrcs, cap_bitrate + 5000); - clock.AdvanceTime(TimeDelta::Millis(1000)); - packet_router.OnReceiveBitrateChanged(ssrcs, cap_bitrate - 5000); - - // Test tear-down. - packet_router.RemoveSendRtpModule(&remb_sender); -} - -TEST(PacketRouterRembTest, - SetMaxDesiredReceiveBitrateTriggersRembWhenMoreRestrictive) { - rtc::ScopedFakeClock clock; - PacketRouter packet_router; - clock.AdvanceTime(TimeDelta::Millis(1000)); - NiceMock remb_sender; - constexpr bool remb_candidate = true; - packet_router.AddSendRtpModule(&remb_sender, remb_candidate); - - const int64_t measured_bitrate_bps = 150000; - const int64_t cap_bitrate_bps = measured_bitrate_bps - 5000; - const std::vector ssrcs = {1234}; - EXPECT_CALL(remb_sender, SetRemb(measured_bitrate_bps, _)); - packet_router.OnReceiveBitrateChanged(ssrcs, measured_bitrate_bps); - - EXPECT_CALL(remb_sender, SetRemb(cap_bitrate_bps, _)); - packet_router.SetMaxDesiredReceiveBitrate(cap_bitrate_bps); - - // Test tear-down. - packet_router.RemoveSendRtpModule(&remb_sender); -} - -TEST(PacketRouterRembTest, - SetMaxDesiredReceiveBitrateDoesNotTriggerRembWhenAsRestrictive) { - rtc::ScopedFakeClock clock; - PacketRouter packet_router; - clock.AdvanceTime(TimeDelta::Millis(1000)); - NiceMock remb_sender; - constexpr bool remb_candidate = true; - packet_router.AddSendRtpModule(&remb_sender, remb_candidate); - - const uint32_t measured_bitrate_bps = 150000; - const uint32_t cap_bitrate_bps = measured_bitrate_bps; - const std::vector ssrcs = {1234}; - EXPECT_CALL(remb_sender, SetRemb(measured_bitrate_bps, _)); - packet_router.OnReceiveBitrateChanged(ssrcs, measured_bitrate_bps); - - EXPECT_CALL(remb_sender, SetRemb(_, _)).Times(0); - packet_router.SetMaxDesiredReceiveBitrate(cap_bitrate_bps); - - // Test tear-down. - packet_router.RemoveSendRtpModule(&remb_sender); -} - -TEST(PacketRouterRembTest, - SetMaxDesiredReceiveBitrateDoesNotTriggerRembWhenLessRestrictive) { - rtc::ScopedFakeClock clock; - PacketRouter packet_router; - clock.AdvanceTime(TimeDelta::Millis(1000)); - NiceMock remb_sender; - constexpr bool remb_candidate = true; - packet_router.AddSendRtpModule(&remb_sender, remb_candidate); - - const uint32_t measured_bitrate_bps = 150000; - const uint32_t cap_bitrate_bps = measured_bitrate_bps + 500; - const std::vector ssrcs = {1234}; - EXPECT_CALL(remb_sender, SetRemb(measured_bitrate_bps, _)); - packet_router.OnReceiveBitrateChanged(ssrcs, measured_bitrate_bps); - - EXPECT_CALL(remb_sender, SetRemb(_, _)).Times(0); - packet_router.SetMaxDesiredReceiveBitrate(cap_bitrate_bps); - - // Test tear-down. - packet_router.RemoveSendRtpModule(&remb_sender); -} - -TEST(PacketRouterRembTest, - SetMaxDesiredReceiveBitrateTriggersRembWhenNoRecentMeasure) { - rtc::ScopedFakeClock clock; - PacketRouter packet_router; - clock.AdvanceTime(TimeDelta::Millis(1000)); - NiceMock remb_sender; - constexpr bool remb_candidate = true; - packet_router.AddSendRtpModule(&remb_sender, remb_candidate); - - const uint32_t measured_bitrate_bps = 150000; - const uint32_t cap_bitrate_bps = measured_bitrate_bps + 5000; - const std::vector ssrcs = {1234}; - EXPECT_CALL(remb_sender, SetRemb(measured_bitrate_bps, _)); - packet_router.OnReceiveBitrateChanged(ssrcs, measured_bitrate_bps); - clock.AdvanceTime(TimeDelta::Millis(1000)); - - EXPECT_CALL(remb_sender, SetRemb(cap_bitrate_bps, _)); - packet_router.SetMaxDesiredReceiveBitrate(cap_bitrate_bps); - - // Test tear-down. - packet_router.RemoveSendRtpModule(&remb_sender); -} - -TEST(PacketRouterRembTest, - SetMaxDesiredReceiveBitrateTriggersRembWhenNoMeasures) { - rtc::ScopedFakeClock clock; - PacketRouter packet_router; - clock.AdvanceTime(TimeDelta::Millis(1000)); - NiceMock remb_sender; - constexpr bool remb_candidate = true; - packet_router.AddSendRtpModule(&remb_sender, remb_candidate); - - // Set cap. - EXPECT_CALL(remb_sender, SetRemb(100000, _)).Times(1); - packet_router.SetMaxDesiredReceiveBitrate(100000); - // Increase cap. - EXPECT_CALL(remb_sender, SetRemb(200000, _)).Times(1); - packet_router.SetMaxDesiredReceiveBitrate(200000); - // Decrease cap. - EXPECT_CALL(remb_sender, SetRemb(150000, _)).Times(1); - packet_router.SetMaxDesiredReceiveBitrate(150000); - - // Test tear-down. - packet_router.RemoveSendRtpModule(&remb_sender); -} - // Only register receiving modules and make sure we fallback to trigger a REMB // packet on this one. TEST(PacketRouterRembTest, NoSendingRtpModule) { @@ -729,18 +471,14 @@ TEST(PacketRouterRembTest, NoSendingRtpModule) { uint32_t bitrate_estimate = 456; const std::vector ssrcs = {1234}; - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); - - // Call OnReceiveBitrateChanged twice to get a first estimate. - clock.AdvanceTime(TimeDelta::Millis(1000)); - EXPECT_CALL(rtp, SetRemb(bitrate_estimate, ssrcs)).Times(1); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); + EXPECT_CALL(rtp, SetRemb(bitrate_estimate, ssrcs)); + packet_router.SendRemb(bitrate_estimate, ssrcs); // Lower the estimate to trigger a new packet REMB packet. - EXPECT_CALL(rtp, SetRemb(bitrate_estimate - 100, ssrcs)).Times(1); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate - 100); + EXPECT_CALL(rtp, SetRemb(bitrate_estimate, ssrcs)); + packet_router.SendRemb(bitrate_estimate, ssrcs); - EXPECT_CALL(rtp, UnsetRemb()).Times(1); + EXPECT_CALL(rtp, UnsetRemb()); packet_router.RemoveReceiveRtpModule(&rtp); } @@ -756,8 +494,7 @@ TEST(PacketRouterRembTest, NonCandidateSendRtpModuleNotUsedForRemb) { constexpr uint32_t bitrate_estimate = 456; const std::vector ssrcs = {1234}; EXPECT_CALL(module, SetRemb(_, _)).Times(0); - clock.AdvanceTime(TimeDelta::Millis(1000)); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); + packet_router.SendRemb(bitrate_estimate, ssrcs); // Test tear-down packet_router.RemoveSendRtpModule(&module); @@ -774,9 +511,8 @@ TEST(PacketRouterRembTest, CandidateSendRtpModuleUsedForRemb) { constexpr uint32_t bitrate_estimate = 456; const std::vector ssrcs = {1234}; - EXPECT_CALL(module, SetRemb(bitrate_estimate, ssrcs)).Times(1); - clock.AdvanceTime(TimeDelta::Millis(1000)); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); + EXPECT_CALL(module, SetRemb(bitrate_estimate, ssrcs)); + packet_router.SendRemb(bitrate_estimate, ssrcs); // Test tear-down packet_router.RemoveSendRtpModule(&module); @@ -794,8 +530,7 @@ TEST(PacketRouterRembTest, NonCandidateReceiveRtpModuleNotUsedForRemb) { constexpr uint32_t bitrate_estimate = 456; const std::vector ssrcs = {1234}; EXPECT_CALL(module, SetRemb(_, _)).Times(0); - clock.AdvanceTime(TimeDelta::Millis(1000)); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); + packet_router.SendRemb(bitrate_estimate, ssrcs); // Test tear-down packet_router.RemoveReceiveRtpModule(&module); @@ -812,9 +547,8 @@ TEST(PacketRouterRembTest, CandidateReceiveRtpModuleUsedForRemb) { constexpr uint32_t bitrate_estimate = 456; const std::vector ssrcs = {1234}; - EXPECT_CALL(module, SetRemb(bitrate_estimate, ssrcs)).Times(1); - clock.AdvanceTime(TimeDelta::Millis(1000)); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); + EXPECT_CALL(module, SetRemb(bitrate_estimate, ssrcs)); + packet_router.SendRemb(bitrate_estimate, ssrcs); // Test tear-down packet_router.RemoveReceiveRtpModule(&module); @@ -837,11 +571,10 @@ TEST(PacketRouterRembTest, constexpr uint32_t bitrate_estimate = 456; const std::vector ssrcs = {1234}; - EXPECT_CALL(send_module, SetRemb(bitrate_estimate, ssrcs)).Times(1); + EXPECT_CALL(send_module, SetRemb(bitrate_estimate, ssrcs)); EXPECT_CALL(receive_module, SetRemb(_, _)).Times(0); - clock.AdvanceTime(TimeDelta::Millis(1000)); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); + packet_router.SendRemb(bitrate_estimate, ssrcs); // Test tear-down packet_router.RemoveReceiveRtpModule(&receive_module); @@ -865,11 +598,11 @@ TEST(PacketRouterRembTest, constexpr uint32_t bitrate_estimate = 456; const std::vector ssrcs = {1234}; - EXPECT_CALL(send_module, SetRemb(bitrate_estimate, ssrcs)).Times(1); + EXPECT_CALL(send_module, SetRemb(bitrate_estimate, ssrcs)); EXPECT_CALL(receive_module, SetRemb(_, _)).Times(0); clock.AdvanceTime(TimeDelta::Millis(1000)); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); + packet_router.SendRemb(bitrate_estimate, ssrcs); // Test tear-down packet_router.RemoveReceiveRtpModule(&receive_module); @@ -893,10 +626,8 @@ TEST(PacketRouterRembTest, ReceiveModuleTakesOverWhenLastSendModuleRemoved) { constexpr uint32_t bitrate_estimate = 456; const std::vector ssrcs = {1234}; EXPECT_CALL(send_module, SetRemb(_, _)).Times(0); - EXPECT_CALL(receive_module, SetRemb(bitrate_estimate, ssrcs)).Times(1); - - clock.AdvanceTime(TimeDelta::Millis(1000)); - packet_router.OnReceiveBitrateChanged(ssrcs, bitrate_estimate); + EXPECT_CALL(receive_module, SetRemb(bitrate_estimate, ssrcs)); + packet_router.SendRemb(bitrate_estimate, ssrcs); // Test tear-down packet_router.RemoveReceiveRtpModule(&receive_module); diff --git a/modules/pacing/round_robin_packet_queue.h b/modules/pacing/round_robin_packet_queue.h index 9446a8e174..cad555a1af 100644 --- a/modules/pacing/round_robin_packet_queue.h +++ b/modules/pacing/round_robin_packet_queue.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "absl/types/optional.h" #include "api/transport/webrtc_key_value_config.h" @@ -163,7 +164,7 @@ class RoundRobinPacketQueue { std::multimap stream_priorities_; // A map of SSRCs to Streams. - std::map streams_; + std::unordered_map streams_; // The enqueue time of every packet currently in the queue. Used to figure out // the age of the oldest packet in the queue. diff --git a/modules/pacing/task_queue_paced_sender.cc b/modules/pacing/task_queue_paced_sender.cc index 69ec5457ad..709718ff16 100644 --- a/modules/pacing/task_queue_paced_sender.cc +++ b/modules/pacing/task_queue_paced_sender.cc @@ -32,7 +32,7 @@ constexpr TimeDelta kMinTimeBetweenStatsUpdates = TimeDelta::Millis(1); TaskQueuePacedSender::TaskQueuePacedSender( Clock* clock, - PacketRouter* packet_router, + PacingController::PacketSender* packet_sender, RtcEventLog* event_log, const WebRtcKeyValueConfig* field_trials, TaskQueueFactory* task_queue_factory, @@ -40,7 +40,7 @@ TaskQueuePacedSender::TaskQueuePacedSender( : clock_(clock), hold_back_window_(hold_back_window), pacing_controller_(clock, - packet_router, + packet_sender, event_log, field_trials, PacingController::ProcessMode::kDynamic), @@ -62,6 +62,14 @@ TaskQueuePacedSender::~TaskQueuePacedSender() { }); } +void TaskQueuePacedSender::EnsureStarted() { + task_queue_.PostTask([this]() { + RTC_DCHECK_RUN_ON(&task_queue_); + is_started_ = true; + MaybeProcessPackets(Timestamp::MinusInfinity()); + }); +} + void TaskQueuePacedSender::CreateProbeCluster(DataRate bitrate, int cluster_id) { task_queue_.PostTask([this, bitrate, cluster_id]() { @@ -136,6 +144,7 @@ void TaskQueuePacedSender::EnqueuePackets( task_queue_.PostTask([this, packets_ = std::move(packets)]() mutable { RTC_DCHECK_RUN_ON(&task_queue_); for (auto& packet : packets_) { + RTC_DCHECK_GE(packet->capture_time_ms(), 0); pacing_controller_.EnqueuePacket(std::move(packet)); } MaybeProcessPackets(Timestamp::MinusInfinity()); @@ -196,7 +205,7 @@ void TaskQueuePacedSender::MaybeProcessPackets( Timestamp scheduled_process_time) { RTC_DCHECK_RUN_ON(&task_queue_); - if (is_shutdown_) { + if (is_shutdown_ || !is_started_) { return; } diff --git a/modules/pacing/task_queue_paced_sender.h b/modules/pacing/task_queue_paced_sender.h index ba4f4667b7..0673441e52 100644 --- a/modules/pacing/task_queue_paced_sender.h +++ b/modules/pacing/task_queue_paced_sender.h @@ -20,17 +20,16 @@ #include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_factory.h" #include "api/units/data_size.h" #include "api/units/time_delta.h" #include "api/units/timestamp.h" #include "modules/include/module.h" #include "modules/pacing/pacing_controller.h" -#include "modules/pacing/packet_router.h" #include "modules/pacing/rtp_packet_pacer.h" #include "modules/rtp_rtcp/source/rtp_packet_to_send.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/task_queue.h" #include "rtc_base/thread_annotations.h" @@ -47,7 +46,7 @@ class TaskQueuePacedSender : public RtpPacketPacer, public RtpPacketSender { // TODO(bugs.webrtc.org/10809): Remove default value for hold_back_window. TaskQueuePacedSender( Clock* clock, - PacketRouter* packet_router, + PacingController::PacketSender* packet_sender, RtcEventLog* event_log, const WebRtcKeyValueConfig* field_trials, TaskQueueFactory* task_queue_factory, @@ -55,10 +54,13 @@ class TaskQueuePacedSender : public RtpPacketPacer, public RtpPacketSender { ~TaskQueuePacedSender() override; + // Ensure that necessary delayed tasks are scheduled. + void EnsureStarted(); + // Methods implementing RtpPacketSender. - // Adds the packet to the queue and calls PacketRouter::SendPacket() when - // it's time to send. + // Adds the packet to the queue and calls + // PacingController::PacketSender::SendPacket() when it's time to send. void EnqueuePackets( std::vector> packets) override; @@ -150,6 +152,10 @@ class TaskQueuePacedSender : public RtpPacketPacer, public RtpPacketSender { // Last time stats were updated. Timestamp last_stats_time_ RTC_GUARDED_BY(task_queue_); + // Indicates if this task queue is started. If not, don't allow + // posting delayed tasks yet. + bool is_started_ RTC_GUARDED_BY(task_queue_) = false; + // Indicates if this task queue is shutting down. If so, don't allow // posting any more delayed tasks as that can cause the task queue to // never drain. diff --git a/modules/pacing/task_queue_paced_sender_unittest.cc b/modules/pacing/task_queue_paced_sender_unittest.cc index d389e271f7..3806ec28d2 100644 --- a/modules/pacing/task_queue_paced_sender_unittest.cc +++ b/modules/pacing/task_queue_paced_sender_unittest.cc @@ -157,6 +157,7 @@ namespace test { pacer.SetPacingRates( DataRate::BitsPerSec(kDefaultPacketSize * 8 * kPacketsToSend), DataRate::Zero()); + pacer.EnsureStarted(); pacer.EnqueuePackets( GeneratePackets(RtpPacketMediaType::kVideo, kPacketsToSend)); @@ -196,6 +197,7 @@ namespace test { const DataRate kPacingRate = DataRate::BitsPerSec(kDefaultPacketSize * 8 * kPacketsPerSecond); pacer.SetPacingRates(kPacingRate, DataRate::Zero()); + pacer.EnsureStarted(); // Send some initial packets to be rid of any probes. EXPECT_CALL(packet_router, SendPacket).Times(kPacketsPerSecond); @@ -247,6 +249,7 @@ namespace test { const TimeDelta kPacketPacingTime = kPacketSize / kPacingDataRate; pacer.SetPacingRates(kPacingDataRate, DataRate::Zero()); + pacer.EnsureStarted(); // Add some initial video packets, only one should be sent. EXPECT_CALL(packet_router, SendPacket); @@ -280,6 +283,7 @@ namespace test { const DataRate kPacingDataRate = kPacketSize / kPacketPacingTime; pacer.SetPacingRates(kPacingDataRate, DataRate::Zero()); + pacer.EnsureStarted(); // Add 10 packets. The first should be sent immediately since the buffers // are clear. @@ -316,6 +320,7 @@ namespace test { const DataRate kPacingDataRate = kPacketSize / kPacketPacingTime; pacer.SetPacingRates(kPacingDataRate, DataRate::Zero()); + pacer.EnsureStarted(); // Add 10 packets. The first should be sent immediately since the buffers // are clear. This will also trigger the probe to start. @@ -342,6 +347,7 @@ namespace test { kCoalescingWindow); const DataRate kPacingDataRate = DataRate::KilobitsPerSec(300); pacer.SetPacingRates(kPacingDataRate, DataRate::Zero()); + pacer.EnsureStarted(); const TimeDelta kMinTimeBetweenStatsUpdates = TimeDelta::Millis(1); @@ -388,6 +394,7 @@ namespace test { size_t num_expected_stats_updates = 0; EXPECT_EQ(pacer.num_stats_updates_, num_expected_stats_updates); pacer.SetPacingRates(kPacingDataRate, DataRate::Zero()); + pacer.EnsureStarted(); time_controller.AdvanceTime(kMinTimeBetweenStatsUpdates); // Updating pacing rates refreshes stats. EXPECT_EQ(pacer.num_stats_updates_, ++num_expected_stats_updates); @@ -443,6 +450,7 @@ namespace test { const TimeDelta kPacketPacingTime = TimeDelta::Millis(4); const DataRate kPacingDataRate = kPacketSize / kPacketPacingTime; pacer.SetPacingRates(kPacingDataRate, /*padding_rate=*/DataRate::Zero()); + pacer.EnsureStarted(); EXPECT_CALL(packet_router, FetchFec).WillRepeatedly([]() { return std::vector>(); }); @@ -514,6 +522,7 @@ namespace test { const TimeDelta kPacketPacingTime = TimeDelta::Millis(4); const DataRate kPacingDataRate = kPacketSize / kPacketPacingTime; pacer.SetPacingRates(kPacingDataRate, /*padding_rate=*/DataRate::Zero()); + pacer.EnsureStarted(); EXPECT_CALL(packet_router, FetchFec).WillRepeatedly([]() { return std::vector>(); }); @@ -552,5 +561,33 @@ namespace test { EXPECT_EQ(data_sent, kProbingRate * TimeDelta::Millis(1) + DataSize::Bytes(1)); } + + TEST(TaskQueuePacedSenderTest, NoStatsUpdatesBeforeStart) { + const TimeDelta kCoalescingWindow = TimeDelta::Millis(5); + GlobalSimulatedTimeController time_controller(Timestamp::Millis(1234)); + MockPacketRouter packet_router; + TaskQueuePacedSenderForTest pacer( + time_controller.GetClock(), &packet_router, + /*event_log=*/nullptr, + /*field_trials=*/nullptr, time_controller.GetTaskQueueFactory(), + kCoalescingWindow); + const DataRate kPacingDataRate = DataRate::KilobitsPerSec(300); + pacer.SetPacingRates(kPacingDataRate, DataRate::Zero()); + + const TimeDelta kMinTimeBetweenStatsUpdates = TimeDelta::Millis(1); + + // Nothing inserted, no stats updates yet. + EXPECT_EQ(pacer.num_stats_updates_, 0u); + + // Insert one packet, stats should not be updated. + pacer.EnqueuePackets(GeneratePackets(RtpPacketMediaType::kVideo, 1)); + time_controller.AdvanceTime(TimeDelta::Zero()); + EXPECT_EQ(pacer.num_stats_updates_, 0u); + + // Advance time of the min stats update interval, and trigger a + // refresh - stats should not be updated still. + time_controller.AdvanceTime(kMinTimeBetweenStatsUpdates); + EXPECT_EQ(pacer.num_stats_updates_, 0u); + } } // namespace test } // namespace webrtc diff --git a/modules/remote_bitrate_estimator/BUILD.gn b/modules/remote_bitrate_estimator/BUILD.gn index 81aa1efdda..923f00a74c 100644 --- a/modules/remote_bitrate_estimator/BUILD.gn +++ b/modules/remote_bitrate_estimator/BUILD.gn @@ -21,6 +21,8 @@ rtc_library("remote_bitrate_estimator") { "overuse_detector.h", "overuse_estimator.cc", "overuse_estimator.h", + "packet_arrival_map.cc", + "packet_arrival_map.h", "remote_bitrate_estimator_abs_send_time.cc", "remote_bitrate_estimator_abs_send_time.h", "remote_bitrate_estimator_single_stream.cc", @@ -45,6 +47,8 @@ rtc_library("remote_bitrate_estimator") { "../../api/transport:network_control", "../../api/transport:webrtc_key_value_config", "../../api/units:data_rate", + "../../api/units:data_size", + "../../api/units:time_delta", "../../api/units:timestamp", "../../modules:module_api", "../../modules:module_api_public", @@ -74,10 +78,9 @@ if (!build_with_chromium) { "tools/bwe_rtp.h", ] deps = [ - ":remote_bitrate_estimator", "../../rtc_base:rtc_base_approved", "../../test:rtp_test_utils", - "../rtp_rtcp", + "../rtp_rtcp:rtp_rtcp_format", ] absl_deps = [ "//third_party/abseil-cpp/absl/flags:flag", @@ -90,10 +93,10 @@ if (!build_with_chromium) { sources = [ "tools/rtp_to_text.cc" ] deps = [ ":bwe_rtp", - "../../modules/rtp_rtcp", "../../rtc_base:macromagic", "../../rtc_base:stringutils", "../../test:rtp_test_utils", + "../rtp_rtcp:rtp_rtcp_format", ] } } @@ -106,6 +109,7 @@ if (rtc_include_tests) { "aimd_rate_control_unittest.cc", "inter_arrival_unittest.cc", "overuse_detector_unittest.cc", + "packet_arrival_map_test.cc", "remote_bitrate_estimator_abs_send_time_unittest.cc", "remote_bitrate_estimator_single_stream_unittest.cc", "remote_bitrate_estimator_unittest_helper.cc", diff --git a/modules/remote_bitrate_estimator/aimd_rate_control.cc b/modules/remote_bitrate_estimator/aimd_rate_control.cc index 2ca298b7fa..bf7119cc7d 100644 --- a/modules/remote_bitrate_estimator/aimd_rate_control.cc +++ b/modules/remote_bitrate_estimator/aimd_rate_control.cc @@ -362,7 +362,7 @@ void AimdRateControl::ChangeBitrate(const RateControlInput& input, break; } default: - assert(false); + RTC_NOTREACHED(); } current_bitrate_ = ClampBitrate(new_bitrate.value_or(current_bitrate_)); @@ -417,7 +417,7 @@ void AimdRateControl::ChangeState(const RateControlInput& input, rate_control_state_ = RateControlState::kRcHold; break; default: - assert(false); + RTC_NOTREACHED(); } } diff --git a/modules/remote_bitrate_estimator/include/remote_bitrate_estimator.h b/modules/remote_bitrate_estimator/include/remote_bitrate_estimator.h index c60c030e8d..ac937bbfe0 100644 --- a/modules/remote_bitrate_estimator/include/remote_bitrate_estimator.h +++ b/modules/remote_bitrate_estimator/include/remote_bitrate_estimator.h @@ -38,14 +38,6 @@ class RemoteBitrateObserver { virtual ~RemoteBitrateObserver() {} }; -class TransportFeedbackSenderInterface { - public: - virtual ~TransportFeedbackSenderInterface() = default; - - virtual bool SendCombinedRtcpPacket( - std::vector> packets) = 0; -}; - // TODO(holmer): Remove when all implementations have been updated. struct ReceiveBandwidthEstimatorStats {}; diff --git a/modules/remote_bitrate_estimator/inter_arrival.cc b/modules/remote_bitrate_estimator/inter_arrival.cc index b8e683b89a..a8cf47fbfe 100644 --- a/modules/remote_bitrate_estimator/inter_arrival.cc +++ b/modules/remote_bitrate_estimator/inter_arrival.cc @@ -37,9 +37,9 @@ bool InterArrival::ComputeDeltas(uint32_t timestamp, uint32_t* timestamp_delta, int64_t* arrival_time_delta_ms, int* packet_size_delta) { - assert(timestamp_delta != NULL); - assert(arrival_time_delta_ms != NULL); - assert(packet_size_delta != NULL); + RTC_DCHECK(timestamp_delta); + RTC_DCHECK(arrival_time_delta_ms); + RTC_DCHECK(packet_size_delta); bool calculated_deltas = false; if (current_timestamp_group_.IsFirstPacket()) { // We don't have enough data to update the filter, so we store it until we @@ -85,7 +85,7 @@ bool InterArrival::ComputeDeltas(uint32_t timestamp, } else { num_consecutive_reordered_packets_ = 0; } - assert(*arrival_time_delta_ms >= 0); + RTC_DCHECK_GE(*arrival_time_delta_ms, 0); *packet_size_delta = static_cast(current_timestamp_group_.size) - static_cast(prev_timestamp_group_.size); calculated_deltas = true; @@ -141,7 +141,7 @@ bool InterArrival::BelongsToBurst(int64_t arrival_time_ms, if (!burst_grouping_) { return false; } - assert(current_timestamp_group_.complete_time_ms >= 0); + RTC_DCHECK_GE(current_timestamp_group_.complete_time_ms, 0); int64_t arrival_time_delta_ms = arrival_time_ms - current_timestamp_group_.complete_time_ms; uint32_t timestamp_diff = timestamp - current_timestamp_group_.timestamp; diff --git a/modules/remote_bitrate_estimator/overuse_estimator.cc b/modules/remote_bitrate_estimator/overuse_estimator.cc index 74449bec66..3427d5880c 100644 --- a/modules/remote_bitrate_estimator/overuse_estimator.cc +++ b/modules/remote_bitrate_estimator/overuse_estimator.cc @@ -110,7 +110,7 @@ void OveruseEstimator::Update(int64_t t_delta, bool positive_semi_definite = E_[0][0] + E_[1][1] >= 0 && E_[0][0] * E_[1][1] - E_[0][1] * E_[1][0] >= 0 && E_[0][0] >= 0; - assert(positive_semi_definite); + RTC_DCHECK(positive_semi_definite); if (!positive_semi_definite) { RTC_LOG(LS_ERROR) << "The over-use estimator's covariance matrix is no longer " diff --git a/modules/remote_bitrate_estimator/packet_arrival_map.cc b/modules/remote_bitrate_estimator/packet_arrival_map.cc new file mode 100644 index 0000000000..72696f6c80 --- /dev/null +++ b/modules/remote_bitrate_estimator/packet_arrival_map.cc @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/remote_bitrate_estimator/packet_arrival_map.h" + +#include + +#include "rtc_base/numerics/safe_minmax.h" + +namespace webrtc { + +constexpr size_t PacketArrivalTimeMap::kMaxNumberOfPackets; + +void PacketArrivalTimeMap::AddPacket(int64_t sequence_number, + int64_t arrival_time_ms) { + if (!has_seen_packet_) { + // First packet. + has_seen_packet_ = true; + begin_sequence_number_ = sequence_number; + arrival_times.push_back(arrival_time_ms); + return; + } + + int64_t pos = sequence_number - begin_sequence_number_; + if (pos >= 0 && pos < static_cast(arrival_times.size())) { + // The packet is within the buffer - no need to expand it. + arrival_times[pos] = arrival_time_ms; + return; + } + + if (pos < 0) { + // The packet goes before the current buffer. Expand to add packet, but only + // if it fits within kMaxNumberOfPackets. + size_t missing_packets = -pos; + if (missing_packets + arrival_times.size() > kMaxNumberOfPackets) { + // Don't expand the buffer further, as that would remove newly received + // packets. + return; + } + + arrival_times.insert(arrival_times.begin(), missing_packets, 0); + arrival_times[0] = arrival_time_ms; + begin_sequence_number_ = sequence_number; + return; + } + + // The packet goes after the buffer. + + if (static_cast(pos) >= kMaxNumberOfPackets) { + // The buffer grows too large - old packets have to be removed. + size_t packets_to_remove = pos - kMaxNumberOfPackets + 1; + if (packets_to_remove >= arrival_times.size()) { + arrival_times.clear(); + begin_sequence_number_ = sequence_number; + pos = 0; + } else { + // Also trim the buffer to remove leading non-received packets, to + // ensure that the buffer only spans received packets. + while (packets_to_remove < arrival_times.size() && + arrival_times[packets_to_remove] == 0) { + ++packets_to_remove; + } + + arrival_times.erase(arrival_times.begin(), + arrival_times.begin() + packets_to_remove); + begin_sequence_number_ += packets_to_remove; + pos -= packets_to_remove; + RTC_DCHECK_GE(pos, 0); + } + } + + // Packets can be received out-of-order. If this isn't the next expected + // packet, add enough placeholders to fill the gap. + size_t missing_gap_packets = pos - arrival_times.size(); + if (missing_gap_packets > 0) { + arrival_times.insert(arrival_times.end(), missing_gap_packets, 0); + } + RTC_DCHECK_EQ(arrival_times.size(), pos); + arrival_times.push_back(arrival_time_ms); + RTC_DCHECK_LE(arrival_times.size(), kMaxNumberOfPackets); +} + +void PacketArrivalTimeMap::RemoveOldPackets(int64_t sequence_number, + int64_t arrival_time_limit) { + while (!arrival_times.empty() && begin_sequence_number_ < sequence_number && + arrival_times.front() <= arrival_time_limit) { + arrival_times.pop_front(); + ++begin_sequence_number_; + } +} + +bool PacketArrivalTimeMap::has_received(int64_t sequence_number) const { + int64_t pos = sequence_number - begin_sequence_number_; + if (pos >= 0 && pos < static_cast(arrival_times.size()) && + arrival_times[pos] != 0) { + return true; + } + return false; +} + +void PacketArrivalTimeMap::EraseTo(int64_t sequence_number) { + if (sequence_number > begin_sequence_number_) { + size_t count = + std::min(static_cast(sequence_number - begin_sequence_number_), + arrival_times.size()); + + arrival_times.erase(arrival_times.begin(), arrival_times.begin() + count); + begin_sequence_number_ += count; + } +} + +int64_t PacketArrivalTimeMap::clamp(int64_t sequence_number) const { + return rtc::SafeClamp(sequence_number, begin_sequence_number(), + end_sequence_number()); +} + +} // namespace webrtc diff --git a/modules/remote_bitrate_estimator/packet_arrival_map.h b/modules/remote_bitrate_estimator/packet_arrival_map.h new file mode 100644 index 0000000000..10659e0f65 --- /dev/null +++ b/modules/remote_bitrate_estimator/packet_arrival_map.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef MODULES_REMOTE_BITRATE_ESTIMATOR_PACKET_ARRIVAL_MAP_H_ +#define MODULES_REMOTE_BITRATE_ESTIMATOR_PACKET_ARRIVAL_MAP_H_ + +#include +#include +#include + +#include "rtc_base/checks.h" + +namespace webrtc { + +// PacketArrivalTimeMap is an optimized map of packet sequence number to arrival +// time, limited in size to never exceed `kMaxNumberOfPackets`. It will grow as +// needed, and remove old packets, and will expand to allow earlier packets to +// be added (out-of-order). +// +// Not yet received packets have the arrival time zero. The queue will not span +// larger than necessary and the last packet should always be received. The +// first packet in the queue doesn't have to be received in case of receiving +// packets out-of-order. +class PacketArrivalTimeMap { + public: + // Impossible to request feedback older than what can be represented by 15 + // bits. + static constexpr size_t kMaxNumberOfPackets = (1 << 15); + + // Indicates if the packet with `sequence_number` has already been received. + bool has_received(int64_t sequence_number) const; + + // Returns the sequence number of the first entry in the map, i.e. the + // sequence number that a `begin()` iterator would represent. + int64_t begin_sequence_number() const { return begin_sequence_number_; } + + // Returns the sequence number of the element just after the map, i.e. the + // sequence number that an `end()` iterator would represent. + int64_t end_sequence_number() const { + return begin_sequence_number_ + arrival_times.size(); + } + + // Returns an element by `sequence_number`, which must be valid, i.e. + // between [begin_sequence_number, end_sequence_number). + int64_t get(int64_t sequence_number) { + int64_t pos = sequence_number - begin_sequence_number_; + RTC_DCHECK(pos >= 0 && pos < static_cast(arrival_times.size())); + return arrival_times[pos]; + } + + // Clamps `sequence_number` between [begin_sequence_number, + // end_sequence_number]. + int64_t clamp(int64_t sequence_number) const; + + // Erases all elements from the beginning of the map until `sequence_number`. + void EraseTo(int64_t sequence_number); + + // Records the fact that a packet with `sequence_number` arrived at + // `arrival_time_ms`. + void AddPacket(int64_t sequence_number, int64_t arrival_time_ms); + + // Removes packets from the beginning of the map as long as they are received + // before `sequence_number` and with an age older than `arrival_time_limit` + void RemoveOldPackets(int64_t sequence_number, int64_t arrival_time_limit); + + private: + // Deque representing unwrapped sequence number -> time, where the index + + // `begin_sequence_number_` represents the packet's sequence number. + std::deque arrival_times; + + // The unwrapped sequence number for the first element in + // `arrival_times`. + int64_t begin_sequence_number_ = 0; + + // Indicates if this map has had any packet added to it. The first packet + // decides the initial sequence number. + bool has_seen_packet_ = false; +}; + +} // namespace webrtc + +#endif // MODULES_REMOTE_BITRATE_ESTIMATOR_PACKET_ARRIVAL_MAP_H_ diff --git a/modules/remote_bitrate_estimator/packet_arrival_map_test.cc b/modules/remote_bitrate_estimator/packet_arrival_map_test.cc new file mode 100644 index 0000000000..afc7038832 --- /dev/null +++ b/modules/remote_bitrate_estimator/packet_arrival_map_test.cc @@ -0,0 +1,252 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/remote_bitrate_estimator/packet_arrival_map.h" + +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +TEST(PacketArrivalMapTest, IsConsistentWhenEmpty) { + PacketArrivalTimeMap map; + + EXPECT_EQ(map.begin_sequence_number(), map.end_sequence_number()); + EXPECT_FALSE(map.has_received(0)); + EXPECT_EQ(map.clamp(-5), 0); + EXPECT_EQ(map.clamp(5), 0); +} + +TEST(PacketArrivalMapTest, InsertsFirstItemIntoMap) { + PacketArrivalTimeMap map; + + map.AddPacket(42, 10); + EXPECT_EQ(map.begin_sequence_number(), 42); + EXPECT_EQ(map.end_sequence_number(), 43); + + EXPECT_FALSE(map.has_received(41)); + EXPECT_TRUE(map.has_received(42)); + EXPECT_FALSE(map.has_received(44)); + + EXPECT_EQ(map.clamp(-100), 42); + EXPECT_EQ(map.clamp(42), 42); + EXPECT_EQ(map.clamp(100), 43); +} + +TEST(PacketArrivalMapTest, InsertsWithGaps) { + PacketArrivalTimeMap map; + + map.AddPacket(42, 10); + map.AddPacket(45, 11); + EXPECT_EQ(map.begin_sequence_number(), 42); + EXPECT_EQ(map.end_sequence_number(), 46); + + EXPECT_FALSE(map.has_received(41)); + EXPECT_TRUE(map.has_received(42)); + EXPECT_FALSE(map.has_received(43)); + EXPECT_FALSE(map.has_received(44)); + EXPECT_TRUE(map.has_received(45)); + EXPECT_FALSE(map.has_received(46)); + + EXPECT_EQ(map.get(42), 10); + EXPECT_EQ(map.get(43), 0); + EXPECT_EQ(map.get(44), 0); + EXPECT_EQ(map.get(45), 11); + + EXPECT_EQ(map.clamp(-100), 42); + EXPECT_EQ(map.clamp(44), 44); + EXPECT_EQ(map.clamp(100), 46); +} + +TEST(PacketArrivalMapTest, InsertsWithinBuffer) { + PacketArrivalTimeMap map; + + map.AddPacket(42, 10); + map.AddPacket(45, 11); + + map.AddPacket(43, 12); + map.AddPacket(44, 13); + + EXPECT_EQ(map.begin_sequence_number(), 42); + EXPECT_EQ(map.end_sequence_number(), 46); + + EXPECT_FALSE(map.has_received(41)); + EXPECT_TRUE(map.has_received(42)); + EXPECT_TRUE(map.has_received(43)); + EXPECT_TRUE(map.has_received(44)); + EXPECT_TRUE(map.has_received(45)); + EXPECT_FALSE(map.has_received(46)); + + EXPECT_EQ(map.get(42), 10); + EXPECT_EQ(map.get(43), 12); + EXPECT_EQ(map.get(44), 13); + EXPECT_EQ(map.get(45), 11); +} + +TEST(PacketArrivalMapTest, GrowsBufferAndRemoveOld) { + PacketArrivalTimeMap map; + + constexpr int64_t kLargeSeq = 42 + PacketArrivalTimeMap::kMaxNumberOfPackets; + map.AddPacket(42, 10); + map.AddPacket(43, 11); + map.AddPacket(44, 12); + map.AddPacket(45, 13); + map.AddPacket(kLargeSeq, 12); + + EXPECT_EQ(map.begin_sequence_number(), 43); + EXPECT_EQ(map.end_sequence_number(), kLargeSeq + 1); + EXPECT_EQ(static_cast(map.end_sequence_number() - + map.begin_sequence_number()), + PacketArrivalTimeMap::kMaxNumberOfPackets); + + EXPECT_FALSE(map.has_received(41)); + EXPECT_FALSE(map.has_received(42)); + EXPECT_TRUE(map.has_received(43)); + EXPECT_TRUE(map.has_received(44)); + EXPECT_TRUE(map.has_received(45)); + EXPECT_FALSE(map.has_received(46)); + EXPECT_TRUE(map.has_received(kLargeSeq)); + EXPECT_FALSE(map.has_received(kLargeSeq + 1)); +} + +TEST(PacketArrivalMapTest, GrowsBufferAndRemoveOldTrimsBeginning) { + PacketArrivalTimeMap map; + + constexpr int64_t kLargeSeq = 42 + PacketArrivalTimeMap::kMaxNumberOfPackets; + map.AddPacket(42, 10); + // Missing: 43, 44 + map.AddPacket(45, 13); + map.AddPacket(kLargeSeq, 12); + + EXPECT_EQ(map.begin_sequence_number(), 45); + EXPECT_EQ(map.end_sequence_number(), kLargeSeq + 1); + + EXPECT_FALSE(map.has_received(44)); + EXPECT_TRUE(map.has_received(45)); + EXPECT_FALSE(map.has_received(46)); + EXPECT_TRUE(map.has_received(kLargeSeq)); + EXPECT_FALSE(map.has_received(kLargeSeq + 1)); +} + +TEST(PacketArrivalMapTest, SequenceNumberJumpsDeletesAll) { + PacketArrivalTimeMap map; + + constexpr int64_t kLargeSeq = + 42 + 2 * PacketArrivalTimeMap::kMaxNumberOfPackets; + map.AddPacket(42, 10); + map.AddPacket(kLargeSeq, 12); + + EXPECT_EQ(map.begin_sequence_number(), kLargeSeq); + EXPECT_EQ(map.end_sequence_number(), kLargeSeq + 1); + + EXPECT_FALSE(map.has_received(42)); + EXPECT_TRUE(map.has_received(kLargeSeq)); + EXPECT_FALSE(map.has_received(kLargeSeq + 1)); +} + +TEST(PacketArrivalMapTest, ExpandsBeforeBeginning) { + PacketArrivalTimeMap map; + + map.AddPacket(42, 10); + map.AddPacket(-1000, 13); + + EXPECT_EQ(map.begin_sequence_number(), -1000); + EXPECT_EQ(map.end_sequence_number(), 43); + + EXPECT_FALSE(map.has_received(-1001)); + EXPECT_TRUE(map.has_received(-1000)); + EXPECT_FALSE(map.has_received(-999)); + EXPECT_TRUE(map.has_received(42)); + EXPECT_FALSE(map.has_received(43)); +} + +TEST(PacketArrivalMapTest, ExpandingBeforeBeginningKeepsReceived) { + PacketArrivalTimeMap map; + + map.AddPacket(42, 10); + constexpr int64_t kSmallSeq = + static_cast(42) - 2 * PacketArrivalTimeMap::kMaxNumberOfPackets; + map.AddPacket(kSmallSeq, 13); + + EXPECT_EQ(map.begin_sequence_number(), 42); + EXPECT_EQ(map.end_sequence_number(), 43); +} + +TEST(PacketArrivalMapTest, ErasesToRemoveElements) { + PacketArrivalTimeMap map; + + map.AddPacket(42, 10); + map.AddPacket(43, 11); + map.AddPacket(44, 12); + map.AddPacket(45, 13); + + map.EraseTo(44); + + EXPECT_EQ(map.begin_sequence_number(), 44); + EXPECT_EQ(map.end_sequence_number(), 46); + + EXPECT_FALSE(map.has_received(43)); + EXPECT_TRUE(map.has_received(44)); + EXPECT_TRUE(map.has_received(45)); + EXPECT_FALSE(map.has_received(46)); +} + +TEST(PacketArrivalMapTest, ErasesInEmptyMap) { + PacketArrivalTimeMap map; + + EXPECT_EQ(map.begin_sequence_number(), map.end_sequence_number()); + + map.EraseTo(map.end_sequence_number()); + EXPECT_EQ(map.begin_sequence_number(), map.end_sequence_number()); +} + +TEST(PacketArrivalMapTest, IsTolerantToWrongArgumentsForErase) { + PacketArrivalTimeMap map; + + map.AddPacket(42, 10); + map.AddPacket(43, 11); + + map.EraseTo(1); + + EXPECT_EQ(map.begin_sequence_number(), 42); + EXPECT_EQ(map.end_sequence_number(), 44); + + map.EraseTo(100); + + EXPECT_EQ(map.begin_sequence_number(), 44); + EXPECT_EQ(map.end_sequence_number(), 44); +} + +TEST(PacketArrivalMapTest, EraseAllRemembersBeginningSeqNbr) { + PacketArrivalTimeMap map; + + map.AddPacket(42, 10); + map.AddPacket(43, 11); + map.AddPacket(44, 12); + map.AddPacket(45, 13); + + map.EraseTo(46); + + map.AddPacket(50, 10); + + EXPECT_EQ(map.begin_sequence_number(), 46); + EXPECT_EQ(map.end_sequence_number(), 51); + + EXPECT_FALSE(map.has_received(45)); + EXPECT_FALSE(map.has_received(46)); + EXPECT_FALSE(map.has_received(47)); + EXPECT_FALSE(map.has_received(48)); + EXPECT_FALSE(map.has_received(49)); + EXPECT_TRUE(map.has_received(50)); + EXPECT_FALSE(map.has_received(51)); +} + +} // namespace +} // namespace webrtc diff --git a/modules/remote_bitrate_estimator/remote_bitrate_estimator_abs_send_time.cc b/modules/remote_bitrate_estimator/remote_bitrate_estimator_abs_send_time.cc index 4196f6dc57..ae960ab960 100644 --- a/modules/remote_bitrate_estimator/remote_bitrate_estimator_abs_send_time.cc +++ b/modules/remote_bitrate_estimator/remote_bitrate_estimator_abs_send_time.cc @@ -13,18 +13,36 @@ #include #include +#include +#include #include "api/transport/field_trial_based_config.h" +#include "api/units/data_rate.h" +#include "api/units/data_size.h" +#include "api/units/time_delta.h" +#include "api/units/timestamp.h" #include "modules/remote_bitrate_estimator/include/bwe_defines.h" #include "modules/remote_bitrate_estimator/include/remote_bitrate_estimator.h" #include "rtc_base/checks.h" -#include "rtc_base/constructor_magic.h" #include "rtc_base/logging.h" #include "rtc_base/thread_annotations.h" #include "system_wrappers/include/metrics.h" namespace webrtc { namespace { + +constexpr TimeDelta kMinClusterDelta = TimeDelta::Millis(1); +constexpr TimeDelta kInitialProbingInterval = TimeDelta::Seconds(2); +constexpr int kTimestampGroupLengthMs = 5; +constexpr int kAbsSendTimeInterArrivalUpshift = 8; +constexpr int kInterArrivalShift = + RTPHeaderExtension::kAbsSendTimeFraction + kAbsSendTimeInterArrivalUpshift; +constexpr int kMinClusterSize = 4; +constexpr int kMaxProbePackets = 15; +constexpr int kExpectedNumberOfProbes = 3; +constexpr double kTimestampToMs = + 1000.0 / static_cast(1 << kInterArrivalShift); + absl::optional OptionalRateFromOptionalBps( absl::optional bitrate_bps) { if (bitrate_bps) { @@ -33,62 +51,48 @@ absl::optional OptionalRateFromOptionalBps( return absl::nullopt; } } -} // namespace - -enum { - kTimestampGroupLengthMs = 5, - kAbsSendTimeInterArrivalUpshift = 8, - kInterArrivalShift = RTPHeaderExtension::kAbsSendTimeFraction + - kAbsSendTimeInterArrivalUpshift, - kInitialProbingIntervalMs = 2000, - kMinClusterSize = 4, - kMaxProbePackets = 15, - kExpectedNumberOfProbes = 3 -}; - -static const double kTimestampToMs = - 1000.0 / static_cast(1 << kInterArrivalShift); template std::vector Keys(const std::map& map) { std::vector keys; keys.reserve(map.size()); - for (typename std::map::const_iterator it = map.begin(); - it != map.end(); ++it) { - keys.push_back(it->first); + for (const auto& kv_pair : map) { + keys.push_back(kv_pair.first); } return keys; } -uint32_t ConvertMsTo24Bits(int64_t time_ms) { - uint32_t time_24_bits = - static_cast(((static_cast(time_ms) - << RTPHeaderExtension::kAbsSendTimeFraction) + - 500) / - 1000) & - 0x00FFFFFF; - return time_24_bits; -} +} // namespace RemoteBitrateEstimatorAbsSendTime::~RemoteBitrateEstimatorAbsSendTime() = default; bool RemoteBitrateEstimatorAbsSendTime::IsWithinClusterBounds( - int send_delta_ms, + TimeDelta send_delta, const Cluster& cluster_aggregate) { if (cluster_aggregate.count == 0) return true; - float cluster_mean = cluster_aggregate.send_mean_ms / - static_cast(cluster_aggregate.count); - return fabs(static_cast(send_delta_ms) - cluster_mean) < 2.5f; + TimeDelta cluster_mean = + cluster_aggregate.send_mean / cluster_aggregate.count; + return (send_delta - cluster_mean).Abs() < TimeDelta::Micros(2'500); } -void RemoteBitrateEstimatorAbsSendTime::AddCluster(std::list* clusters, - Cluster* cluster) { - cluster->send_mean_ms /= static_cast(cluster->count); - cluster->recv_mean_ms /= static_cast(cluster->count); - cluster->mean_size /= cluster->count; - clusters->push_back(*cluster); +void RemoteBitrateEstimatorAbsSendTime::MaybeAddCluster( + const Cluster& cluster_aggregate, + std::list& clusters) { + if (cluster_aggregate.count < kMinClusterSize || + cluster_aggregate.send_mean <= TimeDelta::Zero() || + cluster_aggregate.recv_mean <= TimeDelta::Zero()) { + return; + } + + Cluster cluster; + cluster.send_mean = cluster_aggregate.send_mean / cluster_aggregate.count; + cluster.recv_mean = cluster_aggregate.recv_mean / cluster_aggregate.count; + cluster.mean_size = cluster_aggregate.mean_size / cluster_aggregate.count; + cluster.count = cluster_aggregate.count; + cluster.num_above_min_delta = cluster_aggregate.num_above_min_delta; + clusters.push_back(cluster); } RemoteBitrateEstimatorAbsSendTime::RemoteBitrateEstimatorAbsSendTime( @@ -96,91 +100,77 @@ RemoteBitrateEstimatorAbsSendTime::RemoteBitrateEstimatorAbsSendTime( Clock* clock) : clock_(clock), observer_(observer), - inter_arrival_(), - estimator_(), detector_(&field_trials_), - incoming_bitrate_(kBitrateWindowMs, 8000), - incoming_bitrate_initialized_(false), - total_probes_received_(0), - first_packet_time_ms_(-1), - last_update_ms_(-1), - uma_recorded_(false), remote_rate_(&field_trials_) { RTC_DCHECK(clock_); RTC_DCHECK(observer_); RTC_LOG(LS_INFO) << "RemoteBitrateEstimatorAbsSendTime: Instantiating."; } -void RemoteBitrateEstimatorAbsSendTime::ComputeClusters( - std::list* clusters) const { - Cluster current; - int64_t prev_send_time = -1; - int64_t prev_recv_time = -1; - for (std::list::const_iterator it = probes_.begin(); - it != probes_.end(); ++it) { - if (prev_send_time >= 0) { - int send_delta_ms = it->send_time_ms - prev_send_time; - int recv_delta_ms = it->recv_time_ms - prev_recv_time; - if (send_delta_ms >= 1 && recv_delta_ms >= 1) { - ++current.num_above_min_delta; +std::list +RemoteBitrateEstimatorAbsSendTime::ComputeClusters() const { + std::list clusters; + Cluster cluster_aggregate; + Timestamp prev_send_time = Timestamp::MinusInfinity(); + Timestamp prev_recv_time = Timestamp::MinusInfinity(); + for (const Probe& probe : probes_) { + if (prev_send_time.IsFinite()) { + TimeDelta send_delta = probe.send_time - prev_send_time; + TimeDelta recv_delta = probe.recv_time - prev_recv_time; + if (send_delta >= kMinClusterDelta && recv_delta >= kMinClusterDelta) { + ++cluster_aggregate.num_above_min_delta; } - if (!IsWithinClusterBounds(send_delta_ms, current)) { - if (current.count >= kMinClusterSize && current.send_mean_ms > 0.0f && - current.recv_mean_ms > 0.0f) { - AddCluster(clusters, ¤t); - } - current = Cluster(); + if (!IsWithinClusterBounds(send_delta, cluster_aggregate)) { + MaybeAddCluster(cluster_aggregate, clusters); + cluster_aggregate = Cluster(); } - current.send_mean_ms += send_delta_ms; - current.recv_mean_ms += recv_delta_ms; - current.mean_size += it->payload_size; - ++current.count; + cluster_aggregate.send_mean += send_delta; + cluster_aggregate.recv_mean += recv_delta; + cluster_aggregate.mean_size += probe.payload_size; + ++cluster_aggregate.count; } - prev_send_time = it->send_time_ms; - prev_recv_time = it->recv_time_ms; - } - if (current.count >= kMinClusterSize && current.send_mean_ms > 0.0f && - current.recv_mean_ms > 0.0f) { - AddCluster(clusters, ¤t); + prev_send_time = probe.send_time; + prev_recv_time = probe.recv_time; } + MaybeAddCluster(cluster_aggregate, clusters); + return clusters; } -std::list::const_iterator +const RemoteBitrateEstimatorAbsSendTime::Cluster* RemoteBitrateEstimatorAbsSendTime::FindBestProbe( const std::list& clusters) const { - int highest_probe_bitrate_bps = 0; - std::list::const_iterator best_it = clusters.end(); - for (std::list::const_iterator it = clusters.begin(); - it != clusters.end(); ++it) { - if (it->send_mean_ms == 0 || it->recv_mean_ms == 0) + DataRate highest_probe_bitrate = DataRate::Zero(); + const Cluster* best = nullptr; + for (const auto& cluster : clusters) { + if (cluster.send_mean == TimeDelta::Zero() || + cluster.recv_mean == TimeDelta::Zero()) { continue; - if (it->num_above_min_delta > it->count / 2 && - (it->recv_mean_ms - it->send_mean_ms <= 2.0f && - it->send_mean_ms - it->recv_mean_ms <= 5.0f)) { - int probe_bitrate_bps = - std::min(it->GetSendBitrateBps(), it->GetRecvBitrateBps()); - if (probe_bitrate_bps > highest_probe_bitrate_bps) { - highest_probe_bitrate_bps = probe_bitrate_bps; - best_it = it; + } + if (cluster.num_above_min_delta > cluster.count / 2 && + (cluster.recv_mean - cluster.send_mean <= TimeDelta::Millis(2) && + cluster.send_mean - cluster.recv_mean <= TimeDelta::Millis(5))) { + DataRate probe_bitrate = + std::min(cluster.SendBitrate(), cluster.RecvBitrate()); + if (probe_bitrate > highest_probe_bitrate) { + highest_probe_bitrate = probe_bitrate; + best = &cluster; } } else { - int send_bitrate_bps = it->mean_size * 8 * 1000 / it->send_mean_ms; - int recv_bitrate_bps = it->mean_size * 8 * 1000 / it->recv_mean_ms; - RTC_LOG(LS_INFO) << "Probe failed, sent at " << send_bitrate_bps - << " bps, received at " << recv_bitrate_bps - << " bps. Mean send delta: " << it->send_mean_ms - << " ms, mean recv delta: " << it->recv_mean_ms - << " ms, num probes: " << it->count; + RTC_LOG(LS_INFO) << "Probe failed, sent at " + << cluster.SendBitrate().bps() << " bps, received at " + << cluster.RecvBitrate().bps() + << " bps. Mean send delta: " << cluster.send_mean.ms() + << " ms, mean recv delta: " << cluster.recv_mean.ms() + << " ms, num probes: " << cluster.count; break; } } - return best_it; + return best; } RemoteBitrateEstimatorAbsSendTime::ProbeResult -RemoteBitrateEstimatorAbsSendTime::ProcessClusters(int64_t now_ms) { - std::list clusters; - ComputeClusters(&clusters); +RemoteBitrateEstimatorAbsSendTime::ProcessClusters(Timestamp now) { + std::list clusters = ComputeClusters(); if (clusters.empty()) { // If we reach the max number of probe packets and still have no clusters, // we will remove the oldest one. @@ -189,21 +179,18 @@ RemoteBitrateEstimatorAbsSendTime::ProcessClusters(int64_t now_ms) { return ProbeResult::kNoUpdate; } - std::list::const_iterator best_it = FindBestProbe(clusters); - if (best_it != clusters.end()) { - int probe_bitrate_bps = - std::min(best_it->GetSendBitrateBps(), best_it->GetRecvBitrateBps()); + if (const Cluster* best = FindBestProbe(clusters)) { + DataRate probe_bitrate = std::min(best->SendBitrate(), best->RecvBitrate()); // Make sure that a probe sent on a lower bitrate than our estimate can't // reduce the estimate. - if (IsBitrateImproving(probe_bitrate_bps)) { + if (IsBitrateImproving(probe_bitrate)) { RTC_LOG(LS_INFO) << "Probe successful, sent at " - << best_it->GetSendBitrateBps() << " bps, received at " - << best_it->GetRecvBitrateBps() - << " bps. Mean send delta: " << best_it->send_mean_ms - << " ms, mean recv delta: " << best_it->recv_mean_ms - << " ms, num probes: " << best_it->count; - remote_rate_.SetEstimate(DataRate::BitsPerSec(probe_bitrate_bps), - Timestamp::Millis(now_ms)); + << best->SendBitrate().bps() << " bps, received at " + << best->RecvBitrate().bps() + << " bps. Mean send delta: " << best->send_mean.ms() + << " ms, mean recv delta: " << best->recv_mean.ms() + << " ms, num probes: " << best->count; + remote_rate_.SetEstimate(probe_bitrate, now); return ProbeResult::kBitrateUpdated; } } @@ -216,11 +203,11 @@ RemoteBitrateEstimatorAbsSendTime::ProcessClusters(int64_t now_ms) { } bool RemoteBitrateEstimatorAbsSendTime::IsBitrateImproving( - int new_bitrate_bps) const { - bool initial_probe = !remote_rate_.ValidEstimate() && new_bitrate_bps > 0; - bool bitrate_above_estimate = - remote_rate_.ValidEstimate() && - new_bitrate_bps > remote_rate_.LatestEstimate().bps(); + DataRate probe_bitrate) const { + bool initial_probe = + !remote_rate_.ValidEstimate() && probe_bitrate > DataRate::Zero(); + bool bitrate_above_estimate = remote_rate_.ValidEstimate() && + probe_bitrate > remote_rate_.LatestEstimate(); return initial_probe || bitrate_above_estimate; } @@ -235,14 +222,15 @@ void RemoteBitrateEstimatorAbsSendTime::IncomingPacket( "is missing absolute send time extension!"; return; } - IncomingPacketInfo(arrival_time_ms, header.extension.absoluteSendTime, - payload_size, header.ssrc); + IncomingPacketInfo(Timestamp::Millis(arrival_time_ms), + header.extension.absoluteSendTime, + DataSize::Bytes(payload_size), header.ssrc); } void RemoteBitrateEstimatorAbsSendTime::IncomingPacketInfo( - int64_t arrival_time_ms, + Timestamp arrival_time, uint32_t send_time_24bits, - size_t payload_size, + DataSize payload_size, uint32_t ssrc) { RTC_CHECK(send_time_24bits < (1ul << 24)); if (!uma_recorded_) { @@ -253,15 +241,16 @@ void RemoteBitrateEstimatorAbsSendTime::IncomingPacketInfo( // Shift up send time to use the full 32 bits that inter_arrival works with, // so wrapping works properly. uint32_t timestamp = send_time_24bits << kAbsSendTimeInterArrivalUpshift; - int64_t send_time_ms = static_cast(timestamp) * kTimestampToMs; + Timestamp send_time = + Timestamp::Millis(static_cast(timestamp) * kTimestampToMs); - int64_t now_ms = clock_->TimeInMilliseconds(); + Timestamp now = clock_->CurrentTime(); // TODO(holmer): SSRCs are only needed for REMB, should be broken out from // here. // Check if incoming bitrate estimate is valid, and if it needs to be reset. absl::optional incoming_bitrate = - incoming_bitrate_.Rate(arrival_time_ms); + incoming_bitrate_.Rate(arrival_time.ms()); if (incoming_bitrate) { incoming_bitrate_initialized_ = true; } else if (incoming_bitrate_initialized_) { @@ -271,74 +260,82 @@ void RemoteBitrateEstimatorAbsSendTime::IncomingPacketInfo( incoming_bitrate_.Reset(); incoming_bitrate_initialized_ = false; } - incoming_bitrate_.Update(payload_size, arrival_time_ms); + incoming_bitrate_.Update(payload_size.bytes(), arrival_time.ms()); - if (first_packet_time_ms_ == -1) - first_packet_time_ms_ = now_ms; + if (first_packet_time_.IsInfinite()) { + first_packet_time_ = now; + } uint32_t ts_delta = 0; int64_t t_delta = 0; int size_delta = 0; bool update_estimate = false; - uint32_t target_bitrate_bps = 0; + DataRate target_bitrate = DataRate::Zero(); std::vector ssrcs; { MutexLock lock(&mutex_); - TimeoutStreams(now_ms); - RTC_DCHECK(inter_arrival_.get()); - RTC_DCHECK(estimator_.get()); - ssrcs_[ssrc] = now_ms; + TimeoutStreams(now); + RTC_DCHECK(inter_arrival_); + RTC_DCHECK(estimator_); + // TODO(danilchap): Replace 5 lines below with insert_or_assign when that + // c++17 function is available. + auto inserted = ssrcs_.insert(std::make_pair(ssrc, now)); + if (!inserted.second) { + // Already inserted, update. + inserted.first->second = now; + } // For now only try to detect probes while we don't have a valid estimate. // We currently assume that only packets larger than 200 bytes are paced by // the sender. - const size_t kMinProbePacketSize = 200; + static constexpr DataSize kMinProbePacketSize = DataSize::Bytes(200); if (payload_size > kMinProbePacketSize && (!remote_rate_.ValidEstimate() || - now_ms - first_packet_time_ms_ < kInitialProbingIntervalMs)) { + now - first_packet_time_ < kInitialProbingInterval)) { // TODO(holmer): Use a map instead to get correct order? if (total_probes_received_ < kMaxProbePackets) { - int send_delta_ms = -1; - int recv_delta_ms = -1; + TimeDelta send_delta = TimeDelta::Millis(-1); + TimeDelta recv_delta = TimeDelta::Millis(-1); if (!probes_.empty()) { - send_delta_ms = send_time_ms - probes_.back().send_time_ms; - recv_delta_ms = arrival_time_ms - probes_.back().recv_time_ms; + send_delta = send_time - probes_.back().send_time; + recv_delta = arrival_time - probes_.back().recv_time; } - RTC_LOG(LS_INFO) << "Probe packet received: send time=" << send_time_ms - << " ms, recv time=" << arrival_time_ms - << " ms, send delta=" << send_delta_ms - << " ms, recv delta=" << recv_delta_ms << " ms."; + RTC_LOG(LS_INFO) << "Probe packet received: send time=" + << send_time.ms() + << " ms, recv time=" << arrival_time.ms() + << " ms, send delta=" << send_delta.ms() + << " ms, recv delta=" << recv_delta.ms() << " ms."; } - probes_.push_back(Probe(send_time_ms, arrival_time_ms, payload_size)); + probes_.emplace_back(send_time, arrival_time, payload_size); ++total_probes_received_; // Make sure that a probe which updated the bitrate immediately has an // effect by calling the OnReceiveBitrateChanged callback. - if (ProcessClusters(now_ms) == ProbeResult::kBitrateUpdated) + if (ProcessClusters(now) == ProbeResult::kBitrateUpdated) update_estimate = true; } - if (inter_arrival_->ComputeDeltas(timestamp, arrival_time_ms, now_ms, - payload_size, &ts_delta, &t_delta, + if (inter_arrival_->ComputeDeltas(timestamp, arrival_time.ms(), now.ms(), + payload_size.bytes(), &ts_delta, &t_delta, &size_delta)) { double ts_delta_ms = (1000.0 * ts_delta) / (1 << kInterArrivalShift); estimator_->Update(t_delta, ts_delta_ms, size_delta, detector_.State(), - arrival_time_ms); + arrival_time.ms()); detector_.Detect(estimator_->offset(), ts_delta_ms, - estimator_->num_of_deltas(), arrival_time_ms); + estimator_->num_of_deltas(), arrival_time.ms()); } if (!update_estimate) { // Check if it's time for a periodic update or if we should update because // of an over-use. - if (last_update_ms_ == -1 || - now_ms - last_update_ms_ > remote_rate_.GetFeedbackInterval().ms()) { + if (last_update_.IsInfinite() || + now.ms() - last_update_.ms() > + remote_rate_.GetFeedbackInterval().ms()) { update_estimate = true; } else if (detector_.State() == BandwidthUsage::kBwOverusing) { absl::optional incoming_rate = - incoming_bitrate_.Rate(arrival_time_ms); + incoming_bitrate_.Rate(arrival_time.ms()); if (incoming_rate && remote_rate_.TimeToReduceFurther( - Timestamp::Millis(now_ms), - DataRate::BitsPerSec(*incoming_rate))) { + now, DataRate::BitsPerSec(*incoming_rate))) { update_estimate = true; } } @@ -349,18 +346,16 @@ void RemoteBitrateEstimatorAbsSendTime::IncomingPacketInfo( // We also have to update the estimate immediately if we are overusing // and the target bitrate is too high compared to what we are receiving. const RateControlInput input( - detector_.State(), - OptionalRateFromOptionalBps(incoming_bitrate_.Rate(arrival_time_ms))); - target_bitrate_bps = - remote_rate_.Update(&input, Timestamp::Millis(now_ms)) - .bps(); + detector_.State(), OptionalRateFromOptionalBps( + incoming_bitrate_.Rate(arrival_time.ms()))); + target_bitrate = remote_rate_.Update(&input, now); update_estimate = remote_rate_.ValidEstimate(); ssrcs = Keys(ssrcs_); } } if (update_estimate) { - last_update_ms_ = now_ms; - observer_->OnReceiveBitrateChanged(ssrcs, target_bitrate_bps); + last_update_ = now; + observer_->OnReceiveBitrateChanged(ssrcs, target_bitrate.bps()); } } @@ -371,9 +366,9 @@ int64_t RemoteBitrateEstimatorAbsSendTime::TimeUntilNextProcess() { return kDisabledModuleTime; } -void RemoteBitrateEstimatorAbsSendTime::TimeoutStreams(int64_t now_ms) { - for (Ssrcs::iterator it = ssrcs_.begin(); it != ssrcs_.end();) { - if ((now_ms - it->second) > kStreamTimeOutMs) { +void RemoteBitrateEstimatorAbsSendTime::TimeoutStreams(Timestamp now) { + for (auto it = ssrcs_.begin(); it != ssrcs_.end();) { + if (now - it->second > TimeDelta::Millis(kStreamTimeOutMs)) { ssrcs_.erase(it++); } else { ++it; @@ -381,17 +376,17 @@ void RemoteBitrateEstimatorAbsSendTime::TimeoutStreams(int64_t now_ms) { } if (ssrcs_.empty()) { // We can't update the estimate if we don't have any active streams. - inter_arrival_.reset( - new InterArrival((kTimestampGroupLengthMs << kInterArrivalShift) / 1000, - kTimestampToMs, true)); - estimator_.reset(new OveruseEstimator(OverUseDetectorOptions())); + inter_arrival_ = std::make_unique( + (kTimestampGroupLengthMs << kInterArrivalShift) / 1000, kTimestampToMs, + true); + estimator_ = std::make_unique(OverUseDetectorOptions()); // We deliberately don't reset the first_packet_time_ms_ here for now since // we only probe for bandwidth in the beginning of a call right now. } } void RemoteBitrateEstimatorAbsSendTime::OnRttUpdate(int64_t avg_rtt_ms, - int64_t max_rtt_ms) { + int64_t /*max_rtt_ms*/) { MutexLock lock(&mutex_); remote_rate_.SetRtt(TimeDelta::Millis(avg_rtt_ms)); } diff --git a/modules/remote_bitrate_estimator/remote_bitrate_estimator_abs_send_time.h b/modules/remote_bitrate_estimator/remote_bitrate_estimator_abs_send_time.h index f42a28f8c8..4117382577 100644 --- a/modules/remote_bitrate_estimator/remote_bitrate_estimator_abs_send_time.h +++ b/modules/remote_bitrate_estimator/remote_bitrate_estimator_abs_send_time.h @@ -21,6 +21,10 @@ #include "api/rtp_headers.h" #include "api/transport/field_trial_based_config.h" +#include "api/units/data_rate.h" +#include "api/units/data_size.h" +#include "api/units/time_delta.h" +#include "api/units/timestamp.h" #include "modules/remote_bitrate_estimator/aimd_rate_control.h" #include "modules/remote_bitrate_estimator/include/remote_bitrate_estimator.h" #include "modules/remote_bitrate_estimator/inter_arrival.h" @@ -35,42 +39,6 @@ namespace webrtc { -struct Probe { - Probe(int64_t send_time_ms, int64_t recv_time_ms, size_t payload_size) - : send_time_ms(send_time_ms), - recv_time_ms(recv_time_ms), - payload_size(payload_size) {} - int64_t send_time_ms; - int64_t recv_time_ms; - size_t payload_size; -}; - -struct Cluster { - Cluster() - : send_mean_ms(0.0f), - recv_mean_ms(0.0f), - mean_size(0), - count(0), - num_above_min_delta(0) {} - - int GetSendBitrateBps() const { - RTC_CHECK_GT(send_mean_ms, 0.0f); - return mean_size * 8 * 1000 / send_mean_ms; - } - - int GetRecvBitrateBps() const { - RTC_CHECK_GT(recv_mean_ms, 0.0f); - return mean_size * 8 * 1000 / recv_mean_ms; - } - - float send_mean_ms; - float recv_mean_ms; - // TODO(holmer): Add some variance metric as well? - size_t mean_size; - int count; - int num_above_min_delta; -}; - class RemoteBitrateEstimatorAbsSendTime : public RemoteBitrateEstimator { public: RemoteBitrateEstimatorAbsSendTime(RemoteBitrateObserver* observer, @@ -100,32 +68,54 @@ class RemoteBitrateEstimatorAbsSendTime : public RemoteBitrateEstimator { void SetMinBitrate(int min_bitrate_bps) override; private: - typedef std::map Ssrcs; + struct Probe { + Probe(Timestamp send_time, Timestamp recv_time, DataSize payload_size) + : send_time(send_time), + recv_time(recv_time), + payload_size(payload_size) {} + + Timestamp send_time; + Timestamp recv_time; + DataSize payload_size; + }; + + struct Cluster { + DataRate SendBitrate() const { return mean_size / send_mean; } + DataRate RecvBitrate() const { return mean_size / recv_mean; } + + TimeDelta send_mean = TimeDelta::Zero(); + TimeDelta recv_mean = TimeDelta::Zero(); + // TODO(holmer): Add some variance metric as well? + DataSize mean_size = DataSize::Zero(); + int count = 0; + int num_above_min_delta = 0; + }; + enum class ProbeResult { kBitrateUpdated, kNoUpdate }; - static bool IsWithinClusterBounds(int send_delta_ms, + static bool IsWithinClusterBounds(TimeDelta send_delta, const Cluster& cluster_aggregate); - static void AddCluster(std::list* clusters, Cluster* cluster); + static void MaybeAddCluster(const Cluster& cluster_aggregate, + std::list& clusters); - void IncomingPacketInfo(int64_t arrival_time_ms, + void IncomingPacketInfo(Timestamp arrival_time, uint32_t send_time_24bits, - size_t payload_size, + DataSize payload_size, uint32_t ssrc); - void ComputeClusters(std::list* clusters) const; + std::list ComputeClusters() const; - std::list::const_iterator FindBestProbe( - const std::list& clusters) const; + const Cluster* FindBestProbe(const std::list& clusters) const; // Returns true if a probe which changed the estimate was detected. - ProbeResult ProcessClusters(int64_t now_ms) + ProbeResult ProcessClusters(Timestamp now) RTC_EXCLUSIVE_LOCKS_REQUIRED(&mutex_); - bool IsBitrateImproving(int probe_bitrate_bps) const + bool IsBitrateImproving(DataRate probe_bitrate) const RTC_EXCLUSIVE_LOCKS_REQUIRED(&mutex_); - void TimeoutStreams(int64_t now_ms) RTC_EXCLUSIVE_LOCKS_REQUIRED(&mutex_); + void TimeoutStreams(Timestamp now) RTC_EXCLUSIVE_LOCKS_REQUIRED(&mutex_); rtc::RaceChecker network_race_; Clock* const clock_; @@ -134,18 +124,16 @@ class RemoteBitrateEstimatorAbsSendTime : public RemoteBitrateEstimator { std::unique_ptr inter_arrival_; std::unique_ptr estimator_; OveruseDetector detector_; - RateStatistics incoming_bitrate_; - bool incoming_bitrate_initialized_; - std::vector recent_propagation_delta_ms_; - std::vector recent_update_time_ms_; + RateStatistics incoming_bitrate_{kBitrateWindowMs, 8000}; + bool incoming_bitrate_initialized_ = false; std::list probes_; - size_t total_probes_received_; - int64_t first_packet_time_ms_; - int64_t last_update_ms_; - bool uma_recorded_; + size_t total_probes_received_ = 0; + Timestamp first_packet_time_ = Timestamp::MinusInfinity(); + Timestamp last_update_ = Timestamp::MinusInfinity(); + bool uma_recorded_ = false; mutable Mutex mutex_; - Ssrcs ssrcs_ RTC_GUARDED_BY(&mutex_); + std::map ssrcs_ RTC_GUARDED_BY(&mutex_); AimdRateControl remote_rate_ RTC_GUARDED_BY(&mutex_); }; diff --git a/modules/remote_bitrate_estimator/remote_bitrate_estimator_single_stream.cc b/modules/remote_bitrate_estimator/remote_bitrate_estimator_single_stream.cc index 46d8fbc434..ddaa1de088 100644 --- a/modules/remote_bitrate_estimator/remote_bitrate_estimator_single_stream.cc +++ b/modules/remote_bitrate_estimator/remote_bitrate_estimator_single_stream.cc @@ -234,7 +234,7 @@ bool RemoteBitrateEstimatorSingleStream::LatestEstimate( std::vector* ssrcs, uint32_t* bitrate_bps) const { MutexLock lock(&mutex_); - assert(bitrate_bps); + RTC_DCHECK(bitrate_bps); if (!remote_rate_->ValidEstimate()) { return false; } @@ -248,7 +248,7 @@ bool RemoteBitrateEstimatorSingleStream::LatestEstimate( void RemoteBitrateEstimatorSingleStream::GetSsrcs( std::vector* ssrcs) const { - assert(ssrcs); + RTC_DCHECK(ssrcs); ssrcs->resize(overuse_detectors_.size()); int i = 0; for (SsrcOveruseEstimatorMap::const_iterator it = overuse_detectors_.begin(); diff --git a/modules/remote_bitrate_estimator/remote_bitrate_estimator_unittest_helper.cc b/modules/remote_bitrate_estimator/remote_bitrate_estimator_unittest_helper.cc index 5e117942c1..66f8ca053a 100644 --- a/modules/remote_bitrate_estimator/remote_bitrate_estimator_unittest_helper.cc +++ b/modules/remote_bitrate_estimator/remote_bitrate_estimator_unittest_helper.cc @@ -46,7 +46,7 @@ RtpStream::RtpStream(int fps, next_rtcp_time_(rtcp_receive_time), rtp_timestamp_offset_(timestamp_offset), kNtpFracPerMs(4.294967296E6) { - assert(fps_ > 0); + RTC_DCHECK_GT(fps_, 0); } void RtpStream::set_rtp_timestamp_offset(uint32_t offset) { @@ -60,7 +60,7 @@ int64_t RtpStream::GenerateFrame(int64_t time_now_us, PacketList* packets) { if (time_now_us < next_rtp_time_) { return next_rtp_time_; } - assert(packets != NULL); + RTC_DCHECK(packets); size_t bits_per_frame = (bitrate_bps_ + fps_ / 2) / fps_; size_t n_packets = std::max((bits_per_frame + 4 * kMtu) / (8 * kMtu), 1u); @@ -173,9 +173,9 @@ void StreamGenerator::set_rtp_timestamp_offset(uint32_t ssrc, uint32_t offset) { // it possible to simulate different types of channels. int64_t StreamGenerator::GenerateFrame(RtpStream::PacketList* packets, int64_t time_now_us) { - assert(packets != NULL); - assert(packets->empty()); - assert(capacity_ > 0); + RTC_DCHECK(packets); + RTC_DCHECK(packets->empty()); + RTC_DCHECK_GT(capacity_, 0); StreamMap::iterator it = std::min_element(streams_.begin(), streams_.end(), RtpStream::Compare); (*it).second->GenerateFrame(time_now_us, packets); diff --git a/modules/remote_bitrate_estimator/remote_estimator_proxy.cc b/modules/remote_bitrate_estimator/remote_estimator_proxy.cc index a9cc170a35..7764e60ef2 100644 --- a/modules/remote_bitrate_estimator/remote_estimator_proxy.cc +++ b/modules/remote_bitrate_estimator/remote_estimator_proxy.cc @@ -23,9 +23,6 @@ namespace webrtc { -// Impossible to request feedback older than what can be represented by 15 bits. -const int RemoteEstimatorProxy::kMaxNumberOfPackets = (1 << 15); - // The maximum allowed value for a timestamp in milliseconds. This is lower // than the numerical limit since we often convert to microseconds. static constexpr int64_t kMaxTimeMs = @@ -33,11 +30,11 @@ static constexpr int64_t kMaxTimeMs = RemoteEstimatorProxy::RemoteEstimatorProxy( Clock* clock, - TransportFeedbackSenderInterface* feedback_sender, + TransportFeedbackSender feedback_sender, const WebRtcKeyValueConfig* key_value_config, NetworkStateEstimator* network_state_estimator) : clock_(clock), - feedback_sender_(feedback_sender), + feedback_sender_(std::move(feedback_sender)), send_config_(key_value_config), last_process_time_ms_(-1), network_state_estimator_(network_state_estimator), @@ -54,6 +51,18 @@ RemoteEstimatorProxy::RemoteEstimatorProxy( RemoteEstimatorProxy::~RemoteEstimatorProxy() {} +void RemoteEstimatorProxy::MaybeCullOldPackets(int64_t sequence_number, + int64_t arrival_time_ms) { + if (periodic_window_start_seq_.has_value()) { + if (*periodic_window_start_seq_ >= + packet_arrival_times_.end_sequence_number()) { + // Start new feedback packet, cull old packets. + packet_arrival_times_.RemoveOldPackets( + sequence_number, arrival_time_ms - send_config_.back_window->ms()); + } + } +} + void RemoteEstimatorProxy::IncomingPacket(int64_t arrival_time_ms, size_t payload_size, const RTPHeader& header) { @@ -69,39 +78,26 @@ void RemoteEstimatorProxy::IncomingPacket(int64_t arrival_time_ms, seq = unwrapper_.Unwrap(header.extension.transportSequenceNumber); if (send_periodic_feedback_) { - if (periodic_window_start_seq_ && - packet_arrival_times_.lower_bound(*periodic_window_start_seq_) == - packet_arrival_times_.end()) { - // Start new feedback packet, cull old packets. - for (auto it = packet_arrival_times_.begin(); - it != packet_arrival_times_.end() && it->first < seq && - arrival_time_ms - it->second >= send_config_.back_window->ms();) { - it = packet_arrival_times_.erase(it); - } - } + MaybeCullOldPackets(seq, arrival_time_ms); + if (!periodic_window_start_seq_ || seq < *periodic_window_start_seq_) { periodic_window_start_seq_ = seq; } } // We are only interested in the first time a packet is received. - if (packet_arrival_times_.find(seq) != packet_arrival_times_.end()) + if (packet_arrival_times_.has_received(seq)) { return; + } - packet_arrival_times_[seq] = arrival_time_ms; + packet_arrival_times_.AddPacket(seq, arrival_time_ms); // Limit the range of sequence numbers to send feedback for. - auto first_arrival_time_to_keep = packet_arrival_times_.lower_bound( - packet_arrival_times_.rbegin()->first - kMaxNumberOfPackets); - if (first_arrival_time_to_keep != packet_arrival_times_.begin()) { - packet_arrival_times_.erase(packet_arrival_times_.begin(), - first_arrival_time_to_keep); - if (send_periodic_feedback_) { - // |packet_arrival_times_| cannot be empty since we just added one - // element and the last element is not deleted. - RTC_DCHECK(!packet_arrival_times_.empty()); - periodic_window_start_seq_ = packet_arrival_times_.begin()->first; - } + if (!periodic_window_start_seq_.has_value() || + periodic_window_start_seq_.value() < + packet_arrival_times_.begin_sequence_number()) { + periodic_window_start_seq_ = + packet_arrival_times_.begin_sequence_number(); } if (header.extension.feedback_request) { @@ -113,8 +109,8 @@ void RemoteEstimatorProxy::IncomingPacket(int64_t arrival_time_ms, if (network_state_estimator_ && header.extension.hasAbsoluteSendTime) { PacketResult packet_result; packet_result.receive_time = Timestamp::Millis(arrival_time_ms); - // Ignore reordering of packets and assume they have approximately the same - // send time. + // Ignore reordering of packets and assume they have approximately the + // same send time. abs_send_timestamp_ += std::max( header.extension.GetAbsoluteSendTimeDelta(previous_abs_send_time_), TimeDelta::Millis(0)); @@ -183,9 +179,9 @@ void RemoteEstimatorProxy::SetSendPeriodicFeedback( } void RemoteEstimatorProxy::SendPeriodicFeedbacks() { - // |periodic_window_start_seq_| is the first sequence number to include in the - // current feedback packet. Some older may still be in the map, in case a - // reordering happens and we need to retransmit them. + // |periodic_window_start_seq_| is the first sequence number to include in + // the current feedback packet. Some older may still be in the map, in case + // a reordering happens and we need to retransmit them. if (!periodic_window_start_seq_) return; @@ -199,15 +195,17 @@ void RemoteEstimatorProxy::SendPeriodicFeedbacks() { } } - for (auto begin_iterator = - packet_arrival_times_.lower_bound(*periodic_window_start_seq_); - begin_iterator != packet_arrival_times_.cend(); - begin_iterator = - packet_arrival_times_.lower_bound(*periodic_window_start_seq_)) { - auto feedback_packet = std::make_unique(); - periodic_window_start_seq_ = BuildFeedbackPacket( - feedback_packet_count_++, media_ssrc_, *periodic_window_start_seq_, - begin_iterator, packet_arrival_times_.cend(), feedback_packet.get()); + int64_t packet_arrival_times_end_seq = + packet_arrival_times_.end_sequence_number(); + while (periodic_window_start_seq_ < packet_arrival_times_end_seq) { + auto feedback_packet = MaybeBuildFeedbackPacket( + /*include_timestamps=*/true, periodic_window_start_seq_.value(), + packet_arrival_times_end_seq, + /*is_periodic_update=*/true); + + if (feedback_packet == nullptr) { + break; + } RTC_DCHECK(feedback_sender_ != nullptr); @@ -217,10 +215,10 @@ void RemoteEstimatorProxy::SendPeriodicFeedbacks() { } packets.push_back(std::move(feedback_packet)); - feedback_sender_->SendCombinedRtcpPacket(std::move(packets)); - // Note: Don't erase items from packet_arrival_times_ after sending, in case - // they need to be re-sent after a reordering. Removal will be handled - // by OnPacketArrival once packets are too old. + feedback_sender_(std::move(packets)); + // Note: Don't erase items from packet_arrival_times_ after sending, in + // case they need to be re-sent after a reordering. Removal will be + // handled by OnPacketArrival once packets are too old. } } @@ -231,61 +229,79 @@ void RemoteEstimatorProxy::SendFeedbackOnRequest( return; } - auto feedback_packet = std::make_unique( - feedback_request.include_timestamps); - int64_t first_sequence_number = sequence_number - feedback_request.sequence_count + 1; - auto begin_iterator = - packet_arrival_times_.lower_bound(first_sequence_number); - auto end_iterator = packet_arrival_times_.upper_bound(sequence_number); - BuildFeedbackPacket(feedback_packet_count_++, media_ssrc_, - first_sequence_number, begin_iterator, end_iterator, - feedback_packet.get()); + auto feedback_packet = MaybeBuildFeedbackPacket( + feedback_request.include_timestamps, first_sequence_number, + sequence_number + 1, /*is_periodic_update=*/false); + + // This is called when a packet has just been added. + RTC_DCHECK(feedback_packet != nullptr); // Clear up to the first packet that is included in this feedback packet. - packet_arrival_times_.erase(packet_arrival_times_.begin(), begin_iterator); + packet_arrival_times_.EraseTo(first_sequence_number); RTC_DCHECK(feedback_sender_ != nullptr); std::vector> packets; packets.push_back(std::move(feedback_packet)); - feedback_sender_->SendCombinedRtcpPacket(std::move(packets)); + feedback_sender_(std::move(packets)); } -int64_t RemoteEstimatorProxy::BuildFeedbackPacket( - uint8_t feedback_packet_count, - uint32_t media_ssrc, - int64_t base_sequence_number, - std::map::const_iterator begin_iterator, - std::map::const_iterator end_iterator, - rtcp::TransportFeedback* feedback_packet) { - RTC_DCHECK(begin_iterator != end_iterator); - - // TODO(sprang): Measure receive times in microseconds and remove the - // conversions below. - feedback_packet->SetMediaSsrc(media_ssrc); - // Base sequence number is the expected first sequence number. This is known, - // but we might not have actually received it, so the base time shall be the - // time of the first received packet in the feedback. - feedback_packet->SetBase(static_cast(base_sequence_number & 0xFFFF), - begin_iterator->second * 1000); - feedback_packet->SetFeedbackSequenceNumber(feedback_packet_count); - int64_t next_sequence_number = base_sequence_number; - for (auto it = begin_iterator; it != end_iterator; ++it) { - if (!feedback_packet->AddReceivedPacket( - static_cast(it->first & 0xFFFF), it->second * 1000)) { - // If we can't even add the first seq to the feedback packet, we won't be - // able to build it at all. - RTC_CHECK(begin_iterator != it); +std::unique_ptr +RemoteEstimatorProxy::MaybeBuildFeedbackPacket( + bool include_timestamps, + int64_t begin_sequence_number_inclusive, + int64_t end_sequence_number_exclusive, + bool is_periodic_update) { + RTC_DCHECK_LT(begin_sequence_number_inclusive, end_sequence_number_exclusive); + + int64_t start_seq = + packet_arrival_times_.clamp(begin_sequence_number_inclusive); + + int64_t end_seq = packet_arrival_times_.clamp(end_sequence_number_exclusive); + + // Create the packet on demand, as it's not certain that there are packets + // in the range that have been received. + std::unique_ptr feedback_packet = nullptr; + + int64_t next_sequence_number = begin_sequence_number_inclusive; + for (int64_t seq = start_seq; seq < end_seq; ++seq) { + int64_t arrival_time_ms = packet_arrival_times_.get(seq); + if (arrival_time_ms == 0) { + // Packet not received. + continue; + } + + if (feedback_packet == nullptr) { + feedback_packet = + std::make_unique(include_timestamps); + // TODO(sprang): Measure receive times in microseconds and remove the + // conversions below. + feedback_packet->SetMediaSsrc(media_ssrc_); + // Base sequence number is the expected first sequence number. This is + // known, but we might not have actually received it, so the base time + // shall be the time of the first received packet in the feedback. + feedback_packet->SetBase( + static_cast(begin_sequence_number_inclusive & 0xFFFF), + arrival_time_ms * 1000); + feedback_packet->SetFeedbackSequenceNumber(feedback_packet_count_++); + } + + if (!feedback_packet->AddReceivedPacket(static_cast(seq & 0xFFFF), + arrival_time_ms * 1000)) { // Could not add timestamp, feedback packet might be full. Return and // try again with a fresh packet. break; } - next_sequence_number = it->first + 1; + + next_sequence_number = seq + 1; + } + if (is_periodic_update) { + periodic_window_start_seq_ = next_sequence_number; } - return next_sequence_number; + return feedback_packet; } } // namespace webrtc diff --git a/modules/remote_bitrate_estimator/remote_estimator_proxy.h b/modules/remote_bitrate_estimator/remote_estimator_proxy.h index a4adefc5ee..4f89409995 100644 --- a/modules/remote_bitrate_estimator/remote_estimator_proxy.h +++ b/modules/remote_bitrate_estimator/remote_estimator_proxy.h @@ -11,12 +11,15 @@ #ifndef MODULES_REMOTE_BITRATE_ESTIMATOR_REMOTE_ESTIMATOR_PROXY_H_ #define MODULES_REMOTE_BITRATE_ESTIMATOR_REMOTE_ESTIMATOR_PROXY_H_ -#include +#include +#include +#include #include #include "api/transport/network_control.h" #include "api/transport/webrtc_key_value_config.h" #include "modules/remote_bitrate_estimator/include/remote_bitrate_estimator.h" +#include "modules/remote_bitrate_estimator/packet_arrival_map.h" #include "rtc_base/experiments/field_trial_parser.h" #include "rtc_base/numerics/sequence_number_util.h" #include "rtc_base/synchronization/mutex.h" @@ -24,7 +27,6 @@ namespace webrtc { class Clock; -class PacketRouter; namespace rtcp { class TransportFeedback; } @@ -32,11 +34,14 @@ class TransportFeedback; // Class used when send-side BWE is enabled: This proxy is instantiated on the // receive side. It buffers a number of receive timestamps and then sends // transport feedback messages back too the send side. - class RemoteEstimatorProxy : public RemoteBitrateEstimator { public: + // Used for sending transport feedback messages when send side + // BWE is used. + using TransportFeedbackSender = std::function> packets)>; RemoteEstimatorProxy(Clock* clock, - TransportFeedbackSenderInterface* feedback_sender, + TransportFeedbackSender feedback_sender, const WebRtcKeyValueConfig* key_value_config, NetworkStateEstimator* network_state_estimator); ~RemoteEstimatorProxy() override; @@ -71,24 +76,33 @@ class RemoteEstimatorProxy : public RemoteBitrateEstimator { } }; - static const int kMaxNumberOfPackets; - + void MaybeCullOldPackets(int64_t sequence_number, int64_t arrival_time_ms) + RTC_EXCLUSIVE_LOCKS_REQUIRED(&lock_); void SendPeriodicFeedbacks() RTC_EXCLUSIVE_LOCKS_REQUIRED(&lock_); void SendFeedbackOnRequest(int64_t sequence_number, const FeedbackRequest& feedback_request) RTC_EXCLUSIVE_LOCKS_REQUIRED(&lock_); - static int64_t BuildFeedbackPacket( - uint8_t feedback_packet_count, - uint32_t media_ssrc, - int64_t base_sequence_number, - std::map::const_iterator - begin_iterator, // |begin_iterator| is inclusive. - std::map::const_iterator - end_iterator, // |end_iterator| is exclusive. - rtcp::TransportFeedback* feedback_packet); + + // Returns a Transport Feedback packet with information about as many packets + // that has been received between [`begin_sequence_number_incl`, + // `end_sequence_number_excl`) that can fit in it. If `is_periodic_update`, + // this represents sending a periodic feedback message, which will make it + // update the `periodic_window_start_seq_` variable with the first packet that + // was not included in the feedback packet, so that the next update can + // continue from that sequence number. + // + // If no incoming packets were added, nullptr is returned. + // + // `include_timestamps` decide if the returned TransportFeedback should + // include timestamps. + std::unique_ptr MaybeBuildFeedbackPacket( + bool include_timestamps, + int64_t begin_sequence_number_inclusive, + int64_t end_sequence_number_exclusive, + bool is_periodic_update) RTC_EXCLUSIVE_LOCKS_REQUIRED(&lock_); Clock* const clock_; - TransportFeedbackSenderInterface* const feedback_sender_; + const TransportFeedbackSender feedback_sender_; const TransportWideFeedbackConfig send_config_; int64_t last_process_time_ms_; @@ -99,9 +113,14 @@ class RemoteEstimatorProxy : public RemoteBitrateEstimator { uint32_t media_ssrc_ RTC_GUARDED_BY(&lock_); uint8_t feedback_packet_count_ RTC_GUARDED_BY(&lock_); SeqNumUnwrapper unwrapper_ RTC_GUARDED_BY(&lock_); + + // The next sequence number that should be the start sequence number during + // periodic reporting. Will be absl::nullopt before the first seen packet. absl::optional periodic_window_start_seq_ RTC_GUARDED_BY(&lock_); - // Map unwrapped seq -> time. - std::map packet_arrival_times_ RTC_GUARDED_BY(&lock_); + + // Packet arrival times, by sequence number. + PacketArrivalTimeMap packet_arrival_times_ RTC_GUARDED_BY(&lock_); + int64_t send_interval_ms_ RTC_GUARDED_BY(&lock_); bool send_periodic_feedback_ RTC_GUARDED_BY(&lock_); diff --git a/modules/remote_bitrate_estimator/remote_estimator_proxy_unittest.cc b/modules/remote_bitrate_estimator/remote_estimator_proxy_unittest.cc index da995922d9..296724fa71 100644 --- a/modules/remote_bitrate_estimator/remote_estimator_proxy_unittest.cc +++ b/modules/remote_bitrate_estimator/remote_estimator_proxy_unittest.cc @@ -16,8 +16,8 @@ #include "api/transport/field_trial_based_config.h" #include "api/transport/network_types.h" #include "api/transport/test/mock_network_control.h" -#include "modules/pacing/packet_router.h" #include "modules/rtp_rtcp/source/rtcp_packet/transport_feedback.h" +#include "modules/rtp_rtcp/source/rtp_header_extensions.h" #include "system_wrappers/include/clock.h" #include "test/gmock.h" #include "test/gtest.h" @@ -25,6 +25,7 @@ using ::testing::_; using ::testing::ElementsAre; using ::testing::Invoke; +using ::testing::MockFunction; using ::testing::Return; using ::testing::SizeIs; @@ -63,20 +64,12 @@ std::vector TimestampsMs( return timestamps; } -class MockTransportFeedbackSender : public TransportFeedbackSenderInterface { - public: - MOCK_METHOD(bool, - SendCombinedRtcpPacket, - (std::vector> feedback_packets), - (override)); -}; - class RemoteEstimatorProxyTest : public ::testing::Test { public: RemoteEstimatorProxyTest() : clock_(0), proxy_(&clock_, - &router_, + feedback_sender_.AsStdFunction(), &field_trial_config_, &network_state_estimator_) {} @@ -113,7 +106,8 @@ class RemoteEstimatorProxyTest : public ::testing::Test { FieldTrialBasedConfig field_trial_config_; SimulatedClock clock_; - ::testing::StrictMock router_; + MockFunction>)> + feedback_sender_; ::testing::NiceMock network_state_estimator_; RemoteEstimatorProxy proxy_; }; @@ -121,7 +115,7 @@ class RemoteEstimatorProxyTest : public ::testing::Test { TEST_F(RemoteEstimatorProxyTest, SendsSinglePacketFeedback) { IncomingPacket(kBaseSeq, kBaseTimeMs); - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -134,7 +128,6 @@ TEST_F(RemoteEstimatorProxyTest, SendsSinglePacketFeedback) { ElementsAre(kBaseSeq)); EXPECT_THAT(TimestampsMs(*feedback_packet), ElementsAre(kBaseTimeMs)); - return true; })); Process(); @@ -144,7 +137,7 @@ TEST_F(RemoteEstimatorProxyTest, DuplicatedPackets) { IncomingPacket(kBaseSeq, kBaseTimeMs); IncomingPacket(kBaseSeq, kBaseTimeMs + 1000); - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -167,13 +160,13 @@ TEST_F(RemoteEstimatorProxyTest, FeedbackWithMissingStart) { // First feedback. IncomingPacket(kBaseSeq, kBaseTimeMs); IncomingPacket(kBaseSeq + 1, kBaseTimeMs + 1000); - EXPECT_CALL(router_, SendCombinedRtcpPacket).WillOnce(Return(true)); + EXPECT_CALL(feedback_sender_, Call); Process(); // Second feedback starts with a missing packet (DROP kBaseSeq + 2). IncomingPacket(kBaseSeq + 3, kBaseTimeMs + 3000); - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -186,7 +179,6 @@ TEST_F(RemoteEstimatorProxyTest, FeedbackWithMissingStart) { ElementsAre(kBaseSeq + 3)); EXPECT_THAT(TimestampsMs(*feedback_packet), ElementsAre(kBaseTimeMs + 3000)); - return true; })); Process(); @@ -197,7 +189,7 @@ TEST_F(RemoteEstimatorProxyTest, SendsFeedbackWithVaryingDeltas) { IncomingPacket(kBaseSeq + 1, kBaseTimeMs + kMaxSmallDeltaMs); IncomingPacket(kBaseSeq + 2, kBaseTimeMs + (2 * kMaxSmallDeltaMs) + 1); - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -211,7 +203,6 @@ TEST_F(RemoteEstimatorProxyTest, SendsFeedbackWithVaryingDeltas) { EXPECT_THAT(TimestampsMs(*feedback_packet), ElementsAre(kBaseTimeMs, kBaseTimeMs + kMaxSmallDeltaMs, kBaseTimeMs + (2 * kMaxSmallDeltaMs) + 1)); - return true; })); Process(); @@ -224,7 +215,7 @@ TEST_F(RemoteEstimatorProxyTest, SendsFragmentedFeedback) { IncomingPacket(kBaseSeq, kBaseTimeMs); IncomingPacket(kBaseSeq + 1, kBaseTimeMs + kTooLargeDelta); - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -237,7 +228,6 @@ TEST_F(RemoteEstimatorProxyTest, SendsFragmentedFeedback) { ElementsAre(kBaseSeq)); EXPECT_THAT(TimestampsMs(*feedback_packet), ElementsAre(kBaseTimeMs)); - return true; })) .WillOnce(Invoke( [](std::vector> feedback_packets) { @@ -251,7 +241,6 @@ TEST_F(RemoteEstimatorProxyTest, SendsFragmentedFeedback) { ElementsAre(kBaseSeq + 1)); EXPECT_THAT(TimestampsMs(*feedback_packet), ElementsAre(kBaseTimeMs + kTooLargeDelta)); - return true; })); Process(); @@ -263,7 +252,7 @@ TEST_F(RemoteEstimatorProxyTest, HandlesReorderingAndWrap) { IncomingPacket(kBaseSeq, kBaseTimeMs); IncomingPacket(kLargeSeq, kBaseTimeMs + kDeltaMs); - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [&](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -274,7 +263,6 @@ TEST_F(RemoteEstimatorProxyTest, HandlesReorderingAndWrap) { EXPECT_THAT(TimestampsMs(*feedback_packet), ElementsAre(kBaseTimeMs + kDeltaMs, kBaseTimeMs)); - return true; })); Process(); @@ -293,7 +281,7 @@ TEST_F(RemoteEstimatorProxyTest, HandlesMalformedSequenceNumbers) { } // Only expect feedback for the last two packets. - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [&](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -306,7 +294,6 @@ TEST_F(RemoteEstimatorProxyTest, HandlesMalformedSequenceNumbers) { EXPECT_THAT(TimestampsMs(*feedback_packet), ElementsAre(kBaseTimeMs + 28 * kDeltaMs, kBaseTimeMs + 29 * kDeltaMs)); - return true; })); Process(); @@ -324,7 +311,7 @@ TEST_F(RemoteEstimatorProxyTest, HandlesBackwardsWrappingSequenceNumbers) { } // Only expect feedback for the first two packets. - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [&](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -336,7 +323,6 @@ TEST_F(RemoteEstimatorProxyTest, HandlesBackwardsWrappingSequenceNumbers) { ElementsAre(kBaseSeq + 40000, kBaseSeq)); EXPECT_THAT(TimestampsMs(*feedback_packet), ElementsAre(kBaseTimeMs + kDeltaMs, kBaseTimeMs)); - return true; })); Process(); @@ -346,7 +332,7 @@ TEST_F(RemoteEstimatorProxyTest, ResendsTimestampsOnReordering) { IncomingPacket(kBaseSeq, kBaseTimeMs); IncomingPacket(kBaseSeq + 2, kBaseTimeMs + 2); - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -359,14 +345,13 @@ TEST_F(RemoteEstimatorProxyTest, ResendsTimestampsOnReordering) { ElementsAre(kBaseSeq, kBaseSeq + 2)); EXPECT_THAT(TimestampsMs(*feedback_packet), ElementsAre(kBaseTimeMs, kBaseTimeMs + 2)); - return true; })); Process(); IncomingPacket(kBaseSeq + 1, kBaseTimeMs + 1); - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -379,7 +364,6 @@ TEST_F(RemoteEstimatorProxyTest, ResendsTimestampsOnReordering) { ElementsAre(kBaseSeq + 1, kBaseSeq + 2)); EXPECT_THAT(TimestampsMs(*feedback_packet), ElementsAre(kBaseTimeMs + 1, kBaseTimeMs + 2)); - return true; })); Process(); @@ -390,7 +374,7 @@ TEST_F(RemoteEstimatorProxyTest, RemovesTimestampsOutOfScope) { IncomingPacket(kBaseSeq + 2, kBaseTimeMs); - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -400,14 +384,13 @@ TEST_F(RemoteEstimatorProxyTest, RemovesTimestampsOutOfScope) { EXPECT_THAT(TimestampsMs(*feedback_packet), ElementsAre(kBaseTimeMs)); - return true; })); Process(); IncomingPacket(kBaseSeq + 3, kTimeoutTimeMs); // kBaseSeq + 2 times out here. - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [&](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -417,7 +400,6 @@ TEST_F(RemoteEstimatorProxyTest, RemovesTimestampsOutOfScope) { EXPECT_THAT(TimestampsMs(*feedback_packet), ElementsAre(kTimeoutTimeMs)); - return true; })); Process(); @@ -427,7 +409,7 @@ TEST_F(RemoteEstimatorProxyTest, RemovesTimestampsOutOfScope) { IncomingPacket(kBaseSeq, kBaseTimeMs - 1); IncomingPacket(kBaseSeq + 1, kTimeoutTimeMs - 1); - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [&](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -440,7 +422,6 @@ TEST_F(RemoteEstimatorProxyTest, RemovesTimestampsOutOfScope) { EXPECT_THAT(TimestampsMs(*feedback_packet), ElementsAre(kBaseTimeMs - 1, kTimeoutTimeMs - 1, kTimeoutTimeMs)); - return true; })); Process(); @@ -496,7 +477,7 @@ TEST_F(RemoteEstimatorProxyOnRequestTest, TimeUntilNextProcessIsHigh) { TEST_F(RemoteEstimatorProxyOnRequestTest, ProcessDoesNotSendFeedback) { proxy_.SetSendPeriodicFeedback(false); IncomingPacket(kBaseSeq, kBaseTimeMs); - EXPECT_CALL(router_, SendCombinedRtcpPacket).Times(0); + EXPECT_CALL(feedback_sender_, Call).Times(0); Process(); } @@ -506,7 +487,7 @@ TEST_F(RemoteEstimatorProxyOnRequestTest, RequestSinglePacketFeedback) { IncomingPacket(kBaseSeq + 1, kBaseTimeMs + kMaxSmallDeltaMs); IncomingPacket(kBaseSeq + 2, kBaseTimeMs + 2 * kMaxSmallDeltaMs); - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -519,7 +500,6 @@ TEST_F(RemoteEstimatorProxyOnRequestTest, RequestSinglePacketFeedback) { ElementsAre(kBaseSeq + 3)); EXPECT_THAT(TimestampsMs(*feedback_packet), ElementsAre(kBaseTimeMs + 3 * kMaxSmallDeltaMs)); - return true; })); constexpr FeedbackRequest kSinglePacketFeedbackRequest = { @@ -535,7 +515,7 @@ TEST_F(RemoteEstimatorProxyOnRequestTest, RequestLastFivePacketFeedback) { IncomingPacket(kBaseSeq + i, kBaseTimeMs + i * kMaxSmallDeltaMs); } - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -553,7 +533,6 @@ TEST_F(RemoteEstimatorProxyOnRequestTest, RequestLastFivePacketFeedback) { kBaseTimeMs + 8 * kMaxSmallDeltaMs, kBaseTimeMs + 9 * kMaxSmallDeltaMs, kBaseTimeMs + 10 * kMaxSmallDeltaMs)); - return true; })); constexpr FeedbackRequest kFivePacketsFeedbackRequest = { @@ -571,7 +550,7 @@ TEST_F(RemoteEstimatorProxyOnRequestTest, IncomingPacket(kBaseSeq + i, kBaseTimeMs + i * kMaxSmallDeltaMs); } - EXPECT_CALL(router_, SendCombinedRtcpPacket) + EXPECT_CALL(feedback_sender_, Call) .WillOnce(Invoke( [](std::vector> feedback_packets) { rtcp::TransportFeedback* feedback_packet = @@ -586,7 +565,6 @@ TEST_F(RemoteEstimatorProxyOnRequestTest, ElementsAre(kBaseTimeMs + 6 * kMaxSmallDeltaMs, kBaseTimeMs + 8 * kMaxSmallDeltaMs, kBaseTimeMs + 10 * kMaxSmallDeltaMs)); - return true; })); constexpr FeedbackRequest kFivePacketsFeedbackRequest = { @@ -658,13 +636,7 @@ TEST_F(RemoteEstimatorProxyTest, SendTransportFeedbackAndNetworkStateUpdate) { AbsoluteSendTime::MsTo24Bits(kBaseTimeMs - 1))); EXPECT_CALL(network_state_estimator_, GetCurrentEstimate()) .WillOnce(Return(NetworkStateEstimate())); - EXPECT_CALL(router_, SendCombinedRtcpPacket) - .WillOnce( - [](std::vector> feedback_packets) { - EXPECT_THAT(feedback_packets, SizeIs(2)); - return true; - }); - + EXPECT_CALL(feedback_sender_, Call(SizeIs(2))); Process(); } diff --git a/modules/remote_bitrate_estimator/tools/bwe_rtp.cc b/modules/remote_bitrate_estimator/tools/bwe_rtp.cc index c0b3a37ba5..403f81fd03 100644 --- a/modules/remote_bitrate_estimator/tools/bwe_rtp.cc +++ b/modules/remote_bitrate_estimator/tools/bwe_rtp.cc @@ -18,10 +18,8 @@ #include "absl/flags/flag.h" #include "absl/flags/parse.h" -#include "modules/remote_bitrate_estimator/remote_bitrate_estimator_abs_send_time.h" -#include "modules/remote_bitrate_estimator/remote_bitrate_estimator_single_stream.h" +#include "modules/rtp_rtcp/include/rtp_header_extension_map.h" #include "test/rtp_file_reader.h" -#include "test/rtp_header_parser.h" ABSL_FLAG(std::string, extension_type, @@ -65,14 +63,11 @@ std::set SsrcFilter() { return ssrcs; } -std::unique_ptr ParseArgsAndSetupEstimator( +bool ParseArgsAndSetupRtpReader( int argc, char** argv, - webrtc::Clock* clock, - webrtc::RemoteBitrateObserver* observer, - std::unique_ptr* rtp_reader, - std::unique_ptr* estimator, - std::string* estimator_used) { + std::unique_ptr& rtp_reader, + webrtc::RtpHeaderExtensionMap& rtp_header_extensions) { absl::ParseCommandLine(argc, argv); std::string filename = InputFile(); @@ -84,16 +79,16 @@ std::unique_ptr ParseArgsAndSetupEstimator( fprintf(stderr, "\n"); if (filename.substr(filename.find_last_of('.')) == ".pcap") { fprintf(stderr, "Opening as pcap\n"); - rtp_reader->reset(webrtc::test::RtpFileReader::Create( + rtp_reader.reset(webrtc::test::RtpFileReader::Create( webrtc::test::RtpFileReader::kPcap, filename.c_str(), SsrcFilter())); } else { fprintf(stderr, "Opening as rtp\n"); - rtp_reader->reset(webrtc::test::RtpFileReader::Create( + rtp_reader.reset(webrtc::test::RtpFileReader::Create( webrtc::test::RtpFileReader::kRtpDump, filename.c_str())); } - if (!*rtp_reader) { + if (!rtp_reader) { fprintf(stderr, "Cannot open input file %s\n", filename.c_str()); - return nullptr; + return false; } fprintf(stderr, "Input file: %s\n\n", filename.c_str()); @@ -105,31 +100,10 @@ std::unique_ptr ParseArgsAndSetupEstimator( fprintf(stderr, "Extension: abs\n"); } else { fprintf(stderr, "Unknown extension type\n"); - return nullptr; + return false; } - // Setup the RTP header parser and the bitrate estimator. - auto parser = webrtc::RtpHeaderParser::CreateForTest(); - parser->RegisterRtpHeaderExtension(extension, ExtensionId()); - if (estimator) { - switch (extension) { - case webrtc::kRtpExtensionAbsoluteSendTime: { - estimator->reset( - new webrtc::RemoteBitrateEstimatorAbsSendTime(observer, clock)); - *estimator_used = "AbsoluteSendTimeRemoteBitrateEstimator"; - break; - } - case webrtc::kRtpExtensionTransmissionTimeOffset: { - estimator->reset( - new webrtc::RemoteBitrateEstimatorSingleStream(observer, clock)); - *estimator_used = "RemoteBitrateEstimator"; - break; - } - default: - assert(false); - return nullptr; - } - } + rtp_header_extensions.RegisterByType(ExtensionId(), extension); - return parser; + return true; } diff --git a/modules/remote_bitrate_estimator/tools/bwe_rtp.h b/modules/remote_bitrate_estimator/tools/bwe_rtp.h index 4285f926b5..3b161db37b 100644 --- a/modules/remote_bitrate_estimator/tools/bwe_rtp.h +++ b/modules/remote_bitrate_estimator/tools/bwe_rtp.h @@ -12,25 +12,14 @@ #define MODULES_REMOTE_BITRATE_ESTIMATOR_TOOLS_BWE_RTP_H_ #include -#include -namespace webrtc { -class Clock; -class RemoteBitrateEstimator; -class RemoteBitrateObserver; -class RtpHeaderParser; -namespace test { -class RtpFileReader; -} -} // namespace webrtc +#include "modules/rtp_rtcp/include/rtp_header_extension_map.h" +#include "test/rtp_file_reader.h" -std::unique_ptr ParseArgsAndSetupEstimator( +bool ParseArgsAndSetupRtpReader( int argc, char** argv, - webrtc::Clock* clock, - webrtc::RemoteBitrateObserver* observer, - std::unique_ptr* rtp_reader, - std::unique_ptr* estimator, - std::string* estimator_used); + std::unique_ptr& rtp_reader, + webrtc::RtpHeaderExtensionMap& rtp_header_extensions); #endif // MODULES_REMOTE_BITRATE_ESTIMATOR_TOOLS_BWE_RTP_H_ diff --git a/modules/remote_bitrate_estimator/tools/rtp_to_text.cc b/modules/remote_bitrate_estimator/tools/rtp_to_text.cc index 7f1e009793..98f502a42e 100644 --- a/modules/remote_bitrate_estimator/tools/rtp_to_text.cc +++ b/modules/remote_bitrate_estimator/tools/rtp_to_text.cc @@ -13,17 +13,19 @@ #include #include "modules/remote_bitrate_estimator/tools/bwe_rtp.h" +#include "modules/rtp_rtcp/include/rtp_header_extension_map.h" +#include "modules/rtp_rtcp/source/rtp_header_extensions.h" +#include "modules/rtp_rtcp/source/rtp_packet.h" #include "rtc_base/format_macros.h" #include "rtc_base/strings/string_builder.h" #include "test/rtp_file_reader.h" -#include "test/rtp_header_parser.h" int main(int argc, char* argv[]) { std::unique_ptr reader; - std::unique_ptr parser(ParseArgsAndSetupEstimator( - argc, argv, nullptr, nullptr, &reader, nullptr, nullptr)); - if (!parser) + webrtc::RtpHeaderExtensionMap rtp_header_extensions; + if (!ParseArgsAndSetupRtpReader(argc, argv, reader, rtp_header_extensions)) { return -1; + } bool arrival_time_only = (argc >= 5 && strncmp(argv[4], "-t", 2) == 0); @@ -35,11 +37,15 @@ int main(int argc, char* argv[]) { int non_zero_ts_offsets = 0; webrtc::test::RtpPacket packet; while (reader->NextPacket(&packet)) { - webrtc::RTPHeader header; - parser->Parse(packet.data, packet.length, &header); - if (header.extension.absoluteSendTime != 0) + webrtc::RtpPacket header(&rtp_header_extensions); + header.Parse(packet.data, packet.length); + uint32_t abs_send_time = 0; + if (header.GetExtension(&abs_send_time) && + abs_send_time != 0) ++non_zero_abs_send_time; - if (header.extension.transmissionTimeOffset != 0) + int32_t toffset = 0; + if (header.GetExtension(&toffset) && + toffset != 0) ++non_zero_ts_offsets; if (arrival_time_only) { rtc::StringBuilder ss; @@ -47,11 +53,9 @@ int main(int argc, char* argv[]) { fprintf(stdout, "%s\n", ss.str().c_str()); } else { fprintf(stdout, "%u %u %d %u %u %d %u %" RTC_PRIuS " %" RTC_PRIuS "\n", - header.sequenceNumber, header.timestamp, - header.extension.transmissionTimeOffset, - header.extension.absoluteSendTime, packet.time_ms, - header.markerBit, header.ssrc, packet.length, - packet.original_length); + header.SequenceNumber(), header.Timestamp(), toffset, + abs_send_time, packet.time_ms, header.Marker(), header.Ssrc(), + packet.length, packet.original_length); } ++packet_counter; } diff --git a/modules/rtp_rtcp/BUILD.gn b/modules/rtp_rtcp/BUILD.gn index e10d8463fd..778baf6e15 100644 --- a/modules/rtp_rtcp/BUILD.gn +++ b/modules/rtp_rtcp/BUILD.gn @@ -52,6 +52,7 @@ rtc_library("rtp_rtcp_format") { "source/rtp_packet.h", "source/rtp_packet_received.h", "source/rtp_packet_to_send.h", + "source/rtp_util.h", "source/rtp_video_layers_allocation_extension.h", ] sources = [ @@ -96,6 +97,7 @@ rtc_library("rtp_rtcp_format") { "source/rtp_packet.cc", "source/rtp_packet_received.cc", "source/rtp_packet_to_send.cc", + "source/rtp_util.cc", "source/rtp_video_layers_allocation_extension.cc", ] @@ -103,21 +105,22 @@ rtc_library("rtp_rtcp_format") { "..:module_api_public", "../../api:array_view", "../../api:function_view", + "../../api:refcountedbase", "../../api:rtp_headers", "../../api:rtp_parameters", + "../../api:scoped_refptr", "../../api/audio_codecs:audio_codecs_api", "../../api/transport:network_control", "../../api/transport/rtp:dependency_descriptor", "../../api/units:time_delta", + "../../api/units:timestamp", "../../api/video:video_frame", "../../api/video:video_layers_allocation", "../../api/video:video_rtp_headers", "../../common_video", "../../rtc_base:checks", - "../../rtc_base:deprecation", "../../rtc_base:divide_round", "../../rtc_base:rtc_base_approved", - "../../rtc_base/system:unused", "../../system_wrappers", "../video_coding:codec_globals_headers", ] @@ -136,18 +139,17 @@ rtc_library("rtp_rtcp") { "include/flexfec_sender.h", "include/receive_statistics.h", "include/remote_ntp_time_estimator.h", - "include/rtp_rtcp.h", # deprecated "include/ulpfec_receiver.h", - "source/absolute_capture_time_receiver.cc", - "source/absolute_capture_time_receiver.h", + "source/absolute_capture_time_interpolator.cc", + "source/absolute_capture_time_interpolator.h", "source/absolute_capture_time_sender.cc", "source/absolute_capture_time_sender.h", "source/active_decode_targets_helper.cc", "source/active_decode_targets_helper.h", + "source/capture_clock_offset_updater.cc", + "source/capture_clock_offset_updater.h", "source/create_video_rtp_depacketizer.cc", "source/create_video_rtp_depacketizer.h", - "source/deprecated/deprecated_rtp_sender_egress.cc", - "source/deprecated/deprecated_rtp_sender_egress.h", "source/dtmf_queue.cc", "source/dtmf_queue.h", "source/fec_private_tables_bursty.cc", @@ -164,6 +166,8 @@ rtc_library("rtp_rtcp") { "source/forward_error_correction_internal.h", "source/packet_loss_stats.cc", "source/packet_loss_stats.h", + "source/packet_sequencer.cc", + "source/packet_sequencer.h", "source/receive_statistics_impl.cc", "source/receive_statistics_impl.h", "source/remote_ntp_time_estimator.cc", @@ -192,8 +196,6 @@ rtc_library("rtp_rtcp") { "source/rtp_packetizer_av1.cc", "source/rtp_packetizer_av1.h", "source/rtp_rtcp_config.h", - "source/rtp_rtcp_impl.cc", - "source/rtp_rtcp_impl.h", "source/rtp_rtcp_impl2.cc", "source/rtp_rtcp_impl2.h", "source/rtp_rtcp_interface.h", @@ -249,7 +251,6 @@ rtc_library("rtp_rtcp") { deps = [ ":rtp_rtcp_format", ":rtp_video_header", - "..:module_api", "..:module_api_public", "..:module_fec_api", "../../api:array_view", @@ -260,6 +261,7 @@ rtc_library("rtp_rtcp") { "../../api:rtp_packet_info", "../../api:rtp_parameters", "../../api:scoped_refptr", + "../../api:sequence_checker", "../../api:transport_api", "../../api/audio_codecs:audio_codecs_api", "../../api/crypto:frame_encryptor_interface", @@ -288,21 +290,21 @@ rtc_library("rtp_rtcp") { "../../logging:rtc_event_rtp_rtcp", "../../modules/audio_coding:audio_coding_module_typedefs", "../../rtc_base:checks", - "../../rtc_base:deprecation", "../../rtc_base:divide_round", "../../rtc_base:gtest_prod", "../../rtc_base:rate_limiter", "../../rtc_base:rtc_base_approved", "../../rtc_base:rtc_numerics", "../../rtc_base:safe_minmax", + "../../rtc_base/containers:flat_map", "../../rtc_base/experiments:field_trial_parser", "../../rtc_base/synchronization:mutex", - "../../rtc_base/synchronization:sequence_checker", "../../rtc_base/system:no_unique_address", "../../rtc_base/task_utils:pending_task_safety_flag", "../../rtc_base/task_utils:repeating_task", "../../rtc_base/task_utils:to_queued_task", "../../rtc_base/time:timestamp_extrapolator", + "../../rtc_base/containers:flat_map", "../../system_wrappers", "../../system_wrappers:metrics", "../remote_bitrate_estimator", @@ -320,8 +322,36 @@ rtc_library("rtp_rtcp") { } rtc_source_set("rtp_rtcp_legacy") { - # TODO(bugs.webrtc.org/11581): The files "source/rtp_rtcp_impl.cc" - # and "source/rtp_rtcp_impl.h" should be moved to this target. + sources = [ + "include/rtp_rtcp.h", + "source/deprecated/deprecated_rtp_sender_egress.cc", + "source/deprecated/deprecated_rtp_sender_egress.h", + "source/rtp_rtcp_impl.cc", + "source/rtp_rtcp_impl.h", + ] + deps = [ + ":rtp_rtcp", + ":rtp_rtcp_format", + "..:module_api", + "..:module_fec_api", + "../../api:rtp_headers", + "../../api:transport_api", + "../../api/rtc_event_log", + "../../api/transport:field_trial_based_config", + "../../api/units:data_rate", + "../../api/video:video_bitrate_allocation", + "../../logging:rtc_event_rtp_rtcp", + "../../rtc_base:checks", + "../../rtc_base:gtest_prod", + "../../rtc_base:rtc_base_approved", + "../../rtc_base/synchronization:mutex", + "../../system_wrappers", + "../remote_bitrate_estimator", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] } rtc_library("rtcp_transceiver") { @@ -343,6 +373,7 @@ rtc_library("rtcp_transceiver") { "../../api:rtp_headers", "../../api:transport_api", "../../api/task_queue", + "../../api/units:timestamp", "../../api/video:video_bitrate_allocation", "../../rtc_base:checks", "../../rtc_base:rtc_base_approved", @@ -352,6 +383,7 @@ rtc_library("rtcp_transceiver") { ] absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/base:core_headers", "//third_party/abseil-cpp/absl/memory", "//third_party/abseil-cpp/absl/types:optional", ] @@ -411,23 +443,33 @@ rtc_library("mock_rtp_rtcp") { absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } +rtc_library("rtp_packetizer_av1_test_helper") { + testonly = true + sources = [ + "source/rtp_packetizer_av1_test_helper.cc", + "source/rtp_packetizer_av1_test_helper.h", + ] +} + if (rtc_include_tests) { - rtc_executable("test_packet_masks_metrics") { - testonly = true + if (!build_with_chromium) { + rtc_executable("test_packet_masks_metrics") { + testonly = true - sources = [ - "test/testFec/average_residual_loss_xor_codes.h", - "test/testFec/test_packet_masks_metrics.cc", - ] + sources = [ + "test/testFec/average_residual_loss_xor_codes.h", + "test/testFec/test_packet_masks_metrics.cc", + ] - deps = [ - ":rtp_rtcp", - "../../test:fileutils", - "../../test:test_main", - "../../test:test_support", - "//testing/gtest", - ] - } # test_packet_masks_metrics + deps = [ + ":rtp_rtcp", + "../../test:fileutils", + "../../test:test_main", + "../../test:test_support", + "//testing/gtest", + ] + } # test_packet_masks_metrics + } rtc_library("rtp_rtcp_modules_tests") { testonly = true @@ -446,10 +488,11 @@ if (rtc_include_tests) { testonly = true sources = [ - "source/absolute_capture_time_receiver_unittest.cc", + "source/absolute_capture_time_interpolator_unittest.cc", "source/absolute_capture_time_sender_unittest.cc", "source/active_decode_targets_helper_unittest.cc", "source/byte_io_unittest.cc", + "source/capture_clock_offset_updater_unittest.cc", "source/fec_private_tables_bursty_unittest.cc", "source/flexfec_header_reader_writer_unittest.cc", "source/flexfec_receiver_unittest.cc", @@ -505,9 +548,11 @@ if (rtc_include_tests) { "source/rtp_rtcp_impl2_unittest.cc", "source/rtp_rtcp_impl_unittest.cc", "source/rtp_sender_audio_unittest.cc", + "source/rtp_sender_egress_unittest.cc", "source/rtp_sender_unittest.cc", "source/rtp_sender_video_unittest.cc", "source/rtp_sequence_number_map_unittest.cc", + "source/rtp_util_unittest.cc", "source/rtp_utility_unittest.cc", "source/rtp_video_layers_allocation_extension_unittest.cc", "source/source_tracker_unittest.cc", @@ -526,6 +571,7 @@ if (rtc_include_tests) { ":fec_test_helper", ":mock_rtp_rtcp", ":rtcp_transceiver", + ":rtp_packetizer_av1_test_helper", ":rtp_rtcp", ":rtp_rtcp_format", ":rtp_rtcp_legacy", @@ -540,6 +586,8 @@ if (rtc_include_tests) { "../../api/rtc_event_log", "../../api/transport:field_trial_based_config", "../../api/transport/rtp:dependency_descriptor", + "../../api/units:data_size", + "../../api/units:time_delta", "../../api/units:timestamp", "../../api/video:encoded_image", "../../api/video:video_bitrate_allocation", diff --git a/modules/rtp_rtcp/include/flexfec_receiver.h b/modules/rtp_rtcp/include/flexfec_receiver.h index f9bac9c7fa..b0caea68ff 100644 --- a/modules/rtp_rtcp/include/flexfec_receiver.h +++ b/modules/rtp_rtcp/include/flexfec_receiver.h @@ -15,11 +15,11 @@ #include +#include "api/sequence_checker.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" #include "modules/rtp_rtcp/include/ulpfec_receiver.h" #include "modules/rtp_rtcp/source/forward_error_correction.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/thread_annotations.h" diff --git a/modules/rtp_rtcp/include/receive_statistics.h b/modules/rtp_rtcp/include/receive_statistics.h index 4e6441340c..ce87b99a42 100644 --- a/modules/rtp_rtcp/include/receive_statistics.h +++ b/modules/rtp_rtcp/include/receive_statistics.h @@ -17,11 +17,9 @@ #include "absl/types/optional.h" #include "call/rtp_packet_sink_interface.h" -#include "modules/include/module.h" #include "modules/rtp_rtcp/include/rtcp_statistics.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" #include "modules/rtp_rtcp/source/rtcp_packet/report_block.h" -#include "rtc_base/deprecation.h" namespace webrtc { @@ -57,7 +55,12 @@ class ReceiveStatistics : public ReceiveStatisticsProvider, public: ~ReceiveStatistics() override = default; + // Returns a thread-safe instance of ReceiveStatistics. + // https://chromium.googlesource.com/chromium/src/+/lkgr/docs/threading_and_tasks.md#threading-lexicon static std::unique_ptr Create(Clock* clock); + // Returns a thread-compatible instance of ReceiveStatistics. + static std::unique_ptr CreateThreadCompatible( + Clock* clock); // Returns a pointer to the statistician of an ssrc. virtual StreamStatistician* GetStatistician(uint32_t ssrc) const = 0; diff --git a/modules/rtp_rtcp/include/rtcp_statistics.h b/modules/rtp_rtcp/include/rtcp_statistics.h index e26c475e31..de70c14943 100644 --- a/modules/rtp_rtcp/include/rtcp_statistics.h +++ b/modules/rtp_rtcp/include/rtcp_statistics.h @@ -17,22 +17,6 @@ namespace webrtc { -// Statistics for an RTCP channel -struct RtcpStatistics { - uint8_t fraction_lost = 0; - int32_t packets_lost = 0; // Defined as a 24 bit signed integer in RTCP - uint32_t extended_highest_sequence_number = 0; - uint32_t jitter = 0; -}; - -class RtcpStatisticsCallback { - public: - virtual ~RtcpStatisticsCallback() {} - - virtual void StatisticsUpdated(const RtcpStatistics& statistics, - uint32_t ssrc) = 0; -}; - // Statistics for RTCP packet types. struct RtcpPacketTypeCounter { RtcpPacketTypeCounter() diff --git a/modules/rtp_rtcp/include/rtp_header_extension_map.h b/modules/rtp_rtcp/include/rtp_header_extension_map.h index ff2d34d60d..72e5541d37 100644 --- a/modules/rtp_rtcp/include/rtp_header_extension_map.h +++ b/modules/rtp_rtcp/include/rtp_header_extension_map.h @@ -19,7 +19,6 @@ #include "api/rtp_parameters.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" #include "rtc_base/checks.h" -#include "rtc_base/deprecation.h" namespace webrtc { diff --git a/modules/rtp_rtcp/include/rtp_rtcp.h b/modules/rtp_rtcp/include/rtp_rtcp.h index 8663296eba..727fc6e649 100644 --- a/modules/rtp_rtcp/include/rtp_rtcp.h +++ b/modules/rtp_rtcp/include/rtp_rtcp.h @@ -12,12 +12,10 @@ #define MODULES_RTP_RTCP_INCLUDE_RTP_RTCP_H_ #include -#include -#include +#include "absl/base/attributes.h" #include "modules/include/module.h" #include "modules/rtp_rtcp/source/rtp_rtcp_interface.h" -#include "rtc_base/deprecation.h" namespace webrtc { @@ -25,52 +23,14 @@ namespace webrtc { class RtpRtcp : public Module, public RtpRtcpInterface { public: // Instantiates a deprecated version of the RtpRtcp module. - static std::unique_ptr RTC_DEPRECATED - Create(const Configuration& configuration) { + static std::unique_ptr ABSL_DEPRECATED("") + Create(const Configuration& configuration) { return DEPRECATED_Create(configuration); } static std::unique_ptr DEPRECATED_Create( const Configuration& configuration); - // (TMMBR) Temporary Max Media Bit Rate - RTC_DEPRECATED virtual bool TMMBR() const = 0; - - RTC_DEPRECATED virtual void SetTMMBRStatus(bool enable) = 0; - - // Returns -1 on failure else 0. - RTC_DEPRECATED virtual int32_t AddMixedCNAME(uint32_t ssrc, - const char* cname) = 0; - - // Returns -1 on failure else 0. - RTC_DEPRECATED virtual int32_t RemoveMixedCNAME(uint32_t ssrc) = 0; - - // Returns remote CName. - // Returns -1 on failure else 0. - RTC_DEPRECATED virtual int32_t RemoteCNAME( - uint32_t remote_ssrc, - char cname[RTCP_CNAME_SIZE]) const = 0; - - // (De)registers RTP header extension type and id. - // Returns -1 on failure else 0. - RTC_DEPRECATED virtual int32_t RegisterSendRtpHeaderExtension( - RTPExtensionType type, - uint8_t id) = 0; - - // (APP) Sets application specific data. - // Returns -1 on failure else 0. - RTC_DEPRECATED virtual int32_t SetRTCPApplicationSpecificData( - uint8_t sub_type, - uint32_t name, - const uint8_t* data, - uint16_t length) = 0; - - // Returns statistics of the amount of data sent. - // Returns -1 on failure else 0. - RTC_DEPRECATED virtual int32_t DataCountersRTP( - size_t* bytes_sent, - uint32_t* packets_sent) const = 0; - // Requests new key frame. // using PLI, https://tools.ietf.org/html/rfc4585#section-6.3.1.1 void SendPictureLossIndication() { SendRTCP(kRtcpPli); } diff --git a/modules/rtp_rtcp/include/rtp_rtcp_defines.h b/modules/rtp_rtcp/include/rtp_rtcp_defines.h index cbc2d92111..998a754cc0 100644 --- a/modules/rtp_rtcp/include/rtp_rtcp_defines.h +++ b/modules/rtp_rtcp/include/rtp_rtcp_defines.h @@ -57,6 +57,7 @@ enum RTPExtensionType : int { kRtpExtensionNone, kRtpExtensionTransmissionTimeOffset, kRtpExtensionAudioLevel, + kRtpExtensionCsrcAudioLevel, kRtpExtensionInbandComfortNoise, kRtpExtensionAbsoluteSendTime, kRtpExtensionAbsoluteCaptureTime, @@ -74,6 +75,7 @@ enum RTPExtensionType : int { kRtpExtensionGenericFrameDescriptor = kRtpExtensionGenericFrameDescriptor00, kRtpExtensionGenericFrameDescriptor02, kRtpExtensionColorSpace, + kRtpExtensionVideoFrameTrackingId, kRtpExtensionNumberOfExtensions // Must be the last entity in the enum. }; @@ -226,8 +228,11 @@ struct RtpPacketSendInfo { RtpPacketSendInfo() = default; uint16_t transport_sequence_number = 0; + // TODO(bugs.webrtc.org/12713): Remove once downstream usage is gone. uint32_t ssrc = 0; - uint16_t rtp_sequence_number = 0; + absl::optional media_ssrc; + uint16_t rtp_sequence_number = 0; // Only valid if |media_ssrc| is set. + uint32_t rtp_timestamp = 0; size_t length = 0; absl::optional packet_type; PacedPacketInfo pacing_info; @@ -264,9 +269,13 @@ class RtcpFeedbackSenderInterface { class StreamFeedbackObserver { public: struct StreamPacketInfo { - uint32_t ssrc; - uint16_t rtp_sequence_number; bool received; + + // |rtp_sequence_number| and |is_retransmission| are only valid if |ssrc| + // is populated. + absl::optional ssrc; + uint16_t rtp_sequence_number; + bool is_retransmission; }; virtual ~StreamFeedbackObserver() = default; diff --git a/modules/rtp_rtcp/mocks/mock_rtp_rtcp.h b/modules/rtp_rtcp/mocks/mock_rtp_rtcp.h index 77289c993b..a7707ecc19 100644 --- a/modules/rtp_rtcp/mocks/mock_rtp_rtcp.h +++ b/modules/rtp_rtcp/mocks/mock_rtp_rtcp.h @@ -34,6 +34,7 @@ class MockRtpRtcpInterface : public RtpRtcpInterface { (const uint8_t* incoming_packet, size_t packet_length), (override)); MOCK_METHOD(void, SetRemoteSSRC, (uint32_t ssrc), (override)); + MOCK_METHOD(void, SetLocalSsrc, (uint32_t ssrc), (override)); MOCK_METHOD(void, SetMaxRtpPacketSize, (size_t size), (override)); MOCK_METHOD(size_t, MaxRtpPacketSize, (), (const, override)); MOCK_METHOD(void, @@ -141,14 +142,14 @@ class MockRtpRtcpInterface : public RtpRtcpInterface { GetSendStreamDataCounters, (StreamDataCounters*, StreamDataCounters*), (const, override)); - MOCK_METHOD(int32_t, - RemoteRTCPStat, - (std::vector * receive_blocks), - (const, override)); MOCK_METHOD(std::vector, GetLatestReportBlockData, (), (const, override)); + MOCK_METHOD(absl::optional, + GetSenderReportStats, + (), + (const, override)); MOCK_METHOD(void, SetRemb, (int64_t bitrate, std::vector ssrcs), diff --git a/modules/rtp_rtcp/source/absolute_capture_time_receiver.cc b/modules/rtp_rtcp/source/absolute_capture_time_interpolator.cc similarity index 70% rename from modules/rtp_rtcp/source/absolute_capture_time_receiver.cc rename to modules/rtp_rtcp/source/absolute_capture_time_interpolator.cc index 529ed7eef6..99fc030aca 100644 --- a/modules/rtp_rtcp/source/absolute_capture_time_receiver.cc +++ b/modules/rtp_rtcp/source/absolute_capture_time_interpolator.cc @@ -8,7 +8,7 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include "modules/rtp_rtcp/source/absolute_capture_time_receiver.h" +#include "modules/rtp_rtcp/source/absolute_capture_time_interpolator.h" #include @@ -20,14 +20,12 @@ namespace { constexpr Timestamp kInvalidLastReceiveTime = Timestamp::MinusInfinity(); } // namespace -constexpr TimeDelta AbsoluteCaptureTimeReceiver::kInterpolationMaxInterval; +constexpr TimeDelta AbsoluteCaptureTimeInterpolator::kInterpolationMaxInterval; -AbsoluteCaptureTimeReceiver::AbsoluteCaptureTimeReceiver(Clock* clock) - : clock_(clock), - remote_to_local_clock_offset_(absl::nullopt), - last_receive_time_(kInvalidLastReceiveTime) {} +AbsoluteCaptureTimeInterpolator::AbsoluteCaptureTimeInterpolator(Clock* clock) + : clock_(clock), last_receive_time_(kInvalidLastReceiveTime) {} -uint32_t AbsoluteCaptureTimeReceiver::GetSource( +uint32_t AbsoluteCaptureTimeInterpolator::GetSource( uint32_t ssrc, rtc::ArrayView csrcs) { if (csrcs.empty()) { @@ -37,15 +35,8 @@ uint32_t AbsoluteCaptureTimeReceiver::GetSource( return csrcs[0]; } -void AbsoluteCaptureTimeReceiver::SetRemoteToLocalClockOffset( - absl::optional value_q32x32) { - MutexLock lock(&mutex_); - - remote_to_local_clock_offset_ = value_q32x32; -} - absl::optional -AbsoluteCaptureTimeReceiver::OnReceivePacket( +AbsoluteCaptureTimeInterpolator::OnReceivePacket( uint32_t source, uint32_t rtp_timestamp, uint32_t rtp_clock_frequency, @@ -81,13 +72,10 @@ AbsoluteCaptureTimeReceiver::OnReceivePacket( extension = *received_extension; } - extension.estimated_capture_clock_offset = AdjustEstimatedCaptureClockOffset( - extension.estimated_capture_clock_offset); - return extension; } -uint64_t AbsoluteCaptureTimeReceiver::InterpolateAbsoluteCaptureTimestamp( +uint64_t AbsoluteCaptureTimeInterpolator::InterpolateAbsoluteCaptureTimestamp( uint32_t rtp_timestamp, uint32_t rtp_clock_frequency, uint32_t last_rtp_timestamp, @@ -101,7 +89,7 @@ uint64_t AbsoluteCaptureTimeReceiver::InterpolateAbsoluteCaptureTimestamp( rtp_clock_frequency; } -bool AbsoluteCaptureTimeReceiver::ShouldInterpolateExtension( +bool AbsoluteCaptureTimeInterpolator::ShouldInterpolateExtension( Timestamp receive_time, uint32_t source, uint32_t rtp_timestamp, @@ -134,17 +122,4 @@ bool AbsoluteCaptureTimeReceiver::ShouldInterpolateExtension( return true; } -absl::optional -AbsoluteCaptureTimeReceiver::AdjustEstimatedCaptureClockOffset( - absl::optional received_value) const { - if (received_value == absl::nullopt || - remote_to_local_clock_offset_ == absl::nullopt) { - return absl::nullopt; - } - - // Do calculations as "unsigned" to make overflows deterministic. - return static_cast(*received_value) + - static_cast(*remote_to_local_clock_offset_); -} - } // namespace webrtc diff --git a/modules/rtp_rtcp/source/absolute_capture_time_receiver.h b/modules/rtp_rtcp/source/absolute_capture_time_interpolator.h similarity index 70% rename from modules/rtp_rtcp/source/absolute_capture_time_receiver.h rename to modules/rtp_rtcp/source/absolute_capture_time_interpolator.h index ce3442b386..89d7f0850c 100644 --- a/modules/rtp_rtcp/source/absolute_capture_time_receiver.h +++ b/modules/rtp_rtcp/source/absolute_capture_time_interpolator.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source @@ -8,8 +8,8 @@ * be found in the AUTHORS file in the root of the source tree. */ -#ifndef MODULES_RTP_RTCP_SOURCE_ABSOLUTE_CAPTURE_TIME_RECEIVER_H_ -#define MODULES_RTP_RTCP_SOURCE_ABSOLUTE_CAPTURE_TIME_RECEIVER_H_ +#ifndef MODULES_RTP_RTCP_SOURCE_ABSOLUTE_CAPTURE_TIME_INTERPOLATOR_H_ +#define MODULES_RTP_RTCP_SOURCE_ABSOLUTE_CAPTURE_TIME_INTERPOLATOR_H_ #include "api/array_view.h" #include "api/rtp_headers.h" @@ -22,7 +22,7 @@ namespace webrtc { // -// Helper class for receiving the |AbsoluteCaptureTime| header extension. +// Helper class for interpolating the |AbsoluteCaptureTime| header extension. // // Supports the "timestamp interpolation" optimization: // A receiver SHOULD memorize the capture system (i.e. CSRC/SSRC), capture @@ -33,25 +33,17 @@ namespace webrtc { // // See: https://webrtc.org/experiments/rtp-hdrext/abs-capture-time/ // -class AbsoluteCaptureTimeReceiver { +class AbsoluteCaptureTimeInterpolator { public: static constexpr TimeDelta kInterpolationMaxInterval = TimeDelta::Millis(5000); - explicit AbsoluteCaptureTimeReceiver(Clock* clock); + explicit AbsoluteCaptureTimeInterpolator(Clock* clock); // Returns the source (i.e. SSRC or CSRC) of the capture system. static uint32_t GetSource(uint32_t ssrc, rtc::ArrayView csrcs); - // Sets the NTP clock offset between the sender system (which may be different - // from the capture system) and the local system. This information is normally - // provided by passing half the value of the Round-Trip Time estimation given - // by RTCP sender reports (see DLSR/DLRR). - // - // Note that the value must be in Q32.32-formatted fixed-point seconds. - void SetRemoteToLocalClockOffset(absl::optional value_q32x32); - // Returns a received header extension, an interpolated header extension, or // |absl::nullopt| if it's not possible to interpolate a header extension. absl::optional OnReceivePacket( @@ -75,16 +67,10 @@ class AbsoluteCaptureTimeReceiver { uint32_t rtp_clock_frequency) const RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - absl::optional AdjustEstimatedCaptureClockOffset( - absl::optional received_value) const - RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - Clock* const clock_; Mutex mutex_; - absl::optional remote_to_local_clock_offset_ RTC_GUARDED_BY(mutex_); - Timestamp last_receive_time_ RTC_GUARDED_BY(mutex_); uint32_t last_source_ RTC_GUARDED_BY(mutex_); @@ -93,8 +79,8 @@ class AbsoluteCaptureTimeReceiver { uint64_t last_absolute_capture_timestamp_ RTC_GUARDED_BY(mutex_); absl::optional last_estimated_capture_clock_offset_ RTC_GUARDED_BY(mutex_); -}; // AbsoluteCaptureTimeReceiver +}; // AbsoluteCaptureTimeInterpolator } // namespace webrtc -#endif // MODULES_RTP_RTCP_SOURCE_ABSOLUTE_CAPTURE_TIME_RECEIVER_H_ +#endif // MODULES_RTP_RTCP_SOURCE_ABSOLUTE_CAPTURE_TIME_INTERPOLATOR_H_ diff --git a/modules/rtp_rtcp/source/absolute_capture_time_receiver_unittest.cc b/modules/rtp_rtcp/source/absolute_capture_time_interpolator_unittest.cc similarity index 61% rename from modules/rtp_rtcp/source/absolute_capture_time_receiver_unittest.cc rename to modules/rtp_rtcp/source/absolute_capture_time_interpolator_unittest.cc index ecf256734d..6a312f9b43 100644 --- a/modules/rtp_rtcp/source/absolute_capture_time_receiver_unittest.cc +++ b/modules/rtp_rtcp/source/absolute_capture_time_interpolator_unittest.cc @@ -8,7 +8,7 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include "modules/rtp_rtcp/source/absolute_capture_time_receiver.h" +#include "modules/rtp_rtcp/source/absolute_capture_time_interpolator.h" #include "system_wrappers/include/ntp_time.h" #include "test/gmock.h" @@ -16,20 +16,21 @@ namespace webrtc { -TEST(AbsoluteCaptureTimeReceiverTest, GetSourceWithoutCsrcs) { +TEST(AbsoluteCaptureTimeInterpolatorTest, GetSourceWithoutCsrcs) { constexpr uint32_t kSsrc = 12; - EXPECT_EQ(AbsoluteCaptureTimeReceiver::GetSource(kSsrc, nullptr), kSsrc); + EXPECT_EQ(AbsoluteCaptureTimeInterpolator::GetSource(kSsrc, nullptr), kSsrc); } -TEST(AbsoluteCaptureTimeReceiverTest, GetSourceWithCsrcs) { +TEST(AbsoluteCaptureTimeInterpolatorTest, GetSourceWithCsrcs) { constexpr uint32_t kSsrc = 12; constexpr uint32_t kCsrcs[] = {34, 56, 78, 90}; - EXPECT_EQ(AbsoluteCaptureTimeReceiver::GetSource(kSsrc, kCsrcs), kCsrcs[0]); + EXPECT_EQ(AbsoluteCaptureTimeInterpolator::GetSource(kSsrc, kCsrcs), + kCsrcs[0]); } -TEST(AbsoluteCaptureTimeReceiverTest, ReceiveExtensionReturnsExtension) { +TEST(AbsoluteCaptureTimeInterpolatorTest, ReceiveExtensionReturnsExtension) { constexpr uint32_t kSource = 1337; constexpr uint32_t kRtpClockFrequency = 64000; constexpr uint32_t kRtpTimestamp0 = 1020300000; @@ -40,20 +41,19 @@ TEST(AbsoluteCaptureTimeReceiverTest, ReceiveExtensionReturnsExtension) { AbsoluteCaptureTime{Int64MsToUQ32x32(9020), absl::nullopt}; SimulatedClock clock(0); - AbsoluteCaptureTimeReceiver receiver(&clock); + AbsoluteCaptureTimeInterpolator interpolator(&clock); - receiver.SetRemoteToLocalClockOffset(0); - - EXPECT_EQ(receiver.OnReceivePacket(kSource, kRtpTimestamp0, - kRtpClockFrequency, kExtension0), + EXPECT_EQ(interpolator.OnReceivePacket(kSource, kRtpTimestamp0, + kRtpClockFrequency, kExtension0), kExtension0); - EXPECT_EQ(receiver.OnReceivePacket(kSource, kRtpTimestamp1, - kRtpClockFrequency, kExtension1), + EXPECT_EQ(interpolator.OnReceivePacket(kSource, kRtpTimestamp1, + kRtpClockFrequency, kExtension1), kExtension1); } -TEST(AbsoluteCaptureTimeReceiverTest, ReceiveNoExtensionReturnsNoExtension) { +TEST(AbsoluteCaptureTimeInterpolatorTest, + ReceiveNoExtensionReturnsNoExtension) { constexpr uint32_t kSource = 1337; constexpr uint32_t kRtpClockFrequency = 64000; constexpr uint32_t kRtpTimestamp0 = 1020300000; @@ -62,20 +62,18 @@ TEST(AbsoluteCaptureTimeReceiverTest, ReceiveNoExtensionReturnsNoExtension) { static const absl::optional kExtension1 = absl::nullopt; SimulatedClock clock(0); - AbsoluteCaptureTimeReceiver receiver(&clock); - - receiver.SetRemoteToLocalClockOffset(0); + AbsoluteCaptureTimeInterpolator interpolator(&clock); - EXPECT_EQ(receiver.OnReceivePacket(kSource, kRtpTimestamp0, - kRtpClockFrequency, kExtension0), + EXPECT_EQ(interpolator.OnReceivePacket(kSource, kRtpTimestamp0, + kRtpClockFrequency, kExtension0), absl::nullopt); - EXPECT_EQ(receiver.OnReceivePacket(kSource, kRtpTimestamp1, - kRtpClockFrequency, kExtension1), + EXPECT_EQ(interpolator.OnReceivePacket(kSource, kRtpTimestamp1, + kRtpClockFrequency, kExtension1), absl::nullopt); } -TEST(AbsoluteCaptureTimeReceiverTest, InterpolateLaterPacketArrivingLater) { +TEST(AbsoluteCaptureTimeInterpolatorTest, InterpolateLaterPacketArrivingLater) { constexpr uint32_t kSource = 1337; constexpr uint32_t kRtpClockFrequency = 64000; constexpr uint32_t kRtpTimestamp0 = 1020300000; @@ -87,15 +85,13 @@ TEST(AbsoluteCaptureTimeReceiverTest, InterpolateLaterPacketArrivingLater) { static const absl::optional kExtension2 = absl::nullopt; SimulatedClock clock(0); - AbsoluteCaptureTimeReceiver receiver(&clock); - - receiver.SetRemoteToLocalClockOffset(0); + AbsoluteCaptureTimeInterpolator interpolator(&clock); - EXPECT_EQ(receiver.OnReceivePacket(kSource, kRtpTimestamp0, - kRtpClockFrequency, kExtension0), + EXPECT_EQ(interpolator.OnReceivePacket(kSource, kRtpTimestamp0, + kRtpClockFrequency, kExtension0), kExtension0); - absl::optional extension = receiver.OnReceivePacket( + absl::optional extension = interpolator.OnReceivePacket( kSource, kRtpTimestamp1, kRtpClockFrequency, kExtension1); EXPECT_TRUE(extension.has_value()); EXPECT_EQ(UQ32x32ToInt64Ms(extension->absolute_capture_timestamp), @@ -103,8 +99,8 @@ TEST(AbsoluteCaptureTimeReceiverTest, InterpolateLaterPacketArrivingLater) { EXPECT_EQ(extension->estimated_capture_clock_offset, kExtension0->estimated_capture_clock_offset); - extension = receiver.OnReceivePacket(kSource, kRtpTimestamp2, - kRtpClockFrequency, kExtension2); + extension = interpolator.OnReceivePacket(kSource, kRtpTimestamp2, + kRtpClockFrequency, kExtension2); EXPECT_TRUE(extension.has_value()); EXPECT_EQ(UQ32x32ToInt64Ms(extension->absolute_capture_timestamp), UQ32x32ToInt64Ms(kExtension0->absolute_capture_timestamp) + 40); @@ -112,7 +108,8 @@ TEST(AbsoluteCaptureTimeReceiverTest, InterpolateLaterPacketArrivingLater) { kExtension0->estimated_capture_clock_offset); } -TEST(AbsoluteCaptureTimeReceiverTest, InterpolateEarlierPacketArrivingLater) { +TEST(AbsoluteCaptureTimeInterpolatorTest, + InterpolateEarlierPacketArrivingLater) { constexpr uint32_t kSource = 1337; constexpr uint32_t kRtpClockFrequency = 64000; constexpr uint32_t kRtpTimestamp0 = 1020300000; @@ -124,15 +121,13 @@ TEST(AbsoluteCaptureTimeReceiverTest, InterpolateEarlierPacketArrivingLater) { static const absl::optional kExtension2 = absl::nullopt; SimulatedClock clock(0); - AbsoluteCaptureTimeReceiver receiver(&clock); + AbsoluteCaptureTimeInterpolator interpolator(&clock); - receiver.SetRemoteToLocalClockOffset(0); - - EXPECT_EQ(receiver.OnReceivePacket(kSource, kRtpTimestamp0, - kRtpClockFrequency, kExtension0), + EXPECT_EQ(interpolator.OnReceivePacket(kSource, kRtpTimestamp0, + kRtpClockFrequency, kExtension0), kExtension0); - absl::optional extension = receiver.OnReceivePacket( + absl::optional extension = interpolator.OnReceivePacket( kSource, kRtpTimestamp1, kRtpClockFrequency, kExtension1); EXPECT_TRUE(extension.has_value()); EXPECT_EQ(UQ32x32ToInt64Ms(extension->absolute_capture_timestamp), @@ -140,8 +135,8 @@ TEST(AbsoluteCaptureTimeReceiverTest, InterpolateEarlierPacketArrivingLater) { EXPECT_EQ(extension->estimated_capture_clock_offset, kExtension0->estimated_capture_clock_offset); - extension = receiver.OnReceivePacket(kSource, kRtpTimestamp2, - kRtpClockFrequency, kExtension2); + extension = interpolator.OnReceivePacket(kSource, kRtpTimestamp2, + kRtpClockFrequency, kExtension2); EXPECT_TRUE(extension.has_value()); EXPECT_EQ(UQ32x32ToInt64Ms(extension->absolute_capture_timestamp), UQ32x32ToInt64Ms(kExtension0->absolute_capture_timestamp) - 40); @@ -149,7 +144,7 @@ TEST(AbsoluteCaptureTimeReceiverTest, InterpolateEarlierPacketArrivingLater) { kExtension0->estimated_capture_clock_offset); } -TEST(AbsoluteCaptureTimeReceiverTest, +TEST(AbsoluteCaptureTimeInterpolatorTest, InterpolateLaterPacketArrivingLaterWithRtpTimestampWrapAround) { constexpr uint32_t kSource = 1337; constexpr uint32_t kRtpClockFrequency = 64000; @@ -162,15 +157,13 @@ TEST(AbsoluteCaptureTimeReceiverTest, static const absl::optional kExtension2 = absl::nullopt; SimulatedClock clock(0); - AbsoluteCaptureTimeReceiver receiver(&clock); - - receiver.SetRemoteToLocalClockOffset(0); + AbsoluteCaptureTimeInterpolator interpolator(&clock); - EXPECT_EQ(receiver.OnReceivePacket(kSource, kRtpTimestamp0, - kRtpClockFrequency, kExtension0), + EXPECT_EQ(interpolator.OnReceivePacket(kSource, kRtpTimestamp0, + kRtpClockFrequency, kExtension0), kExtension0); - absl::optional extension = receiver.OnReceivePacket( + absl::optional extension = interpolator.OnReceivePacket( kSource, kRtpTimestamp1, kRtpClockFrequency, kExtension1); EXPECT_TRUE(extension.has_value()); EXPECT_EQ(UQ32x32ToInt64Ms(extension->absolute_capture_timestamp), @@ -178,8 +171,8 @@ TEST(AbsoluteCaptureTimeReceiverTest, EXPECT_EQ(extension->estimated_capture_clock_offset, kExtension0->estimated_capture_clock_offset); - extension = receiver.OnReceivePacket(kSource, kRtpTimestamp2, - kRtpClockFrequency, kExtension2); + extension = interpolator.OnReceivePacket(kSource, kRtpTimestamp2, + kRtpClockFrequency, kExtension2); EXPECT_TRUE(extension.has_value()); EXPECT_EQ(UQ32x32ToInt64Ms(extension->absolute_capture_timestamp), UQ32x32ToInt64Ms(kExtension0->absolute_capture_timestamp) + 40); @@ -187,7 +180,7 @@ TEST(AbsoluteCaptureTimeReceiverTest, kExtension0->estimated_capture_clock_offset); } -TEST(AbsoluteCaptureTimeReceiverTest, +TEST(AbsoluteCaptureTimeInterpolatorTest, InterpolateEarlierPacketArrivingLaterWithRtpTimestampWrapAround) { constexpr uint32_t kSource = 1337; constexpr uint32_t kRtpClockFrequency = 64000; @@ -200,15 +193,13 @@ TEST(AbsoluteCaptureTimeReceiverTest, static const absl::optional kExtension2 = absl::nullopt; SimulatedClock clock(0); - AbsoluteCaptureTimeReceiver receiver(&clock); + AbsoluteCaptureTimeInterpolator interpolator(&clock); - receiver.SetRemoteToLocalClockOffset(0); - - EXPECT_EQ(receiver.OnReceivePacket(kSource, kRtpTimestamp0, - kRtpClockFrequency, kExtension0), + EXPECT_EQ(interpolator.OnReceivePacket(kSource, kRtpTimestamp0, + kRtpClockFrequency, kExtension0), kExtension0); - absl::optional extension = receiver.OnReceivePacket( + absl::optional extension = interpolator.OnReceivePacket( kSource, kRtpTimestamp1, kRtpClockFrequency, kExtension1); EXPECT_TRUE(extension.has_value()); EXPECT_EQ(UQ32x32ToInt64Ms(extension->absolute_capture_timestamp), @@ -216,8 +207,8 @@ TEST(AbsoluteCaptureTimeReceiverTest, EXPECT_EQ(extension->estimated_capture_clock_offset, kExtension0->estimated_capture_clock_offset); - extension = receiver.OnReceivePacket(kSource, kRtpTimestamp2, - kRtpClockFrequency, kExtension2); + extension = interpolator.OnReceivePacket(kSource, kRtpTimestamp2, + kRtpClockFrequency, kExtension2); EXPECT_TRUE(extension.has_value()); EXPECT_EQ(UQ32x32ToInt64Ms(extension->absolute_capture_timestamp), UQ32x32ToInt64Ms(kExtension0->absolute_capture_timestamp) - 40); @@ -225,51 +216,7 @@ TEST(AbsoluteCaptureTimeReceiverTest, kExtension0->estimated_capture_clock_offset); } -TEST(AbsoluteCaptureTimeReceiverTest, - SkipEstimatedCaptureClockOffsetIfRemoteToLocalClockOffsetIsUnknown) { - constexpr uint32_t kSource = 1337; - constexpr uint32_t kRtpClockFrequency = 64000; - constexpr uint32_t kRtpTimestamp0 = 1020300000; - constexpr uint32_t kRtpTimestamp1 = kRtpTimestamp0 + 1280; - constexpr uint32_t kRtpTimestamp2 = kRtpTimestamp0 + 2560; - static const absl::optional kExtension0 = - AbsoluteCaptureTime{Int64MsToUQ32x32(9000), Int64MsToQ32x32(-350)}; - static const absl::optional kExtension1 = absl::nullopt; - static const absl::optional kExtension2 = absl::nullopt; - static const absl::optional kRemoteToLocalClockOffset2 = - Int64MsToQ32x32(-7000007); - - SimulatedClock clock(0); - AbsoluteCaptureTimeReceiver receiver(&clock); - - receiver.SetRemoteToLocalClockOffset(0); - - EXPECT_EQ(receiver.OnReceivePacket(kSource, kRtpTimestamp0, - kRtpClockFrequency, kExtension0), - kExtension0); - - receiver.SetRemoteToLocalClockOffset(absl::nullopt); - - absl::optional extension = receiver.OnReceivePacket( - kSource, kRtpTimestamp1, kRtpClockFrequency, kExtension1); - EXPECT_TRUE(extension.has_value()); - EXPECT_EQ(UQ32x32ToInt64Ms(extension->absolute_capture_timestamp), - UQ32x32ToInt64Ms(kExtension0->absolute_capture_timestamp) + 20); - EXPECT_EQ(extension->estimated_capture_clock_offset, absl::nullopt); - - receiver.SetRemoteToLocalClockOffset(kRemoteToLocalClockOffset2); - - extension = receiver.OnReceivePacket(kSource, kRtpTimestamp2, - kRtpClockFrequency, kExtension2); - EXPECT_TRUE(extension.has_value()); - EXPECT_EQ(UQ32x32ToInt64Ms(extension->absolute_capture_timestamp), - UQ32x32ToInt64Ms(kExtension0->absolute_capture_timestamp) + 40); - EXPECT_EQ(extension->estimated_capture_clock_offset, - *kExtension0->estimated_capture_clock_offset + - *kRemoteToLocalClockOffset2); -} - -TEST(AbsoluteCaptureTimeReceiverTest, SkipInterpolateIfTooLate) { +TEST(AbsoluteCaptureTimeInterpolatorTest, SkipInterpolateIfTooLate) { constexpr uint32_t kSource = 1337; constexpr uint32_t kRtpClockFrequency = 64000; constexpr uint32_t kRtpTimestamp0 = 1020300000; @@ -281,30 +228,28 @@ TEST(AbsoluteCaptureTimeReceiverTest, SkipInterpolateIfTooLate) { static const absl::optional kExtension2 = absl::nullopt; SimulatedClock clock(0); - AbsoluteCaptureTimeReceiver receiver(&clock); - - receiver.SetRemoteToLocalClockOffset(0); + AbsoluteCaptureTimeInterpolator interpolator(&clock); - EXPECT_EQ(receiver.OnReceivePacket(kSource, kRtpTimestamp0, - kRtpClockFrequency, kExtension0), + EXPECT_EQ(interpolator.OnReceivePacket(kSource, kRtpTimestamp0, + kRtpClockFrequency, kExtension0), kExtension0); - clock.AdvanceTime(AbsoluteCaptureTimeReceiver::kInterpolationMaxInterval); + clock.AdvanceTime(AbsoluteCaptureTimeInterpolator::kInterpolationMaxInterval); - EXPECT_TRUE(receiver + EXPECT_TRUE(interpolator .OnReceivePacket(kSource, kRtpTimestamp1, kRtpClockFrequency, kExtension1) .has_value()); clock.AdvanceTimeMilliseconds(1); - EXPECT_FALSE(receiver + EXPECT_FALSE(interpolator .OnReceivePacket(kSource, kRtpTimestamp2, kRtpClockFrequency, kExtension2) .has_value()); } -TEST(AbsoluteCaptureTimeReceiverTest, SkipInterpolateIfSourceChanged) { +TEST(AbsoluteCaptureTimeInterpolatorTest, SkipInterpolateIfSourceChanged) { constexpr uint32_t kSource0 = 1337; constexpr uint32_t kSource1 = 1338; constexpr uint32_t kRtpClockFrequency = 64000; @@ -315,21 +260,19 @@ TEST(AbsoluteCaptureTimeReceiverTest, SkipInterpolateIfSourceChanged) { static const absl::optional kExtension1 = absl::nullopt; SimulatedClock clock(0); - AbsoluteCaptureTimeReceiver receiver(&clock); + AbsoluteCaptureTimeInterpolator interpolator(&clock); - receiver.SetRemoteToLocalClockOffset(0); - - EXPECT_EQ(receiver.OnReceivePacket(kSource0, kRtpTimestamp0, - kRtpClockFrequency, kExtension0), + EXPECT_EQ(interpolator.OnReceivePacket(kSource0, kRtpTimestamp0, + kRtpClockFrequency, kExtension0), kExtension0); - EXPECT_FALSE(receiver + EXPECT_FALSE(interpolator .OnReceivePacket(kSource1, kRtpTimestamp1, kRtpClockFrequency, kExtension1) .has_value()); } -TEST(AbsoluteCaptureTimeReceiverTest, +TEST(AbsoluteCaptureTimeInterpolatorTest, SkipInterpolateIfRtpClockFrequencyChanged) { constexpr uint32_t kSource = 1337; constexpr uint32_t kRtpClockFrequency0 = 64000; @@ -341,21 +284,19 @@ TEST(AbsoluteCaptureTimeReceiverTest, static const absl::optional kExtension1 = absl::nullopt; SimulatedClock clock(0); - AbsoluteCaptureTimeReceiver receiver(&clock); - - receiver.SetRemoteToLocalClockOffset(0); + AbsoluteCaptureTimeInterpolator interpolator(&clock); - EXPECT_EQ(receiver.OnReceivePacket(kSource, kRtpTimestamp0, - kRtpClockFrequency0, kExtension0), + EXPECT_EQ(interpolator.OnReceivePacket(kSource, kRtpTimestamp0, + kRtpClockFrequency0, kExtension0), kExtension0); - EXPECT_FALSE(receiver + EXPECT_FALSE(interpolator .OnReceivePacket(kSource, kRtpTimestamp1, kRtpClockFrequency1, kExtension1) .has_value()); } -TEST(AbsoluteCaptureTimeReceiverTest, +TEST(AbsoluteCaptureTimeInterpolatorTest, SkipInterpolateIfRtpClockFrequencyIsInvalid) { constexpr uint32_t kSource = 1337; constexpr uint32_t kRtpClockFrequency = 0; @@ -366,21 +307,19 @@ TEST(AbsoluteCaptureTimeReceiverTest, static const absl::optional kExtension1 = absl::nullopt; SimulatedClock clock(0); - AbsoluteCaptureTimeReceiver receiver(&clock); + AbsoluteCaptureTimeInterpolator interpolator(&clock); - receiver.SetRemoteToLocalClockOffset(0); - - EXPECT_EQ(receiver.OnReceivePacket(kSource, kRtpTimestamp0, - kRtpClockFrequency, kExtension0), + EXPECT_EQ(interpolator.OnReceivePacket(kSource, kRtpTimestamp0, + kRtpClockFrequency, kExtension0), kExtension0); - EXPECT_FALSE(receiver + EXPECT_FALSE(interpolator .OnReceivePacket(kSource, kRtpTimestamp1, kRtpClockFrequency, kExtension1) .has_value()); } -TEST(AbsoluteCaptureTimeReceiverTest, SkipInterpolateIsSticky) { +TEST(AbsoluteCaptureTimeInterpolatorTest, SkipInterpolateIsSticky) { constexpr uint32_t kSource0 = 1337; constexpr uint32_t kSource1 = 1338; constexpr uint32_t kSource2 = 1337; @@ -394,20 +333,18 @@ TEST(AbsoluteCaptureTimeReceiverTest, SkipInterpolateIsSticky) { static const absl::optional kExtension2 = absl::nullopt; SimulatedClock clock(0); - AbsoluteCaptureTimeReceiver receiver(&clock); - - receiver.SetRemoteToLocalClockOffset(0); + AbsoluteCaptureTimeInterpolator interpolator(&clock); - EXPECT_EQ(receiver.OnReceivePacket(kSource0, kRtpTimestamp0, - kRtpClockFrequency, kExtension0), + EXPECT_EQ(interpolator.OnReceivePacket(kSource0, kRtpTimestamp0, + kRtpClockFrequency, kExtension0), kExtension0); - EXPECT_FALSE(receiver + EXPECT_FALSE(interpolator .OnReceivePacket(kSource1, kRtpTimestamp1, kRtpClockFrequency, kExtension1) .has_value()); - EXPECT_FALSE(receiver + EXPECT_FALSE(interpolator .OnReceivePacket(kSource2, kRtpTimestamp2, kRtpClockFrequency, kExtension2) .has_value()); diff --git a/modules/rtp_rtcp/source/absolute_capture_time_sender.cc b/modules/rtp_rtcp/source/absolute_capture_time_sender.cc index 83ba6cac91..28266769ff 100644 --- a/modules/rtp_rtcp/source/absolute_capture_time_sender.cc +++ b/modules/rtp_rtcp/source/absolute_capture_time_sender.cc @@ -12,7 +12,7 @@ #include -#include "modules/rtp_rtcp/source/absolute_capture_time_receiver.h" +#include "modules/rtp_rtcp/source/absolute_capture_time_interpolator.h" #include "system_wrappers/include/ntp_time.h" namespace webrtc { @@ -26,7 +26,7 @@ constexpr TimeDelta AbsoluteCaptureTimeSender::kInterpolationMaxInterval; constexpr TimeDelta AbsoluteCaptureTimeSender::kInterpolationMaxError; static_assert( - AbsoluteCaptureTimeReceiver::kInterpolationMaxInterval >= + AbsoluteCaptureTimeInterpolator::kInterpolationMaxInterval >= AbsoluteCaptureTimeSender::kInterpolationMaxInterval, "Receivers should be as willing to interpolate timestamps as senders."); @@ -36,7 +36,7 @@ AbsoluteCaptureTimeSender::AbsoluteCaptureTimeSender(Clock* clock) uint32_t AbsoluteCaptureTimeSender::GetSource( uint32_t ssrc, rtc::ArrayView csrcs) { - return AbsoluteCaptureTimeReceiver::GetSource(ssrc, csrcs); + return AbsoluteCaptureTimeInterpolator::GetSource(ssrc, csrcs); } absl::optional AbsoluteCaptureTimeSender::OnSendPacket( @@ -108,7 +108,7 @@ bool AbsoluteCaptureTimeSender::ShouldSendExtension( // Should if interpolation would introduce too much error. const uint64_t interpolated_absolute_capture_timestamp = - AbsoluteCaptureTimeReceiver::InterpolateAbsoluteCaptureTimestamp( + AbsoluteCaptureTimeInterpolator::InterpolateAbsoluteCaptureTimestamp( rtp_timestamp, rtp_clock_frequency, last_rtp_timestamp_, last_absolute_capture_timestamp_); const int64_t interpolation_error_ms = UQ32x32ToInt64Ms(std::min( diff --git a/modules/rtp_rtcp/source/capture_clock_offset_updater.cc b/modules/rtp_rtcp/source/capture_clock_offset_updater.cc new file mode 100644 index 0000000000..a5b12cb422 --- /dev/null +++ b/modules/rtp_rtcp/source/capture_clock_offset_updater.cc @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/rtp_rtcp/source/capture_clock_offset_updater.h" + +namespace webrtc { + +absl::optional +CaptureClockOffsetUpdater::AdjustEstimatedCaptureClockOffset( + absl::optional remote_capture_clock_offset) const { + if (remote_capture_clock_offset == absl::nullopt || + remote_to_local_clock_offset_ == absl::nullopt) { + return absl::nullopt; + } + + // Do calculations as "unsigned" to make overflows deterministic. + return static_cast(*remote_capture_clock_offset) + + static_cast(*remote_to_local_clock_offset_); +} + +void CaptureClockOffsetUpdater::SetRemoteToLocalClockOffset( + absl::optional offset_q32x32) { + remote_to_local_clock_offset_ = offset_q32x32; +} + +} // namespace webrtc diff --git a/modules/rtp_rtcp/source/capture_clock_offset_updater.h b/modules/rtp_rtcp/source/capture_clock_offset_updater.h new file mode 100644 index 0000000000..71d3eb4831 --- /dev/null +++ b/modules/rtp_rtcp/source/capture_clock_offset_updater.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_RTP_RTCP_SOURCE_CAPTURE_CLOCK_OFFSET_UPDATER_H_ +#define MODULES_RTP_RTCP_SOURCE_CAPTURE_CLOCK_OFFSET_UPDATER_H_ + +#include + +#include "absl/types/optional.h" + +namespace webrtc { + +// +// Helper class for calculating the clock offset against the capturer's clock. +// +// This is achieved by adjusting the estimated capture clock offset in received +// Absolute Capture Time RTP header extension (see +// https://webrtc.org/experiments/rtp-hdrext/abs-capture-time/), which +// represents the clock offset between a remote sender and the capturer, by +// adding local-to-remote clock offset. + +class CaptureClockOffsetUpdater { + public: + // Adjusts remote_capture_clock_offset, which originates from Absolute Capture + // Time RTP header extension, to get the local clock offset against the + // capturer's clock. + absl::optional AdjustEstimatedCaptureClockOffset( + absl::optional remote_capture_clock_offset) const; + + // Sets the NTP clock offset between the sender system (which may be different + // from the capture system) and the local system. This information is normally + // provided by passing half the value of the Round-Trip Time estimation given + // by RTCP sender reports (see DLSR/DLRR). + // + // Note that the value must be in Q32.32-formatted fixed-point seconds. + void SetRemoteToLocalClockOffset(absl::optional offset_q32x32); + + private: + absl::optional remote_to_local_clock_offset_; +}; + +} // namespace webrtc + +#endif // MODULES_RTP_RTCP_SOURCE_CAPTURE_CLOCK_OFFSET_UPDATER_H_ diff --git a/modules/rtp_rtcp/source/capture_clock_offset_updater_unittest.cc b/modules/rtp_rtcp/source/capture_clock_offset_updater_unittest.cc new file mode 100644 index 0000000000..43e1dd1379 --- /dev/null +++ b/modules/rtp_rtcp/source/capture_clock_offset_updater_unittest.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/rtp_rtcp/source/capture_clock_offset_updater.h" + +#include "system_wrappers/include/ntp_time.h" +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { + +TEST(AbsoluteCaptureTimeReceiverTest, + SkipEstimatedCaptureClockOffsetIfRemoteToLocalClockOffsetIsUnknown) { + static const absl::optional kRemoteCaptureClockOffset = + Int64MsToQ32x32(-350); + CaptureClockOffsetUpdater updater; + updater.SetRemoteToLocalClockOffset(absl::nullopt); + EXPECT_EQ( + updater.AdjustEstimatedCaptureClockOffset(kRemoteCaptureClockOffset), + absl::nullopt); +} + +TEST(AbsoluteCaptureTimeReceiverTest, + SkipEstimatedCaptureClockOffsetIfRemoteCaptureClockOffsetIsUnknown) { + static const absl::optional kCaptureClockOffsetNull = absl::nullopt; + CaptureClockOffsetUpdater updater; + updater.SetRemoteToLocalClockOffset(0); + EXPECT_EQ(updater.AdjustEstimatedCaptureClockOffset(kCaptureClockOffsetNull), + kCaptureClockOffsetNull); + + static const absl::optional kRemoteCaptureClockOffset = + Int64MsToQ32x32(-350); + EXPECT_EQ( + updater.AdjustEstimatedCaptureClockOffset(kRemoteCaptureClockOffset), + kRemoteCaptureClockOffset); +} + +TEST(AbsoluteCaptureTimeReceiverTest, EstimatedCaptureClockOffsetArithmetic) { + static const absl::optional kRemoteCaptureClockOffset = + Int64MsToQ32x32(-350); + static const absl::optional kRemoteToLocalClockOffset = + Int64MsToQ32x32(-7000007); + CaptureClockOffsetUpdater updater; + updater.SetRemoteToLocalClockOffset(kRemoteToLocalClockOffset); + EXPECT_THAT( + updater.AdjustEstimatedCaptureClockOffset(kRemoteCaptureClockOffset), + ::testing::Optional(::testing::Eq(*kRemoteCaptureClockOffset + + *kRemoteToLocalClockOffset))); +} + +} // namespace webrtc diff --git a/modules/rtp_rtcp/source/deprecated/deprecated_rtp_sender_egress.cc b/modules/rtp_rtcp/source/deprecated/deprecated_rtp_sender_egress.cc index 6cb9d9330c..c542557526 100644 --- a/modules/rtp_rtcp/source/deprecated/deprecated_rtp_sender_egress.cc +++ b/modules/rtp_rtcp/source/deprecated/deprecated_rtp_sender_egress.cc @@ -176,8 +176,7 @@ void DEPRECATED_RtpSenderEgress::SendPacket( AddPacketToTransportFeedback(*packet_id, *packet, pacing_info); } - options.application_data.assign(packet->application_data().begin(), - packet->application_data().end()); + options.additional_data = packet->additional_data(); if (packet->packet_type() != RtpPacketMediaType::kPadding && packet->packet_type() != RtpPacketMediaType::kRetransmission) { @@ -314,7 +313,9 @@ void DEPRECATED_RtpSenderEgress::AddPacketToTransportFeedback( } RtpPacketSendInfo packet_info; + // TODO(bugs.webrtc.org/12713): Remove once downstream usage is gone. packet_info.ssrc = ssrc_; + packet_info.media_ssrc = ssrc_; packet_info.transport_sequence_number = packet_id; packet_info.rtp_sequence_number = packet.SequenceNumber(); packet_info.length = packet_size; diff --git a/modules/rtp_rtcp/source/fec_test_helper.cc b/modules/rtp_rtcp/source/fec_test_helper.cc index ff736fd5f2..b9ac25e4a8 100644 --- a/modules/rtp_rtcp/source/fec_test_helper.cc +++ b/modules/rtp_rtcp/source/fec_test_helper.cc @@ -184,19 +184,21 @@ UlpfecPacketGenerator::UlpfecPacketGenerator(uint32_t ssrc) RtpPacketReceived UlpfecPacketGenerator::BuildMediaRedPacket( const AugmentedPacket& packet, bool is_recovered) { - RtpPacketReceived red_packet; - // Copy RTP header. + // Create a temporary buffer used to wrap the media packet in RED. + rtc::CopyOnWriteBuffer red_buffer; const size_t kHeaderLength = packet.header.headerLength; - red_packet.Parse(packet.data.cdata(), kHeaderLength); - RTC_DCHECK_EQ(red_packet.headers_size(), kHeaderLength); - uint8_t* rtp_payload = - red_packet.AllocatePayload(packet.data.size() + 1 - kHeaderLength); - // Move payload type into rtp payload. - rtp_payload[0] = red_packet.PayloadType(); + // Append header. + red_buffer.SetData(packet.data.data(), kHeaderLength); + // Find payload type and add it as RED header. + uint8_t media_payload_type = red_buffer[1] & 0x7F; + red_buffer.AppendData({media_payload_type}); + // Append rest of payload/padding. + red_buffer.AppendData( + packet.data.Slice(kHeaderLength, packet.data.size() - kHeaderLength)); + + RtpPacketReceived red_packet; + RTC_CHECK(red_packet.Parse(std::move(red_buffer))); red_packet.SetPayloadType(kRedPayloadType); - // Copy the payload. - memcpy(rtp_payload + 1, packet.data.cdata() + kHeaderLength, - packet.data.size() - kHeaderLength); red_packet.set_recovered(is_recovered); return red_packet; diff --git a/modules/rtp_rtcp/source/flexfec_header_reader_writer.cc b/modules/rtp_rtcp/source/flexfec_header_reader_writer.cc index 8b4162fe2f..40426f16bf 100644 --- a/modules/rtp_rtcp/source/flexfec_header_reader_writer.cc +++ b/modules/rtp_rtcp/source/flexfec_header_reader_writer.cc @@ -25,6 +25,11 @@ namespace { // Maximum number of media packets that can be protected in one batch. constexpr size_t kMaxMediaPackets = 48; // Since we are reusing ULPFEC masks. +// Maximum number of media packets tracked by FEC decoder. +// Maintain a sufficiently larger tracking window than |kMaxMediaPackets| +// to account for packet reordering in pacer/ network. +constexpr size_t kMaxTrackedMediaPackets = 4 * kMaxMediaPackets; + // Maximum number of FEC packets stored inside ForwardErrorCorrection. constexpr size_t kMaxFecPackets = kMaxMediaPackets; @@ -72,7 +77,7 @@ size_t FlexfecHeaderSize(size_t packet_mask_size) { } // namespace FlexfecHeaderReader::FlexfecHeaderReader() - : FecHeaderReader(kMaxMediaPackets, kMaxFecPackets) {} + : FecHeaderReader(kMaxTrackedMediaPackets, kMaxFecPackets) {} FlexfecHeaderReader::~FlexfecHeaderReader() = default; diff --git a/modules/rtp_rtcp/source/flexfec_receiver_unittest.cc b/modules/rtp_rtcp/source/flexfec_receiver_unittest.cc index b9391eeb74..7261280aef 100644 --- a/modules/rtp_rtcp/source/flexfec_receiver_unittest.cc +++ b/modules/rtp_rtcp/source/flexfec_receiver_unittest.cc @@ -374,7 +374,8 @@ TEST_F(FlexfecReceiverTest, RecoversFrom50PercentLoss) { TEST_F(FlexfecReceiverTest, DelayedFecPacketDoesHelp) { // These values need to be updated if the underlying erasure code // implementation changes. - const size_t kNumFrames = 48; + // Delay FEC packet by maximum number of media packets tracked by receiver. + const size_t kNumFrames = 192; const size_t kNumMediaPacketsPerFrame = 1; const size_t kNumFecPackets = 1; @@ -412,14 +413,16 @@ TEST_F(FlexfecReceiverTest, DelayedFecPacketDoesHelp) { TEST_F(FlexfecReceiverTest, TooDelayedFecPacketDoesNotHelp) { // These values need to be updated if the underlying erasure code // implementation changes. - const size_t kNumFrames = 49; + // Delay FEC packet by one more than maximum number of media packets + // tracked by receiver. + const size_t kNumFrames = 193; const size_t kNumMediaPacketsPerFrame = 1; const size_t kNumFecPackets = 1; PacketList media_packets; PacketizeFrame(kNumMediaPacketsPerFrame, 0, &media_packets); PacketizeFrame(kNumMediaPacketsPerFrame, 1, &media_packets); - // Protect two first frames. + // Protect first two frames. std::list fec_packets = EncodeFec(media_packets, kNumFecPackets); for (size_t i = 2; i < kNumFrames; ++i) { PacketizeFrame(kNumMediaPacketsPerFrame, i, &media_packets); @@ -646,4 +649,58 @@ TEST_F(FlexfecReceiverTest, CalculatesNumberOfPackets) { EXPECT_EQ(1U, packet_counter.num_recovered_packets); } +TEST_F(FlexfecReceiverTest, DoesNotDecodeWrappedMediaSequenceUsingOldFec) { + const size_t kFirstFrameNumMediaPackets = 2; + const size_t kFirstFrameNumFecPackets = 1; + + PacketList media_packets; + PacketizeFrame(kFirstFrameNumMediaPackets, 0, &media_packets); + + // Protect first frame (sequences 0 and 1) with 1 FEC packet. + std::list fec_packets = + EncodeFec(media_packets, kFirstFrameNumFecPackets); + + // Generate enough media packets to simulate media sequence number wraparound. + // Use no FEC for these frames to make sure old FEC is not purged due to age. + const size_t kNumFramesSequenceWrapAround = + std::numeric_limits::max(); + const size_t kNumMediaPacketsPerFrame = 1; + + for (size_t i = 1; i <= kNumFramesSequenceWrapAround; ++i) { + PacketizeFrame(kNumMediaPacketsPerFrame, i, &media_packets); + } + + // Receive first (|kFirstFrameNumMediaPackets| + 192) media packets. + // Simulate an old FEC packet by separating it from its encoded media + // packets by at least 192 packets. + auto media_it = media_packets.begin(); + for (size_t i = 0; i < (kFirstFrameNumMediaPackets + 192); i++) { + if (i == 1) { + // Drop the second packet of the first frame. + media_it++; + } else { + receiver_.OnRtpPacket(ParsePacket(**media_it++)); + } + } + + // Receive FEC packet. Although a protected packet was dropped, + // expect no recovery callback since it is delayed from first frame + // by more than 192 packets. + auto fec_it = fec_packets.begin(); + std::unique_ptr fec_packet_with_rtp_header = + packet_generator_.BuildFlexfecPacket(**fec_it); + receiver_.OnRtpPacket(ParsePacket(*fec_packet_with_rtp_header)); + + // Receive remaining media packets. + // NOTE: Because we sent enough to simulate wrap around, sequence 0 is + // received again, but is a different packet than the original first + // packet of first frame. + while (media_it != media_packets.end()) { + receiver_.OnRtpPacket(ParsePacket(**media_it++)); + } + + // Do not expect a recovery callback, the FEC packet is old + // and should not decode wrapped around media sequences. +} + } // namespace webrtc diff --git a/modules/rtp_rtcp/source/forward_error_correction.cc b/modules/rtp_rtcp/source/forward_error_correction.cc index 56eabc8a7f..da8025d3db 100644 --- a/modules/rtp_rtcp/source/forward_error_correction.cc +++ b/modules/rtp_rtcp/source/forward_error_correction.cc @@ -31,6 +31,8 @@ namespace webrtc { namespace { // Transport header size in bytes. Assume UDP/IPv4 as a reasonable minimum. constexpr size_t kTransportOverhead = 28; + +constexpr uint16_t kOldSequenceThreshold = 0x3fff; } // namespace ForwardErrorCorrection::Packet::Packet() : data(0), ref_count_(0) {} @@ -508,9 +510,6 @@ void ForwardErrorCorrection::InsertPacket( // This is important for keeping |received_fec_packets_| sorted, and may // also reduce the possibility of incorrect decoding due to sequence number // wrap-around. - // TODO(marpan/holmer): We should be able to improve detection/discarding of - // old FEC packets based on timestamp information or better sequence number - // thresholding (e.g., to distinguish between wrap-around and reordering). if (!received_fec_packets_.empty() && received_packet.ssrc == received_fec_packets_.front()->ssrc) { // It only makes sense to detect wrap-around when |received_packet| @@ -521,7 +520,7 @@ void ForwardErrorCorrection::InsertPacket( auto it = received_fec_packets_.begin(); while (it != received_fec_packets_.end()) { uint16_t seq_num_diff = MinDiff(received_packet.seq_num, (*it)->seq_num); - if (seq_num_diff > 0x3fff) { + if (seq_num_diff > kOldSequenceThreshold) { it = received_fec_packets_.erase(it); } else { // No need to keep iterating, since |received_fec_packets_| is sorted. @@ -698,9 +697,10 @@ void ForwardErrorCorrection::AttemptRecovery( // this may allow additional packets to be recovered. // Restart for first FEC packet. fec_packet_it = received_fec_packets_.begin(); - } else if (packets_missing == 0) { - // Either all protected packets arrived or have been recovered. We can - // discard this FEC packet. + } else if (packets_missing == 0 || + IsOldFecPacket(**fec_packet_it, recovered_packets)) { + // Either all protected packets arrived or have been recovered, or the FEC + // packet is old. We can discard this FEC packet. fec_packet_it = received_fec_packets_.erase(fec_packet_it); } else { fec_packet_it++; @@ -731,6 +731,23 @@ void ForwardErrorCorrection::DiscardOldRecoveredPackets( RTC_DCHECK_LE(recovered_packets->size(), max_media_packets); } +bool ForwardErrorCorrection::IsOldFecPacket( + const ReceivedFecPacket& fec_packet, + const RecoveredPacketList* recovered_packets) { + if (recovered_packets->empty()) { + return false; + } + + const uint16_t back_recovered_seq_num = recovered_packets->back()->seq_num; + const uint16_t last_protected_seq_num = + fec_packet.protected_packets.back()->seq_num; + + // FEC packet is old if its last protected sequence number is much + // older than the latest protected sequence number received. + return (MinDiff(back_recovered_seq_num, last_protected_seq_num) > + kOldSequenceThreshold); +} + uint16_t ForwardErrorCorrection::ParseSequenceNumber(const uint8_t* packet) { return (packet[2] << 8) + packet[3]; } diff --git a/modules/rtp_rtcp/source/forward_error_correction.h b/modules/rtp_rtcp/source/forward_error_correction.h index 0c54ad984c..b97693d01f 100644 --- a/modules/rtp_rtcp/source/forward_error_correction.h +++ b/modules/rtp_rtcp/source/forward_error_correction.h @@ -330,6 +330,11 @@ class ForwardErrorCorrection { // for recovering lost packets. void DiscardOldRecoveredPackets(RecoveredPacketList* recovered_packets); + // Checks if the FEC packet is old enough and no longer relevant for + // recovering lost media packets. + bool IsOldFecPacket(const ReceivedFecPacket& fec_packet, + const RecoveredPacketList* recovered_packets); + // These SSRCs are only used by the decoder. const uint32_t ssrc_; const uint32_t protected_media_ssrc_; diff --git a/modules/rtp_rtcp/source/nack_rtx_unittest.cc b/modules/rtp_rtcp/source/nack_rtx_unittest.cc index 8afaf3ee61..fc035047b0 100644 --- a/modules/rtp_rtcp/source/nack_rtx_unittest.cc +++ b/modules/rtp_rtcp/source/nack_rtx_unittest.cc @@ -218,7 +218,6 @@ class RtpRtcpRtxNackTest : public ::testing::Test { if (length > 0) rtp_rtcp_module_->SendNACK(nack_list, length); fake_clock.AdvanceTimeMilliseconds(28); // 33ms - 5ms delay. - rtp_rtcp_module_->Process(); // Prepare next frame. timestamp += 3000; } @@ -265,7 +264,6 @@ TEST_F(RtpRtcpRtxNackTest, LongNackList) { // Prepare next frame. timestamp += 3000; fake_clock.AdvanceTimeMilliseconds(33); - rtp_rtcp_module_->Process(); } EXPECT_FALSE(transport_.expected_sequence_numbers_.empty()); EXPECT_FALSE(media_stream_.sequence_numbers_.empty()); diff --git a/modules/rtp_rtcp/source/packet_sequencer.cc b/modules/rtp_rtcp/source/packet_sequencer.cc new file mode 100644 index 0000000000..03ea9b8154 --- /dev/null +++ b/modules/rtp_rtcp/source/packet_sequencer.cc @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/rtp_rtcp/source/packet_sequencer.h" + +#include "rtc_base/checks.h" + +namespace webrtc { + +namespace { +// RED header is first byte of payload, if present. +constexpr size_t kRedForFecHeaderLength = 1; + +// Timestamps use a 90kHz clock. +constexpr uint32_t kTimestampTicksPerMs = 90; +} // namespace + +PacketSequencer::PacketSequencer(uint32_t media_ssrc, + uint32_t rtx_ssrc, + bool require_marker_before_media_padding, + Clock* clock) + : media_ssrc_(media_ssrc), + rtx_ssrc_(rtx_ssrc), + require_marker_before_media_padding_(require_marker_before_media_padding), + clock_(clock), + media_sequence_number_(0), + rtx_sequence_number_(0), + last_payload_type_(-1), + last_rtp_timestamp_(0), + last_capture_time_ms_(0), + last_timestamp_time_ms_(0), + last_packet_marker_bit_(false) {} + +bool PacketSequencer::Sequence(RtpPacketToSend& packet) { + if (packet.packet_type() == RtpPacketMediaType::kPadding && + !PopulatePaddingFields(packet)) { + // This padding packet can't be sent with current state, return without + // updating the sequence number. + return false; + } + + if (packet.Ssrc() == media_ssrc_) { + packet.SetSequenceNumber(media_sequence_number_++); + if (packet.packet_type() != RtpPacketMediaType::kPadding) { + UpdateLastPacketState(packet); + } + return true; + } + + RTC_DCHECK_EQ(packet.Ssrc(), rtx_ssrc_); + packet.SetSequenceNumber(rtx_sequence_number_++); + return true; +} + +void PacketSequencer::SetRtpState(const RtpState& state) { + media_sequence_number_ = state.sequence_number; + last_rtp_timestamp_ = state.timestamp; + last_capture_time_ms_ = state.capture_time_ms; + last_timestamp_time_ms_ = state.last_timestamp_time_ms; +} + +void PacketSequencer::PupulateRtpState(RtpState& state) const { + state.sequence_number = media_sequence_number_; + state.timestamp = last_rtp_timestamp_; + state.capture_time_ms = last_capture_time_ms_; + state.last_timestamp_time_ms = last_timestamp_time_ms_; +} + +void PacketSequencer::UpdateLastPacketState(const RtpPacketToSend& packet) { + // Remember marker bit to determine if padding can be inserted with + // sequence number following |packet|. + last_packet_marker_bit_ = packet.Marker(); + // Remember media payload type to use in the padding packet if rtx is + // disabled. + if (packet.is_red()) { + RTC_DCHECK_GE(packet.payload_size(), kRedForFecHeaderLength); + last_payload_type_ = packet.PayloadBuffer()[0]; + } else { + last_payload_type_ = packet.PayloadType(); + } + // Save timestamps to generate timestamp field and extensions for the padding. + last_rtp_timestamp_ = packet.Timestamp(); + last_timestamp_time_ms_ = clock_->TimeInMilliseconds(); + last_capture_time_ms_ = packet.capture_time_ms(); +} + +bool PacketSequencer::PopulatePaddingFields(RtpPacketToSend& packet) { + if (packet.Ssrc() == media_ssrc_) { + if (last_payload_type_ == -1) { + return false; + } + + // Without RTX we can't send padding in the middle of frames. + // For audio marker bits doesn't mark the end of a frame and frames + // are usually a single packet, so for now we don't apply this rule + // for audio. + if (require_marker_before_media_padding_ && !last_packet_marker_bit_) { + return false; + } + + packet.SetTimestamp(last_rtp_timestamp_); + packet.set_capture_time_ms(last_capture_time_ms_); + packet.SetPayloadType(last_payload_type_); + return true; + } + + RTC_DCHECK_EQ(packet.Ssrc(), rtx_ssrc_); + if (packet.payload_size() > 0) { + // This is payload padding packet, don't update timestamp fields. + return true; + } + + packet.SetTimestamp(last_rtp_timestamp_); + packet.set_capture_time_ms(last_capture_time_ms_); + + // Only change the timestamp of padding packets sent over RTX. + // Padding only packets over RTP has to be sent as part of a media + // frame (and therefore the same timestamp). + int64_t now_ms = clock_->TimeInMilliseconds(); + if (last_timestamp_time_ms_ > 0) { + packet.SetTimestamp(packet.Timestamp() + + (now_ms - last_timestamp_time_ms_) * + kTimestampTicksPerMs); + if (packet.capture_time_ms() > 0) { + packet.set_capture_time_ms(packet.capture_time_ms() + + (now_ms - last_timestamp_time_ms_)); + } + } + + return true; +} + +} // namespace webrtc diff --git a/modules/rtp_rtcp/source/packet_sequencer.h b/modules/rtp_rtcp/source/packet_sequencer.h new file mode 100644 index 0000000000..67255164f3 --- /dev/null +++ b/modules/rtp_rtcp/source/packet_sequencer.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_RTP_RTCP_SOURCE_PACKET_SEQUENCER_H_ +#define MODULES_RTP_RTCP_SOURCE_PACKET_SEQUENCER_H_ + +#include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" +#include "modules/rtp_rtcp/source/rtp_packet_to_send.h" +#include "system_wrappers/include/clock.h" + +namespace webrtc { + +// Helper class used to assign RTP sequence numbers and populate some fields for +// padding packets based on the last sequenced packets. +// This class is not thread safe, the caller must provide that. +class PacketSequencer { + public: + // If |require_marker_before_media_padding_| is true, padding packets on the + // media ssrc is not allowed unless the last sequenced media packet had the + // marker bit set (i.e. don't insert padding packets between the first and + // last packets of a video frame). + PacketSequencer(uint32_t media_ssrc, + uint32_t rtx_ssrc, + bool require_marker_before_media_padding, + Clock* clock); + + // Assigns sequence number, and in the case of non-RTX padding also timestamps + // and payload type. + // Returns false if sequencing failed, which it can do for instance if the + // packet to squence is padding on the media ssrc, but the media is mid frame + // (the last marker bit is false). + bool Sequence(RtpPacketToSend& packet); + + void set_media_sequence_number(uint16_t sequence_number) { + media_sequence_number_ = sequence_number; + } + void set_rtx_sequence_number(uint16_t sequence_number) { + rtx_sequence_number_ = sequence_number; + } + + void SetRtpState(const RtpState& state); + void PupulateRtpState(RtpState& state) const; + + uint16_t media_sequence_number() const { return media_sequence_number_; } + uint16_t rtx_sequence_number() const { return rtx_sequence_number_; } + + private: + void UpdateLastPacketState(const RtpPacketToSend& packet); + bool PopulatePaddingFields(RtpPacketToSend& packet); + + const uint32_t media_ssrc_; + const uint32_t rtx_ssrc_; + const bool require_marker_before_media_padding_; + Clock* const clock_; + + uint16_t media_sequence_number_; + uint16_t rtx_sequence_number_; + + int8_t last_payload_type_; + uint32_t last_rtp_timestamp_; + int64_t last_capture_time_ms_; + int64_t last_timestamp_time_ms_; + bool last_packet_marker_bit_; +}; + +} // namespace webrtc + +#endif // MODULES_RTP_RTCP_SOURCE_PACKET_SEQUENCER_H_ diff --git a/modules/rtp_rtcp/source/receive_statistics_impl.cc b/modules/rtp_rtcp/source/receive_statistics_impl.cc index 6ec41a1eb0..f5c3eafbf3 100644 --- a/modules/rtp_rtcp/source/receive_statistics_impl.cc +++ b/modules/rtp_rtcp/source/receive_statistics_impl.cc @@ -13,9 +13,11 @@ #include #include #include +#include #include #include "modules/remote_bitrate_estimator/test/bwe_test_logging.h" +#include "modules/rtp_rtcp/source/rtcp_packet/report_block.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "modules/rtp_rtcp/source/rtp_rtcp_config.h" #include "modules/rtp_rtcp/source/time_util.h" @@ -23,9 +25,14 @@ #include "system_wrappers/include/clock.h" namespace webrtc { +namespace { +constexpr int64_t kStatisticsTimeoutMs = 8000; +constexpr int64_t kStatisticsProcessIntervalMs = 1000; -const int64_t kStatisticsTimeoutMs = 8000; -const int64_t kStatisticsProcessIntervalMs = 1000; +// Number of seconds since 1900 January 1 00:00 GMT (see +// https://tools.ietf.org/html/rfc868). +constexpr int64_t kNtpJan1970Millisecs = 2'208'988'800'000; +} // namespace StreamStatistician::~StreamStatistician() {} @@ -34,10 +41,14 @@ StreamStatisticianImpl::StreamStatisticianImpl(uint32_t ssrc, int max_reordering_threshold) : ssrc_(ssrc), clock_(clock), + delta_internal_unix_epoch_ms_(clock_->CurrentNtpInMilliseconds() - + clock_->TimeInMilliseconds() - + kNtpJan1970Millisecs), incoming_bitrate_(kStatisticsProcessIntervalMs, RateStatistics::kBpsScale), max_reordering_threshold_(max_reordering_threshold), enable_retransmit_detection_(false), + cumulative_loss_is_capped_(false), jitter_q4_(0), cumulative_loss_(0), cumulative_loss_rtcp_offset_(0), @@ -100,7 +111,6 @@ bool StreamStatisticianImpl::UpdateOutOfOrder(const RtpPacketReceived& packet, } void StreamStatisticianImpl::UpdateCounters(const RtpPacketReceived& packet) { - MutexLock lock(&stream_lock_); RTC_DCHECK_EQ(ssrc_, packet.Ssrc()); int64_t now_ms = clock_->TimeInMilliseconds(); @@ -159,47 +169,42 @@ void StreamStatisticianImpl::UpdateJitter(const RtpPacketReceived& packet, void StreamStatisticianImpl::SetMaxReorderingThreshold( int max_reordering_threshold) { - MutexLock lock(&stream_lock_); max_reordering_threshold_ = max_reordering_threshold; } void StreamStatisticianImpl::EnableRetransmitDetection(bool enable) { - MutexLock lock(&stream_lock_); enable_retransmit_detection_ = enable; } RtpReceiveStats StreamStatisticianImpl::GetStats() const { - MutexLock lock(&stream_lock_); RtpReceiveStats stats; stats.packets_lost = cumulative_loss_; // TODO(nisse): Can we return a float instead? // Note: internal jitter value is in Q4 and needs to be scaled by 1/16. stats.jitter = jitter_q4_ >> 4; - stats.last_packet_received_timestamp_ms = - receive_counters_.last_packet_received_timestamp_ms; + if (receive_counters_.last_packet_received_timestamp_ms.has_value()) { + stats.last_packet_received_timestamp_ms = + *receive_counters_.last_packet_received_timestamp_ms + + delta_internal_unix_epoch_ms_; + } stats.packet_counter = receive_counters_.transmitted; return stats; } -bool StreamStatisticianImpl::GetActiveStatisticsAndReset( - RtcpStatistics* statistics) { - MutexLock lock(&stream_lock_); - if (clock_->TimeInMilliseconds() - last_receive_time_ms_ >= - kStatisticsTimeoutMs) { +void StreamStatisticianImpl::MaybeAppendReportBlockAndReset( + std::vector& report_blocks) { + int64_t now_ms = clock_->TimeInMilliseconds(); + if (now_ms - last_receive_time_ms_ >= kStatisticsTimeoutMs) { // Not active. - return false; + return; } if (!ReceivedRtpPacket()) { - return false; + return; } - *statistics = CalculateRtcpStatistics(); - - return true; -} - -RtcpStatistics StreamStatisticianImpl::CalculateRtcpStatistics() { - RtcpStatistics stats; + report_blocks.emplace_back(); + rtcp::ReportBlock& stats = report_blocks.back(); + stats.SetMediaSsrc(ssrc_); // Calculate fraction lost. int64_t exp_since_last = received_seq_max_ - last_report_seq_max_; RTC_DCHECK_GE(exp_since_last, 0); @@ -207,41 +212,42 @@ RtcpStatistics StreamStatisticianImpl::CalculateRtcpStatistics() { int32_t lost_since_last = cumulative_loss_ - last_report_cumulative_loss_; if (exp_since_last > 0 && lost_since_last > 0) { // Scale 0 to 255, where 255 is 100% loss. - stats.fraction_lost = - static_cast(255 * lost_since_last / exp_since_last); - } else { - stats.fraction_lost = 0; + stats.SetFractionLost(255 * lost_since_last / exp_since_last); } - // TODO(danilchap): Ensure |stats.packets_lost| is clamped to fit in a signed - // 24-bit value. - stats.packets_lost = cumulative_loss_ + cumulative_loss_rtcp_offset_; - if (stats.packets_lost < 0) { + int packets_lost = cumulative_loss_ + cumulative_loss_rtcp_offset_; + if (packets_lost < 0) { // Clamp to zero. Work around to accomodate for senders that misbehave with // negative cumulative loss. - stats.packets_lost = 0; + packets_lost = 0; cumulative_loss_rtcp_offset_ = -cumulative_loss_; } - stats.extended_highest_sequence_number = - static_cast(received_seq_max_); + if (packets_lost > 0x7fffff) { + // Packets lost is a 24 bit signed field, and thus should be clamped, as + // described in https://datatracker.ietf.org/doc/html/rfc3550#appendix-A.3 + if (!cumulative_loss_is_capped_) { + cumulative_loss_is_capped_ = true; + RTC_LOG(LS_WARNING) << "Cumulative loss reached maximum value for ssrc " + << ssrc_; + } + packets_lost = 0x7fffff; + } + stats.SetCumulativeLost(packets_lost); + stats.SetExtHighestSeqNum(received_seq_max_); // Note: internal jitter value is in Q4 and needs to be scaled by 1/16. - stats.jitter = jitter_q4_ >> 4; + stats.SetJitter(jitter_q4_ >> 4); // Only for report blocks in RTCP SR and RR. last_report_cumulative_loss_ = cumulative_loss_; last_report_seq_max_ = received_seq_max_; - BWE_TEST_LOGGING_PLOT_WITH_SSRC(1, "cumulative_loss_pkts", - clock_->TimeInMilliseconds(), + BWE_TEST_LOGGING_PLOT_WITH_SSRC(1, "cumulative_loss_pkts", now_ms, cumulative_loss_, ssrc_); - BWE_TEST_LOGGING_PLOT_WITH_SSRC( - 1, "received_seq_max_pkts", clock_->TimeInMilliseconds(), - (received_seq_max_ - received_seq_first_), ssrc_); - - return stats; + BWE_TEST_LOGGING_PLOT_WITH_SSRC(1, "received_seq_max_pkts", now_ms, + (received_seq_max_ - received_seq_first_), + ssrc_); } absl::optional StreamStatisticianImpl::GetFractionLostInPercent() const { - MutexLock lock(&stream_lock_); if (!ReceivedRtpPacket()) { return absl::nullopt; } @@ -257,12 +263,10 @@ absl::optional StreamStatisticianImpl::GetFractionLostInPercent() const { StreamDataCounters StreamStatisticianImpl::GetReceiveStreamDataCounters() const { - MutexLock lock(&stream_lock_); return receive_counters_; } uint32_t StreamStatisticianImpl::BitrateReceived() const { - MutexLock lock(&stream_lock_); return incoming_bitrate_.Rate(clock_->TimeInMilliseconds()).value_or(0); } @@ -295,21 +299,33 @@ bool StreamStatisticianImpl::IsRetransmitOfOldPacket( } std::unique_ptr ReceiveStatistics::Create(Clock* clock) { - return std::make_unique(clock); + return std::make_unique( + clock, [](uint32_t ssrc, Clock* clock, int max_reordering_threshold) { + return std::make_unique( + ssrc, clock, max_reordering_threshold); + }); +} + +std::unique_ptr ReceiveStatistics::CreateThreadCompatible( + Clock* clock) { + return std::make_unique( + clock, [](uint32_t ssrc, Clock* clock, int max_reordering_threshold) { + return std::make_unique( + ssrc, clock, max_reordering_threshold); + }); } -ReceiveStatisticsImpl::ReceiveStatisticsImpl(Clock* clock) +ReceiveStatisticsImpl::ReceiveStatisticsImpl( + Clock* clock, + std::function( + uint32_t ssrc, + Clock* clock, + int max_reordering_threshold)> stream_statistician_factory) : clock_(clock), - last_returned_ssrc_(0), + stream_statistician_factory_(std::move(stream_statistician_factory)), + last_returned_ssrc_idx_(0), max_reordering_threshold_(kDefaultMaxReorderingThreshold) {} -ReceiveStatisticsImpl::~ReceiveStatisticsImpl() { - while (!statisticians_.empty()) { - delete statisticians_.begin()->second; - statisticians_.erase(statisticians_.begin()); - } -} - void ReceiveStatisticsImpl::OnRtpPacket(const RtpPacketReceived& packet) { // StreamStatisticianImpl instance is created once and only destroyed when // this whole ReceiveStatisticsImpl is destroyed. StreamStatisticianImpl has @@ -318,34 +334,29 @@ void ReceiveStatisticsImpl::OnRtpPacket(const RtpPacketReceived& packet) { GetOrCreateStatistician(packet.Ssrc())->UpdateCounters(packet); } -StreamStatisticianImpl* ReceiveStatisticsImpl::GetStatistician( +StreamStatistician* ReceiveStatisticsImpl::GetStatistician( uint32_t ssrc) const { - MutexLock lock(&receive_statistics_lock_); const auto& it = statisticians_.find(ssrc); if (it == statisticians_.end()) - return NULL; - return it->second; + return nullptr; + return it->second.get(); } -StreamStatisticianImpl* ReceiveStatisticsImpl::GetOrCreateStatistician( +StreamStatisticianImplInterface* ReceiveStatisticsImpl::GetOrCreateStatistician( uint32_t ssrc) { - MutexLock lock(&receive_statistics_lock_); - StreamStatisticianImpl*& impl = statisticians_[ssrc]; + std::unique_ptr& impl = statisticians_[ssrc]; if (impl == nullptr) { // new element - impl = new StreamStatisticianImpl(ssrc, clock_, max_reordering_threshold_); + impl = + stream_statistician_factory_(ssrc, clock_, max_reordering_threshold_); + all_ssrcs_.push_back(ssrc); } - return impl; + return impl.get(); } void ReceiveStatisticsImpl::SetMaxReorderingThreshold( int max_reordering_threshold) { - std::map statisticians; - { - MutexLock lock(&receive_statistics_lock_); - max_reordering_threshold_ = max_reordering_threshold; - statisticians = statisticians_; - } - for (auto& statistician : statisticians) { + max_reordering_threshold_ = max_reordering_threshold; + for (auto& statistician : statisticians_) { statistician.second->SetMaxReorderingThreshold(max_reordering_threshold); } } @@ -364,42 +375,18 @@ void ReceiveStatisticsImpl::EnableRetransmitDetection(uint32_t ssrc, std::vector ReceiveStatisticsImpl::RtcpReportBlocks( size_t max_blocks) { - std::map statisticians; - { - MutexLock lock(&receive_statistics_lock_); - statisticians = statisticians_; - } std::vector result; - result.reserve(std::min(max_blocks, statisticians.size())); - auto add_report_block = [&result](uint32_t media_ssrc, - StreamStatisticianImpl* statistician) { - // Do we have receive statistics to send? - RtcpStatistics stats; - if (!statistician->GetActiveStatisticsAndReset(&stats)) - return; - result.emplace_back(); - rtcp::ReportBlock& block = result.back(); - block.SetMediaSsrc(media_ssrc); - block.SetFractionLost(stats.fraction_lost); - if (!block.SetCumulativeLost(stats.packets_lost)) { - RTC_LOG(LS_WARNING) << "Cumulative lost is oversized."; - result.pop_back(); - return; - } - block.SetExtHighestSeqNum(stats.extended_highest_sequence_number); - block.SetJitter(stats.jitter); - }; - - const auto start_it = statisticians.upper_bound(last_returned_ssrc_); - for (auto it = start_it; - result.size() < max_blocks && it != statisticians.end(); ++it) - add_report_block(it->first, it->second); - for (auto it = statisticians.begin(); - result.size() < max_blocks && it != start_it; ++it) - add_report_block(it->first, it->second); - - if (!result.empty()) - last_returned_ssrc_ = result.back().source_ssrc(); + result.reserve(std::min(max_blocks, all_ssrcs_.size())); + + size_t ssrc_idx = 0; + for (size_t i = 0; i < all_ssrcs_.size() && result.size() < max_blocks; ++i) { + ssrc_idx = (last_returned_ssrc_idx_ + i + 1) % all_ssrcs_.size(); + const uint32_t media_ssrc = all_ssrcs_[ssrc_idx]; + auto statistician_it = statisticians_.find(media_ssrc); + RTC_DCHECK(statistician_it != statisticians_.end()); + statistician_it->second->MaybeAppendReportBlockAndReset(result); + } + last_returned_ssrc_idx_ = ssrc_idx; return result; } diff --git a/modules/rtp_rtcp/source/receive_statistics_impl.h b/modules/rtp_rtcp/source/receive_statistics_impl.h index 41830b0b48..1a70fe4ad7 100644 --- a/modules/rtp_rtcp/source/receive_statistics_impl.h +++ b/modules/rtp_rtcp/source/receive_statistics_impl.h @@ -12,98 +12,162 @@ #define MODULES_RTP_RTCP_SOURCE_RECEIVE_STATISTICS_IMPL_H_ #include -#include +#include +#include +#include #include #include "absl/types/optional.h" #include "modules/include/module_common_types_public.h" #include "modules/rtp_rtcp/include/receive_statistics.h" +#include "modules/rtp_rtcp/source/rtcp_packet/report_block.h" +#include "rtc_base/containers/flat_map.h" #include "rtc_base/rate_statistics.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread_annotations.h" namespace webrtc { -class StreamStatisticianImpl : public StreamStatistician { +// Extends StreamStatistician with methods needed by the implementation. +class StreamStatisticianImplInterface : public StreamStatistician { + public: + virtual ~StreamStatisticianImplInterface() = default; + virtual void MaybeAppendReportBlockAndReset( + std::vector& report_blocks) = 0; + virtual void SetMaxReorderingThreshold(int max_reordering_threshold) = 0; + virtual void EnableRetransmitDetection(bool enable) = 0; + virtual void UpdateCounters(const RtpPacketReceived& packet) = 0; +}; + +// Thread-compatible implementation of StreamStatisticianImplInterface. +class StreamStatisticianImpl : public StreamStatisticianImplInterface { public: StreamStatisticianImpl(uint32_t ssrc, Clock* clock, int max_reordering_threshold); ~StreamStatisticianImpl() override; + // Implements StreamStatistician RtpReceiveStats GetStats() const override; - - bool GetActiveStatisticsAndReset(RtcpStatistics* statistics); absl::optional GetFractionLostInPercent() const override; StreamDataCounters GetReceiveStreamDataCounters() const override; uint32_t BitrateReceived() const override; - void SetMaxReorderingThreshold(int max_reordering_threshold); - void EnableRetransmitDetection(bool enable); - + // Implements StreamStatisticianImplInterface + void MaybeAppendReportBlockAndReset( + std::vector& report_blocks) override; + void SetMaxReorderingThreshold(int max_reordering_threshold) override; + void EnableRetransmitDetection(bool enable) override; // Updates StreamStatistician for incoming packets. - void UpdateCounters(const RtpPacketReceived& packet); + void UpdateCounters(const RtpPacketReceived& packet) override; private: bool IsRetransmitOfOldPacket(const RtpPacketReceived& packet, - int64_t now_ms) const - RTC_EXCLUSIVE_LOCKS_REQUIRED(stream_lock_); - RtcpStatistics CalculateRtcpStatistics() - RTC_EXCLUSIVE_LOCKS_REQUIRED(stream_lock_); - void UpdateJitter(const RtpPacketReceived& packet, int64_t receive_time_ms) - RTC_EXCLUSIVE_LOCKS_REQUIRED(stream_lock_); + int64_t now_ms) const; + void UpdateJitter(const RtpPacketReceived& packet, int64_t receive_time_ms); // Updates StreamStatistician for out of order packets. // Returns true if packet considered to be out of order. bool UpdateOutOfOrder(const RtpPacketReceived& packet, int64_t sequence_number, - int64_t now_ms) - RTC_EXCLUSIVE_LOCKS_REQUIRED(stream_lock_); + int64_t now_ms); // Checks if this StreamStatistician received any rtp packets. - bool ReceivedRtpPacket() const RTC_EXCLUSIVE_LOCKS_REQUIRED(stream_lock_) { - return received_seq_first_ >= 0; - } + bool ReceivedRtpPacket() const { return received_seq_first_ >= 0; } const uint32_t ssrc_; Clock* const clock_; - mutable Mutex stream_lock_; - RateStatistics incoming_bitrate_ RTC_GUARDED_BY(&stream_lock_); + // Delta used to map internal timestamps to Unix epoch ones. + const int64_t delta_internal_unix_epoch_ms_; + RateStatistics incoming_bitrate_; // In number of packets or sequence numbers. - int max_reordering_threshold_ RTC_GUARDED_BY(&stream_lock_); - bool enable_retransmit_detection_ RTC_GUARDED_BY(&stream_lock_); + int max_reordering_threshold_; + bool enable_retransmit_detection_; + bool cumulative_loss_is_capped_; // Stats on received RTP packets. - uint32_t jitter_q4_ RTC_GUARDED_BY(&stream_lock_); + uint32_t jitter_q4_; // Cumulative loss according to RFC 3550, which may be negative (and often is, // if packets are reordered and there are non-RTX retransmissions). - int32_t cumulative_loss_ RTC_GUARDED_BY(&stream_lock_); + int32_t cumulative_loss_; // Offset added to outgoing rtcp reports, to make ensure that the reported // cumulative loss is non-negative. Reports with negative values confuse some // senders, in particular, our own loss-based bandwidth estimator. - int32_t cumulative_loss_rtcp_offset_ RTC_GUARDED_BY(&stream_lock_); + int32_t cumulative_loss_rtcp_offset_; - int64_t last_receive_time_ms_ RTC_GUARDED_BY(&stream_lock_); - uint32_t last_received_timestamp_ RTC_GUARDED_BY(&stream_lock_); - SequenceNumberUnwrapper seq_unwrapper_ RTC_GUARDED_BY(&stream_lock_); - int64_t received_seq_first_ RTC_GUARDED_BY(&stream_lock_); - int64_t received_seq_max_ RTC_GUARDED_BY(&stream_lock_); + int64_t last_receive_time_ms_; + uint32_t last_received_timestamp_; + SequenceNumberUnwrapper seq_unwrapper_; + int64_t received_seq_first_; + int64_t received_seq_max_; // Assume that the other side restarted when there are two sequential packets // with large jump from received_seq_max_. - absl::optional received_seq_out_of_order_ - RTC_GUARDED_BY(&stream_lock_); + absl::optional received_seq_out_of_order_; // Current counter values. - StreamDataCounters receive_counters_ RTC_GUARDED_BY(&stream_lock_); + StreamDataCounters receive_counters_; // Counter values when we sent the last report. - int32_t last_report_cumulative_loss_ RTC_GUARDED_BY(&stream_lock_); - int64_t last_report_seq_max_ RTC_GUARDED_BY(&stream_lock_); + int32_t last_report_cumulative_loss_; + int64_t last_report_seq_max_; }; -class ReceiveStatisticsImpl : public ReceiveStatistics { +// Thread-safe implementation of StreamStatisticianImplInterface. +class StreamStatisticianLocked : public StreamStatisticianImplInterface { public: - explicit ReceiveStatisticsImpl(Clock* clock); + StreamStatisticianLocked(uint32_t ssrc, + Clock* clock, + int max_reordering_threshold) + : impl_(ssrc, clock, max_reordering_threshold) {} + ~StreamStatisticianLocked() override = default; + + RtpReceiveStats GetStats() const override { + MutexLock lock(&stream_lock_); + return impl_.GetStats(); + } + absl::optional GetFractionLostInPercent() const override { + MutexLock lock(&stream_lock_); + return impl_.GetFractionLostInPercent(); + } + StreamDataCounters GetReceiveStreamDataCounters() const override { + MutexLock lock(&stream_lock_); + return impl_.GetReceiveStreamDataCounters(); + } + uint32_t BitrateReceived() const override { + MutexLock lock(&stream_lock_); + return impl_.BitrateReceived(); + } + void MaybeAppendReportBlockAndReset( + std::vector& report_blocks) override { + MutexLock lock(&stream_lock_); + impl_.MaybeAppendReportBlockAndReset(report_blocks); + } + void SetMaxReorderingThreshold(int max_reordering_threshold) override { + MutexLock lock(&stream_lock_); + return impl_.SetMaxReorderingThreshold(max_reordering_threshold); + } + void EnableRetransmitDetection(bool enable) override { + MutexLock lock(&stream_lock_); + return impl_.EnableRetransmitDetection(enable); + } + void UpdateCounters(const RtpPacketReceived& packet) override { + MutexLock lock(&stream_lock_); + return impl_.UpdateCounters(packet); + } + + private: + mutable Mutex stream_lock_; + StreamStatisticianImpl impl_ RTC_GUARDED_BY(&stream_lock_); +}; - ~ReceiveStatisticsImpl() override; +// Thread-compatible implementation. +class ReceiveStatisticsImpl : public ReceiveStatistics { + public: + ReceiveStatisticsImpl( + Clock* clock, + std::function( + uint32_t ssrc, + Clock* clock, + int max_reordering_threshold)> stream_statistician_factory); + ~ReceiveStatisticsImpl() override = default; // Implements ReceiveStatisticsProvider. std::vector RtcpReportBlocks(size_t max_blocks) override; @@ -112,22 +176,71 @@ class ReceiveStatisticsImpl : public ReceiveStatistics { void OnRtpPacket(const RtpPacketReceived& packet) override; // Implements ReceiveStatistics. - // Note: More specific return type for use in the implementation. - StreamStatisticianImpl* GetStatistician(uint32_t ssrc) const override; + StreamStatistician* GetStatistician(uint32_t ssrc) const override; void SetMaxReorderingThreshold(int max_reordering_threshold) override; void SetMaxReorderingThreshold(uint32_t ssrc, int max_reordering_threshold) override; void EnableRetransmitDetection(uint32_t ssrc, bool enable) override; private: - StreamStatisticianImpl* GetOrCreateStatistician(uint32_t ssrc); + StreamStatisticianImplInterface* GetOrCreateStatistician(uint32_t ssrc); Clock* const clock_; + std::function( + uint32_t ssrc, + Clock* clock, + int max_reordering_threshold)> + stream_statistician_factory_; + // The index within `all_ssrcs_` that was last returned. + size_t last_returned_ssrc_idx_; + std::vector all_ssrcs_; + int max_reordering_threshold_; + flat_map> + statisticians_; +}; + +// Thread-safe implementation wrapping access to ReceiveStatisticsImpl with a +// mutex. +class ReceiveStatisticsLocked : public ReceiveStatistics { + public: + explicit ReceiveStatisticsLocked( + Clock* clock, + std::function( + uint32_t ssrc, + Clock* clock, + int max_reordering_threshold)> stream_statitician_factory) + : impl_(clock, std::move(stream_statitician_factory)) {} + ~ReceiveStatisticsLocked() override = default; + std::vector RtcpReportBlocks(size_t max_blocks) override { + MutexLock lock(&receive_statistics_lock_); + return impl_.RtcpReportBlocks(max_blocks); + } + void OnRtpPacket(const RtpPacketReceived& packet) override { + MutexLock lock(&receive_statistics_lock_); + return impl_.OnRtpPacket(packet); + } + StreamStatistician* GetStatistician(uint32_t ssrc) const override { + MutexLock lock(&receive_statistics_lock_); + return impl_.GetStatistician(ssrc); + } + void SetMaxReorderingThreshold(int max_reordering_threshold) override { + MutexLock lock(&receive_statistics_lock_); + return impl_.SetMaxReorderingThreshold(max_reordering_threshold); + } + void SetMaxReorderingThreshold(uint32_t ssrc, + int max_reordering_threshold) override { + MutexLock lock(&receive_statistics_lock_); + return impl_.SetMaxReorderingThreshold(ssrc, max_reordering_threshold); + } + void EnableRetransmitDetection(uint32_t ssrc, bool enable) override { + MutexLock lock(&receive_statistics_lock_); + return impl_.EnableRetransmitDetection(ssrc, enable); + } + + private: mutable Mutex receive_statistics_lock_; - uint32_t last_returned_ssrc_; - int max_reordering_threshold_ RTC_GUARDED_BY(receive_statistics_lock_); - std::map statisticians_ - RTC_GUARDED_BY(receive_statistics_lock_); + ReceiveStatisticsImpl impl_ RTC_GUARDED_BY(&receive_statistics_lock_); }; + } // namespace webrtc #endif // MODULES_RTP_RTCP_SOURCE_RECEIVE_STATISTICS_IMPL_H_ diff --git a/modules/rtp_rtcp/source/receive_statistics_unittest.cc b/modules/rtp_rtcp/source/receive_statistics_unittest.cc index 053460e2ba..d40a743469 100644 --- a/modules/rtp_rtcp/source/receive_statistics_unittest.cc +++ b/modules/rtp_rtcp/source/receive_statistics_unittest.cc @@ -65,10 +65,13 @@ void IncrementSequenceNumber(RtpPacketReceived* packet) { IncrementSequenceNumber(packet, 1); } -class ReceiveStatisticsTest : public ::testing::Test { +class ReceiveStatisticsTest : public ::testing::TestWithParam { public: ReceiveStatisticsTest() - : clock_(0), receive_statistics_(ReceiveStatistics::Create(&clock_)) { + : clock_(0), + receive_statistics_( + GetParam() ? ReceiveStatistics::Create(&clock_) + : ReceiveStatistics::CreateThreadCompatible(&clock_)) { packet1_ = CreateRtpPacket(kSsrc1, kPacketSize1); packet2_ = CreateRtpPacket(kSsrc2, kPacketSize2); } @@ -80,7 +83,14 @@ class ReceiveStatisticsTest : public ::testing::Test { RtpPacketReceived packet2_; }; -TEST_F(ReceiveStatisticsTest, TwoIncomingSsrcs) { +INSTANTIATE_TEST_SUITE_P(All, + ReceiveStatisticsTest, + ::testing::Bool(), + [](::testing::TestParamInfo info) { + return info.param ? "WithMutex" : "WithoutMutex"; + }); + +TEST_P(ReceiveStatisticsTest, TwoIncomingSsrcs) { receive_statistics_->OnRtpPacket(packet1_); IncrementSequenceNumber(&packet1_); receive_statistics_->OnRtpPacket(packet2_); @@ -133,7 +143,7 @@ TEST_F(ReceiveStatisticsTest, TwoIncomingSsrcs) { EXPECT_EQ(3u, counters.transmitted.packets); } -TEST_F(ReceiveStatisticsTest, +TEST_P(ReceiveStatisticsTest, RtcpReportBlocksReturnsMaxBlocksWhenThereAreMoreStatisticians) { RtpPacketReceived packet1 = CreateRtpPacket(kSsrc1, kPacketSize1); RtpPacketReceived packet2 = CreateRtpPacket(kSsrc2, kPacketSize1); @@ -147,7 +157,7 @@ TEST_F(ReceiveStatisticsTest, EXPECT_THAT(receive_statistics_->RtcpReportBlocks(2), SizeIs(2)); } -TEST_F(ReceiveStatisticsTest, +TEST_P(ReceiveStatisticsTest, RtcpReportBlocksReturnsAllObservedSsrcsWithMultipleCalls) { RtpPacketReceived packet1 = CreateRtpPacket(kSsrc1, kPacketSize1); RtpPacketReceived packet2 = CreateRtpPacket(kSsrc2, kPacketSize1); @@ -174,7 +184,7 @@ TEST_F(ReceiveStatisticsTest, UnorderedElementsAre(kSsrc1, kSsrc2, kSsrc3, kSsrc4)); } -TEST_F(ReceiveStatisticsTest, ActiveStatisticians) { +TEST_P(ReceiveStatisticsTest, ActiveStatisticians) { receive_statistics_->OnRtpPacket(packet1_); IncrementSequenceNumber(&packet1_); clock_.AdvanceTimeMilliseconds(1000); @@ -206,7 +216,7 @@ TEST_F(ReceiveStatisticsTest, ActiveStatisticians) { EXPECT_EQ(2u, counters.transmitted.packets); } -TEST_F(ReceiveStatisticsTest, +TEST_P(ReceiveStatisticsTest, DoesntCreateRtcpReportBlockUntilFirstReceivedPacketForSsrc) { // Creates a statistician object for the ssrc. receive_statistics_->EnableRetransmitDetection(kSsrc1, true); @@ -217,7 +227,7 @@ TEST_F(ReceiveStatisticsTest, EXPECT_EQ(1u, receive_statistics_->RtcpReportBlocks(3).size()); } -TEST_F(ReceiveStatisticsTest, GetReceiveStreamDataCounters) { +TEST_P(ReceiveStatisticsTest, GetReceiveStreamDataCounters) { receive_statistics_->OnRtpPacket(packet1_); StreamStatistician* statistician = receive_statistics_->GetStatistician(kSsrc1); @@ -233,7 +243,7 @@ TEST_F(ReceiveStatisticsTest, GetReceiveStreamDataCounters) { EXPECT_EQ(2u, counters.transmitted.packets); } -TEST_F(ReceiveStatisticsTest, SimpleLossComputation) { +TEST_P(ReceiveStatisticsTest, SimpleLossComputation) { packet1_.SetSequenceNumber(1); receive_statistics_->OnRtpPacket(packet1_); packet1_.SetSequenceNumber(3); @@ -256,7 +266,7 @@ TEST_F(ReceiveStatisticsTest, SimpleLossComputation) { EXPECT_EQ(20, statistician->GetFractionLostInPercent()); } -TEST_F(ReceiveStatisticsTest, LossComputationWithReordering) { +TEST_P(ReceiveStatisticsTest, LossComputationWithReordering) { packet1_.SetSequenceNumber(1); receive_statistics_->OnRtpPacket(packet1_); packet1_.SetSequenceNumber(3); @@ -279,7 +289,7 @@ TEST_F(ReceiveStatisticsTest, LossComputationWithReordering) { EXPECT_EQ(20, statistician->GetFractionLostInPercent()); } -TEST_F(ReceiveStatisticsTest, LossComputationWithDuplicates) { +TEST_P(ReceiveStatisticsTest, LossComputationWithDuplicates) { // Lose 2 packets, but also receive 1 duplicate. Should actually count as // only 1 packet being lost. packet1_.SetSequenceNumber(1); @@ -304,7 +314,7 @@ TEST_F(ReceiveStatisticsTest, LossComputationWithDuplicates) { EXPECT_EQ(20, statistician->GetFractionLostInPercent()); } -TEST_F(ReceiveStatisticsTest, LossComputationWithSequenceNumberWrapping) { +TEST_P(ReceiveStatisticsTest, LossComputationWithSequenceNumberWrapping) { // First, test loss computation over a period that included a sequence number // rollover. packet1_.SetSequenceNumber(0xfffd); @@ -344,7 +354,7 @@ TEST_F(ReceiveStatisticsTest, LossComputationWithSequenceNumberWrapping) { EXPECT_EQ(28, statistician->GetFractionLostInPercent()); } -TEST_F(ReceiveStatisticsTest, StreamRestartDoesntCountAsLoss) { +TEST_P(ReceiveStatisticsTest, StreamRestartDoesntCountAsLoss) { receive_statistics_->SetMaxReorderingThreshold(kSsrc1, 200); packet1_.SetSequenceNumber(0); @@ -377,7 +387,7 @@ TEST_F(ReceiveStatisticsTest, StreamRestartDoesntCountAsLoss) { EXPECT_EQ(0, statistician->GetFractionLostInPercent()); } -TEST_F(ReceiveStatisticsTest, CountsLossAfterStreamRestart) { +TEST_P(ReceiveStatisticsTest, CountsLossAfterStreamRestart) { receive_statistics_->SetMaxReorderingThreshold(kSsrc1, 200); packet1_.SetSequenceNumber(0); @@ -405,7 +415,7 @@ TEST_F(ReceiveStatisticsTest, CountsLossAfterStreamRestart) { EXPECT_EQ(0, statistician->GetFractionLostInPercent()); } -TEST_F(ReceiveStatisticsTest, StreamCanRestartAtSequenceNumberWrapAround) { +TEST_P(ReceiveStatisticsTest, StreamCanRestartAtSequenceNumberWrapAround) { receive_statistics_->SetMaxReorderingThreshold(kSsrc1, 200); packet1_.SetSequenceNumber(0xffff - 401); @@ -428,7 +438,7 @@ TEST_F(ReceiveStatisticsTest, StreamCanRestartAtSequenceNumberWrapAround) { EXPECT_EQ(1, report_blocks[0].cumulative_lost_signed()); } -TEST_F(ReceiveStatisticsTest, StreamRestartNeedsTwoConsecutivePackets) { +TEST_P(ReceiveStatisticsTest, StreamRestartNeedsTwoConsecutivePackets) { receive_statistics_->SetMaxReorderingThreshold(kSsrc1, 200); packet1_.SetSequenceNumber(400); @@ -458,7 +468,7 @@ TEST_F(ReceiveStatisticsTest, StreamRestartNeedsTwoConsecutivePackets) { EXPECT_EQ(4u, report_blocks[0].extended_high_seq_num()); } -TEST_F(ReceiveStatisticsTest, WrapsAroundExtendedHighestSequenceNumber) { +TEST_P(ReceiveStatisticsTest, WrapsAroundExtendedHighestSequenceNumber) { packet1_.SetSequenceNumber(0xffff); receive_statistics_->OnRtpPacket(packet1_); @@ -503,8 +513,7 @@ TEST_F(ReceiveStatisticsTest, WrapsAroundExtendedHighestSequenceNumber) { EXPECT_EQ(0x20001u, report_blocks[0].extended_high_seq_num()); } -TEST_F(ReceiveStatisticsTest, StreamDataCounters) { - receive_statistics_ = ReceiveStatistics::Create(&clock_); +TEST_P(ReceiveStatisticsTest, StreamDataCounters) { receive_statistics_->EnableRetransmitDetection(kSsrc1, true); const size_t kHeaderLength = 20; @@ -554,9 +563,7 @@ TEST_F(ReceiveStatisticsTest, StreamDataCounters) { EXPECT_EQ(counters.retransmitted.packets, 1u); } -TEST_F(ReceiveStatisticsTest, LastPacketReceivedTimestamp) { - receive_statistics_ = ReceiveStatistics::Create(&clock_); - +TEST_P(ReceiveStatisticsTest, LastPacketReceivedTimestamp) { clock_.AdvanceTimeMilliseconds(42); receive_statistics_->OnRtpPacket(packet1_); StreamDataCounters counters = receive_statistics_->GetStatistician(kSsrc1) diff --git a/modules/rtp_rtcp/source/remote_ntp_time_estimator.cc b/modules/rtp_rtcp/source/remote_ntp_time_estimator.cc index 6fed7314c0..723064eeba 100644 --- a/modules/rtp_rtcp/source/remote_ntp_time_estimator.cc +++ b/modules/rtp_rtcp/source/remote_ntp_time_estimator.cc @@ -15,6 +15,7 @@ #include "modules/rtp_rtcp/source/time_util.h" #include "rtc_base/logging.h" #include "system_wrappers/include/clock.h" +#include "system_wrappers/include/ntp_time.h" namespace webrtc { @@ -51,9 +52,8 @@ bool RemoteNtpTimeEstimator::UpdateRtcpTimestamp(int64_t rtt, // Update extrapolator with the new arrival time. // The extrapolator assumes the ntp time. - int64_t receiver_arrival_time_ms = - clock_->TimeInMilliseconds() + NtpOffsetMs(); - int64_t sender_send_time_ms = Clock::NtpToMs(ntp_secs, ntp_frac); + int64_t receiver_arrival_time_ms = clock_->CurrentNtpInMilliseconds(); + int64_t sender_send_time_ms = NtpTime(ntp_secs, ntp_frac).ToMs(); int64_t sender_arrival_time_ms = sender_send_time_ms + rtt / 2; int64_t remote_to_local_clocks_offset = receiver_arrival_time_ms - sender_arrival_time_ms; @@ -72,16 +72,7 @@ int64_t RemoteNtpTimeEstimator::Estimate(uint32_t rtp_timestamp) { int64_t receiver_capture_ntp_ms = sender_capture_ntp_ms + remote_to_local_clocks_offset; - // TODO(bugs.webrtc.org/11327): Clock::CurrentNtpInMilliseconds() was - // previously used to calculate the offset between the local and the remote - // clock. However, rtc::TimeMillis() + NtpOffsetMs() is now used as the local - // ntp clock value. To preserve the old behavior of this method, the return - // value is adjusted with the difference between the two local ntp clocks. int64_t now_ms = clock_->TimeInMilliseconds(); - int64_t offset_between_local_ntp_clocks = - clock_->CurrentNtpInMilliseconds() - now_ms - NtpOffsetMs(); - receiver_capture_ntp_ms += offset_between_local_ntp_clocks; - if (now_ms - last_timing_log_ms_ > kTimingLogIntervalMs) { RTC_LOG(LS_INFO) << "RTP timestamp: " << rtp_timestamp << " in NTP clock: " << sender_capture_ntp_ms @@ -89,6 +80,7 @@ int64_t RemoteNtpTimeEstimator::Estimate(uint32_t rtp_timestamp) { << receiver_capture_ntp_ms; last_timing_log_ms_ = now_ms; } + return receiver_capture_ntp_ms; } diff --git a/modules/rtp_rtcp/source/remote_ntp_time_estimator_unittest.cc b/modules/rtp_rtcp/source/remote_ntp_time_estimator_unittest.cc index 85f08483ea..73c3e9b9b8 100644 --- a/modules/rtp_rtcp/source/remote_ntp_time_estimator_unittest.cc +++ b/modules/rtp_rtcp/source/remote_ntp_time_estimator_unittest.cc @@ -10,7 +10,6 @@ #include "modules/rtp_rtcp/include/remote_ntp_time_estimator.h" #include "absl/types/optional.h" -#include "modules/rtp_rtcp/source/time_util.h" #include "system_wrappers/include/clock.h" #include "system_wrappers/include/ntp_time.h" #include "test/gmock.h" @@ -43,9 +42,7 @@ class RemoteNtpTimeEstimatorTest : public ::testing::Test { kTimestampOffset; } - NtpTime GetRemoteNtpTime() { - return TimeMicrosToNtp(remote_clock_.TimeInMicroseconds()); - } + NtpTime GetRemoteNtpTime() { return remote_clock_.CurrentNtpTime(); } void SendRtcpSr() { uint32_t rtcp_timestamp = GetRemoteTimestamp(); diff --git a/modules/rtp_rtcp/source/rtcp_packet/extended_reports.h b/modules/rtp_rtcp/source/rtcp_packet/extended_reports.h index 9627aac959..6c804bbc7b 100644 --- a/modules/rtp_rtcp/source/rtcp_packet/extended_reports.h +++ b/modules/rtp_rtcp/source/rtcp_packet/extended_reports.h @@ -62,7 +62,6 @@ class ExtendedReports : public RtcpPacket { void ParseRrtrBlock(const uint8_t* block, uint16_t block_length); void ParseDlrrBlock(const uint8_t* block, uint16_t block_length); - void ParseVoipMetricBlock(const uint8_t* block, uint16_t block_length); void ParseTargetBitrateBlock(const uint8_t* block, uint16_t block_length); absl::optional rrtr_block_; diff --git a/modules/rtp_rtcp/source/rtcp_packet/loss_notification.h b/modules/rtp_rtcp/source/rtcp_packet/loss_notification.h index 2603a6715e..99f6d12da4 100644 --- a/modules/rtp_rtcp/source/rtcp_packet/loss_notification.h +++ b/modules/rtp_rtcp/source/rtcp_packet/loss_notification.h @@ -11,9 +11,9 @@ #ifndef MODULES_RTP_RTCP_SOURCE_RTCP_PACKET_LOSS_NOTIFICATION_H_ #define MODULES_RTP_RTCP_SOURCE_RTCP_PACKET_LOSS_NOTIFICATION_H_ +#include "absl/base/attributes.h" #include "modules/rtp_rtcp/source/rtcp_packet/common_header.h" #include "modules/rtp_rtcp/source/rtcp_packet/psfb.h" -#include "rtc_base/system/unused.h" namespace webrtc { namespace rtcp { @@ -29,14 +29,15 @@ class LossNotification : public Psfb { size_t BlockLength() const override; + ABSL_MUST_USE_RESULT bool Create(uint8_t* packet, size_t* index, size_t max_length, - PacketReadyCallback callback) const override - RTC_WARN_UNUSED_RESULT; + PacketReadyCallback callback) const override; // Parse assumes header is already parsed and validated. - bool Parse(const CommonHeader& packet) RTC_WARN_UNUSED_RESULT; + ABSL_MUST_USE_RESULT + bool Parse(const CommonHeader& packet); // Set all of the values transmitted by the loss notification message. // If the values may not be represented by a loss notification message, @@ -44,9 +45,10 @@ class LossNotification : public Psfb { // when |last_recieved| is ahead of |last_decoded| by more than 0x7fff. // This is because |last_recieved| is represented on the wire as a delta, // and only 15 bits are available for that delta. + ABSL_MUST_USE_RESULT bool Set(uint16_t last_decoded, uint16_t last_received, - bool decodability_flag) RTC_WARN_UNUSED_RESULT; + bool decodability_flag); // RTP sequence number of the first packet belong to the last decoded // non-discardable frame. diff --git a/modules/rtp_rtcp/source/rtcp_receiver.cc b/modules/rtp_rtcp/source/rtcp_receiver.cc index a9ec2a10a0..3ab78df17c 100644 --- a/modules/rtp_rtcp/source/rtcp_receiver.cc +++ b/modules/rtp_rtcp/source/rtcp_receiver.cc @@ -39,6 +39,7 @@ #include "modules/rtp_rtcp/source/rtcp_packet/tmmbr.h" #include "modules/rtp_rtcp/source/rtcp_packet/transport_feedback.h" #include "modules/rtp_rtcp/source/rtp_rtcp_config.h" +#include "modules/rtp_rtcp/source/rtp_rtcp_impl2.h" #include "modules/rtp_rtcp/source/time_util.h" #include "modules/rtp_rtcp/source/tmmbr_help.h" #include "rtc_base/checks.h" @@ -67,22 +68,6 @@ const size_t kMaxNumberOfStoredRrtrs = 300; constexpr TimeDelta kDefaultVideoReportInterval = TimeDelta::Seconds(1); constexpr TimeDelta kDefaultAudioReportInterval = TimeDelta::Seconds(5); -std::set GetRegisteredSsrcs( - const RtpRtcpInterface::Configuration& config) { - std::set ssrcs; - ssrcs.insert(config.local_media_ssrc); - if (config.rtx_send_ssrc) { - ssrcs.insert(*config.rtx_send_ssrc); - } - if (config.fec_generator) { - absl::optional flexfec_ssrc = config.fec_generator->FecSsrc(); - if (flexfec_ssrc) { - ssrcs.insert(*flexfec_ssrc); - } - } - return ssrcs; -} - // Returns true if the |timestamp| has exceeded the |interval * // kRrTimeoutIntervals| period and was reset (set to PlusInfinity()). Returns // false if the timer was either already reset or if it has not expired. @@ -100,6 +85,43 @@ bool ResetTimestampIfExpired(const Timestamp now, } // namespace +constexpr size_t RTCPReceiver::RegisteredSsrcs::kMediaSsrcIndex; +constexpr size_t RTCPReceiver::RegisteredSsrcs::kMaxSsrcs; + +RTCPReceiver::RegisteredSsrcs::RegisteredSsrcs( + bool disable_sequence_checker, + const RtpRtcpInterface::Configuration& config) + : packet_sequence_checker_(disable_sequence_checker) { + packet_sequence_checker_.Detach(); + ssrcs_.push_back(config.local_media_ssrc); + if (config.rtx_send_ssrc) { + ssrcs_.push_back(*config.rtx_send_ssrc); + } + if (config.fec_generator) { + absl::optional flexfec_ssrc = config.fec_generator->FecSsrc(); + if (flexfec_ssrc) { + ssrcs_.push_back(*flexfec_ssrc); + } + } + // Ensure that the RegisteredSsrcs can inline the SSRCs. + RTC_DCHECK_LE(ssrcs_.size(), RTCPReceiver::RegisteredSsrcs::kMaxSsrcs); +} + +bool RTCPReceiver::RegisteredSsrcs::contains(uint32_t ssrc) const { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + return absl::c_linear_search(ssrcs_, ssrc); +} + +uint32_t RTCPReceiver::RegisteredSsrcs::media_ssrc() const { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + return ssrcs_[kMediaSsrcIndex]; +} + +void RTCPReceiver::RegisteredSsrcs::set_media_ssrc(uint32_t ssrc) { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + ssrcs_[kMediaSsrcIndex] = ssrc; +} + struct RTCPReceiver::PacketInformation { uint32_t packet_type_flags = 0; // RTCPPacketTypeFlags bit field. @@ -116,43 +138,39 @@ struct RTCPReceiver::PacketInformation { std::unique_ptr loss_notification; }; -// Structure for handing TMMBR and TMMBN rtcp messages (RFC5104, section 3.5.4). -struct RTCPReceiver::TmmbrInformation { - struct TimedTmmbrItem { - rtcp::TmmbItem tmmbr_item; - int64_t last_updated_ms; - }; - - int64_t last_time_received_ms = 0; - - bool ready_for_delete = false; - - std::vector tmmbn; - std::map tmmbr; -}; - -// Structure for storing received RRTR RTCP messages (RFC3611, section 4.4). -struct RTCPReceiver::RrtrInformation { - RrtrInformation(uint32_t ssrc, - uint32_t received_remote_mid_ntp_time, - uint32_t local_receive_mid_ntp_time) - : ssrc(ssrc), - received_remote_mid_ntp_time(received_remote_mid_ntp_time), - local_receive_mid_ntp_time(local_receive_mid_ntp_time) {} - - uint32_t ssrc; - // Received NTP timestamp in compact representation. - uint32_t received_remote_mid_ntp_time; - // NTP time when the report was received in compact representation. - uint32_t local_receive_mid_ntp_time; -}; - -struct RTCPReceiver::LastFirStatus { - LastFirStatus(int64_t now_ms, uint8_t sequence_number) - : request_ms(now_ms), sequence_number(sequence_number) {} - int64_t request_ms; - uint8_t sequence_number; -}; +RTCPReceiver::RTCPReceiver(const RtpRtcpInterface::Configuration& config, + ModuleRtpRtcpImpl2* owner) + : clock_(config.clock), + receiver_only_(config.receiver_only), + rtp_rtcp_(owner), + main_ssrc_(config.local_media_ssrc), + registered_ssrcs_(false, config), + rtcp_bandwidth_observer_(config.bandwidth_callback), + rtcp_intra_frame_observer_(config.intra_frame_callback), + rtcp_loss_notification_observer_(config.rtcp_loss_notification_observer), + network_state_estimate_observer_(config.network_state_estimate_observer), + transport_feedback_observer_(config.transport_feedback_callback), + bitrate_allocation_observer_(config.bitrate_allocation_observer), + report_interval_(config.rtcp_report_interval_ms > 0 + ? TimeDelta::Millis(config.rtcp_report_interval_ms) + : (config.audio ? kDefaultAudioReportInterval + : kDefaultVideoReportInterval)), + // TODO(bugs.webrtc.org/10774): Remove fallback. + remote_ssrc_(0), + remote_sender_rtp_time_(0), + remote_sender_packet_count_(0), + remote_sender_octet_count_(0), + remote_sender_reports_count_(0), + xr_rrtr_status_(config.non_sender_rtt_measurement), + xr_rr_rtt_ms_(0), + oldest_tmmbr_info_ms_(0), + cname_callback_(config.rtcp_cname_callback), + report_block_data_observer_(config.report_block_data_observer), + packet_type_counter_observer_(config.rtcp_packet_type_counter_observer), + num_skipped_packets_(0), + last_skipped_packets_warning_ms_(clock_->TimeInMilliseconds()) { + RTC_DCHECK(owner); +} RTCPReceiver::RTCPReceiver(const RtpRtcpInterface::Configuration& config, ModuleRtpRtcp* owner) @@ -160,7 +178,7 @@ RTCPReceiver::RTCPReceiver(const RtpRtcpInterface::Configuration& config, receiver_only_(config.receiver_only), rtp_rtcp_(owner), main_ssrc_(config.local_media_ssrc), - registered_ssrcs_(GetRegisteredSsrcs(config)), + registered_ssrcs_(true, config), rtcp_bandwidth_observer_(config.bandwidth_callback), rtcp_intra_frame_observer_(config.intra_frame_callback), rtcp_loss_notification_observer_(config.rtcp_loss_notification_observer), @@ -174,16 +192,31 @@ RTCPReceiver::RTCPReceiver(const RtpRtcpInterface::Configuration& config, // TODO(bugs.webrtc.org/10774): Remove fallback. remote_ssrc_(0), remote_sender_rtp_time_(0), + remote_sender_packet_count_(0), + remote_sender_octet_count_(0), + remote_sender_reports_count_(0), xr_rrtr_status_(config.non_sender_rtt_measurement), xr_rr_rtt_ms_(0), oldest_tmmbr_info_ms_(0), - stats_callback_(config.rtcp_statistics_callback), cname_callback_(config.rtcp_cname_callback), report_block_data_observer_(config.report_block_data_observer), packet_type_counter_observer_(config.rtcp_packet_type_counter_observer), num_skipped_packets_(0), last_skipped_packets_warning_ms_(clock_->TimeInMilliseconds()) { RTC_DCHECK(owner); + // Dear reader - if you're here because of this log statement and are + // wondering what this is about, chances are that you are using an instance + // of RTCPReceiver without using the webrtc APIs. This creates a bit of a + // problem for WebRTC because this class is a part of an internal + // implementation that is constantly changing and being improved. + // The intention of this log statement is to give a heads up that changes + // are coming and encourage you to use the public APIs or be prepared that + // things might break down the line as more changes land. A thing you could + // try out for now is to replace the `CustomSequenceChecker` in the header + // with a regular `SequenceChecker` and see if that triggers an + // error in your code. If it does, chances are you have your own threading + // model that is not the same as WebRTC internally has. + RTC_LOG(LS_INFO) << "************** !!!DEPRECATION WARNING!! **************"; } RTCPReceiver::~RTCPReceiver() {} @@ -214,11 +247,31 @@ void RTCPReceiver::SetRemoteSSRC(uint32_t ssrc) { remote_ssrc_ = ssrc; } +void RTCPReceiver::set_local_media_ssrc(uint32_t ssrc) { + registered_ssrcs_.set_media_ssrc(ssrc); +} + +uint32_t RTCPReceiver::local_media_ssrc() const { + return registered_ssrcs_.media_ssrc(); +} + uint32_t RTCPReceiver::RemoteSSRC() const { MutexLock lock(&rtcp_receiver_lock_); return remote_ssrc_; } +void RTCPReceiver::RttStats::AddRtt(TimeDelta rtt) { + last_rtt_ = rtt; + if (rtt < min_rtt_) { + min_rtt_ = rtt; + } + if (rtt > max_rtt_) { + max_rtt_ = rtt; + } + sum_rtt_ += rtt; + ++num_rtts_; +} + int32_t RTCPReceiver::RTT(uint32_t remote_ssrc, int64_t* last_rtt_ms, int64_t* avg_rtt_ms, @@ -226,32 +279,26 @@ int32_t RTCPReceiver::RTT(uint32_t remote_ssrc, int64_t* max_rtt_ms) const { MutexLock lock(&rtcp_receiver_lock_); - auto it = received_report_blocks_.find(main_ssrc_); - if (it == received_report_blocks_.end()) - return -1; - - auto it_info = it->second.find(remote_ssrc); - if (it_info == it->second.end()) - return -1; - - const ReportBlockData* report_block_data = &it_info->second; - - if (report_block_data->num_rtts() == 0) + auto it = rtts_.find(remote_ssrc); + if (it == rtts_.end()) { return -1; + } - if (last_rtt_ms) - *last_rtt_ms = report_block_data->last_rtt_ms(); + if (last_rtt_ms) { + *last_rtt_ms = it->second.last_rtt().ms(); + } if (avg_rtt_ms) { - *avg_rtt_ms = - report_block_data->sum_rtt_ms() / report_block_data->num_rtts(); + *avg_rtt_ms = it->second.average_rtt().ms(); } - if (min_rtt_ms) - *min_rtt_ms = report_block_data->min_rtt_ms(); + if (min_rtt_ms) { + *min_rtt_ms = it->second.min_rtt().ms(); + } - if (max_rtt_ms) - *max_rtt_ms = report_block_data->max_rtt_ms(); + if (max_rtt_ms) { + *max_rtt_ms = it->second.max_rtt().ms(); + } return 0; } @@ -279,26 +326,14 @@ absl::optional RTCPReceiver::OnPeriodicRttUpdate( // amount of time. MutexLock lock(&rtcp_receiver_lock_); if (last_received_rb_.IsInfinite() || last_received_rb_ > newer_than) { - // Stow away the report block for the main ssrc. We'll use the associated - // data map to look up each sender and check the last_rtt_ms(). - auto main_report_it = received_report_blocks_.find(main_ssrc_); - if (main_report_it != received_report_blocks_.end()) { - const ReportBlockDataMap& main_data_map = main_report_it->second; - int64_t max_rtt = 0; - for (const auto& reports_per_receiver : received_report_blocks_) { - for (const auto& report : reports_per_receiver.second) { - const RTCPReportBlock& block = report.second.report_block(); - auto it_info = main_data_map.find(block.sender_ssrc); - if (it_info != main_data_map.end()) { - const ReportBlockData* report_block_data = &it_info->second; - if (report_block_data->num_rtts() > 0) { - max_rtt = std::max(report_block_data->last_rtt_ms(), max_rtt); - } - } - } + TimeDelta max_rtt = TimeDelta::MinusInfinity(); + for (const auto& rtt_stats : rtts_) { + if (rtt_stats.second.last_rtt() > max_rtt) { + max_rtt = rtt_stats.second.last_rtt(); } - if (max_rtt) - rtt.emplace(TimeDelta::Millis(max_rtt)); + } + if (max_rtt.IsFinite()) { + rtt = max_rtt; } } @@ -325,7 +360,10 @@ bool RTCPReceiver::NTP(uint32_t* received_ntp_secs, uint32_t* received_ntp_frac, uint32_t* rtcp_arrival_time_secs, uint32_t* rtcp_arrival_time_frac, - uint32_t* rtcp_timestamp) const { + uint32_t* rtcp_timestamp, + uint32_t* remote_sender_packet_count, + uint64_t* remote_sender_octet_count, + uint64_t* remote_sender_reports_count) const { MutexLock lock(&rtcp_receiver_lock_); if (!last_received_sr_ntp_.Valid()) return false; @@ -335,7 +373,6 @@ bool RTCPReceiver::NTP(uint32_t* received_ntp_secs, *received_ntp_secs = remote_sender_ntp_time_.seconds(); if (received_ntp_frac) *received_ntp_frac = remote_sender_ntp_time_.fractions(); - // Rtp time from incoming SenderReport. if (rtcp_timestamp) *rtcp_timestamp = remote_sender_rtp_time_; @@ -346,6 +383,14 @@ bool RTCPReceiver::NTP(uint32_t* received_ntp_secs, if (rtcp_arrival_time_frac) *rtcp_arrival_time_frac = last_received_sr_ntp_.fractions(); + // Counters. + if (remote_sender_packet_count) + *remote_sender_packet_count = remote_sender_packet_count_; + if (remote_sender_octet_count) + *remote_sender_octet_count = remote_sender_octet_count_; + if (remote_sender_reports_count) + *remote_sender_reports_count = remote_sender_reports_count_; + return true; } @@ -358,8 +403,7 @@ RTCPReceiver::ConsumeReceivedXrReferenceTimeInfo() { std::vector last_xr_rtis; last_xr_rtis.reserve(last_xr_rtis_size); - const uint32_t now_ntp = - CompactNtp(TimeMicrosToNtp(clock_->TimeInMicroseconds())); + const uint32_t now_ntp = CompactNtp(clock_->CurrentNtpTime()); for (size_t i = 0; i < last_xr_rtis_size; ++i) { RrtrInformation& rrtr = received_rrtrs_.front(); @@ -372,23 +416,12 @@ RTCPReceiver::ConsumeReceivedXrReferenceTimeInfo() { return last_xr_rtis; } -// We can get multiple receive reports when we receive the report from a CE. -int32_t RTCPReceiver::StatisticsReceived( - std::vector* receive_blocks) const { - RTC_DCHECK(receive_blocks); - MutexLock lock(&rtcp_receiver_lock_); - for (const auto& reports_per_receiver : received_report_blocks_) - for (const auto& report : reports_per_receiver.second) - receive_blocks->push_back(report.second.report_block()); - return 0; -} - std::vector RTCPReceiver::GetLatestReportBlockData() const { std::vector result; MutexLock lock(&rtcp_receiver_lock_); - for (const auto& reports_per_receiver : received_report_blocks_) - for (const auto& report : reports_per_receiver.second) - result.push_back(report.second); + for (const auto& report : received_report_blocks_) { + result.push_back(report.second); + } return result; } @@ -518,7 +551,10 @@ void RTCPReceiver::HandleSenderReport(const CommonHeader& rtcp_block, remote_sender_ntp_time_ = sender_report.ntp(); remote_sender_rtp_time_ = sender_report.rtp_timestamp(); - last_received_sr_ntp_ = TimeMicrosToNtp(clock_->TimeInMicroseconds()); + last_received_sr_ntp_ = clock_->CurrentNtpTime(); + remote_sender_packet_count_ = sender_report.sender_packet_count(); + remote_sender_octet_count_ = sender_report.sender_octet_count(); + remote_sender_reports_count_++; } else { // We will only store the send report from one source, but // we will store all the receive blocks. @@ -562,13 +598,13 @@ void RTCPReceiver::HandleReportBlock(const ReportBlock& report_block, // which the information in this reception report block pertains. // Filter out all report blocks that are not for us. - if (registered_ssrcs_.count(report_block.source_ssrc()) == 0) + if (!registered_ssrcs_.contains(report_block.source_ssrc())) return; last_received_rb_ = clock_->CurrentTime(); ReportBlockData* report_block_data = - &received_report_blocks_[report_block.source_ssrc()][remote_ssrc]; + &received_report_blocks_[report_block.source_ssrc()]; RTCPReportBlock rtcp_report_block; rtcp_report_block.sender_ssrc = remote_ssrc; rtcp_report_block.source_ssrc = report_block.source_ssrc(); @@ -605,13 +641,16 @@ void RTCPReceiver::HandleReportBlock(const ReportBlock& report_block, uint32_t delay_ntp = report_block.delay_since_last_sr(); // Local NTP time. uint32_t receive_time_ntp = - CompactNtp(TimeMicrosToNtp(last_received_rb_.us())); + CompactNtp(clock_->ConvertTimestampToNtpTime(last_received_rb_)); // RTT in 1/(2^16) seconds. uint32_t rtt_ntp = receive_time_ntp - delay_ntp - send_time_ntp; // Convert to 1/1000 seconds (milliseconds). rtt_ms = CompactNtpRttToMs(rtt_ntp); report_block_data->AddRoundTripTimeSample(rtt_ms); + if (report_block.source_ssrc() == main_ssrc_) { + rtts_[remote_ssrc].AddRtt(TimeDelta::Millis(rtt_ms)); + } packet_information->rtt_ms = rtt_ms; } @@ -714,7 +753,6 @@ void RTCPReceiver::HandleSdes(const CommonHeader& rtcp_block, } for (const rtcp::Sdes::Chunk& chunk : sdes.chunks()) { - received_cnames_[chunk.ssrc] = chunk.cname; if (cname_callback_) cname_callback_->OnCname(chunk.ssrc, chunk.cname); } @@ -770,15 +808,16 @@ void RTCPReceiver::HandleBye(const CommonHeader& rtcp_block) { } // Clear our lists. - for (auto& reports_per_receiver : received_report_blocks_) - reports_per_receiver.second.erase(bye.sender_ssrc()); + rtts_.erase(bye.sender_ssrc()); + EraseIf(received_report_blocks_, [&](const auto& elem) { + return elem.second.report_block().sender_ssrc == bye.sender_ssrc(); + }); TmmbrInformation* tmmbr_info = GetTmmbrInformation(bye.sender_ssrc()); if (tmmbr_info) tmmbr_info->ready_for_delete = true; last_fir_.erase(bye.sender_ssrc()); - received_cnames_.erase(bye.sender_ssrc()); auto it = received_rrtrs_ssrc_it_.find(bye.sender_ssrc()); if (it != received_rrtrs_ssrc_it_.end()) { received_rrtrs_.erase(it->second); @@ -810,8 +849,7 @@ void RTCPReceiver::HandleXr(const CommonHeader& rtcp_block, void RTCPReceiver::HandleXrReceiveReferenceTime(uint32_t sender_ssrc, const rtcp::Rrtr& rrtr) { uint32_t received_remote_mid_ntp_time = CompactNtp(rrtr.ntp()); - uint32_t local_receive_mid_ntp_time = - CompactNtp(TimeMicrosToNtp(clock_->TimeInMicroseconds())); + uint32_t local_receive_mid_ntp_time = CompactNtp(clock_->CurrentNtpTime()); auto it = received_rrtrs_ssrc_it_.find(sender_ssrc); if (it != received_rrtrs_ssrc_it_.end()) { @@ -830,7 +868,7 @@ void RTCPReceiver::HandleXrReceiveReferenceTime(uint32_t sender_ssrc, } void RTCPReceiver::HandleXrDlrrReportBlock(const rtcp::ReceiveTimeInfo& rti) { - if (registered_ssrcs_.count(rti.ssrc) == 0) // Not to us. + if (!registered_ssrcs_.contains(rti.ssrc)) // Not to us. return; // Caller should explicitly enable rtt calculation using extended reports. @@ -845,7 +883,7 @@ void RTCPReceiver::HandleXrDlrrReportBlock(const rtcp::ReceiveTimeInfo& rti) { return; uint32_t delay_ntp = rti.delay_since_last_rr; - uint32_t now_ntp = CompactNtp(TimeMicrosToNtp(clock_->TimeInMicroseconds())); + uint32_t now_ntp = CompactNtp(clock_->CurrentNtpTime()); uint32_t rtt_ntp = now_ntp - delay_ntp - send_time_ntp; xr_rr_rtt_ms_ = CompactNtpRttToMs(rtt_ntp); @@ -1053,14 +1091,7 @@ void RTCPReceiver::TriggerCallbacksFromRtcpPacket( // Might trigger a OnReceivedBandwidthEstimateUpdate. NotifyTmmbrUpdated(); } - uint32_t local_ssrc; - std::set registered_ssrcs; - { - // We don't want to hold this critsect when triggering the callbacks below. - MutexLock lock(&rtcp_receiver_lock_); - local_ssrc = main_ssrc_; - registered_ssrcs = registered_ssrcs_; - } + if (!receiver_only_ && (packet_information.packet_type_flags & kRtcpSrReq)) { rtp_rtcp_->OnRequestSendReport(); } @@ -1087,7 +1118,7 @@ void RTCPReceiver::TriggerCallbacksFromRtcpPacket( RTC_LOG(LS_VERBOSE) << "Incoming FIR from SSRC " << packet_information.remote_ssrc; } - rtcp_intra_frame_observer_->OnReceivedIntraFrameRequest(local_ssrc); + rtcp_intra_frame_observer_->OnReceivedIntraFrameRequest(main_ssrc_); } } if (rtcp_loss_notification_observer_ && @@ -1095,7 +1126,7 @@ void RTCPReceiver::TriggerCallbacksFromRtcpPacket( rtcp::LossNotification* loss_notification = packet_information.loss_notification.get(); RTC_DCHECK(loss_notification); - if (loss_notification->media_ssrc() == local_ssrc) { + if (loss_notification->media_ssrc() == main_ssrc_) { rtcp_loss_notification_observer_->OnReceivedLossNotification( loss_notification->media_ssrc(), loss_notification->last_decoded(), loss_notification->last_received(), @@ -1127,8 +1158,8 @@ void RTCPReceiver::TriggerCallbacksFromRtcpPacket( (packet_information.packet_type_flags & kRtcpTransportFeedback)) { uint32_t media_source_ssrc = packet_information.transport_feedback->media_ssrc(); - if (media_source_ssrc == local_ssrc || - registered_ssrcs.find(media_source_ssrc) != registered_ssrcs.end()) { + if (media_source_ssrc == main_ssrc_ || + registered_ssrcs_.contains(media_source_ssrc)) { transport_feedback_observer_->OnTransportFeedback( *packet_information.transport_feedback); } @@ -1147,18 +1178,6 @@ void RTCPReceiver::TriggerCallbacksFromRtcpPacket( } if (!receiver_only_) { - if (stats_callback_) { - for (const auto& report_block : packet_information.report_blocks) { - RtcpStatistics stats; - stats.packets_lost = report_block.packets_lost; - stats.extended_highest_sequence_number = - report_block.extended_highest_sequence_number; - stats.fraction_lost = report_block.fraction_lost; - stats.jitter = report_block.jitter; - - stats_callback_->StatisticsUpdated(stats, report_block.source_ssrc); - } - } if (report_block_data_observer_) { for (const auto& report_block_data : packet_information.report_block_datas) { @@ -1169,20 +1188,6 @@ void RTCPReceiver::TriggerCallbacksFromRtcpPacket( } } -int32_t RTCPReceiver::CNAME(uint32_t remoteSSRC, - char cName[RTCP_CNAME_SIZE]) const { - RTC_DCHECK(cName); - - MutexLock lock(&rtcp_receiver_lock_); - auto received_cname_it = received_cnames_.find(remoteSSRC); - if (received_cname_it == received_cnames_.end()) - return -1; - - size_t length = received_cname_it->second.copy(cName, RTCP_CNAME_SIZE - 1); - cName[length] = 0; - return 0; -} - std::vector RTCPReceiver::TmmbrReceived() { MutexLock lock(&rtcp_receiver_lock_); std::vector candidates; diff --git a/modules/rtp_rtcp/source/rtcp_receiver.h b/modules/rtp_rtcp/source/rtcp_receiver.h index d735653f41..fa9f367c9e 100644 --- a/modules/rtp_rtcp/source/rtcp_receiver.h +++ b/modules/rtp_rtcp/source/rtcp_receiver.h @@ -13,23 +13,29 @@ #include #include -#include #include #include #include "api/array_view.h" +#include "api/sequence_checker.h" #include "modules/rtp_rtcp/include/report_block_data.h" #include "modules/rtp_rtcp/include/rtcp_statistics.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" #include "modules/rtp_rtcp/source/rtcp_nack_stats.h" #include "modules/rtp_rtcp/source/rtcp_packet/dlrr.h" +#include "modules/rtp_rtcp/source/rtcp_packet/tmmb_item.h" #include "modules/rtp_rtcp/source/rtp_rtcp_interface.h" +#include "rtc_base/containers/flat_map.h" #include "rtc_base/synchronization/mutex.h" +#include "rtc_base/system/no_unique_address.h" #include "rtc_base/thread_annotations.h" #include "system_wrappers/include/ntp_time.h" namespace webrtc { + +class ModuleRtpRtcpImpl2; class VideoBitrateAllocationObserver; + namespace rtcp { class CommonHeader; class ReportBlock; @@ -55,6 +61,10 @@ class RTCPReceiver final { RTCPReceiver(const RtpRtcpInterface::Configuration& config, ModuleRtpRtcp* owner); + + RTCPReceiver(const RtpRtcpInterface::Configuration& config, + ModuleRtpRtcpImpl2* owner); + ~RTCPReceiver(); void IncomingPacket(const uint8_t* packet, size_t packet_size) { @@ -64,18 +74,30 @@ class RTCPReceiver final { int64_t LastReceivedReportBlockMs() const; + void set_local_media_ssrc(uint32_t ssrc); + uint32_t local_media_ssrc() const; + void SetRemoteSSRC(uint32_t ssrc); uint32_t RemoteSSRC() const; - // Get received cname. - int32_t CNAME(uint32_t remote_ssrc, char cname[RTCP_CNAME_SIZE]) const; + bool receiver_only() const { return receiver_only_; } // Get received NTP. + // The types for the arguments below derive from the specification: + // - `remote_sender_packet_count`: `RTCSentRtpStreamStats.packetsSent` [1] + // - `remote_sender_octet_count`: `RTCSentRtpStreamStats.bytesSent` [1] + // - `remote_sender_reports_count`: + // `RTCRemoteOutboundRtpStreamStats.reportsSent` [2] + // [1] https://www.w3.org/TR/webrtc-stats/#remoteoutboundrtpstats-dict* + // [2] https://www.w3.org/TR/webrtc-stats/#dom-rtcsentrtpstreamstats bool NTP(uint32_t* received_ntp_secs, uint32_t* received_ntp_frac, uint32_t* rtcp_arrival_time_secs, uint32_t* rtcp_arrival_time_frac, - uint32_t* rtcp_timestamp) const; + uint32_t* rtcp_timestamp, + uint32_t* remote_sender_packet_count, + uint64_t* remote_sender_octet_count, + uint64_t* remote_sender_reports_count) const; std::vector ConsumeReceivedXrReferenceTimeInfo(); @@ -93,12 +115,9 @@ class RTCPReceiver final { absl::optional OnPeriodicRttUpdate(Timestamp newer_than, bool sending); - // Get statistics. - int32_t StatisticsReceived(std::vector* receiveBlocks) const; // A snapshot of Report Blocks with additional data of interest to statistics. - // Within this list, the sender-source SSRC pair is unique and per-pair the - // ReportBlockData represents the latest Report Block that was received for - // that pair. + // Within this list, the source SSRC is unique and ReportBlockData represents + // the latest Report Block that was received for that SSRC. std::vector GetLatestReportBlockData() const; // Returns true if we haven't received an RTCP RR for several RTCP @@ -119,14 +138,111 @@ class RTCPReceiver final { void NotifyTmmbrUpdated(); private: +#if RTC_DCHECK_IS_ON + class CustomSequenceChecker : public SequenceChecker { + public: + explicit CustomSequenceChecker(bool disable_checks) + : disable_checks_(disable_checks) {} + bool IsCurrent() const { + if (disable_checks_) + return true; + return SequenceChecker::IsCurrent(); + } + + private: + const bool disable_checks_; + }; +#else + class CustomSequenceChecker : public SequenceChecker { + public: + explicit CustomSequenceChecker(bool) {} + }; +#endif + + // A lightweight inlined set of local SSRCs. + class RegisteredSsrcs { + public: + static constexpr size_t kMediaSsrcIndex = 0; + static constexpr size_t kMaxSsrcs = 3; + // Initializes the set of registered local SSRCS by extracting them from the + // provided `config`. The `disable_sequence_checker` flag is a workaround + // to be able to use a sequence checker without breaking downstream + // code that currently doesn't follow the same threading rules as webrtc. + RegisteredSsrcs(bool disable_sequence_checker, + const RtpRtcpInterface::Configuration& config); + + // Indicates if `ssrc` is in the set of registered local SSRCs. + bool contains(uint32_t ssrc) const; + uint32_t media_ssrc() const; + void set_media_ssrc(uint32_t ssrc); + + private: + RTC_NO_UNIQUE_ADDRESS CustomSequenceChecker packet_sequence_checker_; + absl::InlinedVector ssrcs_ + RTC_GUARDED_BY(packet_sequence_checker_); + }; + struct PacketInformation; - struct TmmbrInformation; - struct RrtrInformation; - struct LastFirStatus; - // RTCP report blocks mapped by remote SSRC. - using ReportBlockDataMap = std::map; - // RTCP report blocks map mapped by source SSRC. - using ReportBlockMap = std::map; + + // Structure for handing TMMBR and TMMBN rtcp messages (RFC5104, + // section 3.5.4). + struct TmmbrInformation { + struct TimedTmmbrItem { + rtcp::TmmbItem tmmbr_item; + int64_t last_updated_ms; + }; + + int64_t last_time_received_ms = 0; + + bool ready_for_delete = false; + + std::vector tmmbn; + std::map tmmbr; + }; + + // Structure for storing received RRTR RTCP messages (RFC3611, section 4.4). + struct RrtrInformation { + RrtrInformation(uint32_t ssrc, + uint32_t received_remote_mid_ntp_time, + uint32_t local_receive_mid_ntp_time) + : ssrc(ssrc), + received_remote_mid_ntp_time(received_remote_mid_ntp_time), + local_receive_mid_ntp_time(local_receive_mid_ntp_time) {} + + uint32_t ssrc; + // Received NTP timestamp in compact representation. + uint32_t received_remote_mid_ntp_time; + // NTP time when the report was received in compact representation. + uint32_t local_receive_mid_ntp_time; + }; + + struct LastFirStatus { + LastFirStatus(int64_t now_ms, uint8_t sequence_number) + : request_ms(now_ms), sequence_number(sequence_number) {} + int64_t request_ms; + uint8_t sequence_number; + }; + + class RttStats { + public: + RttStats() = default; + RttStats(const RttStats&) = default; + RttStats& operator=(const RttStats&) = default; + + void AddRtt(TimeDelta rtt); + + TimeDelta last_rtt() const { return last_rtt_; } + TimeDelta min_rtt() const { return min_rtt_; } + TimeDelta max_rtt() const { return max_rtt_; } + TimeDelta average_rtt() const { return sum_rtt_ / num_rtts_; } + + private: + TimeDelta last_rtt_ = TimeDelta::Zero(); + TimeDelta min_rtt_ = TimeDelta::PlusInfinity(); + TimeDelta max_rtt_ = TimeDelta::MinusInfinity(); + TimeDelta sum_rtt_ = TimeDelta::Zero(); + size_t num_rtts_ = 0; + }; bool ParseCompoundPacket(rtc::ArrayView packet, PacketInformation* packet_information); @@ -224,7 +340,8 @@ class RTCPReceiver final { const bool receiver_only_; ModuleRtpRtcp* const rtp_rtcp_; const uint32_t main_ssrc_; - const std::set registered_ssrcs_; + // The set of registered local SSRCs. + RegisteredSsrcs registered_ssrcs_; RtcpBandwidthObserver* const rtcp_bandwidth_observer_; RtcpIntraFrameObserver* const rtcp_intra_frame_observer_; @@ -242,12 +359,15 @@ class RTCPReceiver final { uint32_t remote_sender_rtp_time_ RTC_GUARDED_BY(rtcp_receiver_lock_); // When did we receive the last send report. NtpTime last_received_sr_ntp_ RTC_GUARDED_BY(rtcp_receiver_lock_); + uint32_t remote_sender_packet_count_ RTC_GUARDED_BY(rtcp_receiver_lock_); + uint64_t remote_sender_octet_count_ RTC_GUARDED_BY(rtcp_receiver_lock_); + uint64_t remote_sender_reports_count_ RTC_GUARDED_BY(rtcp_receiver_lock_); // Received RRTR information in ascending receive time order. std::list received_rrtrs_ RTC_GUARDED_BY(rtcp_receiver_lock_); // Received RRTR information mapped by remote ssrc. - std::map::iterator> + flat_map::iterator> received_rrtrs_ssrc_it_ RTC_GUARDED_BY(rtcp_receiver_lock_); // Estimated rtt, zero when there is no valid estimate. @@ -256,13 +376,16 @@ class RTCPReceiver final { int64_t oldest_tmmbr_info_ms_ RTC_GUARDED_BY(rtcp_receiver_lock_); // Mapped by remote ssrc. - std::map tmmbr_infos_ + flat_map tmmbr_infos_ RTC_GUARDED_BY(rtcp_receiver_lock_); - ReportBlockMap received_report_blocks_ RTC_GUARDED_BY(rtcp_receiver_lock_); - std::map last_fir_ + // Round-Trip Time per remote sender ssrc. + flat_map rtts_ RTC_GUARDED_BY(rtcp_receiver_lock_); + + // Report blocks per local source ssrc. + flat_map received_report_blocks_ RTC_GUARDED_BY(rtcp_receiver_lock_); - std::map received_cnames_ + flat_map last_fir_ RTC_GUARDED_BY(rtcp_receiver_lock_); // The last time we received an RTCP Report block for this module. @@ -273,11 +396,7 @@ class RTCPReceiver final { // delivered RTP packet to the remote side. Timestamp last_increased_sequence_number_ = Timestamp::PlusInfinity(); - RtcpStatisticsCallback* const stats_callback_; RtcpCnameCallback* const cname_callback_; - // TODO(hbos): Remove RtcpStatisticsCallback in favor of - // ReportBlockDataObserver; the ReportBlockData contains a superset of the - // RtcpStatistics data. ReportBlockDataObserver* const report_block_data_observer_; RtcpPacketTypeCounterObserver* const packet_type_counter_observer_; diff --git a/modules/rtp_rtcp/source/rtcp_receiver_unittest.cc b/modules/rtp_rtcp/source/rtcp_receiver_unittest.cc index 1a1d94a4f0..3065534108 100644 --- a/modules/rtp_rtcp/source/rtcp_receiver_unittest.cc +++ b/modules/rtp_rtcp/source/rtcp_receiver_unittest.cc @@ -51,6 +51,7 @@ using rtcp::ReceiveTimeInfo; using ::testing::_; using ::testing::AllOf; using ::testing::ElementsAreArray; +using ::testing::Eq; using ::testing::Field; using ::testing::InSequence; using ::testing::IsEmpty; @@ -86,14 +87,6 @@ class MockRtcpLossNotificationObserver : public RtcpLossNotificationObserver { (override)); }; -class MockRtcpCallbackImpl : public RtcpStatisticsCallback { - public: - MOCK_METHOD(void, - StatisticsUpdated, - (const RtcpStatistics&, uint32_t), - (override)); -}; - class MockCnameCallbackImpl : public RtcpCnameCallback { public: MOCK_METHOD(void, OnCname, (uint32_t, absl::string_view), (override)); @@ -208,7 +201,8 @@ TEST(RtcpReceiverTest, InjectSrPacket) { RTCPReceiver receiver(DefaultConfiguration(&mocks), &mocks.rtp_rtcp_impl); receiver.SetRemoteSSRC(kSenderSsrc); - EXPECT_FALSE(receiver.NTP(nullptr, nullptr, nullptr, nullptr, nullptr)); + EXPECT_FALSE(receiver.NTP(nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr)); int64_t now = mocks.clock.TimeInMilliseconds(); rtcp::SenderReport sr; @@ -219,7 +213,8 @@ TEST(RtcpReceiverTest, InjectSrPacket) { OnReceivedRtcpReceiverReport(IsEmpty(), _, now)); receiver.IncomingPacket(sr.Build()); - EXPECT_TRUE(receiver.NTP(nullptr, nullptr, nullptr, nullptr, nullptr)); + EXPECT_TRUE(receiver.NTP(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr)); } TEST(RtcpReceiverTest, InjectSrPacketFromUnknownSender) { @@ -239,7 +234,8 @@ TEST(RtcpReceiverTest, InjectSrPacketFromUnknownSender) { receiver.IncomingPacket(sr.Build()); // But will not flag that he's gotten sender information. - EXPECT_FALSE(receiver.NTP(nullptr, nullptr, nullptr, nullptr, nullptr)); + EXPECT_FALSE(receiver.NTP(nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr)); } TEST(RtcpReceiverTest, InjectSrPacketCalculatesRTT) { @@ -254,8 +250,7 @@ TEST(RtcpReceiverTest, InjectSrPacketCalculatesRTT) { int64_t rtt_ms = 0; EXPECT_EQ(-1, receiver.RTT(kSenderSsrc, &rtt_ms, nullptr, nullptr, nullptr)); - uint32_t sent_ntp = - CompactNtp(TimeMicrosToNtp(mocks.clock.TimeInMicroseconds())); + uint32_t sent_ntp = CompactNtp(mocks.clock.CurrentNtpTime()); mocks.clock.AdvanceTimeMilliseconds(kRttMs + kDelayMs); rtcp::SenderReport sr; @@ -286,8 +281,7 @@ TEST(RtcpReceiverTest, InjectSrPacketCalculatesNegativeRTTAsOne) { int64_t rtt_ms = 0; EXPECT_EQ(-1, receiver.RTT(kSenderSsrc, &rtt_ms, nullptr, nullptr, nullptr)); - uint32_t sent_ntp = - CompactNtp(TimeMicrosToNtp(mocks.clock.TimeInMicroseconds())); + uint32_t sent_ntp = CompactNtp(mocks.clock.CurrentNtpTime()); mocks.clock.AdvanceTimeMilliseconds(kRttMs + kDelayMs); rtcp::SenderReport sr; @@ -317,8 +311,7 @@ TEST(RtcpReceiverTest, const uint32_t kDelayNtp = 123000; const int64_t kDelayMs = CompactNtpRttToMs(kDelayNtp); - uint32_t sent_ntp = - CompactNtp(TimeMicrosToNtp(mocks.clock.TimeInMicroseconds())); + uint32_t sent_ntp = CompactNtp(mocks.clock.CurrentNtpTime()); mocks.clock.AdvanceTimeMilliseconds(kRttMs + kDelayMs); rtcp::SenderReport sr; @@ -352,9 +345,7 @@ TEST(RtcpReceiverTest, InjectRrPacket) { OnReceivedRtcpReceiverReport(IsEmpty(), _, now)); receiver.IncomingPacket(rr.Build()); - std::vector report_blocks; - receiver.StatisticsReceived(&report_blocks); - EXPECT_TRUE(report_blocks.empty()); + EXPECT_THAT(receiver.GetLatestReportBlockData(), IsEmpty()); } TEST(RtcpReceiverTest, InjectRrPacketWithReportBlockNotToUsIgnored) { @@ -375,9 +366,7 @@ TEST(RtcpReceiverTest, InjectRrPacketWithReportBlockNotToUsIgnored) { receiver.IncomingPacket(rr.Build()); EXPECT_EQ(0, receiver.LastReceivedReportBlockMs()); - std::vector received_blocks; - receiver.StatisticsReceived(&received_blocks); - EXPECT_TRUE(received_blocks.empty()); + EXPECT_THAT(receiver.GetLatestReportBlockData(), IsEmpty()); } TEST(RtcpReceiverTest, InjectRrPacketWithOneReportBlock) { @@ -399,9 +388,7 @@ TEST(RtcpReceiverTest, InjectRrPacketWithOneReportBlock) { receiver.IncomingPacket(rr.Build()); EXPECT_EQ(now, receiver.LastReceivedReportBlockMs()); - std::vector received_blocks; - receiver.StatisticsReceived(&received_blocks); - EXPECT_EQ(1u, received_blocks.size()); + EXPECT_THAT(receiver.GetLatestReportBlockData(), SizeIs(1)); } TEST(RtcpReceiverTest, InjectSrPacketWithOneReportBlock) { @@ -423,9 +410,7 @@ TEST(RtcpReceiverTest, InjectSrPacketWithOneReportBlock) { receiver.IncomingPacket(sr.Build()); EXPECT_EQ(now, receiver.LastReceivedReportBlockMs()); - std::vector received_blocks; - receiver.StatisticsReceived(&received_blocks); - EXPECT_EQ(1u, received_blocks.size()); + EXPECT_THAT(receiver.GetLatestReportBlockData(), SizeIs(1)); } TEST(RtcpReceiverTest, InjectRrPacketWithTwoReportBlocks) { @@ -459,11 +444,12 @@ TEST(RtcpReceiverTest, InjectRrPacketWithTwoReportBlocks) { receiver.IncomingPacket(rr1.Build()); EXPECT_EQ(now, receiver.LastReceivedReportBlockMs()); - std::vector received_blocks; - receiver.StatisticsReceived(&received_blocks); - EXPECT_THAT(received_blocks, - UnorderedElementsAre(Field(&RTCPReportBlock::fraction_lost, 0), - Field(&RTCPReportBlock::fraction_lost, 10))); + EXPECT_THAT(receiver.GetLatestReportBlockData(), + UnorderedElementsAre( + Property(&ReportBlockData::report_block, + Field(&RTCPReportBlock::fraction_lost, 0)), + Property(&ReportBlockData::report_block, + Field(&RTCPReportBlock::fraction_lost, 10)))); // Insert next receiver report with same ssrc but new values. rtcp::ReportBlock rb3; @@ -492,25 +478,27 @@ TEST(RtcpReceiverTest, InjectRrPacketWithTwoReportBlocks) { OnReceivedRtcpReceiverReport(SizeIs(2), _, now)); receiver.IncomingPacket(rr2.Build()); - received_blocks.clear(); - receiver.StatisticsReceived(&received_blocks); - EXPECT_EQ(2u, received_blocks.size()); EXPECT_THAT( - received_blocks, + receiver.GetLatestReportBlockData(), UnorderedElementsAre( - AllOf(Field(&RTCPReportBlock::source_ssrc, kReceiverMainSsrc), - Field(&RTCPReportBlock::fraction_lost, kFracLost[0]), - Field(&RTCPReportBlock::packets_lost, kCumLost[0]), - Field(&RTCPReportBlock::extended_highest_sequence_number, - kSequenceNumbers[0])), - AllOf(Field(&RTCPReportBlock::source_ssrc, kReceiverExtraSsrc), - Field(&RTCPReportBlock::fraction_lost, kFracLost[1]), - Field(&RTCPReportBlock::packets_lost, kCumLost[1]), - Field(&RTCPReportBlock::extended_highest_sequence_number, - kSequenceNumbers[1])))); + Property( + &ReportBlockData::report_block, + AllOf(Field(&RTCPReportBlock::source_ssrc, kReceiverMainSsrc), + Field(&RTCPReportBlock::fraction_lost, kFracLost[0]), + Field(&RTCPReportBlock::packets_lost, kCumLost[0]), + Field(&RTCPReportBlock::extended_highest_sequence_number, + kSequenceNumbers[0]))), + Property( + &ReportBlockData::report_block, + AllOf(Field(&RTCPReportBlock::source_ssrc, kReceiverExtraSsrc), + Field(&RTCPReportBlock::fraction_lost, kFracLost[1]), + Field(&RTCPReportBlock::packets_lost, kCumLost[1]), + Field(&RTCPReportBlock::extended_highest_sequence_number, + kSequenceNumbers[1]))))); } -TEST(RtcpReceiverTest, InjectRrPacketsFromTwoRemoteSsrcs) { +TEST(RtcpReceiverTest, + InjectRrPacketsFromTwoRemoteSsrcsReturnsLatestReportBlock) { const uint32_t kSenderSsrc2 = 0x20304; const uint16_t kSequenceNumbers[] = {10, 12423}; const int32_t kCumLost[] = {13, 555}; @@ -537,15 +525,16 @@ TEST(RtcpReceiverTest, InjectRrPacketsFromTwoRemoteSsrcs) { EXPECT_EQ(now, receiver.LastReceivedReportBlockMs()); - std::vector received_blocks; - receiver.StatisticsReceived(&received_blocks); - EXPECT_EQ(1u, received_blocks.size()); - EXPECT_EQ(kSenderSsrc, received_blocks[0].sender_ssrc); - EXPECT_EQ(kReceiverMainSsrc, received_blocks[0].source_ssrc); - EXPECT_EQ(kFracLost[0], received_blocks[0].fraction_lost); - EXPECT_EQ(kCumLost[0], received_blocks[0].packets_lost); - EXPECT_EQ(kSequenceNumbers[0], - received_blocks[0].extended_highest_sequence_number); + EXPECT_THAT( + receiver.GetLatestReportBlockData(), + ElementsAre(Property( + &ReportBlockData::report_block, + AllOf(Field(&RTCPReportBlock::source_ssrc, kReceiverMainSsrc), + Field(&RTCPReportBlock::sender_ssrc, kSenderSsrc), + Field(&RTCPReportBlock::fraction_lost, kFracLost[0]), + Field(&RTCPReportBlock::packets_lost, kCumLost[0]), + Field(&RTCPReportBlock::extended_highest_sequence_number, + kSequenceNumbers[0]))))); rtcp::ReportBlock rb2; rb2.SetMediaSsrc(kReceiverMainSsrc); @@ -561,24 +550,17 @@ TEST(RtcpReceiverTest, InjectRrPacketsFromTwoRemoteSsrcs) { OnReceivedRtcpReceiverReport(SizeIs(1), _, now)); receiver.IncomingPacket(rr2.Build()); - received_blocks.clear(); - receiver.StatisticsReceived(&received_blocks); - ASSERT_EQ(2u, received_blocks.size()); EXPECT_THAT( - received_blocks, + receiver.GetLatestReportBlockData(), UnorderedElementsAre( - AllOf(Field(&RTCPReportBlock::source_ssrc, kReceiverMainSsrc), - Field(&RTCPReportBlock::sender_ssrc, kSenderSsrc), - Field(&RTCPReportBlock::fraction_lost, kFracLost[0]), - Field(&RTCPReportBlock::packets_lost, kCumLost[0]), - Field(&RTCPReportBlock::extended_highest_sequence_number, - kSequenceNumbers[0])), - AllOf(Field(&RTCPReportBlock::source_ssrc, kReceiverMainSsrc), - Field(&RTCPReportBlock::sender_ssrc, kSenderSsrc2), - Field(&RTCPReportBlock::fraction_lost, kFracLost[1]), - Field(&RTCPReportBlock::packets_lost, kCumLost[1]), - Field(&RTCPReportBlock::extended_highest_sequence_number, - kSequenceNumbers[1])))); + Property( + &ReportBlockData::report_block, + AllOf(Field(&RTCPReportBlock::source_ssrc, kReceiverMainSsrc), + Field(&RTCPReportBlock::sender_ssrc, kSenderSsrc2), + Field(&RTCPReportBlock::fraction_lost, kFracLost[1]), + Field(&RTCPReportBlock::packets_lost, kCumLost[1]), + Field(&RTCPReportBlock::extended_highest_sequence_number, + kSequenceNumbers[1]))))); } TEST(RtcpReceiverTest, GetRtt) { @@ -648,33 +630,6 @@ TEST(RtcpReceiverTest, InjectSdesWithOneChunk) { EXPECT_CALL(callback, OnCname(kSenderSsrc, StrEq(kCname))); receiver.IncomingPacket(sdes.Build()); - - char cName[RTCP_CNAME_SIZE]; - EXPECT_EQ(0, receiver.CNAME(kSenderSsrc, cName)); - EXPECT_EQ(0, strncmp(cName, kCname, RTCP_CNAME_SIZE)); -} - -TEST(RtcpReceiverTest, InjectByePacket_RemovesCname) { - ReceiverMocks mocks; - RTCPReceiver receiver(DefaultConfiguration(&mocks), &mocks.rtp_rtcp_impl); - receiver.SetRemoteSSRC(kSenderSsrc); - - const char kCname[] = "alice@host"; - rtcp::Sdes sdes; - sdes.AddCName(kSenderSsrc, kCname); - - receiver.IncomingPacket(sdes.Build()); - - char cName[RTCP_CNAME_SIZE]; - EXPECT_EQ(0, receiver.CNAME(kSenderSsrc, cName)); - - // Verify that BYE removes the CNAME. - rtcp::Bye bye; - bye.SetSenderSsrc(kSenderSsrc); - - receiver.IncomingPacket(bye.Build()); - - EXPECT_EQ(-1, receiver.CNAME(kSenderSsrc, cName)); } TEST(RtcpReceiverTest, InjectByePacket_RemovesReportBlocks) { @@ -695,9 +650,7 @@ TEST(RtcpReceiverTest, InjectByePacket_RemovesReportBlocks) { EXPECT_CALL(mocks.bandwidth_observer, OnReceivedRtcpReceiverReport); receiver.IncomingPacket(rr.Build()); - std::vector received_blocks; - receiver.StatisticsReceived(&received_blocks); - EXPECT_EQ(2u, received_blocks.size()); + EXPECT_THAT(receiver.GetLatestReportBlockData(), SizeIs(2)); // Verify that BYE removes the report blocks. rtcp::Bye bye; @@ -705,18 +658,14 @@ TEST(RtcpReceiverTest, InjectByePacket_RemovesReportBlocks) { receiver.IncomingPacket(bye.Build()); - received_blocks.clear(); - receiver.StatisticsReceived(&received_blocks); - EXPECT_TRUE(received_blocks.empty()); + EXPECT_THAT(receiver.GetLatestReportBlockData(), IsEmpty()); // Inject packet again. EXPECT_CALL(mocks.rtp_rtcp_impl, OnReceivedRtcpReportBlocks); EXPECT_CALL(mocks.bandwidth_observer, OnReceivedRtcpReceiverReport); receiver.IncomingPacket(rr.Build()); - received_blocks.clear(); - receiver.StatisticsReceived(&received_blocks); - EXPECT_EQ(2u, received_blocks.size()); + EXPECT_THAT(receiver.GetLatestReportBlockData(), SizeIs(2)); } TEST(RtcpReceiverTest, InjectByePacketRemovesReferenceTimeInfo) { @@ -872,8 +821,7 @@ TEST(RtcpReceiverTest, InjectExtendedReportsDlrrPacketWithSubBlock) { receiver.IncomingPacket(xr.Build()); - uint32_t compact_ntp_now = - CompactNtp(TimeMicrosToNtp(mocks.clock.TimeInMicroseconds())); + uint32_t compact_ntp_now = CompactNtp(mocks.clock.CurrentNtpTime()); EXPECT_TRUE(receiver.GetAndResetXrRrRtt(&rtt_ms)); uint32_t rtt_ntp = compact_ntp_now - kDelay - kLastRR; EXPECT_NEAR(CompactNtpRttToMs(rtt_ntp), rtt_ms, 1); @@ -897,8 +845,7 @@ TEST(RtcpReceiverTest, InjectExtendedReportsDlrrPacketWithMultipleSubBlocks) { receiver.IncomingPacket(xr.Build()); - uint32_t compact_ntp_now = - CompactNtp(TimeMicrosToNtp(mocks.clock.TimeInMicroseconds())); + uint32_t compact_ntp_now = CompactNtp(mocks.clock.CurrentNtpTime()); int64_t rtt_ms = 0; EXPECT_TRUE(receiver.GetAndResetXrRrRtt(&rtt_ms)); uint32_t rtt_ntp = compact_ntp_now - kDelay - kLastRR; @@ -977,7 +924,7 @@ TEST(RtcpReceiverTest, RttCalculatedAfterExtendedReportsDlrr) { const int64_t kRttMs = rand.Rand(1, 9 * 3600 * 1000); const uint32_t kDelayNtp = rand.Rand(0, 0x7fffffff); const int64_t kDelayMs = CompactNtpRttToMs(kDelayNtp); - NtpTime now = TimeMicrosToNtp(mocks.clock.TimeInMicroseconds()); + NtpTime now = mocks.clock.CurrentNtpTime(); uint32_t sent_ntp = CompactNtp(now); mocks.clock.AdvanceTimeMilliseconds(kRttMs + kDelayMs); @@ -1003,7 +950,7 @@ TEST(RtcpReceiverTest, XrDlrrCalculatesNegativeRttAsOne) { const int64_t kRttMs = rand.Rand(-3600 * 1000, -1); const uint32_t kDelayNtp = rand.Rand(0, 0x7fffffff); const int64_t kDelayMs = CompactNtpRttToMs(kDelayNtp); - NtpTime now = TimeMicrosToNtp(mocks.clock.TimeInMicroseconds()); + NtpTime now = mocks.clock.CurrentNtpTime(); uint32_t sent_ntp = CompactNtp(now); mocks.clock.AdvanceTimeMilliseconds(kRttMs + kDelayMs); @@ -1300,53 +1247,17 @@ TEST(RtcpReceiverTest, TmmbrThreeConstraintsTimeOut) { mocks.clock.AdvanceTimeMilliseconds(5000); } // It is now starttime + 15. - std::vector candidate_set = receiver.TmmbrReceived(); - ASSERT_EQ(3u, candidate_set.size()); - EXPECT_EQ(30000U, candidate_set[0].bitrate_bps()); + EXPECT_THAT(receiver.TmmbrReceived(), + AllOf(SizeIs(3), + Each(Property(&rtcp::TmmbItem::bitrate_bps, Eq(30'000U))))); // We expect the timeout to be 25 seconds. Advance the clock by 12 // seconds, timing out the first packet. mocks.clock.AdvanceTimeMilliseconds(12000); - candidate_set = receiver.TmmbrReceived(); - ASSERT_EQ(2u, candidate_set.size()); - EXPECT_EQ(kSenderSsrc + 1, candidate_set[0].ssrc()); -} - -TEST(RtcpReceiverTest, Callbacks) { - ReceiverMocks mocks; - MockRtcpCallbackImpl callback; - RtpRtcpInterface::Configuration config = DefaultConfiguration(&mocks); - config.rtcp_statistics_callback = &callback; - RTCPReceiver receiver(config, &mocks.rtp_rtcp_impl); - receiver.SetRemoteSSRC(kSenderSsrc); - - const uint8_t kFractionLoss = 3; - const uint32_t kCumulativeLoss = 7; - const uint32_t kJitter = 9; - const uint16_t kSequenceNumber = 1234; - - // First packet, all numbers should just propagate. - rtcp::ReportBlock rb1; - rb1.SetMediaSsrc(kReceiverMainSsrc); - rb1.SetExtHighestSeqNum(kSequenceNumber); - rb1.SetFractionLost(kFractionLoss); - rb1.SetCumulativeLost(kCumulativeLoss); - rb1.SetJitter(kJitter); - - rtcp::ReceiverReport rr1; - rr1.SetSenderSsrc(kSenderSsrc); - rr1.AddReportBlock(rb1); - EXPECT_CALL(callback, - StatisticsUpdated( - AllOf(Field(&RtcpStatistics::fraction_lost, kFractionLoss), - Field(&RtcpStatistics::packets_lost, kCumulativeLoss), - Field(&RtcpStatistics::extended_highest_sequence_number, - kSequenceNumber), - Field(&RtcpStatistics::jitter, kJitter)), - kReceiverMainSsrc)); - EXPECT_CALL(mocks.rtp_rtcp_impl, OnReceivedRtcpReportBlocks); - EXPECT_CALL(mocks.bandwidth_observer, OnReceivedRtcpReceiverReport); - receiver.IncomingPacket(rr1.Build()); + EXPECT_THAT(receiver.TmmbrReceived(), + UnorderedElementsAre( + Property(&rtcp::TmmbItem::ssrc, Eq(kSenderSsrc + 1)), + Property(&rtcp::TmmbItem::ssrc, Eq(kSenderSsrc + 2)))); } TEST(RtcpReceiverTest, @@ -1411,8 +1322,7 @@ TEST(RtcpReceiverTest, VerifyRttObtainedFromReportBlockDataObserver) { const uint32_t kDelayNtp = 123000; const int64_t kDelayMs = CompactNtpRttToMs(kDelayNtp); - uint32_t sent_ntp = - CompactNtp(TimeMicrosToNtp(mocks.clock.TimeInMicroseconds())); + uint32_t sent_ntp = CompactNtp(mocks.clock.CurrentNtpTime()); mocks.clock.AdvanceTimeMilliseconds(kRttMs + kDelayMs); rtcp::SenderReport sr; diff --git a/modules/rtp_rtcp/source/rtcp_sender.cc b/modules/rtp_rtcp/source/rtcp_sender.cc index 79f5aa6c67..8f5e3b104c 100644 --- a/modules/rtp_rtcp/source/rtcp_sender.cc +++ b/modules/rtp_rtcp/source/rtcp_sender.cc @@ -16,7 +16,11 @@ #include #include +#include "absl/types/optional.h" #include "api/rtc_event_log/rtc_event_log.h" +#include "api/rtp_headers.h" +#include "api/units/time_delta.h" +#include "api/units/timestamp.h" #include "logging/rtc_event_log/events/rtc_event_rtcp_packet_outgoing.h" #include "modules/rtp_rtcp/source/rtcp_packet/app.h" #include "modules/rtp_rtcp/source/rtcp_packet/bye.h" @@ -34,6 +38,7 @@ #include "modules/rtp_rtcp/source/rtcp_packet/tmmbr.h" #include "modules/rtp_rtcp/source/rtcp_packet/transport_feedback.h" #include "modules/rtp_rtcp/source/rtp_rtcp_impl2.h" +#include "modules/rtp_rtcp/source/rtp_rtcp_interface.h" #include "modules/rtp_rtcp/source/time_util.h" #include "modules/rtp_rtcp/source/tmmbr_help.h" #include "rtc_base/checks.h" @@ -49,37 +54,10 @@ const uint32_t kRtcpAnyExtendedReports = kRtcpXrReceiverReferenceTime | kRtcpXrTargetBitrate; constexpr int32_t kDefaultVideoReportInterval = 1000; constexpr int32_t kDefaultAudioReportInterval = 5000; - -class PacketContainer : public rtcp::CompoundPacket { - public: - PacketContainer(Transport* transport, RtcEventLog* event_log) - : transport_(transport), event_log_(event_log) {} - - PacketContainer() = delete; - PacketContainer(const PacketContainer&) = delete; - PacketContainer& operator=(const PacketContainer&) = delete; - - size_t SendPackets(size_t max_payload_length) { - size_t bytes_sent = 0; - Build(max_payload_length, [&](rtc::ArrayView packet) { - if (transport_->SendRtcp(packet.data(), packet.size())) { - bytes_sent += packet.size(); - if (event_log_) { - event_log_->Log(std::make_unique(packet)); - } - } - }); - return bytes_sent; - } - - private: - Transport* transport_; - RtcEventLog* const event_log_; -}; +} // namespace // Helper to put several RTCP packets into lower layer datagram RTCP packet. -// Prefer to use this class instead of PacketContainer. -class PacketSender { +class RTCPSender::PacketSender { public: PacketSender(rtcp::RtcpPacket::PacketReadyCallback callback, size_t max_packet_size) @@ -102,8 +80,6 @@ class PacketSender { } } - bool IsEmpty() const { return index_ == 0; } - private: const rtcp::RtcpPacket::PacketReadyCallback callback_; const size_t max_packet_size_; @@ -111,8 +87,6 @@ class PacketSender { uint8_t buffer_[IP_PACKET_SIZE]; }; -} // namespace - RTCPSender::FeedbackState::FeedbackState() : packets_sent(0), media_bytes_sent(0), @@ -133,19 +107,38 @@ class RTCPSender::RtcpContext { RtcpContext(const FeedbackState& feedback_state, int32_t nack_size, const uint16_t* nack_list, - int64_t now_us) + Timestamp now) : feedback_state_(feedback_state), nack_size_(nack_size), nack_list_(nack_list), - now_us_(now_us) {} + now_(now) {} const FeedbackState& feedback_state_; const int32_t nack_size_; const uint16_t* nack_list_; - const int64_t now_us_; + const Timestamp now_; }; -RTCPSender::RTCPSender(const RtpRtcpInterface::Configuration& config) +RTCPSender::Configuration RTCPSender::Configuration::FromRtpRtcpConfiguration( + const RtpRtcpInterface::Configuration& configuration) { + RTCPSender::Configuration result; + result.audio = configuration.audio; + result.local_media_ssrc = configuration.local_media_ssrc; + result.clock = configuration.clock; + result.outgoing_transport = configuration.outgoing_transport; + result.non_sender_rtt_measurement = configuration.non_sender_rtt_measurement; + result.event_log = configuration.event_log; + if (configuration.rtcp_report_interval_ms) { + result.rtcp_report_interval = + TimeDelta::Millis(configuration.rtcp_report_interval_ms); + } + result.receive_statistics = configuration.receive_statistics; + result.rtcp_packet_type_counter_observer = + configuration.rtcp_packet_type_counter_observer; + return result; +} + +RTCPSender::RTCPSender(Configuration config) : audio_(config.audio), ssrc_(config.local_media_ssrc), clock_(config.clock), @@ -153,15 +146,14 @@ RTCPSender::RTCPSender(const RtpRtcpInterface::Configuration& config) method_(RtcpMode::kOff), event_log_(config.event_log), transport_(config.outgoing_transport), - report_interval_ms_(config.rtcp_report_interval_ms > 0 - ? config.rtcp_report_interval_ms - : (config.audio ? kDefaultAudioReportInterval - : kDefaultVideoReportInterval)), + report_interval_(config.rtcp_report_interval.value_or( + TimeDelta::Millis(config.audio ? kDefaultAudioReportInterval + : kDefaultVideoReportInterval))), + schedule_next_rtcp_send_evaluation_function_( + std::move(config.schedule_next_rtcp_send_evaluation_function)), sending_(false), - next_time_to_send_rtcp_(0), timestamp_offset_(0), last_rtp_timestamp_(0), - last_frame_capture_time_ms_(-1), remote_ssrc_(0), receive_statistics_(config.receive_statistics), @@ -204,10 +196,11 @@ RtcpMode RTCPSender::Status() const { void RTCPSender::SetRTCPStatus(RtcpMode new_method) { MutexLock lock(&mutex_rtcp_sender_); - if (method_ == RtcpMode::kOff && new_method != RtcpMode::kOff) { + if (new_method == RtcpMode::kOff) { + next_time_to_send_rtcp_ = absl::nullopt; + } else if (method_ == RtcpMode::kOff) { // When switching on, reschedule the next packet - next_time_to_send_rtcp_ = - clock_->TimeInMilliseconds() + (report_interval_ms_ / 2); + SetNextRtcpSendEvaluationDuration(report_interval_ / 2); } method_ = new_method; } @@ -217,8 +210,8 @@ bool RTCPSender::Sending() const { return sending_; } -int32_t RTCPSender::SetSendingStatus(const FeedbackState& feedback_state, - bool sending) { +void RTCPSender::SetSendingStatus(const FeedbackState& feedback_state, + bool sending) { bool sendRTCPBye = false; { MutexLock lock(&mutex_rtcp_sender_); @@ -231,9 +224,11 @@ int32_t RTCPSender::SetSendingStatus(const FeedbackState& feedback_state, } sending_ = sending; } - if (sendRTCPBye) - return SendRTCP(feedback_state, kRtcpBye); - return 0; + if (sendRTCPBye) { + if (SendRTCP(feedback_state, kRtcpBye) != 0) { + RTC_LOG(LS_WARNING) << "Failed to send RTCP BYE"; + } + } } int32_t RTCPSender::SendLossNotification(const FeedbackState& feedback_state, @@ -241,21 +236,42 @@ int32_t RTCPSender::SendLossNotification(const FeedbackState& feedback_state, uint16_t last_received_seq_num, bool decodability_flag, bool buffering_allowed) { - MutexLock lock(&mutex_rtcp_sender_); + int32_t error_code = -1; + auto callback = [&](rtc::ArrayView packet) { + transport_->SendRtcp(packet.data(), packet.size()); + error_code = 0; + if (event_log_) { + event_log_->Log(std::make_unique(packet)); + } + }; + absl::optional sender; + { + MutexLock lock(&mutex_rtcp_sender_); - loss_notification_state_.last_decoded_seq_num = last_decoded_seq_num; - loss_notification_state_.last_received_seq_num = last_received_seq_num; - loss_notification_state_.decodability_flag = decodability_flag; + if (!loss_notification_.Set(last_decoded_seq_num, last_received_seq_num, + decodability_flag)) { + return -1; + } + + SetFlag(kRtcpLossNotification, /*is_volatile=*/true); - SetFlag(kRtcpLossNotification, /*is_volatile=*/true); + if (buffering_allowed) { + // The loss notification will be batched with additional feedback + // messages. + return 0; + } - if (buffering_allowed) { - // The loss notification will be batched with additional feedback messages. - return 0; + sender.emplace(callback, max_packet_size_); + auto result = ComputeCompoundRTCPPacket( + feedback_state, RTCPPacketType::kRtcpLossNotification, 0, nullptr, + *sender); + if (result) { + return *result; + } } + sender->Send(); - return SendCompoundRTCPLocked( - feedback_state, {RTCPPacketType::kRtcpLossNotification}, 0, nullptr); + return error_code; } void RTCPSender::SetRemb(int64_t bitrate_bps, std::vector ssrcs) { @@ -267,7 +283,7 @@ void RTCPSender::SetRemb(int64_t bitrate_bps, std::vector ssrcs) { SetFlag(kRtcpRemb, /*is_volatile=*/false); // Send a REMB immediately if we have a new REMB. The frequency of REMBs is // throttled by the caller. - next_time_to_send_rtcp_ = clock_->TimeInMilliseconds(); + SetNextRtcpSendEvaluationDuration(TimeDelta::Zero()); } void RTCPSender::UnsetRemb() { @@ -281,15 +297,6 @@ bool RTCPSender::TMMBR() const { return IsFlagPresent(RTCPPacketType::kRtcpTmmbr); } -void RTCPSender::SetTMMBRStatus(bool enable) { - MutexLock lock(&mutex_rtcp_sender_); - if (enable) { - SetFlag(RTCPPacketType::kRtcpTmmbr, false); - } else { - ConsumeFlag(RTCPPacketType::kRtcpTmmbr, true); - } -} - void RTCPSender::SetMaxRtpPacketSize(size_t max_packet_size) { MutexLock lock(&mutex_rtcp_sender_); max_packet_size_ = max_packet_size; @@ -301,20 +308,20 @@ void RTCPSender::SetTimestampOffset(uint32_t timestamp_offset) { } void RTCPSender::SetLastRtpTime(uint32_t rtp_timestamp, - int64_t capture_time_ms, - int8_t payload_type) { + absl::optional capture_time, + absl::optional payload_type) { MutexLock lock(&mutex_rtcp_sender_); // For compatibility with clients who don't set payload type correctly on all // calls. - if (payload_type != -1) { - last_payload_type_ = payload_type; + if (payload_type.has_value()) { + last_payload_type_ = *payload_type; } last_rtp_timestamp_ = rtp_timestamp; - if (capture_time_ms <= 0) { + if (!capture_time.has_value()) { // We don't currently get a capture time from VoiceEngine. - last_frame_capture_time_ms_ = clock_->TimeInMilliseconds(); + last_frame_capture_time_ = clock_->CurrentTime(); } else { - last_frame_capture_time_ms_ = capture_time_ms; + last_frame_capture_time_ = *capture_time; } } @@ -323,6 +330,16 @@ void RTCPSender::SetRtpClockRate(int8_t payload_type, int rtp_clock_rate_hz) { rtp_clock_rates_khz_[payload_type] = rtp_clock_rate_hz / 1000; } +uint32_t RTCPSender::SSRC() const { + MutexLock lock(&mutex_rtcp_sender_); + return ssrc_; +} + +void RTCPSender::SetSsrc(uint32_t ssrc) { + MutexLock lock(&mutex_rtcp_sender_); + ssrc_ = ssrc; +} + void RTCPSender::SetRemoteSSRC(uint32_t ssrc) { MutexLock lock(&mutex_rtcp_sender_); remote_ssrc_ = ssrc; @@ -338,31 +355,6 @@ int32_t RTCPSender::SetCNAME(const char* c_name) { return 0; } -int32_t RTCPSender::AddMixedCNAME(uint32_t SSRC, const char* c_name) { - RTC_DCHECK(c_name); - RTC_DCHECK_LT(strlen(c_name), RTCP_CNAME_SIZE); - MutexLock lock(&mutex_rtcp_sender_); - // One spot is reserved for ssrc_/cname_. - // TODO(danilchap): Add support for more than 30 contributes by sending - // several sdes packets. - if (csrc_cnames_.size() >= rtcp::Sdes::kMaxNumberOfChunks - 1) - return -1; - - csrc_cnames_[SSRC] = c_name; - return 0; -} - -int32_t RTCPSender::RemoveMixedCNAME(uint32_t SSRC) { - MutexLock lock(&mutex_rtcp_sender_); - auto it = csrc_cnames_.find(SSRC); - - if (it == csrc_cnames_.end()) - return -1; - - csrc_cnames_.erase(it); - return 0; -} - bool RTCPSender::TimeToSendRTCPReport(bool sendKeyframeBeforeRTP) const { /* For audio we use a configurable interval (default: 5 seconds) @@ -422,25 +414,27 @@ bool RTCPSender::TimeToSendRTCPReport(bool sendKeyframeBeforeRTP) const { a value of the RTCP bandwidth below the intended average */ - int64_t now = clock_->TimeInMilliseconds(); + Timestamp now = clock_->CurrentTime(); MutexLock lock(&mutex_rtcp_sender_); - + RTC_DCHECK( + (method_ == RtcpMode::kOff && !next_time_to_send_rtcp_.has_value()) || + (method_ != RtcpMode::kOff && next_time_to_send_rtcp_.has_value())); if (method_ == RtcpMode::kOff) return false; if (!audio_ && sendKeyframeBeforeRTP) { // for video key-frames we want to send the RTCP before the large key-frame // if we have a 100 ms margin - now += RTCP_SEND_BEFORE_KEY_FRAME_MS; + now += RTCP_SEND_BEFORE_KEY_FRAME; } - return now >= next_time_to_send_rtcp_; + return now >= *next_time_to_send_rtcp_; } -std::unique_ptr RTCPSender::BuildSR(const RtcpContext& ctx) { +void RTCPSender::BuildSR(const RtcpContext& ctx, PacketSender& sender) { // Timestamp shouldn't be estimated before first media frame. - RTC_DCHECK_GE(last_frame_capture_time_ms_, 0); + RTC_DCHECK(last_frame_capture_time_.has_value()); // The timestamp of this RTCP packet should be estimated as the timestamp of // the frame being captured at this moment. We are calculating that // timestamp as the last frame's timestamp + the time since the last frame @@ -455,71 +449,61 @@ std::unique_ptr RTCPSender::BuildSR(const RtcpContext& ctx) { // when converted to milliseconds, uint32_t rtp_timestamp = timestamp_offset_ + last_rtp_timestamp_ + - ((ctx.now_us_ + 500) / 1000 - last_frame_capture_time_ms_) * rtp_rate; - - rtcp::SenderReport* report = new rtcp::SenderReport(); - report->SetSenderSsrc(ssrc_); - report->SetNtp(TimeMicrosToNtp(ctx.now_us_)); - report->SetRtpTimestamp(rtp_timestamp); - report->SetPacketCount(ctx.feedback_state_.packets_sent); - report->SetOctetCount(ctx.feedback_state_.media_bytes_sent); - report->SetReportBlocks(CreateReportBlocks(ctx.feedback_state_)); + ((ctx.now_.us() + 500) / 1000 - last_frame_capture_time_->ms()) * + rtp_rate; - return std::unique_ptr(report); + rtcp::SenderReport report; + report.SetSenderSsrc(ssrc_); + report.SetNtp(clock_->ConvertTimestampToNtpTime(ctx.now_)); + report.SetRtpTimestamp(rtp_timestamp); + report.SetPacketCount(ctx.feedback_state_.packets_sent); + report.SetOctetCount(ctx.feedback_state_.media_bytes_sent); + report.SetReportBlocks(CreateReportBlocks(ctx.feedback_state_)); + sender.AppendPacket(report); } -std::unique_ptr RTCPSender::BuildSDES( - const RtcpContext& ctx) { +void RTCPSender::BuildSDES(const RtcpContext& ctx, PacketSender& sender) { size_t length_cname = cname_.length(); RTC_CHECK_LT(length_cname, RTCP_CNAME_SIZE); - rtcp::Sdes* sdes = new rtcp::Sdes(); - sdes->AddCName(ssrc_, cname_); - - for (const auto& it : csrc_cnames_) - RTC_CHECK(sdes->AddCName(it.first, it.second)); - - return std::unique_ptr(sdes); + rtcp::Sdes sdes; + sdes.AddCName(ssrc_, cname_); + sender.AppendPacket(sdes); } -std::unique_ptr RTCPSender::BuildRR(const RtcpContext& ctx) { - rtcp::ReceiverReport* report = new rtcp::ReceiverReport(); - report->SetSenderSsrc(ssrc_); - report->SetReportBlocks(CreateReportBlocks(ctx.feedback_state_)); - - return std::unique_ptr(report); +void RTCPSender::BuildRR(const RtcpContext& ctx, PacketSender& sender) { + rtcp::ReceiverReport report; + report.SetSenderSsrc(ssrc_); + report.SetReportBlocks(CreateReportBlocks(ctx.feedback_state_)); + sender.AppendPacket(report); } -std::unique_ptr RTCPSender::BuildPLI(const RtcpContext& ctx) { - rtcp::Pli* pli = new rtcp::Pli(); - pli->SetSenderSsrc(ssrc_); - pli->SetMediaSsrc(remote_ssrc_); +void RTCPSender::BuildPLI(const RtcpContext& ctx, PacketSender& sender) { + rtcp::Pli pli; + pli.SetSenderSsrc(ssrc_); + pli.SetMediaSsrc(remote_ssrc_); ++packet_type_counter_.pli_packets; - - return std::unique_ptr(pli); + sender.AppendPacket(pli); } -std::unique_ptr RTCPSender::BuildFIR(const RtcpContext& ctx) { +void RTCPSender::BuildFIR(const RtcpContext& ctx, PacketSender& sender) { ++sequence_number_fir_; - rtcp::Fir* fir = new rtcp::Fir(); - fir->SetSenderSsrc(ssrc_); - fir->AddRequestTo(remote_ssrc_, sequence_number_fir_); + rtcp::Fir fir; + fir.SetSenderSsrc(ssrc_); + fir.AddRequestTo(remote_ssrc_, sequence_number_fir_); ++packet_type_counter_.fir_packets; - - return std::unique_ptr(fir); + sender.AppendPacket(fir); } -std::unique_ptr RTCPSender::BuildREMB( - const RtcpContext& ctx) { - rtcp::Remb* remb = new rtcp::Remb(); - remb->SetSenderSsrc(ssrc_); - remb->SetBitrateBps(remb_bitrate_); - remb->SetSsrcs(remb_ssrcs_); - - return std::unique_ptr(remb); +void RTCPSender::BuildREMB(const RtcpContext& ctx, PacketSender& sender) { + rtcp::Remb remb; + remb.SetSenderSsrc(ssrc_); + remb.SetBitrateBps(remb_bitrate_); + remb.SetSsrcs(remb_ssrcs_); + sender.AppendPacket(remb); } void RTCPSender::SetTargetBitrate(unsigned int target_bitrate) { @@ -527,10 +511,9 @@ void RTCPSender::SetTargetBitrate(unsigned int target_bitrate) { tmmbr_send_bps_ = target_bitrate; } -std::unique_ptr RTCPSender::BuildTMMBR( - const RtcpContext& ctx) { +void RTCPSender::BuildTMMBR(const RtcpContext& ctx, PacketSender& sender) { if (ctx.feedback_state_.receiver == nullptr) - return nullptr; + return; // Before sending the TMMBR check the received TMMBN, only an owner is // allowed to raise the bitrate: // * If the sender is an owner of the TMMBN -> send TMMBR @@ -550,7 +533,7 @@ std::unique_ptr RTCPSender::BuildTMMBR( if (candidate.bitrate_bps() == tmmbr_send_bps_ && candidate.packet_overhead() == packet_oh_send_) { // Do not send the same tuple. - return nullptr; + return; } } if (!tmmbr_owner) { @@ -564,62 +547,53 @@ std::unique_ptr RTCPSender::BuildTMMBR( tmmbr_owner = TMMBRHelp::IsOwner(bounding, ssrc_); if (!tmmbr_owner) { // Did not enter bounding set, no meaning to send this request. - return nullptr; + return; } } } if (!tmmbr_send_bps_) - return nullptr; + return; - rtcp::Tmmbr* tmmbr = new rtcp::Tmmbr(); - tmmbr->SetSenderSsrc(ssrc_); + rtcp::Tmmbr tmmbr; + tmmbr.SetSenderSsrc(ssrc_); rtcp::TmmbItem request; request.set_ssrc(remote_ssrc_); request.set_bitrate_bps(tmmbr_send_bps_); request.set_packet_overhead(packet_oh_send_); - tmmbr->AddTmmbr(request); - - return std::unique_ptr(tmmbr); + tmmbr.AddTmmbr(request); + sender.AppendPacket(tmmbr); } -std::unique_ptr RTCPSender::BuildTMMBN( - const RtcpContext& ctx) { - rtcp::Tmmbn* tmmbn = new rtcp::Tmmbn(); - tmmbn->SetSenderSsrc(ssrc_); +void RTCPSender::BuildTMMBN(const RtcpContext& ctx, PacketSender& sender) { + rtcp::Tmmbn tmmbn; + tmmbn.SetSenderSsrc(ssrc_); for (const rtcp::TmmbItem& tmmbr : tmmbn_to_send_) { if (tmmbr.bitrate_bps() > 0) { - tmmbn->AddTmmbr(tmmbr); + tmmbn.AddTmmbr(tmmbr); } } - - return std::unique_ptr(tmmbn); + sender.AppendPacket(tmmbn); } -std::unique_ptr RTCPSender::BuildAPP(const RtcpContext& ctx) { - rtcp::App* app = new rtcp::App(); - app->SetSenderSsrc(ssrc_); - - return std::unique_ptr(app); +void RTCPSender::BuildAPP(const RtcpContext& ctx, PacketSender& sender) { + rtcp::App app; + app.SetSenderSsrc(ssrc_); + sender.AppendPacket(app); } -std::unique_ptr RTCPSender::BuildLossNotification( - const RtcpContext& ctx) { - auto loss_notification = std::make_unique( - loss_notification_state_.last_decoded_seq_num, - loss_notification_state_.last_received_seq_num, - loss_notification_state_.decodability_flag); - loss_notification->SetSenderSsrc(ssrc_); - loss_notification->SetMediaSsrc(remote_ssrc_); - return std::move(loss_notification); +void RTCPSender::BuildLossNotification(const RtcpContext& ctx, + PacketSender& sender) { + loss_notification_.SetSenderSsrc(ssrc_); + loss_notification_.SetMediaSsrc(remote_ssrc_); + sender.AppendPacket(loss_notification_); } -std::unique_ptr RTCPSender::BuildNACK( - const RtcpContext& ctx) { - rtcp::Nack* nack = new rtcp::Nack(); - nack->SetSenderSsrc(ssrc_); - nack->SetMediaSsrc(remote_ssrc_); - nack->SetPacketIds(ctx.nack_list_, ctx.nack_size_); +void RTCPSender::BuildNACK(const RtcpContext& ctx, PacketSender& sender) { + rtcp::Nack nack; + nack.SetSenderSsrc(ssrc_); + nack.SetMediaSsrc(remote_ssrc_); + nack.SetPacketIds(ctx.nack_list_, ctx.nack_size_); // Report stats. for (int idx = 0; idx < ctx.nack_size_; ++idx) { @@ -629,31 +603,29 @@ std::unique_ptr RTCPSender::BuildNACK( packet_type_counter_.unique_nack_requests = nack_stats_.unique_requests(); ++packet_type_counter_.nack_packets; - - return std::unique_ptr(nack); + sender.AppendPacket(nack); } -std::unique_ptr RTCPSender::BuildBYE(const RtcpContext& ctx) { - rtcp::Bye* bye = new rtcp::Bye(); - bye->SetSenderSsrc(ssrc_); - bye->SetCsrcs(csrcs_); - - return std::unique_ptr(bye); +void RTCPSender::BuildBYE(const RtcpContext& ctx, PacketSender& sender) { + rtcp::Bye bye; + bye.SetSenderSsrc(ssrc_); + bye.SetCsrcs(csrcs_); + sender.AppendPacket(bye); } -std::unique_ptr RTCPSender::BuildExtendedReports( - const RtcpContext& ctx) { - std::unique_ptr xr(new rtcp::ExtendedReports()); - xr->SetSenderSsrc(ssrc_); +void RTCPSender::BuildExtendedReports(const RtcpContext& ctx, + PacketSender& sender) { + rtcp::ExtendedReports xr; + xr.SetSenderSsrc(ssrc_); if (!sending_ && xr_send_receiver_reference_time_enabled_) { rtcp::Rrtr rrtr; - rrtr.SetNtp(TimeMicrosToNtp(ctx.now_us_)); - xr->SetRrtr(rrtr); + rrtr.SetNtp(clock_->ConvertTimestampToNtpTime(ctx.now_)); + xr.SetRrtr(rrtr); } for (const rtcp::ReceiveTimeInfo& rti : ctx.feedback_state_.last_xr_rtis) { - xr->AddDlrrItem(rti); + xr.AddDlrrItem(rti); } if (send_video_bitrate_allocation_) { @@ -668,75 +640,56 @@ std::unique_ptr RTCPSender::BuildExtendedReports( } } - xr->SetTargetBitrate(target_bitrate); + xr.SetTargetBitrate(target_bitrate); send_video_bitrate_allocation_ = false; } - - return std::move(xr); + sender.AppendPacket(xr); } int32_t RTCPSender::SendRTCP(const FeedbackState& feedback_state, - RTCPPacketType packetType, + RTCPPacketType packet_type, int32_t nack_size, const uint16_t* nack_list) { - return SendCompoundRTCP( - feedback_state, std::set(&packetType, &packetType + 1), - nack_size, nack_list); -} - -int32_t RTCPSender::SendCompoundRTCP( - const FeedbackState& feedback_state, - const std::set& packet_types, - int32_t nack_size, - const uint16_t* nack_list) { - PacketContainer container(transport_, event_log_); - size_t max_packet_size; - + int32_t error_code = -1; + auto callback = [&](rtc::ArrayView packet) { + if (transport_->SendRtcp(packet.data(), packet.size())) { + error_code = 0; + if (event_log_) { + event_log_->Log(std::make_unique(packet)); + } + } + }; + absl::optional sender; { MutexLock lock(&mutex_rtcp_sender_); - auto result = ComputeCompoundRTCPPacket(feedback_state, packet_types, - nack_size, nack_list, &container); + sender.emplace(callback, max_packet_size_); + auto result = ComputeCompoundRTCPPacket(feedback_state, packet_type, + nack_size, nack_list, *sender); if (result) { return *result; } - max_packet_size = max_packet_size_; } + sender->Send(); - size_t bytes_sent = container.SendPackets(max_packet_size); - return bytes_sent == 0 ? -1 : 0; -} - -int32_t RTCPSender::SendCompoundRTCPLocked( - const FeedbackState& feedback_state, - const std::set& packet_types, - int32_t nack_size, - const uint16_t* nack_list) { - PacketContainer container(transport_, event_log_); - auto result = ComputeCompoundRTCPPacket(feedback_state, packet_types, - nack_size, nack_list, &container); - if (result) { - return *result; - } - size_t bytes_sent = container.SendPackets(max_packet_size_); - return bytes_sent == 0 ? -1 : 0; + return error_code; } absl::optional RTCPSender::ComputeCompoundRTCPPacket( const FeedbackState& feedback_state, - const std::set& packet_types, + RTCPPacketType packet_type, int32_t nack_size, const uint16_t* nack_list, - rtcp::CompoundPacket* out_packet) { + PacketSender& sender) { if (method_ == RtcpMode::kOff) { RTC_LOG(LS_WARNING) << "Can't send rtcp if it is disabled."; return -1; } - // Add all flags as volatile. Non volatile entries will not be overwritten. - // All new volatile flags added will be consumed by the end of this call. - SetFlags(packet_types, true); + // Add the flag as volatile. Non volatile entries will not be overwritten. + // The new volatile flag will be consumed by the end of this call. + SetFlag(packet_type, true); // Prevent sending streams to send SR before any media has been sent. - const bool can_calculate_rtp_timestamp = (last_frame_capture_time_ms_ >= 0); + const bool can_calculate_rtp_timestamp = last_frame_capture_time_.has_value(); if (!can_calculate_rtp_timestamp) { bool consumed_sr_flag = ConsumeFlag(kRtcpSr); bool consumed_report_flag = sending_ && ConsumeFlag(kRtcpReport); @@ -756,41 +709,41 @@ absl::optional RTCPSender::ComputeCompoundRTCPPacket( // We need to send our NTP even if we haven't received any reports. RtcpContext context(feedback_state, nack_size, nack_list, - clock_->TimeInMicroseconds()); + clock_->CurrentTime()); PrepareReport(feedback_state); - std::unique_ptr packet_bye; + bool create_bye = false; auto it = report_flags_.begin(); while (it != report_flags_.end()) { - auto builder_it = builders_.find(it->type); + uint32_t rtcp_packet_type = it->type; + if (it->is_volatile) { report_flags_.erase(it++); } else { ++it; } + // If there is a BYE, don't append now - save it and append it + // at the end later. + if (rtcp_packet_type == kRtcpBye) { + create_bye = true; + continue; + } + auto builder_it = builders_.find(rtcp_packet_type); if (builder_it == builders_.end()) { - RTC_NOTREACHED() << "Could not find builder for packet type " << it->type; + RTC_NOTREACHED() << "Could not find builder for packet type " + << rtcp_packet_type; } else { BuilderFunc func = builder_it->second; - std::unique_ptr packet = (this->*func)(context); - if (packet == nullptr) - return -1; - // If there is a BYE, don't append now - save it and append it - // at the end later. - if (builder_it->first == kRtcpBye) { - packet_bye = std::move(packet); - } else { - out_packet->Append(std::move(packet)); - } + (this->*func)(context, sender); } } // Append the BYE now at the end - if (packet_bye) { - out_packet->Append(std::move(packet_bye)); + if (create_bye) { + BuildBYE(context, sender); } if (packet_type_counter_observer_ != nullptr) { @@ -827,24 +780,25 @@ void RTCPSender::PrepareReport(const FeedbackState& feedback_state) { } // generate next time to send an RTCP report - int min_interval_ms = report_interval_ms_; + TimeDelta min_interval = report_interval_; if (!audio_ && sending_) { // Calculate bandwidth for video; 360 / send bandwidth in kbit/s. int send_bitrate_kbit = feedback_state.send_bitrate / 1000; if (send_bitrate_kbit != 0) { - min_interval_ms = 360000 / send_bitrate_kbit; - min_interval_ms = std::min(min_interval_ms, report_interval_ms_); + min_interval = std::min(TimeDelta::Millis(360000 / send_bitrate_kbit), + report_interval_); } } // The interval between RTCP packets is varied randomly over the // range [1/2,3/2] times the calculated interval. - int time_to_next = - random_.Rand(min_interval_ms * 1 / 2, min_interval_ms * 3 / 2); + int min_interval_int = rtc::dchecked_cast(min_interval.ms()); + TimeDelta time_to_next = TimeDelta::Millis( + random_.Rand(min_interval_int * 1 / 2, min_interval_int * 3 / 2)); - RTC_DCHECK_GT(time_to_next, 0); - next_time_to_send_rtcp_ = clock_->TimeInMilliseconds() + time_to_next; + RTC_DCHECK(!time_to_next.IsZero()); + SetNextRtcpSendEvaluationDuration(time_to_next); // RtcpSender expected to be used for sending either just sender reports // or just receiver reports. @@ -866,7 +820,7 @@ std::vector RTCPSender::CreateReportBlocks( if (!result.empty() && ((feedback_state.last_rr_ntp_secs != 0) || (feedback_state.last_rr_ntp_frac != 0))) { // Get our NTP as late as possible to avoid a race. - uint32_t now = CompactNtp(TimeMicrosToNtp(clock_->TimeInMicroseconds())); + uint32_t now = CompactNtp(clock_->CurrentNtpTime()); uint32_t receive_time = feedback_state.last_rr_ntp_secs & 0x0000FFFF; receive_time <<= 16; @@ -904,12 +858,6 @@ void RTCPSender::SetFlag(uint32_t type, bool is_volatile) { } } -void RTCPSender::SetFlags(const std::set& types, - bool is_volatile) { - for (RTCPPacketType type : types) - SetFlag(type, is_volatile); -} - bool RTCPSender::IsFlagPresent(uint32_t type) const { return report_flags_.find(ReportFlag(type, false)) != report_flags_.end(); } @@ -944,7 +892,7 @@ void RTCPSender::SetVideoBitrateAllocation( RTC_LOG(LS_INFO) << "Emitting TargetBitrate XR for SSRC " << ssrc_ << " with new layers enabled/disabled: " << video_bitrate_allocation_.ToString(); - next_time_to_send_rtcp_ = clock_->TimeInMilliseconds(); + SetNextRtcpSendEvaluationDuration(TimeDelta::Zero()); } else { video_bitrate_allocation_ = bitrate; } @@ -1005,4 +953,12 @@ void RTCPSender::SendCombinedRtcpPacket( sender.Send(); } +void RTCPSender::SetNextRtcpSendEvaluationDuration(TimeDelta duration) { + next_time_to_send_rtcp_ = clock_->CurrentTime() + duration; + // TODO(bugs.webrtc.org/11581): make unconditional once downstream consumers + // are using the callback method. + if (schedule_next_rtcp_send_evaluation_function_) + schedule_next_rtcp_send_evaluation_function_(duration); +} + } // namespace webrtc diff --git a/modules/rtp_rtcp/source/rtcp_sender.h b/modules/rtp_rtcp/source/rtcp_sender.h index cc9091dfc7..2d1c7da0fc 100644 --- a/modules/rtp_rtcp/source/rtcp_sender.h +++ b/modules/rtp_rtcp/source/rtcp_sender.h @@ -19,6 +19,8 @@ #include "absl/types/optional.h" #include "api/call/transport.h" +#include "api/units/time_delta.h" +#include "api/units/timestamp.h" #include "api/video/video_bitrate_allocation.h" #include "modules/remote_bitrate_estimator/include/remote_bitrate_estimator.h" #include "modules/rtp_rtcp/include/receive_statistics.h" @@ -27,6 +29,7 @@ #include "modules/rtp_rtcp/source/rtcp_packet.h" #include "modules/rtp_rtcp/source/rtcp_packet/compound_packet.h" #include "modules/rtp_rtcp/source/rtcp_packet/dlrr.h" +#include "modules/rtp_rtcp/source/rtcp_packet/loss_notification.h" #include "modules/rtp_rtcp/source/rtcp_packet/report_block.h" #include "modules/rtp_rtcp/source/rtcp_packet/tmmb_item.h" #include "modules/rtp_rtcp/source/rtp_rtcp_interface.h" @@ -41,6 +44,43 @@ class RtcEventLog; class RTCPSender final { public: + struct Configuration { + // TODO(bugs.webrtc.org/11581): Remove this temporary conversion utility + // once rtc_rtcp_impl.cc/h are gone. + static Configuration FromRtpRtcpConfiguration( + const RtpRtcpInterface::Configuration& config); + + // True for a audio version of the RTP/RTCP module object false will create + // a video version. + bool audio = false; + // SSRCs for media and retransmission, respectively. + // FlexFec SSRC is fetched from |flexfec_sender|. + uint32_t local_media_ssrc = 0; + // The clock to use to read time. If nullptr then system clock will be used. + Clock* clock = nullptr; + // Transport object that will be called when packets are ready to be sent + // out on the network. + Transport* outgoing_transport = nullptr; + // Estimate RTT as non-sender as described in + // https://tools.ietf.org/html/rfc3611#section-4.4 and #section-4.5 + bool non_sender_rtt_measurement = false; + // Optional callback which, if specified, is used by RTCPSender to schedule + // the next time to evaluate if RTCP should be sent by means of + // TimeToSendRTCPReport/SendRTCP. + // The RTCPSender client still needs to call TimeToSendRTCPReport/SendRTCP + // to actually get RTCP sent. + // + // Note: It's recommended to use the callback to ensure program design that + // doesn't use polling. + // TODO(bugs.webrtc.org/11581): Make mandatory once downstream consumers + // have migrated to the callback solution. + std::function schedule_next_rtcp_send_evaluation_function; + + RtcEventLog* event_log = nullptr; + absl::optional rtcp_report_interval; + ReceiveStatisticsProvider* receive_statistics = nullptr; + RtcpPacketTypeCounterObserver* rtcp_packet_type_counter_observer = nullptr; + }; struct FeedbackState { FeedbackState(); FeedbackState(const FeedbackState&); @@ -62,7 +102,7 @@ class RTCPSender final { RTCPReceiver* receiver; }; - explicit RTCPSender(const RtpRtcpInterface::Configuration& config); + explicit RTCPSender(Configuration config); RTCPSender() = delete; RTCPSender(const RTCPSender&) = delete; @@ -74,8 +114,8 @@ class RTCPSender final { void SetRTCPStatus(RtcpMode method) RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); bool Sending() const RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); - int32_t SetSendingStatus(const FeedbackState& feedback_state, - bool enabled) + void SetSendingStatus(const FeedbackState& feedback_state, + bool enabled) RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); // combine the functions int32_t SetNackStatus(bool enable) RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); @@ -83,28 +123,21 @@ class RTCPSender final { void SetTimestampOffset(uint32_t timestamp_offset) RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); - // TODO(bugs.webrtc.org/6458): Remove default parameter value when all the - // depending projects are updated to correctly set payload type. void SetLastRtpTime(uint32_t rtp_timestamp, - int64_t capture_time_ms, - int8_t payload_type = -1) + absl::optional capture_time, + absl::optional payload_type) RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); void SetRtpClockRate(int8_t payload_type, int rtp_clock_rate_hz) RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); - uint32_t SSRC() const { return ssrc_; } + uint32_t SSRC() const; + void SetSsrc(uint32_t ssrc); void SetRemoteSSRC(uint32_t ssrc) RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); int32_t SetCNAME(const char* cName) RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); - int32_t AddMixedCNAME(uint32_t SSRC, const char* c_name) - RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); - - int32_t RemoveMixedCNAME(uint32_t SSRC) - RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); - bool TimeToSendRTCPReport(bool sendKeyframeBeforeRTP = false) const RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); @@ -114,12 +147,6 @@ class RTCPSender final { const uint16_t* nackList = 0) RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); - int32_t SendCompoundRTCP(const FeedbackState& feedback_state, - const std::set& packetTypes, - int32_t nackSize = 0, - const uint16_t* nackList = nullptr) - RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); - int32_t SendLossNotification(const FeedbackState& feedback_state, uint16_t last_decoded_seq_num, uint16_t last_received_seq_num, @@ -134,8 +161,6 @@ class RTCPSender final { bool TMMBR() const RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); - void SetTMMBRStatus(bool enable) RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); - void SetMaxRtpPacketSize(size_t max_packet_size) RTC_LOCKS_EXCLUDED(mutex_rtcp_sender_); @@ -155,20 +180,14 @@ class RTCPSender final { private: class RtcpContext; - - int32_t SendCompoundRTCPLocked(const FeedbackState& feedback_state, - const std::set& packet_types, - int32_t nack_size, - const uint16_t* nack_list) - RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); + class PacketSender; absl::optional ComputeCompoundRTCPPacket( const FeedbackState& feedback_state, - const std::set& packet_types, + RTCPPacketType packet_type, int32_t nack_size, const uint16_t* nack_list, - rtcp::CompoundPacket* out_packet) - RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); + PacketSender& sender) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); // Determine which RTCP messages should be sent and setup flags. void PrepareReport(const FeedbackState& feedback_state) @@ -178,38 +197,43 @@ class RTCPSender final { const FeedbackState& feedback_state) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - std::unique_ptr BuildSR(const RtcpContext& context) + void BuildSR(const RtcpContext& context, PacketSender& sender) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - std::unique_ptr BuildRR(const RtcpContext& context) + void BuildRR(const RtcpContext& context, PacketSender& sender) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - std::unique_ptr BuildSDES(const RtcpContext& context) + void BuildSDES(const RtcpContext& context, PacketSender& sender) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - std::unique_ptr BuildPLI(const RtcpContext& context) + void BuildPLI(const RtcpContext& context, PacketSender& sender) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - std::unique_ptr BuildREMB(const RtcpContext& context) + void BuildREMB(const RtcpContext& context, PacketSender& sender) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - std::unique_ptr BuildTMMBR(const RtcpContext& context) + void BuildTMMBR(const RtcpContext& context, PacketSender& sender) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - std::unique_ptr BuildTMMBN(const RtcpContext& context) + void BuildTMMBN(const RtcpContext& context, PacketSender& sender) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - std::unique_ptr BuildAPP(const RtcpContext& context) + void BuildAPP(const RtcpContext& context, PacketSender& sender) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - std::unique_ptr BuildLossNotification( - const RtcpContext& context) + void BuildLossNotification(const RtcpContext& context, PacketSender& sender) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - std::unique_ptr BuildExtendedReports( - const RtcpContext& context) + void BuildExtendedReports(const RtcpContext& context, PacketSender& sender) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - std::unique_ptr BuildBYE(const RtcpContext& context) + void BuildBYE(const RtcpContext& context, PacketSender& sender) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - std::unique_ptr BuildFIR(const RtcpContext& context) + void BuildFIR(const RtcpContext& context, PacketSender& sender) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - std::unique_ptr BuildNACK(const RtcpContext& context) + void BuildNACK(const RtcpContext& context, PacketSender& sender) + RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); + + // |duration| being TimeDelta::Zero() means schedule immediately. + void SetNextRtcpSendEvaluationDuration(TimeDelta duration) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - private: const bool audio_; - const uint32_t ssrc_; + // TODO(bugs.webrtc.org/11581): `mutex_rtcp_sender_` shouldn't be required if + // we consistently run network related operations on the network thread. + // This is currently not possible due to callbacks from the process thread in + // ModuleRtpRtcpImpl2. + uint32_t ssrc_ RTC_GUARDED_BY(mutex_rtcp_sender_); Clock* const clock_; Random random_ RTC_GUARDED_BY(mutex_rtcp_sender_); RtcpMode method_ RTC_GUARDED_BY(mutex_rtcp_sender_); @@ -217,24 +241,28 @@ class RTCPSender final { RtcEventLog* const event_log_; Transport* const transport_; - const int report_interval_ms_; + const TimeDelta report_interval_; + // Set from + // RTCPSender::Configuration::schedule_next_rtcp_send_evaluation_function. + const std::function + schedule_next_rtcp_send_evaluation_function_; mutable Mutex mutex_rtcp_sender_; bool sending_ RTC_GUARDED_BY(mutex_rtcp_sender_); - int64_t next_time_to_send_rtcp_ RTC_GUARDED_BY(mutex_rtcp_sender_); + absl::optional next_time_to_send_rtcp_ + RTC_GUARDED_BY(mutex_rtcp_sender_); uint32_t timestamp_offset_ RTC_GUARDED_BY(mutex_rtcp_sender_); uint32_t last_rtp_timestamp_ RTC_GUARDED_BY(mutex_rtcp_sender_); - int64_t last_frame_capture_time_ms_ RTC_GUARDED_BY(mutex_rtcp_sender_); + absl::optional last_frame_capture_time_ + RTC_GUARDED_BY(mutex_rtcp_sender_); // SSRC that we receive on our RTP channel uint32_t remote_ssrc_ RTC_GUARDED_BY(mutex_rtcp_sender_); std::string cname_ RTC_GUARDED_BY(mutex_rtcp_sender_); ReceiveStatisticsProvider* receive_statistics_ RTC_GUARDED_BY(mutex_rtcp_sender_); - std::map csrc_cnames_ - RTC_GUARDED_BY(mutex_rtcp_sender_); // send CSRCs std::vector csrcs_ RTC_GUARDED_BY(mutex_rtcp_sender_); @@ -242,14 +270,7 @@ class RTCPSender final { // Full intra request uint8_t sequence_number_fir_ RTC_GUARDED_BY(mutex_rtcp_sender_); - // Loss Notification - struct LossNotificationState { - uint16_t last_decoded_seq_num; - uint16_t last_received_seq_num; - bool decodability_flag; - }; - LossNotificationState loss_notification_state_ - RTC_GUARDED_BY(mutex_rtcp_sender_); + rtcp::LossNotification loss_notification_ RTC_GUARDED_BY(mutex_rtcp_sender_); // REMB int64_t remb_bitrate_ RTC_GUARDED_BY(mutex_rtcp_sender_); @@ -281,8 +302,6 @@ class RTCPSender final { void SetFlag(uint32_t type, bool is_volatile) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); - void SetFlags(const std::set& types, bool is_volatile) - RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); bool IsFlagPresent(uint32_t type) const RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_rtcp_sender_); bool ConsumeFlag(uint32_t type, bool forced = false) @@ -300,8 +319,7 @@ class RTCPSender final { std::set report_flags_ RTC_GUARDED_BY(mutex_rtcp_sender_); - typedef std::unique_ptr (RTCPSender::*BuilderFunc)( - const RtcpContext&); + typedef void (RTCPSender::*BuilderFunc)(const RtcpContext&, PacketSender&); // Map from RTCPPacketType to builder. std::map builders_; }; diff --git a/modules/rtp_rtcp/source/rtcp_sender_unittest.cc b/modules/rtp_rtcp/source/rtcp_sender_unittest.cc index 4c8038fd04..347be79398 100644 --- a/modules/rtp_rtcp/source/rtcp_sender_unittest.cc +++ b/modules/rtp_rtcp/source/rtcp_sender_unittest.cc @@ -14,12 +14,12 @@ #include #include "absl/base/macros.h" +#include "api/units/time_delta.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" #include "modules/rtp_rtcp/source/rtcp_packet/bye.h" #include "modules/rtp_rtcp/source/rtcp_packet/common_header.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "modules/rtp_rtcp/source/rtp_rtcp_impl2.h" -#include "modules/rtp_rtcp/source/time_util.h" #include "rtc_base/rate_limiter.h" #include "test/gmock.h" #include "test/gtest.h" @@ -28,7 +28,9 @@ using ::testing::_; using ::testing::ElementsAre; +using ::testing::Eq; using ::testing::Invoke; +using ::testing::Property; using ::testing::SizeIs; namespace webrtc { @@ -70,43 +72,50 @@ static const uint32_t kStartRtpTimestamp = 0x34567; static const uint32_t kRtpTimestamp = 0x45678; std::unique_ptr CreateRtcpSender( - const RtpRtcpInterface::Configuration& config, + const RTCPSender::Configuration& config, bool init_timestamps = true) { auto rtcp_sender = std::make_unique(config); rtcp_sender->SetRemoteSSRC(kRemoteSsrc); if (init_timestamps) { rtcp_sender->SetTimestampOffset(kStartRtpTimestamp); - rtcp_sender->SetLastRtpTime(kRtpTimestamp, - config.clock->TimeInMilliseconds(), + rtcp_sender->SetLastRtpTime(kRtpTimestamp, config.clock->CurrentTime(), /*payload_type=*/0); } return rtcp_sender; } - } // namespace class RtcpSenderTest : public ::testing::Test { protected: RtcpSenderTest() : clock_(1335900000), - receive_statistics_(ReceiveStatistics::Create(&clock_)), - retransmission_rate_limiter_(&clock_, 1000) { - RtpRtcpInterface::Configuration configuration = GetDefaultConfig(); - rtp_rtcp_impl_.reset(new ModuleRtpRtcpImpl2(configuration)); + receive_statistics_(ReceiveStatistics::Create(&clock_)) { + rtp_rtcp_impl_.reset(new ModuleRtpRtcpImpl2(GetDefaultRtpRtcpConfig())); } - RtpRtcpInterface::Configuration GetDefaultConfig() { - RtpRtcpInterface::Configuration configuration; + RTCPSender::Configuration GetDefaultConfig() { + RTCPSender::Configuration configuration; configuration.audio = false; configuration.clock = &clock_; configuration.outgoing_transport = &test_transport_; - configuration.retransmission_rate_limiter = &retransmission_rate_limiter_; - configuration.rtcp_report_interval_ms = 1000; + configuration.rtcp_report_interval = TimeDelta::Millis(1000); configuration.receive_statistics = receive_statistics_.get(); configuration.local_media_ssrc = kSenderSsrc; return configuration; } + RtpRtcpInterface::Configuration GetDefaultRtpRtcpConfig() { + RTCPSender::Configuration config = GetDefaultConfig(); + RtpRtcpInterface::Configuration result; + result.audio = config.audio; + result.clock = config.clock; + result.outgoing_transport = config.outgoing_transport; + result.rtcp_report_interval_ms = config.rtcp_report_interval->ms(); + result.receive_statistics = config.receive_statistics; + result.local_media_ssrc = config.local_media_ssrc; + return result; + } + void InsertIncomingPacket(uint32_t remote_ssrc, uint16_t seq_num) { RtpPacketReceived packet; packet.SetSsrc(remote_ssrc); @@ -126,7 +135,6 @@ class RtcpSenderTest : public ::testing::Test { TestTransport test_transport_; std::unique_ptr receive_statistics_; std::unique_ptr rtp_rtcp_impl_; - RateLimiter retransmission_rate_limiter_; }; TEST_F(RtcpSenderTest, SetRtcpStatus) { @@ -139,7 +147,7 @@ TEST_F(RtcpSenderTest, SetRtcpStatus) { TEST_F(RtcpSenderTest, SetSendingStatus) { auto rtcp_sender = CreateRtcpSender(GetDefaultConfig()); EXPECT_FALSE(rtcp_sender->Sending()); - EXPECT_EQ(0, rtcp_sender->SetSendingStatus(feedback_state(), true)); + rtcp_sender->SetSendingStatus(feedback_state(), true); EXPECT_TRUE(rtcp_sender->Sending()); } @@ -158,7 +166,7 @@ TEST_F(RtcpSenderTest, SendSr) { rtcp_sender->SetSendingStatus(feedback_state, true); feedback_state.packets_sent = kPacketCount; feedback_state.media_bytes_sent = kOctetCount; - NtpTime ntp = TimeMicrosToNtp(clock_.TimeInMicroseconds()); + NtpTime ntp = clock_.CurrentNtpTime(); EXPECT_EQ(0, rtcp_sender->SendRTCP(feedback_state, kRtcpSr)); EXPECT_EQ(1, parser()->sender_report()->num_packets()); EXPECT_EQ(kSenderSsrc, parser()->sender_report()->sender_ssrc()); @@ -205,11 +213,11 @@ TEST_F(RtcpSenderTest, SendConsecutiveSrWithExactSlope) { } TEST_F(RtcpSenderTest, DoNotSendSrBeforeRtp) { - RtpRtcpInterface::Configuration config; + RTCPSender::Configuration config; config.clock = &clock_; config.receive_statistics = receive_statistics_.get(); config.outgoing_transport = &test_transport_; - config.rtcp_report_interval_ms = 1000; + config.rtcp_report_interval = TimeDelta::Millis(1000); config.local_media_ssrc = kSenderSsrc; auto rtcp_sender = CreateRtcpSender(config, /*init_timestamps=*/false); rtcp_sender->SetRTCPStatus(RtcpMode::kReducedSize); @@ -226,11 +234,11 @@ TEST_F(RtcpSenderTest, DoNotSendSrBeforeRtp) { } TEST_F(RtcpSenderTest, DoNotSendCompundBeforeRtp) { - RtpRtcpInterface::Configuration config; + RTCPSender::Configuration config; config.clock = &clock_; config.receive_statistics = receive_statistics_.get(); config.outgoing_transport = &test_transport_; - config.rtcp_report_interval_ms = 1000; + config.rtcp_report_interval = TimeDelta::Millis(1000); config.local_media_ssrc = kSenderSsrc; auto rtcp_sender = CreateRtcpSender(config, /*init_timestamps=*/false); rtcp_sender->SetRTCPStatus(RtcpMode::kCompound); @@ -276,11 +284,11 @@ TEST_F(RtcpSenderTest, SendRrWithTwoReportBlocks) { EXPECT_EQ(0, rtcp_sender->SendRTCP(feedback_state(), kRtcpRr)); EXPECT_EQ(1, parser()->receiver_report()->num_packets()); EXPECT_EQ(kSenderSsrc, parser()->receiver_report()->sender_ssrc()); - EXPECT_EQ(2U, parser()->receiver_report()->report_blocks().size()); - EXPECT_EQ(kRemoteSsrc, - parser()->receiver_report()->report_blocks()[0].source_ssrc()); - EXPECT_EQ(kRemoteSsrc + 1, - parser()->receiver_report()->report_blocks()[1].source_ssrc()); + EXPECT_THAT( + parser()->receiver_report()->report_blocks(), + UnorderedElementsAre( + Property(&rtcp::ReportBlock::source_ssrc, Eq(kRemoteSsrc)), + Property(&rtcp::ReportBlock::source_ssrc, Eq(kRemoteSsrc + 1)))); } TEST_F(RtcpSenderTest, SendSdes) { @@ -294,20 +302,6 @@ TEST_F(RtcpSenderTest, SendSdes) { EXPECT_EQ("alice@host", parser()->sdes()->chunks()[0].cname); } -TEST_F(RtcpSenderTest, SendSdesWithMaxChunks) { - auto rtcp_sender = CreateRtcpSender(GetDefaultConfig()); - rtcp_sender->SetRTCPStatus(RtcpMode::kReducedSize); - EXPECT_EQ(0, rtcp_sender->SetCNAME("alice@host")); - const char cname[] = "smith@host"; - for (size_t i = 0; i < 30; ++i) { - const uint32_t csrc = 0x1234 + i; - EXPECT_EQ(0, rtcp_sender->AddMixedCNAME(csrc, cname)); - } - EXPECT_EQ(0, rtcp_sender->SendRTCP(feedback_state(), kRtcpSdes)); - EXPECT_EQ(1, parser()->sdes()->num_packets()); - EXPECT_EQ(31U, parser()->sdes()->chunks().size()); -} - TEST_F(RtcpSenderTest, SdesIncludedInCompoundPacket) { auto rtcp_sender = CreateRtcpSender(GetDefaultConfig()); rtcp_sender->SetRTCPStatus(RtcpMode::kCompound); @@ -329,8 +323,8 @@ TEST_F(RtcpSenderTest, SendBye) { TEST_F(RtcpSenderTest, StopSendingTriggersBye) { auto rtcp_sender = CreateRtcpSender(GetDefaultConfig()); rtcp_sender->SetRTCPStatus(RtcpMode::kReducedSize); - EXPECT_EQ(0, rtcp_sender->SetSendingStatus(feedback_state(), true)); - EXPECT_EQ(0, rtcp_sender->SetSendingStatus(feedback_state(), false)); + rtcp_sender->SetSendingStatus(feedback_state(), true); + rtcp_sender->SetSendingStatus(feedback_state(), false); EXPECT_EQ(1, parser()->bye()->num_packets()); EXPECT_EQ(kSenderSsrc, parser()->bye()->sender_ssrc()); } @@ -523,12 +517,12 @@ TEST_F(RtcpSenderTest, SendXrWithMultipleDlrrSubBlocks) { } TEST_F(RtcpSenderTest, SendXrWithRrtr) { - RtpRtcpInterface::Configuration config = GetDefaultConfig(); + RTCPSender::Configuration config = GetDefaultConfig(); config.non_sender_rtt_measurement = true; auto rtcp_sender = CreateRtcpSender(config); rtcp_sender->SetRTCPStatus(RtcpMode::kCompound); - EXPECT_EQ(0, rtcp_sender->SetSendingStatus(feedback_state(), false)); - NtpTime ntp = TimeMicrosToNtp(clock_.TimeInMicroseconds()); + rtcp_sender->SetSendingStatus(feedback_state(), false); + NtpTime ntp = clock_.CurrentNtpTime(); EXPECT_EQ(0, rtcp_sender->SendRTCP(feedback_state(), kRtcpReport)); EXPECT_EQ(1, parser()->xr()->num_packets()); EXPECT_EQ(kSenderSsrc, parser()->xr()->sender_ssrc()); @@ -538,33 +532,33 @@ TEST_F(RtcpSenderTest, SendXrWithRrtr) { } TEST_F(RtcpSenderTest, TestNoXrRrtrSentIfSending) { - RtpRtcpInterface::Configuration config = GetDefaultConfig(); + RTCPSender::Configuration config = GetDefaultConfig(); config.non_sender_rtt_measurement = true; auto rtcp_sender = CreateRtcpSender(config); rtcp_sender->SetRTCPStatus(RtcpMode::kCompound); - EXPECT_EQ(0, rtcp_sender->SetSendingStatus(feedback_state(), true)); + rtcp_sender->SetSendingStatus(feedback_state(), true); EXPECT_EQ(0, rtcp_sender->SendRTCP(feedback_state(), kRtcpReport)); EXPECT_EQ(0, parser()->xr()->num_packets()); } TEST_F(RtcpSenderTest, TestNoXrRrtrSentIfNotEnabled) { - RtpRtcpInterface::Configuration config = GetDefaultConfig(); + RTCPSender::Configuration config = GetDefaultConfig(); config.non_sender_rtt_measurement = false; auto rtcp_sender = CreateRtcpSender(config); rtcp_sender->SetRTCPStatus(RtcpMode::kCompound); - EXPECT_EQ(0, rtcp_sender->SetSendingStatus(feedback_state(), false)); + rtcp_sender->SetSendingStatus(feedback_state(), false); EXPECT_EQ(0, rtcp_sender->SendRTCP(feedback_state(), kRtcpReport)); EXPECT_EQ(0, parser()->xr()->num_packets()); } TEST_F(RtcpSenderTest, TestRegisterRtcpPacketTypeObserver) { RtcpPacketTypeCounterObserverImpl observer; - RtpRtcpInterface::Configuration config; + RTCPSender::Configuration config; config.clock = &clock_; config.receive_statistics = receive_statistics_.get(); config.outgoing_transport = &test_transport_; config.rtcp_packet_type_counter_observer = &observer; - config.rtcp_report_interval_ms = 1000; + config.rtcp_report_interval = TimeDelta::Millis(1000); auto rtcp_sender = CreateRtcpSender(config); rtcp_sender->SetRTCPStatus(RtcpMode::kReducedSize); EXPECT_EQ(0, rtcp_sender->SendRTCP(feedback_state(), kRtcpPli)); @@ -588,25 +582,6 @@ TEST_F(RtcpSenderTest, SendTmmbr) { // TODO(asapersson): tmmbr_item()->Overhead() looks broken, always zero. } -TEST_F(RtcpSenderTest, TmmbrIncludedInCompoundPacketIfEnabled) { - const unsigned int kBitrateBps = 312000; - auto rtcp_sender = CreateRtcpSender(GetDefaultConfig()); - rtcp_sender->SetRTCPStatus(RtcpMode::kCompound); - EXPECT_FALSE(rtcp_sender->TMMBR()); - rtcp_sender->SetTMMBRStatus(true); - EXPECT_TRUE(rtcp_sender->TMMBR()); - rtcp_sender->SetTargetBitrate(kBitrateBps); - EXPECT_EQ(0, rtcp_sender->SendRTCP(feedback_state(), kRtcpReport)); - EXPECT_EQ(1, parser()->tmmbr()->num_packets()); - EXPECT_EQ(1U, parser()->tmmbr()->requests().size()); - // TMMBR should be included in each compound packet. - EXPECT_EQ(0, rtcp_sender->SendRTCP(feedback_state(), kRtcpReport)); - EXPECT_EQ(2, parser()->tmmbr()->num_packets()); - - rtcp_sender->SetTMMBRStatus(false); - EXPECT_FALSE(rtcp_sender->TMMBR()); -} - TEST_F(RtcpSenderTest, SendTmmbn) { auto rtcp_sender = CreateRtcpSender(GetDefaultConfig()); rtcp_sender->SetRTCPStatus(RtcpMode::kCompound); @@ -648,21 +623,6 @@ TEST_F(RtcpSenderTest, SendsTmmbnIfSetAndEmpty) { EXPECT_EQ(0U, parser()->tmmbn()->items().size()); } -TEST_F(RtcpSenderTest, SendCompoundPliRemb) { - const int kBitrate = 261011; - auto rtcp_sender = CreateRtcpSender(GetDefaultConfig()); - std::vector ssrcs; - ssrcs.push_back(kRemoteSsrc); - rtcp_sender->SetRTCPStatus(RtcpMode::kCompound); - rtcp_sender->SetRemb(kBitrate, ssrcs); - std::set packet_types; - packet_types.insert(kRtcpRemb); - packet_types.insert(kRtcpPli); - EXPECT_EQ(0, rtcp_sender->SendCompoundRTCP(feedback_state(), packet_types)); - EXPECT_EQ(1, parser()->remb()->num_packets()); - EXPECT_EQ(1, parser()->pli()->num_packets()); -} - // This test is written to verify that BYE is always the last packet // type in a RTCP compoud packet. The rtcp_sender is recreated with // mock_transport, which is used to check for whether BYE at the end @@ -690,16 +650,16 @@ TEST_F(RtcpSenderTest, ByeMustBeLast) { })); // Re-configure rtcp_sender with mock_transport_ - RtpRtcpInterface::Configuration config; + RTCPSender::Configuration config; config.clock = &clock_; config.receive_statistics = receive_statistics_.get(); config.outgoing_transport = &mock_transport; - config.rtcp_report_interval_ms = 1000; + config.rtcp_report_interval = TimeDelta::Millis(1000); config.local_media_ssrc = kSenderSsrc; auto rtcp_sender = CreateRtcpSender(config); rtcp_sender->SetTimestampOffset(kStartRtpTimestamp); - rtcp_sender->SetLastRtpTime(kRtpTimestamp, clock_.TimeInMilliseconds(), + rtcp_sender->SetLastRtpTime(kRtpTimestamp, clock_.CurrentTime(), /*payload_type=*/0); // Set up REMB info to be included with BYE. diff --git a/modules/rtp_rtcp/source/rtcp_transceiver.cc b/modules/rtp_rtcp/source/rtcp_transceiver.cc index 1de581849b..41fa5e6206 100644 --- a/modules/rtp_rtcp/source/rtcp_transceiver.cc +++ b/modules/rtp_rtcp/source/rtcp_transceiver.cc @@ -14,6 +14,7 @@ #include #include +#include "api/units/timestamp.h" #include "modules/rtp_rtcp/source/rtcp_packet/transport_feedback.h" #include "rtc_base/checks.h" #include "rtc_base/event.h" @@ -23,7 +24,8 @@ namespace webrtc { RtcpTransceiver::RtcpTransceiver(const RtcpTransceiverConfig& config) - : task_queue_(config.task_queue), + : clock_(config.clock), + task_queue_(config.task_queue), rtcp_transceiver_(std::make_unique(config)) { RTC_DCHECK(task_queue_); } @@ -82,9 +84,9 @@ void RtcpTransceiver::SetReadyToSend(bool ready) { void RtcpTransceiver::ReceivePacket(rtc::CopyOnWriteBuffer packet) { RTC_CHECK(rtcp_transceiver_); RtcpTransceiverImpl* ptr = rtcp_transceiver_.get(); - int64_t now_us = rtc::TimeMicros(); - task_queue_->PostTask(ToQueuedTask( - [ptr, packet, now_us] { ptr->ReceivePacket(packet, now_us); })); + Timestamp now = clock_->CurrentTime(); + task_queue_->PostTask( + ToQueuedTask([ptr, packet, now] { ptr->ReceivePacket(packet, now); })); } void RtcpTransceiver::SendCompoundPacket() { diff --git a/modules/rtp_rtcp/source/rtcp_transceiver.h b/modules/rtp_rtcp/source/rtcp_transceiver.h index 2d1f37cd44..52f4610716 100644 --- a/modules/rtp_rtcp/source/rtcp_transceiver.h +++ b/modules/rtp_rtcp/source/rtcp_transceiver.h @@ -20,6 +20,7 @@ #include "modules/rtp_rtcp/source/rtcp_transceiver_config.h" #include "modules/rtp_rtcp/source/rtcp_transceiver_impl.h" #include "rtc_base/copy_on_write_buffer.h" +#include "system_wrappers/include/clock.h" namespace webrtc { // @@ -93,6 +94,7 @@ class RtcpTransceiver : public RtcpFeedbackSenderInterface { void SendFullIntraRequest(std::vector ssrcs, bool new_request); private: + Clock* const clock_; TaskQueueBase* const task_queue_; std::unique_ptr rtcp_transceiver_; }; diff --git a/modules/rtp_rtcp/source/rtcp_transceiver_config.h b/modules/rtp_rtcp/source/rtcp_transceiver_config.h index 8a8fd6aed8..0501b9af7f 100644 --- a/modules/rtp_rtcp/source/rtcp_transceiver_config.h +++ b/modules/rtp_rtcp/source/rtcp_transceiver_config.h @@ -17,6 +17,7 @@ #include "api/task_queue/task_queue_base.h" #include "api/video/video_bitrate_allocation.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" +#include "system_wrappers/include/clock.h" #include "system_wrappers/include/ntp_time.h" namespace webrtc { @@ -61,6 +62,9 @@ struct RtcpTransceiverConfig { // Maximum packet size outgoing transport accepts. size_t max_packet_size = 1200; + // The clock to use when querying for the NTP time. Should be set. + Clock* clock = nullptr; + // Transport to send rtcp packets to. Should be set. Transport* outgoing_transport = nullptr; diff --git a/modules/rtp_rtcp/source/rtcp_transceiver_impl.cc b/modules/rtp_rtcp/source/rtcp_transceiver_impl.cc index 0102616d59..5753ffd692 100644 --- a/modules/rtp_rtcp/source/rtcp_transceiver_impl.cc +++ b/modules/rtp_rtcp/source/rtcp_transceiver_impl.cc @@ -40,7 +40,7 @@ namespace webrtc { namespace { struct SenderReportTimes { - int64_t local_received_time_us; + Timestamp local_received_time; NtpTime remote_sent_time; }; @@ -92,9 +92,7 @@ RtcpTransceiverImpl::RtcpTransceiverImpl(const RtcpTransceiverConfig& config) : config_(config), ready_to_send_(config.initial_ready_to_send) { RTC_CHECK(config_.Validate()); if (ready_to_send_ && config_.schedule_periodic_compound_packets) { - config_.task_queue->PostTask(ToQueuedTask([this] { - SchedulePeriodicCompoundPackets(config_.initial_report_delay_ms); - })); + SchedulePeriodicCompoundPackets(config_.initial_report_delay_ms); } } @@ -133,13 +131,13 @@ void RtcpTransceiverImpl::SetReadyToSend(bool ready) { } void RtcpTransceiverImpl::ReceivePacket(rtc::ArrayView packet, - int64_t now_us) { + Timestamp now) { while (!packet.empty()) { rtcp::CommonHeader rtcp_block; if (!rtcp_block.Parse(packet.data(), packet.size())) return; - HandleReceivedPacket(rtcp_block, now_us); + HandleReceivedPacket(rtcp_block, now); // TODO(danilchap): Use packet.remove_prefix() when that function exists. packet = packet.subview(rtcp_block.packet_size()); @@ -228,16 +226,16 @@ void RtcpTransceiverImpl::SendFullIntraRequest( void RtcpTransceiverImpl::HandleReceivedPacket( const rtcp::CommonHeader& rtcp_packet_header, - int64_t now_us) { + Timestamp now) { switch (rtcp_packet_header.type()) { case rtcp::Bye::kPacketType: HandleBye(rtcp_packet_header); break; case rtcp::SenderReport::kPacketType: - HandleSenderReport(rtcp_packet_header, now_us); + HandleSenderReport(rtcp_packet_header, now); break; case rtcp::ExtendedReports::kPacketType: - HandleExtendedReports(rtcp_packet_header, now_us); + HandleExtendedReports(rtcp_packet_header, now); break; } } @@ -256,17 +254,14 @@ void RtcpTransceiverImpl::HandleBye( void RtcpTransceiverImpl::HandleSenderReport( const rtcp::CommonHeader& rtcp_packet_header, - int64_t now_us) { + Timestamp now) { rtcp::SenderReport sender_report; if (!sender_report.Parse(rtcp_packet_header)) return; RemoteSenderState& remote_sender = remote_senders_[sender_report.sender_ssrc()]; - absl::optional& last = - remote_sender.last_received_sender_report; - last.emplace(); - last->local_received_time_us = now_us; - last->remote_sent_time = sender_report.ntp(); + remote_sender.last_received_sender_report = + absl::optional({now, sender_report.ntp()}); for (MediaReceiverRtcpObserver* observer : remote_sender.observers) observer->OnSenderReport(sender_report.sender_ssrc(), sender_report.ntp(), @@ -275,26 +270,27 @@ void RtcpTransceiverImpl::HandleSenderReport( void RtcpTransceiverImpl::HandleExtendedReports( const rtcp::CommonHeader& rtcp_packet_header, - int64_t now_us) { + Timestamp now) { rtcp::ExtendedReports extended_reports; if (!extended_reports.Parse(rtcp_packet_header)) return; if (extended_reports.dlrr()) - HandleDlrr(extended_reports.dlrr(), now_us); + HandleDlrr(extended_reports.dlrr(), now); if (extended_reports.target_bitrate()) HandleTargetBitrate(*extended_reports.target_bitrate(), extended_reports.sender_ssrc()); } -void RtcpTransceiverImpl::HandleDlrr(const rtcp::Dlrr& dlrr, int64_t now_us) { +void RtcpTransceiverImpl::HandleDlrr(const rtcp::Dlrr& dlrr, Timestamp now) { if (!config_.non_sender_rtt_measurement || config_.rtt_observer == nullptr) return; // Delay and last_rr are transferred using 32bit compact ntp resolution. // Convert packet arrival time to same format through 64bit ntp format. - uint32_t receive_time_ntp = CompactNtp(TimeMicrosToNtp(now_us)); + uint32_t receive_time_ntp = + CompactNtp(config_.clock->ConvertTimestampToNtpTime(now)); for (const rtcp::ReceiveTimeInfo& rti : dlrr.sub_blocks()) { if (rti.ssrc != config_.feedback_ssrc) continue; @@ -353,13 +349,16 @@ void RtcpTransceiverImpl::SchedulePeriodicCompoundPackets(int64_t delay_ms) { void RtcpTransceiverImpl::CreateCompoundPacket(PacketSender* sender) { RTC_DCHECK(sender->IsEmpty()); const uint32_t sender_ssrc = config_.feedback_ssrc; - int64_t now_us = rtc::TimeMicros(); + Timestamp now = config_.clock->CurrentTime(); rtcp::ReceiverReport receiver_report; receiver_report.SetSenderSsrc(sender_ssrc); - receiver_report.SetReportBlocks(CreateReportBlocks(now_us)); - sender->AppendPacket(receiver_report); + receiver_report.SetReportBlocks(CreateReportBlocks(now)); + if (config_.rtcp_mode == RtcpMode::kCompound || + !receiver_report.report_blocks().empty()) { + sender->AppendPacket(receiver_report); + } - if (!config_.cname.empty()) { + if (!config_.cname.empty() && !sender->IsEmpty()) { rtcp::Sdes sdes; bool added = sdes.AddCName(config_.feedback_ssrc, config_.cname); RTC_DCHECK(added) << "Failed to add cname " << config_.cname @@ -377,7 +376,7 @@ void RtcpTransceiverImpl::CreateCompoundPacket(PacketSender* sender) { rtcp::ExtendedReports xr; rtcp::Rrtr rrtr; - rrtr.SetNtp(TimeMicrosToNtp(now_us)); + rrtr.SetNtp(config_.clock->ConvertTimestampToNtpTime(now)); xr.SetRrtr(rrtr); xr.SetSenderSsrc(sender_ssrc); @@ -428,7 +427,7 @@ void RtcpTransceiverImpl::SendImmediateFeedback( } std::vector RtcpTransceiverImpl::CreateReportBlocks( - int64_t now_us) { + Timestamp now) { if (!config_.receive_statistics) return {}; // TODO(danilchap): Support sending more than @@ -448,7 +447,7 @@ std::vector RtcpTransceiverImpl::CreateReportBlocks( *it->second.last_received_sender_report; last_sr = CompactNtp(last_sender_report.remote_sent_time); last_delay = SaturatedUsToCompactNtp( - now_us - last_sender_report.local_received_time_us); + now.us() - last_sender_report.local_received_time.us()); report_block.SetLastSr(last_sr); report_block.SetDelayLastSr(last_delay); } diff --git a/modules/rtp_rtcp/source/rtcp_transceiver_impl.h b/modules/rtp_rtcp/source/rtcp_transceiver_impl.h index 6a6454662c..bcdee83e56 100644 --- a/modules/rtp_rtcp/source/rtcp_transceiver_impl.h +++ b/modules/rtp_rtcp/source/rtcp_transceiver_impl.h @@ -18,6 +18,7 @@ #include "absl/types/optional.h" #include "api/array_view.h" +#include "api/units/timestamp.h" #include "modules/rtp_rtcp/source/rtcp_packet/common_header.h" #include "modules/rtp_rtcp/source/rtcp_packet/dlrr.h" #include "modules/rtp_rtcp/source/rtcp_packet/remb.h" @@ -48,7 +49,7 @@ class RtcpTransceiverImpl { void SetReadyToSend(bool ready); - void ReceivePacket(rtc::ArrayView packet, int64_t now_us); + void ReceivePacket(rtc::ArrayView packet, Timestamp now); void SendCompoundPacket(); @@ -76,15 +77,15 @@ class RtcpTransceiverImpl { struct RemoteSenderState; void HandleReceivedPacket(const rtcp::CommonHeader& rtcp_packet_header, - int64_t now_us); + Timestamp now); // Individual rtcp packet handlers. void HandleBye(const rtcp::CommonHeader& rtcp_packet_header); void HandleSenderReport(const rtcp::CommonHeader& rtcp_packet_header, - int64_t now_us); + Timestamp now); void HandleExtendedReports(const rtcp::CommonHeader& rtcp_packet_header, - int64_t now_us); + Timestamp now); // Extended Reports blocks handlers. - void HandleDlrr(const rtcp::Dlrr& dlrr, int64_t now_us); + void HandleDlrr(const rtcp::Dlrr& dlrr, Timestamp now); void HandleTargetBitrate(const rtcp::TargetBitrate& target_bitrate, uint32_t remote_ssrc); @@ -97,7 +98,7 @@ class RtcpTransceiverImpl { void SendPeriodicCompoundPacket(); void SendImmediateFeedback(const rtcp::RtcpPacket& rtcp_packet); // Generate Report Blocks to be send in Sender or Receiver Report. - std::vector CreateReportBlocks(int64_t now_us); + std::vector CreateReportBlocks(Timestamp now); const RtcpTransceiverConfig config_; diff --git a/modules/rtp_rtcp/source/rtcp_transceiver_impl_unittest.cc b/modules/rtp_rtcp/source/rtcp_transceiver_impl_unittest.cc index b7694df1e8..06e1083aa8 100644 --- a/modules/rtp_rtcp/source/rtcp_transceiver_impl_unittest.cc +++ b/modules/rtp_rtcp/source/rtcp_transceiver_impl_unittest.cc @@ -16,6 +16,8 @@ #include "absl/memory/memory.h" #include "api/rtp_headers.h" +#include "api/units/time_delta.h" +#include "api/units/timestamp.h" #include "api/video/video_bitrate_allocation.h" #include "modules/rtp_rtcp/include/receive_statistics.h" #include "modules/rtp_rtcp/mocks/mock_rtcp_rtt_stats.h" @@ -24,8 +26,9 @@ #include "modules/rtp_rtcp/source/rtcp_packet/compound_packet.h" #include "modules/rtp_rtcp/source/time_util.h" #include "rtc_base/event.h" -#include "rtc_base/fake_clock.h" #include "rtc_base/task_queue_for_test.h" +#include "rtc_base/time_utils.h" +#include "system_wrappers/include/clock.h" #include "test/gmock.h" #include "test/gtest.h" #include "test/mock_transport.h" @@ -35,6 +38,7 @@ namespace { using ::testing::_; using ::testing::ElementsAre; +using ::testing::NiceMock; using ::testing::Return; using ::testing::SizeIs; using ::testing::StrictMock; @@ -46,8 +50,10 @@ using ::webrtc::NtpTime; using ::webrtc::RtcpTransceiverConfig; using ::webrtc::RtcpTransceiverImpl; using ::webrtc::SaturatedUsToCompactNtp; +using ::webrtc::SimulatedClock; using ::webrtc::TaskQueueForTest; -using ::webrtc::TimeMicrosToNtp; +using ::webrtc::TimeDelta; +using ::webrtc::Timestamp; using ::webrtc::VideoBitrateAllocation; using ::webrtc::rtcp::Bye; using ::webrtc::rtcp::CompoundPacket; @@ -142,9 +148,11 @@ RtcpTransceiverConfig DefaultTestConfig() { } TEST(RtcpTransceiverImplTest, NeedToStopPeriodicTaskToDestroyOnTaskQueue) { + SimulatedClock clock(0); FakeRtcpTransport transport; TaskQueueForTest queue("rtcp"); RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; config.task_queue = queue.Get(); config.schedule_periodic_compound_packets = true; config.outgoing_transport = &transport; @@ -161,10 +169,31 @@ TEST(RtcpTransceiverImplTest, NeedToStopPeriodicTaskToDestroyOnTaskQueue) { ASSERT_TRUE(done.Wait(/*milliseconds=*/1000)); } +TEST(RtcpTransceiverImplTest, CanBeDestroyedRightAfterCreation) { + SimulatedClock clock(0); + FakeRtcpTransport transport; + TaskQueueForTest queue("rtcp"); + RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; + config.task_queue = queue.Get(); + config.schedule_periodic_compound_packets = true; + config.outgoing_transport = &transport; + + rtc::Event done; + queue.PostTask([&] { + RtcpTransceiverImpl rtcp_transceiver(config); + rtcp_transceiver.StopPeriodicTask(); + done.Set(); + }); + ASSERT_TRUE(done.Wait(/*milliseconds=*/1000)); +} + TEST(RtcpTransceiverImplTest, CanDestroyAfterTaskQueue) { + SimulatedClock clock(0); FakeRtcpTransport transport; auto* queue = new TaskQueueForTest("rtcp"); RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; config.task_queue = queue->Get(); config.schedule_periodic_compound_packets = true; config.outgoing_transport = &transport; @@ -177,9 +206,11 @@ TEST(RtcpTransceiverImplTest, CanDestroyAfterTaskQueue) { } TEST(RtcpTransceiverImplTest, DelaysSendingFirstCompondPacket) { + SimulatedClock clock(0); TaskQueueForTest queue("rtcp"); FakeRtcpTransport transport; RtcpTransceiverConfig config; + config.clock = &clock; config.outgoing_transport = &transport; config.initial_report_delay_ms = 10; config.task_queue = queue.Get(); @@ -202,9 +233,11 @@ TEST(RtcpTransceiverImplTest, DelaysSendingFirstCompondPacket) { } TEST(RtcpTransceiverImplTest, PeriodicallySendsPackets) { + SimulatedClock clock(0); TaskQueueForTest queue("rtcp"); FakeRtcpTransport transport; RtcpTransceiverConfig config; + config.clock = &clock; config.outgoing_transport = &transport; config.initial_report_delay_ms = 0; config.report_period_ms = kReportPeriodMs; @@ -236,9 +269,11 @@ TEST(RtcpTransceiverImplTest, PeriodicallySendsPackets) { } TEST(RtcpTransceiverImplTest, SendCompoundPacketDelaysPeriodicSendPackets) { + SimulatedClock clock(0); TaskQueueForTest queue("rtcp"); FakeRtcpTransport transport; RtcpTransceiverConfig config; + config.clock = &clock; config.outgoing_transport = &transport; config.initial_report_delay_ms = 0; config.report_period_ms = kReportPeriodMs; @@ -282,8 +317,10 @@ TEST(RtcpTransceiverImplTest, SendCompoundPacketDelaysPeriodicSendPackets) { } TEST(RtcpTransceiverImplTest, SendsNoRtcpWhenNetworkStateIsDown) { + SimulatedClock clock(0); MockTransport mock_transport; RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; config.initial_ready_to_send = false; config.outgoing_transport = &mock_transport; RtcpTransceiverImpl rtcp_transceiver(config); @@ -301,8 +338,10 @@ TEST(RtcpTransceiverImplTest, SendsNoRtcpWhenNetworkStateIsDown) { } TEST(RtcpTransceiverImplTest, SendsRtcpWhenNetworkStateIsUp) { + SimulatedClock clock(0); MockTransport mock_transport; RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; config.initial_ready_to_send = false; config.outgoing_transport = &mock_transport; RtcpTransceiverImpl rtcp_transceiver(config); @@ -322,9 +361,11 @@ TEST(RtcpTransceiverImplTest, SendsRtcpWhenNetworkStateIsUp) { } TEST(RtcpTransceiverImplTest, SendsPeriodicRtcpWhenNetworkStateIsUp) { + SimulatedClock clock(0); TaskQueueForTest queue("rtcp"); FakeRtcpTransport transport; RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; config.schedule_periodic_compound_packets = true; config.initial_ready_to_send = false; config.outgoing_transport = &transport; @@ -348,7 +389,9 @@ TEST(RtcpTransceiverImplTest, SendsPeriodicRtcpWhenNetworkStateIsUp) { TEST(RtcpTransceiverImplTest, SendsMinimalCompoundPacket) { const uint32_t kSenderSsrc = 12345; + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; config.cname = "cname"; RtcpPacketParser rtcp_parser; @@ -369,9 +412,52 @@ TEST(RtcpTransceiverImplTest, SendsMinimalCompoundPacket) { EXPECT_EQ(rtcp_parser.sdes()->chunks()[0].cname, config.cname); } +TEST(RtcpTransceiverImplTest, AvoidsEmptyPacketsInReducedMode) { + MockTransport transport; + EXPECT_CALL(transport, SendRtcp).Times(0); + NiceMock receive_statistics; + SimulatedClock clock(0); + + RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; + config.outgoing_transport = &transport; + config.rtcp_mode = webrtc::RtcpMode::kReducedSize; + config.schedule_periodic_compound_packets = false; + config.receive_statistics = &receive_statistics; + RtcpTransceiverImpl rtcp_transceiver(config); + + rtcp_transceiver.SendCompoundPacket(); +} + +TEST(RtcpTransceiverImplTest, AvoidsEmptyReceiverReportsInReducedMode) { + RtcpPacketParser rtcp_parser; + RtcpParserTransport transport(&rtcp_parser); + NiceMock receive_statistics; + SimulatedClock clock(0); + + RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; + config.outgoing_transport = &transport; + config.rtcp_mode = webrtc::RtcpMode::kReducedSize; + config.schedule_periodic_compound_packets = false; + config.receive_statistics = &receive_statistics; + // Set it to produce something (RRTR) in the "periodic" rtcp packets. + config.non_sender_rtt_measurement = true; + RtcpTransceiverImpl rtcp_transceiver(config); + + // Rather than waiting for the right time to produce the periodic packet, + // trigger it manually. + rtcp_transceiver.SendCompoundPacket(); + + EXPECT_EQ(rtcp_parser.receiver_report()->num_packets(), 0); + EXPECT_GT(rtcp_parser.xr()->num_packets(), 0); +} + TEST(RtcpTransceiverImplTest, SendsNoRembInitially) { const uint32_t kSenderSsrc = 12345; + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; RtcpPacketParser rtcp_parser; RtcpParserTransport transport(&rtcp_parser); @@ -387,7 +473,9 @@ TEST(RtcpTransceiverImplTest, SendsNoRembInitially) { TEST(RtcpTransceiverImplTest, SetRembIncludesRembInNextCompoundPacket) { const uint32_t kSenderSsrc = 12345; + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; RtcpPacketParser rtcp_parser; RtcpParserTransport transport(&rtcp_parser); @@ -406,7 +494,9 @@ TEST(RtcpTransceiverImplTest, SetRembIncludesRembInNextCompoundPacket) { TEST(RtcpTransceiverImplTest, SetRembUpdatesValuesToSend) { const uint32_t kSenderSsrc = 12345; + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; RtcpPacketParser rtcp_parser; RtcpParserTransport transport(&rtcp_parser); @@ -431,7 +521,9 @@ TEST(RtcpTransceiverImplTest, SetRembUpdatesValuesToSend) { TEST(RtcpTransceiverImplTest, SetRembSendsImmediatelyIfSendRembOnChange) { const uint32_t kSenderSsrc = 12345; + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.send_remb_on_change = true; config.feedback_ssrc = kSenderSsrc; RtcpPacketParser rtcp_parser; @@ -457,7 +549,9 @@ TEST(RtcpTransceiverImplTest, SetRembSendsImmediatelyIfSendRembOnChange) { TEST(RtcpTransceiverImplTest, SetRembSendsImmediatelyIfSendRembOnChangeReducedSize) { const uint32_t kSenderSsrc = 12345; + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.send_remb_on_change = true; config.rtcp_mode = webrtc::RtcpMode::kReducedSize; config.feedback_ssrc = kSenderSsrc; @@ -475,7 +569,9 @@ TEST(RtcpTransceiverImplTest, TEST(RtcpTransceiverImplTest, SetRembIncludesRembInAllCompoundPackets) { const uint32_t kSenderSsrc = 12345; + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; RtcpPacketParser rtcp_parser; RtcpParserTransport transport(&rtcp_parser); @@ -493,7 +589,9 @@ TEST(RtcpTransceiverImplTest, SetRembIncludesRembInAllCompoundPackets) { TEST(RtcpTransceiverImplTest, SendsNoRembAfterUnset) { const uint32_t kSenderSsrc = 12345; + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; RtcpPacketParser rtcp_parser; RtcpParserTransport transport(&rtcp_parser); @@ -522,7 +620,9 @@ TEST(RtcpTransceiverImplTest, ReceiverReportUsesReceiveStatistics) { EXPECT_CALL(receive_statistics, RtcpReportBlocks(_)) .WillRepeatedly(Return(report_blocks)); + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; RtcpPacketParser rtcp_parser; RtcpParserTransport transport(&rtcp_parser); @@ -543,9 +643,12 @@ TEST(RtcpTransceiverImplTest, ReceiverReportUsesReceiveStatistics) { TEST(RtcpTransceiverImplTest, MultipleObserversOnSameSsrc) { const uint32_t kRemoteSsrc = 12345; + SimulatedClock clock(0); StrictMock observer1; StrictMock observer2; - RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); + RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; + RtcpTransceiverImpl rtcp_transceiver(config); rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, &observer1); rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, &observer2); @@ -559,14 +662,17 @@ TEST(RtcpTransceiverImplTest, MultipleObserversOnSameSsrc) { EXPECT_CALL(observer1, OnSenderReport(kRemoteSsrc, kRemoteNtp, kRemoteRtp)); EXPECT_CALL(observer2, OnSenderReport(kRemoteSsrc, kRemoteNtp, kRemoteRtp)); - rtcp_transceiver.ReceivePacket(raw_packet, /*now_us=*/0); + rtcp_transceiver.ReceivePacket(raw_packet, Timestamp::Micros(0)); } TEST(RtcpTransceiverImplTest, DoesntCallsObserverAfterRemoved) { const uint32_t kRemoteSsrc = 12345; + SimulatedClock clock(0); StrictMock observer1; StrictMock observer2; - RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); + RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; + RtcpTransceiverImpl rtcp_transceiver(config); rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, &observer1); rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, &observer2); @@ -578,15 +684,18 @@ TEST(RtcpTransceiverImplTest, DoesntCallsObserverAfterRemoved) { EXPECT_CALL(observer1, OnSenderReport(_, _, _)).Times(0); EXPECT_CALL(observer2, OnSenderReport(_, _, _)); - rtcp_transceiver.ReceivePacket(raw_packet, /*now_us=*/0); + rtcp_transceiver.ReceivePacket(raw_packet, Timestamp::Micros(0)); } TEST(RtcpTransceiverImplTest, CallsObserverOnSenderReportBySenderSsrc) { const uint32_t kRemoteSsrc1 = 12345; const uint32_t kRemoteSsrc2 = 22345; + SimulatedClock clock(0); StrictMock observer1; StrictMock observer2; - RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); + RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; + RtcpTransceiverImpl rtcp_transceiver(config); rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc1, &observer1); rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc2, &observer2); @@ -600,15 +709,18 @@ TEST(RtcpTransceiverImplTest, CallsObserverOnSenderReportBySenderSsrc) { EXPECT_CALL(observer1, OnSenderReport(kRemoteSsrc1, kRemoteNtp, kRemoteRtp)); EXPECT_CALL(observer2, OnSenderReport(_, _, _)).Times(0); - rtcp_transceiver.ReceivePacket(raw_packet, /*now_us=*/0); + rtcp_transceiver.ReceivePacket(raw_packet, Timestamp::Micros(0)); } TEST(RtcpTransceiverImplTest, CallsObserverOnByeBySenderSsrc) { const uint32_t kRemoteSsrc1 = 12345; const uint32_t kRemoteSsrc2 = 22345; + SimulatedClock clock(0); StrictMock observer1; StrictMock observer2; - RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); + RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; + RtcpTransceiverImpl rtcp_transceiver(config); rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc1, &observer1); rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc2, &observer2); @@ -618,15 +730,18 @@ TEST(RtcpTransceiverImplTest, CallsObserverOnByeBySenderSsrc) { EXPECT_CALL(observer1, OnBye(kRemoteSsrc1)); EXPECT_CALL(observer2, OnBye(_)).Times(0); - rtcp_transceiver.ReceivePacket(raw_packet, /*now_us=*/0); + rtcp_transceiver.ReceivePacket(raw_packet, Timestamp::Micros(0)); } TEST(RtcpTransceiverImplTest, CallsObserverOnTargetBitrateBySenderSsrc) { const uint32_t kRemoteSsrc1 = 12345; const uint32_t kRemoteSsrc2 = 22345; + SimulatedClock clock(0); StrictMock observer1; StrictMock observer2; - RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); + RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; + RtcpTransceiverImpl rtcp_transceiver(config); rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc1, &observer1); rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc2, &observer2); @@ -647,13 +762,16 @@ TEST(RtcpTransceiverImplTest, CallsObserverOnTargetBitrateBySenderSsrc) { bitrate_allocation.SetBitrate(1, 1, /*bitrate_bps=*/80000); EXPECT_CALL(observer1, OnBitrateAllocation(kRemoteSsrc1, bitrate_allocation)); EXPECT_CALL(observer2, OnBitrateAllocation(_, _)).Times(0); - rtcp_transceiver.ReceivePacket(raw_packet, /*now_us=*/0); + rtcp_transceiver.ReceivePacket(raw_packet, Timestamp::Micros(0)); } TEST(RtcpTransceiverImplTest, SkipsIncorrectTargetBitrateEntries) { const uint32_t kRemoteSsrc = 12345; + SimulatedClock clock(0); MockMediaReceiverRtcpObserver observer; - RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); + RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; + RtcpTransceiverImpl rtcp_transceiver(config); rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, &observer); webrtc::rtcp::TargetBitrate target_bitrate; @@ -669,13 +787,16 @@ TEST(RtcpTransceiverImplTest, SkipsIncorrectTargetBitrateEntries) { VideoBitrateAllocation expected_allocation; expected_allocation.SetBitrate(0, 0, /*bitrate_bps=*/10000); EXPECT_CALL(observer, OnBitrateAllocation(kRemoteSsrc, expected_allocation)); - rtcp_transceiver.ReceivePacket(raw_packet, /*now_us=*/0); + rtcp_transceiver.ReceivePacket(raw_packet, Timestamp::Micros(0)); } TEST(RtcpTransceiverImplTest, CallsObserverOnByeBehindSenderReport) { const uint32_t kRemoteSsrc = 12345; + SimulatedClock clock(0); MockMediaReceiverRtcpObserver observer; - RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); + RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; + RtcpTransceiverImpl rtcp_transceiver(config); rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, &observer); CompoundPacket compound; @@ -689,13 +810,16 @@ TEST(RtcpTransceiverImplTest, CallsObserverOnByeBehindSenderReport) { EXPECT_CALL(observer, OnBye(kRemoteSsrc)); EXPECT_CALL(observer, OnSenderReport(kRemoteSsrc, _, _)); - rtcp_transceiver.ReceivePacket(raw_packet, /*now_us=*/0); + rtcp_transceiver.ReceivePacket(raw_packet, Timestamp::Micros(0)); } TEST(RtcpTransceiverImplTest, CallsObserverOnByeBehindUnknownRtcpPacket) { const uint32_t kRemoteSsrc = 12345; + SimulatedClock clock(0); MockMediaReceiverRtcpObserver observer; - RtcpTransceiverImpl rtcp_transceiver(DefaultTestConfig()); + RtcpTransceiverConfig config = DefaultTestConfig(); + config.clock = &clock; + RtcpTransceiverImpl rtcp_transceiver(config); rtcp_transceiver.AddMediaReceiverRtcpObserver(kRemoteSsrc, &observer); CompoundPacket compound; @@ -708,7 +832,7 @@ TEST(RtcpTransceiverImplTest, CallsObserverOnByeBehindUnknownRtcpPacket) { auto raw_packet = compound.Build(); EXPECT_CALL(observer, OnBye(kRemoteSsrc)); - rtcp_transceiver.ReceivePacket(raw_packet, /*now_us=*/0); + rtcp_transceiver.ReceivePacket(raw_packet, Timestamp::Micros(0)); } TEST(RtcpTransceiverImplTest, @@ -722,7 +846,9 @@ TEST(RtcpTransceiverImplTest, EXPECT_CALL(receive_statistics, RtcpReportBlocks(_)) .WillOnce(Return(statistics_report_blocks)); + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.schedule_periodic_compound_packets = false; RtcpPacketParser rtcp_parser; RtcpParserTransport transport(&rtcp_parser); @@ -736,7 +862,7 @@ TEST(RtcpTransceiverImplTest, sr.SetSenderSsrc(kRemoteSsrc1); sr.SetNtp(kRemoteNtp); auto raw_packet = sr.Build(); - rtcp_transceiver.ReceivePacket(raw_packet, /*now_us=*/0); + rtcp_transceiver.ReceivePacket(raw_packet, Timestamp::Micros(0)); // Trigger sending ReceiverReport. rtcp_transceiver.SendCompoundPacket(); @@ -759,7 +885,7 @@ TEST(RtcpTransceiverImplTest, WhenSendsReceiverReportCalculatesDelaySinceLastSenderReport) { const uint32_t kRemoteSsrc1 = 4321; const uint32_t kRemoteSsrc2 = 5321; - rtc::ScopedFakeClock clock; + std::vector statistics_report_blocks(2); statistics_report_blocks[0].SetMediaSsrc(kRemoteSsrc1); statistics_report_blocks[1].SetMediaSsrc(kRemoteSsrc2); @@ -767,7 +893,9 @@ TEST(RtcpTransceiverImplTest, EXPECT_CALL(receive_statistics, RtcpReportBlocks(_)) .WillOnce(Return(statistics_report_blocks)); + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.schedule_periodic_compound_packets = false; RtcpPacketParser rtcp_parser; RtcpParserTransport transport(&rtcp_parser); @@ -775,18 +903,19 @@ TEST(RtcpTransceiverImplTest, config.receive_statistics = &receive_statistics; RtcpTransceiverImpl rtcp_transceiver(config); - auto receive_sender_report = [&rtcp_transceiver](uint32_t remote_ssrc) { + auto receive_sender_report = [&rtcp_transceiver, + &clock](uint32_t remote_ssrc) { SenderReport sr; sr.SetSenderSsrc(remote_ssrc); auto raw_packet = sr.Build(); - rtcp_transceiver.ReceivePacket(raw_packet, rtc::TimeMicros()); + rtcp_transceiver.ReceivePacket(raw_packet, clock.CurrentTime()); }; receive_sender_report(kRemoteSsrc1); - clock.AdvanceTime(webrtc::TimeDelta::Millis(100)); + clock.AdvanceTime(TimeDelta::Millis(100)); receive_sender_report(kRemoteSsrc2); - clock.AdvanceTime(webrtc::TimeDelta::Millis(100)); + clock.AdvanceTime(TimeDelta::Millis(100)); // Trigger ReceiverReport back. rtcp_transceiver.SendCompoundPacket(); @@ -808,7 +937,9 @@ TEST(RtcpTransceiverImplTest, SendsNack) { const uint32_t kSenderSsrc = 1234; const uint32_t kRemoteSsrc = 4321; std::vector kMissingSequenceNumbers = {34, 37, 38}; + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; config.schedule_periodic_compound_packets = false; RtcpPacketParser rtcp_parser; @@ -827,7 +958,9 @@ TEST(RtcpTransceiverImplTest, SendsNack) { TEST(RtcpTransceiverImplTest, RequestKeyFrameWithPictureLossIndication) { const uint32_t kSenderSsrc = 1234; const uint32_t kRemoteSsrc = 4321; + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; config.schedule_periodic_compound_packets = false; RtcpPacketParser rtcp_parser; @@ -846,7 +979,9 @@ TEST(RtcpTransceiverImplTest, RequestKeyFrameWithPictureLossIndication) { TEST(RtcpTransceiverImplTest, RequestKeyFrameWithFullIntraRequest) { const uint32_t kSenderSsrc = 1234; const uint32_t kRemoteSsrcs[] = {4321, 5321}; + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; config.schedule_periodic_compound_packets = false; RtcpPacketParser rtcp_parser; @@ -863,7 +998,9 @@ TEST(RtcpTransceiverImplTest, RequestKeyFrameWithFullIntraRequest) { } TEST(RtcpTransceiverImplTest, RequestKeyFrameWithFirIncreaseSeqNoPerSsrc) { + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.schedule_periodic_compound_packets = false; RtcpPacketParser rtcp_parser; RtcpParserTransport transport(&rtcp_parser); @@ -893,7 +1030,9 @@ TEST(RtcpTransceiverImplTest, RequestKeyFrameWithFirIncreaseSeqNoPerSsrc) { } TEST(RtcpTransceiverImplTest, SendFirDoesNotIncreaseSeqNoIfOldRequest) { + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.schedule_periodic_compound_packets = false; RtcpPacketParser rtcp_parser; RtcpParserTransport transport(&rtcp_parser); @@ -919,7 +1058,9 @@ TEST(RtcpTransceiverImplTest, SendFirDoesNotIncreaseSeqNoIfOldRequest) { TEST(RtcpTransceiverImplTest, KeyFrameRequestCreatesCompoundPacket) { const uint32_t kRemoteSsrcs[] = {4321}; + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; // Turn periodic off to ensure sent rtcp packet is explicitly requested. config.schedule_periodic_compound_packets = false; RtcpPacketParser rtcp_parser; @@ -938,7 +1079,9 @@ TEST(RtcpTransceiverImplTest, KeyFrameRequestCreatesCompoundPacket) { TEST(RtcpTransceiverImplTest, KeyFrameRequestCreatesReducedSizePacket) { const uint32_t kRemoteSsrcs[] = {4321}; + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; // Turn periodic off to ensure sent rtcp packet is explicitly requested. config.schedule_periodic_compound_packets = false; RtcpPacketParser rtcp_parser; @@ -957,8 +1100,9 @@ TEST(RtcpTransceiverImplTest, KeyFrameRequestCreatesReducedSizePacket) { TEST(RtcpTransceiverImplTest, SendsXrRrtrWhenEnabled) { const uint32_t kSenderSsrc = 4321; - rtc::ScopedFakeClock clock; + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; config.schedule_periodic_compound_packets = false; RtcpPacketParser rtcp_parser; @@ -968,7 +1112,7 @@ TEST(RtcpTransceiverImplTest, SendsXrRrtrWhenEnabled) { RtcpTransceiverImpl rtcp_transceiver(config); rtcp_transceiver.SendCompoundPacket(); - NtpTime ntp_time_now = TimeMicrosToNtp(rtc::TimeMicros()); + NtpTime ntp_time_now = clock.CurrentNtpTime(); EXPECT_EQ(rtcp_parser.xr()->num_packets(), 1); EXPECT_EQ(rtcp_parser.xr()->sender_ssrc(), kSenderSsrc); @@ -977,7 +1121,9 @@ TEST(RtcpTransceiverImplTest, SendsXrRrtrWhenEnabled) { } TEST(RtcpTransceiverImplTest, SendsNoXrRrtrWhenDisabled) { + SimulatedClock clock(0); RtcpTransceiverConfig config; + config.clock = &clock; config.schedule_periodic_compound_packets = false; RtcpPacketParser rtcp_parser; RtcpParserTransport transport(&rtcp_parser); @@ -995,9 +1141,11 @@ TEST(RtcpTransceiverImplTest, SendsNoXrRrtrWhenDisabled) { TEST(RtcpTransceiverImplTest, CalculatesRoundTripTimeOnDlrr) { const uint32_t kSenderSsrc = 4321; + SimulatedClock clock(0); MockRtcpRttStats rtt_observer; MockTransport null_transport; RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; config.schedule_periodic_compound_packets = false; config.outgoing_transport = &null_transport; @@ -1005,25 +1153,27 @@ TEST(RtcpTransceiverImplTest, CalculatesRoundTripTimeOnDlrr) { config.rtt_observer = &rtt_observer; RtcpTransceiverImpl rtcp_transceiver(config); - int64_t time_us = 12345678; + Timestamp time = Timestamp::Micros(12345678); webrtc::rtcp::ReceiveTimeInfo rti; rti.ssrc = kSenderSsrc; - rti.last_rr = CompactNtp(TimeMicrosToNtp(time_us)); + rti.last_rr = CompactNtp(clock.ConvertTimestampToNtpTime(time)); rti.delay_since_last_rr = SaturatedUsToCompactNtp(10 * 1000); webrtc::rtcp::ExtendedReports xr; xr.AddDlrrItem(rti); auto raw_packet = xr.Build(); EXPECT_CALL(rtt_observer, OnRttUpdate(100 /* rtt_ms */)); - rtcp_transceiver.ReceivePacket(raw_packet, time_us + 110 * 1000); + rtcp_transceiver.ReceivePacket(raw_packet, time + TimeDelta::Millis(110)); } TEST(RtcpTransceiverImplTest, IgnoresUnknownSsrcInDlrr) { const uint32_t kSenderSsrc = 4321; const uint32_t kUnknownSsrc = 4322; + SimulatedClock clock(0); MockRtcpRttStats rtt_observer; MockTransport null_transport; RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; config.schedule_periodic_compound_packets = false; config.outgoing_transport = &null_transport; @@ -1031,16 +1181,16 @@ TEST(RtcpTransceiverImplTest, IgnoresUnknownSsrcInDlrr) { config.rtt_observer = &rtt_observer; RtcpTransceiverImpl rtcp_transceiver(config); - int64_t time_us = 12345678; + Timestamp time = Timestamp::Micros(12345678); webrtc::rtcp::ReceiveTimeInfo rti; rti.ssrc = kUnknownSsrc; - rti.last_rr = CompactNtp(TimeMicrosToNtp(time_us)); + rti.last_rr = CompactNtp(clock.ConvertTimestampToNtpTime(time)); webrtc::rtcp::ExtendedReports xr; xr.AddDlrrItem(rti); auto raw_packet = xr.Build(); EXPECT_CALL(rtt_observer, OnRttUpdate(_)).Times(0); - rtcp_transceiver.ReceivePacket(raw_packet, time_us + 100000); + rtcp_transceiver.ReceivePacket(raw_packet, time + TimeDelta::Millis(100)); } } // namespace diff --git a/modules/rtp_rtcp/source/rtcp_transceiver_unittest.cc b/modules/rtp_rtcp/source/rtcp_transceiver_unittest.cc index 9c181c6526..290aa48ff4 100644 --- a/modules/rtp_rtcp/source/rtcp_transceiver_unittest.cc +++ b/modules/rtp_rtcp/source/rtcp_transceiver_unittest.cc @@ -18,6 +18,7 @@ #include "modules/rtp_rtcp/source/rtcp_packet/transport_feedback.h" #include "rtc_base/event.h" #include "rtc_base/task_queue_for_test.h" +#include "system_wrappers/include/clock.h" #include "test/gmock.h" #include "test/gtest.h" #include "test/mock_transport.h" @@ -34,6 +35,7 @@ using ::testing::NiceMock; using ::webrtc::MockTransport; using ::webrtc::RtcpTransceiver; using ::webrtc::RtcpTransceiverConfig; +using ::webrtc::SimulatedClock; using ::webrtc::TaskQueueForTest; using ::webrtc::rtcp::RemoteEstimate; using ::webrtc::rtcp::RtcpPacket; @@ -57,9 +59,11 @@ void WaitPostedTasks(TaskQueueForTest* queue) { } TEST(RtcpTransceiverTest, SendsRtcpOnTaskQueueWhenCreatedOffTaskQueue) { + SimulatedClock clock(0); MockTransport outgoing_transport; TaskQueueForTest queue("rtcp"); RtcpTransceiverConfig config; + config.clock = &clock; config.outgoing_transport = &outgoing_transport; config.task_queue = queue.Get(); EXPECT_CALL(outgoing_transport, SendRtcp(_, _)) @@ -74,9 +78,11 @@ TEST(RtcpTransceiverTest, SendsRtcpOnTaskQueueWhenCreatedOffTaskQueue) { } TEST(RtcpTransceiverTest, SendsRtcpOnTaskQueueWhenCreatedOnTaskQueue) { + SimulatedClock clock(0); MockTransport outgoing_transport; TaskQueueForTest queue("rtcp"); RtcpTransceiverConfig config; + config.clock = &clock; config.outgoing_transport = &outgoing_transport; config.task_queue = queue.Get(); EXPECT_CALL(outgoing_transport, SendRtcp(_, _)) @@ -94,9 +100,11 @@ TEST(RtcpTransceiverTest, SendsRtcpOnTaskQueueWhenCreatedOnTaskQueue) { } TEST(RtcpTransceiverTest, CanBeDestroyedOnTaskQueue) { + SimulatedClock clock(0); NiceMock outgoing_transport; TaskQueueForTest queue("rtcp"); RtcpTransceiverConfig config; + config.clock = &clock; config.outgoing_transport = &outgoing_transport; config.task_queue = queue.Get(); auto rtcp_transceiver = std::make_unique(config); @@ -110,9 +118,11 @@ TEST(RtcpTransceiverTest, CanBeDestroyedOnTaskQueue) { } TEST(RtcpTransceiverTest, CanBeDestroyedWithoutBlocking) { + SimulatedClock clock(0); TaskQueueForTest queue("rtcp"); NiceMock outgoing_transport; RtcpTransceiverConfig config; + config.clock = &clock; config.outgoing_transport = &outgoing_transport; config.task_queue = queue.Get(); auto* rtcp_transceiver = new RtcpTransceiver(config); @@ -131,9 +141,11 @@ TEST(RtcpTransceiverTest, CanBeDestroyedWithoutBlocking) { } TEST(RtcpTransceiverTest, MaySendPacketsAfterDestructor) { // i.e. Be careful! + SimulatedClock clock(0); NiceMock outgoing_transport; // Must outlive queue below. TaskQueueForTest queue("rtcp"); RtcpTransceiverConfig config; + config.clock = &clock; config.outgoing_transport = &outgoing_transport; config.task_queue = queue.Get(); auto* rtcp_transceiver = new RtcpTransceiver(config); @@ -162,9 +174,11 @@ rtc::CopyOnWriteBuffer CreateSenderReport(uint32_t ssrc, uint32_t rtp_time) { TEST(RtcpTransceiverTest, DoesntPostToRtcpObserverAfterCallToRemove) { const uint32_t kRemoteSsrc = 1234; + SimulatedClock clock(0); MockTransport null_transport; TaskQueueForTest queue("rtcp"); RtcpTransceiverConfig config; + config.clock = &clock; config.outgoing_transport = &null_transport; config.task_queue = queue.Get(); RtcpTransceiver rtcp_transceiver(config); @@ -189,9 +203,11 @@ TEST(RtcpTransceiverTest, DoesntPostToRtcpObserverAfterCallToRemove) { TEST(RtcpTransceiverTest, RemoveMediaReceiverRtcpObserverIsNonBlocking) { const uint32_t kRemoteSsrc = 1234; + SimulatedClock clock(0); MockTransport null_transport; TaskQueueForTest queue("rtcp"); RtcpTransceiverConfig config; + config.clock = &clock; config.outgoing_transport = &null_transport; config.task_queue = queue.Get(); RtcpTransceiver rtcp_transceiver(config); @@ -213,9 +229,11 @@ TEST(RtcpTransceiverTest, RemoveMediaReceiverRtcpObserverIsNonBlocking) { } TEST(RtcpTransceiverTest, CanCallSendCompoundPacketFromAnyThread) { + SimulatedClock clock(0); MockTransport outgoing_transport; TaskQueueForTest queue("rtcp"); RtcpTransceiverConfig config; + config.clock = &clock; config.outgoing_transport = &outgoing_transport; config.task_queue = queue.Get(); @@ -242,9 +260,11 @@ TEST(RtcpTransceiverTest, CanCallSendCompoundPacketFromAnyThread) { } TEST(RtcpTransceiverTest, DoesntSendPacketsAfterStopCallback) { + SimulatedClock clock(0); NiceMock outgoing_transport; TaskQueueForTest queue("rtcp"); RtcpTransceiverConfig config; + config.clock = &clock; config.outgoing_transport = &outgoing_transport; config.task_queue = queue.Get(); config.schedule_periodic_compound_packets = true; @@ -263,9 +283,11 @@ TEST(RtcpTransceiverTest, DoesntSendPacketsAfterStopCallback) { TEST(RtcpTransceiverTest, SendsCombinedRtcpPacketOnTaskQueue) { static constexpr uint32_t kSenderSsrc = 12345; + SimulatedClock clock(0); MockTransport outgoing_transport; TaskQueueForTest queue("rtcp"); RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; config.outgoing_transport = &outgoing_transport; config.task_queue = queue.Get(); @@ -300,9 +322,11 @@ TEST(RtcpTransceiverTest, SendsCombinedRtcpPacketOnTaskQueue) { TEST(RtcpTransceiverTest, SendFrameIntraRequestDefaultsToNewRequest) { static constexpr uint32_t kSenderSsrc = 12345; + SimulatedClock clock(0); MockTransport outgoing_transport; TaskQueueForTest queue("rtcp"); RtcpTransceiverConfig config; + config.clock = &clock; config.feedback_ssrc = kSenderSsrc; config.outgoing_transport = &outgoing_transport; config.task_queue = queue.Get(); diff --git a/modules/rtp_rtcp/source/rtp_dependency_descriptor_extension_unittest.cc b/modules/rtp_rtcp/source/rtp_dependency_descriptor_extension_unittest.cc index 11d809693c..974557ce6e 100644 --- a/modules/rtp_rtcp/source/rtp_dependency_descriptor_extension_unittest.cc +++ b/modules/rtp_rtcp/source/rtp_dependency_descriptor_extension_unittest.cc @@ -115,5 +115,23 @@ TEST(RtpDependencyDescriptorExtensionTest, buffer, structure, active_chains, descriptor)); } +TEST(RtpDependencyDescriptorExtensionTest, FailsToWriteInvalidDescriptor) { + uint8_t buffer[256]; + FrameDependencyStructure structure; + structure.num_decode_targets = 2; + structure.num_chains = 2; + structure.templates = { + FrameDependencyTemplate().T(0).Dtis("SR").ChainDiffs({2, 2})}; + DependencyDescriptor descriptor; + descriptor.frame_dependencies = structure.templates[0]; + descriptor.frame_dependencies.temporal_id = 1; + + EXPECT_EQ( + RtpDependencyDescriptorExtension::ValueSize(structure, 0b11, descriptor), + 0u); + EXPECT_FALSE(RtpDependencyDescriptorExtension::Write(buffer, structure, 0b11, + descriptor)); +} + } // namespace } // namespace webrtc diff --git a/modules/rtp_rtcp/source/rtp_dependency_descriptor_reader.cc b/modules/rtp_rtcp/source/rtp_dependency_descriptor_reader.cc index cba594dc6f..8f0cb349bc 100644 --- a/modules/rtp_rtcp/source/rtp_dependency_descriptor_reader.cc +++ b/modules/rtp_rtcp/source/rtp_dependency_descriptor_reader.cc @@ -47,14 +47,14 @@ RtpDependencyDescriptorReader::RtpDependencyDescriptorReader( uint32_t RtpDependencyDescriptorReader::ReadBits(size_t bit_count) { uint32_t value = 0; - if (!buffer_.ReadBits(&value, bit_count)) + if (!buffer_.ReadBits(bit_count, value)) parsing_failed_ = true; return value; } uint32_t RtpDependencyDescriptorReader::ReadNonSymmetric(size_t num_values) { uint32_t value = 0; - if (!buffer_.ReadNonSymmetric(&value, num_values)) + if (!buffer_.ReadNonSymmetric(num_values, value)) parsing_failed_ = true; return value; } diff --git a/modules/rtp_rtcp/source/rtp_dependency_descriptor_writer.cc b/modules/rtp_rtcp/source/rtp_dependency_descriptor_writer.cc index 25d221253b..31df783064 100644 --- a/modules/rtp_rtcp/source/rtp_dependency_descriptor_writer.cc +++ b/modules/rtp_rtcp/source/rtp_dependency_descriptor_writer.cc @@ -66,6 +66,9 @@ RtpDependencyDescriptorWriter::RtpDependencyDescriptorWriter( } bool RtpDependencyDescriptorWriter::Write() { + if (build_failed_) { + return false; + } WriteMandatoryFields(); if (HasExtendedFields()) { WriteExtendedFields(); @@ -83,6 +86,9 @@ bool RtpDependencyDescriptorWriter::Write() { } int RtpDependencyDescriptorWriter::ValueSizeBits() const { + if (build_failed_) { + return 0; + } static constexpr int kMandatoryFields = 1 + 1 + 6 + 16; int value_size_bits = kMandatoryFields + best_template_.extra_size_bits; if (HasExtendedFields()) { @@ -172,7 +178,10 @@ void RtpDependencyDescriptorWriter::FindBestTemplate() { frame_template.temporal_id; }; auto first = absl::c_find_if(templates, same_layer); - RTC_CHECK(first != templates.end()); + if (first == templates.end()) { + build_failed_ = true; + return; + } auto last = std::find_if_not(first, templates.end(), same_layer); best_template_ = CalculateMatch(first); diff --git a/modules/rtp_rtcp/source/rtp_header_extension_map.cc b/modules/rtp_rtcp/source/rtp_header_extension_map.cc index c16dcaf6f7..0b5ba474c7 100644 --- a/modules/rtp_rtcp/source/rtp_header_extension_map.cc +++ b/modules/rtp_rtcp/source/rtp_header_extension_map.cc @@ -34,6 +34,7 @@ constexpr ExtensionInfo CreateExtensionInfo() { constexpr ExtensionInfo kExtensions[] = { CreateExtensionInfo(), CreateExtensionInfo(), + CreateExtensionInfo(), CreateExtensionInfo(), CreateExtensionInfo(), CreateExtensionInfo(), @@ -50,6 +51,7 @@ constexpr ExtensionInfo kExtensions[] = { CreateExtensionInfo(), CreateExtensionInfo(), CreateExtensionInfo(), + CreateExtensionInfo(), }; // Because of kRtpExtensionNone, NumberOfExtension is 1 bigger than the actual diff --git a/modules/rtp_rtcp/source/rtp_header_extensions.cc b/modules/rtp_rtcp/source/rtp_header_extensions.cc index b540e4b22e..1dd4f54759 100644 --- a/modules/rtp_rtcp/source/rtp_header_extensions.cc +++ b/modules/rtp_rtcp/source/rtp_header_extensions.cc @@ -13,6 +13,7 @@ #include #include +#include #include #include "modules/rtp_rtcp/include/rtp_cvo.h" @@ -186,6 +187,60 @@ bool AudioLevel::Write(rtc::ArrayView data, return true; } +// An RTP Header Extension for Mixer-to-Client Audio Level Indication +// +// https://tools.ietf.org/html/rfc6465 +// +// The form of the audio level extension block: +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID | len=2 |0| level 1 |0| level 2 |0| level 3 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// Sample Audio Level Encoding Using the One-Byte Header Format +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID | len=3 |0| level 1 |0| level 2 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |0| level 3 | 0 (pad) | ... | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// Sample Audio Level Encoding Using the Two-Byte Header Format +constexpr RTPExtensionType CsrcAudioLevel::kId; +constexpr uint8_t CsrcAudioLevel::kMaxValueSizeBytes; +constexpr const char CsrcAudioLevel::kUri[]; + +bool CsrcAudioLevel::Parse(rtc::ArrayView data, + std::vector* csrc_audio_levels) { + if (data.size() > kRtpCsrcSize) { + return false; + } + csrc_audio_levels->resize(data.size()); + for (size_t i = 0; i < data.size(); i++) { + (*csrc_audio_levels)[i] = data[i] & 0x7F; + } + return true; +} + +size_t CsrcAudioLevel::ValueSize( + rtc::ArrayView csrc_audio_levels) { + return csrc_audio_levels.size(); +} + +bool CsrcAudioLevel::Write(rtc::ArrayView data, + rtc::ArrayView csrc_audio_levels) { + RTC_CHECK_LE(csrc_audio_levels.size(), kRtpCsrcSize); + if (csrc_audio_levels.size() != data.size()) { + return false; + } + for (size_t i = 0; i < csrc_audio_levels.size(); i++) { + data[i] = csrc_audio_levels[i] & 0x7F; + } + return true; +} + // From RFC 5450: Transmission Time Offsets in RTP Streams. // // The transmission time is signaled to the receiver in-band using the @@ -823,4 +878,32 @@ bool InbandComfortNoiseExtension::Write(rtc::ArrayView data, return true; } +// VideoFrameTrackingIdExtension +// +// 0 1 2 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID | L=1 | video-frame-tracking-id | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +constexpr RTPExtensionType VideoFrameTrackingIdExtension::kId; +constexpr uint8_t VideoFrameTrackingIdExtension::kValueSizeBytes; +constexpr const char VideoFrameTrackingIdExtension::kUri[]; + +bool VideoFrameTrackingIdExtension::Parse(rtc::ArrayView data, + uint16_t* video_frame_tracking_id) { + if (data.size() != kValueSizeBytes) { + return false; + } + *video_frame_tracking_id = ByteReader::ReadBigEndian(data.data()); + return true; +} + +bool VideoFrameTrackingIdExtension::Write(rtc::ArrayView data, + uint16_t video_frame_tracking_id) { + RTC_DCHECK_EQ(data.size(), kValueSizeBytes); + ByteWriter::WriteBigEndian(data.data(), video_frame_tracking_id); + return true; +} + } // namespace webrtc diff --git a/modules/rtp_rtcp/source/rtp_header_extensions.h b/modules/rtp_rtcp/source/rtp_header_extensions.h index 1352611fb1..b47824afdb 100644 --- a/modules/rtp_rtcp/source/rtp_header_extensions.h +++ b/modules/rtp_rtcp/source/rtp_header_extensions.h @@ -14,6 +14,7 @@ #include #include +#include #include "api/array_view.h" #include "api/rtp_headers.h" @@ -77,6 +78,20 @@ class AudioLevel { uint8_t audio_level); }; +class CsrcAudioLevel { + public: + static constexpr RTPExtensionType kId = kRtpExtensionCsrcAudioLevel; + static constexpr uint8_t kMaxValueSizeBytes = 15; + static constexpr const char kUri[] = + "urn:ietf:params:rtp-hdrext:csrc-audio-level"; + + static bool Parse(rtc::ArrayView data, + std::vector* csrc_audio_levels); + static size_t ValueSize(rtc::ArrayView csrc_audio_levels); + static bool Write(rtc::ArrayView data, + rtc::ArrayView csrc_audio_levels); +}; + class TransmissionOffset { public: using value_type = int32_t; @@ -307,5 +322,21 @@ class InbandComfortNoiseExtension { absl::optional level); }; +class VideoFrameTrackingIdExtension { + public: + using value_type = uint16_t; + static constexpr RTPExtensionType kId = kRtpExtensionVideoFrameTrackingId; + static constexpr uint8_t kValueSizeBytes = 2; + static constexpr const char kUri[] = + "http://www.webrtc.org/experiments/rtp-hdrext/video-frame-tracking-id"; + static bool Parse(rtc::ArrayView data, + uint16_t* video_frame_tracking_id); + static size_t ValueSize(uint16_t /*video_frame_tracking_id*/) { + return kValueSizeBytes; + } + static bool Write(rtc::ArrayView data, + uint16_t video_frame_tracking_id); +}; + } // namespace webrtc #endif // MODULES_RTP_RTCP_SOURCE_RTP_HEADER_EXTENSIONS_H_ diff --git a/modules/rtp_rtcp/source/rtp_packet.cc b/modules/rtp_rtcp/source/rtp_packet.cc index 38d29cc2b4..8523637feb 100644 --- a/modules/rtp_rtcp/source/rtp_packet.cc +++ b/modules/rtp_rtcp/source/rtp_packet.cc @@ -27,6 +27,7 @@ constexpr size_t kFixedHeaderSize = 12; constexpr uint8_t kRtpVersion = 2; constexpr uint16_t kOneByteExtensionProfileId = 0xBEDE; constexpr uint16_t kTwoByteExtensionProfileId = 0x1000; +constexpr uint16_t kTwobyteExtensionProfileIdAppBitsFilter = 0xfff0; constexpr size_t kOneByteExtensionHeaderLength = 1; constexpr size_t kTwoByteExtensionHeaderLength = 2; constexpr size_t kDefaultPacketSize = 1500; @@ -70,8 +71,8 @@ RtpPacket::RtpPacket(const ExtensionManager* extensions, size_t capacity) RtpPacket::~RtpPacket() {} -void RtpPacket::IdentifyExtensions(const ExtensionManager& extensions) { - extensions_ = extensions; +void RtpPacket::IdentifyExtensions(ExtensionManager extensions) { + extensions_ = std::move(extensions); } bool RtpPacket::Parse(const uint8_t* buffer, size_t buffer_size) { @@ -111,8 +112,6 @@ std::vector RtpPacket::Csrcs() const { } void RtpPacket::CopyHeaderFrom(const RtpPacket& packet) { - RTC_DCHECK_GE(capacity(), packet.headers_size()); - marker_ = packet.marker_; payload_type_ = packet.payload_type_; sequence_number_ = packet.sequence_number_; @@ -186,6 +185,7 @@ void RtpPacket::ZeroMutableExtensions() { break; } case RTPExtensionType::kRtpExtensionAudioLevel: + case RTPExtensionType::kRtpExtensionCsrcAudioLevel: case RTPExtensionType::kRtpExtensionAbsoluteCaptureTime: case RTPExtensionType::kRtpExtensionColorSpace: case RTPExtensionType::kRtpExtensionGenericFrameDescriptor00: @@ -198,7 +198,8 @@ void RtpPacket::ZeroMutableExtensions() { case RTPExtensionType::kRtpExtensionVideoContentType: case RTPExtensionType::kRtpExtensionVideoLayersAllocation: case RTPExtensionType::kRtpExtensionVideoRotation: - case RTPExtensionType::kRtpExtensionInbandComfortNoise: { + case RTPExtensionType::kRtpExtensionInbandComfortNoise: + case RTPExtensionType::kRtpExtensionVideoFrameTrackingId: { // Non-mutable extension. Don't change it. break; } @@ -465,16 +466,6 @@ bool RtpPacket::ParseBuffer(const uint8_t* buffer, size_t size) { } payload_offset_ = kFixedHeaderSize + number_of_crcs * 4; - if (has_padding) { - padding_size_ = buffer[size - 1]; - if (padding_size_ == 0) { - RTC_LOG(LS_WARNING) << "Padding was set, but padding size is zero"; - return false; - } - } else { - padding_size_ = 0; - } - extensions_size_ = 0; extension_entries_.clear(); if (has_extension) { @@ -500,7 +491,8 @@ bool RtpPacket::ParseBuffer(const uint8_t* buffer, size_t size) { return false; } if (profile != kOneByteExtensionProfileId && - profile != kTwoByteExtensionProfileId) { + (profile & kTwobyteExtensionProfileIdAppBitsFilter) != + kTwoByteExtensionProfileId) { RTC_LOG(LS_WARNING) << "Unsupported rtp extension " << profile; } else { size_t extension_header_length = profile == kOneByteExtensionProfileId @@ -554,6 +546,16 @@ bool RtpPacket::ParseBuffer(const uint8_t* buffer, size_t size) { payload_offset_ = extension_offset + extensions_capacity; } + if (has_padding && payload_offset_ < size) { + padding_size_ = buffer[size - 1]; + if (padding_size_ == 0) { + RTC_LOG(LS_WARNING) << "Padding was set, but padding size is zero"; + return false; + } + } else { + padding_size_ = 0; + } + if (payload_offset_ + padding_size_ > size) { return false; } diff --git a/modules/rtp_rtcp/source/rtp_packet.h b/modules/rtp_rtcp/source/rtp_packet.h index aa854f35ab..e2e291cf5d 100644 --- a/modules/rtp_rtcp/source/rtp_packet.h +++ b/modules/rtp_rtcp/source/rtp_packet.h @@ -51,7 +51,7 @@ class RtpPacket { bool Parse(rtc::CopyOnWriteBuffer packet); // Maps extensions id to their types. - void IdentifyExtensions(const ExtensionManager& extensions); + void IdentifyExtensions(ExtensionManager extensions); // Header. bool Marker() const { return marker_; } @@ -65,6 +65,7 @@ class RtpPacket { // Payload. size_t payload_size() const { return payload_size_; } + bool has_padding() const { return buffer_[0] & 0x20; } size_t padding_size() const { return padding_size_; } rtc::ArrayView payload() const { return rtc::MakeArrayView(data() + payload_offset_, payload_size_); @@ -114,6 +115,11 @@ class RtpPacket { bool HasExtension() const; bool HasExtension(ExtensionType type) const; + // Returns whether there is an associated id for the extension and thus it is + // possible to set the extension. + template + bool IsRegistered() const; + template bool GetExtension(FirstValue, Values...) const; @@ -207,6 +213,11 @@ bool RtpPacket::HasExtension() const { return HasExtension(Extension::kId); } +template +bool RtpPacket::IsRegistered() const { + return extensions_.IsRegistered(Extension::kId); +} + template bool RtpPacket::GetExtension(FirstValue first, Values... values) const { auto raw = FindExtension(Extension::kId); diff --git a/modules/rtp_rtcp/source/rtp_packet_history.cc b/modules/rtp_rtcp/source/rtp_packet_history.cc index 1fbfb7651d..5089933051 100644 --- a/modules/rtp_rtcp/source/rtp_packet_history.cc +++ b/modules/rtp_rtcp/source/rtp_packet_history.cc @@ -134,7 +134,7 @@ void RtpPacketHistory::PutRtpPacket(std::unique_ptr packet, // Store packet. const uint16_t rtp_seq_no = packet->SequenceNumber(); int packet_index = GetPacketIndex(rtp_seq_no); - if (packet_index >= 0u && + if (packet_index >= 0 && static_cast(packet_index) < packet_history_.size() && packet_history_[packet_index].packet_ != nullptr) { RTC_LOG(LS_WARNING) << "Duplicate packet inserted: " << rtp_seq_no; diff --git a/modules/rtp_rtcp/source/rtp_packet_received.cc b/modules/rtp_rtcp/source/rtp_packet_received.cc index feadee1db1..6b2cc76981 100644 --- a/modules/rtp_rtcp/source/rtp_packet_received.cc +++ b/modules/rtp_rtcp/source/rtp_packet_received.cc @@ -21,8 +21,10 @@ namespace webrtc { RtpPacketReceived::RtpPacketReceived() = default; -RtpPacketReceived::RtpPacketReceived(const ExtensionManager* extensions) - : RtpPacket(extensions) {} +RtpPacketReceived::RtpPacketReceived( + const ExtensionManager* extensions, + webrtc::Timestamp arrival_time /*= webrtc::Timestamp::MinusInfinity()*/) + : RtpPacket(extensions), arrival_time_(arrival_time) {} RtpPacketReceived::RtpPacketReceived(const RtpPacketReceived& packet) = default; RtpPacketReceived::RtpPacketReceived(RtpPacketReceived&& packet) = default; diff --git a/modules/rtp_rtcp/source/rtp_packet_received.h b/modules/rtp_rtcp/source/rtp_packet_received.h index 6727b67750..431d3f52be 100644 --- a/modules/rtp_rtcp/source/rtp_packet_received.h +++ b/modules/rtp_rtcp/source/rtp_packet_received.h @@ -12,18 +12,26 @@ #include -#include +#include +#include "absl/base/attributes.h" #include "api/array_view.h" +#include "api/ref_counted_base.h" #include "api/rtp_headers.h" +#include "api/scoped_refptr.h" +#include "api/units/timestamp.h" #include "modules/rtp_rtcp/source/rtp_packet.h" namespace webrtc { // Class to hold rtp packet with metadata for receiver side. +// The metadata is not parsed from the rtp packet, but may be derived from the +// data that is parsed from the rtp packet. class RtpPacketReceived : public RtpPacket { public: RtpPacketReceived(); - explicit RtpPacketReceived(const ExtensionManager* extensions); + explicit RtpPacketReceived( + const ExtensionManager* extensions, + webrtc::Timestamp arrival_time = webrtc::Timestamp::MinusInfinity()); RtpPacketReceived(const RtpPacketReceived& packet); RtpPacketReceived(RtpPacketReceived&& packet); @@ -38,8 +46,17 @@ class RtpPacketReceived : public RtpPacket { // Time in local time base as close as it can to packet arrived on the // network. - int64_t arrival_time_ms() const { return arrival_time_ms_; } - void set_arrival_time_ms(int64_t time) { arrival_time_ms_ = time; } + webrtc::Timestamp arrival_time() const { return arrival_time_; } + void set_arrival_time(webrtc::Timestamp time) { arrival_time_ = time; } + + ABSL_DEPRECATED("Use arrival_time() instead") + int64_t arrival_time_ms() const { + return arrival_time_.IsMinusInfinity() ? -1 : arrival_time_.ms(); + } + ABSL_DEPRECATED("Use set_arrival_time() instead") + void set_arrival_time_ms(int64_t time) { + arrival_time_ = webrtc::Timestamp::Millis(time); + } // Flag if packet was recovered via RTX or FEC. bool recovered() const { return recovered_; } @@ -50,20 +67,20 @@ class RtpPacketReceived : public RtpPacket { payload_type_frequency_ = value; } - // Additional data bound to the RTP packet for use in application code, - // outside of WebRTC. - rtc::ArrayView application_data() const { - return application_data_; + // An application can attach arbitrary data to an RTP packet using + // `additional_data`. The additional data does not affect WebRTC processing. + rtc::scoped_refptr additional_data() const { + return additional_data_; } - void set_application_data(rtc::ArrayView data) { - application_data_.assign(data.begin(), data.end()); + void set_additional_data(rtc::scoped_refptr data) { + additional_data_ = std::move(data); } private: - int64_t arrival_time_ms_ = 0; + webrtc::Timestamp arrival_time_ = Timestamp::MinusInfinity(); int payload_type_frequency_ = 0; bool recovered_ = false; - std::vector application_data_; + rtc::scoped_refptr additional_data_; }; } // namespace webrtc diff --git a/modules/rtp_rtcp/source/rtp_packet_to_send.h b/modules/rtp_rtcp/source/rtp_packet_to_send.h index 9aaf9a52e6..12341ef6cf 100644 --- a/modules/rtp_rtcp/source/rtp_packet_to_send.h +++ b/modules/rtp_rtcp/source/rtp_packet_to_send.h @@ -13,10 +13,12 @@ #include #include -#include +#include #include "absl/types/optional.h" #include "api/array_view.h" +#include "api/ref_counted_base.h" +#include "api/scoped_refptr.h" #include "api/video/video_timing.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" #include "modules/rtp_rtcp/source/rtp_header_extensions.h" @@ -24,6 +26,8 @@ namespace webrtc { // Class to hold rtp packet with metadata for sender side. +// The metadata is not send over the wire, but packet sender may use it to +// create rtp header extensions or other data that is sent over the wire. class RtpPacketToSend : public RtpPacket { public: // RtpPacketToSend::Type is deprecated. Use RtpPacketMediaType directly. @@ -55,23 +59,22 @@ class RtpPacketToSend : public RtpPacket { void set_retransmitted_sequence_number(uint16_t sequence_number) { retransmitted_sequence_number_ = sequence_number; } - absl::optional retransmitted_sequence_number() { + absl::optional retransmitted_sequence_number() const { return retransmitted_sequence_number_; } void set_allow_retransmission(bool allow_retransmission) { allow_retransmission_ = allow_retransmission; } - bool allow_retransmission() { return allow_retransmission_; } + bool allow_retransmission() const { return allow_retransmission_; } - // Additional data bound to the RTP packet for use in application code, - // outside of WebRTC. - rtc::ArrayView application_data() const { - return application_data_; + // An application can attach arbitrary data to an RTP packet using + // `additional_data`. The additional data does not affect WebRTC processing. + rtc::scoped_refptr additional_data() const { + return additional_data_; } - - void set_application_data(rtc::ArrayView data) { - application_data_.assign(data.begin(), data.end()); + void set_additional_data(rtc::scoped_refptr data) { + additional_data_ = std::move(data); } void set_packetization_finish_time_ms(int64_t time) { @@ -122,7 +125,7 @@ class RtpPacketToSend : public RtpPacket { absl::optional packet_type_; bool allow_retransmission_ = false; absl::optional retransmitted_sequence_number_; - std::vector application_data_; + rtc::scoped_refptr additional_data_; bool is_first_packet_of_frame_ = false; bool is_key_frame_ = false; bool fec_protect_packet_ = false; diff --git a/modules/rtp_rtcp/source/rtp_packet_unittest.cc b/modules/rtp_rtcp/source/rtp_packet_unittest.cc index f7f21af41d..8c5df1a0ad 100644 --- a/modules/rtp_rtcp/source/rtp_packet_unittest.cc +++ b/modules/rtp_rtcp/source/rtp_packet_unittest.cc @@ -354,6 +354,35 @@ TEST(RtpPacketTest, CreateWithMaxSizeHeaderExtension) { EXPECT_EQ(read, kValue); } +TEST(RtpPacketTest, SetsRegisteredExtension) { + RtpPacketToSend::ExtensionManager extensions; + extensions.Register(kTransmissionOffsetExtensionId); + RtpPacketToSend packet(&extensions); + + EXPECT_TRUE(packet.IsRegistered()); + EXPECT_FALSE(packet.HasExtension()); + + // Try to set the extensions. + EXPECT_TRUE(packet.SetExtension(kTimeOffset)); + + EXPECT_TRUE(packet.HasExtension()); + EXPECT_EQ(packet.GetExtension(), kTimeOffset); +} + +TEST(RtpPacketTest, FailsToSetUnregisteredExtension) { + RtpPacketToSend::ExtensionManager extensions; + extensions.Register(kTransmissionOffsetExtensionId); + RtpPacketToSend packet(&extensions); + + EXPECT_FALSE(packet.IsRegistered()); + EXPECT_FALSE(packet.HasExtension()); + + EXPECT_FALSE(packet.SetExtension(42)); + + EXPECT_FALSE(packet.HasExtension()); + EXPECT_EQ(packet.GetExtension(), absl::nullopt); +} + TEST(RtpPacketTest, SetReservedExtensionsAfterPayload) { const size_t kPayloadSize = 4; RtpPacketToSend::ExtensionManager extensions; @@ -475,6 +504,76 @@ TEST(RtpPacketTest, ParseWithExtension) { EXPECT_EQ(0u, packet.padding_size()); } +TEST(RtpPacketTest, ParseHeaderOnly) { + // clang-format off + constexpr uint8_t kPaddingHeader[] = { + 0x80, 0x62, 0x35, 0x79, + 0x65, 0x43, 0x12, 0x78, + 0x12, 0x34, 0x56, 0x78}; + // clang-format on + + RtpPacket packet; + EXPECT_TRUE(packet.Parse(rtc::CopyOnWriteBuffer(kPaddingHeader))); + EXPECT_EQ(packet.PayloadType(), 0x62u); + EXPECT_EQ(packet.SequenceNumber(), 0x3579u); + EXPECT_EQ(packet.Timestamp(), 0x65431278u); + EXPECT_EQ(packet.Ssrc(), 0x12345678u); + + EXPECT_FALSE(packet.has_padding()); + EXPECT_EQ(packet.padding_size(), 0u); + EXPECT_EQ(packet.payload_size(), 0u); +} + +TEST(RtpPacketTest, ParseHeaderOnlyWithPadding) { + // clang-format off + constexpr uint8_t kPaddingHeader[] = { + 0xa0, 0x62, 0x35, 0x79, + 0x65, 0x43, 0x12, 0x78, + 0x12, 0x34, 0x56, 0x78}; + // clang-format on + + RtpPacket packet; + EXPECT_TRUE(packet.Parse(rtc::CopyOnWriteBuffer(kPaddingHeader))); + + EXPECT_TRUE(packet.has_padding()); + EXPECT_EQ(packet.padding_size(), 0u); + EXPECT_EQ(packet.payload_size(), 0u); +} + +TEST(RtpPacketTest, ParseHeaderOnlyWithExtensionAndPadding) { + // clang-format off + constexpr uint8_t kPaddingHeader[] = { + 0xb0, 0x62, 0x35, 0x79, + 0x65, 0x43, 0x12, 0x78, + 0x12, 0x34, 0x56, 0x78, + 0xbe, 0xde, 0x00, 0x01, + 0x11, 0x00, 0x00, 0x00}; + // clang-format on + + RtpHeaderExtensionMap extensions; + extensions.Register(1); + RtpPacket packet(&extensions); + EXPECT_TRUE(packet.Parse(rtc::CopyOnWriteBuffer(kPaddingHeader))); + EXPECT_TRUE(packet.has_padding()); + EXPECT_TRUE(packet.HasExtension()); + EXPECT_EQ(packet.padding_size(), 0u); +} + +TEST(RtpPacketTest, ParsePaddingOnlyPacket) { + // clang-format off + constexpr uint8_t kPaddingHeader[] = { + 0xa0, 0x62, 0x35, 0x79, + 0x65, 0x43, 0x12, 0x78, + 0x12, 0x34, 0x56, 0x78, + 0, 0, 3}; + // clang-format on + + RtpPacket packet; + EXPECT_TRUE(packet.Parse(rtc::CopyOnWriteBuffer(kPaddingHeader))); + EXPECT_TRUE(packet.has_padding()); + EXPECT_EQ(packet.padding_size(), 3u); +} + TEST(RtpPacketTest, GetExtensionWithoutParametersReturnsOptionalValue) { RtpPacket::ExtensionManager extensions; extensions.Register(kTransmissionOffsetExtensionId); diff --git a/modules/rtp_rtcp/source/rtp_packetizer_av1_test_helper.cc b/modules/rtp_rtcp/source/rtp_packetizer_av1_test_helper.cc new file mode 100644 index 0000000000..3d62bcef44 --- /dev/null +++ b/modules/rtp_rtcp/source/rtp_packetizer_av1_test_helper.cc @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/rtp_rtcp/source/rtp_packetizer_av1_test_helper.h" + +#include + +#include +#include + +namespace webrtc { + +Av1Obu::Av1Obu(uint8_t obu_type) : header_(obu_type | kAv1ObuSizePresentBit) {} + +Av1Obu& Av1Obu::WithExtension(uint8_t extension) { + extension_ = extension; + header_ |= kAv1ObuExtensionPresentBit; + return *this; +} +Av1Obu& Av1Obu::WithoutSize() { + header_ &= ~kAv1ObuSizePresentBit; + return *this; +} +Av1Obu& Av1Obu::WithPayload(std::vector payload) { + payload_ = std::move(payload); + return *this; +} + +std::vector BuildAv1Frame(std::initializer_list obus) { + std::vector raw; + for (const Av1Obu& obu : obus) { + raw.push_back(obu.header_); + if (obu.header_ & kAv1ObuExtensionPresentBit) { + raw.push_back(obu.extension_); + } + if (obu.header_ & kAv1ObuSizePresentBit) { + // write size in leb128 format. + size_t payload_size = obu.payload_.size(); + while (payload_size >= 0x80) { + raw.push_back(0x80 | (payload_size & 0x7F)); + payload_size >>= 7; + } + raw.push_back(payload_size); + } + raw.insert(raw.end(), obu.payload_.begin(), obu.payload_.end()); + } + return raw; +} + +} // namespace webrtc diff --git a/modules/rtp_rtcp/source/rtp_packetizer_av1_test_helper.h b/modules/rtp_rtcp/source/rtp_packetizer_av1_test_helper.h new file mode 100644 index 0000000000..04a902fe56 --- /dev/null +++ b/modules/rtp_rtcp/source/rtp_packetizer_av1_test_helper.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_RTP_RTCP_SOURCE_RTP_PACKETIZER_AV1_TEST_HELPER_H_ +#define MODULES_RTP_RTCP_SOURCE_RTP_PACKETIZER_AV1_TEST_HELPER_H_ + +#include + +#include +#include +#include + +namespace webrtc { +// All obu types offset by 3 to take correct position in the obu_header. +constexpr uint8_t kAv1ObuTypeSequenceHeader = 1 << 3; +constexpr uint8_t kAv1ObuTypeTemporalDelimiter = 2 << 3; +constexpr uint8_t kAv1ObuTypeFrameHeader = 3 << 3; +constexpr uint8_t kAv1ObuTypeTileGroup = 4 << 3; +constexpr uint8_t kAv1ObuTypeMetadata = 5 << 3; +constexpr uint8_t kAv1ObuTypeFrame = 6 << 3; +constexpr uint8_t kAv1ObuTypeTileList = 8 << 3; +constexpr uint8_t kAv1ObuExtensionPresentBit = 0b0'0000'100; +constexpr uint8_t kAv1ObuSizePresentBit = 0b0'0000'010; +constexpr uint8_t kAv1ObuExtensionS1T1 = 0b001'01'000; + +class Av1Obu { + public: + explicit Av1Obu(uint8_t obu_type); + + Av1Obu& WithExtension(uint8_t extension); + Av1Obu& WithoutSize(); + Av1Obu& WithPayload(std::vector payload); + + private: + friend std::vector BuildAv1Frame(std::initializer_list obus); + uint8_t header_; + uint8_t extension_ = 0; + std::vector payload_; +}; + +std::vector BuildAv1Frame(std::initializer_list obus); + +} // namespace webrtc +#endif // MODULES_RTP_RTCP_SOURCE_RTP_PACKETIZER_AV1_TEST_HELPER_H_ diff --git a/modules/rtp_rtcp/source/rtp_packetizer_av1_unittest.cc b/modules/rtp_rtcp/source/rtp_packetizer_av1_unittest.cc index 84d2b35bc6..2151a59295 100644 --- a/modules/rtp_rtcp/source/rtp_packetizer_av1_unittest.cc +++ b/modules/rtp_rtcp/source/rtp_packetizer_av1_unittest.cc @@ -21,6 +21,7 @@ #include "api/scoped_refptr.h" #include "api/video/encoded_image.h" #include "modules/rtp_rtcp/source/rtp_packet_to_send.h" +#include "modules/rtp_rtcp/source/rtp_packetizer_av1_test_helper.h" #include "modules/rtp_rtcp/source/video_rtp_depacketizer_av1.h" #include "test/gmock.h" #include "test/gtest.h" @@ -35,17 +36,6 @@ using ::testing::Le; using ::testing::SizeIs; constexpr uint8_t kNewCodedVideoSequenceBit = 0b00'00'1000; -// All obu types offset by 3 to take correct position in the obu_header. -constexpr uint8_t kObuTypeSequenceHeader = 1 << 3; -constexpr uint8_t kObuTypeTemporalDelimiter = 2 << 3; -constexpr uint8_t kObuTypeFrameHeader = 3 << 3; -constexpr uint8_t kObuTypeTileGroup = 4 << 3; -constexpr uint8_t kObuTypeMetadata = 5 << 3; -constexpr uint8_t kObuTypeFrame = 6 << 3; -constexpr uint8_t kObuTypeTileList = 8 << 3; -constexpr uint8_t kObuExtensionPresentBit = 0b0'0000'100; -constexpr uint8_t kObuSizePresentBit = 0b0'0000'010; -constexpr uint8_t kObuExtensionS1T1 = 0b001'01'000; // Wrapper around rtp_packet to make it look like container of payload bytes. struct RtpPayload { @@ -109,135 +99,90 @@ Av1Frame ReassembleFrame(rtc::ArrayView rtp_payloads) { return Av1Frame(VideoRtpDepacketizerAv1().AssembleFrame(payloads)); } -class Obu { - public: - explicit Obu(uint8_t obu_type) : header_(obu_type | kObuSizePresentBit) { - EXPECT_EQ(obu_type & 0b0'1111'000, obu_type); - } - - Obu& WithExtension(uint8_t extension) { - extension_ = extension; - header_ |= kObuExtensionPresentBit; - return *this; - } - Obu& WithoutSize() { - header_ &= ~kObuSizePresentBit; - return *this; - } - Obu& WithPayload(std::vector payload) { - payload_ = std::move(payload); - return *this; - } - - private: - friend std::vector BuildAv1Frame(std::initializer_list obus); - uint8_t header_; - uint8_t extension_ = 0; - std::vector payload_; -}; - -std::vector BuildAv1Frame(std::initializer_list obus) { - std::vector raw; - for (const Obu& obu : obus) { - raw.push_back(obu.header_); - if (obu.header_ & kObuExtensionPresentBit) { - raw.push_back(obu.extension_); - } - if (obu.header_ & kObuSizePresentBit) { - // write size in leb128 format. - size_t payload_size = obu.payload_.size(); - while (payload_size >= 0x80) { - raw.push_back(0x80 | (payload_size & 0x7F)); - payload_size >>= 7; - } - raw.push_back(payload_size); - } - raw.insert(raw.end(), obu.payload_.begin(), obu.payload_.end()); - } - return raw; -} - TEST(RtpPacketizerAv1Test, PacketizeOneObuWithoutSizeAndExtension) { - auto kFrame = BuildAv1Frame( - {Obu(kObuTypeFrame).WithoutSize().WithPayload({1, 2, 3, 4, 5, 6, 7})}); + auto kFrame = BuildAv1Frame({Av1Obu(kAv1ObuTypeFrame) + .WithoutSize() + .WithPayload({1, 2, 3, 4, 5, 6, 7})}); EXPECT_THAT(Packetize(kFrame, {}), ElementsAre(ElementsAre(0b00'01'0000, // aggregation header - kObuTypeFrame, 1, 2, 3, 4, 5, 6, 7))); + kAv1ObuTypeFrame, 1, 2, 3, 4, 5, 6, 7))); } TEST(RtpPacketizerAv1Test, PacketizeOneObuWithoutSizeWithExtension) { - auto kFrame = BuildAv1Frame({Obu(kObuTypeFrame) + auto kFrame = BuildAv1Frame({Av1Obu(kAv1ObuTypeFrame) .WithoutSize() - .WithExtension(kObuExtensionS1T1) + .WithExtension(kAv1ObuExtensionS1T1) .WithPayload({2, 3, 4, 5, 6, 7})}); - EXPECT_THAT(Packetize(kFrame, {}), - ElementsAre(ElementsAre(0b00'01'0000, // aggregation header - kObuTypeFrame | kObuExtensionPresentBit, - kObuExtensionS1T1, 2, 3, 4, 5, 6, 7))); + EXPECT_THAT( + Packetize(kFrame, {}), + ElementsAre(ElementsAre(0b00'01'0000, // aggregation header + kAv1ObuTypeFrame | kAv1ObuExtensionPresentBit, + kAv1ObuExtensionS1T1, 2, 3, 4, 5, 6, 7))); } TEST(RtpPacketizerAv1Test, RemovesObuSizeFieldWithoutExtension) { auto kFrame = BuildAv1Frame( - {Obu(kObuTypeFrame).WithPayload({11, 12, 13, 14, 15, 16, 17})}); + {Av1Obu(kAv1ObuTypeFrame).WithPayload({11, 12, 13, 14, 15, 16, 17})}); EXPECT_THAT( Packetize(kFrame, {}), ElementsAre(ElementsAre(0b00'01'0000, // aggregation header - kObuTypeFrame, 11, 12, 13, 14, 15, 16, 17))); + kAv1ObuTypeFrame, 11, 12, 13, 14, 15, 16, 17))); } TEST(RtpPacketizerAv1Test, RemovesObuSizeFieldWithExtension) { - auto kFrame = BuildAv1Frame({Obu(kObuTypeFrame) - .WithExtension(kObuExtensionS1T1) + auto kFrame = BuildAv1Frame({Av1Obu(kAv1ObuTypeFrame) + .WithExtension(kAv1ObuExtensionS1T1) .WithPayload({1, 2, 3, 4, 5, 6, 7})}); - EXPECT_THAT(Packetize(kFrame, {}), - ElementsAre(ElementsAre(0b00'01'0000, // aggregation header - kObuTypeFrame | kObuExtensionPresentBit, - kObuExtensionS1T1, 1, 2, 3, 4, 5, 6, 7))); + EXPECT_THAT( + Packetize(kFrame, {}), + ElementsAre(ElementsAre(0b00'01'0000, // aggregation header + kAv1ObuTypeFrame | kAv1ObuExtensionPresentBit, + kAv1ObuExtensionS1T1, 1, 2, 3, 4, 5, 6, 7))); } TEST(RtpPacketizerAv1Test, OmitsSizeForLastObuWhenThreeObusFitsIntoThePacket) { auto kFrame = BuildAv1Frame( - {Obu(kObuTypeSequenceHeader).WithPayload({1, 2, 3, 4, 5, 6}), - Obu(kObuTypeMetadata).WithPayload({11, 12, 13, 14}), - Obu(kObuTypeFrame).WithPayload({21, 22, 23, 24, 25, 26})}); - EXPECT_THAT( - Packetize(kFrame, {}), - ElementsAre(ElementsAre(0b00'11'0000, // aggregation header - 7, kObuTypeSequenceHeader, 1, 2, 3, 4, 5, 6, // - 5, kObuTypeMetadata, 11, 12, 13, 14, // - kObuTypeFrame, 21, 22, 23, 24, 25, 26))); + {Av1Obu(kAv1ObuTypeSequenceHeader).WithPayload({1, 2, 3, 4, 5, 6}), + Av1Obu(kAv1ObuTypeMetadata).WithPayload({11, 12, 13, 14}), + Av1Obu(kAv1ObuTypeFrame).WithPayload({21, 22, 23, 24, 25, 26})}); + EXPECT_THAT(Packetize(kFrame, {}), + ElementsAre(ElementsAre( + 0b00'11'0000, // aggregation header + 7, kAv1ObuTypeSequenceHeader, 1, 2, 3, 4, 5, 6, // + 5, kAv1ObuTypeMetadata, 11, 12, 13, 14, // + kAv1ObuTypeFrame, 21, 22, 23, 24, 25, 26))); } TEST(RtpPacketizerAv1Test, UseSizeForAllObusWhenFourObusFitsIntoThePacket) { auto kFrame = BuildAv1Frame( - {Obu(kObuTypeSequenceHeader).WithPayload({1, 2, 3, 4, 5, 6}), - Obu(kObuTypeMetadata).WithPayload({11, 12, 13, 14}), - Obu(kObuTypeFrameHeader).WithPayload({21, 22, 23}), - Obu(kObuTypeTileGroup).WithPayload({31, 32, 33, 34, 35, 36})}); - EXPECT_THAT( - Packetize(kFrame, {}), - ElementsAre(ElementsAre(0b00'00'0000, // aggregation header - 7, kObuTypeSequenceHeader, 1, 2, 3, 4, 5, 6, // - 5, kObuTypeMetadata, 11, 12, 13, 14, // - 4, kObuTypeFrameHeader, 21, 22, 23, // - 7, kObuTypeTileGroup, 31, 32, 33, 34, 35, 36))); + {Av1Obu(kAv1ObuTypeSequenceHeader).WithPayload({1, 2, 3, 4, 5, 6}), + Av1Obu(kAv1ObuTypeMetadata).WithPayload({11, 12, 13, 14}), + Av1Obu(kAv1ObuTypeFrameHeader).WithPayload({21, 22, 23}), + Av1Obu(kAv1ObuTypeTileGroup).WithPayload({31, 32, 33, 34, 35, 36})}); + EXPECT_THAT(Packetize(kFrame, {}), + ElementsAre(ElementsAre( + 0b00'00'0000, // aggregation header + 7, kAv1ObuTypeSequenceHeader, 1, 2, 3, 4, 5, 6, // + 5, kAv1ObuTypeMetadata, 11, 12, 13, 14, // + 4, kAv1ObuTypeFrameHeader, 21, 22, 23, // + 7, kAv1ObuTypeTileGroup, 31, 32, 33, 34, 35, 36))); } TEST(RtpPacketizerAv1Test, DiscardsTemporalDelimiterAndTileListObu) { auto kFrame = BuildAv1Frame( - {Obu(kObuTypeTemporalDelimiter), Obu(kObuTypeMetadata), - Obu(kObuTypeTileList).WithPayload({1, 2, 3, 4, 5, 6}), - Obu(kObuTypeFrameHeader).WithPayload({21, 22, 23}), - Obu(kObuTypeTileGroup).WithPayload({31, 32, 33, 34, 35, 36})}); + {Av1Obu(kAv1ObuTypeTemporalDelimiter), Av1Obu(kAv1ObuTypeMetadata), + Av1Obu(kAv1ObuTypeTileList).WithPayload({1, 2, 3, 4, 5, 6}), + Av1Obu(kAv1ObuTypeFrameHeader).WithPayload({21, 22, 23}), + Av1Obu(kAv1ObuTypeTileGroup).WithPayload({31, 32, 33, 34, 35, 36})}); EXPECT_THAT( Packetize(kFrame, {}), ElementsAre(ElementsAre(0b00'11'0000, // aggregation header 1, - kObuTypeMetadata, // - 4, kObuTypeFrameHeader, 21, 22, + kAv1ObuTypeMetadata, // + 4, kAv1ObuTypeFrameHeader, 21, 22, 23, // - kObuTypeTileGroup, 31, 32, 33, 34, 35, 36))); + kAv1ObuTypeTileGroup, 31, 32, 33, 34, 35, 36))); } TEST(RtpPacketizerAv1Test, SplitTwoObusIntoTwoPacketForceSplitObuHeader) { @@ -246,17 +191,17 @@ TEST(RtpPacketizerAv1Test, SplitTwoObusIntoTwoPacketForceSplitObuHeader) { const uint8_t kExpectPayload1[6] = { 0b01'10'0000, // aggregation_header 3, - kObuTypeFrameHeader | kObuExtensionPresentBit, - kObuExtensionS1T1, + kAv1ObuTypeFrameHeader | kAv1ObuExtensionPresentBit, + kAv1ObuExtensionS1T1, 21, // - kObuTypeTileGroup | kObuExtensionPresentBit}; + kAv1ObuTypeTileGroup | kAv1ObuExtensionPresentBit}; const uint8_t kExpectPayload2[6] = {0b10'01'0000, // aggregation_header - kObuExtensionS1T1, 11, 12, 13, 14}; - auto kFrame = BuildAv1Frame({Obu(kObuTypeFrameHeader) - .WithExtension(kObuExtensionS1T1) + kAv1ObuExtensionS1T1, 11, 12, 13, 14}; + auto kFrame = BuildAv1Frame({Av1Obu(kAv1ObuTypeFrameHeader) + .WithExtension(kAv1ObuExtensionS1T1) .WithPayload({21}), - Obu(kObuTypeTileGroup) - .WithExtension(kObuExtensionS1T1) + Av1Obu(kAv1ObuTypeTileGroup) + .WithExtension(kAv1ObuExtensionS1T1) .WithPayload({11, 12, 13, 14})}); RtpPacketizer::PayloadSizeLimits limits; @@ -269,7 +214,7 @@ TEST(RtpPacketizerAv1Test, SplitTwoObusIntoTwoPacketForceSplitObuHeader) { TEST(RtpPacketizerAv1Test, SetsNbitAtTheFirstPacketOfAKeyFrameWithSequenceHeader) { auto kFrame = BuildAv1Frame( - {Obu(kObuTypeSequenceHeader).WithPayload({1, 2, 3, 4, 5, 6, 7})}); + {Av1Obu(kAv1ObuTypeSequenceHeader).WithPayload({1, 2, 3, 4, 5, 6, 7})}); RtpPacketizer::PayloadSizeLimits limits; limits.max_payload_len = 6; auto packets = Packetize(kFrame, limits, VideoFrameType::kVideoFrameKey); @@ -280,8 +225,8 @@ TEST(RtpPacketizerAv1Test, TEST(RtpPacketizerAv1Test, DoesntSetNbitAtThePacketsOfAKeyFrameWithoutSequenceHeader) { - auto kFrame = - BuildAv1Frame({Obu(kObuTypeFrame).WithPayload({1, 2, 3, 4, 5, 6, 7})}); + auto kFrame = BuildAv1Frame( + {Av1Obu(kAv1ObuTypeFrame).WithPayload({1, 2, 3, 4, 5, 6, 7})}); RtpPacketizer::PayloadSizeLimits limits; limits.max_payload_len = 6; auto packets = Packetize(kFrame, limits, VideoFrameType::kVideoFrameKey); @@ -293,7 +238,7 @@ TEST(RtpPacketizerAv1Test, TEST(RtpPacketizerAv1Test, DoesntSetNbitAtThePacketsOfADeltaFrame) { // Even when that delta frame starts with a (redundant) sequence header. auto kFrame = BuildAv1Frame( - {Obu(kObuTypeSequenceHeader).WithPayload({1, 2, 3, 4, 5, 6, 7})}); + {Av1Obu(kAv1ObuTypeSequenceHeader).WithPayload({1, 2, 3, 4, 5, 6, 7})}); RtpPacketizer::PayloadSizeLimits limits; limits.max_payload_len = 6; auto packets = Packetize(kFrame, limits, VideoFrameType::kVideoFrameDelta); @@ -308,8 +253,9 @@ TEST(RtpPacketizerAv1Test, DoesntSetNbitAtThePacketsOfADeltaFrame) { // RtpDepacketizer always inserts obu_size fields in the output, use frame where // each obu has obu_size fields for more streight forward validation. TEST(RtpPacketizerAv1Test, SplitSingleObuIntoTwoPackets) { - auto kFrame = BuildAv1Frame( - {Obu(kObuTypeFrame).WithPayload({11, 12, 13, 14, 15, 16, 17, 18, 19})}); + auto kFrame = + BuildAv1Frame({Av1Obu(kAv1ObuTypeFrame) + .WithPayload({11, 12, 13, 14, 15, 16, 17, 18, 19})}); RtpPacketizer::PayloadSizeLimits limits; limits.max_payload_len = 8; @@ -322,7 +268,7 @@ TEST(RtpPacketizerAv1Test, SplitSingleObuIntoTwoPackets) { TEST(RtpPacketizerAv1Test, SplitSingleObuIntoManyPackets) { auto kFrame = BuildAv1Frame( - {Obu(kObuTypeFrame).WithPayload(std::vector(1200, 27))}); + {Av1Obu(kAv1ObuTypeFrame).WithPayload(std::vector(1200, 27))}); RtpPacketizer::PayloadSizeLimits limits; limits.max_payload_len = 100; @@ -336,7 +282,7 @@ TEST(RtpPacketizerAv1Test, SplitSingleObuIntoManyPackets) { TEST(RtpPacketizerAv1Test, SetMarkerBitForLastPacketInEndOfPictureFrame) { auto kFrame = BuildAv1Frame( - {Obu(kObuTypeFrame).WithPayload(std::vector(200, 27))}); + {Av1Obu(kAv1ObuTypeFrame).WithPayload(std::vector(200, 27))}); RtpPacketizer::PayloadSizeLimits limits; limits.max_payload_len = 100; @@ -350,7 +296,7 @@ TEST(RtpPacketizerAv1Test, SetMarkerBitForLastPacketInEndOfPictureFrame) { TEST(RtpPacketizerAv1Test, DoesntSetMarkerBitForPacketsNotInEndOfPictureFrame) { auto kFrame = BuildAv1Frame( - {Obu(kObuTypeFrame).WithPayload(std::vector(200, 27))}); + {Av1Obu(kAv1ObuTypeFrame).WithPayload(std::vector(200, 27))}); RtpPacketizer::PayloadSizeLimits limits; limits.max_payload_len = 100; @@ -366,8 +312,8 @@ TEST(RtpPacketizerAv1Test, SplitTwoObusIntoTwoPackets) { // 2nd OBU is too large to fit into one packet, so its head would be in the // same packet as the 1st OBU. auto kFrame = BuildAv1Frame( - {Obu(kObuTypeSequenceHeader).WithPayload({11, 12}), - Obu(kObuTypeFrame).WithPayload({1, 2, 3, 4, 5, 6, 7, 8, 9})}); + {Av1Obu(kAv1ObuTypeSequenceHeader).WithPayload({11, 12}), + Av1Obu(kAv1ObuTypeFrame).WithPayload({1, 2, 3, 4, 5, 6, 7, 8, 9})}); RtpPacketizer::PayloadSizeLimits limits; limits.max_payload_len = 8; @@ -380,8 +326,9 @@ TEST(RtpPacketizerAv1Test, SplitTwoObusIntoTwoPackets) { TEST(RtpPacketizerAv1Test, SplitSingleObuIntoTwoPacketsBecauseOfSinglePacketLimit) { - auto kFrame = BuildAv1Frame( - {Obu(kObuTypeFrame).WithPayload({11, 12, 13, 14, 15, 16, 17, 18, 19})}); + auto kFrame = + BuildAv1Frame({Av1Obu(kAv1ObuTypeFrame) + .WithPayload({11, 12, 13, 14, 15, 16, 17, 18, 19})}); RtpPacketizer::PayloadSizeLimits limits; limits.max_payload_len = 10; limits.single_packet_reduction_len = 8; diff --git a/modules/rtp_rtcp/source/rtp_rtcp_config.h b/modules/rtp_rtcp/source/rtp_rtcp_config.h index 6863c4c353..66caadd578 100644 --- a/modules/rtp_rtcp/source/rtp_rtcp_config.h +++ b/modules/rtp_rtcp/source/rtp_rtcp_config.h @@ -11,13 +11,15 @@ #ifndef MODULES_RTP_RTCP_SOURCE_RTP_RTCP_CONFIG_H_ #define MODULES_RTP_RTCP_SOURCE_RTP_RTCP_CONFIG_H_ +#include "api/units/time_delta.h" + // Configuration file for RTP utilities (RTPSender, RTPReceiver ...) namespace webrtc { -enum { kDefaultMaxReorderingThreshold = 50 }; // In sequence numbers. -enum { kRtcpMaxNackFields = 253 }; +constexpr int kDefaultMaxReorderingThreshold = 5; // In sequence numbers. +constexpr int kRtcpMaxNackFields = 253; -enum { RTCP_SEND_BEFORE_KEY_FRAME_MS = 100 }; -enum { RTCP_MAX_REPORT_BLOCKS = 31 }; // RFC 3550 page 37 +constexpr TimeDelta RTCP_SEND_BEFORE_KEY_FRAME = TimeDelta::Millis(100); +constexpr int RTCP_MAX_REPORT_BLOCKS = 31; // RFC 3550 page 37 } // namespace webrtc #endif // MODULES_RTP_RTCP_SOURCE_RTP_RTCP_CONFIG_H_ diff --git a/modules/rtp_rtcp/source/rtp_rtcp_impl.cc b/modules/rtp_rtcp/source/rtp_rtcp_impl.cc index 69a64fe3f6..3f985e213a 100644 --- a/modules/rtp_rtcp/source/rtp_rtcp_impl.cc +++ b/modules/rtp_rtcp/source/rtp_rtcp_impl.cc @@ -21,9 +21,12 @@ #include "api/transport/field_trial_based_config.h" #include "modules/rtp_rtcp/source/rtcp_packet/dlrr.h" +#include "modules/rtp_rtcp/source/rtcp_sender.h" #include "modules/rtp_rtcp/source/rtp_rtcp_config.h" +#include "modules/rtp_rtcp/source/rtp_rtcp_interface.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" +#include "system_wrappers/include/ntp_time.h" #ifdef _WIN32 // Disable warning C4355: 'this' : used in base member initializer list. @@ -57,7 +60,8 @@ std::unique_ptr RtpRtcp::DEPRECATED_Create( } ModuleRtpRtcpImpl::ModuleRtpRtcpImpl(const Configuration& configuration) - : rtcp_sender_(configuration), + : rtcp_sender_( + RTCPSender::Configuration::FromRtpRtcpConfiguration(configuration)), rtcp_receiver_(configuration, this), clock_(configuration.clock), last_bitrate_process_time_(clock_->TimeInMilliseconds()), @@ -122,20 +126,18 @@ void ModuleRtpRtcpImpl::Process() { // processed RTT for at least |kRtpRtcpRttProcessTimeMs| milliseconds. // Note that LastReceivedReportBlockMs() grabs a lock, so check // |process_rtt| first. - if (process_rtt && + if (process_rtt && rtt_stats_ != nullptr && rtcp_receiver_.LastReceivedReportBlockMs() > last_rtt_process_time_) { - std::vector receive_blocks; - rtcp_receiver_.StatisticsReceived(&receive_blocks); - int64_t max_rtt = 0; - for (std::vector::iterator it = receive_blocks.begin(); - it != receive_blocks.end(); ++it) { - int64_t rtt = 0; - rtcp_receiver_.RTT(it->sender_ssrc, &rtt, NULL, NULL, NULL); - max_rtt = (rtt > max_rtt) ? rtt : max_rtt; + int64_t max_rtt_ms = 0; + for (const auto& block : rtcp_receiver_.GetLatestReportBlockData()) { + if (block.last_rtt_ms() > max_rtt_ms) { + max_rtt_ms = block.last_rtt_ms(); + } } // Report the rtt. - if (rtt_stats_ && max_rtt != 0) - rtt_stats_->OnRttUpdate(max_rtt); + if (max_rtt_ms > 0) { + rtt_stats_->OnRttUpdate(max_rtt_ms); + } } // Verify receiver reports are delivered and the reported sequence number @@ -192,7 +194,7 @@ void ModuleRtpRtcpImpl::Process() { if (rtcp_sender_.TimeToSendRTCPReport()) rtcp_sender_.SendRTCP(GetFeedbackState(), kRtcpReport); - if (TMMBR() && rtcp_receiver_.UpdateTmmbrTimers()) { + if (rtcp_sender_.TMMBR() && rtcp_receiver_.UpdateTmmbrTimers()) { rtcp_receiver_.NotifyTmmbrUpdated(); } } @@ -312,8 +314,19 @@ RTCPSender::FeedbackState ModuleRtpRtcpImpl::GetFeedbackState() { } state.receiver = &rtcp_receiver_; - LastReceivedNTP(&state.last_rr_ntp_secs, &state.last_rr_ntp_frac, - &state.remote_sr); + uint32_t received_ntp_secs = 0; + uint32_t received_ntp_frac = 0; + state.remote_sr = 0; + if (rtcp_receiver_.NTP(&received_ntp_secs, &received_ntp_frac, + /*rtcp_arrival_time_secs=*/&state.last_rr_ntp_secs, + /*rtcp_arrival_time_frac=*/&state.last_rr_ntp_frac, + /*rtcp_timestamp=*/nullptr, + /*remote_sender_packet_count=*/nullptr, + /*remote_sender_octet_count=*/nullptr, + /*remote_sender_reports_count=*/nullptr)) { + state.remote_sr = ((received_ntp_secs & 0x0000ffff) << 16) + + ((received_ntp_frac & 0xffff0000) >> 16); + } state.last_xr_rtis = rtcp_receiver_.ConsumeReceivedXrReferenceTimeInfo(); @@ -326,9 +339,7 @@ RTCPSender::FeedbackState ModuleRtpRtcpImpl::GetFeedbackState() { int32_t ModuleRtpRtcpImpl::SetSendingStatus(const bool sending) { if (rtcp_sender_.Sending() != sending) { // Sends RTCP BYE when going from true to false - if (rtcp_sender_.SetSendingStatus(GetFeedbackState(), sending) != 0) { - RTC_LOG(LS_WARNING) << "Failed to send RTCP BYE"; - } + rtcp_sender_.SetSendingStatus(GetFeedbackState(), sending); } return 0; } @@ -370,7 +381,16 @@ bool ModuleRtpRtcpImpl::OnSendingRtpFrame(uint32_t timestamp, if (!Sending()) return false; - rtcp_sender_.SetLastRtpTime(timestamp, capture_time_ms, payload_type); + // TODO(bugs.webrtc.org/12873): Migrate this method and it's users to use + // optional Timestamps. + absl::optional capture_time; + if (capture_time_ms > 0) { + capture_time = Timestamp::Millis(capture_time_ms); + } + absl::optional payload_type_optional; + if (payload_type >= 0) + payload_type_optional = payload_type; + rtcp_sender_.SetLastRtpTime(timestamp, capture_time, payload_type_optional); // Make sure an RTCP report isn't queued behind a key frame. if (rtcp_sender_.TimeToSendRTCPReport(force_sender_report)) rtcp_sender_.SendRTCP(GetFeedbackState(), kRtcpReport); @@ -467,19 +487,6 @@ int32_t ModuleRtpRtcpImpl::SetCNAME(const char* c_name) { return rtcp_sender_.SetCNAME(c_name); } -int32_t ModuleRtpRtcpImpl::AddMixedCNAME(uint32_t ssrc, const char* c_name) { - return rtcp_sender_.AddMixedCNAME(ssrc, c_name); -} - -int32_t ModuleRtpRtcpImpl::RemoveMixedCNAME(const uint32_t ssrc) { - return rtcp_sender_.RemoveMixedCNAME(ssrc); -} - -int32_t ModuleRtpRtcpImpl::RemoteCNAME(const uint32_t remote_ssrc, - char c_name[RTCP_CNAME_SIZE]) const { - return rtcp_receiver_.CNAME(remote_ssrc, c_name); -} - int32_t ModuleRtpRtcpImpl::RemoteNTP(uint32_t* received_ntpsecs, uint32_t* received_ntpfrac, uint32_t* rtcp_arrival_time_secs, @@ -487,7 +494,10 @@ int32_t ModuleRtpRtcpImpl::RemoteNTP(uint32_t* received_ntpsecs, uint32_t* rtcp_timestamp) const { return rtcp_receiver_.NTP(received_ntpsecs, received_ntpfrac, rtcp_arrival_time_secs, rtcp_arrival_time_frac, - rtcp_timestamp) + rtcp_timestamp, + /*remote_sender_packet_count=*/nullptr, + /*remote_sender_octet_count=*/nullptr, + /*remote_sender_reports_count=*/nullptr) ? 0 : -1; } @@ -527,39 +537,6 @@ int32_t ModuleRtpRtcpImpl::SendRTCP(RTCPPacketType packet_type) { return rtcp_sender_.SendRTCP(GetFeedbackState(), packet_type); } -int32_t ModuleRtpRtcpImpl::SetRTCPApplicationSpecificData( - const uint8_t sub_type, - const uint32_t name, - const uint8_t* data, - const uint16_t length) { - RTC_NOTREACHED() << "Not implemented"; - return -1; -} - -// TODO(asapersson): Replace this method with the one below. -int32_t ModuleRtpRtcpImpl::DataCountersRTP(size_t* bytes_sent, - uint32_t* packets_sent) const { - StreamDataCounters rtp_stats; - StreamDataCounters rtx_stats; - rtp_sender_->packet_sender.GetDataCounters(&rtp_stats, &rtx_stats); - - if (bytes_sent) { - // TODO(http://crbug.com/webrtc/10525): Bytes sent should only include - // payload bytes, not header and padding bytes. - *bytes_sent = rtp_stats.transmitted.payload_bytes + - rtp_stats.transmitted.padding_bytes + - rtp_stats.transmitted.header_bytes + - rtx_stats.transmitted.payload_bytes + - rtx_stats.transmitted.padding_bytes + - rtx_stats.transmitted.header_bytes; - } - if (packets_sent) { - *packets_sent = - rtp_stats.transmitted.packets + rtx_stats.transmitted.packets; - } - return 0; -} - void ModuleRtpRtcpImpl::GetSendStreamDataCounters( StreamDataCounters* rtp_counters, StreamDataCounters* rtx_counters) const { @@ -567,16 +544,31 @@ void ModuleRtpRtcpImpl::GetSendStreamDataCounters( } // Received RTCP report. -int32_t ModuleRtpRtcpImpl::RemoteRTCPStat( - std::vector* receive_blocks) const { - return rtcp_receiver_.StatisticsReceived(receive_blocks); -} - std::vector ModuleRtpRtcpImpl::GetLatestReportBlockData() const { return rtcp_receiver_.GetLatestReportBlockData(); } +absl::optional +ModuleRtpRtcpImpl::GetSenderReportStats() const { + SenderReportStats stats; + uint32_t remote_timestamp_secs; + uint32_t remote_timestamp_frac; + uint32_t arrival_timestamp_secs; + uint32_t arrival_timestamp_frac; + if (rtcp_receiver_.NTP(&remote_timestamp_secs, &remote_timestamp_frac, + &arrival_timestamp_secs, &arrival_timestamp_frac, + /*rtcp_timestamp=*/nullptr, &stats.packets_sent, + &stats.bytes_sent, &stats.reports_count)) { + stats.last_remote_timestamp.Set(remote_timestamp_secs, + remote_timestamp_frac); + stats.last_arrival_timestamp.Set(arrival_timestamp_secs, + arrival_timestamp_frac); + return stats; + } + return absl::nullopt; +} + // (REMB) Receiver Estimated Max Bitrate. void ModuleRtpRtcpImpl::SetRemb(int64_t bitrate_bps, std::vector ssrcs) { @@ -591,12 +583,6 @@ void ModuleRtpRtcpImpl::SetExtmapAllowMixed(bool extmap_allow_mixed) { rtp_sender_->packet_generator.SetExtmapAllowMixed(extmap_allow_mixed); } -int32_t ModuleRtpRtcpImpl::RegisterSendRtpHeaderExtension( - const RTPExtensionType type, - const uint8_t id) { - return rtp_sender_->packet_generator.RegisterRtpHeaderExtension(type, id); -} - void ModuleRtpRtcpImpl::RegisterRtpHeaderExtension(absl::string_view uri, int id) { bool registered = @@ -613,15 +599,6 @@ void ModuleRtpRtcpImpl::DeregisterSendRtpHeaderExtension( rtp_sender_->packet_generator.DeregisterRtpHeaderExtension(uri); } -// (TMMBR) Temporary Max Media Bit Rate. -bool ModuleRtpRtcpImpl::TMMBR() const { - return rtcp_sender_.TMMBR(); -} - -void ModuleRtpRtcpImpl::SetTMMBRStatus(const bool enable) { - rtcp_sender_.SetTMMBRStatus(enable); -} - void ModuleRtpRtcpImpl::SetTmmbn(std::vector bounding_set) { rtcp_sender_.SetTmmbn(std::move(bounding_set)); } @@ -718,6 +695,11 @@ void ModuleRtpRtcpImpl::SetRemoteSSRC(const uint32_t ssrc) { rtcp_receiver_.SetRemoteSSRC(ssrc); } +void ModuleRtpRtcpImpl::SetLocalSsrc(uint32_t local_ssrc) { + rtcp_receiver_.set_local_media_ssrc(local_ssrc); + rtcp_sender_.SetSsrc(local_ssrc); +} + RtpSendRates ModuleRtpRtcpImpl::GetSendRates() const { return rtp_sender_->packet_sender.GetSendRates(); } @@ -763,23 +745,6 @@ void ModuleRtpRtcpImpl::OnReceivedRtcpReportBlocks( } } -bool ModuleRtpRtcpImpl::LastReceivedNTP( - uint32_t* rtcp_arrival_time_secs, // When we got the last report. - uint32_t* rtcp_arrival_time_frac, - uint32_t* remote_sr) const { - // Remote SR: NTP inside the last received (mid 16 bits from sec and frac). - uint32_t ntp_secs = 0; - uint32_t ntp_frac = 0; - - if (!rtcp_receiver_.NTP(&ntp_secs, &ntp_frac, rtcp_arrival_time_secs, - rtcp_arrival_time_frac, NULL)) { - return false; - } - *remote_sr = - ((ntp_secs & 0x0000ffff) << 16) + ((ntp_frac & 0xffff0000) >> 16); - return true; -} - void ModuleRtpRtcpImpl::set_rtt_ms(int64_t rtt_ms) { { MutexLock lock(&mutex_rtt_); diff --git a/modules/rtp_rtcp/source/rtp_rtcp_impl.h b/modules/rtp_rtcp/source/rtp_rtcp_impl.h index e30f1cc3d0..b0e0b41c48 100644 --- a/modules/rtp_rtcp/source/rtp_rtcp_impl.h +++ b/modules/rtp_rtcp/source/rtp_rtcp_impl.h @@ -63,6 +63,7 @@ class ModuleRtpRtcpImpl : public RtpRtcp, public RTCPReceiver::ModuleRtpRtcp { size_t incoming_packet_length) override; void SetRemoteSSRC(uint32_t ssrc) override; + void SetLocalSsrc(uint32_t ssrc) override; // Sender part. void RegisterSendPayloadFrequency(int payload_type, @@ -73,8 +74,6 @@ class ModuleRtpRtcpImpl : public RtpRtcp, public RTCPReceiver::ModuleRtpRtcp { void SetExtmapAllowMixed(bool extmap_allow_mixed) override; // Register RTP header extension. - int32_t RegisterSendRtpHeaderExtension(RTPExtensionType type, - uint8_t id) override; void RegisterRtpHeaderExtension(absl::string_view uri, int id) override; int32_t DeregisterSendRtpHeaderExtension(RTPExtensionType type) override; void DeregisterSendRtpHeaderExtension(absl::string_view uri) override; @@ -166,10 +165,6 @@ class ModuleRtpRtcpImpl : public RtpRtcp, public RTCPReceiver::ModuleRtpRtcp { // Set RTCP CName. int32_t SetCNAME(const char* c_name) override; - // Get remote CName. - int32_t RemoteCNAME(uint32_t remote_ssrc, - char c_name[RTCP_CNAME_SIZE]) const override; - // Get remote NTP. int32_t RemoteNTP(uint32_t* received_ntp_secs, uint32_t* received_ntp_frac, @@ -177,10 +172,6 @@ class ModuleRtpRtcpImpl : public RtpRtcp, public RTCPReceiver::ModuleRtpRtcp { uint32_t* rtcp_arrival_time_frac, uint32_t* rtcp_timestamp) const override; - int32_t AddMixedCNAME(uint32_t ssrc, const char* c_name) override; - - int32_t RemoveMixedCNAME(uint32_t ssrc) override; - // Get RoundTripTime. int32_t RTT(uint32_t remote_ssrc, int64_t* rtt, @@ -194,32 +185,21 @@ class ModuleRtpRtcpImpl : public RtpRtcp, public RTCPReceiver::ModuleRtpRtcp { // Normal SR and RR are triggered via the process function. int32_t SendRTCP(RTCPPacketType rtcpPacketType) override; - // Statistics of the amount of data sent and received. - int32_t DataCountersRTP(size_t* bytes_sent, - uint32_t* packets_sent) const override; - void GetSendStreamDataCounters( StreamDataCounters* rtp_counters, StreamDataCounters* rtx_counters) const override; - // Get received RTCP report, report block. - int32_t RemoteRTCPStat( - std::vector* receive_blocks) const override; // A snapshot of the most recent Report Block with additional data of // interest to statistics. Used to implement RTCRemoteInboundRtpStreamStats. // Within this list, the ReportBlockData::RTCPReportBlock::source_ssrc(), // which is the SSRC of the corresponding outbound RTP stream, is unique. std::vector GetLatestReportBlockData() const override; + absl::optional GetSenderReportStats() const override; // (REMB) Receiver Estimated Max Bitrate. void SetRemb(int64_t bitrate_bps, std::vector ssrcs) override; void UnsetRemb() override; - // (TMMBR) Temporary Max Media Bit Rate. - bool TMMBR() const override; - - void SetTMMBRStatus(bool enable) override; - void SetTmmbn(std::vector bounding_set) override; size_t MaxRtpPacketSize() const override; @@ -241,22 +221,12 @@ class ModuleRtpRtcpImpl : public RtpRtcp, public RTCPReceiver::ModuleRtpRtcp { void SendCombinedRtcpPacket( std::vector> rtcp_packets) override; - // (APP) Application specific data. - int32_t SetRTCPApplicationSpecificData(uint8_t sub_type, - uint32_t name, - const uint8_t* data, - uint16_t length) override; - // Video part. int32_t SendLossNotification(uint16_t last_decoded_seq_num, uint16_t last_received_seq_num, bool decodability_flag, bool buffering_allowed) override; - bool LastReceivedNTP(uint32_t* NTPsecs, - uint32_t* NTPfrac, - uint32_t* remote_sr) const; - RtpSendRates GetSendRates() const override; void OnReceivedNack( diff --git a/modules/rtp_rtcp/source/rtp_rtcp_impl2.cc b/modules/rtp_rtcp/source/rtp_rtcp_impl2.cc index 94dc2977e0..7fae1e3bd0 100644 --- a/modules/rtp_rtcp/source/rtp_rtcp_impl2.cc +++ b/modules/rtp_rtcp/source/rtp_rtcp_impl2.cc @@ -19,11 +19,18 @@ #include #include +#include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "api/transport/field_trial_based_config.h" +#include "api/units/time_delta.h" +#include "api/units/timestamp.h" #include "modules/rtp_rtcp/source/rtcp_packet/dlrr.h" #include "modules/rtp_rtcp/source/rtp_rtcp_config.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" +#include "rtc_base/task_utils/to_queued_task.h" +#include "rtc_base/time_utils.h" +#include "system_wrappers/include/ntp_time.h" #ifdef _WIN32 // Disable warning C4355: 'this' : used in base member initializer list. @@ -32,10 +39,25 @@ namespace webrtc { namespace { -const int64_t kRtpRtcpMaxIdleTimeProcessMs = 5; const int64_t kDefaultExpectedRetransmissionTimeMs = 125; constexpr TimeDelta kRttUpdateInterval = TimeDelta::Millis(1000); + +RTCPSender::Configuration AddRtcpSendEvaluationCallback( + RTCPSender::Configuration config, + std::function send_evaluation_callback) { + config.schedule_next_rtcp_send_evaluation_function = + std::move(send_evaluation_callback); + return config; +} + +int DelayMillisForDuration(TimeDelta duration) { + // TimeDelta::ms() rounds downwards sometimes which leads to too little time + // slept. Account for this, unless |duration| is exactly representable in + // millisecs. + return (duration.us() + rtc::kNumMillisecsPerSec - 1) / + rtc::kNumMicrosecsPerMillisec; +} } // namespace ModuleRtpRtcpImpl2::RtpSenderContext::RtpSenderContext( @@ -54,12 +76,13 @@ void ModuleRtpRtcpImpl2::RtpSenderContext::AssignSequenceNumber( ModuleRtpRtcpImpl2::ModuleRtpRtcpImpl2(const Configuration& configuration) : worker_queue_(TaskQueueBase::Current()), - rtcp_sender_(configuration), + rtcp_sender_(AddRtcpSendEvaluationCallback( + RTCPSender::Configuration::FromRtpRtcpConfiguration(configuration), + [this](TimeDelta duration) { + ScheduleRtcpSendEvaluation(duration); + })), rtcp_receiver_(configuration, this), clock_(configuration.clock), - last_rtt_process_time_(clock_->TimeInMilliseconds()), - next_process_time_(clock_->TimeInMilliseconds() + - kRtpRtcpMaxIdleTimeProcessMs), packet_overhead_(28), // IPV4 UDP. nack_last_time_sent_full_ms_(0), nack_last_seq_number_sent_(0), @@ -67,7 +90,7 @@ ModuleRtpRtcpImpl2::ModuleRtpRtcpImpl2(const Configuration& configuration) rtt_stats_(configuration.rtt_stats), rtt_ms_(0) { RTC_DCHECK(worker_queue_); - process_thread_checker_.Detach(); + packet_sequence_checker_.Detach(); if (!configuration.receiver_only) { rtp_sender_ = std::make_unique(configuration); // Make sure rtcp sender use same timestamp offset as rtp sender. @@ -103,44 +126,6 @@ std::unique_ptr ModuleRtpRtcpImpl2::Create( return std::make_unique(configuration); } -// Returns the number of milliseconds until the module want a worker thread -// to call Process. -int64_t ModuleRtpRtcpImpl2::TimeUntilNextProcess() { - RTC_DCHECK_RUN_ON(&process_thread_checker_); - return std::max(0, - next_process_time_ - clock_->TimeInMilliseconds()); -} - -// Process any pending tasks such as timeouts (non time critical events). -void ModuleRtpRtcpImpl2::Process() { - RTC_DCHECK_RUN_ON(&process_thread_checker_); - - const Timestamp now = clock_->CurrentTime(); - - // TODO(bugs.webrtc.org/11581): Figure out why we need to call Process() 200 - // times a second. - next_process_time_ = now.ms() + kRtpRtcpMaxIdleTimeProcessMs; - - // TODO(bugs.webrtc.org/11581): once we don't use Process() to trigger - // calls to SendRTCP(), the only remaining timer will require remote_bitrate_ - // to be not null. In that case, we can disable the timer when it is null. - if (remote_bitrate_ && rtcp_sender_.Sending() && rtcp_sender_.TMMBR()) { - unsigned int target_bitrate = 0; - std::vector ssrcs; - if (remote_bitrate_->LatestEstimate(&ssrcs, &target_bitrate)) { - if (!ssrcs.empty()) { - target_bitrate = target_bitrate / ssrcs.size(); - } - rtcp_sender_.SetTargetBitrate(target_bitrate); - } - } - - // TODO(bugs.webrtc.org/11581): Run this on a separate set of delayed tasks - // based off of next_time_to_send_rtcp_ in RTCPSender. - if (rtcp_sender_.TimeToSendRTCPReport()) - rtcp_sender_.SendRTCP(GetFeedbackState(), kRtcpReport); -} - void ModuleRtpRtcpImpl2::SetRtxSendStatus(int mode) { rtp_sender_->packet_generator.SetRtxStatus(mode); } @@ -168,6 +153,7 @@ absl::optional ModuleRtpRtcpImpl2::FlexfecSsrc() const { void ModuleRtpRtcpImpl2::IncomingRtcpPacket(const uint8_t* rtcp_packet, const size_t length) { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); rtcp_receiver_.IncomingPacket(rtcp_packet, length); } @@ -218,6 +204,12 @@ RtpState ModuleRtpRtcpImpl2::GetRtxState() const { return rtp_sender_->packet_generator.GetRtxRtpState(); } +uint32_t ModuleRtpRtcpImpl2::local_media_ssrc() const { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + RTC_DCHECK_EQ(rtcp_receiver_.local_media_ssrc(), rtcp_sender_.SSRC()); + return rtcp_receiver_.local_media_ssrc(); +} + void ModuleRtpRtcpImpl2::SetRid(const std::string& rid) { if (rtp_sender_) { rtp_sender_->packet_generator.SetRid(rid); @@ -260,8 +252,19 @@ RTCPSender::FeedbackState ModuleRtpRtcpImpl2::GetFeedbackState() { } state.receiver = &rtcp_receiver_; - LastReceivedNTP(&state.last_rr_ntp_secs, &state.last_rr_ntp_frac, - &state.remote_sr); + uint32_t received_ntp_secs = 0; + uint32_t received_ntp_frac = 0; + state.remote_sr = 0; + if (rtcp_receiver_.NTP(&received_ntp_secs, &received_ntp_frac, + /*rtcp_arrival_time_secs=*/&state.last_rr_ntp_secs, + /*rtcp_arrival_time_frac=*/&state.last_rr_ntp_frac, + /*rtcp_timestamp=*/nullptr, + /*remote_sender_packet_count=*/nullptr, + /*remote_sender_octet_count=*/nullptr, + /*remote_sender_reports_count=*/nullptr)) { + state.remote_sr = ((received_ntp_secs & 0x0000ffff) << 16) + + ((received_ntp_frac & 0xffff0000) >> 16); + } state.last_xr_rtis = rtcp_receiver_.ConsumeReceivedXrReferenceTimeInfo(); @@ -274,9 +277,7 @@ RTCPSender::FeedbackState ModuleRtpRtcpImpl2::GetFeedbackState() { int32_t ModuleRtpRtcpImpl2::SetSendingStatus(const bool sending) { if (rtcp_sender_.Sending() != sending) { // Sends RTCP BYE when going from true to false - if (rtcp_sender_.SetSendingStatus(GetFeedbackState(), sending) != 0) { - RTC_LOG(LS_WARNING) << "Failed to send RTCP BYE"; - } + rtcp_sender_.SetSendingStatus(GetFeedbackState(), sending); } return 0; } @@ -318,7 +319,16 @@ bool ModuleRtpRtcpImpl2::OnSendingRtpFrame(uint32_t timestamp, if (!Sending()) return false; - rtcp_sender_.SetLastRtpTime(timestamp, capture_time_ms, payload_type); + // TODO(bugs.webrtc.org/12873): Migrate this method and it's users to use + // optional Timestamps. + absl::optional capture_time; + if (capture_time_ms > 0) { + capture_time = Timestamp::Millis(capture_time_ms); + } + absl::optional payload_type_optional; + if (payload_type >= 0) + payload_type_optional = payload_type; + rtcp_sender_.SetLastRtpTime(timestamp, capture_time, payload_type_optional); // Make sure an RTCP report isn't queued behind a key frame. if (rtcp_sender_.TimeToSendRTCPReport(force_sender_report)) rtcp_sender_.SendRTCP(GetFeedbackState(), kRtcpReport); @@ -436,7 +446,10 @@ int32_t ModuleRtpRtcpImpl2::RemoteNTP(uint32_t* received_ntpsecs, uint32_t* rtcp_timestamp) const { return rtcp_receiver_.NTP(received_ntpsecs, received_ntpfrac, rtcp_arrival_time_secs, rtcp_arrival_time_frac, - rtcp_timestamp) + rtcp_timestamp, + /*remote_sender_packet_count=*/nullptr, + /*remote_sender_octet_count=*/nullptr, + /*remote_sender_reports_count=*/nullptr) ? 0 : -1; } @@ -486,16 +499,31 @@ void ModuleRtpRtcpImpl2::GetSendStreamDataCounters( } // Received RTCP report. -int32_t ModuleRtpRtcpImpl2::RemoteRTCPStat( - std::vector* receive_blocks) const { - return rtcp_receiver_.StatisticsReceived(receive_blocks); -} - std::vector ModuleRtpRtcpImpl2::GetLatestReportBlockData() const { return rtcp_receiver_.GetLatestReportBlockData(); } +absl::optional +ModuleRtpRtcpImpl2::GetSenderReportStats() const { + SenderReportStats stats; + uint32_t remote_timestamp_secs; + uint32_t remote_timestamp_frac; + uint32_t arrival_timestamp_secs; + uint32_t arrival_timestamp_frac; + if (rtcp_receiver_.NTP(&remote_timestamp_secs, &remote_timestamp_frac, + &arrival_timestamp_secs, &arrival_timestamp_frac, + /*rtcp_timestamp=*/nullptr, &stats.packets_sent, + &stats.bytes_sent, &stats.reports_count)) { + stats.last_remote_timestamp.Set(remote_timestamp_secs, + remote_timestamp_frac); + stats.last_arrival_timestamp.Set(arrival_timestamp_secs, + arrival_timestamp_frac); + return stats; + } + return absl::nullopt; +} + // (REMB) Receiver Estimated Max Bitrate. void ModuleRtpRtcpImpl2::SetRemb(int64_t bitrate_bps, std::vector ssrcs) { @@ -622,8 +650,15 @@ void ModuleRtpRtcpImpl2::SetRemoteSSRC(const uint32_t ssrc) { rtcp_receiver_.SetRemoteSSRC(ssrc); } +void ModuleRtpRtcpImpl2::SetLocalSsrc(uint32_t local_ssrc) { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + rtcp_receiver_.set_local_media_ssrc(local_ssrc); + rtcp_sender_.SetSsrc(local_ssrc); +} + RtpSendRates ModuleRtpRtcpImpl2::GetSendRates() const { - RTC_DCHECK_RUN_ON(worker_queue_); + // Typically called on the `rtp_transport_queue_` owned by an + // RtpTransportControllerSendInterface instance. return rtp_sender_->packet_sender.GetSendRates(); } @@ -668,23 +703,6 @@ void ModuleRtpRtcpImpl2::OnReceivedRtcpReportBlocks( } } -bool ModuleRtpRtcpImpl2::LastReceivedNTP( - uint32_t* rtcp_arrival_time_secs, // When we got the last report. - uint32_t* rtcp_arrival_time_frac, - uint32_t* remote_sr) const { - // Remote SR: NTP inside the last received (mid 16 bits from sec and frac). - uint32_t ntp_secs = 0; - uint32_t ntp_frac = 0; - - if (!rtcp_receiver_.NTP(&ntp_secs, &ntp_frac, rtcp_arrival_time_secs, - rtcp_arrival_time_frac, NULL)) { - return false; - } - *remote_sr = - ((ntp_secs & 0x0000ffff) << 16) + ((ntp_frac & 0xffff0000) >> 16); - return true; -} - void ModuleRtpRtcpImpl2::set_rtt_ms(int64_t rtt_ms) { RTC_DCHECK_RUN_ON(worker_queue_); { @@ -724,13 +742,62 @@ void ModuleRtpRtcpImpl2::PeriodicUpdate() { rtt_stats_->OnRttUpdate(rtt->ms()); set_rtt_ms(rtt->ms()); } +} + +// RTC_RUN_ON(worker_queue_); +void ModuleRtpRtcpImpl2::MaybeSendRtcp() { + if (rtcp_sender_.TimeToSendRTCPReport()) + rtcp_sender_.SendRTCP(GetFeedbackState(), kRtcpReport); +} + +// TODO(bugs.webrtc.org/12889): Consider removing this function when the issue +// is resolved. +// RTC_RUN_ON(worker_queue_); +void ModuleRtpRtcpImpl2::MaybeSendRtcpAtOrAfterTimestamp( + Timestamp execution_time) { + Timestamp now = clock_->CurrentTime(); + if (now >= execution_time) { + MaybeSendRtcp(); + return; + } + + RTC_DLOG(LS_WARNING) + << "BUGBUG: Task queue scheduled delayed call too early."; + + ScheduleMaybeSendRtcpAtOrAfterTimestamp(execution_time, execution_time - now); +} + +void ModuleRtpRtcpImpl2::ScheduleRtcpSendEvaluation(TimeDelta duration) { + // We end up here under various sequences including the worker queue, and + // the RTCPSender lock is held. + // We're assuming that the fact that RTCPSender executes under other sequences + // than the worker queue on which it's created on implies that external + // synchronization is present and removes this activity before destruction. + if (duration.IsZero()) { + worker_queue_->PostTask(ToQueuedTask(task_safety_, [this] { + RTC_DCHECK_RUN_ON(worker_queue_); + MaybeSendRtcp(); + })); + } else { + Timestamp execution_time = clock_->CurrentTime() + duration; + ScheduleMaybeSendRtcpAtOrAfterTimestamp(execution_time, duration); + } +} - // kTmmbrTimeoutIntervalMs is 25 seconds, so an order of seconds. - // Instead of this polling approach, consider having an optional timer in the - // RTCPReceiver class that is started/stopped based on the state of - // rtcp_sender_.TMMBR(). - if (rtcp_sender_.TMMBR() && rtcp_receiver_.UpdateTmmbrTimers()) - rtcp_receiver_.NotifyTmmbrUpdated(); +void ModuleRtpRtcpImpl2::ScheduleMaybeSendRtcpAtOrAfterTimestamp( + Timestamp execution_time, + TimeDelta duration) { + // We end up here under various sequences including the worker queue, and + // the RTCPSender lock is held. + // See note in ScheduleRtcpSendEvaluation about why |worker_queue_| can be + // accessed. + worker_queue_->PostDelayedTask( + ToQueuedTask(task_safety_, + [this, execution_time] { + RTC_DCHECK_RUN_ON(worker_queue_); + MaybeSendRtcpAtOrAfterTimestamp(execution_time); + }), + DelayMillisForDuration(duration)); } } // namespace webrtc diff --git a/modules/rtp_rtcp/source/rtp_rtcp_impl2.h b/modules/rtp_rtcp/source/rtp_rtcp_impl2.h index 9431e75884..0ad495593d 100644 --- a/modules/rtp_rtcp/source/rtp_rtcp_impl2.h +++ b/modules/rtp_rtcp/source/rtp_rtcp_impl2.h @@ -21,7 +21,9 @@ #include "absl/types/optional.h" #include "api/rtp_headers.h" +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_base.h" +#include "api/units/time_delta.h" #include "api/video/video_bitrate_allocation.h" #include "modules/include/module_fec_types.h" #include "modules/remote_bitrate_estimator/include/remote_bitrate_estimator.h" @@ -31,15 +33,15 @@ #include "modules/rtp_rtcp/source/rtcp_sender.h" #include "modules/rtp_rtcp/source/rtp_packet_history.h" #include "modules/rtp_rtcp/source/rtp_packet_to_send.h" -#include "modules/rtp_rtcp/source/rtp_rtcp_impl2.h" #include "modules/rtp_rtcp/source/rtp_sender.h" #include "modules/rtp_rtcp/source/rtp_sender_egress.h" #include "rtc_base/gtest_prod_util.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/task_utils/repeating_task.h" +#include "rtc_base/task_utils/to_queued_task.h" +#include "rtc_base/thread_annotations.h" namespace webrtc { @@ -48,7 +50,6 @@ struct PacedPacketInfo; struct RTPVideoHeader; class ModuleRtpRtcpImpl2 final : public RtpRtcpInterface, - public Module, public RTCPReceiver::ModuleRtpRtcp { public: explicit ModuleRtpRtcpImpl2( @@ -62,13 +63,6 @@ class ModuleRtpRtcpImpl2 final : public RtpRtcpInterface, static std::unique_ptr Create( const Configuration& configuration); - // Returns the number of milliseconds until the module want a worker thread to - // call Process. - int64_t TimeUntilNextProcess() override; - - // Process any pending tasks such as timeouts. - void Process() override; - // Receiver part. // Called when we receive an RTCP packet. @@ -77,6 +71,8 @@ class ModuleRtpRtcpImpl2 final : public RtpRtcpInterface, void SetRemoteSSRC(uint32_t ssrc) override; + void SetLocalSsrc(uint32_t local_ssrc) override; + // Sender part. void RegisterSendPayloadFrequency(int payload_type, int payload_frequency) override; @@ -110,6 +106,11 @@ class ModuleRtpRtcpImpl2 final : public RtpRtcpInterface, uint32_t SSRC() const override { return rtcp_sender_.SSRC(); } + // Semantically identical to `SSRC()` but must be called on the packet + // delivery thread/tq and returns the ssrc that maps to + // RtpRtcpInterface::Configuration::local_media_ssrc. + uint32_t local_media_ssrc() const; + void SetRid(const std::string& rid) override; void SetMid(const std::string& mid) override; @@ -193,21 +194,20 @@ class ModuleRtpRtcpImpl2 final : public RtpRtcpInterface, int64_t ExpectedRetransmissionTimeMs() const override; // Force a send of an RTCP packet. - // Normal SR and RR are triggered via the process function. + // Normal SR and RR are triggered via the task queue that's current when this + // object is created. int32_t SendRTCP(RTCPPacketType rtcpPacketType) override; void GetSendStreamDataCounters( StreamDataCounters* rtp_counters, StreamDataCounters* rtx_counters) const override; - // Get received RTCP report, report block. - int32_t RemoteRTCPStat( - std::vector* receive_blocks) const override; // A snapshot of the most recent Report Block with additional data of // interest to statistics. Used to implement RTCRemoteInboundRtpStreamStats. // Within this list, the ReportBlockData::RTCPReportBlock::source_ssrc(), // which is the SSRC of the corresponding outbound RTP stream, is unique. std::vector GetLatestReportBlockData() const override; + absl::optional GetSenderReportStats() const override; // (REMB) Receiver Estimated Max Bitrate. void SetRemb(int64_t bitrate_bps, std::vector ssrcs) override; @@ -240,10 +240,6 @@ class ModuleRtpRtcpImpl2 final : public RtpRtcpInterface, bool decodability_flag, bool buffering_allowed) override; - bool LastReceivedNTP(uint32_t* NTPsecs, - uint32_t* NTPfrac, - uint32_t* remote_sr) const; - RtpSendRates GetSendRates() const override; void OnReceivedNack( @@ -288,18 +284,32 @@ class ModuleRtpRtcpImpl2 final : public RtpRtcpInterface, // Returns true if the module is configured to store packets. bool StorePackets() const; + // Used from RtcpSenderMediator to maybe send rtcp. + void MaybeSendRtcp() RTC_RUN_ON(worker_queue_); + + // Called when |rtcp_sender_| informs of the next RTCP instant. The method may + // be called on various sequences, and is called under a RTCPSenderLock. + void ScheduleRtcpSendEvaluation(TimeDelta duration); + + // Helper method combating too early delayed calls from task queues. + // TODO(bugs.webrtc.org/12889): Consider removing this function when the issue + // is resolved. + void MaybeSendRtcpAtOrAfterTimestamp(Timestamp execution_time) + RTC_RUN_ON(worker_queue_); + + // Schedules a call to MaybeSendRtcpAtOrAfterTimestamp delayed by |duration|. + void ScheduleMaybeSendRtcpAtOrAfterTimestamp(Timestamp execution_time, + TimeDelta duration); + TaskQueueBase* const worker_queue_; - RTC_NO_UNIQUE_ADDRESS SequenceChecker process_thread_checker_; + RTC_NO_UNIQUE_ADDRESS SequenceChecker packet_sequence_checker_; std::unique_ptr rtp_sender_; - RTCPSender rtcp_sender_; RTCPReceiver rtcp_receiver_; Clock* const clock_; - int64_t last_rtt_process_time_; - int64_t next_process_time_; uint16_t packet_overhead_; // Send side @@ -314,6 +324,8 @@ class ModuleRtpRtcpImpl2 final : public RtpRtcpInterface, // The processed RTT from RtcpRttStats. mutable Mutex mutex_rtt_; int64_t rtt_ms_ RTC_GUARDED_BY(mutex_rtt_); + + RTC_NO_UNIQUE_ADDRESS ScopedTaskSafety task_safety_; }; } // namespace webrtc diff --git a/modules/rtp_rtcp/source/rtp_rtcp_impl2_unittest.cc b/modules/rtp_rtcp/source/rtp_rtcp_impl2_unittest.cc index 3b666422b8..c8ab15de78 100644 --- a/modules/rtp_rtcp/source/rtp_rtcp_impl2_unittest.cc +++ b/modules/rtp_rtcp/source/rtp_rtcp_impl2_unittest.cc @@ -10,17 +10,24 @@ #include "modules/rtp_rtcp/source/rtp_rtcp_impl2.h" +#include #include #include #include +#include +#include "absl/types/optional.h" #include "api/transport/field_trial_based_config.h" +#include "api/units/time_delta.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" #include "modules/rtp_rtcp/source/rtcp_packet.h" #include "modules/rtp_rtcp/source/rtcp_packet/nack.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" +#include "modules/rtp_rtcp/source/rtp_rtcp_interface.h" #include "modules/rtp_rtcp/source/rtp_sender_video.h" +#include "rtc_base/logging.h" #include "rtc_base/rate_limiter.h" +#include "rtc_base/strings/string_builder.h" #include "test/gmock.h" #include "test/gtest.h" #include "test/rtcp_packet_parser.h" @@ -28,19 +35,35 @@ #include "test/run_loop.h" #include "test/time_controller/simulated_time_controller.h" +using ::testing::AllOf; using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Field; +using ::testing::Gt; +using ::testing::Not; +using ::testing::Optional; +using ::testing::SizeIs; namespace webrtc { namespace { -const uint32_t kSenderSsrc = 0x12345; -const uint32_t kReceiverSsrc = 0x23456; -const int64_t kOneWayNetworkDelayMs = 100; -const uint8_t kBaseLayerTid = 0; -const uint8_t kHigherLayerTid = 1; -const uint16_t kSequenceNumber = 100; -const uint8_t kPayloadType = 100; -const int kWidth = 320; -const int kHeight = 100; +constexpr uint32_t kSenderSsrc = 0x12345; +constexpr uint32_t kReceiverSsrc = 0x23456; +constexpr TimeDelta kOneWayNetworkDelay = TimeDelta::Millis(100); +constexpr uint8_t kBaseLayerTid = 0; +constexpr uint8_t kHigherLayerTid = 1; +constexpr uint16_t kSequenceNumber = 100; +constexpr uint8_t kPayloadType = 100; +constexpr int kWidth = 320; +constexpr int kHeight = 100; +constexpr int kCaptureTimeMsToRtpTimestamp = 90; // 90 kHz clock. +constexpr TimeDelta kDefaultReportInterval = TimeDelta::Millis(1000); + +// RTP header extension ids. +enum : int { + kAbsoluteSendTimeExtensionId = 1, + kTransportSequenceNumberExtensionId, + kTransmissionOffsetExtensionId, +}; class RtcpRttStatsTestImpl : public RtcpRttStats { public: @@ -52,74 +75,150 @@ class RtcpRttStatsTestImpl : public RtcpRttStats { int64_t rtt_ms_; }; -class SendTransport : public Transport { +// TODO(bugs.webrtc.org/11581): remove inheritance once the ModuleRtpRtcpImpl2 +// Module/ProcessThread dependency is gone. +class SendTransport : public Transport, + public sim_time_impl::SimulatedSequenceRunner { public: - SendTransport() + SendTransport(TimeDelta delay, GlobalSimulatedTimeController* time_controller) : receiver_(nullptr), - time_controller_(nullptr), - delay_ms_(0), + time_controller_(time_controller), + delay_(delay), rtp_packets_sent_(0), - rtcp_packets_sent_(0) {} + rtcp_packets_sent_(0), + last_packet_(&header_extensions_) { + time_controller_->Register(this); + } + + ~SendTransport() { time_controller_->Unregister(this); } void SetRtpRtcpModule(ModuleRtpRtcpImpl2* receiver) { receiver_ = receiver; } - void SimulateNetworkDelay(int64_t delay_ms, TimeController* time_controller) { - time_controller_ = time_controller; - delay_ms_ = delay_ms; - } + void SimulateNetworkDelay(TimeDelta delay) { delay_ = delay; } bool SendRtp(const uint8_t* data, size_t len, const PacketOptions& options) override { - RTPHeader header; - std::unique_ptr parser(RtpHeaderParser::CreateForTest()); - EXPECT_TRUE(parser->Parse(static_cast(data), len, &header)); + EXPECT_TRUE(last_packet_.Parse(data, len)); ++rtp_packets_sent_; - last_rtp_header_ = header; return true; } bool SendRtcp(const uint8_t* data, size_t len) override { test::RtcpPacketParser parser; parser.Parse(data, len); last_nack_list_ = parser.nack()->packet_ids(); - - if (time_controller_) { - time_controller_->AdvanceTime(TimeDelta::Millis(delay_ms_)); - } - EXPECT_TRUE(receiver_); - receiver_->IncomingRtcpPacket(data, len); + Timestamp current_time = time_controller_->GetClock()->CurrentTime(); + Timestamp delivery_time = current_time + delay_; + rtcp_packets_.push_back( + Packet{delivery_time, std::vector(data, data + len)}); ++rtcp_packets_sent_; + RunReady(current_time); return true; } + // sim_time_impl::SimulatedSequenceRunner + Timestamp GetNextRunTime() const override { + if (!rtcp_packets_.empty()) + return rtcp_packets_.front().send_time; + return Timestamp::PlusInfinity(); + } + void RunReady(Timestamp at_time) override { + while (!rtcp_packets_.empty() && + rtcp_packets_.front().send_time <= at_time) { + Packet packet = std::move(rtcp_packets_.front()); + rtcp_packets_.pop_front(); + EXPECT_TRUE(receiver_); + receiver_->IncomingRtcpPacket(packet.data.data(), packet.data.size()); + } + } + TaskQueueBase* GetAsTaskQueue() override { + return reinterpret_cast(this); + } + size_t NumRtcpSent() { return rtcp_packets_sent_; } ModuleRtpRtcpImpl2* receiver_; - TimeController* time_controller_; - int64_t delay_ms_; + GlobalSimulatedTimeController* const time_controller_; + TimeDelta delay_; int rtp_packets_sent_; size_t rtcp_packets_sent_; - RTPHeader last_rtp_header_; std::vector last_nack_list_; + RtpHeaderExtensionMap header_extensions_; + RtpPacketReceived last_packet_; + struct Packet { + Timestamp send_time; + std::vector data; + }; + std::deque rtcp_packets_; +}; + +struct TestConfig { + explicit TestConfig(bool with_overhead) : with_overhead(with_overhead) {} + + bool with_overhead = false; }; -class RtpRtcpModule : public RtcpPacketTypeCounterObserver { +class FieldTrialConfig : public WebRtcKeyValueConfig { public: - RtpRtcpModule(TimeController* time_controller, bool is_sender) - : is_sender_(is_sender), + static FieldTrialConfig GetFromTestConfig(const TestConfig& config) { + FieldTrialConfig trials; + trials.overhead_enabled_ = config.with_overhead; + return trials; + } + + FieldTrialConfig() : overhead_enabled_(false), max_padding_factor_(1200) {} + ~FieldTrialConfig() override {} + + void SetOverHeadEnabled(bool enabled) { overhead_enabled_ = enabled; } + void SetMaxPaddingFactor(double factor) { max_padding_factor_ = factor; } + + std::string Lookup(absl::string_view key) const override { + if (key == "WebRTC-LimitPaddingSize") { + char string_buf[32]; + rtc::SimpleStringBuilder ssb(string_buf); + ssb << "factor:" << max_padding_factor_; + return ssb.str(); + } else if (key == "WebRTC-SendSideBwe-WithOverhead") { + return overhead_enabled_ ? "Enabled" : "Disabled"; + } + return ""; + } + + private: + bool overhead_enabled_; + double max_padding_factor_; +}; + +class RtpRtcpModule : public RtcpPacketTypeCounterObserver, + public SendPacketObserver { + public: + struct SentPacket { + SentPacket(uint16_t packet_id, int64_t capture_time_ms, uint32_t ssrc) + : packet_id(packet_id), capture_time_ms(capture_time_ms), ssrc(ssrc) {} + uint16_t packet_id; + int64_t capture_time_ms; + uint32_t ssrc; + }; + + RtpRtcpModule(GlobalSimulatedTimeController* time_controller, + bool is_sender, + const FieldTrialConfig& trials) + : time_controller_(time_controller), + is_sender_(is_sender), + trials_(trials), receive_statistics_( ReceiveStatistics::Create(time_controller->GetClock())), - time_controller_(time_controller) { + transport_(kOneWayNetworkDelay, time_controller) { CreateModuleImpl(); - transport_.SimulateNetworkDelay(kOneWayNetworkDelayMs, time_controller); } + TimeController* const time_controller_; const bool is_sender_; + const FieldTrialConfig& trials_; RtcpPacketTypeCounter packets_sent_; RtcpPacketTypeCounter packets_received_; std::unique_ptr receive_statistics_; SendTransport transport_; RtcpRttStatsTestImpl rtt_stats_; std::unique_ptr impl_; - int rtcp_report_interval_ms_ = 0; void RtcpPacketTypesCounterUpdated( uint32_t ssrc, @@ -127,6 +226,16 @@ class RtpRtcpModule : public RtcpPacketTypeCounterObserver { counter_map_[ssrc] = packet_counter; } + void OnSendPacket(uint16_t packet_id, + int64_t capture_time_ms, + uint32_t ssrc) override { + last_sent_packet_.emplace(packet_id, capture_time_ms, ssrc); + } + + absl::optional last_sent_packet() const { + return last_sent_packet_; + } + RtcpPacketTypeCounter RtcpSent() { // RTCP counters for remote SSRC. return counter_map_[is_sender_ ? kReceiverSsrc : kSenderSsrc]; @@ -137,14 +246,22 @@ class RtpRtcpModule : public RtcpPacketTypeCounterObserver { return counter_map_[impl_->SSRC()]; } int RtpSent() { return transport_.rtp_packets_sent_; } - uint16_t LastRtpSequenceNumber() { - return transport_.last_rtp_header_.sequenceNumber; - } + uint16_t LastRtpSequenceNumber() { return last_packet().SequenceNumber(); } std::vector LastNackListSent() { return transport_.last_nack_list_; } - void SetRtcpReportIntervalAndReset(int rtcp_report_interval_ms) { - rtcp_report_interval_ms_ = rtcp_report_interval_ms; + void SetRtcpReportIntervalAndReset(TimeDelta rtcp_report_interval) { + rtcp_report_interval_ = rtcp_report_interval; + CreateModuleImpl(); + } + const RtpPacketReceived& last_packet() { return transport_.last_packet_; } + void RegisterHeaderExtension(absl::string_view uri, int id) { + impl_->RegisterRtpHeaderExtension(uri, id); + transport_.header_extensions_.RegisterByUri(id, uri); + transport_.last_packet_.IdentifyExtensions(transport_.header_extensions_); + } + void ReinintWithFec(VideoFecGenerator* fec_generator) { + fec_generator_ = fec_generator; CreateModuleImpl(); } @@ -157,27 +274,36 @@ class RtpRtcpModule : public RtcpPacketTypeCounterObserver { config.receive_statistics = receive_statistics_.get(); config.rtcp_packet_type_counter_observer = this; config.rtt_stats = &rtt_stats_; - config.rtcp_report_interval_ms = rtcp_report_interval_ms_; + config.rtcp_report_interval_ms = rtcp_report_interval_.ms(); config.local_media_ssrc = is_sender_ ? kSenderSsrc : kReceiverSsrc; config.need_rtp_packet_infos = true; config.non_sender_rtt_measurement = true; - + config.field_trials = &trials_; + config.send_packet_observer = this; + config.fec_generator = fec_generator_; impl_.reset(new ModuleRtpRtcpImpl2(config)); impl_->SetRemoteSSRC(is_sender_ ? kReceiverSsrc : kSenderSsrc); impl_->SetRTCPStatus(RtcpMode::kCompound); } - TimeController* const time_controller_; std::map counter_map_; + absl::optional last_sent_packet_; + VideoFecGenerator* fec_generator_ = nullptr; + TimeDelta rtcp_report_interval_ = kDefaultReportInterval; }; } // namespace -class RtpRtcpImpl2Test : public ::testing::Test { +class RtpRtcpImpl2Test : public ::testing::TestWithParam { protected: RtpRtcpImpl2Test() : time_controller_(Timestamp::Micros(133590000000000)), - sender_(&time_controller_, /*is_sender=*/true), - receiver_(&time_controller_, /*is_sender=*/false) {} + field_trials_(FieldTrialConfig::GetFromTestConfig(GetParam())), + sender_(&time_controller_, + /*is_sender=*/true, + field_trials_), + receiver_(&time_controller_, + /*is_sender=*/false, + field_trials_) {} void SetUp() override { // Send module. @@ -186,11 +312,10 @@ class RtpRtcpImpl2Test : public ::testing::Test { sender_.impl_->SetSequenceNumber(kSequenceNumber); sender_.impl_->SetStorePacketsStatus(true, 100); - FieldTrialBasedConfig field_trials; RTPSenderVideo::Config video_config; video_config.clock = time_controller_.GetClock(); video_config.rtp_sender = sender_.impl_->RtpSender(); - video_config.field_trials = &field_trials; + video_config.field_trials = &field_trials_; sender_video_ = std::make_unique(video_config); // Receive module. @@ -201,20 +326,49 @@ class RtpRtcpImpl2Test : public ::testing::Test { receiver_.transport_.SetRtpRtcpModule(sender_.impl_.get()); } - void AdvanceTimeMs(int64_t milliseconds) { - time_controller_.AdvanceTime(TimeDelta::Millis(milliseconds)); + void AdvanceTime(TimeDelta duration) { + time_controller_.AdvanceTime(duration); + } + + void ReinitWithFec(VideoFecGenerator* fec_generator, + absl::optional red_payload_type) { + sender_.ReinintWithFec(fec_generator); + EXPECT_EQ(0, sender_.impl_->SetSendingStatus(true)); + sender_.impl_->SetSendingMediaStatus(true); + sender_.impl_->SetSequenceNumber(kSequenceNumber); + sender_.impl_->SetStorePacketsStatus(true, 100); + receiver_.transport_.SetRtpRtcpModule(sender_.impl_.get()); + + RTPSenderVideo::Config video_config; + video_config.clock = time_controller_.GetClock(); + video_config.rtp_sender = sender_.impl_->RtpSender(); + video_config.field_trials = &field_trials_; + video_config.fec_overhead_bytes = fec_generator->MaxPacketOverhead(); + video_config.fec_type = fec_generator->GetFecType(); + video_config.red_payload_type = red_payload_type; + sender_video_ = std::make_unique(video_config); } GlobalSimulatedTimeController time_controller_; - // test::RunLoop loop_; - // SimulatedClock clock_; + FieldTrialConfig field_trials_; RtpRtcpModule sender_; std::unique_ptr sender_video_; RtpRtcpModule receiver_; - void SendFrame(const RtpRtcpModule* module, + bool SendFrame(const RtpRtcpModule* module, RTPSenderVideo* sender, uint8_t tid) { + int64_t now_ms = time_controller_.GetClock()->TimeInMilliseconds(); + return SendFrame( + module, sender, tid, + static_cast(now_ms * kCaptureTimeMsToRtpTimestamp), now_ms); + } + + bool SendFrame(const RtpRtcpModule* module, + RTPSenderVideo* sender, + uint8_t tid, + uint32_t rtp_timestamp, + int64_t capture_time_ms) { RTPVideoHeaderVP8 vp8_header = {}; vp8_header.temporalIdx = tid; RTPVideoHeader rtp_video_header; @@ -231,9 +385,12 @@ class RtpRtcpImpl2Test : public ::testing::Test { rtp_video_header.video_timing = {0u, 0u, 0u, 0u, 0u, 0u, false}; const uint8_t payload[100] = {0}; - EXPECT_TRUE(module->impl_->OnSendingRtpFrame(0, 0, kPayloadType, true)); - EXPECT_TRUE(sender->SendVideo(kPayloadType, VideoCodecType::kVideoCodecVP8, - 0, 0, payload, rtp_video_header, 0)); + bool success = module->impl_->OnSendingRtpFrame(0, 0, kPayloadType, true); + + success &= sender->SendVideo(kPayloadType, VideoCodecType::kVideoCodecVP8, + rtp_timestamp, capture_time_ms, payload, + rtp_video_header, 0); + return success; } void IncomingRtcpNack(const RtpRtcpModule* module, uint16_t sequence_number) { @@ -250,19 +407,20 @@ class RtpRtcpImpl2Test : public ::testing::Test { } }; -TEST_F(RtpRtcpImpl2Test, RetransmitsAllLayers) { +TEST_P(RtpRtcpImpl2Test, RetransmitsAllLayers) { // Send frames. EXPECT_EQ(0, sender_.RtpSent()); - SendFrame(&sender_, sender_video_.get(), kBaseLayerTid); // kSequenceNumber - SendFrame(&sender_, sender_video_.get(), - kHigherLayerTid); // kSequenceNumber + 1 - SendFrame(&sender_, sender_video_.get(), - kNoTemporalIdx); // kSequenceNumber + 2 + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), + kBaseLayerTid)); // kSequenceNumber + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), + kHigherLayerTid)); // kSequenceNumber + 1 + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), + kNoTemporalIdx)); // kSequenceNumber + 2 EXPECT_EQ(3, sender_.RtpSent()); EXPECT_EQ(kSequenceNumber + 2, sender_.LastRtpSequenceNumber()); // Min required delay until retransmit = 5 + RTT ms (RTT = 0). - AdvanceTimeMs(5); + AdvanceTime(TimeDelta::Millis(5)); // Frame with kBaseLayerTid re-sent. IncomingRtcpNack(&sender_, kSequenceNumber); @@ -278,7 +436,7 @@ TEST_F(RtpRtcpImpl2Test, RetransmitsAllLayers) { EXPECT_EQ(kSequenceNumber + 2, sender_.LastRtpSequenceNumber()); } -TEST_F(RtpRtcpImpl2Test, Rtt) { +TEST_P(RtpRtcpImpl2Test, Rtt) { RtpPacketReceived packet; packet.SetTimestamp(1); packet.SetSequenceNumber(123); @@ -287,13 +445,14 @@ TEST_F(RtpRtcpImpl2Test, Rtt) { receiver_.receive_statistics_->OnRtpPacket(packet); // Send Frame before sending an SR. - SendFrame(&sender_, sender_video_.get(), kBaseLayerTid); + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); // Sender module should send an SR. EXPECT_EQ(0, sender_.impl_->SendRTCP(kRtcpReport)); + AdvanceTime(kOneWayNetworkDelay); // Receiver module should send a RR with a response to the last received SR. - AdvanceTimeMs(1000); EXPECT_EQ(0, receiver_.impl_->SendRTCP(kRtcpReport)); + AdvanceTime(kOneWayNetworkDelay); // Verify RTT. int64_t rtt; @@ -302,10 +461,10 @@ TEST_F(RtpRtcpImpl2Test, Rtt) { int64_t max_rtt; EXPECT_EQ( 0, sender_.impl_->RTT(kReceiverSsrc, &rtt, &avg_rtt, &min_rtt, &max_rtt)); - EXPECT_NEAR(2 * kOneWayNetworkDelayMs, rtt, 1); - EXPECT_NEAR(2 * kOneWayNetworkDelayMs, avg_rtt, 1); - EXPECT_NEAR(2 * kOneWayNetworkDelayMs, min_rtt, 1); - EXPECT_NEAR(2 * kOneWayNetworkDelayMs, max_rtt, 1); + EXPECT_NEAR(2 * kOneWayNetworkDelay.ms(), rtt, 1); + EXPECT_NEAR(2 * kOneWayNetworkDelay.ms(), avg_rtt, 1); + EXPECT_NEAR(2 * kOneWayNetworkDelay.ms(), min_rtt, 1); + EXPECT_NEAR(2 * kOneWayNetworkDelay.ms(), max_rtt, 1); // No RTT from other ssrc. EXPECT_EQ(-1, sender_.impl_->RTT(kReceiverSsrc + 1, &rtt, &avg_rtt, &min_rtt, @@ -314,54 +473,51 @@ TEST_F(RtpRtcpImpl2Test, Rtt) { // Verify RTT from rtt_stats config. EXPECT_EQ(0, sender_.rtt_stats_.LastProcessedRtt()); EXPECT_EQ(0, sender_.impl_->rtt_ms()); - AdvanceTimeMs(1000); + AdvanceTime(TimeDelta::Millis(1000)); - EXPECT_NEAR(2 * kOneWayNetworkDelayMs, sender_.rtt_stats_.LastProcessedRtt(), - 1); - EXPECT_NEAR(2 * kOneWayNetworkDelayMs, sender_.impl_->rtt_ms(), 1); + EXPECT_NEAR(2 * kOneWayNetworkDelay.ms(), + sender_.rtt_stats_.LastProcessedRtt(), 1); + EXPECT_NEAR(2 * kOneWayNetworkDelay.ms(), sender_.impl_->rtt_ms(), 1); } -TEST_F(RtpRtcpImpl2Test, RttForReceiverOnly) { +TEST_P(RtpRtcpImpl2Test, RttForReceiverOnly) { // Receiver module should send a Receiver time reference report (RTRR). EXPECT_EQ(0, receiver_.impl_->SendRTCP(kRtcpReport)); // Sender module should send a response to the last received RTRR (DLRR). - AdvanceTimeMs(1000); + AdvanceTime(TimeDelta::Millis(1000)); // Send Frame before sending a SR. - SendFrame(&sender_, sender_video_.get(), kBaseLayerTid); + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); EXPECT_EQ(0, sender_.impl_->SendRTCP(kRtcpReport)); // Verify RTT. EXPECT_EQ(0, receiver_.rtt_stats_.LastProcessedRtt()); EXPECT_EQ(0, receiver_.impl_->rtt_ms()); - AdvanceTimeMs(1000); - EXPECT_NEAR(2 * kOneWayNetworkDelayMs, + AdvanceTime(TimeDelta::Millis(1000)); + EXPECT_NEAR(2 * kOneWayNetworkDelay.ms(), receiver_.rtt_stats_.LastProcessedRtt(), 1); - EXPECT_NEAR(2 * kOneWayNetworkDelayMs, receiver_.impl_->rtt_ms(), 1); + EXPECT_NEAR(2 * kOneWayNetworkDelay.ms(), receiver_.impl_->rtt_ms(), 1); } -TEST_F(RtpRtcpImpl2Test, NoSrBeforeMedia) { +TEST_P(RtpRtcpImpl2Test, NoSrBeforeMedia) { // Ignore fake transport delays in this test. - sender_.transport_.SimulateNetworkDelay(0, &time_controller_); - receiver_.transport_.SimulateNetworkDelay(0, &time_controller_); - - sender_.impl_->Process(); - EXPECT_EQ(-1, sender_.RtcpSent().first_packet_time_ms); + sender_.transport_.SimulateNetworkDelay(TimeDelta::Millis(0)); + receiver_.transport_.SimulateNetworkDelay(TimeDelta::Millis(0)); + // Move ahead to the instant a rtcp is expected. // Verify no SR is sent before media has been sent, RR should still be sent // from the receiving module though. - AdvanceTimeMs(2000); + AdvanceTime(kDefaultReportInterval / 2); int64_t current_time = time_controller_.GetClock()->TimeInMilliseconds(); - sender_.impl_->Process(); - receiver_.impl_->Process(); EXPECT_EQ(-1, sender_.RtcpSent().first_packet_time_ms); EXPECT_EQ(receiver_.RtcpSent().first_packet_time_ms, current_time); - SendFrame(&sender_, sender_video_.get(), kBaseLayerTid); + // RTCP should be triggered by the RTP send. + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); EXPECT_EQ(sender_.RtcpSent().first_packet_time_ms, current_time); } -TEST_F(RtpRtcpImpl2Test, RtcpPacketTypeCounter_Nack) { +TEST_P(RtpRtcpImpl2Test, RtcpPacketTypeCounter_Nack) { EXPECT_EQ(-1, receiver_.RtcpSent().first_packet_time_ms); EXPECT_EQ(-1, sender_.RtcpReceived().first_packet_time_ms); EXPECT_EQ(0U, sender_.RtcpReceived().nack_packets); @@ -371,6 +527,7 @@ TEST_F(RtpRtcpImpl2Test, RtcpPacketTypeCounter_Nack) { const uint16_t kNackLength = 1; uint16_t nack_list[kNackLength] = {123}; EXPECT_EQ(0, receiver_.impl_->SendNACK(nack_list, kNackLength)); + AdvanceTime(kOneWayNetworkDelay); EXPECT_EQ(1U, receiver_.RtcpSent().nack_packets); EXPECT_GT(receiver_.RtcpSent().first_packet_time_ms, -1); @@ -379,7 +536,7 @@ TEST_F(RtpRtcpImpl2Test, RtcpPacketTypeCounter_Nack) { EXPECT_GT(sender_.RtcpReceived().first_packet_time_ms, -1); } -TEST_F(RtpRtcpImpl2Test, AddStreamDataCounters) { +TEST_P(RtpRtcpImpl2Test, AddStreamDataCounters) { StreamDataCounters rtp; const int64_t kStartTimeMs = 1; rtp.first_packet_time_ms = kStartTimeMs; @@ -422,25 +579,25 @@ TEST_F(RtpRtcpImpl2Test, AddStreamDataCounters) { EXPECT_EQ(kStartTimeMs, sum.first_packet_time_ms); // Holds oldest time. } -TEST_F(RtpRtcpImpl2Test, SendsInitialNackList) { +TEST_P(RtpRtcpImpl2Test, SendsInitialNackList) { // Send module sends a NACK. const uint16_t kNackLength = 1; uint16_t nack_list[kNackLength] = {123}; EXPECT_EQ(0U, sender_.RtcpSent().nack_packets); // Send Frame before sending a compound RTCP that starts with SR. - SendFrame(&sender_, sender_video_.get(), kBaseLayerTid); + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); EXPECT_EQ(0, sender_.impl_->SendNACK(nack_list, kNackLength)); EXPECT_EQ(1U, sender_.RtcpSent().nack_packets); EXPECT_THAT(sender_.LastNackListSent(), ElementsAre(123)); } -TEST_F(RtpRtcpImpl2Test, SendsExtendedNackList) { +TEST_P(RtpRtcpImpl2Test, SendsExtendedNackList) { // Send module sends a NACK. const uint16_t kNackLength = 1; uint16_t nack_list[kNackLength] = {123}; EXPECT_EQ(0U, sender_.RtcpSent().nack_packets); // Send Frame before sending a compound RTCP that starts with SR. - SendFrame(&sender_, sender_video_.get(), kBaseLayerTid); + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); EXPECT_EQ(0, sender_.impl_->SendNACK(nack_list, kNackLength)); EXPECT_EQ(1U, sender_.RtcpSent().nack_packets); EXPECT_THAT(sender_.LastNackListSent(), ElementsAre(123)); @@ -458,33 +615,33 @@ TEST_F(RtpRtcpImpl2Test, SendsExtendedNackList) { EXPECT_THAT(sender_.LastNackListSent(), ElementsAre(124)); } -TEST_F(RtpRtcpImpl2Test, ReSendsNackListAfterRttMs) { - sender_.transport_.SimulateNetworkDelay(0, &time_controller_); +TEST_P(RtpRtcpImpl2Test, ReSendsNackListAfterRttMs) { + sender_.transport_.SimulateNetworkDelay(TimeDelta::Millis(0)); // Send module sends a NACK. const uint16_t kNackLength = 2; uint16_t nack_list[kNackLength] = {123, 125}; EXPECT_EQ(0U, sender_.RtcpSent().nack_packets); // Send Frame before sending a compound RTCP that starts with SR. - SendFrame(&sender_, sender_video_.get(), kBaseLayerTid); + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); EXPECT_EQ(0, sender_.impl_->SendNACK(nack_list, kNackLength)); EXPECT_EQ(1U, sender_.RtcpSent().nack_packets); EXPECT_THAT(sender_.LastNackListSent(), ElementsAre(123, 125)); // Same list not re-send, rtt interval has not passed. - const int kStartupRttMs = 100; - AdvanceTimeMs(kStartupRttMs); + const TimeDelta kStartupRtt = TimeDelta::Millis(100); + AdvanceTime(kStartupRtt); EXPECT_EQ(0, sender_.impl_->SendNACK(nack_list, kNackLength)); EXPECT_EQ(1U, sender_.RtcpSent().nack_packets); // Rtt interval passed, full list sent. - AdvanceTimeMs(1); + AdvanceTime(TimeDelta::Millis(1)); EXPECT_EQ(0, sender_.impl_->SendNACK(nack_list, kNackLength)); EXPECT_EQ(2U, sender_.RtcpSent().nack_packets); EXPECT_THAT(sender_.LastNackListSent(), ElementsAre(123, 125)); } -TEST_F(RtpRtcpImpl2Test, UniqueNackRequests) { - receiver_.transport_.SimulateNetworkDelay(0, &time_controller_); +TEST_P(RtpRtcpImpl2Test, UniqueNackRequests) { + receiver_.transport_.SimulateNetworkDelay(TimeDelta::Millis(0)); EXPECT_EQ(0U, receiver_.RtcpSent().nack_packets); EXPECT_EQ(0U, receiver_.RtcpSent().nack_requests); EXPECT_EQ(0U, receiver_.RtcpSent().unique_nack_requests); @@ -506,8 +663,8 @@ TEST_F(RtpRtcpImpl2Test, UniqueNackRequests) { EXPECT_EQ(100, sender_.RtcpReceived().UniqueNackRequestsInPercent()); // Receive module sends new request with duplicated packets. - const int kStartupRttMs = 100; - AdvanceTimeMs(kStartupRttMs + 1); + const TimeDelta kStartupRtt = TimeDelta::Millis(100); + AdvanceTime(kStartupRtt + TimeDelta::Millis(1)); const uint16_t kNackLength2 = 4; uint16_t nack_list2[kNackLength2] = {11, 18, 20, 21}; EXPECT_EQ(0, receiver_.impl_->SendNACK(nack_list2, kNackLength2)); @@ -523,59 +680,54 @@ TEST_F(RtpRtcpImpl2Test, UniqueNackRequests) { EXPECT_EQ(75, sender_.RtcpReceived().UniqueNackRequestsInPercent()); } -TEST_F(RtpRtcpImpl2Test, ConfigurableRtcpReportInterval) { - const int kVideoReportInterval = 3000; +TEST_P(RtpRtcpImpl2Test, ConfigurableRtcpReportInterval) { + const TimeDelta kVideoReportInterval = TimeDelta::Millis(3000); // Recreate sender impl with new configuration, and redo setup. sender_.SetRtcpReportIntervalAndReset(kVideoReportInterval); SetUp(); - SendFrame(&sender_, sender_video_.get(), kBaseLayerTid); + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); // Initial state - sender_.impl_->Process(); EXPECT_EQ(sender_.RtcpSent().first_packet_time_ms, -1); EXPECT_EQ(0u, sender_.transport_.NumRtcpSent()); // Move ahead to the last ms before a rtcp is expected, no action. - AdvanceTimeMs(kVideoReportInterval / 2 - 1); - sender_.impl_->Process(); + AdvanceTime(kVideoReportInterval / 2 - TimeDelta::Millis(1)); EXPECT_EQ(sender_.RtcpSent().first_packet_time_ms, -1); EXPECT_EQ(sender_.transport_.NumRtcpSent(), 0u); // Move ahead to the first rtcp. Send RTCP. - AdvanceTimeMs(1); - sender_.impl_->Process(); + AdvanceTime(TimeDelta::Millis(1)); EXPECT_GT(sender_.RtcpSent().first_packet_time_ms, -1); EXPECT_EQ(sender_.transport_.NumRtcpSent(), 1u); - SendFrame(&sender_, sender_video_.get(), kBaseLayerTid); + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); // Move ahead to the last possible second before second rtcp is expected. - AdvanceTimeMs(kVideoReportInterval * 1 / 2 - 1); - sender_.impl_->Process(); + AdvanceTime(kVideoReportInterval * 1 / 2 - TimeDelta::Millis(1)); EXPECT_EQ(sender_.transport_.NumRtcpSent(), 1u); // Move ahead into the range of second rtcp, the second rtcp may be sent. - AdvanceTimeMs(1); - sender_.impl_->Process(); + AdvanceTime(TimeDelta::Millis(1)); EXPECT_GE(sender_.transport_.NumRtcpSent(), 1u); - AdvanceTimeMs(kVideoReportInterval / 2); - sender_.impl_->Process(); + AdvanceTime(kVideoReportInterval / 2); EXPECT_GE(sender_.transport_.NumRtcpSent(), 1u); // Move out the range of second rtcp, the second rtcp must have been sent. - AdvanceTimeMs(kVideoReportInterval / 2); - sender_.impl_->Process(); + AdvanceTime(kVideoReportInterval / 2); EXPECT_EQ(sender_.transport_.NumRtcpSent(), 2u); } -TEST_F(RtpRtcpImpl2Test, StoresPacketInfoForSentPackets) { +TEST_P(RtpRtcpImpl2Test, StoresPacketInfoForSentPackets) { const uint32_t kStartTimestamp = 1u; SetUp(); sender_.impl_->SetStartTimestamp(kStartTimestamp); + sender_.impl_->SetSequenceNumber(1); + PacedPacketInfo pacing_info; RtpPacketToSend packet(nullptr); packet.set_packet_type(RtpPacketToSend::Type::kVideo); @@ -587,7 +739,7 @@ TEST_F(RtpRtcpImpl2Test, StoresPacketInfoForSentPackets) { packet.set_first_packet_of_frame(true); packet.SetMarker(true); sender_.impl_->TrySendPacket(&packet, pacing_info); - AdvanceTimeMs(1); + AdvanceTime(TimeDelta::Millis(1)); std::vector seqno_info = sender_.impl_->GetSentRtpPacketInfos(std::vector{1}); @@ -612,7 +764,7 @@ TEST_F(RtpRtcpImpl2Test, StoresPacketInfoForSentPackets) { packet.SetMarker(true); sender_.impl_->TrySendPacket(&packet, pacing_info); - AdvanceTimeMs(1); + AdvanceTime(TimeDelta::Millis(1)); seqno_info = sender_.impl_->GetSentRtpPacketInfos(std::vector{2, 3, 4}); @@ -631,4 +783,302 @@ TEST_F(RtpRtcpImpl2Test, StoresPacketInfoForSentPackets) { /*is_last=*/1))); } +// Checks that the sender report stats are not available if no RTCP SR was sent. +TEST_P(RtpRtcpImpl2Test, SenderReportStatsNotAvailable) { + EXPECT_THAT(receiver_.impl_->GetSenderReportStats(), Eq(absl::nullopt)); +} + +// Checks that the sender report stats are available if an RTCP SR was sent. +TEST_P(RtpRtcpImpl2Test, SenderReportStatsAvailable) { + // Send a frame in order to send an SR. + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); + // Send an SR. + ASSERT_THAT(sender_.impl_->SendRTCP(kRtcpReport), Eq(0)); + AdvanceTime(kOneWayNetworkDelay); + EXPECT_THAT(receiver_.impl_->GetSenderReportStats(), Not(Eq(absl::nullopt))); +} + +// Checks that the sender report stats are not available if an RTCP SR with an +// unexpected SSRC is received. +TEST_P(RtpRtcpImpl2Test, SenderReportStatsNotUpdatedWithUnexpectedSsrc) { + constexpr uint32_t kUnexpectedSenderSsrc = 0x87654321; + static_assert(kUnexpectedSenderSsrc != kSenderSsrc, ""); + // Forge a sender report and pass it to the receiver as if an RTCP SR were + // sent by an unexpected sender. + rtcp::SenderReport sr; + sr.SetSenderSsrc(kUnexpectedSenderSsrc); + sr.SetNtp({/*seconds=*/1u, /*fractions=*/1u << 31}); + sr.SetPacketCount(123u); + sr.SetOctetCount(456u); + auto raw_packet = sr.Build(); + receiver_.impl_->IncomingRtcpPacket(raw_packet.data(), raw_packet.size()); + EXPECT_THAT(receiver_.impl_->GetSenderReportStats(), Eq(absl::nullopt)); +} + +// Checks the stats derived from the last received RTCP SR are set correctly. +TEST_P(RtpRtcpImpl2Test, SenderReportStatsCheckStatsFromLastReport) { + using SenderReportStats = RtpRtcpInterface::SenderReportStats; + const NtpTime ntp(/*seconds=*/1u, /*fractions=*/1u << 31); + constexpr uint32_t kPacketCount = 123u; + constexpr uint32_t kOctetCount = 456u; + // Forge a sender report and pass it to the receiver as if an RTCP SR were + // sent by the sender. + rtcp::SenderReport sr; + sr.SetSenderSsrc(kSenderSsrc); + sr.SetNtp(ntp); + sr.SetPacketCount(kPacketCount); + sr.SetOctetCount(kOctetCount); + auto raw_packet = sr.Build(); + receiver_.impl_->IncomingRtcpPacket(raw_packet.data(), raw_packet.size()); + + EXPECT_THAT( + receiver_.impl_->GetSenderReportStats(), + Optional(AllOf(Field(&SenderReportStats::last_remote_timestamp, Eq(ntp)), + Field(&SenderReportStats::packets_sent, Eq(kPacketCount)), + Field(&SenderReportStats::bytes_sent, Eq(kOctetCount))))); +} + +// Checks that the sender report stats count equals the number of sent RTCP SRs. +TEST_P(RtpRtcpImpl2Test, SenderReportStatsCount) { + using SenderReportStats = RtpRtcpInterface::SenderReportStats; + // Send a frame in order to send an SR. + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); + // Send the first SR. + ASSERT_THAT(sender_.impl_->SendRTCP(kRtcpReport), Eq(0)); + AdvanceTime(kOneWayNetworkDelay); + EXPECT_THAT(receiver_.impl_->GetSenderReportStats(), + Optional(Field(&SenderReportStats::reports_count, Eq(1u)))); + // Send the second SR. + ASSERT_THAT(sender_.impl_->SendRTCP(kRtcpReport), Eq(0)); + AdvanceTime(kOneWayNetworkDelay); + EXPECT_THAT(receiver_.impl_->GetSenderReportStats(), + Optional(Field(&SenderReportStats::reports_count, Eq(2u)))); +} + +// Checks that the sender report stats include a valid arrival time if an RTCP +// SR was sent. +TEST_P(RtpRtcpImpl2Test, SenderReportStatsArrivalTimestampSet) { + // Send a frame in order to send an SR. + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); + // Send an SR. + ASSERT_THAT(sender_.impl_->SendRTCP(kRtcpReport), Eq(0)); + AdvanceTime(kOneWayNetworkDelay); + auto stats = receiver_.impl_->GetSenderReportStats(); + ASSERT_THAT(stats, Not(Eq(absl::nullopt))); + EXPECT_TRUE(stats->last_arrival_timestamp.Valid()); +} + +// Checks that the packet and byte counters from an RTCP SR are not zero once +// a frame is sent. +TEST_P(RtpRtcpImpl2Test, SenderReportStatsPacketByteCounters) { + using SenderReportStats = RtpRtcpInterface::SenderReportStats; + // Send a frame in order to send an SR. + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); + ASSERT_THAT(sender_.transport_.rtp_packets_sent_, Gt(0)); + // Advance time otherwise the RTCP SR report will not include any packets + // generated by `SendFrame()`. + AdvanceTime(TimeDelta::Millis(1)); + // Send an SR. + ASSERT_THAT(sender_.impl_->SendRTCP(kRtcpReport), Eq(0)); + AdvanceTime(kOneWayNetworkDelay); + EXPECT_THAT(receiver_.impl_->GetSenderReportStats(), + Optional(AllOf(Field(&SenderReportStats::packets_sent, Gt(0u)), + Field(&SenderReportStats::bytes_sent, Gt(0u))))); +} + +TEST_P(RtpRtcpImpl2Test, SendingVideoAdvancesSequenceNumber) { + const uint16_t sequence_number = sender_.impl_->SequenceNumber(); + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); + ASSERT_THAT(sender_.transport_.rtp_packets_sent_, Gt(0)); + EXPECT_EQ(sequence_number + 1, sender_.impl_->SequenceNumber()); +} + +TEST_P(RtpRtcpImpl2Test, SequenceNumberNotAdvancedWhenNotSending) { + const uint16_t sequence_number = sender_.impl_->SequenceNumber(); + sender_.impl_->SetSendingMediaStatus(false); + EXPECT_FALSE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); + ASSERT_THAT(sender_.transport_.rtp_packets_sent_, Eq(0)); + EXPECT_EQ(sequence_number, sender_.impl_->SequenceNumber()); +} + +TEST_P(RtpRtcpImpl2Test, PaddingNotAllowedInMiddleOfFrame) { + constexpr size_t kPaddingSize = 100; + + // Can't send padding before media. + EXPECT_THAT(sender_.impl_->GeneratePadding(kPaddingSize), SizeIs(0u)); + + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); + + // Padding is now ok. + EXPECT_THAT(sender_.impl_->GeneratePadding(kPaddingSize), SizeIs(Gt(0u))); + + // Send half a video frame. + PacedPacketInfo pacing_info; + std::unique_ptr packet = + sender_.impl_->RtpSender()->AllocatePacket(); + packet->set_packet_type(RtpPacketToSend::Type::kVideo); + packet->set_first_packet_of_frame(true); + packet->SetMarker(false); // Marker false - not last packet of frame. + sender_.impl_->RtpSender()->AssignSequenceNumber(packet.get()); + + EXPECT_TRUE(sender_.impl_->TrySendPacket(packet.get(), pacing_info)); + + // Padding not allowed in middle of frame. + EXPECT_THAT(sender_.impl_->GeneratePadding(kPaddingSize), SizeIs(0u)); + + packet = sender_.impl_->RtpSender()->AllocatePacket(); + packet->set_packet_type(RtpPacketToSend::Type::kVideo); + packet->set_first_packet_of_frame(true); + packet->SetMarker(true); + sender_.impl_->RtpSender()->AssignSequenceNumber(packet.get()); + + EXPECT_TRUE(sender_.impl_->TrySendPacket(packet.get(), pacing_info)); + + // Padding is OK again. + EXPECT_THAT(sender_.impl_->GeneratePadding(kPaddingSize), SizeIs(Gt(0u))); +} + +TEST_P(RtpRtcpImpl2Test, PaddingTimestampMatchesMedia) { + constexpr size_t kPaddingSize = 100; + const uint32_t kTimestamp = 123; + + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid, + kTimestamp, /*capture_time_ms=*/0)); + EXPECT_EQ(sender_.last_packet().Timestamp(), kTimestamp); + uint16_t media_seq = sender_.last_packet().SequenceNumber(); + + // Generate and send padding. + auto padding = sender_.impl_->GeneratePadding(kPaddingSize); + ASSERT_FALSE(padding.empty()); + for (auto& packet : padding) { + sender_.impl_->TrySendPacket(packet.get(), PacedPacketInfo()); + } + + // Verify we sent a new packet, but with the same timestamp. + EXPECT_NE(sender_.last_packet().SequenceNumber(), media_seq); + EXPECT_EQ(sender_.last_packet().Timestamp(), kTimestamp); +} + +TEST_P(RtpRtcpImpl2Test, AssignsTransportSequenceNumber) { + sender_.RegisterHeaderExtension(TransportSequenceNumber::kUri, + kTransportSequenceNumberExtensionId); + + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); + uint16_t first_transport_seq = 0; + EXPECT_TRUE(sender_.last_packet().GetExtension( + &first_transport_seq)); + + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); + uint16_t second_transport_seq = 0; + EXPECT_TRUE(sender_.last_packet().GetExtension( + &second_transport_seq)); + + EXPECT_EQ(first_transport_seq + 1, second_transport_seq); +} + +TEST_P(RtpRtcpImpl2Test, AssignsAbsoluteSendTime) { + sender_.RegisterHeaderExtension(AbsoluteSendTime::kUri, + kAbsoluteSendTimeExtensionId); + + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); + EXPECT_NE(sender_.last_packet().GetExtension(), 0u); +} + +TEST_P(RtpRtcpImpl2Test, AssignsTransmissionTimeOffset) { + sender_.RegisterHeaderExtension(TransmissionOffset::kUri, + kTransmissionOffsetExtensionId); + + constexpr TimeDelta kOffset = TimeDelta::Millis(100); + // Transmission offset is calculated from difference between capture time + // and send time. + int64_t capture_time_ms = time_controller_.GetClock()->TimeInMilliseconds(); + time_controller_.AdvanceTime(kOffset); + + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid, + /*timestamp=*/0, capture_time_ms)); + EXPECT_EQ(sender_.last_packet().GetExtension(), + kOffset.ms() * kCaptureTimeMsToRtpTimestamp); +} + +TEST_P(RtpRtcpImpl2Test, PropagatesSentPacketInfo) { + sender_.RegisterHeaderExtension(TransportSequenceNumber::kUri, + kTransportSequenceNumberExtensionId); + int64_t now_ms = time_controller_.GetClock()->TimeInMilliseconds(); + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); + EXPECT_THAT( + sender_.last_sent_packet(), + Optional( + AllOf(Field(&RtpRtcpModule::SentPacket::packet_id, + Eq(sender_.last_packet() + .GetExtension())), + Field(&RtpRtcpModule::SentPacket::capture_time_ms, Eq(now_ms)), + Field(&RtpRtcpModule::SentPacket::ssrc, Eq(kSenderSsrc))))); +} + +TEST_P(RtpRtcpImpl2Test, GeneratesFlexfec) { + constexpr int kFlexfecPayloadType = 118; + constexpr uint32_t kFlexfecSsrc = 17; + const char kNoMid[] = ""; + const std::vector kNoRtpExtensions; + const std::vector kNoRtpExtensionSizes; + + // Make sure FlexFec sequence numbers start at a different point than media. + const uint16_t fec_start_seq = sender_.impl_->SequenceNumber() + 100; + RtpState start_state; + start_state.sequence_number = fec_start_seq; + FlexfecSender flexfec_sender(kFlexfecPayloadType, kFlexfecSsrc, kSenderSsrc, + kNoMid, kNoRtpExtensions, kNoRtpExtensionSizes, + &start_state, time_controller_.GetClock()); + ReinitWithFec(&flexfec_sender, /*red_payload_type=*/absl::nullopt); + + // Parameters selected to generate a single FEC packet per media packet. + FecProtectionParams params; + params.fec_rate = 15; + params.max_fec_frames = 1; + params.fec_mask_type = kFecMaskRandom; + sender_.impl_->SetFecProtectionParams(params, params); + + // Send a one packet frame, expect one media packet and one FEC packet. + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); + ASSERT_THAT(sender_.transport_.rtp_packets_sent_, Eq(2)); + + const RtpPacketReceived& fec_packet = sender_.last_packet(); + EXPECT_EQ(fec_packet.SequenceNumber(), fec_start_seq); + EXPECT_EQ(fec_packet.Ssrc(), kFlexfecSsrc); + EXPECT_EQ(fec_packet.PayloadType(), kFlexfecPayloadType); +} + +TEST_P(RtpRtcpImpl2Test, GeneratesUlpfec) { + constexpr int kUlpfecPayloadType = 118; + constexpr int kRedPayloadType = 119; + UlpfecGenerator ulpfec_sender(kRedPayloadType, kUlpfecPayloadType, + time_controller_.GetClock()); + ReinitWithFec(&ulpfec_sender, kRedPayloadType); + + // Parameters selected to generate a single FEC packet per media packet. + FecProtectionParams params; + params.fec_rate = 15; + params.max_fec_frames = 1; + params.fec_mask_type = kFecMaskRandom; + sender_.impl_->SetFecProtectionParams(params, params); + + // Send a one packet frame, expect one media packet and one FEC packet. + EXPECT_TRUE(SendFrame(&sender_, sender_video_.get(), kBaseLayerTid)); + ASSERT_THAT(sender_.transport_.rtp_packets_sent_, Eq(2)); + + // Ulpfec is sent on the media ssrc, sharing the sequene number series. + const RtpPacketReceived& fec_packet = sender_.last_packet(); + EXPECT_EQ(fec_packet.SequenceNumber(), kSequenceNumber + 1); + EXPECT_EQ(fec_packet.Ssrc(), kSenderSsrc); + // The packets are encapsulated in RED packets, check that and that the RED + // header (first byte of payload) indicates the desired FEC payload type. + EXPECT_EQ(fec_packet.PayloadType(), kRedPayloadType); + EXPECT_EQ(fec_packet.payload()[0], kUlpfecPayloadType); +} + +INSTANTIATE_TEST_SUITE_P(WithAndWithoutOverhead, + RtpRtcpImpl2Test, + ::testing::Values(TestConfig{false}, + TestConfig{true})); + } // namespace webrtc diff --git a/modules/rtp_rtcp/source/rtp_rtcp_impl_unittest.cc b/modules/rtp_rtcp/source/rtp_rtcp_impl_unittest.cc index 05c6ae1cbf..ac05584e18 100644 --- a/modules/rtp_rtcp/source/rtp_rtcp_impl_unittest.cc +++ b/modules/rtp_rtcp/source/rtp_rtcp_impl_unittest.cc @@ -27,6 +27,11 @@ #include "test/rtp_header_parser.h" using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Field; +using ::testing::Gt; +using ::testing::Not; +using ::testing::Optional; namespace webrtc { namespace { @@ -67,11 +72,10 @@ class SendTransport : public Transport { bool SendRtp(const uint8_t* data, size_t len, const PacketOptions& options) override { - RTPHeader header; - std::unique_ptr parser(RtpHeaderParser::CreateForTest()); - EXPECT_TRUE(parser->Parse(static_cast(data), len, &header)); + RtpPacket packet; + EXPECT_TRUE(packet.Parse(data, len)); ++rtp_packets_sent_; - last_rtp_header_ = header; + last_rtp_sequence_number_ = packet.SequenceNumber(); return true; } bool SendRtcp(const uint8_t* data, size_t len) override { @@ -93,7 +97,7 @@ class SendTransport : public Transport { int64_t delay_ms_; int rtp_packets_sent_; size_t rtcp_packets_sent_; - RTPHeader last_rtp_header_; + uint16_t last_rtp_sequence_number_; std::vector last_nack_list_; }; @@ -133,7 +137,7 @@ class RtpRtcpModule : public RtcpPacketTypeCounterObserver { } int RtpSent() { return transport_.rtp_packets_sent_; } uint16_t LastRtpSequenceNumber() { - return transport_.last_rtp_header_.sequenceNumber; + return transport_.last_rtp_sequence_number_; } std::vector LastNackListSent() { return transport_.last_nack_list_; @@ -616,4 +620,102 @@ TEST_F(RtpRtcpImplTest, StoresPacketInfoForSentPackets) { /*is_last=*/1))); } +// Checks that the sender report stats are not available if no RTCP SR was sent. +TEST_F(RtpRtcpImplTest, SenderReportStatsNotAvailable) { + EXPECT_THAT(receiver_.impl_->GetSenderReportStats(), Eq(absl::nullopt)); +} + +// Checks that the sender report stats are available if an RTCP SR was sent. +TEST_F(RtpRtcpImplTest, SenderReportStatsAvailable) { + // Send a frame in order to send an SR. + SendFrame(&sender_, sender_video_.get(), kBaseLayerTid); + // Send an SR. + ASSERT_THAT(sender_.impl_->SendRTCP(kRtcpReport), Eq(0)); + EXPECT_THAT(receiver_.impl_->GetSenderReportStats(), Not(Eq(absl::nullopt))); +} + +// Checks that the sender report stats are not available if an RTCP SR with an +// unexpected SSRC is received. +TEST_F(RtpRtcpImplTest, SenderReportStatsNotUpdatedWithUnexpectedSsrc) { + constexpr uint32_t kUnexpectedSenderSsrc = 0x87654321; + static_assert(kUnexpectedSenderSsrc != kSenderSsrc, ""); + // Forge a sender report and pass it to the receiver as if an RTCP SR were + // sent by an unexpected sender. + rtcp::SenderReport sr; + sr.SetSenderSsrc(kUnexpectedSenderSsrc); + sr.SetNtp({/*seconds=*/1u, /*fractions=*/1u << 31}); + sr.SetPacketCount(123u); + sr.SetOctetCount(456u); + auto raw_packet = sr.Build(); + receiver_.impl_->IncomingRtcpPacket(raw_packet.data(), raw_packet.size()); + EXPECT_THAT(receiver_.impl_->GetSenderReportStats(), Eq(absl::nullopt)); +} + +// Checks the stats derived from the last received RTCP SR are set correctly. +TEST_F(RtpRtcpImplTest, SenderReportStatsCheckStatsFromLastReport) { + using SenderReportStats = RtpRtcpInterface::SenderReportStats; + const NtpTime ntp(/*seconds=*/1u, /*fractions=*/1u << 31); + constexpr uint32_t kPacketCount = 123u; + constexpr uint32_t kOctetCount = 456u; + // Forge a sender report and pass it to the receiver as if an RTCP SR were + // sent by the sender. + rtcp::SenderReport sr; + sr.SetSenderSsrc(kSenderSsrc); + sr.SetNtp(ntp); + sr.SetPacketCount(kPacketCount); + sr.SetOctetCount(kOctetCount); + auto raw_packet = sr.Build(); + receiver_.impl_->IncomingRtcpPacket(raw_packet.data(), raw_packet.size()); + + EXPECT_THAT( + receiver_.impl_->GetSenderReportStats(), + Optional(AllOf(Field(&SenderReportStats::last_remote_timestamp, Eq(ntp)), + Field(&SenderReportStats::packets_sent, Eq(kPacketCount)), + Field(&SenderReportStats::bytes_sent, Eq(kOctetCount))))); +} + +// Checks that the sender report stats count equals the number of sent RTCP SRs. +TEST_F(RtpRtcpImplTest, SenderReportStatsCount) { + using SenderReportStats = RtpRtcpInterface::SenderReportStats; + // Send a frame in order to send an SR. + SendFrame(&sender_, sender_video_.get(), kBaseLayerTid); + // Send the first SR. + ASSERT_THAT(sender_.impl_->SendRTCP(kRtcpReport), Eq(0)); + EXPECT_THAT(receiver_.impl_->GetSenderReportStats(), + Optional(Field(&SenderReportStats::reports_count, Eq(1u)))); + // Send the second SR. + ASSERT_THAT(sender_.impl_->SendRTCP(kRtcpReport), Eq(0)); + EXPECT_THAT(receiver_.impl_->GetSenderReportStats(), + Optional(Field(&SenderReportStats::reports_count, Eq(2u)))); +} + +// Checks that the sender report stats include a valid arrival time if an RTCP +// SR was sent. +TEST_F(RtpRtcpImplTest, SenderReportStatsArrivalTimestampSet) { + // Send a frame in order to send an SR. + SendFrame(&sender_, sender_video_.get(), kBaseLayerTid); + // Send an SR. + ASSERT_THAT(sender_.impl_->SendRTCP(kRtcpReport), Eq(0)); + auto stats = receiver_.impl_->GetSenderReportStats(); + ASSERT_THAT(stats, Not(Eq(absl::nullopt))); + EXPECT_TRUE(stats->last_arrival_timestamp.Valid()); +} + +// Checks that the packet and byte counters from an RTCP SR are not zero once +// a frame is sent. +TEST_F(RtpRtcpImplTest, SenderReportStatsPacketByteCounters) { + using SenderReportStats = RtpRtcpInterface::SenderReportStats; + // Send a frame in order to send an SR. + SendFrame(&sender_, sender_video_.get(), kBaseLayerTid); + ASSERT_THAT(sender_.transport_.rtp_packets_sent_, Gt(0)); + // Advance time otherwise the RTCP SR report will not include any packets + // generated by `SendFrame()`. + clock_.AdvanceTimeMilliseconds(1); + // Send an SR. + ASSERT_THAT(sender_.impl_->SendRTCP(kRtcpReport), Eq(0)); + EXPECT_THAT(receiver_.impl_->GetSenderReportStats(), + Optional(AllOf(Field(&SenderReportStats::packets_sent, Gt(0u)), + Field(&SenderReportStats::bytes_sent, Gt(0u))))); +} + } // namespace webrtc diff --git a/modules/rtp_rtcp/source/rtp_rtcp_interface.h b/modules/rtp_rtcp/source/rtp_rtcp_interface.h index 5bb3eb55e2..dd5744ec54 100644 --- a/modules/rtp_rtcp/source/rtp_rtcp_interface.h +++ b/modules/rtp_rtcp/source/rtp_rtcp_interface.h @@ -28,6 +28,7 @@ #include "modules/rtp_rtcp/source/rtp_sequence_number_map.h" #include "modules/rtp_rtcp/source/video_fec_generator.h" #include "rtc_base/constructor_magic.h" +#include "system_wrappers/include/ntp_time.h" namespace webrtc { @@ -76,13 +77,10 @@ class RtpRtcpInterface : public RtcpFeedbackSenderInterface { RtcpRttStats* rtt_stats = nullptr; RtcpPacketTypeCounterObserver* rtcp_packet_type_counter_observer = nullptr; // Called on receipt of RTCP report block from remote side. - // TODO(bugs.webrtc.org/10678): Remove RtcpStatisticsCallback in - // favor of ReportBlockDataObserver. // TODO(bugs.webrtc.org/10679): Consider whether we want to use // only getters or only callbacks. If we decide on getters, the // ReportBlockDataObserver should also be removed in favor of // GetLatestReportBlockData(). - RtcpStatisticsCallback* rtcp_statistics_callback = nullptr; RtcpCnameCallback* rtcp_cname_callback = nullptr; ReportBlockDataObserver* report_block_data_observer = nullptr; @@ -152,6 +150,27 @@ class RtpRtcpInterface : public RtcpFeedbackSenderInterface { RTC_DISALLOW_COPY_AND_ASSIGN(Configuration); }; + // Stats for RTCP sender reports (SR) for a specific SSRC. + // Refer to https://tools.ietf.org/html/rfc3550#section-6.4.1. + struct SenderReportStats { + // Arrival NPT timestamp for the last received RTCP SR. + NtpTime last_arrival_timestamp; + // Received (a.k.a., remote) NTP timestamp for the last received RTCP SR. + NtpTime last_remote_timestamp; + // Total number of RTP data packets transmitted by the sender since starting + // transmission up until the time this SR packet was generated. The count + // should be reset if the sender changes its SSRC identifier. + uint32_t packets_sent; + // Total number of payload octets (i.e., not including header or padding) + // transmitted in RTP data packets by the sender since starting transmission + // up until the time this SR packet was generated. The count should be reset + // if the sender changes its SSRC identifier. + uint64_t bytes_sent; + // Total number of RTCP SR blocks received. + // https://www.w3.org/TR/webrtc-stats/#dom-rtcremoteoutboundrtpstreamstats-reportssent. + uint64_t reports_count; + }; + // ************************************************************************** // Receiver functions // ************************************************************************** @@ -161,6 +180,10 @@ class RtpRtcpInterface : public RtcpFeedbackSenderInterface { virtual void SetRemoteSSRC(uint32_t ssrc) = 0; + // Called when the local ssrc changes (post initialization) for receive + // streams to match with send. Called on the packet receive thread/tq. + virtual void SetLocalSsrc(uint32_t ssrc) = 0; + // ************************************************************************** // Sender // ************************************************************************** @@ -361,17 +384,13 @@ class RtpRtcpInterface : public RtcpFeedbackSenderInterface { StreamDataCounters* rtp_counters, StreamDataCounters* rtx_counters) const = 0; - // Returns received RTCP report block. - // Returns -1 on failure else 0. - // TODO(https://crbug.com/webrtc/10678): Remove this in favor of - // GetLatestReportBlockData(). - virtual int32_t RemoteRTCPStat( - std::vector* receive_blocks) const = 0; // A snapshot of Report Blocks with additional data of interest to statistics. // Within this list, the sender-source SSRC pair is unique and per-pair the // ReportBlockData represents the latest Report Block that was received for // that pair. virtual std::vector GetLatestReportBlockData() const = 0; + // Returns stats based on the received RTCP SRs. + virtual absl::optional GetSenderReportStats() const = 0; // (REMB) Receiver Estimated Max Bitrate. // Schedules sending REMB on next and following sender/receiver reports. diff --git a/modules/rtp_rtcp/source/rtp_sender.cc b/modules/rtp_rtcp/source/rtp_sender.cc index 584fced397..80c319f4f2 100644 --- a/modules/rtp_rtcp/source/rtp_sender.cc +++ b/modules/rtp_rtcp/source/rtp_sender.cc @@ -42,7 +42,6 @@ constexpr size_t kMaxPaddingLength = 224; constexpr size_t kMinAudioPaddingLength = 50; constexpr size_t kRtpHeaderLength = 12; constexpr uint16_t kMaxInitRtpSeqNumber = 32767; // 2^15 -1. -constexpr uint32_t kTimestampTicksPerMs = 90; // Min size needed to get payload padding from packet history. constexpr int kMinPayloadPaddingBytes = 50; @@ -105,6 +104,7 @@ bool IsNonVolatile(RTPExtensionType type) { switch (type) { case kRtpExtensionTransmissionTimeOffset: case kRtpExtensionAudioLevel: + case kRtpExtensionCsrcAudioLevel: case kRtpExtensionAbsoluteSendTime: case kRtpExtensionTransportSequenceNumber: case kRtpExtensionTransportSequenceNumber02: @@ -122,6 +122,7 @@ bool IsNonVolatile(RTPExtensionType type) { case kRtpExtensionVideoTiming: case kRtpExtensionRepairedRtpStreamId: case kRtpExtensionColorSpace: + case kRtpExtensionVideoFrameTrackingId: return false; case kRtpExtensionNone: case kRtpExtensionNumberOfExtensions: @@ -170,28 +171,25 @@ RTPSender::RTPSender(const RtpRtcpInterface::Configuration& config, paced_sender_(packet_sender), sending_media_(true), // Default to sending media. max_packet_size_(IP_PACKET_SIZE - 28), // Default is IP-v4/UDP. - last_payload_type_(-1), rtp_header_extension_map_(config.extmap_allow_mixed), - max_media_packet_header_(kRtpHeaderSize), - max_padding_fec_packet_header_(kRtpHeaderSize), // RTP variables - sequence_number_forced_(false), + sequencer_(config.local_media_ssrc, + config.rtx_send_ssrc.value_or(config.local_media_ssrc), + /*require_marker_before_media_padding_=*/!config.audio, + config.clock), always_send_mid_and_rid_(config.always_send_mid_and_rid), ssrc_has_acked_(false), rtx_ssrc_has_acked_(false), - last_rtp_timestamp_(0), - capture_time_ms_(0), - last_timestamp_time_ms_(0), - last_packet_marker_bit_(false), csrcs_(), rtx_(kRtxOff), supports_bwe_extension_(false), retransmission_rate_limiter_(config.retransmission_rate_limiter) { + UpdateHeaderSizes(); // This random initialization is not intended to be cryptographic strong. timestamp_offset_ = random_.Rand(); // Random start, 16 bits. Can't be 0. - sequence_number_rtx_ = random_.Rand(1, kMaxInitRtpSeqNumber); - sequence_number_ = random_.Rand(1, kMaxInitRtpSeqNumber); + sequencer_.set_rtx_sequence_number(random_.Rand(1, kMaxInitRtpSeqNumber)); + sequencer_.set_media_sequence_number(random_.Rand(1, kMaxInitRtpSeqNumber)); RTC_DCHECK(paced_sender_); RTC_DCHECK(packet_history_); @@ -229,15 +227,6 @@ void RTPSender::SetExtmapAllowMixed(bool extmap_allow_mixed) { rtp_header_extension_map_.SetExtmapAllowMixed(extmap_allow_mixed); } -int32_t RTPSender::RegisterRtpHeaderExtension(RTPExtensionType type, - uint8_t id) { - MutexLock lock(&send_mutex_); - bool registered = rtp_header_extension_map_.RegisterByType(id, type); - supports_bwe_extension_ = HasBweExtension(rtp_header_extension_map_); - UpdateHeaderSizes(); - return registered ? 0 : -1; -} - bool RTPSender::RegisterRtpHeaderExtension(absl::string_view uri, int id) { MutexLock lock(&send_mutex_); bool registered = rtp_header_extension_map_.RegisterByUri(id, uri); @@ -360,7 +349,11 @@ void RTPSender::OnReceivedAckOnSsrc(int64_t extended_highest_sequence_number) { void RTPSender::OnReceivedAckOnRtxSsrc( int64_t extended_highest_sequence_number) { MutexLock lock(&send_mutex_); + bool update_required = !rtx_ssrc_has_acked_; rtx_ssrc_has_acked_ = true; + if (update_required) { + UpdateHeaderSizes(); + } } void RTPSender::OnReceivedNack( @@ -452,23 +445,11 @@ std::vector> RTPSender::GeneratePadding( std::make_unique(&rtp_header_extension_map_); padding_packet->set_packet_type(RtpPacketMediaType::kPadding); padding_packet->SetMarker(false); - padding_packet->SetTimestamp(last_rtp_timestamp_); - padding_packet->set_capture_time_ms(capture_time_ms_); if (rtx_ == kRtxOff) { - if (last_payload_type_ == -1) { - break; - } - // Without RTX we can't send padding in the middle of frames. - // For audio marker bits doesn't mark the end of a frame and frames - // are usually a single packet, so for now we don't apply this rule - // for audio. - if (!audio_configured_ && !last_packet_marker_bit_) { + padding_packet->SetSsrc(ssrc_); + if (!sequencer_.Sequence(*padding_packet)) { break; } - - padding_packet->SetSsrc(ssrc_); - padding_packet->SetPayloadType(last_payload_type_); - padding_packet->SetSequenceNumber(sequence_number_++); } else { // Without abs-send-time or transport sequence number a media packet // must be sent before padding so that the timestamps used for @@ -479,24 +460,13 @@ std::vector> RTPSender::GeneratePadding( TransportSequenceNumber::kId))) { break; } - // Only change the timestamp of padding packets sent over RTX. - // Padding only packets over RTP has to be sent as part of a media - // frame (and therefore the same timestamp). - int64_t now_ms = clock_->TimeInMilliseconds(); - if (last_timestamp_time_ms_ > 0) { - padding_packet->SetTimestamp(padding_packet->Timestamp() + - (now_ms - last_timestamp_time_ms_) * - kTimestampTicksPerMs); - if (padding_packet->capture_time_ms() > 0) { - padding_packet->set_capture_time_ms( - padding_packet->capture_time_ms() + - (now_ms - last_timestamp_time_ms_)); - } - } + RTC_DCHECK(rtx_ssrc_); padding_packet->SetSsrc(*rtx_ssrc_); - padding_packet->SetSequenceNumber(sequence_number_rtx_++); padding_packet->SetPayloadType(rtx_payload_type_map_.begin()->second); + if (!sequencer_.Sequence(*padding_packet)) { + break; + } } if (rtp_header_extension_map_.IsRegistered(TransportSequenceNumber::kId)) { @@ -561,13 +531,6 @@ size_t RTPSender::ExpectedPerPacketOverhead() const { return max_media_packet_header_; } -uint16_t RTPSender::AllocateSequenceNumber(uint16_t packets_to_send) { - MutexLock lock(&send_mutex_); - uint16_t first_allocated_sequence_number = sequence_number_; - sequence_number_ += packets_to_send; - return first_allocated_sequence_number; -} - std::unique_ptr RTPSender::AllocatePacket() const { MutexLock lock(&send_mutex_); // TODO(danilchap): Find better motivator and value for extra capacity. @@ -614,18 +577,18 @@ bool RTPSender::AssignSequenceNumber(RtpPacketToSend* packet) { MutexLock lock(&send_mutex_); if (!sending_media_) return false; - RTC_DCHECK(packet->Ssrc() == ssrc_); - packet->SetSequenceNumber(sequence_number_++); - - // Remember marker bit to determine if padding can be inserted with - // sequence number following |packet|. - last_packet_marker_bit_ = packet->Marker(); - // Remember payload type to use in the padding packet if rtx is disabled. - last_payload_type_ = packet->PayloadType(); - // Save timestamps to generate timestamp field and extensions for the padding. - last_rtp_timestamp_ = packet->Timestamp(); - last_timestamp_time_ms_ = clock_->TimeInMilliseconds(); - capture_time_ms_ = packet->capture_time_ms(); + return sequencer_.Sequence(*packet); +} + +bool RTPSender::AssignSequenceNumbersAndStoreLastPacketState( + rtc::ArrayView> packets) { + RTC_DCHECK(!packets.empty()); + MutexLock lock(&send_mutex_); + if (!sending_media_) + return false; + for (auto& packet : packets) { + sequencer_.Sequence(*packet); + } return true; } @@ -680,11 +643,10 @@ void RTPSender::SetSequenceNumber(uint16_t seq) { bool updated_sequence_number = false; { MutexLock lock(&send_mutex_); - sequence_number_forced_ = true; - if (sequence_number_ != seq) { + if (sequencer_.media_sequence_number() != seq) { updated_sequence_number = true; } - sequence_number_ = seq; + sequencer_.set_media_sequence_number(seq); } if (updated_sequence_number) { @@ -696,7 +658,7 @@ void RTPSender::SetSequenceNumber(uint16_t seq) { uint16_t RTPSender::SequenceNumber() const { MutexLock lock(&send_mutex_); - return sequence_number_; + return sequencer_.media_sequence_number(); } static void CopyHeaderAndExtensionsToRtxPacket(const RtpPacketToSend& packet, @@ -769,12 +731,12 @@ std::unique_ptr RTPSender::BuildRtxPacket( rtx_packet->SetPayloadType(kv->second); - // Replace sequence number. - rtx_packet->SetSequenceNumber(sequence_number_rtx_++); - // Replace SSRC. rtx_packet->SetSsrc(*rtx_ssrc_); + // Replace sequence number. + sequencer_.Sequence(*rtx_packet); + CopyHeaderAndExtensionsToRtxPacket(packet, rtx_packet.get()); // RTX packets are sent on an SSRC different from the main media, so the @@ -809,8 +771,8 @@ std::unique_ptr RTPSender::BuildRtxPacket( auto payload = packet.payload(); memcpy(rtx_payload + kRtxHeaderSize, payload.data(), payload.size()); - // Add original application data. - rtx_packet->set_application_data(packet.application_data()); + // Add original additional data. + rtx_packet->set_additional_data(packet.additional_data()); // Copy capture time so e.g. TransmissionOffset is correctly set. rtx_packet->set_capture_time_ms(packet.capture_time_ms()); @@ -820,12 +782,9 @@ std::unique_ptr RTPSender::BuildRtxPacket( void RTPSender::SetRtpState(const RtpState& rtp_state) { MutexLock lock(&send_mutex_); - sequence_number_ = rtp_state.sequence_number; - sequence_number_forced_ = true; + timestamp_offset_ = rtp_state.start_timestamp; - last_rtp_timestamp_ = rtp_state.timestamp; - capture_time_ms_ = rtp_state.capture_time_ms; - last_timestamp_time_ms_ = rtp_state.last_timestamp_time_ms; + sequencer_.SetRtpState(rtp_state); ssrc_has_acked_ = rtp_state.ssrc_has_acked; UpdateHeaderSizes(); } @@ -834,18 +793,15 @@ RtpState RTPSender::GetRtpState() const { MutexLock lock(&send_mutex_); RtpState state; - state.sequence_number = sequence_number_; state.start_timestamp = timestamp_offset_; - state.timestamp = last_rtp_timestamp_; - state.capture_time_ms = capture_time_ms_; - state.last_timestamp_time_ms = last_timestamp_time_ms_; state.ssrc_has_acked = ssrc_has_acked_; + sequencer_.PupulateRtpState(state); return state; } void RTPSender::SetRtxRtpState(const RtpState& rtp_state) { MutexLock lock(&send_mutex_); - sequence_number_rtx_ = rtp_state.sequence_number; + sequencer_.set_rtx_sequence_number(rtp_state.sequence_number); rtx_ssrc_has_acked_ = rtp_state.ssrc_has_acked; } @@ -853,18 +809,13 @@ RtpState RTPSender::GetRtxRtpState() const { MutexLock lock(&send_mutex_); RtpState state; - state.sequence_number = sequence_number_rtx_; + state.sequence_number = sequencer_.rtx_sequence_number(); state.start_timestamp = timestamp_offset_; state.ssrc_has_acked = rtx_ssrc_has_acked_; return state; } -int64_t RTPSender::LastTimestampTimeMs() const { - MutexLock lock(&send_mutex_); - return last_timestamp_time_ms_; -} - void RTPSender::UpdateHeaderSizes() { const size_t rtp_header_length = kRtpHeaderLength + sizeof(uint32_t) * csrcs_.size(); @@ -874,10 +825,12 @@ void RTPSender::UpdateHeaderSizes() { rtp_header_extension_map_); // RtpStreamId and Mid are treated specially in that we check if they - // currently are being sent. RepairedRtpStreamId is still ignored since we - // assume RTX will not make up large enough bitrate to treat overhead - // differently. - const bool send_mid_rid = always_send_mid_and_rid_ || !ssrc_has_acked_; + // currently are being sent. RepairedRtpStreamId is ignored because it is sent + // instead of RtpStreamId on rtx packets and require the same size. + const bool send_mid_rid_on_rtx = + rtx_ssrc_.has_value() && !rtx_ssrc_has_acked_; + const bool send_mid_rid = + always_send_mid_and_rid_ || !ssrc_has_acked_ || send_mid_rid_on_rtx; std::vector non_volatile_extensions; for (auto& extension : audio_configured_ ? AudioExtensionSizes() : VideoExtensionSizes()) { @@ -901,5 +854,9 @@ void RTPSender::UpdateHeaderSizes() { max_media_packet_header_ = rtp_header_length + RtpHeaderExtensionSize(non_volatile_extensions, rtp_header_extension_map_); + // Reserve extra bytes if packet might be resent in an rtx packet. + if (rtx_ssrc_.has_value()) { + max_media_packet_header_ += kRtxHeaderSize; + } } } // namespace webrtc diff --git a/modules/rtp_rtcp/source/rtp_sender.h b/modules/rtp_rtcp/source/rtp_sender.h index 1580259b36..fbf135049c 100644 --- a/modules/rtp_rtcp/source/rtp_sender.h +++ b/modules/rtp_rtcp/source/rtp_sender.h @@ -26,10 +26,10 @@ #include "modules/rtp_rtcp/include/rtp_header_extension_map.h" #include "modules/rtp_rtcp/include/rtp_packet_sender.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" +#include "modules/rtp_rtcp/source/packet_sequencer.h" #include "modules/rtp_rtcp/source/rtp_packet_history.h" #include "modules/rtp_rtcp/source/rtp_rtcp_config.h" #include "modules/rtp_rtcp/source/rtp_rtcp_interface.h" -#include "rtc_base/deprecation.h" #include "rtc_base/random.h" #include "rtc_base/rate_statistics.h" #include "rtc_base/synchronization/mutex.h" @@ -78,8 +78,6 @@ class RTPSender { RTC_LOCKS_EXCLUDED(send_mutex_); // RTP header extension - int32_t RegisterRtpHeaderExtension(RTPExtensionType type, uint8_t id) - RTC_LOCKS_EXCLUDED(send_mutex_); bool RegisterRtpHeaderExtension(absl::string_view uri, int id) RTC_LOCKS_EXCLUDED(send_mutex_); bool IsRtpHeaderExtensionRegistered(RTPExtensionType type) const @@ -139,13 +137,16 @@ class RTPSender { // Return false if sending was turned off. bool AssignSequenceNumber(RtpPacketToSend* packet) RTC_LOCKS_EXCLUDED(send_mutex_); + // Same as AssignSequenceNumber(), but applies sequence numbers atomically to + // a batch of packets. + bool AssignSequenceNumbersAndStoreLastPacketState( + rtc::ArrayView> packets) + RTC_LOCKS_EXCLUDED(send_mutex_); // Maximum header overhead per fec/padding packet. size_t FecOrPaddingPacketMaxRtpHeaderLength() const RTC_LOCKS_EXCLUDED(send_mutex_); // Expected header overhead per media packet. size_t ExpectedPerPacketOverhead() const RTC_LOCKS_EXCLUDED(send_mutex_); - uint16_t AllocateSequenceNumber(uint16_t packets_to_send) - RTC_LOCKS_EXCLUDED(send_mutex_); // Including RTP headers. size_t MaxRtpPacketSize() const RTC_LOCKS_EXCLUDED(send_mutex_); @@ -171,8 +172,6 @@ class RTPSender { RTC_LOCKS_EXCLUDED(send_mutex_); RtpState GetRtxRtpState() const RTC_LOCKS_EXCLUDED(send_mutex_); - int64_t LastTimestampTimeMs() const RTC_LOCKS_EXCLUDED(send_mutex_); - private: std::unique_ptr BuildRtxPacket( const RtpPacketToSend& packet); @@ -181,6 +180,9 @@ class RTPSender { void UpdateHeaderSizes() RTC_EXCLUSIVE_LOCKS_REQUIRED(send_mutex_); + void UpdateLastPacketState(const RtpPacketToSend& packet) + RTC_EXCLUSIVE_LOCKS_REQUIRED(send_mutex_); + Clock* const clock_; Random random_ RTC_GUARDED_BY(send_mutex_); @@ -201,17 +203,13 @@ class RTPSender { bool sending_media_ RTC_GUARDED_BY(send_mutex_); size_t max_packet_size_; - int8_t last_payload_type_ RTC_GUARDED_BY(send_mutex_); - RtpHeaderExtensionMap rtp_header_extension_map_ RTC_GUARDED_BY(send_mutex_); size_t max_media_packet_header_ RTC_GUARDED_BY(send_mutex_); size_t max_padding_fec_packet_header_ RTC_GUARDED_BY(send_mutex_); // RTP variables uint32_t timestamp_offset_ RTC_GUARDED_BY(send_mutex_); - bool sequence_number_forced_ RTC_GUARDED_BY(send_mutex_); - uint16_t sequence_number_ RTC_GUARDED_BY(send_mutex_); - uint16_t sequence_number_rtx_ RTC_GUARDED_BY(send_mutex_); + PacketSequencer sequencer_ RTC_GUARDED_BY(send_mutex_); // RID value to send in the RID or RepairedRID header extension. std::string rid_ RTC_GUARDED_BY(send_mutex_); // MID value to send in the MID header extension. @@ -222,10 +220,6 @@ class RTPSender { // when to stop sending the MID and RID header extensions. bool ssrc_has_acked_ RTC_GUARDED_BY(send_mutex_); bool rtx_ssrc_has_acked_ RTC_GUARDED_BY(send_mutex_); - uint32_t last_rtp_timestamp_ RTC_GUARDED_BY(send_mutex_); - int64_t capture_time_ms_ RTC_GUARDED_BY(send_mutex_); - int64_t last_timestamp_time_ms_ RTC_GUARDED_BY(send_mutex_); - bool last_packet_marker_bit_ RTC_GUARDED_BY(send_mutex_); std::vector csrcs_ RTC_GUARDED_BY(send_mutex_); int rtx_ RTC_GUARDED_BY(send_mutex_); // Mapping rtx_payload_type_map_[associated] = rtx. diff --git a/modules/rtp_rtcp/source/rtp_sender_audio.cc b/modules/rtp_rtcp/source/rtp_sender_audio.cc index 8cf60aaecd..4d72211b7c 100644 --- a/modules/rtp_rtcp/source/rtp_sender_audio.cc +++ b/modules/rtp_rtcp/source/rtp_sender_audio.cc @@ -157,7 +157,7 @@ bool RTPSenderAudio::SendAudio(AudioFrameType frame_type, return SendAudio(frame_type, payload_type, rtp_timestamp, payload_data, payload_size, // TODO(bugs.webrtc.org/10739) replace once plumbed. - /*absolute_capture_timestamp_ms=*/0); + /*absolute_capture_timestamp_ms=*/-1); } bool RTPSenderAudio::SendAudio(AudioFrameType frame_type, @@ -277,22 +277,26 @@ bool RTPSenderAudio::SendAudio(AudioFrameType frame_type, packet->SetExtension( frame_type == AudioFrameType::kAudioFrameSpeech, audio_level_dbov); - // Send absolute capture time periodically in order to optimize and save - // network traffic. Missing absolute capture times can be interpolated on the - // receiving end if sending intervals are small enough. - auto absolute_capture_time = absolute_capture_time_sender_.OnSendPacket( - AbsoluteCaptureTimeSender::GetSource(packet->Ssrc(), packet->Csrcs()), - packet->Timestamp(), - // Replace missing value with 0 (invalid frequency), this will trigger - // absolute capture time sending. - encoder_rtp_timestamp_frequency.value_or(0), - Int64MsToUQ32x32(absolute_capture_timestamp_ms + NtpOffsetMs()), - /*estimated_capture_clock_offset=*/ - include_capture_clock_offset_ ? absl::make_optional(0) : absl::nullopt); - if (absolute_capture_time) { - // It also checks that extension was registered during SDP negotiation. If - // not then setter won't do anything. - packet->SetExtension(*absolute_capture_time); + if (absolute_capture_timestamp_ms > 0) { + // Send absolute capture time periodically in order to optimize and save + // network traffic. Missing absolute capture times can be interpolated on + // the receiving end if sending intervals are small enough. + auto absolute_capture_time = absolute_capture_time_sender_.OnSendPacket( + AbsoluteCaptureTimeSender::GetSource(packet->Ssrc(), packet->Csrcs()), + packet->Timestamp(), + // Replace missing value with 0 (invalid frequency), this will trigger + // absolute capture time sending. + encoder_rtp_timestamp_frequency.value_or(0), + Int64MsToUQ32x32(clock_->ConvertTimestampToNtpTimeInMilliseconds( + absolute_capture_timestamp_ms)), + /*estimated_capture_clock_offset=*/ + include_capture_clock_offset_ ? absl::make_optional(0) : absl::nullopt); + if (absolute_capture_time) { + // It also checks that extension was registered during SDP negotiation. If + // not then setter won't do anything. + packet->SetExtension( + *absolute_capture_time); + } } uint8_t* payload = packet->AllocatePayload(payload_size); diff --git a/modules/rtp_rtcp/source/rtp_sender_audio.h b/modules/rtp_rtcp/source/rtp_sender_audio.h index 57b9dd7ce6..6d61facc9a 100644 --- a/modules/rtp_rtcp/source/rtp_sender_audio.h +++ b/modules/rtp_rtcp/source/rtp_sender_audio.h @@ -51,6 +51,8 @@ class RTPSenderAudio { const uint8_t* payload_data, size_t payload_size); + // `absolute_capture_timestamp_ms` and `Clock::CurrentTime` + // should be using the same epoch. bool SendAudio(AudioFrameType frame_type, int8_t payload_type, uint32_t rtp_timestamp, diff --git a/modules/rtp_rtcp/source/rtp_sender_audio_unittest.cc b/modules/rtp_rtcp/source/rtp_sender_audio_unittest.cc index d75f4e8947..0221800ea8 100644 --- a/modules/rtp_rtcp/source/rtp_sender_audio_unittest.cc +++ b/modules/rtp_rtcp/source/rtp_sender_audio_unittest.cc @@ -19,7 +19,6 @@ #include "modules/rtp_rtcp/source/rtp_header_extensions.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "modules/rtp_rtcp/source/rtp_rtcp_impl2.h" -#include "modules/rtp_rtcp/source/time_util.h" #include "test/field_trial.h" #include "test/gmock.h" #include "test/gtest.h" @@ -167,8 +166,10 @@ TEST_F(RtpSenderAudioTest, SendAudioWithAbsoluteCaptureTime) { transport_.last_sent_packet() .GetExtension(); EXPECT_TRUE(absolute_capture_time); - EXPECT_EQ(absolute_capture_time->absolute_capture_timestamp, - Int64MsToUQ32x32(kAbsoluteCaptureTimestampMs + NtpOffsetMs())); + EXPECT_EQ( + absolute_capture_time->absolute_capture_timestamp, + Int64MsToUQ32x32(fake_clock_.ConvertTimestampToNtpTimeInMilliseconds( + kAbsoluteCaptureTimestampMs))); EXPECT_FALSE( absolute_capture_time->estimated_capture_clock_offset.has_value()); } @@ -201,8 +202,10 @@ TEST_F(RtpSenderAudioTest, transport_.last_sent_packet() .GetExtension(); EXPECT_TRUE(absolute_capture_time); - EXPECT_EQ(absolute_capture_time->absolute_capture_timestamp, - Int64MsToUQ32x32(kAbsoluteCaptureTimestampMs + NtpOffsetMs())); + EXPECT_EQ( + absolute_capture_time->absolute_capture_timestamp, + Int64MsToUQ32x32(fake_clock_.ConvertTimestampToNtpTimeInMilliseconds( + kAbsoluteCaptureTimestampMs))); EXPECT_TRUE( absolute_capture_time->estimated_capture_clock_offset.has_value()); EXPECT_EQ(0, *absolute_capture_time->estimated_capture_clock_offset); diff --git a/modules/rtp_rtcp/source/rtp_sender_egress.cc b/modules/rtp_rtcp/source/rtp_sender_egress.cc index aba23ddc4b..126b89c8c8 100644 --- a/modules/rtp_rtcp/source/rtp_sender_egress.cc +++ b/modules/rtp_rtcp/source/rtp_sender_egress.cc @@ -142,6 +142,9 @@ void RtpSenderEgress::SendPacket(RtpPacketToSend* packet, RTC_DCHECK(packet->packet_type().has_value()); RTC_DCHECK(HasCorrectSsrc(*packet)); + if (packet->packet_type() == RtpPacketMediaType::kRetransmission) { + RTC_DCHECK(packet->retransmitted_sequence_number().has_value()); + } const uint32_t packet_ssrc = packet->Ssrc(); const int64_t now_ms = clock_->TimeInMilliseconds(); @@ -250,8 +253,7 @@ void RtpSenderEgress::SendPacket(RtpPacketToSend* packet, AddPacketToTransportFeedback(*packet_id, *packet, pacing_info); } - options.application_data.assign(packet->application_data().begin(), - packet->application_data().end()); + options.additional_data = packet->additional_data(); if (packet->packet_type() != RtpPacketMediaType::kPadding && packet->packet_type() != RtpPacketMediaType::kRetransmission) { @@ -410,12 +412,34 @@ void RtpSenderEgress::AddPacketToTransportFeedback( } RtpPacketSendInfo packet_info; - packet_info.ssrc = ssrc_; packet_info.transport_sequence_number = packet_id; - packet_info.rtp_sequence_number = packet.SequenceNumber(); + packet_info.rtp_timestamp = packet.Timestamp(); packet_info.length = packet_size; packet_info.pacing_info = pacing_info; packet_info.packet_type = packet.packet_type(); + + switch (*packet_info.packet_type) { + case RtpPacketMediaType::kAudio: + case RtpPacketMediaType::kVideo: + packet_info.media_ssrc = ssrc_; + packet_info.rtp_sequence_number = packet.SequenceNumber(); + break; + case RtpPacketMediaType::kRetransmission: + // For retransmissions, we're want to remove the original media packet + // if the rentrasmit arrives - so populate that in the packet info. + packet_info.media_ssrc = ssrc_; + packet_info.rtp_sequence_number = + *packet.retransmitted_sequence_number(); + break; + case RtpPacketMediaType::kPadding: + case RtpPacketMediaType::kForwardErrorCorrection: + // We're not interested in feedback about these packets being received + // or lost. + break; + } + // TODO(bugs.webrtc.org/12713): Remove once downstream usage is gone. + packet_info.ssrc = packet_info.media_ssrc.value_or(0); + transport_feedback_observer_->OnAddPacket(packet_info); } } diff --git a/modules/rtp_rtcp/source/rtp_sender_egress.h b/modules/rtp_rtcp/source/rtp_sender_egress.h index d7d71e2f1f..c767a1fe1b 100644 --- a/modules/rtp_rtcp/source/rtp_sender_egress.h +++ b/modules/rtp_rtcp/source/rtp_sender_egress.h @@ -19,6 +19,7 @@ #include "absl/types/optional.h" #include "api/call/transport.h" #include "api/rtc_event_log/rtc_event_log.h" +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_base.h" #include "api/units/data_rate.h" #include "modules/remote_bitrate_estimator/test/bwe_test_logging.h" @@ -29,7 +30,6 @@ #include "modules/rtp_rtcp/source/rtp_sequence_number_map.h" #include "rtc_base/rate_statistics.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/task_utils/repeating_task.h" diff --git a/modules/rtp_rtcp/source/rtp_sender_egress_unittest.cc b/modules/rtp_rtcp/source/rtp_sender_egress_unittest.cc new file mode 100644 index 0000000000..4f3990cc3e --- /dev/null +++ b/modules/rtp_rtcp/source/rtp_sender_egress_unittest.cc @@ -0,0 +1,982 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/rtp_rtcp/source/rtp_sender_egress.h" + +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/call/transport.h" +#include "api/units/data_size.h" +#include "api/units/timestamp.h" +#include "logging/rtc_event_log/mock/mock_rtc_event_log.h" +#include "modules/rtp_rtcp/include/flexfec_sender.h" +#include "modules/rtp_rtcp/include/rtp_rtcp.h" +#include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" +#include "modules/rtp_rtcp/source/rtp_header_extensions.h" +#include "modules/rtp_rtcp/source/rtp_packet_history.h" +#include "modules/rtp_rtcp/source/rtp_packet_received.h" +#include "modules/rtp_rtcp/source/rtp_packet_to_send.h" +#include "test/gmock.h" +#include "test/gtest.h" +#include "test/time_controller/simulated_time_controller.h" + +namespace webrtc { +namespace { + +using ::testing::_; +using ::testing::Field; +using ::testing::NiceMock; +using ::testing::Optional; +using ::testing::StrictMock; + +constexpr Timestamp kStartTime = Timestamp::Millis(123456789); +constexpr int kDefaultPayloadType = 100; +constexpr int kFlexfectPayloadType = 110; +constexpr uint16_t kStartSequenceNumber = 33; +constexpr uint32_t kSsrc = 725242; +constexpr uint32_t kRtxSsrc = 12345; +constexpr uint32_t kFlexFecSsrc = 23456; +enum : int { + kTransportSequenceNumberExtensionId = 1, + kAbsoluteSendTimeExtensionId, + kTransmissionOffsetExtensionId, + kVideoTimingExtensionId, +}; + +struct TestConfig { + explicit TestConfig(bool with_overhead) : with_overhead(with_overhead) {} + bool with_overhead = false; +}; + +class MockSendPacketObserver : public SendPacketObserver { + public: + MOCK_METHOD(void, OnSendPacket, (uint16_t, int64_t, uint32_t), (override)); +}; + +class MockTransportFeedbackObserver : public TransportFeedbackObserver { + public: + MOCK_METHOD(void, OnAddPacket, (const RtpPacketSendInfo&), (override)); + MOCK_METHOD(void, + OnTransportFeedback, + (const rtcp::TransportFeedback&), + (override)); +}; + +class MockStreamDataCountersCallback : public StreamDataCountersCallback { + public: + MOCK_METHOD(void, + DataCountersUpdated, + (const StreamDataCounters& counters, uint32_t ssrc), + (override)); +}; + +class MockSendSideDelayObserver : public SendSideDelayObserver { + public: + MOCK_METHOD(void, + SendSideDelayUpdated, + (int, int, uint64_t, uint32_t), + (override)); +}; + +class FieldTrialConfig : public WebRtcKeyValueConfig { + public: + FieldTrialConfig() : overhead_enabled_(false) {} + ~FieldTrialConfig() override {} + + void SetOverHeadEnabled(bool enabled) { overhead_enabled_ = enabled; } + + std::string Lookup(absl::string_view key) const override { + if (key == "WebRTC-SendSideBwe-WithOverhead") { + return overhead_enabled_ ? "Enabled" : "Disabled"; + } + return ""; + } + + private: + bool overhead_enabled_; +}; + +struct TransmittedPacket { + TransmittedPacket(rtc::ArrayView data, + const PacketOptions& packet_options, + RtpHeaderExtensionMap* extensions) + : packet(extensions), options(packet_options) { + EXPECT_TRUE(packet.Parse(data)); + } + RtpPacketReceived packet; + PacketOptions options; +}; + +class TestTransport : public Transport { + public: + explicit TestTransport(RtpHeaderExtensionMap* extensions) + : total_data_sent_(DataSize::Zero()), extensions_(extensions) {} + bool SendRtp(const uint8_t* packet, + size_t length, + const PacketOptions& options) override { + total_data_sent_ += DataSize::Bytes(length); + last_packet_.emplace(rtc::MakeArrayView(packet, length), options, + extensions_); + return true; + } + + bool SendRtcp(const uint8_t*, size_t) override { RTC_CHECK_NOTREACHED(); } + + absl::optional last_packet() { return last_packet_; } + + private: + DataSize total_data_sent_; + absl::optional last_packet_; + RtpHeaderExtensionMap* const extensions_; +}; + +} // namespace + +class RtpSenderEgressTest : public ::testing::TestWithParam { + protected: + RtpSenderEgressTest() + : time_controller_(kStartTime), + clock_(time_controller_.GetClock()), + transport_(&header_extensions_), + packet_history_(clock_, /*enable_rtx_padding_prioritization=*/true), + sequence_number_(kStartSequenceNumber) { + trials_.SetOverHeadEnabled(GetParam().with_overhead); + } + + std::unique_ptr CreateRtpSenderEgress() { + return std::make_unique(DefaultConfig(), &packet_history_); + } + + RtpRtcp::Configuration DefaultConfig() { + RtpRtcp::Configuration config; + config.clock = clock_; + config.outgoing_transport = &transport_; + config.local_media_ssrc = kSsrc; + config.rtx_send_ssrc = kRtxSsrc; + config.fec_generator = nullptr; + config.event_log = &mock_rtc_event_log_; + config.send_packet_observer = &send_packet_observer_; + config.rtp_stats_callback = &mock_rtp_stats_callback_; + config.transport_feedback_callback = &feedback_observer_; + config.populate_network2_timestamp = false; + config.field_trials = &trials_; + return config; + } + + std::unique_ptr BuildRtpPacket(bool marker_bit, + int64_t capture_time_ms) { + auto packet = std::make_unique(&header_extensions_); + packet->SetSsrc(kSsrc); + packet->ReserveExtension(); + packet->ReserveExtension(); + packet->ReserveExtension(); + + packet->SetPayloadType(kDefaultPayloadType); + packet->set_packet_type(RtpPacketMediaType::kVideo); + packet->SetMarker(marker_bit); + packet->SetTimestamp(capture_time_ms * 90); + packet->set_capture_time_ms(capture_time_ms); + packet->SetSequenceNumber(sequence_number_++); + return packet; + } + + std::unique_ptr BuildRtpPacket() { + return BuildRtpPacket(/*marker_bit=*/true, clock_->CurrentTime().ms()); + } + + GlobalSimulatedTimeController time_controller_; + Clock* const clock_; + NiceMock mock_rtc_event_log_; + NiceMock mock_rtp_stats_callback_; + NiceMock send_packet_observer_; + NiceMock feedback_observer_; + RtpHeaderExtensionMap header_extensions_; + TestTransport transport_; + RtpPacketHistory packet_history_; + FieldTrialConfig trials_; + uint16_t sequence_number_; +}; + +TEST_P(RtpSenderEgressTest, TransportFeedbackObserverGetsCorrectByteCount) { + constexpr size_t kRtpOverheadBytesPerPacket = 12 + 8; + constexpr size_t kPayloadSize = 1400; + const uint16_t kTransportSequenceNumber = 17; + + header_extensions_.RegisterByUri(kTransportSequenceNumberExtensionId, + TransportSequenceNumber::kUri); + + const size_t expected_bytes = GetParam().with_overhead + ? kPayloadSize + kRtpOverheadBytesPerPacket + : kPayloadSize; + + EXPECT_CALL( + feedback_observer_, + OnAddPacket(AllOf( + Field(&RtpPacketSendInfo::media_ssrc, kSsrc), + Field(&RtpPacketSendInfo::transport_sequence_number, + kTransportSequenceNumber), + Field(&RtpPacketSendInfo::rtp_sequence_number, kStartSequenceNumber), + Field(&RtpPacketSendInfo::length, expected_bytes), + Field(&RtpPacketSendInfo::pacing_info, PacedPacketInfo())))); + + std::unique_ptr packet = BuildRtpPacket(); + packet->SetExtension(kTransportSequenceNumber); + packet->AllocatePayload(kPayloadSize); + + std::unique_ptr sender = CreateRtpSenderEgress(); + sender->SendPacket(packet.get(), PacedPacketInfo()); +} + +TEST_P(RtpSenderEgressTest, PacketOptionsIsRetransmitSetByPacketType) { + std::unique_ptr sender = CreateRtpSenderEgress(); + + std::unique_ptr media_packet = BuildRtpPacket(); + media_packet->set_packet_type(RtpPacketMediaType::kVideo); + sender->SendPacket(media_packet.get(), PacedPacketInfo()); + EXPECT_FALSE(transport_.last_packet()->options.is_retransmit); + + std::unique_ptr retransmission = BuildRtpPacket(); + retransmission->set_packet_type(RtpPacketMediaType::kRetransmission); + retransmission->set_retransmitted_sequence_number( + media_packet->SequenceNumber()); + sender->SendPacket(retransmission.get(), PacedPacketInfo()); + EXPECT_TRUE(transport_.last_packet()->options.is_retransmit); +} + +TEST_P(RtpSenderEgressTest, DoesnSetIncludedInAllocationByDefault) { + std::unique_ptr sender = CreateRtpSenderEgress(); + + std::unique_ptr packet = BuildRtpPacket(); + sender->SendPacket(packet.get(), PacedPacketInfo()); + EXPECT_FALSE(transport_.last_packet()->options.included_in_feedback); + EXPECT_FALSE(transport_.last_packet()->options.included_in_allocation); +} + +TEST_P(RtpSenderEgressTest, + SetsIncludedInFeedbackWhenTransportSequenceNumberExtensionIsRegistered) { + std::unique_ptr sender = CreateRtpSenderEgress(); + + header_extensions_.RegisterByUri(kTransportSequenceNumberExtensionId, + TransportSequenceNumber::kUri); + std::unique_ptr packet = BuildRtpPacket(); + sender->SendPacket(packet.get(), PacedPacketInfo()); + EXPECT_TRUE(transport_.last_packet()->options.included_in_feedback); +} + +TEST_P( + RtpSenderEgressTest, + SetsIncludedInAllocationWhenTransportSequenceNumberExtensionIsRegistered) { + std::unique_ptr sender = CreateRtpSenderEgress(); + + header_extensions_.RegisterByUri(kTransportSequenceNumberExtensionId, + TransportSequenceNumber::kUri); + std::unique_ptr packet = BuildRtpPacket(); + sender->SendPacket(packet.get(), PacedPacketInfo()); + EXPECT_TRUE(transport_.last_packet()->options.included_in_allocation); +} + +TEST_P(RtpSenderEgressTest, + SetsIncludedInAllocationWhenForcedAsPartOfAllocation) { + std::unique_ptr sender = CreateRtpSenderEgress(); + sender->ForceIncludeSendPacketsInAllocation(true); + + std::unique_ptr packet = BuildRtpPacket(); + sender->SendPacket(packet.get(), PacedPacketInfo()); + EXPECT_FALSE(transport_.last_packet()->options.included_in_feedback); + EXPECT_TRUE(transport_.last_packet()->options.included_in_allocation); +} + +TEST_P(RtpSenderEgressTest, OnSendSideDelayUpdated) { + StrictMock send_side_delay_observer; + RtpRtcpInterface::Configuration config = DefaultConfig(); + config.send_side_delay_observer = &send_side_delay_observer; + auto sender = std::make_unique(config, &packet_history_); + + // Send packet with 10 ms send-side delay. The average, max and total should + // be 10 ms. + EXPECT_CALL(send_side_delay_observer, + SendSideDelayUpdated(10, 10, 10, kSsrc)); + int64_t capture_time_ms = clock_->TimeInMilliseconds(); + time_controller_.AdvanceTime(TimeDelta::Millis(10)); + sender->SendPacket(BuildRtpPacket(/*marker=*/true, capture_time_ms).get(), + PacedPacketInfo()); + + // Send another packet with 20 ms delay. The average, max and total should be + // 15, 20 and 30 ms respectively. + EXPECT_CALL(send_side_delay_observer, + SendSideDelayUpdated(15, 20, 30, kSsrc)); + capture_time_ms = clock_->TimeInMilliseconds(); + time_controller_.AdvanceTime(TimeDelta::Millis(20)); + sender->SendPacket(BuildRtpPacket(/*marker=*/true, capture_time_ms).get(), + PacedPacketInfo()); + + // Send another packet at the same time, which replaces the last packet. + // Since this packet has 0 ms delay, the average is now 5 ms and max is 10 ms. + // The total counter stays the same though. + // TODO(terelius): Is is not clear that this is the right behavior. + EXPECT_CALL(send_side_delay_observer, SendSideDelayUpdated(5, 10, 30, kSsrc)); + capture_time_ms = clock_->TimeInMilliseconds(); + sender->SendPacket(BuildRtpPacket(/*marker=*/true, capture_time_ms).get(), + PacedPacketInfo()); + + // Send a packet 1 second later. The earlier packets should have timed + // out, so both max and average should be the delay of this packet. The total + // keeps increasing. + time_controller_.AdvanceTime(TimeDelta::Seconds(1)); + EXPECT_CALL(send_side_delay_observer, SendSideDelayUpdated(1, 1, 31, kSsrc)); + capture_time_ms = clock_->TimeInMilliseconds(); + time_controller_.AdvanceTime(TimeDelta::Millis(1)); + sender->SendPacket(BuildRtpPacket(/*marker=*/true, capture_time_ms).get(), + PacedPacketInfo()); +} + +TEST_P(RtpSenderEgressTest, WritesPacerExitToTimingExtension) { + std::unique_ptr sender = CreateRtpSenderEgress(); + header_extensions_.RegisterByUri(kVideoTimingExtensionId, + VideoTimingExtension::kUri); + + std::unique_ptr packet = BuildRtpPacket(); + packet->SetExtension(VideoSendTiming{}); + + const int kStoredTimeInMs = 100; + time_controller_.AdvanceTime(TimeDelta::Millis(kStoredTimeInMs)); + sender->SendPacket(packet.get(), PacedPacketInfo()); + ASSERT_TRUE(transport_.last_packet().has_value()); + + VideoSendTiming video_timing; + EXPECT_TRUE( + transport_.last_packet()->packet.GetExtension( + &video_timing)); + EXPECT_EQ(video_timing.pacer_exit_delta_ms, kStoredTimeInMs); +} + +TEST_P(RtpSenderEgressTest, WritesNetwork2ToTimingExtension) { + RtpRtcpInterface::Configuration rtp_config = DefaultConfig(); + rtp_config.populate_network2_timestamp = true; + auto sender = std::make_unique(rtp_config, &packet_history_); + header_extensions_.RegisterByUri(kVideoTimingExtensionId, + VideoTimingExtension::kUri); + + const uint16_t kPacerExitMs = 1234u; + std::unique_ptr packet = BuildRtpPacket(); + VideoSendTiming send_timing = {}; + send_timing.pacer_exit_delta_ms = kPacerExitMs; + packet->SetExtension(send_timing); + + const int kStoredTimeInMs = 100; + time_controller_.AdvanceTime(TimeDelta::Millis(kStoredTimeInMs)); + sender->SendPacket(packet.get(), PacedPacketInfo()); + ASSERT_TRUE(transport_.last_packet().has_value()); + + VideoSendTiming video_timing; + EXPECT_TRUE( + transport_.last_packet()->packet.GetExtension( + &video_timing)); + EXPECT_EQ(video_timing.network2_timestamp_delta_ms, kStoredTimeInMs); + EXPECT_EQ(video_timing.pacer_exit_delta_ms, kPacerExitMs); +} + +TEST_P(RtpSenderEgressTest, OnSendPacketUpdated) { + std::unique_ptr sender = CreateRtpSenderEgress(); + header_extensions_.RegisterByUri(kTransportSequenceNumberExtensionId, + TransportSequenceNumber::kUri); + + const uint16_t kTransportSequenceNumber = 1; + EXPECT_CALL(send_packet_observer_, + OnSendPacket(kTransportSequenceNumber, + clock_->TimeInMilliseconds(), kSsrc)); + std::unique_ptr packet = BuildRtpPacket(); + packet->SetExtension(kTransportSequenceNumber); + sender->SendPacket(packet.get(), PacedPacketInfo()); +} + +TEST_P(RtpSenderEgressTest, OnSendPacketNotUpdatedForRetransmits) { + std::unique_ptr sender = CreateRtpSenderEgress(); + header_extensions_.RegisterByUri(kTransportSequenceNumberExtensionId, + TransportSequenceNumber::kUri); + + const uint16_t kTransportSequenceNumber = 1; + EXPECT_CALL(send_packet_observer_, OnSendPacket).Times(0); + std::unique_ptr packet = BuildRtpPacket(); + packet->SetExtension(kTransportSequenceNumber); + packet->set_packet_type(RtpPacketMediaType::kRetransmission); + packet->set_retransmitted_sequence_number(packet->SequenceNumber()); + sender->SendPacket(packet.get(), PacedPacketInfo()); +} + +TEST_P(RtpSenderEgressTest, ReportsFecRate) { + constexpr int kNumPackets = 10; + constexpr TimeDelta kTimeBetweenPackets = TimeDelta::Millis(33); + + std::unique_ptr sender = CreateRtpSenderEgress(); + DataSize total_fec_data_sent = DataSize::Zero(); + // Send some packets, alternating between media and FEC. + for (size_t i = 0; i < kNumPackets; ++i) { + std::unique_ptr media_packet = BuildRtpPacket(); + media_packet->set_packet_type(RtpPacketMediaType::kVideo); + media_packet->SetPayloadSize(500); + sender->SendPacket(media_packet.get(), PacedPacketInfo()); + + std::unique_ptr fec_packet = BuildRtpPacket(); + fec_packet->set_packet_type(RtpPacketMediaType::kForwardErrorCorrection); + fec_packet->SetPayloadSize(123); + sender->SendPacket(fec_packet.get(), PacedPacketInfo()); + total_fec_data_sent += DataSize::Bytes(fec_packet->size()); + + time_controller_.AdvanceTime(kTimeBetweenPackets); + } + + EXPECT_NEAR( + (sender->GetSendRates()[RtpPacketMediaType::kForwardErrorCorrection]) + .bps(), + (total_fec_data_sent / (kTimeBetweenPackets * kNumPackets)).bps(), 500); +} + +TEST_P(RtpSenderEgressTest, BitrateCallbacks) { + class MockBitrateStaticsObserver : public BitrateStatisticsObserver { + public: + MOCK_METHOD(void, Notify, (uint32_t, uint32_t, uint32_t), (override)); + } observer; + + RtpRtcpInterface::Configuration config = DefaultConfig(); + config.send_bitrate_observer = &observer; + auto sender = std::make_unique(config, &packet_history_); + + // Simulate kNumPackets sent with kPacketInterval intervals, with the + // number of packets selected so that we fill (but don't overflow) the one + // second averaging window. + const TimeDelta kWindowSize = TimeDelta::Seconds(1); + const TimeDelta kPacketInterval = TimeDelta::Millis(20); + const int kNumPackets = (kWindowSize - kPacketInterval) / kPacketInterval; + + DataSize total_data_sent = DataSize::Zero(); + + // Send all but on of the packets, expect a call for each packet but don't + // verify bitrate yet (noisy measurements in the beginning). + for (int i = 0; i < kNumPackets; ++i) { + std::unique_ptr packet = BuildRtpPacket(); + packet->SetPayloadSize(500); + // Mark all packets as retransmissions - will cause total and retransmission + // rates to be equal. + packet->set_packet_type(RtpPacketMediaType::kRetransmission); + packet->set_retransmitted_sequence_number(packet->SequenceNumber()); + total_data_sent += DataSize::Bytes(packet->size()); + + EXPECT_CALL(observer, Notify(_, _, kSsrc)) + .WillOnce([&](uint32_t total_bitrate_bps, + uint32_t retransmission_bitrate_bps, uint32_t /*ssrc*/) { + TimeDelta window_size = i * kPacketInterval + TimeDelta::Millis(1); + // If there is just a single data point, there is no well defined + // averaging window so a bitrate of zero will be reported. + const double expected_bitrate_bps = + i == 0 ? 0.0 : (total_data_sent / window_size).bps(); + EXPECT_NEAR(total_bitrate_bps, expected_bitrate_bps, 500); + EXPECT_NEAR(retransmission_bitrate_bps, expected_bitrate_bps, 500); + }); + + sender->SendPacket(packet.get(), PacedPacketInfo()); + time_controller_.AdvanceTime(kPacketInterval); + } +} + +TEST_P(RtpSenderEgressTest, DoesNotPutNotRetransmittablePacketsInHistory) { + std::unique_ptr sender = CreateRtpSenderEgress(); + packet_history_.SetStorePacketsStatus( + RtpPacketHistory::StorageMode::kStoreAndCull, 10); + + std::unique_ptr packet = BuildRtpPacket(); + packet->set_allow_retransmission(false); + sender->SendPacket(packet.get(), PacedPacketInfo()); + EXPECT_FALSE( + packet_history_.GetPacketState(packet->SequenceNumber()).has_value()); +} + +TEST_P(RtpSenderEgressTest, PutsRetransmittablePacketsInHistory) { + std::unique_ptr sender = CreateRtpSenderEgress(); + packet_history_.SetStorePacketsStatus( + RtpPacketHistory::StorageMode::kStoreAndCull, 10); + + std::unique_ptr packet = BuildRtpPacket(); + packet->set_allow_retransmission(true); + sender->SendPacket(packet.get(), PacedPacketInfo()); + EXPECT_THAT( + packet_history_.GetPacketState(packet->SequenceNumber()), + Optional( + Field(&RtpPacketHistory::PacketState::pending_transmission, false))); +} + +TEST_P(RtpSenderEgressTest, DoesNotPutNonMediaInHistory) { + std::unique_ptr sender = CreateRtpSenderEgress(); + packet_history_.SetStorePacketsStatus( + RtpPacketHistory::StorageMode::kStoreAndCull, 10); + + // Non-media packets, even when marked as retransmittable, are not put into + // the packet history. + std::unique_ptr retransmission = BuildRtpPacket(); + retransmission->set_allow_retransmission(true); + retransmission->set_packet_type(RtpPacketMediaType::kRetransmission); + retransmission->set_retransmitted_sequence_number( + retransmission->SequenceNumber()); + sender->SendPacket(retransmission.get(), PacedPacketInfo()); + EXPECT_FALSE(packet_history_.GetPacketState(retransmission->SequenceNumber()) + .has_value()); + + std::unique_ptr fec = BuildRtpPacket(); + fec->set_allow_retransmission(true); + fec->set_packet_type(RtpPacketMediaType::kForwardErrorCorrection); + sender->SendPacket(fec.get(), PacedPacketInfo()); + EXPECT_FALSE( + packet_history_.GetPacketState(fec->SequenceNumber()).has_value()); + + std::unique_ptr padding = BuildRtpPacket(); + padding->set_allow_retransmission(true); + padding->set_packet_type(RtpPacketMediaType::kPadding); + sender->SendPacket(padding.get(), PacedPacketInfo()); + EXPECT_FALSE( + packet_history_.GetPacketState(padding->SequenceNumber()).has_value()); +} + +TEST_P(RtpSenderEgressTest, UpdatesSendStatusOfRetransmittedPackets) { + std::unique_ptr sender = CreateRtpSenderEgress(); + packet_history_.SetStorePacketsStatus( + RtpPacketHistory::StorageMode::kStoreAndCull, 10); + + // Send a packet, putting it in the history. + std::unique_ptr media_packet = BuildRtpPacket(); + media_packet->set_allow_retransmission(true); + sender->SendPacket(media_packet.get(), PacedPacketInfo()); + EXPECT_THAT( + packet_history_.GetPacketState(media_packet->SequenceNumber()), + Optional( + Field(&RtpPacketHistory::PacketState::pending_transmission, false))); + + // Simulate a retransmission, marking the packet as pending. + std::unique_ptr retransmission = + packet_history_.GetPacketAndMarkAsPending(media_packet->SequenceNumber()); + retransmission->set_retransmitted_sequence_number( + media_packet->SequenceNumber()); + retransmission->set_packet_type(RtpPacketMediaType::kRetransmission); + EXPECT_THAT(packet_history_.GetPacketState(media_packet->SequenceNumber()), + Optional(Field( + &RtpPacketHistory::PacketState::pending_transmission, true))); + + // Simulate packet leaving pacer, the packet should be marked as non-pending. + sender->SendPacket(retransmission.get(), PacedPacketInfo()); + EXPECT_THAT( + packet_history_.GetPacketState(media_packet->SequenceNumber()), + Optional( + Field(&RtpPacketHistory::PacketState::pending_transmission, false))); +} + +TEST_P(RtpSenderEgressTest, StreamDataCountersCallbacks) { + std::unique_ptr sender = CreateRtpSenderEgress(); + + const RtpPacketCounter kEmptyCounter; + RtpPacketCounter expected_transmitted_counter; + RtpPacketCounter expected_retransmission_counter; + + // Send a media packet. + std::unique_ptr media_packet = BuildRtpPacket(); + media_packet->SetPayloadSize(6); + expected_transmitted_counter.packets += 1; + expected_transmitted_counter.payload_bytes += media_packet->payload_size(); + expected_transmitted_counter.header_bytes += media_packet->headers_size(); + + EXPECT_CALL( + mock_rtp_stats_callback_, + DataCountersUpdated(AllOf(Field(&StreamDataCounters::transmitted, + expected_transmitted_counter), + Field(&StreamDataCounters::retransmitted, + expected_retransmission_counter), + Field(&StreamDataCounters::fec, kEmptyCounter)), + kSsrc)); + sender->SendPacket(media_packet.get(), PacedPacketInfo()); + time_controller_.AdvanceTime(TimeDelta::Zero()); + + // Send a retransmission. Retransmissions are counted into both transmitted + // and retransmitted packet statistics. + std::unique_ptr retransmission_packet = BuildRtpPacket(); + retransmission_packet->set_packet_type(RtpPacketMediaType::kRetransmission); + retransmission_packet->set_retransmitted_sequence_number( + retransmission_packet->SequenceNumber()); + media_packet->SetPayloadSize(7); + expected_transmitted_counter.packets += 1; + expected_transmitted_counter.payload_bytes += + retransmission_packet->payload_size(); + expected_transmitted_counter.header_bytes += + retransmission_packet->headers_size(); + + expected_retransmission_counter.packets += 1; + expected_retransmission_counter.payload_bytes += + retransmission_packet->payload_size(); + expected_retransmission_counter.header_bytes += + retransmission_packet->headers_size(); + + EXPECT_CALL( + mock_rtp_stats_callback_, + DataCountersUpdated(AllOf(Field(&StreamDataCounters::transmitted, + expected_transmitted_counter), + Field(&StreamDataCounters::retransmitted, + expected_retransmission_counter), + Field(&StreamDataCounters::fec, kEmptyCounter)), + kSsrc)); + sender->SendPacket(retransmission_packet.get(), PacedPacketInfo()); + time_controller_.AdvanceTime(TimeDelta::Zero()); + + // Send a padding packet. + std::unique_ptr padding_packet = BuildRtpPacket(); + padding_packet->set_packet_type(RtpPacketMediaType::kPadding); + padding_packet->SetPadding(224); + expected_transmitted_counter.packets += 1; + expected_transmitted_counter.padding_bytes += padding_packet->padding_size(); + expected_transmitted_counter.header_bytes += padding_packet->headers_size(); + + EXPECT_CALL( + mock_rtp_stats_callback_, + DataCountersUpdated(AllOf(Field(&StreamDataCounters::transmitted, + expected_transmitted_counter), + Field(&StreamDataCounters::retransmitted, + expected_retransmission_counter), + Field(&StreamDataCounters::fec, kEmptyCounter)), + kSsrc)); + sender->SendPacket(padding_packet.get(), PacedPacketInfo()); + time_controller_.AdvanceTime(TimeDelta::Zero()); +} + +TEST_P(RtpSenderEgressTest, StreamDataCountersCallbacksFec) { + std::unique_ptr sender = CreateRtpSenderEgress(); + + const RtpPacketCounter kEmptyCounter; + RtpPacketCounter expected_transmitted_counter; + RtpPacketCounter expected_fec_counter; + + // Send a media packet. + std::unique_ptr media_packet = BuildRtpPacket(); + media_packet->SetPayloadSize(6); + expected_transmitted_counter.packets += 1; + expected_transmitted_counter.payload_bytes += media_packet->payload_size(); + expected_transmitted_counter.header_bytes += media_packet->headers_size(); + + EXPECT_CALL( + mock_rtp_stats_callback_, + DataCountersUpdated( + AllOf(Field(&StreamDataCounters::transmitted, + expected_transmitted_counter), + Field(&StreamDataCounters::retransmitted, kEmptyCounter), + Field(&StreamDataCounters::fec, expected_fec_counter)), + kSsrc)); + sender->SendPacket(media_packet.get(), PacedPacketInfo()); + time_controller_.AdvanceTime(TimeDelta::Zero()); + + // Send and FEC packet. FEC is counted into both transmitted and FEC packet + // statistics. + std::unique_ptr fec_packet = BuildRtpPacket(); + fec_packet->set_packet_type(RtpPacketMediaType::kForwardErrorCorrection); + fec_packet->SetPayloadSize(6); + expected_transmitted_counter.packets += 1; + expected_transmitted_counter.payload_bytes += fec_packet->payload_size(); + expected_transmitted_counter.header_bytes += fec_packet->headers_size(); + + expected_fec_counter.packets += 1; + expected_fec_counter.payload_bytes += fec_packet->payload_size(); + expected_fec_counter.header_bytes += fec_packet->headers_size(); + + EXPECT_CALL( + mock_rtp_stats_callback_, + DataCountersUpdated( + AllOf(Field(&StreamDataCounters::transmitted, + expected_transmitted_counter), + Field(&StreamDataCounters::retransmitted, kEmptyCounter), + Field(&StreamDataCounters::fec, expected_fec_counter)), + kSsrc)); + sender->SendPacket(fec_packet.get(), PacedPacketInfo()); + time_controller_.AdvanceTime(TimeDelta::Zero()); +} + +TEST_P(RtpSenderEgressTest, UpdatesDataCounters) { + std::unique_ptr sender = CreateRtpSenderEgress(); + + const RtpPacketCounter kEmptyCounter; + + // Send a media packet. + std::unique_ptr media_packet = BuildRtpPacket(); + media_packet->SetPayloadSize(6); + sender->SendPacket(media_packet.get(), PacedPacketInfo()); + time_controller_.AdvanceTime(TimeDelta::Zero()); + + // Send an RTX retransmission packet. + std::unique_ptr rtx_packet = BuildRtpPacket(); + rtx_packet->set_packet_type(RtpPacketMediaType::kRetransmission); + rtx_packet->SetSsrc(kRtxSsrc); + rtx_packet->SetPayloadSize(7); + rtx_packet->set_retransmitted_sequence_number(media_packet->SequenceNumber()); + sender->SendPacket(rtx_packet.get(), PacedPacketInfo()); + time_controller_.AdvanceTime(TimeDelta::Zero()); + + StreamDataCounters rtp_stats; + StreamDataCounters rtx_stats; + sender->GetDataCounters(&rtp_stats, &rtx_stats); + + EXPECT_EQ(rtp_stats.transmitted.packets, 1u); + EXPECT_EQ(rtp_stats.transmitted.payload_bytes, media_packet->payload_size()); + EXPECT_EQ(rtp_stats.transmitted.padding_bytes, media_packet->padding_size()); + EXPECT_EQ(rtp_stats.transmitted.header_bytes, media_packet->headers_size()); + EXPECT_EQ(rtp_stats.retransmitted, kEmptyCounter); + EXPECT_EQ(rtp_stats.fec, kEmptyCounter); + + // Retransmissions are counted both into transmitted and retransmitted + // packet counts. + EXPECT_EQ(rtx_stats.transmitted.packets, 1u); + EXPECT_EQ(rtx_stats.transmitted.payload_bytes, rtx_packet->payload_size()); + EXPECT_EQ(rtx_stats.transmitted.padding_bytes, rtx_packet->padding_size()); + EXPECT_EQ(rtx_stats.transmitted.header_bytes, rtx_packet->headers_size()); + EXPECT_EQ(rtx_stats.retransmitted, rtx_stats.transmitted); + EXPECT_EQ(rtx_stats.fec, kEmptyCounter); +} + +TEST_P(RtpSenderEgressTest, SendPacketUpdatesExtensions) { + header_extensions_.RegisterByUri(kVideoTimingExtensionId, + VideoTimingExtension::kUri); + header_extensions_.RegisterByUri(kAbsoluteSendTimeExtensionId, + AbsoluteSendTime::kUri); + header_extensions_.RegisterByUri(kTransmissionOffsetExtensionId, + TransmissionOffset::kUri); + std::unique_ptr sender = CreateRtpSenderEgress(); + + std::unique_ptr packet = BuildRtpPacket(); + packet->set_packetization_finish_time_ms(clock_->TimeInMilliseconds()); + + const int32_t kDiffMs = 10; + time_controller_.AdvanceTime(TimeDelta::Millis(kDiffMs)); + + sender->SendPacket(packet.get(), PacedPacketInfo()); + + RtpPacketReceived received_packet = transport_.last_packet()->packet; + + EXPECT_EQ(received_packet.GetExtension(), kDiffMs * 90); + + EXPECT_EQ(received_packet.GetExtension(), + AbsoluteSendTime::MsTo24Bits(clock_->TimeInMilliseconds())); + + VideoSendTiming timing; + EXPECT_TRUE(received_packet.GetExtension(&timing)); + EXPECT_EQ(timing.pacer_exit_delta_ms, kDiffMs); +} + +TEST_P(RtpSenderEgressTest, SendPacketSetsPacketOptions) { + const uint16_t kPacketId = 42; + std::unique_ptr sender = CreateRtpSenderEgress(); + header_extensions_.RegisterByUri(kTransportSequenceNumberExtensionId, + TransportSequenceNumber::kUri); + + std::unique_ptr packet = BuildRtpPacket(); + packet->SetExtension(kPacketId); + EXPECT_CALL(send_packet_observer_, OnSendPacket); + sender->SendPacket(packet.get(), PacedPacketInfo()); + + PacketOptions packet_options = transport_.last_packet()->options; + + EXPECT_EQ(packet_options.packet_id, kPacketId); + EXPECT_TRUE(packet_options.included_in_allocation); + EXPECT_TRUE(packet_options.included_in_feedback); + EXPECT_FALSE(packet_options.is_retransmit); + + // Send another packet as retransmission, verify options are populated. + std::unique_ptr retransmission = BuildRtpPacket(); + retransmission->SetExtension(kPacketId + 1); + retransmission->set_packet_type(RtpPacketMediaType::kRetransmission); + retransmission->set_retransmitted_sequence_number(packet->SequenceNumber()); + sender->SendPacket(retransmission.get(), PacedPacketInfo()); + EXPECT_TRUE(transport_.last_packet()->options.is_retransmit); +} + +TEST_P(RtpSenderEgressTest, SendPacketUpdatesStats) { + const size_t kPayloadSize = 1000; + StrictMock send_side_delay_observer; + + const rtc::ArrayView kNoRtpHeaderExtensionSizes; + FlexfecSender flexfec(kFlexfectPayloadType, kFlexFecSsrc, kSsrc, /*mid=*/"", + /*header_extensions=*/{}, kNoRtpHeaderExtensionSizes, + /*rtp_state=*/nullptr, time_controller_.GetClock()); + RtpRtcpInterface::Configuration config = DefaultConfig(); + config.fec_generator = &flexfec; + config.send_side_delay_observer = &send_side_delay_observer; + auto sender = std::make_unique(config, &packet_history_); + + header_extensions_.RegisterByUri(kTransportSequenceNumberExtensionId, + TransportSequenceNumber::kUri); + + const int64_t capture_time_ms = clock_->TimeInMilliseconds(); + + std::unique_ptr video_packet = BuildRtpPacket(); + video_packet->set_packet_type(RtpPacketMediaType::kVideo); + video_packet->SetPayloadSize(kPayloadSize); + video_packet->SetExtension(1); + + std::unique_ptr rtx_packet = BuildRtpPacket(); + rtx_packet->SetSsrc(kRtxSsrc); + rtx_packet->set_packet_type(RtpPacketMediaType::kRetransmission); + rtx_packet->set_retransmitted_sequence_number(video_packet->SequenceNumber()); + rtx_packet->SetPayloadSize(kPayloadSize); + rtx_packet->SetExtension(2); + + std::unique_ptr fec_packet = BuildRtpPacket(); + fec_packet->SetSsrc(kFlexFecSsrc); + fec_packet->set_packet_type(RtpPacketMediaType::kForwardErrorCorrection); + fec_packet->SetPayloadSize(kPayloadSize); + fec_packet->SetExtension(3); + + const int64_t kDiffMs = 25; + time_controller_.AdvanceTime(TimeDelta::Millis(kDiffMs)); + + EXPECT_CALL(send_side_delay_observer, + SendSideDelayUpdated(kDiffMs, kDiffMs, kDiffMs, kSsrc)); + EXPECT_CALL( + send_side_delay_observer, + SendSideDelayUpdated(kDiffMs, kDiffMs, 2 * kDiffMs, kFlexFecSsrc)); + + EXPECT_CALL(send_packet_observer_, OnSendPacket(1, capture_time_ms, kSsrc)); + + sender->SendPacket(video_packet.get(), PacedPacketInfo()); + + // Send packet observer not called for padding/retransmissions. + EXPECT_CALL(send_packet_observer_, OnSendPacket(2, _, _)).Times(0); + sender->SendPacket(rtx_packet.get(), PacedPacketInfo()); + + EXPECT_CALL(send_packet_observer_, + OnSendPacket(3, capture_time_ms, kFlexFecSsrc)); + sender->SendPacket(fec_packet.get(), PacedPacketInfo()); + + time_controller_.AdvanceTime(TimeDelta::Zero()); + StreamDataCounters rtp_stats; + StreamDataCounters rtx_stats; + sender->GetDataCounters(&rtp_stats, &rtx_stats); + EXPECT_EQ(rtp_stats.transmitted.packets, 2u); + EXPECT_EQ(rtp_stats.fec.packets, 1u); + EXPECT_EQ(rtx_stats.retransmitted.packets, 1u); +} + +TEST_P(RtpSenderEgressTest, TransportFeedbackObserverWithRetransmission) { + const uint16_t kTransportSequenceNumber = 17; + header_extensions_.RegisterByUri(kTransportSequenceNumberExtensionId, + TransportSequenceNumber::kUri); + std::unique_ptr retransmission = BuildRtpPacket(); + retransmission->set_packet_type(RtpPacketMediaType::kRetransmission); + retransmission->SetExtension( + kTransportSequenceNumber); + uint16_t retransmitted_seq = retransmission->SequenceNumber() - 2; + retransmission->set_retransmitted_sequence_number(retransmitted_seq); + + std::unique_ptr sender = CreateRtpSenderEgress(); + EXPECT_CALL( + feedback_observer_, + OnAddPacket(AllOf( + Field(&RtpPacketSendInfo::media_ssrc, kSsrc), + Field(&RtpPacketSendInfo::rtp_sequence_number, retransmitted_seq), + Field(&RtpPacketSendInfo::transport_sequence_number, + kTransportSequenceNumber)))); + sender->SendPacket(retransmission.get(), PacedPacketInfo()); +} + +TEST_P(RtpSenderEgressTest, TransportFeedbackObserverWithRtxRetransmission) { + const uint16_t kTransportSequenceNumber = 17; + header_extensions_.RegisterByUri(kTransportSequenceNumberExtensionId, + TransportSequenceNumber::kUri); + + std::unique_ptr rtx_retransmission = BuildRtpPacket(); + rtx_retransmission->SetSsrc(kRtxSsrc); + rtx_retransmission->SetExtension( + kTransportSequenceNumber); + rtx_retransmission->set_packet_type(RtpPacketMediaType::kRetransmission); + uint16_t rtx_retransmitted_seq = rtx_retransmission->SequenceNumber() - 2; + rtx_retransmission->set_retransmitted_sequence_number(rtx_retransmitted_seq); + + std::unique_ptr sender = CreateRtpSenderEgress(); + EXPECT_CALL( + feedback_observer_, + OnAddPacket(AllOf( + Field(&RtpPacketSendInfo::media_ssrc, kSsrc), + Field(&RtpPacketSendInfo::rtp_sequence_number, rtx_retransmitted_seq), + Field(&RtpPacketSendInfo::transport_sequence_number, + kTransportSequenceNumber)))); + sender->SendPacket(rtx_retransmission.get(), PacedPacketInfo()); +} + +TEST_P(RtpSenderEgressTest, TransportFeedbackObserverPadding) { + const uint16_t kTransportSequenceNumber = 17; + header_extensions_.RegisterByUri(kTransportSequenceNumberExtensionId, + TransportSequenceNumber::kUri); + std::unique_ptr padding = BuildRtpPacket(); + padding->SetPadding(224); + padding->set_packet_type(RtpPacketMediaType::kPadding); + padding->SetExtension(kTransportSequenceNumber); + + std::unique_ptr sender = CreateRtpSenderEgress(); + EXPECT_CALL( + feedback_observer_, + OnAddPacket(AllOf(Field(&RtpPacketSendInfo::media_ssrc, absl::nullopt), + Field(&RtpPacketSendInfo::transport_sequence_number, + kTransportSequenceNumber)))); + sender->SendPacket(padding.get(), PacedPacketInfo()); +} + +TEST_P(RtpSenderEgressTest, TransportFeedbackObserverRtxPadding) { + const uint16_t kTransportSequenceNumber = 17; + header_extensions_.RegisterByUri(kTransportSequenceNumberExtensionId, + TransportSequenceNumber::kUri); + + std::unique_ptr rtx_padding = BuildRtpPacket(); + rtx_padding->SetPadding(224); + rtx_padding->SetSsrc(kRtxSsrc); + rtx_padding->set_packet_type(RtpPacketMediaType::kPadding); + rtx_padding->SetExtension(kTransportSequenceNumber); + + std::unique_ptr sender = CreateRtpSenderEgress(); + EXPECT_CALL( + feedback_observer_, + OnAddPacket(AllOf(Field(&RtpPacketSendInfo::media_ssrc, absl::nullopt), + Field(&RtpPacketSendInfo::transport_sequence_number, + kTransportSequenceNumber)))); + sender->SendPacket(rtx_padding.get(), PacedPacketInfo()); +} + +TEST_P(RtpSenderEgressTest, TransportFeedbackObserverFec) { + const uint16_t kTransportSequenceNumber = 17; + header_extensions_.RegisterByUri(kTransportSequenceNumberExtensionId, + TransportSequenceNumber::kUri); + + std::unique_ptr fec_packet = BuildRtpPacket(); + fec_packet->SetSsrc(kFlexFecSsrc); + fec_packet->set_packet_type(RtpPacketMediaType::kForwardErrorCorrection); + fec_packet->SetExtension(kTransportSequenceNumber); + + const rtc::ArrayView kNoRtpHeaderExtensionSizes; + FlexfecSender flexfec(kFlexfectPayloadType, kFlexFecSsrc, kSsrc, /*mid=*/"", + /*header_extensions=*/{}, kNoRtpHeaderExtensionSizes, + /*rtp_state=*/nullptr, time_controller_.GetClock()); + RtpRtcpInterface::Configuration config = DefaultConfig(); + config.fec_generator = &flexfec; + auto sender = std::make_unique(config, &packet_history_); + EXPECT_CALL( + feedback_observer_, + OnAddPacket(AllOf(Field(&RtpPacketSendInfo::media_ssrc, absl::nullopt), + Field(&RtpPacketSendInfo::transport_sequence_number, + kTransportSequenceNumber)))); + sender->SendPacket(fec_packet.get(), PacedPacketInfo()); +} + +INSTANTIATE_TEST_SUITE_P(WithAndWithoutOverhead, + RtpSenderEgressTest, + ::testing::Values(TestConfig(false), + TestConfig(true))); + +} // namespace webrtc diff --git a/modules/rtp_rtcp/source/rtp_sender_unittest.cc b/modules/rtp_rtcp/source/rtp_sender_unittest.cc index 807d63dab7..e9be016143 100644 --- a/modules/rtp_rtcp/source/rtp_sender_unittest.cc +++ b/modules/rtp_rtcp/source/rtp_sender_unittest.cc @@ -22,18 +22,17 @@ #include "modules/rtp_rtcp/include/rtp_header_extension_map.h" #include "modules/rtp_rtcp/include/rtp_packet_sender.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" -#include "modules/rtp_rtcp/source/rtcp_packet/transport_feedback.h" #include "modules/rtp_rtcp/source/rtp_format_video_generic.h" #include "modules/rtp_rtcp/source/rtp_generic_frame_descriptor.h" #include "modules/rtp_rtcp/source/rtp_generic_frame_descriptor_extension.h" #include "modules/rtp_rtcp/source/rtp_header_extensions.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "modules/rtp_rtcp/source/rtp_packet_to_send.h" -#include "modules/rtp_rtcp/source/rtp_sender_egress.h" #include "modules/rtp_rtcp/source/rtp_sender_video.h" #include "modules/rtp_rtcp/source/rtp_utility.h" #include "modules/rtp_rtcp/source/video_fec_generator.h" #include "rtc_base/arraysize.h" +#include "rtc_base/logging.h" #include "rtc_base/rate_limiter.h" #include "rtc_base/strings/string_builder.h" #include "rtc_base/task_utils/to_queued_task.h" @@ -67,20 +66,19 @@ const uint16_t kSeqNum = 33; const uint32_t kSsrc = 725242; const uint32_t kRtxSsrc = 12345; const uint32_t kFlexFecSsrc = 45678; -const uint16_t kTransportSequenceNumber = 1; const uint64_t kStartTime = 123456789; const size_t kMaxPaddingSize = 224u; const uint8_t kPayloadData[] = {47, 11, 32, 93, 89}; const int64_t kDefaultExpectedRetransmissionTimeMs = 125; -const char kNoRid[] = ""; -const char kNoMid[] = ""; +const size_t kMaxPaddingLength = 224; // Value taken from rtp_sender.cc. +const uint32_t kTimestampTicksPerMs = 90; // 90kHz clock. using ::testing::_; using ::testing::AllOf; using ::testing::AtLeast; using ::testing::Contains; using ::testing::Each; -using ::testing::ElementsAreArray; +using ::testing::ElementsAre; using ::testing::Eq; using ::testing::Field; using ::testing::Gt; @@ -91,62 +89,6 @@ using ::testing::Pointee; using ::testing::Property; using ::testing::Return; using ::testing::SizeIs; -using ::testing::StrictMock; - -uint64_t ConvertMsToAbsSendTime(int64_t time_ms) { - return (((time_ms << 18) + 500) / 1000) & 0x00ffffff; -} - -class LoopbackTransportTest : public webrtc::Transport { - public: - LoopbackTransportTest() : total_bytes_sent_(0) { - receivers_extensions_.Register( - kTransmissionTimeOffsetExtensionId); - receivers_extensions_.Register( - kAbsoluteSendTimeExtensionId); - receivers_extensions_.Register( - kTransportSequenceNumberExtensionId); - receivers_extensions_.Register(kVideoRotationExtensionId); - receivers_extensions_.Register(kAudioLevelExtensionId); - receivers_extensions_.Register( - kVideoTimingExtensionId); - receivers_extensions_.Register(kMidExtensionId); - receivers_extensions_.Register( - kGenericDescriptorId); - receivers_extensions_.Register(kRidExtensionId); - receivers_extensions_.Register( - kRepairedRidExtensionId); - } - - bool SendRtp(const uint8_t* data, - size_t len, - const PacketOptions& options) override { - last_options_ = options; - total_bytes_sent_ += len; - sent_packets_.push_back(RtpPacketReceived(&receivers_extensions_)); - EXPECT_TRUE(sent_packets_.back().Parse(data, len)); - return true; - } - bool SendRtcp(const uint8_t* data, size_t len) override { return false; } - const RtpPacketReceived& last_sent_packet() { return sent_packets_.back(); } - int packets_sent() { return sent_packets_.size(); } - - size_t total_bytes_sent_; - PacketOptions last_options_; - std::vector sent_packets_; - - private: - RtpHeaderExtensionMap receivers_extensions_; -}; - -MATCHER_P(SameRtcEventTypeAs, value, "") { - return value == arg->GetType(); -} - -struct TestConfig { - explicit TestConfig(bool with_overhead) : with_overhead(with_overhead) {} - bool with_overhead = false; -}; class MockRtpPacketPacer : public RtpPacketSender { public: @@ -159,133 +101,11 @@ class MockRtpPacketPacer : public RtpPacketSender { (override)); }; -class MockSendSideDelayObserver : public SendSideDelayObserver { - public: - MOCK_METHOD(void, - SendSideDelayUpdated, - (int, int, uint64_t, uint32_t), - (override)); -}; - -class MockSendPacketObserver : public SendPacketObserver { - public: - MOCK_METHOD(void, OnSendPacket, (uint16_t, int64_t, uint32_t), (override)); -}; - -class MockTransportFeedbackObserver : public TransportFeedbackObserver { - public: - MOCK_METHOD(void, OnAddPacket, (const RtpPacketSendInfo&), (override)); - MOCK_METHOD(void, - OnTransportFeedback, - (const rtcp::TransportFeedback&), - (override)); -}; - -class StreamDataTestCallback : public StreamDataCountersCallback { - public: - StreamDataTestCallback() - : StreamDataCountersCallback(), ssrc_(0), counters_() {} - ~StreamDataTestCallback() override = default; - - void DataCountersUpdated(const StreamDataCounters& counters, - uint32_t ssrc) override { - ssrc_ = ssrc; - counters_ = counters; - } - - uint32_t ssrc_; - StreamDataCounters counters_; - - void MatchPacketCounter(const RtpPacketCounter& expected, - const RtpPacketCounter& actual) { - EXPECT_EQ(expected.payload_bytes, actual.payload_bytes); - EXPECT_EQ(expected.header_bytes, actual.header_bytes); - EXPECT_EQ(expected.padding_bytes, actual.padding_bytes); - EXPECT_EQ(expected.packets, actual.packets); - } - - void Matches(uint32_t ssrc, const StreamDataCounters& counters) { - EXPECT_EQ(ssrc, ssrc_); - MatchPacketCounter(counters.transmitted, counters_.transmitted); - MatchPacketCounter(counters.retransmitted, counters_.retransmitted); - EXPECT_EQ(counters.fec.packets, counters_.fec.packets); - } -}; - -class TaskQueuePacketSender : public RtpPacketSender { - public: - TaskQueuePacketSender(TimeController* time_controller, - std::unique_ptr packet_sender) - : time_controller_(time_controller), - packet_sender_(std::move(packet_sender)), - queue_(time_controller_->CreateTaskQueueFactory()->CreateTaskQueue( - "PacerQueue", - TaskQueueFactory::Priority::NORMAL)) {} - - void EnqueuePackets( - std::vector> packets) override { - queue_->PostTask(ToQueuedTask([sender = packet_sender_.get(), - packets_ = std::move(packets)]() mutable { - sender->EnqueuePackets(std::move(packets_)); - })); - // Trigger task we just enqueued to be executed by updating the simulated - // time controller. - time_controller_->AdvanceTime(TimeDelta::Zero()); - } - - TaskQueueBase* task_queue() const { return queue_.get(); } - - TimeController* const time_controller_; - std::unique_ptr packet_sender_; - std::unique_ptr queue_; -}; - -// Mimics ModuleRtpRtcp::RtpSenderContext. -// TODO(sprang): Split up unit tests and test these components individually -// wherever possible. -struct RtpSenderContext : public SequenceNumberAssigner { - RtpSenderContext(const RtpRtcpInterface::Configuration& config, - TimeController* time_controller) - : time_controller_(time_controller), - packet_history_(config.clock, config.enable_rtx_padding_prioritization), - packet_sender_(config, &packet_history_), - pacer_(time_controller, - std::make_unique( - &packet_sender_, - this)), - packet_generator_(config, - &packet_history_, - config.paced_sender ? config.paced_sender : &pacer_) { - } - void AssignSequenceNumber(RtpPacketToSend* packet) override { - packet_generator_.AssignSequenceNumber(packet); - } - // Inject packet straight into RtpSenderEgress without passing through the - // pacer, but while still running on the pacer task queue. - void InjectPacket(std::unique_ptr packet, - const PacedPacketInfo& packet_info) { - pacer_.task_queue()->PostTask( - ToQueuedTask([sender_ = &packet_sender_, packet_ = std::move(packet), - packet_info]() mutable { - sender_->SendPacket(packet_.get(), packet_info); - })); - time_controller_->AdvanceTime(TimeDelta::Zero()); - } - TimeController* time_controller_; - RtpPacketHistory packet_history_; - RtpSenderEgress packet_sender_; - TaskQueuePacketSender pacer_; - RTPSender packet_generator_; -}; - class FieldTrialConfig : public WebRtcKeyValueConfig { public: - FieldTrialConfig() - : overhead_enabled_(false), - max_padding_factor_(1200) {} + FieldTrialConfig() : max_padding_factor_(1200) {} ~FieldTrialConfig() override {} - void SetOverHeadEnabled(bool enabled) { overhead_enabled_ = enabled; } void SetMaxPaddingFactor(double factor) { max_padding_factor_ = factor; } std::string Lookup(absl::string_view key) const override { @@ -294,20 +114,17 @@ class FieldTrialConfig : public WebRtcKeyValueConfig { rtc::SimpleStringBuilder ssb(string_buf); ssb << "factor:" << max_padding_factor_; return ssb.str(); - } else if (key == "WebRTC-SendSideBwe-WithOverhead") { - return overhead_enabled_ ? "Enabled" : "Disabled"; } return ""; } private: - bool overhead_enabled_; double max_padding_factor_; }; } // namespace -class RtpSenderTest : public ::testing::TestWithParam { +class RtpSenderTest : public ::testing::Test { protected: RtpSenderTest() : time_controller_(Timestamp::Millis(kStartTime)), @@ -322,80 +139,65 @@ class RtpSenderTest : public ::testing::TestWithParam { nullptr, clock_), kMarkerBit(true) { - field_trials_.SetOverHeadEnabled(GetParam().with_overhead); - } - - void SetUp() override { SetUpRtpSender(true, false, false); } - - RTPSender* rtp_sender() { - RTC_DCHECK(rtp_sender_context_); - return &rtp_sender_context_->packet_generator_; } - RtpSenderEgress* rtp_egress() { - RTC_DCHECK(rtp_sender_context_); - return &rtp_sender_context_->packet_sender_; - } - - void SetUpRtpSender(bool pacer, - bool populate_network2, - bool always_send_mid_and_rid) { - SetUpRtpSender(pacer, populate_network2, always_send_mid_and_rid, - &flexfec_sender_); - } + void SetUp() override { SetUpRtpSender(true, false, nullptr); } - void SetUpRtpSender(bool pacer, - bool populate_network2, + void SetUpRtpSender(bool populate_network2, bool always_send_mid_and_rid, VideoFecGenerator* fec_generator) { + RtpRtcpInterface::Configuration config = GetDefaultConfig(); + config.fec_generator = fec_generator; + config.populate_network2_timestamp = populate_network2; + config.always_send_mid_and_rid = always_send_mid_and_rid; + CreateSender(config); + } + + RtpRtcpInterface::Configuration GetDefaultConfig() { RtpRtcpInterface::Configuration config; config.clock = clock_; - config.outgoing_transport = &transport_; config.local_media_ssrc = kSsrc; config.rtx_send_ssrc = kRtxSsrc; - config.fec_generator = fec_generator; config.event_log = &mock_rtc_event_log_; - config.send_packet_observer = &send_packet_observer_; config.retransmission_rate_limiter = &retransmission_rate_limiter_; - config.paced_sender = pacer ? &mock_paced_sender_ : nullptr; - config.populate_network2_timestamp = populate_network2; - config.rtp_stats_callback = &rtp_stats_callback_; - config.always_send_mid_and_rid = always_send_mid_and_rid; + config.paced_sender = &mock_paced_sender_; config.field_trials = &field_trials_; + return config; + } - rtp_sender_context_ = - std::make_unique(config, &time_controller_); - rtp_sender()->SetSequenceNumber(kSeqNum); - rtp_sender()->SetTimestampOffset(0); + void CreateSender(const RtpRtcpInterface::Configuration& config) { + packet_history_ = std::make_unique( + config.clock, config.enable_rtx_padding_prioritization); + rtp_sender_ = std::make_unique(config, packet_history_.get(), + config.paced_sender); + rtp_sender_->SetSequenceNumber(kSeqNum); + rtp_sender_->SetTimestampOffset(0); } GlobalSimulatedTimeController time_controller_; Clock* const clock_; NiceMock mock_rtc_event_log_; MockRtpPacketPacer mock_paced_sender_; - StrictMock send_packet_observer_; - StrictMock feedback_observer_; RateLimiter retransmission_rate_limiter_; FlexfecSender flexfec_sender_; - std::unique_ptr rtp_sender_context_; + std::unique_ptr packet_history_; + std::unique_ptr rtp_sender_; - LoopbackTransportTest transport_; const bool kMarkerBit; FieldTrialConfig field_trials_; - StreamDataTestCallback rtp_stats_callback_; std::unique_ptr BuildRtpPacket(int payload_type, bool marker_bit, uint32_t timestamp, int64_t capture_time_ms) { - auto packet = rtp_sender()->AllocatePacket(); + auto packet = rtp_sender_->AllocatePacket(); packet->SetPayloadType(payload_type); packet->set_packet_type(RtpPacketMediaType::kVideo); packet->SetMarker(marker_bit); packet->SetTimestamp(timestamp); packet->set_capture_time_ms(capture_time_ms); - EXPECT_TRUE(rtp_sender()->AssignSequenceNumber(packet.get())); + EXPECT_TRUE(rtp_sender_->AssignSequenceNumber(packet.get())); return packet; } @@ -408,22 +210,26 @@ class RtpSenderTest : public ::testing::TestWithParam { packet->set_allow_retransmission(true); // Packet should be stored in a send bucket. - EXPECT_TRUE(rtp_sender()->SendToNetwork( - std::make_unique(*packet))); + EXPECT_TRUE( + rtp_sender_->SendToNetwork(std::make_unique(*packet))); return packet; } std::unique_ptr SendGenericPacket() { const int64_t kCaptureTimeMs = clock_->TimeInMilliseconds(); - return SendPacket(kCaptureTimeMs, sizeof(kPayloadData)); + // Use maximum allowed size to catch corner cases when packet is dropped + // because of lack of capacity for the media packet, or for an rtx packet + // containing the media packet. + return SendPacket(kCaptureTimeMs, + /*payload_length=*/rtp_sender_->MaxRtpPacketSize() - + rtp_sender_->ExpectedPerPacketOverhead()); } size_t GenerateAndSendPadding(size_t target_size_bytes) { size_t generated_bytes = 0; - for (auto& packet : - rtp_sender()->GeneratePadding(target_size_bytes, true)) { + for (auto& packet : rtp_sender_->GeneratePadding(target_size_bytes, true)) { generated_bytes += packet->payload_size() + packet->padding_size(); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); + rtp_sender_->SendToNetwork(std::move(packet)); } return generated_bytes; } @@ -436,67 +242,56 @@ class RtpSenderTest : public ::testing::TestWithParam { // RTX needs to be able to read the source packets from the packet store. // Pick a number of packets to store big enough for any unit test. constexpr uint16_t kNumberOfPacketsToStore = 100; - rtp_sender_context_->packet_history_.SetStorePacketsStatus( + packet_history_->SetStorePacketsStatus( RtpPacketHistory::StorageMode::kStoreAndCull, kNumberOfPacketsToStore); - rtp_sender()->SetRtxPayloadType(kRtxPayload, kPayload); - rtp_sender()->SetRtxStatus(kRtxRetransmitted | kRtxRedundantPayloads); + rtp_sender_->SetRtxPayloadType(kRtxPayload, kPayload); + rtp_sender_->SetRtxStatus(kRtxRetransmitted | kRtxRedundantPayloads); } // Enable sending of the MID header extension for both the primary SSRC and // the RTX SSRC. void EnableMidSending(const std::string& mid) { - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionMid, kMidExtensionId); - rtp_sender()->SetMid(mid); + rtp_sender_->RegisterRtpHeaderExtension(RtpMid::kUri, kMidExtensionId); + rtp_sender_->SetMid(mid); } // Enable sending of the RSID header extension for the primary SSRC and the // RRSID header extension for the RTX SSRC. void EnableRidSending(const std::string& rid) { - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionRtpStreamId, - kRidExtensionId); - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionRepairedRtpStreamId, - kRepairedRidExtensionId); - rtp_sender()->SetRid(rid); + rtp_sender_->RegisterRtpHeaderExtension(RtpStreamId::kUri, kRidExtensionId); + rtp_sender_->RegisterRtpHeaderExtension(RepairedRtpStreamId::kUri, + kRepairedRidExtensionId); + rtp_sender_->SetRid(rid); } }; -// TODO(pbos): Move tests over from WithoutPacer to RtpSenderTest as this is our -// default code path. -class RtpSenderTestWithoutPacer : public RtpSenderTest { - public: - void SetUp() override { SetUpRtpSender(false, false, false); } -}; - -TEST_P(RtpSenderTestWithoutPacer, AllocatePacketSetCsrc) { +TEST_F(RtpSenderTest, AllocatePacketSetCsrc) { // Configure rtp_sender with csrc. std::vector csrcs; csrcs.push_back(0x23456789); - rtp_sender()->SetCsrcs(csrcs); + rtp_sender_->SetCsrcs(csrcs); - auto packet = rtp_sender()->AllocatePacket(); + auto packet = rtp_sender_->AllocatePacket(); ASSERT_TRUE(packet); - EXPECT_EQ(rtp_sender()->SSRC(), packet->Ssrc()); + EXPECT_EQ(rtp_sender_->SSRC(), packet->Ssrc()); EXPECT_EQ(csrcs, packet->Csrcs()); } -TEST_P(RtpSenderTestWithoutPacer, AllocatePacketReserveExtensions) { +TEST_F(RtpSenderTest, AllocatePacketReserveExtensions) { // Configure rtp_sender with extensions. - ASSERT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransmissionTimeOffset, - kTransmissionTimeOffsetExtensionId)); - ASSERT_EQ(0, - rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionAbsoluteSendTime, kAbsoluteSendTimeExtensionId)); - ASSERT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionAudioLevel, kAudioLevelExtensionId)); - ASSERT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId)); - ASSERT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionVideoRotation, kVideoRotationExtensionId)); - - auto packet = rtp_sender()->AllocatePacket(); + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + TransmissionOffset::kUri, kTransmissionTimeOffsetExtensionId)); + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + AbsoluteSendTime::kUri, kAbsoluteSendTimeExtensionId)); + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension(AudioLevel::kUri, + kAudioLevelExtensionId)); + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + TransportSequenceNumber::kUri, kTransportSequenceNumberExtensionId)); + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + VideoOrientation::kUri, kVideoRotationExtensionId)); + + auto packet = rtp_sender_->AllocatePacket(); ASSERT_TRUE(packet); // Preallocate BWE extensions RtpSender set itself. @@ -508,1016 +303,284 @@ TEST_P(RtpSenderTestWithoutPacer, AllocatePacketReserveExtensions) { EXPECT_FALSE(packet->HasExtension()); } -TEST_P(RtpSenderTestWithoutPacer, AssignSequenceNumberAdvanceSequenceNumber) { - auto packet = rtp_sender()->AllocatePacket(); - ASSERT_TRUE(packet); - const uint16_t sequence_number = rtp_sender()->SequenceNumber(); - - EXPECT_TRUE(rtp_sender()->AssignSequenceNumber(packet.get())); - - EXPECT_EQ(sequence_number, packet->SequenceNumber()); - EXPECT_EQ(sequence_number + 1, rtp_sender()->SequenceNumber()); -} - -TEST_P(RtpSenderTestWithoutPacer, AssignSequenceNumberFailsOnNotSending) { - auto packet = rtp_sender()->AllocatePacket(); - ASSERT_TRUE(packet); - - rtp_sender()->SetSendingMediaStatus(false); - EXPECT_FALSE(rtp_sender()->AssignSequenceNumber(packet.get())); -} - -TEST_P(RtpSenderTestWithoutPacer, AssignSequenceNumberMayAllowPaddingOnVideo) { - constexpr size_t kPaddingSize = 100; - auto packet = rtp_sender()->AllocatePacket(); - ASSERT_TRUE(packet); - - ASSERT_TRUE(rtp_sender()->GeneratePadding(kPaddingSize, true).empty()); - packet->SetMarker(false); - ASSERT_TRUE(rtp_sender()->AssignSequenceNumber(packet.get())); - // Packet without marker bit doesn't allow padding on video stream. - ASSERT_TRUE(rtp_sender()->GeneratePadding(kPaddingSize, true).empty()); - - packet->SetMarker(true); - ASSERT_TRUE(rtp_sender()->AssignSequenceNumber(packet.get())); - // Packet with marker bit allows send padding. - ASSERT_FALSE(rtp_sender()->GeneratePadding(kPaddingSize, true).empty()); -} - -TEST_P(RtpSenderTest, AssignSequenceNumberAllowsPaddingOnAudio) { - MockTransport transport; - RtpRtcpInterface::Configuration config; +TEST_F(RtpSenderTest, PaddingAlwaysAllowedOnAudio) { + RtpRtcpInterface::Configuration config = GetDefaultConfig(); config.audio = true; - config.clock = clock_; - config.outgoing_transport = &transport; - config.paced_sender = &mock_paced_sender_; - config.local_media_ssrc = kSsrc; - config.event_log = &mock_rtc_event_log_; - config.retransmission_rate_limiter = &retransmission_rate_limiter_; - rtp_sender_context_ = - std::make_unique(config, &time_controller_); - - rtp_sender()->SetTimestampOffset(0); - - std::unique_ptr audio_packet = - rtp_sender()->AllocatePacket(); + CreateSender(config); + + std::unique_ptr audio_packet = rtp_sender_->AllocatePacket(); // Padding on audio stream allowed regardless of marker in the last packet. audio_packet->SetMarker(false); audio_packet->SetPayloadType(kPayload); - rtp_sender()->AssignSequenceNumber(audio_packet.get()); + rtp_sender_->AssignSequenceNumber(audio_packet.get()); const size_t kPaddingSize = 59; - EXPECT_CALL(transport, SendRtp(_, kPaddingSize + kRtpHeaderSize, _)) - .WillOnce(Return(true)); + + EXPECT_CALL( + mock_paced_sender_, + EnqueuePackets(ElementsAre(AllOf( + Pointee(Property(&RtpPacketToSend::packet_type, + RtpPacketMediaType::kPadding)), + Pointee(Property(&RtpPacketToSend::padding_size, kPaddingSize)))))); EXPECT_EQ(kPaddingSize, GenerateAndSendPadding(kPaddingSize)); // Requested padding size is too small, will send a larger one. const size_t kMinPaddingSize = 50; - EXPECT_CALL(transport, SendRtp(_, kMinPaddingSize + kRtpHeaderSize, _)) - .WillOnce(Return(true)); + EXPECT_CALL(mock_paced_sender_, + EnqueuePackets(ElementsAre( + AllOf(Pointee(Property(&RtpPacketToSend::packet_type, + RtpPacketMediaType::kPadding)), + Pointee(Property(&RtpPacketToSend::padding_size, + kMinPaddingSize)))))); EXPECT_EQ(kMinPaddingSize, GenerateAndSendPadding(kMinPaddingSize - 5)); } -TEST_P(RtpSenderTestWithoutPacer, AssignSequenceNumberSetPaddingTimestamps) { - constexpr size_t kPaddingSize = 100; - auto packet = rtp_sender()->AllocatePacket(); - ASSERT_TRUE(packet); - packet->SetMarker(true); - packet->SetTimestamp(kTimestamp); - - ASSERT_TRUE(rtp_sender()->AssignSequenceNumber(packet.get())); - auto padding_packets = rtp_sender()->GeneratePadding(kPaddingSize, true); - - ASSERT_EQ(1u, padding_packets.size()); - // Verify padding packet timestamp. - EXPECT_EQ(kTimestamp, padding_packets[0]->Timestamp()); -} - -TEST_P(RtpSenderTestWithoutPacer, - TransportFeedbackObserverGetsCorrectByteCount) { - constexpr size_t kRtpOverheadBytesPerPacket = 12 + 8; - - RtpRtcpInterface::Configuration config; - config.clock = clock_; - config.outgoing_transport = &transport_; - config.local_media_ssrc = kSsrc; - config.transport_feedback_callback = &feedback_observer_; - config.event_log = &mock_rtc_event_log_; - config.retransmission_rate_limiter = &retransmission_rate_limiter_; - config.field_trials = &field_trials_; - rtp_sender_context_ = - std::make_unique(config, &time_controller_); - - EXPECT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId)); - - const size_t expected_bytes = - GetParam().with_overhead - ? sizeof(kPayloadData) + kRtpOverheadBytesPerPacket - : sizeof(kPayloadData); - - EXPECT_CALL(feedback_observer_, - OnAddPacket(AllOf( - Field(&RtpPacketSendInfo::ssrc, rtp_sender()->SSRC()), - Field(&RtpPacketSendInfo::transport_sequence_number, - kTransportSequenceNumber), - Field(&RtpPacketSendInfo::rtp_sequence_number, - rtp_sender()->SequenceNumber()), - Field(&RtpPacketSendInfo::length, expected_bytes), - Field(&RtpPacketSendInfo::pacing_info, PacedPacketInfo())))) - .Times(1); - EXPECT_EQ(rtp_sender()->ExpectedPerPacketOverhead(), - kRtpOverheadBytesPerPacket); - SendGenericPacket(); -} - -TEST_P(RtpSenderTestWithoutPacer, SendsPacketsWithTransportSequenceNumber) { - RtpRtcpInterface::Configuration config; - config.clock = clock_; - config.outgoing_transport = &transport_; - config.local_media_ssrc = kSsrc; - config.transport_feedback_callback = &feedback_observer_; - config.event_log = &mock_rtc_event_log_; - config.send_packet_observer = &send_packet_observer_; - config.retransmission_rate_limiter = &retransmission_rate_limiter_; - rtp_sender_context_ = - std::make_unique(config, &time_controller_); - - EXPECT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId)); - - EXPECT_CALL(send_packet_observer_, - OnSendPacket(kTransportSequenceNumber, _, _)) - .Times(1); - - EXPECT_CALL(feedback_observer_, - OnAddPacket(AllOf( - Field(&RtpPacketSendInfo::ssrc, rtp_sender()->SSRC()), - Field(&RtpPacketSendInfo::transport_sequence_number, - kTransportSequenceNumber), - Field(&RtpPacketSendInfo::rtp_sequence_number, - rtp_sender()->SequenceNumber()), - Field(&RtpPacketSendInfo::pacing_info, PacedPacketInfo())))) - .Times(1); - - SendGenericPacket(); - - const auto& packet = transport_.last_sent_packet(); - uint16_t transport_seq_no; - ASSERT_TRUE(packet.GetExtension(&transport_seq_no)); - EXPECT_EQ(kTransportSequenceNumber, transport_seq_no); - EXPECT_EQ(transport_.last_options_.packet_id, transport_seq_no); - EXPECT_TRUE(transport_.last_options_.included_in_allocation); -} - -TEST_P(RtpSenderTestWithoutPacer, PacketOptionsNoRetransmission) { - RtpRtcpInterface::Configuration config; - config.clock = clock_; - config.outgoing_transport = &transport_; - config.local_media_ssrc = kSsrc; - config.transport_feedback_callback = &feedback_observer_; - config.event_log = &mock_rtc_event_log_; - config.send_packet_observer = &send_packet_observer_; - config.retransmission_rate_limiter = &retransmission_rate_limiter_; - rtp_sender_context_ = - std::make_unique(config, &time_controller_); - - SendGenericPacket(); - - EXPECT_FALSE(transport_.last_options_.is_retransmit); -} - -TEST_P(RtpSenderTestWithoutPacer, - SetsIncludedInFeedbackWhenTransportSequenceNumberExtensionIsRegistered) { - SetUpRtpSender(false, false, false); - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId); - EXPECT_CALL(send_packet_observer_, OnSendPacket).Times(1); - SendGenericPacket(); - EXPECT_TRUE(transport_.last_options_.included_in_feedback); -} - -TEST_P( - RtpSenderTestWithoutPacer, - SetsIncludedInAllocationWhenTransportSequenceNumberExtensionIsRegistered) { - SetUpRtpSender(false, false, false); - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId); - EXPECT_CALL(send_packet_observer_, OnSendPacket).Times(1); - SendGenericPacket(); - EXPECT_TRUE(transport_.last_options_.included_in_allocation); -} - -TEST_P(RtpSenderTestWithoutPacer, - SetsIncludedInAllocationWhenForcedAsPartOfAllocation) { - SetUpRtpSender(false, false, false); - rtp_egress()->ForceIncludeSendPacketsInAllocation(true); - SendGenericPacket(); - EXPECT_FALSE(transport_.last_options_.included_in_feedback); - EXPECT_TRUE(transport_.last_options_.included_in_allocation); -} - -TEST_P(RtpSenderTestWithoutPacer, DoesnSetIncludedInAllocationByDefault) { - SetUpRtpSender(false, false, false); - SendGenericPacket(); - EXPECT_FALSE(transport_.last_options_.included_in_feedback); - EXPECT_FALSE(transport_.last_options_.included_in_allocation); -} - -TEST_P(RtpSenderTestWithoutPacer, OnSendSideDelayUpdated) { - StrictMock send_side_delay_observer_; - - RtpRtcpInterface::Configuration config; - config.clock = clock_; - config.outgoing_transport = &transport_; - config.local_media_ssrc = kSsrc; - config.send_side_delay_observer = &send_side_delay_observer_; - config.event_log = &mock_rtc_event_log_; - rtp_sender_context_ = - std::make_unique(config, &time_controller_); - - FieldTrialBasedConfig field_trials; - RTPSenderVideo::Config video_config; - video_config.clock = clock_; - video_config.rtp_sender = rtp_sender(); - video_config.field_trials = &field_trials; - RTPSenderVideo rtp_sender_video(video_config); - - const uint8_t kPayloadType = 127; - const absl::optional kCodecType = - VideoCodecType::kVideoCodecGeneric; - - const uint32_t kCaptureTimeMsToRtpTimestamp = 90; // 90 kHz clock - RTPVideoHeader video_header; - - // Send packet with 10 ms send-side delay. The average, max and total should - // be 10 ms. - EXPECT_CALL(send_side_delay_observer_, - SendSideDelayUpdated(10, 10, 10, kSsrc)) - .Times(1); - int64_t capture_time_ms = clock_->TimeInMilliseconds(); - time_controller_.AdvanceTime(TimeDelta::Millis(10)); - video_header.frame_type = VideoFrameType::kVideoFrameKey; - EXPECT_TRUE(rtp_sender_video.SendVideo( - kPayloadType, kCodecType, capture_time_ms * kCaptureTimeMsToRtpTimestamp, - capture_time_ms, kPayloadData, video_header, - kDefaultExpectedRetransmissionTimeMs)); - - // Send another packet with 20 ms delay. The average, max and total should be - // 15, 20 and 30 ms respectively. - EXPECT_CALL(send_side_delay_observer_, - SendSideDelayUpdated(15, 20, 30, kSsrc)) - .Times(1); - time_controller_.AdvanceTime(TimeDelta::Millis(10)); - video_header.frame_type = VideoFrameType::kVideoFrameKey; - EXPECT_TRUE(rtp_sender_video.SendVideo( - kPayloadType, kCodecType, capture_time_ms * kCaptureTimeMsToRtpTimestamp, - capture_time_ms, kPayloadData, video_header, - kDefaultExpectedRetransmissionTimeMs)); - - // Send another packet at the same time, which replaces the last packet. - // Since this packet has 0 ms delay, the average is now 5 ms and max is 10 ms. - // The total counter stays the same though. - // TODO(terelius): Is is not clear that this is the right behavior. - EXPECT_CALL(send_side_delay_observer_, SendSideDelayUpdated(5, 10, 30, kSsrc)) - .Times(1); - capture_time_ms = clock_->TimeInMilliseconds(); - video_header.frame_type = VideoFrameType::kVideoFrameKey; - EXPECT_TRUE(rtp_sender_video.SendVideo( - kPayloadType, kCodecType, capture_time_ms * kCaptureTimeMsToRtpTimestamp, - capture_time_ms, kPayloadData, video_header, - kDefaultExpectedRetransmissionTimeMs)); - - // Send a packet 1 second later. The earlier packets should have timed - // out, so both max and average should be the delay of this packet. The total - // keeps increasing. - time_controller_.AdvanceTime(TimeDelta::Millis(1000)); - capture_time_ms = clock_->TimeInMilliseconds(); - time_controller_.AdvanceTime(TimeDelta::Millis(1)); - EXPECT_CALL(send_side_delay_observer_, SendSideDelayUpdated(1, 1, 31, kSsrc)) - .Times(1); - video_header.frame_type = VideoFrameType::kVideoFrameKey; - EXPECT_TRUE(rtp_sender_video.SendVideo( - kPayloadType, kCodecType, capture_time_ms * kCaptureTimeMsToRtpTimestamp, - capture_time_ms, kPayloadData, video_header, - kDefaultExpectedRetransmissionTimeMs)); -} - -TEST_P(RtpSenderTestWithoutPacer, OnSendPacketUpdated) { - EXPECT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId)); - EXPECT_CALL(send_packet_observer_, - OnSendPacket(kTransportSequenceNumber, _, _)) - .Times(1); - - SendGenericPacket(); -} - -TEST_P(RtpSenderTest, SendsPacketsWithTransportSequenceNumber) { - RtpRtcpInterface::Configuration config; - config.clock = clock_; - config.outgoing_transport = &transport_; - config.paced_sender = &mock_paced_sender_; - config.local_media_ssrc = kSsrc; - config.transport_feedback_callback = &feedback_observer_; - config.event_log = &mock_rtc_event_log_; - config.send_packet_observer = &send_packet_observer_; - config.retransmission_rate_limiter = &retransmission_rate_limiter_; - rtp_sender_context_ = - std::make_unique(config, &time_controller_); - - rtp_sender()->SetSequenceNumber(kSeqNum); - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 10); - EXPECT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId)); - - EXPECT_CALL(send_packet_observer_, - OnSendPacket(kTransportSequenceNumber, _, _)) - .Times(1); - EXPECT_CALL(feedback_observer_, - OnAddPacket(AllOf( - Field(&RtpPacketSendInfo::ssrc, rtp_sender()->SSRC()), - Field(&RtpPacketSendInfo::transport_sequence_number, - kTransportSequenceNumber), - Field(&RtpPacketSendInfo::rtp_sequence_number, - rtp_sender()->SequenceNumber()), - Field(&RtpPacketSendInfo::pacing_info, PacedPacketInfo())))) - .Times(1); - - EXPECT_CALL( - mock_paced_sender_, - EnqueuePackets(Contains(AllOf( - Pointee(Property(&RtpPacketToSend::Ssrc, kSsrc)), - Pointee(Property(&RtpPacketToSend::SequenceNumber, kSeqNum)))))); - auto packet = SendGenericPacket(); - packet->set_packet_type(RtpPacketMediaType::kVideo); - // Transport sequence number is set by PacketRouter, before SendPacket(). - packet->SetExtension(kTransportSequenceNumber); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - - uint16_t transport_seq_no; - EXPECT_TRUE( - transport_.last_sent_packet().GetExtension( - &transport_seq_no)); - EXPECT_EQ(kTransportSequenceNumber, transport_seq_no); - EXPECT_EQ(transport_.last_options_.packet_id, transport_seq_no); -} - -TEST_P(RtpSenderTest, WritesPacerExitToTimingExtension) { - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 10); - EXPECT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionVideoTiming, kVideoTimingExtensionId)); - int64_t capture_time_ms = clock_->TimeInMilliseconds(); - auto packet = rtp_sender()->AllocatePacket(); - packet->SetPayloadType(kPayload); - packet->SetMarker(true); - packet->SetTimestamp(kTimestamp); - packet->set_capture_time_ms(capture_time_ms); - const VideoSendTiming kVideoTiming = {0u, 0u, 0u, 0u, 0u, 0u, true}; - packet->SetExtension(kVideoTiming); - EXPECT_TRUE(rtp_sender()->AssignSequenceNumber(packet.get())); - size_t packet_size = packet->size(); - - const int kStoredTimeInMs = 100; - packet->set_packet_type(RtpPacketMediaType::kVideo); - packet->set_allow_retransmission(true); - EXPECT_CALL(mock_paced_sender_, EnqueuePackets(Contains(Pointee(Property( - &RtpPacketToSend::Ssrc, kSsrc))))); - EXPECT_TRUE( - rtp_sender()->SendToNetwork(std::make_unique(*packet))); - time_controller_.AdvanceTime(TimeDelta::Millis(kStoredTimeInMs)); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - EXPECT_EQ(1, transport_.packets_sent()); - EXPECT_EQ(packet_size, transport_.last_sent_packet().size()); - - VideoSendTiming video_timing; - EXPECT_TRUE(transport_.last_sent_packet().GetExtension( - &video_timing)); - EXPECT_EQ(kStoredTimeInMs, video_timing.pacer_exit_delta_ms); -} - -TEST_P(RtpSenderTest, WritesNetwork2ToTimingExtensionWithPacer) { - SetUpRtpSender(/*pacer=*/true, /*populate_network2=*/true, false); - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 10); - EXPECT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionVideoTiming, kVideoTimingExtensionId)); - int64_t capture_time_ms = clock_->TimeInMilliseconds(); - auto packet = rtp_sender()->AllocatePacket(); - packet->SetPayloadType(kPayload); - packet->SetMarker(true); - packet->SetTimestamp(kTimestamp); - packet->set_capture_time_ms(capture_time_ms); - const uint16_t kPacerExitMs = 1234u; - const VideoSendTiming kVideoTiming = {0u, 0u, 0u, kPacerExitMs, 0u, 0u, true}; - packet->SetExtension(kVideoTiming); - EXPECT_TRUE(rtp_sender()->AssignSequenceNumber(packet.get())); - size_t packet_size = packet->size(); - - const int kStoredTimeInMs = 100; - - packet->set_packet_type(RtpPacketMediaType::kVideo); - packet->set_allow_retransmission(true); - EXPECT_CALL(mock_paced_sender_, EnqueuePackets(Contains(Pointee(Property( - &RtpPacketToSend::Ssrc, kSsrc))))); - EXPECT_TRUE( - rtp_sender()->SendToNetwork(std::make_unique(*packet))); - time_controller_.AdvanceTime(TimeDelta::Millis(kStoredTimeInMs)); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - - EXPECT_EQ(1, transport_.packets_sent()); - EXPECT_EQ(packet_size, transport_.last_sent_packet().size()); - - VideoSendTiming video_timing; - EXPECT_TRUE(transport_.last_sent_packet().GetExtension( - &video_timing)); - EXPECT_EQ(kStoredTimeInMs, video_timing.network2_timestamp_delta_ms); - EXPECT_EQ(kPacerExitMs, video_timing.pacer_exit_delta_ms); -} - -TEST_P(RtpSenderTest, WritesNetwork2ToTimingExtensionWithoutPacer) { - SetUpRtpSender(/*pacer=*/false, /*populate_network2=*/true, false); - EXPECT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionVideoTiming, kVideoTimingExtensionId)); - auto packet = rtp_sender()->AllocatePacket(); - packet->SetMarker(true); - packet->set_capture_time_ms(clock_->TimeInMilliseconds()); - const VideoSendTiming kVideoTiming = {0u, 0u, 0u, 0u, 0u, 0u, true}; - packet->SetExtension(kVideoTiming); - packet->set_allow_retransmission(true); - EXPECT_TRUE(rtp_sender()->AssignSequenceNumber(packet.get())); - packet->set_packet_type(RtpPacketMediaType::kVideo); - - const int kPropagateTimeMs = 10; - time_controller_.AdvanceTime(TimeDelta::Millis(kPropagateTimeMs)); - - EXPECT_TRUE(rtp_sender()->SendToNetwork(std::move(packet))); - - EXPECT_EQ(1, transport_.packets_sent()); - absl::optional video_timing = - transport_.last_sent_packet().GetExtension(); - ASSERT_TRUE(video_timing); - EXPECT_EQ(kPropagateTimeMs, video_timing->network2_timestamp_delta_ms); -} - -TEST_P(RtpSenderTest, TrafficSmoothingWithExtensions) { - EXPECT_CALL(mock_rtc_event_log_, - LogProxy(SameRtcEventTypeAs(RtcEvent::Type::RtpPacketOutgoing))); +TEST_F(RtpSenderTest, SendToNetworkForwardsPacketsToPacer) { + auto packet = BuildRtpPacket(kPayload, kMarkerBit, kTimestamp, 0); + int64_t now_ms = clock_->TimeInMilliseconds(); - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 10); - EXPECT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransmissionTimeOffset, - kTransmissionTimeOffsetExtensionId)); - EXPECT_EQ(0, - rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionAbsoluteSendTime, kAbsoluteSendTimeExtensionId)); - int64_t capture_time_ms = clock_->TimeInMilliseconds(); - auto packet = - BuildRtpPacket(kPayload, kMarkerBit, kTimestamp, capture_time_ms); - size_t packet_size = packet->size(); - - const int kStoredTimeInMs = 100; EXPECT_CALL( mock_paced_sender_, - EnqueuePackets(Contains(AllOf( + EnqueuePackets(ElementsAre(AllOf( Pointee(Property(&RtpPacketToSend::Ssrc, kSsrc)), - Pointee(Property(&RtpPacketToSend::SequenceNumber, kSeqNum)))))); - packet->set_packet_type(RtpPacketMediaType::kVideo); - packet->set_allow_retransmission(true); + Pointee(Property(&RtpPacketToSend::SequenceNumber, kSeqNum)), + Pointee(Property(&RtpPacketToSend::capture_time_ms, now_ms)))))); EXPECT_TRUE( - rtp_sender()->SendToNetwork(std::make_unique(*packet))); - EXPECT_EQ(0, transport_.packets_sent()); - time_controller_.AdvanceTime(TimeDelta::Millis(kStoredTimeInMs)); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - - // Process send bucket. Packet should now be sent. - EXPECT_EQ(1, transport_.packets_sent()); - EXPECT_EQ(packet_size, transport_.last_sent_packet().size()); - - webrtc::RTPHeader rtp_header; - transport_.last_sent_packet().GetHeader(&rtp_header); - - // Verify transmission time offset. - EXPECT_EQ(kStoredTimeInMs * 90, rtp_header.extension.transmissionTimeOffset); - uint64_t expected_send_time = - ConvertMsToAbsSendTime(clock_->TimeInMilliseconds()); - EXPECT_EQ(expected_send_time, rtp_header.extension.absoluteSendTime); + rtp_sender_->SendToNetwork(std::make_unique(*packet))); } -TEST_P(RtpSenderTest, TrafficSmoothingRetransmits) { - EXPECT_CALL(mock_rtc_event_log_, - LogProxy(SameRtcEventTypeAs(RtcEvent::Type::RtpPacketOutgoing))); - - rtp_sender_context_->packet_history_.SetStorePacketsStatus( +TEST_F(RtpSenderTest, ReSendPacketForwardsPacketsToPacer) { + packet_history_->SetStorePacketsStatus( RtpPacketHistory::StorageMode::kStoreAndCull, 10); - EXPECT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransmissionTimeOffset, - kTransmissionTimeOffsetExtensionId)); - EXPECT_EQ(0, - rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionAbsoluteSendTime, kAbsoluteSendTimeExtensionId)); - int64_t capture_time_ms = clock_->TimeInMilliseconds(); - auto packet = - BuildRtpPacket(kPayload, kMarkerBit, kTimestamp, capture_time_ms); - size_t packet_size = packet->size(); - - // Packet should be stored in a send bucket. - EXPECT_CALL( - mock_paced_sender_, - EnqueuePackets(Contains(AllOf( - Pointee(Property(&RtpPacketToSend::Ssrc, kSsrc)), - Pointee(Property(&RtpPacketToSend::SequenceNumber, kSeqNum)))))); - packet->set_packet_type(RtpPacketMediaType::kVideo); + int64_t now_ms = clock_->TimeInMilliseconds(); + auto packet = BuildRtpPacket(kPayload, kMarkerBit, kTimestamp, now_ms); + uint16_t seq_no = packet->SequenceNumber(); packet->set_allow_retransmission(true); - EXPECT_TRUE( - rtp_sender()->SendToNetwork(std::make_unique(*packet))); - // Immediately process send bucket and send packet. - rtp_sender_context_->InjectPacket(std::make_unique(*packet), - PacedPacketInfo()); - - EXPECT_EQ(1, transport_.packets_sent()); + packet_history_->PutRtpPacket(std::move(packet), now_ms); - // Retransmit packet. - const int kStoredTimeInMs = 100; - time_controller_.AdvanceTime(TimeDelta::Millis(kStoredTimeInMs)); - - EXPECT_CALL(mock_rtc_event_log_, - LogProxy(SameRtcEventTypeAs(RtcEvent::Type::RtpPacketOutgoing))); - packet->set_packet_type(RtpPacketMediaType::kRetransmission); - packet->set_retransmitted_sequence_number(kSeqNum); - EXPECT_CALL( - mock_paced_sender_, - EnqueuePackets(Contains(AllOf( - Pointee(Property(&RtpPacketToSend::Ssrc, kSsrc)), - Pointee(Property(&RtpPacketToSend::SequenceNumber, kSeqNum)))))); - EXPECT_EQ(static_cast(packet_size), rtp_sender()->ReSendPacket(kSeqNum)); - EXPECT_EQ(1, transport_.packets_sent()); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - - // Process send bucket. Packet should now be sent. - EXPECT_EQ(2, transport_.packets_sent()); - EXPECT_EQ(packet_size, transport_.last_sent_packet().size()); - - webrtc::RTPHeader rtp_header; - transport_.last_sent_packet().GetHeader(&rtp_header); - - // Verify transmission time offset. - EXPECT_EQ(kStoredTimeInMs * 90, rtp_header.extension.transmissionTimeOffset); - uint64_t expected_send_time = - ConvertMsToAbsSendTime(clock_->TimeInMilliseconds()); - EXPECT_EQ(expected_send_time, rtp_header.extension.absoluteSendTime); + EXPECT_CALL(mock_paced_sender_, + EnqueuePackets(ElementsAre(AllOf( + Pointee(Property(&RtpPacketToSend::Ssrc, kSsrc)), + Pointee(Property(&RtpPacketToSend::SequenceNumber, kSeqNum)), + Pointee(Property(&RtpPacketToSend::capture_time_ms, now_ms)), + Pointee(Property(&RtpPacketToSend::packet_type, + RtpPacketMediaType::kRetransmission)))))); + EXPECT_TRUE(rtp_sender_->ReSendPacket(seq_no)); } // This test sends 1 regular video packet, then 4 padding packets, and then // 1 more regular packet. -TEST_P(RtpSenderTest, SendPadding) { - // Make all (non-padding) packets go to send queue. - EXPECT_CALL(mock_rtc_event_log_, - LogProxy(SameRtcEventTypeAs(RtcEvent::Type::RtpPacketOutgoing))) - .Times(1 + 4 + 1); - - uint16_t seq_num = kSeqNum; - uint32_t timestamp = kTimestamp; - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 10); - size_t rtp_header_len = kRtpHeaderSize; - EXPECT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransmissionTimeOffset, - kTransmissionTimeOffsetExtensionId)); - rtp_header_len += 4; // 4 bytes extension. - EXPECT_EQ(0, - rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionAbsoluteSendTime, kAbsoluteSendTimeExtensionId)); - rtp_header_len += 4; // 4 bytes extension. - rtp_header_len += 4; // 4 extra bytes common to all extension headers. - - webrtc::RTPHeader rtp_header; - - int64_t capture_time_ms = clock_->TimeInMilliseconds(); - auto packet = - BuildRtpPacket(kPayload, kMarkerBit, timestamp, capture_time_ms); - const uint32_t media_packet_timestamp = timestamp; - size_t packet_size = packet->size(); - int total_packets_sent = 0; - const int kStoredTimeInMs = 100; - - // Packet should be stored in a send bucket. - EXPECT_CALL( - mock_paced_sender_, - EnqueuePackets(Contains(AllOf( - Pointee(Property(&RtpPacketToSend::Ssrc, kSsrc)), - Pointee(Property(&RtpPacketToSend::SequenceNumber, kSeqNum)))))); - packet->set_packet_type(RtpPacketMediaType::kVideo); - packet->set_allow_retransmission(true); - EXPECT_TRUE( - rtp_sender()->SendToNetwork(std::make_unique(*packet))); - EXPECT_EQ(total_packets_sent, transport_.packets_sent()); - time_controller_.AdvanceTime(TimeDelta::Millis(kStoredTimeInMs)); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - ++seq_num; - - // Packet should now be sent. This test doesn't verify the regular video - // packet, since it is tested in another test. - EXPECT_EQ(++total_packets_sent, transport_.packets_sent()); - timestamp += 90 * kStoredTimeInMs; - - // Send padding 4 times, waiting 50 ms between each. - for (int i = 0; i < 4; ++i) { - const int kPaddingPeriodMs = 50; - const size_t kPaddingBytes = 100; - const size_t kMaxPaddingLength = 224; // Value taken from rtp_sender.cc. - // Padding will be forced to full packets. - EXPECT_EQ(kMaxPaddingLength, GenerateAndSendPadding(kPaddingBytes)); - - // Process send bucket. Padding should now be sent. - EXPECT_EQ(++total_packets_sent, transport_.packets_sent()); - EXPECT_EQ(kMaxPaddingLength + rtp_header_len, - transport_.last_sent_packet().size()); - - transport_.last_sent_packet().GetHeader(&rtp_header); - EXPECT_EQ(kMaxPaddingLength, rtp_header.paddingLength); - - // Verify sequence number and timestamp. The timestamp should be the same - // as the last media packet. - EXPECT_EQ(seq_num++, rtp_header.sequenceNumber); - EXPECT_EQ(media_packet_timestamp, rtp_header.timestamp); - // Verify transmission time offset. - int offset = timestamp - media_packet_timestamp; - EXPECT_EQ(offset, rtp_header.extension.transmissionTimeOffset); - uint64_t expected_send_time = - ConvertMsToAbsSendTime(clock_->TimeInMilliseconds()); - EXPECT_EQ(expected_send_time, rtp_header.extension.absoluteSendTime); - time_controller_.AdvanceTime(TimeDelta::Millis(kPaddingPeriodMs)); - timestamp += 90 * kPaddingPeriodMs; +TEST_F(RtpSenderTest, SendPadding) { + constexpr int kNumPaddingPackets = 4; + EXPECT_CALL(mock_paced_sender_, EnqueuePackets); + std::unique_ptr media_packet = + SendPacket(/*capture_time_ms=*/clock_->TimeInMilliseconds(), + /*payload_size=*/100); + + // Wait 50 ms before generating each padding packet. + for (int i = 0; i < kNumPaddingPackets; ++i) { + time_controller_.AdvanceTime(TimeDelta::Millis(50)); + const size_t kPaddingTargetBytes = 100; // Request 100 bytes of padding. + + // Padding should be sent on the media ssrc, with a continous sequence + // number range. Size will be forced to full pack size and the timestamp + // shall be that of the last media packet. + EXPECT_CALL(mock_paced_sender_, + EnqueuePackets(ElementsAre(AllOf( + Pointee(Property(&RtpPacketToSend::Ssrc, kSsrc)), + Pointee(Property(&RtpPacketToSend::SequenceNumber, + media_packet->SequenceNumber() + i + 1)), + Pointee(Property(&RtpPacketToSend::padding_size, + kMaxPaddingLength)), + Pointee(Property(&RtpPacketToSend::Timestamp, + media_packet->Timestamp())))))); + std::vector> padding_packets = + rtp_sender_->GeneratePadding(kPaddingTargetBytes, + /*media_has_been_sent=*/true); + ASSERT_THAT(padding_packets, SizeIs(1)); + rtp_sender_->SendToNetwork(std::move(padding_packets[0])); } // Send a regular video packet again. - capture_time_ms = clock_->TimeInMilliseconds(); - packet = BuildRtpPacket(kPayload, kMarkerBit, timestamp, capture_time_ms); - packet_size = packet->size(); - - packet->set_packet_type(RtpPacketMediaType::kVideo); - packet->set_allow_retransmission(true); - EXPECT_CALL( - mock_paced_sender_, - EnqueuePackets(Contains(AllOf( - Pointee(Property(&RtpPacketToSend::Ssrc, kSsrc)), - Pointee(Property(&RtpPacketToSend::SequenceNumber, seq_num)))))); - EXPECT_TRUE( - rtp_sender()->SendToNetwork(std::make_unique(*packet))); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - - // Process send bucket. - EXPECT_EQ(++total_packets_sent, transport_.packets_sent()); - EXPECT_EQ(packet_size, transport_.last_sent_packet().size()); - transport_.last_sent_packet().GetHeader(&rtp_header); - - // Verify sequence number and timestamp. - EXPECT_EQ(seq_num, rtp_header.sequenceNumber); - EXPECT_EQ(timestamp, rtp_header.timestamp); - // Verify transmission time offset. This packet is sent without delay. - EXPECT_EQ(0, rtp_header.extension.transmissionTimeOffset); - uint64_t expected_send_time = - ConvertMsToAbsSendTime(clock_->TimeInMilliseconds()); - EXPECT_EQ(expected_send_time, rtp_header.extension.absoluteSendTime); -} - -TEST_P(RtpSenderTest, OnSendPacketUpdated) { - EXPECT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId)); - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 10); - - EXPECT_CALL(send_packet_observer_, - OnSendPacket(kTransportSequenceNumber, _, _)) - .Times(1); - - EXPECT_CALL( - mock_paced_sender_, - EnqueuePackets(Contains(AllOf( - Pointee(Property(&RtpPacketToSend::Ssrc, kSsrc)), - Pointee(Property(&RtpPacketToSend::SequenceNumber, kSeqNum)))))); - auto packet = SendGenericPacket(); - packet->set_packet_type(RtpPacketMediaType::kVideo); - packet->SetExtension(kTransportSequenceNumber); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); + EXPECT_CALL(mock_paced_sender_, + EnqueuePackets(ElementsAre(AllOf( + Pointee(Property( + &RtpPacketToSend::SequenceNumber, + media_packet->SequenceNumber() + kNumPaddingPackets + 1)), + Pointee(Property(&RtpPacketToSend::Timestamp, + Gt(media_packet->Timestamp()))))))); - EXPECT_EQ(1, transport_.packets_sent()); + std::unique_ptr next_media_packet = + SendPacket(/*capture_time_ms=*/clock_->TimeInMilliseconds(), + /*payload_size=*/100); } -TEST_P(RtpSenderTest, OnSendPacketNotUpdatedForRetransmits) { - EXPECT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId)); - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 10); - - EXPECT_CALL(send_packet_observer_, OnSendPacket(_, _, _)).Times(0); +TEST_F(RtpSenderTest, NoPaddingAsFirstPacketWithoutBweExtensions) { + EXPECT_THAT(rtp_sender_->GeneratePadding(/*target_size_bytes=*/100, + /*media_has_been_sent=*/false), + IsEmpty()); - EXPECT_CALL( - mock_paced_sender_, - EnqueuePackets(Contains(AllOf( - Pointee(Property(&RtpPacketToSend::Ssrc, kSsrc)), - Pointee(Property(&RtpPacketToSend::SequenceNumber, kSeqNum)))))); - auto packet = SendGenericPacket(); - packet->set_packet_type(RtpPacketMediaType::kRetransmission); - packet->SetExtension(kTransportSequenceNumber); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - - EXPECT_EQ(1, transport_.packets_sent()); - EXPECT_TRUE(transport_.last_options_.is_retransmit); -} - -TEST_P(RtpSenderTestWithoutPacer, SendGenericVideo) { - const uint8_t kPayloadType = 127; - const VideoCodecType kCodecType = VideoCodecType::kVideoCodecGeneric; - FieldTrialBasedConfig field_trials; - RTPSenderVideo::Config video_config; - video_config.clock = clock_; - video_config.rtp_sender = rtp_sender(); - video_config.field_trials = &field_trials; - RTPSenderVideo rtp_sender_video(video_config); - uint8_t payload[] = {47, 11, 32, 93, 89}; - - // Send keyframe - RTPVideoHeader video_header; - video_header.frame_type = VideoFrameType::kVideoFrameKey; - ASSERT_TRUE(rtp_sender_video.SendVideo(kPayloadType, kCodecType, 1234, 4321, - payload, video_header, - kDefaultExpectedRetransmissionTimeMs)); - - auto sent_payload = transport_.last_sent_packet().payload(); - uint8_t generic_header = sent_payload[0]; - EXPECT_TRUE(generic_header & RtpFormatVideoGeneric::kKeyFrameBit); - EXPECT_TRUE(generic_header & RtpFormatVideoGeneric::kFirstPacketBit); - EXPECT_THAT(sent_payload.subview(1), ElementsAreArray(payload)); - - // Send delta frame - payload[0] = 13; - payload[1] = 42; - payload[4] = 13; - - video_header.frame_type = VideoFrameType::kVideoFrameDelta; - ASSERT_TRUE(rtp_sender_video.SendVideo(kPayloadType, kCodecType, 1234, 4321, - payload, video_header, - kDefaultExpectedRetransmissionTimeMs)); - - sent_payload = transport_.last_sent_packet().payload(); - generic_header = sent_payload[0]; - EXPECT_FALSE(generic_header & RtpFormatVideoGeneric::kKeyFrameBit); - EXPECT_TRUE(generic_header & RtpFormatVideoGeneric::kFirstPacketBit); - EXPECT_THAT(sent_payload.subview(1), ElementsAreArray(payload)); + // Don't send padding before media even with RTX. + EnableRtx(); + EXPECT_THAT(rtp_sender_->GeneratePadding(/*target_size_bytes=*/100, + /*media_has_been_sent=*/false), + IsEmpty()); } -TEST_P(RtpSenderTestWithoutPacer, SendRawVideo) { - const uint8_t kPayloadType = 111; - const uint8_t payload[] = {11, 22, 33, 44, 55}; - - FieldTrialBasedConfig field_trials; - RTPSenderVideo::Config video_config; - video_config.clock = clock_; - video_config.rtp_sender = rtp_sender(); - video_config.field_trials = &field_trials; - RTPSenderVideo rtp_sender_video(video_config); +TEST_F(RtpSenderTest, AllowPaddingAsFirstPacketOnRtxWithTransportCc) { + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + TransportSequenceNumber::kUri, kTransportSequenceNumberExtensionId)); - // Send a frame. - RTPVideoHeader video_header; - video_header.frame_type = VideoFrameType::kVideoFrameKey; - ASSERT_TRUE(rtp_sender_video.SendVideo(kPayloadType, absl::nullopt, 1234, - 4321, payload, video_header, - kDefaultExpectedRetransmissionTimeMs)); + // Padding can't be sent as first packet on media SSRC since we don't know + // what payload type to assign. + EXPECT_THAT(rtp_sender_->GeneratePadding(/*target_size_bytes=*/100, + /*media_has_been_sent=*/false), + IsEmpty()); - auto sent_payload = transport_.last_sent_packet().payload(); - EXPECT_THAT(sent_payload, ElementsAreArray(payload)); + // With transportcc padding can be sent as first packet on the RTX SSRC. + EnableRtx(); + EXPECT_THAT(rtp_sender_->GeneratePadding(/*target_size_bytes=*/100, + /*media_has_been_sent=*/false), + Not(IsEmpty())); } -TEST_P(RtpSenderTest, SendFlexfecPackets) { - constexpr uint32_t kTimestamp = 1234; - constexpr int kMediaPayloadType = 127; - constexpr VideoCodecType kCodecType = VideoCodecType::kVideoCodecGeneric; - constexpr int kFlexfecPayloadType = 118; - const std::vector kNoRtpExtensions; - const std::vector kNoRtpExtensionSizes; - FlexfecSender flexfec_sender(kFlexfecPayloadType, kFlexFecSsrc, kSsrc, kNoMid, - kNoRtpExtensions, kNoRtpExtensionSizes, - nullptr /* rtp_state */, clock_); - - // Reset |rtp_sender_| to use FlexFEC. - RtpRtcpInterface::Configuration config; - config.clock = clock_; - config.outgoing_transport = &transport_; - config.paced_sender = &mock_paced_sender_; - config.local_media_ssrc = kSsrc; - config.fec_generator = &flexfec_sender_; - config.event_log = &mock_rtc_event_log_; - config.send_packet_observer = &send_packet_observer_; - config.retransmission_rate_limiter = &retransmission_rate_limiter_; - config.field_trials = &field_trials_; - rtp_sender_context_ = - std::make_unique(config, &time_controller_); - - rtp_sender()->SetSequenceNumber(kSeqNum); - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 10); +TEST_F(RtpSenderTest, AllowPaddingAsFirstPacketOnRtxWithAbsSendTime) { + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + AbsoluteSendTime::kUri, kAbsoluteSendTimeExtensionId)); - FieldTrialBasedConfig field_trials; - RTPSenderVideo::Config video_config; - video_config.clock = clock_; - video_config.rtp_sender = rtp_sender(); - video_config.fec_type = flexfec_sender.GetFecType(); - video_config.fec_overhead_bytes = flexfec_sender.MaxPacketOverhead(); - video_config.fec_type = flexfec_sender.GetFecType(); - video_config.fec_overhead_bytes = flexfec_sender.MaxPacketOverhead(); - video_config.field_trials = &field_trials; - RTPSenderVideo rtp_sender_video(video_config); - - // Parameters selected to generate a single FEC packet per media packet. - FecProtectionParams params; - params.fec_rate = 15; - params.max_fec_frames = 1; - params.fec_mask_type = kFecMaskRandom; - flexfec_sender.SetProtectionParameters(params, params); - - uint16_t flexfec_seq_num; - RTPVideoHeader video_header; - - std::unique_ptr media_packet; - std::unique_ptr fec_packet; + // Padding can't be sent as first packet on media SSRC since we don't know + // what payload type to assign. + EXPECT_THAT(rtp_sender_->GeneratePadding(/*target_size_bytes=*/100, + /*media_has_been_sent=*/false), + IsEmpty()); - EXPECT_CALL(mock_paced_sender_, EnqueuePackets) - .WillOnce([&](std::vector> packets) { - for (auto& packet : packets) { - if (packet->packet_type() == RtpPacketMediaType::kVideo) { - EXPECT_EQ(packet->Ssrc(), kSsrc); - EXPECT_EQ(packet->SequenceNumber(), kSeqNum); - media_packet = std::move(packet); - - // Simulate RtpSenderEgress adding packet to fec generator. - flexfec_sender.AddPacketAndGenerateFec(*media_packet); - auto fec_packets = flexfec_sender.GetFecPackets(); - EXPECT_EQ(fec_packets.size(), 1u); - fec_packet = std::move(fec_packets[0]); - EXPECT_EQ(fec_packet->packet_type(), - RtpPacketMediaType::kForwardErrorCorrection); - EXPECT_EQ(fec_packet->Ssrc(), kFlexFecSsrc); - } else { - EXPECT_EQ(packet->packet_type(), - RtpPacketMediaType::kForwardErrorCorrection); - fec_packet = std::move(packet); - EXPECT_EQ(fec_packet->Ssrc(), kFlexFecSsrc); - } - } - }); - - video_header.frame_type = VideoFrameType::kVideoFrameKey; - EXPECT_TRUE(rtp_sender_video.SendVideo( - kMediaPayloadType, kCodecType, kTimestamp, clock_->TimeInMilliseconds(), - kPayloadData, video_header, kDefaultExpectedRetransmissionTimeMs)); - ASSERT_TRUE(media_packet != nullptr); - ASSERT_TRUE(fec_packet != nullptr); - - flexfec_seq_num = fec_packet->SequenceNumber(); - rtp_sender_context_->InjectPacket(std::move(media_packet), PacedPacketInfo()); - rtp_sender_context_->InjectPacket(std::move(fec_packet), PacedPacketInfo()); - - ASSERT_EQ(2, transport_.packets_sent()); - const RtpPacketReceived& sent_media_packet = transport_.sent_packets_[0]; - EXPECT_EQ(kMediaPayloadType, sent_media_packet.PayloadType()); - EXPECT_EQ(kSeqNum, sent_media_packet.SequenceNumber()); - EXPECT_EQ(kSsrc, sent_media_packet.Ssrc()); - const RtpPacketReceived& sent_flexfec_packet = transport_.sent_packets_[1]; - EXPECT_EQ(kFlexfecPayloadType, sent_flexfec_packet.PayloadType()); - EXPECT_EQ(flexfec_seq_num, sent_flexfec_packet.SequenceNumber()); - EXPECT_EQ(kFlexFecSsrc, sent_flexfec_packet.Ssrc()); + // With abs send time, padding can be sent as first packet on the RTX SSRC. + EnableRtx(); + EXPECT_THAT(rtp_sender_->GeneratePadding(/*target_size_bytes=*/100, + /*media_has_been_sent=*/false), + Not(IsEmpty())); } -TEST_P(RtpSenderTestWithoutPacer, SendFlexfecPackets) { - constexpr uint32_t kTimestamp = 1234; - constexpr int kMediaPayloadType = 127; - constexpr VideoCodecType kCodecType = VideoCodecType::kVideoCodecGeneric; - constexpr int kFlexfecPayloadType = 118; - const std::vector kNoRtpExtensions; - const std::vector kNoRtpExtensionSizes; - FlexfecSender flexfec_sender(kFlexfecPayloadType, kFlexFecSsrc, kSsrc, kNoMid, - kNoRtpExtensions, kNoRtpExtensionSizes, - nullptr /* rtp_state */, clock_); - - // Reset |rtp_sender_| to use FlexFEC. - RtpRtcpInterface::Configuration config; - config.clock = clock_; - config.outgoing_transport = &transport_; - config.local_media_ssrc = kSsrc; - config.fec_generator = &flexfec_sender; - config.event_log = &mock_rtc_event_log_; - config.send_packet_observer = &send_packet_observer_; - config.retransmission_rate_limiter = &retransmission_rate_limiter_; - config.field_trials = &field_trials_; - rtp_sender_context_ = - std::make_unique(config, &time_controller_); - - rtp_sender()->SetSequenceNumber(kSeqNum); - - FieldTrialBasedConfig field_trials; - RTPSenderVideo::Config video_config; - video_config.clock = clock_; - video_config.rtp_sender = rtp_sender(); - video_config.fec_type = flexfec_sender.GetFecType(); - video_config.fec_overhead_bytes = flexfec_sender_.MaxPacketOverhead(); - video_config.field_trials = &field_trials; - RTPSenderVideo rtp_sender_video(video_config); +TEST_F(RtpSenderTest, UpdatesTimestampsOnPlainRtxPadding) { + EnableRtx(); + // Timestamps as set based on capture time in RtpSenderTest. + const int64_t start_time = clock_->TimeInMilliseconds(); + const uint32_t start_timestamp = start_time * kTimestampTicksPerMs; - // Parameters selected to generate a single FEC packet per media packet. - FecProtectionParams params; - params.fec_rate = 15; - params.max_fec_frames = 1; - params.fec_mask_type = kFecMaskRandom; - rtp_egress()->SetFecProtectionParameters(params, params); + // Start by sending one media packet. + EXPECT_CALL( + mock_paced_sender_, + EnqueuePackets(ElementsAre(AllOf( + Pointee(Property(&RtpPacketToSend::padding_size, 0u)), + Pointee(Property(&RtpPacketToSend::Timestamp, start_timestamp)), + Pointee(Property(&RtpPacketToSend::capture_time_ms, start_time)))))); + std::unique_ptr media_packet = + SendPacket(start_time, /*payload_size=*/600); + + // Advance time before sending padding. + const TimeDelta kTimeDiff = TimeDelta::Millis(17); + time_controller_.AdvanceTime(kTimeDiff); + + // Timestamps on padding should be offset from the sent media. + EXPECT_THAT( + rtp_sender_->GeneratePadding(/*target_size_bytes=*/100, + /*media_has_been_sent=*/true), + Each(AllOf( + Pointee(Property(&RtpPacketToSend::padding_size, kMaxPaddingLength)), + Pointee(Property( + &RtpPacketToSend::Timestamp, + start_timestamp + (kTimestampTicksPerMs * kTimeDiff.ms()))), + Pointee(Property(&RtpPacketToSend::capture_time_ms, + start_time + kTimeDiff.ms()))))); +} + +TEST_F(RtpSenderTest, KeepsTimestampsOnPayloadPadding) { + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + TransportSequenceNumber::kUri, kTransportSequenceNumberExtensionId)); + EnableRtx(); + // Timestamps as set based on capture time in RtpSenderTest. + const int64_t start_time = clock_->TimeInMilliseconds(); + const uint32_t start_timestamp = start_time * kTimestampTicksPerMs; + const size_t kPayloadSize = 600; + const size_t kRtxHeaderSize = 2; - EXPECT_CALL(mock_rtc_event_log_, - LogProxy(SameRtcEventTypeAs(RtcEvent::Type::RtpPacketOutgoing))) - .Times(2); - RTPVideoHeader video_header; - video_header.frame_type = VideoFrameType::kVideoFrameKey; - EXPECT_TRUE(rtp_sender_video.SendVideo( - kMediaPayloadType, kCodecType, kTimestamp, clock_->TimeInMilliseconds(), - kPayloadData, video_header, kDefaultExpectedRetransmissionTimeMs)); - - ASSERT_EQ(2, transport_.packets_sent()); - const RtpPacketReceived& media_packet = transport_.sent_packets_[0]; - EXPECT_EQ(kMediaPayloadType, media_packet.PayloadType()); - EXPECT_EQ(kSsrc, media_packet.Ssrc()); - const RtpPacketReceived& flexfec_packet = transport_.sent_packets_[1]; - EXPECT_EQ(kFlexfecPayloadType, flexfec_packet.PayloadType()); - EXPECT_EQ(kFlexFecSsrc, flexfec_packet.Ssrc()); + // Start by sending one media packet and putting in the packet history. + EXPECT_CALL( + mock_paced_sender_, + EnqueuePackets(ElementsAre(AllOf( + Pointee(Property(&RtpPacketToSend::padding_size, 0u)), + Pointee(Property(&RtpPacketToSend::Timestamp, start_timestamp)), + Pointee(Property(&RtpPacketToSend::capture_time_ms, start_time)))))); + std::unique_ptr media_packet = + SendPacket(start_time, kPayloadSize); + packet_history_->PutRtpPacket(std::move(media_packet), start_time); + + // Advance time before sending padding. + const TimeDelta kTimeDiff = TimeDelta::Millis(17); + time_controller_.AdvanceTime(kTimeDiff); + + // Timestamps on payload padding should be set to original. + EXPECT_THAT( + rtp_sender_->GeneratePadding(/*target_size_bytes=*/100, + /*media_has_been_sent=*/true), + Each(AllOf( + Pointee(Property(&RtpPacketToSend::padding_size, 0u)), + Pointee(Property(&RtpPacketToSend::payload_size, + kPayloadSize + kRtxHeaderSize)), + Pointee(Property(&RtpPacketToSend::Timestamp, start_timestamp)), + Pointee(Property(&RtpPacketToSend::capture_time_ms, start_time))))); } // Test that the MID header extension is included on sent packets when // configured. -TEST_P(RtpSenderTestWithoutPacer, MidIncludedOnSentPackets) { +TEST_F(RtpSenderTest, MidIncludedOnSentPackets) { const char kMid[] = "mid"; - EnableMidSending(kMid); - // Send a couple packets. + // Send a couple packets, expect both packets to have the MID set. + EXPECT_CALL(mock_paced_sender_, + EnqueuePackets(ElementsAre(Pointee( + Property(&RtpPacketToSend::GetExtension, kMid))))) + .Times(2); SendGenericPacket(); SendGenericPacket(); - - // Expect both packets to have the MID set. - ASSERT_EQ(2u, transport_.sent_packets_.size()); - for (const RtpPacketReceived& packet : transport_.sent_packets_) { - std::string mid; - ASSERT_TRUE(packet.GetExtension(&mid)); - EXPECT_EQ(kMid, mid); - } } -TEST_P(RtpSenderTestWithoutPacer, RidIncludedOnSentPackets) { +TEST_F(RtpSenderTest, RidIncludedOnSentPackets) { const char kRid[] = "f"; - EnableRidSending(kRid); + EXPECT_CALL(mock_paced_sender_, + EnqueuePackets(ElementsAre(Pointee(Property( + &RtpPacketToSend::GetExtension, kRid))))); SendGenericPacket(); - - ASSERT_EQ(1u, transport_.sent_packets_.size()); - const RtpPacketReceived& packet = transport_.sent_packets_[0]; - std::string rid; - ASSERT_TRUE(packet.GetExtension(&rid)); - EXPECT_EQ(kRid, rid); } -TEST_P(RtpSenderTestWithoutPacer, RidIncludedOnRtxSentPackets) { +TEST_F(RtpSenderTest, RidIncludedOnRtxSentPackets) { const char kRid[] = "f"; - EnableRtx(); EnableRidSending(kRid); + EXPECT_CALL(mock_paced_sender_, + EnqueuePackets(ElementsAre(Pointee(AllOf( + Property(&RtpPacketToSend::GetExtension, kRid), + Property(&RtpPacketToSend::HasExtension, + false)))))) + .WillOnce([&](std::vector> packets) { + packet_history_->PutRtpPacket(std::move(packets[0]), + clock_->TimeInMilliseconds()); + }); SendGenericPacket(); - ASSERT_EQ(1u, transport_.sent_packets_.size()); - const RtpPacketReceived& packet = transport_.sent_packets_[0]; - std::string rid; - ASSERT_TRUE(packet.GetExtension(&rid)); - EXPECT_EQ(kRid, rid); - rid = kNoRid; - EXPECT_FALSE(packet.HasExtension()); - - uint16_t packet_id = packet.SequenceNumber(); - rtp_sender()->ReSendPacket(packet_id); - ASSERT_EQ(2u, transport_.sent_packets_.size()); - const RtpPacketReceived& rtx_packet = transport_.sent_packets_[1]; - ASSERT_TRUE(rtx_packet.GetExtension(&rid)); - EXPECT_EQ(kRid, rid); - EXPECT_FALSE(rtx_packet.HasExtension()); + + EXPECT_CALL( + mock_paced_sender_, + EnqueuePackets(ElementsAre(Pointee(AllOf( + Property(&RtpPacketToSend::GetExtension, kRid), + Property(&RtpPacketToSend::HasExtension, false)))))); + rtp_sender_->ReSendPacket(kSeqNum); } -TEST_P(RtpSenderTestWithoutPacer, MidAndRidNotIncludedOnSentPacketsAfterAck) { +TEST_F(RtpSenderTest, MidAndRidNotIncludedOnSentPacketsAfterAck) { const char kMid[] = "mid"; const char kRid[] = "f"; @@ -1525,53 +588,48 @@ TEST_P(RtpSenderTestWithoutPacer, MidAndRidNotIncludedOnSentPacketsAfterAck) { EnableRidSending(kRid); // This first packet should include both MID and RID. + EXPECT_CALL( + mock_paced_sender_, + EnqueuePackets(ElementsAre(Pointee(AllOf( + Property(&RtpPacketToSend::GetExtension, kMid), + Property(&RtpPacketToSend::GetExtension, kRid)))))); auto first_built_packet = SendGenericPacket(); - - rtp_sender()->OnReceivedAckOnSsrc(first_built_packet->SequenceNumber()); + rtp_sender_->OnReceivedAckOnSsrc(first_built_packet->SequenceNumber()); // The second packet should include neither since an ack was received. + EXPECT_CALL( + mock_paced_sender_, + EnqueuePackets(ElementsAre(Pointee(AllOf( + Property(&RtpPacketToSend::HasExtension, false), + Property(&RtpPacketToSend::HasExtension, false)))))); SendGenericPacket(); - - ASSERT_EQ(2u, transport_.sent_packets_.size()); - - const RtpPacketReceived& first_packet = transport_.sent_packets_[0]; - std::string mid, rid; - ASSERT_TRUE(first_packet.GetExtension(&mid)); - EXPECT_EQ(kMid, mid); - ASSERT_TRUE(first_packet.GetExtension(&rid)); - EXPECT_EQ(kRid, rid); - - const RtpPacketReceived& second_packet = transport_.sent_packets_[1]; - EXPECT_FALSE(second_packet.HasExtension()); - EXPECT_FALSE(second_packet.HasExtension()); } -TEST_P(RtpSenderTestWithoutPacer, - MidAndRidAlwaysIncludedOnSentPacketsWhenConfigured) { - SetUpRtpSender(false, false, /*always_send_mid_and_rid=*/true); +TEST_F(RtpSenderTest, MidAndRidAlwaysIncludedOnSentPacketsWhenConfigured) { + SetUpRtpSender(false, /*always_send_mid_and_rid=*/true, nullptr); const char kMid[] = "mid"; const char kRid[] = "f"; EnableMidSending(kMid); EnableRidSending(kRid); // Send two media packets: one before and one after the ack. - auto first_packet = SendGenericPacket(); - rtp_sender()->OnReceivedAckOnSsrc(first_packet->SequenceNumber()); - SendGenericPacket(); - // Due to the configuration, both sent packets should contain MID and RID. - ASSERT_EQ(2u, transport_.sent_packets_.size()); - for (const RtpPacketReceived& packet : transport_.sent_packets_) { - EXPECT_EQ(packet.GetExtension(), kMid); - EXPECT_EQ(packet.GetExtension(), kRid); - } + EXPECT_CALL( + mock_paced_sender_, + EnqueuePackets(ElementsAre(Pointee( + AllOf(Property(&RtpPacketToSend::GetExtension, kMid), + Property(&RtpPacketToSend::GetExtension, kRid)))))) + .Times(2); + auto first_built_packet = SendGenericPacket(); + rtp_sender_->OnReceivedAckOnSsrc(first_built_packet->SequenceNumber()); + SendGenericPacket(); } // Test that the first RTX packet includes both MID and RRID even if the packet // being retransmitted did not have MID or RID. The MID and RID are needed on // the first packets for a given SSRC, and RTX packets are sent on a separate // SSRC. -TEST_P(RtpSenderTestWithoutPacer, MidAndRidIncludedOnFirstRtxPacket) { +TEST_F(RtpSenderTest, MidAndRidIncludedOnFirstRtxPacket) { const char kMid[] = "mid"; const char kRid[] = "f"; @@ -1580,30 +638,32 @@ TEST_P(RtpSenderTestWithoutPacer, MidAndRidIncludedOnFirstRtxPacket) { EnableRidSending(kRid); // This first packet will include both MID and RID. + EXPECT_CALL(mock_paced_sender_, EnqueuePackets); auto first_built_packet = SendGenericPacket(); - rtp_sender()->OnReceivedAckOnSsrc(first_built_packet->SequenceNumber()); + rtp_sender_->OnReceivedAckOnSsrc(first_built_packet->SequenceNumber()); - // The second packet will include neither since an ack was received. + // The second packet will include neither since an ack was received, put + // it in the packet history for retransmission. + EXPECT_CALL(mock_paced_sender_, EnqueuePackets(SizeIs(1))) + .WillOnce([&](std::vector> packets) { + packet_history_->PutRtpPacket(std::move(packets[0]), + clock_->TimeInMilliseconds()); + }); auto second_built_packet = SendGenericPacket(); // The first RTX packet should include MID and RRID. - ASSERT_LT(0, - rtp_sender()->ReSendPacket(second_built_packet->SequenceNumber())); - - ASSERT_EQ(3u, transport_.sent_packets_.size()); - - const RtpPacketReceived& rtx_packet = transport_.sent_packets_[2]; - std::string mid, rrid; - ASSERT_TRUE(rtx_packet.GetExtension(&mid)); - EXPECT_EQ(kMid, mid); - ASSERT_TRUE(rtx_packet.GetExtension(&rrid)); - EXPECT_EQ(kRid, rrid); + EXPECT_CALL(mock_paced_sender_, + EnqueuePackets(ElementsAre(Pointee(AllOf( + Property(&RtpPacketToSend::GetExtension, kMid), + Property(&RtpPacketToSend::GetExtension, + kRid)))))); + rtp_sender_->ReSendPacket(second_built_packet->SequenceNumber()); } // Test that the RTX packets sent after receving an ACK on the RTX SSRC does // not include either MID or RRID even if the packet being retransmitted did // had a MID or RID. -TEST_P(RtpSenderTestWithoutPacer, MidAndRidNotIncludedOnRtxPacketsAfterAck) { +TEST_F(RtpSenderTest, MidAndRidNotIncludedOnRtxPacketsAfterAck) { const char kMid[] = "mid"; const char kRid[] = "f"; @@ -1612,41 +672,44 @@ TEST_P(RtpSenderTestWithoutPacer, MidAndRidNotIncludedOnRtxPacketsAfterAck) { EnableRidSending(kRid); // This first packet will include both MID and RID. + EXPECT_CALL(mock_paced_sender_, EnqueuePackets(SizeIs(1))) + .WillOnce([&](std::vector> packets) { + packet_history_->PutRtpPacket(std::move(packets[0]), + clock_->TimeInMilliseconds()); + }); auto first_built_packet = SendGenericPacket(); - rtp_sender()->OnReceivedAckOnSsrc(first_built_packet->SequenceNumber()); + rtp_sender_->OnReceivedAckOnSsrc(first_built_packet->SequenceNumber()); // The second packet will include neither since an ack was received. + EXPECT_CALL(mock_paced_sender_, EnqueuePackets(SizeIs(1))) + .WillOnce([&](std::vector> packets) { + packet_history_->PutRtpPacket(std::move(packets[0]), + clock_->TimeInMilliseconds()); + }); auto second_built_packet = SendGenericPacket(); // The first RTX packet will include MID and RRID. - ASSERT_LT(0, - rtp_sender()->ReSendPacket(second_built_packet->SequenceNumber())); - - ASSERT_EQ(3u, transport_.sent_packets_.size()); - const RtpPacketReceived& first_rtx_packet = transport_.sent_packets_[2]; - - rtp_sender()->OnReceivedAckOnRtxSsrc(first_rtx_packet.SequenceNumber()); + EXPECT_CALL(mock_paced_sender_, EnqueuePackets(SizeIs(1))) + .WillOnce([&](std::vector> packets) { + rtp_sender_->OnReceivedAckOnRtxSsrc(packets[0]->SequenceNumber()); + packet_history_->MarkPacketAsSent( + *packets[0]->retransmitted_sequence_number()); + }); + rtp_sender_->ReSendPacket(second_built_packet->SequenceNumber()); // The second and third RTX packets should not include MID nor RRID. - ASSERT_LT(0, - rtp_sender()->ReSendPacket(first_built_packet->SequenceNumber())); - ASSERT_LT(0, - rtp_sender()->ReSendPacket(second_built_packet->SequenceNumber())); - - ASSERT_EQ(5u, transport_.sent_packets_.size()); - - const RtpPacketReceived& second_rtx_packet = transport_.sent_packets_[3]; - EXPECT_FALSE(second_rtx_packet.HasExtension()); - EXPECT_FALSE(second_rtx_packet.HasExtension()); - - const RtpPacketReceived& third_rtx_packet = transport_.sent_packets_[4]; - EXPECT_FALSE(third_rtx_packet.HasExtension()); - EXPECT_FALSE(third_rtx_packet.HasExtension()); + EXPECT_CALL(mock_paced_sender_, + EnqueuePackets(ElementsAre(Pointee(AllOf( + Property(&RtpPacketToSend::HasExtension, false), + Property(&RtpPacketToSend::HasExtension, + false)))))) + .Times(2); + rtp_sender_->ReSendPacket(first_built_packet->SequenceNumber()); + rtp_sender_->ReSendPacket(second_built_packet->SequenceNumber()); } -TEST_P(RtpSenderTestWithoutPacer, - MidAndRidAlwaysIncludedOnRtxPacketsWhenConfigured) { - SetUpRtpSender(false, false, /*always_send_mid_and_rid=*/true); +TEST_F(RtpSenderTest, MidAndRidAlwaysIncludedOnRtxPacketsWhenConfigured) { + SetUpRtpSender(false, /*always_send_mid_and_rid=*/true, nullptr); const char kMid[] = "mid"; const char kRid[] = "f"; EnableRtx(); @@ -1654,63 +717,68 @@ TEST_P(RtpSenderTestWithoutPacer, EnableRidSending(kRid); // Send two media packets: one before and one after the ack. + EXPECT_CALL( + mock_paced_sender_, + EnqueuePackets(ElementsAre(Pointee( + AllOf(Property(&RtpPacketToSend::GetExtension, kMid), + Property(&RtpPacketToSend::GetExtension, kRid)))))) + .Times(2) + .WillRepeatedly( + [&](std::vector> packets) { + packet_history_->PutRtpPacket(std::move(packets[0]), + clock_->TimeInMilliseconds()); + }); auto media_packet1 = SendGenericPacket(); - rtp_sender()->OnReceivedAckOnSsrc(media_packet1->SequenceNumber()); + rtp_sender_->OnReceivedAckOnSsrc(media_packet1->SequenceNumber()); auto media_packet2 = SendGenericPacket(); // Send three RTX packets with different combinations of orders w.r.t. the // media and RTX acks. - ASSERT_LT(0, rtp_sender()->ReSendPacket(media_packet2->SequenceNumber())); - ASSERT_EQ(3u, transport_.sent_packets_.size()); - rtp_sender()->OnReceivedAckOnRtxSsrc( - transport_.sent_packets_[2].SequenceNumber()); - ASSERT_LT(0, rtp_sender()->ReSendPacket(media_packet1->SequenceNumber())); - ASSERT_LT(0, rtp_sender()->ReSendPacket(media_packet2->SequenceNumber())); - // Due to the configuration, all sent packets should contain MID // and either RID (media) or RRID (RTX). - ASSERT_EQ(5u, transport_.sent_packets_.size()); - for (const auto& packet : transport_.sent_packets_) { - EXPECT_EQ(packet.GetExtension(), kMid); - } - for (size_t i = 0; i < 2; ++i) { - const RtpPacketReceived& packet = transport_.sent_packets_[i]; - EXPECT_EQ(packet.GetExtension(), kRid); - } - for (size_t i = 2; i < transport_.sent_packets_.size(); ++i) { - const RtpPacketReceived& packet = transport_.sent_packets_[i]; - EXPECT_EQ(packet.GetExtension(), kRid); - } + EXPECT_CALL(mock_paced_sender_, + EnqueuePackets(ElementsAre(Pointee(AllOf( + Property(&RtpPacketToSend::GetExtension, kMid), + Property(&RtpPacketToSend::GetExtension, + kRid)))))) + .Times(3) + .WillRepeatedly( + [&](std::vector> packets) { + rtp_sender_->OnReceivedAckOnRtxSsrc(packets[0]->SequenceNumber()); + packet_history_->MarkPacketAsSent( + *packets[0]->retransmitted_sequence_number()); + }); + rtp_sender_->ReSendPacket(media_packet2->SequenceNumber()); + rtp_sender_->ReSendPacket(media_packet1->SequenceNumber()); + rtp_sender_->ReSendPacket(media_packet2->SequenceNumber()); } // Test that if the RtpState indicates an ACK has been received on that SSRC // then neither the MID nor RID header extensions will be sent. -TEST_P(RtpSenderTestWithoutPacer, - MidAndRidNotIncludedOnSentPacketsAfterRtpStateRestored) { +TEST_F(RtpSenderTest, MidAndRidNotIncludedOnSentPacketsAfterRtpStateRestored) { const char kMid[] = "mid"; const char kRid[] = "f"; EnableMidSending(kMid); EnableRidSending(kRid); - RtpState state = rtp_sender()->GetRtpState(); + RtpState state = rtp_sender_->GetRtpState(); EXPECT_FALSE(state.ssrc_has_acked); state.ssrc_has_acked = true; - rtp_sender()->SetRtpState(state); + rtp_sender_->SetRtpState(state); + EXPECT_CALL( + mock_paced_sender_, + EnqueuePackets(ElementsAre(Pointee(AllOf( + Property(&RtpPacketToSend::HasExtension, false), + Property(&RtpPacketToSend::HasExtension, false)))))); SendGenericPacket(); - - ASSERT_EQ(1u, transport_.sent_packets_.size()); - const RtpPacketReceived& packet = transport_.sent_packets_[0]; - EXPECT_FALSE(packet.HasExtension()); - EXPECT_FALSE(packet.HasExtension()); } // Test that if the RTX RtpState indicates an ACK has been received on that // RTX SSRC then neither the MID nor RRID header extensions will be sent on // RTX packets. -TEST_P(RtpSenderTestWithoutPacer, - MidAndRridNotIncludedOnRtxPacketsAfterRtpStateRestored) { +TEST_F(RtpSenderTest, MidAndRridNotIncludedOnRtxPacketsAfterRtpStateRestored) { const char kMid[] = "mid"; const char kRid[] = "f"; @@ -1718,767 +786,256 @@ TEST_P(RtpSenderTestWithoutPacer, EnableMidSending(kMid); EnableRidSending(kRid); - RtpState rtx_state = rtp_sender()->GetRtxRtpState(); + RtpState rtx_state = rtp_sender_->GetRtxRtpState(); EXPECT_FALSE(rtx_state.ssrc_has_acked); rtx_state.ssrc_has_acked = true; - rtp_sender()->SetRtxRtpState(rtx_state); + rtp_sender_->SetRtxRtpState(rtx_state); + EXPECT_CALL(mock_paced_sender_, EnqueuePackets(SizeIs(1))) + .WillOnce([&](std::vector> packets) { + packet_history_->PutRtpPacket(std::move(packets[0]), + clock_->TimeInMilliseconds()); + }); auto built_packet = SendGenericPacket(); - ASSERT_LT(0, rtp_sender()->ReSendPacket(built_packet->SequenceNumber())); - - ASSERT_EQ(2u, transport_.sent_packets_.size()); - const RtpPacketReceived& rtx_packet = transport_.sent_packets_[1]; - EXPECT_FALSE(rtx_packet.HasExtension()); - EXPECT_FALSE(rtx_packet.HasExtension()); -} - -TEST_P(RtpSenderTest, FecOverheadRate) { - constexpr uint32_t kTimestamp = 1234; - constexpr int kMediaPayloadType = 127; - constexpr VideoCodecType kCodecType = VideoCodecType::kVideoCodecGeneric; - constexpr int kFlexfecPayloadType = 118; - const std::vector kNoRtpExtensions; - const std::vector kNoRtpExtensionSizes; - FlexfecSender flexfec_sender(kFlexfecPayloadType, kFlexFecSsrc, kSsrc, kNoMid, - kNoRtpExtensions, kNoRtpExtensionSizes, - nullptr /* rtp_state */, clock_); - - // Reset |rtp_sender_| to use this FlexFEC instance. - SetUpRtpSender(false, false, false, &flexfec_sender); - - FieldTrialBasedConfig field_trials; - RTPSenderVideo::Config video_config; - video_config.clock = clock_; - video_config.rtp_sender = rtp_sender(); - video_config.fec_type = flexfec_sender.GetFecType(); - video_config.fec_overhead_bytes = flexfec_sender.MaxPacketOverhead(); - video_config.field_trials = &field_trials; - RTPSenderVideo rtp_sender_video(video_config); - // Parameters selected to generate a single FEC packet per media packet. - FecProtectionParams params; - params.fec_rate = 15; - params.max_fec_frames = 1; - params.fec_mask_type = kFecMaskRandom; - rtp_egress()->SetFecProtectionParameters(params, params); - - constexpr size_t kNumMediaPackets = 10; - constexpr size_t kNumFecPackets = kNumMediaPackets; - constexpr int64_t kTimeBetweenPacketsMs = 10; - for (size_t i = 0; i < kNumMediaPackets; ++i) { - RTPVideoHeader video_header; - - video_header.frame_type = VideoFrameType::kVideoFrameKey; - EXPECT_TRUE(rtp_sender_video.SendVideo( - kMediaPayloadType, kCodecType, kTimestamp, clock_->TimeInMilliseconds(), - kPayloadData, video_header, kDefaultExpectedRetransmissionTimeMs)); - - time_controller_.AdvanceTime(TimeDelta::Millis(kTimeBetweenPacketsMs)); - } - constexpr size_t kRtpHeaderLength = 12; - constexpr size_t kFlexfecHeaderLength = 20; - constexpr size_t kGenericCodecHeaderLength = 1; - constexpr size_t kPayloadLength = sizeof(kPayloadData); - constexpr size_t kPacketLength = kRtpHeaderLength + kFlexfecHeaderLength + - kGenericCodecHeaderLength + kPayloadLength; - - EXPECT_NEAR( - kNumFecPackets * kPacketLength * 8 / - (kNumFecPackets * kTimeBetweenPacketsMs / 1000.0f), - rtp_egress() - ->GetSendRates()[RtpPacketMediaType::kForwardErrorCorrection] - .bps(), - 500); -} - -TEST_P(RtpSenderTest, BitrateCallbacks) { - class TestCallback : public BitrateStatisticsObserver { - public: - TestCallback() - : BitrateStatisticsObserver(), - num_calls_(0), - ssrc_(0), - total_bitrate_(0), - retransmit_bitrate_(0) {} - ~TestCallback() override = default; - - void Notify(uint32_t total_bitrate, - uint32_t retransmit_bitrate, - uint32_t ssrc) override { - ++num_calls_; - ssrc_ = ssrc; - total_bitrate_ = total_bitrate; - retransmit_bitrate_ = retransmit_bitrate; - } - - uint32_t num_calls_; - uint32_t ssrc_; - uint32_t total_bitrate_; - uint32_t retransmit_bitrate_; - } callback; - - RtpRtcpInterface::Configuration config; - config.clock = clock_; - config.outgoing_transport = &transport_; - config.local_media_ssrc = kSsrc; - config.send_bitrate_observer = &callback; - config.retransmission_rate_limiter = &retransmission_rate_limiter_; - rtp_sender_context_ = - std::make_unique(config, &time_controller_); - - FieldTrialBasedConfig field_trials; - RTPSenderVideo::Config video_config; - video_config.clock = clock_; - video_config.rtp_sender = rtp_sender(); - video_config.field_trials = &field_trials; - RTPSenderVideo rtp_sender_video(video_config); - const VideoCodecType kCodecType = VideoCodecType::kVideoCodecGeneric; - const uint8_t kPayloadType = 127; - - // Simulate kNumPackets sent with kPacketInterval ms intervals, with the - // number of packets selected so that we fill (but don't overflow) the one - // second averaging window. - const uint32_t kWindowSizeMs = 1000; - const uint32_t kPacketInterval = 20; - const uint32_t kNumPackets = - (kWindowSizeMs - kPacketInterval) / kPacketInterval; - // Overhead = 12 bytes RTP header + 1 byte generic header. - const uint32_t kPacketOverhead = 13; - - uint8_t payload[] = {47, 11, 32, 93, 89}; - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 1); - uint32_t ssrc = rtp_sender()->SSRC(); - - // Send a few frames. - RTPVideoHeader video_header; - for (uint32_t i = 0; i < kNumPackets; ++i) { - video_header.frame_type = VideoFrameType::kVideoFrameKey; - ASSERT_TRUE(rtp_sender_video.SendVideo( - kPayloadType, kCodecType, 1234, 4321, payload, video_header, - kDefaultExpectedRetransmissionTimeMs)); - time_controller_.AdvanceTime(TimeDelta::Millis(kPacketInterval)); - } - - // We get one call for every stats updated, thus two calls since both the - // stream stats and the retransmit stats are updated once. - EXPECT_EQ(kNumPackets, callback.num_calls_); - EXPECT_EQ(ssrc, callback.ssrc_); - const uint32_t kTotalPacketSize = kPacketOverhead + sizeof(payload); - // Bitrate measured over delta between last and first timestamp, plus one. - const uint32_t kExpectedWindowMs = (kNumPackets - 1) * kPacketInterval + 1; - const uint32_t kExpectedBitsAccumulated = kTotalPacketSize * kNumPackets * 8; - const uint32_t kExpectedRateBps = - (kExpectedBitsAccumulated * 1000 + (kExpectedWindowMs / 2)) / - kExpectedWindowMs; - EXPECT_EQ(kExpectedRateBps, callback.total_bitrate_); -} - -TEST_P(RtpSenderTestWithoutPacer, StreamDataCountersCallbacks) { - const uint8_t kPayloadType = 127; - const VideoCodecType kCodecType = VideoCodecType::kVideoCodecGeneric; - FieldTrialBasedConfig field_trials; - RTPSenderVideo::Config video_config; - video_config.clock = clock_; - video_config.rtp_sender = rtp_sender(); - video_config.field_trials = &field_trials; - RTPSenderVideo rtp_sender_video(video_config); - uint8_t payload[] = {47, 11, 32, 93, 89}; - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 1); - uint32_t ssrc = rtp_sender()->SSRC(); - - // Send a frame. - RTPVideoHeader video_header; - video_header.frame_type = VideoFrameType::kVideoFrameKey; - ASSERT_TRUE(rtp_sender_video.SendVideo(kPayloadType, kCodecType, 1234, 4321, - payload, video_header, - kDefaultExpectedRetransmissionTimeMs)); - StreamDataCounters expected; - expected.transmitted.payload_bytes = 6; - expected.transmitted.header_bytes = 12; - expected.transmitted.padding_bytes = 0; - expected.transmitted.packets = 1; - expected.retransmitted.payload_bytes = 0; - expected.retransmitted.header_bytes = 0; - expected.retransmitted.padding_bytes = 0; - expected.retransmitted.packets = 0; - expected.fec.packets = 0; - rtp_stats_callback_.Matches(ssrc, expected); - - // Retransmit a frame. - uint16_t seqno = rtp_sender()->SequenceNumber() - 1; - rtp_sender()->ReSendPacket(seqno); - expected.transmitted.payload_bytes = 12; - expected.transmitted.header_bytes = 24; - expected.transmitted.packets = 2; - expected.retransmitted.payload_bytes = 6; - expected.retransmitted.header_bytes = 12; - expected.retransmitted.padding_bytes = 0; - expected.retransmitted.packets = 1; - rtp_stats_callback_.Matches(ssrc, expected); - - // Send padding. - GenerateAndSendPadding(kMaxPaddingSize); - expected.transmitted.payload_bytes = 12; - expected.transmitted.header_bytes = 36; - expected.transmitted.padding_bytes = kMaxPaddingSize; - expected.transmitted.packets = 3; - rtp_stats_callback_.Matches(ssrc, expected); -} - -TEST_P(RtpSenderTestWithoutPacer, StreamDataCountersCallbacksUlpfec) { - const uint8_t kRedPayloadType = 96; - const uint8_t kUlpfecPayloadType = 97; - const uint8_t kPayloadType = 127; - const VideoCodecType kCodecType = VideoCodecType::kVideoCodecGeneric; - - UlpfecGenerator ulpfec_generator(kRedPayloadType, kUlpfecPayloadType, clock_); - SetUpRtpSender(false, false, false, &ulpfec_generator); - RTPSenderVideo::Config video_config; - video_config.clock = clock_; - video_config.rtp_sender = rtp_sender(); - video_config.field_trials = &field_trials_; - video_config.red_payload_type = kRedPayloadType; - video_config.fec_type = ulpfec_generator.GetFecType(); - video_config.fec_overhead_bytes = ulpfec_generator.MaxPacketOverhead(); - RTPSenderVideo rtp_sender_video(video_config); - uint8_t payload[] = {47, 11, 32, 93, 89}; - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 1); - uint32_t ssrc = rtp_sender()->SSRC(); - - RTPVideoHeader video_header; - StreamDataCounters expected; - - // Send ULPFEC. - FecProtectionParams fec_params; - fec_params.fec_mask_type = kFecMaskRandom; - fec_params.fec_rate = 1; - fec_params.max_fec_frames = 1; - rtp_egress()->SetFecProtectionParameters(fec_params, fec_params); - video_header.frame_type = VideoFrameType::kVideoFrameDelta; - ASSERT_TRUE(rtp_sender_video.SendVideo(kPayloadType, kCodecType, 1234, 4321, - payload, video_header, - kDefaultExpectedRetransmissionTimeMs)); - expected.transmitted.payload_bytes = 28; - expected.transmitted.header_bytes = 24; - expected.transmitted.packets = 2; - expected.fec.packets = 1; - rtp_stats_callback_.Matches(ssrc, expected); -} - -TEST_P(RtpSenderTestWithoutPacer, BytesReportedCorrectly) { - // XXX const char* kPayloadName = "GENERIC"; - const uint8_t kPayloadType = 127; - rtp_sender()->SetRtxPayloadType(kPayloadType - 1, kPayloadType); - rtp_sender()->SetRtxStatus(kRtxRetransmitted | kRtxRedundantPayloads); - SendGenericPacket(); - // Will send 2 full-size padding packets. - GenerateAndSendPadding(1); - GenerateAndSendPadding(1); - - StreamDataCounters rtp_stats; - StreamDataCounters rtx_stats; - rtp_egress()->GetDataCounters(&rtp_stats, &rtx_stats); - - // Payload - EXPECT_GT(rtp_stats.first_packet_time_ms, -1); - EXPECT_EQ(rtp_stats.transmitted.payload_bytes, sizeof(kPayloadData)); - EXPECT_EQ(rtp_stats.transmitted.header_bytes, 12u); - EXPECT_EQ(rtp_stats.transmitted.padding_bytes, 0u); - EXPECT_EQ(rtx_stats.transmitted.payload_bytes, 0u); - EXPECT_EQ(rtx_stats.transmitted.header_bytes, 24u); - EXPECT_EQ(rtx_stats.transmitted.padding_bytes, 2 * kMaxPaddingSize); - - EXPECT_EQ(rtp_stats.transmitted.TotalBytes(), - rtp_stats.transmitted.payload_bytes + - rtp_stats.transmitted.header_bytes + - rtp_stats.transmitted.padding_bytes); - EXPECT_EQ(rtx_stats.transmitted.TotalBytes(), - rtx_stats.transmitted.payload_bytes + - rtx_stats.transmitted.header_bytes + - rtx_stats.transmitted.padding_bytes); - - EXPECT_EQ( - transport_.total_bytes_sent_, - rtp_stats.transmitted.TotalBytes() + rtx_stats.transmitted.TotalBytes()); + EXPECT_CALL( + mock_paced_sender_, + EnqueuePackets(ElementsAre(Pointee(AllOf( + Property(&RtpPacketToSend::HasExtension, false), + Property(&RtpPacketToSend::HasExtension, false)))))); + ASSERT_LT(0, rtp_sender_->ReSendPacket(built_packet->SequenceNumber())); } -TEST_P(RtpSenderTestWithoutPacer, RespectsNackBitrateLimit) { +TEST_F(RtpSenderTest, RespectsNackBitrateLimit) { const int32_t kPacketSize = 1400; const int32_t kNumPackets = 30; retransmission_rate_limiter_.SetMaxRate(kPacketSize * kNumPackets * 8); - rtp_sender_context_->packet_history_.SetStorePacketsStatus( + packet_history_->SetStorePacketsStatus( RtpPacketHistory::StorageMode::kStoreAndCull, kNumPackets); - const uint16_t kStartSequenceNumber = rtp_sender()->SequenceNumber(); + const uint16_t kStartSequenceNumber = rtp_sender_->SequenceNumber(); std::vector sequence_numbers; for (int32_t i = 0; i < kNumPackets; ++i) { sequence_numbers.push_back(kStartSequenceNumber + i); time_controller_.AdvanceTime(TimeDelta::Millis(1)); + EXPECT_CALL(mock_paced_sender_, EnqueuePackets(SizeIs(1))) + .WillOnce([&](std::vector> packets) { + packet_history_->PutRtpPacket(std::move(packets[0]), + clock_->TimeInMilliseconds()); + }); SendPacket(clock_->TimeInMilliseconds(), kPacketSize); } - EXPECT_EQ(kNumPackets, transport_.packets_sent()); time_controller_.AdvanceTime(TimeDelta::Millis(1000 - kNumPackets)); // Resending should work - brings the bandwidth up to the limit. // NACK bitrate is capped to the same bitrate as the encoder, since the max // protection overhead is 50% (see MediaOptimization::SetTargetRates). - rtp_sender()->OnReceivedNack(sequence_numbers, 0); - EXPECT_EQ(kNumPackets * 2, transport_.packets_sent()); + EXPECT_CALL(mock_paced_sender_, EnqueuePackets(ElementsAre(Pointee(Property( + &RtpPacketToSend::packet_type, + RtpPacketMediaType::kRetransmission))))) + .Times(kNumPackets) + .WillRepeatedly( + [&](std::vector> packets) { + for (const auto& packet : packets) { + packet_history_->MarkPacketAsSent( + *packet->retransmitted_sequence_number()); + } + }); + rtp_sender_->OnReceivedNack(sequence_numbers, 0); // Must be at least 5ms in between retransmission attempts. time_controller_.AdvanceTime(TimeDelta::Millis(5)); // Resending should not work, bandwidth exceeded. - rtp_sender()->OnReceivedNack(sequence_numbers, 0); - EXPECT_EQ(kNumPackets * 2, transport_.packets_sent()); + EXPECT_CALL(mock_paced_sender_, EnqueuePackets).Times(0); + rtp_sender_->OnReceivedNack(sequence_numbers, 0); } -TEST_P(RtpSenderTest, UpdatingCsrcsUpdatedOverhead) { - RtpRtcpInterface::Configuration config; - config.clock = clock_; - config.outgoing_transport = &transport_; - config.local_media_ssrc = kSsrc; - config.retransmission_rate_limiter = &retransmission_rate_limiter_; - rtp_sender_context_ = - std::make_unique(config, &time_controller_); +TEST_F(RtpSenderTest, UpdatingCsrcsUpdatedOverhead) { + RtpRtcpInterface::Configuration config = GetDefaultConfig(); + config.rtx_send_ssrc = {}; + CreateSender(config); // Base RTP overhead is 12B. - EXPECT_EQ(rtp_sender()->ExpectedPerPacketOverhead(), 12u); + EXPECT_EQ(rtp_sender_->ExpectedPerPacketOverhead(), 12u); // Adding two csrcs adds 2*4 bytes to the header. - rtp_sender()->SetCsrcs({1, 2}); - EXPECT_EQ(rtp_sender()->ExpectedPerPacketOverhead(), 20u); + rtp_sender_->SetCsrcs({1, 2}); + EXPECT_EQ(rtp_sender_->ExpectedPerPacketOverhead(), 20u); } -TEST_P(RtpSenderTest, OnOverheadChanged) { - RtpRtcpInterface::Configuration config; - config.clock = clock_; - config.outgoing_transport = &transport_; - config.local_media_ssrc = kSsrc; - config.retransmission_rate_limiter = &retransmission_rate_limiter_; - rtp_sender_context_ = - std::make_unique(config, &time_controller_); +TEST_F(RtpSenderTest, OnOverheadChanged) { + RtpRtcpInterface::Configuration config = GetDefaultConfig(); + config.rtx_send_ssrc = {}; + CreateSender(config); // Base RTP overhead is 12B. - EXPECT_EQ(rtp_sender()->ExpectedPerPacketOverhead(), 12u); + EXPECT_EQ(rtp_sender_->ExpectedPerPacketOverhead(), 12u); - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionTransmissionTimeOffset, - kTransmissionTimeOffsetExtensionId); + rtp_sender_->RegisterRtpHeaderExtension(TransmissionOffset::kUri, + kTransmissionTimeOffsetExtensionId); // TransmissionTimeOffset extension has a size of 3B, but with the addition // of header index and rounding to 4 byte boundary we end up with 20B total. - EXPECT_EQ(rtp_sender()->ExpectedPerPacketOverhead(), 20u); + EXPECT_EQ(rtp_sender_->ExpectedPerPacketOverhead(), 20u); } -TEST_P(RtpSenderTest, CountMidOnlyUntilAcked) { - RtpRtcpInterface::Configuration config; - config.clock = clock_; - config.outgoing_transport = &transport_; - config.local_media_ssrc = kSsrc; - config.retransmission_rate_limiter = &retransmission_rate_limiter_; - rtp_sender_context_ = - std::make_unique(config, &time_controller_); +TEST_F(RtpSenderTest, CountMidOnlyUntilAcked) { + RtpRtcpInterface::Configuration config = GetDefaultConfig(); + config.rtx_send_ssrc = {}; + CreateSender(config); // Base RTP overhead is 12B. - EXPECT_EQ(rtp_sender()->ExpectedPerPacketOverhead(), 12u); + EXPECT_EQ(rtp_sender_->ExpectedPerPacketOverhead(), 12u); - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionMid, kMidExtensionId); - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionRtpStreamId, - kRidExtensionId); + rtp_sender_->RegisterRtpHeaderExtension(RtpMid::kUri, kMidExtensionId); + rtp_sender_->RegisterRtpHeaderExtension(RtpStreamId::kUri, kRidExtensionId); // Counted only if set. - EXPECT_EQ(rtp_sender()->ExpectedPerPacketOverhead(), 12u); - rtp_sender()->SetMid("foo"); - EXPECT_EQ(rtp_sender()->ExpectedPerPacketOverhead(), 36u); - rtp_sender()->SetRid("bar"); - EXPECT_EQ(rtp_sender()->ExpectedPerPacketOverhead(), 52u); + EXPECT_EQ(rtp_sender_->ExpectedPerPacketOverhead(), 12u); + rtp_sender_->SetMid("foo"); + EXPECT_EQ(rtp_sender_->ExpectedPerPacketOverhead(), 36u); + rtp_sender_->SetRid("bar"); + EXPECT_EQ(rtp_sender_->ExpectedPerPacketOverhead(), 52u); // Ack received, mid/rid no longer sent. - rtp_sender()->OnReceivedAckOnSsrc(0); - EXPECT_EQ(rtp_sender()->ExpectedPerPacketOverhead(), 12u); + rtp_sender_->OnReceivedAckOnSsrc(0); + EXPECT_EQ(rtp_sender_->ExpectedPerPacketOverhead(), 12u); } -TEST_P(RtpSenderTest, DontCountVolatileExtensionsIntoOverhead) { - RtpRtcpInterface::Configuration config; - config.clock = clock_; - config.outgoing_transport = &transport_; - config.local_media_ssrc = kSsrc; - config.retransmission_rate_limiter = &retransmission_rate_limiter_; - rtp_sender_context_ = - std::make_unique(config, &time_controller_); +TEST_F(RtpSenderTest, DontCountVolatileExtensionsIntoOverhead) { + RtpRtcpInterface::Configuration config = GetDefaultConfig(); + config.rtx_send_ssrc = {}; + CreateSender(config); // Base RTP overhead is 12B. - EXPECT_EQ(rtp_sender()->ExpectedPerPacketOverhead(), 12u); - - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionInbandComfortNoise, 1); - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionAbsoluteCaptureTime, 2); - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionVideoRotation, 3); - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionPlayoutDelay, 4); - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionVideoContentType, 5); - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionVideoTiming, 6); - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionRepairedRtpStreamId, 7); - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionColorSpace, 8); + EXPECT_EQ(rtp_sender_->ExpectedPerPacketOverhead(), 12u); + + rtp_sender_->RegisterRtpHeaderExtension(InbandComfortNoiseExtension::kUri, 1); + rtp_sender_->RegisterRtpHeaderExtension(AbsoluteCaptureTimeExtension::kUri, + 2); + rtp_sender_->RegisterRtpHeaderExtension(VideoOrientation::kUri, 3); + rtp_sender_->RegisterRtpHeaderExtension(PlayoutDelayLimits::kUri, 4); + rtp_sender_->RegisterRtpHeaderExtension(VideoContentTypeExtension::kUri, 5); + rtp_sender_->RegisterRtpHeaderExtension(VideoTimingExtension::kUri, 6); + rtp_sender_->RegisterRtpHeaderExtension(RepairedRtpStreamId::kUri, 7); + rtp_sender_->RegisterRtpHeaderExtension(ColorSpaceExtension::kUri, 8); // Still only 12B counted since can't count on above being sent. - EXPECT_EQ(rtp_sender()->ExpectedPerPacketOverhead(), 12u); -} - -TEST_P(RtpSenderTest, SendPacketMatchesVideo) { - std::unique_ptr packet = - BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->set_packet_type(RtpPacketMediaType::kVideo); - - // Verify sent with correct SSRC. - packet = BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->SetSsrc(kSsrc); - packet->set_packet_type(RtpPacketMediaType::kVideo); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - EXPECT_EQ(transport_.packets_sent(), 1); -} - -TEST_P(RtpSenderTest, SendPacketMatchesAudio) { - std::unique_ptr packet = - BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->set_packet_type(RtpPacketMediaType::kAudio); - - // Verify sent with correct SSRC. - packet = BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->SetSsrc(kSsrc); - packet->set_packet_type(RtpPacketMediaType::kAudio); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - EXPECT_EQ(transport_.packets_sent(), 1); -} - -TEST_P(RtpSenderTest, SendPacketMatchesRetransmissions) { - std::unique_ptr packet = - BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->set_packet_type(RtpPacketMediaType::kRetransmission); - - // Verify sent with correct SSRC (non-RTX). - packet = BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->SetSsrc(kSsrc); - packet->set_packet_type(RtpPacketMediaType::kRetransmission); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - EXPECT_EQ(transport_.packets_sent(), 1); - - // RTX retransmission. - packet = BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->SetSsrc(kRtxSsrc); - packet->set_packet_type(RtpPacketMediaType::kRetransmission); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - EXPECT_EQ(transport_.packets_sent(), 2); -} - -TEST_P(RtpSenderTest, SendPacketMatchesPadding) { - std::unique_ptr packet = - BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->set_packet_type(RtpPacketMediaType::kPadding); - - // Verify sent with correct SSRC (non-RTX). - packet = BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->SetSsrc(kSsrc); - packet->set_packet_type(RtpPacketMediaType::kPadding); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - EXPECT_EQ(transport_.packets_sent(), 1); - - // RTX padding. - packet = BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->SetSsrc(kRtxSsrc); - packet->set_packet_type(RtpPacketMediaType::kPadding); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - EXPECT_EQ(transport_.packets_sent(), 2); -} - -TEST_P(RtpSenderTest, SendPacketMatchesFlexfec) { - std::unique_ptr packet = - BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->set_packet_type(RtpPacketMediaType::kForwardErrorCorrection); - - // Verify sent with correct SSRC. - packet = BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->SetSsrc(kFlexFecSsrc); - packet->set_packet_type(RtpPacketMediaType::kForwardErrorCorrection); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - EXPECT_EQ(transport_.packets_sent(), 1); -} - -TEST_P(RtpSenderTest, SendPacketMatchesUlpfec) { - std::unique_ptr packet = - BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->set_packet_type(RtpPacketMediaType::kForwardErrorCorrection); - - // Verify sent with correct SSRC. - packet = BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->SetSsrc(kSsrc); - packet->set_packet_type(RtpPacketMediaType::kForwardErrorCorrection); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - EXPECT_EQ(transport_.packets_sent(), 1); + EXPECT_EQ(rtp_sender_->ExpectedPerPacketOverhead(), 12u); } -TEST_P(RtpSenderTest, SendPacketHandlesRetransmissionHistory) { - rtp_sender_context_->packet_history_.SetStorePacketsStatus( +TEST_F(RtpSenderTest, SendPacketHandlesRetransmissionHistory) { + packet_history_->SetStorePacketsStatus( RtpPacketHistory::StorageMode::kStoreAndCull, 10); // Ignore calls to EnqueuePackets() for this test. EXPECT_CALL(mock_paced_sender_, EnqueuePackets).WillRepeatedly(Return()); - // Build a media packet and send it. + // Build a media packet and put in the packet history. std::unique_ptr packet = BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); const uint16_t media_sequence_number = packet->SequenceNumber(); - packet->set_packet_type(RtpPacketMediaType::kVideo); packet->set_allow_retransmission(true); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); + packet_history_->PutRtpPacket(std::move(packet), + clock_->TimeInMilliseconds()); - // Simulate retransmission request. + // Simulate successful retransmission request. time_controller_.AdvanceTime(TimeDelta::Millis(30)); - EXPECT_GT(rtp_sender()->ReSendPacket(media_sequence_number), 0); + EXPECT_THAT(rtp_sender_->ReSendPacket(media_sequence_number), Gt(0)); // Packet already pending, retransmission not allowed. time_controller_.AdvanceTime(TimeDelta::Millis(30)); - EXPECT_EQ(rtp_sender()->ReSendPacket(media_sequence_number), 0); - - // Packet exiting pacer, mark as not longer pending. - packet = BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - EXPECT_NE(packet->SequenceNumber(), media_sequence_number); - packet->set_packet_type(RtpPacketMediaType::kRetransmission); - packet->SetSsrc(kRtxSsrc); - packet->set_retransmitted_sequence_number(media_sequence_number); - packet->set_allow_retransmission(false); - uint16_t seq_no = packet->SequenceNumber(); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); + EXPECT_THAT(rtp_sender_->ReSendPacket(media_sequence_number), Eq(0)); + + // Simulate packet exiting pacer, mark as not longer pending. + packet_history_->MarkPacketAsSent(media_sequence_number); // Retransmissions allowed again. time_controller_.AdvanceTime(TimeDelta::Millis(30)); - EXPECT_GT(rtp_sender()->ReSendPacket(media_sequence_number), 0); - - // Retransmission of RTX packet should not be allowed. - EXPECT_EQ(rtp_sender()->ReSendPacket(seq_no), 0); + EXPECT_THAT(rtp_sender_->ReSendPacket(media_sequence_number), Gt(0)); } -TEST_P(RtpSenderTest, SendPacketUpdatesExtensions) { - ASSERT_EQ(rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransmissionTimeOffset, - kTransmissionTimeOffsetExtensionId), - 0); - ASSERT_EQ(rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionAbsoluteSendTime, kAbsoluteSendTimeExtensionId), - 0); - ASSERT_EQ(rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionVideoTiming, - kVideoTimingExtensionId), - 0); - - std::unique_ptr packet = - BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->set_packetization_finish_time_ms(clock_->TimeInMilliseconds()); - - const int32_t kDiffMs = 10; - time_controller_.AdvanceTime(TimeDelta::Millis(kDiffMs)); - - packet->set_packet_type(RtpPacketMediaType::kVideo); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - - const RtpPacketReceived& received_packet = transport_.last_sent_packet(); - - EXPECT_EQ(received_packet.GetExtension(), kDiffMs * 90); - - EXPECT_EQ(received_packet.GetExtension(), - AbsoluteSendTime::MsTo24Bits(clock_->TimeInMilliseconds())); - - VideoSendTiming timing; - EXPECT_TRUE(received_packet.GetExtension(&timing)); - EXPECT_EQ(timing.pacer_exit_delta_ms, kDiffMs); -} +TEST_F(RtpSenderTest, MarksRetransmittedPackets) { + packet_history_->SetStorePacketsStatus( + RtpPacketHistory::StorageMode::kStoreAndCull, 10); -TEST_P(RtpSenderTest, SendPacketSetsPacketOptions) { - const uint16_t kPacketId = 42; - ASSERT_EQ(rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId), - 0); + // Build a media packet and put in the packet history. std::unique_ptr packet = BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->SetExtension(kPacketId); - - packet->set_packet_type(RtpPacketMediaType::kVideo); - EXPECT_CALL(send_packet_observer_, OnSendPacket); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - - EXPECT_EQ(transport_.last_options_.packet_id, kPacketId); - EXPECT_TRUE(transport_.last_options_.included_in_allocation); - EXPECT_TRUE(transport_.last_options_.included_in_feedback); - EXPECT_FALSE(transport_.last_options_.is_retransmit); - - // Send another packet as retransmission, verify options are populated. - packet = BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - packet->SetExtension(kPacketId + 1); - packet->set_packet_type(RtpPacketMediaType::kRetransmission); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - EXPECT_TRUE(transport_.last_options_.is_retransmit); -} - -TEST_P(RtpSenderTest, SendPacketUpdatesStats) { - const size_t kPayloadSize = 1000; - - StrictMock send_side_delay_observer; - - RtpRtcpInterface::Configuration config; - config.clock = clock_; - config.outgoing_transport = &transport_; - config.local_media_ssrc = kSsrc; - config.rtx_send_ssrc = kRtxSsrc; - config.fec_generator = &flexfec_sender_; - config.send_side_delay_observer = &send_side_delay_observer; - config.event_log = &mock_rtc_event_log_; - config.send_packet_observer = &send_packet_observer_; - rtp_sender_context_ = - std::make_unique(config, &time_controller_); - ASSERT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId)); - - const int64_t capture_time_ms = clock_->TimeInMilliseconds(); - - std::unique_ptr video_packet = - BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - video_packet->set_packet_type(RtpPacketMediaType::kVideo); - video_packet->SetPayloadSize(kPayloadSize); - video_packet->SetExtension(1); - - std::unique_ptr rtx_packet = - BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - rtx_packet->SetSsrc(kRtxSsrc); - rtx_packet->set_packet_type(RtpPacketMediaType::kRetransmission); - rtx_packet->SetPayloadSize(kPayloadSize); - rtx_packet->SetExtension(2); - - std::unique_ptr fec_packet = - BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); - fec_packet->SetSsrc(kFlexFecSsrc); - fec_packet->set_packet_type(RtpPacketMediaType::kForwardErrorCorrection); - fec_packet->SetPayloadSize(kPayloadSize); - fec_packet->SetExtension(3); - - const int64_t kDiffMs = 25; - time_controller_.AdvanceTime(TimeDelta::Millis(kDiffMs)); + const uint16_t media_sequence_number = packet->SequenceNumber(); + packet->set_allow_retransmission(true); + packet_history_->PutRtpPacket(std::move(packet), + clock_->TimeInMilliseconds()); - EXPECT_CALL(send_side_delay_observer, - SendSideDelayUpdated(kDiffMs, kDiffMs, kDiffMs, kSsrc)); + // Expect a retransmission packet marked with which packet it is a + // retransmit of. EXPECT_CALL( - send_side_delay_observer, - SendSideDelayUpdated(kDiffMs, kDiffMs, 2 * kDiffMs, kFlexFecSsrc)); - - EXPECT_CALL(send_packet_observer_, OnSendPacket(1, capture_time_ms, kSsrc)); - - rtp_sender_context_->InjectPacket(std::move(video_packet), PacedPacketInfo()); - - // Send packet observer not called for padding/retransmissions. - EXPECT_CALL(send_packet_observer_, OnSendPacket(2, _, _)).Times(0); - rtp_sender_context_->InjectPacket(std::move(rtx_packet), PacedPacketInfo()); - - EXPECT_CALL(send_packet_observer_, - OnSendPacket(3, capture_time_ms, kFlexFecSsrc)); - rtp_sender_context_->InjectPacket(std::move(fec_packet), PacedPacketInfo()); - - StreamDataCounters rtp_stats; - StreamDataCounters rtx_stats; - rtp_egress()->GetDataCounters(&rtp_stats, &rtx_stats); - EXPECT_EQ(rtp_stats.transmitted.packets, 2u); - EXPECT_EQ(rtp_stats.fec.packets, 1u); - EXPECT_EQ(rtx_stats.retransmitted.packets, 1u); + mock_paced_sender_, + EnqueuePackets(ElementsAre(AllOf( + Pointee(Property(&RtpPacketToSend::packet_type, + RtpPacketMediaType::kRetransmission)), + Pointee(Property(&RtpPacketToSend::retransmitted_sequence_number, + Eq(media_sequence_number))))))); + EXPECT_THAT(rtp_sender_->ReSendPacket(media_sequence_number), Gt(0)); } -TEST_P(RtpSenderTest, GeneratedPaddingHasBweExtensions) { +TEST_F(RtpSenderTest, GeneratedPaddingHasBweExtensions) { // Min requested size in order to use RTX payload. const size_t kMinPaddingSize = 50; + EnableRtx(); - rtp_sender()->SetRtxStatus(kRtxRetransmitted | kRtxRedundantPayloads); - rtp_sender()->SetRtxPayloadType(kRtxPayload, kPayload); - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 1); + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + TransmissionOffset::kUri, kTransmissionTimeOffsetExtensionId)); + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + AbsoluteSendTime::kUri, kAbsoluteSendTimeExtensionId)); + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + TransportSequenceNumber::kUri, kTransportSequenceNumberExtensionId)); - ASSERT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransmissionTimeOffset, - kTransmissionTimeOffsetExtensionId)); - ASSERT_EQ(0, - rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionAbsoluteSendTime, kAbsoluteSendTimeExtensionId)); - ASSERT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId)); - - // Send a payload packet first, to enable padding and populate the packet - // history. + // Put a packet in the history, in order to facilitate payload padding. std::unique_ptr packet = BuildRtpPacket(kPayload, true, 0, clock_->TimeInMilliseconds()); packet->set_allow_retransmission(true); packet->SetPayloadSize(kMinPaddingSize); packet->set_packet_type(RtpPacketMediaType::kVideo); - EXPECT_CALL(send_packet_observer_, OnSendPacket).Times(1); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); + packet_history_->PutRtpPacket(std::move(packet), + clock_->TimeInMilliseconds()); // Generate a plain padding packet, check that extensions are registered. std::vector> generated_packets = - rtp_sender()->GeneratePadding(/*target_size_bytes=*/1, true); + rtp_sender_->GeneratePadding(/*target_size_bytes=*/1, true); ASSERT_THAT(generated_packets, SizeIs(1)); auto& plain_padding = generated_packets.front(); EXPECT_GT(plain_padding->padding_size(), 0u); EXPECT_TRUE(plain_padding->HasExtension()); EXPECT_TRUE(plain_padding->HasExtension()); EXPECT_TRUE(plain_padding->HasExtension()); - - // Verify all header extensions have been written. - rtp_sender_context_->InjectPacket(std::move(plain_padding), - PacedPacketInfo()); - const auto& sent_plain_padding = transport_.last_sent_packet(); - EXPECT_TRUE(sent_plain_padding.HasExtension()); - EXPECT_TRUE(sent_plain_padding.HasExtension()); - EXPECT_TRUE(sent_plain_padding.HasExtension()); - webrtc::RTPHeader rtp_header; - sent_plain_padding.GetHeader(&rtp_header); - EXPECT_TRUE(rtp_header.extension.hasAbsoluteSendTime); - EXPECT_TRUE(rtp_header.extension.hasTransmissionTimeOffset); - EXPECT_TRUE(rtp_header.extension.hasTransportSequenceNumber); + EXPECT_GT(plain_padding->padding_size(), 0u); // Generate a payload padding packets, check that extensions are registered. - generated_packets = rtp_sender()->GeneratePadding(kMinPaddingSize, true); + generated_packets = rtp_sender_->GeneratePadding(kMinPaddingSize, true); ASSERT_EQ(generated_packets.size(), 1u); auto& payload_padding = generated_packets.front(); EXPECT_EQ(payload_padding->padding_size(), 0u); EXPECT_TRUE(payload_padding->HasExtension()); EXPECT_TRUE(payload_padding->HasExtension()); EXPECT_TRUE(payload_padding->HasExtension()); - - // Verify all header extensions have been written. - rtp_sender_context_->InjectPacket(std::move(payload_padding), - PacedPacketInfo()); - const auto& sent_payload_padding = transport_.last_sent_packet(); - EXPECT_TRUE(sent_payload_padding.HasExtension()); - EXPECT_TRUE(sent_payload_padding.HasExtension()); - EXPECT_TRUE(sent_payload_padding.HasExtension()); - sent_payload_padding.GetHeader(&rtp_header); - EXPECT_TRUE(rtp_header.extension.hasAbsoluteSendTime); - EXPECT_TRUE(rtp_header.extension.hasTransmissionTimeOffset); - EXPECT_TRUE(rtp_header.extension.hasTransportSequenceNumber); + EXPECT_GT(payload_padding->payload_size(), 0u); } -TEST_P(RtpSenderTest, GeneratePaddingResendsOldPacketsWithRtx) { +TEST_F(RtpSenderTest, GeneratePaddingResendsOldPacketsWithRtx) { // Min requested size in order to use RTX payload. const size_t kMinPaddingSize = 50; - rtp_sender()->SetRtxStatus(kRtxRetransmitted | kRtxRedundantPayloads); - rtp_sender()->SetRtxPayloadType(kRtxPayload, kPayload); - rtp_sender_context_->packet_history_.SetStorePacketsStatus( + rtp_sender_->SetRtxStatus(kRtxRetransmitted | kRtxRedundantPayloads); + rtp_sender_->SetRtxPayloadType(kRtxPayload, kPayload); + packet_history_->SetStorePacketsStatus( RtpPacketHistory::StorageMode::kStoreAndCull, 1); - ASSERT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId)); + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + TransportSequenceNumber::kUri, kTransportSequenceNumberExtensionId)); const size_t kPayloadPacketSize = kMinPaddingSize; std::unique_ptr packet = @@ -2486,15 +1043,13 @@ TEST_P(RtpSenderTest, GeneratePaddingResendsOldPacketsWithRtx) { packet->set_allow_retransmission(true); packet->SetPayloadSize(kPayloadPacketSize); packet->set_packet_type(RtpPacketMediaType::kVideo); - - // Send a dummy video packet so it ends up in the packet history. - EXPECT_CALL(send_packet_observer_, OnSendPacket).Times(1); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); + packet_history_->PutRtpPacket(std::move(packet), + clock_->TimeInMilliseconds()); // Generated padding has large enough budget that the video packet should be // retransmitted as padding. std::vector> generated_packets = - rtp_sender()->GeneratePadding(kMinPaddingSize, true); + rtp_sender_->GeneratePadding(kMinPaddingSize, true); ASSERT_EQ(generated_packets.size(), 1u); auto& padding_packet = generated_packets.front(); EXPECT_EQ(padding_packet->packet_type(), RtpPacketMediaType::kPadding); @@ -2507,7 +1062,7 @@ TEST_P(RtpSenderTest, GeneratePaddingResendsOldPacketsWithRtx) { size_t padding_bytes_generated = 0; generated_packets = - rtp_sender()->GeneratePadding(kPaddingBytesRequested, true); + rtp_sender_->GeneratePadding(kPaddingBytesRequested, true); EXPECT_EQ(generated_packets.size(), 1u); for (auto& packet : generated_packets) { EXPECT_EQ(packet->packet_type(), RtpPacketMediaType::kPadding); @@ -2520,19 +1075,18 @@ TEST_P(RtpSenderTest, GeneratePaddingResendsOldPacketsWithRtx) { EXPECT_EQ(padding_bytes_generated, kMaxPaddingSize); } -TEST_P(RtpSenderTest, LimitsPayloadPaddingSize) { +TEST_F(RtpSenderTest, LimitsPayloadPaddingSize) { // Limit RTX payload padding to 2x target size. const double kFactor = 2.0; field_trials_.SetMaxPaddingFactor(kFactor); - SetUpRtpSender(true, false, false); - rtp_sender()->SetRtxStatus(kRtxRetransmitted | kRtxRedundantPayloads); - rtp_sender()->SetRtxPayloadType(kRtxPayload, kPayload); - rtp_sender_context_->packet_history_.SetStorePacketsStatus( + SetUpRtpSender(false, false, nullptr); + rtp_sender_->SetRtxStatus(kRtxRetransmitted | kRtxRedundantPayloads); + rtp_sender_->SetRtxPayloadType(kRtxPayload, kPayload); + packet_history_->SetStorePacketsStatus( RtpPacketHistory::StorageMode::kStoreAndCull, 1); - ASSERT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId)); + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + TransportSequenceNumber::kUri, kTransportSequenceNumberExtensionId)); // Send a dummy video packet so it ends up in the packet history. const size_t kPayloadPacketSize = 1234u; @@ -2541,8 +1095,8 @@ TEST_P(RtpSenderTest, LimitsPayloadPaddingSize) { packet->set_allow_retransmission(true); packet->SetPayloadSize(kPayloadPacketSize); packet->set_packet_type(RtpPacketMediaType::kVideo); - EXPECT_CALL(send_packet_observer_, OnSendPacket).Times(1); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); + packet_history_->PutRtpPacket(std::move(packet), + clock_->TimeInMilliseconds()); // Smallest target size that will result in the sent packet being returned as // padding. @@ -2552,30 +1106,27 @@ TEST_P(RtpSenderTest, LimitsPayloadPaddingSize) { // Generated padding has large enough budget that the video packet should be // retransmitted as padding. EXPECT_THAT( - rtp_sender()->GeneratePadding(kMinTargerSizeForPayload, true), + rtp_sender_->GeneratePadding(kMinTargerSizeForPayload, true), AllOf(Not(IsEmpty()), Each(Pointee(Property(&RtpPacketToSend::padding_size, Eq(0u)))))); // If payload padding is > 2x requested size, plain padding is returned // instead. EXPECT_THAT( - rtp_sender()->GeneratePadding(kMinTargerSizeForPayload - 1, true), + rtp_sender_->GeneratePadding(kMinTargerSizeForPayload - 1, true), AllOf(Not(IsEmpty()), Each(Pointee(Property(&RtpPacketToSend::padding_size, Gt(0u)))))); } -TEST_P(RtpSenderTest, GeneratePaddingCreatesPurePaddingWithoutRtx) { - rtp_sender_context_->packet_history_.SetStorePacketsStatus( +TEST_F(RtpSenderTest, GeneratePaddingCreatesPurePaddingWithoutRtx) { + packet_history_->SetStorePacketsStatus( RtpPacketHistory::StorageMode::kStoreAndCull, 1); - ASSERT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransmissionTimeOffset, - kTransmissionTimeOffsetExtensionId)); - ASSERT_EQ(0, - rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionAbsoluteSendTime, kAbsoluteSendTimeExtensionId)); - ASSERT_EQ(0, rtp_sender()->RegisterRtpHeaderExtension( - kRtpExtensionTransportSequenceNumber, - kTransportSequenceNumberExtensionId)); + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + TransmissionOffset::kUri, kTransmissionTimeOffsetExtensionId)); + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + AbsoluteSendTime::kUri, kAbsoluteSendTimeExtensionId)); + ASSERT_TRUE(rtp_sender_->RegisterRtpHeaderExtension( + TransportSequenceNumber::kUri, kTransportSequenceNumberExtensionId)); const size_t kPayloadPacketSize = 1234; // Send a dummy video packet so it ends up in the packet history. Since we @@ -2585,8 +1136,8 @@ TEST_P(RtpSenderTest, GeneratePaddingCreatesPurePaddingWithoutRtx) { packet->set_allow_retransmission(true); packet->SetPayloadSize(kPayloadPacketSize); packet->set_packet_type(RtpPacketMediaType::kVideo); - EXPECT_CALL(send_packet_observer_, OnSendPacket).Times(1); - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); + packet_history_->PutRtpPacket(std::move(packet), + clock_->TimeInMilliseconds()); // Payload padding not available without RTX, only generate plain padding on // the media SSRC. @@ -2598,7 +1149,7 @@ TEST_P(RtpSenderTest, GeneratePaddingCreatesPurePaddingWithoutRtx) { (kPaddingBytesRequested + kMaxPaddingSize - 1) / kMaxPaddingSize; size_t padding_bytes_generated = 0; std::vector> padding_packets = - rtp_sender()->GeneratePadding(kPaddingBytesRequested, true); + rtp_sender_->GeneratePadding(kPaddingBytesRequested, true); EXPECT_EQ(padding_packets.size(), kExpectedNumPaddingPackets); for (auto& packet : padding_packets) { EXPECT_EQ(packet->packet_type(), RtpPacketMediaType::kPadding); @@ -2609,221 +1160,142 @@ TEST_P(RtpSenderTest, GeneratePaddingCreatesPurePaddingWithoutRtx) { EXPECT_TRUE(packet->HasExtension()); EXPECT_TRUE(packet->HasExtension()); EXPECT_TRUE(packet->HasExtension()); - - // Verify all header extensions are received. - rtp_sender_context_->InjectPacket(std::move(packet), PacedPacketInfo()); - webrtc::RTPHeader rtp_header; - transport_.last_sent_packet().GetHeader(&rtp_header); - EXPECT_TRUE(rtp_header.extension.hasAbsoluteSendTime); - EXPECT_TRUE(rtp_header.extension.hasTransmissionTimeOffset); - EXPECT_TRUE(rtp_header.extension.hasTransportSequenceNumber); } EXPECT_EQ(padding_bytes_generated, kExpectedNumPaddingPackets * kMaxPaddingSize); } -TEST_P(RtpSenderTest, SupportsPadding) { +TEST_F(RtpSenderTest, SupportsPadding) { bool kSendingMediaStats[] = {true, false}; bool kEnableRedundantPayloads[] = {true, false}; - RTPExtensionType kBweExtensionTypes[] = { - kRtpExtensionTransportSequenceNumber, - kRtpExtensionTransportSequenceNumber02, kRtpExtensionAbsoluteSendTime, - kRtpExtensionTransmissionTimeOffset}; + absl::string_view kBweExtensionUris[] = { + TransportSequenceNumber::kUri, TransportSequenceNumberV2::kUri, + AbsoluteSendTime::kUri, TransmissionOffset::kUri}; const int kExtensionsId = 7; for (bool sending_media : kSendingMediaStats) { - rtp_sender()->SetSendingMediaStatus(sending_media); + rtp_sender_->SetSendingMediaStatus(sending_media); for (bool redundant_payloads : kEnableRedundantPayloads) { int rtx_mode = kRtxRetransmitted; if (redundant_payloads) { rtx_mode |= kRtxRedundantPayloads; } - rtp_sender()->SetRtxStatus(rtx_mode); + rtp_sender_->SetRtxStatus(rtx_mode); - for (auto extension_type : kBweExtensionTypes) { - EXPECT_FALSE(rtp_sender()->SupportsPadding()); - rtp_sender()->RegisterRtpHeaderExtension(extension_type, kExtensionsId); + for (auto extension_uri : kBweExtensionUris) { + EXPECT_FALSE(rtp_sender_->SupportsPadding()); + rtp_sender_->RegisterRtpHeaderExtension(extension_uri, kExtensionsId); if (!sending_media) { - EXPECT_FALSE(rtp_sender()->SupportsPadding()); + EXPECT_FALSE(rtp_sender_->SupportsPadding()); } else { - EXPECT_TRUE(rtp_sender()->SupportsPadding()); + EXPECT_TRUE(rtp_sender_->SupportsPadding()); if (redundant_payloads) { - EXPECT_TRUE(rtp_sender()->SupportsRtxPayloadPadding()); + EXPECT_TRUE(rtp_sender_->SupportsRtxPayloadPadding()); } else { - EXPECT_FALSE(rtp_sender()->SupportsRtxPayloadPadding()); + EXPECT_FALSE(rtp_sender_->SupportsRtxPayloadPadding()); } } - rtp_sender()->DeregisterRtpHeaderExtension(extension_type); - EXPECT_FALSE(rtp_sender()->SupportsPadding()); + rtp_sender_->DeregisterRtpHeaderExtension(extension_uri); + EXPECT_FALSE(rtp_sender_->SupportsPadding()); } } } } -TEST_P(RtpSenderTest, SetsCaptureTimeAndPopulatesTransmissionOffset) { - rtp_sender()->RegisterRtpHeaderExtension(kRtpExtensionTransmissionTimeOffset, - kTransmissionTimeOffsetExtensionId); - - rtp_sender()->SetSendingMediaStatus(true); - rtp_sender()->SetRtxStatus(kRtxRetransmitted | kRtxRedundantPayloads); - rtp_sender()->SetRtxPayloadType(kRtxPayload, kPayload); - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 10); - - const int64_t kMissingCaptureTimeMs = 0; - const uint32_t kTimestampTicksPerMs = 90; - const int64_t kOffsetMs = 10; - - auto packet = - BuildRtpPacket(kPayload, kMarkerBit, clock_->TimeInMilliseconds(), - kMissingCaptureTimeMs); - packet->set_packet_type(RtpPacketMediaType::kVideo); - packet->ReserveExtension(); - packet->AllocatePayload(sizeof(kPayloadData)); - - std::unique_ptr packet_to_pace; - EXPECT_CALL(mock_paced_sender_, EnqueuePackets) - .WillOnce([&](std::vector> packets) { - EXPECT_EQ(packets.size(), 1u); - EXPECT_GT(packets[0]->capture_time_ms(), 0); - packet_to_pace = std::move(packets[0]); - }); +TEST_F(RtpSenderTest, SetsCaptureTimeOnRtxRetransmissions) { + EnableRtx(); + // Put a packet in the packet history, with current time as capture time. + const int64_t start_time_ms = clock_->TimeInMilliseconds(); + std::unique_ptr packet = + BuildRtpPacket(kPayload, kMarkerBit, start_time_ms, + /*capture_time_ms=*/start_time_ms); packet->set_allow_retransmission(true); - EXPECT_TRUE(rtp_sender()->SendToNetwork(std::move(packet))); - - time_controller_.AdvanceTime(TimeDelta::Millis(kOffsetMs)); - - rtp_sender_context_->InjectPacket(std::move(packet_to_pace), - PacedPacketInfo()); - - EXPECT_EQ(1, transport_.packets_sent()); - absl::optional transmission_time_extension = - transport_.sent_packets_.back().GetExtension(); - ASSERT_TRUE(transmission_time_extension.has_value()); - EXPECT_EQ(*transmission_time_extension, kOffsetMs * kTimestampTicksPerMs); - - // Retransmit packet. The RTX packet should get the same capture time as the - // original packet, so offset is delta from original packet to now. - time_controller_.AdvanceTime(TimeDelta::Millis(kOffsetMs)); - - std::unique_ptr rtx_packet_to_pace; - EXPECT_CALL(mock_paced_sender_, EnqueuePackets) - .WillOnce([&](std::vector> packets) { - EXPECT_GT(packets[0]->capture_time_ms(), 0); - rtx_packet_to_pace = std::move(packets[0]); - }); + packet_history_->PutRtpPacket(std::move(packet), start_time_ms); - EXPECT_GT(rtp_sender()->ReSendPacket(kSeqNum), 0); - rtp_sender_context_->InjectPacket(std::move(rtx_packet_to_pace), - PacedPacketInfo()); + // Advance time, request an RTX retransmission. Capture timestamp should be + // preserved. + time_controller_.AdvanceTime(TimeDelta::Millis(10)); - EXPECT_EQ(2, transport_.packets_sent()); - transmission_time_extension = - transport_.sent_packets_.back().GetExtension(); - ASSERT_TRUE(transmission_time_extension.has_value()); - EXPECT_EQ(*transmission_time_extension, 2 * kOffsetMs * kTimestampTicksPerMs); + EXPECT_CALL(mock_paced_sender_, + EnqueuePackets(ElementsAre(Pointee(Property( + &RtpPacketToSend::capture_time_ms, start_time_ms))))); + EXPECT_GT(rtp_sender_->ReSendPacket(kSeqNum), 0); } -TEST_P(RtpSenderTestWithoutPacer, ClearHistoryOnSequenceNumberCange) { - const int64_t kRtt = 10; - - rtp_sender()->SetSendingMediaStatus(true); - rtp_sender()->SetRtxStatus(kRtxRetransmitted | kRtxRedundantPayloads); - rtp_sender()->SetRtxPayloadType(kRtxPayload, kPayload); - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 10); - rtp_sender_context_->packet_history_.SetRtt(kRtt); +TEST_F(RtpSenderTest, ClearHistoryOnSequenceNumberCange) { + EnableRtx(); - // Send a packet and record its sequence numbers. - SendGenericPacket(); - ASSERT_EQ(1u, transport_.sent_packets_.size()); - const uint16_t packet_seqence_number = - transport_.sent_packets_.back().SequenceNumber(); + // Put a packet in the packet history. + const int64_t now_ms = clock_->TimeInMilliseconds(); + std::unique_ptr packet = + BuildRtpPacket(kPayload, kMarkerBit, now_ms, now_ms); + packet->set_allow_retransmission(true); + packet_history_->PutRtpPacket(std::move(packet), now_ms); - // Advance time and make sure it can be retransmitted, even if we try to set - // the ssrc the what it already is. - rtp_sender()->SetSequenceNumber(rtp_sender()->SequenceNumber()); - time_controller_.AdvanceTime(TimeDelta::Millis(kRtt)); - EXPECT_GT(rtp_sender()->ReSendPacket(packet_seqence_number), 0); + EXPECT_TRUE(packet_history_->GetPacketState(kSeqNum)); - // Change the sequence number, then move the time and try to retransmit again. - // The old packet should now be gone. - rtp_sender()->SetSequenceNumber(rtp_sender()->SequenceNumber() - 1); - time_controller_.AdvanceTime(TimeDelta::Millis(kRtt)); - EXPECT_EQ(rtp_sender()->ReSendPacket(packet_seqence_number), 0); + // Update the sequence number of the RTP module, verify packet has been + // removed. + rtp_sender_->SetSequenceNumber(rtp_sender_->SequenceNumber() - 1); + EXPECT_FALSE(packet_history_->GetPacketState(kSeqNum)); } -TEST_P(RtpSenderTest, IgnoresNackAfterDisablingMedia) { +TEST_F(RtpSenderTest, IgnoresNackAfterDisablingMedia) { const int64_t kRtt = 10; - rtp_sender()->SetSendingMediaStatus(true); - rtp_sender()->SetRtxStatus(kRtxRetransmitted | kRtxRedundantPayloads); - rtp_sender()->SetRtxPayloadType(kRtxPayload, kPayload); - rtp_sender_context_->packet_history_.SetStorePacketsStatus( - RtpPacketHistory::StorageMode::kStoreAndCull, 10); - rtp_sender_context_->packet_history_.SetRtt(kRtt); - - // Send a packet so it is in the packet history. - std::unique_ptr packet_to_pace; - EXPECT_CALL(mock_paced_sender_, EnqueuePackets) - .WillOnce([&](std::vector> packets) { - packet_to_pace = std::move(packets[0]); - }); - - SendGenericPacket(); - rtp_sender_context_->InjectPacket(std::move(packet_to_pace), - PacedPacketInfo()); + EnableRtx(); + packet_history_->SetRtt(kRtt); - ASSERT_EQ(1u, transport_.sent_packets_.size()); + // Put a packet in the history. + const int64_t start_time_ms = clock_->TimeInMilliseconds(); + std::unique_ptr packet = + BuildRtpPacket(kPayload, kMarkerBit, start_time_ms, + /*capture_time_ms=*/start_time_ms); + packet->set_allow_retransmission(true); + packet_history_->PutRtpPacket(std::move(packet), start_time_ms); - // Disable media sending and try to retransmit the packet, it should fail. - rtp_sender()->SetSendingMediaStatus(false); - time_controller_.AdvanceTime(TimeDelta::Millis(kRtt)); - EXPECT_LT(rtp_sender()->ReSendPacket(kSeqNum), 0); + // Disable media sending and try to retransmit the packet, it should fail. + rtp_sender_->SetSendingMediaStatus(false); + time_controller_.AdvanceTime(TimeDelta::Millis(kRtt)); + EXPECT_LT(rtp_sender_->ReSendPacket(kSeqNum), 0); } -TEST_P(RtpSenderTest, DoesntFecProtectRetransmissions) { +TEST_F(RtpSenderTest, DoesntFecProtectRetransmissions) { // Set up retranmission without RTX, so that a plain copy of the old packet is // re-sent instead. const int64_t kRtt = 10; - rtp_sender()->SetSendingMediaStatus(true); - rtp_sender()->SetRtxStatus(kRtxOff); - rtp_sender_context_->packet_history_.SetStorePacketsStatus( + rtp_sender_->SetSendingMediaStatus(true); + rtp_sender_->SetRtxStatus(kRtxOff); + packet_history_->SetStorePacketsStatus( RtpPacketHistory::StorageMode::kStoreAndCull, 10); - rtp_sender_context_->packet_history_.SetRtt(kRtt); + packet_history_->SetRtt(kRtt); - // Send a packet so it is in the packet history, make sure to mark it for - // FEC protection. - std::unique_ptr packet_to_pace; - EXPECT_CALL(mock_paced_sender_, EnqueuePackets) - .WillOnce([&](std::vector> packets) { - packet_to_pace = std::move(packets[0]); - }); - - SendGenericPacket(); - packet_to_pace->set_fec_protect_packet(true); - rtp_sender_context_->InjectPacket(std::move(packet_to_pace), - PacedPacketInfo()); - - ASSERT_EQ(1u, transport_.sent_packets_.size()); + // Put a fec protected packet in the history. + const int64_t start_time_ms = clock_->TimeInMilliseconds(); + std::unique_ptr packet = + BuildRtpPacket(kPayload, kMarkerBit, start_time_ms, + /*capture_time_ms=*/start_time_ms); + packet->set_allow_retransmission(true); + packet->set_fec_protect_packet(true); + packet_history_->PutRtpPacket(std::move(packet), start_time_ms); // Re-send packet, the retransmitted packet should not have the FEC protection // flag set. EXPECT_CALL(mock_paced_sender_, - EnqueuePackets(Each(Pointee( + EnqueuePackets(ElementsAre(Pointee( Property(&RtpPacketToSend::fec_protect_packet, false))))); time_controller_.AdvanceTime(TimeDelta::Millis(kRtt)); - EXPECT_GT(rtp_sender()->ReSendPacket(kSeqNum), 0); + EXPECT_GT(rtp_sender_->ReSendPacket(kSeqNum), 0); } -TEST_P(RtpSenderTest, MarksPacketsWithKeyframeStatus) { +TEST_F(RtpSenderTest, MarksPacketsWithKeyframeStatus) { FieldTrialBasedConfig field_trials; RTPSenderVideo::Config video_config; video_config.clock = clock_; - video_config.rtp_sender = rtp_sender(); + video_config.rtp_sender = rtp_sender_.get(); video_config.field_trials = &field_trials; RTPSenderVideo rtp_sender_video(video_config); @@ -2866,14 +1338,4 @@ TEST_P(RtpSenderTest, MarksPacketsWithKeyframeStatus) { } } -INSTANTIATE_TEST_SUITE_P(WithAndWithoutOverhead, - RtpSenderTest, - ::testing::Values(TestConfig{false}, - TestConfig{true})); - -INSTANTIATE_TEST_SUITE_P(WithAndWithoutOverhead, - RtpSenderTestWithoutPacer, - ::testing::Values(TestConfig{false}, - TestConfig{true})); - } // namespace webrtc diff --git a/modules/rtp_rtcp/source/rtp_sender_video.cc b/modules/rtp_rtcp/source/rtp_sender_video.cc index 934be824a4..4919e3ebf4 100644 --- a/modules/rtp_rtcp/source/rtp_sender_video.cc +++ b/modules/rtp_rtcp/source/rtp_sender_video.cc @@ -169,8 +169,7 @@ RTPSenderVideo::RTPSenderVideo(const Config& config) absolute_capture_time_sender_(config.clock), frame_transformer_delegate_( config.frame_transformer - ? new rtc::RefCountedObject< - RTPSenderVideoFrameTransformerDelegate>( + ? rtc::make_ref_counted( this, config.frame_transformer, rtp_sender_->SSRC(), @@ -362,7 +361,8 @@ void RTPSenderVideo::AddRtpHeaderExtensions( if (video_header.generic) { bool extension_is_set = false; - if (video_structure_ != nullptr) { + if (packet->IsRegistered() && + video_structure_ != nullptr) { DependencyDescriptor descriptor; descriptor.first_packet_in_frame = first_packet; descriptor.last_packet_in_frame = last_packet; @@ -408,7 +408,8 @@ void RTPSenderVideo::AddRtpHeaderExtensions( } // Do not use generic frame descriptor when dependency descriptor is stored. - if (!extension_is_set) { + if (packet->IsRegistered() && + !extension_is_set) { RtpGenericFrameDescriptor generic_descriptor; generic_descriptor.SetFirstPacketInSubFrame(first_packet); generic_descriptor.SetLastPacketInSubFrame(last_packet); @@ -438,7 +439,8 @@ void RTPSenderVideo::AddRtpHeaderExtensions( } } - if (first_packet && + if (packet->IsRegistered() && + first_packet && send_allocation_ != SendVideoLayersAllocation::kDontSend && (video_header.frame_type == VideoFrameType::kVideoFrameKey || PacketWillLikelyBeRequestedForRestransmitionIfLost(video_header))) { @@ -447,6 +449,11 @@ void RTPSenderVideo::AddRtpHeaderExtensions( send_allocation_ == SendVideoLayersAllocation::kSendWithResolution; packet->SetExtension(allocation); } + + if (first_packet && video_header.video_frame_tracking_id) { + packet->SetExtension( + *video_header.video_frame_tracking_id); + } } bool RTPSenderVideo::SendVideo( @@ -519,7 +526,8 @@ bool RTPSenderVideo::SendVideo( AbsoluteCaptureTimeSender::GetSource(single_packet->Ssrc(), single_packet->Csrcs()), single_packet->Timestamp(), kVideoPayloadTypeFrequency, - Int64MsToUQ32x32(single_packet->capture_time_ms() + NtpOffsetMs()), + Int64MsToUQ32x32( + clock_->ConvertTimestampToNtpTimeInMilliseconds(capture_time_ms)), /*estimated_capture_clock_offset=*/ include_capture_clock_offset_ ? estimated_capture_clock_offset_ms : absl::nullopt); @@ -648,8 +656,6 @@ bool RTPSenderVideo::SendVideo( if (!packetizer->NextPacket(packet.get())) return false; RTC_DCHECK_LE(packet->payload_size(), expected_payload_capacity); - if (!rtp_sender_->AssignSequenceNumber(packet.get())) - return false; packet->set_allow_retransmission(allow_retransmission); packet->set_is_key_frame(video_header.frame_type == @@ -670,7 +676,7 @@ bool RTPSenderVideo::SendVideo( red_packet->SetPayloadType(*red_payload_type_); red_packet->set_is_red(true); - // Send |red_packet| instead of |packet| for allocated sequence number. + // Append |red_packet| instead of |packet| to output. red_packet->set_packet_type(RtpPacketMediaType::kVideo); red_packet->set_allow_retransmission(packet->allow_retransmission()); rtp_packets.emplace_back(std::move(red_packet)); @@ -691,6 +697,11 @@ bool RTPSenderVideo::SendVideo( } } + if (!rtp_sender_->AssignSequenceNumbersAndStoreLastPacketState(rtp_packets)) { + // Media not being sent. + return false; + } + LogAndSendToNetwork(std::move(rtp_packets), payload.size()); // Update details about the last sent frame. diff --git a/modules/rtp_rtcp/source/rtp_sender_video.h b/modules/rtp_rtcp/source/rtp_sender_video.h index 6e469900d6..ba8d7e8360 100644 --- a/modules/rtp_rtcp/source/rtp_sender_video.h +++ b/modules/rtp_rtcp/source/rtp_sender_video.h @@ -20,6 +20,7 @@ #include "api/array_view.h" #include "api/frame_transformer_interface.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_base.h" #include "api/transport/rtp/dependency_descriptor.h" #include "api/video/video_codec_type.h" @@ -37,7 +38,6 @@ #include "rtc_base/race_checker.h" #include "rtc_base/rate_statistics.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/thread_annotations.h" namespace webrtc { @@ -89,6 +89,7 @@ class RTPSenderVideo { virtual ~RTPSenderVideo(); // expected_retransmission_time_ms.has_value() -> retransmission allowed. + // `capture_time_ms` and `clock::CurrentTime` should be using the same epoch. // Calls to this method is assumed to be externally serialized. // |estimated_capture_clock_offset_ms| is an estimated clock offset between // this sender and the original capturer, for this video packet. See diff --git a/modules/rtp_rtcp/source/rtp_sender_video_frame_transformer_delegate.cc b/modules/rtp_rtcp/source/rtp_sender_video_frame_transformer_delegate.cc index 074b64086a..23e66bf757 100644 --- a/modules/rtp_rtcp/source/rtp_sender_video_frame_transformer_delegate.cc +++ b/modules/rtp_rtcp/source/rtp_sender_video_frame_transformer_delegate.cc @@ -129,9 +129,10 @@ void RTPSenderVideoFrameTransformerDelegate::OnTransformedFrame( std::unique_ptr frame) { MutexLock lock(&sender_lock_); - // The encoder queue gets destroyed after the sender; as long as the sender is - // alive, it's safe to post. - if (!sender_) + // The encoder queue normally gets destroyed after the sender; + // however, it might still be null by the time a previously queued frame + // arrives. + if (!sender_ || !encoder_queue_) return; rtc::scoped_refptr delegate = this; encoder_queue_->PostTask(ToQueuedTask( diff --git a/modules/rtp_rtcp/source/rtp_sender_video_unittest.cc b/modules/rtp_rtcp/source/rtp_sender_video_unittest.cc index 55bafdc790..ea727828cc 100644 --- a/modules/rtp_rtcp/source/rtp_sender_video_unittest.cc +++ b/modules/rtp_rtcp/source/rtp_sender_video_unittest.cc @@ -34,7 +34,6 @@ #include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "modules/rtp_rtcp/source/rtp_rtcp_impl2.h" #include "modules/rtp_rtcp/source/rtp_video_layers_allocation_extension.h" -#include "modules/rtp_rtcp/source/time_util.h" #include "rtc_base/arraysize.h" #include "rtc_base/rate_limiter.h" #include "rtc_base/task_queue_for_test.h" @@ -1054,8 +1053,10 @@ TEST_P(RtpSenderVideoTest, AbsoluteCaptureTime) { packet.GetExtension(); if (absolute_capture_time) { ++packets_with_abs_capture_time; - EXPECT_EQ(absolute_capture_time->absolute_capture_timestamp, - Int64MsToUQ32x32(kAbsoluteCaptureTimestampMs + NtpOffsetMs())); + EXPECT_EQ( + absolute_capture_time->absolute_capture_timestamp, + Int64MsToUQ32x32(fake_clock_.ConvertTimestampToNtpTimeInMilliseconds( + kAbsoluteCaptureTimestampMs))); EXPECT_FALSE( absolute_capture_time->estimated_capture_clock_offset.has_value()); } @@ -1092,8 +1093,10 @@ TEST_P(RtpSenderVideoTest, AbsoluteCaptureTimeWithCaptureClockOffset) { packet.GetExtension(); if (absolute_capture_time) { ++packets_with_abs_capture_time; - EXPECT_EQ(absolute_capture_time->absolute_capture_timestamp, - Int64MsToUQ32x32(kAbsoluteCaptureTimestampMs + NtpOffsetMs())); + EXPECT_EQ( + absolute_capture_time->absolute_capture_timestamp, + Int64MsToUQ32x32(fake_clock_.ConvertTimestampToNtpTimeInMilliseconds( + kAbsoluteCaptureTimestampMs))); EXPECT_EQ(kExpectedCaptureClockOffset, absolute_capture_time->estimated_capture_clock_offset); } @@ -1158,6 +1161,55 @@ TEST_P(RtpSenderVideoTest, PopulatesPlayoutDelay) { EXPECT_EQ(received_delay, kExpectedDelay); } +TEST_P(RtpSenderVideoTest, SendGenericVideo) { + const uint8_t kPayloadType = 127; + const VideoCodecType kCodecType = VideoCodecType::kVideoCodecGeneric; + const uint8_t kPayload[] = {47, 11, 32, 93, 89}; + + // Send keyframe. + RTPVideoHeader video_header; + video_header.frame_type = VideoFrameType::kVideoFrameKey; + ASSERT_TRUE(rtp_sender_video_->SendVideo(kPayloadType, kCodecType, 1234, 4321, + kPayload, video_header, + absl::nullopt)); + + rtc::ArrayView sent_payload = + transport_.last_sent_packet().payload(); + uint8_t generic_header = sent_payload[0]; + EXPECT_TRUE(generic_header & RtpFormatVideoGeneric::kKeyFrameBit); + EXPECT_TRUE(generic_header & RtpFormatVideoGeneric::kFirstPacketBit); + EXPECT_THAT(sent_payload.subview(1), ElementsAreArray(kPayload)); + + // Send delta frame. + const uint8_t kDeltaPayload[] = {13, 42, 32, 93, 13}; + video_header.frame_type = VideoFrameType::kVideoFrameDelta; + ASSERT_TRUE(rtp_sender_video_->SendVideo(kPayloadType, kCodecType, 1234, 4321, + kDeltaPayload, video_header, + absl::nullopt)); + + sent_payload = sent_payload = transport_.last_sent_packet().payload(); + generic_header = sent_payload[0]; + EXPECT_FALSE(generic_header & RtpFormatVideoGeneric::kKeyFrameBit); + EXPECT_TRUE(generic_header & RtpFormatVideoGeneric::kFirstPacketBit); + EXPECT_THAT(sent_payload.subview(1), ElementsAreArray(kDeltaPayload)); +} + +TEST_P(RtpSenderVideoTest, SendRawVideo) { + const uint8_t kPayloadType = 111; + const uint8_t kPayload[] = {11, 22, 33, 44, 55}; + + // Send a frame. + RTPVideoHeader video_header; + video_header.frame_type = VideoFrameType::kVideoFrameKey; + ASSERT_TRUE(rtp_sender_video_->SendVideo(kPayloadType, absl::nullopt, 1234, + 4321, kPayload, video_header, + absl::nullopt)); + + rtc::ArrayView sent_payload = + transport_.last_sent_packet().payload(); + EXPECT_THAT(sent_payload, ElementsAreArray(kPayload)); +} + INSTANTIATE_TEST_SUITE_P(WithAndWithoutOverhead, RtpSenderVideoTest, ::testing::Bool()); diff --git a/modules/rtp_rtcp/source/rtp_util.cc b/modules/rtp_rtcp/source/rtp_util.cc new file mode 100644 index 0000000000..46c641ea2f --- /dev/null +++ b/modules/rtp_rtcp/source/rtp_util.cc @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/rtp_rtcp/source/rtp_util.h" + +#include +#include + +#include "api/array_view.h" + +namespace webrtc { +namespace { + +constexpr uint8_t kRtpVersion = 2; +constexpr size_t kMinRtpPacketLen = 12; +constexpr size_t kMinRtcpPacketLen = 4; + +bool HasCorrectRtpVersion(rtc::ArrayView packet) { + return packet[0] >> 6 == kRtpVersion; +} + +// For additional details, see http://tools.ietf.org/html/rfc5761#section-4 +bool PayloadTypeIsReservedForRtcp(uint8_t payload_type) { + return 64 <= payload_type && payload_type < 96; +} + +} // namespace + +bool IsRtpPacket(rtc::ArrayView packet) { + return packet.size() >= kMinRtpPacketLen && HasCorrectRtpVersion(packet) && + !PayloadTypeIsReservedForRtcp(packet[1] & 0x7F); +} + +bool IsRtcpPacket(rtc::ArrayView packet) { + return packet.size() >= kMinRtcpPacketLen && HasCorrectRtpVersion(packet) && + PayloadTypeIsReservedForRtcp(packet[1] & 0x7F); +} + +} // namespace webrtc diff --git a/test/pc/e2e/analyzer/video/id_generator.cc b/modules/rtp_rtcp/source/rtp_util.h similarity index 52% rename from test/pc/e2e/analyzer/video/id_generator.cc rename to modules/rtp_rtcp/source/rtp_util.h index f1ead37e2f..b85727bf47 100644 --- a/test/pc/e2e/analyzer/video/id_generator.cc +++ b/modules/rtp_rtcp/source/rtp_util.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source @@ -8,17 +8,18 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include "test/pc/e2e/analyzer/video/id_generator.h" +#ifndef MODULES_RTP_RTCP_SOURCE_RTP_UTIL_H_ +#define MODULES_RTP_RTCP_SOURCE_RTP_UTIL_H_ -namespace webrtc { -namespace webrtc_pc_e2e { +#include + +#include "api/array_view.h" -IntIdGenerator::IntIdGenerator(int start_value) : next_id_(start_value) {} -IntIdGenerator::~IntIdGenerator() = default; +namespace webrtc { -int IntIdGenerator::GetNextId() { - return next_id_++; -} +bool IsRtcpPacket(rtc::ArrayView packet); +bool IsRtpPacket(rtc::ArrayView packet); -} // namespace webrtc_pc_e2e } // namespace webrtc + +#endif // MODULES_RTP_RTCP_SOURCE_RTP_UTIL_H_ diff --git a/modules/rtp_rtcp/source/rtp_util_unittest.cc b/modules/rtp_rtcp/source/rtp_util_unittest.cc new file mode 100644 index 0000000000..8f980ecff1 --- /dev/null +++ b/modules/rtp_rtcp/source/rtp_util_unittest.cc @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/rtp_rtcp/source/rtp_util.h" + +#include "test/gmock.h" + +namespace webrtc { +namespace { + +TEST(RtpUtil, IsRtpPacket) { + constexpr uint8_t kMinimalisticRtpPacket[] = {0x80, 97, 0, 0, // + 0, 0, 0, 0, // + 0, 0, 0, 0}; + EXPECT_TRUE(IsRtpPacket(kMinimalisticRtpPacket)); + + constexpr uint8_t kWrongRtpVersion[] = {0xc0, 97, 0, 0, // + 0, 0, 0, 0, // + 0, 0, 0, 0}; + EXPECT_FALSE(IsRtpPacket(kWrongRtpVersion)); + + constexpr uint8_t kPacketWithPayloadForRtcp[] = {0x80, 200, 0, 0, // + 0, 0, 0, 0, // + 0, 0, 0, 0}; + EXPECT_FALSE(IsRtpPacket(kPacketWithPayloadForRtcp)); + + constexpr uint8_t kTooSmallRtpPacket[] = {0x80, 97, 0, 0, // + 0, 0, 0, 0, // + 0, 0, 0}; + EXPECT_FALSE(IsRtpPacket(kTooSmallRtpPacket)); + + EXPECT_FALSE(IsRtpPacket({})); +} + +TEST(RtpUtil, IsRtcpPacket) { + constexpr uint8_t kMinimalisticRtcpPacket[] = {0x80, 202, 0, 0}; + EXPECT_TRUE(IsRtcpPacket(kMinimalisticRtcpPacket)); + + constexpr uint8_t kWrongRtpVersion[] = {0xc0, 202, 0, 0}; + EXPECT_FALSE(IsRtcpPacket(kWrongRtpVersion)); + + constexpr uint8_t kPacketWithPayloadForRtp[] = {0x80, 225, 0, 0}; + EXPECT_FALSE(IsRtcpPacket(kPacketWithPayloadForRtp)); + + constexpr uint8_t kTooSmallRtcpPacket[] = {0x80, 202, 0}; + EXPECT_FALSE(IsRtcpPacket(kTooSmallRtcpPacket)); + + EXPECT_FALSE(IsRtcpPacket({})); +} + +} // namespace +} // namespace webrtc diff --git a/modules/rtp_rtcp/source/rtp_utility.cc b/modules/rtp_rtcp/source/rtp_utility.cc index a3d6d6f7f1..9b68f0dead 100644 --- a/modules/rtp_rtcp/source/rtp_utility.cc +++ b/modules/rtp_rtcp/source/rtp_utility.cc @@ -131,7 +131,7 @@ bool RtpHeaderParser::RTCP() const { } bool RtpHeaderParser::ParseRtcp(RTPHeader* header) const { - assert(header != NULL); + RTC_DCHECK(header); const ptrdiff_t length = _ptrRTPDataEnd - _ptrRTPDataBegin; if (length < kRtcpMinParseLength) { @@ -364,6 +364,10 @@ void RtpHeaderParser::ParseOneByteExtensionHeader( header->extension.hasTransmissionTimeOffset = true; break; } + case kRtpExtensionCsrcAudioLevel: { + RTC_LOG(LS_WARNING) << "Csrc audio level extension not supported"; + return; + } case kRtpExtensionAudioLevel: { if (len != 0) { RTC_LOG(LS_WARNING) << "Incorrect audio level len: " << len; @@ -536,6 +540,10 @@ void RtpHeaderParser::ParseOneByteExtensionHeader( RTC_LOG(WARNING) << "Inband comfort noise extension unsupported by " "rtp header parser."; break; + case kRtpExtensionVideoFrameTrackingId: + RTC_LOG(WARNING) + << "VideoFrameTrackingId unsupported by rtp header parser."; + break; case kRtpExtensionNone: case kRtpExtensionNumberOfExtensions: { RTC_NOTREACHED() << "Invalid extension type: " << type; diff --git a/modules/rtp_rtcp/source/rtp_utility.h b/modules/rtp_rtcp/source/rtp_utility.h index cdda9ef119..cdfff4072f 100644 --- a/modules/rtp_rtcp/source/rtp_utility.h +++ b/modules/rtp_rtcp/source/rtp_utility.h @@ -15,6 +15,7 @@ #include +#include "absl/base/attributes.h" #include "absl/strings/string_view.h" #include "api/rtp_headers.h" #include "modules/rtp_rtcp/include/rtp_header_extension_map.h" @@ -34,6 +35,7 @@ class RtpHeaderParser { RtpHeaderParser(const uint8_t* rtpData, size_t rtpDataLength); ~RtpHeaderParser(); + ABSL_DEPRECATED("Use IsRtpPacket or IsRtcpPacket") bool RTCP() const; bool ParseRtcp(RTPHeader* header) const; bool Parse(RTPHeader* parsedPacket, diff --git a/modules/rtp_rtcp/source/rtp_video_header.h b/modules/rtp_rtcp/source/rtp_video_header.h index 8a2fcba939..c1be76fa4c 100644 --- a/modules/rtp_rtcp/source/rtp_video_header.h +++ b/modules/rtp_rtcp/source/rtp_video_header.h @@ -77,6 +77,9 @@ struct RTPVideoHeader { VideoPlayoutDelay playout_delay; VideoSendTiming video_timing; absl::optional color_space; + // This field is meant for media quality testing purpose only. When enabled it + // carries the webrtc::VideoFrame id field from the sender to the receiver. + absl::optional video_frame_tracking_id; RTPVideoTypeHeader video_type_header; }; diff --git a/modules/rtp_rtcp/source/rtp_video_layers_allocation_extension.cc b/modules/rtp_rtcp/source/rtp_video_layers_allocation_extension.cc index 1587bc34cf..93fb235dcd 100644 --- a/modules/rtp_rtcp/source/rtp_video_layers_allocation_extension.cc +++ b/modules/rtp_rtcp/source/rtp_video_layers_allocation_extension.cc @@ -110,14 +110,14 @@ bool AllocationIsValid(const VideoLayersAllocation& allocation) { if (spatial_layer.height <= 0) { return false; } - if (spatial_layer.frame_rate_fps < 0 || - spatial_layer.frame_rate_fps > 255) { + if (spatial_layer.frame_rate_fps > 255) { return false; } } } if (allocation.rtp_stream_index < 0 || - allocation.rtp_stream_index > max_rtp_stream_idx) { + (!allocation.active_spatial_layers.empty() && + allocation.rtp_stream_index > max_rtp_stream_idx)) { return false; } return true; @@ -201,17 +201,21 @@ SpatialLayersBitmasks SpatialLayersBitmasksPerRtpStream( // Encoded (width - 1), 16-bit, (height - 1), 16-bit, max frame rate 8-bit // per spatial layer per RTP stream. // Values are stored in (RTP stream id, spatial id) ascending order. +// +// An empty layer allocation (i.e nothing sent on ssrc) is encoded as +// special case with a single 0 byte. bool RtpVideoLayersAllocationExtension::Write( rtc::ArrayView data, const VideoLayersAllocation& allocation) { - if (allocation.active_spatial_layers.empty()) { - return false; - } - RTC_DCHECK(AllocationIsValid(allocation)); RTC_DCHECK_GE(data.size(), ValueSize(allocation)); + if (allocation.active_spatial_layers.empty()) { + data[0] = 0; + return true; + } + SpatialLayersBitmasks slb = SpatialLayersBitmasksPerRtpStream(allocation); uint8_t* write_at = data.data(); // First half of the header byte. @@ -276,10 +280,18 @@ bool RtpVideoLayersAllocationExtension::Parse( if (data.empty() || allocation == nullptr) { return false; } + + allocation->active_spatial_layers.clear(); + const uint8_t* read_at = data.data(); const uint8_t* const end = data.data() + data.size(); - allocation->active_spatial_layers.clear(); + if (data.size() == 1 && *read_at == 0) { + allocation->rtp_stream_index = 0; + allocation->resolution_and_frame_rate_is_valid = true; + return true; + } + // Header byte. allocation->rtp_stream_index = *read_at >> 6; int num_rtp_streams = 1 + ((*read_at >> 4) & 0b11); @@ -374,7 +386,7 @@ bool RtpVideoLayersAllocationExtension::Parse( size_t RtpVideoLayersAllocationExtension::ValueSize( const VideoLayersAllocation& allocation) { if (allocation.active_spatial_layers.empty()) { - return 0; + return 1; } size_t result = 1; // header SpatialLayersBitmasks slb = SpatialLayersBitmasksPerRtpStream(allocation); diff --git a/modules/rtp_rtcp/source/rtp_video_layers_allocation_extension_unittest.cc b/modules/rtp_rtcp/source/rtp_video_layers_allocation_extension_unittest.cc index c8363ae257..92e5673441 100644 --- a/modules/rtp_rtcp/source/rtp_video_layers_allocation_extension_unittest.cc +++ b/modules/rtp_rtcp/source/rtp_video_layers_allocation_extension_unittest.cc @@ -19,13 +19,31 @@ namespace webrtc { namespace { +TEST(RtpVideoLayersAllocationExtension, WriteEmptyLayersAllocationReturnsTrue) { + VideoLayersAllocation written_allocation; + rtc::Buffer buffer( + RtpVideoLayersAllocationExtension::ValueSize(written_allocation)); + EXPECT_TRUE( + RtpVideoLayersAllocationExtension::Write(buffer, written_allocation)); +} + TEST(RtpVideoLayersAllocationExtension, - WriteEmptyLayersAllocationReturnsFalse) { + CanWriteAndParseLayersAllocationWithZeroSpatialLayers) { + // We require the resolution_and_frame_rate_is_valid to be set to true in + // order to send an "empty" allocation. VideoLayersAllocation written_allocation; + written_allocation.resolution_and_frame_rate_is_valid = true; + written_allocation.rtp_stream_index = 0; + rtc::Buffer buffer( RtpVideoLayersAllocationExtension::ValueSize(written_allocation)); - EXPECT_FALSE( + EXPECT_TRUE( RtpVideoLayersAllocationExtension::Write(buffer, written_allocation)); + + VideoLayersAllocation parsed_allocation; + EXPECT_TRUE( + RtpVideoLayersAllocationExtension::Parse(buffer, &parsed_allocation)); + EXPECT_EQ(written_allocation, parsed_allocation); } TEST(RtpVideoLayersAllocationExtension, @@ -221,5 +239,15 @@ TEST(RtpVideoLayersAllocationExtension, EXPECT_EQ(written_allocation, parsed_allocation); } +TEST(RtpVideoLayersAllocationExtension, + WriteEmptyAllocationCanHaveAnyRtpStreamIndex) { + VideoLayersAllocation written_allocation; + written_allocation.rtp_stream_index = 1; + rtc::Buffer buffer( + RtpVideoLayersAllocationExtension::ValueSize(written_allocation)); + EXPECT_TRUE( + RtpVideoLayersAllocationExtension::Write(buffer, written_allocation)); +} + } // namespace } // namespace webrtc diff --git a/modules/rtp_rtcp/source/source_tracker_unittest.cc b/modules/rtp_rtcp/source/source_tracker_unittest.cc index 32f9f4b2a3..8514e8462d 100644 --- a/modules/rtp_rtcp/source/source_tracker_unittest.cc +++ b/modules/rtp_rtcp/source/source_tracker_unittest.cc @@ -111,7 +111,7 @@ class SourceTrackerRandomTest packet_infos.emplace_back(GenerateSsrc(), GenerateCsrcs(), GenerateRtpTimestamp(), GenerateAudioLevel(), GenerateAbsoluteCaptureTime(), - GenerateReceiveTimeMs()); + GenerateReceiveTime()); } return RtpPacketInfos(std::move(packet_infos)); @@ -192,8 +192,9 @@ class SourceTrackerRandomTest return value; } - int64_t GenerateReceiveTimeMs() { - return std::uniform_int_distribution()(generator_); + Timestamp GenerateReceiveTime() { + return Timestamp::Micros( + std::uniform_int_distribution()(generator_)); } const uint32_t ssrcs_count_; @@ -239,78 +240,156 @@ TEST(SourceTrackerTest, StartEmpty) { EXPECT_THAT(tracker.GetSources(), IsEmpty()); } -TEST(SourceTrackerTest, OnFrameDeliveredRecordsSources) { +TEST(SourceTrackerTest, OnFrameDeliveredRecordsSourcesDistinctSsrcs) { + constexpr uint32_t kSsrc1 = 10; + constexpr uint32_t kSsrc2 = 11; + constexpr uint32_t kCsrcs0 = 20; + constexpr uint32_t kCsrcs1 = 21; + constexpr uint32_t kCsrcs2 = 22; + constexpr uint32_t kRtpTimestamp0 = 40; + constexpr uint32_t kRtpTimestamp1 = 50; + constexpr absl::optional kAudioLevel0 = 50; + constexpr absl::optional kAudioLevel1 = 20; + constexpr absl::optional kAbsoluteCaptureTime = + AbsoluteCaptureTime{/*absolute_capture_timestamp=*/12, + /*estimated_capture_clock_offset=*/absl::nullopt}; + constexpr Timestamp kReceiveTime0 = Timestamp::Millis(60); + constexpr Timestamp kReceiveTime1 = Timestamp::Millis(70); + + SimulatedClock clock(1000000000000ULL); + SourceTracker tracker(&clock); + + tracker.OnFrameDelivered(RtpPacketInfos( + {RtpPacketInfo(kSsrc1, {kCsrcs0, kCsrcs1}, kRtpTimestamp0, kAudioLevel0, + kAbsoluteCaptureTime, kReceiveTime0), + RtpPacketInfo(kSsrc2, {kCsrcs2}, kRtpTimestamp1, kAudioLevel1, + kAbsoluteCaptureTime, kReceiveTime1)})); + + int64_t timestamp_ms = clock.TimeInMilliseconds(); + constexpr RtpSource::Extensions extensions0 = {kAudioLevel0, + kAbsoluteCaptureTime}; + constexpr RtpSource::Extensions extensions1 = {kAudioLevel1, + kAbsoluteCaptureTime}; + + EXPECT_THAT(tracker.GetSources(), + ElementsAre(RtpSource(timestamp_ms, kSsrc2, RtpSourceType::SSRC, + kRtpTimestamp1, extensions1), + RtpSource(timestamp_ms, kCsrcs2, RtpSourceType::CSRC, + kRtpTimestamp1, extensions1), + RtpSource(timestamp_ms, kSsrc1, RtpSourceType::SSRC, + kRtpTimestamp0, extensions0), + RtpSource(timestamp_ms, kCsrcs1, RtpSourceType::CSRC, + kRtpTimestamp0, extensions0), + RtpSource(timestamp_ms, kCsrcs0, RtpSourceType::CSRC, + kRtpTimestamp0, extensions0))); +} + +TEST(SourceTrackerTest, OnFrameDeliveredRecordsSourcesSameSsrc) { constexpr uint32_t kSsrc = 10; constexpr uint32_t kCsrcs0 = 20; constexpr uint32_t kCsrcs1 = 21; - constexpr uint32_t kRtpTimestamp = 40; - constexpr absl::optional kAudioLevel = 50; + constexpr uint32_t kCsrcs2 = 22; + constexpr uint32_t kRtpTimestamp0 = 40; + constexpr uint32_t kRtpTimestamp1 = 45; + constexpr uint32_t kRtpTimestamp2 = 50; + constexpr absl::optional kAudioLevel0 = 50; + constexpr absl::optional kAudioLevel1 = 20; + constexpr absl::optional kAudioLevel2 = 10; constexpr absl::optional kAbsoluteCaptureTime = AbsoluteCaptureTime{/*absolute_capture_timestamp=*/12, /*estimated_capture_clock_offset=*/absl::nullopt}; - constexpr int64_t kReceiveTimeMs = 60; + constexpr Timestamp kReceiveTime0 = Timestamp::Millis(60); + constexpr Timestamp kReceiveTime1 = Timestamp::Millis(70); + constexpr Timestamp kReceiveTime2 = Timestamp::Millis(80); SimulatedClock clock(1000000000000ULL); SourceTracker tracker(&clock); tracker.OnFrameDelivered(RtpPacketInfos( - {RtpPacketInfo(kSsrc, {kCsrcs0, kCsrcs1}, kRtpTimestamp, kAudioLevel, - kAbsoluteCaptureTime, kReceiveTimeMs)})); + {RtpPacketInfo(kSsrc, {kCsrcs0, kCsrcs1}, kRtpTimestamp0, kAudioLevel0, + kAbsoluteCaptureTime, kReceiveTime0), + RtpPacketInfo(kSsrc, {kCsrcs2}, kRtpTimestamp1, kAudioLevel1, + kAbsoluteCaptureTime, kReceiveTime1), + RtpPacketInfo(kSsrc, {kCsrcs0}, kRtpTimestamp2, kAudioLevel2, + kAbsoluteCaptureTime, kReceiveTime2)})); int64_t timestamp_ms = clock.TimeInMilliseconds(); - constexpr RtpSource::Extensions extensions = {kAudioLevel, - kAbsoluteCaptureTime}; + constexpr RtpSource::Extensions extensions0 = {kAudioLevel0, + kAbsoluteCaptureTime}; + constexpr RtpSource::Extensions extensions1 = {kAudioLevel1, + kAbsoluteCaptureTime}; + constexpr RtpSource::Extensions extensions2 = {kAudioLevel2, + kAbsoluteCaptureTime}; EXPECT_THAT(tracker.GetSources(), ElementsAre(RtpSource(timestamp_ms, kSsrc, RtpSourceType::SSRC, - kRtpTimestamp, extensions), - RtpSource(timestamp_ms, kCsrcs1, RtpSourceType::CSRC, - kRtpTimestamp, extensions), + kRtpTimestamp2, extensions2), RtpSource(timestamp_ms, kCsrcs0, RtpSourceType::CSRC, - kRtpTimestamp, extensions))); + kRtpTimestamp2, extensions2), + RtpSource(timestamp_ms, kCsrcs2, RtpSourceType::CSRC, + kRtpTimestamp1, extensions1), + RtpSource(timestamp_ms, kCsrcs1, RtpSourceType::CSRC, + kRtpTimestamp0, extensions0))); } TEST(SourceTrackerTest, OnFrameDeliveredUpdatesSources) { - constexpr uint32_t kSsrc = 10; + constexpr uint32_t kSsrc1 = 10; + constexpr uint32_t kSsrc2 = 11; constexpr uint32_t kCsrcs0 = 20; constexpr uint32_t kCsrcs1 = 21; constexpr uint32_t kCsrcs2 = 22; constexpr uint32_t kRtpTimestamp0 = 40; constexpr uint32_t kRtpTimestamp1 = 41; + constexpr uint32_t kRtpTimestamp2 = 42; constexpr absl::optional kAudioLevel0 = 50; constexpr absl::optional kAudioLevel1 = absl::nullopt; + constexpr absl::optional kAudioLevel2 = 10; constexpr absl::optional kAbsoluteCaptureTime0 = AbsoluteCaptureTime{12, 34}; constexpr absl::optional kAbsoluteCaptureTime1 = AbsoluteCaptureTime{56, 78}; - constexpr int64_t kReceiveTimeMs0 = 60; - constexpr int64_t kReceiveTimeMs1 = 61; + constexpr absl::optional kAbsoluteCaptureTime2 = + AbsoluteCaptureTime{89, 90}; + constexpr Timestamp kReceiveTime0 = Timestamp::Millis(60); + constexpr Timestamp kReceiveTime1 = Timestamp::Millis(61); + constexpr Timestamp kReceiveTime2 = Timestamp::Millis(62); + + constexpr RtpSource::Extensions extensions0 = {kAudioLevel0, + kAbsoluteCaptureTime0}; + constexpr RtpSource::Extensions extensions1 = {kAudioLevel1, + kAbsoluteCaptureTime1}; + constexpr RtpSource::Extensions extensions2 = {kAudioLevel2, + kAbsoluteCaptureTime2}; SimulatedClock clock(1000000000000ULL); SourceTracker tracker(&clock); tracker.OnFrameDelivered(RtpPacketInfos( - {RtpPacketInfo(kSsrc, {kCsrcs0, kCsrcs1}, kRtpTimestamp0, kAudioLevel0, - kAbsoluteCaptureTime0, kReceiveTimeMs0)})); + {RtpPacketInfo(kSsrc1, {kCsrcs0, kCsrcs1}, kRtpTimestamp0, kAudioLevel0, + kAbsoluteCaptureTime0, kReceiveTime0)})); int64_t timestamp_ms_0 = clock.TimeInMilliseconds(); + EXPECT_THAT( + tracker.GetSources(), + ElementsAre(RtpSource(timestamp_ms_0, kSsrc1, RtpSourceType::SSRC, + kRtpTimestamp0, extensions0), + RtpSource(timestamp_ms_0, kCsrcs1, RtpSourceType::CSRC, + kRtpTimestamp0, extensions0), + RtpSource(timestamp_ms_0, kCsrcs0, RtpSourceType::CSRC, + kRtpTimestamp0, extensions0))); - clock.AdvanceTimeMilliseconds(17); + // Deliver packets with updated sources. + clock.AdvanceTimeMilliseconds(17); tracker.OnFrameDelivered(RtpPacketInfos( - {RtpPacketInfo(kSsrc, {kCsrcs0, kCsrcs2}, kRtpTimestamp1, kAudioLevel1, - kAbsoluteCaptureTime1, kReceiveTimeMs1)})); + {RtpPacketInfo(kSsrc1, {kCsrcs0, kCsrcs2}, kRtpTimestamp1, kAudioLevel1, + kAbsoluteCaptureTime1, kReceiveTime1)})); int64_t timestamp_ms_1 = clock.TimeInMilliseconds(); - constexpr RtpSource::Extensions extensions0 = {kAudioLevel0, - kAbsoluteCaptureTime0}; - constexpr RtpSource::Extensions extensions1 = {kAudioLevel1, - kAbsoluteCaptureTime1}; - EXPECT_THAT( tracker.GetSources(), - ElementsAre(RtpSource(timestamp_ms_1, kSsrc, RtpSourceType::SSRC, + ElementsAre(RtpSource(timestamp_ms_1, kSsrc1, RtpSourceType::SSRC, kRtpTimestamp1, extensions1), RtpSource(timestamp_ms_1, kCsrcs2, RtpSourceType::CSRC, kRtpTimestamp1, extensions1), @@ -318,6 +397,27 @@ TEST(SourceTrackerTest, OnFrameDeliveredUpdatesSources) { kRtpTimestamp1, extensions1), RtpSource(timestamp_ms_0, kCsrcs1, RtpSourceType::CSRC, kRtpTimestamp0, extensions0))); + + // Deliver more packets with update csrcs and a new ssrc. + clock.AdvanceTimeMilliseconds(17); + tracker.OnFrameDelivered(RtpPacketInfos( + {RtpPacketInfo(kSsrc2, {kCsrcs0}, kRtpTimestamp2, kAudioLevel2, + kAbsoluteCaptureTime2, kReceiveTime2)})); + + int64_t timestamp_ms_2 = clock.TimeInMilliseconds(); + + EXPECT_THAT( + tracker.GetSources(), + ElementsAre(RtpSource(timestamp_ms_2, kSsrc2, RtpSourceType::SSRC, + kRtpTimestamp2, extensions2), + RtpSource(timestamp_ms_2, kCsrcs0, RtpSourceType::CSRC, + kRtpTimestamp2, extensions2), + RtpSource(timestamp_ms_1, kSsrc1, RtpSourceType::SSRC, + kRtpTimestamp1, extensions1), + RtpSource(timestamp_ms_1, kCsrcs2, RtpSourceType::CSRC, + kRtpTimestamp1, extensions1), + RtpSource(timestamp_ms_0, kCsrcs1, RtpSourceType::CSRC, + kRtpTimestamp0, extensions0))); } TEST(SourceTrackerTest, TimedOutSourcesAreRemoved) { @@ -333,21 +433,21 @@ TEST(SourceTrackerTest, TimedOutSourcesAreRemoved) { AbsoluteCaptureTime{12, 34}; constexpr absl::optional kAbsoluteCaptureTime1 = AbsoluteCaptureTime{56, 78}; - constexpr int64_t kReceiveTimeMs0 = 60; - constexpr int64_t kReceiveTimeMs1 = 61; + constexpr Timestamp kReceiveTime0 = Timestamp::Millis(60); + constexpr Timestamp kReceiveTime1 = Timestamp::Millis(61); SimulatedClock clock(1000000000000ULL); SourceTracker tracker(&clock); tracker.OnFrameDelivered(RtpPacketInfos( {RtpPacketInfo(kSsrc, {kCsrcs0, kCsrcs1}, kRtpTimestamp0, kAudioLevel0, - kAbsoluteCaptureTime0, kReceiveTimeMs0)})); + kAbsoluteCaptureTime0, kReceiveTime0)})); clock.AdvanceTimeMilliseconds(17); tracker.OnFrameDelivered(RtpPacketInfos( {RtpPacketInfo(kSsrc, {kCsrcs0, kCsrcs2}, kRtpTimestamp1, kAudioLevel1, - kAbsoluteCaptureTime1, kReceiveTimeMs1)})); + kAbsoluteCaptureTime1, kReceiveTime1)})); int64_t timestamp_ms_1 = clock.TimeInMilliseconds(); diff --git a/modules/rtp_rtcp/source/time_util.cc b/modules/rtp_rtcp/source/time_util.cc index b5b4f8bd98..fe0cfea11f 100644 --- a/modules/rtp_rtcp/source/time_util.cc +++ b/modules/rtp_rtcp/source/time_util.cc @@ -17,48 +17,6 @@ #include "rtc_base/time_utils.h" namespace webrtc { -namespace { - -int64_t NtpOffsetMsCalledOnce() { - constexpr int64_t kNtpJan1970Sec = 2208988800; - int64_t clock_time = rtc::TimeMillis(); - int64_t utc_time = rtc::TimeUTCMillis(); - return utc_time - clock_time + kNtpJan1970Sec * rtc::kNumMillisecsPerSec; -} - -} // namespace - -int64_t NtpOffsetMs() { - // Calculate the offset once. - static int64_t ntp_offset_ms = NtpOffsetMsCalledOnce(); - return ntp_offset_ms; -} - -NtpTime TimeMicrosToNtp(int64_t time_us) { - // Since this doesn't return a wallclock time, but only NTP representation - // of rtc::TimeMillis() clock, the exact offset doesn't matter. - // To simplify conversions between NTP and RTP time, this offset is - // limited to milliseconds in resolution. - int64_t time_ntp_us = time_us + NtpOffsetMs() * 1000; - RTC_DCHECK_GE(time_ntp_us, 0); // Time before year 1900 is unsupported. - - // TODO(danilchap): Convert both seconds and fraction together using int128 - // when that type is easily available. - // Currently conversion is done separetly for seconds and fraction of a second - // to avoid overflow. - - // Convert seconds to uint32 through uint64 for well-defined cast. - // Wrap around (will happen in 2036) is expected for ntp time. - uint32_t ntp_seconds = - static_cast(time_ntp_us / rtc::kNumMicrosecsPerSec); - - // Scale fractions of the second to ntp resolution. - constexpr int64_t kNtpInSecond = 1LL << 32; - int64_t us_fractions = time_ntp_us % rtc::kNumMicrosecsPerSec; - uint32_t ntp_fractions = - us_fractions * kNtpInSecond / rtc::kNumMicrosecsPerSec; - return NtpTime(ntp_seconds, ntp_fractions); -} uint32_t SaturatedUsToCompactNtp(int64_t us) { constexpr uint32_t kMaxCompactNtp = 0xFFFFFFFF; diff --git a/modules/rtp_rtcp/source/time_util.h b/modules/rtp_rtcp/source/time_util.h index 94b914310c..c883e5ca38 100644 --- a/modules/rtp_rtcp/source/time_util.h +++ b/modules/rtp_rtcp/source/time_util.h @@ -17,20 +17,6 @@ namespace webrtc { -// Converts time obtained using rtc::TimeMicros to ntp format. -// TimeMicrosToNtp guarantees difference of the returned values matches -// difference of the passed values. -// As a result TimeMicrosToNtp(rtc::TimeMicros()) doesn't guarantee to match -// system time. -// However, TimeMicrosToNtp Guarantees that returned NtpTime will be offsetted -// from rtc::TimeMicros() by integral number of milliseconds. -// Use NtpOffsetMs() to get that offset value. -NtpTime TimeMicrosToNtp(int64_t time_us); - -// Difference between Ntp time and local relative time returned by -// rtc::TimeMicros() -int64_t NtpOffsetMs(); - // Helper function for compact ntp representation: // RFC 3550, Section 4. Time Format. // Wallclock time is represented using the timestamp format of diff --git a/modules/rtp_rtcp/source/time_util_unittest.cc b/modules/rtp_rtcp/source/time_util_unittest.cc index 4b469bb956..6ff55dda55 100644 --- a/modules/rtp_rtcp/source/time_util_unittest.cc +++ b/modules/rtp_rtcp/source/time_util_unittest.cc @@ -9,34 +9,10 @@ */ #include "modules/rtp_rtcp/source/time_util.h" -#include "rtc_base/fake_clock.h" -#include "rtc_base/time_utils.h" -#include "system_wrappers/include/clock.h" #include "test/gtest.h" namespace webrtc { -TEST(TimeUtilTest, TimeMicrosToNtpDoesntChangeBetweenRuns) { - rtc::ScopedFakeClock clock; - // TimeMicrosToNtp is not pure: it behave differently between different - // execution of the program, but should behave same during same execution. - const int64_t time_us = 12345; - clock.SetTime(Timestamp::Micros(2)); - NtpTime time_ntp = TimeMicrosToNtp(time_us); - clock.SetTime(Timestamp::Micros(time_us)); - EXPECT_EQ(TimeMicrosToNtp(time_us), time_ntp); - clock.SetTime(Timestamp::Micros(1000000)); - EXPECT_EQ(TimeMicrosToNtp(time_us), time_ntp); -} - -TEST(TimeUtilTest, TimeMicrosToNtpKeepsIntervals) { - rtc::ScopedFakeClock clock; - NtpTime time_ntp1 = TimeMicrosToNtp(rtc::TimeMicros()); - clock.AdvanceTime(TimeDelta::Millis(20)); - NtpTime time_ntp2 = TimeMicrosToNtp(rtc::TimeMicros()); - EXPECT_EQ(time_ntp2.ToMs() - time_ntp1.ToMs(), 20); -} - TEST(TimeUtilTest, CompactNtp) { const uint32_t kNtpSec = 0x12345678; const uint32_t kNtpFrac = 0x23456789; diff --git a/modules/rtp_rtcp/source/ulpfec_header_reader_writer.cc b/modules/rtp_rtcp/source/ulpfec_header_reader_writer.cc index 2aebbead68..49f483dad6 100644 --- a/modules/rtp_rtcp/source/ulpfec_header_reader_writer.cc +++ b/modules/rtp_rtcp/source/ulpfec_header_reader_writer.cc @@ -24,6 +24,11 @@ namespace { // Maximum number of media packets that can be protected in one batch. constexpr size_t kMaxMediaPackets = 48; +// Maximum number of media packets tracked by FEC decoder. +// Maintain a sufficiently larger tracking window than |kMaxMediaPackets| +// to account for packet reordering in pacer/ network. +constexpr size_t kMaxTrackedMediaPackets = 4 * kMaxMediaPackets; + // Maximum number of FEC packets stored inside ForwardErrorCorrection. constexpr size_t kMaxFecPackets = kMaxMediaPackets; @@ -51,7 +56,7 @@ size_t UlpfecHeaderSize(size_t packet_mask_size) { } // namespace UlpfecHeaderReader::UlpfecHeaderReader() - : FecHeaderReader(kMaxMediaPackets, kMaxFecPackets) {} + : FecHeaderReader(kMaxTrackedMediaPackets, kMaxFecPackets) {} UlpfecHeaderReader::~UlpfecHeaderReader() = default; diff --git a/modules/rtp_rtcp/source/ulpfec_receiver_impl.cc b/modules/rtp_rtcp/source/ulpfec_receiver_impl.cc index fee0b9c4da..fdfa475186 100644 --- a/modules/rtp_rtcp/source/ulpfec_receiver_impl.cc +++ b/modules/rtp_rtcp/source/ulpfec_receiver_impl.cc @@ -37,12 +37,13 @@ UlpfecReceiverImpl::UlpfecReceiverImpl( fec_(ForwardErrorCorrection::CreateUlpfec(ssrc_)) {} UlpfecReceiverImpl::~UlpfecReceiverImpl() { + RTC_DCHECK_RUN_ON(&sequence_checker_); received_packets_.clear(); fec_->ResetState(&recovered_packets_); } FecPacketCounter UlpfecReceiverImpl::GetPacketCounter() const { - MutexLock lock(&mutex_); + RTC_DCHECK_RUN_ON(&sequence_checker_); return packet_counter_; } @@ -77,6 +78,10 @@ FecPacketCounter UlpfecReceiverImpl::GetPacketCounter() const { bool UlpfecReceiverImpl::AddReceivedRedPacket( const RtpPacketReceived& rtp_packet, uint8_t ulpfec_payload_type) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + // TODO(bugs.webrtc.org/11993): We get here via Call::DeliverRtp, so should be + // moved to the network thread. + if (rtp_packet.Ssrc() != ssrc_) { RTC_LOG(LS_WARNING) << "Received RED packet with different SSRC than expected; dropping."; @@ -87,7 +92,6 @@ bool UlpfecReceiverImpl::AddReceivedRedPacket( "packet size; dropping."; return false; } - MutexLock lock(&mutex_); static constexpr uint8_t kRedHeaderLength = 1; @@ -128,9 +132,8 @@ bool UlpfecReceiverImpl::AddReceivedRedPacket( rtp_packet.Buffer().Slice(rtp_packet.headers_size() + kRedHeaderLength, rtp_packet.payload_size() - kRedHeaderLength); } else { - auto red_payload = rtp_packet.payload().subview(kRedHeaderLength); - received_packet->pkt->data.EnsureCapacity(rtp_packet.headers_size() + - red_payload.size()); + received_packet->pkt->data.EnsureCapacity(rtp_packet.size() - + kRedHeaderLength); // Copy RTP header. received_packet->pkt->data.SetData(rtp_packet.data(), rtp_packet.headers_size()); @@ -138,9 +141,10 @@ bool UlpfecReceiverImpl::AddReceivedRedPacket( uint8_t& payload_type_byte = received_packet->pkt->data.MutableData()[1]; payload_type_byte &= 0x80; // Reset RED payload type. payload_type_byte += payload_type; // Set media payload type. - // Copy payload data. - received_packet->pkt->data.AppendData(red_payload.data(), - red_payload.size()); + // Copy payload and padding data, after the RED header. + received_packet->pkt->data.AppendData( + rtp_packet.data() + rtp_packet.headers_size() + kRedHeaderLength, + rtp_packet.size() - rtp_packet.headers_size() - kRedHeaderLength); } if (received_packet->pkt->data.size() > 0) { @@ -151,7 +155,7 @@ bool UlpfecReceiverImpl::AddReceivedRedPacket( // TODO(nisse): Drop always-zero return value. int32_t UlpfecReceiverImpl::ProcessReceivedFec() { - mutex_.Lock(); + RTC_DCHECK_RUN_ON(&sequence_checker_); // If we iterate over |received_packets_| and it contains a packet that cause // us to recurse back to this function (for example a RED packet encapsulating @@ -168,10 +172,8 @@ int32_t UlpfecReceiverImpl::ProcessReceivedFec() { // Send received media packet to VCM. if (!received_packet->is_fec) { ForwardErrorCorrection::Packet* packet = received_packet->pkt; - mutex_.Unlock(); recovered_packet_callback_->OnRecoveredPacket(packet->data.data(), packet->data.size()); - mutex_.Lock(); // Create a packet with the buffer to modify it. RtpPacketReceived rtp_packet; const uint8_t* const original_data = packet->data.cdata(); @@ -208,13 +210,10 @@ int32_t UlpfecReceiverImpl::ProcessReceivedFec() { // Set this flag first; in case the recovered packet carries a RED // header, OnRecoveredPacket will recurse back here. recovered_packet->returned = true; - mutex_.Unlock(); recovered_packet_callback_->OnRecoveredPacket(packet->data.data(), packet->data.size()); - mutex_.Lock(); } - mutex_.Unlock(); return 0; } diff --git a/modules/rtp_rtcp/source/ulpfec_receiver_impl.h b/modules/rtp_rtcp/source/ulpfec_receiver_impl.h index 2bed042747..f59251f848 100644 --- a/modules/rtp_rtcp/source/ulpfec_receiver_impl.h +++ b/modules/rtp_rtcp/source/ulpfec_receiver_impl.h @@ -17,12 +17,13 @@ #include #include +#include "api/sequence_checker.h" #include "modules/rtp_rtcp/include/rtp_header_extension_map.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" #include "modules/rtp_rtcp/include/ulpfec_receiver.h" #include "modules/rtp_rtcp/source/forward_error_correction.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" -#include "rtc_base/synchronization/mutex.h" +#include "rtc_base/system/no_unique_address.h" namespace webrtc { @@ -44,17 +45,18 @@ class UlpfecReceiverImpl : public UlpfecReceiver { const uint32_t ssrc_; const RtpHeaderExtensionMap extensions_; - mutable Mutex mutex_; - RecoveredPacketReceiver* recovered_packet_callback_; - std::unique_ptr fec_; + RTC_NO_UNIQUE_ADDRESS SequenceChecker sequence_checker_; + RecoveredPacketReceiver* const recovered_packet_callback_; + const std::unique_ptr fec_; // TODO(nisse): The AddReceivedRedPacket method adds one or two packets to // this list at a time, after which it is emptied by ProcessReceivedFec. It // will make things simpler to merge AddReceivedRedPacket and // ProcessReceivedFec into a single method, and we can then delete this list. std::vector> - received_packets_; - ForwardErrorCorrection::RecoveredPacketList recovered_packets_; - FecPacketCounter packet_counter_; + received_packets_ RTC_GUARDED_BY(&sequence_checker_); + ForwardErrorCorrection::RecoveredPacketList recovered_packets_ + RTC_GUARDED_BY(&sequence_checker_); + FecPacketCounter packet_counter_ RTC_GUARDED_BY(&sequence_checker_); }; } // namespace webrtc diff --git a/modules/rtp_rtcp/source/ulpfec_receiver_unittest.cc b/modules/rtp_rtcp/source/ulpfec_receiver_unittest.cc index 9dbaeb81f3..53d363de67 100644 --- a/modules/rtp_rtcp/source/ulpfec_receiver_unittest.cc +++ b/modules/rtp_rtcp/source/ulpfec_receiver_unittest.cc @@ -392,7 +392,7 @@ TEST_F(UlpfecReceiverTest, PacketNotDroppedTooEarly) { delayed_fec = fec_packets.front(); // Fill the FEC decoder. No packets should be dropped. - const size_t kNumMediaPacketsBatch2 = 46; + const size_t kNumMediaPacketsBatch2 = 191; std::list augmented_media_packets_batch2; ForwardErrorCorrection::PacketList media_packets_batch2; for (size_t i = 0; i < kNumMediaPacketsBatch2; ++i) { @@ -431,7 +431,7 @@ TEST_F(UlpfecReceiverTest, PacketDroppedWhenTooOld) { delayed_fec = fec_packets.front(); // Fill the FEC decoder and force the last packet to be dropped. - const size_t kNumMediaPacketsBatch2 = 48; + const size_t kNumMediaPacketsBatch2 = 192; std::list augmented_media_packets_batch2; ForwardErrorCorrection::PacketList media_packets_batch2; for (size_t i = 0; i < kNumMediaPacketsBatch2; ++i) { @@ -512,4 +512,31 @@ TEST_F(UlpfecReceiverTest, TruncatedPacketWithoutDataPastFirstBlock) { SurvivesMaliciousPacket(kPacket, sizeof(kPacket), 100); } +TEST_F(UlpfecReceiverTest, MediaWithPadding) { + const size_t kNumFecPackets = 1; + std::list augmented_media_packets; + ForwardErrorCorrection::PacketList media_packets; + PacketizeFrame(2, 0, &augmented_media_packets, &media_packets); + + // Append four bytes of padding to the first media packet. + const uint8_t kPadding[] = {0, 0, 0, 4}; + augmented_media_packets.front()->data.AppendData(kPadding); + augmented_media_packets.front()->data.MutableData()[0] |= 1 << 5; // P bit. + augmented_media_packets.front()->header.paddingLength = 4; + + std::list fec_packets; + EncodeFec(media_packets, kNumFecPackets, &fec_packets); + + auto it = augmented_media_packets.begin(); + BuildAndAddRedMediaPacket(augmented_media_packets.front()); + + VerifyReconstructedMediaPacket(**it, 1); + EXPECT_EQ(0, receiver_fec_->ProcessReceivedFec()); + + BuildAndAddRedFecPacket(fec_packets.front()); + ++it; + VerifyReconstructedMediaPacket(**it, 1); + EXPECT_EQ(0, receiver_fec_->ProcessReceivedFec()); +} + } // namespace webrtc diff --git a/modules/rtp_rtcp/source/video_rtp_depacketizer_vp8.h b/modules/rtp_rtcp/source/video_rtp_depacketizer_vp8.h index a7573993f7..3d7cb3291d 100644 --- a/modules/rtp_rtcp/source/video_rtp_depacketizer_vp8.h +++ b/modules/rtp_rtcp/source/video_rtp_depacketizer_vp8.h @@ -25,7 +25,7 @@ class VideoRtpDepacketizerVp8 : public VideoRtpDepacketizer { public: VideoRtpDepacketizerVp8() = default; VideoRtpDepacketizerVp8(const VideoRtpDepacketizerVp8&) = delete; - VideoRtpDepacketizerVp8& operator=(VideoRtpDepacketizerVp8&) = delete; + VideoRtpDepacketizerVp8& operator=(const VideoRtpDepacketizerVp8&) = delete; ~VideoRtpDepacketizerVp8() override = default; // Parses vp8 rtp payload descriptor. diff --git a/modules/rtp_rtcp/source/video_rtp_depacketizer_vp9.cc b/modules/rtp_rtcp/source/video_rtp_depacketizer_vp9.cc index a719d7ab12..be05009807 100644 --- a/modules/rtp_rtcp/source/video_rtp_depacketizer_vp9.cc +++ b/modules/rtp_rtcp/source/video_rtp_depacketizer_vp9.cc @@ -40,12 +40,12 @@ constexpr int kFailedToParse = 0; bool ParsePictureId(rtc::BitBuffer* parser, RTPVideoHeaderVP9* vp9) { uint32_t picture_id; uint32_t m_bit; - RETURN_FALSE_ON_ERROR(parser->ReadBits(&m_bit, 1)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(1, m_bit)); if (m_bit) { - RETURN_FALSE_ON_ERROR(parser->ReadBits(&picture_id, 15)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(15, picture_id)); vp9->max_picture_id = kMaxTwoBytePictureId; } else { - RETURN_FALSE_ON_ERROR(parser->ReadBits(&picture_id, 7)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(7, picture_id)); vp9->max_picture_id = kMaxOneBytePictureId; } vp9->picture_id = picture_id; @@ -60,10 +60,10 @@ bool ParsePictureId(rtc::BitBuffer* parser, RTPVideoHeaderVP9* vp9) { // bool ParseLayerInfoCommon(rtc::BitBuffer* parser, RTPVideoHeaderVP9* vp9) { uint32_t t, u_bit, s, d_bit; - RETURN_FALSE_ON_ERROR(parser->ReadBits(&t, 3)); - RETURN_FALSE_ON_ERROR(parser->ReadBits(&u_bit, 1)); - RETURN_FALSE_ON_ERROR(parser->ReadBits(&s, 3)); - RETURN_FALSE_ON_ERROR(parser->ReadBits(&d_bit, 1)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(3, t)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(1, u_bit)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(3, s)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(1, d_bit)); vp9->temporal_idx = t; vp9->temporal_up_switch = u_bit ? true : false; if (s >= kMaxSpatialLayers) @@ -84,7 +84,7 @@ bool ParseLayerInfoCommon(rtc::BitBuffer* parser, RTPVideoHeaderVP9* vp9) { bool ParseLayerInfoNonFlexibleMode(rtc::BitBuffer* parser, RTPVideoHeaderVP9* vp9) { uint8_t tl0picidx; - RETURN_FALSE_ON_ERROR(parser->ReadUInt8(&tl0picidx)); + RETURN_FALSE_ON_ERROR(parser->ReadUInt8(tl0picidx)); vp9->tl0_pic_idx = tl0picidx; return true; } @@ -117,8 +117,8 @@ bool ParseRefIndices(rtc::BitBuffer* parser, RTPVideoHeaderVP9* vp9) { return false; uint32_t p_diff; - RETURN_FALSE_ON_ERROR(parser->ReadBits(&p_diff, 7)); - RETURN_FALSE_ON_ERROR(parser->ReadBits(&n_bit, 1)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(7, p_diff)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(1, n_bit)); vp9->pid_diff[vp9->num_ref_pics] = p_diff; uint32_t scaled_pid = vp9->picture_id; @@ -154,9 +154,9 @@ bool ParseRefIndices(rtc::BitBuffer* parser, RTPVideoHeaderVP9* vp9) { // bool ParseSsData(rtc::BitBuffer* parser, RTPVideoHeaderVP9* vp9) { uint32_t n_s, y_bit, g_bit; - RETURN_FALSE_ON_ERROR(parser->ReadBits(&n_s, 3)); - RETURN_FALSE_ON_ERROR(parser->ReadBits(&y_bit, 1)); - RETURN_FALSE_ON_ERROR(parser->ReadBits(&g_bit, 1)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(3, n_s)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(1, y_bit)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(1, g_bit)); RETURN_FALSE_ON_ERROR(parser->ConsumeBits(3)); vp9->num_spatial_layers = n_s + 1; vp9->spatial_layer_resolution_present = y_bit ? true : false; @@ -164,20 +164,20 @@ bool ParseSsData(rtc::BitBuffer* parser, RTPVideoHeaderVP9* vp9) { if (y_bit) { for (size_t i = 0; i < vp9->num_spatial_layers; ++i) { - RETURN_FALSE_ON_ERROR(parser->ReadUInt16(&vp9->width[i])); - RETURN_FALSE_ON_ERROR(parser->ReadUInt16(&vp9->height[i])); + RETURN_FALSE_ON_ERROR(parser->ReadUInt16(vp9->width[i])); + RETURN_FALSE_ON_ERROR(parser->ReadUInt16(vp9->height[i])); } } if (g_bit) { uint8_t n_g; - RETURN_FALSE_ON_ERROR(parser->ReadUInt8(&n_g)); + RETURN_FALSE_ON_ERROR(parser->ReadUInt8(n_g)); vp9->gof.num_frames_in_gof = n_g; } for (size_t i = 0; i < vp9->gof.num_frames_in_gof; ++i) { uint32_t t, u_bit, r; - RETURN_FALSE_ON_ERROR(parser->ReadBits(&t, 3)); - RETURN_FALSE_ON_ERROR(parser->ReadBits(&u_bit, 1)); - RETURN_FALSE_ON_ERROR(parser->ReadBits(&r, 2)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(3, t)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(1, u_bit)); + RETURN_FALSE_ON_ERROR(parser->ReadBits(2, r)); RETURN_FALSE_ON_ERROR(parser->ConsumeBits(2)); vp9->gof.temporal_idx[i] = t; vp9->gof.temporal_up_switch[i] = u_bit ? true : false; @@ -185,7 +185,7 @@ bool ParseSsData(rtc::BitBuffer* parser, RTPVideoHeaderVP9* vp9) { for (uint8_t p = 0; p < vp9->gof.num_ref_pics[i]; ++p) { uint8_t p_diff; - RETURN_FALSE_ON_ERROR(parser->ReadUInt8(&p_diff)); + RETURN_FALSE_ON_ERROR(parser->ReadUInt8(p_diff)); vp9->gof.pid_diff[i][p] = p_diff; } } @@ -214,7 +214,7 @@ int VideoRtpDepacketizerVp9::ParseRtpPayload( // Parse mandatory first byte of payload descriptor. rtc::BitBuffer parser(rtp_payload.data(), rtp_payload.size()); uint8_t first_byte; - if (!parser.ReadUInt8(&first_byte)) { + if (!parser.ReadUInt8(first_byte)) { RTC_LOG(LS_ERROR) << "Payload length is zero."; return kFailedToParse; } diff --git a/modules/rtp_rtcp/source/video_rtp_depacketizer_vp9.h b/modules/rtp_rtcp/source/video_rtp_depacketizer_vp9.h index c622cbc75e..4bb358a15f 100644 --- a/modules/rtp_rtcp/source/video_rtp_depacketizer_vp9.h +++ b/modules/rtp_rtcp/source/video_rtp_depacketizer_vp9.h @@ -25,7 +25,7 @@ class VideoRtpDepacketizerVp9 : public VideoRtpDepacketizer { public: VideoRtpDepacketizerVp9() = default; VideoRtpDepacketizerVp9(const VideoRtpDepacketizerVp9&) = delete; - VideoRtpDepacketizerVp9& operator=(VideoRtpDepacketizerVp9&) = delete; + VideoRtpDepacketizerVp9& operator=(const VideoRtpDepacketizerVp9&) = delete; ~VideoRtpDepacketizerVp9() override = default; // Parses vp9 rtp payload descriptor. diff --git a/modules/rtp_rtcp/test/testFec/test_packet_masks_metrics.cc b/modules/rtp_rtcp/test/testFec/test_packet_masks_metrics.cc index 44597b85bb..dffdf2ebf6 100644 --- a/modules/rtp_rtcp/test/testFec/test_packet_masks_metrics.cc +++ b/modules/rtp_rtcp/test/testFec/test_packet_masks_metrics.cc @@ -225,7 +225,7 @@ class FecPacketMaskMetricsTest : public ::testing::Test { } } // Check that we can only recover 1 packet. - assert(check_num_recovered == 1); + RTC_DCHECK_EQ(check_num_recovered, 1); // Update the state with the newly recovered media packet. state_tmp[jsel] = 0; } @@ -260,7 +260,7 @@ class FecPacketMaskMetricsTest : public ::testing::Test { } } } else { // Gilbert-Elliot model for burst model. - assert(loss_model_[k].loss_type == kBurstyLossModel); + RTC_DCHECK_EQ(loss_model_[k].loss_type, kBurstyLossModel); // Transition probabilities: from previous to current state. // Prob. of previous = lost --> current = received. double prob10 = 1.0 / burst_length; @@ -425,8 +425,8 @@ class FecPacketMaskMetricsTest : public ::testing::Test { } } } // Done with loop over total number of packets. - assert(num_media_packets_lost <= num_media_packets); - assert(num_packets_lost <= tot_num_packets && num_packets_lost > 0); + RTC_DCHECK_LE(num_media_packets_lost, num_media_packets); + RTC_DCHECK_LE(num_packets_lost, tot_num_packets && num_packets_lost > 0); double residual_loss = 0.0; // Only need to compute residual loss (number of recovered packets) for // configurations that have at least one media packet lost. @@ -445,7 +445,7 @@ class FecPacketMaskMetricsTest : public ::testing::Test { num_recovered_packets = num_media_packets_lost; } } - assert(num_recovered_packets <= num_media_packets); + RTC_DCHECK_LE(num_recovered_packets, num_media_packets); // Compute the residual loss. We only care about recovering media/source // packets, so residual loss is based on lost/recovered media packets. residual_loss = @@ -464,9 +464,9 @@ class FecPacketMaskMetricsTest : public ::testing::Test { // Update the distribution statistics. // Compute the gap of the loss (the "consecutiveness" of the loss). int gap_loss = GapLoss(tot_num_packets, state.get()); - assert(gap_loss < kMaxGapSize); + RTC_DCHECK_LT(gap_loss, kMaxGapSize); int index = gap_loss * (2 * kMaxMediaPacketsTest) + num_packets_lost; - assert(index < kNumStatesDistribution); + RTC_DCHECK_LT(index, kNumStatesDistribution); metrics_code.residual_loss_per_loss_gap[index] += residual_loss; if (code_type == xor_random_code) { // The configuration density is only a function of the code length and @@ -492,8 +492,8 @@ class FecPacketMaskMetricsTest : public ::testing::Test { metrics_code.variance_residual_loss[k] - (metrics_code.average_residual_loss[k] * metrics_code.average_residual_loss[k]); - assert(metrics_code.variance_residual_loss[k] >= 0.0); - assert(metrics_code.average_residual_loss[k] > 0.0); + RTC_DCHECK_GE(metrics_code.variance_residual_loss[k], 0.0); + RTC_DCHECK_GT(metrics_code.average_residual_loss[k], 0.0); metrics_code.variance_residual_loss[k] = std::sqrt(metrics_code.variance_residual_loss[k]) / metrics_code.average_residual_loss[k]; @@ -509,7 +509,7 @@ class FecPacketMaskMetricsTest : public ::testing::Test { } else if (code_type == xor_bursty_code) { CopyMetrics(&kMetricsXorBursty[code_index], metrics_code); } else { - assert(false); + RTC_NOTREACHED(); } } @@ -588,7 +588,7 @@ class FecPacketMaskMetricsTest : public ::testing::Test { num_loss_models++; } } - assert(num_loss_models == kNumLossModels); + RTC_DCHECK_EQ(num_loss_models, kNumLossModels); } void SetCodeParams() { @@ -738,7 +738,7 @@ class FecPacketMaskMetricsTest : public ::testing::Test { code_index++; } } - assert(code_index == kNumberCodes); + RTC_DCHECK_EQ(code_index, kNumberCodes); return 0; } diff --git a/modules/utility/BUILD.gn b/modules/utility/BUILD.gn index df6945ab2c..aca7b1efdd 100644 --- a/modules/utility/BUILD.gn +++ b/modules/utility/BUILD.gn @@ -31,6 +31,7 @@ rtc_library("utility") { deps = [ "..:module_api", + "../../api:sequence_checker", "../../api/task_queue", "../../common_audio", "../../rtc_base:checks", diff --git a/modules/utility/include/jvm_android.h b/modules/utility/include/jvm_android.h index 3caab87761..693ee519ed 100644 --- a/modules/utility/include/jvm_android.h +++ b/modules/utility/include/jvm_android.h @@ -16,8 +16,8 @@ #include #include +#include "api/sequence_checker.h" #include "modules/utility/include/helpers_android.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -34,7 +34,7 @@ class JvmThreadConnector { ~JvmThreadConnector(); private: - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; bool attached_; }; @@ -111,7 +111,7 @@ class JNIEnvironment { std::string JavaToStdString(const jstring& j_string); private: - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; JNIEnv* const jni_; }; @@ -184,7 +184,7 @@ class JVM { private: JNIEnv* jni() const { return GetEnv(jvm_); } - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; JavaVM* const jvm_; }; diff --git a/modules/utility/source/process_thread_impl.cc b/modules/utility/source/process_thread_impl.cc index 3709306925..73fc23400b 100644 --- a/modules/utility/source/process_thread_impl.cc +++ b/modules/utility/source/process_thread_impl.cc @@ -48,7 +48,6 @@ ProcessThreadImpl::ProcessThreadImpl(const char* thread_name) ProcessThreadImpl::~ProcessThreadImpl() { RTC_DCHECK(thread_checker_.IsCurrent()); - RTC_DCHECK(!thread_.get()); RTC_DCHECK(!stop_); while (!delayed_tasks_.empty()) { @@ -69,10 +68,11 @@ void ProcessThreadImpl::Delete() { delete this; } -void ProcessThreadImpl::Start() { +// Doesn't need locking, because the contending thread isn't running. +void ProcessThreadImpl::Start() RTC_NO_THREAD_SAFETY_ANALYSIS { RTC_DCHECK(thread_checker_.IsCurrent()); - RTC_DCHECK(!thread_.get()); - if (thread_.get()) + RTC_DCHECK(thread_.empty()); + if (!thread_.empty()) return; RTC_DCHECK(!stop_); @@ -80,47 +80,84 @@ void ProcessThreadImpl::Start() { for (ModuleCallback& m : modules_) m.module->ProcessThreadAttached(this); - thread_.reset( - new rtc::PlatformThread(&ProcessThreadImpl::Run, this, thread_name_)); - thread_->Start(); + thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { + CurrentTaskQueueSetter set_current(this); + while (Process()) { + } + }, + thread_name_); } void ProcessThreadImpl::Stop() { RTC_DCHECK(thread_checker_.IsCurrent()); - if (!thread_.get()) + if (thread_.empty()) return; { - rtc::CritScope lock(&lock_); + // Need to take lock, for synchronization with `thread_`. + MutexLock lock(&mutex_); stop_ = true; } wake_up_.Set(); + thread_.Finalize(); + + StopNoLocks(); +} - thread_->Stop(); +// No locking needed, since this is called after the contending thread is +// stopped. +void ProcessThreadImpl::StopNoLocks() RTC_NO_THREAD_SAFETY_ANALYSIS { + RTC_DCHECK(thread_.empty()); stop_ = false; - thread_.reset(); for (ModuleCallback& m : modules_) m.module->ProcessThreadAttached(nullptr); } void ProcessThreadImpl::WakeUp(Module* module) { // Allowed to be called on any thread. - { - rtc::CritScope lock(&lock_); - for (ModuleCallback& m : modules_) { - if (m.module == module) - m.next_callback = kCallProcessImmediately; + auto holds_mutex = [this] { + if (!IsCurrent()) { + return false; } + RTC_DCHECK_RUN_ON(this); + return holds_mutex_; + }; + if (holds_mutex()) { + // Avoid locking if called on the ProcessThread, via a module's Process), + WakeUpNoLocks(module); + } else { + MutexLock lock(&mutex_); + WakeUpInternal(module); } wake_up_.Set(); } +// Must be called only indirectly from Process, which already holds the lock. +void ProcessThreadImpl::WakeUpNoLocks(Module* module) + RTC_NO_THREAD_SAFETY_ANALYSIS { + RTC_DCHECK_RUN_ON(this); + WakeUpInternal(module); +} + +void ProcessThreadImpl::WakeUpInternal(Module* module) { + for (ModuleCallback& m : modules_) { + if (m.module == module) + m.next_callback = kCallProcessImmediately; + } +} + void ProcessThreadImpl::PostTask(std::unique_ptr task) { - // Allowed to be called on any thread. + // Allowed to be called on any thread, except from a module's Process method. + if (IsCurrent()) { + RTC_DCHECK_RUN_ON(this); + RTC_DCHECK(!holds_mutex_) << "Calling ProcessThread::PostTask from " + "Module::Process is not supported"; + } { - rtc::CritScope lock(&lock_); + MutexLock lock(&mutex_); queue_.push(task.release()); } wake_up_.Set(); @@ -131,7 +168,7 @@ void ProcessThreadImpl::PostDelayedTask(std::unique_ptr task, int64_t run_at_ms = rtc::TimeMillis() + milliseconds; bool recalculate_wakeup_time; { - rtc::CritScope lock(&lock_); + MutexLock lock(&mutex_); recalculate_wakeup_time = delayed_tasks_.empty() || run_at_ms < delayed_tasks_.top().run_at_ms; delayed_tasks_.emplace(run_at_ms, std::move(task)); @@ -143,13 +180,14 @@ void ProcessThreadImpl::PostDelayedTask(std::unique_ptr task, void ProcessThreadImpl::RegisterModule(Module* module, const rtc::Location& from) { + TRACE_EVENT0("webrtc", "ProcessThreadImpl::RegisterModule"); RTC_DCHECK(thread_checker_.IsCurrent()); RTC_DCHECK(module) << from.ToString(); #if RTC_DCHECK_IS_ON { // Catch programmer error. - rtc::CritScope lock(&lock_); + MutexLock lock(&mutex_); for (const ModuleCallback& mc : modules_) { RTC_DCHECK(mc.module != module) << "Already registered here: " << mc.location.ToString() @@ -163,11 +201,11 @@ void ProcessThreadImpl::RegisterModule(Module* module, // Now that we know the module isn't in the list, we'll call out to notify // the module that it's attached to the worker thread. We don't hold // the lock while we make this call. - if (thread_.get()) + if (!thread_.empty()) module->ProcessThreadAttached(this); { - rtc::CritScope lock(&lock_); + MutexLock lock(&mutex_); modules_.push_back(ModuleCallback(module, from)); } @@ -182,7 +220,7 @@ void ProcessThreadImpl::DeRegisterModule(Module* module) { RTC_DCHECK(module); { - rtc::CritScope lock(&lock_); + MutexLock lock(&mutex_); modules_.remove_if( [&module](const ModuleCallback& m) { return m.module == module; }); } @@ -191,21 +229,13 @@ void ProcessThreadImpl::DeRegisterModule(Module* module) { module->ProcessThreadAttached(nullptr); } -// static -void ProcessThreadImpl::Run(void* obj) { - ProcessThreadImpl* impl = static_cast(obj); - CurrentTaskQueueSetter set_current(impl); - while (impl->Process()) { - } -} - bool ProcessThreadImpl::Process() { TRACE_EVENT1("webrtc", "ProcessThreadImpl", "name", thread_name_); int64_t now = rtc::TimeMillis(); int64_t next_checkpoint = now + (1000 * 60); - + RTC_DCHECK_RUN_ON(this); { - rtc::CritScope lock(&lock_); + MutexLock lock(&mutex_); if (stop_) return false; for (ModuleCallback& m : modules_) { @@ -216,6 +246,8 @@ bool ProcessThreadImpl::Process() { if (m.next_callback == 0) m.next_callback = GetNextCallbackTime(m.module, now); + // Set to true for the duration of the calls to modules' Process(). + holds_mutex_ = true; if (m.next_callback <= now || m.next_callback == kCallProcessImmediately) { { @@ -230,6 +262,7 @@ bool ProcessThreadImpl::Process() { int64_t new_now = rtc::TimeMillis(); m.next_callback = GetNextCallbackTime(m.module, new_now); } + holds_mutex_ = false; if (m.next_callback < next_checkpoint) next_checkpoint = m.next_callback; @@ -248,11 +281,11 @@ bool ProcessThreadImpl::Process() { while (!queue_.empty()) { QueuedTask* task = queue_.front(); queue_.pop(); - lock_.Leave(); + mutex_.Unlock(); if (task->Run()) { delete task; } - lock_.Enter(); + mutex_.Lock(); } } diff --git a/modules/utility/source/process_thread_impl.h b/modules/utility/source/process_thread_impl.h index ed9f5c3bfc..5d22e37ca1 100644 --- a/modules/utility/source/process_thread_impl.h +++ b/modules/utility/source/process_thread_impl.h @@ -17,14 +17,13 @@ #include #include +#include "api/sequence_checker.h" #include "api/task_queue/queued_task.h" #include "modules/include/module.h" #include "modules/utility/include/process_thread.h" -#include "rtc_base/deprecated/recursive_critical_section.h" #include "rtc_base/event.h" #include "rtc_base/location.h" #include "rtc_base/platform_thread.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -45,7 +44,6 @@ class ProcessThreadImpl : public ProcessThread { void DeRegisterModule(Module* module) override; protected: - static void Run(void* obj); bool Process(); private: @@ -85,25 +83,32 @@ class ProcessThreadImpl : public ProcessThread { typedef std::list ModuleList; void Delete() override; + // The part of Stop processing that doesn't need any locking. + void StopNoLocks(); + void WakeUpNoLocks(Module* module); + void WakeUpInternal(Module* module) RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - // Warning: For some reason, if |lock_| comes immediately before |modules_| - // with the current class layout, we will start to have mysterious crashes - // on Mac 10.9 debug. I (Tommi) suspect we're hitting some obscure alignemnt - // issues, but I haven't figured out what they are, if there are alignment - // requirements for mutexes on Mac or if there's something else to it. - // So be careful with changing the layout. - rtc::RecursiveCriticalSection - lock_; // Used to guard modules_, tasks_ and stop_. + // Members protected by this mutex are accessed on the constructor thread and + // on the spawned process thread, and locking is needed only while the process + // thread is running. + Mutex mutex_; - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; rtc::Event wake_up_; - // TODO(pbos): Remove unique_ptr and stop recreating the thread. - std::unique_ptr thread_; + rtc::PlatformThread thread_; - ModuleList modules_; + ModuleList modules_ RTC_GUARDED_BY(mutex_); + // Set to true when calling Process, to allow reentrant calls to WakeUp. + bool holds_mutex_ RTC_GUARDED_BY(this) = false; std::queue queue_; - std::priority_queue delayed_tasks_ RTC_GUARDED_BY(lock_); - bool stop_; + std::priority_queue delayed_tasks_ RTC_GUARDED_BY(mutex_); + // The `stop_` flag is modified only by the construction thread, protected by + // `thread_checker_`. It is read also by the spawned `thread_`. The latter + // thread must take `mutex_` before access, and for thread safety, the + // constructor thread needs to take `mutex_` when it modifies `stop_` and + // `thread_` is running. Annotations like RTC_GUARDED_BY doesn't support this + // usage pattern. + bool stop_ RTC_GUARDED_BY(mutex_); const char* thread_name_; }; diff --git a/modules/video_capture/device_info_impl.cc b/modules/video_capture/device_info_impl.cc index 846977e68f..d5abb29407 100644 --- a/modules/video_capture/device_info_impl.cc +++ b/modules/video_capture/device_info_impl.cc @@ -52,7 +52,7 @@ int32_t DeviceInfoImpl::NumberOfCapabilities(const char* deviceUniqueIdUTF8) { int32_t DeviceInfoImpl::GetCapability(const char* deviceUniqueIdUTF8, const uint32_t deviceCapabilityNumber, VideoCaptureCapability& capability) { - assert(deviceUniqueIdUTF8 != NULL); + RTC_DCHECK(deviceUniqueIdUTF8); MutexLock lock(&_apiLock); diff --git a/modules/video_capture/linux/device_info_linux.cc b/modules/video_capture/linux/device_info_linux.cc index 3c8fdd20fa..cde3b86d5c 100644 --- a/modules/video_capture/linux/device_info_linux.cc +++ b/modules/video_capture/linux/device_info_linux.cc @@ -42,8 +42,6 @@ int32_t DeviceInfoLinux::Init() { DeviceInfoLinux::~DeviceInfoLinux() {} uint32_t DeviceInfoLinux::NumberOfDevices() { - RTC_LOG(LS_INFO) << __FUNCTION__; - uint32_t count = 0; char device[20]; int fd = -1; @@ -75,8 +73,6 @@ int32_t DeviceInfoLinux::GetDeviceName(uint32_t deviceNumber, uint32_t deviceUniqueIdUTF8Length, char* /*productUniqueIdUTF8*/, uint32_t /*productUniqueIdUTF8Length*/) { - RTC_LOG(LS_INFO) << __FUNCTION__; - // Travel through /dev/video [0-63] uint32_t count = 0; char device[20]; @@ -120,7 +116,7 @@ int32_t DeviceInfoLinux::GetDeviceName(uint32_t deviceNumber, memset(deviceNameUTF8, 0, deviceNameLength); memcpy(cameraName, cap.card, sizeof(cap.card)); - if (deviceNameLength >= strlen(cameraName)) { + if (deviceNameLength > strlen(cameraName)) { memcpy(deviceNameUTF8, cameraName, strlen(cameraName)); } else { RTC_LOG(LS_INFO) << "buffer passed is too small"; @@ -130,7 +126,7 @@ int32_t DeviceInfoLinux::GetDeviceName(uint32_t deviceNumber, if (cap.bus_info[0] != 0) // may not available in all drivers { // copy device id - if (deviceUniqueIdUTF8Length >= strlen((const char*)cap.bus_info)) { + if (deviceUniqueIdUTF8Length > strlen((const char*)cap.bus_info)) { memset(deviceUniqueIdUTF8, 0, deviceUniqueIdUTF8Length); memcpy(deviceUniqueIdUTF8, cap.bus_info, strlen((const char*)cap.bus_info)); @@ -150,7 +146,7 @@ int32_t DeviceInfoLinux::CreateCapabilityMap(const char* deviceUniqueIdUTF8) { const int32_t deviceUniqueIdUTF8Length = (int32_t)strlen((char*)deviceUniqueIdUTF8); - if (deviceUniqueIdUTF8Length > kVideoCaptureUniqueNameLength) { + if (deviceUniqueIdUTF8Length >= kVideoCaptureUniqueNameLength) { RTC_LOG(LS_INFO) << "Device name too long"; return -1; } diff --git a/modules/video_capture/linux/video_capture_linux.cc b/modules/video_capture/linux/video_capture_linux.cc index 504565f512..10f9713ec3 100644 --- a/modules/video_capture/linux/video_capture_linux.cc +++ b/modules/video_capture/linux/video_capture_linux.cc @@ -34,8 +34,7 @@ namespace webrtc { namespace videocapturemodule { rtc::scoped_refptr VideoCaptureImpl::Create( const char* deviceUniqueId) { - rtc::scoped_refptr implementation( - new rtc::RefCountedObject()); + auto implementation = rtc::make_ref_counted(); if (implementation->Init(deviceUniqueId) != 0) return nullptr; @@ -241,12 +240,15 @@ int32_t VideoCaptureModuleV4L2::StartCapture( } // start capture thread; - if (!_captureThread) { + if (_captureThread.empty()) { quit_ = false; - _captureThread.reset( - new rtc::PlatformThread(VideoCaptureModuleV4L2::CaptureThread, this, - "CaptureThread", rtc::kHighPriority)); - _captureThread->Start(); + _captureThread = rtc::PlatformThread::SpawnJoinable( + [this] { + while (CaptureProcess()) { + } + }, + "CaptureThread", + rtc::ThreadAttributes().SetPriority(rtc::ThreadPriority::kHigh)); } // Needed to start UVC camera - from the uvcview application @@ -262,14 +264,13 @@ int32_t VideoCaptureModuleV4L2::StartCapture( } int32_t VideoCaptureModuleV4L2::StopCapture() { - if (_captureThread) { + if (!_captureThread.empty()) { { MutexLock lock(&capture_lock_); quit_ = true; } - // Make sure the capture thread stop stop using the critsect. - _captureThread->Stop(); - _captureThread.reset(); + // Make sure the capture thread stops using the mutex. + _captureThread.Finalize(); } MutexLock lock(&capture_lock_); @@ -357,11 +358,6 @@ bool VideoCaptureModuleV4L2::CaptureStarted() { return _captureStarted; } -void VideoCaptureModuleV4L2::CaptureThread(void* obj) { - VideoCaptureModuleV4L2* capture = static_cast(obj); - while (capture->CaptureProcess()) { - } -} bool VideoCaptureModuleV4L2::CaptureProcess() { int retVal = 0; fd_set rSet; diff --git a/modules/video_capture/linux/video_capture_linux.h b/modules/video_capture/linux/video_capture_linux.h index ddb5d5ba87..fa06d72b8d 100644 --- a/modules/video_capture/linux/video_capture_linux.h +++ b/modules/video_capture/linux/video_capture_linux.h @@ -41,8 +41,7 @@ class VideoCaptureModuleV4L2 : public VideoCaptureImpl { bool AllocateVideoBuffers(); bool DeAllocateVideoBuffers(); - // TODO(pbos): Stop using unique_ptr and resetting the thread. - std::unique_ptr _captureThread; + rtc::PlatformThread _captureThread; Mutex capture_lock_; bool quit_ RTC_GUARDED_BY(capture_lock_); int32_t _deviceId; diff --git a/modules/video_capture/test/video_capture_unittest.cc b/modules/video_capture/test/video_capture_unittest.cc index 1a0cf2d5da..e74a456cee 100644 --- a/modules/video_capture/test/video_capture_unittest.cc +++ b/modules/video_capture/test/video_capture_unittest.cc @@ -152,7 +152,7 @@ class VideoCaptureTest : public ::testing::Test { void SetUp() override { device_info_.reset(VideoCaptureFactory::CreateDeviceInfo()); - assert(device_info_.get()); + RTC_DCHECK(device_info_.get()); number_of_devices_ = device_info_->NumberOfDevices(); ASSERT_GT(number_of_devices_, 0u); } diff --git a/modules/video_capture/windows/device_info_ds.cc b/modules/video_capture/windows/device_info_ds.cc index f43c508bee..3731dce8bc 100644 --- a/modules/video_capture/windows/device_info_ds.cc +++ b/modules/video_capture/windows/device_info_ds.cc @@ -72,10 +72,10 @@ DeviceInfoDS::DeviceInfoDS() // Details: hr = 0x80010106 <=> "Cannot change thread mode after it is // set". // - RTC_LOG(LS_INFO) << __FUNCTION__ - << ": CoInitializeEx(NULL, COINIT_APARTMENTTHREADED)" - " => RPC_E_CHANGED_MODE, error 0x" - << rtc::ToHex(hr); + RTC_DLOG(LS_INFO) << __FUNCTION__ + << ": CoInitializeEx(NULL, COINIT_APARTMENTTHREADED)" + " => RPC_E_CHANGED_MODE, error 0x" + << rtc::ToHex(hr); } } } @@ -203,7 +203,7 @@ int32_t DeviceInfoDS::GetDeviceInfo(uint32_t deviceNumber, } } if (deviceNameLength) { - RTC_LOG(LS_INFO) << __FUNCTION__ << " " << deviceNameUTF8; + RTC_DLOG(LS_INFO) << __FUNCTION__ << " " << deviceNameUTF8; } return index; } @@ -213,7 +213,7 @@ IBaseFilter* DeviceInfoDS::GetDeviceFilter(const char* deviceUniqueIdUTF8, uint32_t productUniqueIdUTF8Length) { const int32_t deviceUniqueIdUTF8Length = (int32_t)strlen( (char*)deviceUniqueIdUTF8); // UTF8 is also NULL terminated - if (deviceUniqueIdUTF8Length > kVideoCaptureUniqueNameLength) { + if (deviceUniqueIdUTF8Length >= kVideoCaptureUniqueNameLength) { RTC_LOG(LS_INFO) << "Device name too long"; return NULL; } @@ -306,7 +306,7 @@ int32_t DeviceInfoDS::CreateCapabilityMap(const char* deviceUniqueIdUTF8) const int32_t deviceUniqueIdUTF8Length = (int32_t)strlen((char*)deviceUniqueIdUTF8); - if (deviceUniqueIdUTF8Length > kVideoCaptureUniqueNameLength) { + if (deviceUniqueIdUTF8Length >= kVideoCaptureUniqueNameLength) { RTC_LOG(LS_INFO) << "Device name too long"; return -1; } @@ -380,7 +380,7 @@ int32_t DeviceInfoDS::CreateCapabilityMap(const char* deviceUniqueIdUTF8) supportFORMAT_VideoInfo2 = true; VIDEOINFOHEADER2* h = reinterpret_cast(pmt->pbFormat); - assert(h); + RTC_DCHECK(h); foundInterlacedFormat |= h->dwInterlaceFlags & (AMINTERLACE_IsInterlaced | AMINTERLACE_DisplayModeBobOnly); @@ -418,7 +418,7 @@ int32_t DeviceInfoDS::CreateCapabilityMap(const char* deviceUniqueIdUTF8) if (pmt->formattype == FORMAT_VideoInfo) { VIDEOINFOHEADER* h = reinterpret_cast(pmt->pbFormat); - assert(h); + RTC_DCHECK(h); capability.directShowCapabilityIndex = tmp; capability.width = h->bmiHeader.biWidth; capability.height = h->bmiHeader.biHeight; @@ -427,7 +427,7 @@ int32_t DeviceInfoDS::CreateCapabilityMap(const char* deviceUniqueIdUTF8) if (pmt->formattype == FORMAT_VideoInfo2) { VIDEOINFOHEADER2* h = reinterpret_cast(pmt->pbFormat); - assert(h); + RTC_DCHECK(h); capability.directShowCapabilityIndex = tmp; capability.width = h->bmiHeader.biWidth; capability.height = h->bmiHeader.biHeight; @@ -568,7 +568,7 @@ void DeviceInfoDS::GetProductId(const char* devicePath, // Find the second occurrence. pos = strchr(pos + 1, '&'); uint32_t bytesToCopy = (uint32_t)(pos - startPos); - if (pos && (bytesToCopy <= productUniqueIdUTF8Length) && + if (pos && (bytesToCopy < productUniqueIdUTF8Length) && bytesToCopy <= kVideoCaptureProductIdLength) { strncpy_s((char*)productUniqueIdUTF8, productUniqueIdUTF8Length, (char*)startPos, bytesToCopy); diff --git a/modules/video_capture/windows/sink_filter_ds.cc b/modules/video_capture/windows/sink_filter_ds.cc index 9019b127cf..e4be7aa14f 100644 --- a/modules/video_capture/windows/sink_filter_ds.cc +++ b/modules/video_capture/windows/sink_filter_ds.cc @@ -58,7 +58,7 @@ class EnumPins : public IEnumPins { } STDMETHOD(Clone)(IEnumPins** pins) { - RTC_DCHECK(false); + RTC_NOTREACHED(); return E_NOTIMPL; } @@ -83,7 +83,7 @@ class EnumPins : public IEnumPins { } STDMETHOD(Skip)(ULONG count) { - RTC_DCHECK(false); + RTC_NOTREACHED(); return E_NOTIMPL; } @@ -274,7 +274,7 @@ class MediaTypesEnum : public IEnumMediaTypes { // IEnumMediaTypes STDMETHOD(Clone)(IEnumMediaTypes** pins) { - RTC_DCHECK(false); + RTC_NOTREACHED(); return E_NOTIMPL; } @@ -364,7 +364,7 @@ class MediaTypesEnum : public IEnumMediaTypes { } STDMETHOD(Skip)(ULONG count) { - RTC_DCHECK(false); + RTC_NOTREACHED(); return E_NOTIMPL; } @@ -538,7 +538,7 @@ STDMETHODIMP CaptureInputPin::Connect(IPin* receive_pin, return VFW_E_NOT_STOPPED; if (receive_pin_) { - RTC_DCHECK(false); + RTC_NOTREACHED(); return VFW_E_ALREADY_CONNECTED; } @@ -564,7 +564,7 @@ STDMETHODIMP CaptureInputPin::ReceiveConnection( RTC_DCHECK(Filter()->IsStopped()); if (receive_pin_) { - RTC_DCHECK(false); + RTC_NOTREACHED(); return VFW_E_ALREADY_CONNECTED; } diff --git a/modules/video_capture/windows/sink_filter_ds.h b/modules/video_capture/windows/sink_filter_ds.h index af264a937a..b0fabda3cd 100644 --- a/modules/video_capture/windows/sink_filter_ds.h +++ b/modules/video_capture/windows/sink_filter_ds.h @@ -17,10 +17,10 @@ #include #include +#include "api/sequence_checker.h" #include "modules/video_capture/video_capture_impl.h" #include "modules/video_capture/windows/help_functions_ds.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" namespace webrtc { namespace videocapturemodule { @@ -89,8 +89,8 @@ class CaptureInputPin : public IMemInputPin, public IPin { STDMETHOD(ReceiveCanBlock)() override; // clang-format on - rtc::ThreadChecker main_checker_; - rtc::ThreadChecker capture_checker_; + SequenceChecker main_checker_; + SequenceChecker capture_checker_; VideoCaptureCapability requested_capability_ RTC_GUARDED_BY(main_checker_); // Accessed on the main thread when Filter()->IsStopped() (capture thread not @@ -147,7 +147,7 @@ class CaptureSinkFilter : public IBaseFilter { virtual ~CaptureSinkFilter(); private: - rtc::ThreadChecker main_checker_; + SequenceChecker main_checker_; const rtc::scoped_refptr> input_pin_; VideoCaptureImpl* const capture_observer_; FILTER_INFO info_ RTC_GUARDED_BY(main_checker_) = {}; diff --git a/modules/video_capture/windows/video_capture_ds.cc b/modules/video_capture/windows/video_capture_ds.cc index 6dca74750c..1a1e51934d 100644 --- a/modules/video_capture/windows/video_capture_ds.cc +++ b/modules/video_capture/windows/video_capture_ds.cc @@ -57,7 +57,7 @@ VideoCaptureDS::~VideoCaptureDS() { int32_t VideoCaptureDS::Init(const char* deviceUniqueIdUTF8) { const int32_t nameLength = (int32_t)strlen((char*)deviceUniqueIdUTF8); - if (nameLength > kVideoCaptureUniqueNameLength) + if (nameLength >= kVideoCaptureUniqueNameLength) return -1; // Store the device name diff --git a/modules/video_capture/windows/video_capture_factory_windows.cc b/modules/video_capture/windows/video_capture_factory_windows.cc index ea9d31add9..34cc982d7e 100644 --- a/modules/video_capture/windows/video_capture_factory_windows.cc +++ b/modules/video_capture/windows/video_capture_factory_windows.cc @@ -27,8 +27,7 @@ rtc::scoped_refptr VideoCaptureImpl::Create( return nullptr; // TODO(tommi): Use Media Foundation implementation for Vista and up. - rtc::scoped_refptr capture( - new rtc::RefCountedObject()); + auto capture = rtc::make_ref_counted(); if (capture->Init(device_id) != 0) { return nullptr; } diff --git a/modules/video_coding/BUILD.gn b/modules/video_coding/BUILD.gn index 713c41bbd6..50f2e8d836 100644 --- a/modules/video_coding/BUILD.gn +++ b/modules/video_coding/BUILD.gn @@ -81,6 +81,7 @@ rtc_library("nack_module") { deps = [ "..:module_api", + "../../api:sequence_checker", "../../api/units:time_delta", "../../api/units:timestamp", "../../rtc_base:checks", @@ -88,7 +89,6 @@ rtc_library("nack_module") { "../../rtc_base:rtc_numerics", "../../rtc_base:rtc_task_queue", "../../rtc_base/experiments:field_trial_parser", - "../../rtc_base/synchronization:sequence_checker", "../../rtc_base/task_utils:pending_task_safety_flag", "../../rtc_base/task_utils:repeating_task", "../../system_wrappers", @@ -168,8 +168,10 @@ rtc_library("video_coding") { "../../api:rtp_headers", "../../api:rtp_packet_info", "../../api:scoped_refptr", + "../../api:sequence_checker", "../../api/units:data_rate", "../../api/units:time_delta", + "../../api/units:timestamp", "../../api/video:builtin_video_bitrate_allocator_factory", "../../api/video:encoded_frame", "../../api/video:encoded_image", @@ -185,10 +187,10 @@ rtc_library("video_coding") { "../../common_video", "../../rtc_base", "../../rtc_base:checks", - "../../rtc_base:deprecation", "../../rtc_base:rtc_base_approved", "../../rtc_base:rtc_numerics", "../../rtc_base:rtc_task_queue", + "../../rtc_base:threading", "../../rtc_base/experiments:alr_experiment", "../../rtc_base/experiments:field_trial_parser", "../../rtc_base/experiments:jitter_upper_bound_experiment", @@ -196,7 +198,6 @@ rtc_library("video_coding") { "../../rtc_base/experiments:rate_control_settings", "../../rtc_base/experiments:rtt_mult_experiment", "../../rtc_base/synchronization:mutex", - "../../rtc_base/synchronization:sequence_checker", "../../rtc_base/system:no_unique_address", "../../rtc_base/task_utils:repeating_task", "../../rtc_base/task_utils:to_queued_task", @@ -235,7 +236,6 @@ rtc_library("video_codec_interface") { "../../api/video_codecs:video_codecs_api", "../../common_video", "../../common_video/generic_frame_descriptor", - "../../rtc_base:deprecation", "../../rtc_base/system:rtc_export", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] @@ -273,6 +273,8 @@ rtc_library("video_coding_legacy") { "..:module_api_public", "../../api:rtp_headers", "../../api:rtp_packet_info", + "../../api:sequence_checker", + "../../api/units:timestamp", "../../api/video:encoded_image", "../../api/video:video_frame", "../../api/video:video_frame_type", @@ -285,7 +287,6 @@ rtc_library("video_coding_legacy") { "../../rtc_base:rtc_base_approved", "../../rtc_base:rtc_event", "../../rtc_base/synchronization:mutex", - "../../rtc_base/synchronization:sequence_checker", "../../system_wrappers", "../rtp_rtcp:rtp_rtcp_format", "../rtp_rtcp:rtp_video_header", @@ -323,6 +324,8 @@ rtc_library("video_coding_utility") { "utility/ivf_file_reader.h", "utility/ivf_file_writer.cc", "utility/ivf_file_writer.h", + "utility/qp_parser.cc", + "utility/qp_parser.h", "utility/quality_scaler.cc", "utility/quality_scaler.h", "utility/simulcast_rate_allocator.cc", @@ -338,11 +341,13 @@ rtc_library("video_coding_utility") { deps = [ ":video_codec_interface", "../../api:scoped_refptr", + "../../api:sequence_checker", "../../api/video:encoded_frame", "../../api/video:encoded_image", "../../api/video:video_adaptation", "../../api/video:video_bitrate_allocation", "../../api/video:video_bitrate_allocator", + "../../api/video:video_codec_constants", "../../api/video:video_frame", "../../api/video_codecs:video_codecs_api", "../../common_video", @@ -356,7 +361,7 @@ rtc_library("video_coding_utility") { "../../rtc_base/experiments:quality_scaling_experiment", "../../rtc_base/experiments:rate_control_settings", "../../rtc_base/experiments:stable_target_rate_experiment", - "../../rtc_base/synchronization:sequence_checker", + "../../rtc_base/synchronization:mutex", "../../rtc_base/system:arch", "../../rtc_base/system:file_wrapper", "../../rtc_base/system:no_unique_address", @@ -365,7 +370,10 @@ rtc_library("video_coding_utility") { "../../system_wrappers:field_trial", "../rtp_rtcp:rtp_rtcp_format", ] - absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings:strings", + "//third_party/abseil-cpp/absl/types:optional", + ] } rtc_library("webrtc_h264") { @@ -459,6 +467,15 @@ rtc_library("webrtc_libvpx_interface") { } } +rtc_library("mock_libvpx_interface") { + testonly = true + sources = [ "codecs/interface/mock_libvpx_interface.h" ] + deps = [ + ":webrtc_libvpx_interface", + "../../test:test_support", + ] +} + # This target includes the internal SW codec. rtc_library("webrtc_vp8") { visibility = [ "*" ] @@ -486,17 +503,20 @@ rtc_library("webrtc_vp8") { "../../api/video_codecs:vp8_temporal_layers_factory", "../../common_video", "../../rtc_base:checks", - "../../rtc_base:deprecation", "../../rtc_base:rtc_base_approved", "../../rtc_base:rtc_numerics", "../../rtc_base/experiments:cpu_speed_experiment", + "../../rtc_base/experiments:encoder_info_settings", "../../rtc_base/experiments:field_trial_parser", "../../rtc_base/experiments:rate_control_settings", "../../system_wrappers:field_trial", "../../system_wrappers:metrics", "//third_party/libyuv", ] - absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/types:optional", + ] if (rtc_build_libvpx) { deps += [ rtc_libvpx_dir ] } @@ -573,6 +593,7 @@ rtc_library("webrtc_vp9") { ":webrtc_libvpx_interface", ":webrtc_vp9_helpers", "../../api:fec_controller_api", + "../../api:refcountedbase", "../../api:scoped_refptr", "../../api/transport:field_trial_based_config", "../../api/transport:webrtc_key_value_config", @@ -582,9 +603,9 @@ rtc_library("webrtc_vp9") { "../../api/video_codecs:video_codecs_api", "../../common_video", "../../media:rtc_media_base", - "../../media:rtc_vp9_profile", "../../rtc_base", "../../rtc_base:checks", + "../../rtc_base/experiments:encoder_info_settings", "../../rtc_base/experiments:field_trial_parser", "../../rtc_base/experiments:rate_control_settings", "../../rtc_base/synchronization:mutex", @@ -595,6 +616,7 @@ rtc_library("webrtc_vp9") { "//third_party/libyuv", ] absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", "//third_party/abseil-cpp/absl/memory", "//third_party/abseil-cpp/absl/strings:strings", ] @@ -666,15 +688,6 @@ if (rtc_include_tests) { ] } - rtc_library("mock_libvpx_interface") { - testonly = true - sources = [ "codecs/interface/mock_libvpx_interface.h" ] - deps = [ - ":webrtc_libvpx_interface", - "../../test:test_support", - ] - } - rtc_library("simulcast_test_fixture_impl") { testonly = true sources = [ @@ -720,6 +733,7 @@ if (rtc_include_tests) { "../../api:create_frame_generator", "../../api:frame_generator_api", "../../api:scoped_refptr", + "../../api:sequence_checker", "../../api:videocodec_test_fixture_api", "../../api/task_queue", "../../api/video:builtin_video_bitrate_allocator_factory", @@ -735,7 +749,6 @@ if (rtc_include_tests) { "../../rtc_base:rtc_base_approved", "../../rtc_base:rtc_task_queue", "../../rtc_base/synchronization:mutex", - "../../rtc_base/synchronization:sequence_checker", "../../rtc_base/system:no_unique_address", "../../rtc_base/task_utils:to_queued_task", "../../test:test_support", @@ -802,7 +815,6 @@ if (rtc_include_tests) { "../../call:video_stream_api", "../../common_video", "../../media:rtc_audio_video", - "../../media:rtc_h264_profile_id", "../../media:rtc_internal_video_codecs", "../../media:rtc_media_base", "../../rtc_base:checks", @@ -816,7 +828,10 @@ if (rtc_include_tests) { "../../test:video_test_common", "../../test:video_test_support", ] - absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings:strings", + "//third_party/abseil-cpp/absl/types:optional", + ] } rtc_library("videocodec_test_stats_impl") { @@ -846,9 +861,12 @@ if (rtc_include_tests) { "codecs/test/video_encoder_decoder_instantiation_tests.cc", "codecs/test/videocodec_test_libvpx.cc", "codecs/vp8/test/vp8_impl_unittest.cc", - "codecs/vp9/test/vp9_impl_unittest.cc", ] + if (rtc_libvpx_build_vp9) { + sources += [ "codecs/vp9/test/vp9_impl_unittest.cc" ] + } + # TODO(jianj): Fix crash on iOS and re-enable if (enable_libaom && !is_ios) { sources += [ "codecs/test/videocodec_test_libaom.cc" ] @@ -886,11 +904,9 @@ if (rtc_include_tests) { "../../api/video_codecs:video_codecs_api", "../../common_video", "../../common_video/test:utilities", - "../../media:rtc_h264_profile_id", "../../media:rtc_internal_video_codecs", "../../media:rtc_media_base", "../../media:rtc_simulcast_encoder_adapter", - "../../media:rtc_vp9_profile", "../../rtc_base", "../../test:explicit_key_value_config", "../../test:field_trial", @@ -953,10 +969,12 @@ if (rtc_include_tests) { "packet_buffer_unittest.cc", "receiver_unittest.cc", "rtp_frame_reference_finder_unittest.cc", + "rtp_vp8_ref_finder_unittest.cc", "rtp_vp9_ref_finder_unittest.cc", "session_info_unittest.cc", "test/stream_generator.cc", "test/stream_generator.h", + "timestamp_map_unittest.cc", "timing_unittest.cc", "unique_timestamp_counter_unittest.cc", "utility/decoded_frames_history_unittest.cc", @@ -964,8 +982,10 @@ if (rtc_include_tests) { "utility/framerate_controller_unittest.cc", "utility/ivf_file_reader_unittest.cc", "utility/ivf_file_writer_unittest.cc", + "utility/qp_parser_unittest.cc", "utility/quality_scaler_unittest.cc", "utility/simulcast_rate_allocator_unittest.cc", + "utility/vp9_uncompressed_header_parser_unittest.cc", "video_codec_initializer_unittest.cc", "video_receiver_unittest.cc", ] @@ -1008,6 +1028,7 @@ if (rtc_include_tests) { "../../api/task_queue:default_task_queue_factory", "../../api/test/video:function_video_factory", "../../api/video:builtin_video_bitrate_allocator_factory", + "../../api/video:encoded_frame", "../../api/video:video_adaptation", "../../api/video:video_bitrate_allocation", "../../api/video:video_bitrate_allocator", diff --git a/modules/video_coding/codecs/av1/BUILD.gn b/modules/video_coding/codecs/av1/BUILD.gn index 95b5ad1274..e7c901cc9a 100644 --- a/modules/video_coding/codecs/av1/BUILD.gn +++ b/modules/video_coding/codecs/av1/BUILD.gn @@ -88,6 +88,7 @@ if (rtc_include_tests) { deps = [ ":av1_svc_config", "../../../../api/video_codecs:video_codecs_api", + "../../../../test:test_support", ] if (enable_libaom) { @@ -104,7 +105,6 @@ if (rtc_include_tests) { "../../../../api/units:data_size", "../../../../api/units:time_delta", "../../../../api/video:video_frame", - "../../../../test:test_support", "../../svc:scalability_structures", "../../svc:scalable_video_controller", ] diff --git a/modules/video_coding/codecs/av1/av1_svc_config.cc b/modules/video_coding/codecs/av1/av1_svc_config.cc index 1e61477b78..b15443c563 100644 --- a/modules/video_coding/codecs/av1/av1_svc_config.cc +++ b/modules/video_coding/codecs/av1/av1_svc_config.cc @@ -51,8 +51,9 @@ bool SetAv1SvcConfig(VideoCodec& video_codec) { if (info.num_spatial_layers == 1) { SpatialLayer& spatial_layer = video_codec.spatialLayers[0]; spatial_layer.minBitrate = video_codec.minBitrate; - spatial_layer.targetBitrate = video_codec.startBitrate; spatial_layer.maxBitrate = video_codec.maxBitrate; + spatial_layer.targetBitrate = + (video_codec.minBitrate + video_codec.maxBitrate) / 2; return true; } diff --git a/modules/video_coding/codecs/av1/av1_svc_config_unittest.cc b/modules/video_coding/codecs/av1/av1_svc_config_unittest.cc index 02ded1c70d..e6035328da 100644 --- a/modules/video_coding/codecs/av1/av1_svc_config_unittest.cc +++ b/modules/video_coding/codecs/av1/av1_svc_config_unittest.cc @@ -97,19 +97,21 @@ TEST(Av1SvcConfigTest, SetsNumberOfTemporalLayers) { EXPECT_EQ(video_codec.spatialLayers[0].numberOfTemporalLayers, 3); } -TEST(Av1SvcConfigTest, CopiesBitrateForSingleSpatialLayer) { +TEST(Av1SvcConfigTest, CopiesMinMaxBitrateForSingleSpatialLayer) { VideoCodec video_codec; video_codec.codecType = kVideoCodecAV1; video_codec.SetScalabilityMode("L1T3"); video_codec.minBitrate = 100; - video_codec.startBitrate = 200; video_codec.maxBitrate = 500; EXPECT_TRUE(SetAv1SvcConfig(video_codec)); EXPECT_EQ(video_codec.spatialLayers[0].minBitrate, 100u); - EXPECT_EQ(video_codec.spatialLayers[0].targetBitrate, 200u); EXPECT_EQ(video_codec.spatialLayers[0].maxBitrate, 500u); + EXPECT_LE(video_codec.spatialLayers[0].minBitrate, + video_codec.spatialLayers[0].targetBitrate); + EXPECT_LE(video_codec.spatialLayers[0].targetBitrate, + video_codec.spatialLayers[0].maxBitrate); } TEST(Av1SvcConfigTest, SetsBitratesForMultipleSpatialLayers) { diff --git a/modules/video_coding/codecs/av1/libaom_av1_encoder.cc b/modules/video_coding/codecs/av1/libaom_av1_encoder.cc index a99c642f07..034709a989 100644 --- a/modules/video_coding/codecs/av1/libaom_av1_encoder.cc +++ b/modules/video_coding/codecs/av1/libaom_av1_encoder.cc @@ -41,9 +41,9 @@ namespace { // Encoder configuration parameters constexpr int kQpMin = 10; -constexpr int kUsageProfile = 1; // 0 = good quality; 1 = real-time. -constexpr int kMinQindex = 58; // Min qindex threshold for QP scaling. -constexpr int kMaxQindex = 180; // Max qindex threshold for QP scaling. +constexpr int kUsageProfile = AOM_USAGE_REALTIME; +constexpr int kMinQindex = 145; // Min qindex threshold for QP scaling. +constexpr int kMaxQindex = 205; // Max qindex threshold for QP scaling. constexpr int kBitDepth = 8; constexpr int kLagInFrames = 0; // No look ahead. constexpr int kRtpTicksPerSecond = 90000; @@ -54,18 +54,27 @@ constexpr float kMinimumFrameRate = 1.0; int GetCpuSpeed(int width, int height, int number_of_cores) { // For smaller resolutions, use lower speed setting (get some coding gain at // the cost of increased encoding complexity). - if (number_of_cores > 2 && width * height <= 320 * 180) + if (number_of_cores > 4 && width * height < 320 * 180) return 6; else if (width * height >= 1280 * 720) + return 9; + else if (width * height >= 640 * 360) return 8; else return 7; } +aom_superblock_size_t GetSuperblockSize(int width, int height, int threads) { + int resolution = width * height; + if (threads >= 4 && resolution >= 960 * 540 && resolution < 1920 * 1080) + return AOM_SUPERBLOCK_SIZE_64X64; + else + return AOM_SUPERBLOCK_SIZE_DYNAMIC; +} + class LibaomAv1Encoder final : public VideoEncoder { public: - explicit LibaomAv1Encoder( - std::unique_ptr svc_controller); + LibaomAv1Encoder(); ~LibaomAv1Encoder(); int InitEncode(const VideoCodec* codec_settings, @@ -132,14 +141,10 @@ int32_t VerifyCodecSettings(const VideoCodec& codec_settings) { return WEBRTC_VIDEO_CODEC_OK; } -LibaomAv1Encoder::LibaomAv1Encoder( - std::unique_ptr svc_controller) - : svc_controller_(std::move(svc_controller)), - inited_(false), +LibaomAv1Encoder::LibaomAv1Encoder() + : inited_(false), frame_for_encode_(nullptr), - encoded_image_callback_(nullptr) { - RTC_DCHECK(svc_controller_); -} + encoded_image_callback_(nullptr) {} LibaomAv1Encoder::~LibaomAv1Encoder() { Release(); @@ -173,11 +178,11 @@ int LibaomAv1Encoder::InitEncode(const VideoCodec* codec_settings, return result; } absl::string_view scalability_mode = encoder_settings_.ScalabilityMode(); - // When scalability_mode is not set, keep using svc_controller_ created - // at construction of the encoder. - if (!scalability_mode.empty()) { - svc_controller_ = CreateScalabilityStructure(scalability_mode); + if (scalability_mode.empty()) { + RTC_LOG(LS_WARNING) << "Scalability mode is not set, using 'NONE'."; + scalability_mode = "NONE"; } + svc_controller_ = CreateScalabilityStructure(scalability_mode); if (svc_controller_ == nullptr) { RTC_LOG(LS_WARNING) << "Failed to set scalability mode " << scalability_mode; @@ -190,7 +195,7 @@ int LibaomAv1Encoder::InitEncode(const VideoCodec* codec_settings, // Initialize encoder configuration structure with default values aom_codec_err_t ret = - aom_codec_enc_config_default(aom_codec_av1_cx(), &cfg_, 0); + aom_codec_enc_config_default(aom_codec_av1_cx(), &cfg_, kUsageProfile); if (ret != AOM_CODEC_OK) { RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret << " on aom_codec_enc_config_default."; @@ -209,6 +214,11 @@ int LibaomAv1Encoder::InitEncode(const VideoCodec* codec_settings, cfg_.kf_mode = AOM_KF_DISABLED; cfg_.rc_min_quantizer = kQpMin; cfg_.rc_max_quantizer = encoder_settings_.qpMax; + cfg_.rc_undershoot_pct = 50; + cfg_.rc_overshoot_pct = 50; + cfg_.rc_buf_initial_sz = 600; + cfg_.rc_buf_optimal_sz = 600; + cfg_.rc_buf_sz = 1000; cfg_.g_usage = kUsageProfile; cfg_.g_error_resilient = 0; // Low-latency settings. @@ -288,13 +298,13 @@ int LibaomAv1Encoder::InitEncode(const VideoCodec* codec_settings, << " on control AV1E_SET_MAX_INTRA_BITRATE_PCT."; return WEBRTC_VIDEO_CODEC_ERROR; } - ret = aom_codec_control(&ctx_, AV1E_SET_COEFF_COST_UPD_FREQ, 2); + ret = aom_codec_control(&ctx_, AV1E_SET_COEFF_COST_UPD_FREQ, 3); if (ret != AOM_CODEC_OK) { RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret << " on control AV1E_SET_COEFF_COST_UPD_FREQ."; return WEBRTC_VIDEO_CODEC_ERROR; } - ret = aom_codec_control(&ctx_, AV1E_SET_MODE_COST_UPD_FREQ, 2); + ret = aom_codec_control(&ctx_, AV1E_SET_MODE_COST_UPD_FREQ, 3); if (ret != AOM_CODEC_OK) { RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret << " on control AV1E_SET_MODE_COST_UPD_FREQ."; @@ -307,17 +317,109 @@ int LibaomAv1Encoder::InitEncode(const VideoCodec* codec_settings, return WEBRTC_VIDEO_CODEC_ERROR; } - ret = aom_codec_control(&ctx_, AV1E_SET_TILE_COLUMNS, cfg_.g_threads >> 1); + if (cfg_.g_threads == 4 && cfg_.g_w == 640 && + (cfg_.g_h == 360 || cfg_.g_h == 480)) { + ret = aom_codec_control(&ctx_, AV1E_SET_TILE_ROWS, + static_cast(log2(cfg_.g_threads))); + if (ret != AOM_CODEC_OK) { + RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret + << " on control AV1E_SET_TILE_ROWS."; + return WEBRTC_VIDEO_CODEC_ERROR; + } + } else { + ret = aom_codec_control(&ctx_, AV1E_SET_TILE_COLUMNS, + static_cast(log2(cfg_.g_threads))); + if (ret != AOM_CODEC_OK) { + RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret + << " on control AV1E_SET_TILE_COLUMNS."; + return WEBRTC_VIDEO_CODEC_ERROR; + } + } + + ret = aom_codec_control(&ctx_, AV1E_SET_ROW_MT, 1); if (ret != AOM_CODEC_OK) { RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret - << " on control AV1E_SET_TILE_COLUMNS."; + << " on control AV1E_SET_ROW_MT."; return WEBRTC_VIDEO_CODEC_ERROR; } - ret = aom_codec_control(&ctx_, AV1E_SET_ROW_MT, 1); + ret = aom_codec_control(&ctx_, AV1E_SET_ENABLE_OBMC, 0); if (ret != AOM_CODEC_OK) { RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret - << " on control AV1E_SET_ROW_MT."; + << " on control AV1E_SET_ENABLE_OBMC."; + return WEBRTC_VIDEO_CODEC_ERROR; + } + + ret = aom_codec_control(&ctx_, AV1E_SET_NOISE_SENSITIVITY, 0); + if (ret != AOM_CODEC_OK) { + RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret + << " on control AV1E_SET_NOISE_SENSITIVITY."; + return WEBRTC_VIDEO_CODEC_ERROR; + } + + ret = aom_codec_control(&ctx_, AV1E_SET_ENABLE_WARPED_MOTION, 0); + if (ret != AOM_CODEC_OK) { + RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret + << " on control AV1E_SET_ENABLE_WARPED_MOTION."; + return WEBRTC_VIDEO_CODEC_ERROR; + } + + ret = aom_codec_control(&ctx_, AV1E_SET_ENABLE_GLOBAL_MOTION, 0); + if (ret != AOM_CODEC_OK) { + RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret + << " on control AV1E_SET_ENABLE_GLOBAL_MOTION."; + return WEBRTC_VIDEO_CODEC_ERROR; + } + + ret = aom_codec_control(&ctx_, AV1E_SET_ENABLE_REF_FRAME_MVS, 0); + if (ret != AOM_CODEC_OK) { + RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret + << " on control AV1E_SET_ENABLE_REF_FRAME_MVS."; + return WEBRTC_VIDEO_CODEC_ERROR; + } + + ret = + aom_codec_control(&ctx_, AV1E_SET_SUPERBLOCK_SIZE, + GetSuperblockSize(cfg_.g_w, cfg_.g_h, cfg_.g_threads)); + if (ret != AOM_CODEC_OK) { + RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret + << " on control AV1E_SET_SUPERBLOCK_SIZE."; + return WEBRTC_VIDEO_CODEC_ERROR; + } + + ret = aom_codec_control(&ctx_, AV1E_SET_ENABLE_CFL_INTRA, 0); + if (ret != AOM_CODEC_OK) { + RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret + << " on control AV1E_SET_ENABLE_CFL_INTRA."; + return WEBRTC_VIDEO_CODEC_ERROR; + } + + ret = aom_codec_control(&ctx_, AV1E_SET_ENABLE_SMOOTH_INTRA, 0); + if (ret != AOM_CODEC_OK) { + RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret + << " on control AV1E_SET_ENABLE_SMOOTH_INTRA."; + return WEBRTC_VIDEO_CODEC_ERROR; + } + + ret = aom_codec_control(&ctx_, AV1E_SET_ENABLE_ANGLE_DELTA, 0); + if (ret != AOM_CODEC_OK) { + RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret + << " on control AV1E_SET_ENABLE_ANGLE_DELTA."; + return WEBRTC_VIDEO_CODEC_ERROR; + } + + ret = aom_codec_control(&ctx_, AV1E_SET_ENABLE_FILTER_INTRA, 0); + if (ret != AOM_CODEC_OK) { + RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::EncodeInit returned " << ret + << " on control AV1E_SET_ENABLE_FILTER_INTRA."; + return WEBRTC_VIDEO_CODEC_ERROR; + } + + ret = aom_codec_control(&ctx_, AV1E_SET_INTRA_DEFAULT_TX_ONLY, 1); + if (ret != AOM_CODEC_OK) { + RTC_LOG(LS_WARNING) + << "LibaomAv1Encoder::EncodeInit returned " << ret + << " on control AOM_CTRL_AV1E_SET_INTRA_DEFAULT_TX_ONLY."; return WEBRTC_VIDEO_CODEC_ERROR; } @@ -327,11 +429,12 @@ int LibaomAv1Encoder::InitEncode(const VideoCodec* codec_settings, int LibaomAv1Encoder::NumberOfThreads(int width, int height, int number_of_cores) { - // Keep the number of encoder threads equal to the possible number of column - // tiles, which is (1, 2, 4, 8). See comments below for AV1E_SET_TILE_COLUMNS. - if (width * height >= 1280 * 720 && number_of_cores > 4) { + // Keep the number of encoder threads equal to the possible number of + // column/row tiles, which is (1, 2, 4, 8). See comments below for + // AV1E_SET_TILE_COLUMNS/ROWS. + if (width * height >= 640 * 360 && number_of_cores > 4) { return 4; - } else if (width * height >= 640 * 360 && number_of_cores > 2) { + } else if (width * height >= 320 * 180 && number_of_cores > 2) { return 2; } else { // Use 2 threads for low res on ARM. @@ -480,9 +583,36 @@ int32_t LibaomAv1Encoder::Encode( // Convert input frame to I420, if needed. VideoFrame prepped_input_frame = frame; if (prepped_input_frame.video_frame_buffer()->type() != - VideoFrameBuffer::Type::kI420) { + VideoFrameBuffer::Type::kI420 && + prepped_input_frame.video_frame_buffer()->type() != + VideoFrameBuffer::Type::kI420A) { rtc::scoped_refptr converted_buffer( prepped_input_frame.video_frame_buffer()->ToI420()); + // The buffer should now be a mapped I420 or I420A format, but some buffer + // implementations incorrectly return the wrong buffer format, such as + // kNative. As a workaround to this, we perform ToI420() a second time. + // TODO(https://crbug.com/webrtc/12602): When Android buffers have a correct + // ToI420() implementaion, remove his workaround. + if (!converted_buffer) { + RTC_LOG(LS_ERROR) << "Failed to convert " + << VideoFrameBufferTypeToString( + converted_buffer->type()) + << " image to I420. Can't encode frame."; + return WEBRTC_VIDEO_CODEC_ENCODER_FAILURE; + } + if (converted_buffer->type() != VideoFrameBuffer::Type::kI420 && + converted_buffer->type() != VideoFrameBuffer::Type::kI420A) { + converted_buffer = converted_buffer->ToI420(); + RTC_CHECK(converted_buffer->type() == VideoFrameBuffer::Type::kI420 || + converted_buffer->type() == VideoFrameBuffer::Type::kI420A); + } + if (!converted_buffer) { + RTC_LOG(LS_ERROR) << "Failed to convert " + << VideoFrameBufferTypeToString( + converted_buffer->type()) + << " image to I420. Can't encode frame."; + return WEBRTC_VIDEO_CODEC_ENCODER_FAILURE; + } prepped_input_frame = VideoFrame(converted_buffer, frame.timestamp(), frame.render_time_ms(), frame.rotation()); } @@ -512,6 +642,15 @@ int32_t LibaomAv1Encoder::Encode( if (SvcEnabled()) { SetSvcLayerId(layer_frame); SetSvcRefFrameConfig(layer_frame); + + aom_codec_err_t ret = + aom_codec_control(&ctx_, AV1E_SET_ERROR_RESILIENT_MODE, + layer_frame.TemporalId() > 0 ? 1 : 0); + if (ret != AOM_CODEC_OK) { + RTC_LOG(LS_WARNING) << "LibaomAv1Encoder::Encode returned " << ret + << " on control AV1E_SET_ERROR_RESILIENT_MODE."; + return WEBRTC_VIDEO_CODEC_ERROR; + } } // Encode a frame. @@ -551,8 +690,15 @@ int32_t LibaomAv1Encoder::Encode( encoded_image.content_type_ = VideoContentType::UNSPECIFIED; // If encoded image width/height info are added to aom_codec_cx_pkt_t, // use those values in lieu of the values in frame. - encoded_image._encodedHeight = frame.height(); - encoded_image._encodedWidth = frame.width(); + if (svc_params_) { + int n = svc_params_->scaling_factor_num[layer_frame.SpatialId()]; + int d = svc_params_->scaling_factor_den[layer_frame.SpatialId()]; + encoded_image._encodedWidth = cfg_.g_w * n / d; + encoded_image._encodedHeight = cfg_.g_h * n / d; + } else { + encoded_image._encodedWidth = cfg_.g_w; + encoded_image._encodedHeight = cfg_.g_h; + } encoded_image.timing_.flags = VideoSendTiming::kInvalid; int qp = -1; ret = aom_codec_control(&ctx_, AOME_GET_LAST_QUANTIZER, &qp); @@ -615,15 +761,8 @@ void LibaomAv1Encoder::SetRates(const RateControlParameters& parameters) { return; } - // Check input target bit rate value. - uint32_t rc_target_bitrate_kbps = parameters.bitrate.get_sum_kbps(); - if (encoder_settings_.maxBitrate > 0) - RTC_DCHECK_LE(rc_target_bitrate_kbps, encoder_settings_.maxBitrate); - RTC_DCHECK_GE(rc_target_bitrate_kbps, encoder_settings_.minBitrate); - svc_controller_->OnRatesUpdated(parameters.bitrate); - // Set target bit rate. - cfg_.rc_target_bitrate = rc_target_bitrate_kbps; + cfg_.rc_target_bitrate = parameters.bitrate.get_sum_kbps(); if (SvcEnabled()) { for (int sid = 0; sid < svc_params_->number_spatial_layers; ++sid) { @@ -680,13 +819,7 @@ VideoEncoder::EncoderInfo LibaomAv1Encoder::GetEncoderInfo() const { const bool kIsLibaomAv1EncoderSupported = true; std::unique_ptr CreateLibaomAv1Encoder() { - return std::make_unique( - std::make_unique()); -} - -std::unique_ptr CreateLibaomAv1Encoder( - std::unique_ptr svc_controller) { - return std::make_unique(std::move(svc_controller)); + return std::make_unique(); } } // namespace webrtc diff --git a/modules/video_coding/codecs/av1/libaom_av1_encoder.h b/modules/video_coding/codecs/av1/libaom_av1_encoder.h index 04a2b65f54..4b0ee28d40 100644 --- a/modules/video_coding/codecs/av1/libaom_av1_encoder.h +++ b/modules/video_coding/codecs/av1/libaom_av1_encoder.h @@ -14,15 +14,12 @@ #include "absl/base/attributes.h" #include "api/video_codecs/video_encoder.h" -#include "modules/video_coding/svc/scalable_video_controller.h" namespace webrtc { ABSL_CONST_INIT extern const bool kIsLibaomAv1EncoderSupported; std::unique_ptr CreateLibaomAv1Encoder(); -std::unique_ptr CreateLibaomAv1Encoder( - std::unique_ptr controller); } // namespace webrtc diff --git a/modules/video_coding/codecs/av1/libaom_av1_encoder_unittest.cc b/modules/video_coding/codecs/av1/libaom_av1_encoder_unittest.cc index 146397ffea..96057a0ce2 100644 --- a/modules/video_coding/codecs/av1/libaom_av1_encoder_unittest.cc +++ b/modules/video_coding/codecs/av1/libaom_av1_encoder_unittest.cc @@ -18,7 +18,6 @@ #include "api/video_codecs/video_encoder.h" #include "modules/video_coding/codecs/test/encoded_video_frame_producer.h" #include "modules/video_coding/include/video_error_codes.h" -#include "modules/video_coding/svc/scalability_structure_l1t2.h" #include "test/gmock.h" #include "test/gtest.h" @@ -26,6 +25,7 @@ namespace webrtc { namespace { using ::testing::ElementsAre; +using ::testing::Field; using ::testing::IsEmpty; using ::testing::SizeIs; @@ -61,9 +61,9 @@ TEST(LibaomAv1EncoderTest, InitAndRelease) { TEST(LibaomAv1EncoderTest, NoBitrateOnTopLayerRefecltedInActiveDecodeTargets) { // Configure encoder with 2 temporal layers. - std::unique_ptr encoder = - CreateLibaomAv1Encoder(std::make_unique()); + std::unique_ptr encoder = CreateLibaomAv1Encoder(); VideoCodec codec_settings = DefaultCodecSettings(); + codec_settings.SetScalabilityMode("L1T2"); ASSERT_EQ(encoder->InitEncode(&codec_settings, DefaultEncoderSettings()), WEBRTC_VIDEO_CODEC_OK); @@ -104,6 +104,23 @@ TEST(LibaomAv1EncoderTest, SetsEndOfPictureForLastFrameInTemporalUnit) { EXPECT_TRUE(encoded_frames[5].codec_specific_info.end_of_picture); } +TEST(LibaomAv1EncoderTest, CheckOddDimensionsWithSpatialLayers) { + std::unique_ptr encoder = CreateLibaomAv1Encoder(); + VideoCodec codec_settings = DefaultCodecSettings(); + // Configure encoder with 3 spatial layers. + codec_settings.SetScalabilityMode("L3T1"); + // Odd width and height values should not make encoder crash. + codec_settings.width = 623; + codec_settings.height = 405; + ASSERT_EQ(encoder->InitEncode(&codec_settings, DefaultEncoderSettings()), + WEBRTC_VIDEO_CODEC_OK); + EncodedVideoFrameProducer evfp(*encoder); + evfp.SetResolution(RenderResolution{623, 405}); + std::vector encoded_frames = + evfp.SetNumInputFrames(2).Encode(); + ASSERT_THAT(encoded_frames, SizeIs(6)); +} + TEST(LibaomAv1EncoderTest, EncoderInfoProvidesFpsAllocation) { std::unique_ptr encoder = CreateLibaomAv1Encoder(); VideoCodec codec_settings = DefaultCodecSettings(); @@ -119,5 +136,36 @@ TEST(LibaomAv1EncoderTest, EncoderInfoProvidesFpsAllocation) { EXPECT_THAT(encoder_info.fps_allocation[3], IsEmpty()); } +TEST(LibaomAv1EncoderTest, PopulatesEncodedFrameSize) { + std::unique_ptr encoder = CreateLibaomAv1Encoder(); + VideoCodec codec_settings = DefaultCodecSettings(); + ASSERT_GT(codec_settings.width, 4); + // Configure encoder with 3 spatial layers. + codec_settings.SetScalabilityMode("L3T1"); + ASSERT_EQ(encoder->InitEncode(&codec_settings, DefaultEncoderSettings()), + WEBRTC_VIDEO_CODEC_OK); + + using Frame = EncodedVideoFrameProducer::EncodedFrame; + std::vector encoded_frames = + EncodedVideoFrameProducer(*encoder).SetNumInputFrames(1).Encode(); + EXPECT_THAT( + encoded_frames, + ElementsAre( + Field(&Frame::encoded_image, + AllOf(Field(&EncodedImage::_encodedWidth, + codec_settings.width / 4), + Field(&EncodedImage::_encodedHeight, + codec_settings.height / 4))), + Field(&Frame::encoded_image, + AllOf(Field(&EncodedImage::_encodedWidth, + codec_settings.width / 2), + Field(&EncodedImage::_encodedHeight, + codec_settings.height / 2))), + Field(&Frame::encoded_image, + AllOf(Field(&EncodedImage::_encodedWidth, codec_settings.width), + Field(&EncodedImage::_encodedHeight, + codec_settings.height))))); +} + } // namespace } // namespace webrtc diff --git a/modules/video_coding/codecs/av1/libaom_av1_unittest.cc b/modules/video_coding/codecs/av1/libaom_av1_unittest.cc index 78725ab626..e63e0f8c94 100644 --- a/modules/video_coding/codecs/av1/libaom_av1_unittest.cc +++ b/modules/video_coding/codecs/av1/libaom_av1_unittest.cc @@ -55,6 +55,7 @@ constexpr int kFramerate = 30; VideoCodec DefaultCodecSettings() { VideoCodec codec_settings; + codec_settings.SetScalabilityMode("NONE"); codec_settings.width = kWidth; codec_settings.height = kHeight; codec_settings.maxFramerate = kFramerate; @@ -250,10 +251,10 @@ TEST_P(LibaomAv1SvcTest, SetRatesMatchMeasuredBitrate) { kv.second.bps()); } - std::unique_ptr encoder = - CreateLibaomAv1Encoder(CreateScalabilityStructure(param.name)); + std::unique_ptr encoder = CreateLibaomAv1Encoder(); ASSERT_TRUE(encoder); VideoCodec codec_settings = DefaultCodecSettings(); + codec_settings.SetScalabilityMode(param.name); codec_settings.maxBitrate = allocation.get_sum_kbps(); codec_settings.maxFramerate = 30; ASSERT_EQ(encoder->InitEncode(&codec_settings, DefaultEncoderSettings()), @@ -314,6 +315,7 @@ INSTANTIATE_TEST_SUITE_P( SvcTestParam{"L3T1", /*num_frames_to_generate=*/3}, SvcTestParam{"L3T3", /*num_frames_to_generate=*/8}, SvcTestParam{"S2T1", /*num_frames_to_generate=*/3}, + SvcTestParam{"S3T3", /*num_frames_to_generate=*/8}, SvcTestParam{"L2T2", /*num_frames_to_generate=*/4}, SvcTestParam{"L2T2_KEY", /*num_frames_to_generate=*/4}, SvcTestParam{"L2T2_KEY_SHIFT", diff --git a/modules/video_coding/codecs/h264/h264.cc b/modules/video_coding/codecs/h264/h264.cc index be5b031e88..14e1691153 100644 --- a/modules/video_coding/codecs/h264/h264.cc +++ b/modules/video_coding/codecs/h264/h264.cc @@ -17,6 +17,7 @@ #include "absl/types/optional.h" #include "api/video_codecs/sdp_video_format.h" #include "media/base/media_constants.h" +#include "rtc_base/trace_event.h" #if defined(WEBRTC_USE_H264) #include "modules/video_coding/codecs/h264/h264_decoder_impl.h" @@ -45,11 +46,11 @@ bool IsH264CodecSupported() { } // namespace -SdpVideoFormat CreateH264Format(H264::Profile profile, - H264::Level level, +SdpVideoFormat CreateH264Format(H264Profile profile, + H264Level level, const std::string& packetization_mode) { const absl::optional profile_string = - H264::ProfileLevelIdToString(H264::ProfileLevelId(profile, level)); + H264ProfileLevelIdToString(H264ProfileLevelId(profile, level)); RTC_CHECK(profile_string); return SdpVideoFormat( cricket::kH264CodecName, @@ -65,6 +66,7 @@ void DisableRtcUseH264() { } std::vector SupportedH264Codecs() { + TRACE_EVENT0("webrtc", __func__); if (!IsH264CodecSupported()) return std::vector(); // We only support encoding Constrained Baseline Profile (CBP), but the @@ -76,12 +78,14 @@ std::vector SupportedH264Codecs() { // // We support both packetization modes 0 (mandatory) and 1 (optional, // preferred). - return { - CreateH264Format(H264::kProfileBaseline, H264::kLevel3_1, "1"), - CreateH264Format(H264::kProfileBaseline, H264::kLevel3_1, "0"), - CreateH264Format(H264::kProfileConstrainedBaseline, H264::kLevel3_1, "1"), - CreateH264Format(H264::kProfileConstrainedBaseline, H264::kLevel3_1, - "0")}; + return {CreateH264Format(H264Profile::kProfileBaseline, H264Level::kLevel3_1, + "1"), + CreateH264Format(H264Profile::kProfileBaseline, H264Level::kLevel3_1, + "0"), + CreateH264Format(H264Profile::kProfileConstrainedBaseline, + H264Level::kLevel3_1, "1"), + CreateH264Format(H264Profile::kProfileConstrainedBaseline, + H264Level::kLevel3_1, "0")}; } std::unique_ptr H264Encoder::Create( diff --git a/modules/video_coding/codecs/h264/h264_decoder_impl.cc b/modules/video_coding/codecs/h264/h264_decoder_impl.cc index 8c7a39b609..83f9a77614 100644 --- a/modules/video_coding/codecs/h264/h264_decoder_impl.cc +++ b/modules/video_coding/codecs/h264/h264_decoder_impl.cc @@ -32,7 +32,6 @@ extern "C" { #include "common_video/include/video_frame_buffer.h" #include "modules/video_coding/codecs/h264/h264_color_space.h" #include "rtc_base/checks.h" -#include "rtc_base/keep_ref_until_done.h" #include "rtc_base/logging.h" #include "system_wrappers/include/field_trial.h" #include "system_wrappers/include/metrics.h" @@ -55,6 +54,16 @@ enum H264DecoderImplEvent { kH264DecoderEventMax = 16, }; +struct ScopedPtrAVFreePacket { + void operator()(AVPacket* packet) { av_packet_free(&packet); } +}; +typedef std::unique_ptr ScopedAVPacket; + +ScopedAVPacket MakeScopedAVPacket() { + ScopedAVPacket packet(av_packet_alloc()); + return packet; +} + } // namespace int H264DecoderImpl::AVGetBuffer2(AVCodecContext* context, @@ -203,7 +212,7 @@ int32_t H264DecoderImpl::InitDecode(const VideoCodec* codec_settings, // a pointer |this|. av_context_->opaque = this; - AVCodec* codec = avcodec_find_decoder(av_context_->codec_id); + const AVCodec* codec = avcodec_find_decoder(av_context_->codec_id); if (!codec) { // This is an indication that FFmpeg has not been initialized or it has not // been compiled/initialized with the correct set of codecs. @@ -262,21 +271,25 @@ int32_t H264DecoderImpl::Decode(const EncodedImage& input_image, return WEBRTC_VIDEO_CODEC_ERR_PARAMETER; } - AVPacket packet; - av_init_packet(&packet); + ScopedAVPacket packet = MakeScopedAVPacket(); + if (!packet) { + ReportError(); + return WEBRTC_VIDEO_CODEC_ERROR; + } // packet.data has a non-const type, but isn't modified by // avcodec_send_packet. - packet.data = const_cast(input_image.data()); + packet->data = const_cast(input_image.data()); if (input_image.size() > static_cast(std::numeric_limits::max())) { ReportError(); return WEBRTC_VIDEO_CODEC_ERROR; } - packet.size = static_cast(input_image.size()); + packet->size = static_cast(input_image.size()); int64_t frame_timestamp_us = input_image.ntp_time_ms_ * 1000; // ms -> μs av_context_->reordered_opaque = frame_timestamp_us; - int result = avcodec_send_packet(av_context_.get(), &packet); + int result = avcodec_send_packet(av_context_.get(), packet.get()); + if (result < 0) { RTC_LOG(LS_ERROR) << "avcodec_send_packet error: " << result; ReportError(); @@ -302,8 +315,9 @@ int32_t H264DecoderImpl::Decode(const EncodedImage& input_image, VideoFrame* input_frame = static_cast(av_buffer_get_opaque(av_frame_->buf[0])); RTC_DCHECK(input_frame); - const webrtc::I420BufferInterface* i420_buffer = - input_frame->video_frame_buffer()->GetI420(); + rtc::scoped_refptr frame_buffer = + input_frame->video_frame_buffer(); + const webrtc::I420BufferInterface* i420_buffer = frame_buffer->GetI420(); // When needed, FFmpeg applies cropping by moving plane pointers and adjusting // frame width/height. Ensure that cropped buffers lie within the allocated @@ -330,7 +344,9 @@ int32_t H264DecoderImpl::Decode(const EncodedImage& input_image, av_frame_->width, av_frame_->height, av_frame_->data[kYPlaneIndex], av_frame_->linesize[kYPlaneIndex], av_frame_->data[kUPlaneIndex], av_frame_->linesize[kUPlaneIndex], av_frame_->data[kVPlaneIndex], - av_frame_->linesize[kVPlaneIndex], rtc::KeepRefUntilDone(i420_buffer)); + av_frame_->linesize[kVPlaneIndex], + // To keep reference alive. + [frame_buffer] {}); if (preferred_output_format_ == VideoFrameBuffer::Type::kNV12) { const I420BufferInterface* cropped_i420 = cropped_buffer->GetI420(); diff --git a/modules/video_coding/codecs/h264/h264_encoder_impl.cc b/modules/video_coding/codecs/h264/h264_encoder_impl.cc index 3f4f660ffa..af0393976e 100644 --- a/modules/video_coding/codecs/h264/h264_encoder_impl.cc +++ b/modules/video_coding/codecs/h264/h264_encoder_impl.cc @@ -16,6 +16,7 @@ #include "modules/video_coding/codecs/h264/h264_encoder_impl.h" +#include #include #include @@ -241,7 +242,8 @@ int32_t H264EncoderImpl::InitEncode(const VideoCodec* inst, configurations_[i].frame_dropping_on = codec_.H264()->frameDroppingOn; configurations_[i].key_frame_interval = codec_.H264()->keyFrameInterval; configurations_[i].num_temporal_layers = - codec_.simulcastStream[idx].numberOfTemporalLayers; + std::max(codec_.H264()->numberOfTemporalLayers, + codec_.simulcastStream[idx].numberOfTemporalLayers); // Create downscaled image buffers. if (i > 0) { @@ -373,8 +375,19 @@ int32_t H264EncoderImpl::Encode( return WEBRTC_VIDEO_CODEC_UNINITIALIZED; } - rtc::scoped_refptr frame_buffer = + rtc::scoped_refptr frame_buffer = input_frame.video_frame_buffer()->ToI420(); + // The buffer should now be a mapped I420 or I420A format, but some buffer + // implementations incorrectly return the wrong buffer format, such as + // kNative. As a workaround to this, we perform ToI420() a second time. + // TODO(https://crbug.com/webrtc/12602): When Android buffers have a correct + // ToI420() implementaion, remove his workaround. + if (frame_buffer->type() != VideoFrameBuffer::Type::kI420 && + frame_buffer->type() != VideoFrameBuffer::Type::kI420A) { + frame_buffer = frame_buffer->ToI420(); + RTC_CHECK(frame_buffer->type() == VideoFrameBuffer::Type::kI420 || + frame_buffer->type() == VideoFrameBuffer::Type::kI420A); + } bool send_key_frame = false; for (size_t i = 0; i < configurations_.size(); ++i) { @@ -434,7 +447,7 @@ int32_t H264EncoderImpl::Encode( pictures_[i].iStride[0], pictures_[i].pData[1], pictures_[i].iStride[1], pictures_[i].pData[2], pictures_[i].iStride[2], configurations_[i].width, - configurations_[i].height, libyuv::kFilterBilinear); + configurations_[i].height, libyuv::kFilterBox); } if (!configurations_[i].sending) { @@ -567,7 +580,13 @@ SEncParamExt H264EncoderImpl::CreateEncoderParams(size_t i) const { encoder_params.iMaxBitrate; encoder_params.iTemporalLayerNum = configurations_[i].num_temporal_layers; if (encoder_params.iTemporalLayerNum > 1) { - encoder_params.iNumRefFrame = 1; + // iNumRefFrame specifies total number of reference buffers to allocate. + // For N temporal layers we need at least (N - 1) buffers to store last + // encoded frames of all reference temporal layers. + // Note that there is no API in OpenH264 encoder to specify exact set of + // references to be used to prediction of a given frame. Encoder can + // theoretically use all available reference buffers. + encoder_params.iNumRefFrame = encoder_params.iTemporalLayerNum - 1; } RTC_LOG(INFO) << "OpenH264 version is " << OPENH264_MAJOR << "." << OPENH264_MINOR; diff --git a/modules/video_coding/codecs/h264/include/h264.h b/modules/video_coding/codecs/h264/include/h264.h index 70ca817988..1f8f796064 100644 --- a/modules/video_coding/codecs/h264/include/h264.h +++ b/modules/video_coding/codecs/h264/include/h264.h @@ -27,8 +27,8 @@ struct SdpVideoFormat; // Creates an H264 SdpVideoFormat entry with specified paramters. RTC_EXPORT SdpVideoFormat -CreateH264Format(H264::Profile profile, - H264::Level level, +CreateH264Format(H264Profile profile, + H264Level level, const std::string& packetization_mode); // Set to disable the H.264 encoder/decoder implementations that are provided if diff --git a/modules/video_coding/codecs/multiplex/augmented_video_frame_buffer.cc b/modules/video_coding/codecs/multiplex/augmented_video_frame_buffer.cc index b48996cbcf..8740884f5b 100644 --- a/modules/video_coding/codecs/multiplex/augmented_video_frame_buffer.cc +++ b/modules/video_coding/codecs/multiplex/augmented_video_frame_buffer.cc @@ -54,4 +54,12 @@ int AugmentedVideoFrameBuffer::height() const { rtc::scoped_refptr AugmentedVideoFrameBuffer::ToI420() { return video_frame_buffer_->ToI420(); } + +const I420BufferInterface* AugmentedVideoFrameBuffer::GetI420() const { + // TODO(https://crbug.com/webrtc/12021): When AugmentedVideoFrameBuffer is + // updated to implement the buffer interfaces of relevant + // VideoFrameBuffer::Types, stop overriding GetI420() as a workaround to + // AugmentedVideoFrameBuffer not being the type that is returned by type(). + return video_frame_buffer_->GetI420(); +} } // namespace webrtc diff --git a/modules/video_coding/codecs/multiplex/include/augmented_video_frame_buffer.h b/modules/video_coding/codecs/multiplex/include/augmented_video_frame_buffer.h index c45ab3b2a4..d711cd07da 100644 --- a/modules/video_coding/codecs/multiplex/include/augmented_video_frame_buffer.h +++ b/modules/video_coding/codecs/multiplex/include/augmented_video_frame_buffer.h @@ -45,6 +45,12 @@ class AugmentedVideoFrameBuffer : public VideoFrameBuffer { // Get the I140 Buffer from the underlying frame buffer rtc::scoped_refptr ToI420() final; + // Returns GetI420() of the underlying VideoFrameBuffer. + // TODO(hbos): AugmentedVideoFrameBuffer should not return a type (such as + // kI420) without also implementing that type's interface (i.e. + // I420BufferInterface). Either implement all possible Type's interfaces or + // return kNative. + const I420BufferInterface* GetI420() const final; private: uint16_t augmenting_data_size_; diff --git a/modules/video_coding/codecs/multiplex/multiplex_decoder_adapter.cc b/modules/video_coding/codecs/multiplex/multiplex_decoder_adapter.cc index 426a9f80d1..2332fcddfb 100644 --- a/modules/video_coding/codecs/multiplex/multiplex_decoder_adapter.cc +++ b/modules/video_coding/codecs/multiplex/multiplex_decoder_adapter.cc @@ -17,7 +17,6 @@ #include "common_video/libyuv/include/webrtc_libyuv.h" #include "modules/video_coding/codecs/multiplex/include/augmented_video_frame_buffer.h" #include "modules/video_coding/codecs/multiplex/multiplex_encoded_image_packer.h" -#include "rtc_base/keep_ref_until_done.h" #include "rtc_base/logging.h" namespace webrtc { @@ -249,9 +248,8 @@ void MultiplexDecoderAdapter::MergeAlphaImages( [yuv_buffer, alpha_buffer] {}); } if (supports_augmenting_data_) { - merged_buffer = rtc::scoped_refptr( - new rtc::RefCountedObject( - merged_buffer, std::move(augmenting_data), augmenting_data_length)); + merged_buffer = rtc::make_ref_counted( + merged_buffer, std::move(augmenting_data), augmenting_data_length); } VideoFrame merged_image = VideoFrame::Builder() diff --git a/modules/video_coding/codecs/multiplex/multiplex_encoder_adapter.cc b/modules/video_coding/codecs/multiplex/multiplex_encoder_adapter.cc index 0620a788e3..db525b8f98 100644 --- a/modules/video_coding/codecs/multiplex/multiplex_encoder_adapter.cc +++ b/modules/video_coding/codecs/multiplex/multiplex_encoder_adapter.cc @@ -18,7 +18,6 @@ #include "common_video/libyuv/include/webrtc_libyuv.h" #include "media/base/video_common.h" #include "modules/video_coding/codecs/multiplex/include/augmented_video_frame_buffer.h" -#include "rtc_base/keep_ref_until_done.h" #include "rtc_base/logging.h" namespace webrtc { @@ -158,20 +157,38 @@ int MultiplexEncoderAdapter::Encode( return WEBRTC_VIDEO_CODEC_UNINITIALIZED; } + // The input image is forwarded as-is, unless it is a native buffer and + // |supports_augmented_data_| is true in which case we need to map it in order + // to access the underlying AugmentedVideoFrameBuffer. + VideoFrame forwarded_image = input_image; + if (supports_augmented_data_ && + forwarded_image.video_frame_buffer()->type() == + VideoFrameBuffer::Type::kNative) { + auto info = GetEncoderInfo(); + rtc::scoped_refptr mapped_buffer = + forwarded_image.video_frame_buffer()->GetMappedFrameBuffer( + info.preferred_pixel_formats); + if (!mapped_buffer) { + // Unable to map the buffer. + return WEBRTC_VIDEO_CODEC_ERROR; + } + forwarded_image.set_video_frame_buffer(std::move(mapped_buffer)); + } + std::vector adjusted_frame_types; if (key_frame_interval_ > 0 && picture_index_ % key_frame_interval_ == 0) { adjusted_frame_types.push_back(VideoFrameType::kVideoFrameKey); } else { adjusted_frame_types.push_back(VideoFrameType::kVideoFrameDelta); } - const bool has_alpha = input_image.video_frame_buffer()->type() == + const bool has_alpha = forwarded_image.video_frame_buffer()->type() == VideoFrameBuffer::Type::kI420A; std::unique_ptr augmenting_data = nullptr; uint16_t augmenting_data_length = 0; AugmentedVideoFrameBuffer* augmented_video_frame_buffer = nullptr; if (supports_augmented_data_) { augmented_video_frame_buffer = static_cast( - input_image.video_frame_buffer().get()); + forwarded_image.video_frame_buffer().get()); augmenting_data_length = augmented_video_frame_buffer->GetAugmentingDataSize(); augmenting_data = @@ -186,7 +203,7 @@ int MultiplexEncoderAdapter::Encode( MutexLock lock(&mutex_); stashed_images_.emplace( std::piecewise_construct, - std::forward_as_tuple(input_image.timestamp()), + std::forward_as_tuple(forwarded_image.timestamp()), std::forward_as_tuple( picture_index_, has_alpha ? kAlphaCodecStreams : 1, std::move(augmenting_data), augmenting_data_length)); @@ -195,7 +212,8 @@ int MultiplexEncoderAdapter::Encode( ++picture_index_; // Encode YUV - int rv = encoders_[kYUVStream]->Encode(input_image, &adjusted_frame_types); + int rv = + encoders_[kYUVStream]->Encode(forwarded_image, &adjusted_frame_types); // If we do not receive an alpha frame, we send a single frame for this // |picture_index_|. The receiver will receive |frame_count| as 1 which @@ -204,24 +222,27 @@ int MultiplexEncoderAdapter::Encode( return rv; // Encode AXX - const I420ABufferInterface* yuva_buffer = + rtc::scoped_refptr frame_buffer = supports_augmented_data_ - ? augmented_video_frame_buffer->GetVideoFrameBuffer()->GetI420A() - : input_image.video_frame_buffer()->GetI420A(); + ? augmented_video_frame_buffer->GetVideoFrameBuffer() + : forwarded_image.video_frame_buffer(); + const I420ABufferInterface* yuva_buffer = frame_buffer->GetI420A(); rtc::scoped_refptr alpha_buffer = - WrapI420Buffer(input_image.width(), input_image.height(), + WrapI420Buffer(forwarded_image.width(), forwarded_image.height(), yuva_buffer->DataA(), yuva_buffer->StrideA(), multiplex_dummy_planes_.data(), yuva_buffer->StrideU(), multiplex_dummy_planes_.data(), yuva_buffer->StrideV(), - rtc::KeepRefUntilDone(input_image.video_frame_buffer())); - VideoFrame alpha_image = VideoFrame::Builder() - .set_video_frame_buffer(alpha_buffer) - .set_timestamp_rtp(input_image.timestamp()) - .set_timestamp_ms(input_image.render_time_ms()) - .set_rotation(input_image.rotation()) - .set_id(input_image.id()) - .set_packet_infos(input_image.packet_infos()) - .build(); + // To keep reference alive. + [frame_buffer] {}); + VideoFrame alpha_image = + VideoFrame::Builder() + .set_video_frame_buffer(alpha_buffer) + .set_timestamp_rtp(forwarded_image.timestamp()) + .set_timestamp_ms(forwarded_image.render_time_ms()) + .set_rotation(forwarded_image.rotation()) + .set_id(forwarded_image.id()) + .set_packet_infos(forwarded_image.packet_infos()) + .build(); rv = encoders_[kAXXStream]->Encode(alpha_image, &adjusted_frame_types); return rv; } @@ -297,9 +318,6 @@ EncodedImageCallback::Result MultiplexEncoderAdapter::OnEncodedImage( PayloadStringToCodecType(associated_format_.name); image_component.encoded_image = encodedImage; - // If we don't already own the buffer, make a copy. - image_component.encoded_image.Retain(); - MutexLock lock(&mutex_); const auto& stashed_image_itr = stashed_images_.find(encodedImage.Timestamp()); diff --git a/modules/video_coding/codecs/multiplex/test/multiplex_adapter_unittest.cc b/modules/video_coding/codecs/multiplex/test/multiplex_adapter_unittest.cc index 770d8b596c..7ecb24a87c 100644 --- a/modules/video_coding/codecs/multiplex/test/multiplex_adapter_unittest.cc +++ b/modules/video_coding/codecs/multiplex/test/multiplex_adapter_unittest.cc @@ -38,7 +38,6 @@ #include "modules/video_coding/codecs/vp9/include/vp9.h" #include "modules/video_coding/include/video_codec_interface.h" #include "modules/video_coding/include/video_error_codes.h" -#include "rtc_base/keep_ref_until_done.h" #include "rtc_base/ref_counted_object.h" #include "test/gmock.h" #include "test/gtest.h" @@ -91,9 +90,9 @@ class TestMultiplexAdapter : public VideoCodecUnitTest, for (int i = 0; i < 16; i++) { data[i] = i; } - rtc::scoped_refptr augmented_video_frame_buffer = - new rtc::RefCountedObject( - video_buffer, std::move(data), 16); + auto augmented_video_frame_buffer = + rtc::make_ref_counted(video_buffer, + std::move(data), 16); return std::make_unique( VideoFrame::Builder() .set_video_frame_buffer(augmented_video_frame_buffer) @@ -112,7 +111,9 @@ class TestMultiplexAdapter : public VideoCodecUnitTest, yuv_buffer->width(), yuv_buffer->height(), yuv_buffer->DataY(), yuv_buffer->StrideY(), yuv_buffer->DataU(), yuv_buffer->StrideU(), yuv_buffer->DataV(), yuv_buffer->StrideV(), yuv_buffer->DataY(), - yuv_buffer->StrideY(), rtc::KeepRefUntilDone(yuv_buffer)); + yuv_buffer->StrideY(), + // To keep reference alive. + [yuv_buffer] {}); return std::make_unique(VideoFrame::Builder() .set_video_frame_buffer(yuva_buffer) .set_timestamp_rtp(123) @@ -168,8 +169,7 @@ class TestMultiplexAdapter : public VideoCodecUnitTest, rtc::scoped_refptr axx_buffer = WrapI420Buffer( yuva_buffer->width(), yuva_buffer->height(), yuva_buffer->DataA(), yuva_buffer->StrideA(), yuva_buffer->DataU(), yuva_buffer->StrideU(), - yuva_buffer->DataV(), yuva_buffer->StrideV(), - rtc::KeepRefUntilDone(video_frame_buffer)); + yuva_buffer->DataV(), yuva_buffer->StrideV(), [video_frame_buffer] {}); return std::make_unique(VideoFrame::Builder() .set_video_frame_buffer(axx_buffer) .set_timestamp_rtp(123) diff --git a/modules/video_coding/codecs/test/videocodec_test_fixture_impl.cc b/modules/video_coding/codecs/test/videocodec_test_fixture_impl.cc index 7bd1ba35e0..dee5b1b939 100644 --- a/modules/video_coding/codecs/test/videocodec_test_fixture_impl.cc +++ b/modules/video_coding/codecs/test/videocodec_test_fixture_impl.cc @@ -20,16 +20,17 @@ #include #include +#include "absl/strings/str_replace.h" #include "absl/types/optional.h" #include "api/array_view.h" #include "api/transport/field_trial_based_config.h" #include "api/video/video_bitrate_allocation.h" +#include "api/video_codecs/h264_profile_level_id.h" #include "api/video_codecs/sdp_video_format.h" #include "api/video_codecs/video_codec.h" #include "api/video_codecs/video_decoder.h" #include "api/video_codecs/video_encoder_config.h" #include "common_video/h264/h264_common.h" -#include "media/base/h264_profile_level_id.h" #include "media/base/media_constants.h" #include "media/engine/internal_decoder_factory.h" #include "media/engine/internal_encoder_factory.h" @@ -127,6 +128,8 @@ std::string CodecSpecificToString(const VideoCodec& codec) { case kVideoCodecH264: ss << "frame_dropping: " << codec.H264().frameDroppingOn; ss << "\nkey_frame_interval: " << codec.H264().keyFrameInterval; + ss << "\nnum_temporal_layers: " + << static_cast(codec.H264().numberOfTemporalLayers); break; default: break; @@ -213,6 +216,8 @@ void VideoCodecTestFixtureImpl::Config::SetCodecSettings( case kVideoCodecH264: codec_settings.H264()->frameDroppingOn = frame_dropper_on; codec_settings.H264()->keyFrameInterval = kBaseKeyFrameInterval; + codec_settings.H264()->numberOfTemporalLayers = + static_cast(num_temporal_layers); break; default: break; @@ -235,6 +240,8 @@ size_t VideoCodecTestFixtureImpl::Config::NumberOfTemporalLayers() const { return codec_settings.VP8().numberOfTemporalLayers; } else if (codec_settings.codecType == kVideoCodecVP9) { return codec_settings.VP9().numberOfTemporalLayers; + } else if (codec_settings.codecType == kVideoCodecH264) { + return codec_settings.H264().numberOfTemporalLayers; } else { return 1; } @@ -301,11 +308,11 @@ std::string VideoCodecTestFixtureImpl::Config::CodecName() const { name = CodecTypeToPayloadString(codec_settings.codecType); } if (codec_settings.codecType == kVideoCodecH264) { - if (h264_codec_settings.profile == H264::kProfileConstrainedHigh) { + if (h264_codec_settings.profile == H264Profile::kProfileConstrainedHigh) { return name + "-CHP"; } else { RTC_DCHECK_EQ(h264_codec_settings.profile, - H264::kProfileConstrainedBaseline); + H264Profile::kProfileConstrainedBaseline); return name + "-CBP"; } } @@ -408,8 +415,14 @@ void VideoCodecTestFixtureImpl::RunTest( // codecs on a task queue. TaskQueueForTest task_queue("VidProc TQ"); - SetUpAndInitObjects(&task_queue, rate_profiles[0].target_kbps, - rate_profiles[0].input_fps); + bool is_setup_succeeded = SetUpAndInitObjects( + &task_queue, rate_profiles[0].target_kbps, rate_profiles[0].input_fps); + EXPECT_TRUE(is_setup_succeeded); + if (!is_setup_succeeded) { + ReleaseAndCloseObjects(&task_queue); + return; + } + PrintSettings(&task_queue); ProcessAllFrames(&task_queue, rate_profiles); ReleaseAndCloseObjects(&task_queue); @@ -597,7 +610,7 @@ void VideoCodecTestFixtureImpl::VerifyVideoStatistic( } } -void VideoCodecTestFixtureImpl::CreateEncoderAndDecoder() { +bool VideoCodecTestFixtureImpl::CreateEncoderAndDecoder() { SdpVideoFormat::Parameters params; if (config_.codec_settings.codecType == kVideoCodecH264) { const char* packetization_mode = @@ -606,8 +619,8 @@ void VideoCodecTestFixtureImpl::CreateEncoderAndDecoder() { ? "1" : "0"; params = {{cricket::kH264FmtpProfileLevelId, - *H264::ProfileLevelIdToString(H264::ProfileLevelId( - config_.h264_codec_settings.profile, H264::kLevel3_1))}, + *H264ProfileLevelIdToString(H264ProfileLevelId( + config_.h264_codec_settings.profile, H264Level::kLevel3_1))}, {cricket::kH264FmtpPacketizationMode, packetization_mode}}; } else { params = {}; @@ -616,6 +629,9 @@ void VideoCodecTestFixtureImpl::CreateEncoderAndDecoder() { encoder_ = encoder_factory_->CreateVideoEncoder(format); EXPECT_TRUE(encoder_) << "Encoder not successfully created."; + if (encoder_ == nullptr) { + return false; + } const size_t num_simulcast_or_spatial_layers = std::max( config_.NumberOfSimulcastStreams(), config_.NumberOfSpatialLayers()); @@ -626,7 +642,12 @@ void VideoCodecTestFixtureImpl::CreateEncoderAndDecoder() { for (const auto& decoder : decoders_) { EXPECT_TRUE(decoder) << "Decoder not successfully created."; + if (decoder == nullptr) { + return false; + } } + + return true; } void VideoCodecTestFixtureImpl::DestroyEncoderAndDecoder() { @@ -638,7 +659,7 @@ VideoCodecTestStats& VideoCodecTestFixtureImpl::GetStats() { return stats_; } -void VideoCodecTestFixtureImpl::SetUpAndInitObjects( +bool VideoCodecTestFixtureImpl::SetUpAndInitObjects( TaskQueueForTest* task_queue, size_t initial_bitrate_kbps, double initial_framerate_fps) { @@ -661,17 +682,45 @@ void VideoCodecTestFixtureImpl::SetUpAndInitObjects( RTC_DCHECK(encoded_frame_writers_.empty()); RTC_DCHECK(decoded_frame_writers_.empty()); + stats_.Clear(); + + cpu_process_time_.reset(new CpuProcessTime(config_)); + + bool is_codec_created = false; + task_queue->SendTask( + [this, &is_codec_created]() { + is_codec_created = CreateEncoderAndDecoder(); + }, + RTC_FROM_HERE); + + if (!is_codec_created) { + return false; + } + + task_queue->SendTask( + [this]() { + processor_ = std::make_unique( + encoder_.get(), &decoders_, source_frame_reader_.get(), config_, + &stats_, &encoded_frame_writers_, + decoded_frame_writers_.empty() ? nullptr : &decoded_frame_writers_); + }, + RTC_FROM_HERE); + if (config_.visualization_params.save_encoded_ivf || config_.visualization_params.save_decoded_y4m) { + std::string encoder_name = GetCodecName(task_queue, /*is_encoder=*/true); + encoder_name = absl::StrReplaceAll(encoder_name, {{":", ""}, {" ", "-"}}); + const size_t num_simulcast_or_spatial_layers = std::max( config_.NumberOfSimulcastStreams(), config_.NumberOfSpatialLayers()); const size_t num_temporal_layers = config_.NumberOfTemporalLayers(); for (size_t simulcast_svc_idx = 0; simulcast_svc_idx < num_simulcast_or_spatial_layers; ++simulcast_svc_idx) { - const std::string output_filename_base = JoinFilename( - config_.output_path, FilenameWithParams(config_) + "_sl" + - std::to_string(simulcast_svc_idx)); + const std::string output_filename_base = + JoinFilename(config_.output_path, + FilenameWithParams(config_) + "_" + encoder_name + + "_sl" + std::to_string(simulcast_svc_idx)); if (config_.visualization_params.save_encoded_ivf) { for (size_t temporal_idx = 0; temporal_idx < num_temporal_layers; @@ -699,19 +748,7 @@ void VideoCodecTestFixtureImpl::SetUpAndInitObjects( } } - stats_.Clear(); - - cpu_process_time_.reset(new CpuProcessTime(config_)); - - task_queue->SendTask( - [this]() { - CreateEncoderAndDecoder(); - processor_ = std::make_unique( - encoder_.get(), &decoders_, source_frame_reader_.get(), config_, - &stats_, &encoded_frame_writers_, - decoded_frame_writers_.empty() ? nullptr : &decoded_frame_writers_); - }, - RTC_FROM_HERE); + return true; } void VideoCodecTestFixtureImpl::ReleaseAndCloseObjects( @@ -737,22 +774,32 @@ void VideoCodecTestFixtureImpl::ReleaseAndCloseObjects( decoded_frame_writers_.clear(); } +std::string VideoCodecTestFixtureImpl::GetCodecName( + TaskQueueForTest* task_queue, + bool is_encoder) const { + std::string codec_name; + task_queue->SendTask( + [this, is_encoder, &codec_name] { + if (is_encoder) { + codec_name = encoder_->GetEncoderInfo().implementation_name; + } else { + codec_name = decoders_.at(0)->ImplementationName(); + } + }, + RTC_FROM_HERE); + return codec_name; +} + void VideoCodecTestFixtureImpl::PrintSettings( TaskQueueForTest* task_queue) const { RTC_LOG(LS_INFO) << "==> Config"; RTC_LOG(LS_INFO) << config_.ToString(); RTC_LOG(LS_INFO) << "==> Codec names"; - std::string encoder_name; - std::string decoder_name; - task_queue->SendTask( - [this, &encoder_name, &decoder_name] { - encoder_name = encoder_->GetEncoderInfo().implementation_name; - decoder_name = decoders_.at(0)->GetDecoderInfo().implementation_name; - }, - RTC_FROM_HERE); - RTC_LOG(LS_INFO) << "enc_impl_name: " << encoder_name; - RTC_LOG(LS_INFO) << "dec_impl_name: " << decoder_name; + RTC_LOG(LS_INFO) << "enc_impl_name: " + << GetCodecName(task_queue, /*is_encoder=*/true); + RTC_LOG(LS_INFO) << "dec_impl_name: " + << GetCodecName(task_queue, /*is_encoder=*/false); } } // namespace test diff --git a/modules/video_coding/codecs/test/videocodec_test_fixture_impl.h b/modules/video_coding/codecs/test/videocodec_test_fixture_impl.h index 3bbe50ecc3..005b7c0a8e 100644 --- a/modules/video_coding/codecs/test/videocodec_test_fixture_impl.h +++ b/modules/video_coding/codecs/test/videocodec_test_fixture_impl.h @@ -59,9 +59,9 @@ class VideoCodecTestFixtureImpl : public VideoCodecTestFixture { private: class CpuProcessTime; - void CreateEncoderAndDecoder(); + bool CreateEncoderAndDecoder(); void DestroyEncoderAndDecoder(); - void SetUpAndInitObjects(TaskQueueForTest* task_queue, + bool SetUpAndInitObjects(TaskQueueForTest* task_queue, size_t initial_bitrate_kbps, double initial_framerate_fps); void ReleaseAndCloseObjects(TaskQueueForTest* task_queue); @@ -82,6 +82,7 @@ class VideoCodecTestFixtureImpl : public VideoCodecTestFixture { size_t target_bitrate_kbps, double input_framerate_fps); + std::string GetCodecName(TaskQueueForTest* task_queue, bool is_encoder) const; void PrintSettings(TaskQueueForTest* task_queue) const; // Codecs. diff --git a/modules/video_coding/codecs/test/videocodec_test_libaom.cc b/modules/video_coding/codecs/test/videocodec_test_libaom.cc index 18852e0646..c3263e7134 100644 --- a/modules/video_coding/codecs/test/videocodec_test_libaom.cc +++ b/modules/video_coding/codecs/test/videocodec_test_libaom.cc @@ -42,6 +42,7 @@ TEST(VideoCodecTestLibaom, HighBitrateAV1) { auto config = CreateConfig("foreman_cif"); config.SetCodecSettings(cricket::kAv1CodecName, 1, 1, 1, false, true, true, kCifWidth, kCifHeight); + config.codec_settings.SetScalabilityMode("NONE"); config.num_frames = kNumFramesLong; auto fixture = CreateVideoCodecTestFixture(config); @@ -59,6 +60,7 @@ TEST(VideoCodecTestLibaom, VeryLowBitrateAV1) { auto config = CreateConfig("foreman_cif"); config.SetCodecSettings(cricket::kAv1CodecName, 1, 1, 1, false, true, true, kCifWidth, kCifHeight); + config.codec_settings.SetScalabilityMode("NONE"); auto fixture = CreateVideoCodecTestFixture(config); std::vector rate_profiles = {{50, 30, 0}}; @@ -66,7 +68,7 @@ TEST(VideoCodecTestLibaom, VeryLowBitrateAV1) { std::vector rc_thresholds = { {15, 8, 75, 2, 2, 2, 2, 1}}; - std::vector quality_thresholds = {{28, 25, 0.70, 0.62}}; + std::vector quality_thresholds = {{28, 25, 0.70, 0.60}}; fixture->RunTest(rate_profiles, &rc_thresholds, &quality_thresholds, nullptr); } @@ -78,6 +80,7 @@ TEST(VideoCodecTestLibaom, HdAV1) { auto config = CreateConfig("ConferenceMotion_1280_720_50"); config.SetCodecSettings(cricket::kAv1CodecName, 1, 1, 1, false, true, true, kHdWidth, kHdHeight); + config.codec_settings.SetScalabilityMode("NONE"); config.num_frames = kNumFramesLong; auto fixture = CreateVideoCodecTestFixture(config); @@ -86,7 +89,7 @@ TEST(VideoCodecTestLibaom, HdAV1) { std::vector rc_thresholds = { {13, 3, 0, 1, 0.3, 0.1, 0, 1}}; - std::vector quality_thresholds = {{36, 32, 0.93, 0.87}}; + std::vector quality_thresholds = {{36, 31.7, 0.93, 0.87}}; fixture->RunTest(rate_profiles, &rc_thresholds, &quality_thresholds, nullptr); } diff --git a/modules/video_coding/codecs/test/videocodec_test_libvpx.cc b/modules/video_coding/codecs/test/videocodec_test_libvpx.cc index 8076e40fd4..0eb0d5a284 100644 --- a/modules/video_coding/codecs/test/videocodec_test_libvpx.cc +++ b/modules/video_coding/codecs/test/videocodec_test_libvpx.cc @@ -222,21 +222,6 @@ TEST(VideoCodecTestLibvpx, HighBitrateVP8) { fixture->RunTest(rate_profiles, &rc_thresholds, &quality_thresholds, nullptr); } -// The tests below are currently disabled for Android. For ARM, the encoder -// uses |cpu_speed| = 12, as opposed to default |cpu_speed| <= 6 for x86, -// which leads to significantly different quality. The quality and rate control -// settings in the tests below are defined for encoder speed setting -// |cpu_speed| <= ~6. A number of settings would need to be significantly -// modified for the |cpu_speed| = 12 case. For now, keep the tests below -// disabled on Android. Some quality parameter in the above test has been -// adjusted to also pass for |cpu_speed| <= 12. - -// TODO(webrtc:9267): Fails on iOS -#if defined(WEBRTC_ANDROID) || defined(WEBRTC_IOS) -#define MAYBE_ChangeBitrateVP8 DISABLED_ChangeBitrateVP8 -#else -#define MAYBE_ChangeBitrateVP8 ChangeBitrateVP8 -#endif TEST(VideoCodecTestLibvpx, MAYBE_ChangeBitrateVP8) { auto config = CreateConfig(); config.SetCodecSettings(cricket::kVp8CodecName, 1, 1, 1, true, true, false, @@ -265,12 +250,6 @@ TEST(VideoCodecTestLibvpx, MAYBE_ChangeBitrateVP8) { fixture->RunTest(rate_profiles, &rc_thresholds, &quality_thresholds, nullptr); } -// TODO(webrtc:9267): Fails on iOS -#if defined(WEBRTC_ANDROID) || defined(WEBRTC_IOS) -#define MAYBE_ChangeFramerateVP8 DISABLED_ChangeFramerateVP8 -#else -#define MAYBE_ChangeFramerateVP8 ChangeFramerateVP8 -#endif TEST(VideoCodecTestLibvpx, MAYBE_ChangeFramerateVP8) { auto config = CreateConfig(); config.SetCodecSettings(cricket::kVp8CodecName, 1, 1, 1, true, true, false, @@ -286,7 +265,7 @@ TEST(VideoCodecTestLibvpx, MAYBE_ChangeFramerateVP8) { #if defined(WEBRTC_ARCH_ARM) || defined(WEBRTC_ARCH_ARM64) std::vector rc_thresholds = { - {10, 2, 60, 1, 0.3, 0.3, 0, 1}, + {10, 2.42, 60, 1, 0.3, 0.3, 0, 1}, {10, 2, 30, 1, 0.3, 0.3, 0, 0}, {10, 2, 10, 1, 0.3, 0.2, 0, 0}}; #else @@ -298,10 +277,10 @@ TEST(VideoCodecTestLibvpx, MAYBE_ChangeFramerateVP8) { #if defined(WEBRTC_ARCH_ARM) || defined(WEBRTC_ARCH_ARM64) std::vector quality_thresholds = { - {31, 30, 0.85, 0.84}, {31.5, 30.5, 0.86, 0.84}, {30.5, 29, 0.83, 0.78}}; + {31, 30, 0.85, 0.84}, {31.4, 30.5, 0.86, 0.84}, {30.5, 29, 0.83, 0.78}}; #else std::vector quality_thresholds = { - {31, 30, 0.87, 0.86}, {32, 31, 0.89, 0.86}, {32, 30, 0.87, 0.82}}; + {31, 30, 0.87, 0.85}, {32, 31, 0.88, 0.85}, {32, 30, 0.87, 0.82}}; #endif fixture->RunTest(rate_profiles, &rc_thresholds, &quality_thresholds, nullptr); } diff --git a/modules/video_coding/codecs/test/videocodec_test_mediacodec.cc b/modules/video_coding/codecs/test/videocodec_test_mediacodec.cc index 9f887160a4..978fd8856f 100644 --- a/modules/video_coding/codecs/test/videocodec_test_mediacodec.cc +++ b/modules/video_coding/codecs/test/videocodec_test_mediacodec.cc @@ -95,7 +95,7 @@ TEST(VideoCodecTestMediaCodec, DISABLED_ForemanCif500kbpsH264CHP) { const auto frame_checker = std::make_unique(); - config.h264_codec_settings.profile = H264::kProfileConstrainedHigh; + config.h264_codec_settings.profile = H264Profile::kProfileConstrainedHigh; config.encoded_frame_checker = frame_checker.get(); config.SetCodecSettings(cricket::kH264CodecName, 1, 1, 1, false, false, false, 352, 288); diff --git a/modules/video_coding/codecs/test/videocodec_test_videotoolbox.cc b/modules/video_coding/codecs/test/videocodec_test_videotoolbox.cc index 0f02080f27..6df974362f 100644 --- a/modules/video_coding/codecs/test/videocodec_test_videotoolbox.cc +++ b/modules/video_coding/codecs/test/videocodec_test_videotoolbox.cc @@ -71,7 +71,7 @@ MAYBE_TEST(VideoCodecTestVideoToolbox, ForemanCif500kbpsH264CHP) { const auto frame_checker = std::make_unique(); auto config = CreateConfig(); - config.h264_codec_settings.profile = H264::kProfileConstrainedHigh; + config.h264_codec_settings.profile = H264Profile::kProfileConstrainedHigh; config.SetCodecSettings(cricket::kH264CodecName, 1, 1, 1, false, false, false, 352, 288); config.encoded_frame_checker = frame_checker.get(); diff --git a/modules/video_coding/codecs/test/videoprocessor.cc b/modules/video_coding/codecs/test/videoprocessor.cc index a4918ae73d..23eadfc0db 100644 --- a/modules/video_coding/codecs/test/videoprocessor.cc +++ b/modules/video_coding/codecs/test/videoprocessor.cc @@ -650,6 +650,8 @@ const webrtc::EncodedImage* VideoProcessor::BuildAndStoreSuperframe( EncodedImage copied_image = encoded_image; copied_image.SetEncodedData(buffer); + if (base_image.size()) + copied_image._frameType = base_image._frameType; // Replace previous EncodedImage for this spatial layer. merged_encoded_frames_.at(spatial_idx) = std::move(copied_image); diff --git a/modules/video_coding/codecs/test/videoprocessor.h b/modules/video_coding/codecs/test/videoprocessor.h index ba171d6cd9..d9e10f13bf 100644 --- a/modules/video_coding/codecs/test/videoprocessor.h +++ b/modules/video_coding/codecs/test/videoprocessor.h @@ -20,6 +20,7 @@ #include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "api/task_queue/queued_task.h" #include "api/task_queue/task_queue_base.h" #include "api/test/videocodec_test_fixture.h" @@ -37,10 +38,8 @@ #include "rtc_base/buffer.h" #include "rtc_base/checks.h" #include "rtc_base/constructor_magic.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" #include "test/testsupport/frame_reader.h" #include "test/testsupport/frame_writer.h" @@ -122,7 +121,6 @@ class VideoProcessor { : video_processor_(video_processor), encoded_image_(encoded_image), codec_specific_info_(*codec_specific_info) { - encoded_image_.Retain(); } bool Run() override { diff --git a/modules/video_coding/codecs/vp8/default_temporal_layers.cc b/modules/video_coding/codecs/vp8/default_temporal_layers.cc index b5652593ae..c84d9acb1c 100644 --- a/modules/video_coding/codecs/vp8/default_temporal_layers.cc +++ b/modules/video_coding/codecs/vp8/default_temporal_layers.cc @@ -27,10 +27,12 @@ namespace webrtc { DefaultTemporalLayers::PendingFrame::PendingFrame() = default; DefaultTemporalLayers::PendingFrame::PendingFrame( + uint32_t timestamp, bool expired, uint8_t updated_buffers_mask, const DependencyInfo& dependency_info) - : expired(expired), + : timestamp(timestamp), + expired(expired), updated_buffer_mask(updated_buffers_mask), dependency_info(dependency_info) {} @@ -96,8 +98,24 @@ uint8_t GetUpdatedBuffers(const Vp8FrameConfig& config) { } return flags; } + +size_t BufferToIndex(Vp8BufferReference buffer) { + switch (buffer) { + case Vp8FrameConfig::Vp8BufferReference::kLast: + return 0; + case Vp8FrameConfig::Vp8BufferReference::kGolden: + return 1; + case Vp8FrameConfig::Vp8BufferReference::kAltref: + return 2; + case Vp8FrameConfig::Vp8BufferReference::kNone: + RTC_CHECK_NOTREACHED(); + } +} + } // namespace +constexpr size_t DefaultTemporalLayers::kNumReferenceBuffers; + std::vector DefaultTemporalLayers::GetDependencyInfo(size_t num_layers) { // For indexing in the patterns described below (which temporal layers they @@ -225,11 +243,30 @@ DefaultTemporalLayers::GetDependencyInfo(size_t num_layers) { return {{"", {kNone, kNone, kNone}}}; } +std::bitset +DefaultTemporalLayers::DetermineStaticBuffers( + const std::vector& temporal_pattern) { + std::bitset buffers; + buffers.set(); + for (const DependencyInfo& info : temporal_pattern) { + uint8_t updated_buffers = GetUpdatedBuffers(info.frame_config); + + for (Vp8BufferReference buffer : kAllBuffers) { + if (static_cast(buffer) & updated_buffers) { + buffers.reset(BufferToIndex(buffer)); + } + } + } + return buffers; +} + DefaultTemporalLayers::DefaultTemporalLayers(int number_of_temporal_layers) : num_layers_(std::max(1, number_of_temporal_layers)), temporal_ids_(GetTemporalIds(num_layers_)), temporal_pattern_(GetDependencyInfo(num_layers_)), - pattern_idx_(kUninitializedPatternIndex) { + is_static_buffer_(DetermineStaticBuffers(temporal_pattern_)), + pattern_idx_(kUninitializedPatternIndex), + new_bitrates_bps_(std::vector(num_layers_, 0u)) { RTC_CHECK_GE(kMaxTemporalStreams, number_of_temporal_layers); RTC_CHECK_GE(number_of_temporal_layers, 0); RTC_CHECK_LE(number_of_temporal_layers, 4); @@ -238,25 +275,12 @@ DefaultTemporalLayers::DefaultTemporalLayers(int number_of_temporal_layers) // wrap at max(temporal_ids_.size(), temporal_pattern_.size()). RTC_DCHECK_LE(temporal_ids_.size(), temporal_pattern_.size()); -#if RTC_DCHECK_IS_ON - checker_ = TemporalLayersChecker::CreateTemporalLayersChecker( - Vp8TemporalLayersType::kFixedPattern, number_of_temporal_layers); -#endif + RTC_DCHECK( + checker_ = TemporalLayersChecker::CreateTemporalLayersChecker( + Vp8TemporalLayersType::kFixedPattern, number_of_temporal_layers)); // Always need to start with a keyframe, so pre-populate all frame counters. - for (Vp8BufferReference buffer : kAllBuffers) { - frames_since_buffer_refresh_[buffer] = 0; - } - - kf_buffers_ = {kAllBuffers.begin(), kAllBuffers.end()}; - for (const DependencyInfo& info : temporal_pattern_) { - uint8_t updated_buffers = GetUpdatedBuffers(info.frame_config); - - for (Vp8BufferReference buffer : kAllBuffers) { - if (static_cast(buffer) & updated_buffers) - kf_buffers_.erase(buffer); - } - } + frames_since_buffer_refresh_.fill(0); } DefaultTemporalLayers::~DefaultTemporalLayers() = default; @@ -340,12 +364,12 @@ bool DefaultTemporalLayers::IsSyncFrame(const Vp8FrameConfig& config) const { } if ((config.golden_buffer_flags & BufferFlags::kReference) && - kf_buffers_.find(Vp8BufferReference::kGolden) == kf_buffers_.end()) { + !is_static_buffer_[BufferToIndex(Vp8BufferReference::kGolden)]) { // Referencing a golden frame that contains a non-(base layer|key frame). return false; } if ((config.arf_buffer_flags & BufferFlags::kReference) && - kf_buffers_.find(Vp8BufferReference::kAltref) == kf_buffers_.end()) { + !is_static_buffer_[BufferToIndex(Vp8BufferReference::kAltref)]) { // Referencing an altref frame that contains a non-(base layer|key frame). return false; } @@ -372,8 +396,8 @@ Vp8FrameConfig DefaultTemporalLayers::NextFrameConfig(size_t stream_index, // Start of new pattern iteration, set up clear state by invalidating any // pending frames, so that we don't make an invalid reference to a buffer // containing data from a previous iteration. - for (auto& it : pending_frames_) { - it.second.expired = true; + for (auto& frame : pending_frames_) { + frame.expired = true; } } @@ -401,21 +425,19 @@ Vp8FrameConfig DefaultTemporalLayers::NextFrameConfig(size_t stream_index, // To prevent this data spill over into the next iteration, // the |pedning_frames_| map is reset in loops. If delay is constant, // the relative age should still be OK for the search order. - for (Vp8BufferReference buffer : kAllBuffers) { - ++frames_since_buffer_refresh_[buffer]; + for (size_t& n : frames_since_buffer_refresh_) { + ++n; } } // Add frame to set of pending frames, awaiting completion. - pending_frames_[timestamp] = - PendingFrame{false, GetUpdatedBuffers(tl_config), dependency_info}; + pending_frames_.emplace_back(timestamp, false, GetUpdatedBuffers(tl_config), + dependency_info); -#if RTC_DCHECK_IS_ON // Checker does not yet support encoder frame dropping, so validate flags // here before they can be dropped. // TODO(sprang): Update checker to support dropping. RTC_DCHECK(checker_->CheckTemporalConfig(first_frame, tl_config)); -#endif return tl_config; } @@ -426,10 +448,8 @@ void DefaultTemporalLayers::ValidateReferences(BufferFlags* flags, // if it also a dynamically updating one (buffers always just containing // keyframes are always safe to reference). if ((*flags & BufferFlags::kReference) && - kf_buffers_.find(ref) == kf_buffers_.end()) { - auto it = frames_since_buffer_refresh_.find(ref); - if (it == frames_since_buffer_refresh_.end() || - it->second >= pattern_idx_) { + !is_static_buffer_[BufferToIndex(ref)]) { + if (NumFramesSinceBufferRefresh(ref) >= pattern_idx_) { // No valid buffer state, or buffer contains frame that is older than the // current pattern. This reference is not valid, so remove it. *flags = static_cast(*flags & ~BufferFlags::kReference); @@ -446,17 +466,17 @@ void DefaultTemporalLayers::UpdateSearchOrder(Vp8FrameConfig* config) { if (config->last_buffer_flags & BufferFlags::kReference) { eligible_buffers.emplace_back( Vp8BufferReference::kLast, - frames_since_buffer_refresh_[Vp8BufferReference::kLast]); + NumFramesSinceBufferRefresh(Vp8BufferReference::kLast)); } if (config->golden_buffer_flags & BufferFlags::kReference) { eligible_buffers.emplace_back( Vp8BufferReference::kGolden, - frames_since_buffer_refresh_[Vp8BufferReference::kGolden]); + NumFramesSinceBufferRefresh(Vp8BufferReference::kGolden)); } if (config->arf_buffer_flags & BufferFlags::kReference) { eligible_buffers.emplace_back( Vp8BufferReference::kAltref, - frames_since_buffer_refresh_[Vp8BufferReference::kAltref]); + NumFramesSinceBufferRefresh(Vp8BufferReference::kAltref)); } std::sort(eligible_buffers.begin(), eligible_buffers.end(), @@ -476,6 +496,23 @@ void DefaultTemporalLayers::UpdateSearchOrder(Vp8FrameConfig* config) { } } +size_t DefaultTemporalLayers::NumFramesSinceBufferRefresh( + Vp8FrameConfig::Vp8BufferReference ref) const { + return frames_since_buffer_refresh_[BufferToIndex(ref)]; +} + +void DefaultTemporalLayers::ResetNumFramesSinceBufferRefresh( + Vp8FrameConfig::Vp8BufferReference ref) { + frames_since_buffer_refresh_[BufferToIndex(ref)] = 0; +} + +void DefaultTemporalLayers::CullPendingFramesBefore(uint32_t timestamp) { + while (!pending_frames_.empty() && + pending_frames_.front().timestamp != timestamp) { + pending_frames_.pop_front(); + } +} + void DefaultTemporalLayers::OnEncodeDone(size_t stream_index, uint32_t rtp_timestamp, size_t size_bytes, @@ -491,17 +528,15 @@ void DefaultTemporalLayers::OnEncodeDone(size_t stream_index, return; } - auto pending_frame = pending_frames_.find(rtp_timestamp); - RTC_DCHECK(pending_frame != pending_frames_.end()); - - PendingFrame& frame = pending_frame->second; + CullPendingFramesBefore(rtp_timestamp); + RTC_CHECK(!pending_frames_.empty()); + PendingFrame& frame = pending_frames_.front(); + RTC_DCHECK_EQ(frame.timestamp, rtp_timestamp); const Vp8FrameConfig& frame_config = frame.dependency_info.frame_config; -#if RTC_DCHECK_IS_ON if (is_keyframe) { // Signal key-frame so checker resets state. RTC_DCHECK(checker_->CheckTemporalConfig(true, frame_config)); } -#endif CodecSpecificInfoVP8& vp8_info = info->codecSpecific.VP8; if (num_layers_ == 1) { @@ -515,10 +550,10 @@ void DefaultTemporalLayers::OnEncodeDone(size_t stream_index, vp8_info.layerSync = true; // Keyframes are always sync frames. for (Vp8BufferReference buffer : kAllBuffers) { - if (kf_buffers_.find(buffer) != kf_buffers_.end()) { + if (is_static_buffer_[BufferToIndex(buffer)]) { // Update frame count of all kf-only buffers, regardless of state of // |pending_frames_|. - frames_since_buffer_refresh_[buffer] = 0; + ResetNumFramesSinceBufferRefresh(buffer); } else { // Key-frames update all buffers, this should be reflected when // updating state in FrameEncoded(). @@ -558,8 +593,9 @@ void DefaultTemporalLayers::OnEncodeDone(size_t stream_index, vp8_info.updatedBuffers[vp8_info.updatedBuffersCount++] = i; } - if (references || updates) + if (references || updates) { generic_frame_info.encoder_buffers.emplace_back(i, references, updates); + } } // The templates are always present on keyframes, and then refered to by @@ -578,19 +614,20 @@ void DefaultTemporalLayers::OnEncodeDone(size_t stream_index, if (!frame.expired) { for (Vp8BufferReference buffer : kAllBuffers) { if (frame.updated_buffer_mask & static_cast(buffer)) { - frames_since_buffer_refresh_[buffer] = 0; + ResetNumFramesSinceBufferRefresh(buffer); } } } - pending_frames_.erase(pending_frame); + pending_frames_.pop_front(); } void DefaultTemporalLayers::OnFrameDropped(size_t stream_index, uint32_t rtp_timestamp) { - auto pending_frame = pending_frames_.find(rtp_timestamp); - RTC_DCHECK(pending_frame != pending_frames_.end()); - pending_frames_.erase(pending_frame); + CullPendingFramesBefore(rtp_timestamp); + RTC_CHECK(!pending_frames_.empty()); + RTC_DCHECK_EQ(pending_frames_.front().timestamp, rtp_timestamp); + pending_frames_.pop_front(); } void DefaultTemporalLayers::OnPacketLossRateUpdate(float packet_loss_rate) {} diff --git a/modules/video_coding/codecs/vp8/default_temporal_layers.h b/modules/video_coding/codecs/vp8/default_temporal_layers.h index d127d8056d..bc6574c54c 100644 --- a/modules/video_coding/codecs/vp8/default_temporal_layers.h +++ b/modules/video_coding/codecs/vp8/default_temporal_layers.h @@ -15,8 +15,9 @@ #include #include +#include +#include #include -#include #include #include #include @@ -53,13 +54,15 @@ class DefaultTemporalLayers final : public Vp8FrameBufferController { Vp8EncoderConfig UpdateConfiguration(size_t stream_index) override; + // Callbacks methods on frame completion. OnEncodeDone() or OnFrameDropped() + // should be called once for each NextFrameConfig() call (using the RTP + // timestamp as ID), and the calls MUST be in the same order. void OnEncodeDone(size_t stream_index, uint32_t rtp_timestamp, size_t size_bytes, bool is_keyframe, int qp, CodecSpecificInfo* info) override; - void OnFrameDropped(size_t stream_index, uint32_t rtp_timestamp) override; void OnPacketLossRateUpdate(float packet_loss_rate) override; @@ -70,6 +73,7 @@ class DefaultTemporalLayers final : public Vp8FrameBufferController { const VideoEncoder::LossNotification& loss_notification) override; private: + static constexpr size_t kNumReferenceBuffers = 3; // Last, golden, altref. struct DependencyInfo { DependencyInfo() = default; DependencyInfo(absl::string_view indication_symbols, @@ -81,46 +85,54 @@ class DefaultTemporalLayers final : public Vp8FrameBufferController { absl::InlinedVector decode_target_indications; Vp8FrameConfig frame_config; }; + struct PendingFrame { + PendingFrame(); + PendingFrame(uint32_t timestamp, + bool expired, + uint8_t updated_buffers_mask, + const DependencyInfo& dependency_info); + uint32_t timestamp = 0; + // Flag indicating if this frame has expired, ie it belongs to a previous + // iteration of the temporal pattern. + bool expired = false; + // Bitmask of Vp8BufferReference flags, indicating which buffers this frame + // updates. + uint8_t updated_buffer_mask = 0; + // The frame config returned by NextFrameConfig() for this frame. + DependencyInfo dependency_info; + }; static std::vector GetDependencyInfo(size_t num_layers); + static std::bitset DetermineStaticBuffers( + const std::vector& temporal_pattern); bool IsSyncFrame(const Vp8FrameConfig& config) const; void ValidateReferences(Vp8FrameConfig::BufferFlags* flags, Vp8FrameConfig::Vp8BufferReference ref) const; void UpdateSearchOrder(Vp8FrameConfig* config); + size_t NumFramesSinceBufferRefresh( + Vp8FrameConfig::Vp8BufferReference ref) const; + void ResetNumFramesSinceBufferRefresh(Vp8FrameConfig::Vp8BufferReference ref); + void CullPendingFramesBefore(uint32_t timestamp); const size_t num_layers_; const std::vector temporal_ids_; const std::vector temporal_pattern_; - // Set of buffers that are never updated except by keyframes. - std::set kf_buffers_; + // Per reference buffer flag indicating if it is static, meaning it is only + // updated by key-frames. + const std::bitset is_static_buffer_; FrameDependencyStructure GetTemplateStructure(int num_layers) const; uint8_t pattern_idx_; // Updated cumulative bitrates, per temporal layer. absl::optional> new_bitrates_bps_; - struct PendingFrame { - PendingFrame(); - PendingFrame(bool expired, - uint8_t updated_buffers_mask, - const DependencyInfo& dependency_info); - // Flag indicating if this frame has expired, ie it belongs to a previous - // iteration of the temporal pattern. - bool expired = false; - // Bitmask of Vp8BufferReference flags, indicating which buffers this frame - // updates. - uint8_t updated_buffer_mask = 0; - // The frame config returned by NextFrameConfig() for this frame. - DependencyInfo dependency_info; - }; - // Map from rtp timestamp to pending frame status. Reset on pattern loop. - std::map pending_frames_; + // Status for each pending frame, in + std::deque pending_frames_; - // One counter per Vp8BufferReference, indicating number of frames since last + // One counter per reference buffer, indicating number of frames since last // refresh. For non-base-layer frames (ie golden, altref buffers), this is // reset when the pattern loops. - std::map - frames_since_buffer_refresh_; + std::array frames_since_buffer_refresh_; // Optional utility used to verify reference validity. std::unique_ptr checker_; diff --git a/modules/video_coding/codecs/vp8/default_temporal_layers_unittest.cc b/modules/video_coding/codecs/vp8/default_temporal_layers_unittest.cc index 64ad40ab76..a18ac40e7d 100644 --- a/modules/video_coding/codecs/vp8/default_temporal_layers_unittest.cc +++ b/modules/video_coding/codecs/vp8/default_temporal_layers_unittest.cc @@ -687,6 +687,25 @@ TEST_F(TemporalLayersTest, KeyFrame) { } } +TEST_F(TemporalLayersTest, SetsTlCountOnFirstConfigUpdate) { + // Create an instance and fetch config update without setting any rate. + constexpr int kNumLayers = 2; + DefaultTemporalLayers tl(kNumLayers); + Vp8EncoderConfig config = tl.UpdateConfiguration(0); + + // Config should indicate correct number of temporal layers, but zero bitrate. + ASSERT_TRUE(config.temporal_layer_config.has_value()); + EXPECT_EQ(config.temporal_layer_config->ts_number_layers, + uint32_t{kNumLayers}); + std::array + kZeroRate = {}; + EXPECT_EQ(config.temporal_layer_config->ts_target_bitrate, kZeroRate); + + // On second call, no new update. + config = tl.UpdateConfiguration(0); + EXPECT_FALSE(config.temporal_layer_config.has_value()); +} + class TemporalLayersReferenceTest : public TemporalLayersTest, public ::testing::WithParamInterface { public: diff --git a/modules/video_coding/codecs/vp8/include/vp8.h b/modules/video_coding/codecs/vp8/include/vp8.h index 44efbeeb3b..d05c3a68d1 100644 --- a/modules/video_coding/codecs/vp8/include/vp8.h +++ b/modules/video_coding/codecs/vp8/include/vp8.h @@ -14,10 +14,10 @@ #include #include +#include "absl/base/attributes.h" #include "api/video_codecs/video_encoder.h" #include "api/video_codecs/vp8_frame_buffer_controller.h" #include "modules/video_coding/include/video_codec_interface.h" -#include "rtc_base/deprecation.h" namespace webrtc { @@ -40,7 +40,8 @@ class VP8Encoder { static std::unique_ptr Create(); static std::unique_ptr Create(Settings settings); - RTC_DEPRECATED static std::unique_ptr Create( + ABSL_DEPRECATED("") + static std::unique_ptr Create( std::unique_ptr frame_buffer_controller_factory); }; diff --git a/modules/video_coding/codecs/vp8/libvpx_vp8_decoder.cc b/modules/video_coding/codecs/vp8/libvpx_vp8_decoder.cc index 979ded9a63..9d6ffdba90 100644 --- a/modules/video_coding/codecs/vp8/libvpx_vp8_decoder.cc +++ b/modules/video_coding/codecs/vp8/libvpx_vp8_decoder.cc @@ -54,13 +54,9 @@ constexpr bool kIsArm = false; #endif absl::optional DefaultDeblockParams() { - if (kIsArm) { - // For ARM, this is only called when deblocking is explicitly enabled, and - // the default strength is set by the ctor. - return LibvpxVp8Decoder::DeblockParams(); - } - // For non-arm, don't use the explicit deblocking settings by default. - return absl::nullopt; + return LibvpxVp8Decoder::DeblockParams(/*max_level=*/8, + /*degrade_qp=*/60, + /*min_qp=*/30); } absl::optional diff --git a/modules/video_coding/codecs/vp8/libvpx_vp8_decoder.h b/modules/video_coding/codecs/vp8/libvpx_vp8_decoder.h index 8d84b67ce3..60295e5d5d 100644 --- a/modules/video_coding/codecs/vp8/libvpx_vp8_decoder.h +++ b/modules/video_coding/codecs/vp8/libvpx_vp8_decoder.h @@ -42,9 +42,12 @@ class LibvpxVp8Decoder : public VideoDecoder { const char* ImplementationName() const override; struct DeblockParams { - int max_level = 6; // Deblocking strength: [0, 16]. - int degrade_qp = 1; // If QP value is below, start lowering |max_level|. - int min_qp = 0; // If QP value is below, turn off deblocking. + DeblockParams() : max_level(6), degrade_qp(1), min_qp(0) {} + DeblockParams(int max_level, int degrade_qp, int min_qp) + : max_level(max_level), degrade_qp(degrade_qp), min_qp(min_qp) {} + int max_level; // Deblocking strength: [0, 16]. + int degrade_qp; // If QP value is below, start lowering |max_level|. + int min_qp; // If QP value is below, turn off deblocking. }; private: diff --git a/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.cc b/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.cc index 7713a0d3d0..a994193031 100644 --- a/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.cc +++ b/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.cc @@ -21,6 +21,7 @@ #include #include +#include "absl/algorithm/container.h" #include "api/scoped_refptr.h" #include "api/video/video_content_type.h" #include "api/video/video_frame_buffer.h" @@ -49,11 +50,6 @@ constexpr char kVP8IosMaxNumberOfThreadFieldTrial[] = constexpr char kVP8IosMaxNumberOfThreadFieldTrialParameter[] = "max_thread"; #endif -constexpr char kVp8GetEncoderInfoOverrideFieldTrial[] = - "WebRTC-VP8-GetEncoderInfoOverride"; -constexpr char kVp8RequestedResolutionAlignmentFieldTrialParameter[] = - "requested_resolution_alignment"; - constexpr char kVp8ForcePartitionResilience[] = "WebRTC-VP8-ForcePartitionResilience"; @@ -165,13 +161,51 @@ void ApplyVp8EncoderConfigToVpxConfig(const Vp8EncoderConfig& encoder_config, } } -absl::optional GetRequestedResolutionAlignmentOverride() { - const std::string trial_string = - field_trial::FindFullName(kVp8GetEncoderInfoOverrideFieldTrial); - FieldTrialOptional requested_resolution_alignment( - kVp8RequestedResolutionAlignmentFieldTrialParameter); - ParseFieldTrial({&requested_resolution_alignment}, trial_string); - return requested_resolution_alignment.GetOptional(); +bool IsCompatibleVideoFrameBufferType(VideoFrameBuffer::Type left, + VideoFrameBuffer::Type right) { + if (left == VideoFrameBuffer::Type::kI420 || + left == VideoFrameBuffer::Type::kI420A) { + // LibvpxVp8Encoder does not care about the alpha channel, I420A and I420 + // are considered compatible. + return right == VideoFrameBuffer::Type::kI420 || + right == VideoFrameBuffer::Type::kI420A; + } + return left == right; +} + +void SetRawImagePlanes(vpx_image_t* raw_image, VideoFrameBuffer* buffer) { + switch (buffer->type()) { + case VideoFrameBuffer::Type::kI420: + case VideoFrameBuffer::Type::kI420A: { + const I420BufferInterface* i420_buffer = buffer->GetI420(); + RTC_DCHECK(i420_buffer); + raw_image->planes[VPX_PLANE_Y] = + const_cast(i420_buffer->DataY()); + raw_image->planes[VPX_PLANE_U] = + const_cast(i420_buffer->DataU()); + raw_image->planes[VPX_PLANE_V] = + const_cast(i420_buffer->DataV()); + raw_image->stride[VPX_PLANE_Y] = i420_buffer->StrideY(); + raw_image->stride[VPX_PLANE_U] = i420_buffer->StrideU(); + raw_image->stride[VPX_PLANE_V] = i420_buffer->StrideV(); + break; + } + case VideoFrameBuffer::Type::kNV12: { + const NV12BufferInterface* nv12_buffer = buffer->GetNV12(); + RTC_DCHECK(nv12_buffer); + raw_image->planes[VPX_PLANE_Y] = + const_cast(nv12_buffer->DataY()); + raw_image->planes[VPX_PLANE_U] = + const_cast(nv12_buffer->DataUV()); + raw_image->planes[VPX_PLANE_V] = raw_image->planes[VPX_PLANE_U] + 1; + raw_image->stride[VPX_PLANE_Y] = nv12_buffer->StrideY(); + raw_image->stride[VPX_PLANE_U] = nv12_buffer->StrideUV(); + raw_image->stride[VPX_PLANE_V] = nv12_buffer->StrideUV(); + break; + } + default: + RTC_NOTREACHED(); + } } } // namespace @@ -230,8 +264,6 @@ LibvpxVp8Encoder::LibvpxVp8Encoder(std::unique_ptr interface, VP8Encoder::Settings settings) : libvpx_(std::move(interface)), rate_control_settings_(RateControlSettings::ParseFromFieldTrials()), - requested_resolution_alignment_override_( - GetRequestedResolutionAlignmentOverride()), frame_buffer_controller_factory_( std::move(settings.frame_buffer_controller_factory)), resolution_bitrate_limits_(std::move(settings.resolution_bitrate_limits)), @@ -945,40 +977,29 @@ int LibvpxVp8Encoder::Encode(const VideoFrame& frame, flags[i] = send_key_frame ? VPX_EFLAG_FORCE_KF : EncodeFlags(tl_configs[i]); } - rtc::scoped_refptr input_image = frame.video_frame_buffer(); - // Since we are extracting raw pointers from |input_image| to - // |raw_images_[0]|, the resolution of these frames must match. - RTC_DCHECK_EQ(input_image->width(), raw_images_[0].d_w); - RTC_DCHECK_EQ(input_image->height(), raw_images_[0].d_h); - switch (input_image->type()) { - case VideoFrameBuffer::Type::kI420: - PrepareI420Image(input_image->GetI420()); - break; - case VideoFrameBuffer::Type::kNV12: - PrepareNV12Image(input_image->GetNV12()); - break; - default: { - rtc::scoped_refptr i420_image = - input_image->ToI420(); - if (!i420_image) { - RTC_LOG(LS_ERROR) << "Failed to convert " - << VideoFrameBufferTypeToString(input_image->type()) - << " image to I420. Can't encode frame."; - return WEBRTC_VIDEO_CODEC_ERROR; - } - input_image = i420_image; - PrepareI420Image(i420_image); - } + // Scale and map buffers and set |raw_images_| to hold pointers to the result. + // Because |raw_images_| are set to hold pointers to the prepared buffers, we + // need to keep these buffers alive through reference counting until after + // encoding is complete. + std::vector> prepared_buffers = + PrepareBuffers(frame.video_frame_buffer()); + if (prepared_buffers.empty()) { + return WEBRTC_VIDEO_CODEC_ERROR; } struct CleanUpOnExit { - explicit CleanUpOnExit(vpx_image_t& raw_image) : raw_image_(raw_image) {} + explicit CleanUpOnExit( + vpx_image_t* raw_image, + std::vector> prepared_buffers) + : raw_image_(raw_image), + prepared_buffers_(std::move(prepared_buffers)) {} ~CleanUpOnExit() { - raw_image_.planes[VPX_PLANE_Y] = nullptr; - raw_image_.planes[VPX_PLANE_U] = nullptr; - raw_image_.planes[VPX_PLANE_V] = nullptr; + raw_image_->planes[VPX_PLANE_Y] = nullptr; + raw_image_->planes[VPX_PLANE_U] = nullptr; + raw_image_->planes[VPX_PLANE_V] = nullptr; } - vpx_image_t& raw_image_; - } clean_up_on_exit(raw_images_[0]); + vpx_image_t* raw_image_; + std::vector> prepared_buffers_; + } clean_up_on_exit(&raw_images_[0], std::move(prepared_buffers)); if (send_key_frame) { // Adapt the size of the key frame when in screenshare with 1 temporal @@ -1016,7 +1037,7 @@ int LibvpxVp8Encoder::Encode(const VideoFrame& frame, // would like to use the duration of the previous frame. Unfortunately the // rate control seems to be off with that setup. Using the average input // frame rate to calculate an average duration for now. - assert(codec_.maxFramerate > 0); + RTC_DCHECK_GT(codec_.maxFramerate, 0); uint32_t duration = kRtpTicksPerSecond / codec_.maxFramerate; int error = WEBRTC_VIDEO_CODEC_OK; @@ -1053,7 +1074,7 @@ void LibvpxVp8Encoder::PopulateCodecSpecific(CodecSpecificInfo* codec_specific, int stream_idx, int encoder_idx, uint32_t timestamp) { - assert(codec_specific != NULL); + RTC_DCHECK(codec_specific); codec_specific->codecType = kVideoCodecVP8; codec_specific->codecSpecific.VP8.keyIdx = kNoKeyIdx; // TODO(hlundin) populate this @@ -1189,9 +1210,15 @@ VideoEncoder::EncoderInfo LibvpxVp8Encoder::GetEncoderInfo() const { if (!resolution_bitrate_limits_.empty()) { info.resolution_bitrate_limits = resolution_bitrate_limits_; } - if (requested_resolution_alignment_override_) { + if (encoder_info_override_.requested_resolution_alignment()) { info.requested_resolution_alignment = - *requested_resolution_alignment_override_; + *encoder_info_override_.requested_resolution_alignment(); + info.apply_alignment_to_all_simulcast_layers = + encoder_info_override_.apply_alignment_to_all_simulcast_layers(); + } + if (!encoder_info_override_.resolution_bitrate_limits().empty()) { + info.resolution_bitrate_limits = + encoder_info_override_.resolution_bitrate_limits(); } const bool enable_scaling = @@ -1272,61 +1299,121 @@ void LibvpxVp8Encoder::MaybeUpdatePixelFormat(vpx_img_fmt fmt) { } } -void LibvpxVp8Encoder::PrepareI420Image(const I420BufferInterface* frame) { - RTC_DCHECK(!raw_images_.empty()); - MaybeUpdatePixelFormat(VPX_IMG_FMT_I420); - // Image in vpx_image_t format. - // Input image is const. VP8's raw image is not defined as const. - raw_images_[0].planes[VPX_PLANE_Y] = const_cast(frame->DataY()); - raw_images_[0].planes[VPX_PLANE_U] = const_cast(frame->DataU()); - raw_images_[0].planes[VPX_PLANE_V] = const_cast(frame->DataV()); - - raw_images_[0].stride[VPX_PLANE_Y] = frame->StrideY(); - raw_images_[0].stride[VPX_PLANE_U] = frame->StrideU(); - raw_images_[0].stride[VPX_PLANE_V] = frame->StrideV(); - - for (size_t i = 1; i < encoders_.size(); ++i) { - // Scale the image down a number of times by downsampling factor - libyuv::I420Scale( - raw_images_[i - 1].planes[VPX_PLANE_Y], - raw_images_[i - 1].stride[VPX_PLANE_Y], - raw_images_[i - 1].planes[VPX_PLANE_U], - raw_images_[i - 1].stride[VPX_PLANE_U], - raw_images_[i - 1].planes[VPX_PLANE_V], - raw_images_[i - 1].stride[VPX_PLANE_V], raw_images_[i - 1].d_w, - raw_images_[i - 1].d_h, raw_images_[i].planes[VPX_PLANE_Y], - raw_images_[i].stride[VPX_PLANE_Y], raw_images_[i].planes[VPX_PLANE_U], - raw_images_[i].stride[VPX_PLANE_U], raw_images_[i].planes[VPX_PLANE_V], - raw_images_[i].stride[VPX_PLANE_V], raw_images_[i].d_w, - raw_images_[i].d_h, libyuv::kFilterBilinear); +std::vector> +LibvpxVp8Encoder::PrepareBuffers(rtc::scoped_refptr buffer) { + RTC_DCHECK_EQ(buffer->width(), raw_images_[0].d_w); + RTC_DCHECK_EQ(buffer->height(), raw_images_[0].d_h); + absl::InlinedVector + supported_formats = {VideoFrameBuffer::Type::kI420, + VideoFrameBuffer::Type::kNV12}; + + rtc::scoped_refptr mapped_buffer; + if (buffer->type() != VideoFrameBuffer::Type::kNative) { + // |buffer| is already mapped. + mapped_buffer = buffer; + } else { + // Attempt to map to one of the supported formats. + mapped_buffer = buffer->GetMappedFrameBuffer(supported_formats); + } + if (!mapped_buffer || + (absl::c_find(supported_formats, mapped_buffer->type()) == + supported_formats.end() && + mapped_buffer->type() != VideoFrameBuffer::Type::kI420A)) { + // Unknown pixel format or unable to map, convert to I420 and prepare that + // buffer instead to ensure Scale() is safe to use. + auto converted_buffer = buffer->ToI420(); + if (!converted_buffer) { + RTC_LOG(LS_ERROR) << "Failed to convert " + << VideoFrameBufferTypeToString(buffer->type()) + << " image to I420. Can't encode frame."; + return {}; + } + // The buffer should now be a mapped I420 or I420A format, but some buffer + // implementations incorrectly return the wrong buffer format, such as + // kNative. As a workaround to this, we perform ToI420() a second time. + // TODO(https://crbug.com/webrtc/12602): When Android buffers have a correct + // ToI420() implementaion, remove his workaround. + if (converted_buffer->type() != VideoFrameBuffer::Type::kI420 && + converted_buffer->type() != VideoFrameBuffer::Type::kI420A) { + converted_buffer = converted_buffer->ToI420(); + if (!converted_buffer) { + RTC_LOG(LS_ERROR) << "Failed to convert " + << VideoFrameBufferTypeToString( + converted_buffer->type()) + << " image to I420. Can't encode frame."; + return {}; + } + RTC_CHECK(converted_buffer->type() == VideoFrameBuffer::Type::kI420 || + converted_buffer->type() == VideoFrameBuffer::Type::kI420A); + } + // Because |buffer| had to be converted, use |converted_buffer| instead... + buffer = mapped_buffer = converted_buffer; } -} -void LibvpxVp8Encoder::PrepareNV12Image(const NV12BufferInterface* frame) { - RTC_DCHECK(!raw_images_.empty()); - MaybeUpdatePixelFormat(VPX_IMG_FMT_NV12); - // Image in vpx_image_t format. - // Input image is const. VP8's raw image is not defined as const. - raw_images_[0].planes[VPX_PLANE_Y] = const_cast(frame->DataY()); - raw_images_[0].planes[VPX_PLANE_U] = const_cast(frame->DataUV()); - raw_images_[0].planes[VPX_PLANE_V] = raw_images_[0].planes[VPX_PLANE_U] + 1; - raw_images_[0].stride[VPX_PLANE_Y] = frame->StrideY(); - raw_images_[0].stride[VPX_PLANE_U] = frame->StrideUV(); - raw_images_[0].stride[VPX_PLANE_V] = frame->StrideUV(); + // Maybe update pixel format. + absl::InlinedVector + mapped_type = {mapped_buffer->type()}; + switch (mapped_buffer->type()) { + case VideoFrameBuffer::Type::kI420: + case VideoFrameBuffer::Type::kI420A: + MaybeUpdatePixelFormat(VPX_IMG_FMT_I420); + break; + case VideoFrameBuffer::Type::kNV12: + MaybeUpdatePixelFormat(VPX_IMG_FMT_NV12); + break; + default: + RTC_NOTREACHED(); + } + // Prepare |raw_images_| from |mapped_buffer| and, if simulcast, scaled + // versions of |buffer|. + std::vector> prepared_buffers; + SetRawImagePlanes(&raw_images_[0], mapped_buffer); + prepared_buffers.push_back(mapped_buffer); for (size_t i = 1; i < encoders_.size(); ++i) { - // Scale the image down a number of times by downsampling factor - libyuv::NV12Scale( - raw_images_[i - 1].planes[VPX_PLANE_Y], - raw_images_[i - 1].stride[VPX_PLANE_Y], - raw_images_[i - 1].planes[VPX_PLANE_U], - raw_images_[i - 1].stride[VPX_PLANE_U], raw_images_[i - 1].d_w, - raw_images_[i - 1].d_h, raw_images_[i].planes[VPX_PLANE_Y], - raw_images_[i].stride[VPX_PLANE_Y], raw_images_[i].planes[VPX_PLANE_U], - raw_images_[i].stride[VPX_PLANE_U], raw_images_[i].d_w, - raw_images_[i].d_h, libyuv::kFilterBilinear); - raw_images_[i].planes[VPX_PLANE_V] = raw_images_[i].planes[VPX_PLANE_U] + 1; + // Native buffers should implement optimized scaling and is the preferred + // buffer to scale. But if the buffer isn't native, it should be cheaper to + // scale from the previously prepared buffer which is smaller than |buffer|. + VideoFrameBuffer* buffer_to_scale = + buffer->type() == VideoFrameBuffer::Type::kNative + ? buffer.get() + : prepared_buffers.back().get(); + + auto scaled_buffer = + buffer_to_scale->Scale(raw_images_[i].d_w, raw_images_[i].d_h); + if (scaled_buffer->type() == VideoFrameBuffer::Type::kNative) { + auto mapped_scaled_buffer = + scaled_buffer->GetMappedFrameBuffer(mapped_type); + RTC_DCHECK(mapped_scaled_buffer) << "Unable to map the scaled buffer."; + if (!mapped_scaled_buffer) { + RTC_LOG(LS_ERROR) << "Failed to map scaled " + << VideoFrameBufferTypeToString(scaled_buffer->type()) + << " image to " + << VideoFrameBufferTypeToString(mapped_buffer->type()) + << ". Can't encode frame."; + return {}; + } + scaled_buffer = mapped_scaled_buffer; + } + if (!IsCompatibleVideoFrameBufferType(scaled_buffer->type(), + mapped_buffer->type())) { + RTC_LOG(LS_ERROR) << "When scaling " + << VideoFrameBufferTypeToString(buffer_to_scale->type()) + << ", the image was unexpectedly converted to " + << VideoFrameBufferTypeToString(scaled_buffer->type()) + << " instead of " + << VideoFrameBufferTypeToString(mapped_buffer->type()) + << ". Can't encode frame."; + RTC_NOTREACHED() << "Scaled buffer type " + << VideoFrameBufferTypeToString(scaled_buffer->type()) + << " is not compatible with mapped buffer type " + << VideoFrameBufferTypeToString(mapped_buffer->type()); + return {}; + } + SetRawImagePlanes(&raw_images_[i], scaled_buffer); + prepared_buffers.push_back(scaled_buffer); } + return prepared_buffers; } // static diff --git a/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.h b/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.h index bfe4275f50..ed80eacab2 100644 --- a/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.h +++ b/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.h @@ -26,6 +26,7 @@ #include "modules/video_coding/include/video_codec_interface.h" #include "modules/video_coding/utility/framerate_controller.h" #include "rtc_base/experiments/cpu_speed_experiment.h" +#include "rtc_base/experiments/encoder_info_settings.h" #include "rtc_base/experiments/rate_control_settings.h" #include "vpx/vp8cx.h" #include "vpx/vpx_encoder.h" @@ -94,17 +95,19 @@ class LibvpxVp8Encoder : public VideoEncoder { bool UpdateVpxConfiguration(size_t stream_index); void MaybeUpdatePixelFormat(vpx_img_fmt fmt); - void PrepareI420Image(const I420BufferInterface* frame); - void PrepareNV12Image(const NV12BufferInterface* frame); + // Prepares |raw_image_| to reference image data of |buffer|, or of mapped or + // scaled versions of |buffer|. Returns a list of buffers that got referenced + // as a result, allowing the caller to keep references to them until after + // encoding has finished. On failure to convert the buffer, an empty list is + // returned. + std::vector> PrepareBuffers( + rtc::scoped_refptr buffer); const std::unique_ptr libvpx_; const CpuSpeedExperiment experimental_cpu_speed_config_arm_; const RateControlSettings rate_control_settings_; - // EncoderInfo::requested_resolution_alignment override from field trial. - const absl::optional requested_resolution_alignment_override_; - EncodedImageCallback* encoded_complete_callback_ = nullptr; VideoCodec codec_; bool inited_ = false; @@ -146,6 +149,8 @@ class LibvpxVp8Encoder : public VideoEncoder { int num_steady_state_frames_ = 0; FecControllerOverride* fec_controller_override_ = nullptr; + + const LibvpxVp8EncoderInfoSettings encoder_info_override_; }; } // namespace webrtc diff --git a/modules/video_coding/codecs/vp8/test/vp8_impl_unittest.cc b/modules/video_coding/codecs/vp8/test/vp8_impl_unittest.cc index 94ea1794ef..047bf2acae 100644 --- a/modules/video_coding/codecs/vp8/test/vp8_impl_unittest.cc +++ b/modules/video_coding/codecs/vp8/test/vp8_impl_unittest.cc @@ -27,6 +27,7 @@ #include "modules/video_coding/utility/vp8_header_parser.h" #include "rtc_base/time_utils.h" #include "test/field_trial.h" +#include "test/mappable_native_buffer.h" #include "test/video_codec_settings.h" namespace webrtc { @@ -539,6 +540,29 @@ TEST(LibvpxVp8EncoderTest, RequestedResolutionAlignmentFromFieldTrial) { VP8Encoder::Settings()); EXPECT_EQ(encoder.GetEncoderInfo().requested_resolution_alignment, 10); + EXPECT_FALSE( + encoder.GetEncoderInfo().apply_alignment_to_all_simulcast_layers); + EXPECT_TRUE(encoder.GetEncoderInfo().resolution_bitrate_limits.empty()); +} + +TEST(LibvpxVp8EncoderTest, ResolutionBitrateLimitsFromFieldTrial) { + test::ScopedFieldTrials field_trials( + "WebRTC-VP8-GetEncoderInfoOverride/" + "frame_size_pixels:123|456|789," + "min_start_bitrate_bps:11000|22000|33000," + "min_bitrate_bps:44000|55000|66000," + "max_bitrate_bps:77000|88000|99000/"); + + auto* const vpx = new NiceMock(); + LibvpxVp8Encoder encoder((std::unique_ptr(vpx)), + VP8Encoder::Settings()); + + EXPECT_THAT( + encoder.GetEncoderInfo().resolution_bitrate_limits, + ::testing::ElementsAre( + VideoEncoder::ResolutionBitrateLimits{123, 11000, 44000, 77000}, + VideoEncoder::ResolutionBitrateLimits{456, 22000, 55000, 88000}, + VideoEncoder::ResolutionBitrateLimits{789, 33000, 66000, 99000})); } TEST(LibvpxVp8EncoderTest, @@ -692,4 +716,61 @@ TEST_F(TestVp8Impl, GetEncoderInfoFpsAllocationSimulcastVideo) { ::testing::ElementsAreArray(expected_fps_allocation)); } +class TestVp8ImplForPixelFormat + : public TestVp8Impl, + public ::testing::WithParamInterface { + public: + TestVp8ImplForPixelFormat() : TestVp8Impl(), mappable_type_(GetParam()) {} + + protected: + VideoFrameBuffer::Type mappable_type_; +}; + +TEST_P(TestVp8ImplForPixelFormat, EncodeNativeFrameSimulcast) { + EXPECT_EQ(WEBRTC_VIDEO_CODEC_OK, encoder_->Release()); + + // Configure simulcast. + codec_settings_.numberOfSimulcastStreams = 3; + codec_settings_.simulcastStream[0] = { + kWidth / 4, kHeight / 4, kFramerateFps, 1, 4000, 3000, 2000, 80, true}; + codec_settings_.simulcastStream[1] = { + kWidth / 2, kHeight / 2, kFramerateFps, 1, 4000, 3000, 2000, 80, true}; + codec_settings_.simulcastStream[2] = { + kWidth, kHeight, kFramerateFps, 1, 4000, 3000, 2000, 80, true}; + EXPECT_EQ(WEBRTC_VIDEO_CODEC_OK, + encoder_->InitEncode(&codec_settings_, kSettings)); + + // Create a zero-conversion NV12 frame (calling ToI420 on it crashes). + VideoFrame input_frame = + test::CreateMappableNativeFrame(1, mappable_type_, kWidth, kHeight); + + EncodedImage encoded_frame; + CodecSpecificInfo codec_specific_info; + EncodeAndWaitForFrame(input_frame, &encoded_frame, &codec_specific_info); + + // After encoding, we expect one mapping per simulcast layer. + rtc::scoped_refptr mappable_buffer = + test::GetMappableNativeBufferFromVideoFrame(input_frame); + std::vector> mapped_buffers = + mappable_buffer->GetMappedFramedBuffers(); + ASSERT_EQ(mapped_buffers.size(), 3u); + EXPECT_EQ(mapped_buffers[0]->type(), mappable_type_); + EXPECT_EQ(mapped_buffers[0]->width(), kWidth); + EXPECT_EQ(mapped_buffers[0]->height(), kHeight); + EXPECT_EQ(mapped_buffers[1]->type(), mappable_type_); + EXPECT_EQ(mapped_buffers[1]->width(), kWidth / 2); + EXPECT_EQ(mapped_buffers[1]->height(), kHeight / 2); + EXPECT_EQ(mapped_buffers[2]->type(), mappable_type_); + EXPECT_EQ(mapped_buffers[2]->width(), kWidth / 4); + EXPECT_EQ(mapped_buffers[2]->height(), kHeight / 4); + EXPECT_FALSE(mappable_buffer->DidConvertToI420()); + + EXPECT_EQ(WEBRTC_VIDEO_CODEC_OK, encoder_->Release()); +} + +INSTANTIATE_TEST_SUITE_P(All, + TestVp8ImplForPixelFormat, + ::testing::Values(VideoFrameBuffer::Type::kI420, + VideoFrameBuffer::Type::kNV12)); + } // namespace webrtc diff --git a/modules/video_coding/codecs/vp9/include/vp9_globals.h b/modules/video_coding/codecs/vp9/include/vp9_globals.h index 6f9d09933f..34aa0bc6cf 100644 --- a/modules/video_coding/codecs/vp9/include/vp9_globals.h +++ b/modules/video_coding/codecs/vp9/include/vp9_globals.h @@ -18,6 +18,7 @@ #include #include "modules/video_coding/codecs/interface/common_constants.h" +#include "rtc_base/checks.h" namespace webrtc { @@ -131,7 +132,7 @@ struct GofInfoVP9 { pid_diff[7][1] = 2; break; default: - assert(false); + RTC_NOTREACHED(); } } diff --git a/modules/video_coding/codecs/vp9/libvpx_vp9_decoder.cc b/modules/video_coding/codecs/vp9/libvpx_vp9_decoder.cc index 0a99c6a46e..3500ef5919 100644 --- a/modules/video_coding/codecs/vp9/libvpx_vp9_decoder.cc +++ b/modules/video_coding/codecs/vp9/libvpx_vp9_decoder.cc @@ -22,7 +22,6 @@ #include "common_video/include/video_frame_buffer.h" #include "modules/video_coding/utility/vp9_uncompressed_header_parser.h" #include "rtc_base/checks.h" -#include "rtc_base/keep_ref_until_done.h" #include "rtc_base/logging.h" #include "third_party/libyuv/include/libyuv/convert.h" #include "vpx/vp8dx.h" @@ -277,7 +276,7 @@ int LibvpxVp9Decoder::ReturnFrame( // This buffer contains all of |img|'s image data, a reference counted // Vp9FrameBuffer. (libvpx is done with the buffers after a few // vpx_codec_decode calls or vpx_codec_destroy). - Vp9FrameBufferPool::Vp9FrameBuffer* img_buffer = + rtc::scoped_refptr img_buffer = static_cast(img->fb_priv); // The buffer can be used directly by the VideoFrame (without copy) by @@ -312,7 +311,7 @@ int LibvpxVp9Decoder::ReturnFrame( // WrappedI420Buffer's mechanism for allowing the release of its // frame buffer is through a callback function. This is where we // should release |img_buffer|. - rtc::KeepRefUntilDone(img_buffer)); + [img_buffer] {}); } } else if (img->fmt == VPX_IMG_FMT_I444) { img_wrapped_buffer = WrapI444Buffer( @@ -323,7 +322,7 @@ int LibvpxVp9Decoder::ReturnFrame( // WrappedI444Buffer's mechanism for allowing the release of its // frame buffer is through a callback function. This is where we // should release |img_buffer|. - rtc::KeepRefUntilDone(img_buffer)); + [img_buffer] {}); } else { RTC_LOG(LS_ERROR) << "Unsupported pixel format produced by the decoder: " @@ -339,7 +338,7 @@ int LibvpxVp9Decoder::ReturnFrame( reinterpret_cast(img->planes[VPX_PLANE_U]), img->stride[VPX_PLANE_U] / 2, reinterpret_cast(img->planes[VPX_PLANE_V]), - img->stride[VPX_PLANE_V] / 2, rtc::KeepRefUntilDone(img_buffer)); + img->stride[VPX_PLANE_V] / 2, [img_buffer] {}); break; default: RTC_LOG(LS_ERROR) << "Unsupported bit depth produced by the decoder: " diff --git a/modules/video_coding/codecs/vp9/libvpx_vp9_encoder.cc b/modules/video_coding/codecs/vp9/libvpx_vp9_encoder.cc index 81223019fd..511e6df585 100644 --- a/modules/video_coding/codecs/vp9/libvpx_vp9_encoder.cc +++ b/modules/video_coding/codecs/vp9/libvpx_vp9_encoder.cc @@ -18,6 +18,7 @@ #include #include +#include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "api/video/color_space.h" @@ -1040,37 +1041,17 @@ int LibvpxVp9Encoder::Encode(const VideoFrame& input_image, // doing this. input_image_ = &input_image; - // Keep reference to buffer until encode completes. - rtc::scoped_refptr video_frame_buffer; + // In case we need to map the buffer, |mapped_buffer| is used to keep it alive + // through reference counting until after encoding has finished. + rtc::scoped_refptr mapped_buffer; const I010BufferInterface* i010_buffer; rtc::scoped_refptr i010_copy; switch (profile_) { case VP9Profile::kProfile0: { - if (input_image.video_frame_buffer()->type() == - VideoFrameBuffer::Type::kNV12) { - const NV12BufferInterface* nv12_buffer = - input_image.video_frame_buffer()->GetNV12(); - video_frame_buffer = nv12_buffer; - MaybeRewrapRawWithFormat(VPX_IMG_FMT_NV12); - raw_->planes[VPX_PLANE_Y] = const_cast(nv12_buffer->DataY()); - raw_->planes[VPX_PLANE_U] = const_cast(nv12_buffer->DataUV()); - raw_->planes[VPX_PLANE_V] = raw_->planes[VPX_PLANE_U] + 1; - raw_->stride[VPX_PLANE_Y] = nv12_buffer->StrideY(); - raw_->stride[VPX_PLANE_U] = nv12_buffer->StrideUV(); - raw_->stride[VPX_PLANE_V] = nv12_buffer->StrideUV(); - } else { - rtc::scoped_refptr i420_buffer = - input_image.video_frame_buffer()->ToI420(); - video_frame_buffer = i420_buffer; - MaybeRewrapRawWithFormat(VPX_IMG_FMT_I420); - // Image in vpx_image_t format. - // Input image is const. VPX's raw image is not defined as const. - raw_->planes[VPX_PLANE_Y] = const_cast(i420_buffer->DataY()); - raw_->planes[VPX_PLANE_U] = const_cast(i420_buffer->DataU()); - raw_->planes[VPX_PLANE_V] = const_cast(i420_buffer->DataV()); - raw_->stride[VPX_PLANE_Y] = i420_buffer->StrideY(); - raw_->stride[VPX_PLANE_U] = i420_buffer->StrideU(); - raw_->stride[VPX_PLANE_V] = i420_buffer->StrideV(); + mapped_buffer = + PrepareBufferForProfile0(input_image.video_frame_buffer()); + if (!mapped_buffer) { + return WEBRTC_VIDEO_CODEC_ERROR; } break; } @@ -1087,8 +1068,15 @@ int LibvpxVp9Encoder::Encode(const VideoFrame& input_image, break; } default: { - i010_copy = - I010Buffer::Copy(*input_image.video_frame_buffer()->ToI420()); + auto i420_buffer = input_image.video_frame_buffer()->ToI420(); + if (!i420_buffer) { + RTC_LOG(LS_ERROR) << "Failed to convert " + << VideoFrameBufferTypeToString( + input_image.video_frame_buffer()->type()) + << " image to I420. Can't encode frame."; + return WEBRTC_VIDEO_CODEC_ERROR; + } + i010_copy = I010Buffer::Copy(*i420_buffer); i010_buffer = i010_copy.get(); } } @@ -1167,7 +1155,7 @@ int LibvpxVp9Encoder::Encode(const VideoFrame& input_image, return WEBRTC_VIDEO_CODEC_OK; } -void LibvpxVp9Encoder::PopulateCodecSpecific(CodecSpecificInfo* codec_specific, +bool LibvpxVp9Encoder::PopulateCodecSpecific(CodecSpecificInfo* codec_specific, absl::optional* spatial_idx, const vpx_codec_cx_pkt& pkt, uint32_t timestamp) { @@ -1287,10 +1275,15 @@ void LibvpxVp9Encoder::PopulateCodecSpecific(CodecSpecificInfo* codec_specific, auto it = absl::c_find_if( layer_frames_, [&](const ScalableVideoController::LayerFrameConfig& config) { - return config.SpatialId() == spatial_idx->value_or(0); + return config.SpatialId() == layer_id.spatial_layer_id; }); - RTC_CHECK(it != layer_frames_.end()) - << "Failed to find spatial id " << spatial_idx->value_or(0); + if (it == layer_frames_.end()) { + RTC_LOG(LS_ERROR) << "Encoder produced a frame for layer S" + << layer_id.spatial_layer_id << "T" + << layer_id.temporal_layer_id + << " that wasn't requested."; + return false; + } codec_specific->generic_frame_info = svc_controller_->OnEncodeDone(*it); if (is_key_frame) { codec_specific->template_structure = @@ -1306,6 +1299,7 @@ void LibvpxVp9Encoder::PopulateCodecSpecific(CodecSpecificInfo* codec_specific, } } } + return true; } void LibvpxVp9Encoder::FillReferenceIndices(const vpx_codec_cx_pkt& pkt, @@ -1563,12 +1557,12 @@ vpx_svc_ref_frame_config_t LibvpxVp9Encoder::SetReferences( return ref_config; } -int LibvpxVp9Encoder::GetEncodedLayerFrame(const vpx_codec_cx_pkt* pkt) { +void LibvpxVp9Encoder::GetEncodedLayerFrame(const vpx_codec_cx_pkt* pkt) { RTC_DCHECK_EQ(pkt->kind, VPX_CODEC_CX_FRAME_PKT); if (pkt->data.frame.sz == 0) { // Ignore dropped frame. - return WEBRTC_VIDEO_CODEC_OK; + return; } vpx_svc_layer_id_t layer_id = {0}; @@ -1599,8 +1593,12 @@ int LibvpxVp9Encoder::GetEncodedLayerFrame(const vpx_codec_cx_pkt* pkt) { codec_specific_ = {}; absl::optional spatial_index; - PopulateCodecSpecific(&codec_specific_, &spatial_index, *pkt, - input_image_->timestamp()); + if (!PopulateCodecSpecific(&codec_specific_, &spatial_index, *pkt, + input_image_->timestamp())) { + // Drop the frame. + encoded_image_.set_size(0); + return; + } encoded_image_.SetSpatialIndex(spatial_index); UpdateReferenceBuffers(*pkt, pics_since_key_); @@ -1620,8 +1618,6 @@ int LibvpxVp9Encoder::GetEncodedLayerFrame(const vpx_codec_cx_pkt* pkt) { num_active_spatial_layers_; DeliverBufferedFrame(end_of_picture); } - - return WEBRTC_VIDEO_CODEC_OK; } void LibvpxVp9Encoder::DeliverBufferedFrame(bool end_of_picture) { @@ -1718,6 +1714,10 @@ VideoEncoder::EncoderInfo LibvpxVp9Encoder::GetEncoderInfo() const { VideoFrameBuffer::Type::kNV12}; } } + if (!encoder_info_override_.resolution_bitrate_limits().empty()) { + info.resolution_bitrate_limits = + encoder_info_override_.resolution_bitrate_limits(); + } return info; } @@ -1880,6 +1880,87 @@ void LibvpxVp9Encoder::MaybeRewrapRawWithFormat(const vpx_img_fmt fmt) { // else no-op since the image is already in the right format. } +rtc::scoped_refptr LibvpxVp9Encoder::PrepareBufferForProfile0( + rtc::scoped_refptr buffer) { + absl::InlinedVector + supported_formats = {VideoFrameBuffer::Type::kI420, + VideoFrameBuffer::Type::kNV12}; + + rtc::scoped_refptr mapped_buffer; + if (buffer->type() != VideoFrameBuffer::Type::kNative) { + // |buffer| is already mapped. + mapped_buffer = buffer; + } else { + // Attempt to map to one of the supported formats. + mapped_buffer = buffer->GetMappedFrameBuffer(supported_formats); + } + if (!mapped_buffer || + (absl::c_find(supported_formats, mapped_buffer->type()) == + supported_formats.end() && + mapped_buffer->type() != VideoFrameBuffer::Type::kI420A)) { + // Unknown pixel format or unable to map, convert to I420 and prepare that + // buffer instead to ensure Scale() is safe to use. + auto converted_buffer = buffer->ToI420(); + if (!converted_buffer) { + RTC_LOG(LS_ERROR) << "Failed to convert " + << VideoFrameBufferTypeToString(buffer->type()) + << " image to I420. Can't encode frame."; + return {}; + } + // The buffer should now be a mapped I420 or I420A format, but some buffer + // implementations incorrectly return the wrong buffer format, such as + // kNative. As a workaround to this, we perform ToI420() a second time. + // TODO(https://crbug.com/webrtc/12602): When Android buffers have a correct + // ToI420() implementaion, remove his workaround. + if (converted_buffer->type() != VideoFrameBuffer::Type::kI420 && + converted_buffer->type() != VideoFrameBuffer::Type::kI420A) { + converted_buffer = converted_buffer->ToI420(); + if (!converted_buffer) { + RTC_LOG(LS_ERROR) << "Failed to convert " + << VideoFrameBufferTypeToString(buffer->type()) + << " image to I420. Can't encode frame."; + return {}; + } + RTC_CHECK(converted_buffer->type() == VideoFrameBuffer::Type::kI420 || + converted_buffer->type() == VideoFrameBuffer::Type::kI420A); + } + // Because |buffer| had to be converted, use |converted_buffer| instead. + buffer = mapped_buffer = converted_buffer; + } + + // Prepare |raw_| from |mapped_buffer|. + switch (mapped_buffer->type()) { + case VideoFrameBuffer::Type::kI420: + case VideoFrameBuffer::Type::kI420A: { + MaybeRewrapRawWithFormat(VPX_IMG_FMT_I420); + const I420BufferInterface* i420_buffer = mapped_buffer->GetI420(); + RTC_DCHECK(i420_buffer); + raw_->planes[VPX_PLANE_Y] = const_cast(i420_buffer->DataY()); + raw_->planes[VPX_PLANE_U] = const_cast(i420_buffer->DataU()); + raw_->planes[VPX_PLANE_V] = const_cast(i420_buffer->DataV()); + raw_->stride[VPX_PLANE_Y] = i420_buffer->StrideY(); + raw_->stride[VPX_PLANE_U] = i420_buffer->StrideU(); + raw_->stride[VPX_PLANE_V] = i420_buffer->StrideV(); + break; + } + case VideoFrameBuffer::Type::kNV12: { + MaybeRewrapRawWithFormat(VPX_IMG_FMT_NV12); + const NV12BufferInterface* nv12_buffer = mapped_buffer->GetNV12(); + RTC_DCHECK(nv12_buffer); + raw_->planes[VPX_PLANE_Y] = const_cast(nv12_buffer->DataY()); + raw_->planes[VPX_PLANE_U] = const_cast(nv12_buffer->DataUV()); + raw_->planes[VPX_PLANE_V] = raw_->planes[VPX_PLANE_U] + 1; + raw_->stride[VPX_PLANE_Y] = nv12_buffer->StrideY(); + raw_->stride[VPX_PLANE_U] = nv12_buffer->StrideUV(); + raw_->stride[VPX_PLANE_V] = nv12_buffer->StrideUV(); + break; + } + default: + RTC_NOTREACHED(); + } + return mapped_buffer; +} + } // namespace webrtc #endif // RTC_ENABLE_VP9 diff --git a/modules/video_coding/codecs/vp9/libvpx_vp9_encoder.h b/modules/video_coding/codecs/vp9/libvpx_vp9_encoder.h index 037c760c17..954c044c2c 100644 --- a/modules/video_coding/codecs/vp9/libvpx_vp9_encoder.h +++ b/modules/video_coding/codecs/vp9/libvpx_vp9_encoder.h @@ -21,13 +21,14 @@ #include "api/fec_controller_override.h" #include "api/transport/webrtc_key_value_config.h" #include "api/video_codecs/video_encoder.h" +#include "api/video_codecs/vp9_profile.h" #include "common_video/include/video_frame_buffer_pool.h" -#include "media/base/vp9_profile.h" #include "modules/video_coding/codecs/interface/libvpx_interface.h" #include "modules/video_coding/codecs/vp9/include/vp9.h" #include "modules/video_coding/codecs/vp9/vp9_frame_buffer_pool.h" #include "modules/video_coding/svc/scalable_video_controller.h" #include "modules/video_coding/utility/framerate_controller.h" +#include "rtc_base/experiments/encoder_info_settings.h" #include "vpx/vp8cx.h" namespace webrtc { @@ -64,7 +65,7 @@ class LibvpxVp9Encoder : public VP9Encoder { // Call encoder initialize function and set control settings. int InitAndSetControlSettings(const VideoCodec* inst); - void PopulateCodecSpecific(CodecSpecificInfo* codec_specific, + bool PopulateCodecSpecific(CodecSpecificInfo* codec_specific, absl::optional* spatial_idx, const vpx_codec_cx_pkt& pkt, uint32_t timestamp); @@ -81,7 +82,7 @@ class LibvpxVp9Encoder : public VP9Encoder { bool ExplicitlyConfiguredSpatialLayers() const; bool SetSvcRates(const VideoBitrateAllocation& bitrate_allocation); - virtual int GetEncodedLayerFrame(const vpx_codec_cx_pkt* pkt); + void GetEncodedLayerFrame(const vpx_codec_cx_pkt* pkt); // Callback function for outputting packets per spatial layer. static void EncoderOutputCodedPacketCallback(vpx_codec_cx_pkt* pkt, @@ -102,6 +103,12 @@ class LibvpxVp9Encoder : public VP9Encoder { size_t SteadyStateSize(int sid, int tid); void MaybeRewrapRawWithFormat(const vpx_img_fmt fmt); + // Prepares |raw_| to reference image data of |buffer|, or of mapped or scaled + // versions of |buffer|. Returns the buffer that got referenced as a result, + // allowing the caller to keep a reference to it until after encoding has + // finished. On failure to convert the buffer, null is returned. + rtc::scoped_refptr PrepareBufferForProfile0( + rtc::scoped_refptr buffer); const std::unique_ptr libvpx_; EncodedImage encoded_image_; @@ -230,6 +237,8 @@ class LibvpxVp9Encoder : public VP9Encoder { int num_steady_state_frames_; // Only set config when this flag is set. bool config_changed_; + + const LibvpxVp9EncoderInfoSettings encoder_info_override_; }; } // namespace webrtc diff --git a/modules/video_coding/codecs/vp9/test/vp9_impl_unittest.cc b/modules/video_coding/codecs/vp9/test/vp9_impl_unittest.cc index 3d658838ed..e96538427b 100644 --- a/modules/video_coding/codecs/vp9/test/vp9_impl_unittest.cc +++ b/modules/video_coding/codecs/vp9/test/vp9_impl_unittest.cc @@ -15,8 +15,8 @@ #include "api/video/color_space.h" #include "api/video/i420_buffer.h" #include "api/video_codecs/video_encoder.h" +#include "api/video_codecs/vp9_profile.h" #include "common_video/libyuv/include/webrtc_libyuv.h" -#include "media/base/vp9_profile.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" #include "modules/video_coding/codecs/interface/libvpx_interface.h" #include "modules/video_coding/codecs/interface/mock_libvpx_interface.h" @@ -30,6 +30,7 @@ #include "test/field_trial.h" #include "test/gmock.h" #include "test/gtest.h" +#include "test/mappable_native_buffer.h" #include "test/video_codec_settings.h" namespace webrtc { @@ -158,6 +159,31 @@ TEST_P(TestVp9ImplForPixelFormat, EncodeDecode) { color_space.chroma_siting_vertical()); } +TEST_P(TestVp9ImplForPixelFormat, EncodeNativeBuffer) { + VideoFrame input_frame = NextInputFrame(); + // Replace the input frame with a fake native buffer of the same size and + // underlying pixel format. Do not allow ToI420() for non-I420 buffers, + // ensuring zero-conversion. + input_frame = test::CreateMappableNativeFrame( + input_frame.ntp_time_ms(), input_frame.video_frame_buffer()->type(), + input_frame.width(), input_frame.height()); + EXPECT_EQ(WEBRTC_VIDEO_CODEC_OK, encoder_->Encode(input_frame, nullptr)); + EncodedImage encoded_frame; + CodecSpecificInfo codec_specific_info; + ASSERT_TRUE(WaitForEncodedFrame(&encoded_frame, &codec_specific_info)); + + // After encoding, we would expect a single mapping to have happened. + rtc::scoped_refptr mappable_buffer = + test::GetMappableNativeBufferFromVideoFrame(input_frame); + std::vector> mapped_buffers = + mappable_buffer->GetMappedFramedBuffers(); + ASSERT_EQ(mapped_buffers.size(), 1u); + EXPECT_EQ(mapped_buffers[0]->type(), mappable_buffer->mappable_type()); + EXPECT_EQ(mapped_buffers[0]->width(), input_frame.width()); + EXPECT_EQ(mapped_buffers[0]->height(), input_frame.height()); + EXPECT_FALSE(mappable_buffer->DidConvertToI420()); +} + TEST_P(TestVp9ImplForPixelFormat, DecodedColorSpaceFromBitstream) { EXPECT_EQ(WEBRTC_VIDEO_CODEC_OK, encoder_->Encode(NextInputFrame(), nullptr)); EncodedImage encoded_frame; @@ -522,6 +548,62 @@ TEST(Vp9ImplTest, EnableDisableSpatialLayersWithSvcController) { } } +MATCHER_P2(GenericLayerIs, spatial_id, temporal_id, "") { + if (arg.codec_specific_info.generic_frame_info == absl::nullopt) { + *result_listener << " miss generic_frame_info"; + return false; + } + const auto& layer = *arg.codec_specific_info.generic_frame_info; + if (layer.spatial_id != spatial_id || layer.temporal_id != temporal_id) { + *result_listener << " frame from layer (" << layer.spatial_id << ", " + << layer.temporal_id << ")"; + return false; + } + return true; +} + +TEST(Vp9ImplTest, SpatialUpswitchNotAtGOFBoundary) { + test::ScopedFieldTrials override_field_trials( + "WebRTC-Vp9DependencyDescriptor/Enabled/"); + std::unique_ptr encoder = VP9Encoder::Create(); + VideoCodec codec_settings = DefaultCodecSettings(); + ConfigureSvc(codec_settings, /*num_spatial_layers=*/3, + /*num_temporal_layers=*/3); + codec_settings.VP9()->frameDroppingOn = true; + EXPECT_EQ(encoder->InitEncode(&codec_settings, kSettings), + WEBRTC_VIDEO_CODEC_OK); + + EncodedVideoFrameProducer producer(*encoder); + producer.SetResolution({kWidth, kHeight}); + + // Disable all but spatial_layer = 0; + VideoBitrateAllocation bitrate_allocation; + int layer_bitrate_bps = codec_settings.spatialLayers[0].targetBitrate * 1000; + bitrate_allocation.SetBitrate(0, 0, layer_bitrate_bps); + bitrate_allocation.SetBitrate(0, 1, layer_bitrate_bps); + bitrate_allocation.SetBitrate(0, 2, layer_bitrate_bps); + encoder->SetRates(VideoEncoder::RateControlParameters( + bitrate_allocation, codec_settings.maxFramerate)); + EXPECT_THAT(producer.SetNumInputFrames(3).Encode(), + ElementsAre(GenericLayerIs(0, 0), GenericLayerIs(0, 2), + GenericLayerIs(0, 1))); + + // Upswitch to spatial_layer = 1 + layer_bitrate_bps = codec_settings.spatialLayers[1].targetBitrate * 1000; + bitrate_allocation.SetBitrate(1, 0, layer_bitrate_bps); + bitrate_allocation.SetBitrate(1, 1, layer_bitrate_bps); + bitrate_allocation.SetBitrate(1, 2, layer_bitrate_bps); + encoder->SetRates(VideoEncoder::RateControlParameters( + bitrate_allocation, codec_settings.maxFramerate)); + // Expect upswitch doesn't happen immediately since there is no S1 frame that + // S1T2 frame can reference. + EXPECT_THAT(producer.SetNumInputFrames(1).Encode(), + ElementsAre(GenericLayerIs(0, 2))); + // Expect spatial upswitch happens now, at T0 frame. + EXPECT_THAT(producer.SetNumInputFrames(1).Encode(), + ElementsAre(GenericLayerIs(0, 0), GenericLayerIs(1, 0))); +} + TEST_F(TestVp9Impl, DisableEnableBaseLayerTriggersKeyFrame) { // Configure encoder to produce N spatial layers. Encode frames for all // layers. Then disable all but the last layer. Then reenable all back again. @@ -1636,6 +1718,27 @@ TEST_F(TestVp9Impl, Profile0PreferredPixelFormats) { VideoFrameBuffer::Type::kI420)); } +TEST_F(TestVp9Impl, EncoderInfoWithoutResolutionBitrateLimits) { + EXPECT_TRUE(encoder_->GetEncoderInfo().resolution_bitrate_limits.empty()); +} + +TEST_F(TestVp9Impl, EncoderInfoWithBitrateLimitsFromFieldTrial) { + test::ScopedFieldTrials field_trials( + "WebRTC-VP9-GetEncoderInfoOverride/" + "frame_size_pixels:123|456|789," + "min_start_bitrate_bps:11000|22000|33000," + "min_bitrate_bps:44000|55000|66000," + "max_bitrate_bps:77000|88000|99000/"); + SetUp(); + + EXPECT_THAT( + encoder_->GetEncoderInfo().resolution_bitrate_limits, + ::testing::ElementsAre( + VideoEncoder::ResolutionBitrateLimits{123, 11000, 44000, 77000}, + VideoEncoder::ResolutionBitrateLimits{456, 22000, 55000, 88000}, + VideoEncoder::ResolutionBitrateLimits{789, 33000, 66000, 99000})); +} + TEST_F(TestVp9Impl, EncoderInfoFpsAllocation) { const uint8_t kNumSpatialLayers = 3; const uint8_t kNumTemporalLayers = 3; @@ -2231,7 +2334,7 @@ TEST(Vp9SpeedSettingsTrialsTest, } TEST(Vp9SpeedSettingsTrialsTest, PerLayerFlagsWithSvc) { - // Per-temporal and satial layer speed settings: + // Per-temporal and spatial layer speed settings: // SL0: TL0 = speed 5, TL1/TL2 = speed 8. // SL1/2: TL0 = speed 7, TL1/TL2 = speed 9. // Deblocking-mode per spatial layer: diff --git a/modules/video_coding/codecs/vp9/vp9.cc b/modules/video_coding/codecs/vp9/vp9.cc index 1efb1b4f9f..d9caf0f039 100644 --- a/modules/video_coding/codecs/vp9/vp9.cc +++ b/modules/video_coding/codecs/vp9/vp9.cc @@ -14,7 +14,7 @@ #include "api/transport/field_trial_based_config.h" #include "api/video_codecs/sdp_video_format.h" -#include "media/base/vp9_profile.h" +#include "api/video_codecs/vp9_profile.h" #include "modules/video_coding/codecs/vp9/libvpx_vp9_decoder.h" #include "modules/video_coding/codecs/vp9/libvpx_vp9_encoder.h" #include "rtc_base/checks.h" diff --git a/modules/video_coding/codecs/vp9/vp9_frame_buffer_pool.cc b/modules/video_coding/codecs/vp9/vp9_frame_buffer_pool.cc index 4d0a6983ac..d1f58b1bb8 100644 --- a/modules/video_coding/codecs/vp9/vp9_frame_buffer_pool.cc +++ b/modules/video_coding/codecs/vp9/vp9_frame_buffer_pool.cc @@ -15,7 +15,6 @@ #include "rtc_base/checks.h" #include "rtc_base/logging.h" -#include "rtc_base/ref_counted_object.h" #include "vpx/vpx_codec.h" #include "vpx/vpx_decoder.h" #include "vpx/vpx_frame_buffer.h" @@ -68,7 +67,7 @@ Vp9FrameBufferPool::GetFrameBuffer(size_t min_size) { } // Otherwise create one. if (available_buffer == nullptr) { - available_buffer = new rtc::RefCountedObject(); + available_buffer = new Vp9FrameBuffer(); allocated_buffers_.push_back(available_buffer); if (allocated_buffers_.size() > max_num_buffers_) { RTC_LOG(LS_WARNING) diff --git a/modules/video_coding/codecs/vp9/vp9_frame_buffer_pool.h b/modules/video_coding/codecs/vp9/vp9_frame_buffer_pool.h index d37a9fc0e2..bce10be4d9 100644 --- a/modules/video_coding/codecs/vp9/vp9_frame_buffer_pool.h +++ b/modules/video_coding/codecs/vp9/vp9_frame_buffer_pool.h @@ -16,9 +16,9 @@ #include +#include "api/ref_counted_base.h" #include "api/scoped_refptr.h" #include "rtc_base/buffer.h" -#include "rtc_base/ref_count.h" #include "rtc_base/synchronization/mutex.h" struct vpx_codec_ctx; @@ -65,13 +65,14 @@ constexpr size_t kDefaultMaxNumBuffers = 68; // vpx_codec_destroy(decoder_ctx); class Vp9FrameBufferPool { public: - class Vp9FrameBuffer : public rtc::RefCountInterface { + class Vp9FrameBuffer final + : public rtc::RefCountedNonVirtual { public: uint8_t* GetData(); size_t GetDataSize() const; void SetSize(size_t size); - virtual bool HasOneRef() const = 0; + using rtc::RefCountedNonVirtual::HasOneRef; private: // Data as an easily resizable buffer. diff --git a/modules/video_coding/decoder_database.cc b/modules/video_coding/decoder_database.cc index 594ca86553..6aa332eb88 100644 --- a/modules/video_coding/decoder_database.cc +++ b/modules/video_coding/decoder_database.cc @@ -56,7 +56,6 @@ bool VCMDecoderDataBase::DeregisterExternalDecoder(uint8_t payload_type) { // Release it if it was registered and in use. ptr_decoder_.reset(); } - DeregisterReceiveCodec(payload_type); delete it->second; dec_external_map_.erase(it); return true; @@ -73,6 +72,12 @@ void VCMDecoderDataBase::RegisterExternalDecoder(VideoDecoder* external_decoder, dec_external_map_[payload_type] = ext_decoder; } +bool VCMDecoderDataBase::IsExternalDecoderRegistered( + uint8_t payload_type) const { + return payload_type == current_payload_type_ || + FindExternalDecoderItem(payload_type); +} + bool VCMDecoderDataBase::RegisterReceiveCodec(uint8_t payload_type, const VideoCodec* receive_codec, int number_of_cores) { diff --git a/modules/video_coding/decoder_database.h b/modules/video_coding/decoder_database.h index f7c5d70338..81c68e4138 100644 --- a/modules/video_coding/decoder_database.h +++ b/modules/video_coding/decoder_database.h @@ -44,6 +44,7 @@ class VCMDecoderDataBase { bool DeregisterExternalDecoder(uint8_t payload_type); void RegisterExternalDecoder(VideoDecoder* external_decoder, uint8_t payload_type); + bool IsExternalDecoderRegistered(uint8_t payload_type) const; bool RegisterReceiveCodec(uint8_t payload_type, const VideoCodec* receive_codec, diff --git a/modules/video_coding/decoding_state.cc b/modules/video_coding/decoding_state.cc index a951358992..5e405cbd05 100644 --- a/modules/video_coding/decoding_state.cc +++ b/modules/video_coding/decoding_state.cc @@ -55,21 +55,22 @@ uint16_t VCMDecodingState::sequence_num() const { } bool VCMDecodingState::IsOldFrame(const VCMFrameBuffer* frame) const { - assert(frame != NULL); + RTC_DCHECK(frame); if (in_initial_state_) return false; return !IsNewerTimestamp(frame->Timestamp(), time_stamp_); } bool VCMDecodingState::IsOldPacket(const VCMPacket* packet) const { - assert(packet != NULL); + RTC_DCHECK(packet); if (in_initial_state_) return false; return !IsNewerTimestamp(packet->timestamp, time_stamp_); } void VCMDecodingState::SetState(const VCMFrameBuffer* frame) { - assert(frame != NULL && frame->GetHighSeqNum() >= 0); + RTC_DCHECK(frame); + RTC_CHECK_GE(frame->GetHighSeqNum(), 0); if (!UsingFlexibleMode(frame)) UpdateSyncState(frame); sequence_num_ = static_cast(frame->GetHighSeqNum()); @@ -150,7 +151,7 @@ bool VCMDecodingState::UpdateEmptyFrame(const VCMFrameBuffer* frame) { } void VCMDecodingState::UpdateOldPacket(const VCMPacket* packet) { - assert(packet != NULL); + RTC_DCHECK(packet); if (packet->timestamp == time_stamp_) { // Late packet belonging to the last decoded frame - make sure we update the // last decoded sequence number. @@ -204,7 +205,7 @@ bool VCMDecodingState::ContinuousFrame(const VCMFrameBuffer* frame) const { // - Sequence numbers. // Return true when in initial state. // Note that when a method is not applicable it will return false. - assert(frame != NULL); + RTC_DCHECK(frame); // A key frame is always considered continuous as it doesn't refer to any // frames and therefore won't introduce any errors even if prior frames are // missing. diff --git a/modules/video_coding/decoding_state.h b/modules/video_coding/decoding_state.h index b87fb2d034..ec972949d8 100644 --- a/modules/video_coding/decoding_state.h +++ b/modules/video_coding/decoding_state.h @@ -11,6 +11,7 @@ #ifndef MODULES_VIDEO_CODING_DECODING_STATE_H_ #define MODULES_VIDEO_CODING_DECODING_STATE_H_ +#include #include #include #include diff --git a/modules/video_coding/deprecated/BUILD.gn b/modules/video_coding/deprecated/BUILD.gn index fd3a5fa5fc..487c0267d5 100644 --- a/modules/video_coding/deprecated/BUILD.gn +++ b/modules/video_coding/deprecated/BUILD.gn @@ -21,7 +21,6 @@ rtc_library("nack_module") { "../../../api/units:timestamp", "../../../rtc_base:checks", "../../../rtc_base:criticalsection", - "../../../rtc_base:deprecation", "../../../rtc_base:logging", "../../../rtc_base:macromagic", "../../../rtc_base:rtc_numerics", @@ -31,4 +30,5 @@ rtc_library("nack_module") { "../../../system_wrappers:field_trial", "../../utility", ] + absl_deps = [ "//third_party/abseil-cpp/absl/base:core_headers" ] } diff --git a/modules/video_coding/deprecated/nack_module.h b/modules/video_coding/deprecated/nack_module.h index f9580ae80c..2fac6ce128 100644 --- a/modules/video_coding/deprecated/nack_module.h +++ b/modules/video_coding/deprecated/nack_module.h @@ -17,11 +17,11 @@ #include #include +#include "absl/base/attributes.h" #include "api/units/time_delta.h" #include "modules/include/module.h" #include "modules/include/module_common_types.h" #include "modules/video_coding/histogram.h" -#include "rtc_base/deprecation.h" #include "rtc_base/numerics/sequence_number_util.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread_annotations.h" @@ -125,7 +125,7 @@ class DEPRECATED_NackModule : public Module { const absl::optional backoff_settings_; }; -using NackModule = RTC_DEPRECATED DEPRECATED_NackModule; +using NackModule ABSL_DEPRECATED("") = DEPRECATED_NackModule; } // namespace webrtc diff --git a/modules/video_coding/encoded_frame.cc b/modules/video_coding/encoded_frame.cc index f7d666bea4..637a20cfc9 100644 --- a/modules/video_coding/encoded_frame.cc +++ b/modules/video_coding/encoded_frame.cc @@ -136,6 +136,10 @@ void VCMEncodedFrame::CopyCodecSpecific(const RTPVideoHeader* header) { _codecSpecificInfo.codecType = kVideoCodecH264; break; } + case kVideoCodecAV1: { + _codecSpecificInfo.codecType = kVideoCodecAV1; + break; + } default: { _codecSpecificInfo.codecType = kVideoCodecGeneric; break; diff --git a/modules/video_coding/encoded_frame.h b/modules/video_coding/encoded_frame.h index 61054ead35..9cc769277d 100644 --- a/modules/video_coding/encoded_frame.h +++ b/modules/video_coding/encoded_frame.h @@ -52,7 +52,6 @@ class RTC_EXPORT VCMEncodedFrame : public EncodedImage { using EncodedImage::GetEncodedData; using EncodedImage::NtpTimeMs; using EncodedImage::PacketInfos; - using EncodedImage::Retain; using EncodedImage::set_size; using EncodedImage::SetColorSpace; using EncodedImage::SetEncodedData; diff --git a/modules/video_coding/frame_buffer.cc b/modules/video_coding/frame_buffer.cc index 0f64ab1449..8f73e73bad 100644 --- a/modules/video_coding/frame_buffer.cc +++ b/modules/video_coding/frame_buffer.cc @@ -75,7 +75,7 @@ VCMFrameBufferEnum VCMFrameBuffer::InsertPacket(const VCMPacket& packet, int64_t timeInMs, const FrameData& frame_data) { TRACE_EVENT0("webrtc", "VCMFrameBuffer::InsertPacket"); - assert(!(NULL == packet.dataPtr && packet.sizeBytes > 0)); + RTC_DCHECK(!(NULL == packet.dataPtr && packet.sizeBytes > 0)); if (packet.dataPtr != NULL) { _payloadType = packet.payloadType; } @@ -230,19 +230,19 @@ void VCMFrameBuffer::SetState(VCMFrameBufferStateEnum state) { switch (state) { case kStateIncomplete: // we can go to this state from state kStateEmpty - assert(_state == kStateEmpty); + RTC_DCHECK_EQ(_state, kStateEmpty); // Do nothing, we received a packet break; case kStateComplete: - assert(_state == kStateEmpty || _state == kStateIncomplete); + RTC_DCHECK(_state == kStateEmpty || _state == kStateIncomplete); break; case kStateEmpty: // Should only be set to empty through Reset(). - assert(false); + RTC_NOTREACHED(); break; } _state = state; diff --git a/modules/video_coding/frame_buffer2.cc b/modules/video_coding/frame_buffer2.cc index c085557e5b..80f9eb1814 100644 --- a/modules/video_coding/frame_buffer2.cc +++ b/modules/video_coding/frame_buffer2.cc @@ -63,7 +63,11 @@ FrameBuffer::FrameBuffer(Clock* clock, last_log_non_decoded_ms_(-kLogNonDecodedIntervalMs), add_rtt_to_playout_delay_( webrtc::field_trial::IsEnabled("WebRTC-AddRttToPlayoutDelay")), - rtt_mult_settings_(RttMultExperiment::GetRttMultValue()) { + rtt_mult_settings_(RttMultExperiment::GetRttMultValue()), + zero_playout_delay_max_decode_queue_size_("max_decode_queue_size", + kMaxFramesBuffered) { + ParseFieldTrial({&zero_playout_delay_max_decode_queue_size_}, + field_trial::FindFullName("WebRTC-ZeroPlayoutDelay")); callback_checker_.Detach(); } @@ -110,6 +114,8 @@ void FrameBuffer::StartWaitForNextFrameOnQueue() { if (!frames_to_decode_.empty()) { // We have frames, deliver! frame = absl::WrapUnique(GetNextFrame()); + timing_->SetLastDecodeScheduledTimestamp( + clock_->TimeInMilliseconds()); } else if (clock_->TimeInMilliseconds() < latest_return_time_ms_) { // If there's no frames to decode and there is still time left, it // means that the frame buffer was cleared between creation and @@ -179,8 +185,7 @@ int64_t FrameBuffer::FindNextFrame(int64_t now_ms) { for (size_t i = 0; i < EncodedFrame::kMaxFrameReferences && i < next_frame_it->second.frame->num_references; ++i) { - if (next_frame_it->second.frame->references[i] >= - frame_it->first.picture_id) { + if (next_frame_it->second.frame->references[i] >= frame_it->first) { has_inter_layer_dependency = true; break; } @@ -211,7 +216,11 @@ int64_t FrameBuffer::FindNextFrame(int64_t now_ms) { if (frame->RenderTime() == -1) { frame->SetRenderTime(timing_->RenderTimeMs(frame->Timestamp(), now_ms)); } - wait_ms = timing_->MaxWaitingTime(frame->RenderTime(), now_ms); + bool too_many_frames_queued = + frames_.size() > zero_playout_delay_max_decode_queue_size_ ? true + : false; + wait_ms = timing_->MaxWaitingTime(frame->RenderTime(), now_ms, + too_many_frames_queued); // This will cause the frame buffer to prefer high framerate rather // than high resolution in the case of the decoder not decoding fast @@ -262,11 +271,11 @@ EncodedFrame* FrameBuffer::GetNextFrame() { // Remove decoded frame and all undecoded frames before it. if (stats_callback_) { - unsigned int dropped_frames = std::count_if( - frames_.begin(), frame_it, - [](const std::pair& frame) { - return frame.second.frame != nullptr; - }); + unsigned int dropped_frames = + std::count_if(frames_.begin(), frame_it, + [](const std::pair& frame) { + return frame.second.frame != nullptr; + }); if (dropped_frames > 0) { stats_callback_->OnDroppedFrames(dropped_frames); } @@ -371,7 +380,7 @@ void FrameBuffer::UpdateRtt(int64_t rtt_ms) { bool FrameBuffer::ValidReferences(const EncodedFrame& frame) const { for (size_t i = 0; i < frame.num_references; ++i) { - if (frame.references[i] >= frame.id.picture_id) + if (frame.references[i] >= frame.Id()) return false; for (size_t j = i + 1; j < frame.num_references; ++j) { @@ -397,82 +406,69 @@ int64_t FrameBuffer::InsertFrame(std::unique_ptr frame) { MutexLock lock(&mutex_); - const VideoLayerFrameId& id = frame->id; - int64_t last_continuous_picture_id = - !last_continuous_frame_ ? -1 : last_continuous_frame_->picture_id; + int64_t last_continuous_frame_id = last_continuous_frame_.value_or(-1); if (!ValidReferences(*frame)) { - RTC_LOG(LS_WARNING) << "Frame with (picture_id:spatial_id) (" - << id.picture_id << ":" - << static_cast(id.spatial_layer) - << ") has invalid frame references, dropping frame."; - return last_continuous_picture_id; + RTC_LOG(LS_WARNING) << "Frame " << frame->Id() + << " has invalid frame references, dropping frame."; + return last_continuous_frame_id; } if (frames_.size() >= kMaxFramesBuffered) { if (frame->is_keyframe()) { - RTC_LOG(LS_WARNING) << "Inserting keyframe (picture_id:spatial_id) (" - << id.picture_id << ":" - << static_cast(id.spatial_layer) - << ") but buffer is full, clearing" + RTC_LOG(LS_WARNING) << "Inserting keyframe " << frame->Id() + << " but buffer is full, clearing" " buffer and inserting the frame."; ClearFramesAndHistory(); } else { - RTC_LOG(LS_WARNING) << "Frame with (picture_id:spatial_id) (" - << id.picture_id << ":" - << static_cast(id.spatial_layer) - << ") could not be inserted due to the frame " + RTC_LOG(LS_WARNING) << "Frame " << frame->Id() + << " could not be inserted due to the frame " "buffer being full, dropping frame."; - return last_continuous_picture_id; + return last_continuous_frame_id; } } auto last_decoded_frame = decoded_frames_history_.GetLastDecodedFrameId(); auto last_decoded_frame_timestamp = decoded_frames_history_.GetLastDecodedFrameTimestamp(); - if (last_decoded_frame && id <= *last_decoded_frame) { + if (last_decoded_frame && frame->Id() <= *last_decoded_frame) { if (AheadOf(frame->Timestamp(), *last_decoded_frame_timestamp) && frame->is_keyframe()) { - // If this frame has a newer timestamp but an earlier picture id then we - // assume there has been a jump in the picture id due to some encoder + // If this frame has a newer timestamp but an earlier frame id then we + // assume there has been a jump in the frame id due to some encoder // reconfiguration or some other reason. Even though this is not according // to spec we can still continue to decode from this frame if it is a // keyframe. RTC_LOG(LS_WARNING) - << "A jump in picture id was detected, clearing buffer."; + << "A jump in frame id was detected, clearing buffer."; ClearFramesAndHistory(); - last_continuous_picture_id = -1; + last_continuous_frame_id = -1; } else { - RTC_LOG(LS_WARNING) << "Frame with (picture_id:spatial_id) (" - << id.picture_id << ":" - << static_cast(id.spatial_layer) - << ") inserted after frame (" - << last_decoded_frame->picture_id << ":" - << static_cast(last_decoded_frame->spatial_layer) - << ") was handed off for decoding, dropping frame."; - return last_continuous_picture_id; + RTC_LOG(LS_WARNING) << "Frame " << frame->Id() << " inserted after frame " + << *last_decoded_frame + << " was handed off for decoding, dropping frame."; + return last_continuous_frame_id; } } // Test if inserting this frame would cause the order of the frames to become // ambiguous (covering more than half the interval of 2^16). This can happen - // when the picture id make large jumps mid stream. - if (!frames_.empty() && id < frames_.begin()->first && - frames_.rbegin()->first < id) { - RTC_LOG(LS_WARNING) - << "A jump in picture id was detected, clearing buffer."; + // when the frame id make large jumps mid stream. + if (!frames_.empty() && frame->Id() < frames_.begin()->first && + frames_.rbegin()->first < frame->Id()) { + RTC_LOG(LS_WARNING) << "A jump in frame id was detected, clearing buffer."; ClearFramesAndHistory(); - last_continuous_picture_id = -1; + last_continuous_frame_id = -1; } - auto info = frames_.emplace(id, FrameInfo()).first; + auto info = frames_.emplace(frame->Id(), FrameInfo()).first; if (info->second.frame) { - return last_continuous_picture_id; + return last_continuous_frame_id; } if (!UpdateFrameInfoWithIncomingFrame(*frame, info)) - return last_continuous_picture_id; + return last_continuous_frame_id; if (!frame->delayed_by_retransmission()) timing_->IncomingTimestamp(frame->Timestamp(), frame->ReceivedTime()); @@ -489,7 +485,7 @@ int64_t FrameBuffer::InsertFrame(std::unique_ptr frame) { if (info->second.num_missing_continuous == 0) { info->second.continuous = true; PropagateContinuity(info); - last_continuous_picture_id = last_continuous_frame_->picture_id; + last_continuous_frame_id = *last_continuous_frame_; // Since we now have new continuous frames there might be a better frame // to return from NextFrame. @@ -505,7 +501,7 @@ int64_t FrameBuffer::InsertFrame(std::unique_ptr frame) { } } - return last_continuous_picture_id; + return last_continuous_frame_id; } void FrameBuffer::PropagateContinuity(FrameMap::iterator start) { @@ -558,8 +554,6 @@ void FrameBuffer::PropagateDecodability(const FrameInfo& info) { bool FrameBuffer::UpdateFrameInfoWithIncomingFrame(const EncodedFrame& frame, FrameMap::iterator info) { TRACE_EVENT0("webrtc", "FrameBuffer::UpdateFrameInfoWithIncomingFrame"); - const VideoLayerFrameId& id = frame.id; - auto last_decoded_frame = decoded_frames_history_.GetLastDecodedFrameId(); RTC_DCHECK(!last_decoded_frame || *last_decoded_frame < info->first); @@ -572,35 +566,34 @@ bool FrameBuffer::UpdateFrameInfoWithIncomingFrame(const EncodedFrame& frame, // so that |num_missing_continuous| and |num_missing_decodable| can be // decremented as frames become continuous/are decoded. struct Dependency { - VideoLayerFrameId id; + int64_t frame_id; bool continuous; }; std::vector not_yet_fulfilled_dependencies; // Find all dependencies that have not yet been fulfilled. for (size_t i = 0; i < frame.num_references; ++i) { - VideoLayerFrameId ref_key(frame.references[i], frame.id.spatial_layer); // Does |frame| depend on a frame earlier than the last decoded one? - if (last_decoded_frame && ref_key <= *last_decoded_frame) { + if (last_decoded_frame && frame.references[i] <= *last_decoded_frame) { // Was that frame decoded? If not, this |frame| will never become // decodable. - if (!decoded_frames_history_.WasDecoded(ref_key)) { + if (!decoded_frames_history_.WasDecoded(frame.references[i])) { int64_t now_ms = clock_->TimeInMilliseconds(); if (last_log_non_decoded_ms_ + kLogNonDecodedIntervalMs < now_ms) { RTC_LOG(LS_WARNING) - << "Frame with (picture_id:spatial_id) (" << id.picture_id << ":" - << static_cast(id.spatial_layer) - << ") depends on a non-decoded frame more previous than" - " the last decoded frame, dropping frame."; + << "Frame " << frame.Id() + << " depends on a non-decoded frame more previous than the last " + "decoded frame, dropping frame."; last_log_non_decoded_ms_ = now_ms; } return false; } } else { - auto ref_info = frames_.find(ref_key); + auto ref_info = frames_.find(frame.references[i]); bool ref_continuous = ref_info != frames_.end() && ref_info->second.continuous; - not_yet_fulfilled_dependencies.push_back({ref_key, ref_continuous}); + not_yet_fulfilled_dependencies.push_back( + {frame.references[i], ref_continuous}); } } @@ -611,7 +604,7 @@ bool FrameBuffer::UpdateFrameInfoWithIncomingFrame(const EncodedFrame& frame, if (dep.continuous) --info->second.num_missing_continuous; - frames_[dep.id].dependent_frames.push_back(id); + frames_[dep.frame_id].dependent_frames.push_back(frame.Id()); } return true; @@ -647,11 +640,11 @@ void FrameBuffer::UpdateTimingFrameInfo() { void FrameBuffer::ClearFramesAndHistory() { TRACE_EVENT0("webrtc", "FrameBuffer::ClearFramesAndHistory"); if (stats_callback_) { - unsigned int dropped_frames = std::count_if( - frames_.begin(), frames_.end(), - [](const std::pair& frame) { - return frame.second.frame != nullptr; - }); + unsigned int dropped_frames = + std::count_if(frames_.begin(), frames_.end(), + [](const std::pair& frame) { + return frame.second.frame != nullptr; + }); if (dropped_frames > 0) { stats_callback_->OnDroppedFrames(dropped_frames); } @@ -683,7 +676,6 @@ EncodedFrame* FrameBuffer::CombineAndDeleteFrames( // Spatial index of combined frame is set equal to spatial index of its top // spatial layer. first_frame->SetSpatialIndex(last_frame->SpatialIndex().value_or(0)); - first_frame->id.spatial_layer = last_frame->id.spatial_layer; first_frame->video_timing_mutable()->network2_timestamp_ms = last_frame->video_timing().network2_timestamp_ms; diff --git a/modules/video_coding/frame_buffer2.h b/modules/video_coding/frame_buffer2.h index 080ce7c10c..c7d8fcd403 100644 --- a/modules/video_coding/frame_buffer2.h +++ b/modules/video_coding/frame_buffer2.h @@ -18,16 +18,17 @@ #include #include "absl/container/inlined_vector.h" +#include "api/sequence_checker.h" #include "api/video/encoded_frame.h" #include "modules/video_coding/include/video_coding_defines.h" #include "modules/video_coding/inter_frame_delay.h" #include "modules/video_coding/jitter_estimator.h" #include "modules/video_coding/utility/decoded_frames_history.h" #include "rtc_base/event.h" +#include "rtc_base/experiments/field_trial_parser.h" #include "rtc_base/experiments/rtt_mult_experiment.h" #include "rtc_base/numerics/sequence_number_util.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_queue.h" #include "rtc_base/task_utils/repeating_task.h" @@ -58,7 +59,6 @@ class FrameBuffer { // Insert a frame into the frame buffer. Returns the picture id // of the last continuous frame or -1 if there is no continuous frame. - // TODO(philipel): Return a VideoLayerFrameId and not only the picture id. int64_t InsertFrame(std::unique_ptr frame); // Get the next frame for decoding. Will return at latest after @@ -95,7 +95,7 @@ class FrameBuffer { // Which other frames that have direct unfulfilled dependencies // on this frame. - absl::InlinedVector dependent_frames; + absl::InlinedVector dependent_frames; // A frame is continiuous if it has all its referenced/indirectly // referenced frames. @@ -115,7 +115,7 @@ class FrameBuffer { std::unique_ptr frame; }; - using FrameMap = std::map; + using FrameMap = std::map; // Check that the references of |frame| are valid. bool ValidReferences(const EncodedFrame& frame) const; @@ -178,8 +178,7 @@ class FrameBuffer { VCMJitterEstimator jitter_estimator_ RTC_GUARDED_BY(mutex_); VCMTiming* const timing_ RTC_GUARDED_BY(mutex_); VCMInterFrameDelay inter_frame_delay_ RTC_GUARDED_BY(mutex_); - absl::optional last_continuous_frame_ - RTC_GUARDED_BY(mutex_); + absl::optional last_continuous_frame_ RTC_GUARDED_BY(mutex_); std::vector frames_to_decode_ RTC_GUARDED_BY(mutex_); bool stopped_ RTC_GUARDED_BY(mutex_); VCMVideoProtection protection_mode_ RTC_GUARDED_BY(mutex_); @@ -190,6 +189,13 @@ class FrameBuffer { // rtt_mult experiment settings. const absl::optional rtt_mult_settings_; + + // Maximum number of frames in the decode queue to allow pacing. If the + // queue grows beyond the max limit, pacing will be disabled and frames will + // be pushed to the decoder as soon as possible. This only has an effect + // when the low-latency rendering path is active, which is indicated by + // the frame's render time == 0. + FieldTrialParameter zero_playout_delay_max_decode_queue_size_; }; } // namespace video_coding diff --git a/modules/video_coding/frame_buffer2_unittest.cc b/modules/video_coding/frame_buffer2_unittest.cc index 7ec789533d..f2a0589411 100644 --- a/modules/video_coding/frame_buffer2_unittest.cc +++ b/modules/video_coding/frame_buffer2_unittest.cc @@ -56,7 +56,8 @@ class VCMTimingFake : public VCMTiming { } int64_t MaxWaitingTime(int64_t render_time_ms, - int64_t now_ms) const override { + int64_t now_ms, + bool too_many_frames_queued) const override { return render_time_ms - now_ms - kDecodeTime; } @@ -164,7 +165,7 @@ class TestFrameBuffer2 : public ::testing::Test { {rtc::checked_cast(refs)...}}; auto frame = std::make_unique(); - frame->id.picture_id = picture_id; + frame->SetId(picture_id); frame->SetSpatialIndex(spatial_layer); frame->SetTimestamp(ts_ms * 90); frame->num_references = references.size(); @@ -199,7 +200,7 @@ class TestFrameBuffer2 : public ::testing::Test { time_task_queue_.PostTask([this, max_wait_time, keyframe_required]() { buffer_->NextFrame( max_wait_time, keyframe_required, &time_task_queue_, - [this](std::unique_ptr frame, + [this](std::unique_ptr frame, video_coding::FrameBuffer::ReturnReason reason) { if (reason != FrameBuffer::ReturnReason::kStopped) { frames_.emplace_back(std::move(frame)); @@ -214,7 +215,7 @@ class TestFrameBuffer2 : public ::testing::Test { void CheckFrame(size_t index, int picture_id, int spatial_layer) { ASSERT_LT(index, frames_.size()); ASSERT_TRUE(frames_[index]); - ASSERT_EQ(picture_id, frames_[index]->id.picture_id); + ASSERT_EQ(picture_id, frames_[index]->Id()); ASSERT_EQ(spatial_layer, frames_[index]->SpatialIndex().value_or(0)); } @@ -278,7 +279,7 @@ TEST_F(TestFrameBuffer2, ZeroPlayoutDelay) { new FrameBuffer(time_controller_.GetClock(), &timing, &stats_callback_)); const VideoPlayoutDelay kPlayoutDelayMs = {0, 0}; std::unique_ptr test_frame(new FrameObjectFake()); - test_frame->id.picture_id = 0; + test_frame->SetId(0); test_frame->SetPlayoutDelay(kPlayoutDelayMs); buffer_->InsertFrame(std::move(test_frame)); ExtractFrame(0, false); @@ -544,7 +545,7 @@ TEST_F(TestFrameBuffer2, StatsCallback) { { std::unique_ptr frame(new FrameObjectFake()); frame->SetEncodedData(EncodedImageBuffer::Create(kFrameSize)); - frame->id.picture_id = pid; + frame->SetId(pid); frame->SetTimestamp(ts); frame->num_references = 0; diff --git a/modules/video_coding/frame_object.cc b/modules/video_coding/frame_object.cc index 25fd23234c..d226dcd013 100644 --- a/modules/video_coding/frame_object.cc +++ b/modules/video_coding/frame_object.cc @@ -19,7 +19,6 @@ #include "rtc_base/checks.h" namespace webrtc { -namespace video_coding { RtpFrameObject::RtpFrameObject( uint16_t first_seq_num, uint16_t last_seq_num, @@ -69,6 +68,7 @@ RtpFrameObject::RtpFrameObject( rotation_ = rotation; SetColorSpace(color_space); + SetVideoFrameTrackingId(rtp_video_header_.video_frame_tracking_id); content_type_ = content_type; if (timing.flags != VideoSendTiming::kInvalid) { // ntp_time_ms_ may be -1 if not estimated yet. This is not a problem, @@ -128,5 +128,4 @@ const RTPVideoHeader& RtpFrameObject::GetRtpVideoHeader() const { return rtp_video_header_; } -} // namespace video_coding } // namespace webrtc diff --git a/modules/video_coding/frame_object.h b/modules/video_coding/frame_object.h index d812b8fd2e..c6f069f241 100644 --- a/modules/video_coding/frame_object.h +++ b/modules/video_coding/frame_object.h @@ -15,7 +15,6 @@ #include "api/video/encoded_frame.h" namespace webrtc { -namespace video_coding { class RtpFrameObject : public EncodedFrame { public: @@ -64,7 +63,6 @@ class RtpFrameObject : public EncodedFrame { int times_nacked_; }; -} // namespace video_coding } // namespace webrtc #endif // MODULES_VIDEO_CODING_FRAME_OBJECT_H_ diff --git a/modules/video_coding/g3doc/index.md b/modules/video_coding/g3doc/index.md new file mode 100644 index 0000000000..6fdab6eb98 --- /dev/null +++ b/modules/video_coding/g3doc/index.md @@ -0,0 +1,177 @@ + + + +# Video coding in WebRTC + +## Introduction to layered video coding + +[Video coding][video-coding-wiki] is the process of encoding a stream of +uncompressed video frames into a compressed bitstream, whose bitrate is lower +than that of the original stream. + +### Block-based hybrid video coding + +All video codecs in WebRTC are based on the block-based hybrid video coding +paradigm, which entails prediction of the original video frame using either +[information from previously encoded frames][motion-compensation-wiki] or +information from previously encoded portions of the current frame, subtraction +of the prediction from the original video, and +[transform][transform-coding-wiki] and [quantization][quantization-wiki] of the +resulting difference. The output of the quantization process, quantized +transform coefficients, is losslessly [entropy coded][entropy-coding-wiki] along +with other encoder parameters (e.g., those related to the prediction process) +and then a reconstruction is constructed by inverse quantizing and inverse +transforming the quantized transform coefficients and adding the result to the +prediction. Finally, in-loop filtering is applied and the resulting +reconstruction is stored as a reference frame to be used to develop predictions +for future frames. + +### Frame types + +When an encoded frame depends on previously encoded frames (i.e., it has one or +more inter-frame dependencies), the prior frames must be available at the +receiver before the current frame can be decoded. In order for a receiver to +start decoding an encoded bitstream, a frame which has no prior dependencies is +required. Such a frame is called a "key frame". For real-time-communications +encoding, key frames typically compress less efficiently than "delta frames" +(i.e., frames whose predictions are derived from previously encoded frames). + +### Single-layer coding + +In 1:1 calls, the encoded bitstream has a single recipient. Using end-to-end +bandwidth estimation, the target bitrate can thus be well tailored for the +intended recipient. The number of key frames can be kept to a minimum and the +compressability of the stream can be maximized. One way of achiving this is by +using "single-layer coding", where each delta frame only depends on the frame +that was most recently encoded. + +### Scalable video coding + +In multiway conferences, on the other hand, the encoded bitstream has multiple +recipients each of whom may have different downlink bandwidths. In order to +tailor the encoded bitstreams to a heterogeneous network of receivers, +[scalable video coding][svc-wiki] can be used. The idea is to introduce +structure into the dependency graph of the encoded bitstream, such that _layers_ of +the full stream can be decoded using only available lower layers. This structure +allows for a [selective forwarding unit][sfu-webrtc-glossary] to discard upper +layers of the of the bitstream in order to achieve the intended downlink +bandwidth. + +There are multiple types of scalability: + +* _Temporal scalability_ are layers whose framerate (and bitrate) is lower than that of the upper layer(s) +* _Spatial scalability_ are layers whose resolution (and bitrate) is lower than that of the upper layer(s) +* _Quality scalability_ are layers whose bitrate is lower than that of the upper layer(s) + +WebRTC supports temporal scalability for `VP8`, `VP9` and `AV1`, and spatial +scalability for `VP9` and `AV1`. + +### Simulcast + +Simulcast is another approach for multiway conferencing, where multiple +_independent_ bitstreams are produced by the encoder. + +In cases where multiple encodings of the same source are required (e.g., uplink +transmission in a multiway call), spatial scalability with inter-layer +prediction generally offers superior coding efficiency compared with simulcast. +When a single encoding is required (e.g., downlink transmission in any call), +simulcast generally provides better coding efficiency for the upper spatial +layers. The `K-SVC` concept, where spatial inter-layer dependencies are only +used to encode key frames, for which inter-layer prediction is typically +significantly more effective than it is for delta frames, can be seen as a +compromise between full spatial scalability and simulcast. + +## Overview of implementation in `modules/video_coding` + +Given the general introduction to video coding above, we now describe some +specifics of the [`modules/video_coding`][modules-video-coding] folder in WebRTC. + +### Built-in software codecs in [`modules/video_coding/codecs`][modules-video-coding-codecs] + +This folder contains WebRTC-specific classes that wrap software codec +implementations for different video coding standards: + +* [libaom][libaom-src] for [AV1][av1-spec] +* [libvpx][libvpx-src] for [VP8][vp8-spec] and [VP9][vp9-spec] +* [OpenH264][openh264-src] for [H.264 constrained baseline profile][h264-spec] + +Users of the library can also inject their own codecs, using the +[VideoEncoderFactory][video-encoder-factory-interface] and +[VideoDecoderFactory][video-decoder-factory-interface] interfaces. This is how +platform-supported codecs, such as hardware backed codecs, are implemented. + +### Video codec test framework in [`modules/video_coding/codecs/test`][modules-video-coding-codecs-test] + +This folder contains a test framework that can be used to evaluate video quality +performance of different video codec implementations. + +### SVC helper classes in [`modules/video_coding/svc`][modules-video-coding-svc] + +* [`ScalabilityStructure*`][scalabilitystructure] - different + [standardized scalability structures][scalability-structure-spec] +* [`ScalableVideoController`][scalablevideocontroller] - provides instructions to the video encoder how + to create a scalable stream +* [`SvcRateAllocator`][svcrateallocator] - bitrate allocation to different spatial and temporal + layers + +### Utility classes in [`modules/video_coding/utility`][modules-video-coding-utility] + +* [`FrameDropper`][framedropper] - drops incoming frames when encoder systematically + overshoots its target bitrate +* [`FramerateController`][frameratecontroller] - drops incoming frames to achieve a target framerate +* [`QpParser`][qpparser] - parses the quantization parameter from a bitstream +* [`QualityScaler`][qualityscaler] - signals when an encoder generates encoded frames whose + quantization parameter is outside the window of acceptable values +* [`SimulcastRateAllocator`][simulcastrateallocator] - bitrate allocation to simulcast layers + +### General helper classes in [`modules/video_coding`][modules-video-coding] + +* [`FecControllerDefault`][feccontrollerdefault] - provides a default implementation for rate + allocation to [forward error correction][fec-wiki] +* [`VideoCodecInitializer`][videocodecinitializer] - converts between different encoder configuration + structs + +### Receiver buffer classes in [`modules/video_coding`][modules-video-coding] + +* [`PacketBuffer`][packetbuffer] - (re-)combines RTP packets into frames +* [`RtpFrameReferenceFinder`][rtpframereferencefinder] - determines dependencies between frames based on information in the RTP header, payload header and RTP extensions +* [`FrameBuffer`][framebuffer] - order frames based on their dependencies to be fed to the decoder + +[video-coding-wiki]: https://en.wikipedia.org/wiki/Video_coding_format +[motion-compensation-wiki]: https://en.wikipedia.org/wiki/Motion_compensation +[transform-coding-wiki]: https://en.wikipedia.org/wiki/Transform_coding +[motion-vector-wiki]: https://en.wikipedia.org/wiki/Motion_vector +[mpeg-wiki]: https://en.wikipedia.org/wiki/Moving_Picture_Experts_Group +[svc-wiki]: https://en.wikipedia.org/wiki/Scalable_Video_Coding +[sfu-webrtc-glossary]: https://webrtcglossary.com/sfu/ +[libvpx-src]: https://chromium.googlesource.com/webm/libvpx/ +[libaom-src]: https://aomedia.googlesource.com/aom/ +[openh264-src]: https://github.com/cisco/openh264 +[vp8-spec]: https://tools.ietf.org/html/rfc6386 +[vp9-spec]: https://storage.googleapis.com/downloads.webmproject.org/docs/vp9/vp9-bitstream-specification-v0.6-20160331-draft.pdf +[av1-spec]: https://aomediacodec.github.io/av1-spec/ +[h264-spec]: https://www.itu.int/rec/T-REC-H.264-201906-I/en +[video-encoder-factory-interface]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/video_codecs/video_encoder_factory.h;l=27;drc=afadfb24a5e608da6ae102b20b0add53a083dcf3 +[video-decoder-factory-interface]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/video_codecs/video_decoder_factory.h;l=27;drc=49c293f03d8f593aa3aca282577fcb14daa63207 +[scalability-structure-spec]: https://w3c.github.io/webrtc-svc/#scalabilitymodes* +[fec-wiki]: https://en.wikipedia.org/wiki/Error_correction_code#Forward_error_correction +[entropy-coding-wiki]: https://en.wikipedia.org/wiki/Entropy_encoding +[modules-video-coding]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/ +[modules-video-coding-codecs]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/codecs/ +[modules-video-coding-codecs-test]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/codecs/test/ +[modules-video-coding-svc]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/svc/ +[modules-video-coding-utility]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/utility/ +[scalabilitystructure]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/svc/create_scalability_structure.h?q=CreateScalabilityStructure +[scalablevideocontroller]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/svc/scalable_video_controller.h?q=ScalableVideoController +[svcrateallocator]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/svc/svc_rate_allocator.h?q=SvcRateAllocator +[framedropper]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/utility/frame_dropper.h?q=FrameDropper +[frameratecontroller]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/utility/framerate_controller.h?q=FramerateController +[qpparser]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/utility/qp_parser.h?q=QpParser +[qualityscaler]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/utility/quality_scaler.h?q=QualityScaler +[simulcastrateallocator]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/utility/simulcast_rate_allocator.h?q=SimulcastRateAllocator +[feccontrollerdefault]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/fec_controller_default.h?q=FecControllerDefault +[videocodecinitializer]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/include/video_codec_initializer.h?q=VideoCodecInitializer +[packetbuffer]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/packet_buffer.h?q=PacketBuffer +[rtpframereferencefinder]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/rtp_frame_reference_finder.h?q=RtpFrameReferenceFinder +[framebuffer]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/frame_buffer2.h?q=FrameBuffer +[quantization-wiki]: https://en.wikipedia.org/wiki/Quantization_(signal_processing) diff --git a/modules/video_coding/generic_decoder.cc b/modules/video_coding/generic_decoder.cc index 28c97f08f5..acb4307f3f 100644 --- a/modules/video_coding/generic_decoder.cc +++ b/modules/video_coding/generic_decoder.cc @@ -91,18 +91,30 @@ void VCMDecodedFrameCallback::Decoded(VideoFrame& decodedImage, "timestamp", decodedImage.timestamp()); // TODO(holmer): We should improve this so that we can handle multiple // callbacks from one call to Decode(). - VCMFrameInformation* frameInfo; + absl::optional frameInfo; int timestamp_map_size = 0; + int dropped_frames = 0; { MutexLock lock(&lock_); + int initial_timestamp_map_size = _timestampMap.Size(); frameInfo = _timestampMap.Pop(decodedImage.timestamp()); timestamp_map_size = _timestampMap.Size(); + // _timestampMap.Pop() erases all frame upto the specified timestamp and + // return the frame info for this timestamp if it exists. Thus, the + // difference in the _timestampMap size before and after Pop() will show + // internally dropped frames. + dropped_frames = + initial_timestamp_map_size - timestamp_map_size - (frameInfo ? 1 : 0); } - if (frameInfo == NULL) { + if (dropped_frames > 0) { + _receiveCallback->OnDroppedFrames(dropped_frames); + } + + if (!frameInfo) { RTC_LOG(LS_WARNING) << "Too many frames backed up in the decoder, dropping " - "this one."; - _receiveCallback->OnDroppedFrames(1); + "frame with timestamp " + << decodedImage.timestamp(); return; } @@ -110,8 +122,7 @@ void VCMDecodedFrameCallback::Decoded(VideoFrame& decodedImage, decodedImage.set_packet_infos(frameInfo->packet_infos); decodedImage.set_rotation(frameInfo->rotation); - if (low_latency_renderer_enabled_ && frameInfo->playout_delay.min_ms == 0 && - frameInfo->playout_delay.max_ms > 0) { + if (low_latency_renderer_enabled_) { absl::optional max_composition_delay_in_frames = _timing->MaxCompositionDelayInFrames(); if (max_composition_delay_in_frames) { @@ -197,18 +208,30 @@ void VCMDecodedFrameCallback::OnDecoderImplementationName( } void VCMDecodedFrameCallback::Map(uint32_t timestamp, - VCMFrameInformation* frameInfo) { - MutexLock lock(&lock_); - _timestampMap.Add(timestamp, frameInfo); + const VCMFrameInformation& frameInfo) { + int dropped_frames = 0; + { + MutexLock lock(&lock_); + int initial_size = _timestampMap.Size(); + _timestampMap.Add(timestamp, frameInfo); + // If no frame is dropped, the new size should be |initial_size| + 1 + dropped_frames = (initial_size + 1) - _timestampMap.Size(); + } + if (dropped_frames > 0) { + _receiveCallback->OnDroppedFrames(dropped_frames); + } } -int32_t VCMDecodedFrameCallback::Pop(uint32_t timestamp) { - MutexLock lock(&lock_); - if (_timestampMap.Pop(timestamp) == NULL) { - return VCM_GENERAL_ERROR; +void VCMDecodedFrameCallback::ClearTimestampMap() { + int dropped_frames = 0; + { + MutexLock lock(&lock_); + dropped_frames = _timestampMap.Size(); + _timestampMap.Clear(); + } + if (dropped_frames > 0) { + _receiveCallback->OnDroppedFrames(dropped_frames); } - _receiveCallback->OnDroppedFrames(1); - return VCM_OK; } VCMGenericDecoder::VCMGenericDecoder(std::unique_ptr decoder) @@ -216,8 +239,6 @@ VCMGenericDecoder::VCMGenericDecoder(std::unique_ptr decoder) VCMGenericDecoder::VCMGenericDecoder(VideoDecoder* decoder, bool isExternal) : _callback(NULL), - _frameInfos(), - _nextFrameInfoIdx(0), decoder_(decoder), _codecType(kVideoCodecGeneric), _isExternal(isExternal), @@ -250,34 +271,32 @@ int32_t VCMGenericDecoder::InitDecode(const VideoCodec* settings, int32_t VCMGenericDecoder::Decode(const VCMEncodedFrame& frame, Timestamp now) { TRACE_EVENT1("webrtc", "VCMGenericDecoder::Decode", "timestamp", frame.Timestamp()); - _frameInfos[_nextFrameInfoIdx].decodeStart = now; - _frameInfos[_nextFrameInfoIdx].renderTimeMs = frame.RenderTimeMs(); - _frameInfos[_nextFrameInfoIdx].rotation = frame.rotation(); - _frameInfos[_nextFrameInfoIdx].playout_delay = frame.PlayoutDelay(); - _frameInfos[_nextFrameInfoIdx].timing = frame.video_timing(); - _frameInfos[_nextFrameInfoIdx].ntp_time_ms = - frame.EncodedImage().ntp_time_ms_; - _frameInfos[_nextFrameInfoIdx].packet_infos = frame.PacketInfos(); + VCMFrameInformation frame_info; + frame_info.decodeStart = now; + frame_info.renderTimeMs = frame.RenderTimeMs(); + frame_info.rotation = frame.rotation(); + frame_info.timing = frame.video_timing(); + frame_info.ntp_time_ms = frame.EncodedImage().ntp_time_ms_; + frame_info.packet_infos = frame.PacketInfos(); // Set correctly only for key frames. Thus, use latest key frame // content type. If the corresponding key frame was lost, decode will fail // and content type will be ignored. if (frame.FrameType() == VideoFrameType::kVideoFrameKey) { - _frameInfos[_nextFrameInfoIdx].content_type = frame.contentType(); + frame_info.content_type = frame.contentType(); _last_keyframe_content_type = frame.contentType(); } else { - _frameInfos[_nextFrameInfoIdx].content_type = _last_keyframe_content_type; + frame_info.content_type = _last_keyframe_content_type; } - _callback->Map(frame.Timestamp(), &_frameInfos[_nextFrameInfoIdx]); + _callback->Map(frame.Timestamp(), frame_info); - _nextFrameInfoIdx = (_nextFrameInfoIdx + 1) % kDecoderFrameMemoryLength; int32_t ret = decoder_->Decode(frame.EncodedImage(), frame.MissingFrame(), frame.RenderTimeMs()); VideoDecoder::DecoderInfo decoder_info = decoder_->GetDecoderInfo(); if (decoder_info != decoder_info_) { RTC_LOG(LS_INFO) << "Changed decoder implementation to: " << decoder_info.ToString(); - + decoder_info_ = decoder_info; _callback->OnDecoderImplementationName( decoder_info.implementation_name.empty() ? "unknown" @@ -286,11 +305,10 @@ int32_t VCMGenericDecoder::Decode(const VCMEncodedFrame& frame, Timestamp now) { if (ret < WEBRTC_VIDEO_CODEC_OK) { RTC_LOG(LS_WARNING) << "Failed to decode frame with timestamp " << frame.Timestamp() << ", error code: " << ret; - _callback->Pop(frame.Timestamp()); - return ret; + _callback->ClearTimestampMap(); } else if (ret == WEBRTC_VIDEO_CODEC_NO_OUTPUT) { - // No output - _callback->Pop(frame.Timestamp()); + // No output. + _callback->ClearTimestampMap(); } return ret; } diff --git a/modules/video_coding/generic_decoder.h b/modules/video_coding/generic_decoder.h index 9524bab99b..8e79cb4e19 100644 --- a/modules/video_coding/generic_decoder.h +++ b/modules/video_coding/generic_decoder.h @@ -14,6 +14,7 @@ #include #include +#include "api/sequence_checker.h" #include "api/units/time_delta.h" #include "modules/video_coding/encoded_frame.h" #include "modules/video_coding/include/video_codec_interface.h" @@ -21,7 +22,6 @@ #include "modules/video_coding/timing.h" #include "rtc_base/experiments/field_trial_parser.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -29,19 +29,6 @@ class VCMReceiveCallback; enum { kDecoderFrameMemoryLength = 10 }; -struct VCMFrameInformation { - int64_t renderTimeMs; - absl::optional decodeStart; - void* userData; - VideoRotation rotation; - VideoContentType content_type; - PlayoutDelay playout_delay; - EncodedImage::Timing timing; - int64_t ntp_time_ms; - RtpPacketInfos packet_infos; - // ColorSpace is not stored here, as it might be modified by decoders. -}; - class VCMDecodedFrameCallback : public DecodedImageCallback { public: VCMDecodedFrameCallback(VCMTiming* timing, Clock* clock); @@ -57,11 +44,11 @@ class VCMDecodedFrameCallback : public DecodedImageCallback { void OnDecoderImplementationName(const char* implementation_name); - void Map(uint32_t timestamp, VCMFrameInformation* frameInfo); - int32_t Pop(uint32_t timestamp); + void Map(uint32_t timestamp, const VCMFrameInformation& frameInfo); + void ClearTimestampMap(); private: - rtc::ThreadChecker construction_thread_; + SequenceChecker construction_thread_; // Protect |_timestampMap|. Clock* const _clock; // This callback must be set before the decoder thread starts running @@ -117,8 +104,6 @@ class VCMGenericDecoder { private: VCMDecodedFrameCallback* _callback; - VCMFrameInformation _frameInfos[kDecoderFrameMemoryLength]; - uint32_t _nextFrameInfoIdx; std::unique_ptr decoder_; VideoCodecType _codecType; const bool _isExternal; diff --git a/modules/video_coding/include/video_codec_interface.h b/modules/video_coding/include/video_codec_interface.h index b786be1693..4737dde90f 100644 --- a/modules/video_coding/include/video_codec_interface.h +++ b/modules/video_coding/include/video_codec_interface.h @@ -13,6 +13,7 @@ #include +#include "absl/base/attributes.h" #include "absl/types/optional.h" #include "api/video/video_frame.h" #include "api/video_codecs/video_decoder.h" @@ -21,7 +22,6 @@ #include "modules/video_coding/codecs/h264/include/h264_globals.h" #include "modules/video_coding/codecs/vp9/include/vp9_globals.h" #include "modules/video_coding/include/video_error_codes.h" -#include "rtc_base/deprecation.h" #include "rtc_base/system/rtc_export.h" namespace webrtc { @@ -79,7 +79,7 @@ struct CodecSpecificInfoVP9 { uint8_t num_ref_pics; uint8_t p_diff[kMaxVp9RefPics]; - RTC_DEPRECATED bool end_of_picture; + ABSL_DEPRECATED("") bool end_of_picture; }; static_assert(std::is_pod::value, ""); diff --git a/modules/video_coding/jitter_buffer.cc b/modules/video_coding/jitter_buffer.cc index 772098a738..75142e93ee 100644 --- a/modules/video_coding/jitter_buffer.cc +++ b/modules/video_coding/jitter_buffer.cc @@ -347,7 +347,7 @@ VCMFrameBufferEnum VCMJitterBuffer::GetFrame(const VCMPacket& packet, int64_t VCMJitterBuffer::LastPacketTime(const VCMEncodedFrame* frame, bool* retransmitted) const { - assert(retransmitted); + RTC_DCHECK(retransmitted); MutexLock lock(&mutex_); const VCMFrameBuffer* frame_buffer = static_cast(frame); @@ -498,7 +498,7 @@ VCMFrameBufferEnum VCMJitterBuffer::InsertPacket(const VCMPacket& packet, RecycleFrameBuffer(frame); return kFlushIndicator; default: - assert(false); + RTC_NOTREACHED(); } return buffer_state; } @@ -580,8 +580,8 @@ void VCMJitterBuffer::SetNackSettings(size_t max_nack_list_size, int max_packet_age_to_nack, int max_incomplete_time_ms) { MutexLock lock(&mutex_); - assert(max_packet_age_to_nack >= 0); - assert(max_incomplete_time_ms_ >= 0); + RTC_DCHECK_GE(max_packet_age_to_nack, 0); + RTC_DCHECK_GE(max_incomplete_time_ms_, 0); max_nack_list_size_ = max_nack_list_size; max_packet_age_to_nack_ = max_packet_age_to_nack; max_incomplete_time_ms_ = max_incomplete_time_ms; @@ -600,7 +600,7 @@ int VCMJitterBuffer::NonContinuousOrIncompleteDuration() { uint16_t VCMJitterBuffer::EstimatedLowSequenceNumber( const VCMFrameBuffer& frame) const { - assert(frame.GetLowSeqNum() >= 0); + RTC_DCHECK_GE(frame.GetLowSeqNum(), 0); if (frame.HaveFirstPacket()) return frame.GetLowSeqNum(); diff --git a/modules/video_coding/jitter_buffer_unittest.cc b/modules/video_coding/jitter_buffer_unittest.cc index acfee8c6f7..752ceb835e 100644 --- a/modules/video_coding/jitter_buffer_unittest.cc +++ b/modules/video_coding/jitter_buffer_unittest.cc @@ -67,8 +67,7 @@ class TestBasicJitterBuffer : public ::testing::Test { video_header.is_first_packet_in_frame = true; video_header.frame_type = VideoFrameType::kVideoFrameDelta; packet_.reset(new VCMPacket(data_, size_, rtp_header, video_header, - /*ntp_time_ms=*/0, - clock_->TimeInMilliseconds())); + /*ntp_time_ms=*/0, clock_->CurrentTime())); } VCMEncodedFrame* DecodeCompleteFrame() { @@ -541,7 +540,7 @@ TEST_F(TestBasicJitterBuffer, TestReorderingWithPadding) { video_header.codec = kVideoCodecGeneric; video_header.frame_type = VideoFrameType::kEmptyFrame; VCMPacket empty_packet(data_, 0, rtp_header, video_header, - /*ntp_time_ms=*/0, clock_->TimeInMilliseconds()); + /*ntp_time_ms=*/0, clock_->CurrentTime()); EXPECT_EQ(kOldPacket, jitter_buffer_->InsertPacket(empty_packet, &retransmitted)); empty_packet.seqNum += 1; diff --git a/modules/video_coding/jitter_estimator.cc b/modules/video_coding/jitter_estimator.cc index 44e2a9811e..92a298c259 100644 --- a/modules/video_coding/jitter_estimator.cc +++ b/modules/video_coding/jitter_estimator.cc @@ -247,7 +247,7 @@ void VCMJitterEstimator::KalmanEstimateChannel(int64_t frameDelayMS, hMh_sigma = deltaFSBytes * Mh[0] + Mh[1] + sigma; if ((hMh_sigma < 1e-9 && hMh_sigma >= 0) || (hMh_sigma > -1e-9 && hMh_sigma <= 0)) { - assert(false); + RTC_NOTREACHED(); return; } kalmanGain[0] = Mh[0] / hMh_sigma; @@ -276,11 +276,11 @@ void VCMJitterEstimator::KalmanEstimateChannel(int64_t frameDelayMS, kalmanGain[1] * deltaFSBytes * t01; // Covariance matrix, must be positive semi-definite. - assert(_thetaCov[0][0] + _thetaCov[1][1] >= 0 && - _thetaCov[0][0] * _thetaCov[1][1] - - _thetaCov[0][1] * _thetaCov[1][0] >= - 0 && - _thetaCov[0][0] >= 0); + RTC_DCHECK(_thetaCov[0][0] + _thetaCov[1][1] >= 0 && + _thetaCov[0][0] * _thetaCov[1][1] - + _thetaCov[0][1] * _thetaCov[1][0] >= + 0 && + _thetaCov[0][0] >= 0); } // Calculate difference in delay between a sample and the expected delay @@ -302,7 +302,7 @@ void VCMJitterEstimator::EstimateRandomJitter(double d_dT, _lastUpdateT = now; if (_alphaCount == 0) { - assert(false); + RTC_NOTREACHED(); return; } double alpha = @@ -428,7 +428,7 @@ double VCMJitterEstimator::GetFrameRate() const { double fps = 1000000.0 / fps_counter_.ComputeMean(); // Sanity check. - assert(fps >= 0.0); + RTC_DCHECK_GE(fps, 0.0); if (fps > kMaxFramerateEstimate) { fps = kMaxFramerateEstimate; } diff --git a/modules/video_coding/loss_notification_controller.h b/modules/video_coding/loss_notification_controller.h index 06e193b557..4d536ba4f9 100644 --- a/modules/video_coding/loss_notification_controller.h +++ b/modules/video_coding/loss_notification_controller.h @@ -17,8 +17,8 @@ #include "absl/types/optional.h" #include "api/array_view.h" +#include "api/sequence_checker.h" #include "modules/include/module_common_types.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" namespace webrtc { diff --git a/modules/video_coding/media_opt_util.cc b/modules/video_coding/media_opt_util.cc index b47eeb55d3..0136ae8ec9 100644 --- a/modules/video_coding/media_opt_util.cc +++ b/modules/video_coding/media_opt_util.cc @@ -87,10 +87,10 @@ VCMNackFecMethod::VCMNackFecMethod(int64_t lowRttNackThresholdMs, _lowRttNackMs(lowRttNackThresholdMs), _highRttNackMs(highRttNackThresholdMs), _maxFramesFec(1) { - assert(lowRttNackThresholdMs >= -1 && highRttNackThresholdMs >= -1); - assert(highRttNackThresholdMs == -1 || - lowRttNackThresholdMs <= highRttNackThresholdMs); - assert(lowRttNackThresholdMs > -1 || highRttNackThresholdMs == -1); + RTC_DCHECK(lowRttNackThresholdMs >= -1 && highRttNackThresholdMs >= -1); + RTC_DCHECK(highRttNackThresholdMs == -1 || + lowRttNackThresholdMs <= highRttNackThresholdMs); + RTC_DCHECK(lowRttNackThresholdMs > -1 || highRttNackThresholdMs == -1); _type = kNackFec; } @@ -384,7 +384,7 @@ bool VCMFecMethod::ProtectionFactor(const VCMProtectionParameters* parameters) { indexTableKey = VCM_MIN(indexTableKey, kFecRateTableSize); // Check on table index - assert(indexTableKey < kFecRateTableSize); + RTC_DCHECK_LT(indexTableKey, kFecRateTableSize); // Protection factor for I frame codeRateKey = kFecRateTable[indexTableKey]; diff --git a/modules/video_coding/nack_module2.h b/modules/video_coding/nack_module2.h index 89dd082192..f58f886934 100644 --- a/modules/video_coding/nack_module2.h +++ b/modules/video_coding/nack_module2.h @@ -17,11 +17,11 @@ #include #include +#include "api/sequence_checker.h" #include "api/units/time_delta.h" #include "modules/include/module_common_types.h" #include "modules/video_coding/histogram.h" #include "rtc_base/numerics/sequence_number_util.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/task_queue.h" #include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/task_utils/repeating_task.h" diff --git a/modules/video_coding/packet.cc b/modules/video_coding/packet.cc index 0c4a658b8f..324248ab36 100644 --- a/modules/video_coding/packet.cc +++ b/modules/video_coding/packet.cc @@ -34,7 +34,7 @@ VCMPacket::VCMPacket(const uint8_t* ptr, const RTPHeader& rtp_header, const RTPVideoHeader& videoHeader, int64_t ntp_time_ms, - int64_t receive_time_ms) + Timestamp receive_time) : payloadType(rtp_header.payloadType), timestamp(rtp_header.timestamp), ntp_time_ms_(ntp_time_ms), @@ -47,7 +47,7 @@ VCMPacket::VCMPacket(const uint8_t* ptr, insertStartCode(videoHeader.codec == kVideoCodecH264 && videoHeader.is_first_packet_in_frame), video_header(videoHeader), - packet_info(rtp_header, receive_time_ms) { + packet_info(rtp_header, receive_time) { if (is_first_packet_in_frame() && markerBit) { completeNALU = kNaluComplete; } else if (is_first_packet_in_frame()) { diff --git a/modules/video_coding/packet.h b/modules/video_coding/packet.h index f157e10898..9aa2d5ce08 100644 --- a/modules/video_coding/packet.h +++ b/modules/video_coding/packet.h @@ -17,6 +17,7 @@ #include "absl/types/optional.h" #include "api/rtp_headers.h" #include "api/rtp_packet_info.h" +#include "api/units/timestamp.h" #include "api/video/video_frame_type.h" #include "modules/rtp_rtcp/source/rtp_generic_frame_descriptor.h" #include "modules/rtp_rtcp/source/rtp_video_header.h" @@ -41,7 +42,7 @@ class VCMPacket { const RTPHeader& rtp_header, const RTPVideoHeader& video_header, int64_t ntp_time_ms, - int64_t receive_time_ms); + Timestamp receive_time); ~VCMPacket(); diff --git a/modules/video_coding/packet_buffer.cc b/modules/video_coding/packet_buffer.cc index d2a2bcfb47..c98ae00389 100644 --- a/modules/video_coding/packet_buffer.cc +++ b/modules/video_coding/packet_buffer.cc @@ -30,34 +30,21 @@ #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/numerics/mod_ops.h" -#include "system_wrappers/include/clock.h" namespace webrtc { namespace video_coding { PacketBuffer::Packet::Packet(const RtpPacketReceived& rtp_packet, - const RTPVideoHeader& video_header, - int64_t ntp_time_ms, - int64_t receive_time_ms) + const RTPVideoHeader& video_header) : marker_bit(rtp_packet.Marker()), payload_type(rtp_packet.PayloadType()), seq_num(rtp_packet.SequenceNumber()), timestamp(rtp_packet.Timestamp()), - ntp_time_ms(ntp_time_ms), times_nacked(-1), - video_header(video_header), - packet_info(rtp_packet.Ssrc(), - rtp_packet.Csrcs(), - rtp_packet.Timestamp(), - /*audio_level=*/absl::nullopt, - rtp_packet.GetExtension(), - receive_time_ms) {} - -PacketBuffer::PacketBuffer(Clock* clock, - size_t start_buffer_size, - size_t max_buffer_size) - : clock_(clock), - max_size_(max_buffer_size), + video_header(video_header) {} + +PacketBuffer::PacketBuffer(size_t start_buffer_size, size_t max_buffer_size) + : max_size_(max_buffer_size), first_seq_num_(0), first_packet_received_(false), is_cleared_to_first_seq_num_(false), @@ -76,7 +63,6 @@ PacketBuffer::~PacketBuffer() { PacketBuffer::InsertResult PacketBuffer::InsertPacket( std::unique_ptr packet) { PacketBuffer::InsertResult result; - MutexLock lock(&mutex_); uint16_t seq_num = packet->seq_num; size_t index = seq_num % buffer_.size(); @@ -116,14 +102,6 @@ PacketBuffer::InsertResult PacketBuffer::InsertPacket( } } - int64_t now_ms = clock_->TimeInMilliseconds(); - last_received_packet_ms_ = now_ms; - if (packet->video_header.frame_type == VideoFrameType::kVideoFrameKey || - last_received_keyframe_rtp_timestamp_ == packet->timestamp) { - last_received_keyframe_packet_ms_ = now_ms; - last_received_keyframe_rtp_timestamp_ = packet->timestamp; - } - packet->continuous = false; buffer_[index] = std::move(packet); @@ -134,7 +112,6 @@ PacketBuffer::InsertResult PacketBuffer::InsertPacket( } void PacketBuffer::ClearTo(uint16_t seq_num) { - MutexLock lock(&mutex_); // We have already cleared past this sequence number, no need to do anything. if (is_cleared_to_first_seq_num_ && AheadOf(first_seq_num_, seq_num)) { @@ -171,30 +148,20 @@ void PacketBuffer::ClearTo(uint16_t seq_num) { } void PacketBuffer::Clear() { - MutexLock lock(&mutex_); ClearInternal(); } PacketBuffer::InsertResult PacketBuffer::InsertPadding(uint16_t seq_num) { PacketBuffer::InsertResult result; - MutexLock lock(&mutex_); UpdateMissingPackets(seq_num); result.packets = FindFrames(static_cast(seq_num + 1)); return result; } -absl::optional PacketBuffer::LastReceivedPacketMs() const { - MutexLock lock(&mutex_); - return last_received_packet_ms_; -} - -absl::optional PacketBuffer::LastReceivedKeyframePacketMs() const { - MutexLock lock(&mutex_); - return last_received_keyframe_packet_ms_; -} void PacketBuffer::ForceSpsPpsIdrIsH264Keyframe() { sps_pps_idr_is_h264_keyframe_ = true; } + void PacketBuffer::ClearInternal() { for (auto& entry : buffer_) { entry = nullptr; @@ -202,8 +169,6 @@ void PacketBuffer::ClearInternal() { first_packet_received_ = false; is_cleared_to_first_seq_num_ = false; - last_received_packet_ms_.reset(); - last_received_keyframe_packet_ms_.reset(); newest_inserted_seq_num_.reset(); missing_packets_.clear(); } diff --git a/modules/video_coding/packet_buffer.h b/modules/video_coding/packet_buffer.h index e34f7040b5..f4dbe31266 100644 --- a/modules/video_coding/packet_buffer.h +++ b/modules/video_coding/packet_buffer.h @@ -18,14 +18,13 @@ #include "absl/base/attributes.h" #include "api/rtp_packet_info.h" +#include "api/units/timestamp.h" #include "api/video/encoded_image.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "modules/rtp_rtcp/source/rtp_video_header.h" #include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/numerics/sequence_number_util.h" -#include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread_annotations.h" -#include "system_wrappers/include/clock.h" namespace webrtc { namespace video_coding { @@ -35,9 +34,7 @@ class PacketBuffer { struct Packet { Packet() = default; Packet(const RtpPacketReceived& rtp_packet, - const RTPVideoHeader& video_header, - int64_t ntp_time_ms, - int64_t receive_time_ms); + const RTPVideoHeader& video_header); Packet(const Packet&) = delete; Packet(Packet&&) = delete; Packet& operator=(const Packet&) = delete; @@ -62,14 +59,10 @@ class PacketBuffer { uint8_t payload_type = 0; uint16_t seq_num = 0; uint32_t timestamp = 0; - // NTP time of the capture time in local timebase in milliseconds. - int64_t ntp_time_ms = -1; int times_nacked = -1; rtc::CopyOnWriteBuffer video_payload; RTPVideoHeader video_header; - - RtpPacketInfo packet_info; }; struct InsertResult { std::vector> packets; @@ -79,72 +72,50 @@ class PacketBuffer { }; // Both |start_buffer_size| and |max_buffer_size| must be a power of 2. - PacketBuffer(Clock* clock, size_t start_buffer_size, size_t max_buffer_size); + PacketBuffer(size_t start_buffer_size, size_t max_buffer_size); ~PacketBuffer(); - ABSL_MUST_USE_RESULT InsertResult InsertPacket(std::unique_ptr packet) - RTC_LOCKS_EXCLUDED(mutex_); - ABSL_MUST_USE_RESULT InsertResult InsertPadding(uint16_t seq_num) - RTC_LOCKS_EXCLUDED(mutex_); - void ClearTo(uint16_t seq_num) RTC_LOCKS_EXCLUDED(mutex_); - void Clear() RTC_LOCKS_EXCLUDED(mutex_); - - // Timestamp (not RTP timestamp) of the last received packet/keyframe packet. - absl::optional LastReceivedPacketMs() const - RTC_LOCKS_EXCLUDED(mutex_); - absl::optional LastReceivedKeyframePacketMs() const - RTC_LOCKS_EXCLUDED(mutex_); + ABSL_MUST_USE_RESULT InsertResult + InsertPacket(std::unique_ptr packet); + ABSL_MUST_USE_RESULT InsertResult InsertPadding(uint16_t seq_num); + void ClearTo(uint16_t seq_num); + void Clear(); + void ForceSpsPpsIdrIsH264Keyframe(); private: - Clock* const clock_; - - // Clears with |mutex_| taken. - void ClearInternal() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + void ClearInternal(); // Tries to expand the buffer. - bool ExpandBufferSize() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + bool ExpandBufferSize(); // Test if all previous packets has arrived for the given sequence number. - bool PotentialNewFrame(uint16_t seq_num) const - RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + bool PotentialNewFrame(uint16_t seq_num) const; // Test if all packets of a frame has arrived, and if so, returns packets to // create frames. - std::vector> FindFrames(uint16_t seq_num) - RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + std::vector> FindFrames(uint16_t seq_num); - void UpdateMissingPackets(uint16_t seq_num) - RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - mutable Mutex mutex_; + void UpdateMissingPackets(uint16_t seq_num); // buffer_.size() and max_size_ must always be a power of two. const size_t max_size_; // The fist sequence number currently in the buffer. - uint16_t first_seq_num_ RTC_GUARDED_BY(mutex_); + uint16_t first_seq_num_; // If the packet buffer has received its first packet. - bool first_packet_received_ RTC_GUARDED_BY(mutex_); + bool first_packet_received_; // If the buffer is cleared to |first_seq_num_|. - bool is_cleared_to_first_seq_num_ RTC_GUARDED_BY(mutex_); + bool is_cleared_to_first_seq_num_; // Buffer that holds the the inserted packets and information needed to // determine continuity between them. - std::vector> buffer_ RTC_GUARDED_BY(mutex_); - - // Timestamp of the last received packet/keyframe packet. - absl::optional last_received_packet_ms_ RTC_GUARDED_BY(mutex_); - absl::optional last_received_keyframe_packet_ms_ - RTC_GUARDED_BY(mutex_); - absl::optional last_received_keyframe_rtp_timestamp_ - RTC_GUARDED_BY(mutex_); - - absl::optional newest_inserted_seq_num_ RTC_GUARDED_BY(mutex_); - std::set> missing_packets_ - RTC_GUARDED_BY(mutex_); + std::vector> buffer_; + + absl::optional newest_inserted_seq_num_; + std::set> missing_packets_; // Indicates if we should require SPS, PPS, and IDR for a particular // RTP timestamp to treat the corresponding frame as a keyframe. diff --git a/modules/video_coding/packet_buffer_unittest.cc b/modules/video_coding/packet_buffer_unittest.cc index a01b480398..97012618f3 100644 --- a/modules/video_coding/packet_buffer_unittest.cc +++ b/modules/video_coding/packet_buffer_unittest.cc @@ -19,7 +19,6 @@ #include "common_video/h264/h264_common.h" #include "modules/video_coding/frame_object.h" #include "rtc_base/random.h" -#include "system_wrappers/include/clock.h" #include "test/field_trial.h" #include "test/gmock.h" #include "test/gtest.h" @@ -100,10 +99,7 @@ void PrintTo(const PacketBufferInsertResult& result, std::ostream* os) { class PacketBufferTest : public ::testing::Test { protected: - PacketBufferTest() - : rand_(0x7732213), - clock_(0), - packet_buffer_(&clock_, kStartSize, kMaxSize) {} + PacketBufferTest() : rand_(0x7732213), packet_buffer_(kStartSize, kMaxSize) {} uint16_t Rand() { return rand_.Rand(); } @@ -133,7 +129,6 @@ class PacketBufferTest : public ::testing::Test { } Random rand_; - SimulatedClock clock_; PacketBuffer packet_buffer_; }; @@ -616,67 +611,6 @@ TEST_F(PacketBufferTest, ContinuousSeqNumDoubleMarkerBit) { EXPECT_THAT(Insert(3, kKeyFrame, kNotFirst, kLast).packets, IsEmpty()); } -TEST_F(PacketBufferTest, PacketTimestamps) { - absl::optional packet_ms; - absl::optional packet_keyframe_ms; - - packet_ms = packet_buffer_.LastReceivedPacketMs(); - packet_keyframe_ms = packet_buffer_.LastReceivedKeyframePacketMs(); - EXPECT_FALSE(packet_ms); - EXPECT_FALSE(packet_keyframe_ms); - - int64_t keyframe_ms = clock_.TimeInMilliseconds(); - Insert(100, kKeyFrame, kFirst, kLast, {}, /*timestamp=*/1000); - packet_ms = packet_buffer_.LastReceivedPacketMs(); - packet_keyframe_ms = packet_buffer_.LastReceivedKeyframePacketMs(); - EXPECT_TRUE(packet_ms); - EXPECT_TRUE(packet_keyframe_ms); - EXPECT_EQ(keyframe_ms, *packet_ms); - EXPECT_EQ(keyframe_ms, *packet_keyframe_ms); - - clock_.AdvanceTimeMilliseconds(100); - int64_t delta_ms = clock_.TimeInMilliseconds(); - Insert(101, kDeltaFrame, kFirst, kLast, {}, /*timestamp=*/2000); - packet_ms = packet_buffer_.LastReceivedPacketMs(); - packet_keyframe_ms = packet_buffer_.LastReceivedKeyframePacketMs(); - EXPECT_TRUE(packet_ms); - EXPECT_TRUE(packet_keyframe_ms); - EXPECT_EQ(delta_ms, *packet_ms); - EXPECT_EQ(keyframe_ms, *packet_keyframe_ms); - - packet_buffer_.Clear(); - packet_ms = packet_buffer_.LastReceivedPacketMs(); - packet_keyframe_ms = packet_buffer_.LastReceivedKeyframePacketMs(); - EXPECT_FALSE(packet_ms); - EXPECT_FALSE(packet_keyframe_ms); -} - -TEST_F(PacketBufferTest, - LastReceivedKeyFrameReturnsReceiveTimeOfALastReceivedPacketOfAKeyFrame) { - clock_.AdvanceTimeMilliseconds(100); - Insert(/*seq_num=*/100, kKeyFrame, kFirst, kNotLast, {}, /*timestamp=*/1000); - EXPECT_EQ(packet_buffer_.LastReceivedKeyframePacketMs(), - clock_.TimeInMilliseconds()); - - clock_.AdvanceTimeMilliseconds(100); - Insert(/*seq_num=*/102, kDeltaFrame, kNotFirst, kLast, {}, - /*timestamp=*/1000); - EXPECT_EQ(packet_buffer_.LastReceivedKeyframePacketMs(), - clock_.TimeInMilliseconds()); - - clock_.AdvanceTimeMilliseconds(100); - Insert(/*seq_num=*/101, kDeltaFrame, kNotFirst, kNotLast, {}, - /*timestamp=*/1000); - EXPECT_EQ(packet_buffer_.LastReceivedKeyframePacketMs(), - clock_.TimeInMilliseconds()); - - clock_.AdvanceTimeMilliseconds(100); - Insert(/*seq_num=*/103, kDeltaFrame, kFirst, kNotLast, {}, - /*timestamp=*/2000); - EXPECT_EQ(packet_buffer_.LastReceivedKeyframePacketMs(), - clock_.TimeInMilliseconds() - 100); -} - TEST_F(PacketBufferTest, IncomingCodecChange) { auto packet = std::make_unique(); packet->video_header.is_first_packet_in_frame = true; diff --git a/modules/video_coding/receiver.cc b/modules/video_coding/receiver.cc index 6b942fbe57..8e8f0e1ee2 100644 --- a/modules/video_coding/receiver.cc +++ b/modules/video_coding/receiver.cc @@ -141,7 +141,8 @@ VCMEncodedFrame* VCMReceiver::FrameForDecoding(uint16_t max_wait_time_ms, uint16_t new_max_wait_time = static_cast(VCM_MAX(available_wait_time, 0)); uint32_t wait_time_ms = rtc::saturated_cast( - timing_->MaxWaitingTime(render_time_ms, clock_->TimeInMilliseconds())); + timing_->MaxWaitingTime(render_time_ms, clock_->TimeInMilliseconds(), + /*too_many_frames_queued=*/false)); if (new_max_wait_time < wait_time_ms) { // We're not allowed to wait until the frame is supposed to be rendered, // waiting as long as we're allowed to avoid busy looping, and then return diff --git a/modules/video_coding/receiver_unittest.cc b/modules/video_coding/receiver_unittest.cc index 2585056023..b2d5bc6f03 100644 --- a/modules/video_coding/receiver_unittest.cc +++ b/modules/video_coding/receiver_unittest.cc @@ -30,18 +30,14 @@ namespace webrtc { class TestVCMReceiver : public ::testing::Test { protected: TestVCMReceiver() - : clock_(new SimulatedClock(0)), - timing_(clock_.get()), - receiver_(&timing_, clock_.get()) { - stream_generator_.reset( - new StreamGenerator(0, clock_->TimeInMilliseconds())); - } - - virtual void SetUp() {} + : clock_(0), + timing_(&clock_), + receiver_(&timing_, &clock_), + stream_generator_(0, clock_.TimeInMilliseconds()) {} int32_t InsertPacket(int index) { VCMPacket packet; - bool packet_available = stream_generator_->GetPacket(&packet, index); + bool packet_available = stream_generator_.GetPacket(&packet, index); EXPECT_TRUE(packet_available); if (!packet_available) return kGeneralError; // Return here to avoid crashes below. @@ -50,7 +46,7 @@ class TestVCMReceiver : public ::testing::Test { int32_t InsertPacketAndPop(int index) { VCMPacket packet; - bool packet_available = stream_generator_->PopPacket(&packet, index); + bool packet_available = stream_generator_.PopPacket(&packet, index); EXPECT_TRUE(packet_available); if (!packet_available) return kGeneralError; // Return here to avoid crashes below. @@ -59,18 +55,18 @@ class TestVCMReceiver : public ::testing::Test { int32_t InsertFrame(VideoFrameType frame_type, bool complete) { int num_of_packets = complete ? 1 : 2; - stream_generator_->GenerateFrame( + stream_generator_.GenerateFrame( frame_type, (frame_type != VideoFrameType::kEmptyFrame) ? num_of_packets : 0, (frame_type == VideoFrameType::kEmptyFrame) ? 1 : 0, - clock_->TimeInMilliseconds()); + clock_.TimeInMilliseconds()); int32_t ret = InsertPacketAndPop(0); if (!complete) { // Drop the second packet. VCMPacket packet; - stream_generator_->PopPacket(&packet, 0); + stream_generator_.PopPacket(&packet, 0); } - clock_->AdvanceTimeMilliseconds(kDefaultFramePeriodMs); + clock_.AdvanceTimeMilliseconds(kDefaultFramePeriodMs); return ret; } @@ -82,10 +78,10 @@ class TestVCMReceiver : public ::testing::Test { return true; } - std::unique_ptr clock_; + SimulatedClock clock_; VCMTiming timing_; VCMReceiver receiver_; - std::unique_ptr stream_generator_; + StreamGenerator stream_generator_; }; TEST_F(TestVCMReceiver, NonDecodableDuration_Empty) { @@ -97,7 +93,7 @@ TEST_F(TestVCMReceiver, NonDecodableDuration_Empty) { kMaxNonDecodableDuration); EXPECT_GE(InsertFrame(VideoFrameType::kVideoFrameKey, true), kNoError); // Advance time until it's time to decode the key frame. - clock_->AdvanceTimeMilliseconds(kMinDelayMs); + clock_.AdvanceTimeMilliseconds(kMinDelayMs); EXPECT_TRUE(DecodeNextFrame()); bool request_key_frame = false; std::vector nack_list = receiver_.NackList(&request_key_frame); @@ -129,7 +125,7 @@ TEST_F(TestVCMReceiver, NonDecodableDuration_OneIncomplete) { receiver_.SetNackSettings(kMaxNackListSize, kMaxPacketAgeToNack, kMaxNonDecodableDuration); timing_.set_min_playout_delay(kMinDelayMs); - int64_t key_frame_inserted = clock_->TimeInMilliseconds(); + int64_t key_frame_inserted = clock_.TimeInMilliseconds(); EXPECT_GE(InsertFrame(VideoFrameType::kVideoFrameKey, true), kNoError); // Insert an incomplete frame. EXPECT_GE(InsertFrame(VideoFrameType::kVideoFrameDelta, false), kNoError); @@ -138,8 +134,8 @@ TEST_F(TestVCMReceiver, NonDecodableDuration_OneIncomplete) { EXPECT_GE(InsertFrame(VideoFrameType::kVideoFrameDelta, true), kNoError); } // Advance time until it's time to decode the key frame. - clock_->AdvanceTimeMilliseconds(kMinDelayMs - clock_->TimeInMilliseconds() - - key_frame_inserted); + clock_.AdvanceTimeMilliseconds(kMinDelayMs - clock_.TimeInMilliseconds() - + key_frame_inserted); EXPECT_TRUE(DecodeNextFrame()); // Make sure we get a key frame request. bool request_key_frame = false; @@ -157,7 +153,7 @@ TEST_F(TestVCMReceiver, NonDecodableDuration_NoTrigger) { receiver_.SetNackSettings(kMaxNackListSize, kMaxPacketAgeToNack, kMaxNonDecodableDuration); timing_.set_min_playout_delay(kMinDelayMs); - int64_t key_frame_inserted = clock_->TimeInMilliseconds(); + int64_t key_frame_inserted = clock_.TimeInMilliseconds(); EXPECT_GE(InsertFrame(VideoFrameType::kVideoFrameKey, true), kNoError); // Insert an incomplete frame. EXPECT_GE(InsertFrame(VideoFrameType::kVideoFrameDelta, false), kNoError); @@ -167,8 +163,8 @@ TEST_F(TestVCMReceiver, NonDecodableDuration_NoTrigger) { EXPECT_GE(InsertFrame(VideoFrameType::kVideoFrameDelta, true), kNoError); } // Advance time until it's time to decode the key frame. - clock_->AdvanceTimeMilliseconds(kMinDelayMs - clock_->TimeInMilliseconds() - - key_frame_inserted); + clock_.AdvanceTimeMilliseconds(kMinDelayMs - clock_.TimeInMilliseconds() - + key_frame_inserted); EXPECT_TRUE(DecodeNextFrame()); // Make sure we don't get a key frame request since we haven't generated // enough frames. @@ -187,7 +183,7 @@ TEST_F(TestVCMReceiver, NonDecodableDuration_NoTrigger2) { receiver_.SetNackSettings(kMaxNackListSize, kMaxPacketAgeToNack, kMaxNonDecodableDuration); timing_.set_min_playout_delay(kMinDelayMs); - int64_t key_frame_inserted = clock_->TimeInMilliseconds(); + int64_t key_frame_inserted = clock_.TimeInMilliseconds(); EXPECT_GE(InsertFrame(VideoFrameType::kVideoFrameKey, true), kNoError); // Insert enough frames to have too long non-decodable sequence, except that // we don't have any losses. @@ -197,8 +193,8 @@ TEST_F(TestVCMReceiver, NonDecodableDuration_NoTrigger2) { // Insert an incomplete frame. EXPECT_GE(InsertFrame(VideoFrameType::kVideoFrameDelta, false), kNoError); // Advance time until it's time to decode the key frame. - clock_->AdvanceTimeMilliseconds(kMinDelayMs - clock_->TimeInMilliseconds() - - key_frame_inserted); + clock_.AdvanceTimeMilliseconds(kMinDelayMs - clock_.TimeInMilliseconds() - + key_frame_inserted); EXPECT_TRUE(DecodeNextFrame()); // Make sure we don't get a key frame request since the non-decodable duration // is only one frame. @@ -217,7 +213,7 @@ TEST_F(TestVCMReceiver, NonDecodableDuration_KeyFrameAfterIncompleteFrames) { receiver_.SetNackSettings(kMaxNackListSize, kMaxPacketAgeToNack, kMaxNonDecodableDuration); timing_.set_min_playout_delay(kMinDelayMs); - int64_t key_frame_inserted = clock_->TimeInMilliseconds(); + int64_t key_frame_inserted = clock_.TimeInMilliseconds(); EXPECT_GE(InsertFrame(VideoFrameType::kVideoFrameKey, true), kNoError); // Insert an incomplete frame. EXPECT_GE(InsertFrame(VideoFrameType::kVideoFrameDelta, false), kNoError); @@ -227,8 +223,8 @@ TEST_F(TestVCMReceiver, NonDecodableDuration_KeyFrameAfterIncompleteFrames) { } EXPECT_GE(InsertFrame(VideoFrameType::kVideoFrameKey, true), kNoError); // Advance time until it's time to decode the key frame. - clock_->AdvanceTimeMilliseconds(kMinDelayMs - clock_->TimeInMilliseconds() - - key_frame_inserted); + clock_.AdvanceTimeMilliseconds(kMinDelayMs - clock_.TimeInMilliseconds() - + key_frame_inserted); EXPECT_TRUE(DecodeNextFrame()); // Make sure we don't get a key frame request since we have a key frame // in the list. @@ -367,7 +363,6 @@ class FrameInjectEvent : public EventWrapper { class VCMReceiverTimingTest : public ::testing::Test { protected: VCMReceiverTimingTest() - : clock_(&stream_generator_, &receiver_), stream_generator_(0, clock_.TimeInMilliseconds()), timing_(&clock_), diff --git a/modules/video_coding/rtp_frame_id_only_ref_finder.cc b/modules/video_coding/rtp_frame_id_only_ref_finder.cc index f2494ec763..9f3d5bb296 100644 --- a/modules/video_coding/rtp_frame_id_only_ref_finder.cc +++ b/modules/video_coding/rtp_frame_id_only_ref_finder.cc @@ -15,20 +15,19 @@ #include "rtc_base/logging.h" namespace webrtc { -namespace video_coding { RtpFrameReferenceFinder::ReturnVector RtpFrameIdOnlyRefFinder::ManageFrame( std::unique_ptr frame, int frame_id) { - frame->id.picture_id = unwrapper_.Unwrap(frame_id & (kFrameIdLength - 1)); + frame->SetSpatialIndex(0); + frame->SetId(unwrapper_.Unwrap(frame_id & (kFrameIdLength - 1))); frame->num_references = frame->frame_type() == VideoFrameType::kVideoFrameKey ? 0 : 1; - frame->references[0] = frame->id.picture_id - 1; + frame->references[0] = frame->Id() - 1; RtpFrameReferenceFinder::ReturnVector res; res.push_back(std::move(frame)); return res; } -} // namespace video_coding } // namespace webrtc diff --git a/modules/video_coding/rtp_frame_id_only_ref_finder.h b/modules/video_coding/rtp_frame_id_only_ref_finder.h index 7728ba92bc..1df4870c5b 100644 --- a/modules/video_coding/rtp_frame_id_only_ref_finder.h +++ b/modules/video_coding/rtp_frame_id_only_ref_finder.h @@ -19,7 +19,6 @@ #include "rtc_base/numerics/sequence_number_util.h" namespace webrtc { -namespace video_coding { class RtpFrameIdOnlyRefFinder { public: @@ -34,7 +33,6 @@ class RtpFrameIdOnlyRefFinder { SeqNumUnwrapper unwrapper_; }; -} // namespace video_coding } // namespace webrtc #endif // MODULES_VIDEO_CODING_RTP_FRAME_ID_ONLY_REF_FINDER_H_ diff --git a/modules/video_coding/rtp_frame_reference_finder.cc b/modules/video_coding/rtp_frame_reference_finder.cc index 3084b5b2d9..a44b76bf15 100644 --- a/modules/video_coding/rtp_frame_reference_finder.cc +++ b/modules/video_coding/rtp_frame_reference_finder.cc @@ -21,7 +21,6 @@ #include "modules/video_coding/rtp_vp9_ref_finder.h" namespace webrtc { -namespace video_coding { namespace internal { class RtpFrameReferenceFinderImpl { public: @@ -143,31 +142,34 @@ T& RtpFrameReferenceFinderImpl::GetRefFinderAs() { } // namespace internal -RtpFrameReferenceFinder::RtpFrameReferenceFinder( - OnCompleteFrameCallback* frame_callback) - : RtpFrameReferenceFinder(frame_callback, 0) {} +RtpFrameReferenceFinder::RtpFrameReferenceFinder() + : RtpFrameReferenceFinder(0) {} RtpFrameReferenceFinder::RtpFrameReferenceFinder( - OnCompleteFrameCallback* frame_callback, int64_t picture_id_offset) : picture_id_offset_(picture_id_offset), - frame_callback_(frame_callback), impl_(std::make_unique()) {} RtpFrameReferenceFinder::~RtpFrameReferenceFinder() = default; -void RtpFrameReferenceFinder::ManageFrame( +RtpFrameReferenceFinder::ReturnVector RtpFrameReferenceFinder::ManageFrame( std::unique_ptr frame) { // If we have cleared past this frame, drop it. if (cleared_to_seq_num_ != -1 && AheadOf(cleared_to_seq_num_, frame->first_seq_num())) { - return; + return {}; } - HandOffFrames(impl_->ManageFrame(std::move(frame))); + + auto frames = impl_->ManageFrame(std::move(frame)); + AddPictureIdOffset(frames); + return frames; } -void RtpFrameReferenceFinder::PaddingReceived(uint16_t seq_num) { - HandOffFrames(impl_->PaddingReceived(seq_num)); +RtpFrameReferenceFinder::ReturnVector RtpFrameReferenceFinder::PaddingReceived( + uint16_t seq_num) { + auto frames = impl_->PaddingReceived(seq_num); + AddPictureIdOffset(frames); + return frames; } void RtpFrameReferenceFinder::ClearTo(uint16_t seq_num) { @@ -175,16 +177,13 @@ void RtpFrameReferenceFinder::ClearTo(uint16_t seq_num) { impl_->ClearTo(seq_num); } -void RtpFrameReferenceFinder::HandOffFrames(ReturnVector frames) { +void RtpFrameReferenceFinder::AddPictureIdOffset(ReturnVector& frames) { for (auto& frame : frames) { - frame->id.picture_id += picture_id_offset_; + frame->SetId(frame->Id() + picture_id_offset_); for (size_t i = 0; i < frame->num_references; ++i) { frame->references[i] += picture_id_offset_; } - - frame_callback_->OnCompleteFrame(std::move(frame)); } } -} // namespace video_coding } // namespace webrtc diff --git a/modules/video_coding/rtp_frame_reference_finder.h b/modules/video_coding/rtp_frame_reference_finder.h index c7ee07e215..d2447773a3 100644 --- a/modules/video_coding/rtp_frame_reference_finder.h +++ b/modules/video_coding/rtp_frame_reference_finder.h @@ -16,56 +16,45 @@ #include "modules/video_coding/frame_object.h" namespace webrtc { -namespace video_coding { namespace internal { class RtpFrameReferenceFinderImpl; } // namespace internal -// A complete frame is a frame which has received all its packets and all its -// references are known. -class OnCompleteFrameCallback { - public: - virtual ~OnCompleteFrameCallback() {} - virtual void OnCompleteFrame(std::unique_ptr frame) = 0; -}; - class RtpFrameReferenceFinder { public: using ReturnVector = absl::InlinedVector, 3>; - explicit RtpFrameReferenceFinder(OnCompleteFrameCallback* frame_callback); - explicit RtpFrameReferenceFinder(OnCompleteFrameCallback* frame_callback, - int64_t picture_id_offset); + RtpFrameReferenceFinder(); + explicit RtpFrameReferenceFinder(int64_t picture_id_offset); ~RtpFrameReferenceFinder(); - // Manage this frame until: - // - We have all information needed to determine its references, after - // which |frame_callback_| is called with the completed frame, or - // - We have too many stashed frames (determined by |kMaxStashedFrames|) - // so we drop this frame, or - // - It gets cleared by ClearTo, which also means we drop it. - void ManageFrame(std::unique_ptr frame); + // The RtpFrameReferenceFinder will hold onto the frame until: + // - the required information to determine its references has been received, + // in which case it (and possibly other) frames are returned, or + // - There are too many stashed frames (determined by |kMaxStashedFrames|), + // in which case it gets dropped, or + // - It gets cleared by ClearTo, in which case its dropped. + // - The frame is old, in which case it also gets dropped. + ReturnVector ManageFrame(std::unique_ptr frame); // Notifies that padding has been received, which the reference finder // might need to calculate the references of a frame. - void PaddingReceived(uint16_t seq_num); + ReturnVector PaddingReceived(uint16_t seq_num); // Clear all stashed frames that include packets older than |seq_num|. void ClearTo(uint16_t seq_num); private: - void HandOffFrames(ReturnVector frames); + void AddPictureIdOffset(ReturnVector& frames); // How far frames have been cleared out of the buffer by RTP sequence number. // A frame will be cleared if it contains a packet with a sequence number // older than |cleared_to_seq_num_|. int cleared_to_seq_num_ = -1; const int64_t picture_id_offset_; - OnCompleteFrameCallback* frame_callback_; std::unique_ptr impl_; }; -} // namespace video_coding } // namespace webrtc #endif // MODULES_VIDEO_CODING_RTP_FRAME_REFERENCE_FINDER_H_ diff --git a/modules/video_coding/rtp_frame_reference_finder_unittest.cc b/modules/video_coding/rtp_frame_reference_finder_unittest.cc index 373e12d226..a5b0fc49ce 100644 --- a/modules/video_coding/rtp_frame_reference_finder_unittest.cc +++ b/modules/video_coding/rtp_frame_reference_finder_unittest.cc @@ -24,7 +24,6 @@ #include "test/gtest.h" namespace webrtc { -namespace video_coding { namespace { std::unique_ptr CreateFrame( @@ -61,28 +60,29 @@ std::unique_ptr CreateFrame( } } // namespace -class TestRtpFrameReferenceFinder : public ::testing::Test, - public OnCompleteFrameCallback { +class TestRtpFrameReferenceFinder : public ::testing::Test { protected: TestRtpFrameReferenceFinder() : rand_(0x8739211), - reference_finder_(new RtpFrameReferenceFinder(this)), + reference_finder_(std::make_unique()), frames_from_callback_(FrameComp()) {} uint16_t Rand() { return rand_.Rand(); } - void OnCompleteFrame(std::unique_ptr frame) override { - int64_t pid = frame->id.picture_id; - uint16_t sidx = frame->id.spatial_layer; - auto frame_it = frames_from_callback_.find(std::make_pair(pid, sidx)); - if (frame_it != frames_from_callback_.end()) { - ADD_FAILURE() << "Already received frame with (pid:sidx): (" << pid << ":" - << sidx << ")"; - return; + void OnCompleteFrames(RtpFrameReferenceFinder::ReturnVector frames) { + for (auto& frame : frames) { + int64_t pid = frame->Id(); + uint16_t sidx = *frame->SpatialIndex(); + auto frame_it = frames_from_callback_.find(std::make_pair(pid, sidx)); + if (frame_it != frames_from_callback_.end()) { + ADD_FAILURE() << "Already received frame with (pid:sidx): (" << pid + << ":" << sidx << ")"; + return; + } + + frames_from_callback_.insert( + std::make_pair(std::make_pair(pid, sidx), std::move(frame))); } - - frames_from_callback_.insert( - std::make_pair(std::make_pair(pid, sidx), std::move(frame))); } void InsertGeneric(uint16_t seq_num_start, @@ -92,33 +92,18 @@ class TestRtpFrameReferenceFinder : public ::testing::Test, CreateFrame(seq_num_start, seq_num_end, keyframe, kVideoCodecGeneric, RTPVideoTypeHeader()); - reference_finder_->ManageFrame(std::move(frame)); - } - - void InsertVp8(uint16_t seq_num_start, - uint16_t seq_num_end, - bool keyframe, - int32_t pid = kNoPictureId, - uint8_t tid = kNoTemporalIdx, - int32_t tl0 = kNoTl0PicIdx, - bool sync = false) { - RTPVideoHeaderVP8 vp8_header{}; - vp8_header.pictureId = pid % (1 << 15); - vp8_header.temporalIdx = tid; - vp8_header.tl0PicIdx = tl0; - vp8_header.layerSync = sync; - - std::unique_ptr frame = CreateFrame( - seq_num_start, seq_num_end, keyframe, kVideoCodecVP8, vp8_header); - - reference_finder_->ManageFrame(std::move(frame)); + OnCompleteFrames(reference_finder_->ManageFrame(std::move(frame))); } void InsertH264(uint16_t seq_num_start, uint16_t seq_num_end, bool keyframe) { std::unique_ptr frame = CreateFrame(seq_num_start, seq_num_end, keyframe, kVideoCodecH264, RTPVideoTypeHeader()); - reference_finder_->ManageFrame(std::move(frame)); + OnCompleteFrames(reference_finder_->ManageFrame(std::move(frame))); + } + + void InsertPadding(uint16_t seq_num) { + OnCompleteFrames(reference_finder_->PaddingReceived(seq_num)); } // Check if a frame with picture id |pid| and spatial index |sidx| has been @@ -151,11 +136,6 @@ class TestRtpFrameReferenceFinder : public ::testing::Test, CheckReferences(pid, 0, refs...); } - template - void CheckReferencesVp8(int64_t pid, T... refs) const { - CheckReferences(pid, 0, refs...); - } - template void CheckReferencesH264(int64_t pid, T... refs) const { CheckReferences(pid, 0, refs...); @@ -190,7 +170,7 @@ TEST_F(TestRtpFrameReferenceFinder, PaddingPackets) { InsertGeneric(sn, sn, true); InsertGeneric(sn + 2, sn + 2, false); EXPECT_EQ(1UL, frames_from_callback_.size()); - reference_finder_->PaddingReceived(sn + 1); + InsertPadding(sn + 1); EXPECT_EQ(2UL, frames_from_callback_.size()); } @@ -198,8 +178,8 @@ TEST_F(TestRtpFrameReferenceFinder, PaddingPacketsReordered) { uint16_t sn = Rand(); InsertGeneric(sn, sn, true); - reference_finder_->PaddingReceived(sn + 1); - reference_finder_->PaddingReceived(sn + 4); + InsertPadding(sn + 1); + InsertPadding(sn + 4); InsertGeneric(sn + 2, sn + 3, false); EXPECT_EQ(2UL, frames_from_callback_.size()); @@ -211,12 +191,12 @@ TEST_F(TestRtpFrameReferenceFinder, PaddingPacketsReorderedMultipleKeyframes) { uint16_t sn = Rand(); InsertGeneric(sn, sn, true); - reference_finder_->PaddingReceived(sn + 1); - reference_finder_->PaddingReceived(sn + 4); + InsertPadding(sn + 1); + InsertPadding(sn + 4); InsertGeneric(sn + 2, sn + 3, false); InsertGeneric(sn + 5, sn + 5, true); - reference_finder_->PaddingReceived(sn + 6); - reference_finder_->PaddingReceived(sn + 9); + InsertPadding(sn + 6); + InsertPadding(sn + 9); InsertGeneric(sn + 7, sn + 8, false); EXPECT_EQ(4UL, frames_from_callback_.size()); @@ -253,415 +233,6 @@ TEST_F(TestRtpFrameReferenceFinder, ClearTo) { EXPECT_EQ(3UL, frames_from_callback_.size()); } -TEST_F(TestRtpFrameReferenceFinder, Vp8NoPictureId) { - uint16_t sn = Rand(); - - InsertVp8(sn, sn + 2, true); - ASSERT_EQ(1UL, frames_from_callback_.size()); - - InsertVp8(sn + 3, sn + 4, false); - ASSERT_EQ(2UL, frames_from_callback_.size()); - - InsertVp8(sn + 5, sn + 8, false); - ASSERT_EQ(3UL, frames_from_callback_.size()); - - InsertVp8(sn + 9, sn + 9, false); - ASSERT_EQ(4UL, frames_from_callback_.size()); - - InsertVp8(sn + 10, sn + 11, false); - ASSERT_EQ(5UL, frames_from_callback_.size()); - - InsertVp8(sn + 12, sn + 12, true); - ASSERT_EQ(6UL, frames_from_callback_.size()); - - InsertVp8(sn + 13, sn + 17, false); - ASSERT_EQ(7UL, frames_from_callback_.size()); - - InsertVp8(sn + 18, sn + 18, false); - ASSERT_EQ(8UL, frames_from_callback_.size()); - - InsertVp8(sn + 19, sn + 20, false); - ASSERT_EQ(9UL, frames_from_callback_.size()); - - InsertVp8(sn + 21, sn + 21, false); - - ASSERT_EQ(10UL, frames_from_callback_.size()); - CheckReferencesVp8(sn + 2); - CheckReferencesVp8(sn + 4, sn + 2); - CheckReferencesVp8(sn + 8, sn + 4); - CheckReferencesVp8(sn + 9, sn + 8); - CheckReferencesVp8(sn + 11, sn + 9); - CheckReferencesVp8(sn + 12); - CheckReferencesVp8(sn + 17, sn + 12); - CheckReferencesVp8(sn + 18, sn + 17); - CheckReferencesVp8(sn + 20, sn + 18); - CheckReferencesVp8(sn + 21, sn + 20); -} - -TEST_F(TestRtpFrameReferenceFinder, Vp8NoPictureIdReordered) { - uint16_t sn = 0xfffa; - - InsertVp8(sn, sn + 2, true); - InsertVp8(sn + 3, sn + 4, false); - InsertVp8(sn + 5, sn + 8, false); - InsertVp8(sn + 9, sn + 9, false); - InsertVp8(sn + 10, sn + 11, false); - InsertVp8(sn + 12, sn + 12, true); - InsertVp8(sn + 13, sn + 17, false); - InsertVp8(sn + 18, sn + 18, false); - InsertVp8(sn + 19, sn + 20, false); - InsertVp8(sn + 21, sn + 21, false); - - ASSERT_EQ(10UL, frames_from_callback_.size()); - CheckReferencesVp8(sn + 2); - CheckReferencesVp8(sn + 4, sn + 2); - CheckReferencesVp8(sn + 8, sn + 4); - CheckReferencesVp8(sn + 9, sn + 8); - CheckReferencesVp8(sn + 11, sn + 9); - CheckReferencesVp8(sn + 12); - CheckReferencesVp8(sn + 17, sn + 12); - CheckReferencesVp8(sn + 18, sn + 17); - CheckReferencesVp8(sn + 20, sn + 18); - CheckReferencesVp8(sn + 21, sn + 20); -} - -TEST_F(TestRtpFrameReferenceFinder, Vp8KeyFrameReferences) { - uint16_t sn = Rand(); - InsertVp8(sn, sn, true); - - ASSERT_EQ(1UL, frames_from_callback_.size()); - CheckReferencesVp8(sn); -} - -TEST_F(TestRtpFrameReferenceFinder, Vp8RepeatedFrame_0) { - uint16_t pid = Rand(); - uint16_t sn = Rand(); - - InsertVp8(sn, sn, true, pid, 0, 1); - InsertVp8(sn + 1, sn + 1, false, pid + 1, 0, 2); - InsertVp8(sn + 1, sn + 1, false, pid + 1, 0, 2); - - ASSERT_EQ(2UL, frames_from_callback_.size()); - CheckReferencesVp8(pid); - CheckReferencesVp8(pid + 1, pid); -} - -TEST_F(TestRtpFrameReferenceFinder, Vp8RepeatedFrameLayerSync_01) { - uint16_t pid = Rand(); - uint16_t sn = Rand(); - - InsertVp8(sn, sn, true, pid, 0, 1); - InsertVp8(sn + 1, sn + 1, false, pid + 1, 1, 1, true); - ASSERT_EQ(2UL, frames_from_callback_.size()); - InsertVp8(sn + 1, sn + 1, false, pid + 1, 1, 1, true); - - ASSERT_EQ(2UL, frames_from_callback_.size()); - CheckReferencesVp8(pid); - CheckReferencesVp8(pid + 1, pid); -} - -TEST_F(TestRtpFrameReferenceFinder, Vp8RepeatedFrame_01) { - uint16_t pid = Rand(); - uint16_t sn = Rand(); - - InsertVp8(sn, sn, true, pid, 0, 1); - InsertVp8(sn + 1, sn + 1, false, pid + 1, 0, 2, true); - InsertVp8(sn + 2, sn + 2, false, pid + 2, 0, 3); - InsertVp8(sn + 3, sn + 3, false, pid + 3, 0, 4); - InsertVp8(sn + 3, sn + 3, false, pid + 3, 0, 4); - - ASSERT_EQ(4UL, frames_from_callback_.size()); - CheckReferencesVp8(pid); - CheckReferencesVp8(pid + 1, pid); - CheckReferencesVp8(pid + 2, pid + 1); - CheckReferencesVp8(pid + 3, pid + 2); -} - -// Test with 1 temporal layer. -TEST_F(TestRtpFrameReferenceFinder, Vp8TemporalLayers_0) { - uint16_t pid = Rand(); - uint16_t sn = Rand(); - - InsertVp8(sn, sn, true, pid, 0, 1); - InsertVp8(sn + 1, sn + 1, false, pid + 1, 0, 2); - InsertVp8(sn + 2, sn + 2, false, pid + 2, 0, 3); - InsertVp8(sn + 3, sn + 3, false, pid + 3, 0, 4); - - ASSERT_EQ(4UL, frames_from_callback_.size()); - CheckReferencesVp8(pid); - CheckReferencesVp8(pid + 1, pid); - CheckReferencesVp8(pid + 2, pid + 1); - CheckReferencesVp8(pid + 3, pid + 2); -} - -TEST_F(TestRtpFrameReferenceFinder, Vp8DuplicateTl1Frames) { - uint16_t pid = Rand(); - uint16_t sn = Rand(); - - InsertVp8(sn, sn, true, pid, 0, 0); - InsertVp8(sn + 1, sn + 1, false, pid + 1, 1, 0, true); - InsertVp8(sn + 2, sn + 2, false, pid + 2, 0, 1); - InsertVp8(sn + 3, sn + 3, false, pid + 3, 1, 1); - InsertVp8(sn + 3, sn + 3, false, pid + 3, 1, 1); - InsertVp8(sn + 4, sn + 4, false, pid + 4, 0, 2); - InsertVp8(sn + 5, sn + 5, false, pid + 5, 1, 2); - - ASSERT_EQ(6UL, frames_from_callback_.size()); - CheckReferencesVp8(pid); - CheckReferencesVp8(pid + 1, pid); - CheckReferencesVp8(pid + 2, pid); - CheckReferencesVp8(pid + 3, pid + 1, pid + 2); - CheckReferencesVp8(pid + 4, pid + 2); - CheckReferencesVp8(pid + 5, pid + 3, pid + 4); -} - -// Test with 1 temporal layer. -TEST_F(TestRtpFrameReferenceFinder, Vp8TemporalLayersReordering_0) { - uint16_t pid = Rand(); - uint16_t sn = Rand(); - - InsertVp8(sn, sn, true, pid, 0, 1); - InsertVp8(sn + 1, sn + 1, false, pid + 1, 0, 2); - InsertVp8(sn + 3, sn + 3, false, pid + 3, 0, 4); - InsertVp8(sn + 2, sn + 2, false, pid + 2, 0, 3); - InsertVp8(sn + 5, sn + 5, false, pid + 5, 0, 6); - InsertVp8(sn + 6, sn + 6, false, pid + 6, 0, 7); - InsertVp8(sn + 4, sn + 4, false, pid + 4, 0, 5); - - ASSERT_EQ(7UL, frames_from_callback_.size()); - CheckReferencesVp8(pid); - CheckReferencesVp8(pid + 1, pid); - CheckReferencesVp8(pid + 2, pid + 1); - CheckReferencesVp8(pid + 3, pid + 2); - CheckReferencesVp8(pid + 4, pid + 3); - CheckReferencesVp8(pid + 5, pid + 4); - CheckReferencesVp8(pid + 6, pid + 5); -} - -// Test with 2 temporal layers in a 01 pattern. -TEST_F(TestRtpFrameReferenceFinder, Vp8TemporalLayers_01) { - uint16_t pid = Rand(); - uint16_t sn = Rand(); - - InsertVp8(sn, sn, true, pid, 0, 255); - InsertVp8(sn + 1, sn + 1, false, pid + 1, 1, 255, true); - InsertVp8(sn + 2, sn + 2, false, pid + 2, 0, 0); - InsertVp8(sn + 3, sn + 3, false, pid + 3, 1, 0); - - ASSERT_EQ(4UL, frames_from_callback_.size()); - CheckReferencesVp8(pid); - CheckReferencesVp8(pid + 1, pid); - CheckReferencesVp8(pid + 2, pid); - CheckReferencesVp8(pid + 3, pid + 1, pid + 2); -} - -// Test with 2 temporal layers in a 01 pattern. -TEST_F(TestRtpFrameReferenceFinder, Vp8TemporalLayersReordering_01) { - uint16_t pid = Rand(); - uint16_t sn = Rand(); - - InsertVp8(sn + 1, sn + 1, false, pid + 1, 1, 255, true); - InsertVp8(sn, sn, true, pid, 0, 255); - InsertVp8(sn + 3, sn + 3, false, pid + 3, 1, 0); - InsertVp8(sn + 5, sn + 5, false, pid + 5, 1, 1); - InsertVp8(sn + 2, sn + 2, false, pid + 2, 0, 0); - InsertVp8(sn + 4, sn + 4, false, pid + 4, 0, 1); - InsertVp8(sn + 6, sn + 6, false, pid + 6, 0, 2); - InsertVp8(sn + 7, sn + 7, false, pid + 7, 1, 2); - - ASSERT_EQ(8UL, frames_from_callback_.size()); - CheckReferencesVp8(pid); - CheckReferencesVp8(pid + 1, pid); - CheckReferencesVp8(pid + 2, pid); - CheckReferencesVp8(pid + 3, pid + 1, pid + 2); - CheckReferencesVp8(pid + 4, pid + 2); - CheckReferencesVp8(pid + 5, pid + 3, pid + 4); - CheckReferencesVp8(pid + 6, pid + 4); - CheckReferencesVp8(pid + 7, pid + 5, pid + 6); -} - -// Test with 3 temporal layers in a 0212 pattern. -TEST_F(TestRtpFrameReferenceFinder, Vp8TemporalLayers_0212) { - uint16_t pid = Rand(); - uint16_t sn = Rand(); - - InsertVp8(sn, sn, true, pid, 0, 55); - InsertVp8(sn + 1, sn + 1, false, pid + 1, 2, 55, true); - InsertVp8(sn + 2, sn + 2, false, pid + 2, 1, 55, true); - InsertVp8(sn + 3, sn + 3, false, pid + 3, 2, 55); - InsertVp8(sn + 4, sn + 4, false, pid + 4, 0, 56); - InsertVp8(sn + 5, sn + 5, false, pid + 5, 2, 56); - InsertVp8(sn + 6, sn + 6, false, pid + 6, 1, 56); - InsertVp8(sn + 7, sn + 7, false, pid + 7, 2, 56); - InsertVp8(sn + 8, sn + 8, false, pid + 8, 0, 57); - InsertVp8(sn + 9, sn + 9, false, pid + 9, 2, 57, true); - InsertVp8(sn + 10, sn + 10, false, pid + 10, 1, 57, true); - InsertVp8(sn + 11, sn + 11, false, pid + 11, 2, 57); - - ASSERT_EQ(12UL, frames_from_callback_.size()); - CheckReferencesVp8(pid); - CheckReferencesVp8(pid + 1, pid); - CheckReferencesVp8(pid + 2, pid); - CheckReferencesVp8(pid + 3, pid, pid + 1, pid + 2); - CheckReferencesVp8(pid + 4, pid); - CheckReferencesVp8(pid + 5, pid + 2, pid + 3, pid + 4); - CheckReferencesVp8(pid + 6, pid + 2, pid + 4); - CheckReferencesVp8(pid + 7, pid + 4, pid + 5, pid + 6); - CheckReferencesVp8(pid + 8, pid + 4); - CheckReferencesVp8(pid + 9, pid + 8); - CheckReferencesVp8(pid + 10, pid + 8); - CheckReferencesVp8(pid + 11, pid + 8, pid + 9, pid + 10); -} - -// Test with 3 temporal layers in a 0212 pattern. -TEST_F(TestRtpFrameReferenceFinder, Vp8TemporalLayersMissingFrame_0212) { - uint16_t pid = Rand(); - uint16_t sn = Rand(); - - InsertVp8(sn, sn, true, pid, 0, 55, false); - InsertVp8(sn + 2, sn + 2, false, pid + 2, 1, 55, true); - InsertVp8(sn + 3, sn + 3, false, pid + 3, 2, 55, false); - - ASSERT_EQ(2UL, frames_from_callback_.size()); - CheckReferencesVp8(pid); - CheckReferencesVp8(pid + 2, pid); -} - -// Test with 3 temporal layers in a 0212 pattern. -TEST_F(TestRtpFrameReferenceFinder, Vp8TemporalLayersReordering_0212) { - uint16_t pid = 126; - uint16_t sn = Rand(); - - InsertVp8(sn + 1, sn + 1, false, pid + 1, 2, 55, true); - InsertVp8(sn, sn, true, pid, 0, 55, false); - InsertVp8(sn + 2, sn + 2, false, pid + 2, 1, 55, true); - InsertVp8(sn + 4, sn + 4, false, pid + 4, 0, 56, false); - InsertVp8(sn + 5, sn + 5, false, pid + 5, 2, 56, false); - InsertVp8(sn + 3, sn + 3, false, pid + 3, 2, 55, false); - InsertVp8(sn + 7, sn + 7, false, pid + 7, 2, 56, false); - InsertVp8(sn + 9, sn + 9, false, pid + 9, 2, 57, true); - InsertVp8(sn + 6, sn + 6, false, pid + 6, 1, 56, false); - InsertVp8(sn + 8, sn + 8, false, pid + 8, 0, 57, false); - InsertVp8(sn + 11, sn + 11, false, pid + 11, 2, 57, false); - InsertVp8(sn + 10, sn + 10, false, pid + 10, 1, 57, true); - - ASSERT_EQ(12UL, frames_from_callback_.size()); - CheckReferencesVp8(pid); - CheckReferencesVp8(pid + 1, pid); - CheckReferencesVp8(pid + 2, pid); - CheckReferencesVp8(pid + 3, pid, pid + 1, pid + 2); - CheckReferencesVp8(pid + 4, pid); - CheckReferencesVp8(pid + 5, pid + 2, pid + 3, pid + 4); - CheckReferencesVp8(pid + 6, pid + 2, pid + 4); - CheckReferencesVp8(pid + 7, pid + 4, pid + 5, pid + 6); - CheckReferencesVp8(pid + 8, pid + 4); - CheckReferencesVp8(pid + 9, pid + 8); - CheckReferencesVp8(pid + 10, pid + 8); - CheckReferencesVp8(pid + 11, pid + 8, pid + 9, pid + 10); -} - -TEST_F(TestRtpFrameReferenceFinder, Vp8InsertManyFrames_0212) { - uint16_t pid = Rand(); - uint16_t sn = Rand(); - - const int keyframes_to_insert = 50; - const int frames_per_keyframe = 120; // Should be a multiple of 4. - uint8_t tl0 = 128; - - for (int k = 0; k < keyframes_to_insert; ++k) { - InsertVp8(sn, sn, true, pid, 0, tl0, false); - InsertVp8(sn + 1, sn + 1, false, pid + 1, 2, tl0, true); - InsertVp8(sn + 2, sn + 2, false, pid + 2, 1, tl0, true); - InsertVp8(sn + 3, sn + 3, false, pid + 3, 2, tl0, false); - CheckReferencesVp8(pid); - CheckReferencesVp8(pid + 1, pid); - CheckReferencesVp8(pid + 2, pid); - CheckReferencesVp8(pid + 3, pid, pid + 1, pid + 2); - frames_from_callback_.clear(); - ++tl0; - - for (int f = 4; f < frames_per_keyframe; f += 4) { - uint16_t sf = sn + f; - int64_t pidf = pid + f; - - InsertVp8(sf, sf, false, pidf, 0, tl0, false); - InsertVp8(sf + 1, sf + 1, false, pidf + 1, 2, tl0, false); - InsertVp8(sf + 2, sf + 2, false, pidf + 2, 1, tl0, false); - InsertVp8(sf + 3, sf + 3, false, pidf + 3, 2, tl0, false); - CheckReferencesVp8(pidf, pidf - 4); - CheckReferencesVp8(pidf + 1, pidf, pidf - 1, pidf - 2); - CheckReferencesVp8(pidf + 2, pidf, pidf - 2); - CheckReferencesVp8(pidf + 3, pidf, pidf + 1, pidf + 2); - frames_from_callback_.clear(); - ++tl0; - } - - pid += frames_per_keyframe; - sn += frames_per_keyframe; - } -} - -TEST_F(TestRtpFrameReferenceFinder, Vp8LayerSync) { - uint16_t pid = Rand(); - uint16_t sn = Rand(); - - InsertVp8(sn, sn, true, pid, 0, 0, false); - InsertVp8(sn + 1, sn + 1, false, pid + 1, 1, 0, true); - InsertVp8(sn + 2, sn + 2, false, pid + 2, 0, 1, false); - ASSERT_EQ(3UL, frames_from_callback_.size()); - - InsertVp8(sn + 4, sn + 4, false, pid + 4, 0, 2, false); - InsertVp8(sn + 5, sn + 5, false, pid + 5, 1, 2, true); - InsertVp8(sn + 6, sn + 6, false, pid + 6, 0, 3, false); - InsertVp8(sn + 7, sn + 7, false, pid + 7, 1, 3, false); - - ASSERT_EQ(7UL, frames_from_callback_.size()); - CheckReferencesVp8(pid); - CheckReferencesVp8(pid + 1, pid); - CheckReferencesVp8(pid + 2, pid); - CheckReferencesVp8(pid + 4, pid + 2); - CheckReferencesVp8(pid + 5, pid + 4); - CheckReferencesVp8(pid + 6, pid + 4); - CheckReferencesVp8(pid + 7, pid + 6, pid + 5); -} - -TEST_F(TestRtpFrameReferenceFinder, Vp8Tl1SyncFrameAfterTl1Frame) { - InsertVp8(1000, 1000, true, 1, 0, 247, true); - InsertVp8(1001, 1001, false, 3, 0, 248, false); - InsertVp8(1002, 1002, false, 4, 1, 248, false); // Will be dropped - InsertVp8(1003, 1003, false, 5, 1, 248, true); // due to this frame. - - ASSERT_EQ(3UL, frames_from_callback_.size()); - CheckReferencesVp8(1); - CheckReferencesVp8(3, 1); - CheckReferencesVp8(5, 3); -} - -TEST_F(TestRtpFrameReferenceFinder, Vp8DetectMissingFrame_0212) { - InsertVp8(1, 1, true, 1, 0, 1, false); - InsertVp8(2, 2, false, 2, 2, 1, true); - InsertVp8(3, 3, false, 3, 1, 1, true); - InsertVp8(4, 4, false, 4, 2, 1, false); - - InsertVp8(6, 6, false, 6, 2, 2, false); - InsertVp8(7, 7, false, 7, 1, 2, false); - InsertVp8(8, 8, false, 8, 2, 2, false); - ASSERT_EQ(4UL, frames_from_callback_.size()); - - InsertVp8(5, 5, false, 5, 0, 2, false); - ASSERT_EQ(8UL, frames_from_callback_.size()); - - CheckReferencesVp8(1); - CheckReferencesVp8(2, 1); - CheckReferencesVp8(3, 1); - CheckReferencesVp8(4, 3, 2, 1); - - CheckReferencesVp8(5, 1); - CheckReferencesVp8(6, 5, 4, 3); - CheckReferencesVp8(7, 5, 3); - CheckReferencesVp8(8, 7, 6, 5); -} - TEST_F(TestRtpFrameReferenceFinder, H264KeyFrameReferences) { uint16_t sn = Rand(); InsertH264(sn, sn, true); @@ -742,11 +313,10 @@ TEST_F(TestRtpFrameReferenceFinder, Av1FrameNoDependencyDescriptor) { CreateFrame(/*seq_num_start=*/sn, /*seq_num_end=*/sn, /*keyframe=*/true, kVideoCodecAV1, RTPVideoTypeHeader()); - reference_finder_->ManageFrame(std::move(frame)); + OnCompleteFrames(reference_finder_->ManageFrame(std::move(frame))); ASSERT_EQ(1UL, frames_from_callback_.size()); CheckReferencesGeneric(sn); } -} // namespace video_coding } // namespace webrtc diff --git a/modules/video_coding/rtp_generic_ref_finder.cc b/modules/video_coding/rtp_generic_ref_finder.cc index f5603e3ca9..87fff9c26f 100644 --- a/modules/video_coding/rtp_generic_ref_finder.cc +++ b/modules/video_coding/rtp_generic_ref_finder.cc @@ -15,14 +15,13 @@ #include "rtc_base/logging.h" namespace webrtc { -namespace video_coding { RtpFrameReferenceFinder::ReturnVector RtpGenericFrameRefFinder::ManageFrame( std::unique_ptr frame, const RTPVideoHeader::GenericDescriptorInfo& descriptor) { // Frame IDs are unwrapped in the RtpVideoStreamReceiver, no need to unwrap // them here. - frame->id.picture_id = descriptor.frame_id; + frame->SetId(descriptor.frame_id); frame->SetSpatialIndex(descriptor.spatial_index); RtpFrameReferenceFinder::ReturnVector res; @@ -40,5 +39,4 @@ RtpFrameReferenceFinder::ReturnVector RtpGenericFrameRefFinder::ManageFrame( return res; } -} // namespace video_coding } // namespace webrtc diff --git a/modules/video_coding/rtp_generic_ref_finder.h b/modules/video_coding/rtp_generic_ref_finder.h index 278de2635e..87d7b59406 100644 --- a/modules/video_coding/rtp_generic_ref_finder.h +++ b/modules/video_coding/rtp_generic_ref_finder.h @@ -17,7 +17,6 @@ #include "modules/video_coding/rtp_frame_reference_finder.h" namespace webrtc { -namespace video_coding { class RtpGenericFrameRefFinder { public: @@ -28,7 +27,6 @@ class RtpGenericFrameRefFinder { const RTPVideoHeader::GenericDescriptorInfo& descriptor); }; -} // namespace video_coding } // namespace webrtc #endif // MODULES_VIDEO_CODING_RTP_GENERIC_REF_FINDER_H_ diff --git a/modules/video_coding/rtp_seq_num_only_ref_finder.cc b/modules/video_coding/rtp_seq_num_only_ref_finder.cc index 7177a14be3..4381cf0952 100644 --- a/modules/video_coding/rtp_seq_num_only_ref_finder.cc +++ b/modules/video_coding/rtp_seq_num_only_ref_finder.cc @@ -15,7 +15,6 @@ #include "rtc_base/logging.h" namespace webrtc { -namespace video_coding { RtpFrameReferenceFinder::ReturnVector RtpSeqNumOnlyRefFinder::ManageFrame( std::unique_ptr frame) { @@ -86,17 +85,18 @@ RtpSeqNumOnlyRefFinder::ManageFrameInternal(RtpFrameObject* frame) { // Since keyframes can cause reordering we can't simply assign the // picture id according to some incrementing counter. - frame->id.picture_id = frame->last_seq_num(); + frame->SetId(frame->last_seq_num()); frame->num_references = frame->frame_type() == VideoFrameType::kVideoFrameDelta; frame->references[0] = rtp_seq_num_unwrapper_.Unwrap(last_picture_id_gop); - if (AheadOf(frame->id.picture_id, last_picture_id_gop)) { - seq_num_it->second.first = frame->id.picture_id; - seq_num_it->second.second = frame->id.picture_id; + if (AheadOf(frame->Id(), last_picture_id_gop)) { + seq_num_it->second.first = frame->Id(); + seq_num_it->second.second = frame->Id(); } - UpdateLastPictureIdWithPadding(frame->id.picture_id); - frame->id.picture_id = rtp_seq_num_unwrapper_.Unwrap(frame->id.picture_id); + UpdateLastPictureIdWithPadding(frame->Id()); + frame->SetSpatialIndex(0); + frame->SetId(rtp_seq_num_unwrapper_.Unwrap(frame->Id())); return kHandOff; } @@ -183,5 +183,4 @@ void RtpSeqNumOnlyRefFinder::ClearTo(uint16_t seq_num) { } } -} // namespace video_coding } // namespace webrtc diff --git a/modules/video_coding/rtp_seq_num_only_ref_finder.h b/modules/video_coding/rtp_seq_num_only_ref_finder.h index 1b0cc7722a..ef3c022111 100644 --- a/modules/video_coding/rtp_seq_num_only_ref_finder.h +++ b/modules/video_coding/rtp_seq_num_only_ref_finder.h @@ -23,7 +23,6 @@ #include "rtc_base/numerics/sequence_number_util.h" namespace webrtc { -namespace video_coding { class RtpSeqNumOnlyRefFinder { public: @@ -66,7 +65,6 @@ class RtpSeqNumOnlyRefFinder { SeqNumUnwrapper rtp_seq_num_unwrapper_; }; -} // namespace video_coding } // namespace webrtc #endif // MODULES_VIDEO_CODING_RTP_SEQ_NUM_ONLY_REF_FINDER_H_ diff --git a/modules/video_coding/rtp_vp8_ref_finder.cc b/modules/video_coding/rtp_vp8_ref_finder.cc index 341bba90a4..b448b23308 100644 --- a/modules/video_coding/rtp_vp8_ref_finder.cc +++ b/modules/video_coding/rtp_vp8_ref_finder.cc @@ -15,7 +15,6 @@ #include "rtc_base/logging.h" namespace webrtc { -namespace video_coding { RtpFrameReferenceFinder::ReturnVector RtpVp8RefFinder::ManageFrame( std::unique_ptr frame) { @@ -49,14 +48,15 @@ RtpVp8RefFinder::FrameDecision RtpVp8RefFinder::ManageFrameInternal( if (codec_header.temporalIdx >= kMaxTemporalLayers) return kDrop; - frame->id.picture_id = codec_header.pictureId & 0x7FFF; + frame->SetSpatialIndex(0); + frame->SetId(codec_header.pictureId & 0x7FFF); if (last_picture_id_ == -1) - last_picture_id_ = frame->id.picture_id; + last_picture_id_ = frame->Id(); // Clean up info about not yet received frames that are too old. uint16_t old_picture_id = - Subtract(frame->id.picture_id, kMaxNotYetReceivedFrames); + Subtract(frame->Id(), kMaxNotYetReceivedFrames); auto clean_frames_to = not_yet_received_frames_.lower_bound(old_picture_id); not_yet_received_frames_.erase(not_yet_received_frames_.begin(), clean_frames_to); @@ -66,12 +66,11 @@ RtpVp8RefFinder::FrameDecision RtpVp8RefFinder::ManageFrameInternal( } // Find if there has been a gap in fully received frames and save the picture // id of those frames in |not_yet_received_frames_|. - if (AheadOf(frame->id.picture_id, - last_picture_id_)) { + if (AheadOf(frame->Id(), last_picture_id_)) { do { last_picture_id_ = Add(last_picture_id_, 1); not_yet_received_frames_.insert(last_picture_id_); - } while (last_picture_id_ != frame->id.picture_id); + } while (last_picture_id_ != frame->Id()); } int64_t unwrapped_tl0 = tl0_unwrapper_.Unwrap(codec_header.tl0PicIdx & 0xFF); @@ -109,8 +108,7 @@ RtpVp8RefFinder::FrameDecision RtpVp8RefFinder::ManageFrameInternal( // Is this an old frame that has already been used to update the state? If // so, drop it. - if (AheadOrAt(last_pid_on_layer, - frame->id.picture_id)) { + if (AheadOrAt(last_pid_on_layer, frame->Id())) { return kDrop; } @@ -127,8 +125,7 @@ RtpVp8RefFinder::FrameDecision RtpVp8RefFinder::ManageFrameInternal( // Is this an old frame that has already been used to update the state? If // so, drop it. if (last_pid_on_layer != -1 && - AheadOrAt(last_pid_on_layer, - frame->id.picture_id)) { + AheadOrAt(last_pid_on_layer, frame->Id())) { return kDrop; } @@ -149,7 +146,7 @@ RtpVp8RefFinder::FrameDecision RtpVp8RefFinder::ManageFrameInternal( // a layer sync frame has been received after this frame for the same // base layer frame, drop this frame. if (AheadOf(layer_info_it->second[layer], - frame->id.picture_id)) { + frame->Id())) { return kDrop; } @@ -158,14 +155,14 @@ RtpVp8RefFinder::FrameDecision RtpVp8RefFinder::ManageFrameInternal( auto not_received_frame_it = not_yet_received_frames_.upper_bound(layer_info_it->second[layer]); if (not_received_frame_it != not_yet_received_frames_.end() && - AheadOf(frame->id.picture_id, + AheadOf(frame->Id(), *not_received_frame_it)) { return kStash; } - if (!(AheadOf(frame->id.picture_id, + if (!(AheadOf(frame->Id(), layer_info_it->second[layer]))) { - RTC_LOG(LS_WARNING) << "Frame with picture id " << frame->id.picture_id + RTC_LOG(LS_WARNING) << "Frame with picture id " << frame->Id() << " and packet range [" << frame->first_seq_num() << ", " << frame->last_seq_num() << "] already received, " @@ -190,17 +187,17 @@ void RtpVp8RefFinder::UpdateLayerInfoVp8(RtpFrameObject* frame, while (layer_info_it != layer_info_.end()) { if (layer_info_it->second[temporal_idx] != -1 && AheadOf(layer_info_it->second[temporal_idx], - frame->id.picture_id)) { + frame->Id())) { // The frame was not newer, then no subsequent layer info have to be // update. break; } - layer_info_it->second[temporal_idx] = frame->id.picture_id; + layer_info_it->second[temporal_idx] = frame->Id(); ++unwrapped_tl0; layer_info_it = layer_info_.find(unwrapped_tl0); } - not_yet_received_frames_.erase(frame->id.picture_id); + not_yet_received_frames_.erase(frame->Id()); UnwrapPictureIds(frame); } @@ -232,7 +229,7 @@ void RtpVp8RefFinder::RetryStashedFrames( void RtpVp8RefFinder::UnwrapPictureIds(RtpFrameObject* frame) { for (size_t i = 0; i < frame->num_references; ++i) frame->references[i] = unwrapper_.Unwrap(frame->references[i]); - frame->id.picture_id = unwrapper_.Unwrap(frame->id.picture_id); + frame->SetId(unwrapper_.Unwrap(frame->Id())); } void RtpVp8RefFinder::ClearTo(uint16_t seq_num) { @@ -246,5 +243,4 @@ void RtpVp8RefFinder::ClearTo(uint16_t seq_num) { } } -} // namespace video_coding } // namespace webrtc diff --git a/modules/video_coding/rtp_vp8_ref_finder.h b/modules/video_coding/rtp_vp8_ref_finder.h index 55d2de921e..0a6cd7e10d 100644 --- a/modules/video_coding/rtp_vp8_ref_finder.h +++ b/modules/video_coding/rtp_vp8_ref_finder.h @@ -22,7 +22,6 @@ #include "rtc_base/numerics/sequence_number_util.h" namespace webrtc { -namespace video_coding { class RtpVp8RefFinder { public: @@ -72,7 +71,6 @@ class RtpVp8RefFinder { SeqNumUnwrapper tl0_unwrapper_; }; -} // namespace video_coding } // namespace webrtc #endif // MODULES_VIDEO_CODING_RTP_VP8_REF_FINDER_H_ diff --git a/modules/video_coding/rtp_vp8_ref_finder_unittest.cc b/modules/video_coding/rtp_vp8_ref_finder_unittest.cc new file mode 100644 index 0000000000..a77149a89b --- /dev/null +++ b/modules/video_coding/rtp_vp8_ref_finder_unittest.cc @@ -0,0 +1,360 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/video_coding/rtp_vp8_ref_finder.h" + +#include +#include + +#include "modules/video_coding/frame_object.h" +#include "test/gmock.h" +#include "test/gtest.h" + +using ::testing::Contains; +using ::testing::Eq; +using ::testing::Matcher; +using ::testing::Matches; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAreArray; + +namespace webrtc { +namespace { + +MATCHER_P2(HasIdAndRefs, id, refs, "") { + return Matches(Eq(id))(arg->Id()) && + Matches(UnorderedElementsAreArray(refs))( + rtc::ArrayView(arg->references, arg->num_references)); +} + +Matcher>&> +HasFrameWithIdAndRefs(int64_t frame_id, const std::vector& refs) { + return Contains(HasIdAndRefs(frame_id, refs)); +} + +class Frame { + public: + Frame& AsKeyFrame(bool is_keyframe = true) { + is_keyframe_ = is_keyframe; + return *this; + } + + Frame& Pid(int pid) { + picture_id_ = pid; + return *this; + } + + Frame& Tid(int tid) { + temporal_id_ = tid; + return *this; + } + + Frame& Tl0(int tl0) { + tl0_idx_ = tl0; + return *this; + } + + Frame& AsSync(bool is_sync = true) { + sync = is_sync; + return *this; + } + + operator std::unique_ptr() { + RTPVideoHeaderVP8 vp8_header{}; + vp8_header.pictureId = *picture_id_; + vp8_header.temporalIdx = *temporal_id_; + vp8_header.tl0PicIdx = *tl0_idx_; + vp8_header.layerSync = sync; + + RTPVideoHeader video_header; + video_header.frame_type = is_keyframe_ ? VideoFrameType::kVideoFrameKey + : VideoFrameType::kVideoFrameDelta; + video_header.video_type_header = vp8_header; + // clang-format off + return std::make_unique( + /*seq_num_start=*/0, + /*seq_num_end=*/0, + /*markerBit=*/true, + /*times_nacked=*/0, + /*first_packet_received_time=*/0, + /*last_packet_received_time=*/0, + /*rtp_timestamp=*/0, + /*ntp_time_ms=*/0, + VideoSendTiming(), + /*payload_type=*/0, + kVideoCodecVP8, + kVideoRotation_0, + VideoContentType::UNSPECIFIED, + video_header, + /*color_space=*/absl::nullopt, + RtpPacketInfos(), + EncodedImageBuffer::Create(/*size=*/0)); + // clang-format on + } + + private: + bool is_keyframe_ = false; + absl::optional picture_id_; + absl::optional temporal_id_; + absl::optional tl0_idx_; + bool sync = false; +}; + +} // namespace + +class RtpVp8RefFinderTest : public ::testing::Test { + protected: + RtpVp8RefFinderTest() : ref_finder_(std::make_unique()) {} + + void Insert(std::unique_ptr frame) { + for (auto& f : ref_finder_->ManageFrame(std::move(frame))) { + frames_.push_back(std::move(f)); + } + } + + std::unique_ptr ref_finder_; + std::vector> frames_; +}; + +TEST_F(RtpVp8RefFinderTest, Vp8RepeatedFrame_0) { + Insert(Frame().Pid(0).Tid(0).Tl0(1).AsKeyFrame()); + Insert(Frame().Pid(1).Tid(0).Tl0(2)); + Insert(Frame().Pid(1).Tid(0).Tl0(2)); + + EXPECT_THAT(frames_, SizeIs(2)); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(0, {})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(1, {0})); +} + +TEST_F(RtpVp8RefFinderTest, Vp8RepeatedFrameLayerSync_01) { + Insert(Frame().Pid(0).Tid(0).Tl0(1).AsKeyFrame()); + Insert(Frame().Pid(1).Tid(1).Tl0(1).AsSync()); + Insert(Frame().Pid(1).Tid(1).Tl0(1).AsSync()); + + EXPECT_THAT(frames_, SizeIs(2)); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(0, {})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(1, {0})); +} + +TEST_F(RtpVp8RefFinderTest, Vp8RepeatedFrame_01) { + Insert(Frame().Pid(0).Tid(0).Tl0(1).AsKeyFrame()); + Insert(Frame().Pid(1).Tid(0).Tl0(2).AsSync()); + Insert(Frame().Pid(2).Tid(0).Tl0(3)); + Insert(Frame().Pid(3).Tid(0).Tl0(4)); + Insert(Frame().Pid(3).Tid(0).Tl0(4)); + + EXPECT_THAT(frames_, SizeIs(4)); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(0, {})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(1, {0})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(2, {1})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(3, {2})); +} + +TEST_F(RtpVp8RefFinderTest, Vp8TemporalLayers_0) { + Insert(Frame().Pid(0).Tid(0).Tl0(1).AsKeyFrame()); + Insert(Frame().Pid(1).Tid(0).Tl0(2)); + + EXPECT_THAT(frames_, SizeIs(2)); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(0, {})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(1, {0})); +} + +TEST_F(RtpVp8RefFinderTest, Vp8DuplicateTl1Frames) { + Insert(Frame().Pid(0).Tid(0).Tl0(0).AsKeyFrame()); + Insert(Frame().Pid(1).Tid(1).Tl0(0).AsSync()); + Insert(Frame().Pid(2).Tid(0).Tl0(1)); + Insert(Frame().Pid(3).Tid(1).Tl0(1)); + Insert(Frame().Pid(3).Tid(1).Tl0(1)); + Insert(Frame().Pid(4).Tid(0).Tl0(2)); + Insert(Frame().Pid(5).Tid(1).Tl0(2)); + + EXPECT_THAT(frames_, SizeIs(6)); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(0, {})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(1, {0})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(2, {0})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(3, {1, 2})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(4, {2})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(5, {3, 4})); +} + +TEST_F(RtpVp8RefFinderTest, Vp8TemporalLayersReordering_0) { + Insert(Frame().Pid(1).Tid(0).Tl0(2)); + Insert(Frame().Pid(0).Tid(0).Tl0(1).AsKeyFrame()); + Insert(Frame().Pid(3).Tid(0).Tl0(4)); + Insert(Frame().Pid(2).Tid(0).Tl0(3)); + Insert(Frame().Pid(5).Tid(0).Tl0(6)); + Insert(Frame().Pid(6).Tid(0).Tl0(7)); + Insert(Frame().Pid(4).Tid(0).Tl0(5)); + + EXPECT_THAT(frames_, SizeIs(7)); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(0, {})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(1, {0})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(2, {1})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(3, {2})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(4, {3})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(5, {4})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(6, {5})); +} + +TEST_F(RtpVp8RefFinderTest, Vp8TemporalLayers_01) { + Insert(Frame().Pid(0).Tid(0).Tl0(255).AsKeyFrame()); + Insert(Frame().Pid(1).Tid(1).Tl0(255).AsSync()); + Insert(Frame().Pid(2).Tid(0).Tl0(0)); + Insert(Frame().Pid(3).Tid(1).Tl0(0)); + + EXPECT_THAT(frames_, SizeIs(4)); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(0, {})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(1, {0})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(2, {0})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(3, {1, 2})); +} + +TEST_F(RtpVp8RefFinderTest, Vp8TemporalLayersReordering_01) { + Insert(Frame().Pid(1).Tid(1).Tl0(255).AsSync()); + Insert(Frame().Pid(0).Tid(0).Tl0(255).AsKeyFrame()); + Insert(Frame().Pid(3).Tid(1).Tl0(0)); + Insert(Frame().Pid(5).Tid(1).Tl0(1)); + Insert(Frame().Pid(2).Tid(0).Tl0(0)); + Insert(Frame().Pid(4).Tid(0).Tl0(1)); + Insert(Frame().Pid(6).Tid(0).Tl0(2)); + Insert(Frame().Pid(7).Tid(1).Tl0(2)); + + EXPECT_THAT(frames_, SizeIs(8)); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(0, {})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(1, {0})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(2, {0})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(3, {1, 2})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(4, {2})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(5, {3, 4})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(6, {4})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(7, {5, 6})); +} + +TEST_F(RtpVp8RefFinderTest, Vp8TemporalLayers_0212) { + Insert(Frame().Pid(0).Tid(0).Tl0(55).AsKeyFrame()); + Insert(Frame().Pid(1).Tid(2).Tl0(55).AsSync()); + Insert(Frame().Pid(2).Tid(1).Tl0(55).AsSync()); + Insert(Frame().Pid(3).Tid(2).Tl0(55)); + Insert(Frame().Pid(4).Tid(0).Tl0(56)); + Insert(Frame().Pid(5).Tid(2).Tl0(56)); + Insert(Frame().Pid(6).Tid(1).Tl0(56)); + Insert(Frame().Pid(7).Tid(2).Tl0(56)); + Insert(Frame().Pid(8).Tid(0).Tl0(57)); + Insert(Frame().Pid(9).Tid(2).Tl0(57).AsSync()); + Insert(Frame().Pid(10).Tid(1).Tl0(57).AsSync()); + Insert(Frame().Pid(11).Tid(2).Tl0(57)); + + EXPECT_THAT(frames_, SizeIs(12)); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(0, {})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(1, {0})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(2, {0})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(3, {0, 1, 2})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(4, {0})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(5, {2, 3, 4})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(6, {2, 4})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(7, {4, 5, 6})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(8, {4})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(9, {8})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(10, {8})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(11, {8, 9, 10})); +} + +TEST_F(RtpVp8RefFinderTest, Vp8TemporalLayersMissingFrame_0212) { + Insert(Frame().Pid(0).Tid(0).Tl0(55).AsKeyFrame()); + Insert(Frame().Pid(2).Tid(1).Tl0(55).AsSync()); + Insert(Frame().Pid(3).Tid(2).Tl0(55)); + + EXPECT_THAT(frames_, SizeIs(2)); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(0, {})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(2, {0})); +} + +// Test with 3 temporal layers in a 0212 pattern. +TEST_F(RtpVp8RefFinderTest, Vp8TemporalLayersReordering_0212) { + Insert(Frame().Pid(127).Tid(2).Tl0(55).AsSync()); + Insert(Frame().Pid(126).Tid(0).Tl0(55).AsKeyFrame()); + Insert(Frame().Pid(128).Tid(1).Tl0(55).AsSync()); + Insert(Frame().Pid(130).Tid(0).Tl0(56)); + Insert(Frame().Pid(131).Tid(2).Tl0(56)); + Insert(Frame().Pid(129).Tid(2).Tl0(55)); + Insert(Frame().Pid(133).Tid(2).Tl0(56)); + Insert(Frame().Pid(135).Tid(2).Tl0(57).AsSync()); + Insert(Frame().Pid(132).Tid(1).Tl0(56)); + Insert(Frame().Pid(134).Tid(0).Tl0(57)); + Insert(Frame().Pid(137).Tid(2).Tl0(57)); + Insert(Frame().Pid(136).Tid(1).Tl0(57).AsSync()); + + EXPECT_THAT(frames_, SizeIs(12)); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(126, {})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(127, {126})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(128, {126})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(129, {126, 127, 128})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(130, {126})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(131, {128, 129, 130})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(132, {128, 130})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(133, {130, 131, 132})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(134, {130})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(135, {134})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(136, {134})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(137, {134, 135, 136})); +} + +TEST_F(RtpVp8RefFinderTest, Vp8LayerSync) { + Insert(Frame().Pid(0).Tid(0).Tl0(0).AsKeyFrame()); + Insert(Frame().Pid(1).Tid(1).Tl0(0).AsSync()); + Insert(Frame().Pid(2).Tid(0).Tl0(1)); + Insert(Frame().Pid(4).Tid(0).Tl0(2)); + Insert(Frame().Pid(5).Tid(1).Tl0(2).AsSync()); + Insert(Frame().Pid(6).Tid(0).Tl0(3)); + Insert(Frame().Pid(7).Tid(1).Tl0(3)); + + EXPECT_THAT(frames_, SizeIs(7)); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(0, {})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(1, {0})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(2, {0})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(4, {2})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(5, {4})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(6, {4})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(7, {5, 6})); +} + +TEST_F(RtpVp8RefFinderTest, Vp8Tl1SyncFrameAfterTl1Frame) { + Insert(Frame().Pid(1).Tid(0).Tl0(247).AsKeyFrame().AsSync()); + Insert(Frame().Pid(3).Tid(0).Tl0(248)); + Insert(Frame().Pid(4).Tid(1).Tl0(248)); + Insert(Frame().Pid(5).Tid(1).Tl0(248).AsSync()); + + EXPECT_THAT(frames_, SizeIs(3)); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(1, {})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(3, {1})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(5, {3})); +} + +TEST_F(RtpVp8RefFinderTest, Vp8DetectMissingFrame_0212) { + Insert(Frame().Pid(1).Tid(0).Tl0(1).AsKeyFrame()); + Insert(Frame().Pid(2).Tid(2).Tl0(1).AsSync()); + Insert(Frame().Pid(3).Tid(1).Tl0(1).AsSync()); + Insert(Frame().Pid(4).Tid(2).Tl0(1)); + Insert(Frame().Pid(6).Tid(2).Tl0(2)); + Insert(Frame().Pid(7).Tid(1).Tl0(2)); + Insert(Frame().Pid(8).Tid(2).Tl0(2)); + Insert(Frame().Pid(5).Tid(0).Tl0(2)); + + EXPECT_THAT(frames_, SizeIs(8)); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(1, {})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(2, {1})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(3, {1})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(4, {1, 2, 3})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(5, {1})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(6, {3, 4, 5})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(7, {3, 5})); + EXPECT_THAT(frames_, HasFrameWithIdAndRefs(8, {5, 6, 7})); +} + +} // namespace webrtc diff --git a/modules/video_coding/rtp_vp9_ref_finder.cc b/modules/video_coding/rtp_vp9_ref_finder.cc index e1dba9cd0e..b44bb2500d 100644 --- a/modules/video_coding/rtp_vp9_ref_finder.cc +++ b/modules/video_coding/rtp_vp9_ref_finder.cc @@ -16,7 +16,6 @@ #include "rtc_base/logging.h" namespace webrtc { -namespace video_coding { RtpFrameReferenceFinder::ReturnVector RtpVp9RefFinder::ManageFrame( std::unique_ptr frame) { @@ -52,10 +51,10 @@ RtpVp9RefFinder::FrameDecision RtpVp9RefFinder::ManageFrameInternal( return kDrop; frame->SetSpatialIndex(codec_header.spatial_idx); - frame->id.picture_id = codec_header.picture_id & (kFrameIdLength - 1); + frame->SetId(codec_header.picture_id & (kFrameIdLength - 1)); if (last_picture_id_ == -1) - last_picture_id_ = frame->id.picture_id; + last_picture_id_ = frame->Id(); if (codec_header.flexible_mode) { if (codec_header.num_ref_pics > EncodedFrame::kMaxFrameReferences) { @@ -63,8 +62,8 @@ RtpVp9RefFinder::FrameDecision RtpVp9RefFinder::ManageFrameInternal( } frame->num_references = codec_header.num_ref_pics; for (size_t i = 0; i < frame->num_references; ++i) { - frame->references[i] = Subtract(frame->id.picture_id, - codec_header.pid_diff[i]); + frame->references[i] = + Subtract(frame->Id(), codec_header.pid_diff[i]); } FlattenFrameIdAndRefs(frame, codec_header.inter_layer_predicted); @@ -104,10 +103,10 @@ RtpVp9RefFinder::FrameDecision RtpVp9RefFinder::ManageFrameInternal( current_ss_idx_ = Add(current_ss_idx_, 1); scalability_structures_[current_ss_idx_] = gof; - scalability_structures_[current_ss_idx_].pid_start = frame->id.picture_id; - gof_info_.emplace(unwrapped_tl0, - GofInfo(&scalability_structures_[current_ss_idx_], - frame->id.picture_id)); + scalability_structures_[current_ss_idx_].pid_start = frame->Id(); + gof_info_.emplace( + unwrapped_tl0, + GofInfo(&scalability_structures_[current_ss_idx_], frame->Id())); } const auto gof_info_it = gof_info_.find(unwrapped_tl0); @@ -118,7 +117,7 @@ RtpVp9RefFinder::FrameDecision RtpVp9RefFinder::ManageFrameInternal( if (frame->frame_type() == VideoFrameType::kVideoFrameKey) { frame->num_references = 0; - FrameReceivedVp9(frame->id.picture_id, info); + FrameReceivedVp9(frame->Id(), info); FlattenFrameIdAndRefs(frame, codec_header.inter_layer_predicted); return kHandOff; } @@ -134,7 +133,7 @@ RtpVp9RefFinder::FrameDecision RtpVp9RefFinder::ManageFrameInternal( info = &gof_info_it->second; frame->num_references = 0; - FrameReceivedVp9(frame->id.picture_id, info); + FrameReceivedVp9(frame->Id(), info); FlattenFrameIdAndRefs(frame, codec_header.inter_layer_predicted); return kHandOff; } else { @@ -147,8 +146,8 @@ RtpVp9RefFinder::FrameDecision RtpVp9RefFinder::ManageFrameInternal( if (codec_header.temporal_idx == 0) { gof_info_it = gof_info_ - .emplace(unwrapped_tl0, GofInfo(gof_info_it->second.gof, - frame->id.picture_id)) + .emplace(unwrapped_tl0, + GofInfo(gof_info_it->second.gof, frame->Id())) .first; } @@ -160,23 +159,23 @@ RtpVp9RefFinder::FrameDecision RtpVp9RefFinder::ManageFrameInternal( auto clean_gof_info_to = gof_info_.lower_bound(old_tl0_pic_idx); gof_info_.erase(gof_info_.begin(), clean_gof_info_to); - FrameReceivedVp9(frame->id.picture_id, info); + FrameReceivedVp9(frame->Id(), info); // Make sure we don't miss any frame that could potentially have the // up switch flag set. - if (MissingRequiredFrameVp9(frame->id.picture_id, *info)) + if (MissingRequiredFrameVp9(frame->Id(), *info)) return kStash; if (codec_header.temporal_up_switch) - up_switch_.emplace(frame->id.picture_id, codec_header.temporal_idx); + up_switch_.emplace(frame->Id(), codec_header.temporal_idx); // Clean out old info about up switch frames. - uint16_t old_picture_id = Subtract(frame->id.picture_id, 50); + uint16_t old_picture_id = Subtract(frame->Id(), 50); auto up_switch_erase_to = up_switch_.lower_bound(old_picture_id); up_switch_.erase(up_switch_.begin(), up_switch_erase_to); - size_t diff = ForwardDiff(info->gof->pid_start, - frame->id.picture_id); + size_t diff = + ForwardDiff(info->gof->pid_start, frame->Id()); size_t gof_idx = diff % info->gof->num_frames_in_gof; if (info->gof->num_ref_pics[gof_idx] > EncodedFrame::kMaxFrameReferences) { @@ -185,12 +184,12 @@ RtpVp9RefFinder::FrameDecision RtpVp9RefFinder::ManageFrameInternal( // Populate references according to the scalability structure. frame->num_references = info->gof->num_ref_pics[gof_idx]; for (size_t i = 0; i < frame->num_references; ++i) { - frame->references[i] = Subtract( - frame->id.picture_id, info->gof->pid_diff[gof_idx][i]); + frame->references[i] = + Subtract(frame->Id(), info->gof->pid_diff[gof_idx][i]); // If this is a reference to a frame earlier than the last up switch point, // then ignore this reference. - if (UpSwitchInIntervalVp9(frame->id.picture_id, codec_header.temporal_idx, + if (UpSwitchInIntervalVp9(frame->Id(), codec_header.temporal_idx, frame->references[i])) { --frame->num_references; } @@ -330,13 +329,12 @@ void RtpVp9RefFinder::FlattenFrameIdAndRefs(RtpFrameObject* frame, unwrapper_.Unwrap(frame->references[i]) * kMaxSpatialLayers + *frame->SpatialIndex(); } - frame->id.picture_id = - unwrapper_.Unwrap(frame->id.picture_id) * kMaxSpatialLayers + - *frame->SpatialIndex(); + frame->SetId(unwrapper_.Unwrap(frame->Id()) * kMaxSpatialLayers + + *frame->SpatialIndex()); if (inter_layer_predicted && frame->num_references + 1 <= EncodedFrame::kMaxFrameReferences) { - frame->references[frame->num_references] = frame->id.picture_id - 1; + frame->references[frame->num_references] = frame->Id() - 1; ++frame->num_references; } } @@ -352,5 +350,4 @@ void RtpVp9RefFinder::ClearTo(uint16_t seq_num) { } } -} // namespace video_coding } // namespace webrtc diff --git a/modules/video_coding/rtp_vp9_ref_finder.h b/modules/video_coding/rtp_vp9_ref_finder.h index 1ccfa3b1ed..81008fea88 100644 --- a/modules/video_coding/rtp_vp9_ref_finder.h +++ b/modules/video_coding/rtp_vp9_ref_finder.h @@ -22,7 +22,6 @@ #include "rtc_base/numerics/sequence_number_util.h" namespace webrtc { -namespace video_coding { class RtpVp9RefFinder { public: @@ -96,7 +95,6 @@ class RtpVp9RefFinder { SeqNumUnwrapper tl0_unwrapper_; }; -} // namespace video_coding } // namespace webrtc #endif // MODULES_VIDEO_CODING_RTP_VP9_REF_FINDER_H_ diff --git a/modules/video_coding/rtp_vp9_ref_finder_unittest.cc b/modules/video_coding/rtp_vp9_ref_finder_unittest.cc index aa883c8508..6de7ce106f 100644 --- a/modules/video_coding/rtp_vp9_ref_finder_unittest.cc +++ b/modules/video_coding/rtp_vp9_ref_finder_unittest.cc @@ -26,7 +26,6 @@ using ::testing::Property; using ::testing::UnorderedElementsAreArray; namespace webrtc { -namespace video_coding { namespace { class Frame { @@ -163,7 +162,7 @@ class HasFrameMatcher : public MatcherInterface { MatchResultListener* result_listener) const override { auto it = std::find_if(frames.begin(), frames.end(), [this](const std::unique_ptr& f) { - return f->id.picture_id == frame_id_; + return f->Id() == frame_id_; }); if (it == frames.end()) { if (result_listener->IsInterested()) { @@ -635,7 +634,7 @@ TEST_F(RtpVp9RefFinderTest, WrappingFlexReference) { ASSERT_EQ(1UL, frames_.size()); const EncodedFrame& frame = *frames_[0]; - ASSERT_EQ(frame.id.picture_id - frame.references[0], 5); + ASSERT_EQ(frame.Id() - frame.references[0], 5); } TEST_F(RtpVp9RefFinderTest, GofPidJump) { @@ -703,5 +702,4 @@ TEST_F(RtpVp9RefFinderTest, SpatialIndex) { Contains(Pointee(Property(&EncodedFrame::SpatialIndex, 2)))); } -} // namespace video_coding } // namespace webrtc diff --git a/modules/video_coding/session_info.cc b/modules/video_coding/session_info.cc index 07b9a9d6b5..477bbbe209 100644 --- a/modules/video_coding/session_info.cc +++ b/modules/video_coding/session_info.cc @@ -49,7 +49,7 @@ void VCMSessionInfo::UpdateDataPointers(const uint8_t* old_base_ptr, const uint8_t* new_base_ptr) { for (PacketIterator it = packets_.begin(); it != packets_.end(); ++it) if ((*it).dataPtr != NULL) { - assert(old_base_ptr != NULL && new_base_ptr != NULL); + RTC_DCHECK(old_base_ptr != NULL && new_base_ptr != NULL); (*it).dataPtr = new_base_ptr + ((*it).dataPtr - old_base_ptr); } } @@ -348,7 +348,7 @@ VCMSessionInfo::PacketIterator VCMSessionInfo::FindNextPartitionBeginning( VCMSessionInfo::PacketIterator VCMSessionInfo::FindPartitionEnd( PacketIterator it) const { - assert((*it).codec() == kVideoCodecVP8); + RTC_DCHECK_EQ((*it).codec(), kVideoCodecVP8); PacketIterator prev_it = it; const int partition_id = absl::get((*it).video_header.video_type_header) diff --git a/modules/video_coding/svc/BUILD.gn b/modules/video_coding/svc/BUILD.gn index 3e93b897b4..2eb25025c1 100644 --- a/modules/video_coding/svc/BUILD.gn +++ b/modules/video_coding/svc/BUILD.gn @@ -34,24 +34,10 @@ rtc_source_set("scalability_structures") { "scalability_structure_full_svc.h", "scalability_structure_key_svc.cc", "scalability_structure_key_svc.h", - "scalability_structure_l1t2.cc", - "scalability_structure_l1t2.h", - "scalability_structure_l1t3.cc", - "scalability_structure_l1t3.h", - "scalability_structure_l2t1.cc", - "scalability_structure_l2t1.h", - "scalability_structure_l2t1h.cc", - "scalability_structure_l2t1h.h", - "scalability_structure_l2t2.cc", - "scalability_structure_l2t2.h", "scalability_structure_l2t2_key_shift.cc", "scalability_structure_l2t2_key_shift.h", - "scalability_structure_l3t1.cc", - "scalability_structure_l3t1.h", - "scalability_structure_l3t3.cc", - "scalability_structure_l3t3.h", - "scalability_structure_s2t1.cc", - "scalability_structure_s2t1.h", + "scalability_structure_simulcast.cc", + "scalability_structure_simulcast.h", ] deps = [ ":scalable_video_controller", @@ -89,9 +75,9 @@ if (rtc_include_tests) { rtc_source_set("scalability_structure_tests") { testonly = true sources = [ + "scalability_structure_full_svc_unittest.cc", "scalability_structure_key_svc_unittest.cc", "scalability_structure_l2t2_key_shift_unittest.cc", - "scalability_structure_l3t3_unittest.cc", "scalability_structure_test_helpers.cc", "scalability_structure_test_helpers.h", "scalability_structure_unittest.cc", diff --git a/modules/video_coding/svc/create_scalability_structure.cc b/modules/video_coding/svc/create_scalability_structure.cc index 4b4a23ed24..39710d82ff 100644 --- a/modules/video_coding/svc/create_scalability_structure.cc +++ b/modules/video_coding/svc/create_scalability_structure.cc @@ -12,16 +12,10 @@ #include #include "absl/strings/string_view.h" +#include "modules/video_coding/svc/scalability_structure_full_svc.h" #include "modules/video_coding/svc/scalability_structure_key_svc.h" -#include "modules/video_coding/svc/scalability_structure_l1t2.h" -#include "modules/video_coding/svc/scalability_structure_l1t3.h" -#include "modules/video_coding/svc/scalability_structure_l2t1.h" -#include "modules/video_coding/svc/scalability_structure_l2t1h.h" -#include "modules/video_coding/svc/scalability_structure_l2t2.h" #include "modules/video_coding/svc/scalability_structure_l2t2_key_shift.h" -#include "modules/video_coding/svc/scalability_structure_l3t1.h" -#include "modules/video_coding/svc/scalability_structure_l3t3.h" -#include "modules/video_coding/svc/scalability_structure_s2t1.h" +#include "modules/video_coding/svc/scalability_structure_simulcast.h" #include "modules/video_coding/svc/scalable_video_controller.h" #include "modules/video_coding/svc/scalable_video_controller_no_layering.h" #include "rtc_base/checks.h" @@ -41,20 +35,31 @@ std::unique_ptr Create() { return std::make_unique(); } +template +std::unique_ptr CreateH() { + // 1.5:1 scaling, see https://w3c.github.io/webrtc-svc/#scalabilitymodes* + typename T::ScalingFactor factor; + factor.num = 2; + factor.den = 3; + return std::make_unique(factor); +} + constexpr NamedStructureFactory kFactories[] = { {"NONE", Create}, {"L1T2", Create}, {"L1T3", Create}, {"L2T1", Create}, - {"L2T1h", Create}, + {"L2T1h", CreateH}, {"L2T1_KEY", Create}, {"L2T2", Create}, {"L2T2_KEY", Create}, {"L2T2_KEY_SHIFT", Create}, + {"L2T3_KEY", Create}, {"L3T1", Create}, {"L3T3", Create}, {"L3T3_KEY", Create}, {"S2T1", Create}, + {"S3T3", Create}, }; } // namespace diff --git a/modules/video_coding/svc/scalability_structure_full_svc.cc b/modules/video_coding/svc/scalability_structure_full_svc.cc index 5454622924..b89de99330 100644 --- a/modules/video_coding/svc/scalability_structure_full_svc.cc +++ b/modules/video_coding/svc/scalability_structure_full_svc.cc @@ -19,9 +19,6 @@ #include "rtc_base/logging.h" namespace webrtc { -namespace { -enum : int { kKey, kDelta }; -} // namespace constexpr int ScalabilityStructureFullSvc::kMaxNumSpatialLayers; constexpr int ScalabilityStructureFullSvc::kMaxNumTemporalLayers; @@ -29,9 +26,11 @@ constexpr absl::string_view ScalabilityStructureFullSvc::kFramePatternNames[]; ScalabilityStructureFullSvc::ScalabilityStructureFullSvc( int num_spatial_layers, - int num_temporal_layers) + int num_temporal_layers, + ScalingFactor resolution_factor) : num_spatial_layers_(num_spatial_layers), num_temporal_layers_(num_temporal_layers), + resolution_factor_(resolution_factor), active_decode_targets_( (uint32_t{1} << (num_spatial_layers * num_temporal_layers)) - 1) { RTC_DCHECK_LE(num_spatial_layers, kMaxNumSpatialLayers); @@ -48,8 +47,10 @@ ScalabilityStructureFullSvc::StreamConfig() const { result.scaling_factor_num[num_spatial_layers_ - 1] = 1; result.scaling_factor_den[num_spatial_layers_ - 1] = 1; for (int sid = num_spatial_layers_ - 1; sid > 0; --sid) { - result.scaling_factor_num[sid - 1] = 1; - result.scaling_factor_den[sid - 1] = 2 * result.scaling_factor_den[sid]; + result.scaling_factor_num[sid - 1] = + resolution_factor_.num * result.scaling_factor_num[sid]; + result.scaling_factor_den[sid - 1] = + resolution_factor_.den * result.scaling_factor_den[sid]; } return result; } @@ -98,6 +99,7 @@ ScalabilityStructureFullSvc::FramePattern ScalabilityStructureFullSvc::NextPattern() const { switch (last_pattern_) { case kNone: + return kKey; case kDeltaT2B: return kDeltaT0; case kDeltaT2A: @@ -110,6 +112,7 @@ ScalabilityStructureFullSvc::NextPattern() const { return kDeltaT2B; } return kDeltaT0; + case kKey: case kDeltaT0: if (TemporalLayerIsActive(2)) { return kDeltaT2A; @@ -119,6 +122,8 @@ ScalabilityStructureFullSvc::NextPattern() const { } return kDeltaT0; } + RTC_NOTREACHED(); + return kNone; } std::vector @@ -139,6 +144,7 @@ ScalabilityStructureFullSvc::NextFrameConfig(bool restart) { absl::optional spatial_dependency_buffer_id; switch (current_pattern) { case kDeltaT0: + case kKey: // Disallow temporal references cross T0 on higher temporal layers. can_reference_t1_frame_for_spatial_id_.reset(); for (int sid = 0; sid < num_spatial_layers_; ++sid) { @@ -150,11 +156,11 @@ ScalabilityStructureFullSvc::NextFrameConfig(bool restart) { } configs.emplace_back(); ScalableVideoController::LayerFrameConfig& config = configs.back(); - config.Id(last_pattern_ == kNone ? kKey : kDelta).S(sid).T(0); + config.Id(current_pattern).S(sid).T(0); if (spatial_dependency_buffer_id) { config.Reference(*spatial_dependency_buffer_id); - } else if (last_pattern_ == kNone) { + } else if (current_pattern == kKey) { config.Keyframe(); } @@ -178,7 +184,7 @@ ScalabilityStructureFullSvc::NextFrameConfig(bool restart) { } configs.emplace_back(); ScalableVideoController::LayerFrameConfig& config = configs.back(); - config.Id(kDelta).S(sid).T(1); + config.Id(current_pattern).S(sid).T(1); // Temporal reference. config.Reference(BufferIndex(sid, /*tid=*/0)); // Spatial reference unless this is the lowest active spatial layer. @@ -201,7 +207,7 @@ ScalabilityStructureFullSvc::NextFrameConfig(bool restart) { } configs.emplace_back(); ScalableVideoController::LayerFrameConfig& config = configs.back(); - config.Id(kDelta).S(sid).T(2); + config.Id(current_pattern).S(sid).T(2); // Temporal reference. if (current_pattern == kDeltaT2B && can_reference_t1_frame_for_spatial_id_[sid]) { @@ -239,12 +245,16 @@ ScalabilityStructureFullSvc::NextFrameConfig(bool restart) { return NextFrameConfig(/*restart=*/true); } - last_pattern_ = current_pattern; return configs; } GenericFrameInfo ScalabilityStructureFullSvc::OnEncodeDone( const LayerFrameConfig& config) { + // When encoder drops all frames for a temporal unit, it is better to reuse + // old temporal pattern rather than switch to next one, thus switch to next + // pattern defered here from the `NextFrameConfig`. + // In particular creating VP9 references rely on this behavior. + last_pattern_ = static_cast(config.Id()); if (config.TemporalId() == 1) { can_reference_t1_frame_for_spatial_id_.set(config.SpatialId()); } @@ -285,4 +295,104 @@ void ScalabilityStructureFullSvc::OnRatesUpdated( } } +FrameDependencyStructure ScalabilityStructureL1T2::DependencyStructure() const { + FrameDependencyStructure structure; + structure.num_decode_targets = 2; + structure.num_chains = 1; + structure.decode_target_protected_by_chain = {0, 0}; + structure.templates.resize(3); + structure.templates[0].T(0).Dtis("SS").ChainDiffs({0}); + structure.templates[1].T(0).Dtis("SS").ChainDiffs({2}).FrameDiffs({2}); + structure.templates[2].T(1).Dtis("-D").ChainDiffs({1}).FrameDiffs({1}); + return structure; +} + +FrameDependencyStructure ScalabilityStructureL1T3::DependencyStructure() const { + FrameDependencyStructure structure; + structure.num_decode_targets = 3; + structure.num_chains = 1; + structure.decode_target_protected_by_chain = {0, 0, 0}; + structure.templates.resize(5); + structure.templates[0].T(0).Dtis("SSS").ChainDiffs({0}); + structure.templates[1].T(0).Dtis("SSS").ChainDiffs({4}).FrameDiffs({4}); + structure.templates[2].T(1).Dtis("-DS").ChainDiffs({2}).FrameDiffs({2}); + structure.templates[3].T(2).Dtis("--D").ChainDiffs({1}).FrameDiffs({1}); + structure.templates[4].T(2).Dtis("--D").ChainDiffs({3}).FrameDiffs({1}); + return structure; +} + +FrameDependencyStructure ScalabilityStructureL2T1::DependencyStructure() const { + FrameDependencyStructure structure; + structure.num_decode_targets = 2; + structure.num_chains = 2; + structure.decode_target_protected_by_chain = {0, 1}; + structure.templates.resize(4); + structure.templates[0].S(0).Dtis("SR").ChainDiffs({2, 1}).FrameDiffs({2}); + structure.templates[1].S(0).Dtis("SS").ChainDiffs({0, 0}); + structure.templates[2].S(1).Dtis("-S").ChainDiffs({1, 1}).FrameDiffs({2, 1}); + structure.templates[3].S(1).Dtis("-S").ChainDiffs({1, 1}).FrameDiffs({1}); + return structure; +} + +FrameDependencyStructure ScalabilityStructureL2T2::DependencyStructure() const { + FrameDependencyStructure structure; + structure.num_decode_targets = 4; + structure.num_chains = 2; + structure.decode_target_protected_by_chain = {0, 0, 1, 1}; + structure.templates.resize(6); + auto& templates = structure.templates; + templates[0].S(0).T(0).Dtis("SSSS").ChainDiffs({0, 0}); + templates[1].S(0).T(0).Dtis("SSRR").ChainDiffs({4, 3}).FrameDiffs({4}); + templates[2].S(0).T(1).Dtis("-D-R").ChainDiffs({2, 1}).FrameDiffs({2}); + templates[3].S(1).T(0).Dtis("--SS").ChainDiffs({1, 1}).FrameDiffs({1}); + templates[4].S(1).T(0).Dtis("--SS").ChainDiffs({1, 1}).FrameDiffs({4, 1}); + templates[5].S(1).T(1).Dtis("---D").ChainDiffs({3, 2}).FrameDiffs({2, 1}); + return structure; +} + +FrameDependencyStructure ScalabilityStructureL3T1::DependencyStructure() const { + FrameDependencyStructure structure; + structure.num_decode_targets = 3; + structure.num_chains = 3; + structure.decode_target_protected_by_chain = {0, 1, 2}; + auto& templates = structure.templates; + templates.resize(6); + templates[0].S(0).Dtis("SRR").ChainDiffs({3, 2, 1}).FrameDiffs({3}); + templates[1].S(0).Dtis("SSS").ChainDiffs({0, 0, 0}); + templates[2].S(1).Dtis("-SR").ChainDiffs({1, 1, 1}).FrameDiffs({3, 1}); + templates[3].S(1).Dtis("-SS").ChainDiffs({1, 1, 1}).FrameDiffs({1}); + templates[4].S(2).Dtis("--S").ChainDiffs({2, 1, 1}).FrameDiffs({3, 1}); + templates[5].S(2).Dtis("--S").ChainDiffs({2, 1, 1}).FrameDiffs({1}); + return structure; +} + +FrameDependencyStructure ScalabilityStructureL3T3::DependencyStructure() const { + FrameDependencyStructure structure; + structure.num_decode_targets = 9; + structure.num_chains = 3; + structure.decode_target_protected_by_chain = {0, 0, 0, 1, 1, 1, 2, 2, 2}; + auto& t = structure.templates; + t.resize(15); + // Templates are shown in the order frames following them appear in the + // stream, but in `structure.templates` array templates are sorted by + // (`spatial_id`, `temporal_id`) since that is a dependency descriptor + // requirement. Indexes are written in hex for nicer alignment. + t[0x1].S(0).T(0).Dtis("SSSSSSSSS").ChainDiffs({0, 0, 0}); + t[0x6].S(1).T(0).Dtis("---SSSSSS").ChainDiffs({1, 1, 1}).FrameDiffs({1}); + t[0xB].S(2).T(0).Dtis("------SSS").ChainDiffs({2, 1, 1}).FrameDiffs({1}); + t[0x3].S(0).T(2).Dtis("--D--R--R").ChainDiffs({3, 2, 1}).FrameDiffs({3}); + t[0x8].S(1).T(2).Dtis("-----D--R").ChainDiffs({4, 3, 2}).FrameDiffs({3, 1}); + t[0xD].S(2).T(2).Dtis("--------D").ChainDiffs({5, 4, 3}).FrameDiffs({3, 1}); + t[0x2].S(0).T(1).Dtis("-DS-RR-RR").ChainDiffs({6, 5, 4}).FrameDiffs({6}); + t[0x7].S(1).T(1).Dtis("----DS-RR").ChainDiffs({7, 6, 5}).FrameDiffs({6, 1}); + t[0xC].S(2).T(1).Dtis("-------DS").ChainDiffs({8, 7, 6}).FrameDiffs({6, 1}); + t[0x4].S(0).T(2).Dtis("--D--R--R").ChainDiffs({9, 8, 7}).FrameDiffs({3}); + t[0x9].S(1).T(2).Dtis("-----D--R").ChainDiffs({10, 9, 8}).FrameDiffs({3, 1}); + t[0xE].S(2).T(2).Dtis("--------D").ChainDiffs({11, 10, 9}).FrameDiffs({3, 1}); + t[0x0].S(0).T(0).Dtis("SSSRRRRRR").ChainDiffs({12, 11, 10}).FrameDiffs({12}); + t[0x5].S(1).T(0).Dtis("---SSSRRR").ChainDiffs({1, 1, 1}).FrameDiffs({12, 1}); + t[0xA].S(2).T(0).Dtis("------SSS").ChainDiffs({2, 1, 1}).FrameDiffs({12, 1}); + return structure; +} + } // namespace webrtc diff --git a/modules/video_coding/svc/scalability_structure_full_svc.h b/modules/video_coding/svc/scalability_structure_full_svc.h index d490d6e4a1..a3cad0af8a 100644 --- a/modules/video_coding/svc/scalability_structure_full_svc.h +++ b/modules/video_coding/svc/scalability_structure_full_svc.h @@ -21,7 +21,13 @@ namespace webrtc { class ScalabilityStructureFullSvc : public ScalableVideoController { public: - ScalabilityStructureFullSvc(int num_spatial_layers, int num_temporal_layers); + struct ScalingFactor { + int num = 1; + int den = 2; + }; + ScalabilityStructureFullSvc(int num_spatial_layers, + int num_temporal_layers, + ScalingFactor resolution_factor); ~ScalabilityStructureFullSvc() override; StreamLayersConfig StreamConfig() const override; @@ -33,13 +39,14 @@ class ScalabilityStructureFullSvc : public ScalableVideoController { private: enum FramePattern { kNone, + kKey, kDeltaT2A, kDeltaT1, kDeltaT2B, kDeltaT0, }; static constexpr absl::string_view kFramePatternNames[] = { - "None", "DeltaT2A", "DeltaT1", "DeltaT2B", "DeltaT0"}; + "None", "Key", "DeltaT2A", "DeltaT1", "DeltaT2B", "DeltaT0"}; static constexpr int kMaxNumSpatialLayers = 3; static constexpr int kMaxNumTemporalLayers = 3; @@ -61,6 +68,7 @@ class ScalabilityStructureFullSvc : public ScalableVideoController { const int num_spatial_layers_; const int num_temporal_layers_; + const ScalingFactor resolution_factor_; FramePattern last_pattern_ = kNone; std::bitset can_reference_t0_frame_for_spatial_id_ = 0; @@ -68,6 +76,88 @@ class ScalabilityStructureFullSvc : public ScalableVideoController { std::bitset<32> active_decode_targets_; }; +// T1 0 0 +// / / / ... +// T0 0---0---0-- +// Time-> 0 1 2 3 4 +class ScalabilityStructureL1T2 : public ScalabilityStructureFullSvc { + public: + explicit ScalabilityStructureL1T2(ScalingFactor resolution_factor = {}) + : ScalabilityStructureFullSvc(1, 2, resolution_factor) {} + ~ScalabilityStructureL1T2() override = default; + + FrameDependencyStructure DependencyStructure() const override; +}; + +// T2 0 0 0 0 +// | / | / +// T1 / 0 / 0 ... +// |_/ |_/ +// T0 0-------0------ +// Time-> 0 1 2 3 4 5 6 7 +class ScalabilityStructureL1T3 : public ScalabilityStructureFullSvc { + public: + explicit ScalabilityStructureL1T3(ScalingFactor resolution_factor = {}) + : ScalabilityStructureFullSvc(1, 3, resolution_factor) {} + ~ScalabilityStructureL1T3() override = default; + + FrameDependencyStructure DependencyStructure() const override; +}; + +// S1 0--0--0- +// | | | ... +// S0 0--0--0- +class ScalabilityStructureL2T1 : public ScalabilityStructureFullSvc { + public: + explicit ScalabilityStructureL2T1(ScalingFactor resolution_factor = {}) + : ScalabilityStructureFullSvc(2, 1, resolution_factor) {} + ~ScalabilityStructureL2T1() override = default; + + FrameDependencyStructure DependencyStructure() const override; +}; + +// S1T1 0 0 +// /| /| / +// S1T0 0-+-0-+-0 +// | | | | | ... +// S0T1 | 0 | 0 | +// |/ |/ |/ +// S0T0 0---0---0-- +// Time-> 0 1 2 3 4 +class ScalabilityStructureL2T2 : public ScalabilityStructureFullSvc { + public: + explicit ScalabilityStructureL2T2(ScalingFactor resolution_factor = {}) + : ScalabilityStructureFullSvc(2, 2, resolution_factor) {} + ~ScalabilityStructureL2T2() override = default; + + FrameDependencyStructure DependencyStructure() const override; +}; + +// S2 0-0-0- +// | | | +// S1 0-0-0-... +// | | | +// S0 0-0-0- +// Time-> 0 1 2 +class ScalabilityStructureL3T1 : public ScalabilityStructureFullSvc { + public: + explicit ScalabilityStructureL3T1(ScalingFactor resolution_factor = {}) + : ScalabilityStructureFullSvc(3, 1, resolution_factor) {} + ~ScalabilityStructureL3T1() override = default; + + FrameDependencyStructure DependencyStructure() const override; +}; + +// https://www.w3.org/TR/webrtc-svc/#L3T3* +class ScalabilityStructureL3T3 : public ScalabilityStructureFullSvc { + public: + explicit ScalabilityStructureL3T3(ScalingFactor resolution_factor = {}) + : ScalabilityStructureFullSvc(3, 3, resolution_factor) {} + ~ScalabilityStructureL3T3() override = default; + + FrameDependencyStructure DependencyStructure() const override; +}; + } // namespace webrtc #endif // MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_FULL_SVC_H_ diff --git a/modules/video_coding/svc/scalability_structure_l3t3_unittest.cc b/modules/video_coding/svc/scalability_structure_full_svc_unittest.cc similarity index 74% rename from modules/video_coding/svc/scalability_structure_l3t3_unittest.cc rename to modules/video_coding/svc/scalability_structure_full_svc_unittest.cc index ca66fa8f2b..9ccbe21f75 100644 --- a/modules/video_coding/svc/scalability_structure_l3t3_unittest.cc +++ b/modules/video_coding/svc/scalability_structure_full_svc_unittest.cc @@ -7,7 +7,7 @@ * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ -#include "modules/video_coding/svc/scalability_structure_l3t3.h" +#include "modules/video_coding/svc/scalability_structure_full_svc.h" #include @@ -59,15 +59,32 @@ TEST(ScalabilityStructureL3T3Test, SkipT1FrameByEncoderKeepsReferencesValid) { // one more temporal units (T2) wrapper.GenerateFrames(/*num_temporal_units=*/1, frames); - ASSERT_THAT(frames, SizeIs(9)); - EXPECT_EQ(frames[0].temporal_id, 0); - EXPECT_EQ(frames[3].temporal_id, 2); - // T1 frame was dropped by the encoder. - EXPECT_EQ(frames[6].temporal_id, 2); - EXPECT_TRUE(wrapper.FrameReferencesAreValid(frames)); } +TEST(ScalabilityStructureL3T3Test, + SkippingFrameReusePreviousFrameConfiguration) { + std::vector frames; + ScalabilityStructureL3T3 structure; + ScalabilityStructureWrapper wrapper(structure); + + // 1st 2 temporal units (T0 and T2) + wrapper.GenerateFrames(/*num_temporal_units=*/2, frames); + ASSERT_THAT(frames, SizeIs(6)); + ASSERT_EQ(frames[0].temporal_id, 0); + ASSERT_EQ(frames[3].temporal_id, 2); + + // Simulate a frame dropped by the encoder, + // i.e. retrieve config, but skip calling OnEncodeDone. + structure.NextFrameConfig(/*restart=*/false); + // two more temporal unit, expect temporal pattern continues + wrapper.GenerateFrames(/*num_temporal_units=*/2, frames); + ASSERT_THAT(frames, SizeIs(12)); + // Expect temporal pattern continues as if there were no dropped frames. + EXPECT_EQ(frames[6].temporal_id, 1); + EXPECT_EQ(frames[9].temporal_id, 2); +} + TEST(ScalabilityStructureL3T3Test, SwitchSpatialLayerBeforeT1Frame) { ScalabilityStructureL3T3 structure; ScalabilityStructureWrapper wrapper(structure); diff --git a/modules/video_coding/svc/scalability_structure_key_svc.cc b/modules/video_coding/svc/scalability_structure_key_svc.cc index 9399c0cf7e..1cee80e84b 100644 --- a/modules/video_coding/svc/scalability_structure_key_svc.cc +++ b/modules/video_coding/svc/scalability_structure_key_svc.cc @@ -22,28 +22,6 @@ #include "rtc_base/logging.h" namespace webrtc { -namespace { -// Values to use as LayerFrameConfig::Id -enum : int { kKey, kDelta }; - -DecodeTargetIndication -Dti(int sid, int tid, const ScalableVideoController::LayerFrameConfig& config) { - if (config.IsKeyframe() || config.Id() == kKey) { - RTC_DCHECK_EQ(config.TemporalId(), 0); - return sid < config.SpatialId() ? DecodeTargetIndication::kNotPresent - : DecodeTargetIndication::kSwitch; - } - - if (sid != config.SpatialId() || tid < config.TemporalId()) { - return DecodeTargetIndication::kNotPresent; - } - if (tid == config.TemporalId() && tid > 0) { - return DecodeTargetIndication::kDiscardable; - } - return DecodeTargetIndication::kSwitch; -} - -} // namespace constexpr int ScalabilityStructureKeySvc::kMaxNumSpatialLayers; constexpr int ScalabilityStructureKeySvc::kMaxNumTemporalLayers; @@ -88,6 +66,25 @@ bool ScalabilityStructureKeySvc::TemporalLayerIsActive(int tid) const { return false; } +DecodeTargetIndication ScalabilityStructureKeySvc::Dti( + int sid, + int tid, + const LayerFrameConfig& config) { + if (config.IsKeyframe() || config.Id() == kKey) { + RTC_DCHECK_EQ(config.TemporalId(), 0); + return sid < config.SpatialId() ? DecodeTargetIndication::kNotPresent + : DecodeTargetIndication::kSwitch; + } + + if (sid != config.SpatialId() || tid < config.TemporalId()) { + return DecodeTargetIndication::kNotPresent; + } + if (tid == config.TemporalId() && tid > 0) { + return DecodeTargetIndication::kDiscardable; + } + return DecodeTargetIndication::kSwitch; +} + std::vector ScalabilityStructureKeySvc::KeyframeConfig() { std::vector configs; @@ -129,7 +126,7 @@ ScalabilityStructureKeySvc::T0Config() { continue; } configs.emplace_back(); - configs.back().Id(kDelta).S(sid).T(0).ReferenceAndUpdate( + configs.back().Id(kDeltaT0).S(sid).T(0).ReferenceAndUpdate( BufferIndex(sid, /*tid=*/0)); } return configs; @@ -145,7 +142,7 @@ ScalabilityStructureKeySvc::T1Config() { } configs.emplace_back(); ScalableVideoController::LayerFrameConfig& config = configs.back(); - config.Id(kDelta).S(sid).T(1).Reference(BufferIndex(sid, /*tid=*/0)); + config.Id(kDeltaT1).S(sid).T(1).Reference(BufferIndex(sid, /*tid=*/0)); if (num_temporal_layers_ > 2) { config.Update(BufferIndex(sid, /*tid=*/1)); } @@ -154,7 +151,7 @@ ScalabilityStructureKeySvc::T1Config() { } std::vector -ScalabilityStructureKeySvc::T2Config() { +ScalabilityStructureKeySvc::T2Config(FramePattern pattern) { std::vector configs; configs.reserve(num_spatial_layers_); for (int sid = 0; sid < num_spatial_layers_; ++sid) { @@ -163,7 +160,7 @@ ScalabilityStructureKeySvc::T2Config() { } configs.emplace_back(); ScalableVideoController::LayerFrameConfig& config = configs.back(); - config.Id(kDelta).S(sid).T(2); + config.Id(pattern).S(sid).T(2); if (can_reference_t1_frame_for_spatial_id_[sid]) { config.Reference(BufferIndex(sid, /*tid=*/1)); } else { @@ -173,6 +170,37 @@ ScalabilityStructureKeySvc::T2Config() { return configs; } +ScalabilityStructureKeySvc::FramePattern +ScalabilityStructureKeySvc::NextPattern(FramePattern last_pattern) const { + switch (last_pattern) { + case kNone: + return kKey; + case kDeltaT2B: + return kDeltaT0; + case kDeltaT2A: + if (TemporalLayerIsActive(1)) { + return kDeltaT1; + } + return kDeltaT0; + case kDeltaT1: + if (TemporalLayerIsActive(2)) { + return kDeltaT2B; + } + return kDeltaT0; + case kDeltaT0: + case kKey: + if (TemporalLayerIsActive(2)) { + return kDeltaT2A; + } + if (TemporalLayerIsActive(1)) { + return kDeltaT1; + } + return kDeltaT0; + } + RTC_NOTREACHED(); + return kNone; +} + std::vector ScalabilityStructureKeySvc::NextFrameConfig(bool restart) { if (active_decode_targets_.none()) { @@ -184,37 +212,19 @@ ScalabilityStructureKeySvc::NextFrameConfig(bool restart) { last_pattern_ = kNone; } - switch (last_pattern_) { - case kNone: - last_pattern_ = kDeltaT0; + FramePattern current_pattern = NextPattern(last_pattern_); + switch (current_pattern) { + case kKey: return KeyframeConfig(); - case kDeltaT2B: - last_pattern_ = kDeltaT0; - return T0Config(); - case kDeltaT2A: - if (TemporalLayerIsActive(1)) { - last_pattern_ = kDeltaT1; - return T1Config(); - } - last_pattern_ = kDeltaT0; - return T0Config(); - case kDeltaT1: - if (TemporalLayerIsActive(2)) { - last_pattern_ = kDeltaT2B; - return T2Config(); - } - last_pattern_ = kDeltaT0; - return T0Config(); case kDeltaT0: - if (TemporalLayerIsActive(2)) { - last_pattern_ = kDeltaT2A; - return T2Config(); - } else if (TemporalLayerIsActive(1)) { - last_pattern_ = kDeltaT1; - return T1Config(); - } - last_pattern_ = kDeltaT0; return T0Config(); + case kDeltaT1: + return T1Config(); + case kDeltaT2A: + case kDeltaT2B: + return T2Config(current_pattern); + case kNone: + break; } RTC_NOTREACHED(); return {}; @@ -222,6 +232,11 @@ ScalabilityStructureKeySvc::NextFrameConfig(bool restart) { GenericFrameInfo ScalabilityStructureKeySvc::OnEncodeDone( const LayerFrameConfig& config) { + // When encoder drops all frames for a temporal unit, it is better to reuse + // old temporal pattern rather than switch to next one, thus switch to next + // pattern defered here from the `NextFrameConfig`. + // In particular creating VP9 references rely on this behavior. + last_pattern_ = static_cast(config.Id()); if (config.TemporalId() == 1) { can_reference_t1_frame_for_spatial_id_.set(config.SpatialId()); } @@ -304,6 +319,29 @@ FrameDependencyStructure ScalabilityStructureL2T2Key::DependencyStructure() return structure; } +ScalabilityStructureL2T3Key::~ScalabilityStructureL2T3Key() = default; + +FrameDependencyStructure ScalabilityStructureL2T3Key::DependencyStructure() + const { + FrameDependencyStructure structure; + structure.num_decode_targets = 6; + structure.num_chains = 2; + structure.decode_target_protected_by_chain = {0, 0, 0, 1, 1, 1}; + auto& templates = structure.templates; + templates.resize(10); + templates[0].S(0).T(0).Dtis("SSSSSS").ChainDiffs({0, 0}); + templates[1].S(0).T(0).Dtis("SSS---").ChainDiffs({8, 7}).FrameDiffs({8}); + templates[2].S(0).T(1).Dtis("-DS---").ChainDiffs({4, 3}).FrameDiffs({4}); + templates[3].S(0).T(2).Dtis("--D---").ChainDiffs({2, 1}).FrameDiffs({2}); + templates[4].S(0).T(2).Dtis("--D---").ChainDiffs({6, 5}).FrameDiffs({2}); + templates[5].S(1).T(0).Dtis("---SSS").ChainDiffs({1, 1}).FrameDiffs({1}); + templates[6].S(1).T(0).Dtis("---SSS").ChainDiffs({1, 8}).FrameDiffs({8}); + templates[7].S(1).T(1).Dtis("----DS").ChainDiffs({5, 4}).FrameDiffs({4}); + templates[8].S(1).T(2).Dtis("-----D").ChainDiffs({3, 2}).FrameDiffs({2}); + templates[9].S(1).T(2).Dtis("-----D").ChainDiffs({7, 6}).FrameDiffs({2}); + return structure; +} + ScalabilityStructureL3T3Key::~ScalabilityStructureL3T3Key() = default; FrameDependencyStructure ScalabilityStructureL3T3Key::DependencyStructure() diff --git a/modules/video_coding/svc/scalability_structure_key_svc.h b/modules/video_coding/svc/scalability_structure_key_svc.h index 1d3277b5cd..b66f6f83e4 100644 --- a/modules/video_coding/svc/scalability_structure_key_svc.h +++ b/modules/video_coding/svc/scalability_structure_key_svc.h @@ -32,8 +32,9 @@ class ScalabilityStructureKeySvc : public ScalableVideoController { void OnRatesUpdated(const VideoBitrateAllocation& bitrates) override; private: - enum FramePattern { + enum FramePattern : int { kNone, + kKey, kDeltaT0, kDeltaT2A, kDeltaT1, @@ -53,10 +54,16 @@ class ScalabilityStructureKeySvc : public ScalableVideoController { active_decode_targets_.set(sid * num_temporal_layers_ + tid, value); } bool TemporalLayerIsActive(int tid) const; + static DecodeTargetIndication Dti(int sid, + int tid, + const LayerFrameConfig& config); + std::vector KeyframeConfig(); std::vector T0Config(); std::vector T1Config(); - std::vector T2Config(); + std::vector T2Config(FramePattern pattern); + + FramePattern NextPattern(FramePattern last_pattern) const; const int num_spatial_layers_; const int num_temporal_layers_; @@ -94,6 +101,14 @@ class ScalabilityStructureL2T2Key : public ScalabilityStructureKeySvc { FrameDependencyStructure DependencyStructure() const override; }; +class ScalabilityStructureL2T3Key : public ScalabilityStructureKeySvc { + public: + ScalabilityStructureL2T3Key() : ScalabilityStructureKeySvc(2, 3) {} + ~ScalabilityStructureL2T3Key() override; + + FrameDependencyStructure DependencyStructure() const override; +}; + class ScalabilityStructureL3T3Key : public ScalabilityStructureKeySvc { public: ScalabilityStructureL3T3Key() : ScalabilityStructureKeySvc(3, 3) {} diff --git a/modules/video_coding/svc/scalability_structure_key_svc_unittest.cc b/modules/video_coding/svc/scalability_structure_key_svc_unittest.cc index 34ec74726d..5f923bb487 100644 --- a/modules/video_coding/svc/scalability_structure_key_svc_unittest.cc +++ b/modules/video_coding/svc/scalability_structure_key_svc_unittest.cc @@ -62,14 +62,108 @@ TEST(ScalabilityStructureL3T3KeyTest, // Simulate T1 frame dropped by the encoder, // i.e. retrieve config, but skip calling OnEncodeDone. structure.NextFrameConfig(/*restart=*/false); - // one more temporal units (T2) + // one more temporal unit. wrapper.GenerateFrames(/*num_temporal_units=*/1, frames); - ASSERT_THAT(frames, SizeIs(9)); + EXPECT_THAT(frames, SizeIs(9)); + EXPECT_TRUE(wrapper.FrameReferencesAreValid(frames)); +} + +TEST(ScalabilityStructureL3T3KeyTest, + SkippingFrameReusePreviousFrameConfiguration) { + std::vector frames; + ScalabilityStructureL3T3Key structure; + ScalabilityStructureWrapper wrapper(structure); + + // 1st 2 temporal units (T0 and T2) + wrapper.GenerateFrames(/*num_temporal_units=*/2, frames); + ASSERT_THAT(frames, SizeIs(6)); + ASSERT_EQ(frames[0].temporal_id, 0); + ASSERT_EQ(frames[3].temporal_id, 2); + + // Simulate a frame dropped by the encoder, + // i.e. retrieve config, but skip calling OnEncodeDone. + structure.NextFrameConfig(/*restart=*/false); + // two more temporal unit, expect temporal pattern continues + wrapper.GenerateFrames(/*num_temporal_units=*/2, frames); + ASSERT_THAT(frames, SizeIs(12)); + // Expect temporal pattern continues as if there were no dropped frames. + EXPECT_EQ(frames[6].temporal_id, 1); + EXPECT_EQ(frames[9].temporal_id, 2); +} + +TEST(ScalabilityStructureL3T3KeyTest, SkippingKeyFrameTriggersNewKeyFrame) { + std::vector frames; + ScalabilityStructureL3T3Key structure; + ScalabilityStructureWrapper wrapper(structure); + + // Ask for a key frame config, but do not return any frames + structure.NextFrameConfig(/*restart=*/false); + + // Ask for more frames, expect they start with a key frame. + wrapper.GenerateFrames(/*num_temporal_units=*/2, frames); + ASSERT_THAT(frames, SizeIs(6)); + ASSERT_EQ(frames[0].temporal_id, 0); + ASSERT_EQ(frames[3].temporal_id, 2); + EXPECT_TRUE(wrapper.FrameReferencesAreValid(frames)); +} + +TEST(ScalabilityStructureL3T3KeyTest, + SkippingT2FrameAndDisablingT2LayerProduceT1AsNextFrame) { + std::vector frames; + ScalabilityStructureL3T3Key structure; + ScalabilityStructureWrapper wrapper(structure); + + wrapper.GenerateFrames(/*num_temporal_units=*/1, frames); + // Ask for next (T2) frame config, but do not return any frames + auto config = structure.NextFrameConfig(/*restart=*/false); + ASSERT_THAT(config, Not(IsEmpty())); + ASSERT_EQ(config.front().TemporalId(), 2); + + // Disable T2 layer, + structure.OnRatesUpdated(EnableTemporalLayers(/*s0=*/2, /*s1=*/2, /*s2=*/2)); + // Expect instead of reusing unused config, T1 config is generated. + config = structure.NextFrameConfig(/*restart=*/false); + ASSERT_THAT(config, Not(IsEmpty())); + EXPECT_EQ(config.front().TemporalId(), 1); +} + +TEST(ScalabilityStructureL3T3KeyTest, EnableT2LayerWhileProducingT1Frame) { + std::vector frames; + ScalabilityStructureL3T3Key structure; + ScalabilityStructureWrapper wrapper(structure); + + // Disable T2 layer, + structure.OnRatesUpdated(EnableTemporalLayers(/*s0=*/2, /*s1=*/2, /*s2=*/2)); + + // Generate the key frame. + wrapper.GenerateFrames(/*num_temporal_units=*/1, frames); + ASSERT_THAT(frames, SizeIs(3)); EXPECT_EQ(frames[0].temporal_id, 0); - EXPECT_EQ(frames[3].temporal_id, 2); - // T1 frames were dropped by the encoder. + + // Ask for next (T1) frame config, but do not return any frames yet. + auto config = structure.NextFrameConfig(/*restart=*/false); + ASSERT_THAT(config, Not(IsEmpty())); + ASSERT_EQ(config.front().TemporalId(), 1); + + // Reenable T2 layer. + structure.OnRatesUpdated(EnableTemporalLayers(/*s0=*/3, /*s1=*/3, /*s2=*/3)); + + // Finish encoding previously requested config. + for (auto layer_config : config) { + GenericFrameInfo info = structure.OnEncodeDone(layer_config); + EXPECT_EQ(info.temporal_id, 1); + frames.push_back(info); + } + ASSERT_THAT(frames, SizeIs(6)); + + // Generate more frames, expect T2 pattern resumes. + wrapper.GenerateFrames(/*num_temporal_units=*/4, frames); + ASSERT_THAT(frames, SizeIs(18)); EXPECT_EQ(frames[6].temporal_id, 2); + EXPECT_EQ(frames[9].temporal_id, 0); + EXPECT_EQ(frames[12].temporal_id, 2); + EXPECT_EQ(frames[15].temporal_id, 1); EXPECT_TRUE(wrapper.FrameReferencesAreValid(frames)); } diff --git a/modules/video_coding/svc/scalability_structure_l1t2.cc b/modules/video_coding/svc/scalability_structure_l1t2.cc deleted file mode 100644 index f639e2da6e..0000000000 --- a/modules/video_coding/svc/scalability_structure_l1t2.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#include "modules/video_coding/svc/scalability_structure_l1t2.h" - -#include - -#include "api/transport/rtp/dependency_descriptor.h" - -namespace webrtc { - -ScalabilityStructureL1T2::~ScalabilityStructureL1T2() = default; - -FrameDependencyStructure ScalabilityStructureL1T2::DependencyStructure() const { - FrameDependencyStructure structure; - structure.num_decode_targets = 2; - structure.num_chains = 1; - structure.decode_target_protected_by_chain = {0, 0}; - structure.templates.resize(3); - structure.templates[0].T(0).Dtis("SS").ChainDiffs({0}); - structure.templates[1].T(0).Dtis("SS").ChainDiffs({2}).FrameDiffs({2}); - structure.templates[2].T(1).Dtis("-D").ChainDiffs({1}).FrameDiffs({1}); - return structure; -} - -} // namespace webrtc diff --git a/modules/video_coding/svc/scalability_structure_l1t2.h b/modules/video_coding/svc/scalability_structure_l1t2.h deleted file mode 100644 index d2f81aa113..0000000000 --- a/modules/video_coding/svc/scalability_structure_l1t2.h +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#ifndef MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L1T2_H_ -#define MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L1T2_H_ - -#include "api/transport/rtp/dependency_descriptor.h" -#include "modules/video_coding/svc/scalability_structure_full_svc.h" - -namespace webrtc { - -class ScalabilityStructureL1T2 : public ScalabilityStructureFullSvc { - public: - ScalabilityStructureL1T2() : ScalabilityStructureFullSvc(1, 2) {} - ~ScalabilityStructureL1T2() override; - - FrameDependencyStructure DependencyStructure() const override; -}; - -} // namespace webrtc - -#endif // MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L1T2_H_ diff --git a/modules/video_coding/svc/scalability_structure_l1t3.cc b/modules/video_coding/svc/scalability_structure_l1t3.cc deleted file mode 100644 index 17073344c3..0000000000 --- a/modules/video_coding/svc/scalability_structure_l1t3.cc +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#include "modules/video_coding/svc/scalability_structure_l1t3.h" - -#include - -#include "api/transport/rtp/dependency_descriptor.h" - -namespace webrtc { - -ScalabilityStructureL1T3::~ScalabilityStructureL1T3() = default; - -FrameDependencyStructure ScalabilityStructureL1T3::DependencyStructure() const { - FrameDependencyStructure structure; - structure.num_decode_targets = 3; - structure.num_chains = 1; - structure.decode_target_protected_by_chain = {0, 0, 0}; - structure.templates.resize(5); - structure.templates[0].T(0).Dtis("SSS").ChainDiffs({0}); - structure.templates[1].T(0).Dtis("SSS").ChainDiffs({4}).FrameDiffs({4}); - structure.templates[2].T(1).Dtis("-DS").ChainDiffs({2}).FrameDiffs({2}); - structure.templates[3].T(2).Dtis("--D").ChainDiffs({1}).FrameDiffs({1}); - structure.templates[4].T(2).Dtis("--D").ChainDiffs({3}).FrameDiffs({1}); - return structure; -} - -} // namespace webrtc diff --git a/modules/video_coding/svc/scalability_structure_l1t3.h b/modules/video_coding/svc/scalability_structure_l1t3.h deleted file mode 100644 index 00e48ccc47..0000000000 --- a/modules/video_coding/svc/scalability_structure_l1t3.h +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#ifndef MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L1T3_H_ -#define MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L1T3_H_ - -#include "api/transport/rtp/dependency_descriptor.h" -#include "modules/video_coding/svc/scalability_structure_full_svc.h" - -namespace webrtc { - -// T2 0 0 0 0 -// | / | / -// T1 / 0 / 0 ... -// |_/ |_/ -// T0 0-------0------ -// Time-> 0 1 2 3 4 5 6 7 -class ScalabilityStructureL1T3 : public ScalabilityStructureFullSvc { - public: - ScalabilityStructureL1T3() : ScalabilityStructureFullSvc(1, 3) {} - ~ScalabilityStructureL1T3() override; - - FrameDependencyStructure DependencyStructure() const override; -}; - -} // namespace webrtc - -#endif // MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L1T3_H_ diff --git a/modules/video_coding/svc/scalability_structure_l2t1.cc b/modules/video_coding/svc/scalability_structure_l2t1.cc deleted file mode 100644 index efd7516657..0000000000 --- a/modules/video_coding/svc/scalability_structure_l2t1.cc +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#include "modules/video_coding/svc/scalability_structure_l2t1.h" - -#include - -#include "api/transport/rtp/dependency_descriptor.h" - -namespace webrtc { - -ScalabilityStructureL2T1::~ScalabilityStructureL2T1() = default; - -FrameDependencyStructure ScalabilityStructureL2T1::DependencyStructure() const { - FrameDependencyStructure structure; - structure.num_decode_targets = 2; - structure.num_chains = 2; - structure.decode_target_protected_by_chain = {0, 1}; - structure.templates.resize(4); - structure.templates[0].S(0).Dtis("SR").ChainDiffs({2, 1}).FrameDiffs({2}); - structure.templates[1].S(0).Dtis("SS").ChainDiffs({0, 0}); - structure.templates[2].S(1).Dtis("-S").ChainDiffs({1, 1}).FrameDiffs({2, 1}); - structure.templates[3].S(1).Dtis("-S").ChainDiffs({1, 1}).FrameDiffs({1}); - return structure; -} - -} // namespace webrtc diff --git a/modules/video_coding/svc/scalability_structure_l2t1.h b/modules/video_coding/svc/scalability_structure_l2t1.h deleted file mode 100644 index 96a0da56df..0000000000 --- a/modules/video_coding/svc/scalability_structure_l2t1.h +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#ifndef MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L2T1_H_ -#define MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L2T1_H_ - -#include "api/transport/rtp/dependency_descriptor.h" -#include "modules/video_coding/svc/scalability_structure_full_svc.h" - -namespace webrtc { - -// S1 0--0--0- -// | | | ... -// S0 0--0--0- -class ScalabilityStructureL2T1 : public ScalabilityStructureFullSvc { - public: - ScalabilityStructureL2T1() : ScalabilityStructureFullSvc(2, 1) {} - ~ScalabilityStructureL2T1() override; - - FrameDependencyStructure DependencyStructure() const override; -}; - -} // namespace webrtc - -#endif // MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L2T1_H_ diff --git a/modules/video_coding/svc/scalability_structure_l2t1h.cc b/modules/video_coding/svc/scalability_structure_l2t1h.cc deleted file mode 100644 index c4682764ae..0000000000 --- a/modules/video_coding/svc/scalability_structure_l2t1h.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#include "modules/video_coding/svc/scalability_structure_l2t1h.h" - -#include -#include - -#include "absl/base/macros.h" -#include "api/transport/rtp/dependency_descriptor.h" -#include "rtc_base/checks.h" -#include "rtc_base/logging.h" - -namespace webrtc { - -ScalabilityStructureL2T1h::~ScalabilityStructureL2T1h() = default; - -ScalableVideoController::StreamLayersConfig -ScalabilityStructureL2T1h::StreamConfig() const { - StreamLayersConfig result; - result.num_spatial_layers = 2; - result.num_temporal_layers = 1; - // 1.5:1 scaling, see https://w3c.github.io/webrtc-svc/#scalabilitymodes* - result.scaling_factor_num[0] = 2; - result.scaling_factor_den[0] = 3; - return result; -} - -} // namespace webrtc diff --git a/modules/video_coding/svc/scalability_structure_l2t1h.h b/modules/video_coding/svc/scalability_structure_l2t1h.h deleted file mode 100644 index 7200a10843..0000000000 --- a/modules/video_coding/svc/scalability_structure_l2t1h.h +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#ifndef MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L2T1H_H_ -#define MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L2T1H_H_ - -#include "modules/video_coding/svc/scalability_structure_l2t1.h" -#include "modules/video_coding/svc/scalable_video_controller.h" - -namespace webrtc { - -class ScalabilityStructureL2T1h : public ScalabilityStructureL2T1 { - public: - ~ScalabilityStructureL2T1h() override; - - StreamLayersConfig StreamConfig() const override; -}; - -} // namespace webrtc - -#endif // MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L2T1H_H_ diff --git a/modules/video_coding/svc/scalability_structure_l2t2.cc b/modules/video_coding/svc/scalability_structure_l2t2.cc deleted file mode 100644 index a381ad080a..0000000000 --- a/modules/video_coding/svc/scalability_structure_l2t2.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#include "modules/video_coding/svc/scalability_structure_l2t2.h" - -#include - -#include "api/transport/rtp/dependency_descriptor.h" - -namespace webrtc { - -ScalabilityStructureL2T2::~ScalabilityStructureL2T2() = default; - -FrameDependencyStructure ScalabilityStructureL2T2::DependencyStructure() const { - FrameDependencyStructure structure; - structure.num_decode_targets = 4; - structure.num_chains = 2; - structure.decode_target_protected_by_chain = {0, 0, 1, 1}; - structure.templates.resize(6); - auto& templates = structure.templates; - templates[0].S(0).T(0).Dtis("SSSS").ChainDiffs({0, 0}); - templates[1].S(0).T(0).Dtis("SSRR").ChainDiffs({4, 3}).FrameDiffs({4}); - templates[2].S(0).T(1).Dtis("-D-R").ChainDiffs({2, 1}).FrameDiffs({2}); - templates[3].S(1).T(0).Dtis("--SS").ChainDiffs({1, 1}).FrameDiffs({1}); - templates[4].S(1).T(0).Dtis("--SS").ChainDiffs({1, 1}).FrameDiffs({4, 1}); - templates[5].S(1).T(1).Dtis("---D").ChainDiffs({3, 2}).FrameDiffs({2, 1}); - return structure; -} - -} // namespace webrtc diff --git a/modules/video_coding/svc/scalability_structure_l2t2.h b/modules/video_coding/svc/scalability_structure_l2t2.h deleted file mode 100644 index 781ea7e60d..0000000000 --- a/modules/video_coding/svc/scalability_structure_l2t2.h +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#ifndef MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L2T2_H_ -#define MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L2T2_H_ - -#include "api/transport/rtp/dependency_descriptor.h" -#include "modules/video_coding/svc/scalability_structure_full_svc.h" - -namespace webrtc { - -// S1T1 0 0 -// /| /| / -// S1T0 0-+-0-+-0 -// | | | | | ... -// S0T1 | 0 | 0 | -// |/ |/ |/ -// S0T0 0---0---0-- -// Time-> 0 1 2 3 4 -class ScalabilityStructureL2T2 : public ScalabilityStructureFullSvc { - public: - ScalabilityStructureL2T2() : ScalabilityStructureFullSvc(2, 2) {} - ~ScalabilityStructureL2T2() override; - - FrameDependencyStructure DependencyStructure() const override; -}; - -} // namespace webrtc - -#endif // MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L2T2_H_ diff --git a/modules/video_coding/svc/scalability_structure_l3t1.cc b/modules/video_coding/svc/scalability_structure_l3t1.cc deleted file mode 100644 index d7a5324465..0000000000 --- a/modules/video_coding/svc/scalability_structure_l3t1.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#include "modules/video_coding/svc/scalability_structure_l3t1.h" - -#include - -#include "api/transport/rtp/dependency_descriptor.h" - -namespace webrtc { - -ScalabilityStructureL3T1::~ScalabilityStructureL3T1() = default; - -FrameDependencyStructure ScalabilityStructureL3T1::DependencyStructure() const { - FrameDependencyStructure structure; - structure.num_decode_targets = 3; - structure.num_chains = 3; - structure.decode_target_protected_by_chain = {0, 1, 2}; - auto& templates = structure.templates; - templates.resize(6); - templates[0].S(0).Dtis("SRR").ChainDiffs({3, 2, 1}).FrameDiffs({3}); - templates[1].S(0).Dtis("SSS").ChainDiffs({0, 0, 0}); - templates[2].S(1).Dtis("-SR").ChainDiffs({1, 1, 1}).FrameDiffs({3, 1}); - templates[3].S(1).Dtis("-SS").ChainDiffs({1, 1, 1}).FrameDiffs({1}); - templates[4].S(2).Dtis("--S").ChainDiffs({2, 1, 1}).FrameDiffs({3, 1}); - templates[5].S(2).Dtis("--S").ChainDiffs({2, 1, 1}).FrameDiffs({1}); - return structure; -} - -} // namespace webrtc diff --git a/modules/video_coding/svc/scalability_structure_l3t1.h b/modules/video_coding/svc/scalability_structure_l3t1.h deleted file mode 100644 index dea40e96b8..0000000000 --- a/modules/video_coding/svc/scalability_structure_l3t1.h +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#ifndef MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L3T1_H_ -#define MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L3T1_H_ - -#include "api/transport/rtp/dependency_descriptor.h" -#include "modules/video_coding/svc/scalability_structure_full_svc.h" - -namespace webrtc { - -// S2 0-0-0- -// | | | -// S1 0-0-0-... -// | | | -// S0 0-0-0- -// Time-> 0 1 2 -class ScalabilityStructureL3T1 : public ScalabilityStructureFullSvc { - public: - ScalabilityStructureL3T1() : ScalabilityStructureFullSvc(3, 1) {} - ~ScalabilityStructureL3T1() override; - - FrameDependencyStructure DependencyStructure() const override; -}; - -} // namespace webrtc - -#endif // MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L3T1_H_ diff --git a/modules/video_coding/svc/scalability_structure_l3t3.cc b/modules/video_coding/svc/scalability_structure_l3t3.cc deleted file mode 100644 index 932056b0d3..0000000000 --- a/modules/video_coding/svc/scalability_structure_l3t3.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#include "modules/video_coding/svc/scalability_structure_l3t3.h" - -#include - -#include "api/transport/rtp/dependency_descriptor.h" - -namespace webrtc { - -ScalabilityStructureL3T3::~ScalabilityStructureL3T3() = default; - -FrameDependencyStructure ScalabilityStructureL3T3::DependencyStructure() const { - FrameDependencyStructure structure; - structure.num_decode_targets = 9; - structure.num_chains = 3; - structure.decode_target_protected_by_chain = {0, 0, 0, 1, 1, 1, 2, 2, 2}; - auto& t = structure.templates; - t.resize(15); - // Templates are shown in the order frames following them appear in the - // stream, but in `structure.templates` array templates are sorted by - // (`spatial_id`, `temporal_id`) since that is a dependency descriptor - // requirement. Indexes are written in hex for nicer alignment. - t[0x1].S(0).T(0).Dtis("SSSSSSSSS").ChainDiffs({0, 0, 0}); - t[0x6].S(1).T(0).Dtis("---SSSSSS").ChainDiffs({1, 1, 1}).FrameDiffs({1}); - t[0xB].S(2).T(0).Dtis("------SSS").ChainDiffs({2, 1, 1}).FrameDiffs({1}); - t[0x3].S(0).T(2).Dtis("--D--R--R").ChainDiffs({3, 2, 1}).FrameDiffs({3}); - t[0x8].S(1).T(2).Dtis("-----D--R").ChainDiffs({4, 3, 2}).FrameDiffs({3, 1}); - t[0xD].S(2).T(2).Dtis("--------D").ChainDiffs({5, 4, 3}).FrameDiffs({3, 1}); - t[0x2].S(0).T(1).Dtis("-DS-RR-RR").ChainDiffs({6, 5, 4}).FrameDiffs({6}); - t[0x7].S(1).T(1).Dtis("----DS-RR").ChainDiffs({7, 6, 5}).FrameDiffs({6, 1}); - t[0xC].S(2).T(1).Dtis("-------DS").ChainDiffs({8, 7, 6}).FrameDiffs({6, 1}); - t[0x4].S(0).T(2).Dtis("--D--R--R").ChainDiffs({9, 8, 7}).FrameDiffs({3}); - t[0x9].S(1).T(2).Dtis("-----D--R").ChainDiffs({10, 9, 8}).FrameDiffs({3, 1}); - t[0xE].S(2).T(2).Dtis("--------D").ChainDiffs({11, 10, 9}).FrameDiffs({3, 1}); - t[0x0].S(0).T(0).Dtis("SSSRRRRRR").ChainDiffs({12, 11, 10}).FrameDiffs({12}); - t[0x5].S(1).T(0).Dtis("---SSSRRR").ChainDiffs({1, 1, 1}).FrameDiffs({12, 1}); - t[0xA].S(2).T(0).Dtis("------SSS").ChainDiffs({2, 1, 1}).FrameDiffs({12, 1}); - return structure; -} - -} // namespace webrtc diff --git a/modules/video_coding/svc/scalability_structure_l3t3.h b/modules/video_coding/svc/scalability_structure_l3t3.h deleted file mode 100644 index 3f42726cc1..0000000000 --- a/modules/video_coding/svc/scalability_structure_l3t3.h +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#ifndef MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L3T3_H_ -#define MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L3T3_H_ - -#include "api/transport/rtp/dependency_descriptor.h" -#include "modules/video_coding/svc/scalability_structure_full_svc.h" - -namespace webrtc { - -// https://aomediacodec.github.io/av1-rtp-spec/#a63-l3t3-full-svc -class ScalabilityStructureL3T3 : public ScalabilityStructureFullSvc { - public: - ScalabilityStructureL3T3() : ScalabilityStructureFullSvc(3, 3) {} - ~ScalabilityStructureL3T3() override; - - FrameDependencyStructure DependencyStructure() const override; -}; - -} // namespace webrtc - -#endif // MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_L3T3_H_ diff --git a/modules/video_coding/svc/scalability_structure_s2t1.cc b/modules/video_coding/svc/scalability_structure_s2t1.cc deleted file mode 100644 index 618deb4b37..0000000000 --- a/modules/video_coding/svc/scalability_structure_s2t1.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#include "modules/video_coding/svc/scalability_structure_s2t1.h" - -#include -#include - -#include "absl/base/macros.h" -#include "api/transport/rtp/dependency_descriptor.h" -#include "rtc_base/checks.h" -#include "rtc_base/logging.h" - -namespace webrtc { - -constexpr int ScalabilityStructureS2T1::kNumSpatialLayers; - -ScalabilityStructureS2T1::~ScalabilityStructureS2T1() = default; - -ScalableVideoController::StreamLayersConfig -ScalabilityStructureS2T1::StreamConfig() const { - StreamLayersConfig result; - result.num_spatial_layers = kNumSpatialLayers; - result.num_temporal_layers = 1; - result.scaling_factor_num[0] = 1; - result.scaling_factor_den[0] = 2; - return result; -} - -FrameDependencyStructure ScalabilityStructureS2T1::DependencyStructure() const { - FrameDependencyStructure structure; - structure.num_decode_targets = kNumSpatialLayers; - structure.num_chains = kNumSpatialLayers; - structure.decode_target_protected_by_chain = {0, 1}; - structure.templates.resize(4); - structure.templates[0].S(0).Dtis("S-").ChainDiffs({2, 1}).FrameDiffs({2}); - structure.templates[1].S(0).Dtis("S-").ChainDiffs({0, 0}); - structure.templates[2].S(1).Dtis("-S").ChainDiffs({1, 2}).FrameDiffs({2}); - structure.templates[3].S(1).Dtis("-S").ChainDiffs({1, 0}); - return structure; -} - -std::vector -ScalabilityStructureS2T1::NextFrameConfig(bool restart) { - if (restart) { - can_reference_frame_for_spatial_id_.reset(); - } - std::vector configs; - configs.reserve(kNumSpatialLayers); - for (int sid = 0; sid < kNumSpatialLayers; ++sid) { - if (!active_decode_targets_[sid]) { - can_reference_frame_for_spatial_id_.reset(sid); - continue; - } - configs.emplace_back(); - LayerFrameConfig& config = configs.back().S(sid); - if (can_reference_frame_for_spatial_id_[sid]) { - config.ReferenceAndUpdate(sid); - } else { - config.Keyframe().Update(sid); - can_reference_frame_for_spatial_id_.set(sid); - } - } - - return configs; -} - -GenericFrameInfo ScalabilityStructureS2T1::OnEncodeDone( - const LayerFrameConfig& config) { - GenericFrameInfo frame_info; - frame_info.spatial_id = config.SpatialId(); - frame_info.temporal_id = config.TemporalId(); - frame_info.encoder_buffers = config.Buffers(); - frame_info.decode_target_indications = { - config.SpatialId() == 0 ? DecodeTargetIndication::kSwitch - : DecodeTargetIndication::kNotPresent, - config.SpatialId() == 1 ? DecodeTargetIndication::kSwitch - : DecodeTargetIndication::kNotPresent, - }; - frame_info.part_of_chain = {config.SpatialId() == 0, config.SpatialId() == 1}; - frame_info.active_decode_targets = active_decode_targets_; - return frame_info; -} - -void ScalabilityStructureS2T1::OnRatesUpdated( - const VideoBitrateAllocation& bitrates) { - active_decode_targets_.set(0, bitrates.GetBitrate(/*sid=*/0, /*tid=*/0) > 0); - active_decode_targets_.set(1, bitrates.GetBitrate(/*sid=*/1, /*tid=*/0) > 0); -} - -} // namespace webrtc diff --git a/modules/video_coding/svc/scalability_structure_s2t1.h b/modules/video_coding/svc/scalability_structure_s2t1.h deleted file mode 100644 index 0f27e480fa..0000000000 --- a/modules/video_coding/svc/scalability_structure_s2t1.h +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#ifndef MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_S2T1_H_ -#define MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_S2T1_H_ - -#include - -#include "api/transport/rtp/dependency_descriptor.h" -#include "api/video/video_bitrate_allocation.h" -#include "common_video/generic_frame_descriptor/generic_frame_info.h" -#include "modules/video_coding/svc/scalable_video_controller.h" - -namespace webrtc { - -// S1 0--0--0- -// ... -// S0 0--0--0- -class ScalabilityStructureS2T1 : public ScalableVideoController { - public: - ~ScalabilityStructureS2T1() override; - - StreamLayersConfig StreamConfig() const override; - FrameDependencyStructure DependencyStructure() const override; - - std::vector NextFrameConfig(bool restart) override; - GenericFrameInfo OnEncodeDone(const LayerFrameConfig& config) override; - void OnRatesUpdated(const VideoBitrateAllocation& bitrates) override; - - private: - static constexpr int kNumSpatialLayers = 2; - - std::bitset can_reference_frame_for_spatial_id_; - std::bitset<32> active_decode_targets_ = 0b11; -}; - -} // namespace webrtc - -#endif // MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_S2T1_H_ diff --git a/modules/video_coding/svc/scalability_structure_simulcast.cc b/modules/video_coding/svc/scalability_structure_simulcast.cc new file mode 100644 index 0000000000..c236066736 --- /dev/null +++ b/modules/video_coding/svc/scalability_structure_simulcast.cc @@ -0,0 +1,273 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/video_coding/svc/scalability_structure_simulcast.h" + +#include +#include + +#include "absl/base/macros.h" +#include "api/transport/rtp/dependency_descriptor.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" + +namespace webrtc { +namespace { + +DecodeTargetIndication +Dti(int sid, int tid, const ScalableVideoController::LayerFrameConfig& config) { + if (sid != config.SpatialId() || tid < config.TemporalId()) { + return DecodeTargetIndication::kNotPresent; + } + if (tid == 0) { + RTC_DCHECK_EQ(config.TemporalId(), 0); + return DecodeTargetIndication::kSwitch; + } + if (tid == config.TemporalId()) { + return DecodeTargetIndication::kDiscardable; + } + RTC_DCHECK_GT(tid, config.TemporalId()); + return DecodeTargetIndication::kSwitch; +} + +} // namespace + +constexpr int ScalabilityStructureSimulcast::kMaxNumSpatialLayers; +constexpr int ScalabilityStructureSimulcast::kMaxNumTemporalLayers; + +ScalabilityStructureSimulcast::ScalabilityStructureSimulcast( + int num_spatial_layers, + int num_temporal_layers) + : num_spatial_layers_(num_spatial_layers), + num_temporal_layers_(num_temporal_layers), + active_decode_targets_( + (uint32_t{1} << (num_spatial_layers * num_temporal_layers)) - 1) { + RTC_DCHECK_LE(num_spatial_layers, kMaxNumSpatialLayers); + RTC_DCHECK_LE(num_temporal_layers, kMaxNumTemporalLayers); +} + +ScalabilityStructureSimulcast::~ScalabilityStructureSimulcast() = default; + +ScalableVideoController::StreamLayersConfig +ScalabilityStructureSimulcast::StreamConfig() const { + StreamLayersConfig result; + result.num_spatial_layers = num_spatial_layers_; + result.num_temporal_layers = num_temporal_layers_; + result.scaling_factor_num[num_spatial_layers_ - 1] = 1; + result.scaling_factor_den[num_spatial_layers_ - 1] = 1; + for (int sid = num_spatial_layers_ - 1; sid > 0; --sid) { + result.scaling_factor_num[sid - 1] = 1; + result.scaling_factor_den[sid - 1] = 2 * result.scaling_factor_den[sid]; + } + return result; +} + +bool ScalabilityStructureSimulcast::TemporalLayerIsActive(int tid) const { + if (tid >= num_temporal_layers_) { + return false; + } + for (int sid = 0; sid < num_spatial_layers_; ++sid) { + if (DecodeTargetIsActive(sid, tid)) { + return true; + } + } + return false; +} + +ScalabilityStructureSimulcast::FramePattern +ScalabilityStructureSimulcast::NextPattern() const { + switch (last_pattern_) { + case kNone: + case kDeltaT2B: + return kDeltaT0; + case kDeltaT2A: + if (TemporalLayerIsActive(1)) { + return kDeltaT1; + } + return kDeltaT0; + case kDeltaT1: + if (TemporalLayerIsActive(2)) { + return kDeltaT2B; + } + return kDeltaT0; + case kDeltaT0: + if (TemporalLayerIsActive(2)) { + return kDeltaT2A; + } + if (TemporalLayerIsActive(1)) { + return kDeltaT1; + } + return kDeltaT0; + } + RTC_NOTREACHED(); + return kDeltaT0; +} + +std::vector +ScalabilityStructureSimulcast::NextFrameConfig(bool restart) { + std::vector configs; + if (active_decode_targets_.none()) { + last_pattern_ = kNone; + return configs; + } + configs.reserve(num_spatial_layers_); + + if (last_pattern_ == kNone || restart) { + can_reference_t0_frame_for_spatial_id_.reset(); + last_pattern_ = kNone; + } + FramePattern current_pattern = NextPattern(); + + switch (current_pattern) { + case kDeltaT0: + // Disallow temporal references cross T0 on higher temporal layers. + can_reference_t1_frame_for_spatial_id_.reset(); + for (int sid = 0; sid < num_spatial_layers_; ++sid) { + if (!DecodeTargetIsActive(sid, /*tid=*/0)) { + // Next frame from the spatial layer `sid` shouldn't depend on + // potentially old previous frame from the spatial layer `sid`. + can_reference_t0_frame_for_spatial_id_.reset(sid); + continue; + } + configs.emplace_back(); + ScalableVideoController::LayerFrameConfig& config = configs.back(); + config.Id(current_pattern).S(sid).T(0); + + if (can_reference_t0_frame_for_spatial_id_[sid]) { + config.ReferenceAndUpdate(BufferIndex(sid, /*tid=*/0)); + } else { + config.Keyframe().Update(BufferIndex(sid, /*tid=*/0)); + } + can_reference_t0_frame_for_spatial_id_.set(sid); + } + break; + case kDeltaT1: + for (int sid = 0; sid < num_spatial_layers_; ++sid) { + if (!DecodeTargetIsActive(sid, /*tid=*/1) || + !can_reference_t0_frame_for_spatial_id_[sid]) { + continue; + } + configs.emplace_back(); + ScalableVideoController::LayerFrameConfig& config = configs.back(); + config.Id(current_pattern) + .S(sid) + .T(1) + .Reference(BufferIndex(sid, /*tid=*/0)); + // Save frame only if there is a higher temporal layer that may need it. + if (num_temporal_layers_ > 2) { + config.Update(BufferIndex(sid, /*tid=*/1)); + } + } + break; + case kDeltaT2A: + case kDeltaT2B: + for (int sid = 0; sid < num_spatial_layers_; ++sid) { + if (!DecodeTargetIsActive(sid, /*tid=*/2) || + !can_reference_t0_frame_for_spatial_id_[sid]) { + continue; + } + configs.emplace_back(); + ScalableVideoController::LayerFrameConfig& config = configs.back(); + config.Id(current_pattern).S(sid).T(2); + if (can_reference_t1_frame_for_spatial_id_[sid]) { + config.Reference(BufferIndex(sid, /*tid=*/1)); + } else { + config.Reference(BufferIndex(sid, /*tid=*/0)); + } + } + break; + case kNone: + RTC_NOTREACHED(); + break; + } + + return configs; +} + +GenericFrameInfo ScalabilityStructureSimulcast::OnEncodeDone( + const LayerFrameConfig& config) { + last_pattern_ = static_cast(config.Id()); + if (config.TemporalId() == 1) { + can_reference_t1_frame_for_spatial_id_.set(config.SpatialId()); + } + GenericFrameInfo frame_info; + frame_info.spatial_id = config.SpatialId(); + frame_info.temporal_id = config.TemporalId(); + frame_info.encoder_buffers = config.Buffers(); + frame_info.decode_target_indications.reserve(num_spatial_layers_ * + num_temporal_layers_); + for (int sid = 0; sid < num_spatial_layers_; ++sid) { + for (int tid = 0; tid < num_temporal_layers_; ++tid) { + frame_info.decode_target_indications.push_back(Dti(sid, tid, config)); + } + } + frame_info.part_of_chain.assign(num_spatial_layers_, false); + if (config.TemporalId() == 0) { + frame_info.part_of_chain[config.SpatialId()] = true; + } + frame_info.active_decode_targets = active_decode_targets_; + return frame_info; +} + +void ScalabilityStructureSimulcast::OnRatesUpdated( + const VideoBitrateAllocation& bitrates) { + for (int sid = 0; sid < num_spatial_layers_; ++sid) { + // Enable/disable spatial layers independetely. + bool active = true; + for (int tid = 0; tid < num_temporal_layers_; ++tid) { + // To enable temporal layer, require bitrates for lower temporal layers. + active = active && bitrates.GetBitrate(sid, tid) > 0; + SetDecodeTargetIsActive(sid, tid, active); + } + } +} + +FrameDependencyStructure ScalabilityStructureS2T1::DependencyStructure() const { + FrameDependencyStructure structure; + structure.num_decode_targets = 2; + structure.num_chains = 2; + structure.decode_target_protected_by_chain = {0, 1}; + structure.templates.resize(4); + structure.templates[0].S(0).Dtis("S-").ChainDiffs({2, 1}).FrameDiffs({2}); + structure.templates[1].S(0).Dtis("S-").ChainDiffs({0, 0}); + structure.templates[2].S(1).Dtis("-S").ChainDiffs({1, 2}).FrameDiffs({2}); + structure.templates[3].S(1).Dtis("-S").ChainDiffs({1, 0}); + return structure; +} + +FrameDependencyStructure ScalabilityStructureS3T3::DependencyStructure() const { + FrameDependencyStructure structure; + structure.num_decode_targets = 9; + structure.num_chains = 3; + structure.decode_target_protected_by_chain = {0, 0, 0, 1, 1, 1, 2, 2, 2}; + auto& t = structure.templates; + t.resize(15); + // Templates are shown in the order frames following them appear in the + // stream, but in `structure.templates` array templates are sorted by + // (`spatial_id`, `temporal_id`) since that is a dependency descriptor + // requirement. Indexes are written in hex for nicer alignment. + t[0x1].S(0).T(0).Dtis("SSS------").ChainDiffs({0, 0, 0}); + t[0x6].S(1).T(0).Dtis("---SSS---").ChainDiffs({1, 0, 0}); + t[0xB].S(2).T(0).Dtis("------SSS").ChainDiffs({2, 1, 0}); + t[0x3].S(0).T(2).Dtis("--D------").ChainDiffs({3, 2, 1}).FrameDiffs({3}); + t[0x8].S(1).T(2).Dtis("-----D---").ChainDiffs({4, 3, 2}).FrameDiffs({3}); + t[0xD].S(2).T(2).Dtis("--------D").ChainDiffs({5, 4, 3}).FrameDiffs({3}); + t[0x2].S(0).T(1).Dtis("-DS------").ChainDiffs({6, 5, 4}).FrameDiffs({6}); + t[0x7].S(1).T(1).Dtis("----DS---").ChainDiffs({7, 6, 5}).FrameDiffs({6}); + t[0xC].S(2).T(1).Dtis("-------DS").ChainDiffs({8, 7, 6}).FrameDiffs({6}); + t[0x4].S(0).T(2).Dtis("--D------").ChainDiffs({9, 8, 7}).FrameDiffs({3}); + t[0x9].S(1).T(2).Dtis("-----D---").ChainDiffs({10, 9, 8}).FrameDiffs({3}); + t[0xE].S(2).T(2).Dtis("--------D").ChainDiffs({11, 10, 9}).FrameDiffs({3}); + t[0x0].S(0).T(0).Dtis("SSS------").ChainDiffs({12, 11, 10}).FrameDiffs({12}); + t[0x5].S(1).T(0).Dtis("---SSS---").ChainDiffs({1, 12, 11}).FrameDiffs({12}); + t[0xA].S(2).T(0).Dtis("------SSS").ChainDiffs({2, 1, 12}).FrameDiffs({12}); + return structure; +} + +} // namespace webrtc diff --git a/modules/video_coding/svc/scalability_structure_simulcast.h b/modules/video_coding/svc/scalability_structure_simulcast.h new file mode 100644 index 0000000000..7b57df2985 --- /dev/null +++ b/modules/video_coding/svc/scalability_structure_simulcast.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_SIMULCAST_H_ +#define MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_SIMULCAST_H_ + +#include + +#include "api/transport/rtp/dependency_descriptor.h" +#include "api/video/video_bitrate_allocation.h" +#include "common_video/generic_frame_descriptor/generic_frame_info.h" +#include "modules/video_coding/svc/scalable_video_controller.h" + +namespace webrtc { + +// Scalability structure with multiple independent spatial layers each with the +// same temporal layering. +class ScalabilityStructureSimulcast : public ScalableVideoController { + public: + ScalabilityStructureSimulcast(int num_spatial_layers, + int num_temporal_layers); + ~ScalabilityStructureSimulcast() override; + + StreamLayersConfig StreamConfig() const override; + std::vector NextFrameConfig(bool restart) override; + GenericFrameInfo OnEncodeDone(const LayerFrameConfig& config) override; + void OnRatesUpdated(const VideoBitrateAllocation& bitrates) override; + + private: + enum FramePattern { + kNone, + kDeltaT2A, + kDeltaT1, + kDeltaT2B, + kDeltaT0, + }; + static constexpr int kMaxNumSpatialLayers = 3; + static constexpr int kMaxNumTemporalLayers = 3; + + // Index of the buffer to store last frame for layer (`sid`, `tid`) + int BufferIndex(int sid, int tid) const { + return tid * num_spatial_layers_ + sid; + } + bool DecodeTargetIsActive(int sid, int tid) const { + return active_decode_targets_[sid * num_temporal_layers_ + tid]; + } + void SetDecodeTargetIsActive(int sid, int tid, bool value) { + active_decode_targets_.set(sid * num_temporal_layers_ + tid, value); + } + FramePattern NextPattern() const; + bool TemporalLayerIsActive(int tid) const; + + const int num_spatial_layers_; + const int num_temporal_layers_; + + FramePattern last_pattern_ = kNone; + std::bitset can_reference_t0_frame_for_spatial_id_ = 0; + std::bitset can_reference_t1_frame_for_spatial_id_ = 0; + std::bitset<32> active_decode_targets_; +}; + +// S1 0--0--0- +// ... +// S0 0--0--0- +class ScalabilityStructureS2T1 : public ScalabilityStructureSimulcast { + public: + ScalabilityStructureS2T1() : ScalabilityStructureSimulcast(2, 1) {} + ~ScalabilityStructureS2T1() override = default; + + FrameDependencyStructure DependencyStructure() const override; +}; + +class ScalabilityStructureS3T3 : public ScalabilityStructureSimulcast { + public: + ScalabilityStructureS3T3() : ScalabilityStructureSimulcast(3, 3) {} + ~ScalabilityStructureS3T3() override = default; + + FrameDependencyStructure DependencyStructure() const override; +}; + +} // namespace webrtc + +#endif // MODULES_VIDEO_CODING_SVC_SCALABILITY_STRUCTURE_SIMULCAST_H_ diff --git a/modules/video_coding/svc/scalability_structure_unittest.cc b/modules/video_coding/svc/scalability_structure_unittest.cc index d6766f0d50..8bd933be5d 100644 --- a/modules/video_coding/svc/scalability_structure_unittest.cc +++ b/modules/video_coding/svc/scalability_structure_unittest.cc @@ -292,16 +292,19 @@ TEST_P(ScalabilityStructureTest, ProduceNoFrameForDisabledLayers) { INSTANTIATE_TEST_SUITE_P( Svc, ScalabilityStructureTest, - Values(SvcTestParam{"L1T2", /*num_temporal_units=*/4}, + Values(SvcTestParam{"NONE", /*num_temporal_units=*/3}, + SvcTestParam{"L1T2", /*num_temporal_units=*/4}, SvcTestParam{"L1T3", /*num_temporal_units=*/8}, SvcTestParam{"L2T1", /*num_temporal_units=*/3}, SvcTestParam{"L2T1_KEY", /*num_temporal_units=*/3}, SvcTestParam{"L3T1", /*num_temporal_units=*/3}, SvcTestParam{"L3T3", /*num_temporal_units=*/8}, SvcTestParam{"S2T1", /*num_temporal_units=*/3}, + SvcTestParam{"S3T3", /*num_temporal_units=*/8}, SvcTestParam{"L2T2", /*num_temporal_units=*/4}, SvcTestParam{"L2T2_KEY", /*num_temporal_units=*/4}, SvcTestParam{"L2T2_KEY_SHIFT", /*num_temporal_units=*/4}, + SvcTestParam{"L2T3_KEY", /*num_temporal_units=*/8}, SvcTestParam{"L3T3_KEY", /*num_temporal_units=*/8}), [](const testing::TestParamInfo& info) { return info.param.name; diff --git a/modules/video_coding/svc/scalable_video_controller_no_layering.cc b/modules/video_coding/svc/scalable_video_controller_no_layering.cc index 6d8e6e8fc6..3934e57804 100644 --- a/modules/video_coding/svc/scalable_video_controller_no_layering.cc +++ b/modules/video_coding/svc/scalable_video_controller_no_layering.cc @@ -32,14 +32,28 @@ FrameDependencyStructure ScalableVideoControllerNoLayering::DependencyStructure() const { FrameDependencyStructure structure; structure.num_decode_targets = 1; - FrameDependencyTemplate a_template; - a_template.decode_target_indications = {DecodeTargetIndication::kSwitch}; - structure.templates.push_back(a_template); + structure.num_chains = 1; + structure.decode_target_protected_by_chain = {0}; + + FrameDependencyTemplate key_frame; + key_frame.decode_target_indications = {DecodeTargetIndication::kSwitch}; + key_frame.chain_diffs = {0}; + structure.templates.push_back(key_frame); + + FrameDependencyTemplate delta_frame; + delta_frame.decode_target_indications = {DecodeTargetIndication::kSwitch}; + delta_frame.chain_diffs = {1}; + delta_frame.frame_diffs = {1}; + structure.templates.push_back(delta_frame); + return structure; } std::vector ScalableVideoControllerNoLayering::NextFrameConfig(bool restart) { + if (!enabled_) { + return {}; + } std::vector result(1); if (restart || start_) { result[0].Id(0).Keyframe().Update(0); @@ -61,7 +75,13 @@ GenericFrameInfo ScalableVideoControllerNoLayering::OnEncodeDone( } } frame_info.decode_target_indications = {DecodeTargetIndication::kSwitch}; + frame_info.part_of_chain = {true}; return frame_info; } +void ScalableVideoControllerNoLayering::OnRatesUpdated( + const VideoBitrateAllocation& bitrates) { + enabled_ = bitrates.GetBitrate(0, 0) > 0; +} + } // namespace webrtc diff --git a/modules/video_coding/svc/scalable_video_controller_no_layering.h b/modules/video_coding/svc/scalable_video_controller_no_layering.h index e253ffe841..6d66b61c8b 100644 --- a/modules/video_coding/svc/scalable_video_controller_no_layering.h +++ b/modules/video_coding/svc/scalable_video_controller_no_layering.h @@ -28,10 +28,11 @@ class ScalableVideoControllerNoLayering : public ScalableVideoController { std::vector NextFrameConfig(bool restart) override; GenericFrameInfo OnEncodeDone(const LayerFrameConfig& config) override; - void OnRatesUpdated(const VideoBitrateAllocation& bitrates) override {} + void OnRatesUpdated(const VideoBitrateAllocation& bitrates) override; private: bool start_ = true; + bool enabled_ = true; }; } // namespace webrtc diff --git a/modules/video_coding/timestamp_map.cc b/modules/video_coding/timestamp_map.cc index d79075ff21..f6fb81815a 100644 --- a/modules/video_coding/timestamp_map.cc +++ b/modules/video_coding/timestamp_map.cc @@ -24,7 +24,7 @@ VCMTimestampMap::VCMTimestampMap(size_t capacity) VCMTimestampMap::~VCMTimestampMap() {} -void VCMTimestampMap::Add(uint32_t timestamp, VCMFrameInformation* data) { +void VCMTimestampMap::Add(uint32_t timestamp, const VCMFrameInformation& data) { ring_buffer_[next_add_idx_].timestamp = timestamp; ring_buffer_[next_add_idx_].data = data; next_add_idx_ = (next_add_idx_ + 1) % capacity_; @@ -35,18 +35,18 @@ void VCMTimestampMap::Add(uint32_t timestamp, VCMFrameInformation* data) { } } -VCMFrameInformation* VCMTimestampMap::Pop(uint32_t timestamp) { +absl::optional VCMTimestampMap::Pop(uint32_t timestamp) { while (!IsEmpty()) { if (ring_buffer_[next_pop_idx_].timestamp == timestamp) { // Found start time for this timestamp. - VCMFrameInformation* data = ring_buffer_[next_pop_idx_].data; - ring_buffer_[next_pop_idx_].data = nullptr; + const VCMFrameInformation& data = ring_buffer_[next_pop_idx_].data; + ring_buffer_[next_pop_idx_].timestamp = 0; next_pop_idx_ = (next_pop_idx_ + 1) % capacity_; return data; } else if (IsNewerTimestamp(ring_buffer_[next_pop_idx_].timestamp, timestamp)) { // The timestamp we are looking for is not in the list. - return nullptr; + return absl::nullopt; } // Not in this position, check next (and forget this position). @@ -54,7 +54,7 @@ VCMFrameInformation* VCMTimestampMap::Pop(uint32_t timestamp) { } // Could not find matching timestamp in list. - return nullptr; + return absl::nullopt; } bool VCMTimestampMap::IsEmpty() const { @@ -69,4 +69,11 @@ size_t VCMTimestampMap::Size() const { : next_add_idx_ + capacity_ - next_pop_idx_; } +void VCMTimestampMap::Clear() { + while (!IsEmpty()) { + ring_buffer_[next_pop_idx_].timestamp = 0; + next_pop_idx_ = (next_pop_idx_ + 1) % capacity_; + } +} + } // namespace webrtc diff --git a/modules/video_coding/timestamp_map.h b/modules/video_coding/timestamp_map.h index cfa12573ec..dc20a0551c 100644 --- a/modules/video_coding/timestamp_map.h +++ b/modules/video_coding/timestamp_map.h @@ -13,23 +13,42 @@ #include +#include "absl/types/optional.h" +#include "api/rtp_packet_infos.h" +#include "api/units/timestamp.h" +#include "api/video/encoded_image.h" +#include "api/video/video_content_type.h" +#include "api/video/video_rotation.h" +#include "api/video/video_timing.h" + namespace webrtc { -struct VCMFrameInformation; +struct VCMFrameInformation { + int64_t renderTimeMs; + absl::optional decodeStart; + void* userData; + VideoRotation rotation; + VideoContentType content_type; + EncodedImage::Timing timing; + int64_t ntp_time_ms; + RtpPacketInfos packet_infos; + // ColorSpace is not stored here, as it might be modified by decoders. +}; class VCMTimestampMap { public: explicit VCMTimestampMap(size_t capacity); ~VCMTimestampMap(); - void Add(uint32_t timestamp, VCMFrameInformation* data); - VCMFrameInformation* Pop(uint32_t timestamp); + void Add(uint32_t timestamp, const VCMFrameInformation& data); + absl::optional Pop(uint32_t timestamp); size_t Size() const; + void Clear(); private: struct TimestampDataTuple { uint32_t timestamp; - VCMFrameInformation* data; + VCMFrameInformation data; }; bool IsEmpty() const; diff --git a/modules/video_coding/timestamp_map_unittest.cc b/modules/video_coding/timestamp_map_unittest.cc new file mode 100644 index 0000000000..5e90786b95 --- /dev/null +++ b/modules/video_coding/timestamp_map_unittest.cc @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/video_coding/timestamp_map.h" + +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { +namespace video_coding { +namespace { +constexpr int kTimestampMapSize = 6; +constexpr int kTimestamp1 = 1; +constexpr int kTimestamp2 = 2; +constexpr int kNoExistingTimestamp3 = 3; +constexpr int kTimestamp4 = 4; +constexpr int kTimestamp5 = 5; +constexpr int kTimestamp6 = 6; +constexpr int kTimestamp7 = 7; +constexpr int64_t kRenderTime1 = 1000; +constexpr int64_t kRenderTime2 = 2000; +constexpr int64_t kRenderTime4 = 4000; +constexpr int64_t kRenderTime5 = 5000; +constexpr int64_t kRenderTime6 = 6000; +constexpr int64_t kRenderTime7 = 7000; +} // namespace + +class VcmTimestampMapTest : public ::testing::Test { + protected: + VcmTimestampMapTest() : _timestampMap(kTimestampMapSize) {} + + void SetUp() override { + _timestampMap.Add(kTimestamp1, VCMFrameInformation({kRenderTime1})); + _timestampMap.Add(kTimestamp2, VCMFrameInformation({kRenderTime2})); + _timestampMap.Add(kTimestamp4, VCMFrameInformation({kRenderTime4})); + } + + VCMTimestampMap _timestampMap; +}; + +TEST_F(VcmTimestampMapTest, PopExistingFrameInfo) { + EXPECT_EQ(_timestampMap.Size(), 3u); + auto frameInfo = _timestampMap.Pop(kTimestamp1); + ASSERT_TRUE(frameInfo); + EXPECT_EQ(frameInfo->renderTimeMs, kRenderTime1); + frameInfo = _timestampMap.Pop(kTimestamp2); + ASSERT_TRUE(frameInfo); + EXPECT_EQ(frameInfo->renderTimeMs, kRenderTime2); + frameInfo = _timestampMap.Pop(kTimestamp4); + ASSERT_TRUE(frameInfo); + EXPECT_EQ(frameInfo->renderTimeMs, kRenderTime4); +} + +TEST_F(VcmTimestampMapTest, PopNonexistingClearsOlderFrameInfos) { + auto frameInfo = _timestampMap.Pop(kNoExistingTimestamp3); + EXPECT_FALSE(frameInfo); + EXPECT_EQ(_timestampMap.Size(), 1u); +} + +TEST_F(VcmTimestampMapTest, SizeIsIncrementedWhenAddingNewFrameInfo) { + EXPECT_EQ(_timestampMap.Size(), 3u); + _timestampMap.Add(kTimestamp5, VCMFrameInformation({kRenderTime5})); + EXPECT_EQ(_timestampMap.Size(), 4u); + _timestampMap.Add(kTimestamp6, VCMFrameInformation({kRenderTime6})); + EXPECT_EQ(_timestampMap.Size(), 5u); +} + +TEST_F(VcmTimestampMapTest, SizeIsDecreasedWhenPoppingFrameInfo) { + EXPECT_EQ(_timestampMap.Size(), 3u); + EXPECT_TRUE(_timestampMap.Pop(kTimestamp1)); + EXPECT_EQ(_timestampMap.Size(), 2u); + EXPECT_TRUE(_timestampMap.Pop(kTimestamp2)); + EXPECT_EQ(_timestampMap.Size(), 1u); + EXPECT_FALSE(_timestampMap.Pop(kNoExistingTimestamp3)); + EXPECT_EQ(_timestampMap.Size(), 1u); + EXPECT_TRUE(_timestampMap.Pop(kTimestamp4)); + EXPECT_EQ(_timestampMap.Size(), 0u); +} + +TEST_F(VcmTimestampMapTest, ClearEmptiesMap) { + EXPECT_EQ(_timestampMap.Size(), 3u); + _timestampMap.Clear(); + EXPECT_EQ(_timestampMap.Size(), 0u); + // Clear empty map does nothing. + _timestampMap.Clear(); + EXPECT_EQ(_timestampMap.Size(), 0u); +} + +TEST_F(VcmTimestampMapTest, PopLastAddedClearsMap) { + EXPECT_EQ(_timestampMap.Size(), 3u); + EXPECT_TRUE(_timestampMap.Pop(kTimestamp4)); + EXPECT_EQ(_timestampMap.Size(), 0u); +} + +TEST_F(VcmTimestampMapTest, LastAddedIsDiscardedIfMapGetsFull) { + EXPECT_EQ(_timestampMap.Size(), 3u); + _timestampMap.Add(kTimestamp5, VCMFrameInformation({kRenderTime5})); + EXPECT_EQ(_timestampMap.Size(), 4u); + _timestampMap.Add(kTimestamp6, VCMFrameInformation({kRenderTime6})); + EXPECT_EQ(_timestampMap.Size(), 5u); + _timestampMap.Add(kTimestamp7, VCMFrameInformation({kRenderTime7})); + // Size is not incremented since the oldest element is discarded. + EXPECT_EQ(_timestampMap.Size(), 5u); + EXPECT_FALSE(_timestampMap.Pop(kTimestamp1)); + EXPECT_TRUE(_timestampMap.Pop(kTimestamp2)); + EXPECT_TRUE(_timestampMap.Pop(kTimestamp4)); + EXPECT_TRUE(_timestampMap.Pop(kTimestamp5)); + EXPECT_TRUE(_timestampMap.Pop(kTimestamp6)); + EXPECT_TRUE(_timestampMap.Pop(kTimestamp7)); + EXPECT_EQ(_timestampMap.Size(), 0u); +} + +} // namespace video_coding +} // namespace webrtc diff --git a/modules/video_coding/timing.cc b/modules/video_coding/timing.cc index eddac4f5de..ea1b59cad7 100644 --- a/modules/video_coding/timing.cc +++ b/modules/video_coding/timing.cc @@ -34,9 +34,13 @@ VCMTiming::VCMTiming(Clock* clock) prev_frame_timestamp_(0), timing_frame_info_(), num_decoded_frames_(0), - low_latency_renderer_enabled_("enabled", true) { + low_latency_renderer_enabled_("enabled", true), + zero_playout_delay_min_pacing_("min_pacing", TimeDelta::Millis(0)), + last_decode_scheduled_ts_(0) { ParseFieldTrial({&low_latency_renderer_enabled_}, field_trial::FindFullName("WebRTC-LowLatencyRenderer")); + ParseFieldTrial({&zero_playout_delay_min_pacing_}, + field_trial::FindFullName("WebRTC-ZeroPlayoutDelay")); } void VCMTiming::Reset() { @@ -153,7 +157,7 @@ void VCMTiming::StopDecodeTimer(uint32_t /*time_stamp*/, void VCMTiming::StopDecodeTimer(int32_t decode_time_ms, int64_t now_ms) { MutexLock lock(&mutex_); codec_timer_->AddTiming(decode_time_ms, now_ms); - assert(decode_time_ms >= 0); + RTC_DCHECK_GE(decode_time_ms, 0); ++num_decoded_frames_; } @@ -168,6 +172,12 @@ int64_t VCMTiming::RenderTimeMs(uint32_t frame_timestamp, return RenderTimeMsInternal(frame_timestamp, now_ms); } +void VCMTiming::SetLastDecodeScheduledTimestamp( + int64_t last_decode_scheduled_ts) { + MutexLock lock(&mutex_); + last_decode_scheduled_ts_ = last_decode_scheduled_ts; +} + int64_t VCMTiming::RenderTimeMsInternal(uint32_t frame_timestamp, int64_t now_ms) const { constexpr int kLowLatencyRendererMaxPlayoutDelayMs = 500; @@ -195,18 +205,33 @@ int64_t VCMTiming::RenderTimeMsInternal(uint32_t frame_timestamp, int VCMTiming::RequiredDecodeTimeMs() const { const int decode_time_ms = codec_timer_->RequiredDecodeTimeMs(); - assert(decode_time_ms >= 0); + RTC_DCHECK_GE(decode_time_ms, 0); return decode_time_ms; } int64_t VCMTiming::MaxWaitingTime(int64_t render_time_ms, - int64_t now_ms) const { + int64_t now_ms, + bool too_many_frames_queued) const { MutexLock lock(&mutex_); - const int64_t max_wait_time_ms = - render_time_ms - now_ms - RequiredDecodeTimeMs() - render_delay_ms_; - - return max_wait_time_ms; + if (render_time_ms == 0 && zero_playout_delay_min_pacing_->us() > 0 && + min_playout_delay_ms_ == 0 && max_playout_delay_ms_ > 0) { + // |render_time_ms| == 0 indicates that the frame should be decoded and + // rendered as soon as possible. However, the decoder can be choked if too + // many frames are sent at once. Therefore, limit the interframe delay to + // |zero_playout_delay_min_pacing_| unless too many frames are queued in + // which case the frames are sent to the decoder at once. + if (too_many_frames_queued) { + return 0; + } + int64_t earliest_next_decode_start_time = + last_decode_scheduled_ts_ + zero_playout_delay_min_pacing_->ms(); + int64_t max_wait_time_ms = now_ms >= earliest_next_decode_start_time + ? 0 + : earliest_next_decode_start_time - now_ms; + return max_wait_time_ms; + } + return render_time_ms - now_ms - RequiredDecodeTimeMs() - render_delay_ms_; } int VCMTiming::TargetVideoDelay() const { diff --git a/modules/video_coding/timing.h b/modules/video_coding/timing.h index 736b5e9ae4..7f891e4b9b 100644 --- a/modules/video_coding/timing.h +++ b/modules/video_coding/timing.h @@ -14,6 +14,7 @@ #include #include "absl/types/optional.h" +#include "api/units/time_delta.h" #include "api/video/video_timing.h" #include "modules/video_coding/codec_timer.h" #include "rtc_base/experiments/field_trial_parser.h" @@ -81,8 +82,15 @@ class VCMTiming { virtual int64_t RenderTimeMs(uint32_t frame_timestamp, int64_t now_ms) const; // Returns the maximum time in ms that we can wait for a frame to become - // complete before we must pass it to the decoder. - virtual int64_t MaxWaitingTime(int64_t render_time_ms, int64_t now_ms) const; + // complete before we must pass it to the decoder. render_time_ms==0 indicates + // that the frames should be processed as quickly as possible, with possibly + // only a small delay added to make sure that the decoder is not overloaded. + // In this case, the parameter too_many_frames_queued is used to signal that + // the decode queue is full and that the frame should be decoded as soon as + // possible. + virtual int64_t MaxWaitingTime(int64_t render_time_ms, + int64_t now_ms, + bool too_many_frames_queued) const; // Returns the current target delay which is required delay + decode time + // render delay. @@ -104,6 +112,9 @@ class VCMTiming { absl::optional max_composition_delay_in_frames); absl::optional MaxCompositionDelayInFrames() const; + // Updates the last time a frame was scheduled for decoding. + void SetLastDecodeScheduledTimestamp(int64_t last_decode_scheduled_ts); + enum { kDefaultRenderDelayMs = 10 }; enum { kDelayMaxChangeMsPerS = 100 }; @@ -139,6 +150,15 @@ class VCMTiming { FieldTrialParameter low_latency_renderer_enabled_ RTC_GUARDED_BY(mutex_); absl::optional max_composition_delay_in_frames_ RTC_GUARDED_BY(mutex_); + // Set by the field trial WebRTC-ZeroPlayoutDelay. The parameter min_pacing + // determines the minimum delay between frames scheduled for decoding that is + // used when min playout delay=0 and max playout delay>=0. + FieldTrialParameter zero_playout_delay_min_pacing_ + RTC_GUARDED_BY(mutex_); + // Timestamp at which the last frame was scheduled to be sent to the decoder. + // Used only when the RTP header extension playout delay is set to min=0 ms + // which is indicated by a render time set to 0. + int64_t last_decode_scheduled_ts_ RTC_GUARDED_BY(mutex_); }; } // namespace webrtc diff --git a/modules/video_coding/timing_unittest.cc b/modules/video_coding/timing_unittest.cc index ee86605fb6..cc87a3b4e0 100644 --- a/modules/video_coding/timing_unittest.cc +++ b/modules/video_coding/timing_unittest.cc @@ -11,6 +11,7 @@ #include "modules/video_coding/timing.h" #include "system_wrappers/include/clock.h" +#include "test/field_trial.h" #include "test/gtest.h" namespace webrtc { @@ -18,7 +19,7 @@ namespace { const int kFps = 25; } // namespace -TEST(ReceiverTiming, Tests) { +TEST(ReceiverTimingTest, JitterDelay) { SimulatedClock clock(0); VCMTiming timing(&clock); timing.Reset(); @@ -35,7 +36,7 @@ TEST(ReceiverTiming, Tests) { timing.set_render_delay(0); uint32_t wait_time_ms = timing.MaxWaitingTime( timing.RenderTimeMs(timestamp, clock.TimeInMilliseconds()), - clock.TimeInMilliseconds()); + clock.TimeInMilliseconds(), /*too_many_frames_queued=*/false); // First update initializes the render time. Since we have no decode delay // we get wait_time_ms = renderTime - now - renderDelay = jitter. EXPECT_EQ(jitter_delay_ms, wait_time_ms); @@ -47,7 +48,7 @@ TEST(ReceiverTiming, Tests) { timing.UpdateCurrentDelay(timestamp); wait_time_ms = timing.MaxWaitingTime( timing.RenderTimeMs(timestamp, clock.TimeInMilliseconds()), - clock.TimeInMilliseconds()); + clock.TimeInMilliseconds(), /*too_many_frames_queued=*/false); // Since we gradually increase the delay we only get 100 ms every second. EXPECT_EQ(jitter_delay_ms - 10, wait_time_ms); @@ -56,7 +57,7 @@ TEST(ReceiverTiming, Tests) { timing.UpdateCurrentDelay(timestamp); wait_time_ms = timing.MaxWaitingTime( timing.RenderTimeMs(timestamp, clock.TimeInMilliseconds()), - clock.TimeInMilliseconds()); + clock.TimeInMilliseconds(), /*too_many_frames_queued=*/false); EXPECT_EQ(jitter_delay_ms, wait_time_ms); // Insert frames without jitter, verify that this gives the exact wait time. @@ -69,7 +70,7 @@ TEST(ReceiverTiming, Tests) { timing.UpdateCurrentDelay(timestamp); wait_time_ms = timing.MaxWaitingTime( timing.RenderTimeMs(timestamp, clock.TimeInMilliseconds()), - clock.TimeInMilliseconds()); + clock.TimeInMilliseconds(), /*too_many_frames_queued=*/false); EXPECT_EQ(jitter_delay_ms, wait_time_ms); // Add decode time estimates for 1 second. @@ -84,7 +85,7 @@ TEST(ReceiverTiming, Tests) { timing.UpdateCurrentDelay(timestamp); wait_time_ms = timing.MaxWaitingTime( timing.RenderTimeMs(timestamp, clock.TimeInMilliseconds()), - clock.TimeInMilliseconds()); + clock.TimeInMilliseconds(), /*too_many_frames_queued=*/false); EXPECT_EQ(jitter_delay_ms, wait_time_ms); const int kMinTotalDelayMs = 200; @@ -96,7 +97,7 @@ TEST(ReceiverTiming, Tests) { timing.set_render_delay(kRenderDelayMs); wait_time_ms = timing.MaxWaitingTime( timing.RenderTimeMs(timestamp, clock.TimeInMilliseconds()), - clock.TimeInMilliseconds()); + clock.TimeInMilliseconds(), /*too_many_frames_queued=*/false); // We should at least have kMinTotalDelayMs - decodeTime (10) - renderTime // (10) to wait. EXPECT_EQ(kMinTotalDelayMs - kDecodeTimeMs - kRenderDelayMs, wait_time_ms); @@ -110,7 +111,7 @@ TEST(ReceiverTiming, Tests) { timing.UpdateCurrentDelay(timestamp); } -TEST(ReceiverTiming, WrapAround) { +TEST(ReceiverTimingTest, TimestampWrapAround) { SimulatedClock clock(0); VCMTiming timing(&clock); // Provoke a wrap-around. The fifth frame will have wrapped at 25 fps. @@ -127,4 +128,155 @@ TEST(ReceiverTiming, WrapAround) { } } +TEST(ReceiverTimingTest, MaxWaitingTimeIsZeroForZeroRenderTime) { + // This is the default path when the RTP playout delay header extension is set + // to min==0. + constexpr int64_t kStartTimeUs = 3.15e13; // About one year in us. + constexpr int64_t kTimeDeltaMs = 1000.0 / 60.0; + constexpr int64_t kZeroRenderTimeMs = 0; + SimulatedClock clock(kStartTimeUs); + VCMTiming timing(&clock); + timing.Reset(); + for (int i = 0; i < 10; ++i) { + clock.AdvanceTimeMilliseconds(kTimeDeltaMs); + int64_t now_ms = clock.TimeInMilliseconds(); + EXPECT_LT(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/false), + 0); + } + // Another frame submitted at the same time also returns a negative max + // waiting time. + int64_t now_ms = clock.TimeInMilliseconds(); + EXPECT_LT(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/false), + 0); + // MaxWaitingTime should be less than zero even if there's a burst of frames. + EXPECT_LT(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/false), + 0); + EXPECT_LT(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/false), + 0); + EXPECT_LT(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/false), + 0); +} + +TEST(ReceiverTimingTest, MaxWaitingTimeZeroDelayPacingExperiment) { + // The minimum pacing is enabled by a field trial and active if the RTP + // playout delay header extension is set to min==0. + constexpr int64_t kMinPacingMs = 3; + test::ScopedFieldTrials override_field_trials( + "WebRTC-ZeroPlayoutDelay/min_pacing:3ms/"); + constexpr int64_t kStartTimeUs = 3.15e13; // About one year in us. + constexpr int64_t kTimeDeltaMs = 1000.0 / 60.0; + constexpr int64_t kZeroRenderTimeMs = 0; + SimulatedClock clock(kStartTimeUs); + VCMTiming timing(&clock); + timing.Reset(); + // MaxWaitingTime() returns zero for evenly spaced video frames. + for (int i = 0; i < 10; ++i) { + clock.AdvanceTimeMilliseconds(kTimeDeltaMs); + int64_t now_ms = clock.TimeInMilliseconds(); + EXPECT_EQ(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/false), + 0); + timing.SetLastDecodeScheduledTimestamp(now_ms); + } + // Another frame submitted at the same time is paced according to the field + // trial setting. + int64_t now_ms = clock.TimeInMilliseconds(); + EXPECT_EQ(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/false), + kMinPacingMs); + // If there's a burst of frames, the wait time is calculated based on next + // decode time. + EXPECT_EQ(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/false), + kMinPacingMs); + EXPECT_EQ(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/false), + kMinPacingMs); + // Allow a few ms to pass, this should be subtracted from the MaxWaitingTime. + constexpr int64_t kTwoMs = 2; + clock.AdvanceTimeMilliseconds(kTwoMs); + now_ms = clock.TimeInMilliseconds(); + EXPECT_EQ(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/false), + kMinPacingMs - kTwoMs); + // A frame is decoded at the current time, the wait time should be restored to + // pacing delay. + timing.SetLastDecodeScheduledTimestamp(now_ms); + EXPECT_EQ(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/false), + kMinPacingMs); +} + +TEST(ReceiverTimingTest, DefaultMaxWaitingTimeUnaffectedByPacingExperiment) { + // The minimum pacing is enabled by a field trial but should not have any + // effect if render_time_ms is greater than 0; + test::ScopedFieldTrials override_field_trials( + "WebRTC-ZeroPlayoutDelay/min_pacing:3ms/"); + constexpr int64_t kStartTimeUs = 3.15e13; // About one year in us. + constexpr int64_t kTimeDeltaMs = 1000.0 / 60.0; + SimulatedClock clock(kStartTimeUs); + VCMTiming timing(&clock); + timing.Reset(); + clock.AdvanceTimeMilliseconds(kTimeDeltaMs); + int64_t now_ms = clock.TimeInMilliseconds(); + int64_t render_time_ms = now_ms + 30; + // Estimate the internal processing delay from the first frame. + int64_t estimated_processing_delay = + (render_time_ms - now_ms) - + timing.MaxWaitingTime(render_time_ms, now_ms, + /*too_many_frames_queued=*/false); + EXPECT_GT(estimated_processing_delay, 0); + + // Any other frame submitted at the same time should be scheduled according to + // its render time. + for (int i = 0; i < 5; ++i) { + render_time_ms += kTimeDeltaMs; + EXPECT_EQ(timing.MaxWaitingTime(render_time_ms, now_ms, + /*too_many_frames_queued=*/false), + render_time_ms - now_ms - estimated_processing_delay); + } +} + +TEST(ReceiverTiminTest, MaxWaitingTimeReturnsZeroIfTooManyFramesQueuedIsTrue) { + // The minimum pacing is enabled by a field trial and active if the RTP + // playout delay header extension is set to min==0. + constexpr int64_t kMinPacingMs = 3; + test::ScopedFieldTrials override_field_trials( + "WebRTC-ZeroPlayoutDelay/min_pacing:3ms/"); + constexpr int64_t kStartTimeUs = 3.15e13; // About one year in us. + constexpr int64_t kTimeDeltaMs = 1000.0 / 60.0; + constexpr int64_t kZeroRenderTimeMs = 0; + SimulatedClock clock(kStartTimeUs); + VCMTiming timing(&clock); + timing.Reset(); + // MaxWaitingTime() returns zero for evenly spaced video frames. + for (int i = 0; i < 10; ++i) { + clock.AdvanceTimeMilliseconds(kTimeDeltaMs); + int64_t now_ms = clock.TimeInMilliseconds(); + EXPECT_EQ(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/false), + 0); + timing.SetLastDecodeScheduledTimestamp(now_ms); + } + // Another frame submitted at the same time is paced according to the field + // trial setting. + int64_t now_ms = clock.TimeInMilliseconds(); + EXPECT_EQ(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/false), + kMinPacingMs); + // MaxWaitingTime returns 0 even if there's a burst of frames if + // too_many_frames_queued is set to true. + EXPECT_EQ(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/true), + 0); + EXPECT_EQ(timing.MaxWaitingTime(kZeroRenderTimeMs, now_ms, + /*too_many_frames_queued=*/true), + 0); +} + } // namespace webrtc diff --git a/modules/video_coding/utility/decoded_frames_history.cc b/modules/video_coding/utility/decoded_frames_history.cc index d15cf26d8d..005bb26ea6 100644 --- a/modules/video_coding/utility/decoded_frames_history.cc +++ b/modules/video_coding/utility/decoded_frames_history.cc @@ -18,89 +18,63 @@ namespace webrtc { namespace video_coding { -DecodedFramesHistory::LayerHistory::LayerHistory() = default; -DecodedFramesHistory::LayerHistory::~LayerHistory() = default; - DecodedFramesHistory::DecodedFramesHistory(size_t window_size) - : window_size_(window_size) {} + : buffer_(window_size) {} DecodedFramesHistory::~DecodedFramesHistory() = default; -void DecodedFramesHistory::InsertDecoded(const VideoLayerFrameId& frameid, - uint32_t timestamp) { - last_decoded_frame_ = frameid; +void DecodedFramesHistory::InsertDecoded(int64_t frame_id, uint32_t timestamp) { + last_decoded_frame_ = frame_id; last_decoded_frame_timestamp_ = timestamp; - if (static_cast(layers_.size()) < frameid.spatial_layer + 1) { - size_t old_size = layers_.size(); - layers_.resize(frameid.spatial_layer + 1); - - for (size_t i = old_size; i < layers_.size(); ++i) - layers_[i].buffer.resize(window_size_); - - layers_[frameid.spatial_layer].last_picture_id = frameid.picture_id; - layers_[frameid.spatial_layer] - .buffer[PictureIdToIndex(frameid.picture_id)] = true; - return; - } - - int new_index = PictureIdToIndex(frameid.picture_id); - LayerHistory& history = layers_[frameid.spatial_layer]; + int new_index = FrameIdToIndex(frame_id); - RTC_DCHECK(history.last_picture_id < frameid.picture_id); + RTC_DCHECK(last_frame_id_ < frame_id); - // Clears expired values from the cyclic buffer. - if (history.last_picture_id) { - int64_t id_jump = frameid.picture_id - *history.last_picture_id; - int last_index = PictureIdToIndex(*history.last_picture_id); + // Clears expired values from the cyclic buffer_. + if (last_frame_id_) { + int64_t id_jump = frame_id - *last_frame_id_; + int last_index = FrameIdToIndex(*last_frame_id_); - if (id_jump >= window_size_) { - std::fill(history.buffer.begin(), history.buffer.end(), false); + if (id_jump >= static_cast(buffer_.size())) { + std::fill(buffer_.begin(), buffer_.end(), false); } else if (new_index > last_index) { - std::fill(history.buffer.begin() + last_index + 1, - history.buffer.begin() + new_index, false); - } else { - std::fill(history.buffer.begin() + last_index + 1, history.buffer.end(), - false); - std::fill(history.buffer.begin(), history.buffer.begin() + new_index, + std::fill(buffer_.begin() + last_index + 1, buffer_.begin() + new_index, false); + } else { + std::fill(buffer_.begin() + last_index + 1, buffer_.end(), false); + std::fill(buffer_.begin(), buffer_.begin() + new_index, false); } } - history.buffer[new_index] = true; - history.last_picture_id = frameid.picture_id; + buffer_[new_index] = true; + last_frame_id_ = frame_id; } -bool DecodedFramesHistory::WasDecoded(const VideoLayerFrameId& frameid) { - // Unseen before spatial layer. - if (static_cast(layers_.size()) < frameid.spatial_layer + 1) - return false; - - LayerHistory& history = layers_[frameid.spatial_layer]; - - if (!history.last_picture_id) +bool DecodedFramesHistory::WasDecoded(int64_t frame_id) { + if (!last_frame_id_) return false; - // Reference to the picture_id out of the stored history should happen. - if (frameid.picture_id <= *history.last_picture_id - window_size_) { - RTC_LOG(LS_WARNING) << "Referencing a frame out of the history window. " + // Reference to the picture_id out of the stored should happen. + if (frame_id <= *last_frame_id_ - static_cast(buffer_.size())) { + RTC_LOG(LS_WARNING) << "Referencing a frame out of the window. " "Assuming it was undecoded to avoid artifacts."; return false; } - if (frameid.picture_id > history.last_picture_id) + if (frame_id > last_frame_id_) return false; - return history.buffer[PictureIdToIndex(frameid.picture_id)]; + return buffer_[FrameIdToIndex(frame_id)]; } void DecodedFramesHistory::Clear() { - layers_.clear(); last_decoded_frame_timestamp_.reset(); last_decoded_frame_.reset(); + std::fill(buffer_.begin(), buffer_.end(), false); + last_frame_id_.reset(); } -absl::optional -DecodedFramesHistory::GetLastDecodedFrameId() { +absl::optional DecodedFramesHistory::GetLastDecodedFrameId() { return last_decoded_frame_; } @@ -108,9 +82,9 @@ absl::optional DecodedFramesHistory::GetLastDecodedFrameTimestamp() { return last_decoded_frame_timestamp_; } -int DecodedFramesHistory::PictureIdToIndex(int64_t frame_id) const { - int m = frame_id % window_size_; - return m >= 0 ? m : m + window_size_; +int DecodedFramesHistory::FrameIdToIndex(int64_t frame_id) const { + int m = frame_id % buffer_.size(); + return m >= 0 ? m : m + buffer_.size(); } } // namespace video_coding diff --git a/modules/video_coding/utility/decoded_frames_history.h b/modules/video_coding/utility/decoded_frames_history.h index 7cbe1f5cfc..06008dc22e 100644 --- a/modules/video_coding/utility/decoded_frames_history.h +++ b/modules/video_coding/utility/decoded_frames_history.h @@ -27,31 +27,23 @@ class DecodedFramesHistory { // window_size - how much frames back to the past are actually remembered. explicit DecodedFramesHistory(size_t window_size); ~DecodedFramesHistory(); - // Called for each decoded frame. Assumes picture id's are non-decreasing. - void InsertDecoded(const VideoLayerFrameId& frameid, uint32_t timestamp); - // Query if the following (picture_id, spatial_id) pair was inserted before. - // Should be at most less by window_size-1 than the last inserted picture id. - bool WasDecoded(const VideoLayerFrameId& frameid); + // Called for each decoded frame. Assumes frame id's are non-decreasing. + void InsertDecoded(int64_t frame_id, uint32_t timestamp); + // Query if the following (frame_id, spatial_id) pair was inserted before. + // Should be at most less by window_size-1 than the last inserted frame id. + bool WasDecoded(int64_t frame_id); void Clear(); - absl::optional GetLastDecodedFrameId(); + absl::optional GetLastDecodedFrameId(); absl::optional GetLastDecodedFrameTimestamp(); private: - struct LayerHistory { - LayerHistory(); - ~LayerHistory(); - // Cyclic bitset buffer. Stores last known |window_size| bits. - std::vector buffer; - absl::optional last_picture_id; - }; - - int PictureIdToIndex(int64_t frame_id) const; - - const int window_size_; - std::vector layers_; - absl::optional last_decoded_frame_; + int FrameIdToIndex(int64_t frame_id) const; + + std::vector buffer_; + absl::optional last_frame_id_; + absl::optional last_decoded_frame_; absl::optional last_decoded_frame_timestamp_; }; diff --git a/modules/video_coding/utility/decoded_frames_history_unittest.cc b/modules/video_coding/utility/decoded_frames_history_unittest.cc index ccf393d403..ac09a42053 100644 --- a/modules/video_coding/utility/decoded_frames_history_unittest.cc +++ b/modules/video_coding/utility/decoded_frames_history_unittest.cc @@ -20,125 +20,93 @@ constexpr int kHistorySize = 1 << 13; TEST(DecodedFramesHistory, RequestOnEmptyHistory) { DecodedFramesHistory history(kHistorySize); - EXPECT_EQ(history.WasDecoded({1234, 0}), false); + EXPECT_EQ(history.WasDecoded(1234), false); } TEST(DecodedFramesHistory, FindsLastDecodedFrame) { DecodedFramesHistory history(kHistorySize); - history.InsertDecoded({1234, 0}, 0); - EXPECT_EQ(history.WasDecoded({1234, 0}), true); + history.InsertDecoded(1234, 0); + EXPECT_EQ(history.WasDecoded(1234), true); } TEST(DecodedFramesHistory, FindsPreviousFrame) { DecodedFramesHistory history(kHistorySize); - history.InsertDecoded({1234, 0}, 0); - history.InsertDecoded({1235, 0}, 0); - EXPECT_EQ(history.WasDecoded({1234, 0}), true); + history.InsertDecoded(1234, 0); + history.InsertDecoded(1235, 0); + EXPECT_EQ(history.WasDecoded(1234), true); } TEST(DecodedFramesHistory, ReportsMissingFrame) { DecodedFramesHistory history(kHistorySize); - history.InsertDecoded({1234, 0}, 0); - history.InsertDecoded({1236, 0}, 0); - EXPECT_EQ(history.WasDecoded({1235, 0}), false); + history.InsertDecoded(1234, 0); + history.InsertDecoded(1236, 0); + EXPECT_EQ(history.WasDecoded(1235), false); } TEST(DecodedFramesHistory, ClearsHistory) { DecodedFramesHistory history(kHistorySize); - history.InsertDecoded({1234, 0}, 0); + history.InsertDecoded(1234, 0); history.Clear(); - EXPECT_EQ(history.WasDecoded({1234, 0}), false); + EXPECT_EQ(history.WasDecoded(1234), false); EXPECT_EQ(history.GetLastDecodedFrameId(), absl::nullopt); EXPECT_EQ(history.GetLastDecodedFrameTimestamp(), absl::nullopt); } -TEST(DecodedFramesHistory, HandlesMultipleLayers) { - DecodedFramesHistory history(kHistorySize); - history.InsertDecoded({1234, 0}, 0); - history.InsertDecoded({1234, 1}, 0); - history.InsertDecoded({1235, 0}, 0); - history.InsertDecoded({1236, 0}, 0); - history.InsertDecoded({1236, 1}, 0); - EXPECT_EQ(history.WasDecoded({1235, 0}), true); - EXPECT_EQ(history.WasDecoded({1235, 1}), false); -} - -TEST(DecodedFramesHistory, HandlesNewLayer) { - DecodedFramesHistory history(kHistorySize); - history.InsertDecoded({1234, 0}, 0); - history.InsertDecoded({1234, 1}, 0); - history.InsertDecoded({1235, 0}, 0); - history.InsertDecoded({1235, 1}, 0); - history.InsertDecoded({1236, 0}, 0); - history.InsertDecoded({1236, 1}, 0); - EXPECT_EQ(history.WasDecoded({1234, 2}), false); -} - -TEST(DecodedFramesHistory, HandlesSkippedLayer) { - DecodedFramesHistory history(kHistorySize); - history.InsertDecoded({1234, 0}, 0); - history.InsertDecoded({1234, 2}, 0); - history.InsertDecoded({1235, 0}, 0); - history.InsertDecoded({1235, 1}, 0); - EXPECT_EQ(history.WasDecoded({1234, 1}), false); - EXPECT_EQ(history.WasDecoded({1235, 1}), true); -} - TEST(DecodedFramesHistory, HandlesBigJumpInPictureId) { DecodedFramesHistory history(kHistorySize); - history.InsertDecoded({1234, 0}, 0); - history.InsertDecoded({1235, 0}, 0); - history.InsertDecoded({1236, 0}, 0); - history.InsertDecoded({1236 + kHistorySize / 2, 0}, 0); - EXPECT_EQ(history.WasDecoded({1234, 0}), true); - EXPECT_EQ(history.WasDecoded({1237, 0}), false); + history.InsertDecoded(1234, 0); + history.InsertDecoded(1235, 0); + history.InsertDecoded(1236, 0); + history.InsertDecoded(1236 + kHistorySize / 2, 0); + EXPECT_EQ(history.WasDecoded(1234), true); + EXPECT_EQ(history.WasDecoded(1237), false); } TEST(DecodedFramesHistory, ForgetsTooOldHistory) { DecodedFramesHistory history(kHistorySize); - history.InsertDecoded({1234, 0}, 0); - history.InsertDecoded({1235, 0}, 0); - history.InsertDecoded({1236, 0}, 0); - history.InsertDecoded({1236 + kHistorySize * 2, 0}, 0); - EXPECT_EQ(history.WasDecoded({1234, 0}), false); - EXPECT_EQ(history.WasDecoded({1237, 0}), false); + history.InsertDecoded(1234, 0); + history.InsertDecoded(1235, 0); + history.InsertDecoded(1236, 0); + history.InsertDecoded(1236 + kHistorySize * 2, 0); + EXPECT_EQ(history.WasDecoded(1234), false); + EXPECT_EQ(history.WasDecoded(1237), false); } TEST(DecodedFramesHistory, ReturnsLastDecodedFrameId) { DecodedFramesHistory history(kHistorySize); EXPECT_EQ(history.GetLastDecodedFrameId(), absl::nullopt); - history.InsertDecoded({1234, 0}, 0); - EXPECT_EQ(history.GetLastDecodedFrameId(), VideoLayerFrameId(1234, 0)); - history.InsertDecoded({1235, 0}, 0); - EXPECT_EQ(history.GetLastDecodedFrameId(), VideoLayerFrameId(1235, 0)); + history.InsertDecoded(1234, 0); + EXPECT_EQ(history.GetLastDecodedFrameId(), 1234); + history.InsertDecoded(1235, 0); + EXPECT_EQ(history.GetLastDecodedFrameId(), 1235); } TEST(DecodedFramesHistory, ReturnsLastDecodedFrameTimestamp) { DecodedFramesHistory history(kHistorySize); EXPECT_EQ(history.GetLastDecodedFrameTimestamp(), absl::nullopt); - history.InsertDecoded({1234, 0}, 12345); + history.InsertDecoded(1234, 12345); EXPECT_EQ(history.GetLastDecodedFrameTimestamp(), 12345u); - history.InsertDecoded({1235, 0}, 12366); + history.InsertDecoded(1235, 12366); EXPECT_EQ(history.GetLastDecodedFrameTimestamp(), 12366u); } TEST(DecodedFramesHistory, NegativePictureIds) { DecodedFramesHistory history(kHistorySize); - history.InsertDecoded({-1234, 0}, 12345); - history.InsertDecoded({-1233, 0}, 12366); - EXPECT_EQ(history.GetLastDecodedFrameId()->picture_id, -1233); + history.InsertDecoded(-1234, 12345); + history.InsertDecoded(-1233, 12366); + EXPECT_EQ(*history.GetLastDecodedFrameId(), -1233); - history.InsertDecoded({-1, 0}, 12377); - history.InsertDecoded({0, 0}, 12388); - EXPECT_EQ(history.GetLastDecodedFrameId()->picture_id, 0); + history.InsertDecoded(-1, 12377); + history.InsertDecoded(0, 12388); + EXPECT_EQ(*history.GetLastDecodedFrameId(), 0); - history.InsertDecoded({1, 0}, 12399); - EXPECT_EQ(history.GetLastDecodedFrameId()->picture_id, 1); + history.InsertDecoded(1, 12399); + EXPECT_EQ(*history.GetLastDecodedFrameId(), 1); - EXPECT_EQ(history.WasDecoded({-1234, 0}), true); - EXPECT_EQ(history.WasDecoded({-1, 0}), true); - EXPECT_EQ(history.WasDecoded({0, 0}), true); - EXPECT_EQ(history.WasDecoded({1, 0}), true); + EXPECT_EQ(history.WasDecoded(-1234), true); + EXPECT_EQ(history.WasDecoded(-1), true); + EXPECT_EQ(history.WasDecoded(0), true); + EXPECT_EQ(history.WasDecoded(1), true); } } // namespace diff --git a/modules/video_coding/utility/frame_dropper.h b/modules/video_coding/utility/frame_dropper.h index 50a8d58e66..014b5dd7aa 100644 --- a/modules/video_coding/utility/frame_dropper.h +++ b/modules/video_coding/utility/frame_dropper.h @@ -44,7 +44,7 @@ class FrameDropper { // Input: // - framesize_bytes : The size of the latest frame returned // from the encoder. - // - delta_frame : True if the encoder returned a key frame. + // - delta_frame : True if the encoder returned a delta frame. void Fill(size_t framesize_bytes, bool delta_frame); void Leak(uint32_t input_framerate); diff --git a/modules/video_coding/utility/ivf_file_reader.cc b/modules/video_coding/utility/ivf_file_reader.cc index e3c249947d..f326c8cb53 100644 --- a/modules/video_coding/utility/ivf_file_reader.cc +++ b/modules/video_coding/utility/ivf_file_reader.cc @@ -164,7 +164,7 @@ absl::optional IvfFileReader::NextFrame() { image.SetTimestamp(static_cast(current_timestamp)); } image.SetEncodedData(payload); - image.SetSpatialIndex(static_cast(layer_sizes.size())); + image.SetSpatialIndex(static_cast(layer_sizes.size()) - 1); for (size_t i = 0; i < layer_sizes.size(); ++i) { image.SetSpatialLayerFrameSize(static_cast(i), layer_sizes[i]); } diff --git a/modules/video_coding/utility/ivf_file_reader_unittest.cc b/modules/video_coding/utility/ivf_file_reader_unittest.cc index 58a808840d..c9cf14674b 100644 --- a/modules/video_coding/utility/ivf_file_reader_unittest.cc +++ b/modules/video_coding/utility/ivf_file_reader_unittest.cc @@ -83,7 +83,7 @@ class IvfFileReaderTest : public ::testing::Test { bool use_capture_tims_ms, int spatial_layers_count) { ASSERT_TRUE(frame); - EXPECT_EQ(frame->SpatialIndex(), spatial_layers_count); + EXPECT_EQ(frame->SpatialIndex(), spatial_layers_count - 1); if (use_capture_tims_ms) { EXPECT_EQ(frame->capture_time_ms_, static_cast(frame_index)); EXPECT_EQ(frame->Timestamp(), static_cast(90 * frame_index)); diff --git a/modules/video_coding/utility/qp_parser.cc b/modules/video_coding/utility/qp_parser.cc new file mode 100644 index 0000000000..18f225447d --- /dev/null +++ b/modules/video_coding/utility/qp_parser.cc @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/video_coding/utility/qp_parser.h" + +#include "modules/video_coding/utility/vp8_header_parser.h" +#include "modules/video_coding/utility/vp9_uncompressed_header_parser.h" + +namespace webrtc { + +absl::optional QpParser::Parse(VideoCodecType codec_type, + size_t spatial_idx, + const uint8_t* frame_data, + size_t frame_size) { + if (frame_data == nullptr || frame_size == 0 || + spatial_idx >= kMaxSimulcastStreams) { + return absl::nullopt; + } + + if (codec_type == kVideoCodecVP8) { + int qp = -1; + if (vp8::GetQp(frame_data, frame_size, &qp)) { + return qp; + } + } else if (codec_type == kVideoCodecVP9) { + int qp = -1; + if (vp9::GetQp(frame_data, frame_size, &qp)) { + return qp; + } + } else if (codec_type == kVideoCodecH264) { + return h264_parsers_[spatial_idx].Parse(frame_data, frame_size); + } + + return absl::nullopt; +} + +absl::optional QpParser::H264QpParser::Parse( + const uint8_t* frame_data, + size_t frame_size) { + MutexLock lock(&mutex_); + bitstream_parser_.ParseBitstream( + rtc::ArrayView(frame_data, frame_size)); + return bitstream_parser_.GetLastSliceQp(); +} + +} // namespace webrtc diff --git a/modules/video_coding/utility/qp_parser.h b/modules/video_coding/utility/qp_parser.h new file mode 100644 index 0000000000..f132ff9337 --- /dev/null +++ b/modules/video_coding/utility/qp_parser.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_VIDEO_CODING_UTILITY_QP_PARSER_H_ +#define MODULES_VIDEO_CODING_UTILITY_QP_PARSER_H_ + +#include "absl/types/optional.h" +#include "api/video/video_codec_constants.h" +#include "api/video/video_codec_type.h" +#include "common_video/h264/h264_bitstream_parser.h" +#include "rtc_base/synchronization/mutex.h" + +namespace webrtc { +class QpParser { + public: + absl::optional Parse(VideoCodecType codec_type, + size_t spatial_idx, + const uint8_t* frame_data, + size_t frame_size); + + private: + // A thread safe wrapper for H264 bitstream parser. + class H264QpParser { + public: + absl::optional Parse(const uint8_t* frame_data, + size_t frame_size); + + private: + Mutex mutex_; + H264BitstreamParser bitstream_parser_ RTC_GUARDED_BY(mutex_); + }; + + H264QpParser h264_parsers_[kMaxSimulcastStreams]; +}; + +} // namespace webrtc + +#endif // MODULES_VIDEO_CODING_UTILITY_QP_PARSER_H_ diff --git a/modules/video_coding/utility/qp_parser_unittest.cc b/modules/video_coding/utility/qp_parser_unittest.cc new file mode 100644 index 0000000000..1131288f26 --- /dev/null +++ b/modules/video_coding/utility/qp_parser_unittest.cc @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/video_coding/utility/qp_parser.h" + +#include + +#include "test/gtest.h" + +namespace webrtc { + +namespace { +// ffmpeg -s 16x16 -f rawvideo -pix_fmt rgb24 -r 30 -i /dev/zero -c:v libvpx +// -qmin 20 -qmax 20 -crf 20 -frames:v 1 -y out.ivf +const uint8_t kCodedFrameVp8Qp25[] = { + 0x10, 0x02, 0x00, 0x9d, 0x01, 0x2a, 0x10, 0x00, 0x10, 0x00, + 0x02, 0x47, 0x08, 0x85, 0x85, 0x88, 0x85, 0x84, 0x88, 0x0c, + 0x82, 0x00, 0x0c, 0x0d, 0x60, 0x00, 0xfe, 0xfc, 0x5c, 0xd0}; + +// ffmpeg -s 16x16 -f rawvideo -pix_fmt rgb24 -r 30 -i /dev/zero -c:v libvpx-vp9 +// -qmin 24 -qmax 24 -crf 24 -frames:v 1 -y out.ivf +const uint8_t kCodedFrameVp9Qp96[] = { + 0xa2, 0x49, 0x83, 0x42, 0xe0, 0x00, 0xf0, 0x00, 0xf6, 0x00, + 0x38, 0x24, 0x1c, 0x18, 0xc0, 0x00, 0x00, 0x30, 0x70, 0x00, + 0x00, 0x4a, 0xa7, 0xff, 0xfc, 0xb9, 0x01, 0xbf, 0xff, 0xff, + 0x97, 0x20, 0xdb, 0xff, 0xff, 0xcb, 0x90, 0x5d, 0x40}; + +// ffmpeg -s 16x16 -f rawvideo -pix_fmt yuv420p -r 30 -i /dev/zero -c:v libx264 +// -qmin 38 -qmax 38 -crf 38 -profile:v baseline -frames:v 2 -y out.264 +const uint8_t kCodedFrameH264SpsPpsIdrQp38[] = { + 0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0xc0, 0x0a, 0xd9, 0x1e, 0x84, + 0x00, 0x00, 0x03, 0x00, 0x04, 0x00, 0x00, 0x03, 0x00, 0xf0, 0x3c, + 0x48, 0x99, 0x20, 0x00, 0x00, 0x00, 0x01, 0x68, 0xcb, 0x80, 0xc4, + 0xb2, 0x00, 0x00, 0x01, 0x65, 0x88, 0x84, 0xf1, 0x18, 0xa0, 0x00, + 0x20, 0x5b, 0x1c, 0x00, 0x04, 0x07, 0xe3, 0x80, 0x00, 0x80, 0xfe}; + +const uint8_t kCodedFrameH264SpsPpsIdrQp49[] = { + 0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0xc0, 0x0a, 0xd9, 0x1e, 0x84, + 0x00, 0x00, 0x03, 0x00, 0x04, 0x00, 0x00, 0x03, 0x00, 0xf0, 0x3c, + 0x48, 0x99, 0x20, 0x00, 0x00, 0x00, 0x01, 0x68, 0xcb, 0x80, 0x5d, + 0x2c, 0x80, 0x00, 0x00, 0x01, 0x65, 0x88, 0x84, 0xf1, 0x18, 0xa0, + 0x00, 0x5e, 0x38, 0x00, 0x08, 0x03, 0xc7, 0x00, 0x01, 0x00, 0x7c}; + +const uint8_t kCodedFrameH264InterSliceQpDelta0[] = {0x00, 0x00, 0x00, 0x01, + 0x41, 0x9a, 0x39, 0xea}; + +} // namespace + +TEST(QpParserTest, ParseQpVp8) { + QpParser parser; + absl::optional qp = parser.Parse( + kVideoCodecVP8, 0, kCodedFrameVp8Qp25, sizeof(kCodedFrameVp8Qp25)); + EXPECT_EQ(qp, 25u); +} + +TEST(QpParserTest, ParseQpVp9) { + QpParser parser; + absl::optional qp = parser.Parse( + kVideoCodecVP9, 0, kCodedFrameVp9Qp96, sizeof(kCodedFrameVp9Qp96)); + EXPECT_EQ(qp, 96u); +} + +TEST(QpParserTest, ParseQpH264) { + QpParser parser; + absl::optional qp = parser.Parse( + VideoCodecType::kVideoCodecH264, 0, kCodedFrameH264SpsPpsIdrQp38, + sizeof(kCodedFrameH264SpsPpsIdrQp38)); + EXPECT_EQ(qp, 38u); + + qp = parser.Parse(kVideoCodecH264, 1, kCodedFrameH264SpsPpsIdrQp49, + sizeof(kCodedFrameH264SpsPpsIdrQp49)); + EXPECT_EQ(qp, 49u); + + qp = parser.Parse(kVideoCodecH264, 0, kCodedFrameH264InterSliceQpDelta0, + sizeof(kCodedFrameH264InterSliceQpDelta0)); + EXPECT_EQ(qp, 38u); + + qp = parser.Parse(kVideoCodecH264, 1, kCodedFrameH264InterSliceQpDelta0, + sizeof(kCodedFrameH264InterSliceQpDelta0)); + EXPECT_EQ(qp, 49u); +} + +TEST(QpParserTest, ParseQpUnsupportedCodecType) { + QpParser parser; + absl::optional qp = parser.Parse( + kVideoCodecGeneric, 0, kCodedFrameVp8Qp25, sizeof(kCodedFrameVp8Qp25)); + EXPECT_FALSE(qp.has_value()); +} + +TEST(QpParserTest, ParseQpNullData) { + QpParser parser; + absl::optional qp = parser.Parse(kVideoCodecVP8, 0, nullptr, 100); + EXPECT_FALSE(qp.has_value()); +} + +TEST(QpParserTest, ParseQpEmptyData) { + QpParser parser; + absl::optional qp = + parser.Parse(kVideoCodecVP8, 0, kCodedFrameVp8Qp25, 0); + EXPECT_FALSE(qp.has_value()); +} + +TEST(QpParserTest, ParseQpSpatialIdxExceedsMax) { + QpParser parser; + absl::optional qp = + parser.Parse(kVideoCodecVP8, kMaxSimulcastStreams, kCodedFrameVp8Qp25, + sizeof(kCodedFrameVp8Qp25)); + EXPECT_FALSE(qp.has_value()); +} + +} // namespace webrtc diff --git a/modules/video_coding/utility/quality_scaler.h b/modules/video_coding/utility/quality_scaler.h index 987d49f1a8..20169a3cee 100644 --- a/modules/video_coding/utility/quality_scaler.h +++ b/modules/video_coding/utility/quality_scaler.h @@ -18,12 +18,12 @@ #include "absl/types/optional.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "api/video_codecs/video_encoder.h" #include "rtc_base/experiments/quality_scaling_experiment.h" #include "rtc_base/numerics/moving_average.h" #include "rtc_base/ref_count.h" #include "rtc_base/ref_counted_object.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_queue.h" diff --git a/modules/video_coding/utility/quality_scaler_unittest.cc b/modules/video_coding/utility/quality_scaler_unittest.cc index d5b22a8a29..91911a7696 100644 --- a/modules/video_coding/utility/quality_scaler_unittest.cc +++ b/modules/video_coding/utility/quality_scaler_unittest.cc @@ -116,7 +116,7 @@ INSTANTIATE_TEST_SUITE_P( QualityScalerTest, ::testing::Values( "WebRTC-Video-QualityScaling/Enabled-1,2,3,4,5,6,7,8,0.9,0.99,1/", - "")); + "WebRTC-Video-QualityScaling/Disabled/")); TEST_P(QualityScalerTest, DownscalesAfterContinuousFramedrop) { task_queue_.SendTask([this] { TriggerScale(kScaleDown); }, RTC_FROM_HERE); @@ -171,7 +171,8 @@ TEST_P(QualityScalerTest, DoesNotDownscaleAfterHalfFramedrop) { } TEST_P(QualityScalerTest, DownscalesAfterTwoThirdsIfFieldTrialEnabled) { - const bool kDownScaleExpected = !GetParam().empty(); + const bool kDownScaleExpected = + GetParam().find("Enabled") != std::string::npos; task_queue_.SendTask( [this] { for (int i = 0; i < kFramerate * 5; ++i) { diff --git a/modules/video_coding/utility/simulcast_test_fixture_impl.cc b/modules/video_coding/utility/simulcast_test_fixture_impl.cc index a9af643446..6d3195c32b 100644 --- a/modules/video_coding/utility/simulcast_test_fixture_impl.cc +++ b/modules/video_coding/utility/simulcast_test_fixture_impl.cc @@ -190,7 +190,7 @@ void ConfigureStream(int width, float max_framerate, SpatialLayer* stream, int num_temporal_layers) { - assert(stream); + RTC_DCHECK(stream); stream->width = width; stream->height = height; stream->maxBitrate = max_bitrate; @@ -590,6 +590,7 @@ void SimulcastTestFixtureImpl::SwitchingToOneStream(int width, int height) { settings_.VP8()->numberOfTemporalLayers = 1; temporal_layer_profile = kDefaultTemporalLayerProfile; } else { + settings_.H264()->numberOfTemporalLayers = 1; temporal_layer_profile = kNoTemporalLayerProfile; } settings_.maxBitrate = 100; diff --git a/modules/video_coding/utility/vp9_uncompressed_header_parser.cc b/modules/video_coding/utility/vp9_uncompressed_header_parser.cc index f8ddd4db41..07ba3255c6 100644 --- a/modules/video_coding/utility/vp9_uncompressed_header_parser.cc +++ b/modules/video_coding/utility/vp9_uncompressed_header_parser.cc @@ -9,90 +9,195 @@ */ #include "modules/video_coding/utility/vp9_uncompressed_header_parser.h" +#include "absl/strings/string_view.h" #include "rtc_base/bit_buffer.h" #include "rtc_base/logging.h" namespace webrtc { -#define RETURN_FALSE_IF_ERROR(x) \ - if (!(x)) { \ - return false; \ +// Evaluates x and returns false if false. +#define RETURN_IF_FALSE(x) \ + if (!(x)) { \ + return false; \ } +// Evaluates x, which is intended to return an optional. If result is nullopt, +// returns false. Else, calls fun() with the dereferenced optional as parameter. +#define READ_OR_RETURN(x, fun) \ + do { \ + if (auto optional_val = (x)) { \ + fun(*optional_val); \ + } else { \ + return false; \ + } \ + } while (false) + namespace vp9 { namespace { const size_t kVp9NumRefsPerFrame = 3; const size_t kVp9MaxRefLFDeltas = 4; const size_t kVp9MaxModeLFDeltas = 2; +const size_t kVp9MinTileWidthB64 = 4; +const size_t kVp9MaxTileWidthB64 = 64; + +class BitstreamReader { + public: + explicit BitstreamReader(rtc::BitBuffer* buffer) : buffer_(buffer) {} + + // Reads on bit from the input stream and: + // * returns false if bit cannot be read + // * calls f_true() if bit is true, returns return value of that function + // * calls f_else() if bit is false, returns return value of that function + bool IfNextBoolean( + std::function f_true, + std::function f_false = [] { return true; }) { + uint32_t val; + if (!buffer_->ReadBits(1, val)) { + return false; + } + if (val != 0) { + return f_true(); + } + return f_false(); + } -bool Vp9ReadProfile(rtc::BitBuffer* br, uint8_t* profile) { - uint32_t high_bit; - uint32_t low_bit; - RETURN_FALSE_IF_ERROR(br->ReadBits(&low_bit, 1)); - RETURN_FALSE_IF_ERROR(br->ReadBits(&high_bit, 1)); - *profile = (high_bit << 1) + low_bit; - if (*profile > 2) { - uint32_t reserved_bit; - RETURN_FALSE_IF_ERROR(br->ReadBits(&reserved_bit, 1)); - if (reserved_bit) { - RTC_LOG(LS_WARNING) << "Failed to get QP. Unsupported bitstream profile."; + absl::optional ReadBoolean() { + uint32_t val; + if (!buffer_->ReadBits(1, val)) { + return {}; + } + return {val != 0}; + } + + // Reads a bit from the input stream and returns: + // * false if bit cannot be read + // * true if bit matches expected_val + // * false if bit does not match expected_val - in which case |error_msg| is + // logged as warning, if provided. + bool VerifyNextBooleanIs(bool expected_val, absl::string_view error_msg) { + uint32_t val; + if (!buffer_->ReadBits(1, val)) { + return false; + } + if ((val != 0) != expected_val) { + if (!error_msg.empty()) { + RTC_LOG(LS_WARNING) << error_msg; + } return false; } + return true; } - return true; -} -bool Vp9ReadSyncCode(rtc::BitBuffer* br) { - uint32_t sync_code; - RETURN_FALSE_IF_ERROR(br->ReadBits(&sync_code, 24)); - if (sync_code != 0x498342) { - RTC_LOG(LS_WARNING) << "Failed to get QP. Invalid sync code."; - return false; + // Reads |bits| bits from the bitstream and interprets them as an unsigned + // integer that gets cast to the type T before returning. + // Returns nullopt if all bits cannot be read. + // If number of bits matches size of data type, the bits parameter may be + // omitted. Ex: + // ReadUnsigned(2); // Returns uint8_t with 2 LSB populated. + // ReadUnsigned(); // Returns uint8_t with all 8 bits populated. + template + absl::optional ReadUnsigned(int bits = sizeof(T) * 8) { + RTC_DCHECK_LE(bits, 32); + RTC_DCHECK_LE(bits, sizeof(T) * 8); + uint32_t val; + if (!buffer_->ReadBits(bits, val)) { + return {}; + } + return (static_cast(val)); } - return true; -} -bool Vp9ReadColorConfig(rtc::BitBuffer* br, - uint8_t profile, - FrameInfo* frame_info) { - if (profile == 0 || profile == 1) { + // Helper method that reads |num_bits| from the bitstream, returns: + // * false if bits cannot be read. + // * true if |expected_val| matches the read bits + // * false if |expected_val| does not match the read bits, and logs + // |error_msg| as a warning (if provided). + bool VerifyNextUnsignedIs(int num_bits, + uint32_t expected_val, + absl::string_view error_msg) { + uint32_t val; + if (!buffer_->ReadBits(num_bits, val)) { + return false; + } + if (val != expected_val) { + if (!error_msg.empty()) { + RTC_LOG(LS_WARNING) << error_msg; + } + return false; + } + return true; + } + + // Basically the same as ReadUnsigned() - but for signed integers. + // Here |bits| indicates the size of the value - number of bits read from the + // bit buffer is one higher (the sign bit). This is made to matche the spec in + // which eg s(4) = f(1) sign-bit, plus an f(4). + template + absl::optional ReadSigned(int bits = sizeof(T) * 8) { + uint32_t sign; + if (!buffer_->ReadBits(1, sign)) { + return {}; + } + uint32_t val; + if (!buffer_->ReadBits(bits, val)) { + return {}; + } + int64_t sign_val = val; + if (sign != 0) { + sign_val = -sign_val; + } + return {static_cast(sign_val)}; + } + + // Reads |bits| from the bitstream, disregarding their value. + // Returns true if full number of bits were read, false otherwise. + bool ConsumeBits(int bits) { return buffer_->ConsumeBits(bits); } + + private: + rtc::BitBuffer* buffer_; +}; + +bool Vp9ReadColorConfig(BitstreamReader* br, FrameInfo* frame_info) { + if (frame_info->profile == 2 || frame_info->profile == 3) { + READ_OR_RETURN(br->ReadBoolean(), [frame_info](bool ten_or_twelve_bits) { + frame_info->bit_detph = + ten_or_twelve_bits ? BitDept::k12Bit : BitDept::k10Bit; + }); + } else { frame_info->bit_detph = BitDept::k8Bit; - } else if (profile == 2 || profile == 3) { - uint32_t ten_or_twelve_bits; - RETURN_FALSE_IF_ERROR(br->ReadBits(&ten_or_twelve_bits, 1)); - frame_info->bit_detph = - ten_or_twelve_bits ? BitDept::k12Bit : BitDept::k10Bit; } - uint32_t color_space; - RETURN_FALSE_IF_ERROR(br->ReadBits(&color_space, 3)); - frame_info->color_space = static_cast(color_space); - - // SRGB is 7. - if (color_space != 7) { - uint32_t color_range; - RETURN_FALSE_IF_ERROR(br->ReadBits(&color_range, 1)); - frame_info->color_range = - color_range ? ColorRange::kFull : ColorRange::kStudio; - - if (profile == 1 || profile == 3) { - uint32_t subsampling_x; - uint32_t subsampling_y; - RETURN_FALSE_IF_ERROR(br->ReadBits(&subsampling_x, 1)); - RETURN_FALSE_IF_ERROR(br->ReadBits(&subsampling_y, 1)); - if (subsampling_x) { - frame_info->sub_sampling = - subsampling_y ? YuvSubsampling::k420 : YuvSubsampling::k422; - } else { - frame_info->sub_sampling = - subsampling_y ? YuvSubsampling::k440 : YuvSubsampling::k444; - } - uint32_t reserved_bit; - RETURN_FALSE_IF_ERROR(br->ReadBits(&reserved_bit, 1)); - if (reserved_bit) { - RTC_LOG(LS_WARNING) << "Failed to parse header. Reserved bit set."; - return false; - } + READ_OR_RETURN( + br->ReadUnsigned(3), [frame_info](uint8_t color_space) { + frame_info->color_space = static_cast(color_space); + }); + + if (frame_info->color_space != ColorSpace::CS_RGB) { + READ_OR_RETURN(br->ReadBoolean(), [frame_info](bool color_range) { + frame_info->color_range = + color_range ? ColorRange::kFull : ColorRange::kStudio; + }); + + if (frame_info->profile == 1 || frame_info->profile == 3) { + READ_OR_RETURN(br->ReadUnsigned(2), + [frame_info](uint8_t subsampling) { + switch (subsampling) { + case 0b00: + frame_info->sub_sampling = YuvSubsampling::k444; + break; + case 0b01: + frame_info->sub_sampling = YuvSubsampling::k440; + break; + case 0b10: + frame_info->sub_sampling = YuvSubsampling::k422; + break; + case 0b11: + frame_info->sub_sampling = YuvSubsampling::k420; + break; + } + }); + + RETURN_IF_FALSE(br->VerifyNextBooleanIs( + 0, "Failed to parse header. Reserved bit set.")); } else { // Profile 0 or 2. frame_info->sub_sampling = YuvSubsampling::k420; @@ -100,14 +205,10 @@ bool Vp9ReadColorConfig(rtc::BitBuffer* br, } else { // SRGB frame_info->color_range = ColorRange::kFull; - if (profile == 1 || profile == 3) { + if (frame_info->profile == 1 || frame_info->profile == 3) { frame_info->sub_sampling = YuvSubsampling::k444; - uint32_t reserved_bit; - RETURN_FALSE_IF_ERROR(br->ReadBits(&reserved_bit, 1)); - if (reserved_bit) { - RTC_LOG(LS_WARNING) << "Failed to parse header. Reserved bit set."; - return false; - } + RETURN_IF_FALSE(br->VerifyNextBooleanIs( + 0, "Failed to parse header. Reserved bit set.")); } else { RTC_LOG(LS_WARNING) << "Failed to parse header. 4:4:4 color not supported" " in profile 0 or 2."; @@ -118,44 +219,45 @@ bool Vp9ReadColorConfig(rtc::BitBuffer* br, return true; } -bool Vp9ReadFrameSize(rtc::BitBuffer* br, FrameInfo* frame_info) { - // 16 bits: frame width - 1. - uint16_t frame_width_minus_one; - RETURN_FALSE_IF_ERROR(br->ReadUInt16(&frame_width_minus_one)); - // 16 bits: frame height - 1. - uint16_t frame_height_minus_one; - RETURN_FALSE_IF_ERROR(br->ReadUInt16(&frame_height_minus_one)); - frame_info->frame_width = frame_width_minus_one + 1; - frame_info->frame_height = frame_height_minus_one + 1; +bool Vp9ReadFrameSize(BitstreamReader* br, FrameInfo* frame_info) { + // 16 bits: frame (width|height) - 1. + READ_OR_RETURN(br->ReadUnsigned(), [frame_info](uint16_t width) { + frame_info->frame_width = width + 1; + }); + READ_OR_RETURN(br->ReadUnsigned(), [frame_info](uint16_t height) { + frame_info->frame_height = height + 1; + }); return true; } -bool Vp9ReadRenderSize(rtc::BitBuffer* br, FrameInfo* frame_info) { - uint32_t render_and_frame_size_different; - RETURN_FALSE_IF_ERROR(br->ReadBits(&render_and_frame_size_different, 1)); - if (render_and_frame_size_different) { - // 16 bits: render width - 1. - uint16_t render_width_minus_one; - RETURN_FALSE_IF_ERROR(br->ReadUInt16(&render_width_minus_one)); - // 16 bits: render height - 1. - uint16_t render_height_minus_one; - RETURN_FALSE_IF_ERROR(br->ReadUInt16(&render_height_minus_one)); - frame_info->render_width = render_width_minus_one + 1; - frame_info->render_height = render_height_minus_one + 1; - } else { - frame_info->render_width = frame_info->frame_width; - frame_info->render_height = frame_info->frame_height; - } - return true; +bool Vp9ReadRenderSize(BitstreamReader* br, FrameInfo* frame_info) { + // render_and_frame_size_different + return br->IfNextBoolean( + [&] { + // 16 bits: render (width|height) - 1. + READ_OR_RETURN(br->ReadUnsigned(), + [frame_info](uint16_t width) { + frame_info->render_width = width + 1; + }); + READ_OR_RETURN(br->ReadUnsigned(), + [frame_info](uint16_t height) { + frame_info->render_height = height + 1; + }); + return true; + }, + /*else*/ + [&] { + frame_info->render_height = frame_info->frame_height; + frame_info->render_width = frame_info->frame_width; + return true; + }); } -bool Vp9ReadFrameSizeFromRefs(rtc::BitBuffer* br, FrameInfo* frame_info) { - uint32_t found_ref = 0; - for (size_t i = 0; i < kVp9NumRefsPerFrame; i++) { +bool Vp9ReadFrameSizeFromRefs(BitstreamReader* br, FrameInfo* frame_info) { + bool found_ref = false; + for (size_t i = 0; !found_ref && i < kVp9NumRefsPerFrame; i++) { // Size in refs. - RETURN_FALSE_IF_ERROR(br->ReadBits(&found_ref, 1)); - if (found_ref) - break; + READ_OR_RETURN(br->ReadBoolean(), [&](bool ref) { found_ref = ref; }); } if (!found_ref) { @@ -166,83 +268,156 @@ bool Vp9ReadFrameSizeFromRefs(rtc::BitBuffer* br, FrameInfo* frame_info) { return Vp9ReadRenderSize(br, frame_info); } -bool Vp9ReadInterpolationFilter(rtc::BitBuffer* br) { - uint32_t bit; - RETURN_FALSE_IF_ERROR(br->ReadBits(&bit, 1)); - if (bit) - return true; - - return br->ConsumeBits(2); -} - -bool Vp9ReadLoopfilter(rtc::BitBuffer* br) { +bool Vp9ReadLoopfilter(BitstreamReader* br) { // 6 bits: filter level. // 3 bits: sharpness level. - RETURN_FALSE_IF_ERROR(br->ConsumeBits(9)); - - uint32_t mode_ref_delta_enabled; - RETURN_FALSE_IF_ERROR(br->ReadBits(&mode_ref_delta_enabled, 1)); - if (mode_ref_delta_enabled) { - uint32_t mode_ref_delta_update; - RETURN_FALSE_IF_ERROR(br->ReadBits(&mode_ref_delta_update, 1)); - if (mode_ref_delta_update) { - uint32_t bit; + RETURN_IF_FALSE(br->ConsumeBits(9)); + + return br->IfNextBoolean([&] { // if mode_ref_delta_enabled + return br->IfNextBoolean([&] { // if mode_ref_delta_update for (size_t i = 0; i < kVp9MaxRefLFDeltas; i++) { - RETURN_FALSE_IF_ERROR(br->ReadBits(&bit, 1)); - if (bit) { - RETURN_FALSE_IF_ERROR(br->ConsumeBits(7)); - } + RETURN_IF_FALSE(br->IfNextBoolean([&] { return br->ConsumeBits(7); })); } for (size_t i = 0; i < kVp9MaxModeLFDeltas; i++) { - RETURN_FALSE_IF_ERROR(br->ReadBits(&bit, 1)); - if (bit) { - RETURN_FALSE_IF_ERROR(br->ConsumeBits(7)); + RETURN_IF_FALSE(br->IfNextBoolean([&] { return br->ConsumeBits(7); })); + } + return true; + }); + }); +} + +bool Vp9ReadQp(BitstreamReader* br, FrameInfo* frame_info) { + READ_OR_RETURN(br->ReadUnsigned(), + [frame_info](uint8_t qp) { frame_info->base_qp = qp; }); + + // yuv offsets + for (int i = 0; i < 3; ++i) { + RETURN_IF_FALSE(br->IfNextBoolean([br] { // if delta_coded + return br->ConsumeBits(5); + })); + } + return true; +} + +bool Vp9ReadSegmentationParams(BitstreamReader* br) { + constexpr int kVp9MaxSegments = 8; + constexpr int kVp9SegLvlMax = 4; + constexpr int kSegmentationFeatureBits[kVp9SegLvlMax] = {8, 6, 2, 0}; + constexpr bool kSegmentationFeatureSigned[kVp9SegLvlMax] = {1, 1, 0, 0}; + + RETURN_IF_FALSE(br->IfNextBoolean([&] { // segmentation_enabled + return br->IfNextBoolean([&] { // update_map + // Consume probs. + for (int i = 0; i < 7; ++i) { + RETURN_IF_FALSE(br->IfNextBoolean([br] { return br->ConsumeBits(7); })); + } + + return br->IfNextBoolean([&] { // temporal_update + // Consume probs. + for (int i = 0; i < 3; ++i) { + RETURN_IF_FALSE( + br->IfNextBoolean([br] { return br->ConsumeBits(7); })); } + return true; + }); + }); + })); + + return br->IfNextBoolean([&] { + RETURN_IF_FALSE(br->ConsumeBits(1)); // abs_or_delta + for (int i = 0; i < kVp9MaxSegments; ++i) { + for (int j = 0; j < kVp9SegLvlMax; ++j) { + RETURN_IF_FALSE(br->IfNextBoolean([&] { // feature_enabled + return br->ConsumeBits(kSegmentationFeatureBits[j] + + kSegmentationFeatureSigned[j]); + })); } } + return true; + }); +} + +bool Vp9ReadTileInfo(BitstreamReader* br, FrameInfo* frame_info) { + size_t mi_cols = (frame_info->frame_width + 7) >> 3; + size_t sb64_cols = (mi_cols + 7) >> 3; + + size_t min_log2 = 0; + while ((kVp9MaxTileWidthB64 << min_log2) < sb64_cols) { + ++min_log2; } - return true; + + size_t max_log2 = 1; + while ((sb64_cols >> max_log2) >= kVp9MinTileWidthB64) { + ++max_log2; + } + --max_log2; + + size_t cols_log2 = min_log2; + bool done = false; + while (!done && cols_log2 < max_log2) { + RETURN_IF_FALSE(br->IfNextBoolean( + [&] { + ++cols_log2; + return true; + }, + [&] { + done = true; + return true; + })); + } + + // rows_log2; + return br->IfNextBoolean([&] { return br->ConsumeBits(1); }); } } // namespace -bool Parse(const uint8_t* buf, size_t length, int* qp, FrameInfo* frame_info) { - rtc::BitBuffer br(buf, length); +bool Parse(const uint8_t* buf, size_t length, FrameInfo* frame_info) { + rtc::BitBuffer bit_buffer(buf, length); + BitstreamReader br(&bit_buffer); // Frame marker. - uint32_t frame_marker; - RETURN_FALSE_IF_ERROR(br.ReadBits(&frame_marker, 2)); - if (frame_marker != 0x2) { - RTC_LOG(LS_WARNING) << "Failed to parse header. Frame marker should be 2."; - return false; + RETURN_IF_FALSE(br.VerifyNextUnsignedIs( + 2, 0x2, "Failed to parse header. Frame marker should be 2.")); + + // Profile has low bit first. + READ_OR_RETURN(br.ReadBoolean(), + [frame_info](bool low) { frame_info->profile = int{low}; }); + READ_OR_RETURN(br.ReadBoolean(), [frame_info](bool high) { + frame_info->profile |= int{high} << 1; + }); + if (frame_info->profile > 2) { + RETURN_IF_FALSE(br.VerifyNextBooleanIs( + false, "Failed to get QP. Unsupported bitstream profile.")); } - // Profile. - uint8_t profile; - if (!Vp9ReadProfile(&br, &profile)) - return false; - frame_info->profile = profile; - // Show existing frame. - uint32_t show_existing_frame; - RETURN_FALSE_IF_ERROR(br.ReadBits(&show_existing_frame, 1)); - if (show_existing_frame) - return false; + RETURN_IF_FALSE(br.IfNextBoolean([&] { + READ_OR_RETURN(br.ReadUnsigned(3), + [frame_info](uint8_t frame_idx) { + frame_info->show_existing_frame = frame_idx; + }); + return true; + })); + if (frame_info->show_existing_frame.has_value()) { + return true; + } - // Frame type: KEY_FRAME(0), INTER_FRAME(1). - uint32_t frame_type; - uint32_t show_frame; - uint32_t error_resilient; - RETURN_FALSE_IF_ERROR(br.ReadBits(&frame_type, 1)); - RETURN_FALSE_IF_ERROR(br.ReadBits(&show_frame, 1)); - RETURN_FALSE_IF_ERROR(br.ReadBits(&error_resilient, 1)); - frame_info->show_frame = show_frame; - frame_info->error_resilient = error_resilient; - - if (frame_type == 0) { - // Key-frame. - if (!Vp9ReadSyncCode(&br)) - return false; - if (!Vp9ReadColorConfig(&br, profile, frame_info)) + READ_OR_RETURN(br.ReadBoolean(), [frame_info](bool frame_type) { + // Frame type: KEY_FRAME(0), INTER_FRAME(1). + frame_info->is_keyframe = frame_type == 0; + }); + READ_OR_RETURN(br.ReadBoolean(), [frame_info](bool show_frame) { + frame_info->show_frame = show_frame; + }); + READ_OR_RETURN(br.ReadBoolean(), [frame_info](bool error_resilient) { + frame_info->error_resilient = error_resilient; + }); + + if (frame_info->is_keyframe) { + RETURN_IF_FALSE(br.VerifyNextUnsignedIs( + 24, 0x498342, "Failed to get QP. Invalid sync code.")); + + if (!Vp9ReadColorConfig(&br, frame_info)) return false; if (!Vp9ReadFrameSize(&br, frame_info)) return false; @@ -250,76 +425,92 @@ bool Parse(const uint8_t* buf, size_t length, int* qp, FrameInfo* frame_info) { return false; } else { // Non-keyframe. - uint32_t intra_only = 0; - if (!show_frame) - RETURN_FALSE_IF_ERROR(br.ReadBits(&intra_only, 1)); - if (!error_resilient) - RETURN_FALSE_IF_ERROR(br.ConsumeBits(2)); // Reset frame context. - - if (intra_only) { - if (!Vp9ReadSyncCode(&br)) - return false; + bool is_intra_only = false; + if (!frame_info->show_frame) { + READ_OR_RETURN(br.ReadBoolean(), + [&](bool intra_only) { is_intra_only = intra_only; }); + } + if (!frame_info->error_resilient) { + RETURN_IF_FALSE(br.ConsumeBits(2)); // Reset frame context. + } + + if (is_intra_only) { + RETURN_IF_FALSE(br.VerifyNextUnsignedIs( + 24, 0x498342, "Failed to get QP. Invalid sync code.")); - if (profile > 0) { - if (!Vp9ReadColorConfig(&br, profile, frame_info)) + if (frame_info->profile > 0) { + if (!Vp9ReadColorConfig(&br, frame_info)) return false; } // Refresh frame flags. - RETURN_FALSE_IF_ERROR(br.ConsumeBits(8)); + RETURN_IF_FALSE(br.ConsumeBits(8)); if (!Vp9ReadFrameSize(&br, frame_info)) return false; if (!Vp9ReadRenderSize(&br, frame_info)) return false; } else { // Refresh frame flags. - RETURN_FALSE_IF_ERROR(br.ConsumeBits(8)); + RETURN_IF_FALSE(br.ConsumeBits(8)); for (size_t i = 0; i < kVp9NumRefsPerFrame; i++) { // 3 bits: Ref frame index. // 1 bit: Ref frame sign biases. - RETURN_FALSE_IF_ERROR(br.ConsumeBits(4)); + RETURN_IF_FALSE(br.ConsumeBits(4)); } if (!Vp9ReadFrameSizeFromRefs(&br, frame_info)) return false; // Allow high precision mv. - RETURN_FALSE_IF_ERROR(br.ConsumeBits(1)); + RETURN_IF_FALSE(br.ConsumeBits(1)); // Interpolation filter. - if (!Vp9ReadInterpolationFilter(&br)) - return false; + RETURN_IF_FALSE(br.IfNextBoolean([] { return true; }, + [&br] { return br.ConsumeBits(2); })); } } - if (!error_resilient) { + if (!frame_info->error_resilient) { // 1 bit: Refresh frame context. // 1 bit: Frame parallel decoding mode. - RETURN_FALSE_IF_ERROR(br.ConsumeBits(2)); + RETURN_IF_FALSE(br.ConsumeBits(2)); } // Frame context index. - RETURN_FALSE_IF_ERROR(br.ConsumeBits(2)); + RETURN_IF_FALSE(br.ConsumeBits(2)); if (!Vp9ReadLoopfilter(&br)) return false; - // Base QP. - uint8_t base_q0; - RETURN_FALSE_IF_ERROR(br.ReadUInt8(&base_q0)); - *qp = base_q0; + // Read base QP. + RETURN_IF_FALSE(Vp9ReadQp(&br, frame_info)); + + const bool kParseFullHeader = false; + if (kParseFullHeader) { + // Currently not used, but will be needed when parsing beyond the + // uncompressed header. + RETURN_IF_FALSE(Vp9ReadSegmentationParams(&br)); + + RETURN_IF_FALSE(Vp9ReadTileInfo(&br, frame_info)); + + RETURN_IF_FALSE(br.ConsumeBits(16)); // header_size_in_bytes + } + return true; } bool GetQp(const uint8_t* buf, size_t length, int* qp) { FrameInfo frame_info; - return Parse(buf, length, qp, &frame_info); + if (!Parse(buf, length, &frame_info)) { + return false; + } + *qp = frame_info.base_qp; + return true; } absl::optional ParseIntraFrameInfo(const uint8_t* buf, size_t length) { - int qp = 0; FrameInfo frame_info; - if (Parse(buf, length, &qp, &frame_info) && frame_info.frame_width > 0) { + if (Parse(buf, length, &frame_info) && frame_info.frame_width > 0) { return frame_info; } return absl::nullopt; diff --git a/modules/video_coding/utility/vp9_uncompressed_header_parser.h b/modules/video_coding/utility/vp9_uncompressed_header_parser.h index a7f04670d2..7a5e2c058b 100644 --- a/modules/video_coding/utility/vp9_uncompressed_header_parser.h +++ b/modules/video_coding/utility/vp9_uncompressed_header_parser.h @@ -65,6 +65,8 @@ enum class YuvSubsampling { struct FrameInfo { int profile = 0; // Profile 0-3 are valid. + absl::optional show_existing_frame; + bool is_keyframe = false; bool show_frame = false; bool error_resilient = false; BitDept bit_detph = BitDept::k8Bit; @@ -75,6 +77,7 @@ struct FrameInfo { int frame_height = 0; int render_width = 0; int render_height = 0; + int base_qp = 0; }; // Parses frame information for a VP9 key-frame or all-intra frame from a diff --git a/modules/video_coding/utility/vp9_uncompressed_header_parser_unittest.cc b/modules/video_coding/utility/vp9_uncompressed_header_parser_unittest.cc new file mode 100644 index 0000000000..b69b45d5c4 --- /dev/null +++ b/modules/video_coding/utility/vp9_uncompressed_header_parser_unittest.cc @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/video_coding/utility/vp9_uncompressed_header_parser.h" + +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { +namespace vp9 { + +TEST(Vp9UncompressedHeaderParserTest, FrameWithSegmentation) { + // Uncompressed header from a frame generated with libvpx. + // Encoded QVGA frame (SL0 of a VGA frame) that includes a segmentation. + const uint8_t kHeader[] = { + 0x87, 0x01, 0x00, 0x00, 0x02, 0x7e, 0x01, 0xdf, 0x02, 0x7f, 0x01, 0xdf, + 0xc6, 0x87, 0x04, 0x83, 0x83, 0x2e, 0x46, 0x60, 0x20, 0x38, 0x0c, 0x06, + 0x03, 0xcd, 0x80, 0xc0, 0x60, 0x9f, 0xc5, 0x46, 0x00, 0x00, 0x00, 0x00, + 0x2e, 0x73, 0xb7, 0xee, 0x22, 0x06, 0x81, 0x82, 0xd4, 0xef, 0xc3, 0x58, + 0x1f, 0x12, 0xd2, 0x7b, 0x28, 0x1f, 0x80, 0xfc, 0x07, 0xe0, 0x00, 0x00}; + + absl::optional frame_info = + ParseIntraFrameInfo(kHeader, sizeof(kHeader)); + // Segmentation info is not actually populated in FrameInfo struct, but it + // needs to be parsed otherwise we end up on the wrong offset. The check for + // segmentation is thus that we have a valid return value. + ASSERT_TRUE(frame_info.has_value()); + + EXPECT_EQ(frame_info->is_keyframe, false); + EXPECT_EQ(frame_info->error_resilient, true); + EXPECT_EQ(frame_info->show_frame, true); + EXPECT_EQ(frame_info->base_qp, 185); + EXPECT_EQ(frame_info->frame_width, 320); + EXPECT_EQ(frame_info->frame_height, 240); + EXPECT_EQ(frame_info->render_width, 640); + EXPECT_EQ(frame_info->render_height, 480); +} + +} // namespace vp9 +} // namespace webrtc diff --git a/modules/video_coding/video_codec_initializer.cc b/modules/video_coding/video_codec_initializer.cc index 90a02e0c2d..17ea66acb1 100644 --- a/modules/video_coding/video_codec_initializer.cc +++ b/modules/video_coding/video_codec_initializer.cc @@ -262,7 +262,11 @@ VideoCodec VideoCodecInitializer::VideoEncoderConfigToVideoCodec( break; } case kVideoCodecAV1: - if (!SetAv1SvcConfig(video_codec)) { + if (SetAv1SvcConfig(video_codec)) { + for (size_t i = 0; i < config.spatial_layers.size(); ++i) { + video_codec.spatialLayers[i].active = config.spatial_layers[i].active; + } + } else { RTC_LOG(LS_WARNING) << "Failed to configure svc bitrates for av1."; } break; diff --git a/modules/video_coding/video_codec_initializer_unittest.cc b/modules/video_coding/video_codec_initializer_unittest.cc index 1ea145e14f..6c1c2e7a38 100644 --- a/modules/video_coding/video_codec_initializer_unittest.cc +++ b/modules/video_coding/video_codec_initializer_unittest.cc @@ -74,13 +74,13 @@ class VideoCodecInitializerTest : public ::testing::Test { config_.number_of_streams = num_spatial_streams; VideoCodecVP8 vp8_settings = VideoEncoder::GetDefaultVp8Settings(); vp8_settings.numberOfTemporalLayers = num_temporal_streams; - config_.encoder_specific_settings = new rtc::RefCountedObject< + config_.encoder_specific_settings = rtc::make_ref_counted< webrtc::VideoEncoderConfig::Vp8EncoderSpecificSettings>(vp8_settings); } else if (type == VideoCodecType::kVideoCodecVP9) { VideoCodecVP9 vp9_settings = VideoEncoder::GetDefaultVp9Settings(); vp9_settings.numberOfSpatialLayers = num_spatial_streams; vp9_settings.numberOfTemporalLayers = num_temporal_streams; - config_.encoder_specific_settings = new rtc::RefCountedObject< + config_.encoder_specific_settings = rtc::make_ref_counted< webrtc::VideoEncoderConfig::Vp9EncoderSpecificSettings>(vp9_settings); } else if (type != VideoCodecType::kVideoCodecMultiplex) { ADD_FAILURE() << "Unexpected codec type: " << type; @@ -426,4 +426,69 @@ TEST_F(VideoCodecInitializerTest, Vp9DeactivateLayers) { EXPECT_FALSE(codec_out_.spatialLayers[2].active); } +TEST_F(VideoCodecInitializerTest, Av1SingleSpatialLayerBitratesAreConsistent) { + VideoEncoderConfig config; + config.codec_type = VideoCodecType::kVideoCodecAV1; + std::vector streams = {DefaultStream()}; + streams[0].scalability_mode = "L1T2"; + + VideoCodec codec; + EXPECT_TRUE(VideoCodecInitializer::SetupCodec(config, streams, &codec)); + + EXPECT_GE(codec.spatialLayers[0].targetBitrate, + codec.spatialLayers[0].minBitrate); + EXPECT_LE(codec.spatialLayers[0].targetBitrate, + codec.spatialLayers[0].maxBitrate); +} + +TEST_F(VideoCodecInitializerTest, Av1TwoSpatialLayersBitratesAreConsistent) { + VideoEncoderConfig config; + config.codec_type = VideoCodecType::kVideoCodecAV1; + std::vector streams = {DefaultStream()}; + streams[0].scalability_mode = "L2T2"; + + VideoCodec codec; + EXPECT_TRUE(VideoCodecInitializer::SetupCodec(config, streams, &codec)); + + EXPECT_GE(codec.spatialLayers[0].targetBitrate, + codec.spatialLayers[0].minBitrate); + EXPECT_LE(codec.spatialLayers[0].targetBitrate, + codec.spatialLayers[0].maxBitrate); + + EXPECT_GE(codec.spatialLayers[1].targetBitrate, + codec.spatialLayers[1].minBitrate); + EXPECT_LE(codec.spatialLayers[1].targetBitrate, + codec.spatialLayers[1].maxBitrate); +} + +TEST_F(VideoCodecInitializerTest, Av1TwoSpatialLayersActiveByDefault) { + VideoEncoderConfig config; + config.codec_type = VideoCodecType::kVideoCodecAV1; + std::vector streams = {DefaultStream()}; + streams[0].scalability_mode = "L2T2"; + config.spatial_layers = {}; + + VideoCodec codec; + EXPECT_TRUE(VideoCodecInitializer::SetupCodec(config, streams, &codec)); + + EXPECT_TRUE(codec.spatialLayers[0].active); + EXPECT_TRUE(codec.spatialLayers[1].active); +} + +TEST_F(VideoCodecInitializerTest, Av1TwoSpatialLayersOneDeactivated) { + VideoEncoderConfig config; + config.codec_type = VideoCodecType::kVideoCodecAV1; + std::vector streams = {DefaultStream()}; + streams[0].scalability_mode = "L2T2"; + config.spatial_layers.resize(2); + config.spatial_layers[0].active = true; + config.spatial_layers[1].active = false; + + VideoCodec codec; + EXPECT_TRUE(VideoCodecInitializer::SetupCodec(config, streams, &codec)); + + EXPECT_TRUE(codec.spatialLayers[0].active); + EXPECT_FALSE(codec.spatialLayers[1].active); +} + } // namespace webrtc diff --git a/modules/video_coding/video_coding_impl.cc b/modules/video_coding/video_coding_impl.cc index 049695d753..f19ea51325 100644 --- a/modules/video_coding/video_coding_impl.cc +++ b/modules/video_coding/video_coding_impl.cc @@ -13,10 +13,10 @@ #include #include +#include "api/sequence_checker.h" #include "api/video/encoded_image.h" #include "modules/video_coding/include/video_codec_interface.h" #include "modules/video_coding/timing.h" -#include "rtc_base/thread_checker.h" #include "system_wrappers/include/clock.h" namespace webrtc { @@ -105,7 +105,7 @@ class VideoCodingModuleImpl : public VideoCodingModule { } private: - rtc::ThreadChecker construction_thread_; + SequenceChecker construction_thread_; const std::unique_ptr timing_; vcm::VideoReceiver receiver_; }; diff --git a/modules/video_coding/video_coding_impl.h b/modules/video_coding/video_coding_impl.h index aee6337e50..d74799460c 100644 --- a/modules/video_coding/video_coding_impl.h +++ b/modules/video_coding/video_coding_impl.h @@ -16,6 +16,7 @@ #include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "modules/video_coding/decoder_database.h" #include "modules/video_coding/frame_buffer.h" #include "modules/video_coding/generic_decoder.h" @@ -25,9 +26,7 @@ #include "modules/video_coding/timing.h" #include "rtc_base/one_time_event.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" #include "system_wrappers/include/clock.h" namespace webrtc { @@ -97,9 +96,9 @@ class VideoReceiver : public Module { // In builds where DCHECKs aren't enabled, it will return true. bool IsDecoderThreadRunning(); - rtc::ThreadChecker construction_thread_checker_; - rtc::ThreadChecker decoder_thread_checker_; - rtc::ThreadChecker module_thread_checker_; + SequenceChecker construction_thread_checker_; + SequenceChecker decoder_thread_checker_; + SequenceChecker module_thread_checker_; Clock* const clock_; Mutex process_mutex_; VCMTiming* _timing; diff --git a/modules/video_coding/video_receiver.cc b/modules/video_coding/video_receiver.cc index 23c251f59c..43dbc9f0b2 100644 --- a/modules/video_coding/video_receiver.cc +++ b/modules/video_coding/video_receiver.cc @@ -14,6 +14,7 @@ #include #include "api/rtp_headers.h" +#include "api/sequence_checker.h" #include "api/video_codecs/video_codec.h" #include "api/video_codecs/video_decoder.h" #include "modules/utility/include/process_thread.h" @@ -33,7 +34,6 @@ #include "rtc_base/location.h" #include "rtc_base/logging.h" #include "rtc_base/one_time_event.h" -#include "rtc_base/thread_checker.h" #include "rtc_base/trace_event.h" #include "system_wrappers/include/clock.h" @@ -279,7 +279,7 @@ int32_t VideoReceiver::IncomingPacket(const uint8_t* incomingPayload, // Callers don't provide any ntp time. const VCMPacket packet(incomingPayload, payloadLength, rtp_header, video_header, /*ntp_time_ms=*/0, - clock_->TimeInMilliseconds()); + clock_->CurrentTime()); int32_t ret = _receiver.InsertPacket(packet); // TODO(holmer): Investigate if this somehow should use the key frame diff --git a/modules/video_coding/video_receiver2.cc b/modules/video_coding/video_receiver2.cc index 6b3cb63679..b893b954bc 100644 --- a/modules/video_coding/video_receiver2.cc +++ b/modules/video_coding/video_receiver2.cc @@ -33,18 +33,18 @@ VideoReceiver2::VideoReceiver2(Clock* clock, VCMTiming* timing) timing_(timing), decodedFrameCallback_(timing_, clock_), codecDataBase_() { - decoder_thread_checker_.Detach(); + decoder_sequence_checker_.Detach(); } VideoReceiver2::~VideoReceiver2() { - RTC_DCHECK_RUN_ON(&construction_thread_checker_); + RTC_DCHECK_RUN_ON(&construction_sequence_checker_); } // Register a receive callback. Will be called whenever there is a new frame // ready for rendering. int32_t VideoReceiver2::RegisterReceiveCallback( VCMReceiveCallback* receiveCallback) { - RTC_DCHECK_RUN_ON(&construction_thread_checker_); + RTC_DCHECK_RUN_ON(&construction_sequence_checker_); RTC_DCHECK(!IsDecoderThreadRunning()); // This value is set before the decoder thread starts and unset after // the decoder thread has been stopped. @@ -52,20 +52,35 @@ int32_t VideoReceiver2::RegisterReceiveCallback( return VCM_OK; } -// Register an externally defined decoder object. +// Register an externally defined decoder object. This may be called on either +// the construction sequence or the decoder sequence to allow for lazy creation +// of video decoders. If called on the decoder sequence |externalDecoder| cannot +// be a nullptr. It's the responsibility of the caller to make sure that the +// access from the two sequences are mutually exclusive. void VideoReceiver2::RegisterExternalDecoder(VideoDecoder* externalDecoder, uint8_t payloadType) { - RTC_DCHECK_RUN_ON(&construction_thread_checker_); - RTC_DCHECK(!IsDecoderThreadRunning()); + if (IsDecoderThreadRunning()) { + RTC_DCHECK_RUN_ON(&decoder_sequence_checker_); + // Don't allow deregistering decoders on the decoder thread. + RTC_DCHECK(externalDecoder != nullptr); + } else { + RTC_DCHECK_RUN_ON(&construction_sequence_checker_); + } + if (externalDecoder == nullptr) { - RTC_CHECK(codecDataBase_.DeregisterExternalDecoder(payloadType)); + codecDataBase_.DeregisterExternalDecoder(payloadType); return; } codecDataBase_.RegisterExternalDecoder(externalDecoder, payloadType); } +bool VideoReceiver2::IsExternalDecoderRegistered(uint8_t payloadType) const { + RTC_DCHECK_RUN_ON(&decoder_sequence_checker_); + return codecDataBase_.IsExternalDecoderRegistered(payloadType); +} + void VideoReceiver2::DecoderThreadStarting() { - RTC_DCHECK_RUN_ON(&construction_thread_checker_); + RTC_DCHECK_RUN_ON(&construction_sequence_checker_); RTC_DCHECK(!IsDecoderThreadRunning()); #if RTC_DCHECK_IS_ON decoder_thread_is_running_ = true; @@ -73,17 +88,17 @@ void VideoReceiver2::DecoderThreadStarting() { } void VideoReceiver2::DecoderThreadStopped() { - RTC_DCHECK_RUN_ON(&construction_thread_checker_); + RTC_DCHECK_RUN_ON(&construction_sequence_checker_); RTC_DCHECK(IsDecoderThreadRunning()); #if RTC_DCHECK_IS_ON decoder_thread_is_running_ = false; - decoder_thread_checker_.Detach(); + decoder_sequence_checker_.Detach(); #endif } // Must be called from inside the receive side critical section. int32_t VideoReceiver2::Decode(const VCMEncodedFrame* frame) { - RTC_DCHECK_RUN_ON(&decoder_thread_checker_); + RTC_DCHECK_RUN_ON(&decoder_sequence_checker_); TRACE_EVENT0("webrtc", "VideoReceiver2::Decode"); // Change decoder if payload type has changed VCMGenericDecoder* decoder = @@ -98,7 +113,7 @@ int32_t VideoReceiver2::Decode(const VCMEncodedFrame* frame) { int32_t VideoReceiver2::RegisterReceiveCodec(uint8_t payload_type, const VideoCodec* receiveCodec, int32_t numberOfCores) { - RTC_DCHECK_RUN_ON(&construction_thread_checker_); + RTC_DCHECK_RUN_ON(&construction_sequence_checker_); RTC_DCHECK(!IsDecoderThreadRunning()); if (receiveCodec == nullptr) { return VCM_PARAMETER_ERROR; diff --git a/modules/video_coding/video_receiver2.h b/modules/video_coding/video_receiver2.h index c7b7b80b6d..0c3fe1a257 100644 --- a/modules/video_coding/video_receiver2.h +++ b/modules/video_coding/video_receiver2.h @@ -11,11 +11,11 @@ #ifndef MODULES_VIDEO_CODING_VIDEO_RECEIVER2_H_ #define MODULES_VIDEO_CODING_VIDEO_RECEIVER2_H_ +#include "api/sequence_checker.h" #include "modules/video_coding/decoder_database.h" #include "modules/video_coding/encoded_frame.h" #include "modules/video_coding/generic_decoder.h" #include "modules/video_coding/timing.h" -#include "rtc_base/thread_checker.h" #include "system_wrappers/include/clock.h" namespace webrtc { @@ -36,6 +36,7 @@ class VideoReceiver2 { void RegisterExternalDecoder(VideoDecoder* externalDecoder, uint8_t payloadType); + bool IsExternalDecoderRegistered(uint8_t payloadType) const; int32_t RegisterReceiveCallback(VCMReceiveCallback* receiveCallback); int32_t Decode(const webrtc::VCMEncodedFrame* frame); @@ -54,8 +55,8 @@ class VideoReceiver2 { // In builds where DCHECKs aren't enabled, it will return true. bool IsDecoderThreadRunning(); - rtc::ThreadChecker construction_thread_checker_; - rtc::ThreadChecker decoder_thread_checker_; + SequenceChecker construction_sequence_checker_; + SequenceChecker decoder_sequence_checker_; Clock* const clock_; VCMTiming* timing_; VCMDecodedFrameCallback decodedFrameCallback_; diff --git a/net/dcsctp/BUILD.gn b/net/dcsctp/BUILD.gn new file mode 100644 index 0000000000..8b38a65ca1 --- /dev/null +++ b/net/dcsctp/BUILD.gn @@ -0,0 +1,26 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../webrtc.gni") + +if (rtc_include_tests) { + rtc_test("dcsctp_unittests") { + testonly = true + deps = [ + "../../test:test_main", + "common:dcsctp_common_unittests", + "fuzzers:dcsctp_fuzzers_unittests", + "packet:dcsctp_packet_unittests", + "public:dcsctp_public_unittests", + "rx:dcsctp_rx_unittests", + "socket:dcsctp_socket_unittests", + "timer:dcsctp_timer_unittests", + "tx:dcsctp_tx_unittests", + ] + } +} diff --git a/net/dcsctp/OWNERS b/net/dcsctp/OWNERS new file mode 100644 index 0000000000..06a0f86179 --- /dev/null +++ b/net/dcsctp/OWNERS @@ -0,0 +1,2 @@ +boivie@webrtc.org +orphis@webrtc.org diff --git a/net/dcsctp/common/BUILD.gn b/net/dcsctp/common/BUILD.gn new file mode 100644 index 0000000000..6e99cdcef4 --- /dev/null +++ b/net/dcsctp/common/BUILD.gn @@ -0,0 +1,63 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_source_set("internal_types") { + deps = [ + "../public:strong_alias", + "../public:types", + ] + sources = [ "internal_types.h" ] +} + +rtc_source_set("math") { + deps = [] + sources = [ "math.h" ] +} + +rtc_source_set("pair_hash") { + deps = [] + sources = [ "pair_hash.h" ] +} + +rtc_source_set("sequence_numbers") { + deps = [ ":internal_types" ] + sources = [ "sequence_numbers.h" ] +} + +rtc_source_set("str_join") { + deps = [ "../../../rtc_base:stringutils" ] + sources = [ "str_join.h" ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +if (rtc_include_tests) { + rtc_library("dcsctp_common_unittests") { + testonly = true + + defines = [] + deps = [ + ":math", + ":pair_hash", + ":sequence_numbers", + ":str_join", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../rtc_base:rtc_base_approved", + "../../../test:test_support", + ] + sources = [ + "math_test.cc", + "pair_hash_test.cc", + "sequence_numbers_test.cc", + "str_join_test.cc", + ] + } +} diff --git a/net/dcsctp/common/internal_types.h b/net/dcsctp/common/internal_types.h new file mode 100644 index 0000000000..b651d45d91 --- /dev/null +++ b/net/dcsctp/common/internal_types.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_COMMON_INTERNAL_TYPES_H_ +#define NET_DCSCTP_COMMON_INTERNAL_TYPES_H_ + +#include + +#include "net/dcsctp/public/strong_alias.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// Stream Sequence Number (SSN) +using SSN = StrongAlias; + +// Message Identifier (MID) +using MID = StrongAlias; + +// Fragment Sequence Number (FSN) +using FSN = StrongAlias; + +// Transmission Sequence Number (TSN) +using TSN = StrongAlias; + +// Reconfiguration Request Sequence Number +using ReconfigRequestSN = StrongAlias; + +// Verification Tag, used for packet validation. +using VerificationTag = StrongAlias; + +// Tie Tag, used as a nonce when connecting. +using TieTag = StrongAlias; + +// Hasher for separated ordered/unordered stream identifiers. +struct UnorderedStreamHash { + size_t operator()(const std::pair& p) const { + return std::hash{}(*p.first) ^ + (std::hash{}(*p.second) << 1); + } +}; + +} // namespace dcsctp +#endif // NET_DCSCTP_COMMON_INTERNAL_TYPES_H_ diff --git a/net/dcsctp/common/math.h b/net/dcsctp/common/math.h new file mode 100644 index 0000000000..12f690ed57 --- /dev/null +++ b/net/dcsctp/common/math.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_COMMON_MATH_H_ +#define NET_DCSCTP_COMMON_MATH_H_ + +namespace dcsctp { + +// Rounds up `val` to the nearest value that is divisible by four. Frequently +// used to e.g. pad chunks or parameters to an even 32-bit offset. +template +IntType RoundUpTo4(IntType val) { + return (val + 3) & ~3; +} + +// Similarly, rounds down `val` to the nearest value that is divisible by four. +template +IntType RoundDownTo4(IntType val) { + return val & ~3; +} + +// Returns true if `val` is divisible by four. +template +bool IsDivisibleBy4(IntType val) { + return (val & 3) == 0; +} + +} // namespace dcsctp + +#endif // NET_DCSCTP_COMMON_MATH_H_ diff --git a/net/dcsctp/common/math_test.cc b/net/dcsctp/common/math_test.cc new file mode 100644 index 0000000000..f95dfbdb55 --- /dev/null +++ b/net/dcsctp/common/math_test.cc @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/common/math.h" + +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(MathUtilTest, CanRoundUpTo4) { + // Signed numbers + EXPECT_EQ(RoundUpTo4(static_cast(-5)), -4); + EXPECT_EQ(RoundUpTo4(static_cast(-4)), -4); + EXPECT_EQ(RoundUpTo4(static_cast(-3)), 0); + EXPECT_EQ(RoundUpTo4(static_cast(-2)), 0); + EXPECT_EQ(RoundUpTo4(static_cast(-1)), 0); + EXPECT_EQ(RoundUpTo4(static_cast(0)), 0); + EXPECT_EQ(RoundUpTo4(static_cast(1)), 4); + EXPECT_EQ(RoundUpTo4(static_cast(2)), 4); + EXPECT_EQ(RoundUpTo4(static_cast(3)), 4); + EXPECT_EQ(RoundUpTo4(static_cast(4)), 4); + EXPECT_EQ(RoundUpTo4(static_cast(5)), 8); + EXPECT_EQ(RoundUpTo4(static_cast(6)), 8); + EXPECT_EQ(RoundUpTo4(static_cast(7)), 8); + EXPECT_EQ(RoundUpTo4(static_cast(8)), 8); + EXPECT_EQ(RoundUpTo4(static_cast(10000000000)), 10000000000); + EXPECT_EQ(RoundUpTo4(static_cast(10000000001)), 10000000004); + + // Unsigned numbers + EXPECT_EQ(RoundUpTo4(static_cast(0)), 0u); + EXPECT_EQ(RoundUpTo4(static_cast(1)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast(2)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast(3)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast(4)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast(5)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast(6)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast(7)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast(8)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast(10000000000)), 10000000000u); + EXPECT_EQ(RoundUpTo4(static_cast(10000000001)), 10000000004u); +} + +TEST(MathUtilTest, CanRoundDownTo4) { + // Signed numbers + EXPECT_EQ(RoundDownTo4(static_cast(-5)), -8); + EXPECT_EQ(RoundDownTo4(static_cast(-4)), -4); + EXPECT_EQ(RoundDownTo4(static_cast(-3)), -4); + EXPECT_EQ(RoundDownTo4(static_cast(-2)), -4); + EXPECT_EQ(RoundDownTo4(static_cast(-1)), -4); + EXPECT_EQ(RoundDownTo4(static_cast(0)), 0); + EXPECT_EQ(RoundDownTo4(static_cast(1)), 0); + EXPECT_EQ(RoundDownTo4(static_cast(2)), 0); + EXPECT_EQ(RoundDownTo4(static_cast(3)), 0); + EXPECT_EQ(RoundDownTo4(static_cast(4)), 4); + EXPECT_EQ(RoundDownTo4(static_cast(5)), 4); + EXPECT_EQ(RoundDownTo4(static_cast(6)), 4); + EXPECT_EQ(RoundDownTo4(static_cast(7)), 4); + EXPECT_EQ(RoundDownTo4(static_cast(8)), 8); + EXPECT_EQ(RoundDownTo4(static_cast(10000000000)), 10000000000); + EXPECT_EQ(RoundDownTo4(static_cast(10000000001)), 10000000000); + + // Unsigned numbers + EXPECT_EQ(RoundDownTo4(static_cast(0)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast(1)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast(2)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast(3)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast(4)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast(5)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast(6)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast(7)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast(8)), 8u); + EXPECT_EQ(RoundDownTo4(static_cast(10000000000)), 10000000000u); + EXPECT_EQ(RoundDownTo4(static_cast(10000000001)), 10000000000u); +} + +TEST(MathUtilTest, IsDivisibleBy4) { + // Signed numbers + EXPECT_EQ(IsDivisibleBy4(static_cast(-4)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(-3)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(-2)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(-1)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(0)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(1)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(2)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(3)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(4)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(5)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(6)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(7)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(8)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(10000000000)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(10000000001)), false); + + // Unsigned numbers + EXPECT_EQ(IsDivisibleBy4(static_cast(0)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(1)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(2)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(3)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(4)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(5)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(6)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(7)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast(8)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(10000000000)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast(10000000001)), false); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/common/pair_hash.h b/net/dcsctp/common/pair_hash.h new file mode 100644 index 0000000000..62af8b4221 --- /dev/null +++ b/net/dcsctp/common/pair_hash.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_COMMON_PAIR_HASH_H_ +#define NET_DCSCTP_COMMON_PAIR_HASH_H_ + +#include + +#include +#include + +namespace dcsctp { + +// A custom hash function for std::pair, to be able to be used as key in a +// std::unordered_map. If absl::flat_hash_map would ever be used, this is +// unnecessary as it already has a hash function for std::pair. +struct PairHash { + template + size_t operator()(const std::pair& p) const { + return (3 * std::hash{}(p.first)) ^ std::hash{}(p.second); + } +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_COMMON_PAIR_HASH_H_ diff --git a/net/dcsctp/common/pair_hash_test.cc b/net/dcsctp/common/pair_hash_test.cc new file mode 100644 index 0000000000..bcc3ec86c0 --- /dev/null +++ b/net/dcsctp/common/pair_hash_test.cc @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/common/pair_hash.h" + +#include +#include + +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(PairHashTest, CanInsertIntoSet) { + using MyPair = std::pair; + + std::unordered_set pairs; + + pairs.insert({1, 2}); + pairs.insert({3, 4}); + + EXPECT_NE(pairs.find({1, 2}), pairs.end()); + EXPECT_NE(pairs.find({3, 4}), pairs.end()); + EXPECT_EQ(pairs.find({1, 3}), pairs.end()); + EXPECT_EQ(pairs.find({3, 3}), pairs.end()); +} + +TEST(PairHashTest, CanInsertIntoMap) { + using MyPair = std::pair; + + std::unordered_map pairs; + + pairs[{1, 2}] = 99; + pairs[{3, 4}] = 100; + + EXPECT_EQ((pairs[{1, 2}]), 99); + EXPECT_EQ((pairs[{3, 4}]), 100); + EXPECT_EQ(pairs.find({1, 3}), pairs.end()); + EXPECT_EQ(pairs.find({3, 3}), pairs.end()); +} +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/common/sequence_numbers.h b/net/dcsctp/common/sequence_numbers.h new file mode 100644 index 0000000000..52b638b54a --- /dev/null +++ b/net/dcsctp/common/sequence_numbers.h @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_COMMON_SEQUENCE_NUMBERS_H_ +#define NET_DCSCTP_COMMON_SEQUENCE_NUMBERS_H_ + +#include +#include +#include + +#include "net/dcsctp/common/internal_types.h" + +namespace dcsctp { + +// UnwrappedSequenceNumber handles wrapping sequence numbers and unwraps them to +// an int64_t value space, to allow wrapped sequence numbers to be easily +// compared for ordering. +// +// Sequence numbers are expected to be monotonically increasing, but they do not +// need to be unwrapped in order, as long as the difference to the previous one +// is not larger than half the range of the wrapped sequence number. +// +// The WrappedType must be a StrongAlias type. +template +class UnwrappedSequenceNumber { + public: + static_assert( + !std::numeric_limits::is_signed, + "The wrapped type must be unsigned"); + static_assert( + std::numeric_limits::max() < + std::numeric_limits::max(), + "The wrapped type must be less than the int64_t value space"); + + // The unwrapper is a sort of factory and converts wrapped sequence numbers to + // unwrapped ones. + class Unwrapper { + public: + Unwrapper() : largest_(kValueLimit) {} + Unwrapper(const Unwrapper&) = default; + Unwrapper& operator=(const Unwrapper&) = default; + + // Given a wrapped `value`, and with knowledge of its current last seen + // largest number, will return a value that can be compared using normal + // operators, such as less-than, greater-than etc. + // + // This will also update the Unwrapper's state, to track the last seen + // largest value. + UnwrappedSequenceNumber Unwrap(WrappedType value) { + WrappedType wrapped_largest = + static_cast(largest_ % kValueLimit); + int64_t result = largest_ + Delta(value, wrapped_largest); + if (largest_ < result) { + largest_ = result; + } + return UnwrappedSequenceNumber(result); + } + + // Similar to `Unwrap`, but will not update the Unwrappers's internal state. + UnwrappedSequenceNumber PeekUnwrap(WrappedType value) const { + WrappedType uint32_largest = + static_cast(largest_ % kValueLimit); + int64_t result = largest_ + Delta(value, uint32_largest); + return UnwrappedSequenceNumber(result); + } + + // Resets the Unwrapper to its pristine state. Used when a sequence number + // is to be reset to zero. + void Reset() { largest_ = kValueLimit; } + + private: + static int64_t Delta(WrappedType value, WrappedType prev_value) { + static constexpr typename WrappedType::UnderlyingType kBreakpoint = + kValueLimit / 2; + typename WrappedType::UnderlyingType diff = *value - *prev_value; + diff %= kValueLimit; + if (diff < kBreakpoint) { + return static_cast(diff); + } + return static_cast(diff) - kValueLimit; + } + + int64_t largest_; + }; + + // Returns the wrapped value this type represents. + WrappedType Wrap() const { + return static_cast(value_ % kValueLimit); + } + + template + friend H AbslHashValue(H state, + const UnwrappedSequenceNumber& hash) { + return H::combine(std::move(state), hash.value_); + } + + bool operator==(const UnwrappedSequenceNumber& other) const { + return value_ == other.value_; + } + bool operator!=(const UnwrappedSequenceNumber& other) const { + return value_ != other.value_; + } + bool operator<(const UnwrappedSequenceNumber& other) const { + return value_ < other.value_; + } + bool operator>(const UnwrappedSequenceNumber& other) const { + return value_ > other.value_; + } + bool operator>=(const UnwrappedSequenceNumber& other) const { + return value_ >= other.value_; + } + bool operator<=(const UnwrappedSequenceNumber& other) const { + return value_ <= other.value_; + } + + // Increments the value. + void Increment() { ++value_; } + + // Returns the next value relative to this sequence number. + UnwrappedSequenceNumber next_value() const { + return UnwrappedSequenceNumber(value_ + 1); + } + + // Returns a new sequence number based on `value`, and adding `delta` (which + // may be negative). + static UnwrappedSequenceNumber AddTo( + UnwrappedSequenceNumber value, + int delta) { + return UnwrappedSequenceNumber(value.value_ + delta); + } + + // Returns the absolute difference between `lhs` and `rhs`. + static typename WrappedType::UnderlyingType Difference( + UnwrappedSequenceNumber lhs, + UnwrappedSequenceNumber rhs) { + return (lhs.value_ > rhs.value_) ? (lhs.value_ - rhs.value_) + : (rhs.value_ - lhs.value_); + } + + private: + explicit UnwrappedSequenceNumber(int64_t value) : value_(value) {} + static constexpr int64_t kValueLimit = + static_cast(1) + << std::numeric_limits::digits; + + int64_t value_; +}; + +// Unwrapped Transmission Sequence Numbers (TSN) +using UnwrappedTSN = UnwrappedSequenceNumber; + +// Unwrapped Stream Sequence Numbers (SSN) +using UnwrappedSSN = UnwrappedSequenceNumber; + +// Unwrapped Message Identifier (MID) +using UnwrappedMID = UnwrappedSequenceNumber; + +} // namespace dcsctp + +#endif // NET_DCSCTP_COMMON_SEQUENCE_NUMBERS_H_ diff --git a/net/dcsctp/common/sequence_numbers_test.cc b/net/dcsctp/common/sequence_numbers_test.cc new file mode 100644 index 0000000000..f5fa788876 --- /dev/null +++ b/net/dcsctp/common/sequence_numbers_test.cc @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/common/sequence_numbers.h" + +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +using Wrapped = StrongAlias; +using TestSequence = UnwrappedSequenceNumber; + +TEST(SequenceNumbersTest, SimpleUnwrapping) { + TestSequence::Unwrapper unwrapper; + + TestSequence s0 = unwrapper.Unwrap(Wrapped(0)); + TestSequence s1 = unwrapper.Unwrap(Wrapped(1)); + TestSequence s2 = unwrapper.Unwrap(Wrapped(2)); + TestSequence s3 = unwrapper.Unwrap(Wrapped(3)); + + EXPECT_LT(s0, s1); + EXPECT_LT(s0, s2); + EXPECT_LT(s0, s3); + EXPECT_LT(s1, s2); + EXPECT_LT(s1, s3); + EXPECT_LT(s2, s3); + + EXPECT_EQ(TestSequence::Difference(s1, s0), 1); + EXPECT_EQ(TestSequence::Difference(s2, s0), 2); + EXPECT_EQ(TestSequence::Difference(s3, s0), 3); + + EXPECT_GT(s1, s0); + EXPECT_GT(s2, s0); + EXPECT_GT(s3, s0); + EXPECT_GT(s2, s1); + EXPECT_GT(s3, s1); + EXPECT_GT(s3, s2); + + s0.Increment(); + EXPECT_EQ(s0, s1); + s1.Increment(); + EXPECT_EQ(s1, s2); + s2.Increment(); + EXPECT_EQ(s2, s3); + + EXPECT_EQ(TestSequence::AddTo(s0, 2), s3); +} + +TEST(SequenceNumbersTest, MidValueUnwrapping) { + TestSequence::Unwrapper unwrapper; + + TestSequence s0 = unwrapper.Unwrap(Wrapped(0x7FFE)); + TestSequence s1 = unwrapper.Unwrap(Wrapped(0x7FFF)); + TestSequence s2 = unwrapper.Unwrap(Wrapped(0x8000)); + TestSequence s3 = unwrapper.Unwrap(Wrapped(0x8001)); + + EXPECT_LT(s0, s1); + EXPECT_LT(s0, s2); + EXPECT_LT(s0, s3); + EXPECT_LT(s1, s2); + EXPECT_LT(s1, s3); + EXPECT_LT(s2, s3); + + EXPECT_EQ(TestSequence::Difference(s1, s0), 1); + EXPECT_EQ(TestSequence::Difference(s2, s0), 2); + EXPECT_EQ(TestSequence::Difference(s3, s0), 3); + + EXPECT_GT(s1, s0); + EXPECT_GT(s2, s0); + EXPECT_GT(s3, s0); + EXPECT_GT(s2, s1); + EXPECT_GT(s3, s1); + EXPECT_GT(s3, s2); + + s0.Increment(); + EXPECT_EQ(s0, s1); + s1.Increment(); + EXPECT_EQ(s1, s2); + s2.Increment(); + EXPECT_EQ(s2, s3); + + EXPECT_EQ(TestSequence::AddTo(s0, 2), s3); +} + +TEST(SequenceNumbersTest, WrappedUnwrapping) { + TestSequence::Unwrapper unwrapper; + + TestSequence s0 = unwrapper.Unwrap(Wrapped(0xFFFE)); + TestSequence s1 = unwrapper.Unwrap(Wrapped(0xFFFF)); + TestSequence s2 = unwrapper.Unwrap(Wrapped(0x0000)); + TestSequence s3 = unwrapper.Unwrap(Wrapped(0x0001)); + + EXPECT_LT(s0, s1); + EXPECT_LT(s0, s2); + EXPECT_LT(s0, s3); + EXPECT_LT(s1, s2); + EXPECT_LT(s1, s3); + EXPECT_LT(s2, s3); + + EXPECT_EQ(TestSequence::Difference(s1, s0), 1); + EXPECT_EQ(TestSequence::Difference(s2, s0), 2); + EXPECT_EQ(TestSequence::Difference(s3, s0), 3); + + EXPECT_GT(s1, s0); + EXPECT_GT(s2, s0); + EXPECT_GT(s3, s0); + EXPECT_GT(s2, s1); + EXPECT_GT(s3, s1); + EXPECT_GT(s3, s2); + + s0.Increment(); + EXPECT_EQ(s0, s1); + s1.Increment(); + EXPECT_EQ(s1, s2); + s2.Increment(); + EXPECT_EQ(s2, s3); + + EXPECT_EQ(TestSequence::AddTo(s0, 2), s3); +} + +TEST(SequenceNumbersTest, WrapAroundAFewTimes) { + TestSequence::Unwrapper unwrapper; + + TestSequence s0 = unwrapper.Unwrap(Wrapped(0)); + TestSequence prev = s0; + + for (uint32_t i = 1; i < 65536 * 3; i++) { + uint16_t wrapped = static_cast(i); + TestSequence si = unwrapper.Unwrap(Wrapped(wrapped)); + + EXPECT_LT(s0, si); + EXPECT_LT(prev, si); + prev = si; + } +} + +TEST(SequenceNumbersTest, IncrementIsSameAsWrapped) { + TestSequence::Unwrapper unwrapper; + + TestSequence s0 = unwrapper.Unwrap(Wrapped(0)); + + for (uint32_t i = 1; i < 65536 * 2; i++) { + uint16_t wrapped = static_cast(i); + TestSequence si = unwrapper.Unwrap(Wrapped(wrapped)); + + s0.Increment(); + EXPECT_EQ(s0, si); + } +} + +TEST(SequenceNumbersTest, UnwrappingLargerNumberIsAlwaysLarger) { + TestSequence::Unwrapper unwrapper; + + for (uint32_t i = 1; i < 65536 * 2; i++) { + uint16_t wrapped = static_cast(i); + TestSequence si = unwrapper.Unwrap(Wrapped(wrapped)); + + EXPECT_GT(unwrapper.Unwrap(Wrapped(wrapped + 1)), si); + EXPECT_GT(unwrapper.Unwrap(Wrapped(wrapped + 5)), si); + EXPECT_GT(unwrapper.Unwrap(Wrapped(wrapped + 10)), si); + EXPECT_GT(unwrapper.Unwrap(Wrapped(wrapped + 100)), si); + } +} + +TEST(SequenceNumbersTest, UnwrappingSmallerNumberIsAlwaysSmaller) { + TestSequence::Unwrapper unwrapper; + + for (uint32_t i = 1; i < 65536 * 2; i++) { + uint16_t wrapped = static_cast(i); + TestSequence si = unwrapper.Unwrap(Wrapped(wrapped)); + + EXPECT_LT(unwrapper.Unwrap(Wrapped(wrapped - 1)), si); + EXPECT_LT(unwrapper.Unwrap(Wrapped(wrapped - 5)), si); + EXPECT_LT(unwrapper.Unwrap(Wrapped(wrapped - 10)), si); + EXPECT_LT(unwrapper.Unwrap(Wrapped(wrapped - 100)), si); + } +} + +TEST(SequenceNumbersTest, DifferenceIsAbsolute) { + TestSequence::Unwrapper unwrapper; + + TestSequence this_value = unwrapper.Unwrap(Wrapped(10)); + TestSequence other_value = TestSequence::AddTo(this_value, 100); + + EXPECT_EQ(TestSequence::Difference(this_value, other_value), 100); + EXPECT_EQ(TestSequence::Difference(other_value, this_value), 100); + + TestSequence minus_value = TestSequence::AddTo(this_value, -100); + + EXPECT_EQ(TestSequence::Difference(this_value, minus_value), 100); + EXPECT_EQ(TestSequence::Difference(minus_value, this_value), 100); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/common/str_join.h b/net/dcsctp/common/str_join.h new file mode 100644 index 0000000000..04517827b7 --- /dev/null +++ b/net/dcsctp/common/str_join.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_COMMON_STR_JOIN_H_ +#define NET_DCSCTP_COMMON_STR_JOIN_H_ + +#include + +#include "absl/strings/string_view.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +template +std::string StrJoin(const Range& seq, absl::string_view delimiter) { + rtc::StringBuilder sb; + int idx = 0; + + for (const typename Range::value_type& elem : seq) { + if (idx > 0) { + sb << delimiter; + } + sb << elem; + + ++idx; + } + return sb.Release(); +} + +template +std::string StrJoin(const Range& seq, + absl::string_view delimiter, + const Functor& fn) { + rtc::StringBuilder sb; + int idx = 0; + + for (const typename Range::value_type& elem : seq) { + if (idx > 0) { + sb << delimiter; + } + fn(sb, elem); + + ++idx; + } + return sb.Release(); +} + +} // namespace dcsctp + +#endif // NET_DCSCTP_COMMON_STR_JOIN_H_ diff --git a/net/dcsctp/common/str_join_test.cc b/net/dcsctp/common/str_join_test.cc new file mode 100644 index 0000000000..dbfd92c1cf --- /dev/null +++ b/net/dcsctp/common/str_join_test.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/common/str_join.h" + +#include +#include +#include + +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(StrJoinTest, CanJoinStringsFromVector) { + std::vector strings = {"Hello", "World"}; + std::string s = StrJoin(strings, " "); + EXPECT_EQ(s, "Hello World"); +} + +TEST(StrJoinTest, CanJoinNumbersFromArray) { + std::array numbers = {1, 2, 3}; + std::string s = StrJoin(numbers, ","); + EXPECT_EQ(s, "1,2,3"); +} + +TEST(StrJoinTest, CanFormatElementsWhileJoining) { + std::vector> pairs = { + {"hello", "world"}, {"foo", "bar"}, {"fum", "gazonk"}}; + std::string s = StrJoin(pairs, ",", + [&](rtc::StringBuilder& sb, + const std::pair& p) { + sb << p.first << "=" << p.second; + }); + EXPECT_EQ(s, "hello=world,foo=bar,fum=gazonk"); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/fuzzers/BUILD.gn b/net/dcsctp/fuzzers/BUILD.gn new file mode 100644 index 0000000000..9edbae44d7 --- /dev/null +++ b/net/dcsctp/fuzzers/BUILD.gn @@ -0,0 +1,50 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_library("dcsctp_fuzzers") { + testonly = true + deps = [ + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:math", + "../packet:chunk", + "../packet:error_cause", + "../packet:parameter", + "../public:socket", + "../public:types", + "../socket:dcsctp_socket", + ] + sources = [ + "dcsctp_fuzzers.cc", + "dcsctp_fuzzers.h", + ] +} + +if (rtc_include_tests) { + rtc_library("dcsctp_fuzzers_unittests") { + testonly = true + + deps = [ + ":dcsctp_fuzzers", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../rtc_base:rtc_base_approved", + "../../../test:test_support", + "../packet:sctp_packet", + "../public:socket", + "../socket:dcsctp_socket", + "../testing:testing_macros", + ] + sources = [ "dcsctp_fuzzers_test.cc" ] + } +} diff --git a/net/dcsctp/fuzzers/dcsctp_fuzzers.cc b/net/dcsctp/fuzzers/dcsctp_fuzzers.cc new file mode 100644 index 0000000000..b4b6224ec4 --- /dev/null +++ b/net/dcsctp/fuzzers/dcsctp_fuzzers.cc @@ -0,0 +1,460 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/fuzzers/dcsctp_fuzzers.h" + +#include +#include +#include + +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" +#include "net/dcsctp/packet/error_cause/protocol_violation_cause.h" +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/state_cookie_parameter.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/socket/dcsctp_socket.h" +#include "net/dcsctp/socket/state_cookie.h" +#include "rtc_base/logging.h" + +namespace dcsctp { +namespace dcsctp_fuzzers { +namespace { +static constexpr int kRandomValue = FuzzerCallbacks::kRandomValue; +static constexpr size_t kMinInputLength = 5; +static constexpr size_t kMaxInputLength = 1024; + +// A starting state for the socket, when fuzzing. +enum class StartingState : int { + kConnectNotCalled, + // When socket initiating Connect + kConnectCalled, + kReceivedInitAck, + kReceivedCookieAck, + // When socket initiating Shutdown + kShutdownCalled, + kReceivedShutdownAck, + // When peer socket initiated Connect + kReceivedInit, + kReceivedCookieEcho, + // When peer initiated Shutdown + kReceivedShutdown, + kReceivedShutdownComplete, + kNumberOfStates, +}; + +// State about the current fuzzing iteration +class FuzzState { + public: + explicit FuzzState(rtc::ArrayView data) : data_(data) {} + + uint8_t GetByte() { + uint8_t value = 0; + if (offset_ < data_.size()) { + value = data_[offset_]; + ++offset_; + } + return value; + } + + TSN GetNextTSN() { return TSN(tsn_++); } + MID GetNextMID() { return MID(mid_++); } + + bool empty() const { return offset_ >= data_.size(); } + + private: + uint32_t tsn_ = kRandomValue; + uint32_t mid_ = 0; + rtc::ArrayView data_; + size_t offset_ = 0; +}; + +void SetSocketState(DcSctpSocketInterface& socket, + FuzzerCallbacks& socket_cb, + StartingState state) { + // We'll use another temporary peer socket for the establishment. + FuzzerCallbacks peer_cb; + DcSctpSocket peer("peer", peer_cb, nullptr, {}); + + switch (state) { + case StartingState::kConnectNotCalled: + return; + case StartingState::kConnectCalled: + socket.Connect(); + return; + case StartingState::kReceivedInitAck: + socket.Connect(); + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK + return; + case StartingState::kReceivedCookieAck: + socket.Connect(); + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK + return; + case StartingState::kShutdownCalled: + socket.Connect(); + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK + socket.Shutdown(); + return; + case StartingState::kReceivedShutdownAck: + socket.Connect(); + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK + socket.Shutdown(); + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // SHUTDOWN + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN_ACK + return; + case StartingState::kReceivedInit: + peer.Connect(); + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT + return; + case StartingState::kReceivedCookieEcho: + peer.Connect(); + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT_ACK + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ECHO + return; + case StartingState::kReceivedShutdown: + socket.Connect(); + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK + peer.Shutdown(); + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN + return; + case StartingState::kReceivedShutdownComplete: + socket.Connect(); + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK + peer.Shutdown(); + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // SHUTDOWN_ACK + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN_COMPLETE + return; + case StartingState::kNumberOfStates: + RTC_CHECK(false); + return; + } +} + +void MakeDataChunk(FuzzState& state, SctpPacket::Builder& b) { + DataChunk::Options options; + options.is_unordered = IsUnordered(state.GetByte() != 0); + options.is_beginning = Data::IsBeginning(state.GetByte() != 0); + options.is_end = Data::IsEnd(state.GetByte() != 0); + b.Add(DataChunk(state.GetNextTSN(), StreamID(state.GetByte()), + SSN(state.GetByte()), PPID(53), std::vector(10), + options)); +} + +void MakeInitChunk(FuzzState& state, SctpPacket::Builder& b) { + Parameters::Builder builder; + builder.Add(ForwardTsnSupportedParameter()); + + b.Add(InitChunk(VerificationTag(kRandomValue), 10000, 1000, 1000, + TSN(kRandomValue), builder.Build())); +} + +void MakeInitAckChunk(FuzzState& state, SctpPacket::Builder& b) { + Parameters::Builder builder; + builder.Add(ForwardTsnSupportedParameter()); + + uint8_t state_cookie[] = {1, 2, 3, 4, 5}; + Parameters::Builder params_builder = + Parameters::Builder().Add(StateCookieParameter(state_cookie)); + + b.Add(InitAckChunk(VerificationTag(kRandomValue), 10000, 1000, 1000, + TSN(kRandomValue), builder.Build())); +} + +void MakeSackChunk(FuzzState& state, SctpPacket::Builder& b) { + std::vector gap_ack_blocks; + uint16_t last_end = 0; + while (gap_ack_blocks.size() < 20) { + uint8_t delta_start = state.GetByte(); + if (delta_start < 0x80) { + break; + } + uint8_t delta_end = state.GetByte(); + + uint16_t start = last_end + delta_start; + uint16_t end = start + delta_end; + last_end = end; + gap_ack_blocks.emplace_back(start, end); + } + + TSN cum_ack_tsn(kRandomValue + state.GetByte()); + b.Add(SackChunk(cum_ack_tsn, 10000, std::move(gap_ack_blocks), {})); +} + +void MakeHeartbeatRequestChunk(FuzzState& state, SctpPacket::Builder& b) { + uint8_t info[] = {1, 2, 3, 4, 5}; + b.Add(HeartbeatRequestChunk( + Parameters::Builder().Add(HeartbeatInfoParameter(info)).Build())); +} + +void MakeHeartbeatAckChunk(FuzzState& state, SctpPacket::Builder& b) { + std::vector info(8); + b.Add(HeartbeatRequestChunk( + Parameters::Builder().Add(HeartbeatInfoParameter(info)).Build())); +} + +void MakeAbortChunk(FuzzState& state, SctpPacket::Builder& b) { + b.Add(AbortChunk( + /*filled_in_verification_tag=*/true, + Parameters::Builder().Add(UserInitiatedAbortCause("Fuzzing")).Build())); +} + +void MakeErrorChunk(FuzzState& state, SctpPacket::Builder& b) { + b.Add(ErrorChunk( + Parameters::Builder().Add(ProtocolViolationCause("Fuzzing")).Build())); +} + +void MakeCookieEchoChunk(FuzzState& state, SctpPacket::Builder& b) { + std::vector cookie(StateCookie::kCookieSize); + b.Add(CookieEchoChunk(cookie)); +} + +void MakeCookieAckChunk(FuzzState& state, SctpPacket::Builder& b) { + b.Add(CookieAckChunk()); +} + +void MakeShutdownChunk(FuzzState& state, SctpPacket::Builder& b) { + b.Add(ShutdownChunk(state.GetNextTSN())); +} + +void MakeShutdownAckChunk(FuzzState& state, SctpPacket::Builder& b) { + b.Add(ShutdownAckChunk()); +} + +void MakeShutdownCompleteChunk(FuzzState& state, SctpPacket::Builder& b) { + b.Add(ShutdownCompleteChunk(false)); +} + +void MakeReConfigChunk(FuzzState& state, SctpPacket::Builder& b) { + std::vector streams = {StreamID(state.GetByte())}; + Parameters::Builder params_builder = + Parameters::Builder().Add(OutgoingSSNResetRequestParameter( + ReconfigRequestSN(kRandomValue), ReconfigRequestSN(kRandomValue), + state.GetNextTSN(), streams)); + b.Add(ReConfigChunk(params_builder.Build())); +} + +void MakeForwardTsnChunk(FuzzState& state, SctpPacket::Builder& b) { + std::vector skipped_streams; + for (;;) { + uint8_t stream = state.GetByte(); + if (skipped_streams.size() > 20 || stream < 0x80) { + break; + } + skipped_streams.emplace_back(StreamID(stream), SSN(state.GetByte())); + } + b.Add(ForwardTsnChunk(state.GetNextTSN(), std::move(skipped_streams))); +} + +void MakeIDataChunk(FuzzState& state, SctpPacket::Builder& b) { + DataChunk::Options options; + options.is_unordered = IsUnordered(state.GetByte() != 0); + options.is_beginning = Data::IsBeginning(state.GetByte() != 0); + options.is_end = Data::IsEnd(state.GetByte() != 0); + b.Add(IDataChunk(state.GetNextTSN(), StreamID(state.GetByte()), + state.GetNextMID(), PPID(53), FSN(0), + std::vector(10), options)); +} + +void MakeIForwardTsnChunk(FuzzState& state, SctpPacket::Builder& b) { + std::vector skipped_streams; + for (;;) { + uint8_t stream = state.GetByte(); + if (skipped_streams.size() > 20 || stream < 0x80) { + break; + } + skipped_streams.emplace_back(StreamID(stream), SSN(state.GetByte())); + } + b.Add(IForwardTsnChunk(state.GetNextTSN(), std::move(skipped_streams))); +} + +class RandomFuzzedChunk : public Chunk { + public: + explicit RandomFuzzedChunk(FuzzState& state) : state_(state) {} + + void SerializeTo(std::vector& out) const override { + size_t bytes = state_.GetByte(); + for (size_t i = 0; i < bytes; ++i) { + out.push_back(state_.GetByte()); + } + } + + std::string ToString() const override { return std::string("RANDOM_FUZZED"); } + + private: + FuzzState& state_; +}; + +void MakeChunkWithRandomContent(FuzzState& state, SctpPacket::Builder& b) { + b.Add(RandomFuzzedChunk(state)); +} + +std::vector GeneratePacket(FuzzState& state) { + DcSctpOptions options; + // Setting a fixed limit to not be dependent on the defaults, which may + // change. + options.mtu = 2048; + SctpPacket::Builder builder(VerificationTag(kRandomValue), options); + + // The largest expected serialized chunk, as created by fuzzers. + static constexpr size_t kMaxChunkSize = 256; + + for (int i = 0; i < 5 && builder.bytes_remaining() > kMaxChunkSize; ++i) { + switch (state.GetByte()) { + case 1: + MakeDataChunk(state, builder); + break; + case 2: + MakeInitChunk(state, builder); + break; + case 3: + MakeInitAckChunk(state, builder); + break; + case 4: + MakeSackChunk(state, builder); + break; + case 5: + MakeHeartbeatRequestChunk(state, builder); + break; + case 6: + MakeHeartbeatAckChunk(state, builder); + break; + case 7: + MakeAbortChunk(state, builder); + break; + case 8: + MakeErrorChunk(state, builder); + break; + case 9: + MakeCookieEchoChunk(state, builder); + break; + case 10: + MakeCookieAckChunk(state, builder); + break; + case 11: + MakeShutdownChunk(state, builder); + break; + case 12: + MakeShutdownAckChunk(state, builder); + break; + case 13: + MakeShutdownCompleteChunk(state, builder); + break; + case 14: + MakeReConfigChunk(state, builder); + break; + case 15: + MakeForwardTsnChunk(state, builder); + break; + case 16: + MakeIDataChunk(state, builder); + break; + case 17: + MakeIForwardTsnChunk(state, builder); + break; + case 18: + MakeChunkWithRandomContent(state, builder); + break; + default: + break; + } + } + std::vector packet = builder.Build(); + return packet; +} +} // namespace + +void FuzzSocket(DcSctpSocketInterface& socket, + FuzzerCallbacks& cb, + rtc::ArrayView data) { + if (data.size() < kMinInputLength || data.size() > kMaxInputLength) { + return; + } + if (data[0] >= static_cast(StartingState::kNumberOfStates)) { + return; + } + + // Set the socket in a specified valid starting state + SetSocketState(socket, cb, static_cast(data[0])); + + FuzzState state(data.subview(1)); + + while (!state.empty()) { + switch (state.GetByte()) { + case 1: + // Generate a valid SCTP packet (based on fuzz data) and "receive it". + socket.ReceivePacket(GeneratePacket(state)); + break; + case 2: + socket.Connect(); + break; + case 3: + socket.Shutdown(); + break; + case 4: + socket.Close(); + break; + case 5: { + StreamID streams[] = {StreamID(state.GetByte())}; + socket.ResetStreams(streams); + } break; + case 6: { + uint8_t flags = state.GetByte(); + SendOptions options; + options.unordered = IsUnordered(flags & 0x01); + options.max_retransmissions = + (flags & 0x02) != 0 ? absl::make_optional(0) : absl::nullopt; + size_t payload_exponent = (flags >> 2) % 16; + size_t payload_size = static_cast(1) << payload_exponent; + socket.Send(DcSctpMessage(StreamID(state.GetByte()), PPID(53), + std::vector(payload_size)), + options); + break; + } + case 7: { + // Expire an active timeout/timer. + uint8_t timeout_idx = state.GetByte(); + absl::optional timeout_id = cb.ExpireTimeout(timeout_idx); + if (timeout_id.has_value()) { + socket.HandleTimeout(*timeout_id); + } + break; + } + default: + break; + } + } +} +} // namespace dcsctp_fuzzers +} // namespace dcsctp diff --git a/net/dcsctp/fuzzers/dcsctp_fuzzers.h b/net/dcsctp/fuzzers/dcsctp_fuzzers.h new file mode 100644 index 0000000000..f3de0722f4 --- /dev/null +++ b/net/dcsctp/fuzzers/dcsctp_fuzzers.h @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_ +#define NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_ + +#include +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/public/dcsctp_socket.h" + +namespace dcsctp { +namespace dcsctp_fuzzers { + +// A fake timeout used during fuzzing. +class FuzzerTimeout : public Timeout { + public: + explicit FuzzerTimeout(std::set& active_timeouts) + : active_timeouts_(active_timeouts) {} + + void Start(DurationMs duration_ms, TimeoutID timeout_id) override { + // Start is only allowed to be called on stopped or expired timeouts. + if (timeout_id_.has_value()) { + // It has been started before, but maybe it expired. Ensure that it's not + // running at least. + RTC_DCHECK(active_timeouts_.find(*timeout_id_) == active_timeouts_.end()); + } + timeout_id_ = timeout_id; + RTC_DCHECK(active_timeouts_.insert(timeout_id).second); + } + + void Stop() override { + // Stop is only allowed to be called on active timeouts. Not stopped or + // expired. + RTC_DCHECK(timeout_id_.has_value()); + RTC_DCHECK(active_timeouts_.erase(*timeout_id_) == 1); + timeout_id_ = absl::nullopt; + } + + // A set of all active timeouts, managed by `FuzzerCallbacks`. + std::set& active_timeouts_; + // If present, the timout is active and will expire reported as `timeout_id`. + absl::optional timeout_id_; +}; + +class FuzzerCallbacks : public DcSctpSocketCallbacks { + public: + static constexpr int kRandomValue = 42; + void SendPacket(rtc::ArrayView data) override { + sent_packets_.emplace_back(std::vector(data.begin(), data.end())); + } + std::unique_ptr CreateTimeout() override { + return std::make_unique(active_timeouts_); + } + TimeMs TimeMillis() override { return TimeMs(42); } + uint32_t GetRandomInt(uint32_t low, uint32_t high) override { + return kRandomValue; + } + void OnMessageReceived(DcSctpMessage message) override {} + void OnError(ErrorKind error, absl::string_view message) override {} + void OnAborted(ErrorKind error, absl::string_view message) override {} + void OnConnected() override {} + void OnClosed() override {} + void OnConnectionRestarted() override {} + void OnStreamsResetFailed(rtc::ArrayView outgoing_streams, + absl::string_view reason) override {} + void OnStreamsResetPerformed( + rtc::ArrayView outgoing_streams) override {} + void OnIncomingStreamsReset( + rtc::ArrayView incoming_streams) override {} + + std::vector ConsumeSentPacket() { + if (sent_packets_.empty()) { + return {}; + } + std::vector ret = sent_packets_.front(); + sent_packets_.pop_front(); + return ret; + } + + // Given an index among the active timeouts, will expire that one. + absl::optional ExpireTimeout(size_t index) { + if (index < active_timeouts_.size()) { + auto it = active_timeouts_.begin(); + std::advance(it, index); + TimeoutID timeout_id = *it; + active_timeouts_.erase(it); + return timeout_id; + } + return absl::nullopt; + } + + private: + // Needs to be ordered, to allow fuzzers to expire timers. + std::set active_timeouts_; + std::deque> sent_packets_; +}; + +// Given some fuzzing `data` will send packets to the socket as well as calling +// API methods. +void FuzzSocket(DcSctpSocketInterface& socket, + FuzzerCallbacks& cb, + rtc::ArrayView data); + +} // namespace dcsctp_fuzzers +} // namespace dcsctp +#endif // NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_ diff --git a/net/dcsctp/fuzzers/dcsctp_fuzzers_test.cc b/net/dcsctp/fuzzers/dcsctp_fuzzers_test.cc new file mode 100644 index 0000000000..c7d2cd7c99 --- /dev/null +++ b/net/dcsctp/fuzzers/dcsctp_fuzzers_test.cc @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/fuzzers/dcsctp_fuzzers.h" + +#include "api/array_view.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/socket/dcsctp_socket.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "rtc_base/logging.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace dcsctp_fuzzers { +namespace { + +// This is a testbed where fuzzed data that cause issues can be evaluated and +// crashes reproduced. Use `xxd -i ./crash-abc` to generate `data` below. +TEST(DcsctpFuzzersTest, PassesTestbed) { + uint8_t data[] = {0x07, 0x09, 0x00, 0x01, 0x11, 0xff, 0xff}; + + FuzzerCallbacks cb; + DcSctpOptions options; + options.disable_checksum_verification = true; + DcSctpSocket socket("A", cb, nullptr, options); + + FuzzSocket(socket, cb, data); +} + +} // namespace +} // namespace dcsctp_fuzzers +} // namespace dcsctp diff --git a/net/dcsctp/packet/BUILD.gn b/net/dcsctp/packet/BUILD.gn new file mode 100644 index 0000000000..9c08ebc80e --- /dev/null +++ b/net/dcsctp/packet/BUILD.gn @@ -0,0 +1,338 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +group("packet") { + deps = [ ":bounded_io" ] +} + +rtc_source_set("bounded_io") { + deps = [ + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + ] + sources = [ + "bounded_byte_reader.h", + "bounded_byte_writer.h", + ] +} + +rtc_library("tlv_trait") { + deps = [ + ":bounded_io", + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings:strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + sources = [ + "tlv_trait.cc", + "tlv_trait.h", + ] +} + +rtc_source_set("data") { + deps = [ + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:internal_types", + "../public:types", + ] + sources = [ "data.h" ] +} + +rtc_library("crc32c") { + deps = [ + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "//third_party/crc32c", + ] + sources = [ + "crc32c.cc", + "crc32c.h", + ] +} + +rtc_library("parameter") { + deps = [ + ":bounded_io", + ":data", + ":tlv_trait", + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:internal_types", + "../common:math", + "../common:str_join", + "../public:types", + ] + sources = [ + "parameter/add_incoming_streams_request_parameter.cc", + "parameter/add_incoming_streams_request_parameter.h", + "parameter/add_outgoing_streams_request_parameter.cc", + "parameter/add_outgoing_streams_request_parameter.h", + "parameter/forward_tsn_supported_parameter.cc", + "parameter/forward_tsn_supported_parameter.h", + "parameter/heartbeat_info_parameter.cc", + "parameter/heartbeat_info_parameter.h", + "parameter/incoming_ssn_reset_request_parameter.cc", + "parameter/incoming_ssn_reset_request_parameter.h", + "parameter/outgoing_ssn_reset_request_parameter.cc", + "parameter/outgoing_ssn_reset_request_parameter.h", + "parameter/parameter.cc", + "parameter/parameter.h", + "parameter/reconfiguration_response_parameter.cc", + "parameter/reconfiguration_response_parameter.h", + "parameter/ssn_tsn_reset_request_parameter.cc", + "parameter/ssn_tsn_reset_request_parameter.h", + "parameter/state_cookie_parameter.cc", + "parameter/state_cookie_parameter.h", + "parameter/supported_extensions_parameter.cc", + "parameter/supported_extensions_parameter.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("error_cause") { + deps = [ + ":data", + ":parameter", + ":tlv_trait", + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:internal_types", + "../common:math", + "../common:str_join", + "../packet:bounded_io", + "../public:types", + ] + sources = [ + "error_cause/cookie_received_while_shutting_down_cause.cc", + "error_cause/cookie_received_while_shutting_down_cause.h", + "error_cause/error_cause.cc", + "error_cause/error_cause.h", + "error_cause/invalid_mandatory_parameter_cause.cc", + "error_cause/invalid_mandatory_parameter_cause.h", + "error_cause/invalid_stream_identifier_cause.cc", + "error_cause/invalid_stream_identifier_cause.h", + "error_cause/missing_mandatory_parameter_cause.cc", + "error_cause/missing_mandatory_parameter_cause.h", + "error_cause/no_user_data_cause.cc", + "error_cause/no_user_data_cause.h", + "error_cause/out_of_resource_error_cause.cc", + "error_cause/out_of_resource_error_cause.h", + "error_cause/protocol_violation_cause.cc", + "error_cause/protocol_violation_cause.h", + "error_cause/restart_of_an_association_with_new_address_cause.cc", + "error_cause/restart_of_an_association_with_new_address_cause.h", + "error_cause/stale_cookie_error_cause.cc", + "error_cause/stale_cookie_error_cause.h", + "error_cause/unrecognized_chunk_type_cause.cc", + "error_cause/unrecognized_chunk_type_cause.h", + "error_cause/unrecognized_parameter_cause.cc", + "error_cause/unrecognized_parameter_cause.h", + "error_cause/unresolvable_address_cause.cc", + "error_cause/unresolvable_address_cause.h", + "error_cause/user_initiated_abort_cause.cc", + "error_cause/user_initiated_abort_cause.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("chunk") { + deps = [ + ":data", + ":error_cause", + ":parameter", + ":tlv_trait", + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:math", + "../common:str_join", + "../packet:bounded_io", + ] + sources = [ + "chunk/abort_chunk.cc", + "chunk/abort_chunk.h", + "chunk/chunk.cc", + "chunk/chunk.h", + "chunk/cookie_ack_chunk.cc", + "chunk/cookie_ack_chunk.h", + "chunk/cookie_echo_chunk.cc", + "chunk/cookie_echo_chunk.h", + "chunk/data_chunk.cc", + "chunk/data_chunk.h", + "chunk/data_common.h", + "chunk/error_chunk.cc", + "chunk/error_chunk.h", + "chunk/forward_tsn_chunk.cc", + "chunk/forward_tsn_chunk.h", + "chunk/forward_tsn_common.h", + "chunk/heartbeat_ack_chunk.cc", + "chunk/heartbeat_ack_chunk.h", + "chunk/heartbeat_request_chunk.cc", + "chunk/heartbeat_request_chunk.h", + "chunk/idata_chunk.cc", + "chunk/idata_chunk.h", + "chunk/iforward_tsn_chunk.cc", + "chunk/iforward_tsn_chunk.h", + "chunk/init_ack_chunk.cc", + "chunk/init_ack_chunk.h", + "chunk/init_chunk.cc", + "chunk/init_chunk.h", + "chunk/reconfig_chunk.cc", + "chunk/reconfig_chunk.h", + "chunk/sack_chunk.cc", + "chunk/sack_chunk.h", + "chunk/shutdown_ack_chunk.cc", + "chunk/shutdown_ack_chunk.h", + "chunk/shutdown_chunk.cc", + "chunk/shutdown_chunk.h", + "chunk/shutdown_complete_chunk.cc", + "chunk/shutdown_complete_chunk.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("chunk_validators") { + deps = [ + ":chunk", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + ] + sources = [ + "chunk_validators.cc", + "chunk_validators.h", + ] +} + +rtc_library("sctp_packet") { + deps = [ + ":bounded_io", + ":chunk", + ":crc32c", + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:internal_types", + "../common:math", + "../public:types", + ] + sources = [ + "sctp_packet.cc", + "sctp_packet.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/memory:memory", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +if (rtc_include_tests) { + rtc_library("dcsctp_packet_unittests") { + testonly = true + + deps = [ + ":bounded_io", + ":chunk", + ":chunk_validators", + ":crc32c", + ":error_cause", + ":parameter", + ":sctp_packet", + ":tlv_trait", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../rtc_base:rtc_base_approved", + "../../../test:test_support", + "../common:internal_types", + "../common:math", + "../public:types", + "../testing:testing_macros", + ] + sources = [ + "bounded_byte_reader_test.cc", + "bounded_byte_writer_test.cc", + "chunk/abort_chunk_test.cc", + "chunk/cookie_ack_chunk_test.cc", + "chunk/cookie_echo_chunk_test.cc", + "chunk/data_chunk_test.cc", + "chunk/error_chunk_test.cc", + "chunk/forward_tsn_chunk_test.cc", + "chunk/heartbeat_ack_chunk_test.cc", + "chunk/heartbeat_request_chunk_test.cc", + "chunk/idata_chunk_test.cc", + "chunk/iforward_tsn_chunk_test.cc", + "chunk/init_ack_chunk_test.cc", + "chunk/init_chunk_test.cc", + "chunk/reconfig_chunk_test.cc", + "chunk/sack_chunk_test.cc", + "chunk/shutdown_ack_chunk_test.cc", + "chunk/shutdown_chunk_test.cc", + "chunk/shutdown_complete_chunk_test.cc", + "chunk_validators_test.cc", + "crc32c_test.cc", + "error_cause/cookie_received_while_shutting_down_cause_test.cc", + "error_cause/invalid_mandatory_parameter_cause_test.cc", + "error_cause/invalid_stream_identifier_cause_test.cc", + "error_cause/missing_mandatory_parameter_cause_test.cc", + "error_cause/no_user_data_cause_test.cc", + "error_cause/out_of_resource_error_cause_test.cc", + "error_cause/protocol_violation_cause_test.cc", + "error_cause/restart_of_an_association_with_new_address_cause_test.cc", + "error_cause/stale_cookie_error_cause_test.cc", + "error_cause/unrecognized_chunk_type_cause_test.cc", + "error_cause/unrecognized_parameter_cause_test.cc", + "error_cause/unresolvable_address_cause_test.cc", + "error_cause/user_initiated_abort_cause_test.cc", + "parameter/add_incoming_streams_request_parameter_test.cc", + "parameter/add_outgoing_streams_request_parameter_test.cc", + "parameter/forward_tsn_supported_parameter_test.cc", + "parameter/incoming_ssn_reset_request_parameter_test.cc", + "parameter/outgoing_ssn_reset_request_parameter_test.cc", + "parameter/parameter_test.cc", + "parameter/reconfiguration_response_parameter_test.cc", + "parameter/ssn_tsn_reset_request_parameter_test.cc", + "parameter/state_cookie_parameter_test.cc", + "parameter/supported_extensions_parameter_test.cc", + "sctp_packet_test.cc", + "tlv_trait_test.cc", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + } +} diff --git a/net/dcsctp/packet/bounded_byte_reader.h b/net/dcsctp/packet/bounded_byte_reader.h new file mode 100644 index 0000000000..603ed6ac33 --- /dev/null +++ b/net/dcsctp/packet/bounded_byte_reader.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef NET_DCSCTP_PACKET_BOUNDED_BYTE_READER_H_ +#define NET_DCSCTP_PACKET_BOUNDED_BYTE_READER_H_ + +#include + +#include "api/array_view.h" + +namespace dcsctp { + +// TODO(boivie): These generic functions - and possibly this entire class - +// could be a candidate to have added to rtc_base/. They should use compiler +// intrinsics as well. +namespace internal { +// Loads a 8-bit unsigned word at `data`. +inline uint8_t LoadBigEndian8(const uint8_t* data) { + return data[0]; +} + +// Loads a 16-bit unsigned word at `data`. +inline uint16_t LoadBigEndian16(const uint8_t* data) { + return (data[0] << 8) | data[1]; +} + +// Loads a 32-bit unsigned word at `data`. +inline uint32_t LoadBigEndian32(const uint8_t* data) { + return (data[0] << 24) | (data[1] << 16) | (data[2] << 8) | data[3]; +} +} // namespace internal + +// BoundedByteReader wraps an ArrayView and divides it into two parts; A fixed +// size - which is the template parameter - and a variable size, which is what +// remains in `data` after the `FixedSize`. +// +// The BoundedByteReader provides methods to load/read big endian numbers from +// the FixedSize portion of the buffer, and these are read with static bounds +// checking, to avoid out-of-bounds accesses without a run-time penalty. +// +// The variable sized portion can either be used to create sub-readers, which +// themselves would provide compile-time bounds-checking, or the entire variable +// sized portion can be retrieved as an ArrayView. +template +class BoundedByteReader { + public: + explicit BoundedByteReader(rtc::ArrayView data) : data_(data) { + RTC_CHECK(data.size() >= FixedSize); + } + + template + uint8_t Load8() const { + static_assert(offset + sizeof(uint8_t) <= FixedSize, "Out-of-bounds"); + return internal::LoadBigEndian8(&data_[offset]); + } + + template + uint16_t Load16() const { + static_assert(offset + sizeof(uint16_t) <= FixedSize, "Out-of-bounds"); + static_assert((offset % sizeof(uint16_t)) == 0, "Unaligned access"); + return internal::LoadBigEndian16(&data_[offset]); + } + + template + uint32_t Load32() const { + static_assert(offset + sizeof(uint32_t) <= FixedSize, "Out-of-bounds"); + static_assert((offset % sizeof(uint32_t)) == 0, "Unaligned access"); + return internal::LoadBigEndian32(&data_[offset]); + } + + template + BoundedByteReader sub_reader(size_t variable_offset) const { + RTC_CHECK(FixedSize + variable_offset + SubSize <= data_.size()); + + rtc::ArrayView sub_span = + data_.subview(FixedSize + variable_offset, SubSize); + return BoundedByteReader(sub_span); + } + + size_t variable_data_size() const { return data_.size() - FixedSize; } + + rtc::ArrayView variable_data() const { + return data_.subview(FixedSize, data_.size() - FixedSize); + } + + private: + const rtc::ArrayView data_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_BOUNDED_BYTE_READER_H_ diff --git a/net/dcsctp/packet/bounded_byte_reader_test.cc b/net/dcsctp/packet/bounded_byte_reader_test.cc new file mode 100644 index 0000000000..2fb4a86785 --- /dev/null +++ b/net/dcsctp/packet/bounded_byte_reader_test.cc @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "net/dcsctp/packet/bounded_byte_reader.h" + +#include "api/array_view.h" +#include "rtc_base/buffer.h" +#include "rtc_base/checks.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(BoundedByteReaderTest, CanLoadData) { + uint8_t data[14] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4}; + + BoundedByteReader<8> reader(data); + EXPECT_EQ(reader.variable_data_size(), 6U); + EXPECT_EQ(reader.Load32<0>(), 0x01020304U); + EXPECT_EQ(reader.Load32<4>(), 0x05060708U); + EXPECT_EQ(reader.Load16<4>(), 0x0506U); + EXPECT_EQ(reader.Load8<4>(), 0x05U); + EXPECT_EQ(reader.Load8<5>(), 0x06U); + + BoundedByteReader<6> sub = reader.sub_reader<6>(0); + EXPECT_EQ(sub.Load16<0>(), 0x0900U); + EXPECT_EQ(sub.Load32<0>(), 0x09000102U); + EXPECT_EQ(sub.Load16<4>(), 0x0304U); + + EXPECT_THAT(reader.variable_data(), ElementsAre(9, 0, 1, 2, 3, 4)); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/bounded_byte_writer.h b/net/dcsctp/packet/bounded_byte_writer.h new file mode 100644 index 0000000000..467f26800b --- /dev/null +++ b/net/dcsctp/packet/bounded_byte_writer.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef NET_DCSCTP_PACKET_BOUNDED_BYTE_WRITER_H_ +#define NET_DCSCTP_PACKET_BOUNDED_BYTE_WRITER_H_ + +#include + +#include "api/array_view.h" + +namespace dcsctp { + +// TODO(boivie): These generic functions - and possibly this entire class - +// could be a candidate to have added to rtc_base/. They should use compiler +// intrinsics as well. +namespace internal { +// Stores a 8-bit unsigned word at `data`. +inline void StoreBigEndian8(uint8_t* data, uint8_t val) { + data[0] = val; +} + +// Stores a 16-bit unsigned word at `data`. +inline void StoreBigEndian16(uint8_t* data, uint16_t val) { + data[0] = val >> 8; + data[1] = val; +} + +// Stores a 32-bit unsigned word at `data`. +inline void StoreBigEndian32(uint8_t* data, uint32_t val) { + data[0] = val >> 24; + data[1] = val >> 16; + data[2] = val >> 8; + data[3] = val; +} +} // namespace internal + +// BoundedByteWriter wraps an ArrayView and divides it into two parts; A fixed +// size - which is the template parameter - and a variable size, which is what +// remains in `data` after the `FixedSize`. +// +// The BoundedByteWriter provides methods to write big endian numbers to the +// FixedSize portion of the buffer, and these are written with static bounds +// checking, to avoid out-of-bounds accesses without a run-time penalty. +// +// The variable sized portion can either be used to create sub-writers, which +// themselves would provide compile-time bounds-checking, or data can be copied +// to it. +template +class BoundedByteWriter { + public: + explicit BoundedByteWriter(rtc::ArrayView data) : data_(data) { + RTC_CHECK(data.size() >= FixedSize); + } + + template + void Store8(uint8_t value) { + static_assert(offset + sizeof(uint8_t) <= FixedSize, "Out-of-bounds"); + internal::StoreBigEndian8(&data_[offset], value); + } + + template + void Store16(uint16_t value) { + static_assert(offset + sizeof(uint16_t) <= FixedSize, "Out-of-bounds"); + static_assert((offset % sizeof(uint16_t)) == 0, "Unaligned access"); + internal::StoreBigEndian16(&data_[offset], value); + } + + template + void Store32(uint32_t value) { + static_assert(offset + sizeof(uint32_t) <= FixedSize, "Out-of-bounds"); + static_assert((offset % sizeof(uint32_t)) == 0, "Unaligned access"); + internal::StoreBigEndian32(&data_[offset], value); + } + + template + BoundedByteWriter sub_writer(size_t variable_offset) { + RTC_CHECK(FixedSize + variable_offset + SubSize <= data_.size()); + + return BoundedByteWriter( + data_.subview(FixedSize + variable_offset, SubSize)); + } + + void CopyToVariableData(rtc::ArrayView source) { + memcpy(data_.data() + FixedSize, source.data(), + std::min(source.size(), data_.size() - FixedSize)); + } + + private: + rtc::ArrayView data_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_BOUNDED_BYTE_WRITER_H_ diff --git a/net/dcsctp/packet/bounded_byte_writer_test.cc b/net/dcsctp/packet/bounded_byte_writer_test.cc new file mode 100644 index 0000000000..3cea0a2f7c --- /dev/null +++ b/net/dcsctp/packet/bounded_byte_writer_test.cc @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "net/dcsctp/packet/bounded_byte_writer.h" + +#include + +#include "api/array_view.h" +#include "rtc_base/buffer.h" +#include "rtc_base/checks.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(BoundedByteWriterTest, CanWriteData) { + std::vector data(14); + + BoundedByteWriter<8> writer(data); + writer.Store32<0>(0x01020304); + writer.Store16<4>(0x0506); + writer.Store8<6>(0x07); + writer.Store8<7>(0x08); + + uint8_t variable_data[] = {0, 0, 0, 0, 3, 0}; + writer.CopyToVariableData(variable_data); + + BoundedByteWriter<6> sub = writer.sub_writer<6>(0); + sub.Store32<0>(0x09000000); + sub.Store16<2>(0x0102); + + BoundedByteWriter<2> sub2 = writer.sub_writer<2>(4); + sub2.Store8<1>(0x04); + + EXPECT_THAT(data, ElementsAre(1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4)); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/abort_chunk.cc b/net/dcsctp/packet/chunk/abort_chunk.cc new file mode 100644 index 0000000000..8348eb96a9 --- /dev/null +++ b/net/dcsctp/packet/chunk/abort_chunk.cc @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/abort_chunk.h" + +#include + +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.7 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 6 |Reserved |T| Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / zero or more Error Causes / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int AbortChunk::kType; + +absl::optional AbortChunk::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + absl::optional error_causes = + Parameters::Parse(reader->variable_data()); + if (!error_causes.has_value()) { + return absl::nullopt; + } + uint8_t flags = reader->Load8<1>(); + bool filled_in_verification_tag = (flags & (1 << kFlagsBitT)) == 0; + return AbortChunk(filled_in_verification_tag, *std::move(error_causes)); +} + +void AbortChunk::SerializeTo(std::vector& out) const { + rtc::ArrayView error_causes = error_causes_.data(); + BoundedByteWriter writer = AllocateTLV(out, error_causes.size()); + writer.Store8<1>(filled_in_verification_tag_ ? 0 : (1 << kFlagsBitT)); + writer.CopyToVariableData(error_causes); +} + +std::string AbortChunk::ToString() const { + return "ABORT"; +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/abort_chunk.h b/net/dcsctp/packet/chunk/abort_chunk.h new file mode 100644 index 0000000000..1408a75e80 --- /dev/null +++ b/net/dcsctp/packet/chunk/abort_chunk.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_ABORT_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_ABORT_CHUNK_H_ +#include +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.7 +struct AbortChunkConfig : ChunkConfig { + static constexpr int kType = 6; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class AbortChunk : public Chunk, public TLVTrait { + public: + static constexpr int kType = AbortChunkConfig::kType; + + AbortChunk(bool filled_in_verification_tag, Parameters error_causes) + : filled_in_verification_tag_(filled_in_verification_tag), + error_causes_(std::move(error_causes)) {} + + AbortChunk(AbortChunk&& other) = default; + AbortChunk& operator=(AbortChunk&& other) = default; + + static absl::optional Parse(rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + bool filled_in_verification_tag() const { + return filled_in_verification_tag_; + } + + const Parameters& error_causes() const { return error_causes_; } + + private: + static constexpr int kFlagsBitT = 0; + bool filled_in_verification_tag_; + Parameters error_causes_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_ABORT_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/abort_chunk_test.cc b/net/dcsctp/packet/chunk/abort_chunk_test.cc new file mode 100644 index 0000000000..c1f3a4d5b9 --- /dev/null +++ b/net/dcsctp/packet/chunk/abort_chunk_test.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/abort_chunk.h" + +#include + +#include +#include + +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(AbortChunkTest, FromCapture) { + /* + ABORT chunk + Chunk type: ABORT (6) + Chunk flags: 0x00 + Chunk length: 8 + User initiated ABORT cause + */ + + uint8_t data[] = {0x06, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x04}; + + ASSERT_HAS_VALUE_AND_ASSIGN(AbortChunk chunk, AbortChunk::Parse(data)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + UserInitiatedAbortCause cause, + chunk.error_causes().get()); + + EXPECT_EQ(cause.upper_layer_abort_reason(), ""); +} + +TEST(AbortChunkTest, SerializeAndDeserialize) { + AbortChunk chunk(/*filled_in_verification_tag=*/true, + Parameters::Builder() + .Add(UserInitiatedAbortCause("Close called")) + .Build()); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(AbortChunk deserialized, + AbortChunk::Parse(serialized)); + ASSERT_HAS_VALUE_AND_ASSIGN( + UserInitiatedAbortCause cause, + deserialized.error_causes().get()); + + EXPECT_EQ(cause.upper_layer_abort_reason(), "Close called"); +} + +// Validates that AbortChunk doesn't make any alignment assumptions. +TEST(AbortChunkTest, SerializeAndDeserializeOneChar) { + AbortChunk chunk( + /*filled_in_verification_tag=*/true, + Parameters::Builder().Add(UserInitiatedAbortCause("!")).Build()); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(AbortChunk deserialized, + AbortChunk::Parse(serialized)); + ASSERT_HAS_VALUE_AND_ASSIGN( + UserInitiatedAbortCause cause, + deserialized.error_causes().get()); + + EXPECT_EQ(cause.upper_layer_abort_reason(), "!"); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/chunk.cc b/net/dcsctp/packet/chunk/chunk.cc new file mode 100644 index 0000000000..832ab82288 --- /dev/null +++ b/net/dcsctp/packet/chunk/chunk.cc @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/chunk.h" + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/chunk/abort_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/error_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/init_ack_chunk.h" +#include "net/dcsctp/packet/chunk/init_chunk.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_ack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_complete_chunk.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +template +bool ParseAndPrint(uint8_t chunk_type, + rtc::ArrayView data, + rtc::StringBuilder& sb) { + if (chunk_type == Chunk::kType) { + absl::optional c = Chunk::Parse(data); + if (c.has_value()) { + sb << c->ToString(); + } else { + sb << "Failed to parse chunk of type " << chunk_type; + } + return true; + } + return false; +} + +std::string DebugConvertChunkToString(rtc::ArrayView data) { + rtc::StringBuilder sb; + + if (data.empty()) { + sb << "Failed to parse chunk due to empty data"; + } else { + uint8_t chunk_type = data[0]; + if (!ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb) && + !ParseAndPrint(chunk_type, data, sb)) { + sb << "Unhandled chunk type: " << static_cast(chunk_type); + } + } + return sb.Release(); +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/chunk.h b/net/dcsctp/packet/chunk/chunk.h new file mode 100644 index 0000000000..687aa1daa1 --- /dev/null +++ b/net/dcsctp/packet/chunk/chunk.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_CHUNK_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// Base class for all SCTP chunks +class Chunk { + public: + Chunk() {} + virtual ~Chunk() = default; + + // Chunks can contain data payloads that shouldn't be copied unnecessarily. + Chunk(Chunk&& other) = default; + Chunk& operator=(Chunk&& other) = default; + Chunk(const Chunk&) = delete; + Chunk& operator=(const Chunk&) = delete; + + // Serializes the chunk to `out`, growing it as necessary. + virtual void SerializeTo(std::vector& out) const = 0; + + // Returns a human readable description of this chunk and its parameters. + virtual std::string ToString() const = 0; +}; + +// Introspects the chunk in `data` and returns a human readable textual +// representation of it, to be used in debugging. +std::string DebugConvertChunkToString(rtc::ArrayView data); + +struct ChunkConfig { + static constexpr int kTypeSizeInBytes = 1; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/cookie_ack_chunk.cc b/net/dcsctp/packet/chunk/cookie_ack_chunk.cc new file mode 100644 index 0000000000..4839969ccf --- /dev/null +++ b/net/dcsctp/packet/chunk/cookie_ack_chunk.cc @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" + +#include + +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.12 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 11 |Chunk Flags | Length = 4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int CookieAckChunk::kType; + +absl::optional CookieAckChunk::Parse( + rtc::ArrayView data) { + if (!ParseTLV(data).has_value()) { + return absl::nullopt; + } + return CookieAckChunk(); +} + +void CookieAckChunk::SerializeTo(std::vector& out) const { + AllocateTLV(out); +} + +std::string CookieAckChunk::ToString() const { + return "COOKIE-ACK"; +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/cookie_ack_chunk.h b/net/dcsctp/packet/chunk/cookie_ack_chunk.h new file mode 100644 index 0000000000..f7d4a33f7d --- /dev/null +++ b/net/dcsctp/packet/chunk/cookie_ack_chunk.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_COOKIE_ACK_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_COOKIE_ACK_CHUNK_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.12 +struct CookieAckChunkConfig : ChunkConfig { + static constexpr int kType = 11; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class CookieAckChunk : public Chunk, public TLVTrait { + public: + static constexpr int kType = CookieAckChunkConfig::kType; + + CookieAckChunk() {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_COOKIE_ACK_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/cookie_ack_chunk_test.cc b/net/dcsctp/packet/chunk/cookie_ack_chunk_test.cc new file mode 100644 index 0000000000..3f560c6fef --- /dev/null +++ b/net/dcsctp/packet/chunk/cookie_ack_chunk_test.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" + +#include + +#include +#include + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(CookieAckChunkTest, FromCapture) { + /* + COOKIE_ACK chunk + Chunk type: COOKIE_ACK (11) + Chunk flags: 0x00 + Chunk length: 4 + */ + + uint8_t data[] = {0x0b, 0x00, 0x00, 0x04}; + + EXPECT_TRUE(CookieAckChunk::Parse(data).has_value()); +} + +TEST(CookieAckChunkTest, SerializeAndDeserialize) { + CookieAckChunk chunk; + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(CookieAckChunk deserialized, + CookieAckChunk::Parse(serialized)); + EXPECT_EQ(deserialized.ToString(), "COOKIE-ACK"); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/cookie_echo_chunk.cc b/net/dcsctp/packet/chunk/cookie_echo_chunk.cc new file mode 100644 index 0000000000..a01d0b13c4 --- /dev/null +++ b/net/dcsctp/packet/chunk/cookie_echo_chunk.cc @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" + +#include + +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.11 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 10 |Chunk Flags | Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / Cookie / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int CookieEchoChunk::kType; + +absl::optional CookieEchoChunk::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + return CookieEchoChunk(reader->variable_data()); +} + +void CookieEchoChunk::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = AllocateTLV(out, cookie_.size()); + writer.CopyToVariableData(cookie_); +} + +std::string CookieEchoChunk::ToString() const { + return "COOKIE-ECHO"; +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/cookie_echo_chunk.h b/net/dcsctp/packet/chunk/cookie_echo_chunk.h new file mode 100644 index 0000000000..8cb80527f8 --- /dev/null +++ b/net/dcsctp/packet/chunk/cookie_echo_chunk.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_COOKIE_ECHO_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_COOKIE_ECHO_CHUNK_H_ +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.11 +struct CookieEchoChunkConfig : ChunkConfig { + static constexpr int kType = 10; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class CookieEchoChunk : public Chunk, public TLVTrait { + public: + static constexpr int kType = CookieEchoChunkConfig::kType; + + explicit CookieEchoChunk(rtc::ArrayView cookie) + : cookie_(cookie.begin(), cookie.end()) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + rtc::ArrayView cookie() const { return cookie_; } + + private: + std::vector cookie_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_COOKIE_ECHO_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/cookie_echo_chunk_test.cc b/net/dcsctp/packet/chunk/cookie_echo_chunk_test.cc new file mode 100644 index 0000000000..d06e0a6439 --- /dev/null +++ b/net/dcsctp/packet/chunk/cookie_echo_chunk_test.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" + +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(CookieEchoChunkTest, FromCapture) { + /* + COOKIE_ECHO chunk (Cookie length: 256 bytes) + Chunk type: COOKIE_ECHO (10) + Chunk flags: 0x00 + Chunk length: 260 + Cookie: 12345678 + */ + + uint8_t data[] = {0x0a, 0x00, 0x00, 0x08, 0x12, 0x34, 0x56, 0x78}; + + ASSERT_HAS_VALUE_AND_ASSIGN(CookieEchoChunk chunk, + CookieEchoChunk::Parse(data)); + + EXPECT_THAT(chunk.cookie(), ElementsAre(0x12, 0x34, 0x56, 0x78)); +} + +TEST(CookieEchoChunkTest, SerializeAndDeserialize) { + uint8_t cookie[] = {1, 2, 3, 4}; + CookieEchoChunk chunk(cookie); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(CookieEchoChunk deserialized, + CookieEchoChunk::Parse(serialized)); + + EXPECT_THAT(deserialized.cookie(), ElementsAre(1, 2, 3, 4)); + EXPECT_EQ(deserialized.ToString(), "COOKIE-ECHO"); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/data_chunk.cc b/net/dcsctp/packet/chunk/data_chunk.cc new file mode 100644 index 0000000000..cf65f53d29 --- /dev/null +++ b/net/dcsctp/packet/chunk/data_chunk.cc @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/data_chunk.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.1 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 0 | Reserved|U|B|E| Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Identifier S | Stream Sequence Number n | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Payload Protocol Identifier | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / User Data (seq n of Stream S) / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int DataChunk::kType; + +absl::optional DataChunk::Parse(rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + uint8_t flags = reader->Load8<1>(); + TSN tsn(reader->Load32<4>()); + StreamID stream_identifier(reader->Load16<8>()); + SSN ssn(reader->Load16<10>()); + PPID ppid(reader->Load32<12>()); + + Options options; + options.is_end = Data::IsEnd((flags & (1 << kFlagsBitEnd)) != 0); + options.is_beginning = + Data::IsBeginning((flags & (1 << kFlagsBitBeginning)) != 0); + options.is_unordered = IsUnordered((flags & (1 << kFlagsBitUnordered)) != 0); + options.immediate_ack = + ImmediateAckFlag((flags & (1 << kFlagsBitImmediateAck)) != 0); + + return DataChunk(tsn, stream_identifier, ssn, ppid, + std::vector(reader->variable_data().begin(), + reader->variable_data().end()), + options); +} + +void DataChunk::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = AllocateTLV(out, payload().size()); + + writer.Store8<1>( + (*options().is_end ? (1 << kFlagsBitEnd) : 0) | + (*options().is_beginning ? (1 << kFlagsBitBeginning) : 0) | + (*options().is_unordered ? (1 << kFlagsBitUnordered) : 0) | + (*options().immediate_ack ? (1 << kFlagsBitImmediateAck) : 0)); + writer.Store32<4>(*tsn()); + writer.Store16<8>(*stream_id()); + writer.Store16<10>(*ssn()); + writer.Store32<12>(*ppid()); + + writer.CopyToVariableData(payload()); +} + +std::string DataChunk::ToString() const { + rtc::StringBuilder sb; + sb << "DATA, type=" << (options().is_unordered ? "unordered" : "ordered") + << "::" + << (*options().is_beginning && *options().is_end + ? "complete" + : *options().is_beginning ? "first" + : *options().is_end ? "last" : "middle") + << ", tsn=" << *tsn() << ", stream_id=" << *stream_id() + << ", ppid=" << *ppid() << ", length=" << payload().size(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/data_chunk.h b/net/dcsctp/packet/chunk/data_chunk.h new file mode 100644 index 0000000000..12bb05f2c4 --- /dev/null +++ b/net/dcsctp/packet/chunk/data_chunk.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_DATA_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_DATA_CHUNK_H_ +#include +#include + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.1 +struct DataChunkConfig : ChunkConfig { + static constexpr int kType = 0; + static constexpr size_t kHeaderSize = 16; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class DataChunk : public AnyDataChunk, public TLVTrait { + public: + static constexpr int kType = DataChunkConfig::kType; + + // Exposed to allow the retransmission queue to make room for the correct + // header size. + static constexpr size_t kHeaderSize = DataChunkConfig::kHeaderSize; + + DataChunk(TSN tsn, + StreamID stream_id, + SSN ssn, + PPID ppid, + std::vector payload, + const Options& options) + : AnyDataChunk(tsn, + stream_id, + ssn, + MID(0), + FSN(0), + ppid, + std::move(payload), + options) {} + + DataChunk(TSN tsn, Data&& data, bool immediate_ack) + : AnyDataChunk(tsn, std::move(data), immediate_ack) {} + + static absl::optional Parse(rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_DATA_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/data_chunk_test.cc b/net/dcsctp/packet/chunk/data_chunk_test.cc new file mode 100644 index 0000000000..6a5ca82bae --- /dev/null +++ b/net/dcsctp/packet/chunk/data_chunk_test.cc @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/data_chunk.h" + +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(DataChunkTest, FromCapture) { + /* + DATA chunk(ordered, complete segment, TSN: 1426601532, SID: 2, SSN: 1, + PPID: 53, payload length: 4 bytes) + Chunk type: DATA (0) + Chunk flags: 0x03 + Chunk length: 20 + Transmission sequence number: 1426601532 + Stream identifier: 0x0002 + Stream sequence number: 1 + Payload protocol identifier: WebRTC Binary (53) + */ + + uint8_t data[] = {0x00, 0x03, 0x00, 0x14, 0x55, 0x08, 0x36, 0x3c, 0x00, 0x02, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x35, 0x00, 0x01, 0x02, 0x03}; + + ASSERT_HAS_VALUE_AND_ASSIGN(DataChunk chunk, DataChunk::Parse(data)); + EXPECT_EQ(*chunk.tsn(), 1426601532u); + EXPECT_EQ(*chunk.stream_id(), 2u); + EXPECT_EQ(*chunk.ssn(), 1u); + EXPECT_EQ(*chunk.ppid(), 53u); + EXPECT_TRUE(*chunk.options().is_beginning); + EXPECT_TRUE(*chunk.options().is_end); + EXPECT_FALSE(*chunk.options().is_unordered); + EXPECT_FALSE(*chunk.options().immediate_ack); + EXPECT_THAT(chunk.payload(), ElementsAre(0x0, 0x1, 0x2, 0x3)); +} + +TEST(DataChunkTest, SerializeAndDeserialize) { + DataChunk chunk(TSN(123), StreamID(456), SSN(789), PPID(9090), + /*payload=*/{1, 2, 3, 4, 5}, + /*options=*/{}); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(DataChunk deserialized, + DataChunk::Parse(serialized)); + EXPECT_EQ(*chunk.tsn(), 123u); + EXPECT_EQ(*chunk.stream_id(), 456u); + EXPECT_EQ(*chunk.ssn(), 789u); + EXPECT_EQ(*chunk.ppid(), 9090u); + EXPECT_THAT(chunk.payload(), ElementsAre(1, 2, 3, 4, 5)); + + EXPECT_EQ(deserialized.ToString(), + "DATA, type=ordered::middle, tsn=123, stream_id=456, ppid=9090, " + "length=5"); +} +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/data_common.h b/net/dcsctp/packet/chunk/data_common.h new file mode 100644 index 0000000000..b15a034593 --- /dev/null +++ b/net/dcsctp/packet/chunk/data_common.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_DATA_COMMON_H_ +#define NET_DCSCTP_PACKET_CHUNK_DATA_COMMON_H_ +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/data.h" + +namespace dcsctp { + +// Base class for DataChunk and IDataChunk +class AnyDataChunk : public Chunk { + public: + // Represents the "immediate ack" flag on DATA/I-DATA, from RFC7053. + using ImmediateAckFlag = StrongAlias; + + // Data chunk options. + // See https://tools.ietf.org/html/rfc4960#section-3.3.1 + struct Options { + Data::IsEnd is_end = Data::IsEnd(false); + Data::IsBeginning is_beginning = Data::IsBeginning(false); + IsUnordered is_unordered = IsUnordered(false); + ImmediateAckFlag immediate_ack = ImmediateAckFlag(false); + }; + + TSN tsn() const { return tsn_; } + + Options options() const { + Options options; + options.is_end = data_.is_end; + options.is_beginning = data_.is_beginning; + options.is_unordered = data_.is_unordered; + options.immediate_ack = immediate_ack_; + return options; + } + + StreamID stream_id() const { return data_.stream_id; } + SSN ssn() const { return data_.ssn; } + MID message_id() const { return data_.message_id; } + FSN fsn() const { return data_.fsn; } + PPID ppid() const { return data_.ppid; } + rtc::ArrayView payload() const { return data_.payload; } + + // Extracts the Data from the chunk, as a destructive action. + Data extract() && { return std::move(data_); } + + AnyDataChunk(TSN tsn, + StreamID stream_id, + SSN ssn, + MID message_id, + FSN fsn, + PPID ppid, + std::vector payload, + const Options& options) + : tsn_(tsn), + data_(stream_id, + ssn, + message_id, + fsn, + ppid, + std::move(payload), + options.is_beginning, + options.is_end, + options.is_unordered), + immediate_ack_(options.immediate_ack) {} + + AnyDataChunk(TSN tsn, Data data, bool immediate_ack) + : tsn_(tsn), data_(std::move(data)), immediate_ack_(immediate_ack) {} + + protected: + // Bits in `flags` header field. + static constexpr int kFlagsBitEnd = 0; + static constexpr int kFlagsBitBeginning = 1; + static constexpr int kFlagsBitUnordered = 2; + static constexpr int kFlagsBitImmediateAck = 3; + + private: + TSN tsn_; + Data data_; + ImmediateAckFlag immediate_ack_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_DATA_COMMON_H_ diff --git a/net/dcsctp/packet/chunk/error_chunk.cc b/net/dcsctp/packet/chunk/error_chunk.cc new file mode 100644 index 0000000000..baac0c5588 --- /dev/null +++ b/net/dcsctp/packet/chunk/error_chunk.cc @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/error_chunk.h" + +#include + +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 9 | Chunk Flags | Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / one or more Error Causes / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ErrorChunk::kType; + +absl::optional ErrorChunk::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + absl::optional error_causes = + Parameters::Parse(reader->variable_data()); + if (!error_causes.has_value()) { + return absl::nullopt; + } + return ErrorChunk(*std::move(error_causes)); +} + +void ErrorChunk::SerializeTo(std::vector& out) const { + rtc::ArrayView error_causes = error_causes_.data(); + BoundedByteWriter writer = AllocateTLV(out, error_causes.size()); + writer.CopyToVariableData(error_causes); +} + +std::string ErrorChunk::ToString() const { + return "ERROR"; +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/error_chunk.h b/net/dcsctp/packet/chunk/error_chunk.h new file mode 100644 index 0000000000..96122cff6a --- /dev/null +++ b/net/dcsctp/packet/chunk/error_chunk.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_ERROR_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_ERROR_CHUNK_H_ +#include +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10 +struct ErrorChunkConfig : ChunkConfig { + static constexpr int kType = 9; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 4; +}; + +class ErrorChunk : public Chunk, public TLVTrait { + public: + static constexpr int kType = ErrorChunkConfig::kType; + + explicit ErrorChunk(Parameters error_causes) + : error_causes_(std::move(error_causes)) {} + + ErrorChunk(ErrorChunk&& other) = default; + ErrorChunk& operator=(ErrorChunk&& other) = default; + + static absl::optional Parse(rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + const Parameters& error_causes() const { return error_causes_; } + + private: + Parameters error_causes_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_ERROR_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/error_chunk_test.cc b/net/dcsctp/packet/chunk/error_chunk_test.cc new file mode 100644 index 0000000000..f2b8be1edc --- /dev/null +++ b/net/dcsctp/packet/chunk/error_chunk_test.cc @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/error_chunk.h" + +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(ErrorChunkTest, FromCapture) { + /* + ERROR chunk + Chunk type: ERROR (9) + Chunk flags: 0x00 + Chunk length: 12 + Unrecognized chunk type cause (Type: 73 (unknown)) + */ + + uint8_t data[] = {0x09, 0x00, 0x00, 0x0c, 0x00, 0x06, + 0x00, 0x08, 0x49, 0x00, 0x00, 0x04}; + + ASSERT_HAS_VALUE_AND_ASSIGN(ErrorChunk chunk, ErrorChunk::Parse(data)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + UnrecognizedChunkTypeCause cause, + chunk.error_causes().get()); + + EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(0x49, 0x00, 0x00, 0x04)); +} + +TEST(ErrorChunkTest, SerializeAndDeserialize) { + ErrorChunk chunk(Parameters::Builder() + .Add(UnrecognizedChunkTypeCause({1, 2, 3, 4})) + .Build()); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(ErrorChunk deserialized, + ErrorChunk::Parse(serialized)); + ASSERT_HAS_VALUE_AND_ASSIGN( + UnrecognizedChunkTypeCause cause, + deserialized.error_causes().get()); + + EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(1, 2, 3, 4)); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/forward_tsn_chunk.cc b/net/dcsctp/packet/chunk/forward_tsn_chunk.cc new file mode 100644 index 0000000000..f01505094d --- /dev/null +++ b/net/dcsctp/packet/chunk/forward_tsn_chunk.cc @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" + +#include +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc3758#section-3.2 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 192 | Flags = 0x00 | Length = Variable | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | New Cumulative TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream-1 | Stream Sequence-1 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ / +// / \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream-N | Stream Sequence-N | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ForwardTsnChunk::kType; + +absl::optional ForwardTsnChunk::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + TSN new_cumulative_tsn(reader->Load32<4>()); + + size_t streams_skipped = + reader->variable_data_size() / kSkippedStreamBufferSize; + + std::vector skipped_streams; + skipped_streams.reserve(streams_skipped); + for (size_t i = 0; i < streams_skipped; ++i) { + BoundedByteReader sub_reader = + reader->sub_reader(i * + kSkippedStreamBufferSize); + + StreamID stream_id(sub_reader.Load16<0>()); + SSN ssn(sub_reader.Load16<2>()); + skipped_streams.emplace_back(stream_id, ssn); + } + return ForwardTsnChunk(new_cumulative_tsn, std::move(skipped_streams)); +} + +void ForwardTsnChunk::SerializeTo(std::vector& out) const { + rtc::ArrayView skipped = skipped_streams(); + size_t variable_size = skipped.size() * kSkippedStreamBufferSize; + BoundedByteWriter writer = AllocateTLV(out, variable_size); + + writer.Store32<4>(*new_cumulative_tsn()); + for (size_t i = 0; i < skipped.size(); ++i) { + BoundedByteWriter sub_writer = + writer.sub_writer(i * + kSkippedStreamBufferSize); + sub_writer.Store16<0>(*skipped[i].stream_id); + sub_writer.Store16<2>(*skipped[i].ssn); + } +} + +std::string ForwardTsnChunk::ToString() const { + rtc::StringBuilder sb; + sb << "FORWARD-TSN, new_cumulative_tsn=" << *new_cumulative_tsn(); + return sb.str(); +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/forward_tsn_chunk.h b/net/dcsctp/packet/chunk/forward_tsn_chunk.h new file mode 100644 index 0000000000..b9ef666f41 --- /dev/null +++ b/net/dcsctp/packet/chunk/forward_tsn_chunk.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_FORWARD_TSN_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_FORWARD_TSN_CHUNK_H_ +#include +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc3758#section-3.2 +struct ForwardTsnChunkConfig : ChunkConfig { + static constexpr int kType = 192; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 4; +}; + +class ForwardTsnChunk : public AnyForwardTsnChunk, + public TLVTrait { + public: + static constexpr int kType = ForwardTsnChunkConfig::kType; + + ForwardTsnChunk(TSN new_cumulative_tsn, + std::vector skipped_streams) + : AnyForwardTsnChunk(new_cumulative_tsn, std::move(skipped_streams)) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + private: + static constexpr size_t kSkippedStreamBufferSize = 4; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_FORWARD_TSN_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/forward_tsn_chunk_test.cc b/net/dcsctp/packet/chunk/forward_tsn_chunk_test.cc new file mode 100644 index 0000000000..9420c1f2ef --- /dev/null +++ b/net/dcsctp/packet/chunk/forward_tsn_chunk_test.cc @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" + +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(ForwardTsnChunkTest, FromCapture) { + /* + FORWARD_TSN chunk(Cumulative TSN: 1905748778) + Chunk type: FORWARD_TSN (192) + Chunk flags: 0x00 + Chunk length: 8 + New cumulative TSN: 1905748778 + */ + + uint8_t data[] = {0xc0, 0x00, 0x00, 0x08, 0x71, 0x97, 0x6b, 0x2a}; + + ASSERT_HAS_VALUE_AND_ASSIGN(ForwardTsnChunk chunk, + ForwardTsnChunk::Parse(data)); + EXPECT_EQ(*chunk.new_cumulative_tsn(), 1905748778u); +} + +TEST(ForwardTsnChunkTest, SerializeAndDeserialize) { + ForwardTsnChunk chunk( + TSN(123), {ForwardTsnChunk::SkippedStream(StreamID(1), SSN(23)), + ForwardTsnChunk::SkippedStream(StreamID(42), SSN(99))}); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(ForwardTsnChunk deserialized, + ForwardTsnChunk::Parse(serialized)); + EXPECT_EQ(*deserialized.new_cumulative_tsn(), 123u); + EXPECT_THAT( + deserialized.skipped_streams(), + ElementsAre(ForwardTsnChunk::SkippedStream(StreamID(1), SSN(23)), + ForwardTsnChunk::SkippedStream(StreamID(42), SSN(99)))); + + EXPECT_EQ(deserialized.ToString(), "FORWARD-TSN, new_cumulative_tsn=123"); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/forward_tsn_common.h b/net/dcsctp/packet/chunk/forward_tsn_common.h new file mode 100644 index 0000000000..37bd2aafff --- /dev/null +++ b/net/dcsctp/packet/chunk/forward_tsn_common.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_FORWARD_TSN_COMMON_H_ +#define NET_DCSCTP_PACKET_CHUNK_FORWARD_TSN_COMMON_H_ +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" + +namespace dcsctp { + +// Base class for both ForwardTsnChunk and IForwardTsnChunk +class AnyForwardTsnChunk : public Chunk { + public: + struct SkippedStream { + SkippedStream(StreamID stream_id, SSN ssn) + : stream_id(stream_id), ssn(ssn), unordered(false), message_id(0) {} + SkippedStream(IsUnordered unordered, StreamID stream_id, MID message_id) + : stream_id(stream_id), + ssn(0), + unordered(unordered), + message_id(message_id) {} + + StreamID stream_id; + + // Set for FORWARD_TSN + SSN ssn; + + // Set for I-FORWARD_TSN + IsUnordered unordered; + MID message_id; + + bool operator==(const SkippedStream& other) const { + return stream_id == other.stream_id && ssn == other.ssn && + unordered == other.unordered && message_id == other.message_id; + } + }; + + AnyForwardTsnChunk(TSN new_cumulative_tsn, + std::vector skipped_streams) + : new_cumulative_tsn_(new_cumulative_tsn), + skipped_streams_(std::move(skipped_streams)) {} + + TSN new_cumulative_tsn() const { return new_cumulative_tsn_; } + + rtc::ArrayView skipped_streams() const { + return skipped_streams_; + } + + private: + TSN new_cumulative_tsn_; + std::vector skipped_streams_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_FORWARD_TSN_COMMON_H_ diff --git a/net/dcsctp/packet/chunk/heartbeat_ack_chunk.cc b/net/dcsctp/packet/chunk/heartbeat_ack_chunk.cc new file mode 100644 index 0000000000..3cbcd09c75 --- /dev/null +++ b/net/dcsctp/packet/chunk/heartbeat_ack_chunk.cc @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" + +#include + +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.6 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 5 | Chunk Flags | Heartbeat Ack Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Heartbeat Information TLV (Variable-Length) / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int HeartbeatAckChunk::kType; + +absl::optional HeartbeatAckChunk::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + absl::optional parameters = + Parameters::Parse(reader->variable_data()); + if (!parameters.has_value()) { + return absl::nullopt; + } + return HeartbeatAckChunk(*std::move(parameters)); +} + +void HeartbeatAckChunk::SerializeTo(std::vector& out) const { + rtc::ArrayView parameters = parameters_.data(); + BoundedByteWriter writer = AllocateTLV(out, parameters.size()); + writer.CopyToVariableData(parameters); +} + +std::string HeartbeatAckChunk::ToString() const { + return "HEARTBEAT-ACK"; +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/heartbeat_ack_chunk.h b/net/dcsctp/packet/chunk/heartbeat_ack_chunk.h new file mode 100644 index 0000000000..a6479f78b0 --- /dev/null +++ b/net/dcsctp/packet/chunk/heartbeat_ack_chunk.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_HEARTBEAT_ACK_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_HEARTBEAT_ACK_CHUNK_H_ +#include +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.6 +struct HeartbeatAckChunkConfig : ChunkConfig { + static constexpr int kType = 5; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class HeartbeatAckChunk : public Chunk, + public TLVTrait { + public: + static constexpr int kType = HeartbeatAckChunkConfig::kType; + + explicit HeartbeatAckChunk(Parameters parameters) + : parameters_(std::move(parameters)) {} + + HeartbeatAckChunk(HeartbeatAckChunk&& other) = default; + HeartbeatAckChunk& operator=(HeartbeatAckChunk&& other) = default; + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + const Parameters& parameters() const { return parameters_; } + + absl::optional info() const { + return parameters_.get(); + } + + private: + Parameters parameters_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_HEARTBEAT_ACK_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/heartbeat_ack_chunk_test.cc b/net/dcsctp/packet/chunk/heartbeat_ack_chunk_test.cc new file mode 100644 index 0000000000..e4d0dd1489 --- /dev/null +++ b/net/dcsctp/packet/chunk/heartbeat_ack_chunk_test.cc @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" + +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(HeartbeatAckChunkTest, FromCapture) { + /* + HEARTBEAT_ACK chunk (Information: 40 bytes) + Chunk type: HEARTBEAT_ACK (5) + Chunk flags: 0x00 + Chunk length: 44 + Heartbeat info parameter (Information: 36 bytes) + Parameter type: Heartbeat info (0x0001) + Parameter length: 40 + Heartbeat information: ad2436603726070000000000000000007b1000000100… + */ + + uint8_t data[] = {0x05, 0x00, 0x00, 0x2c, 0x00, 0x01, 0x00, 0x28, 0xad, + 0x24, 0x36, 0x60, 0x37, 0x26, 0x07, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7b, 0x10, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatAckChunk chunk, + HeartbeatAckChunk::Parse(data)); + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info, chunk.info()); + + EXPECT_THAT( + info.info(), + ElementsAre(0xad, 0x24, 0x36, 0x60, 0x37, 0x26, 0x07, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7b, 0x10, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)); +} + +TEST(HeartbeatAckChunkTest, SerializeAndDeserialize) { + uint8_t info_data[] = {1, 2, 3, 4}; + Parameters parameters = + Parameters::Builder().Add(HeartbeatInfoParameter(info_data)).Build(); + HeartbeatAckChunk chunk(std::move(parameters)); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatAckChunk deserialized, + HeartbeatAckChunk::Parse(serialized)); + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info, deserialized.info()); + + EXPECT_THAT(info.info(), ElementsAre(1, 2, 3, 4)); + + EXPECT_EQ(deserialized.ToString(), "HEARTBEAT-ACK"); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/heartbeat_request_chunk.cc b/net/dcsctp/packet/chunk/heartbeat_request_chunk.cc new file mode 100644 index 0000000000..d759d6b16d --- /dev/null +++ b/net/dcsctp/packet/chunk/heartbeat_request_chunk.cc @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" + +#include + +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.5 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 4 | Chunk Flags | Heartbeat Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Heartbeat Information TLV (Variable-Length) / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int HeartbeatRequestChunk::kType; + +absl::optional HeartbeatRequestChunk::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + absl::optional parameters = + Parameters::Parse(reader->variable_data()); + if (!parameters.has_value()) { + return absl::nullopt; + } + return HeartbeatRequestChunk(*std::move(parameters)); +} + +void HeartbeatRequestChunk::SerializeTo(std::vector& out) const { + rtc::ArrayView parameters = parameters_.data(); + BoundedByteWriter writer = AllocateTLV(out, parameters.size()); + writer.CopyToVariableData(parameters); +} + +std::string HeartbeatRequestChunk::ToString() const { + return "HEARTBEAT"; +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/heartbeat_request_chunk.h b/net/dcsctp/packet/chunk/heartbeat_request_chunk.h new file mode 100644 index 0000000000..fe2ce19504 --- /dev/null +++ b/net/dcsctp/packet/chunk/heartbeat_request_chunk.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_HEARTBEAT_REQUEST_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_HEARTBEAT_REQUEST_CHUNK_H_ +#include +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { +// https://tools.ietf.org/html/rfc4960#section-3.3.5 +struct HeartbeatRequestChunkConfig : ChunkConfig { + static constexpr int kType = 4; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class HeartbeatRequestChunk : public Chunk, + public TLVTrait { + public: + static constexpr int kType = HeartbeatRequestChunkConfig::kType; + + explicit HeartbeatRequestChunk(Parameters parameters) + : parameters_(std::move(parameters)) {} + + HeartbeatRequestChunk(HeartbeatRequestChunk&& other) = default; + HeartbeatRequestChunk& operator=(HeartbeatRequestChunk&& other) = default; + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + const Parameters& parameters() const { return parameters_; } + Parameters extract_parameters() && { return std::move(parameters_); } + absl::optional info() const { + return parameters_.get(); + } + + private: + Parameters parameters_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_HEARTBEAT_REQUEST_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/heartbeat_request_chunk_test.cc b/net/dcsctp/packet/chunk/heartbeat_request_chunk_test.cc new file mode 100644 index 0000000000..94911fe28b --- /dev/null +++ b/net/dcsctp/packet/chunk/heartbeat_request_chunk_test.cc @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" + +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(HeartbeatRequestChunkTest, FromCapture) { + /* + HEARTBEAT chunk (Information: 40 bytes) + Chunk type: HEARTBEAT (4) + Chunk flags: 0x00 + Chunk length: 44 + Heartbeat info parameter (Information: 36 bytes) + Parameter type: Heartbeat info (0x0001) + Parameter length: 40 + Heartbeat information: ad2436603726070000000000000000007b10000001… + */ + + uint8_t data[] = {0x04, 0x00, 0x00, 0x2c, 0x00, 0x01, 0x00, 0x28, 0xad, + 0x24, 0x36, 0x60, 0x37, 0x26, 0x07, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7b, 0x10, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatRequestChunk chunk, + HeartbeatRequestChunk::Parse(data)); + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info, chunk.info()); + + EXPECT_THAT( + info.info(), + ElementsAre(0xad, 0x24, 0x36, 0x60, 0x37, 0x26, 0x07, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7b, 0x10, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)); +} + +TEST(HeartbeatRequestChunkTest, SerializeAndDeserialize) { + uint8_t info_data[] = {1, 2, 3, 4}; + Parameters parameters = + Parameters::Builder().Add(HeartbeatInfoParameter(info_data)).Build(); + HeartbeatRequestChunk chunk(std::move(parameters)); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatRequestChunk deserialized, + HeartbeatRequestChunk::Parse(serialized)); + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info, deserialized.info()); + + EXPECT_THAT(info.info(), ElementsAre(1, 2, 3, 4)); + + EXPECT_EQ(deserialized.ToString(), "HEARTBEAT"); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/idata_chunk.cc b/net/dcsctp/packet/chunk/idata_chunk.cc new file mode 100644 index 0000000000..378c527909 --- /dev/null +++ b/net/dcsctp/packet/chunk/idata_chunk.cc @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/idata_chunk.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc8260#section-2.1 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 64 | Res |I|U|B|E| Length = Variable | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Identifier | Reserved | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Message Identifier | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Payload Protocol Identifier / Fragment Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / User Data / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int IDataChunk::kType; + +absl::optional IDataChunk::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + uint8_t flags = reader->Load8<1>(); + TSN tsn(reader->Load32<4>()); + StreamID stream_identifier(reader->Load16<8>()); + MID message_id(reader->Load32<12>()); + uint32_t ppid_or_fsn = reader->Load32<16>(); + + Options options; + options.is_end = Data::IsEnd((flags & (1 << kFlagsBitEnd)) != 0); + options.is_beginning = + Data::IsBeginning((flags & (1 << kFlagsBitBeginning)) != 0); + options.is_unordered = IsUnordered((flags & (1 << kFlagsBitUnordered)) != 0); + options.immediate_ack = + ImmediateAckFlag((flags & (1 << kFlagsBitImmediateAck)) != 0); + + return IDataChunk(tsn, stream_identifier, message_id, + PPID(options.is_beginning ? ppid_or_fsn : 0), + FSN(options.is_beginning ? 0 : ppid_or_fsn), + std::vector(reader->variable_data().begin(), + reader->variable_data().end()), + options); +} + +void IDataChunk::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = AllocateTLV(out, payload().size()); + + writer.Store8<1>( + (*options().is_end ? (1 << kFlagsBitEnd) : 0) | + (*options().is_beginning ? (1 << kFlagsBitBeginning) : 0) | + (*options().is_unordered ? (1 << kFlagsBitUnordered) : 0) | + (*options().immediate_ack ? (1 << kFlagsBitImmediateAck) : 0)); + writer.Store32<4>(*tsn()); + writer.Store16<8>(*stream_id()); + writer.Store32<12>(*message_id()); + writer.Store32<16>(options().is_beginning ? *ppid() : *fsn()); + writer.CopyToVariableData(payload()); +} + +std::string IDataChunk::ToString() const { + rtc::StringBuilder sb; + sb << "I-DATA, type=" << (options().is_unordered ? "unordered" : "ordered") + << "::" + << (*options().is_beginning && *options().is_end + ? "complete" + : *options().is_beginning ? "first" + : *options().is_end ? "last" : "middle") + << ", tsn=" << *tsn() << ", stream_id=" << *stream_id() + << ", message_id=" << *message_id(); + + if (*options().is_beginning) { + sb << ", ppid=" << *ppid(); + } else { + sb << ", fsn=" << *fsn(); + } + sb << ", length=" << payload().size(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/idata_chunk.h b/net/dcsctp/packet/chunk/idata_chunk.h new file mode 100644 index 0000000000..8cdf2a1fc4 --- /dev/null +++ b/net/dcsctp/packet/chunk/idata_chunk.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_IDATA_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_IDATA_CHUNK_H_ +#include +#include + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc8260#section-2.1 +struct IDataChunkConfig : ChunkConfig { + static constexpr int kType = 64; + static constexpr size_t kHeaderSize = 20; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class IDataChunk : public AnyDataChunk, public TLVTrait { + public: + static constexpr int kType = IDataChunkConfig::kType; + + // Exposed to allow the retransmission queue to make room for the correct + // header size. + static constexpr size_t kHeaderSize = IDataChunkConfig::kHeaderSize; + IDataChunk(TSN tsn, + StreamID stream_id, + MID message_id, + PPID ppid, + FSN fsn, + std::vector payload, + const Options& options) + : AnyDataChunk(tsn, + stream_id, + SSN(0), + message_id, + fsn, + ppid, + std::move(payload), + options) {} + + explicit IDataChunk(TSN tsn, Data&& data, bool immediate_ack) + : AnyDataChunk(tsn, std::move(data), immediate_ack) {} + + static absl::optional Parse(rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_IDATA_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/idata_chunk_test.cc b/net/dcsctp/packet/chunk/idata_chunk_test.cc new file mode 100644 index 0000000000..fea492d71e --- /dev/null +++ b/net/dcsctp/packet/chunk/idata_chunk_test.cc @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/idata_chunk.h" + +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(IDataChunkTest, AtBeginningFromCapture) { + /* + I_DATA chunk(ordered, first segment, TSN: 2487901653, SID: 1, MID: 0, + payload length: 1180 bytes) + Chunk type: I_DATA (64) + Chunk flags: 0x02 + Chunk length: 1200 + Transmission sequence number: 2487901653 + Stream identifier: 0x0001 + Reserved: 0 + Message identifier: 0 + Payload protocol identifier: WebRTC Binary (53) + Reassembled Message in frame: 39 + */ + + uint8_t data[] = {0x40, 0x02, 0x00, 0x15, 0x94, 0x4a, 0x5d, 0xd5, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x35, 0x01, 0x00, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(IDataChunk chunk, IDataChunk::Parse(data)); + EXPECT_EQ(*chunk.tsn(), 2487901653); + EXPECT_EQ(*chunk.stream_id(), 1); + EXPECT_EQ(*chunk.message_id(), 0u); + EXPECT_EQ(*chunk.ppid(), 53u); + EXPECT_EQ(*chunk.fsn(), 0u); // Not provided (so set to zero) +} + +TEST(IDataChunkTest, AtBeginningSerializeAndDeserialize) { + IDataChunk::Options options; + options.is_beginning = Data::IsBeginning(true); + IDataChunk chunk(TSN(123), StreamID(456), MID(789), PPID(53), FSN(0), {1}, + options); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(IDataChunk deserialized, + IDataChunk::Parse(serialized)); + EXPECT_EQ(*deserialized.tsn(), 123u); + EXPECT_EQ(*deserialized.stream_id(), 456u); + EXPECT_EQ(*deserialized.message_id(), 789u); + EXPECT_EQ(*deserialized.ppid(), 53u); + EXPECT_EQ(*deserialized.fsn(), 0u); + + EXPECT_EQ(deserialized.ToString(), + "I-DATA, type=ordered::first, tsn=123, stream_id=456, " + "message_id=789, ppid=53, length=1"); +} + +TEST(IDataChunkTest, InMiddleFromCapture) { + /* + I_DATA chunk(ordered, last segment, TSN: 2487901706, SID: 3, MID: 1, + FSN: 8, payload length: 560 bytes) + Chunk type: I_DATA (64) + Chunk flags: 0x01 + Chunk length: 580 + Transmission sequence number: 2487901706 + Stream identifier: 0x0003 + Reserved: 0 + Message identifier: 1 + Fragment sequence number: 8 + Reassembled SCTP Fragments (10000 bytes, 9 fragments): + */ + + uint8_t data[] = {0x40, 0x01, 0x00, 0x15, 0x94, 0x4a, 0x5e, 0x0a, + 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x08, 0x01, 0x00, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(IDataChunk chunk, IDataChunk::Parse(data)); + EXPECT_EQ(*chunk.tsn(), 2487901706); + EXPECT_EQ(*chunk.stream_id(), 3u); + EXPECT_EQ(*chunk.message_id(), 1u); + EXPECT_EQ(*chunk.ppid(), 0u); // Not provided (so set to zero) + EXPECT_EQ(*chunk.fsn(), 8u); +} + +TEST(IDataChunkTest, InMiddleSerializeAndDeserialize) { + IDataChunk chunk(TSN(123), StreamID(456), MID(789), PPID(0), FSN(101112), + {1, 2, 3}, /*options=*/{}); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(IDataChunk deserialized, + IDataChunk::Parse(serialized)); + EXPECT_EQ(*deserialized.tsn(), 123u); + EXPECT_EQ(*deserialized.stream_id(), 456u); + EXPECT_EQ(*deserialized.message_id(), 789u); + EXPECT_EQ(*deserialized.ppid(), 0u); + EXPECT_EQ(*deserialized.fsn(), 101112u); + EXPECT_THAT(deserialized.payload(), ElementsAre(1, 2, 3)); + + EXPECT_EQ(deserialized.ToString(), + "I-DATA, type=ordered::middle, tsn=123, stream_id=456, " + "message_id=789, fsn=101112, length=3"); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/iforward_tsn_chunk.cc b/net/dcsctp/packet/chunk/iforward_tsn_chunk.cc new file mode 100644 index 0000000000..a647a8bf8a --- /dev/null +++ b/net/dcsctp/packet/chunk/iforward_tsn_chunk.cc @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" + +#include +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc8260#section-2.3.1 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 194 | Flags = 0x00 | Length = Variable | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | New Cumulative TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Identifier | Reserved |U| +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Message Identifier | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / / +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Identifier | Reserved |U| +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Message Identifier | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int IForwardTsnChunk::kType; + +absl::optional IForwardTsnChunk::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + TSN new_cumulative_tsn(reader->Load32<4>()); + + size_t streams_skipped = + reader->variable_data_size() / kSkippedStreamBufferSize; + std::vector skipped_streams; + skipped_streams.reserve(streams_skipped); + size_t offset = 0; + for (size_t i = 0; i < streams_skipped; ++i) { + BoundedByteReader sub_reader = + reader->sub_reader(offset); + + StreamID stream_id(sub_reader.Load16<0>()); + IsUnordered unordered(sub_reader.Load8<3>() & 0x01); + MID message_id(sub_reader.Load32<4>()); + skipped_streams.emplace_back(unordered, stream_id, message_id); + offset += kSkippedStreamBufferSize; + } + RTC_DCHECK(offset == reader->variable_data_size()); + return IForwardTsnChunk(new_cumulative_tsn, std::move(skipped_streams)); +} + +void IForwardTsnChunk::SerializeTo(std::vector& out) const { + rtc::ArrayView skipped = skipped_streams(); + size_t variable_size = skipped.size() * kSkippedStreamBufferSize; + BoundedByteWriter writer = AllocateTLV(out, variable_size); + + writer.Store32<4>(*new_cumulative_tsn()); + size_t offset = 0; + for (size_t i = 0; i < skipped.size(); ++i) { + BoundedByteWriter sub_writer = + writer.sub_writer(offset); + + sub_writer.Store16<0>(*skipped[i].stream_id); + sub_writer.Store8<3>(skipped[i].unordered ? 1 : 0); + sub_writer.Store32<4>(*skipped[i].message_id); + offset += kSkippedStreamBufferSize; + } + RTC_DCHECK(offset == variable_size); +} + +std::string IForwardTsnChunk::ToString() const { + rtc::StringBuilder sb; + sb << "I-FORWARD-TSN, new_cumulative_tsn=" << *new_cumulative_tsn(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/iforward_tsn_chunk.h b/net/dcsctp/packet/chunk/iforward_tsn_chunk.h new file mode 100644 index 0000000000..54d23f7a83 --- /dev/null +++ b/net/dcsctp/packet/chunk/iforward_tsn_chunk.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_IFORWARD_TSN_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_IFORWARD_TSN_CHUNK_H_ +#include +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc8260#section-2.3.1 +struct IForwardTsnChunkConfig : ChunkConfig { + static constexpr int kType = 194; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 8; +}; + +class IForwardTsnChunk : public AnyForwardTsnChunk, + public TLVTrait { + public: + static constexpr int kType = IForwardTsnChunkConfig::kType; + + IForwardTsnChunk(TSN new_cumulative_tsn, + std::vector skipped_streams) + : AnyForwardTsnChunk(new_cumulative_tsn, std::move(skipped_streams)) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + private: + static constexpr size_t kSkippedStreamBufferSize = 8; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_IFORWARD_TSN_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/iforward_tsn_chunk_test.cc b/net/dcsctp/packet/chunk/iforward_tsn_chunk_test.cc new file mode 100644 index 0000000000..6a89433be1 --- /dev/null +++ b/net/dcsctp/packet/chunk/iforward_tsn_chunk_test.cc @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" + +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(IForwardTsnChunkTest, FromCapture) { + /* + I_FORWARD_TSN chunk(Cumulative TSN: 3094631148) + Chunk type: I_FORWARD_TSN (194) + Chunk flags: 0x00 + Chunk length: 16 + New cumulative TSN: 3094631148 + Stream identifier: 1 + Flags: 0x0000 + Message identifier: 2 + */ + + uint8_t data[] = {0xc2, 0x00, 0x00, 0x10, 0xb8, 0x74, 0x52, 0xec, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02}; + + ASSERT_HAS_VALUE_AND_ASSIGN(IForwardTsnChunk chunk, + IForwardTsnChunk::Parse(data)); + EXPECT_EQ(*chunk.new_cumulative_tsn(), 3094631148u); + EXPECT_THAT(chunk.skipped_streams(), + ElementsAre(IForwardTsnChunk::SkippedStream( + IsUnordered(false), StreamID(1), MID(2)))); +} + +TEST(IForwardTsnChunkTest, SerializeAndDeserialize) { + IForwardTsnChunk chunk( + TSN(123), {IForwardTsnChunk::SkippedStream(IsUnordered(false), + StreamID(1), MID(23)), + IForwardTsnChunk::SkippedStream(IsUnordered(true), + StreamID(42), MID(99))}); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(IForwardTsnChunk deserialized, + IForwardTsnChunk::Parse(serialized)); + EXPECT_EQ(*deserialized.new_cumulative_tsn(), 123u); + EXPECT_THAT(deserialized.skipped_streams(), + ElementsAre(IForwardTsnChunk::SkippedStream(IsUnordered(false), + StreamID(1), MID(23)), + IForwardTsnChunk::SkippedStream( + IsUnordered(true), StreamID(42), MID(99)))); + + EXPECT_EQ(deserialized.ToString(), "I-FORWARD-TSN, new_cumulative_tsn=123"); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/init_ack_chunk.cc b/net/dcsctp/packet/chunk/init_ack_chunk.cc new file mode 100644 index 0000000000..c7ef9da1f1 --- /dev/null +++ b/net/dcsctp/packet/chunk/init_ack_chunk.cc @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/init_ack_chunk.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_format.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.3 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 2 | Chunk Flags | Chunk Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Initiate Tag | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Advertised Receiver Window Credit | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Number of Outbound Streams | Number of Inbound Streams | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Initial TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Optional/Variable-Length Parameters / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int InitAckChunk::kType; + +absl::optional InitAckChunk::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + VerificationTag initiate_tag(reader->Load32<4>()); + uint32_t a_rwnd = reader->Load32<8>(); + uint16_t nbr_outbound_streams = reader->Load16<12>(); + uint16_t nbr_inbound_streams = reader->Load16<14>(); + TSN initial_tsn(reader->Load32<16>()); + absl::optional parameters = + Parameters::Parse(reader->variable_data()); + if (!parameters.has_value()) { + return absl::nullopt; + } + return InitAckChunk(initiate_tag, a_rwnd, nbr_outbound_streams, + nbr_inbound_streams, initial_tsn, *std::move(parameters)); +} + +void InitAckChunk::SerializeTo(std::vector& out) const { + rtc::ArrayView parameters = parameters_.data(); + BoundedByteWriter writer = AllocateTLV(out, parameters.size()); + + writer.Store32<4>(*initiate_tag_); + writer.Store32<8>(a_rwnd_); + writer.Store16<12>(nbr_outbound_streams_); + writer.Store16<14>(nbr_inbound_streams_); + writer.Store32<16>(*initial_tsn_); + writer.CopyToVariableData(parameters); +} + +std::string InitAckChunk::ToString() const { + return rtc::StringFormat("INIT_ACK, initiate_tag=0x%0x, initial_tsn=%u", + *initiate_tag(), *initial_tsn()); +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/init_ack_chunk.h b/net/dcsctp/packet/chunk/init_ack_chunk.h new file mode 100644 index 0000000000..6fcf64b2eb --- /dev/null +++ b/net/dcsctp/packet/chunk/init_ack_chunk.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_INIT_ACK_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_INIT_ACK_CHUNK_H_ +#include +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.3 +struct InitAckChunkConfig : ChunkConfig { + static constexpr int kType = 2; + static constexpr size_t kHeaderSize = 20; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class InitAckChunk : public Chunk, public TLVTrait { + public: + static constexpr int kType = InitAckChunkConfig::kType; + + InitAckChunk(VerificationTag initiate_tag, + uint32_t a_rwnd, + uint16_t nbr_outbound_streams, + uint16_t nbr_inbound_streams, + TSN initial_tsn, + Parameters parameters) + : initiate_tag_(initiate_tag), + a_rwnd_(a_rwnd), + nbr_outbound_streams_(nbr_outbound_streams), + nbr_inbound_streams_(nbr_inbound_streams), + initial_tsn_(initial_tsn), + parameters_(std::move(parameters)) {} + + InitAckChunk(InitAckChunk&& other) = default; + InitAckChunk& operator=(InitAckChunk&& other) = default; + + static absl::optional Parse(rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + VerificationTag initiate_tag() const { return initiate_tag_; } + uint32_t a_rwnd() const { return a_rwnd_; } + uint16_t nbr_outbound_streams() const { return nbr_outbound_streams_; } + uint16_t nbr_inbound_streams() const { return nbr_inbound_streams_; } + TSN initial_tsn() const { return initial_tsn_; } + const Parameters& parameters() const { return parameters_; } + + private: + VerificationTag initiate_tag_; + uint32_t a_rwnd_; + uint16_t nbr_outbound_streams_; + uint16_t nbr_inbound_streams_; + TSN initial_tsn_; + Parameters parameters_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_INIT_ACK_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/init_ack_chunk_test.cc b/net/dcsctp/packet/chunk/init_ack_chunk_test.cc new file mode 100644 index 0000000000..184ade747d --- /dev/null +++ b/net/dcsctp/packet/chunk/init_ack_chunk_test.cc @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/init_ack_chunk.h" + +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/state_cookie_parameter.h" +#include "net/dcsctp/packet/parameter/supported_extensions_parameter.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(InitAckChunkTest, FromCapture) { + /* + INIT_ACK chunk (Outbound streams: 1000, inbound streams: 2048) + Chunk type: INIT_ACK (2) + Chunk flags: 0x00 + Chunk length: 292 + Initiate tag: 0x579c2f98 + Advertised receiver window credit (a_rwnd): 131072 + Number of outbound streams: 1000 + Number of inbound streams: 2048 + Initial TSN: 1670811335 + Forward TSN supported parameter + Parameter type: Forward TSN supported (0xc000) + Parameter length: 4 + Supported Extensions parameter (Supported types: FORWARD_TSN, RE_CONFIG) + Parameter type: Supported Extensions (0x8008) + Parameter length: 6 + Supported chunk type: FORWARD_TSN (192) + Supported chunk type: RE_CONFIG (130) + Parameter padding: 0000 + State cookie parameter (Cookie length: 256 bytes) + Parameter type: State cookie (0x0007) + Parameter length: 260 + State cookie: 4b414d452d42534420312e310000000096b8386000000000… + */ + + uint8_t data[] = { + 0x02, 0x00, 0x01, 0x24, 0x57, 0x9c, 0x2f, 0x98, 0x00, 0x02, 0x00, 0x00, + 0x03, 0xe8, 0x08, 0x00, 0x63, 0x96, 0x8e, 0xc7, 0xc0, 0x00, 0x00, 0x04, + 0x80, 0x08, 0x00, 0x06, 0xc0, 0x82, 0x00, 0x00, 0x00, 0x07, 0x01, 0x04, + 0x4b, 0x41, 0x4d, 0x45, 0x2d, 0x42, 0x53, 0x44, 0x20, 0x31, 0x2e, 0x31, + 0x00, 0x00, 0x00, 0x00, 0x96, 0xb8, 0x38, 0x60, 0x00, 0x00, 0x00, 0x00, + 0x52, 0x5a, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x60, 0xea, 0x00, 0x00, + 0xb5, 0xaa, 0x19, 0xea, 0x31, 0xef, 0xa4, 0x2b, 0x90, 0x16, 0x7a, 0xde, + 0x57, 0x9c, 0x2f, 0x98, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x01, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x5a, 0xde, 0x7a, 0x16, 0x90, + 0x00, 0x02, 0x00, 0x00, 0x03, 0xe8, 0x03, 0xe8, 0x25, 0x0d, 0x37, 0xe8, + 0x80, 0x00, 0x00, 0x04, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, + 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, + 0xab, 0x31, 0x44, 0x62, 0x12, 0x1a, 0x15, 0x13, 0xfd, 0x5a, 0x5f, 0x69, + 0xef, 0xaa, 0x06, 0xe9, 0xab, 0xd7, 0x48, 0xcc, 0x3b, 0xd1, 0x4b, 0x60, + 0xed, 0x7f, 0xa6, 0x44, 0xce, 0x4d, 0xd2, 0xad, 0x80, 0x04, 0x00, 0x06, + 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, + 0x02, 0x00, 0x01, 0x24, 0x57, 0x9c, 0x2f, 0x98, 0x00, 0x02, 0x00, 0x00, + 0x03, 0xe8, 0x08, 0x00, 0x63, 0x96, 0x8e, 0xc7, 0xc0, 0x00, 0x00, 0x04, + 0x80, 0x08, 0x00, 0x06, 0xc0, 0x82, 0x00, 0x00, 0x51, 0x95, 0x01, 0x88, + 0x0d, 0x80, 0x7b, 0x19, 0xe7, 0xf9, 0xc6, 0x18, 0x5c, 0x4a, 0xbf, 0x39, + 0x32, 0xe5, 0x63, 0x8e}; + + ASSERT_HAS_VALUE_AND_ASSIGN(InitAckChunk chunk, InitAckChunk::Parse(data)); + + EXPECT_EQ(chunk.initiate_tag(), VerificationTag(0x579c2f98u)); + EXPECT_EQ(chunk.a_rwnd(), 131072u); + EXPECT_EQ(chunk.nbr_outbound_streams(), 1000u); + EXPECT_EQ(chunk.nbr_inbound_streams(), 2048u); + EXPECT_EQ(chunk.initial_tsn(), TSN(1670811335u)); + EXPECT_TRUE( + chunk.parameters().get().has_value()); + EXPECT_TRUE( + chunk.parameters().get().has_value()); + EXPECT_TRUE(chunk.parameters().get().has_value()); +} + +TEST(InitAckChunkTest, SerializeAndDeserialize) { + uint8_t state_cookie[] = {1, 2, 3, 4, 5}; + Parameters parameters = + Parameters::Builder().Add(StateCookieParameter(state_cookie)).Build(); + InitAckChunk chunk(VerificationTag(123), /*a_rwnd=*/456, + /*nbr_outbound_streams=*/65535, + /*nbr_inbound_streams=*/65534, /*initial_tsn=*/TSN(789), + /*parameters=*/std::move(parameters)); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(InitAckChunk deserialized, + InitAckChunk::Parse(serialized)); + + EXPECT_EQ(chunk.initiate_tag(), VerificationTag(123u)); + EXPECT_EQ(chunk.a_rwnd(), 456u); + EXPECT_EQ(chunk.nbr_outbound_streams(), 65535u); + EXPECT_EQ(chunk.nbr_inbound_streams(), 65534u); + EXPECT_EQ(chunk.initial_tsn(), TSN(789u)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + StateCookieParameter cookie, + deserialized.parameters().get()); + EXPECT_THAT(cookie.data(), ElementsAre(1, 2, 3, 4, 5)); + EXPECT_EQ(deserialized.ToString(), + "INIT_ACK, initiate_tag=0x7b, initial_tsn=789"); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/init_chunk.cc b/net/dcsctp/packet/chunk/init_chunk.cc new file mode 100644 index 0000000000..8030107072 --- /dev/null +++ b/net/dcsctp/packet/chunk/init_chunk.cc @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/init_chunk.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_format.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.2 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 1 | Chunk Flags | Chunk Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Initiate Tag | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Advertised Receiver Window Credit (a_rwnd) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Number of Outbound Streams | Number of Inbound Streams | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Initial TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Optional/Variable-Length Parameters / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int InitChunk::kType; + +absl::optional InitChunk::Parse(rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + VerificationTag initiate_tag(reader->Load32<4>()); + uint32_t a_rwnd = reader->Load32<8>(); + uint16_t nbr_outbound_streams = reader->Load16<12>(); + uint16_t nbr_inbound_streams = reader->Load16<14>(); + TSN initial_tsn(reader->Load32<16>()); + + absl::optional parameters = + Parameters::Parse(reader->variable_data()); + if (!parameters.has_value()) { + return absl::nullopt; + } + return InitChunk(initiate_tag, a_rwnd, nbr_outbound_streams, + nbr_inbound_streams, initial_tsn, *std::move(parameters)); +} + +void InitChunk::SerializeTo(std::vector& out) const { + rtc::ArrayView parameters = parameters_.data(); + BoundedByteWriter writer = AllocateTLV(out, parameters.size()); + + writer.Store32<4>(*initiate_tag_); + writer.Store32<8>(a_rwnd_); + writer.Store16<12>(nbr_outbound_streams_); + writer.Store16<14>(nbr_inbound_streams_); + writer.Store32<16>(*initial_tsn_); + + writer.CopyToVariableData(parameters); +} + +std::string InitChunk::ToString() const { + return rtc::StringFormat("INIT, initiate_tag=0x%0x, initial_tsn=%u", + *initiate_tag(), *initial_tsn()); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/init_chunk.h b/net/dcsctp/packet/chunk/init_chunk.h new file mode 100644 index 0000000000..38f9994caa --- /dev/null +++ b/net/dcsctp/packet/chunk/init_chunk.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_INIT_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_INIT_CHUNK_H_ +#include +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.2 +struct InitChunkConfig : ChunkConfig { + static constexpr int kType = 1; + static constexpr size_t kHeaderSize = 20; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class InitChunk : public Chunk, public TLVTrait { + public: + static constexpr int kType = InitChunkConfig::kType; + + InitChunk(VerificationTag initiate_tag, + uint32_t a_rwnd, + uint16_t nbr_outbound_streams, + uint16_t nbr_inbound_streams, + TSN initial_tsn, + Parameters parameters) + : initiate_tag_(initiate_tag), + a_rwnd_(a_rwnd), + nbr_outbound_streams_(nbr_outbound_streams), + nbr_inbound_streams_(nbr_inbound_streams), + initial_tsn_(initial_tsn), + parameters_(std::move(parameters)) {} + + InitChunk(InitChunk&& other) = default; + InitChunk& operator=(InitChunk&& other) = default; + + static absl::optional Parse(rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + VerificationTag initiate_tag() const { return initiate_tag_; } + uint32_t a_rwnd() const { return a_rwnd_; } + uint16_t nbr_outbound_streams() const { return nbr_outbound_streams_; } + uint16_t nbr_inbound_streams() const { return nbr_inbound_streams_; } + TSN initial_tsn() const { return initial_tsn_; } + const Parameters& parameters() const { return parameters_; } + + private: + VerificationTag initiate_tag_; + uint32_t a_rwnd_; + uint16_t nbr_outbound_streams_; + uint16_t nbr_inbound_streams_; + TSN initial_tsn_; + Parameters parameters_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_INIT_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/init_chunk_test.cc b/net/dcsctp/packet/chunk/init_chunk_test.cc new file mode 100644 index 0000000000..bd36d6fdf8 --- /dev/null +++ b/net/dcsctp/packet/chunk/init_chunk_test.cc @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/init_chunk.h" + +#include + +#include +#include + +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/supported_extensions_parameter.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(InitChunkTest, FromCapture) { + /* + INIT chunk (Outbound streams: 1000, inbound streams: 1000) + Chunk type: INIT (1) + Chunk flags: 0x00 + Chunk length: 90 + Initiate tag: 0xde7a1690 + Advertised receiver window credit (a_rwnd): 131072 + Number of outbound streams: 1000 + Number of inbound streams: 1000 + Initial TSN: 621623272 + ECN parameter + Parameter type: ECN (0x8000) + Parameter length: 4 + Forward TSN supported parameter + Parameter type: Forward TSN supported (0xc000) + Parameter length: 4 + Supported Extensions parameter (Supported types: FORWARD_TSN, AUTH, + ASCONF, ASCONF_ACK, RE_CONFIG) Parameter type: Supported Extensions (0x8008) + Parameter length: 9 + Supported chunk type: FORWARD_TSN (192) + Supported chunk type: AUTH (15) + Supported chunk type: ASCONF (193) + Supported chunk type: ASCONF_ACK (128) + Supported chunk type: RE_CONFIG (130) + Parameter padding: 000000 + Random parameter + Parameter type: Random (0x8002) + Parameter length: 36 + Random number: ab314462121a1513fd5a5f69efaa06e9abd748cc3bd14b60… + Requested HMAC Algorithm parameter (Supported HMACs: SHA-1) + Parameter type: Requested HMAC Algorithm (0x8004) + Parameter length: 6 + HMAC identifier: SHA-1 (1) + Parameter padding: 0000 + Authenticated Chunk list parameter (Chunk types to be authenticated: + ASCONF_ACK, ASCONF) Parameter type: Authenticated Chunk list (0x8003) + Parameter length: 6 + Chunk type: ASCONF_ACK (128) + Chunk type: ASCONF (193) + */ + + uint8_t data[] = { + 0x01, 0x00, 0x00, 0x5a, 0xde, 0x7a, 0x16, 0x90, 0x00, 0x02, 0x00, 0x00, + 0x03, 0xe8, 0x03, 0xe8, 0x25, 0x0d, 0x37, 0xe8, 0x80, 0x00, 0x00, 0x04, + 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, + 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0xab, 0x31, 0x44, 0x62, + 0x12, 0x1a, 0x15, 0x13, 0xfd, 0x5a, 0x5f, 0x69, 0xef, 0xaa, 0x06, 0xe9, + 0xab, 0xd7, 0x48, 0xcc, 0x3b, 0xd1, 0x4b, 0x60, 0xed, 0x7f, 0xa6, 0x44, + 0xce, 0x4d, 0xd2, 0xad, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, + 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(InitChunk chunk, InitChunk::Parse(data)); + + EXPECT_EQ(chunk.initiate_tag(), VerificationTag(0xde7a1690)); + EXPECT_EQ(chunk.a_rwnd(), 131072u); + EXPECT_EQ(chunk.nbr_outbound_streams(), 1000u); + EXPECT_EQ(chunk.nbr_inbound_streams(), 1000u); + EXPECT_EQ(chunk.initial_tsn(), TSN(621623272u)); + EXPECT_TRUE( + chunk.parameters().get().has_value()); + EXPECT_TRUE( + chunk.parameters().get().has_value()); +} + +TEST(InitChunkTest, SerializeAndDeserialize) { + InitChunk chunk(VerificationTag(123), /*a_rwnd=*/456, + /*nbr_outbound_streams=*/65535, + /*nbr_inbound_streams=*/65534, /*initial_tsn=*/TSN(789), + /*parameters=*/Parameters::Builder().Build()); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(InitChunk deserialized, + InitChunk::Parse(serialized)); + + EXPECT_EQ(deserialized.initiate_tag(), VerificationTag(123u)); + EXPECT_EQ(deserialized.a_rwnd(), 456u); + EXPECT_EQ(deserialized.nbr_outbound_streams(), 65535u); + EXPECT_EQ(deserialized.nbr_inbound_streams(), 65534u); + EXPECT_EQ(deserialized.initial_tsn(), TSN(789u)); + EXPECT_EQ(deserialized.ToString(), + "INIT, initiate_tag=0x7b, initial_tsn=789"); +} +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/reconfig_chunk.cc b/net/dcsctp/packet/chunk/reconfig_chunk.cc new file mode 100644 index 0000000000..f39f3b619f --- /dev/null +++ b/net/dcsctp/packet/chunk/reconfig_chunk.cc @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" + +#include + +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-3.1 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 130 | Chunk Flags | Chunk Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Re-configuration Parameter / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Re-configuration Parameter (optional) / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ReConfigChunk::kType; + +absl::optional ReConfigChunk::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + absl::optional parameters = + Parameters::Parse(reader->variable_data()); + if (!parameters.has_value()) { + return absl::nullopt; + } + + return ReConfigChunk(*std::move(parameters)); +} + +void ReConfigChunk::SerializeTo(std::vector& out) const { + rtc::ArrayView parameters = parameters_.data(); + BoundedByteWriter writer = AllocateTLV(out, parameters.size()); + writer.CopyToVariableData(parameters); +} + +std::string ReConfigChunk::ToString() const { + return "RE-CONFIG"; +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/reconfig_chunk.h b/net/dcsctp/packet/chunk/reconfig_chunk.h new file mode 100644 index 0000000000..9d2539a515 --- /dev/null +++ b/net/dcsctp/packet/chunk/reconfig_chunk.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_RECONFIG_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_RECONFIG_CHUNK_H_ +#include +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-3.1 +struct ReConfigChunkConfig : ChunkConfig { + static constexpr int kType = 130; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class ReConfigChunk : public Chunk, public TLVTrait { + public: + static constexpr int kType = ReConfigChunkConfig::kType; + + explicit ReConfigChunk(Parameters parameters) + : parameters_(std::move(parameters)) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + const Parameters& parameters() const { return parameters_; } + Parameters extract_parameters() { return std::move(parameters_); } + + private: + Parameters parameters_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_RECONFIG_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/reconfig_chunk_test.cc b/net/dcsctp/packet/chunk/reconfig_chunk_test.cc new file mode 100644 index 0000000000..dbf40ff8c0 --- /dev/null +++ b/net/dcsctp/packet/chunk/reconfig_chunk_test.cc @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" + +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::SizeIs; + +TEST(ReConfigChunkTest, FromCapture) { + /* + RE_CONFIG chunk + Chunk type: RE_CONFIG (130) + Chunk flags: 0x00 + Chunk length: 22 + Outgoing SSN reset request parameter + Parameter type: Outgoing SSN reset request (0x000d) + Parameter length: 18 + Re-configuration request sequence number: 2270550051 + Re-configuration response sequence number: 1905748638 + Senders last assigned TSN: 2270550066 + Stream Identifier: 6 + Chunk padding: 0000 + */ + + uint8_t data[] = {0x82, 0x00, 0x00, 0x16, 0x00, 0x0d, 0x00, 0x12, + 0x87, 0x55, 0xd8, 0x23, 0x71, 0x97, 0x6a, 0x9e, + 0x87, 0x55, 0xd8, 0x32, 0x00, 0x06, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(ReConfigChunk chunk, ReConfigChunk::Parse(data)); + + const Parameters& parameters = chunk.parameters(); + EXPECT_THAT(parameters.descriptors(), SizeIs(1)); + ParameterDescriptor desc = parameters.descriptors()[0]; + ASSERT_EQ(desc.type, OutgoingSSNResetRequestParameter::kType); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + OutgoingSSNResetRequestParameter::Parse(desc.data)); + EXPECT_EQ(*req.request_sequence_number(), 2270550051u); + EXPECT_EQ(*req.response_sequence_number(), 1905748638u); + EXPECT_EQ(*req.sender_last_assigned_tsn(), 2270550066u); + EXPECT_THAT(req.stream_ids(), ElementsAre(StreamID(6))); +} + +TEST(ReConfigChunkTest, SerializeAndDeserialize) { + Parameters::Builder params_builder = + Parameters::Builder().Add(OutgoingSSNResetRequestParameter( + ReconfigRequestSN(123), ReconfigRequestSN(456), TSN(789), + {StreamID(42), StreamID(43)})); + + ReConfigChunk chunk(params_builder.Build()); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(ReConfigChunk deserialized, + ReConfigChunk::Parse(serialized)); + + const Parameters& parameters = deserialized.parameters(); + EXPECT_THAT(parameters.descriptors(), SizeIs(1)); + ParameterDescriptor desc = parameters.descriptors()[0]; + ASSERT_EQ(desc.type, OutgoingSSNResetRequestParameter::kType); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + OutgoingSSNResetRequestParameter::Parse(desc.data)); + EXPECT_EQ(*req.request_sequence_number(), 123u); + EXPECT_EQ(*req.response_sequence_number(), 456u); + EXPECT_EQ(*req.sender_last_assigned_tsn(), 789u); + EXPECT_THAT(req.stream_ids(), ElementsAre(StreamID(42), StreamID(43))); + + EXPECT_EQ(deserialized.ToString(), "RE-CONFIG"); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/sack_chunk.cc b/net/dcsctp/packet/chunk/sack_chunk.cc new file mode 100644 index 0000000000..d80e430082 --- /dev/null +++ b/net/dcsctp/packet/chunk/sack_chunk.cc @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/sack_chunk.h" + +#include + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.4 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 3 |Chunk Flags | Chunk Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cumulative TSN Ack | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Advertised Receiver Window Credit (a_rwnd) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Number of Gap Ack Blocks = N | Number of Duplicate TSNs = X | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Gap Ack Block #1 Start | Gap Ack Block #1 End | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / / +// \ ... \ +// / / +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Gap Ack Block #N Start | Gap Ack Block #N End | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Duplicate TSN 1 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / / +// \ ... \ +// / / +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Duplicate TSN X | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int SackChunk::kType; + +absl::optional SackChunk::Parse(rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + TSN tsn_ack(reader->Load32<4>()); + uint32_t a_rwnd = reader->Load32<8>(); + uint16_t nbr_of_gap_blocks = reader->Load16<12>(); + uint16_t nbr_of_dup_tsns = reader->Load16<14>(); + + if (reader->variable_data_size() != nbr_of_gap_blocks * kGapAckBlockSize + + nbr_of_dup_tsns * kDupTsnBlockSize) { + RTC_DLOG(LS_WARNING) << "Invalid number of gap blocks or duplicate TSNs"; + return absl::nullopt; + } + + std::vector gap_ack_blocks; + gap_ack_blocks.reserve(nbr_of_gap_blocks); + size_t offset = 0; + for (int i = 0; i < nbr_of_gap_blocks; ++i) { + BoundedByteReader sub_reader = + reader->sub_reader(offset); + + uint16_t start = sub_reader.Load16<0>(); + uint16_t end = sub_reader.Load16<2>(); + gap_ack_blocks.emplace_back(start, end); + offset += kGapAckBlockSize; + } + + std::set duplicate_tsns; + for (int i = 0; i < nbr_of_dup_tsns; ++i) { + BoundedByteReader sub_reader = + reader->sub_reader(offset); + + duplicate_tsns.insert(TSN(sub_reader.Load32<0>())); + offset += kDupTsnBlockSize; + } + RTC_DCHECK(offset == reader->variable_data_size()); + + return SackChunk(tsn_ack, a_rwnd, gap_ack_blocks, duplicate_tsns); +} + +void SackChunk::SerializeTo(std::vector& out) const { + int nbr_of_gap_blocks = gap_ack_blocks_.size(); + int nbr_of_dup_tsns = duplicate_tsns_.size(); + size_t variable_size = + nbr_of_gap_blocks * kGapAckBlockSize + nbr_of_dup_tsns * kDupTsnBlockSize; + BoundedByteWriter writer = AllocateTLV(out, variable_size); + + writer.Store32<4>(*cumulative_tsn_ack_); + writer.Store32<8>(a_rwnd_); + writer.Store16<12>(nbr_of_gap_blocks); + writer.Store16<14>(nbr_of_dup_tsns); + + size_t offset = 0; + for (int i = 0; i < nbr_of_gap_blocks; ++i) { + BoundedByteWriter sub_writer = + writer.sub_writer(offset); + + sub_writer.Store16<0>(gap_ack_blocks_[i].start); + sub_writer.Store16<2>(gap_ack_blocks_[i].end); + offset += kGapAckBlockSize; + } + + for (TSN tsn : duplicate_tsns_) { + BoundedByteWriter sub_writer = + writer.sub_writer(offset); + + sub_writer.Store32<0>(*tsn); + offset += kDupTsnBlockSize; + } + + RTC_DCHECK(offset == variable_size); +} + +std::string SackChunk::ToString() const { + rtc::StringBuilder sb; + sb << "SACK, cum_ack_tsn=" << *cumulative_tsn_ack() + << ", a_rwnd=" << a_rwnd(); + for (const GapAckBlock& gap : gap_ack_blocks_) { + uint32_t first = *cumulative_tsn_ack_ + gap.start; + uint32_t last = *cumulative_tsn_ack_ + gap.end; + sb << ", gap=" << first << "--" << last; + } + if (!duplicate_tsns_.empty()) { + sb << ", dup_tsns=" + << StrJoin(duplicate_tsns(), ",", + [](rtc::StringBuilder& sb, TSN tsn) { sb << *tsn; }); + } + + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/sack_chunk.h b/net/dcsctp/packet/chunk/sack_chunk.h new file mode 100644 index 0000000000..e6758fa332 --- /dev/null +++ b/net/dcsctp/packet/chunk/sack_chunk.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_SACK_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_SACK_CHUNK_H_ +#include + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.4 +struct SackChunkConfig : ChunkConfig { + static constexpr int kType = 3; + static constexpr size_t kHeaderSize = 16; + static constexpr size_t kVariableLengthAlignment = 4; +}; + +class SackChunk : public Chunk, public TLVTrait { + public: + static constexpr int kType = SackChunkConfig::kType; + + struct GapAckBlock { + GapAckBlock(uint16_t start, uint16_t end) : start(start), end(end) {} + + uint16_t start; + uint16_t end; + + bool operator==(const GapAckBlock& other) const { + return start == other.start && end == other.end; + } + }; + + SackChunk(TSN cumulative_tsn_ack, + uint32_t a_rwnd, + std::vector gap_ack_blocks, + std::set duplicate_tsns) + : cumulative_tsn_ack_(cumulative_tsn_ack), + a_rwnd_(a_rwnd), + gap_ack_blocks_(std::move(gap_ack_blocks)), + duplicate_tsns_(std::move(duplicate_tsns)) {} + static absl::optional Parse(rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + TSN cumulative_tsn_ack() const { return cumulative_tsn_ack_; } + uint32_t a_rwnd() const { return a_rwnd_; } + rtc::ArrayView gap_ack_blocks() const { + return gap_ack_blocks_; + } + const std::set& duplicate_tsns() const { return duplicate_tsns_; } + + private: + static constexpr size_t kGapAckBlockSize = 4; + static constexpr size_t kDupTsnBlockSize = 4; + + const TSN cumulative_tsn_ack_; + const uint32_t a_rwnd_; + std::vector gap_ack_blocks_; + std::set duplicate_tsns_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_SACK_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/sack_chunk_test.cc b/net/dcsctp/packet/chunk/sack_chunk_test.cc new file mode 100644 index 0000000000..9122945308 --- /dev/null +++ b/net/dcsctp/packet/chunk/sack_chunk_test.cc @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/sack_chunk.h" + +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(SackChunkTest, FromCapture) { + /* + SACK chunk (Cumulative TSN: 916312075, a_rwnd: 126323, + gaps: 2, duplicate TSNs: 1) + Chunk type: SACK (3) + Chunk flags: 0x00 + Chunk length: 28 + Cumulative TSN ACK: 916312075 + Advertised receiver window credit (a_rwnd): 126323 + Number of gap acknowledgement blocks: 2 + Number of duplicated TSNs: 1 + Gap Acknowledgement for TSN 916312077 to 916312081 + Gap Acknowledgement for TSN 916312083 to 916312083 + [Number of TSNs in gap acknowledgement blocks: 6] + Duplicate TSN: 916312081 + + */ + + uint8_t data[] = {0x03, 0x00, 0x00, 0x1c, 0x36, 0x9d, 0xd0, 0x0b, 0x00, 0x01, + 0xed, 0x73, 0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0x00, 0x06, + 0x00, 0x08, 0x00, 0x08, 0x36, 0x9d, 0xd0, 0x11}; + + ASSERT_HAS_VALUE_AND_ASSIGN(SackChunk chunk, SackChunk::Parse(data)); + + TSN cum_ack_tsn(916312075); + EXPECT_EQ(chunk.cumulative_tsn_ack(), cum_ack_tsn); + EXPECT_EQ(chunk.a_rwnd(), 126323u); + EXPECT_THAT( + chunk.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock( + static_cast(916312077 - *cum_ack_tsn), + static_cast(916312081 - *cum_ack_tsn)), + SackChunk::GapAckBlock( + static_cast(916312083 - *cum_ack_tsn), + static_cast(916312083 - *cum_ack_tsn)))); + EXPECT_THAT(chunk.duplicate_tsns(), ElementsAre(TSN(916312081))); +} + +TEST(SackChunkTest, SerializeAndDeserialize) { + SackChunk chunk(TSN(123), /*a_rwnd=*/456, {SackChunk::GapAckBlock(2, 3)}, + {TSN(1), TSN(2), TSN(3)}); + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(SackChunk deserialized, + SackChunk::Parse(serialized)); + + EXPECT_EQ(*deserialized.cumulative_tsn_ack(), 123u); + EXPECT_EQ(deserialized.a_rwnd(), 456u); + EXPECT_THAT(deserialized.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 3))); + EXPECT_THAT(deserialized.duplicate_tsns(), + ElementsAre(TSN(1), TSN(2), TSN(3))); + + EXPECT_EQ(deserialized.ToString(), + "SACK, cum_ack_tsn=123, a_rwnd=456, gap=125--126, dup_tsns=1,2,3"); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/shutdown_ack_chunk.cc b/net/dcsctp/packet/chunk/shutdown_ack_chunk.cc new file mode 100644 index 0000000000..d42aceead4 --- /dev/null +++ b/net/dcsctp/packet/chunk/shutdown_ack_chunk.cc @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/shutdown_ack_chunk.h" + +#include + +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.9 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 8 |Chunk Flags | Length = 4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ShutdownAckChunk::kType; + +absl::optional ShutdownAckChunk::Parse( + rtc::ArrayView data) { + if (!ParseTLV(data).has_value()) { + return absl::nullopt; + } + return ShutdownAckChunk(); +} + +void ShutdownAckChunk::SerializeTo(std::vector& out) const { + AllocateTLV(out); +} + +std::string ShutdownAckChunk::ToString() const { + return "SHUTDOWN-ACK"; +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/shutdown_ack_chunk.h b/net/dcsctp/packet/chunk/shutdown_ack_chunk.h new file mode 100644 index 0000000000..29c1a98be6 --- /dev/null +++ b/net/dcsctp/packet/chunk/shutdown_ack_chunk.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_ACK_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_ACK_CHUNK_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.9 +struct ShutdownAckChunkConfig : ChunkConfig { + static constexpr int kType = 8; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class ShutdownAckChunk : public Chunk, public TLVTrait { + public: + static constexpr int kType = ShutdownAckChunkConfig::kType; + + ShutdownAckChunk() {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_ACK_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/shutdown_ack_chunk_test.cc b/net/dcsctp/packet/chunk/shutdown_ack_chunk_test.cc new file mode 100644 index 0000000000..ef04ea9892 --- /dev/null +++ b/net/dcsctp/packet/chunk/shutdown_ack_chunk_test.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/shutdown_ack_chunk.h" + +#include + +#include + +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(ShutdownAckChunkTest, FromCapture) { + /* + SHUTDOWN_ACK chunk + Chunk type: SHUTDOWN_ACK (8) + Chunk flags: 0x00 + Chunk length: 4 + */ + + uint8_t data[] = {0x08, 0x00, 0x00, 0x04}; + + EXPECT_TRUE(ShutdownAckChunk::Parse(data).has_value()); +} + +TEST(ShutdownAckChunkTest, SerializeAndDeserialize) { + ShutdownAckChunk chunk; + + std::vector serialized; + chunk.SerializeTo(serialized); + + EXPECT_TRUE(ShutdownAckChunk::Parse(serialized).has_value()); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/shutdown_chunk.cc b/net/dcsctp/packet/chunk/shutdown_chunk.cc new file mode 100644 index 0000000000..59f806f7f7 --- /dev/null +++ b/net/dcsctp/packet/chunk/shutdown_chunk.cc @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" + +#include + +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.8 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 7 | Chunk Flags | Length = 8 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cumulative TSN Ack | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ShutdownChunk::kType; + +absl::optional ShutdownChunk::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + TSN cumulative_tsn_ack(reader->Load32<4>()); + return ShutdownChunk(cumulative_tsn_ack); +} + +void ShutdownChunk::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = AllocateTLV(out); + writer.Store32<4>(*cumulative_tsn_ack_); +} + +std::string ShutdownChunk::ToString() const { + return "SHUTDOWN"; +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/shutdown_chunk.h b/net/dcsctp/packet/chunk/shutdown_chunk.h new file mode 100644 index 0000000000..8148cca286 --- /dev/null +++ b/net/dcsctp/packet/chunk/shutdown_chunk.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_CHUNK_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.8 +struct ShutdownChunkConfig : ChunkConfig { + static constexpr int kType = 7; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class ShutdownChunk : public Chunk, public TLVTrait { + public: + static constexpr int kType = ShutdownChunkConfig::kType; + + explicit ShutdownChunk(TSN cumulative_tsn_ack) + : cumulative_tsn_ack_(cumulative_tsn_ack) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + TSN cumulative_tsn_ack() const { return cumulative_tsn_ack_; } + + private: + TSN cumulative_tsn_ack_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/shutdown_chunk_test.cc b/net/dcsctp/packet/chunk/shutdown_chunk_test.cc new file mode 100644 index 0000000000..16d147ca83 --- /dev/null +++ b/net/dcsctp/packet/chunk/shutdown_chunk_test.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" + +#include + +#include +#include + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { +TEST(ShutdownChunkTest, FromCapture) { + /* + SHUTDOWN chunk (Cumulative TSN ack: 101831101) + Chunk type: SHUTDOWN (7) + Chunk flags: 0x00 + Chunk length: 8 + Cumulative TSN Ack: 101831101 + */ + + uint8_t data[] = {0x07, 0x00, 0x00, 0x08, 0x06, 0x11, 0xd1, 0xbd}; + + ASSERT_HAS_VALUE_AND_ASSIGN(ShutdownChunk chunk, ShutdownChunk::Parse(data)); + EXPECT_EQ(chunk.cumulative_tsn_ack(), TSN(101831101u)); +} + +TEST(ShutdownChunkTest, SerializeAndDeserialize) { + ShutdownChunk chunk(TSN(12345678)); + + std::vector serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(ShutdownChunk deserialized, + ShutdownChunk::Parse(serialized)); + + EXPECT_EQ(deserialized.cumulative_tsn_ack(), TSN(12345678u)); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/shutdown_complete_chunk.cc b/net/dcsctp/packet/chunk/shutdown_complete_chunk.cc new file mode 100644 index 0000000000..3f54857437 --- /dev/null +++ b/net/dcsctp/packet/chunk/shutdown_complete_chunk.cc @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/shutdown_complete_chunk.h" + +#include + +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.13 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 14 |Reserved |T| Length = 4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ShutdownCompleteChunk::kType; + +absl::optional ShutdownCompleteChunk::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + uint8_t flags = reader->Load8<1>(); + bool tag_reflected = (flags & (1 << kFlagsBitT)) != 0; + return ShutdownCompleteChunk(tag_reflected); +} + +void ShutdownCompleteChunk::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = AllocateTLV(out); + writer.Store8<1>(tag_reflected_ ? (1 << kFlagsBitT) : 0); +} + +std::string ShutdownCompleteChunk::ToString() const { + return "SHUTDOWN-COMPLETE"; +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk/shutdown_complete_chunk.h b/net/dcsctp/packet/chunk/shutdown_complete_chunk.h new file mode 100644 index 0000000000..46d28e88dc --- /dev/null +++ b/net/dcsctp/packet/chunk/shutdown_complete_chunk.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_COMPLETE_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_COMPLETE_CHUNK_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.13 +struct ShutdownCompleteChunkConfig : ChunkConfig { + static constexpr int kType = 14; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class ShutdownCompleteChunk : public Chunk, + public TLVTrait { + public: + static constexpr int kType = ShutdownCompleteChunkConfig::kType; + + explicit ShutdownCompleteChunk(bool tag_reflected) + : tag_reflected_(tag_reflected) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + bool tag_reflected() const { return tag_reflected_; } + + private: + static constexpr int kFlagsBitT = 0; + bool tag_reflected_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_COMPLETE_CHUNK_H_ diff --git a/net/dcsctp/packet/chunk/shutdown_complete_chunk_test.cc b/net/dcsctp/packet/chunk/shutdown_complete_chunk_test.cc new file mode 100644 index 0000000000..253900d5cd --- /dev/null +++ b/net/dcsctp/packet/chunk/shutdown_complete_chunk_test.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/shutdown_complete_chunk.h" + +#include + +#include + +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(ShutdownCompleteChunkTest, FromCapture) { + /* + SHUTDOWN_COMPLETE chunk + Chunk type: SHUTDOWN_COMPLETE (14) + Chunk flags: 0x00 + Chunk length: 4 + */ + + uint8_t data[] = {0x0e, 0x00, 0x00, 0x04}; + + EXPECT_TRUE(ShutdownCompleteChunk::Parse(data).has_value()); +} + +TEST(ShutdownCompleteChunkTest, SerializeAndDeserialize) { + ShutdownCompleteChunk chunk(/*tag_reflected=*/false); + + std::vector serialized; + chunk.SerializeTo(serialized); + + EXPECT_TRUE(ShutdownCompleteChunk::Parse(serialized).has_value()); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk_validators.cc b/net/dcsctp/packet/chunk_validators.cc new file mode 100644 index 0000000000..48d351827e --- /dev/null +++ b/net/dcsctp/packet/chunk_validators.cc @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk_validators.h" + +#include +#include +#include + +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +SackChunk ChunkValidators::Clean(SackChunk&& sack) { + if (Validate(sack)) { + return std::move(sack); + } + + RTC_DLOG(LS_WARNING) << "Received SACK is malformed; cleaning it"; + + std::vector gap_ack_blocks; + gap_ack_blocks.reserve(sack.gap_ack_blocks().size()); + + // First: Only keep blocks that are sane + for (const SackChunk::GapAckBlock& gap_ack_block : sack.gap_ack_blocks()) { + if (gap_ack_block.end > gap_ack_block.start) { + gap_ack_blocks.emplace_back(gap_ack_block); + } + } + + // Not more than at most one remaining? Exit early. + if (gap_ack_blocks.size() <= 1) { + return SackChunk(sack.cumulative_tsn_ack(), sack.a_rwnd(), + std::move(gap_ack_blocks), sack.duplicate_tsns()); + } + + // Sort the intervals by their start value, to aid in the merging below. + absl::c_sort(gap_ack_blocks, [&](const SackChunk::GapAckBlock& a, + const SackChunk::GapAckBlock& b) { + return a.start < b.start; + }); + + // Merge overlapping ranges. + std::vector merged; + merged.reserve(gap_ack_blocks.size()); + merged.push_back(gap_ack_blocks[0]); + + for (size_t i = 1; i < gap_ack_blocks.size(); ++i) { + if (merged.back().end + 1 >= gap_ack_blocks[i].start) { + merged.back().end = std::max(merged.back().end, gap_ack_blocks[i].end); + } else { + merged.push_back(gap_ack_blocks[i]); + } + } + + return SackChunk(sack.cumulative_tsn_ack(), sack.a_rwnd(), std::move(merged), + sack.duplicate_tsns()); +} + +bool ChunkValidators::Validate(const SackChunk& sack) { + if (sack.gap_ack_blocks().empty()) { + return true; + } + + // Ensure that gap-ack-blocks are sorted, has an "end" that is not before + // "start" and are non-overlapping and non-adjacent. + uint16_t prev_end = 0; + for (const SackChunk::GapAckBlock& gap_ack_block : sack.gap_ack_blocks()) { + if (gap_ack_block.end < gap_ack_block.start) { + return false; + } + if (gap_ack_block.start <= (prev_end + 1)) { + return false; + } + prev_end = gap_ack_block.end; + } + return true; +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/chunk_validators.h b/net/dcsctp/packet/chunk_validators.h new file mode 100644 index 0000000000..b11848a162 --- /dev/null +++ b/net/dcsctp/packet/chunk_validators.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_VALIDATORS_H_ +#define NET_DCSCTP_PACKET_CHUNK_VALIDATORS_H_ + +#include "net/dcsctp/packet/chunk/sack_chunk.h" + +namespace dcsctp { +// Validates and cleans SCTP chunks. +class ChunkValidators { + public: + // Given a SackChunk, will return `true` if it's valid, and `false` if not. + static bool Validate(const SackChunk& sack); + + // Given a SackChunk, it will return a cleaned and validated variant of it. + // RFC4960 doesn't say anything about validity of SACKs or if the Gap ACK + // blocks must be sorted, and non-overlapping. While they always are in + // well-behaving implementations, this can't be relied on. + // + // This method internally calls `Validate`, which means that you can always + // pass a SackChunk to this method (valid or not), and use the results. + static SackChunk Clean(SackChunk&& sack); +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_VALIDATORS_H_ diff --git a/net/dcsctp/packet/chunk_validators_test.cc b/net/dcsctp/packet/chunk_validators_test.cc new file mode 100644 index 0000000000..d59fd4ec48 --- /dev/null +++ b/net/dcsctp/packet/chunk_validators_test.cc @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk_validators.h" + +#include + +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::IsEmpty; + +TEST(ChunkValidatorsTest, NoGapAckBlocksAreValid) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + /*gap_ack_blocks=*/{}, {}); + + EXPECT_TRUE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + EXPECT_THAT(clean.gap_ack_blocks(), IsEmpty()); +} + +TEST(ChunkValidatorsTest, OneValidAckBlock) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, {SackChunk::GapAckBlock(2, 3)}, {}); + + EXPECT_TRUE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 3))); +} + +TEST(ChunkValidatorsTest, TwoValidAckBlocks) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(2, 3), SackChunk::GapAckBlock(5, 6)}, + {}); + + EXPECT_TRUE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + EXPECT_THAT( + clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 3), SackChunk::GapAckBlock(5, 6))); +} + +TEST(ChunkValidatorsTest, OneInvalidAckBlock) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, {SackChunk::GapAckBlock(1, 2)}, {}); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + // It's not strictly valid, but due to the renegable nature of gap ack blocks, + // the cum_ack_tsn can't simply be moved. + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(1, 2))); +} + +TEST(ChunkValidatorsTest, RemovesInvalidGapAckBlockFromSack) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(2, 3), SackChunk::GapAckBlock(6, 4)}, + {}); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 3))); +} + +TEST(ChunkValidatorsTest, SortsGapAckBlocksInOrder) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(6, 7), SackChunk::GapAckBlock(3, 4)}, + {}); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + + EXPECT_THAT( + clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(3, 4), SackChunk::GapAckBlock(6, 7))); +} + +TEST(ChunkValidatorsTest, MergesAdjacentBlocks) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(3, 4), SackChunk::GapAckBlock(5, 6)}, + {}); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(3, 6))); +} + +TEST(ChunkValidatorsTest, MergesOverlappingByOne) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(3, 4), SackChunk::GapAckBlock(4, 5)}, + {}); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(3, 5))); +} + +TEST(ChunkValidatorsTest, MergesOverlappingByMore) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(3, 10), SackChunk::GapAckBlock(4, 5)}, + {}); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(3, 10))); +} + +TEST(ChunkValidatorsTest, MergesBlocksStartingWithSameStartOffset) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(3, 7), SackChunk::GapAckBlock(3, 5), + SackChunk::GapAckBlock(3, 9)}, + {}); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(3, 9))); +} + +TEST(ChunkValidatorsTest, MergesBlocksPartiallyOverlapping) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(3, 7), SackChunk::GapAckBlock(5, 9)}, + {}); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(3, 9))); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/crc32c.cc b/net/dcsctp/packet/crc32c.cc new file mode 100644 index 0000000000..e3f0dc1d19 --- /dev/null +++ b/net/dcsctp/packet/crc32c.cc @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/crc32c.h" + +#include + +#include "third_party/crc32c/src/include/crc32c/crc32c.h" + +namespace dcsctp { + +uint32_t GenerateCrc32C(rtc::ArrayView data) { + uint32_t crc32c = crc32c_value(data.data(), data.size()); + + // Byte swapping for little endian byte order: + uint8_t byte0 = crc32c; + uint8_t byte1 = crc32c >> 8; + uint8_t byte2 = crc32c >> 16; + uint8_t byte3 = crc32c >> 24; + crc32c = ((byte0 << 24) | (byte1 << 16) | (byte2 << 8) | byte3); + return crc32c; +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/crc32c.h b/net/dcsctp/packet/crc32c.h new file mode 100644 index 0000000000..a969e1b26b --- /dev/null +++ b/net/dcsctp/packet/crc32c.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CRC32C_H_ +#define NET_DCSCTP_PACKET_CRC32C_H_ + +#include + +#include "api/array_view.h" + +namespace dcsctp { + +// Generates the CRC32C checksum of `data`. +uint32_t GenerateCrc32C(rtc::ArrayView data); + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CRC32C_H_ diff --git a/net/dcsctp/packet/crc32c_test.cc b/net/dcsctp/packet/crc32c_test.cc new file mode 100644 index 0000000000..0821c4ef75 --- /dev/null +++ b/net/dcsctp/packet/crc32c_test.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/crc32c.h" + +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +constexpr std::array kEmpty = {}; +constexpr std::array kZero = {0}; +constexpr std::array kManyZeros = {0, 0, 0, 0}; +constexpr std::array kShort = {1, 2, 3, 4}; +constexpr std::array kLong = {1, 2, 3, 4, 5, 6, 7, 8}; +// https://tools.ietf.org/html/rfc3720#appendix-B.4 +constexpr std::array k32Zeros = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; +constexpr std::array k32Ones = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; +constexpr std::array k32Incrementing = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; +constexpr std::array k32Decrementing = { + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; +constexpr std::array kISCSICommandPDU = { + 0x01, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, + 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x28, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +}; + +TEST(Crc32Test, TestVectors) { + EXPECT_EQ(GenerateCrc32C(kEmpty), 0U); + EXPECT_EQ(GenerateCrc32C(kZero), 0x51537d52U); + EXPECT_EQ(GenerateCrc32C(kManyZeros), 0xc74b6748U); + EXPECT_EQ(GenerateCrc32C(kShort), 0xf48c3029U); + EXPECT_EQ(GenerateCrc32C(kLong), 0x811f8946U); + // https://tools.ietf.org/html/rfc3720#appendix-B.4 + EXPECT_EQ(GenerateCrc32C(k32Zeros), 0xaa36918aU); + EXPECT_EQ(GenerateCrc32C(k32Ones), 0x43aba862U); + EXPECT_EQ(GenerateCrc32C(k32Incrementing), 0x4e79dd46U); + EXPECT_EQ(GenerateCrc32C(k32Decrementing), 0x5cdb3f11U); + EXPECT_EQ(GenerateCrc32C(kISCSICommandPDU), 0x563a96d9U); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/data.h b/net/dcsctp/packet/data.h new file mode 100644 index 0000000000..f2d2e74904 --- /dev/null +++ b/net/dcsctp/packet/data.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_DATA_H_ +#define NET_DCSCTP_PACKET_DATA_H_ + +#include +#include +#include + +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// Represents data that is either received and extracted from a DATA/I-DATA +// chunk, or data that is supposed to be sent, and wrapped in a DATA/I-DATA +// chunk (depending on peer capabilities). +// +// The data wrapped in this structure is actually the same as the DATA/I-DATA +// chunk (actually the union of them), but to avoid having all components be +// aware of the implementation details of the different chunks, this abstraction +// is used instead. A notable difference is also that it doesn't carry a +// Transmission Sequence Number (TSN), as that is not known when a chunk is +// created (assigned late, just when sending), and that the TSNs in DATA/I-DATA +// are wrapped numbers, and within the library, unwrapped sequence numbers are +// preferably used. +struct Data { + // Indicates if a chunk is the first in a fragmented message and maps to the + // "beginning" flag in DATA/I-DATA chunk. + using IsBeginning = StrongAlias; + + // Indicates if a chunk is the last in a fragmented message and maps to the + // "end" flag in DATA/I-DATA chunk. + using IsEnd = StrongAlias; + + Data(StreamID stream_id, + SSN ssn, + MID message_id, + FSN fsn, + PPID ppid, + std::vector payload, + IsBeginning is_beginning, + IsEnd is_end, + IsUnordered is_unordered) + : stream_id(stream_id), + ssn(ssn), + message_id(message_id), + fsn(fsn), + ppid(ppid), + payload(std::move(payload)), + is_beginning(is_beginning), + is_end(is_end), + is_unordered(is_unordered) {} + + // Move-only, to avoid accidental copies. + Data(Data&& other) = default; + Data& operator=(Data&& other) = default; + + // Creates a copy of this `Data` object. + Data Clone() const { + return Data(stream_id, ssn, message_id, fsn, ppid, payload, is_beginning, + is_end, is_unordered); + } + + // The size of this data, which translates to the size of its payload. + size_t size() const { return payload.size(); } + + // Stream Identifier. + StreamID stream_id; + + // Stream Sequence Number (SSN), per stream, for ordered chunks. Defined by + // RFC4960 and used only in DATA chunks (not I-DATA). + SSN ssn; + + // Message Identifier (MID) per stream and ordered/unordered. Defined by + // RFC8260, and used together with options.is_unordered and stream_id to + // uniquely identify a message. Used only in I-DATA chunks (not DATA). + MID message_id; + // Fragment Sequence Number (FSN) per stream and ordered/unordered, as above. + FSN fsn; + + // Payload Protocol Identifier (PPID). + PPID ppid; + + // The actual data payload. + std::vector payload; + + // If this data represents the first, last or a middle chunk. + IsBeginning is_beginning; + IsEnd is_end; + // If this data is sent/received unordered. + IsUnordered is_unordered; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_DATA_H_ diff --git a/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.cc b/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.cc new file mode 100644 index 0000000000..ef67c2a49f --- /dev/null +++ b/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h" + +#include + +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.10 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=10 | Cause Length=4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int CookieReceivedWhileShuttingDownCause::kType; + +absl::optional +CookieReceivedWhileShuttingDownCause::Parse( + rtc::ArrayView data) { + if (!ParseTLV(data).has_value()) { + return absl::nullopt; + } + return CookieReceivedWhileShuttingDownCause(); +} + +void CookieReceivedWhileShuttingDownCause::SerializeTo( + std::vector& out) const { + AllocateTLV(out); +} + +std::string CookieReceivedWhileShuttingDownCause::ToString() const { + return "Cookie Received While Shutting Down"; +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h b/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h new file mode 100644 index 0000000000..362f181fba --- /dev/null +++ b/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_COOKIE_RECEIVED_WHILE_SHUTTING_DOWN_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_COOKIE_RECEIVED_WHILE_SHUTTING_DOWN_CAUSE_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.10 +struct CookieReceivedWhileShuttingDownCauseConfig : public ParameterConfig { + static constexpr int kType = 10; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class CookieReceivedWhileShuttingDownCause + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = + CookieReceivedWhileShuttingDownCauseConfig::kType; + + CookieReceivedWhileShuttingDownCause() {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_COOKIE_RECEIVED_WHILE_SHUTTING_DOWN_CAUSE_H_ diff --git a/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause_test.cc b/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause_test.cc new file mode 100644 index 0000000000..afb8364c32 --- /dev/null +++ b/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause_test.cc @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h" + +#include + +#include +#include + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(CookieReceivedWhileShuttingDownCauseTest, SerializeAndDeserialize) { + CookieReceivedWhileShuttingDownCause parameter; + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + CookieReceivedWhileShuttingDownCause deserialized, + CookieReceivedWhileShuttingDownCause::Parse(serialized)); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/error_cause.cc b/net/dcsctp/packet/error_cause/error_cause.cc new file mode 100644 index 0000000000..dcd07472ed --- /dev/null +++ b/net/dcsctp/packet/error_cause/error_cause.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/error_cause.h" + +#include + +#include +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h" +#include "net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.h" +#include "net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.h" +#include "net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.h" +#include "net/dcsctp/packet/error_cause/no_user_data_cause.h" +#include "net/dcsctp/packet/error_cause/out_of_resource_error_cause.h" +#include "net/dcsctp/packet/error_cause/protocol_violation_cause.h" +#include "net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.h" +#include "net/dcsctp/packet/error_cause/stale_cookie_error_cause.h" +#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h" +#include "net/dcsctp/packet/error_cause/unrecognized_parameter_cause.h" +#include "net/dcsctp/packet/error_cause/unresolvable_address_cause.h" +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +template +bool ParseAndPrint(ParameterDescriptor descriptor, rtc::StringBuilder& sb) { + if (descriptor.type == ErrorCause::kType) { + absl::optional p = ErrorCause::Parse(descriptor.data); + if (p.has_value()) { + sb << p->ToString(); + } else { + sb << "Failed to parse error cause of type " << ErrorCause::kType; + } + return true; + } + return false; +} + +std::string ErrorCausesToString(const Parameters& parameters) { + rtc::StringBuilder sb; + + std::vector descriptors = parameters.descriptors(); + for (size_t i = 0; i < descriptors.size(); ++i) { + if (i > 0) { + sb << "\n"; + } + + const ParameterDescriptor& d = descriptors[i]; + if (!ParseAndPrint(d, sb) && + !ParseAndPrint(d, sb) && + !ParseAndPrint(d, sb) && + !ParseAndPrint(d, sb) && + !ParseAndPrint(d, sb) && + !ParseAndPrint(d, sb) && + !ParseAndPrint(d, sb) && + !ParseAndPrint(d, sb) && + !ParseAndPrint(d, sb) && + !ParseAndPrint(d, sb) && + !ParseAndPrint(d, sb) && + !ParseAndPrint(d, sb) && + !ParseAndPrint(d, sb)) { + sb << "Unhandled parameter of type: " << d.type; + } + } + + return sb.Release(); +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/error_cause.h b/net/dcsctp/packet/error_cause/error_cause.h new file mode 100644 index 0000000000..fa2bf81478 --- /dev/null +++ b/net/dcsctp/packet/error_cause/error_cause.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_ERROR_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_ERROR_CAUSE_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// Converts the Error Causes in `parameters` to a human readable string, +// to be used in error reporting and logging. +std::string ErrorCausesToString(const Parameters& parameters); + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_ERROR_CAUSE_H_ diff --git a/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.cc b/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.cc new file mode 100644 index 0000000000..0187544226 --- /dev/null +++ b/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.h" + +#include + +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.7 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=7 | Cause Length=4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int InvalidMandatoryParameterCause::kType; + +absl::optional +InvalidMandatoryParameterCause::Parse(rtc::ArrayView data) { + if (!ParseTLV(data).has_value()) { + return absl::nullopt; + } + return InvalidMandatoryParameterCause(); +} + +void InvalidMandatoryParameterCause::SerializeTo( + std::vector& out) const { + AllocateTLV(out); +} + +std::string InvalidMandatoryParameterCause::ToString() const { + return "Invalid Mandatory Parameter"; +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.h b/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.h new file mode 100644 index 0000000000..e192b5a42f --- /dev/null +++ b/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_INVALID_MANDATORY_PARAMETER_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_INVALID_MANDATORY_PARAMETER_CAUSE_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.7 +struct InvalidMandatoryParameterCauseConfig : public ParameterConfig { + static constexpr int kType = 7; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class InvalidMandatoryParameterCause + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = InvalidMandatoryParameterCauseConfig::kType; + + InvalidMandatoryParameterCause() {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_INVALID_MANDATORY_PARAMETER_CAUSE_H_ diff --git a/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause_test.cc b/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause_test.cc new file mode 100644 index 0000000000..3d532d09b1 --- /dev/null +++ b/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause_test.cc @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.h" + +#include + +#include +#include + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(InvalidMandatoryParameterCauseTest, SerializeAndDeserialize) { + InvalidMandatoryParameterCause parameter; + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + InvalidMandatoryParameterCause deserialized, + InvalidMandatoryParameterCause::Parse(serialized)); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.cc b/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.cc new file mode 100644 index 0000000000..b2ddd6f4ef --- /dev/null +++ b/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.cc @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.1 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=1 | Cause Length=8 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Identifier | (Reserved) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int InvalidStreamIdentifierCause::kType; + +absl::optional +InvalidStreamIdentifierCause::Parse(rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + StreamID stream_id(reader->Load16<4>()); + return InvalidStreamIdentifierCause(stream_id); +} + +void InvalidStreamIdentifierCause::SerializeTo( + std::vector& out) const { + BoundedByteWriter writer = AllocateTLV(out); + + writer.Store16<4>(*stream_id_); +} + +std::string InvalidStreamIdentifierCause::ToString() const { + rtc::StringBuilder sb; + sb << "Invalid Stream Identifier, stream_id=" << *stream_id_; + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.h b/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.h new file mode 100644 index 0000000000..b7dfe177b8 --- /dev/null +++ b/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_INVALID_STREAM_IDENTIFIER_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_INVALID_STREAM_IDENTIFIER_CAUSE_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.1 +struct InvalidStreamIdentifierCauseConfig : public ParameterConfig { + static constexpr int kType = 1; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class InvalidStreamIdentifierCause + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = InvalidStreamIdentifierCauseConfig::kType; + + explicit InvalidStreamIdentifierCause(StreamID stream_id) + : stream_id_(stream_id) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + StreamID stream_id() const { return stream_id_; } + + private: + StreamID stream_id_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_INVALID_STREAM_IDENTIFIER_CAUSE_H_ diff --git a/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause_test.cc b/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause_test.cc new file mode 100644 index 0000000000..a282ce5ee8 --- /dev/null +++ b/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause_test.cc @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.h" + +#include + +#include +#include + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(InvalidStreamIdentifierCauseTest, SerializeAndDeserialize) { + InvalidStreamIdentifierCause parameter(StreamID(1)); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(InvalidStreamIdentifierCause deserialized, + InvalidStreamIdentifierCause::Parse(serialized)); + + EXPECT_EQ(*deserialized.stream_id(), 1); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.cc b/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.cc new file mode 100644 index 0000000000..b89f86e43e --- /dev/null +++ b/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.cc @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.h" + +#include + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.2 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=2 | Cause Length=8+N*2 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Number of missing params=N | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Missing Param Type #1 | Missing Param Type #2 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Missing Param Type #N-1 | Missing Param Type #N | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int MissingMandatoryParameterCause::kType; + +absl::optional +MissingMandatoryParameterCause::Parse(rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + uint32_t count = reader->Load32<4>(); + if (reader->variable_data_size() / kMissingParameterSize != count) { + RTC_DLOG(LS_WARNING) << "Invalid number of missing parameters"; + return absl::nullopt; + } + + std::vector missing_parameter_types; + missing_parameter_types.reserve(count); + for (uint32_t i = 0; i < count; ++i) { + BoundedByteReader sub_reader = + reader->sub_reader(i * kMissingParameterSize); + + missing_parameter_types.push_back(sub_reader.Load16<0>()); + } + return MissingMandatoryParameterCause(missing_parameter_types); +} + +void MissingMandatoryParameterCause::SerializeTo( + std::vector& out) const { + size_t variable_size = + missing_parameter_types_.size() * kMissingParameterSize; + BoundedByteWriter writer = AllocateTLV(out, variable_size); + + writer.Store32<4>(missing_parameter_types_.size()); + + for (size_t i = 0; i < missing_parameter_types_.size(); ++i) { + BoundedByteWriter sub_writer = + writer.sub_writer(i * kMissingParameterSize); + + sub_writer.Store16<0>(missing_parameter_types_[i]); + } +} + +std::string MissingMandatoryParameterCause::ToString() const { + rtc::StringBuilder sb; + sb << "Missing Mandatory Parameter, missing_parameter_types=" + << StrJoin(missing_parameter_types_, ","); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.h b/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.h new file mode 100644 index 0000000000..4435424295 --- /dev/null +++ b/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_MISSING_MANDATORY_PARAMETER_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_MISSING_MANDATORY_PARAMETER_CAUSE_H_ +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.2 +struct MissingMandatoryParameterCauseConfig : public ParameterConfig { + static constexpr int kType = 2; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 2; +}; + +class MissingMandatoryParameterCause + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = MissingMandatoryParameterCauseConfig::kType; + + explicit MissingMandatoryParameterCause( + rtc::ArrayView missing_parameter_types) + : missing_parameter_types_(missing_parameter_types.begin(), + missing_parameter_types.end()) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + rtc::ArrayView missing_parameter_types() const { + return missing_parameter_types_; + } + + private: + static constexpr size_t kMissingParameterSize = 2; + std::vector missing_parameter_types_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_MISSING_MANDATORY_PARAMETER_CAUSE_H_ diff --git a/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause_test.cc b/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause_test.cc new file mode 100644 index 0000000000..1c526ff0e2 --- /dev/null +++ b/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause_test.cc @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.h" + +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::IsEmpty; + +TEST(MissingMandatoryParameterCauseTest, SerializeAndDeserialize) { + uint16_t parameter_types[] = {1, 2, 3}; + MissingMandatoryParameterCause parameter(parameter_types); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + MissingMandatoryParameterCause deserialized, + MissingMandatoryParameterCause::Parse(serialized)); + + EXPECT_THAT(deserialized.missing_parameter_types(), ElementsAre(1, 2, 3)); +} + +TEST(MissingMandatoryParameterCauseTest, HandlesDeserializeZeroParameters) { + uint8_t serialized[] = {0, 2, 0, 8, 0, 0, 0, 0}; + + ASSERT_HAS_VALUE_AND_ASSIGN( + MissingMandatoryParameterCause deserialized, + MissingMandatoryParameterCause::Parse(serialized)); + + EXPECT_THAT(deserialized.missing_parameter_types(), IsEmpty()); +} + +TEST(MissingMandatoryParameterCauseTest, HandlesOverflowParameterCount) { + // 0x80000004 * 2 = 2**32 + 8 -> if overflow, would validate correctly. + uint8_t serialized[] = {0, 2, 0, 8, 0x80, 0x00, 0x00, 0x04}; + + EXPECT_FALSE(MissingMandatoryParameterCause::Parse(serialized).has_value()); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/no_user_data_cause.cc b/net/dcsctp/packet/error_cause/no_user_data_cause.cc new file mode 100644 index 0000000000..2853915b0c --- /dev/null +++ b/net/dcsctp/packet/error_cause/no_user_data_cause.cc @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/no_user_data_cause.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.9 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=9 | Cause Length=8 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / TSN value / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int NoUserDataCause::kType; + +absl::optional NoUserDataCause::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + TSN tsn(reader->Load32<4>()); + return NoUserDataCause(tsn); +} + +void NoUserDataCause::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = AllocateTLV(out); + writer.Store32<4>(*tsn_); +} + +std::string NoUserDataCause::ToString() const { + rtc::StringBuilder sb; + sb << "No User Data, tsn=" << *tsn_; + return sb.Release(); +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/no_user_data_cause.h b/net/dcsctp/packet/error_cause/no_user_data_cause.h new file mode 100644 index 0000000000..1087dcc97c --- /dev/null +++ b/net/dcsctp/packet/error_cause/no_user_data_cause.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_NO_USER_DATA_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_NO_USER_DATA_CAUSE_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.9 +struct NoUserDataCauseConfig : public ParameterConfig { + static constexpr int kType = 9; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class NoUserDataCause : public Parameter, + public TLVTrait { + public: + static constexpr int kType = NoUserDataCauseConfig::kType; + + explicit NoUserDataCause(TSN tsn) : tsn_(tsn) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + TSN tsn() const { return tsn_; } + + private: + TSN tsn_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_NO_USER_DATA_CAUSE_H_ diff --git a/net/dcsctp/packet/error_cause/no_user_data_cause_test.cc b/net/dcsctp/packet/error_cause/no_user_data_cause_test.cc new file mode 100644 index 0000000000..0a535bf4fa --- /dev/null +++ b/net/dcsctp/packet/error_cause/no_user_data_cause_test.cc @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/no_user_data_cause.h" + +#include + +#include +#include + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(NoUserDataCauseTest, SerializeAndDeserialize) { + NoUserDataCause parameter(TSN(123)); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(NoUserDataCause deserialized, + NoUserDataCause::Parse(serialized)); + + EXPECT_EQ(*deserialized.tsn(), 123u); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/out_of_resource_error_cause.cc b/net/dcsctp/packet/error_cause/out_of_resource_error_cause.cc new file mode 100644 index 0000000000..e5c7c0e787 --- /dev/null +++ b/net/dcsctp/packet/error_cause/out_of_resource_error_cause.cc @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/out_of_resource_error_cause.h" + +#include + +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.4 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=4 | Cause Length=4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int OutOfResourceErrorCause::kType; + +absl::optional OutOfResourceErrorCause::Parse( + rtc::ArrayView data) { + if (!ParseTLV(data).has_value()) { + return absl::nullopt; + } + return OutOfResourceErrorCause(); +} + +void OutOfResourceErrorCause::SerializeTo(std::vector& out) const { + AllocateTLV(out); +} + +std::string OutOfResourceErrorCause::ToString() const { + return "Out Of Resource"; +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/out_of_resource_error_cause.h b/net/dcsctp/packet/error_cause/out_of_resource_error_cause.h new file mode 100644 index 0000000000..fc798ca4ac --- /dev/null +++ b/net/dcsctp/packet/error_cause/out_of_resource_error_cause.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_OUT_OF_RESOURCE_ERROR_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_OUT_OF_RESOURCE_ERROR_CAUSE_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.4 +struct OutOfResourceParameterConfig : public ParameterConfig { + static constexpr int kType = 4; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class OutOfResourceErrorCause : public Parameter, + public TLVTrait { + public: + static constexpr int kType = OutOfResourceParameterConfig::kType; + + OutOfResourceErrorCause() {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_OUT_OF_RESOURCE_ERROR_CAUSE_H_ diff --git a/net/dcsctp/packet/error_cause/out_of_resource_error_cause_test.cc b/net/dcsctp/packet/error_cause/out_of_resource_error_cause_test.cc new file mode 100644 index 0000000000..501fc201cd --- /dev/null +++ b/net/dcsctp/packet/error_cause/out_of_resource_error_cause_test.cc @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/out_of_resource_error_cause.h" + +#include + +#include +#include + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(OutOfResourceErrorCauseTest, SerializeAndDeserialize) { + OutOfResourceErrorCause parameter; + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(OutOfResourceErrorCause deserialized, + OutOfResourceErrorCause::Parse(serialized)); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/protocol_violation_cause.cc b/net/dcsctp/packet/error_cause/protocol_violation_cause.cc new file mode 100644 index 0000000000..1b8d423afb --- /dev/null +++ b/net/dcsctp/packet/error_cause/protocol_violation_cause.cc @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/protocol_violation_cause.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.13 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=13 | Cause Length=Variable | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / Additional Information / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ProtocolViolationCause::kType; + +absl::optional ProtocolViolationCause::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + return ProtocolViolationCause( + std::string(reinterpret_cast(reader->variable_data().data()), + reader->variable_data().size())); +} + +void ProtocolViolationCause::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = + AllocateTLV(out, additional_information_.size()); + writer.CopyToVariableData(rtc::MakeArrayView( + reinterpret_cast(additional_information_.data()), + additional_information_.size())); +} + +std::string ProtocolViolationCause::ToString() const { + rtc::StringBuilder sb; + sb << "Protocol Violation, additional_information=" + << additional_information_; + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/protocol_violation_cause.h b/net/dcsctp/packet/error_cause/protocol_violation_cause.h new file mode 100644 index 0000000000..3081e1f28c --- /dev/null +++ b/net/dcsctp/packet/error_cause/protocol_violation_cause.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_PROTOCOL_VIOLATION_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_PROTOCOL_VIOLATION_CAUSE_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.13 +struct ProtocolViolationCauseConfig : public ParameterConfig { + static constexpr int kType = 13; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class ProtocolViolationCause : public Parameter, + public TLVTrait { + public: + static constexpr int kType = ProtocolViolationCauseConfig::kType; + + explicit ProtocolViolationCause(absl::string_view additional_information) + : additional_information_(additional_information) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + absl::string_view additional_information() const { + return additional_information_; + } + + private: + std::string additional_information_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_PROTOCOL_VIOLATION_CAUSE_H_ diff --git a/net/dcsctp/packet/error_cause/protocol_violation_cause_test.cc b/net/dcsctp/packet/error_cause/protocol_violation_cause_test.cc new file mode 100644 index 0000000000..902d867091 --- /dev/null +++ b/net/dcsctp/packet/error_cause/protocol_violation_cause_test.cc @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/protocol_violation_cause.h" + +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::SizeIs; + +TEST(ProtocolViolationCauseTest, EmptyReason) { + Parameters causes = + Parameters::Builder().Add(ProtocolViolationCause("")).Build(); + + ASSERT_HAS_VALUE_AND_ASSIGN(Parameters deserialized, + Parameters::Parse(causes.data())); + ASSERT_THAT(deserialized.descriptors(), SizeIs(1)); + EXPECT_EQ(deserialized.descriptors()[0].type, ProtocolViolationCause::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN( + ProtocolViolationCause cause, + ProtocolViolationCause::Parse(deserialized.descriptors()[0].data)); + + EXPECT_EQ(cause.additional_information(), ""); +} + +TEST(ProtocolViolationCauseTest, SetReason) { + Parameters causes = Parameters::Builder() + .Add(ProtocolViolationCause("Reason goes here")) + .Build(); + + ASSERT_HAS_VALUE_AND_ASSIGN(Parameters deserialized, + Parameters::Parse(causes.data())); + ASSERT_THAT(deserialized.descriptors(), SizeIs(1)); + EXPECT_EQ(deserialized.descriptors()[0].type, ProtocolViolationCause::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN( + ProtocolViolationCause cause, + ProtocolViolationCause::Parse(deserialized.descriptors()[0].data)); + + EXPECT_EQ(cause.additional_information(), "Reason goes here"); +} +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.cc b/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.cc new file mode 100644 index 0000000000..abe5de6211 --- /dev/null +++ b/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.h" + +#include + +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.11 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=11 | Cause Length=Variable | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / New Address TLVs / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int RestartOfAnAssociationWithNewAddressesCause::kType; + +absl::optional +RestartOfAnAssociationWithNewAddressesCause::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + return RestartOfAnAssociationWithNewAddressesCause(reader->variable_data()); +} + +void RestartOfAnAssociationWithNewAddressesCause::SerializeTo( + std::vector& out) const { + BoundedByteWriter writer = + AllocateTLV(out, new_address_tlvs_.size()); + writer.CopyToVariableData(new_address_tlvs_); +} + +std::string RestartOfAnAssociationWithNewAddressesCause::ToString() const { + return "Restart of an Association with New Addresses"; +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.h b/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.h new file mode 100644 index 0000000000..a1cccdc8a1 --- /dev/null +++ b/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_RESTART_OF_AN_ASSOCIATION_WITH_NEW_ADDRESS_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_RESTART_OF_AN_ASSOCIATION_WITH_NEW_ADDRESS_CAUSE_H_ +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.11 +struct RestartOfAnAssociationWithNewAddressesCauseConfig + : public ParameterConfig { + static constexpr int kType = 11; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class RestartOfAnAssociationWithNewAddressesCause + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = + RestartOfAnAssociationWithNewAddressesCauseConfig::kType; + + explicit RestartOfAnAssociationWithNewAddressesCause( + rtc::ArrayView new_address_tlvs) + : new_address_tlvs_(new_address_tlvs.begin(), new_address_tlvs.end()) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + rtc::ArrayView new_address_tlvs() const { + return new_address_tlvs_; + } + + private: + std::vector new_address_tlvs_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_RESTART_OF_AN_ASSOCIATION_WITH_NEW_ADDRESS_CAUSE_H_ diff --git a/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause_test.cc b/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause_test.cc new file mode 100644 index 0000000000..b8ab8b6803 --- /dev/null +++ b/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause_test.cc @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.h" + +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(RestartOfAnAssociationWithNewAddressesCauseTest, SerializeAndDeserialize) { + uint8_t data[] = {1, 2, 3}; + RestartOfAnAssociationWithNewAddressesCause parameter(data); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + RestartOfAnAssociationWithNewAddressesCause deserialized, + RestartOfAnAssociationWithNewAddressesCause::Parse(serialized)); + + EXPECT_THAT(deserialized.new_address_tlvs(), ElementsAre(1, 2, 3)); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/stale_cookie_error_cause.cc b/net/dcsctp/packet/error_cause/stale_cookie_error_cause.cc new file mode 100644 index 0000000000..d77d8488f1 --- /dev/null +++ b/net/dcsctp/packet/error_cause/stale_cookie_error_cause.cc @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/stale_cookie_error_cause.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.3 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=3 | Cause Length=8 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Measure of Staleness (usec.) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int StaleCookieErrorCause::kType; + +absl::optional StaleCookieErrorCause::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + uint32_t staleness_us = reader->Load32<4>(); + return StaleCookieErrorCause(staleness_us); +} + +void StaleCookieErrorCause::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = AllocateTLV(out); + writer.Store32<4>(staleness_us_); +} + +std::string StaleCookieErrorCause::ToString() const { + rtc::StringBuilder sb; + sb << "Stale Cookie Error, staleness_us=" << staleness_us_; + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/stale_cookie_error_cause.h b/net/dcsctp/packet/error_cause/stale_cookie_error_cause.h new file mode 100644 index 0000000000..d8b7b5b5bd --- /dev/null +++ b/net/dcsctp/packet/error_cause/stale_cookie_error_cause.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_STALE_COOKIE_ERROR_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_STALE_COOKIE_ERROR_CAUSE_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.3 +struct StaleCookieParameterConfig : public ParameterConfig { + static constexpr int kType = 3; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class StaleCookieErrorCause : public Parameter, + public TLVTrait { + public: + static constexpr int kType = StaleCookieParameterConfig::kType; + + explicit StaleCookieErrorCause(uint32_t staleness_us) + : staleness_us_(staleness_us) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + uint16_t staleness_us() const { return staleness_us_; } + + private: + uint32_t staleness_us_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_STALE_COOKIE_ERROR_CAUSE_H_ diff --git a/net/dcsctp/packet/error_cause/stale_cookie_error_cause_test.cc b/net/dcsctp/packet/error_cause/stale_cookie_error_cause_test.cc new file mode 100644 index 0000000000..c0d1ac1c58 --- /dev/null +++ b/net/dcsctp/packet/error_cause/stale_cookie_error_cause_test.cc @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/stale_cookie_error_cause.h" + +#include + +#include +#include + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(StaleCookieErrorCauseTest, SerializeAndDeserialize) { + StaleCookieErrorCause parameter(123); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(StaleCookieErrorCause deserialized, + StaleCookieErrorCause::Parse(serialized)); + + EXPECT_EQ(deserialized.staleness_us(), 123); +} +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.cc b/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.cc new file mode 100644 index 0000000000..04b960d992 --- /dev/null +++ b/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.cc @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h" + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.6 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=6 | Cause Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / Unrecognized Chunk / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int UnrecognizedChunkTypeCause::kType; + +absl::optional UnrecognizedChunkTypeCause::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + std::vector unrecognized_chunk(reader->variable_data().begin(), + reader->variable_data().end()); + return UnrecognizedChunkTypeCause(std::move(unrecognized_chunk)); +} + +void UnrecognizedChunkTypeCause::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = + AllocateTLV(out, unrecognized_chunk_.size()); + writer.CopyToVariableData(unrecognized_chunk_); +} + +std::string UnrecognizedChunkTypeCause::ToString() const { + rtc::StringBuilder sb; + sb << "Unrecognized Chunk Type, chunk_type="; + if (!unrecognized_chunk_.empty()) { + sb << static_cast(unrecognized_chunk_[0]); + } else { + sb << ""; + } + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h b/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h new file mode 100644 index 0000000000..26d3d3b8f9 --- /dev/null +++ b/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_UNRECOGNIZED_CHUNK_TYPE_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_UNRECOGNIZED_CHUNK_TYPE_CAUSE_H_ +#include +#include + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.6 +struct UnrecognizedChunkTypeCauseConfig : public ParameterConfig { + static constexpr int kType = 6; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class UnrecognizedChunkTypeCause + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = UnrecognizedChunkTypeCauseConfig::kType; + + explicit UnrecognizedChunkTypeCause(std::vector unrecognized_chunk) + : unrecognized_chunk_(std::move(unrecognized_chunk)) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + rtc::ArrayView unrecognized_chunk() const { + return unrecognized_chunk_; + } + + private: + std::vector unrecognized_chunk_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_UNRECOGNIZED_CHUNK_TYPE_CAUSE_H_ diff --git a/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause_test.cc b/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause_test.cc new file mode 100644 index 0000000000..baff852f40 --- /dev/null +++ b/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause_test.cc @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h" + +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(UnrecognizedChunkTypeCauseTest, SerializeAndDeserialize) { + UnrecognizedChunkTypeCause parameter({1, 2, 3}); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(UnrecognizedChunkTypeCause deserialized, + UnrecognizedChunkTypeCause::Parse(serialized)); + + EXPECT_THAT(deserialized.unrecognized_chunk(), ElementsAre(1, 2, 3)); +} +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/unrecognized_parameter_cause.cc b/net/dcsctp/packet/error_cause/unrecognized_parameter_cause.cc new file mode 100644 index 0000000000..80001a9eae --- /dev/null +++ b/net/dcsctp/packet/error_cause/unrecognized_parameter_cause.cc @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/unrecognized_parameter_cause.h" + +#include + +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.8 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=8 | Cause Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / Unrecognized Parameters / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int UnrecognizedParametersCause::kType; + +absl::optional UnrecognizedParametersCause::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + return UnrecognizedParametersCause(reader->variable_data()); +} + +void UnrecognizedParametersCause::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = + AllocateTLV(out, unrecognized_parameters_.size()); + writer.CopyToVariableData(unrecognized_parameters_); +} + +std::string UnrecognizedParametersCause::ToString() const { + return "Unrecognized Parameters"; +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/unrecognized_parameter_cause.h b/net/dcsctp/packet/error_cause/unrecognized_parameter_cause.h new file mode 100644 index 0000000000..ebec5ed4c3 --- /dev/null +++ b/net/dcsctp/packet/error_cause/unrecognized_parameter_cause.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_UNRECOGNIZED_PARAMETER_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_UNRECOGNIZED_PARAMETER_CAUSE_H_ +#include +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.8 +struct UnrecognizedParametersCauseConfig : public ParameterConfig { + static constexpr int kType = 8; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class UnrecognizedParametersCause + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = UnrecognizedParametersCauseConfig::kType; + + explicit UnrecognizedParametersCause( + rtc::ArrayView unrecognized_parameters) + : unrecognized_parameters_(unrecognized_parameters.begin(), + unrecognized_parameters.end()) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + rtc::ArrayView unrecognized_parameters() const { + return unrecognized_parameters_; + } + + private: + std::vector unrecognized_parameters_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_UNRECOGNIZED_PARAMETER_CAUSE_H_ diff --git a/net/dcsctp/packet/error_cause/unrecognized_parameter_cause_test.cc b/net/dcsctp/packet/error_cause/unrecognized_parameter_cause_test.cc new file mode 100644 index 0000000000..0449599ca6 --- /dev/null +++ b/net/dcsctp/packet/error_cause/unrecognized_parameter_cause_test.cc @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/unrecognized_parameter_cause.h" + +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(UnrecognizedParametersCauseTest, SerializeAndDeserialize) { + uint8_t unrecognized_parameters[] = {1, 2, 3}; + UnrecognizedParametersCause parameter(unrecognized_parameters); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(UnrecognizedParametersCause deserialized, + UnrecognizedParametersCause::Parse(serialized)); + + EXPECT_THAT(deserialized.unrecognized_parameters(), ElementsAre(1, 2, 3)); +} +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/unresolvable_address_cause.cc b/net/dcsctp/packet/error_cause/unresolvable_address_cause.cc new file mode 100644 index 0000000000..8108d31aa7 --- /dev/null +++ b/net/dcsctp/packet/error_cause/unresolvable_address_cause.cc @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/unresolvable_address_cause.h" + +#include + +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.5 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=5 | Cause Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / Unresolvable Address / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int UnresolvableAddressCause::kType; + +absl::optional UnresolvableAddressCause::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + return UnresolvableAddressCause(reader->variable_data()); +} + +void UnresolvableAddressCause::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = + AllocateTLV(out, unresolvable_address_.size()); + writer.CopyToVariableData(unresolvable_address_); +} + +std::string UnresolvableAddressCause::ToString() const { + return "Unresolvable Address"; +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/unresolvable_address_cause.h b/net/dcsctp/packet/error_cause/unresolvable_address_cause.h new file mode 100644 index 0000000000..c63b3779ef --- /dev/null +++ b/net/dcsctp/packet/error_cause/unresolvable_address_cause.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_UNRESOLVABLE_ADDRESS_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_UNRESOLVABLE_ADDRESS_CAUSE_H_ +#include +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.5 +struct UnresolvableAddressCauseConfig : public ParameterConfig { + static constexpr int kType = 5; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class UnresolvableAddressCause + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = UnresolvableAddressCauseConfig::kType; + + explicit UnresolvableAddressCause( + rtc::ArrayView unresolvable_address) + : unresolvable_address_(unresolvable_address.begin(), + unresolvable_address.end()) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + rtc::ArrayView unresolvable_address() const { + return unresolvable_address_; + } + + private: + std::vector unresolvable_address_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_UNRESOLVABLE_ADDRESS_CAUSE_H_ diff --git a/net/dcsctp/packet/error_cause/unresolvable_address_cause_test.cc b/net/dcsctp/packet/error_cause/unresolvable_address_cause_test.cc new file mode 100644 index 0000000000..688730e6b3 --- /dev/null +++ b/net/dcsctp/packet/error_cause/unresolvable_address_cause_test.cc @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/unresolvable_address_cause.h" + +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(UnresolvableAddressCauseTest, SerializeAndDeserialize) { + uint8_t unresolvable_address[] = {1, 2, 3}; + UnresolvableAddressCause parameter(unresolvable_address); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(UnresolvableAddressCause deserialized, + UnresolvableAddressCause::Parse(serialized)); + + EXPECT_THAT(deserialized.unresolvable_address(), ElementsAre(1, 2, 3)); +} +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/user_initiated_abort_cause.cc b/net/dcsctp/packet/error_cause/user_initiated_abort_cause.cc new file mode 100644 index 0000000000..da99aacbfa --- /dev/null +++ b/net/dcsctp/packet/error_cause/user_initiated_abort_cause.cc @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.12 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=12 | Cause Length=Variable | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / Upper Layer Abort Reason / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int UserInitiatedAbortCause::kType; + +absl::optional UserInitiatedAbortCause::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + if (reader->variable_data().empty()) { + return UserInitiatedAbortCause(""); + } + return UserInitiatedAbortCause( + std::string(reinterpret_cast(reader->variable_data().data()), + reader->variable_data().size())); +} + +void UserInitiatedAbortCause::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = + AllocateTLV(out, upper_layer_abort_reason_.size()); + writer.CopyToVariableData(rtc::MakeArrayView( + reinterpret_cast(upper_layer_abort_reason_.data()), + upper_layer_abort_reason_.size())); +} + +std::string UserInitiatedAbortCause::ToString() const { + rtc::StringBuilder sb; + sb << "User-Initiated Abort, reason=" << upper_layer_abort_reason_; + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/error_cause/user_initiated_abort_cause.h b/net/dcsctp/packet/error_cause/user_initiated_abort_cause.h new file mode 100644 index 0000000000..9eb16657b4 --- /dev/null +++ b/net/dcsctp/packet/error_cause/user_initiated_abort_cause.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_USER_INITIATED_ABORT_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_USER_INITIATED_ABORT_CAUSE_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.12 +struct UserInitiatedAbortCauseConfig : public ParameterConfig { + static constexpr int kType = 12; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class UserInitiatedAbortCause : public Parameter, + public TLVTrait { + public: + static constexpr int kType = UserInitiatedAbortCauseConfig::kType; + + explicit UserInitiatedAbortCause(absl::string_view upper_layer_abort_reason) + : upper_layer_abort_reason_(upper_layer_abort_reason) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + absl::string_view upper_layer_abort_reason() const { + return upper_layer_abort_reason_; + } + + private: + std::string upper_layer_abort_reason_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_USER_INITIATED_ABORT_CAUSE_H_ diff --git a/net/dcsctp/packet/error_cause/user_initiated_abort_cause_test.cc b/net/dcsctp/packet/error_cause/user_initiated_abort_cause_test.cc new file mode 100644 index 0000000000..250959e3df --- /dev/null +++ b/net/dcsctp/packet/error_cause/user_initiated_abort_cause_test.cc @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" + +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::SizeIs; + +TEST(UserInitiatedAbortCauseTest, EmptyReason) { + Parameters causes = + Parameters::Builder().Add(UserInitiatedAbortCause("")).Build(); + + ASSERT_HAS_VALUE_AND_ASSIGN(Parameters deserialized, + Parameters::Parse(causes.data())); + ASSERT_THAT(deserialized.descriptors(), SizeIs(1)); + EXPECT_EQ(deserialized.descriptors()[0].type, UserInitiatedAbortCause::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN( + UserInitiatedAbortCause cause, + UserInitiatedAbortCause::Parse(deserialized.descriptors()[0].data)); + + EXPECT_EQ(cause.upper_layer_abort_reason(), ""); +} + +TEST(UserInitiatedAbortCauseTest, SetReason) { + Parameters causes = Parameters::Builder() + .Add(UserInitiatedAbortCause("User called Close")) + .Build(); + + ASSERT_HAS_VALUE_AND_ASSIGN(Parameters deserialized, + Parameters::Parse(causes.data())); + ASSERT_THAT(deserialized.descriptors(), SizeIs(1)); + EXPECT_EQ(deserialized.descriptors()[0].type, UserInitiatedAbortCause::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN( + UserInitiatedAbortCause cause, + UserInitiatedAbortCause::Parse(deserialized.descriptors()[0].data)); + + EXPECT_EQ(cause.upper_layer_abort_reason(), "User called Close"); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.cc b/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.cc new file mode 100644 index 0000000000..c33e3e11f6 --- /dev/null +++ b/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.cc @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.6 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 18 | Parameter Length = 12 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Request Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Number of new streams | Reserved | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int AddIncomingStreamsRequestParameter::kType; + +absl::optional +AddIncomingStreamsRequestParameter::Parse(rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + ReconfigRequestSN request_sequence_number(reader->Load32<4>()); + uint16_t nbr_of_new_streams = reader->Load16<8>(); + + return AddIncomingStreamsRequestParameter(request_sequence_number, + nbr_of_new_streams); +} + +void AddIncomingStreamsRequestParameter::SerializeTo( + std::vector& out) const { + BoundedByteWriter writer = AllocateTLV(out); + writer.Store32<4>(*request_sequence_number_); + writer.Store16<8>(nbr_of_new_streams_); +} + +std::string AddIncomingStreamsRequestParameter::ToString() const { + rtc::StringBuilder sb; + sb << "Add Incoming Streams Request, req_seq_nbr=" + << *request_sequence_number(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h b/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h new file mode 100644 index 0000000000..3859eb3f7e --- /dev/null +++ b/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_ADD_INCOMING_STREAMS_REQUEST_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_ADD_INCOMING_STREAMS_REQUEST_PARAMETER_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.6 +struct AddIncomingStreamsRequestParameterConfig : ParameterConfig { + static constexpr int kType = 18; + static constexpr size_t kHeaderSize = 12; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class AddIncomingStreamsRequestParameter + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = AddIncomingStreamsRequestParameterConfig::kType; + + explicit AddIncomingStreamsRequestParameter( + ReconfigRequestSN request_sequence_number, + uint16_t nbr_of_new_streams) + : request_sequence_number_(request_sequence_number), + nbr_of_new_streams_(nbr_of_new_streams) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + ReconfigRequestSN request_sequence_number() const { + return request_sequence_number_; + } + uint16_t nbr_of_new_streams() const { return nbr_of_new_streams_; } + + private: + ReconfigRequestSN request_sequence_number_; + uint16_t nbr_of_new_streams_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_ADD_INCOMING_STREAMS_REQUEST_PARAMETER_H_ diff --git a/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter_test.cc b/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter_test.cc new file mode 100644 index 0000000000..a29257a8f8 --- /dev/null +++ b/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter_test.cc @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h" + +#include + +#include +#include + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(AddIncomingStreamsRequestParameterTest, SerializeAndDeserialize) { + AddIncomingStreamsRequestParameter parameter(ReconfigRequestSN(1), 2); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + AddIncomingStreamsRequestParameter deserialized, + AddIncomingStreamsRequestParameter::Parse(serialized)); + + EXPECT_EQ(*deserialized.request_sequence_number(), 1u); + EXPECT_EQ(deserialized.nbr_of_new_streams(), 2u); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.cc b/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.cc new file mode 100644 index 0000000000..4787ee9718 --- /dev/null +++ b/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.cc @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.5 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 17 | Parameter Length = 12 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Request Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Number of new streams | Reserved | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int AddOutgoingStreamsRequestParameter::kType; + +absl::optional +AddOutgoingStreamsRequestParameter::Parse(rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + ReconfigRequestSN request_sequence_number(reader->Load32<4>()); + uint16_t nbr_of_new_streams = reader->Load16<8>(); + + return AddOutgoingStreamsRequestParameter(request_sequence_number, + nbr_of_new_streams); +} + +void AddOutgoingStreamsRequestParameter::SerializeTo( + std::vector& out) const { + BoundedByteWriter writer = AllocateTLV(out); + writer.Store32<4>(*request_sequence_number_); + writer.Store16<8>(nbr_of_new_streams_); +} + +std::string AddOutgoingStreamsRequestParameter::ToString() const { + rtc::StringBuilder sb; + sb << "Add Outgoing Streams Request, req_seq_nbr=" + << *request_sequence_number(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h b/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h new file mode 100644 index 0000000000..01e8f91cfa --- /dev/null +++ b/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_ADD_OUTGOING_STREAMS_REQUEST_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_ADD_OUTGOING_STREAMS_REQUEST_PARAMETER_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.5 +struct AddOutgoingStreamsRequestParameterConfig : ParameterConfig { + static constexpr int kType = 17; + static constexpr size_t kHeaderSize = 12; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class AddOutgoingStreamsRequestParameter + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = AddOutgoingStreamsRequestParameterConfig::kType; + + explicit AddOutgoingStreamsRequestParameter( + ReconfigRequestSN request_sequence_number, + uint16_t nbr_of_new_streams) + : request_sequence_number_(request_sequence_number), + nbr_of_new_streams_(nbr_of_new_streams) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + ReconfigRequestSN request_sequence_number() const { + return request_sequence_number_; + } + uint16_t nbr_of_new_streams() const { return nbr_of_new_streams_; } + + private: + ReconfigRequestSN request_sequence_number_; + uint16_t nbr_of_new_streams_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_ADD_OUTGOING_STREAMS_REQUEST_PARAMETER_H_ diff --git a/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter_test.cc b/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter_test.cc new file mode 100644 index 0000000000..d0303b1ba8 --- /dev/null +++ b/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter_test.cc @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h" + +#include + +#include +#include + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(AddOutgoingStreamsRequestParameterTest, SerializeAndDeserialize) { + AddOutgoingStreamsRequestParameter parameter(ReconfigRequestSN(1), 2); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + AddOutgoingStreamsRequestParameter deserialized, + AddOutgoingStreamsRequestParameter::Parse(serialized)); + + EXPECT_EQ(*deserialized.request_sequence_number(), 1u); + EXPECT_EQ(deserialized.nbr_of_new_streams(), 2u); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/forward_tsn_supported_parameter.cc b/net/dcsctp/packet/parameter/forward_tsn_supported_parameter.cc new file mode 100644 index 0000000000..7dd8e1923f --- /dev/null +++ b/net/dcsctp/packet/parameter/forward_tsn_supported_parameter.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" + +#include + +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc3758#section-3.1 + +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 49152 | Parameter Length = 4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ForwardTsnSupportedParameter::kType; + +absl::optional +ForwardTsnSupportedParameter::Parse(rtc::ArrayView data) { + if (!ParseTLV(data).has_value()) { + return absl::nullopt; + } + return ForwardTsnSupportedParameter(); +} + +void ForwardTsnSupportedParameter::SerializeTo( + std::vector& out) const { + AllocateTLV(out); +} + +std::string ForwardTsnSupportedParameter::ToString() const { + return "Forward TSN Supported"; +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h b/net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h new file mode 100644 index 0000000000..d4cff4ac21 --- /dev/null +++ b/net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_FORWARD_TSN_SUPPORTED_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_FORWARD_TSN_SUPPORTED_PARAMETER_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc3758#section-3.1 +struct ForwardTsnSupportedParameterConfig : ParameterConfig { + static constexpr int kType = 49152; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class ForwardTsnSupportedParameter + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = ForwardTsnSupportedParameterConfig::kType; + + ForwardTsnSupportedParameter() {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_FORWARD_TSN_SUPPORTED_PARAMETER_H_ diff --git a/net/dcsctp/packet/parameter/forward_tsn_supported_parameter_test.cc b/net/dcsctp/packet/parameter/forward_tsn_supported_parameter_test.cc new file mode 100644 index 0000000000..fb4f983fae --- /dev/null +++ b/net/dcsctp/packet/parameter/forward_tsn_supported_parameter_test.cc @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" + +#include + +#include +#include + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(ForwardTsnSupportedParameterTest, SerializeAndDeserialize) { + ForwardTsnSupportedParameter parameter; + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(ForwardTsnSupportedParameter deserialized, + ForwardTsnSupportedParameter::Parse(serialized)); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/heartbeat_info_parameter.cc b/net/dcsctp/packet/parameter/heartbeat_info_parameter.cc new file mode 100644 index 0000000000..918976d305 --- /dev/null +++ b/net/dcsctp/packet/parameter/heartbeat_info_parameter.cc @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.5 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 4 | Chunk Flags | Heartbeat Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Heartbeat Information TLV (Variable-Length) / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int HeartbeatInfoParameter::kType; + +absl::optional HeartbeatInfoParameter::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + return HeartbeatInfoParameter(reader->variable_data()); +} + +void HeartbeatInfoParameter::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = AllocateTLV(out, info_.size()); + writer.CopyToVariableData(info_); +} + +std::string HeartbeatInfoParameter::ToString() const { + rtc::StringBuilder sb; + sb << "Heartbeat Info parameter (info_length=" << info_.size() << ")"; + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/heartbeat_info_parameter.h b/net/dcsctp/packet/parameter/heartbeat_info_parameter.h new file mode 100644 index 0000000000..ec503a94b2 --- /dev/null +++ b/net/dcsctp/packet/parameter/heartbeat_info_parameter.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_HEARTBEAT_INFO_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_HEARTBEAT_INFO_PARAMETER_H_ +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.5 +struct HeartbeatInfoParameterConfig : ParameterConfig { + static constexpr int kType = 1; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class HeartbeatInfoParameter : public Parameter, + public TLVTrait { + public: + static constexpr int kType = HeartbeatInfoParameterConfig::kType; + + explicit HeartbeatInfoParameter(rtc::ArrayView info) + : info_(info.begin(), info.end()) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + rtc::ArrayView info() const { return info_; } + + private: + std::vector info_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_HEARTBEAT_INFO_PARAMETER_H_ diff --git a/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.cc b/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.cc new file mode 100644 index 0000000000..6191adfe9d --- /dev/null +++ b/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.cc @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" + +#include + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.2 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 14 | Parameter Length = 8 + 2 * N | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Request Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Number 1 (optional) | Stream Number 2 (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / ...... / +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Number N-1 (optional) | Stream Number N (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int IncomingSSNResetRequestParameter::kType; + +absl::optional +IncomingSSNResetRequestParameter::Parse(rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + ReconfigRequestSN request_sequence_number(reader->Load32<4>()); + + size_t stream_count = reader->variable_data_size() / kStreamIdSize; + std::vector stream_ids; + stream_ids.reserve(stream_count); + for (size_t i = 0; i < stream_count; ++i) { + BoundedByteReader sub_reader = + reader->sub_reader(i * kStreamIdSize); + + stream_ids.push_back(StreamID(sub_reader.Load16<0>())); + } + + return IncomingSSNResetRequestParameter(request_sequence_number, + std::move(stream_ids)); +} + +void IncomingSSNResetRequestParameter::SerializeTo( + std::vector& out) const { + size_t variable_size = stream_ids_.size() * kStreamIdSize; + BoundedByteWriter writer = AllocateTLV(out, variable_size); + + writer.Store32<4>(*request_sequence_number_); + + for (size_t i = 0; i < stream_ids_.size(); ++i) { + BoundedByteWriter sub_writer = + writer.sub_writer(i * kStreamIdSize); + sub_writer.Store16<0>(*stream_ids_[i]); + } +} + +std::string IncomingSSNResetRequestParameter::ToString() const { + rtc::StringBuilder sb; + sb << "Incoming SSN Reset Request, req_seq_nbr=" + << *request_sequence_number(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h b/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h new file mode 100644 index 0000000000..18963efafc --- /dev/null +++ b/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_INCOMING_SSN_RESET_REQUEST_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_INCOMING_SSN_RESET_REQUEST_PARAMETER_H_ +#include + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.2 +struct IncomingSSNResetRequestParameterConfig : ParameterConfig { + static constexpr int kType = 14; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 2; +}; + +class IncomingSSNResetRequestParameter + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = IncomingSSNResetRequestParameterConfig::kType; + + explicit IncomingSSNResetRequestParameter( + ReconfigRequestSN request_sequence_number, + std::vector stream_ids) + : request_sequence_number_(request_sequence_number), + stream_ids_(std::move(stream_ids)) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + ReconfigRequestSN request_sequence_number() const { + return request_sequence_number_; + } + rtc::ArrayView stream_ids() const { return stream_ids_; } + + private: + static constexpr size_t kStreamIdSize = sizeof(uint16_t); + + ReconfigRequestSN request_sequence_number_; + std::vector stream_ids_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_INCOMING_SSN_RESET_REQUEST_PARAMETER_H_ diff --git a/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter_test.cc b/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter_test.cc new file mode 100644 index 0000000000..17793f6638 --- /dev/null +++ b/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter_test.cc @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" + +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(IncomingSSNResetRequestParameterTest, SerializeAndDeserialize) { + IncomingSSNResetRequestParameter parameter( + ReconfigRequestSN(1), {StreamID(2), StreamID(3), StreamID(4)}); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + IncomingSSNResetRequestParameter deserialized, + IncomingSSNResetRequestParameter::Parse(serialized)); + + EXPECT_EQ(*deserialized.request_sequence_number(), 1u); + EXPECT_THAT(deserialized.stream_ids(), + ElementsAre(StreamID(2), StreamID(3), StreamID(4))); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.cc b/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.cc new file mode 100644 index 0000000000..c25a2426be --- /dev/null +++ b/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.cc @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" + +#include + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/types.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.1 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 13 | Parameter Length = 16 + 2 * N | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Request Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Response Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Sender's Last Assigned TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Number 1 (optional) | Stream Number 2 (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / ...... / +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Number N-1 (optional) | Stream Number N (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int OutgoingSSNResetRequestParameter::kType; + +absl::optional +OutgoingSSNResetRequestParameter::Parse(rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + ReconfigRequestSN request_sequence_number(reader->Load32<4>()); + ReconfigRequestSN response_sequence_number(reader->Load32<8>()); + TSN sender_last_assigned_tsn(reader->Load32<12>()); + + size_t stream_count = reader->variable_data_size() / kStreamIdSize; + std::vector stream_ids; + stream_ids.reserve(stream_count); + for (size_t i = 0; i < stream_count; ++i) { + BoundedByteReader sub_reader = + reader->sub_reader(i * kStreamIdSize); + + stream_ids.push_back(StreamID(sub_reader.Load16<0>())); + } + + return OutgoingSSNResetRequestParameter( + request_sequence_number, response_sequence_number, + sender_last_assigned_tsn, std::move(stream_ids)); +} + +void OutgoingSSNResetRequestParameter::SerializeTo( + std::vector& out) const { + size_t variable_size = stream_ids_.size() * kStreamIdSize; + BoundedByteWriter writer = AllocateTLV(out, variable_size); + + writer.Store32<4>(*request_sequence_number_); + writer.Store32<8>(*response_sequence_number_); + writer.Store32<12>(*sender_last_assigned_tsn_); + + for (size_t i = 0; i < stream_ids_.size(); ++i) { + BoundedByteWriter sub_writer = + writer.sub_writer(i * kStreamIdSize); + sub_writer.Store16<0>(*stream_ids_[i]); + } +} + +std::string OutgoingSSNResetRequestParameter::ToString() const { + rtc::StringBuilder sb; + sb << "Outgoing SSN Reset Request, req_seq_nbr=" << *request_sequence_number() + << ", resp_seq_nbr=" << *response_sequence_number() + << ", sender_last_asg_tsn=" << *sender_last_assigned_tsn(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h b/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h new file mode 100644 index 0000000000..6eb44e079f --- /dev/null +++ b/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_OUTGOING_SSN_RESET_REQUEST_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_OUTGOING_SSN_RESET_REQUEST_PARAMETER_H_ +#include +#include + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.1 +struct OutgoingSSNResetRequestParameterConfig : ParameterConfig { + static constexpr int kType = 13; + static constexpr size_t kHeaderSize = 16; + static constexpr size_t kVariableLengthAlignment = 2; +}; + +class OutgoingSSNResetRequestParameter + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = OutgoingSSNResetRequestParameterConfig::kType; + + explicit OutgoingSSNResetRequestParameter( + ReconfigRequestSN request_sequence_number, + ReconfigRequestSN response_sequence_number, + TSN sender_last_assigned_tsn, + std::vector stream_ids) + : request_sequence_number_(request_sequence_number), + response_sequence_number_(response_sequence_number), + sender_last_assigned_tsn_(sender_last_assigned_tsn), + stream_ids_(std::move(stream_ids)) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + ReconfigRequestSN request_sequence_number() const { + return request_sequence_number_; + } + ReconfigRequestSN response_sequence_number() const { + return response_sequence_number_; + } + TSN sender_last_assigned_tsn() const { return sender_last_assigned_tsn_; } + rtc::ArrayView stream_ids() const { return stream_ids_; } + + private: + static constexpr size_t kStreamIdSize = sizeof(uint16_t); + + ReconfigRequestSN request_sequence_number_; + ReconfigRequestSN response_sequence_number_; + TSN sender_last_assigned_tsn_; + std::vector stream_ids_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_OUTGOING_SSN_RESET_REQUEST_PARAMETER_H_ diff --git a/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter_test.cc b/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter_test.cc new file mode 100644 index 0000000000..dae73c2fba --- /dev/null +++ b/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter_test.cc @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" + +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(OutgoingSSNResetRequestParameterTest, SerializeAndDeserialize) { + OutgoingSSNResetRequestParameter parameter( + ReconfigRequestSN(1), ReconfigRequestSN(2), TSN(3), + {StreamID(4), StreamID(5), StreamID(6)}); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter deserialized, + OutgoingSSNResetRequestParameter::Parse(serialized)); + + EXPECT_EQ(*deserialized.request_sequence_number(), 1u); + EXPECT_EQ(*deserialized.response_sequence_number(), 2u); + EXPECT_EQ(*deserialized.sender_last_assigned_tsn(), 3u); + EXPECT_THAT(deserialized.stream_ids(), + ElementsAre(StreamID(4), StreamID(5), StreamID(6))); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/parameter.cc b/net/dcsctp/packet/parameter/parameter.cc new file mode 100644 index 0000000000..b3b2bffef7 --- /dev/null +++ b/net/dcsctp/packet/parameter/parameter.cc @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/parameter.h" + +#include + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h" +#include "net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h" +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/state_cookie_parameter.h" +#include "net/dcsctp/packet/parameter/supported_extensions_parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +constexpr size_t kParameterHeaderSize = 4; + +Parameters::Builder& Parameters::Builder::Add(const Parameter& p) { + // https://tools.ietf.org/html/rfc4960#section-3.2.1 + // "If the length of the parameter is not a multiple of 4 bytes, the sender + // pads the parameter at the end (i.e., after the Parameter Value field) with + // all zero bytes." + if (data_.size() % 4 != 0) { + data_.resize(RoundUpTo4(data_.size())); + } + + p.SerializeTo(data_); + return *this; +} + +std::vector Parameters::descriptors() const { + rtc::ArrayView span(data_); + std::vector result; + while (!span.empty()) { + BoundedByteReader header(span); + uint16_t type = header.Load16<0>(); + uint16_t length = header.Load16<2>(); + result.emplace_back(type, span.subview(0, length)); + size_t length_with_padding = RoundUpTo4(length); + if (length_with_padding > span.size()) { + break; + } + span = span.subview(length_with_padding); + } + return result; +} + +absl::optional Parameters::Parse( + rtc::ArrayView data) { + // Validate the parameter descriptors + rtc::ArrayView span(data); + while (!span.empty()) { + if (span.size() < kParameterHeaderSize) { + RTC_DLOG(LS_WARNING) << "Insufficient parameter length"; + return absl::nullopt; + } + BoundedByteReader header(span); + uint16_t length = header.Load16<2>(); + if (length < kParameterHeaderSize || length > span.size()) { + RTC_DLOG(LS_WARNING) << "Invalid parameter length field"; + return absl::nullopt; + } + size_t length_with_padding = RoundUpTo4(length); + if (length_with_padding > span.size()) { + break; + } + span = span.subview(length_with_padding); + } + return Parameters(std::vector(data.begin(), data.end())); +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/parameter.h b/net/dcsctp/packet/parameter/parameter.h new file mode 100644 index 0000000000..e8fa67c8f7 --- /dev/null +++ b/net/dcsctp/packet/parameter/parameter.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_PARAMETER_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +class Parameter { + public: + Parameter() {} + virtual ~Parameter() = default; + + Parameter(const Parameter& other) = default; + Parameter& operator=(const Parameter& other) = default; + + virtual void SerializeTo(std::vector& out) const = 0; + virtual std::string ToString() const = 0; +}; + +struct ParameterDescriptor { + ParameterDescriptor(uint16_t type, rtc::ArrayView data) + : type(type), data(data) {} + uint16_t type; + rtc::ArrayView data; +}; + +class Parameters { + public: + class Builder { + public: + Builder() {} + Builder& Add(const Parameter& p); + Parameters Build() { return Parameters(std::move(data_)); } + + private: + std::vector data_; + }; + + static absl::optional Parse(rtc::ArrayView data); + + Parameters() {} + Parameters(Parameters&& other) = default; + Parameters& operator=(Parameters&& other) = default; + + rtc::ArrayView data() const { return data_; } + std::vector descriptors() const; + + template + absl::optional

get() const { + static_assert(std::is_base_of::value, + "Template parameter not derived from Parameter"); + for (const auto& p : descriptors()) { + if (p.type == P::kType) { + return P::Parse(p.data); + } + } + return absl::nullopt; + } + + private: + explicit Parameters(std::vector data) : data_(std::move(data)) {} + std::vector data_; +}; + +struct ParameterConfig { + static constexpr int kTypeSizeInBytes = 2; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_PARAMETER_H_ diff --git a/net/dcsctp/packet/parameter/parameter_test.cc b/net/dcsctp/packet/parameter/parameter_test.cc new file mode 100644 index 0000000000..467e324592 --- /dev/null +++ b/net/dcsctp/packet/parameter/parameter_test.cc @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/parameter.h" + +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::SizeIs; + +TEST(ParameterTest, SerializeDeserializeParameter) { + Parameters parameters = + Parameters::Builder() + .Add(OutgoingSSNResetRequestParameter(ReconfigRequestSN(123), + ReconfigRequestSN(456), + TSN(789), {StreamID(42)})) + .Build(); + + rtc::ArrayView serialized = parameters.data(); + + ASSERT_HAS_VALUE_AND_ASSIGN(Parameters parsed, Parameters::Parse(serialized)); + auto descriptors = parsed.descriptors(); + ASSERT_THAT(descriptors, SizeIs(1)); + EXPECT_THAT(descriptors[0].type, OutgoingSSNResetRequestParameter::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter parsed_param, + OutgoingSSNResetRequestParameter::Parse(descriptors[0].data)); + EXPECT_EQ(*parsed_param.request_sequence_number(), 123u); + EXPECT_EQ(*parsed_param.response_sequence_number(), 456u); + EXPECT_EQ(*parsed_param.sender_last_assigned_tsn(), 789u); + EXPECT_THAT(parsed_param.stream_ids(), ElementsAre(StreamID(42))); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/reconfiguration_response_parameter.cc b/net/dcsctp/packet/parameter/reconfiguration_response_parameter.cc new file mode 100644 index 0000000000..fafb204acc --- /dev/null +++ b/net/dcsctp/packet/parameter/reconfiguration_response_parameter.cc @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" + +#include +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.4 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 16 | Parameter Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Response Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Result | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Sender's Next TSN (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Receiver's Next TSN (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ReconfigurationResponseParameter::kType; + +absl::string_view ToString(ReconfigurationResponseParameter::Result result) { + switch (result) { + case ReconfigurationResponseParameter::Result::kSuccessNothingToDo: + return "Success: nothing to do"; + case ReconfigurationResponseParameter::Result::kSuccessPerformed: + return "Success: performed"; + case ReconfigurationResponseParameter::Result::kDenied: + return "Denied"; + case ReconfigurationResponseParameter::Result::kErrorWrongSSN: + return "Error: wrong ssn"; + case ReconfigurationResponseParameter::Result:: + kErrorRequestAlreadyInProgress: + return "Error: request already in progress"; + case ReconfigurationResponseParameter::Result::kErrorBadSequenceNumber: + return "Error: bad sequence number"; + case ReconfigurationResponseParameter::Result::kInProgress: + return "In progress"; + } +} + +absl::optional +ReconfigurationResponseParameter::Parse(rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + ReconfigRequestSN response_sequence_number(reader->Load32<4>()); + Result result; + uint32_t result_nbr = reader->Load32<8>(); + switch (result_nbr) { + case 0: + result = ReconfigurationResponseParameter::Result::kSuccessNothingToDo; + break; + case 1: + result = ReconfigurationResponseParameter::Result::kSuccessPerformed; + break; + case 2: + result = ReconfigurationResponseParameter::Result::kDenied; + break; + case 3: + result = ReconfigurationResponseParameter::Result::kErrorWrongSSN; + break; + case 4: + result = ReconfigurationResponseParameter::Result:: + kErrorRequestAlreadyInProgress; + break; + case 5: + result = + ReconfigurationResponseParameter::Result::kErrorBadSequenceNumber; + break; + case 6: + result = ReconfigurationResponseParameter::Result::kInProgress; + break; + default: + RTC_DLOG(LS_WARNING) << "Invalid reconfig response result: " + << result_nbr; + return absl::nullopt; + } + + if (reader->variable_data().empty()) { + return ReconfigurationResponseParameter(response_sequence_number, result); + } else if (reader->variable_data_size() != kNextTsnHeaderSize) { + RTC_DLOG(LS_WARNING) << "Invalid parameter size"; + return absl::nullopt; + } + + BoundedByteReader sub_reader = + reader->sub_reader(0); + + TSN sender_next_tsn(sub_reader.Load32<0>()); + TSN receiver_next_tsn(sub_reader.Load32<4>()); + + return ReconfigurationResponseParameter(response_sequence_number, result, + sender_next_tsn, receiver_next_tsn); +} + +void ReconfigurationResponseParameter::SerializeTo( + std::vector& out) const { + size_t variable_size = + (sender_next_tsn().has_value() ? kNextTsnHeaderSize : 0); + BoundedByteWriter writer = AllocateTLV(out, variable_size); + + writer.Store32<4>(*response_sequence_number_); + uint32_t result_nbr = + static_cast::type>(result_); + writer.Store32<8>(result_nbr); + + if (sender_next_tsn().has_value()) { + BoundedByteWriter sub_writer = + writer.sub_writer(0); + + sub_writer.Store32<0>(sender_next_tsn_.has_value() ? **sender_next_tsn_ + : 0); + sub_writer.Store32<4>(receiver_next_tsn_.has_value() ? **receiver_next_tsn_ + : 0); + } +} + +std::string ReconfigurationResponseParameter::ToString() const { + rtc::StringBuilder sb; + sb << "Re-configuration Response, resp_seq_nbr=" + << *response_sequence_number(); + return sb.Release(); +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/reconfiguration_response_parameter.h b/net/dcsctp/packet/parameter/reconfiguration_response_parameter.h new file mode 100644 index 0000000000..c5a68acb33 --- /dev/null +++ b/net/dcsctp/packet/parameter/reconfiguration_response_parameter.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_RECONFIGURATION_RESPONSE_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_RECONFIGURATION_RESPONSE_PARAMETER_H_ +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.4 +struct ReconfigurationResponseParameterConfig : ParameterConfig { + static constexpr int kType = 16; + static constexpr size_t kHeaderSize = 12; + static constexpr size_t kVariableLengthAlignment = 4; +}; + +class ReconfigurationResponseParameter + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = ReconfigurationResponseParameterConfig::kType; + + enum class Result { + kSuccessNothingToDo = 0, + kSuccessPerformed = 1, + kDenied = 2, + kErrorWrongSSN = 3, + kErrorRequestAlreadyInProgress = 4, + kErrorBadSequenceNumber = 5, + kInProgress = 6, + }; + + ReconfigurationResponseParameter(ReconfigRequestSN response_sequence_number, + Result result) + : response_sequence_number_(response_sequence_number), + result_(result), + sender_next_tsn_(absl::nullopt), + receiver_next_tsn_(absl::nullopt) {} + + explicit ReconfigurationResponseParameter( + ReconfigRequestSN response_sequence_number, + Result result, + TSN sender_next_tsn, + TSN receiver_next_tsn) + : response_sequence_number_(response_sequence_number), + result_(result), + sender_next_tsn_(sender_next_tsn), + receiver_next_tsn_(receiver_next_tsn) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + ReconfigRequestSN response_sequence_number() const { + return response_sequence_number_; + } + Result result() const { return result_; } + absl::optional sender_next_tsn() const { return sender_next_tsn_; } + absl::optional receiver_next_tsn() const { return receiver_next_tsn_; } + + private: + static constexpr size_t kNextTsnHeaderSize = 8; + ReconfigRequestSN response_sequence_number_; + Result result_; + absl::optional sender_next_tsn_; + absl::optional receiver_next_tsn_; +}; + +absl::string_view ToString(ReconfigurationResponseParameter::Result result); + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_RECONFIGURATION_RESPONSE_PARAMETER_H_ diff --git a/net/dcsctp/packet/parameter/reconfiguration_response_parameter_test.cc b/net/dcsctp/packet/parameter/reconfiguration_response_parameter_test.cc new file mode 100644 index 0000000000..8125d93cd0 --- /dev/null +++ b/net/dcsctp/packet/parameter/reconfiguration_response_parameter_test.cc @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" + +#include + +#include +#include + +#include "absl/types/optional.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(ReconfigurationResponseParameterTest, SerializeAndDeserializeFirstForm) { + ReconfigurationResponseParameter parameter( + ReconfigRequestSN(1), + ReconfigurationResponseParameter::Result::kSuccessPerformed); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + ReconfigurationResponseParameter deserialized, + ReconfigurationResponseParameter::Parse(serialized)); + + EXPECT_EQ(*deserialized.response_sequence_number(), 1u); + EXPECT_EQ(deserialized.result(), + ReconfigurationResponseParameter::Result::kSuccessPerformed); + EXPECT_EQ(deserialized.sender_next_tsn(), absl::nullopt); + EXPECT_EQ(deserialized.receiver_next_tsn(), absl::nullopt); +} + +TEST(ReconfigurationResponseParameterTest, + SerializeAndDeserializeFirstFormSecondForm) { + ReconfigurationResponseParameter parameter( + ReconfigRequestSN(1), + ReconfigurationResponseParameter::Result::kSuccessPerformed, TSN(2), + TSN(3)); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + ReconfigurationResponseParameter deserialized, + ReconfigurationResponseParameter::Parse(serialized)); + + EXPECT_EQ(*deserialized.response_sequence_number(), 1u); + EXPECT_EQ(deserialized.result(), + ReconfigurationResponseParameter::Result::kSuccessPerformed); + EXPECT_TRUE(deserialized.sender_next_tsn().has_value()); + EXPECT_EQ(**deserialized.sender_next_tsn(), 2u); + EXPECT_TRUE(deserialized.receiver_next_tsn().has_value()); + EXPECT_EQ(**deserialized.receiver_next_tsn(), 3u); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.cc b/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.cc new file mode 100644 index 0000000000..d656e0db8f --- /dev/null +++ b/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.cc @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.3 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 15 | Parameter Length = 8 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Request Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int SSNTSNResetRequestParameter::kType; + +absl::optional SSNTSNResetRequestParameter::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + ReconfigRequestSN request_sequence_number(reader->Load32<4>()); + + return SSNTSNResetRequestParameter(request_sequence_number); +} + +void SSNTSNResetRequestParameter::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = AllocateTLV(out); + writer.Store32<4>(*request_sequence_number_); +} + +std::string SSNTSNResetRequestParameter::ToString() const { + rtc::StringBuilder sb; + sb << "SSN/TSN Reset Request, req_seq_nbr=" << *request_sequence_number(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h b/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h new file mode 100644 index 0000000000..e31d7ebe8f --- /dev/null +++ b/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_SSN_TSN_RESET_REQUEST_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_SSN_TSN_RESET_REQUEST_PARAMETER_H_ +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.3 +struct SSNTSNResetRequestParameterConfig : ParameterConfig { + static constexpr int kType = 15; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class SSNTSNResetRequestParameter + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = SSNTSNResetRequestParameterConfig::kType; + + explicit SSNTSNResetRequestParameter( + ReconfigRequestSN request_sequence_number) + : request_sequence_number_(request_sequence_number) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + ReconfigRequestSN request_sequence_number() const { + return request_sequence_number_; + } + + private: + ReconfigRequestSN request_sequence_number_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_SSN_TSN_RESET_REQUEST_PARAMETER_H_ diff --git a/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter_test.cc b/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter_test.cc new file mode 100644 index 0000000000..eeb973cbcb --- /dev/null +++ b/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter_test.cc @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h" + +#include + +#include +#include + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(SSNTSNResetRequestParameterTest, SerializeAndDeserialize) { + SSNTSNResetRequestParameter parameter(ReconfigRequestSN(1)); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(SSNTSNResetRequestParameter deserialized, + SSNTSNResetRequestParameter::Parse(serialized)); + + EXPECT_EQ(*deserialized.request_sequence_number(), 1u); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/state_cookie_parameter.cc b/net/dcsctp/packet/parameter/state_cookie_parameter.cc new file mode 100644 index 0000000000..9777aa6667 --- /dev/null +++ b/net/dcsctp/packet/parameter/state_cookie_parameter.cc @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/state_cookie_parameter.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.3.1 + +constexpr int StateCookieParameter::kType; + +absl::optional StateCookieParameter::Parse( + rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + return StateCookieParameter(reader->variable_data()); +} + +void StateCookieParameter::SerializeTo(std::vector& out) const { + BoundedByteWriter writer = AllocateTLV(out, data_.size()); + writer.CopyToVariableData(data_); +} + +std::string StateCookieParameter::ToString() const { + rtc::StringBuilder sb; + sb << "State Cookie parameter (cookie_length=" << data_.size() << ")"; + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/state_cookie_parameter.h b/net/dcsctp/packet/parameter/state_cookie_parameter.h new file mode 100644 index 0000000000..f4355495e2 --- /dev/null +++ b/net/dcsctp/packet/parameter/state_cookie_parameter.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_STATE_COOKIE_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_STATE_COOKIE_PARAMETER_H_ +#include +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.3.1 +struct StateCookieParameterConfig : ParameterConfig { + static constexpr int kType = 7; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class StateCookieParameter : public Parameter, + public TLVTrait { + public: + static constexpr int kType = StateCookieParameterConfig::kType; + + explicit StateCookieParameter(rtc::ArrayView data) + : data_(data.begin(), data.end()) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + rtc::ArrayView data() const { return data_; } + + private: + std::vector data_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_STATE_COOKIE_PARAMETER_H_ diff --git a/net/dcsctp/packet/parameter/state_cookie_parameter_test.cc b/net/dcsctp/packet/parameter/state_cookie_parameter_test.cc new file mode 100644 index 0000000000..bcca38b586 --- /dev/null +++ b/net/dcsctp/packet/parameter/state_cookie_parameter_test.cc @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/state_cookie_parameter.h" + +#include + +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(StateCookieParameterTest, SerializeAndDeserialize) { + uint8_t cookie[] = {1, 2, 3}; + StateCookieParameter parameter(cookie); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(StateCookieParameter deserialized, + StateCookieParameter::Parse(serialized)); + + EXPECT_THAT(deserialized.data(), ElementsAre(1, 2, 3)); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/supported_extensions_parameter.cc b/net/dcsctp/packet/parameter/supported_extensions_parameter.cc new file mode 100644 index 0000000000..6a8fb214de --- /dev/null +++ b/net/dcsctp/packet/parameter/supported_extensions_parameter.cc @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/supported_extensions_parameter.h" + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc5061#section-4.2.7 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 0x8008 | Parameter Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | CHUNK TYPE 1 | CHUNK TYPE 2 | CHUNK TYPE 3 | CHUNK TYPE 4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | .... | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | CHUNK TYPE N | PAD | PAD | PAD | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int SupportedExtensionsParameter::kType; + +absl::optional +SupportedExtensionsParameter::Parse(rtc::ArrayView data) { + absl::optional> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + std::vector chunk_types(reader->variable_data().begin(), + reader->variable_data().end()); + return SupportedExtensionsParameter(std::move(chunk_types)); +} + +void SupportedExtensionsParameter::SerializeTo( + std::vector& out) const { + BoundedByteWriter writer = AllocateTLV(out, chunk_types_.size()); + writer.CopyToVariableData(chunk_types_); +} + +std::string SupportedExtensionsParameter::ToString() const { + rtc::StringBuilder sb; + sb << "Supported Extensions (" << StrJoin(chunk_types_, ", ") << ")"; + return sb.Release(); +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/parameter/supported_extensions_parameter.h b/net/dcsctp/packet/parameter/supported_extensions_parameter.h new file mode 100644 index 0000000000..5689fd8035 --- /dev/null +++ b/net/dcsctp/packet/parameter/supported_extensions_parameter.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_SUPPORTED_EXTENSIONS_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_SUPPORTED_EXTENSIONS_PARAMETER_H_ +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc5061#section-4.2.7 +struct SupportedExtensionsParameterConfig : ParameterConfig { + static constexpr int kType = 0x8008; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class SupportedExtensionsParameter + : public Parameter, + public TLVTrait { + public: + static constexpr int kType = SupportedExtensionsParameterConfig::kType; + + explicit SupportedExtensionsParameter(std::vector chunk_types) + : chunk_types_(std::move(chunk_types)) {} + + static absl::optional Parse( + rtc::ArrayView data); + + void SerializeTo(std::vector& out) const override; + std::string ToString() const override; + + bool supports(uint8_t chunk_type) const { + return std::find(chunk_types_.begin(), chunk_types_.end(), chunk_type) != + chunk_types_.end(); + } + + rtc::ArrayView chunk_types() const { return chunk_types_; } + + private: + std::vector chunk_types_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_SUPPORTED_EXTENSIONS_PARAMETER_H_ diff --git a/net/dcsctp/packet/parameter/supported_extensions_parameter_test.cc b/net/dcsctp/packet/parameter/supported_extensions_parameter_test.cc new file mode 100644 index 0000000000..c870af2e70 --- /dev/null +++ b/net/dcsctp/packet/parameter/supported_extensions_parameter_test.cc @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/supported_extensions_parameter.h" + +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(SupportedExtensionsParameterTest, SerializeAndDeserialize) { + SupportedExtensionsParameter parameter({1, 2, 3}); + + std::vector serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(SupportedExtensionsParameter deserialized, + SupportedExtensionsParameter::Parse(serialized)); + + EXPECT_THAT(deserialized.chunk_types(), ElementsAre(1, 2, 3)); + EXPECT_TRUE(deserialized.supports(1)); + EXPECT_TRUE(deserialized.supports(2)); + EXPECT_TRUE(deserialized.supports(3)); + EXPECT_FALSE(deserialized.supports(4)); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/sctp_packet.cc b/net/dcsctp/packet/sctp_packet.cc new file mode 100644 index 0000000000..3e419c5978 --- /dev/null +++ b/net/dcsctp/packet/sctp_packet.cc @@ -0,0 +1,184 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/sctp_packet.h" + +#include + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/crc32c.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_format.h" + +namespace dcsctp { +namespace { +constexpr size_t kMaxUdpPacketSize = 65535; +constexpr size_t kChunkTlvHeaderSize = 4; +constexpr size_t kExpectedDescriptorCount = 4; +} // namespace + +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Source Port Number | Destination Port Number | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Verification Tag | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Checksum | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ + +SctpPacket::Builder::Builder(VerificationTag verification_tag, + const DcSctpOptions& options) + : verification_tag_(verification_tag), + source_port_(options.local_port), + dest_port_(options.remote_port), + max_packet_size_(RoundDownTo4(options.mtu)) {} + +SctpPacket::Builder& SctpPacket::Builder::Add(const Chunk& chunk) { + if (out_.empty()) { + out_.reserve(max_packet_size_); + out_.resize(SctpPacket::kHeaderSize); + BoundedByteWriter buffer(out_); + buffer.Store16<0>(source_port_); + buffer.Store16<2>(dest_port_); + buffer.Store32<4>(*verification_tag_); + // Checksum is at offset 8 - written when calling Build(); + } + RTC_DCHECK(IsDivisibleBy4(out_.size())); + + chunk.SerializeTo(out_); + if (out_.size() % 4 != 0) { + out_.resize(RoundUpTo4(out_.size())); + } + + RTC_DCHECK(out_.size() <= max_packet_size_) + << "Exceeded max size, data=" << out_.size() + << ", max_size=" << max_packet_size_; + return *this; +} + +size_t SctpPacket::Builder::bytes_remaining() const { + if (out_.empty()) { + // The packet header (CommonHeader) hasn't been written yet: + return max_packet_size_ - kHeaderSize; + } else if (out_.size() > max_packet_size_) { + RTC_NOTREACHED() << "Exceeded max size, data=" << out_.size() + << ", max_size=" << max_packet_size_; + return 0; + } + return max_packet_size_ - out_.size(); +} + +std::vector SctpPacket::Builder::Build() { + std::vector out; + out_.swap(out); + + if (!out.empty()) { + uint32_t crc = GenerateCrc32C(out); + BoundedByteWriter(out).Store32<8>(crc); + } + + RTC_DCHECK(out.size() <= max_packet_size_) + << "Exceeded max size, data=" << out.size() + << ", max_size=" << max_packet_size_; + + return out; +} + +absl::optional SctpPacket::Parse( + rtc::ArrayView data, + bool disable_checksum_verification) { + if (data.size() < kHeaderSize + kChunkTlvHeaderSize || + data.size() > kMaxUdpPacketSize) { + RTC_DLOG(LS_WARNING) << "Invalid packet size"; + return absl::nullopt; + } + + BoundedByteReader reader(data); + + CommonHeader common_header; + common_header.source_port = reader.Load16<0>(); + common_header.destination_port = reader.Load16<2>(); + common_header.verification_tag = VerificationTag(reader.Load32<4>()); + common_header.checksum = reader.Load32<8>(); + + // Create a copy of the packet, which will be held by this object. + std::vector data_copy = + std::vector(data.begin(), data.end()); + + // Verify the checksum. The checksum field must be zero when that's done. + BoundedByteWriter(data_copy).Store32<8>(0); + uint32_t calculated_checksum = GenerateCrc32C(data_copy); + if (!disable_checksum_verification && + calculated_checksum != common_header.checksum) { + RTC_DLOG(LS_WARNING) << rtc::StringFormat( + "Invalid packet checksum, packet_checksum=0x%08x, " + "calculated_checksum=0x%08x", + common_header.checksum, calculated_checksum); + return absl::nullopt; + } + // Restore the checksum in the header. + BoundedByteWriter(data_copy).Store32<8>(common_header.checksum); + + // Validate and parse the chunk headers in the message. + /* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Chunk Type | Chunk Flags | Chunk Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + std::vector descriptors; + descriptors.reserve(kExpectedDescriptorCount); + rtc::ArrayView descriptor_data = + rtc::ArrayView(data_copy).subview(kHeaderSize); + while (!descriptor_data.empty()) { + if (descriptor_data.size() < kChunkTlvHeaderSize) { + RTC_DLOG(LS_WARNING) << "Too small chunk"; + return absl::nullopt; + } + BoundedByteReader chunk_header(descriptor_data); + uint8_t type = chunk_header.Load8<0>(); + uint8_t flags = chunk_header.Load8<1>(); + uint16_t length = chunk_header.Load16<2>(); + uint16_t padded_length = RoundUpTo4(length); + if (padded_length > descriptor_data.size()) { + RTC_DLOG(LS_WARNING) << "Too large chunk. length=" << length + << ", remaining=" << descriptor_data.size(); + return absl::nullopt; + } else if (padded_length < kChunkTlvHeaderSize) { + RTC_DLOG(LS_WARNING) << "Too small chunk. length=" << length; + return absl::nullopt; + } + descriptors.emplace_back(type, flags, + descriptor_data.subview(0, padded_length)); + descriptor_data = descriptor_data.subview(padded_length); + } + + // Note that iterators (and pointer) are guaranteed to be stable when moving a + // std::vector, and `descriptors` have pointers to within `data_copy`. + return SctpPacket(common_header, std::move(data_copy), + std::move(descriptors)); +} +} // namespace dcsctp diff --git a/net/dcsctp/packet/sctp_packet.h b/net/dcsctp/packet/sctp_packet.h new file mode 100644 index 0000000000..2600caf7a9 --- /dev/null +++ b/net/dcsctp/packet/sctp_packet.h @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_SCTP_PACKET_H_ +#define NET_DCSCTP_PACKET_SCTP_PACKET_H_ + +#include + +#include +#include +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/public/dcsctp_options.h" + +namespace dcsctp { + +// The "Common Header", which every SCTP packet starts with, and is described in +// https://tools.ietf.org/html/rfc4960#section-3.1. +struct CommonHeader { + uint16_t source_port; + uint16_t destination_port; + VerificationTag verification_tag; + uint32_t checksum; +}; + +// Represents an immutable (received or to-be-sent) SCTP packet. +class SctpPacket { + public: + static constexpr size_t kHeaderSize = 12; + + struct ChunkDescriptor { + ChunkDescriptor(uint8_t type, + uint8_t flags, + rtc::ArrayView data) + : type(type), flags(flags), data(data) {} + uint8_t type; + uint8_t flags; + rtc::ArrayView data; + }; + + SctpPacket(SctpPacket&& other) = default; + SctpPacket& operator=(SctpPacket&& other) = default; + SctpPacket(const SctpPacket&) = delete; + SctpPacket& operator=(const SctpPacket&) = delete; + + // Used for building SctpPacket, as those are immutable. + class Builder { + public: + Builder(VerificationTag verification_tag, const DcSctpOptions& options); + + Builder(Builder&& other) = default; + Builder& operator=(Builder&& other) = default; + + // Adds a chunk to the to-be-built SCTP packet. + Builder& Add(const Chunk& chunk); + + // The number of bytes remaining in the packet for chunk storage until the + // packet reaches its maximum size. + size_t bytes_remaining() const; + + // Indicates if any packets have been added to the builder. + bool empty() const { return out_.empty(); } + + // Returns the payload of the build SCTP packet. The Builder will be cleared + // after having called this function, and can be used to build a new packet. + std::vector Build(); + + private: + void WritePacketHeader(); + VerificationTag verification_tag_; + uint16_t source_port_; + uint16_t dest_port_; + // The maximum packet size is always even divisible by four, as chunks are + // always padded to a size even divisible by four. + size_t max_packet_size_; + std::vector out_; + }; + + // Parses `data` as an SCTP packet and returns it if it validates. + static absl::optional Parse( + rtc::ArrayView data, + bool disable_checksum_verification = false); + + // Returns the SCTP common header. + const CommonHeader& common_header() const { return common_header_; } + + // Returns the chunks (types and offsets) within the packet. + rtc::ArrayView descriptors() const { + return descriptors_; + } + + private: + SctpPacket(const CommonHeader& common_header, + std::vector data, + std::vector descriptors) + : common_header_(common_header), + data_(std::move(data)), + descriptors_(std::move(descriptors)) {} + + CommonHeader common_header_; + + // As the `descriptors_` refer to offset within data, and since SctpPacket is + // movable, `data` needs to be pointer stable, which it is according to + // http://www.open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2321 + std::vector data_; + // The chunks and their offsets within `data_ `. + std::vector descriptors_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_SCTP_PACKET_H_ diff --git a/net/dcsctp/packet/sctp_packet_test.cc b/net/dcsctp/packet/sctp_packet_test.cc new file mode 100644 index 0000000000..7438315eec --- /dev/null +++ b/net/dcsctp/packet/sctp_packet_test.cc @@ -0,0 +1,342 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/sctp_packet.h" + +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/chunk/abort_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/init_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::SizeIs; + +constexpr VerificationTag kVerificationTag = VerificationTag(0x12345678); + +TEST(SctpPacketTest, DeserializeSimplePacketFromCapture) { + /* + Stream Control Transmission Protocol, Src Port: 5000 (5000), Dst Port: 5000 + (5000) Source port: 5000 Destination port: 5000 Verification tag: 0x00000000 + [Association index: 1] + Checksum: 0xaa019d33 [unverified] + [Checksum Status: Unverified] + INIT chunk (Outbound streams: 1000, inbound streams: 1000) + Chunk type: INIT (1) + Chunk flags: 0x00 + Chunk length: 90 + Initiate tag: 0x0eddca08 + Advertised receiver window credit (a_rwnd): 131072 + Number of outbound streams: 1000 + Number of inbound streams: 1000 + Initial TSN: 1426601527 + ECN parameter + Parameter type: ECN (0x8000) + Parameter length: 4 + Forward TSN supported parameter + Parameter type: Forward TSN supported (0xc000) + Parameter length: 4 + Supported Extensions parameter (Supported types: FORWARD_TSN, AUTH, + ASCONF, ASCONF_ACK, RE_CONFIG) Parameter type: Supported Extensions + (0x8008) Parameter length: 9 Supported chunk type: FORWARD_TSN (192) Supported + chunk type: AUTH (15) Supported chunk type: ASCONF (193) Supported chunk type: + ASCONF_ACK (128) Supported chunk type: RE_CONFIG (130) Parameter padding: + 000000 Random parameter Parameter type: Random (0x8002) Parameter length: 36 + Random number: c5a86155090e6f420050634cc8d6b908dfd53e17c99cb143… + Requested HMAC Algorithm parameter (Supported HMACs: SHA-1) + Parameter type: Requested HMAC Algorithm (0x8004) + Parameter length: 6 + HMAC identifier: SHA-1 (1) + Parameter padding: 0000 + Authenticated Chunk list parameter (Chunk types to be authenticated: + ASCONF_ACK, ASCONF) Parameter type: Authenticated Chunk list + (0x8003) Parameter length: 6 Chunk type: ASCONF_ACK (128) Chunk type: ASCONF + (193) Chunk padding: 0000 + */ + + uint8_t data[] = { + 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0xaa, 0x01, 0x9d, 0x33, + 0x01, 0x00, 0x00, 0x5a, 0x0e, 0xdd, 0xca, 0x08, 0x00, 0x02, 0x00, 0x00, + 0x03, 0xe8, 0x03, 0xe8, 0x55, 0x08, 0x36, 0x37, 0x80, 0x00, 0x00, 0x04, + 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, + 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0xc5, 0xa8, 0x61, 0x55, + 0x09, 0x0e, 0x6f, 0x42, 0x00, 0x50, 0x63, 0x4c, 0xc8, 0xd6, 0xb9, 0x08, + 0xdf, 0xd5, 0x3e, 0x17, 0xc9, 0x9c, 0xb1, 0x43, 0x28, 0x4e, 0xaf, 0x64, + 0x68, 0x2a, 0xc2, 0x97, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, + 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(data)); + EXPECT_EQ(packet.common_header().source_port, 5000); + EXPECT_EQ(packet.common_header().destination_port, 5000); + EXPECT_EQ(packet.common_header().verification_tag, VerificationTag(0)); + EXPECT_EQ(packet.common_header().checksum, 0xaa019d33); + + EXPECT_THAT(packet.descriptors(), SizeIs(1)); + EXPECT_EQ(packet.descriptors()[0].type, InitChunk::kType); + ASSERT_HAS_VALUE_AND_ASSIGN(InitChunk init, + InitChunk::Parse(packet.descriptors()[0].data)); + EXPECT_EQ(init.initial_tsn(), TSN(1426601527)); +} + +TEST(SctpPacketTest, DeserializePacketWithTwoChunks) { + /* + Stream Control Transmission Protocol, Src Port: 1234 (1234), + Dst Port: 4321 (4321) + Source port: 1234 + Destination port: 4321 + Verification tag: 0x697e3a4e + [Association index: 3] + Checksum: 0xc06e8b36 [unverified] + [Checksum Status: Unverified] + COOKIE_ACK chunk + Chunk type: COOKIE_ACK (11) + Chunk flags: 0x00 + Chunk length: 4 + SACK chunk (Cumulative TSN: 2930332242, a_rwnd: 131072, + gaps: 0, duplicate TSNs: 0) + Chunk type: SACK (3) + Chunk flags: 0x00 + Chunk length: 16 + Cumulative TSN ACK: 2930332242 + Advertised receiver window credit (a_rwnd): 131072 + Number of gap acknowledgement blocks: 0 + Number of duplicated TSNs: 0 + */ + + uint8_t data[] = {0x04, 0xd2, 0x10, 0xe1, 0x69, 0x7e, 0x3a, 0x4e, + 0xc0, 0x6e, 0x8b, 0x36, 0x0b, 0x00, 0x00, 0x04, + 0x03, 0x00, 0x00, 0x10, 0xae, 0xa9, 0x52, 0x52, + 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(data)); + EXPECT_EQ(packet.common_header().source_port, 1234); + EXPECT_EQ(packet.common_header().destination_port, 4321); + EXPECT_EQ(packet.common_header().verification_tag, + VerificationTag(0x697e3a4eu)); + EXPECT_EQ(packet.common_header().checksum, 0xc06e8b36u); + + EXPECT_THAT(packet.descriptors(), SizeIs(2)); + EXPECT_EQ(packet.descriptors()[0].type, CookieAckChunk::kType); + EXPECT_EQ(packet.descriptors()[1].type, SackChunk::kType); + ASSERT_HAS_VALUE_AND_ASSIGN( + CookieAckChunk cookie_ack, + CookieAckChunk::Parse(packet.descriptors()[0].data)); + ASSERT_HAS_VALUE_AND_ASSIGN(SackChunk sack, + SackChunk::Parse(packet.descriptors()[1].data)); +} + +TEST(SctpPacketTest, DeserializePacketWithWrongChecksum) { + /* + Stream Control Transmission Protocol, Src Port: 5000 (5000), + Dst Port: 5000 (5000) + Source port: 5000 + Destination port: 5000 + Verification tag: 0x0eddca08 + [Association index: 1] + Checksum: 0x2a81f531 [unverified] + [Checksum Status: Unverified] + SACK chunk (Cumulative TSN: 1426601536, a_rwnd: 131072, + gaps: 0, duplicate TSNs: 0) + Chunk type: SACK (3) + Chunk flags: 0x00 + Chunk length: 16 + Cumulative TSN ACK: 1426601536 + Advertised receiver window credit (a_rwnd): 131072 + Number of gap acknowledgement blocks: 0 + Number of duplicated TSNs: 0 + */ + + uint8_t data[] = {0x13, 0x88, 0x13, 0x88, 0x0e, 0xdd, 0xca, 0x08, 0x2a, 0x81, + 0xf5, 0x31, 0x03, 0x00, 0x00, 0x10, 0x55, 0x08, 0x36, 0x40, + 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + + EXPECT_FALSE(SctpPacket::Parse(data).has_value()); +} + +TEST(SctpPacketTest, DeserializePacketDontValidateChecksum) { + /* + Stream Control Transmission Protocol, Src Port: 5000 (5000), + Dst Port: 5000 (5000) + Source port: 5000 + Destination port: 5000 + Verification tag: 0x0eddca08 + [Association index: 1] + Checksum: 0x2a81f531 [unverified] + [Checksum Status: Unverified] + SACK chunk (Cumulative TSN: 1426601536, a_rwnd: 131072, + gaps: 0, duplicate TSNs: 0) + Chunk type: SACK (3) + Chunk flags: 0x00 + Chunk length: 16 + Cumulative TSN ACK: 1426601536 + Advertised receiver window credit (a_rwnd): 131072 + Number of gap acknowledgement blocks: 0 + Number of duplicated TSNs: 0 + */ + + uint8_t data[] = {0x13, 0x88, 0x13, 0x88, 0x0e, 0xdd, 0xca, 0x08, 0x2a, 0x81, + 0xf5, 0x31, 0x03, 0x00, 0x00, 0x10, 0x55, 0x08, 0x36, 0x40, + 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN( + SctpPacket packet, + SctpPacket::Parse(data, /*disable_checksum_verification=*/true)); + EXPECT_EQ(packet.common_header().source_port, 5000); + EXPECT_EQ(packet.common_header().destination_port, 5000); + EXPECT_EQ(packet.common_header().verification_tag, + VerificationTag(0x0eddca08u)); + EXPECT_EQ(packet.common_header().checksum, 0x2a81f531u); +} + +TEST(SctpPacketTest, SerializeAndDeserializeSingleChunk) { + SctpPacket::Builder b(kVerificationTag, {}); + InitChunk init(/*initiate_tag=*/VerificationTag(123), /*a_rwnd=*/456, + /*nbr_outbound_streams=*/65535, + /*nbr_inbound_streams=*/65534, /*initial_tsn=*/TSN(789), + /*parameters=*/Parameters()); + + b.Add(init); + std::vector serialized = b.Build(); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(serialized)); + + EXPECT_EQ(packet.common_header().verification_tag, kVerificationTag); + + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + EXPECT_EQ(packet.descriptors()[0].type, InitChunk::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN(InitChunk deserialized, + InitChunk::Parse(packet.descriptors()[0].data)); + EXPECT_EQ(deserialized.initiate_tag(), VerificationTag(123)); + EXPECT_EQ(deserialized.a_rwnd(), 456u); + EXPECT_EQ(deserialized.nbr_outbound_streams(), 65535u); + EXPECT_EQ(deserialized.nbr_inbound_streams(), 65534u); + EXPECT_EQ(deserialized.initial_tsn(), TSN(789)); +} + +TEST(SctpPacketTest, SerializeAndDeserializeThreeChunks) { + SctpPacket::Builder b(kVerificationTag, {}); + b.Add(SackChunk(/*cumulative_tsn_ack=*/TSN(999), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(2, 3)}, + /*duplicate_tsns=*/{TSN(1), TSN(2), TSN(3)})); + b.Add(DataChunk(TSN(123), StreamID(456), SSN(789), PPID(9090), + /*payload=*/{1, 2, 3, 4, 5}, + /*options=*/{})); + b.Add(DataChunk(TSN(124), StreamID(654), SSN(987), PPID(909), + /*payload=*/{5, 4, 3, 3, 1}, + /*options=*/{})); + + std::vector serialized = b.Build(); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(serialized)); + + EXPECT_EQ(packet.common_header().verification_tag, kVerificationTag); + + ASSERT_THAT(packet.descriptors(), SizeIs(3)); + EXPECT_EQ(packet.descriptors()[0].type, SackChunk::kType); + EXPECT_EQ(packet.descriptors()[1].type, DataChunk::kType); + EXPECT_EQ(packet.descriptors()[2].type, DataChunk::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN(SackChunk sack, + SackChunk::Parse(packet.descriptors()[0].data)); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(999)); + EXPECT_EQ(sack.a_rwnd(), 456u); + + ASSERT_HAS_VALUE_AND_ASSIGN(DataChunk data1, + DataChunk::Parse(packet.descriptors()[1].data)); + EXPECT_EQ(data1.tsn(), TSN(123)); + + ASSERT_HAS_VALUE_AND_ASSIGN(DataChunk data2, + DataChunk::Parse(packet.descriptors()[2].data)); + EXPECT_EQ(data2.tsn(), TSN(124)); +} + +TEST(SctpPacketTest, ParseAbortWithEmptyCause) { + SctpPacket::Builder b(kVerificationTag, {}); + b.Add(AbortChunk( + /*filled_in_verification_tag=*/true, + Parameters::Builder().Add(UserInitiatedAbortCause("")).Build())); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(b.Build())); + + EXPECT_EQ(packet.common_header().verification_tag, kVerificationTag); + + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + EXPECT_EQ(packet.descriptors()[0].type, AbortChunk::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN(AbortChunk abort, + AbortChunk::Parse(packet.descriptors()[0].data)); + ASSERT_HAS_VALUE_AND_ASSIGN( + UserInitiatedAbortCause cause, + abort.error_causes().get()); + EXPECT_EQ(cause.upper_layer_abort_reason(), ""); +} + +TEST(SctpPacketTest, DetectPacketWithZeroSizeChunk) { + uint8_t data[] = {0xff, 0xff, 0xff, 0xff, 0xff, 0x0a, 0x0a, 0x0a, 0x5c, + 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x00, 0x00, 0x00}; + + EXPECT_FALSE(SctpPacket::Parse(data, true).has_value()); +} + +TEST(SctpPacketTest, ReturnsCorrectSpaceAvailableToStayWithinMTU) { + DcSctpOptions options; + options.mtu = 1191; + + SctpPacket::Builder builder(VerificationTag(123), options); + + // Chunks will be padded to an even 4 bytes, so the maximum packet size should + // be rounded down. + const size_t kMaxPacketSize = RoundDownTo4(options.mtu); + EXPECT_EQ(kMaxPacketSize, 1188u); + + const size_t kSctpHeaderSize = 12; + EXPECT_EQ(builder.bytes_remaining(), kMaxPacketSize - kSctpHeaderSize); + EXPECT_EQ(builder.bytes_remaining(), 1176u); + + // Add a smaller packet first. + DataChunk::Options data_options; + + std::vector payload1(183); + builder.Add( + DataChunk(TSN(1), StreamID(1), SSN(0), PPID(53), payload1, data_options)); + + size_t chunk1_size = RoundUpTo4(DataChunk::kHeaderSize + payload1.size()); + EXPECT_EQ(builder.bytes_remaining(), + kMaxPacketSize - kSctpHeaderSize - chunk1_size); + EXPECT_EQ(builder.bytes_remaining(), 976u); // Hand-calculated. + + std::vector payload2(957); + builder.Add( + DataChunk(TSN(1), StreamID(1), SSN(0), PPID(53), payload2, data_options)); + + size_t chunk2_size = RoundUpTo4(DataChunk::kHeaderSize + payload2.size()); + EXPECT_EQ(builder.bytes_remaining(), + kMaxPacketSize - kSctpHeaderSize - chunk1_size - chunk2_size); + EXPECT_EQ(builder.bytes_remaining(), 0u); // Hand-calculated. +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/packet/tlv_trait.cc b/net/dcsctp/packet/tlv_trait.cc new file mode 100644 index 0000000000..493b6a4613 --- /dev/null +++ b/net/dcsctp/packet/tlv_trait.cc @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/tlv_trait.h" + +#include "rtc_base/logging.h" + +namespace dcsctp { +namespace tlv_trait_impl { +void ReportInvalidSize(size_t actual_size, size_t expected_size) { + RTC_DLOG(LS_WARNING) << "Invalid size (" << actual_size + << ", expected minimum " << expected_size << " bytes)"; +} + +void ReportInvalidType(int actual_type, int expected_type) { + RTC_DLOG(LS_WARNING) << "Invalid type (" << actual_type << ", expected " + << expected_type << ")"; +} + +void ReportInvalidFixedLengthField(size_t value, size_t expected) { + RTC_DLOG(LS_WARNING) << "Invalid length field (" << value << ", expected " + << expected << " bytes)"; +} + +void ReportInvalidVariableLengthField(size_t value, size_t available) { + RTC_DLOG(LS_WARNING) << "Invalid length field (" << value << ", available " + << available << " bytes)"; +} + +void ReportInvalidPadding(size_t padding_bytes) { + RTC_DLOG(LS_WARNING) << "Invalid padding (" << padding_bytes << " bytes)"; +} + +void ReportInvalidLengthMultiple(size_t length, size_t alignment) { + RTC_DLOG(LS_WARNING) << "Invalid length field (" << length + << ", expected an even multiple of " << alignment + << " bytes)"; +} +} // namespace tlv_trait_impl +} // namespace dcsctp diff --git a/net/dcsctp/packet/tlv_trait.h b/net/dcsctp/packet/tlv_trait.h new file mode 100644 index 0000000000..a3c728efd7 --- /dev/null +++ b/net/dcsctp/packet/tlv_trait.h @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_TLV_TRAIT_H_ +#define NET_DCSCTP_PACKET_TLV_TRAIT_H_ + +#include +#include + +#include +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" + +namespace dcsctp { +namespace tlv_trait_impl { +// Logging functions, only to be used by TLVTrait, which is a templated class. +void ReportInvalidSize(size_t actual_size, size_t expected_size); +void ReportInvalidType(int actual_type, int expected_type); +void ReportInvalidFixedLengthField(size_t value, size_t expected); +void ReportInvalidVariableLengthField(size_t value, size_t available); +void ReportInvalidPadding(size_t padding_bytes); +void ReportInvalidLengthMultiple(size_t length, size_t alignment); +} // namespace tlv_trait_impl + +// Various entities in SCTP are padded data blocks, with a type and length +// field at fixed offsets, all stored in a 4-byte header. +// +// See e.g. https://tools.ietf.org/html/rfc4960#section-3.2 and +// https://tools.ietf.org/html/rfc4960#section-3.2.1 +// +// These are helper classes for writing and parsing that data, which in SCTP is +// called Type-Length-Value, or TLV. +// +// This templated class is configurable - a struct passed in as template +// parameter with the following expected members: +// * kType - The type field's value +// * kTypeSizeInBytes - The type field's width in bytes. +// Either 1 or 2. +// * kHeaderSize - The fixed size header +// * kVariableLengthAlignment - The size alignment on the variable data. Set +// to zero (0) if no variable data is used. +// +// This class is to be used as a trait +// (https://en.wikipedia.org/wiki/Trait_(computer_programming)) that adds a few +// public and protected members and which a class inherits from when it +// represents a type-length-value object. +template +class TLVTrait { + private: + static constexpr size_t kTlvHeaderSize = 4; + + protected: + static constexpr size_t kHeaderSize = Config::kHeaderSize; + + static_assert(Config::kTypeSizeInBytes == 1 || Config::kTypeSizeInBytes == 2, + "kTypeSizeInBytes must be 1 or 2"); + static_assert(Config::kHeaderSize >= kTlvHeaderSize, + "HeaderSize must be >= 4 bytes"); + static_assert((Config::kHeaderSize % 4 == 0), + "kHeaderSize must be an even multiple of 4 bytes"); + static_assert((Config::kVariableLengthAlignment == 0 || + Config::kVariableLengthAlignment == 1 || + Config::kVariableLengthAlignment == 2 || + Config::kVariableLengthAlignment == 4 || + Config::kVariableLengthAlignment == 8), + "kVariableLengthAlignment must be an allowed value"); + + // Validates the data with regards to size, alignment and type. + // If valid, returns a bounded buffer. + static absl::optional> ParseTLV( + rtc::ArrayView data) { + if (data.size() < Config::kHeaderSize) { + tlv_trait_impl::ReportInvalidSize(data.size(), Config::kHeaderSize); + return absl::nullopt; + } + BoundedByteReader tlv_header(data); + + const int type = (Config::kTypeSizeInBytes == 1) + ? tlv_header.template Load8<0>() + : tlv_header.template Load16<0>(); + + if (type != Config::kType) { + tlv_trait_impl::ReportInvalidType(type, Config::kType); + return absl::nullopt; + } + const uint16_t length = tlv_header.template Load16<2>(); + if (Config::kVariableLengthAlignment == 0) { + // Don't expect any variable length data at all. + if (length != Config::kHeaderSize || data.size() != Config::kHeaderSize) { + tlv_trait_impl::ReportInvalidFixedLengthField(length, + Config::kHeaderSize); + return absl::nullopt; + } + } else { + // Expect variable length data - verify its size alignment. + if (length > data.size() || length < Config::kHeaderSize) { + tlv_trait_impl::ReportInvalidVariableLengthField(length, data.size()); + return absl::nullopt; + } + const size_t padding = data.size() - length; + if (padding > 3) { + // https://tools.ietf.org/html/rfc4960#section-3.2 + // "This padding MUST NOT be more than 3 bytes in total" + tlv_trait_impl::ReportInvalidPadding(padding); + return absl::nullopt; + } + if (!ValidateLengthAlignment(length, Config::kVariableLengthAlignment)) { + tlv_trait_impl::ReportInvalidLengthMultiple( + length, Config::kVariableLengthAlignment); + return absl::nullopt; + } + } + return BoundedByteReader(data.subview(0, length)); + } + + // Allocates space for data with a static header size, as defined by + // `Config::kHeaderSize` and a variable footer, as defined by `variable_size` + // (which may be 0) and writes the type and length in the header. + static BoundedByteWriter AllocateTLV( + std::vector& out, + size_t variable_size = 0) { + const size_t offset = out.size(); + const size_t size = Config::kHeaderSize + variable_size; + out.resize(offset + size); + + BoundedByteWriter tlv_header( + rtc::ArrayView(out.data() + offset, kTlvHeaderSize)); + if (Config::kTypeSizeInBytes == 1) { + tlv_header.template Store8<0>(static_cast(Config::kType)); + } else { + tlv_header.template Store16<0>(Config::kType); + } + tlv_header.template Store16<2>(size); + + return BoundedByteWriter( + rtc::ArrayView(out.data() + offset, size)); + } + + private: + static bool ValidateLengthAlignment(uint16_t length, size_t alignment) { + // This is to avoid MSVC believing there could be a "mod by zero", when it + // certainly can't. + if (alignment == 0) { + return true; + } + return (length % alignment) == 0; + } +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_TLV_TRAIT_H_ diff --git a/net/dcsctp/packet/tlv_trait_test.cc b/net/dcsctp/packet/tlv_trait_test.cc new file mode 100644 index 0000000000..a0dd1a1136 --- /dev/null +++ b/net/dcsctp/packet/tlv_trait_test.cc @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/tlv_trait.h" + +#include + +#include "api/array_view.h" +#include "rtc_base/buffer.h" +#include "rtc_base/checks.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::SizeIs; + +struct OneByteTypeConfig { + static constexpr int kTypeSizeInBytes = 1; + static constexpr int kType = 0x49; + static constexpr size_t kHeaderSize = 12; + static constexpr int kVariableLengthAlignment = 4; +}; + +class OneByteChunk : public TLVTrait { + public: + static constexpr size_t kVariableSize = 4; + + void SerializeTo(std::vector& out) { + BoundedByteWriter writer = + AllocateTLV(out, kVariableSize); + writer.Store32<4>(0x01020304); + writer.Store16<8>(0x0506); + writer.Store16<10>(0x0708); + + uint8_t variable_data[kVariableSize] = {0xDE, 0xAD, 0xBE, 0xEF}; + writer.CopyToVariableData(rtc::ArrayView(variable_data)); + } + + static absl::optional> + Parse(rtc::ArrayView data) { + return ParseTLV(data); + } +}; + +TEST(TlvDataTest, CanWriteOneByteTypeTlvs) { + std::vector out; + OneByteChunk().SerializeTo(out); + + EXPECT_THAT(out, SizeIs(OneByteTypeConfig::kHeaderSize + + OneByteChunk::kVariableSize)); + EXPECT_THAT(out, ElementsAre(0x49, 0x00, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, 0xDE, 0xAD, 0xBE, 0xEF)); +} + +TEST(TlvDataTest, CanReadOneByteTypeTlvs) { + uint8_t data[] = {0x49, 0x00, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, 0xDE, 0xAD, 0xBE, 0xEF}; + + absl::optional> reader = + OneByteChunk::Parse(data); + ASSERT_TRUE(reader.has_value()); + EXPECT_EQ(reader->Load32<4>(), 0x01020304U); + EXPECT_EQ(reader->Load16<8>(), 0x0506U); + EXPECT_EQ(reader->Load16<10>(), 0x0708U); + EXPECT_THAT(reader->variable_data(), ElementsAre(0xDE, 0xAD, 0xBE, 0xEF)); +} + +struct TwoByteTypeConfig { + static constexpr int kTypeSizeInBytes = 2; + static constexpr int kType = 31337; + static constexpr size_t kHeaderSize = 8; + static constexpr int kVariableLengthAlignment = 2; +}; + +class TwoByteChunk : public TLVTrait { + public: + static constexpr size_t kVariableSize = 8; + + void SerializeTo(std::vector& out) { + BoundedByteWriter writer = + AllocateTLV(out, kVariableSize); + writer.Store32<4>(0x01020304U); + + uint8_t variable_data[] = {0x05, 0x06, 0x07, 0x08, 0xDE, 0xAD, 0xBE, 0xEF}; + writer.CopyToVariableData(rtc::ArrayView(variable_data)); + } + + static absl::optional> + Parse(rtc::ArrayView data) { + return ParseTLV(data); + } +}; + +TEST(TlvDataTest, CanWriteTwoByteTypeTlvs) { + std::vector out; + + TwoByteChunk().SerializeTo(out); + + EXPECT_THAT(out, SizeIs(TwoByteTypeConfig::kHeaderSize + + TwoByteChunk::kVariableSize)); + EXPECT_THAT(out, ElementsAre(0x7A, 0x69, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, 0xDE, 0xAD, 0xBE, 0xEF)); +} + +TEST(TlvDataTest, CanReadTwoByteTypeTlvs) { + uint8_t data[] = {0x7A, 0x69, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, 0xDE, 0xAD, 0xBE, 0xEF}; + + absl::optional> reader = + TwoByteChunk::Parse(data); + EXPECT_TRUE(reader.has_value()); + EXPECT_EQ(reader->Load32<4>(), 0x01020304U); + EXPECT_THAT(reader->variable_data(), + ElementsAre(0x05, 0x06, 0x07, 0x08, 0xDE, 0xAD, 0xBE, 0xEF)); +} + +TEST(TlvDataTest, CanHandleInvalidLengthSmallerThanFixedSize) { + // Has 'length=6', which is below the kHeaderSize of 8. + uint8_t data[] = {0x7A, 0x69, 0x00, 0x06, 0x01, 0x02, 0x03, 0x04}; + + EXPECT_FALSE(TwoByteChunk::Parse(data).has_value()); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/public/BUILD.gn b/net/dcsctp/public/BUILD.gn new file mode 100644 index 0000000000..ced94de151 --- /dev/null +++ b/net/dcsctp/public/BUILD.gn @@ -0,0 +1,103 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_source_set("strong_alias") { + sources = [ "strong_alias.h" ] +} + +rtc_source_set("types") { + deps = [ + ":strong_alias", + "../../../api:array_view", + ] + sources = [ + "dcsctp_message.h", + "dcsctp_options.h", + "types.h", + ] +} + +rtc_source_set("socket") { + deps = [ + ":types", + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + ] + sources = [ + "dcsctp_socket.h", + "packet_observer.h", + "timeout.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_source_set("factory") { + deps = [ + ":socket", + ":types", + "../socket:dcsctp_socket", + ] + sources = [ + "dcsctp_socket_factory.cc", + "dcsctp_socket_factory.h", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +rtc_source_set("mocks") { + testonly = true + sources = [ "mock_dcsctp_socket.h" ] + deps = [ + ":socket", + "../../../test:test_support", + ] +} + +rtc_source_set("utils") { + deps = [ + ":socket", + ":types", + "../../../api:array_view", + "../../../rtc_base:logging", + "../../../rtc_base:stringutils", + "../socket:dcsctp_socket", + ] + sources = [ + "text_pcap_packet_observer.cc", + "text_pcap_packet_observer.h", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +if (rtc_include_tests) { + rtc_library("dcsctp_public_unittests") { + testonly = true + + deps = [ + ":mocks", + ":strong_alias", + ":types", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../rtc_base:rtc_base_approved", + "../../../test:test_support", + ] + sources = [ + "mock_dcsctp_socket_test.cc", + "strong_alias_test.cc", + "types_test.cc", + ] + } +} diff --git a/net/dcsctp/public/dcsctp_message.h b/net/dcsctp/public/dcsctp_message.h new file mode 100644 index 0000000000..38e6763916 --- /dev/null +++ b/net/dcsctp/public/dcsctp_message.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_DCSCTP_MESSAGE_H_ +#define NET_DCSCTP_PUBLIC_DCSCTP_MESSAGE_H_ + +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// An SCTP message is a group of bytes sent and received as a whole on a +// specified stream identifier (`stream_id`), and with a payload protocol +// identifier (`ppid`). +class DcSctpMessage { + public: + DcSctpMessage(StreamID stream_id, PPID ppid, std::vector payload) + : stream_id_(stream_id), ppid_(ppid), payload_(std::move(payload)) {} + + DcSctpMessage(DcSctpMessage&& other) = default; + DcSctpMessage& operator=(DcSctpMessage&& other) = default; + DcSctpMessage(const DcSctpMessage&) = delete; + DcSctpMessage& operator=(const DcSctpMessage&) = delete; + + // The stream identifier to which the message is sent. + StreamID stream_id() const { return stream_id_; } + + // The payload protocol identifier (ppid) associated with the message. + PPID ppid() const { return ppid_; } + + // The payload of the message. + rtc::ArrayView payload() const { return payload_; } + + // When destructing the message, extracts the payload. + std::vector ReleasePayload() && { return std::move(payload_); } + + private: + StreamID stream_id_; + PPID ppid_; + std::vector payload_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_DCSCTP_MESSAGE_H_ diff --git a/net/dcsctp/public/dcsctp_options.h b/net/dcsctp/public/dcsctp_options.h new file mode 100644 index 0000000000..caefcff4f5 --- /dev/null +++ b/net/dcsctp/public/dcsctp_options.h @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_DCSCTP_OPTIONS_H_ +#define NET_DCSCTP_PUBLIC_DCSCTP_OPTIONS_H_ + +#include +#include + +#include "net/dcsctp/public/types.h" + +namespace dcsctp { +struct DcSctpOptions { + // The largest safe SCTP packet. Starting from the minimum guaranteed MTU + // value of 1280 for IPv6 (which may not support fragmentation), take off 85 + // bytes for DTLS/TURN/TCP/IP and ciphertext overhead. + // + // Additionally, it's possible that TURN adds an additional 4 bytes of + // overhead after a channel has been established, so an additional 4 bytes is + // subtracted + // + // 1280 IPV6 MTU + // -40 IPV6 header + // -8 UDP + // -24 GCM Cipher + // -13 DTLS record header + // -4 TURN ChannelData + // = 1191 bytes. + static constexpr size_t kMaxSafeMTUSize = 1191; + + // The local port for which the socket is supposed to be bound to. Incoming + // packets will be verified that they are sent to this port number and all + // outgoing packets will have this port number as source port. + int local_port = 5000; + + // The remote port to send packets to. All outgoing packets will have this + // port number as destination port. + int remote_port = 5000; + + // The announced maximum number of incoming streams. Note that this value is + // constant and can't be currently increased in run-time as "Add Incoming + // Streams Request" in RFC6525 isn't supported. + // + // The socket implementation doesn't have any per-stream fixed costs, which is + // why the default value is set to be the maximum value. + uint16_t announced_maximum_incoming_streams = 65535; + + // The announced maximum number of outgoing streams. Note that this value is + // constant and can't be currently increased in run-time as "Add Outgoing + // Streams Request" in RFC6525 isn't supported. + // + // The socket implementation doesn't have any per-stream fixed costs, which is + // why the default value is set to be the maximum value. + uint16_t announced_maximum_outgoing_streams = 65535; + + // Maximum SCTP packet size. The library will limit the size of generated + // packets to be less than or equal to this number. This does not include any + // overhead of DTLS, TURN, UDP or IP headers. + size_t mtu = kMaxSafeMTUSize; + + // The largest allowed message payload to be sent. Messages will be rejected + // if their payload is larger than this value. Note that this doesn't affect + // incoming messages, which may larger than this value (but smaller than + // `max_receiver_window_buffer_size`). + size_t max_message_size = 256 * 1024; + + // Maximum received window buffer size. This should be a bit larger than the + // largest sized message you want to be able to receive. This essentially + // limits the memory usage on the receive side. Note that memory is allocated + // dynamically, and this represents the maximum amount of buffered data. The + // actual memory usage of the library will be smaller in normal operation, and + // will be larger than this due to other allocations and overhead if the + // buffer is fully utilized. + size_t max_receiver_window_buffer_size = 5 * 1024 * 1024; + + // Maximum send buffer size. It will not be possible to queue more data than + // this before sending it. + size_t max_send_buffer_size = 2'000'000; + + // A threshold that, when the amount of data in the send buffer goes below + // this value, will trigger `DcSctpCallbacks::OnTotalBufferedAmountLow`. + size_t total_buffered_amount_low_threshold = 1'800'000; + + // Max allowed RTT value. When the RTT is measured and it's found to be larger + // than this value, it will be discarded and not used for e.g. any RTO + // calculation. The default value is an extreme maximum but can be adapted + // to better match the environment. + DurationMs rtt_max = DurationMs(8000); + + // Initial RTO value. + DurationMs rto_initial = DurationMs(500); + + // Maximum RTO value. + DurationMs rto_max = DurationMs(800); + + // Minimum RTO value. This must be larger than an expected peer delayed ack + // timeout. + DurationMs rto_min = DurationMs(220); + + // T1-init timeout. + DurationMs t1_init_timeout = DurationMs(1000); + + // T1-cookie timeout. + DurationMs t1_cookie_timeout = DurationMs(1000); + + // T2-shutdown timeout. + DurationMs t2_shutdown_timeout = DurationMs(1000); + + // Hearbeat interval (on idle connections only). Set to zero to disable. + DurationMs heartbeat_interval = DurationMs(30000); + + // The maximum time when a SACK will be sent from the arrival of an + // unacknowledged packet. Whatever is smallest of RTO/2 and this will be used. + DurationMs delayed_ack_max_timeout = DurationMs(200); + + // Do slow start as TCP - double cwnd instead of increasing it by MTU. + bool slow_start_tcp_style = false; + + // The initial congestion window size, in number of MTUs. + // See https://tools.ietf.org/html/rfc4960#section-7.2.1 which defaults at ~3 + // and https://research.google/pubs/pub36640/ which argues for at least ten + // segments. + size_t cwnd_mtus_initial = 10; + + // The minimum congestion window size, in number of MTUs. + // See https://tools.ietf.org/html/rfc4960#section-7.2.3. + size_t cwnd_mtus_min = 4; + + // Maximum Data Retransmit Attempts (per DATA chunk). + int max_retransmissions = 10; + + // Max.Init.Retransmits (https://tools.ietf.org/html/rfc4960#section-15) + int max_init_retransmits = 8; + + // RFC3758 Partial Reliability Extension + bool enable_partial_reliability = true; + + // RFC8260 Stream Schedulers and User Message Interleaving + bool enable_message_interleaving = false; + + // If RTO should be added to heartbeat_interval + bool heartbeat_interval_include_rtt = true; + + // Disables SCTP packet crc32 verification. Useful when running with fuzzers. + bool disable_checksum_verification = false; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_DCSCTP_OPTIONS_H_ diff --git a/net/dcsctp/public/dcsctp_socket.h b/net/dcsctp/public/dcsctp_socket.h new file mode 100644 index 0000000000..f07f54e044 --- /dev/null +++ b/net/dcsctp/public/dcsctp_socket.h @@ -0,0 +1,356 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_DCSCTP_SOCKET_H_ +#define NET_DCSCTP_PUBLIC_DCSCTP_SOCKET_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/packet_observer.h" +#include "net/dcsctp/public/timeout.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// The socket/association state +enum class SocketState { + // The socket is closed. + kClosed, + // The socket has initiated a connection, which is not yet established. Note + // that for incoming connections and for reconnections when the socket is + // already connected, the socket will not transition to this state. + kConnecting, + // The socket is connected, and the connection is established. + kConnected, + // The socket is shutting down, and the connection is not yet closed. + kShuttingDown, +}; + +// Send options for sending messages +struct SendOptions { + // If the message should be sent with unordered message delivery. + IsUnordered unordered = IsUnordered(false); + + // If set, will discard messages that haven't been correctly sent and + // received before the lifetime has expired. This is only available if the + // peer supports Partial Reliability Extension (RFC3758). + absl::optional lifetime = absl::nullopt; + + // If set, limits the number of retransmissions. This is only available + // if the peer supports Partial Reliability Extension (RFC3758). + absl::optional max_retransmissions = absl::nullopt; +}; + +enum class ErrorKind { + // Indicates that no error has occurred. This will never be the case when + // `OnError` or `OnAborted` is called. + kNoError, + // There have been too many retries or timeouts, and the library has given up. + kTooManyRetries, + // A command was received that is only possible to execute when the socket is + // connected, which it is not. + kNotConnected, + // Parsing of the command or its parameters failed. + kParseFailed, + // Commands are received in the wrong sequence, which indicates a + // synchronisation mismatch between the peers. + kWrongSequence, + // The peer has reported an issue using ERROR or ABORT command. + kPeerReported, + // The peer has performed a protocol violation. + kProtocolViolation, + // The receive or send buffers have been exhausted. + kResourceExhaustion, + // The client has performed an invalid operation. + kUnsupportedOperation, +}; + +inline constexpr absl::string_view ToString(ErrorKind error) { + switch (error) { + case ErrorKind::kNoError: + return "NO_ERROR"; + case ErrorKind::kTooManyRetries: + return "TOO_MANY_RETRIES"; + case ErrorKind::kNotConnected: + return "NOT_CONNECTED"; + case ErrorKind::kParseFailed: + return "PARSE_FAILED"; + case ErrorKind::kWrongSequence: + return "WRONG_SEQUENCE"; + case ErrorKind::kPeerReported: + return "PEER_REPORTED"; + case ErrorKind::kProtocolViolation: + return "PROTOCOL_VIOLATION"; + case ErrorKind::kResourceExhaustion: + return "RESOURCE_EXHAUSTION"; + case ErrorKind::kUnsupportedOperation: + return "UNSUPPORTED_OPERATION"; + } +} + +enum class SendStatus { + // The message was enqueued successfully. As sending the message is done + // asynchronously, this is no guarantee that the message has been actually + // sent. + kSuccess, + // The message was rejected as the payload was empty (which is not allowed in + // SCTP). + kErrorMessageEmpty, + // The message was rejected as the payload was larger than what has been set + // as `DcSctpOptions.max_message_size`. + kErrorMessageTooLarge, + // The message could not be enqueued as the socket is out of resources. This + // mainly indicates that the send queue is full. + kErrorResourceExhaustion, + // The message could not be sent as the socket is shutting down. + kErrorShuttingDown, +}; + +inline constexpr absl::string_view ToString(SendStatus error) { + switch (error) { + case SendStatus::kSuccess: + return "SUCCESS"; + case SendStatus::kErrorMessageEmpty: + return "ERROR_MESSAGE_EMPTY"; + case SendStatus::kErrorMessageTooLarge: + return "ERROR_MESSAGE_TOO_LARGE"; + case SendStatus::kErrorResourceExhaustion: + return "ERROR_RESOURCE_EXHAUSTION"; + case SendStatus::kErrorShuttingDown: + return "ERROR_SHUTTING_DOWN"; + } +} + +// Return value of ResetStreams. +enum class ResetStreamsStatus { + // If the connection is not yet established, this will be returned. + kNotConnected, + // Indicates that ResetStreams operation has been successfully initiated. + kPerformed, + // Indicates that ResetStreams has failed as it's not supported by the peer. + kNotSupported, +}; + +inline constexpr absl::string_view ToString(ResetStreamsStatus error) { + switch (error) { + case ResetStreamsStatus::kNotConnected: + return "NOT_CONNECTED"; + case ResetStreamsStatus::kPerformed: + return "PERFORMED"; + case ResetStreamsStatus::kNotSupported: + return "NOT_SUPPORTED"; + } +} + +// Callbacks that the DcSctpSocket will be done synchronously to the owning +// client. It is allowed to call back into the library from callbacks that start +// with "On". It has been explicitly documented when it's not allowed to call +// back into this library from within a callback. +// +// Theses callbacks are only synchronously triggered as a result of the client +// calling a public method in `DcSctpSocketInterface`. +class DcSctpSocketCallbacks { + public: + virtual ~DcSctpSocketCallbacks() = default; + + // Called when the library wants the packet serialized as `data` to be sent. + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + virtual void SendPacket(rtc::ArrayView data) = 0; + + // Called when the library wants to create a Timeout. The callback must return + // an object that implements that interface. + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + virtual std::unique_ptr CreateTimeout() = 0; + + // Returns the current time in milliseconds (from any epoch). + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + virtual TimeMs TimeMillis() = 0; + + // Called when the library needs a random number uniformly distributed between + // `low` (inclusive) and `high` (exclusive). The random numbers used by the + // library are not used for cryptographic purposes. There are no requirements + // that the random number generator must be secure. + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + virtual uint32_t GetRandomInt(uint32_t low, uint32_t high) = 0; + + // Triggered when the outgoing message buffer is empty, meaning that there are + // no more queued messages, but there can still be packets in-flight or to be + // retransmitted. (in contrast to SCTP_SENDER_DRY_EVENT). + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + ABSL_DEPRECATED("Use OnTotalBufferedAmountLow instead") + virtual void NotifyOutgoingMessageBufferEmpty() {} + + // Called when the library has received an SCTP message in full and delivers + // it to the upper layer. + // + // It is allowed to call into this library from within this callback. + virtual void OnMessageReceived(DcSctpMessage message) = 0; + + // Triggered when an non-fatal error is reported by either this library or + // from the other peer (by sending an ERROR command). These should be logged, + // but no other action need to be taken as the association is still viable. + // + // It is allowed to call into this library from within this callback. + virtual void OnError(ErrorKind error, absl::string_view message) = 0; + + // Triggered when the socket has aborted - either as decided by this socket + // due to e.g. too many retransmission attempts, or by the peer when + // receiving an ABORT command. No other callbacks will be done after this + // callback, unless reconnecting. + // + // It is allowed to call into this library from within this callback. + virtual void OnAborted(ErrorKind error, absl::string_view message) = 0; + + // Called when calling `Connect` succeeds, but also for incoming successful + // connection attempts. + // + // It is allowed to call into this library from within this callback. + virtual void OnConnected() = 0; + + // Called when the socket is closed in a controlled way. No other + // callbacks will be done after this callback, unless reconnecting. + // + // It is allowed to call into this library from within this callback. + virtual void OnClosed() = 0; + + // On connection restarted (by peer). This is just a notification, and the + // association is expected to work fine after this call, but there could have + // been packet loss as a result of restarting the association. + // + // It is allowed to call into this library from within this callback. + virtual void OnConnectionRestarted() = 0; + + // Indicates that a stream reset request has failed. + // + // It is allowed to call into this library from within this callback. + virtual void OnStreamsResetFailed( + rtc::ArrayView outgoing_streams, + absl::string_view reason) = 0; + + // Indicates that a stream reset request has been performed. + // + // It is allowed to call into this library from within this callback. + virtual void OnStreamsResetPerformed( + rtc::ArrayView outgoing_streams) = 0; + + // When a peer has reset some of its outgoing streams, this will be called. An + // empty list indicates that all streams have been reset. + // + // It is allowed to call into this library from within this callback. + virtual void OnIncomingStreamsReset( + rtc::ArrayView incoming_streams) = 0; + + // Will be called when the amount of data buffered to be sent falls to or + // below the threshold set when calling `SetBufferedAmountLowThreshold`. + // + // It is allowed to call into this library from within this callback. + virtual void OnBufferedAmountLow(StreamID stream_id) {} + + // Will be called when the total amount of data buffered (in the entire send + // buffer, for all streams) falls to or below the threshold specified in + // `DcSctpOptions::total_buffered_amount_low_threshold`. + virtual void OnTotalBufferedAmountLow() {} +}; + +// The DcSctpSocket implementation implements the following interface. +class DcSctpSocketInterface { + public: + virtual ~DcSctpSocketInterface() = default; + + // To be called when an incoming SCTP packet is to be processed. + virtual void ReceivePacket(rtc::ArrayView data) = 0; + + // To be called when a timeout has expired. The `timeout_id` is provided + // when the timeout was initiated. + virtual void HandleTimeout(TimeoutID timeout_id) = 0; + + // Connects the socket. This is an asynchronous operation, and + // `DcSctpSocketCallbacks::OnConnected` will be called on success. + virtual void Connect() = 0; + + // Gracefully shutdowns the socket and sends all outstanding data. This is an + // asynchronous operation and `DcSctpSocketCallbacks::OnClosed` will be called + // on success. + virtual void Shutdown() = 0; + + // Closes the connection non-gracefully. Will send ABORT if the connection is + // not already closed. No callbacks will be made after Close() has returned. + virtual void Close() = 0; + + // The socket state. + virtual SocketState state() const = 0; + + // The options it was created with. + virtual const DcSctpOptions& options() const = 0; + + // Update the options max_message_size. + virtual void SetMaxMessageSize(size_t max_message_size) = 0; + + // Sends the message `message` using the provided send options. + // Sending a message is an asynchrous operation, and the `OnError` callback + // may be invoked to indicate any errors in sending the message. + // + // The association does not have to be established before calling this method. + // If it's called before there is an established association, the message will + // be queued. + virtual SendStatus Send(DcSctpMessage message, + const SendOptions& send_options) = 0; + + // Resetting streams is an asynchronous operation and the results will + // be notified using `DcSctpSocketCallbacks::OnStreamsResetDone()` on success + // and `DcSctpSocketCallbacks::OnStreamsResetFailed()` on failure. Note that + // only outgoing streams can be reset. + // + // When it's known that the peer has reset its own outgoing streams, + // `DcSctpSocketCallbacks::OnIncomingStreamReset` is called. + // + // Note that resetting a stream will also remove all queued messages on those + // streams, but will ensure that the currently sent message (if any) is fully + // sent before closing the stream. + // + // Resetting streams can only be done on an established association that + // supports stream resetting. Calling this method on e.g. a closed association + // or streams that don't support resetting will not perform any operation. + virtual ResetStreamsStatus ResetStreams( + rtc::ArrayView outgoing_streams) = 0; + + // Returns the number of bytes of data currently queued to be sent on a given + // stream. + virtual size_t buffered_amount(StreamID stream_id) const = 0; + + // Returns the number of buffered outgoing bytes that is considered "low" for + // a given stream. See `SetBufferedAmountLowThreshold`. + virtual size_t buffered_amount_low_threshold(StreamID stream_id) const = 0; + + // Used to specify the number of bytes of buffered outgoing data that is + // considered "low" for a given stream, which will trigger an + // OnBufferedAmountLow event. The default value is zero (0). + virtual void SetBufferedAmountLowThreshold(StreamID stream_id, + size_t bytes) = 0; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_DCSCTP_SOCKET_H_ diff --git a/net/dcsctp/public/dcsctp_socket_factory.cc b/net/dcsctp/public/dcsctp_socket_factory.cc new file mode 100644 index 0000000000..338d143424 --- /dev/null +++ b/net/dcsctp/public/dcsctp_socket_factory.cc @@ -0,0 +1,31 @@ +/* + * Copyright 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "net/dcsctp/public/dcsctp_socket_factory.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/packet_observer.h" +#include "net/dcsctp/socket/dcsctp_socket.h" + +namespace dcsctp { +std::unique_ptr DcSctpSocketFactory::Create( + absl::string_view log_prefix, + DcSctpSocketCallbacks& callbacks, + std::unique_ptr packet_observer, + const DcSctpOptions& options) { + return std::make_unique(log_prefix, callbacks, + std::move(packet_observer), options); +} +} // namespace dcsctp diff --git a/net/dcsctp/public/dcsctp_socket_factory.h b/net/dcsctp/public/dcsctp_socket_factory.h new file mode 100644 index 0000000000..dcc68d9b54 --- /dev/null +++ b/net/dcsctp/public/dcsctp_socket_factory.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_DCSCTP_SOCKET_FACTORY_H_ +#define NET_DCSCTP_PUBLIC_DCSCTP_SOCKET_FACTORY_H_ + +#include + +#include "absl/strings/string_view.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/packet_observer.h" + +namespace dcsctp { +class DcSctpSocketFactory { + public: + std::unique_ptr Create( + absl::string_view log_prefix, + DcSctpSocketCallbacks& callbacks, + std::unique_ptr packet_observer, + const DcSctpOptions& options); +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_DCSCTP_SOCKET_FACTORY_H_ diff --git a/net/dcsctp/public/mock_dcsctp_socket.h b/net/dcsctp/public/mock_dcsctp_socket.h new file mode 100644 index 0000000000..18140642b7 --- /dev/null +++ b/net/dcsctp/public/mock_dcsctp_socket.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_H_ +#define NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_H_ + +#include "net/dcsctp/public/dcsctp_socket.h" +#include "test/gmock.h" + +namespace dcsctp { + +class MockDcSctpSocket : public DcSctpSocketInterface { + public: + MOCK_METHOD(void, + ReceivePacket, + (rtc::ArrayView data), + (override)); + + MOCK_METHOD(void, HandleTimeout, (TimeoutID timeout_id), (override)); + + MOCK_METHOD(void, Connect, (), (override)); + + MOCK_METHOD(void, Shutdown, (), (override)); + + MOCK_METHOD(void, Close, (), (override)); + + MOCK_METHOD(SocketState, state, (), (const, override)); + + MOCK_METHOD(const DcSctpOptions&, options, (), (const, override)); + + MOCK_METHOD(void, SetMaxMessageSize, (size_t max_message_size), (override)); + + MOCK_METHOD(SendStatus, + Send, + (DcSctpMessage message, const SendOptions& send_options), + (override)); + + MOCK_METHOD(ResetStreamsStatus, + ResetStreams, + (rtc::ArrayView outgoing_streams), + (override)); + + MOCK_METHOD(size_t, buffered_amount, (StreamID stream_id), (const, override)); + + MOCK_METHOD(size_t, + buffered_amount_low_threshold, + (StreamID stream_id), + (const, override)); + + MOCK_METHOD(void, + SetBufferedAmountLowThreshold, + (StreamID stream_id, size_t bytes), + (override)); +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_H_ diff --git a/net/dcsctp/public/mock_dcsctp_socket_test.cc b/net/dcsctp/public/mock_dcsctp_socket_test.cc new file mode 100644 index 0000000000..57013e4ce2 --- /dev/null +++ b/net/dcsctp/public/mock_dcsctp_socket_test.cc @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/public/mock_dcsctp_socket.h" + +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { + +// This test exists to ensure that all methods are mocked correctly, and to +// generate compiler errors if they are not. +TEST(MockDcSctpSocketTest, CanInstantiateAndConnect) { + testing::StrictMock socket; + + EXPECT_CALL(socket, Connect); + + socket.Connect(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/public/packet_observer.h b/net/dcsctp/public/packet_observer.h new file mode 100644 index 0000000000..fe7567824f --- /dev/null +++ b/net/dcsctp/public/packet_observer.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_PACKET_OBSERVER_H_ +#define NET_DCSCTP_PUBLIC_PACKET_OBSERVER_H_ + +#include + +#include "api/array_view.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// A PacketObserver can be attached to a socket and will be called for +// all sent and received packets. +class PacketObserver { + public: + virtual ~PacketObserver() = default; + // Called when a packet is sent, with the current time (in milliseconds) as + // `now`, and the packet payload as `payload`. + virtual void OnSentPacket(TimeMs now, + rtc::ArrayView payload) = 0; + + // Called when a packet is received, with the current time (in milliseconds) + // as `now`, and the packet payload as `payload`. + virtual void OnReceivedPacket(TimeMs now, + rtc::ArrayView payload) = 0; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_PACKET_OBSERVER_H_ diff --git a/net/dcsctp/public/strong_alias.h b/net/dcsctp/public/strong_alias.h new file mode 100644 index 0000000000..96678442b4 --- /dev/null +++ b/net/dcsctp/public/strong_alias.h @@ -0,0 +1,85 @@ +/* + * Copyright 2019 The Chromium Authors. All rights reserved. + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_STRONG_ALIAS_H_ +#define NET_DCSCTP_PUBLIC_STRONG_ALIAS_H_ + +#include +#include + +namespace dcsctp { + +// This is a copy of +// https://source.chromium.org/chromium/chromium/src/+/master:base/types/strong_alias.h +// as the API (and internals) are using type-safe integral identifiers, but this +// library can't depend on that file. The ostream operator has been removed +// per WebRTC library conventions, and the underlying type is exposed. + +template +class StrongAlias { + public: + using UnderlyingType = TheUnderlyingType; + constexpr StrongAlias() = default; + constexpr explicit StrongAlias(const UnderlyingType& v) : value_(v) {} + constexpr explicit StrongAlias(UnderlyingType&& v) noexcept + : value_(std::move(v)) {} + + constexpr UnderlyingType* operator->() { return &value_; } + constexpr const UnderlyingType* operator->() const { return &value_; } + + constexpr UnderlyingType& operator*() & { return value_; } + constexpr const UnderlyingType& operator*() const& { return value_; } + constexpr UnderlyingType&& operator*() && { return std::move(value_); } + constexpr const UnderlyingType&& operator*() const&& { + return std::move(value_); + } + + constexpr UnderlyingType& value() & { return value_; } + constexpr const UnderlyingType& value() const& { return value_; } + constexpr UnderlyingType&& value() && { return std::move(value_); } + constexpr const UnderlyingType&& value() const&& { return std::move(value_); } + + constexpr explicit operator const UnderlyingType&() const& { return value_; } + + constexpr bool operator==(const StrongAlias& other) const { + return value_ == other.value_; + } + constexpr bool operator!=(const StrongAlias& other) const { + return value_ != other.value_; + } + constexpr bool operator<(const StrongAlias& other) const { + return value_ < other.value_; + } + constexpr bool operator<=(const StrongAlias& other) const { + return value_ <= other.value_; + } + constexpr bool operator>(const StrongAlias& other) const { + return value_ > other.value_; + } + constexpr bool operator>=(const StrongAlias& other) const { + return value_ >= other.value_; + } + + // Hasher to use in std::unordered_map, std::unordered_set, etc. + struct Hasher { + using argument_type = StrongAlias; + using result_type = std::size_t; + result_type operator()(const argument_type& id) const { + return std::hash()(id.value()); + } + }; + + protected: + UnderlyingType value_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_STRONG_ALIAS_H_ diff --git a/net/dcsctp/public/strong_alias_test.cc b/net/dcsctp/public/strong_alias_test.cc new file mode 100644 index 0000000000..0c57c6b248 --- /dev/null +++ b/net/dcsctp/public/strong_alias_test.cc @@ -0,0 +1,362 @@ +/* + * Copyright 2019 The Chromium Authors. All rights reserved. + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/public/strong_alias.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +// This is a copy of +// https://source.chromium.org/chromium/chromium/src/+/master:base/types/strong_alias_unittest.cc +// but adapted to use WebRTC's includes, remove unit tests that test the ostream +// operator (it's removed in this port) and other adaptations to pass lint. + +namespace dcsctp { +namespace { + +// For test correctnenss, it's important that these getters return lexically +// incrementing values as |index| grows. +template +T GetExampleValue(int index); + +template <> +int GetExampleValue(int index) { + return 5 + index; +} +template <> +uint64_t GetExampleValue(int index) { + return 500U + index; +} + +template <> +std::string GetExampleValue(int index) { + return std::string('a', index); +} + +} // namespace + +template +class StrongAliasTest : public ::testing::Test {}; + +using TestedTypes = ::testing::Types; +TYPED_TEST_SUITE(StrongAliasTest, TestedTypes); + +TYPED_TEST(StrongAliasTest, ValueAccessesUnderlyingValue) { + using FooAlias = StrongAlias; + + // Const value getter. + const FooAlias const_alias(GetExampleValue(1)); + EXPECT_EQ(GetExampleValue(1), const_alias.value()); + static_assert(std::is_const::type>::value, + "Reference returned by const value getter should be const."); +} + +TYPED_TEST(StrongAliasTest, ExplicitConversionToUnderlyingValue) { + using FooAlias = StrongAlias; + + const FooAlias const_alias(GetExampleValue(1)); + EXPECT_EQ(GetExampleValue(1), static_cast(const_alias)); +} + +TYPED_TEST(StrongAliasTest, CanBeCopyConstructed) { + using FooAlias = StrongAlias; + FooAlias alias(GetExampleValue(0)); + FooAlias copy_constructed = alias; + EXPECT_EQ(copy_constructed, alias); + + FooAlias copy_assigned; + copy_assigned = alias; + EXPECT_EQ(copy_assigned, alias); +} + +TYPED_TEST(StrongAliasTest, CanBeMoveConstructed) { + using FooAlias = StrongAlias; + FooAlias alias(GetExampleValue(0)); + FooAlias move_constructed = std::move(alias); + EXPECT_EQ(move_constructed, FooAlias(GetExampleValue(0))); + + FooAlias alias2(GetExampleValue(2)); + FooAlias move_assigned; + move_assigned = std::move(alias2); + EXPECT_EQ(move_assigned, FooAlias(GetExampleValue(2))); + + // Check that FooAlias is nothrow move constructible. This matters for + // performance when used in std::vectors. + static_assert(std::is_nothrow_move_constructible::value, + "Error: Alias is not nothow move constructible"); +} + +TYPED_TEST(StrongAliasTest, CanBeConstructedFromMoveOnlyType) { + // Note, using a move-only unique_ptr to T: + using FooAlias = StrongAlias>; + + FooAlias a(std::make_unique(GetExampleValue(0))); + EXPECT_EQ(*a.value(), GetExampleValue(0)); + + auto bare_value = std::make_unique(GetExampleValue(1)); + FooAlias b(std::move(bare_value)); + EXPECT_EQ(*b.value(), GetExampleValue(1)); +} + +TYPED_TEST(StrongAliasTest, MutableOperatorArrow) { + // Note, using a move-only unique_ptr to T: + using Ptr = std::unique_ptr; + using FooAlias = StrongAlias; + + FooAlias a(std::make_unique()); + EXPECT_TRUE(a.value()); + + // Check that `a` can be modified through the use of operator->. + a->reset(); + + EXPECT_FALSE(a.value()); +} + +TYPED_TEST(StrongAliasTest, MutableOperatorStar) { + // Note, using a move-only unique_ptr to T: + using Ptr = std::unique_ptr; + using FooAlias = StrongAlias; + + FooAlias a(std::make_unique()); + FooAlias b(std::make_unique()); + EXPECT_TRUE(*a); + EXPECT_TRUE(*b); + + // Check that both the mutable l-value and r-value overloads work and we can + // move out of the aliases. + { Ptr ignore(*std::move(a)); } + { Ptr ignore(std::move(*b)); } + + EXPECT_FALSE(a.value()); + EXPECT_FALSE(b.value()); +} + +TYPED_TEST(StrongAliasTest, MutableValue) { + // Note, using a move-only unique_ptr to T: + using Ptr = std::unique_ptr; + using FooAlias = StrongAlias; + + FooAlias a(std::make_unique()); + FooAlias b(std::make_unique()); + EXPECT_TRUE(a.value()); + EXPECT_TRUE(b.value()); + + // Check that both the mutable l-value and r-value overloads work and we can + // move out of the aliases. + { Ptr ignore(std::move(a).value()); } + { Ptr ignore(std::move(b.value())); } + + EXPECT_FALSE(a.value()); + EXPECT_FALSE(b.value()); +} + +TYPED_TEST(StrongAliasTest, SizeSameAsUnderlyingType) { + using FooAlias = StrongAlias; + static_assert(sizeof(FooAlias) == sizeof(TypeParam), + "StrongAlias should be as large as the underlying type."); +} + +TYPED_TEST(StrongAliasTest, IsDefaultConstructible) { + using FooAlias = StrongAlias; + static_assert(std::is_default_constructible::value, + "Should be possible to default-construct a StrongAlias."); + static_assert( + std::is_trivially_default_constructible::value == + std::is_trivially_default_constructible::value, + "Should be possible to trivially default-construct a StrongAlias iff the " + "underlying type is trivially default constructible."); +} + +TEST(StrongAliasTest, TrivialTypeAliasIsStandardLayout) { + using FooAlias = StrongAlias; + static_assert(std::is_standard_layout::value, + "int-based alias should have standard layout. "); + static_assert(std::is_trivially_copyable::value, + "int-based alias should be trivially copyable. "); +} + +TYPED_TEST(StrongAliasTest, CannotBeCreatedFromDifferentAlias) { + using FooAlias = StrongAlias; + using BarAlias = StrongAlias; + static_assert(!std::is_constructible::value, + "Should be impossible to construct FooAlias from a BarAlias."); + static_assert(!std::is_convertible::value, + "Should be impossible to convert a BarAlias into FooAlias."); +} + +TYPED_TEST(StrongAliasTest, CannotBeImplicitlyConverterToUnderlyingValue) { + using FooAlias = StrongAlias; + static_assert(!std::is_convertible::value, + "Should be impossible to implicitly convert a StrongAlias into " + "an underlying type."); +} + +TYPED_TEST(StrongAliasTest, ComparesEqualToSameValue) { + using FooAlias = StrongAlias; + // Comparison to self: + const FooAlias a = FooAlias(GetExampleValue(0)); + EXPECT_EQ(a, a); + EXPECT_FALSE(a != a); + EXPECT_TRUE(a >= a); + EXPECT_TRUE(a <= a); + EXPECT_FALSE(a > a); + EXPECT_FALSE(a < a); + // Comparison to other equal object: + const FooAlias b = FooAlias(GetExampleValue(0)); + EXPECT_EQ(a, b); + EXPECT_FALSE(a != b); + EXPECT_TRUE(a >= b); + EXPECT_TRUE(a <= b); + EXPECT_FALSE(a > b); + EXPECT_FALSE(a < b); +} + +TYPED_TEST(StrongAliasTest, ComparesCorrectlyToDifferentValue) { + using FooAlias = StrongAlias; + const FooAlias a = FooAlias(GetExampleValue(0)); + const FooAlias b = FooAlias(GetExampleValue(1)); + EXPECT_NE(a, b); + EXPECT_FALSE(a == b); + EXPECT_TRUE(b >= a); + EXPECT_TRUE(a <= b); + EXPECT_TRUE(b > a); + EXPECT_TRUE(a < b); +} + +TEST(StrongAliasTest, CanBeDerivedFrom) { + // Aliases can be enriched by custom operations or validations if needed. + // Ideally, one could go from a 'using' declaration to a derived class to add + // those methods without the need to change any other code. + class CountryCode : public StrongAlias { + public: + explicit CountryCode(const std::string& value) + : StrongAlias::StrongAlias(value) { + if (value_.length() != 2) { + // Country code invalid! + value_.clear(); // is_null() will return true. + } + } + + bool is_null() const { return value_.empty(); } + }; + + CountryCode valid("US"); + EXPECT_FALSE(valid.is_null()); + + CountryCode invalid("United States"); + EXPECT_TRUE(invalid.is_null()); +} + +TEST(StrongAliasTest, CanWrapComplexStructures) { + // A pair of strings implements odering and can, in principle, be used as + // a base of StrongAlias. + using PairOfStrings = std::pair; + using ComplexAlias = StrongAlias; + + ComplexAlias a1{std::make_pair("aaa", "bbb")}; + ComplexAlias a2{std::make_pair("ccc", "ddd")}; + EXPECT_TRUE(a1 < a2); + + EXPECT_TRUE(a1.value() == PairOfStrings("aaa", "bbb")); + + // Note a caveat, an std::pair doesn't have an overload of operator<<, and it + // cannot be easily added since ADL rules would require it to be in the std + // namespace. So we can't print ComplexAlias. +} + +TYPED_TEST(StrongAliasTest, CanBeKeysInStdUnorderedMap) { + using FooAlias = StrongAlias; + std::unordered_map map; + + FooAlias k1(GetExampleValue(0)); + FooAlias k2(GetExampleValue(1)); + + map[k1] = "value1"; + map[k2] = "value2"; + + EXPECT_EQ(map[k1], "value1"); + EXPECT_EQ(map[k2], "value2"); +} + +TYPED_TEST(StrongAliasTest, CanBeKeysInStdMap) { + using FooAlias = StrongAlias; + std::map map; + + FooAlias k1(GetExampleValue(0)); + FooAlias k2(GetExampleValue(1)); + + map[k1] = "value1"; + map[k2] = "value2"; + + EXPECT_EQ(map[k1], "value1"); + EXPECT_EQ(map[k2], "value2"); +} + +TYPED_TEST(StrongAliasTest, CanDifferentiateOverloads) { + using FooAlias = StrongAlias; + using BarAlias = StrongAlias; + class Scope { + public: + static std::string Overload(FooAlias) { return "FooAlias"; } + static std::string Overload(BarAlias) { return "BarAlias"; } + }; + EXPECT_EQ("FooAlias", Scope::Overload(FooAlias())); + EXPECT_EQ("BarAlias", Scope::Overload(BarAlias())); +} + +TEST(StrongAliasTest, EnsureConstexpr) { + using FooAlias = StrongAlias; + + // Check constructors. + static constexpr FooAlias kZero{}; + static constexpr FooAlias kOne(1); + + // Check operator*. + static_assert(*kZero == 0, ""); + static_assert(*kOne == 1, ""); + + // Check value(). + static_assert(kZero.value() == 0, ""); + static_assert(kOne.value() == 1, ""); + + // Check explicit conversions to underlying type. + static_assert(static_cast(kZero) == 0, ""); + static_assert(static_cast(kOne) == 1, ""); + + // Check comparison operations. + static_assert(kZero == kZero, ""); + static_assert(kZero != kOne, ""); + static_assert(kZero < kOne, ""); + static_assert(kZero <= kOne, ""); + static_assert(kOne > kZero, ""); + static_assert(kOne >= kZero, ""); +} + +TEST(StrongAliasTest, BooleansAreEvaluatedAsBooleans) { + using BoolAlias = StrongAlias; + + BoolAlias happy(true); + BoolAlias sad(false); + + EXPECT_TRUE(happy); + EXPECT_FALSE(sad); + EXPECT_TRUE(*happy); + EXPECT_FALSE(*sad); +} +} // namespace dcsctp diff --git a/net/dcsctp/public/text_pcap_packet_observer.cc b/net/dcsctp/public/text_pcap_packet_observer.cc new file mode 100644 index 0000000000..2b13060190 --- /dev/null +++ b/net/dcsctp/public/text_pcap_packet_observer.cc @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/public/text_pcap_packet_observer.h" + +#include "api/array_view.h" +#include "net/dcsctp/public/types.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +void TextPcapPacketObserver::OnSentPacket( + dcsctp::TimeMs now, + rtc::ArrayView payload) { + PrintPacket("O ", name_, now, payload); +} + +void TextPcapPacketObserver::OnReceivedPacket( + dcsctp::TimeMs now, + rtc::ArrayView payload) { + PrintPacket("I ", name_, now, payload); +} + +void TextPcapPacketObserver::PrintPacket( + absl::string_view prefix, + absl::string_view socket_name, + dcsctp::TimeMs now, + rtc::ArrayView payload) { + rtc::StringBuilder s; + s << "\n" << prefix; + int64_t remaining = *now % (24 * 60 * 60 * 1000); + int hours = remaining / (60 * 60 * 1000); + remaining = remaining % (60 * 60 * 1000); + int minutes = remaining / (60 * 1000); + remaining = remaining % (60 * 1000); + int seconds = remaining / 1000; + int ms = remaining % 1000; + s.AppendFormat("%02d:%02d:%02d.%03d", hours, minutes, seconds, ms); + s << " 0000"; + for (uint8_t byte : payload) { + s.AppendFormat(" %02x", byte); + } + s << " # SCTP_PACKET " << socket_name; + RTC_LOG(LS_VERBOSE) << s.str(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/public/text_pcap_packet_observer.h b/net/dcsctp/public/text_pcap_packet_observer.h new file mode 100644 index 0000000000..0685771ccf --- /dev/null +++ b/net/dcsctp/public/text_pcap_packet_observer.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_TEXT_PCAP_PACKET_OBSERVER_H_ +#define NET_DCSCTP_PUBLIC_TEXT_PCAP_PACKET_OBSERVER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/public/packet_observer.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// Print outs all sent and received packets to the logs, at LS_VERBOSE severity. +class TextPcapPacketObserver : public dcsctp::PacketObserver { + public: + explicit TextPcapPacketObserver(absl::string_view name) : name_(name) {} + + // Implementation of `dcsctp::PacketObserver`. + void OnSentPacket(dcsctp::TimeMs now, + rtc::ArrayView payload) override; + + void OnReceivedPacket(dcsctp::TimeMs now, + rtc::ArrayView payload) override; + + // Prints a packet to the log. Exposed to allow it to be used in compatibility + // tests suites that don't use PacketObserver. + static void PrintPacket(absl::string_view prefix, + absl::string_view socket_name, + dcsctp::TimeMs now, + rtc::ArrayView payload); + + private: + const std::string name_; +}; + +} // namespace dcsctp +#endif // NET_DCSCTP_PUBLIC_TEXT_PCAP_PACKET_OBSERVER_H_ diff --git a/net/dcsctp/public/timeout.h b/net/dcsctp/public/timeout.h new file mode 100644 index 0000000000..64ba351093 --- /dev/null +++ b/net/dcsctp/public/timeout.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_TIMEOUT_H_ +#define NET_DCSCTP_PUBLIC_TIMEOUT_H_ + +#include + +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// A very simple timeout that can be started and stopped. When started, +// it will be given a unique `timeout_id` which should be provided to +// `DcSctpSocket::HandleTimeout` when it expires. +class Timeout { + public: + virtual ~Timeout() = default; + + // Called to start time timeout, with the duration in milliseconds as + // `duration` and with the timeout identifier as `timeout_id`, which - if + // the timeout expires - shall be provided to `DcSctpSocket::HandleTimeout`. + // + // `Start` and `Stop` will always be called in pairs. In other words will + // ´Start` never be called twice, without a call to `Stop` in between. + virtual void Start(DurationMs duration, TimeoutID timeout_id) = 0; + + // Called to stop the running timeout. + // + // `Start` and `Stop` will always be called in pairs. In other words will + // ´Start` never be called twice, without a call to `Stop` in between. + // + // `Stop` will always be called prior to releasing this object. + virtual void Stop() = 0; + + // Called to restart an already running timeout, with the `duration` and + // `timeout_id` parameters as described in `Start`. This can be overridden by + // the implementation to restart it more efficiently. + virtual void Restart(DurationMs duration, TimeoutID timeout_id) { + Stop(); + Start(duration, timeout_id); + } +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_TIMEOUT_H_ diff --git a/net/dcsctp/public/types.h b/net/dcsctp/public/types.h new file mode 100644 index 0000000000..d516daffe3 --- /dev/null +++ b/net/dcsctp/public/types.h @@ -0,0 +1,110 @@ +/* + * Copyright 2019 The Chromium Authors. All rights reserved. + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_TYPES_H_ +#define NET_DCSCTP_PUBLIC_TYPES_H_ + +#include +#include + +#include "net/dcsctp/public/strong_alias.h" + +namespace dcsctp { + +// Stream Identifier +using StreamID = StrongAlias; + +// Payload Protocol Identifier (PPID) +using PPID = StrongAlias; + +// Timeout Identifier +using TimeoutID = StrongAlias; + +// Indicates if a message is allowed to be received out-of-order compared to +// other messages on the same stream. +using IsUnordered = StrongAlias; + +// Duration, as milliseconds. Overflows after 24 days. +class DurationMs : public StrongAlias { + public: + constexpr explicit DurationMs(const UnderlyingType& v) + : StrongAlias(v) {} + + // Convenience methods for working with time. + constexpr DurationMs& operator+=(DurationMs d) { + value_ += d.value_; + return *this; + } + constexpr DurationMs& operator-=(DurationMs d) { + value_ -= d.value_; + return *this; + } + template + constexpr DurationMs& operator*=(T factor) { + value_ *= factor; + return *this; + } +}; + +constexpr inline DurationMs operator+(DurationMs lhs, DurationMs rhs) { + return lhs += rhs; +} +constexpr inline DurationMs operator-(DurationMs lhs, DurationMs rhs) { + return lhs -= rhs; +} +template +constexpr inline DurationMs operator*(DurationMs lhs, T rhs) { + return lhs *= rhs; +} +template +constexpr inline DurationMs operator*(T lhs, DurationMs rhs) { + return rhs *= lhs; +} +constexpr inline int32_t operator/(DurationMs lhs, DurationMs rhs) { + return lhs.value() / rhs.value(); +} + +// Represents time, in milliseconds since a client-defined epoch. +class TimeMs : public StrongAlias { + public: + constexpr explicit TimeMs(const UnderlyingType& v) + : StrongAlias(v) {} + + // Convenience methods for working with time. + constexpr TimeMs& operator+=(DurationMs d) { + value_ += *d; + return *this; + } + constexpr TimeMs& operator-=(DurationMs d) { + value_ -= *d; + return *this; + } + + static constexpr TimeMs InfiniteFuture() { + return TimeMs(std::numeric_limits::max()); + } +}; + +constexpr inline TimeMs operator+(TimeMs lhs, DurationMs rhs) { + return lhs += rhs; +} +constexpr inline TimeMs operator+(DurationMs lhs, TimeMs rhs) { + return rhs += lhs; +} +constexpr inline TimeMs operator-(TimeMs lhs, DurationMs rhs) { + return lhs -= rhs; +} +constexpr inline DurationMs operator-(TimeMs lhs, TimeMs rhs) { + return DurationMs(*lhs - *rhs); +} + +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_TYPES_H_ diff --git a/net/dcsctp/public/types_test.cc b/net/dcsctp/public/types_test.cc new file mode 100644 index 0000000000..d3d1240751 --- /dev/null +++ b/net/dcsctp/public/types_test.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/public/types.h" + +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(TypesTest, DurationOperators) { + DurationMs d1(10); + DurationMs d2(25); + EXPECT_EQ(d1 + d2, DurationMs(35)); + EXPECT_EQ(d2 - d1, DurationMs(15)); + + d1 += d2; + EXPECT_EQ(d1, DurationMs(35)); + + d1 -= DurationMs(5); + EXPECT_EQ(d1, DurationMs(30)); + + d1 *= 1.5; + EXPECT_EQ(d1, DurationMs(45)); + + EXPECT_EQ(DurationMs(10) * 2, DurationMs(20)); +} + +TEST(TypesTest, TimeOperators) { + EXPECT_EQ(TimeMs(250) + DurationMs(100), TimeMs(350)); + EXPECT_EQ(DurationMs(250) + TimeMs(100), TimeMs(350)); + EXPECT_EQ(TimeMs(250) - DurationMs(100), TimeMs(150)); + EXPECT_EQ(TimeMs(250) - TimeMs(100), DurationMs(150)); + + TimeMs t1(150); + t1 -= DurationMs(50); + EXPECT_EQ(t1, TimeMs(100)); + t1 += DurationMs(200); + EXPECT_EQ(t1, TimeMs(300)); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/rx/BUILD.gn b/net/dcsctp/rx/BUILD.gn new file mode 100644 index 0000000000..fb92513158 --- /dev/null +++ b/net/dcsctp/rx/BUILD.gn @@ -0,0 +1,122 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_library("data_tracker") { + deps = [ + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../timer", + ] + sources = [ + "data_tracker.cc", + "data_tracker.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_source_set("reassembly_streams") { + deps = [ + "../../../api:array_view", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../public:types", + ] + sources = [ "reassembly_streams.h" ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +rtc_library("traditional_reassembly_streams") { + deps = [ + ":reassembly_streams", + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../public:types", + ] + sources = [ + "traditional_reassembly_streams.cc", + "traditional_reassembly_streams.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("reassembly_queue") { + deps = [ + ":reassembly_streams", + ":traditional_reassembly_streams", + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:internal_types", + "../common:sequence_numbers", + "../common:str_join", + "../packet:chunk", + "../packet:data", + "../packet:parameter", + "../public:types", + ] + sources = [ + "reassembly_queue.cc", + "reassembly_queue.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +if (rtc_include_tests) { + rtc_library("dcsctp_rx_unittests") { + testonly = true + + deps = [ + ":data_tracker", + ":reassembly_queue", + ":reassembly_streams", + ":traditional_reassembly_streams", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../rtc_base:rtc_base_approved", + "../../../test:test_support", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../public:types", + "../testing:data_generator", + "../timer", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + sources = [ + "data_tracker_test.cc", + "reassembly_queue_test.cc", + "traditional_reassembly_streams_test.cc", + ] + } +} diff --git a/net/dcsctp/rx/data_tracker.cc b/net/dcsctp/rx/data_tracker.cc new file mode 100644 index 0000000000..5b563a8463 --- /dev/null +++ b/net/dcsctp/rx/data_tracker.cc @@ -0,0 +1,359 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/data_tracker.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/timer/timer.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +constexpr size_t DataTracker::kMaxDuplicateTsnReported; +constexpr size_t DataTracker::kMaxGapAckBlocksReported; + +bool DataTracker::AdditionalTsnBlocks::Add(UnwrappedTSN tsn) { + // Find any block to expand. It will look for any block that includes (also + // when expanded) the provided `tsn`. It will return the block that is greater + // than, or equal to `tsn`. + auto it = absl::c_lower_bound( + blocks_, tsn, [&](const TsnRange& elem, const UnwrappedTSN& t) { + return elem.last.next_value() < t; + }); + + if (it == blocks_.end()) { + // No matching block found. There is no greater than, or equal block - which + // means that this TSN is greater than any block. It can then be inserted at + // the end. + blocks_.emplace_back(tsn, tsn); + return true; + } + + if (tsn >= it->first && tsn <= it->last) { + // It's already in this block. + return false; + } + + if (it->last.next_value() == tsn) { + // This block can be expanded to the right, or merged with the next. + auto next_it = it + 1; + if (next_it != blocks_.end() && tsn.next_value() == next_it->first) { + // Expanding it would make it adjacent to next block - merge those. + it->last = next_it->last; + blocks_.erase(next_it); + return true; + } + + // Expand to the right + it->last = tsn; + return true; + } + + if (it->first == tsn.next_value()) { + // This block can be expanded to the left. Merging to the left would've been + // covered by the above "merge to the right". Both blocks (expand a + // right-most block to the left and expand a left-most block to the right) + // would match, but the left-most would be returned by std::lower_bound. + RTC_DCHECK(it == blocks_.begin() || (it - 1)->last.next_value() != tsn); + + // Expand to the left. + it->first = tsn; + return true; + } + + // Need to create a new block in the middle. + blocks_.emplace(it, tsn, tsn); + return true; +} + +void DataTracker::AdditionalTsnBlocks::EraseTo(UnwrappedTSN tsn) { + // Find the block that is greater than or equals `tsn`. + auto it = absl::c_lower_bound( + blocks_, tsn, [&](const TsnRange& elem, const UnwrappedTSN& t) { + return elem.last < t; + }); + + // The block that is found is greater or equal (or possibly ::end, when no + // block is greater or equal). All blocks before this block can be safely + // removed. the TSN might be within this block, so possibly truncate it. + bool tsn_is_within_block = it != blocks_.end() && tsn >= it->first; + blocks_.erase(blocks_.begin(), it); + + if (tsn_is_within_block) { + blocks_.front().first = tsn.next_value(); + } +} + +void DataTracker::AdditionalTsnBlocks::PopFront() { + RTC_DCHECK(!blocks_.empty()); + blocks_.erase(blocks_.begin()); +} + +bool DataTracker::IsTSNValid(TSN tsn) const { + UnwrappedTSN unwrapped_tsn = tsn_unwrapper_.PeekUnwrap(tsn); + + // Note that this method doesn't return `false` for old DATA chunks, as those + // are actually valid, and receiving those may affect the generated SACK + // response (by setting "duplicate TSNs"). + + uint32_t difference = + UnwrappedTSN::Difference(unwrapped_tsn, last_cumulative_acked_tsn_); + if (difference > kMaxAcceptedOutstandingFragments) { + return false; + } + return true; +} + +void DataTracker::Observe(TSN tsn, + AnyDataChunk::ImmediateAckFlag immediate_ack) { + UnwrappedTSN unwrapped_tsn = tsn_unwrapper_.Unwrap(tsn); + + // IsTSNValid must be called prior to calling this method. + RTC_DCHECK( + UnwrappedTSN::Difference(unwrapped_tsn, last_cumulative_acked_tsn_) <= + kMaxAcceptedOutstandingFragments); + + // Old chunk already seen before? + if (unwrapped_tsn <= last_cumulative_acked_tsn_) { + if (duplicate_tsns_.size() < kMaxDuplicateTsnReported) { + duplicate_tsns_.insert(unwrapped_tsn.Wrap()); + } + // https://datatracker.ietf.org/doc/html/rfc4960#section-6.2 + // "When a packet arrives with duplicate DATA chunk(s) and with no new DATA + // chunk(s), the endpoint MUST immediately send a SACK with no delay. If a + // packet arrives with duplicate DATA chunk(s) bundled with new DATA chunks, + // the endpoint MAY immediately send a SACK." + UpdateAckState(AckState::kImmediate, "duplicate data"); + } else { + if (unwrapped_tsn == last_cumulative_acked_tsn_.next_value()) { + last_cumulative_acked_tsn_ = unwrapped_tsn; + // The cumulative acked tsn may be moved even further, if a gap was + // filled. + if (!additional_tsn_blocks_.empty() && + additional_tsn_blocks_.front().first == + last_cumulative_acked_tsn_.next_value()) { + last_cumulative_acked_tsn_ = additional_tsn_blocks_.front().last; + additional_tsn_blocks_.PopFront(); + } + } else { + bool inserted = additional_tsn_blocks_.Add(unwrapped_tsn); + if (!inserted) { + // Already seen before. + if (duplicate_tsns_.size() < kMaxDuplicateTsnReported) { + duplicate_tsns_.insert(unwrapped_tsn.Wrap()); + } + // https://datatracker.ietf.org/doc/html/rfc4960#section-6.2 + // "When a packet arrives with duplicate DATA chunk(s) and with no new + // DATA chunk(s), the endpoint MUST immediately send a SACK with no + // delay. If a packet arrives with duplicate DATA chunk(s) bundled with + // new DATA chunks, the endpoint MAY immediately send a SACK." + // No need to do this. SACKs are sent immediately on packet loss below. + } + } + } + + // https://tools.ietf.org/html/rfc4960#section-6.7 + // "Upon the reception of a new DATA chunk, an endpoint shall examine the + // continuity of the TSNs received. If the endpoint detects a gap in + // the received DATA chunk sequence, it SHOULD send a SACK with Gap Ack + // Blocks immediately. The data receiver continues sending a SACK after + // receipt of each SCTP packet that doesn't fill the gap." + if (!additional_tsn_blocks_.empty()) { + UpdateAckState(AckState::kImmediate, "packet loss"); + } + + // https://tools.ietf.org/html/rfc7053#section-5.2 + // "Upon receipt of an SCTP packet containing a DATA chunk with the I + // bit set, the receiver SHOULD NOT delay the sending of the corresponding + // SACK chunk, i.e., the receiver SHOULD immediately respond with the + // corresponding SACK chunk." + if (*immediate_ack) { + UpdateAckState(AckState::kImmediate, "immediate-ack bit set"); + } + + if (!seen_packet_) { + // https://tools.ietf.org/html/rfc4960#section-5.1 + // "After the reception of the first DATA chunk in an association the + // endpoint MUST immediately respond with a SACK to acknowledge the DATA + // chunk." + seen_packet_ = true; + UpdateAckState(AckState::kImmediate, "first DATA chunk"); + } + + // https://tools.ietf.org/html/rfc4960#section-6.2 + // "Specifically, an acknowledgement SHOULD be generated for at least + // every second packet (not every second DATA chunk) received, and SHOULD be + // generated within 200 ms of the arrival of any unacknowledged DATA chunk." + if (ack_state_ == AckState::kIdle) { + UpdateAckState(AckState::kBecomingDelayed, "received DATA when idle"); + } else if (ack_state_ == AckState::kDelayed) { + UpdateAckState(AckState::kImmediate, "received DATA when already delayed"); + } +} + +void DataTracker::HandleForwardTsn(TSN new_cumulative_ack) { + // ForwardTSN is sent to make the receiver (this socket) "forget" about partly + // received (or not received at all) data, up until `new_cumulative_ack`. + + UnwrappedTSN unwrapped_tsn = tsn_unwrapper_.Unwrap(new_cumulative_ack); + UnwrappedTSN prev_last_cum_ack_tsn = last_cumulative_acked_tsn_; + + // Old chunk already seen before? + if (unwrapped_tsn <= last_cumulative_acked_tsn_) { + // https://tools.ietf.org/html/rfc3758#section-3.6 + // "Note, if the "New Cumulative TSN" value carried in the arrived + // FORWARD TSN chunk is found to be behind or at the current cumulative TSN + // point, the data receiver MUST treat this FORWARD TSN as out-of-date and + // MUST NOT update its Cumulative TSN. The receiver SHOULD send a SACK to + // its peer (the sender of the FORWARD TSN) since such a duplicate may + // indicate the previous SACK was lost in the network." + UpdateAckState(AckState::kImmediate, + "FORWARD_TSN new_cumulative_tsn was behind"); + return; + } + + // https://tools.ietf.org/html/rfc3758#section-3.6 + // "When a FORWARD TSN chunk arrives, the data receiver MUST first update + // its cumulative TSN point to the value carried in the FORWARD TSN chunk, and + // then MUST further advance its cumulative TSN point locally if possible, as + // shown by the following example..." + + // The `new_cumulative_ack` will become the current + // `last_cumulative_acked_tsn_`, and if there have been prior "gaps" that are + // now overlapping with the new value, remove them. + last_cumulative_acked_tsn_ = unwrapped_tsn; + additional_tsn_blocks_.EraseTo(unwrapped_tsn); + + // See if the `last_cumulative_acked_tsn_` can be moved even further: + if (!additional_tsn_blocks_.empty() && + additional_tsn_blocks_.front().first == + last_cumulative_acked_tsn_.next_value()) { + last_cumulative_acked_tsn_ = additional_tsn_blocks_.front().last; + additional_tsn_blocks_.PopFront(); + } + + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "FORWARD_TSN, cum_ack_tsn=" + << *prev_last_cum_ack_tsn.Wrap() << "->" + << *new_cumulative_ack << "->" + << *last_cumulative_acked_tsn_.Wrap(); + + // https://tools.ietf.org/html/rfc3758#section-3.6 + // "Any time a FORWARD TSN chunk arrives, for the purposes of sending a + // SACK, the receiver MUST follow the same rules as if a DATA chunk had been + // received (i.e., follow the delayed sack rules specified in ..." + if (ack_state_ == AckState::kIdle) { + UpdateAckState(AckState::kBecomingDelayed, + "received FORWARD_TSN when idle"); + } else if (ack_state_ == AckState::kDelayed) { + UpdateAckState(AckState::kImmediate, + "received FORWARD_TSN when already delayed"); + } +} + +SackChunk DataTracker::CreateSelectiveAck(size_t a_rwnd) { + // Note that in SCTP, the receiver side is allowed to discard received data + // and signal that to the sender, but only chunks that have previously been + // reported in the gap-ack-blocks. However, this implementation will never do + // that. So this SACK produced is more like a NR-SACK as explained in + // https://ieeexplore.ieee.org/document/4697037 and which there is an RFC + // draft at https://tools.ietf.org/html/draft-tuexen-tsvwg-sctp-multipath-17. + std::set duplicate_tsns; + duplicate_tsns_.swap(duplicate_tsns); + + return SackChunk(last_cumulative_acked_tsn_.Wrap(), a_rwnd, + CreateGapAckBlocks(), std::move(duplicate_tsns)); +} + +std::vector DataTracker::CreateGapAckBlocks() const { + const auto& blocks = additional_tsn_blocks_.blocks(); + std::vector gap_ack_blocks; + gap_ack_blocks.reserve(std::min(blocks.size(), kMaxGapAckBlocksReported)); + for (size_t i = 0; i < blocks.size() && i < kMaxGapAckBlocksReported; ++i) { + auto start_diff = + UnwrappedTSN::Difference(blocks[i].first, last_cumulative_acked_tsn_); + auto end_diff = + UnwrappedTSN::Difference(blocks[i].last, last_cumulative_acked_tsn_); + gap_ack_blocks.emplace_back(static_cast(start_diff), + static_cast(end_diff)); + } + + return gap_ack_blocks; +} + +bool DataTracker::ShouldSendAck(bool also_if_delayed) { + if (ack_state_ == AckState::kImmediate || + (also_if_delayed && (ack_state_ == AckState::kBecomingDelayed || + ack_state_ == AckState::kDelayed))) { + UpdateAckState(AckState::kIdle, "sending SACK"); + return true; + } + + return false; +} + +bool DataTracker::will_increase_cum_ack_tsn(TSN tsn) const { + UnwrappedTSN unwrapped = tsn_unwrapper_.PeekUnwrap(tsn); + return unwrapped == last_cumulative_acked_tsn_.next_value(); +} + +void DataTracker::ForceImmediateSack() { + ack_state_ = AckState::kImmediate; +} + +void DataTracker::HandleDelayedAckTimerExpiry() { + UpdateAckState(AckState::kImmediate, "delayed ack timer expired"); +} + +void DataTracker::ObservePacketEnd() { + if (ack_state_ == AckState::kBecomingDelayed) { + UpdateAckState(AckState::kDelayed, "packet end"); + } +} + +void DataTracker::UpdateAckState(AckState new_state, absl::string_view reason) { + if (new_state != ack_state_) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "State changed from " + << ToString(ack_state_) << " to " + << ToString(new_state) << " due to " << reason; + if (ack_state_ == AckState::kDelayed) { + delayed_ack_timer_.Stop(); + } else if (new_state == AckState::kDelayed) { + delayed_ack_timer_.Start(); + } + ack_state_ = new_state; + } +} + +absl::string_view DataTracker::ToString(AckState ack_state) { + switch (ack_state) { + case AckState::kIdle: + return "IDLE"; + case AckState::kBecomingDelayed: + return "BECOMING_DELAYED"; + case AckState::kDelayed: + return "DELAYED"; + case AckState::kImmediate: + return "IMMEDIATE"; + } +} + +} // namespace dcsctp diff --git a/net/dcsctp/rx/data_tracker.h b/net/dcsctp/rx/data_tracker.h new file mode 100644 index 0000000000..167f5a04e7 --- /dev/null +++ b/net/dcsctp/rx/data_tracker.h @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_RX_DATA_TRACKER_H_ +#define NET_DCSCTP_RX_DATA_TRACKER_H_ + +#include +#include + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/timer/timer.h" + +namespace dcsctp { + +// Keeps track of received DATA chunks and handles all logic for _when_ to +// create SACKs and also _how_ to generate them. +// +// It only uses TSNs to track delivery and doesn't need to be aware of streams. +// +// SACKs are optimally sent every second packet on connections with no packet +// loss. When packet loss is detected, it's sent for every packet. When SACKs +// are not sent directly, a timer is used to send a SACK delayed (by RTO/2, or +// 200ms, whatever is smallest). +class DataTracker { + public: + // The maximum number of duplicate TSNs that will be reported in a SACK. + static constexpr size_t kMaxDuplicateTsnReported = 20; + // The maximum number of gap-ack-blocks that will be reported in a SACK. + static constexpr size_t kMaxGapAckBlocksReported = 20; + + // The maximum number of accepted in-flight DATA chunks. This indicates the + // maximum difference from this buffer's last cumulative ack TSN, and any + // received data. Data received beyond this limit will be dropped, which will + // force the transmitter to send data that actually increases the last + // cumulative acked TSN. + static constexpr uint32_t kMaxAcceptedOutstandingFragments = 100000; + + explicit DataTracker(absl::string_view log_prefix, + Timer* delayed_ack_timer, + TSN peer_initial_tsn) + : log_prefix_(std::string(log_prefix) + "dtrack: "), + delayed_ack_timer_(*delayed_ack_timer), + last_cumulative_acked_tsn_( + tsn_unwrapper_.Unwrap(TSN(*peer_initial_tsn - 1))) {} + + // Indicates if the provided TSN is valid. If this return false, the data + // should be dropped and not added to any other buffers, which essentially + // means that there is intentional packet loss. + bool IsTSNValid(TSN tsn) const; + + // Call for every incoming data chunk. + void Observe(TSN tsn, + AnyDataChunk::ImmediateAckFlag immediate_ack = + AnyDataChunk::ImmediateAckFlag(false)); + // Called at the end of processing an SCTP packet. + void ObservePacketEnd(); + + // Called for incoming FORWARD-TSN/I-FORWARD-TSN chunks + void HandleForwardTsn(TSN new_cumulative_ack); + + // Indicates if a SACK should be sent. There may be other reasons to send a + // SACK, but if this function indicates so, it should be sent as soon as + // possible. Calling this function will make it clear a flag so that if it's + // called again, it will probably return false. + // + // If the delayed ack timer is running, this method will return false _unless_ + // `also_if_delayed` is set to true. Then it will return true as well. + bool ShouldSendAck(bool also_if_delayed = false); + + // Returns the last cumulative ack TSN - the last seen data chunk's TSN + // value before any packet loss was detected. + TSN last_cumulative_acked_tsn() const { + return TSN(last_cumulative_acked_tsn_.Wrap()); + } + + // Returns true if the received `tsn` would increase the cumulative ack TSN. + bool will_increase_cum_ack_tsn(TSN tsn) const; + + // Forces `ShouldSendSack` to return true. + void ForceImmediateSack(); + + // Note that this will clear `duplicates_`, so every SackChunk that is + // consumed must be sent. + SackChunk CreateSelectiveAck(size_t a_rwnd); + + void HandleDelayedAckTimerExpiry(); + + private: + enum class AckState { + // No need to send an ACK. + kIdle, + + // Has received data chunks (but not yet end of packet). + kBecomingDelayed, + + // Has received data chunks and the end of a packet. Delayed ack timer is + // running and a SACK will be sent on expiry, or if DATA is sent, or after + // next packet with data. + kDelayed, + + // Send a SACK immediately after handling this packet. + kImmediate, + }; + + // Represents ranges of TSNs that have been received that are not directly + // following the last cumulative acked TSN. This information is returned to + // the sender in the "gap ack blocks" in the SACK chunk. The blocks are always + // non-overlapping and non-adjacent. + class AdditionalTsnBlocks { + public: + // Represents an inclusive range of received TSNs, i.e. [first, last]. + struct TsnRange { + TsnRange(UnwrappedTSN first, UnwrappedTSN last) + : first(first), last(last) {} + UnwrappedTSN first; + UnwrappedTSN last; + }; + + // Adds a TSN to the set. This will try to expand any existing block and + // might merge blocks to ensure that all blocks are non-adjacent. If a + // current block can't be expanded, a new block is created. + // + // The return value indicates if `tsn` was added. If false is returned, the + // `tsn` was already represented in one of the blocks. + bool Add(UnwrappedTSN tsn); + + // Erases all TSNs up to, and including `tsn`. This will remove all blocks + // that are completely below `tsn` and may truncate a block where `tsn` is + // within that block. In that case, the frontmost block's start TSN will be + // the next following tsn after `tsn`. + void EraseTo(UnwrappedTSN tsn); + + // Removes the first block. Must not be called on an empty set. + void PopFront(); + + const std::vector& blocks() const { return blocks_; } + + bool empty() const { return blocks_.empty(); } + + const TsnRange& front() const { return blocks_.front(); } + + private: + // A sorted vector of non-overlapping and non-adjacent blocks. + std::vector blocks_; + }; + + std::vector CreateGapAckBlocks() const; + void UpdateAckState(AckState new_state, absl::string_view reason); + static absl::string_view ToString(AckState ack_state); + + const std::string log_prefix_; + // If a packet has ever been seen. + bool seen_packet_ = false; + Timer& delayed_ack_timer_; + AckState ack_state_ = AckState::kIdle; + UnwrappedTSN::Unwrapper tsn_unwrapper_; + + // All TSNs up until (and including) this value have been seen. + UnwrappedTSN last_cumulative_acked_tsn_; + // Received TSNs that are not directly following `last_cumulative_acked_tsn_`. + AdditionalTsnBlocks additional_tsn_blocks_; + std::set duplicate_tsns_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_RX_DATA_TRACKER_H_ diff --git a/net/dcsctp/rx/data_tracker_test.cc b/net/dcsctp/rx/data_tracker_test.cc new file mode 100644 index 0000000000..5c2e56fb2b --- /dev/null +++ b/net/dcsctp/rx/data_tracker_test.cc @@ -0,0 +1,636 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/data_tracker.h" + +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/timer/fake_timeout.h" +#include "net/dcsctp/timer/timer.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +constexpr size_t kArwnd = 10000; +constexpr TSN kInitialTSN(11); + +class DataTrackerTest : public testing::Test { + protected: + DataTrackerTest() + : timeout_manager_([this]() { return now_; }), + timer_manager_([this]() { return timeout_manager_.CreateTimeout(); }), + timer_(timer_manager_.CreateTimer( + "test/delayed_ack", + []() { return absl::nullopt; }, + TimerOptions(DurationMs(0)))), + buf_("log: ", timer_.get(), kInitialTSN) {} + + void Observer(std::initializer_list tsns) { + for (const uint32_t tsn : tsns) { + buf_.Observe(TSN(tsn), AnyDataChunk::ImmediateAckFlag(false)); + } + } + + TimeMs now_ = TimeMs(0); + FakeTimeoutManager timeout_manager_; + TimerManager timer_manager_; + std::unique_ptr timer_; + DataTracker buf_; +}; + +TEST_F(DataTrackerTest, Empty) { + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ObserverSingleInOrderPacket) { + Observer({11}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ObserverManyInOrderMovesCumulativeTsnAck) { + Observer({11, 12, 13}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(13)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ObserveOutOfOrderMovesCumulativeTsnAck) { + Observer({12, 13, 14, 11}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, SingleGap) { + Observer({12}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ExampleFromRFC4960Section334) { + Observer({11, 12, 14, 15, 17}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(12)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 5))); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, AckAlreadyReceivedChunk) { + Observer({11}); + SackChunk sack1 = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack1.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack1.gap_ack_blocks(), IsEmpty()); + + // Receive old chunk + Observer({8}); + SackChunk sack2 = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack2.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack2.gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, DoubleSendRetransmittedChunk) { + Observer({11, 13, 14, 15}); + SackChunk sack1 = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack1.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack1.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 4))); + + // Fill in the hole. + Observer({12, 16, 17, 18}); + SackChunk sack2 = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack2.cumulative_tsn_ack(), TSN(18)); + EXPECT_THAT(sack2.gap_ack_blocks(), IsEmpty()); + + // Receive chunk 12 again. + Observer({12, 19, 20, 21}); + SackChunk sack3 = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack3.cumulative_tsn_ack(), TSN(21)); + EXPECT_THAT(sack3.gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ForwardTsnSimple) { + // Messages (11, 12, 13), (14, 15) - first message expires. + Observer({11, 12, 15}); + + buf_.HandleForwardTsn(TSN(13)); + + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(13)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); +} + +TEST_F(DataTrackerTest, ForwardTsnSkipsFromGapBlock) { + // Messages (11, 12, 13), (14, 15) - first message expires. + Observer({11, 12, 14}); + + buf_.HandleForwardTsn(TSN(13)); + + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ExampleFromRFC3758) { + buf_.HandleForwardTsn(TSN(102)); + + Observer({102, 104, 105, 107}); + + buf_.HandleForwardTsn(TSN(103)); + + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(105)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); +} + +TEST_F(DataTrackerTest, EmptyAllAcks) { + Observer({11, 13, 14, 15}); + + buf_.HandleForwardTsn(TSN(100)); + + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(100)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, SetsArwndCorrectly) { + SackChunk sack1 = buf_.CreateSelectiveAck(/*a_rwnd=*/100); + EXPECT_EQ(sack1.a_rwnd(), 100u); + + SackChunk sack2 = buf_.CreateSelectiveAck(/*a_rwnd=*/101); + EXPECT_EQ(sack2.a_rwnd(), 101u); +} + +TEST_F(DataTrackerTest, WillIncreaseCumAckTsn) { + EXPECT_EQ(buf_.last_cumulative_acked_tsn(), TSN(10)); + EXPECT_FALSE(buf_.will_increase_cum_ack_tsn(TSN(10))); + EXPECT_TRUE(buf_.will_increase_cum_ack_tsn(TSN(11))); + EXPECT_FALSE(buf_.will_increase_cum_ack_tsn(TSN(12))); + + Observer({11, 12, 13, 14, 15}); + EXPECT_EQ(buf_.last_cumulative_acked_tsn(), TSN(15)); + EXPECT_FALSE(buf_.will_increase_cum_ack_tsn(TSN(15))); + EXPECT_TRUE(buf_.will_increase_cum_ack_tsn(TSN(16))); + EXPECT_FALSE(buf_.will_increase_cum_ack_tsn(TSN(17))); +} + +TEST_F(DataTrackerTest, ForceShouldSendSackImmediately) { + EXPECT_FALSE(buf_.ShouldSendAck()); + + buf_.ForceImmediateSack(); + + EXPECT_TRUE(buf_.ShouldSendAck()); +} + +TEST_F(DataTrackerTest, WillAcceptValidTSNs) { + // The initial TSN is always one more than the last, which is our base. + TSN last_tsn = TSN(*kInitialTSN - 1); + int limit = static_cast(DataTracker::kMaxAcceptedOutstandingFragments); + + for (int i = -limit; i <= limit; ++i) { + EXPECT_TRUE(buf_.IsTSNValid(TSN(*last_tsn + i))); + } +} + +TEST_F(DataTrackerTest, WillNotAcceptInvalidTSNs) { + // The initial TSN is always one more than the last, which is our base. + TSN last_tsn = TSN(*kInitialTSN - 1); + + size_t limit = DataTracker::kMaxAcceptedOutstandingFragments; + EXPECT_FALSE(buf_.IsTSNValid(TSN(*last_tsn + limit + 1))); + EXPECT_FALSE(buf_.IsTSNValid(TSN(*last_tsn - (limit + 1)))); + EXPECT_FALSE(buf_.IsTSNValid(TSN(*last_tsn + 0x8000000))); + EXPECT_FALSE(buf_.IsTSNValid(TSN(*last_tsn - 0x8000000))); +} + +TEST_F(DataTrackerTest, ReportSingleDuplicateTsns) { + Observer({11, 12, 11}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(12)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), UnorderedElementsAre(TSN(11))); +} + +TEST_F(DataTrackerTest, ReportMultipleDuplicateTsns) { + Observer({11, 12, 13, 14, 12, 13, 12, 13, 15, 16}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(16)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), UnorderedElementsAre(TSN(12), TSN(13))); +} + +TEST_F(DataTrackerTest, ReportDuplicateTsnsInGapAckBlocks) { + Observer({11, /*12,*/ 13, 14, 13, 14, 15, 16}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 5))); + EXPECT_THAT(sack.duplicate_tsns(), UnorderedElementsAre(TSN(13), TSN(14))); +} + +TEST_F(DataTrackerTest, ClearsDuplicateTsnsAfterCreatingSack) { + Observer({11, 12, 13, 14, 12, 13, 12, 13, 15, 16}); + SackChunk sack1 = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack1.cumulative_tsn_ack(), TSN(16)); + EXPECT_THAT(sack1.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack1.duplicate_tsns(), UnorderedElementsAre(TSN(12), TSN(13))); + + Observer({17}); + SackChunk sack2 = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack2.cumulative_tsn_ack(), TSN(17)); + EXPECT_THAT(sack2.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack2.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, LimitsNumberOfDuplicatesReported) { + for (size_t i = 0; i < DataTracker::kMaxDuplicateTsnReported + 10; ++i) { + TSN tsn(11 + i); + buf_.Observe(tsn, AnyDataChunk::ImmediateAckFlag(false)); + buf_.Observe(tsn, AnyDataChunk::ImmediateAckFlag(false)); + } + + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), + SizeIs(DataTracker::kMaxDuplicateTsnReported)); +} + +TEST_F(DataTrackerTest, LimitsNumberOfGapAckBlocksReported) { + for (size_t i = 0; i < DataTracker::kMaxGapAckBlocksReported + 10; ++i) { + TSN tsn(11 + i * 2); + buf_.Observe(tsn, AnyDataChunk::ImmediateAckFlag(false)); + } + + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack.gap_ack_blocks(), + SizeIs(DataTracker::kMaxGapAckBlocksReported)); +} + +TEST_F(DataTrackerTest, SendsSackForFirstPacketObserved) { + Observer({11}); + buf_.ObservePacketEnd(); + EXPECT_TRUE(buf_.ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); +} + +TEST_F(DataTrackerTest, SendsSackEverySecondPacketWhenThereIsNoPacketLoss) { + Observer({11}); + buf_.ObservePacketEnd(); + EXPECT_TRUE(buf_.ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({12}); + buf_.ObservePacketEnd(); + EXPECT_FALSE(buf_.ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); + Observer({13}); + buf_.ObservePacketEnd(); + EXPECT_TRUE(buf_.ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({14}); + buf_.ObservePacketEnd(); + EXPECT_FALSE(buf_.ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); + Observer({15}); + buf_.ObservePacketEnd(); + EXPECT_TRUE(buf_.ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); +} + +TEST_F(DataTrackerTest, SendsSackEveryPacketOnPacketLoss) { + Observer({11}); + buf_.ObservePacketEnd(); + EXPECT_TRUE(buf_.ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({13}); + buf_.ObservePacketEnd(); + EXPECT_TRUE(buf_.ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({14}); + buf_.ObservePacketEnd(); + EXPECT_TRUE(buf_.ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({15}); + buf_.ObservePacketEnd(); + EXPECT_TRUE(buf_.ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({16}); + buf_.ObservePacketEnd(); + EXPECT_TRUE(buf_.ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + // Fill the hole. + Observer({12}); + buf_.ObservePacketEnd(); + EXPECT_FALSE(buf_.ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); + // Goes back to every second packet + Observer({17}); + buf_.ObservePacketEnd(); + EXPECT_TRUE(buf_.ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({18}); + buf_.ObservePacketEnd(); + EXPECT_FALSE(buf_.ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); +} + +TEST_F(DataTrackerTest, SendsSackOnDuplicateDataChunks) { + Observer({11}); + buf_.ObservePacketEnd(); + EXPECT_TRUE(buf_.ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({11}); + buf_.ObservePacketEnd(); + EXPECT_TRUE(buf_.ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({12}); + buf_.ObservePacketEnd(); + EXPECT_FALSE(buf_.ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); + // Goes back to every second packet + Observer({13}); + buf_.ObservePacketEnd(); + EXPECT_TRUE(buf_.ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + // Duplicate again + Observer({12}); + buf_.ObservePacketEnd(); + EXPECT_TRUE(buf_.ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); +} + +TEST_F(DataTrackerTest, GapAckBlockAddSingleBlock) { + Observer({12}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); +} + +TEST_F(DataTrackerTest, GapAckBlockAddsAnother) { + Observer({12}); + Observer({14}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2), + SackChunk::GapAckBlock(4, 4))); +} + +TEST_F(DataTrackerTest, GapAckBlockAddsDuplicate) { + Observer({12}); + Observer({12}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); + EXPECT_THAT(sack.duplicate_tsns(), ElementsAre(TSN(12))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToRight) { + Observer({12}); + Observer({13}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 3))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToRightWithOther) { + Observer({12}); + Observer({20}); + Observer({30}); + Observer({21}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 2), // + SackChunk::GapAckBlock(10, 11), // + SackChunk::GapAckBlock(20, 20))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToLeft) { + Observer({13}); + Observer({12}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 3))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToLeftWithOther) { + Observer({12}); + Observer({21}); + Observer({30}); + Observer({20}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 2), // + SackChunk::GapAckBlock(10, 11), // + SackChunk::GapAckBlock(20, 20))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToLRightAndMerges) { + Observer({12}); + Observer({20}); + Observer({22}); + Observer({30}); + Observer({21}); + SackChunk sack = buf_.CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 2), // + SackChunk::GapAckBlock(10, 12), // + SackChunk::GapAckBlock(20, 20))); +} + +TEST_F(DataTrackerTest, GapAckBlockMergesManyBlocksIntoOne) { + Observer({22}); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12))); + Observer({30}); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(20, 20))); + Observer({24}); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(14, 14), // + SackChunk::GapAckBlock(20, 20))); + Observer({28}); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(14, 14), // + SackChunk::GapAckBlock(18, 18), // + SackChunk::GapAckBlock(20, 20))); + Observer({26}); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(14, 14), // + SackChunk::GapAckBlock(16, 16), // + SackChunk::GapAckBlock(18, 18), // + SackChunk::GapAckBlock(20, 20))); + Observer({29}); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(14, 14), // + SackChunk::GapAckBlock(16, 16), // + SackChunk::GapAckBlock(18, 20))); + Observer({23}); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 14), // + SackChunk::GapAckBlock(16, 16), // + SackChunk::GapAckBlock(18, 20))); + Observer({27}); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 14), // + SackChunk::GapAckBlock(16, 20))); + + Observer({25}); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 20))); + Observer({20}); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(10, 10), // + SackChunk::GapAckBlock(12, 20))); + Observer({32}); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(10, 10), // + SackChunk::GapAckBlock(12, 20), // + SackChunk::GapAckBlock(22, 22))); + Observer({21}); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(10, 20), // + SackChunk::GapAckBlock(22, 22))); + Observer({31}); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(10, 22))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveBeforeCumAckTsn) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + buf_.HandleForwardTsn(TSN(8)); + EXPECT_EQ(buf_.CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 4), // + SackChunk::GapAckBlock(10, 12), + SackChunk::GapAckBlock(20, 21))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveBeforeFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + buf_.HandleForwardTsn(TSN(11)); + EXPECT_EQ(buf_.CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(6, 8), // + SackChunk::GapAckBlock(16, 17))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveAtBeginningOfFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + buf_.HandleForwardTsn(TSN(12)); + EXPECT_EQ(buf_.CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(6, 8), // + SackChunk::GapAckBlock(16, 17))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveAtMiddleOfFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + buf_.HandleForwardTsn(TSN(13)); + EXPECT_EQ(buf_.CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(6, 8), // + SackChunk::GapAckBlock(16, 17))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveAtEndOfFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + buf_.HandleForwardTsn(TSN(14)); + EXPECT_EQ(buf_.CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(6, 8), // + SackChunk::GapAckBlock(16, 17))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightAfterFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + buf_.HandleForwardTsn(TSN(18)); + EXPECT_EQ(buf_.CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(18)); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 4), // + SackChunk::GapAckBlock(12, 13))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightBeforeSecondBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + buf_.HandleForwardTsn(TSN(19)); + EXPECT_EQ(buf_.CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(22)); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(8, 9))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightAtStartOfSecondBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + buf_.HandleForwardTsn(TSN(20)); + EXPECT_EQ(buf_.CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(22)); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(8, 9))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightAtMiddleOfSecondBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + buf_.HandleForwardTsn(TSN(21)); + EXPECT_EQ(buf_.CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(22)); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(8, 9))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightAtEndOfSecondBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + buf_.HandleForwardTsn(TSN(22)); + EXPECT_EQ(buf_.CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(22)); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(8, 9))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveeFarAfterAllBlocks) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + buf_.HandleForwardTsn(TSN(40)); + EXPECT_EQ(buf_.CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(40)); + EXPECT_THAT(buf_.CreateSelectiveAck(kArwnd).gap_ack_blocks(), IsEmpty()); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/rx/reassembly_queue.cc b/net/dcsctp/rx/reassembly_queue.cc new file mode 100644 index 0000000000..581b9fcc49 --- /dev/null +++ b/net/dcsctp/rx/reassembly_queue.cc @@ -0,0 +1,245 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/reassembly_queue.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/rx/reassembly_streams.h" +#include "net/dcsctp/rx/traditional_reassembly_streams.h" +#include "rtc_base/logging.h" + +namespace dcsctp { +ReassemblyQueue::ReassemblyQueue(absl::string_view log_prefix, + TSN peer_initial_tsn, + size_t max_size_bytes) + : log_prefix_(std::string(log_prefix) + "reasm: "), + max_size_bytes_(max_size_bytes), + watermark_bytes_(max_size_bytes * kHighWatermarkLimit), + last_assembled_tsn_watermark_( + tsn_unwrapper_.Unwrap(TSN(*peer_initial_tsn - 1))), + streams_(std::make_unique( + log_prefix_, + [this](rtc::ArrayView tsns, + DcSctpMessage message) { + AddReassembledMessage(tsns, std::move(message)); + })) {} + +void ReassemblyQueue::Add(TSN tsn, Data data) { + RTC_DCHECK(IsConsistent()); + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "added tsn=" << *tsn + << ", stream=" << *data.stream_id << ":" + << *data.message_id << ":" << *data.fsn << ", type=" + << (data.is_beginning && data.is_end + ? "complete" + : data.is_beginning + ? "first" + : data.is_end ? "last" : "middle"); + + UnwrappedTSN unwrapped_tsn = tsn_unwrapper_.Unwrap(tsn); + + if (unwrapped_tsn <= last_assembled_tsn_watermark_ || + delivered_tsns_.find(unwrapped_tsn) != delivered_tsns_.end()) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "Chunk has already been delivered - skipping"; + return; + } + + // If a stream reset has been received with a "sender's last assigned tsn" in + // the future, the socket is in "deferred reset processing" mode and must + // buffer chunks until it's exited. + if (deferred_reset_streams_.has_value() && + unwrapped_tsn > + tsn_unwrapper_.Unwrap( + deferred_reset_streams_->req.sender_last_assigned_tsn())) { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ << "Deferring chunk with tsn=" << *tsn + << " until cum_ack_tsn=" + << *deferred_reset_streams_->req.sender_last_assigned_tsn(); + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "In this mode, any data arriving with a TSN larger than the + // Sender's Last Assigned TSN for the affected stream(s) MUST be queued + // locally and held until the cumulative acknowledgment point reaches the + // Sender's Last Assigned TSN." + queued_bytes_ += data.size(); + deferred_reset_streams_->deferred_chunks.emplace_back( + std::make_pair(tsn, std::move(data))); + } else { + queued_bytes_ += streams_->Add(unwrapped_tsn, std::move(data)); + } + + // https://tools.ietf.org/html/rfc4960#section-6.9 + // "Note: If the data receiver runs out of buffer space while still + // waiting for more fragments to complete the reassembly of the message, it + // should dispatch part of its inbound message through a partial delivery + // API (see Section 10), freeing some of its receive buffer space so that + // the rest of the message may be received." + + // TODO(boivie): Support EOR flag and partial delivery? + RTC_DCHECK(IsConsistent()); +} + +ReconfigurationResponseParameter::Result ReassemblyQueue::ResetStreams( + const OutgoingSSNResetRequestParameter& req, + TSN cum_tsn_ack) { + RTC_DCHECK(IsConsistent()); + if (deferred_reset_streams_.has_value()) { + // In deferred mode already. + return ReconfigurationResponseParameter::Result::kInProgress; + } else if (req.request_sequence_number() <= + last_completed_reset_req_seq_nbr_) { + // Already performed at some time previously. + return ReconfigurationResponseParameter::Result::kSuccessPerformed; + } + + UnwrappedTSN sla_tsn = tsn_unwrapper_.Unwrap(req.sender_last_assigned_tsn()); + UnwrappedTSN unwrapped_cum_tsn_ack = tsn_unwrapper_.Unwrap(cum_tsn_ack); + + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "If the Sender's Last Assigned TSN is greater than the + // cumulative acknowledgment point, then the endpoint MUST enter "deferred + // reset processing"." + if (sla_tsn > unwrapped_cum_tsn_ack) { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ + << "Entering deferred reset processing mode until cum_tsn_ack=" + << *req.sender_last_assigned_tsn(); + deferred_reset_streams_ = absl::make_optional(req); + return ReconfigurationResponseParameter::Result::kInProgress; + } + + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "... streams MUST be reset to 0 as the next expected SSN." + streams_->ResetStreams(req.stream_ids()); + last_completed_reset_req_seq_nbr_ = req.request_sequence_number(); + RTC_DCHECK(IsConsistent()); + return ReconfigurationResponseParameter::Result::kSuccessPerformed; +} + +bool ReassemblyQueue::MaybeResetStreamsDeferred(TSN cum_ack_tsn) { + RTC_DCHECK(IsConsistent()); + if (deferred_reset_streams_.has_value()) { + UnwrappedTSN unwrapped_cum_ack_tsn = tsn_unwrapper_.Unwrap(cum_ack_tsn); + UnwrappedTSN unwrapped_sla_tsn = tsn_unwrapper_.Unwrap( + deferred_reset_streams_->req.sender_last_assigned_tsn()); + if (unwrapped_cum_ack_tsn >= unwrapped_sla_tsn) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "Leaving deferred reset processing with tsn=" + << *cum_ack_tsn << ", feeding back " + << deferred_reset_streams_->deferred_chunks.size() + << " chunks"; + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "... streams MUST be reset to 0 as the next expected SSN." + streams_->ResetStreams(deferred_reset_streams_->req.stream_ids()); + std::vector> deferred_chunks = + std::move(deferred_reset_streams_->deferred_chunks); + // The response will not be sent now, but as a reply to the retried + // request, which will come as "in progress" has been sent prior. + last_completed_reset_req_seq_nbr_ = + deferred_reset_streams_->req.request_sequence_number(); + deferred_reset_streams_ = absl::nullopt; + + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "Any queued TSNs (queued at step E2) MUST now be released and processed + // normally." + for (auto& p : deferred_chunks) { + const TSN& tsn = p.first; + Data& data = p.second; + queued_bytes_ -= data.size(); + Add(tsn, std::move(data)); + } + + RTC_DCHECK(IsConsistent()); + return true; + } else { + RTC_DLOG(LS_VERBOSE) << "Staying in deferred reset processing. tsn=" + << *cum_ack_tsn; + } + } + + return false; +} + +std::vector ReassemblyQueue::FlushMessages() { + std::vector ret; + reassembled_messages_.swap(ret); + return ret; +} + +void ReassemblyQueue::AddReassembledMessage( + rtc::ArrayView tsns, + DcSctpMessage message) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Assembled message from TSN=[" + << StrJoin(tsns, ",", + [](rtc::StringBuilder& sb, UnwrappedTSN tsn) { + sb << *tsn.Wrap(); + }) + << "], message; stream_id=" << *message.stream_id() + << ", ppid=" << *message.ppid() + << ", payload=" << message.payload().size() << " bytes"; + + for (const UnwrappedTSN tsn : tsns) { + // Update watermark, or insert into delivered_tsns_ + if (tsn == last_assembled_tsn_watermark_.next_value()) { + last_assembled_tsn_watermark_.Increment(); + } else { + delivered_tsns_.insert(tsn); + } + } + + // With new TSNs in delivered_tsns, gaps might be filled. + while (!delivered_tsns_.empty() && + *delivered_tsns_.begin() == + last_assembled_tsn_watermark_.next_value()) { + last_assembled_tsn_watermark_.Increment(); + delivered_tsns_.erase(delivered_tsns_.begin()); + } + + reassembled_messages_.emplace_back(std::move(message)); +} + +void ReassemblyQueue::Handle(const AnyForwardTsnChunk& forward_tsn) { + RTC_DCHECK(IsConsistent()); + UnwrappedTSN tsn = tsn_unwrapper_.Unwrap(forward_tsn.new_cumulative_tsn()); + + last_assembled_tsn_watermark_ = std::max(last_assembled_tsn_watermark_, tsn); + delivered_tsns_.erase(delivered_tsns_.begin(), + delivered_tsns_.upper_bound(tsn)); + + queued_bytes_ -= + streams_->HandleForwardTsn(tsn, forward_tsn.skipped_streams()); + RTC_DCHECK(IsConsistent()); +} + +bool ReassemblyQueue::IsConsistent() const { + // Allow queued_bytes_ to be larger than max_size_bytes, as it's not actively + // enforced in this class. This comparison will still trigger if queued_bytes_ + // became "negative". + return (queued_bytes_ >= 0 && queued_bytes_ <= 2 * max_size_bytes_); +} + +} // namespace dcsctp diff --git a/net/dcsctp/rx/reassembly_queue.h b/net/dcsctp/rx/reassembly_queue.h new file mode 100644 index 0000000000..25cda70c58 --- /dev/null +++ b/net/dcsctp/rx/reassembly_queue.h @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_RX_REASSEMBLY_QUEUE_H_ +#define NET_DCSCTP_RX_REASSEMBLY_QUEUE_H_ + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/rx/reassembly_streams.h" + +namespace dcsctp { + +// Contains the received DATA chunks that haven't yet been reassembled, and +// reassembles chunks when possible. +// +// The actual assembly is handled by an implementation of the +// `ReassemblyStreams` interface. +// +// Except for reassembling fragmented messages, this class will also handle two +// less common operations; To handle the receiver-side of partial reliability +// (limited number of retransmissions or limited message lifetime) as well as +// stream resetting, which is used when a sender wishes to close a data channel. +// +// Partial reliability is handled when a FORWARD-TSN or I-FORWARD-TSN chunk is +// received, and it will simply delete any chunks matching the parameters in +// that chunk. This is mainly implemented in ReassemblyStreams. +// +// Resetting streams is handled when a RECONFIG chunks is received, with an +// "Outgoing SSN Reset Request" parameter. That parameter will contain a list of +// streams to reset, and a `sender_last_assigned_tsn`. If this TSN is not yet +// seen, the stream cannot be directly reset, and this class will respond that +// the reset is "deferred". But if this TSN provided is known, the stream can be +// immediately be reset. +// +// The ReassemblyQueue has a maximum size, as it would otherwise be an DoS +// attack vector where a peer could consume all memory of the other peer by +// sending a lot of ordered chunks, but carefully withholding an early one. It +// also has a watermark limit, which the caller can query is the number of bytes +// is above that limit. This is used by the caller to be selective in what to +// add to the reassembly queue, so that it's not exhausted. The caller is +// expected to call `is_full` prior to adding data to the queue and to act +// accordingly if the queue is full. +class ReassemblyQueue { + public: + // When the queue is filled over this fraction (of its maximum size), the + // socket should restrict incoming data to avoid filling up the queue. + static constexpr float kHighWatermarkLimit = 0.9; + + ReassemblyQueue(absl::string_view log_prefix, + TSN peer_initial_tsn, + size_t max_size_bytes); + + // Adds a data chunk to the queue, with a `tsn` and other parameters in + // `data`. + void Add(TSN tsn, Data data); + + // Indicates if the reassembly queue has any reassembled messages that can be + // retrieved by calling `FlushMessages`. + bool HasMessages() const { return !reassembled_messages_.empty(); } + + // Returns any reassembled messages. + std::vector FlushMessages(); + + // Handle a ForwardTSN chunk, when the sender has indicated that the received + // (this class) should forget about some chunks. This is used to implement + // partial reliability. + void Handle(const AnyForwardTsnChunk& forward_tsn); + + // Given the reset stream request and the current cum_tsn_ack, might either + // reset the streams directly (returns kSuccessPerformed), or at a later time, + // by entering the "deferred reset processing" mode (returns kInProgress). + ReconfigurationResponseParameter::Result ResetStreams( + const OutgoingSSNResetRequestParameter& req, + TSN cum_tsn_ack); + + // Given the current (updated) cum_tsn_ack, might leave "defererred reset + // processing" mode and reset streams. Returns true if so. + bool MaybeResetStreamsDeferred(TSN cum_ack_tsn); + + // The number of payload bytes that have been queued. Note that the actual + // memory usage is higher due to additional overhead of tracking received + // data. + size_t queued_bytes() const { return queued_bytes_; } + + // The remaining bytes until the queue has reached the watermark limit. + size_t remaining_bytes() const { return watermark_bytes_ - queued_bytes_; } + + // Indicates if the queue is full. Data should not be added to the queue when + // it's full. + bool is_full() const { return queued_bytes_ >= max_size_bytes_; } + + // Indicates if the queue is above the watermark limit, which is a certain + // percentage of its size. + bool is_above_watermark() const { return queued_bytes_ >= watermark_bytes_; } + + // Returns the watermark limit, in bytes. + size_t watermark_bytes() const { return watermark_bytes_; } + + private: + bool IsConsistent() const; + void AddReassembledMessage(rtc::ArrayView tsns, + DcSctpMessage message); + + struct DeferredResetStreams { + explicit DeferredResetStreams(OutgoingSSNResetRequestParameter req) + : req(std::move(req)) {} + OutgoingSSNResetRequestParameter req; + std::vector> deferred_chunks; + }; + + const std::string log_prefix_; + const size_t max_size_bytes_; + const size_t watermark_bytes_; + UnwrappedTSN::Unwrapper tsn_unwrapper_; + + // Whenever a message has been assembled, either increase + // `last_assembled_tsn_watermark_` or - if there are gaps - add the message's + // TSNs into delivered_tsns_ so that messages are not re-delivered on + // duplicate chunks. + UnwrappedTSN last_assembled_tsn_watermark_; + std::set delivered_tsns_; + // Messages that have been reassembled, and will be returned by + // `FlushMessages`. + std::vector reassembled_messages_; + + // If present, "deferred reset processing" mode is active. + absl::optional deferred_reset_streams_; + + // Contains the last request sequence number of the + // OutgoingSSNResetRequestParameter that was performed. + ReconfigRequestSN last_completed_reset_req_seq_nbr_ = ReconfigRequestSN(0); + + // The number of "payload bytes" that are in this queue, in total. + size_t queued_bytes_ = 0; + + // The actual implementation of ReassemblyStreams. + std::unique_ptr streams_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_RX_REASSEMBLY_QUEUE_H_ diff --git a/net/dcsctp/rx/reassembly_queue_test.cc b/net/dcsctp/rx/reassembly_queue_test.cc new file mode 100644 index 0000000000..e38372c7d1 --- /dev/null +++ b/net/dcsctp/rx/reassembly_queue_test.cc @@ -0,0 +1,298 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/reassembly_queue.h" + +#include + +#include +#include +#include +#include +#include + +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/testing/data_generator.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +// The default maximum size of the Reassembly Queue. +static constexpr size_t kBufferSize = 10000; + +static constexpr StreamID kStreamID(1); +static constexpr SSN kSSN(0); +static constexpr MID kMID(0); +static constexpr FSN kFSN(0); +static constexpr PPID kPPID(53); + +static constexpr std::array kShortPayload = {1, 2, 3, 4}; +static constexpr std::array kMessage2Payload = {5, 6, 7, 8}; +static constexpr std::array kLongPayload = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + +MATCHER_P3(SctpMessageIs, stream_id, ppid, expected_payload, "") { + if (arg.stream_id() != stream_id) { + *result_listener << "the stream_id is " << *arg.stream_id(); + return false; + } + + if (arg.ppid() != ppid) { + *result_listener << "the ppid is " << *arg.ppid(); + return false; + } + + if (std::vector(arg.payload().begin(), arg.payload().end()) != + std::vector(expected_payload.begin(), expected_payload.end())) { + *result_listener << "the payload is wrong"; + return false; + } + return true; +} + +class ReassemblyQueueTest : public testing::Test { + protected: + ReassemblyQueueTest() {} + DataGenerator gen_; +}; + +TEST_F(ReassemblyQueueTest, EmptyQueue) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + EXPECT_FALSE(reasm.HasMessages()); + EXPECT_EQ(reasm.queued_bytes(), 0u); +} + +TEST_F(ReassemblyQueueTest, SingleUnorderedChunkMessage) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); + EXPECT_EQ(reasm.queued_bytes(), 0u); +} + +TEST_F(ReassemblyQueueTest, LargeUnorderedChunkAllPermutations) { + std::vector tsns = {10, 11, 12, 13}; + rtc::ArrayView payload(kLongPayload); + do { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + + for (size_t i = 0; i < tsns.size(); i++) { + auto span = payload.subview((tsns[i] - 10) * 4, 4); + Data::IsBeginning is_beginning(tsns[i] == 10); + Data::IsEnd is_end(tsns[i] == 13); + + reasm.Add(TSN(tsns[i]), + Data(kStreamID, kSSN, kMID, kFSN, kPPID, + std::vector(span.begin(), span.end()), + is_beginning, is_end, IsUnordered(false))); + if (i < 3) { + EXPECT_FALSE(reasm.HasMessages()); + } else { + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kLongPayload))); + EXPECT_EQ(reasm.queued_bytes(), 0u); + } + } + } while (std::next_permutation(std::begin(tsns), std::end(tsns))); +} + +TEST_F(ReassemblyQueueTest, SingleOrderedChunkMessage) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); +} + +TEST_F(ReassemblyQueueTest, ManySmallOrderedMessages) { + std::vector tsns = {10, 11, 12, 13}; + rtc::ArrayView payload(kLongPayload); + do { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + for (size_t i = 0; i < tsns.size(); i++) { + auto span = payload.subview((tsns[i] - 10) * 4, 4); + Data::IsBeginning is_beginning(true); + Data::IsEnd is_end(true); + + SSN ssn(static_cast(tsns[i] - 10)); + reasm.Add(TSN(tsns[i]), + Data(kStreamID, ssn, kMID, kFSN, kPPID, + std::vector(span.begin(), span.end()), + is_beginning, is_end, IsUnordered(false))); + } + EXPECT_THAT( + reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, payload.subview(0, 4)), + SctpMessageIs(kStreamID, kPPID, payload.subview(4, 4)), + SctpMessageIs(kStreamID, kPPID, payload.subview(8, 4)), + SctpMessageIs(kStreamID, kPPID, payload.subview(12, 4)))); + EXPECT_EQ(reasm.queued_bytes(), 0u); + } while (std::next_permutation(std::begin(tsns), std::end(tsns))); +} + +TEST_F(ReassemblyQueueTest, RetransmissionInLargeOrdered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Ordered({1}, "B")); + reasm.Add(TSN(12), gen_.Ordered({3})); + reasm.Add(TSN(13), gen_.Ordered({4})); + reasm.Add(TSN(14), gen_.Ordered({5})); + reasm.Add(TSN(15), gen_.Ordered({6})); + reasm.Add(TSN(16), gen_.Ordered({7})); + reasm.Add(TSN(17), gen_.Ordered({8})); + EXPECT_EQ(reasm.queued_bytes(), 7u); + + // lost and retransmitted + reasm.Add(TSN(11), gen_.Ordered({2})); + reasm.Add(TSN(18), gen_.Ordered({9})); + reasm.Add(TSN(19), gen_.Ordered({10})); + EXPECT_EQ(reasm.queued_bytes(), 10u); + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Add(TSN(20), gen_.Ordered({11, 12, 13, 14, 15, 16}, "E")); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kLongPayload))); + EXPECT_EQ(reasm.queued_bytes(), 0u); +} + +TEST_F(ReassemblyQueueTest, ForwardTSNRemoveUnordered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Unordered({1}, "B")); + reasm.Add(TSN(12), gen_.Unordered({3})); + reasm.Add(TSN(13), gen_.Unordered({4}, "E")); + + reasm.Add(TSN(14), gen_.Unordered({5}, "B")); + reasm.Add(TSN(15), gen_.Unordered({6})); + reasm.Add(TSN(17), gen_.Unordered({8}, "E")); + EXPECT_EQ(reasm.queued_bytes(), 6u); + + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Handle(ForwardTsnChunk(TSN(13), {})); + EXPECT_EQ(reasm.queued_bytes(), 3u); + + // The lost chunk comes, but too late. + reasm.Add(TSN(11), gen_.Unordered({2})); + EXPECT_FALSE(reasm.HasMessages()); + EXPECT_EQ(reasm.queued_bytes(), 3u); + + // The second lost chunk comes, message is assembled. + reasm.Add(TSN(16), gen_.Unordered({7})); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_EQ(reasm.queued_bytes(), 0u); +} + +TEST_F(ReassemblyQueueTest, ForwardTSNRemoveOrdered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Ordered({1}, "B")); + reasm.Add(TSN(12), gen_.Ordered({3})); + reasm.Add(TSN(13), gen_.Ordered({4}, "E")); + + reasm.Add(TSN(14), gen_.Ordered({5}, "B")); + reasm.Add(TSN(15), gen_.Ordered({6})); + reasm.Add(TSN(16), gen_.Ordered({7})); + reasm.Add(TSN(17), gen_.Ordered({8}, "E")); + EXPECT_EQ(reasm.queued_bytes(), 7u); + + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Handle(ForwardTsnChunk( + TSN(13), {ForwardTsnChunk::SkippedStream(kStreamID, kSSN)})); + EXPECT_EQ(reasm.queued_bytes(), 0u); + + // The lost chunk comes, but too late. + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kMessage2Payload))); +} + +TEST_F(ReassemblyQueueTest, ForwardTSNRemoveALotOrdered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Ordered({1}, "B")); + reasm.Add(TSN(12), gen_.Ordered({3})); + reasm.Add(TSN(13), gen_.Ordered({4}, "E")); + + reasm.Add(TSN(15), gen_.Ordered({5}, "B")); + reasm.Add(TSN(16), gen_.Ordered({6})); + reasm.Add(TSN(17), gen_.Ordered({7})); + reasm.Add(TSN(18), gen_.Ordered({8}, "E")); + EXPECT_EQ(reasm.queued_bytes(), 7u); + + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Handle(ForwardTsnChunk( + TSN(13), {ForwardTsnChunk::SkippedStream(kStreamID, kSSN)})); + EXPECT_EQ(reasm.queued_bytes(), 0u); + + // The lost chunk comes, but too late. + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kMessage2Payload))); +} + +TEST_F(ReassemblyQueueTest, ShouldntDeliverMessagesBeforeInitialTsn) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(5), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_FALSE(reasm.HasMessages()); +} + +TEST_F(ReassemblyQueueTest, ShouldntRedeliverUnorderedMessages) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); + reasm.Add(TSN(10), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_FALSE(reasm.HasMessages()); +} + +TEST_F(ReassemblyQueueTest, ShouldntRedeliverUnorderedMessagesReallyUnordered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Unordered({1, 2, 3, 4}, "B")); + EXPECT_EQ(reasm.queued_bytes(), 4u); + + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Add(TSN(12), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 4u); + EXPECT_TRUE(reasm.HasMessages()); + + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); + reasm.Add(TSN(12), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 4u); + EXPECT_FALSE(reasm.HasMessages()); +} + +TEST_F(ReassemblyQueueTest, ShouldntDeliverBeforeForwardedTsn) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Handle(ForwardTsnChunk(TSN(12), {})); + + reasm.Add(TSN(12), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_FALSE(reasm.HasMessages()); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/rx/reassembly_streams.h b/net/dcsctp/rx/reassembly_streams.h new file mode 100644 index 0000000000..a8b42b5a2d --- /dev/null +++ b/net/dcsctp/rx/reassembly_streams.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_RX_REASSEMBLY_STREAMS_H_ +#define NET_DCSCTP_RX_REASSEMBLY_STREAMS_H_ + +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" + +namespace dcsctp { + +// Implementations of this interface will be called when data is received, when +// data should be skipped/forgotten or when sequence number should be reset. +// +// As a result of these operations - mainly when data is received - the +// implementations of this interface should notify when a message has been +// assembled, by calling the provided callback of type `OnAssembledMessage`. How +// it assembles messages will depend on e.g. if a message was sent on an ordered +// or unordered stream. +// +// Implementations will - for each operation - indicate how much additional +// memory that has been used as a result of performing the operation. This is +// used to limit the maximum amount of memory used, to prevent out-of-memory +// situations. +class ReassemblyStreams { + public: + // This callback will be provided as an argument to the constructor of the + // concrete class implementing this interface and should be called when a + // message has been assembled as well as indicating from which TSNs this + // message was assembled from. + using OnAssembledMessage = + std::function tsns, + DcSctpMessage message)>; + + virtual ~ReassemblyStreams() = default; + + // Adds a data chunk to a stream as identified in `data`. + // If it was the last remaining chunk in a message, reassemble one (or + // several, in case of ordered chunks) messages. + // + // Returns the additional number of bytes added to the queue as a result of + // performing this operation. If this addition resulted in messages being + // assembled and delivered, this may be negative. + virtual int Add(UnwrappedTSN tsn, Data data) = 0; + + // Called for incoming FORWARD-TSN/I-FORWARD-TSN chunks - when the sender + // wishes the received to skip/forget about data up until the provided TSN. + // This is used to implement partial reliability, such as limiting the number + // of retransmissions or the an expiration duration. As a result of skipping + // data, this may result in the implementation being able to assemble messages + // in ordered streams. + // + // Returns the number of bytes removed from the queue as a result of + // this operation. + virtual size_t HandleForwardTsn( + UnwrappedTSN new_cumulative_ack_tsn, + rtc::ArrayView + skipped_streams) = 0; + + // Called for incoming (possibly deferred) RE_CONFIG chunks asking for + // either a few streams, or all streams (when the list is empty) to be + // reset - to have their next SSN or Message ID to be zero. + virtual void ResetStreams(rtc::ArrayView stream_ids) = 0; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_RX_REASSEMBLY_STREAMS_H_ diff --git a/net/dcsctp/rx/traditional_reassembly_streams.cc b/net/dcsctp/rx/traditional_reassembly_streams.cc new file mode 100644 index 0000000000..7cec1150d5 --- /dev/null +++ b/net/dcsctp/rx/traditional_reassembly_streams.cc @@ -0,0 +1,290 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/traditional_reassembly_streams.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "rtc_base/logging.h" + +namespace dcsctp { +namespace { + +// Given a map (`chunks`) and an iterator to within that map (`iter`), this +// function will return an iterator to the first chunk in that message, which +// has the `is_beginning` flag set. If there are any gaps, or if the beginning +// can't be found, `absl::nullopt` is returned. +absl::optional::iterator> FindBeginning( + const std::map& chunks, + std::map::iterator iter) { + UnwrappedTSN prev_tsn = iter->first; + for (;;) { + if (iter->second.is_beginning) { + return iter; + } + if (iter == chunks.begin()) { + return absl::nullopt; + } + --iter; + if (iter->first.next_value() != prev_tsn) { + return absl::nullopt; + } + prev_tsn = iter->first; + } +} + +// Given a map (`chunks`) and an iterator to within that map (`iter`), this +// function will return an iterator to the chunk after the last chunk in that +// message, which has the `is_end` flag set. If there are any gaps, or if the +// end can't be found, `absl::nullopt` is returned. +absl::optional::iterator> FindEnd( + std::map& chunks, + std::map::iterator iter) { + UnwrappedTSN prev_tsn = iter->first; + for (;;) { + if (iter->second.is_end) { + return ++iter; + } + ++iter; + if (iter == chunks.end()) { + return absl::nullopt; + } + if (iter->first != prev_tsn.next_value()) { + return absl::nullopt; + } + prev_tsn = iter->first; + } +} +} // namespace + +int TraditionalReassemblyStreams::UnorderedStream::Add(UnwrappedTSN tsn, + Data data) { + int queued_bytes = data.size(); + auto p = chunks_.emplace(tsn, std::move(data)); + if (!p.second /* !inserted */) { + return 0; + } + + queued_bytes -= TryToAssembleMessage(p.first); + + return queued_bytes; +} + +size_t TraditionalReassemblyStreams::UnorderedStream::TryToAssembleMessage( + ChunkMap::iterator iter) { + // TODO(boivie): This method is O(N) with the number of fragments in a + // message, which can be inefficient for very large values of N. This could be + // optimized by e.g. only trying to assemble a message once _any_ beginning + // and _any_ end has been found. + absl::optional start = FindBeginning(chunks_, iter); + if (!start.has_value()) { + return 0; + } + absl::optional end = FindEnd(chunks_, iter); + if (!end.has_value()) { + return 0; + } + + size_t bytes_assembled = AssembleMessage(*start, *end); + chunks_.erase(*start, *end); + return bytes_assembled; +} + +size_t TraditionalReassemblyStreams::StreamBase::AssembleMessage( + const ChunkMap::iterator start, + const ChunkMap::iterator end) { + size_t count = std::distance(start, end); + + if (count == 1) { + // Fast path - zero-copy + const Data& data = start->second; + size_t payload_size = start->second.size(); + UnwrappedTSN tsns[1] = {start->first}; + DcSctpMessage message(data.stream_id, data.ppid, std::move(data.payload)); + parent_.on_assembled_message_(tsns, std::move(message)); + return payload_size; + } + + // Slow path - will need to concatenate the payload. + std::vector tsns; + std::vector payload; + + size_t payload_size = std::accumulate( + start, end, 0, + [](size_t v, const auto& p) { return v + p.second.size(); }); + + tsns.reserve(count); + payload.reserve(payload_size); + for (auto it = start; it != end; ++it) { + const Data& data = it->second; + tsns.push_back(it->first); + payload.insert(payload.end(), data.payload.begin(), data.payload.end()); + } + + DcSctpMessage message(start->second.stream_id, start->second.ppid, + std::move(payload)); + parent_.on_assembled_message_(tsns, std::move(message)); + + return payload_size; +} + +size_t TraditionalReassemblyStreams::UnorderedStream::EraseTo( + UnwrappedTSN tsn) { + auto end_iter = chunks_.upper_bound(tsn); + size_t removed_bytes = std::accumulate( + chunks_.begin(), end_iter, 0, + [](size_t r, const auto& p) { return r + p.second.size(); }); + + chunks_.erase(chunks_.begin(), end_iter); + return removed_bytes; +} + +size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessage() { + if (chunks_by_ssn_.empty() || chunks_by_ssn_.begin()->first != next_ssn_) { + return 0; + } + + ChunkMap& chunks = chunks_by_ssn_.begin()->second; + + if (!chunks.begin()->second.is_beginning || !chunks.rbegin()->second.is_end) { + return 0; + } + + uint32_t tsn_diff = + UnwrappedTSN::Difference(chunks.rbegin()->first, chunks.begin()->first); + if (tsn_diff != chunks.size() - 1) { + return 0; + } + + size_t assembled_bytes = AssembleMessage(chunks.begin(), chunks.end()); + chunks_by_ssn_.erase(chunks_by_ssn_.begin()); + next_ssn_.Increment(); + return assembled_bytes; +} + +size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessages() { + size_t assembled_bytes = 0; + + for (;;) { + size_t assembled_bytes_this_iter = TryToAssembleMessage(); + if (assembled_bytes_this_iter == 0) { + break; + } + assembled_bytes += assembled_bytes_this_iter; + } + return assembled_bytes; +} + +int TraditionalReassemblyStreams::OrderedStream::Add(UnwrappedTSN tsn, + Data data) { + int queued_bytes = data.size(); + + UnwrappedSSN ssn = ssn_unwrapper_.Unwrap(data.ssn); + auto p = chunks_by_ssn_[ssn].emplace(tsn, std::move(data)); + if (!p.second /* !inserted */) { + return 0; + } + + if (ssn == next_ssn_) { + queued_bytes -= TryToAssembleMessages(); + } + + return queued_bytes; +} + +size_t TraditionalReassemblyStreams::OrderedStream::EraseTo(SSN ssn) { + UnwrappedSSN unwrapped_ssn = ssn_unwrapper_.Unwrap(ssn); + + auto end_iter = chunks_by_ssn_.upper_bound(unwrapped_ssn); + size_t removed_bytes = std::accumulate( + chunks_by_ssn_.begin(), end_iter, 0, [](size_t r1, const auto& p) { + return r1 + + absl::c_accumulate(p.second, 0, [](size_t r2, const auto& q) { + return r2 + q.second.size(); + }); + }); + chunks_by_ssn_.erase(chunks_by_ssn_.begin(), end_iter); + + if (unwrapped_ssn >= next_ssn_) { + unwrapped_ssn.Increment(); + next_ssn_ = unwrapped_ssn; + } + + removed_bytes += TryToAssembleMessages(); + return removed_bytes; +} + +int TraditionalReassemblyStreams::Add(UnwrappedTSN tsn, Data data) { + if (data.is_unordered) { + auto it = unordered_streams_.emplace(data.stream_id, this).first; + return it->second.Add(tsn, std::move(data)); + } + + auto it = ordered_streams_.emplace(data.stream_id, this).first; + return it->second.Add(tsn, std::move(data)); +} + +size_t TraditionalReassemblyStreams::HandleForwardTsn( + UnwrappedTSN new_cumulative_ack_tsn, + rtc::ArrayView skipped_streams) { + size_t bytes_removed = 0; + // The `skipped_streams` only over ordered messages - need to + // iterate all unordered streams manually to remove those chunks. + for (auto& entry : unordered_streams_) { + bytes_removed += entry.second.EraseTo(new_cumulative_ack_tsn); + } + + for (const auto& skipped_stream : skipped_streams) { + auto it = ordered_streams_.find(skipped_stream.stream_id); + if (it != ordered_streams_.end()) { + bytes_removed += it->second.EraseTo(skipped_stream.ssn); + } + } + + return bytes_removed; +} + +void TraditionalReassemblyStreams::ResetStreams( + rtc::ArrayView stream_ids) { + if (stream_ids.empty()) { + for (auto& entry : ordered_streams_) { + const StreamID& stream_id = entry.first; + OrderedStream& stream = entry.second; + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "Resetting implicit stream_id=" << *stream_id; + stream.Reset(); + } + } else { + for (StreamID stream_id : stream_ids) { + auto it = ordered_streams_.find(stream_id); + if (it != ordered_streams_.end()) { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ << "Resetting explicit stream_id=" << *stream_id; + it->second.Reset(); + } + } + } +} +} // namespace dcsctp diff --git a/net/dcsctp/rx/traditional_reassembly_streams.h b/net/dcsctp/rx/traditional_reassembly_streams.h new file mode 100644 index 0000000000..12d1d933a4 --- /dev/null +++ b/net/dcsctp/rx/traditional_reassembly_streams.h @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_RX_TRADITIONAL_REASSEMBLY_STREAMS_H_ +#define NET_DCSCTP_RX_TRADITIONAL_REASSEMBLY_STREAMS_H_ +#include +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/rx/reassembly_streams.h" + +namespace dcsctp { + +// Handles reassembly of incoming data when interleaved message sending +// is not enabled on the association, i.e. when RFC8260 is not in use and +// RFC4960 is to be followed. +class TraditionalReassemblyStreams : public ReassemblyStreams { + public: + TraditionalReassemblyStreams(absl::string_view log_prefix, + OnAssembledMessage on_assembled_message) + : log_prefix_(log_prefix), on_assembled_message_(on_assembled_message) {} + + int Add(UnwrappedTSN tsn, Data data) override; + + size_t HandleForwardTsn( + UnwrappedTSN new_cumulative_ack_tsn, + rtc::ArrayView skipped_streams) + override; + + void ResetStreams(rtc::ArrayView stream_ids) override; + + private: + using ChunkMap = std::map; + + // Base class for `UnorderedStream` and `OrderedStream`. + class StreamBase { + protected: + explicit StreamBase(TraditionalReassemblyStreams* parent) + : parent_(*parent) {} + + size_t AssembleMessage(const ChunkMap::iterator start, + const ChunkMap::iterator end); + TraditionalReassemblyStreams& parent_; + }; + + // Manages all received data for a specific unordered stream, and assembles + // messages when possible. + class UnorderedStream : StreamBase { + public: + explicit UnorderedStream(TraditionalReassemblyStreams* parent) + : StreamBase(parent) {} + int Add(UnwrappedTSN tsn, Data data); + // Returns the number of bytes removed from the queue. + size_t EraseTo(UnwrappedTSN tsn); + + private: + // Given an iterator to any chunk within the map, try to assemble a message + // into `reassembled_messages` containing it and - if successful - erase + // those chunks from the stream chunks map. + // + // Returns the number of bytes that were assembled. + size_t TryToAssembleMessage(ChunkMap::iterator iter); + + ChunkMap chunks_; + }; + + // Manages all received data for a specific ordered stream, and assembles + // messages when possible. + class OrderedStream : StreamBase { + public: + explicit OrderedStream(TraditionalReassemblyStreams* parent) + : StreamBase(parent), next_ssn_(ssn_unwrapper_.Unwrap(SSN(0))) {} + int Add(UnwrappedTSN tsn, Data data); + size_t EraseTo(SSN ssn); + void Reset() { + ssn_unwrapper_.Reset(); + next_ssn_ = ssn_unwrapper_.Unwrap(SSN(0)); + } + + private: + // Try to assemble one or several messages in order from the stream. + // Returns the number of bytes assembled if a message was assembled. + size_t TryToAssembleMessage(); + size_t TryToAssembleMessages(); + // This must be an ordered container to be able to iterate in SSN order. + std::map chunks_by_ssn_; + UnwrappedSSN::Unwrapper ssn_unwrapper_; + UnwrappedSSN next_ssn_; + }; + + const std::string log_prefix_; + + // Callback for when a message has been assembled. + const OnAssembledMessage on_assembled_message_; + + // All unordered and ordered streams, managing not-yet-assembled data. + std::unordered_map + unordered_streams_; + std::unordered_map + ordered_streams_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_RX_TRADITIONAL_REASSEMBLY_STREAMS_H_ diff --git a/net/dcsctp/rx/traditional_reassembly_streams_test.cc b/net/dcsctp/rx/traditional_reassembly_streams_test.cc new file mode 100644 index 0000000000..30d29a05dc --- /dev/null +++ b/net/dcsctp/rx/traditional_reassembly_streams_test.cc @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/traditional_reassembly_streams.h" + +#include +#include +#include + +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/rx/reassembly_streams.h" +#include "net/dcsctp/testing/data_generator.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::MockFunction; +using ::testing::NiceMock; + +class TraditionalReassemblyStreamsTest : public testing::Test { + protected: + UnwrappedTSN tsn(uint32_t value) { return tsn_.Unwrap(TSN(value)); } + + TraditionalReassemblyStreamsTest() {} + DataGenerator gen_; + UnwrappedTSN::Unwrapper tsn_; +}; + +TEST_F(TraditionalReassemblyStreamsTest, + AddUnorderedMessageReturnsCorrectSize) { + NiceMock> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Unordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Unordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Unordered({5, 6})), 2); + // Adding the end fragment should make it empty again. + EXPECT_EQ(streams.Add(tsn(4), gen_.Unordered({7}, "E")), -6); +} + +TEST_F(TraditionalReassemblyStreamsTest, + AddSimpleOrderedMessageReturnsCorrectSize) { + NiceMock> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), -6); +} + +TEST_F(TraditionalReassemblyStreamsTest, + AddMoreComplexOrderedMessageReturnsCorrectSize) { + NiceMock> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + Data late = gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + EXPECT_EQ(streams.Add(tsn(2), std::move(late)), -8); +} + +TEST_F(TraditionalReassemblyStreamsTest, + DeleteUnorderedMessageReturnsCorrectSize) { + NiceMock> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Unordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Unordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Unordered({5, 6})), 2); + + EXPECT_EQ(streams.HandleForwardTsn(tsn(3), {}), 6u); +} + +TEST_F(TraditionalReassemblyStreamsTest, + DeleteSimpleOrderedMessageReturnsCorrectSize) { + NiceMock> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + + ForwardTsnChunk::SkippedStream skipped[] = { + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(0))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(3), skipped), 6u); +} + +TEST_F(TraditionalReassemblyStreamsTest, + DeleteManyOrderedMessagesReturnsCorrectSize) { + NiceMock> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + + // Expire all three messages + ForwardTsnChunk::SkippedStream skipped[] = { + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(2))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(8), skipped), 8u); +} + +TEST_F(TraditionalReassemblyStreamsTest, + DeleteOrderedMessageDelivesTwoReturnsCorrectSize) { + NiceMock> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + + // The first ordered message expire, and the following two are delivered. + ForwardTsnChunk::SkippedStream skipped[] = { + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(0))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(4), skipped), 8u); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/socket/BUILD.gn b/net/dcsctp/socket/BUILD.gn new file mode 100644 index 0000000000..72ac139acb --- /dev/null +++ b/net/dcsctp/socket/BUILD.gn @@ -0,0 +1,236 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_source_set("context") { + sources = [ "context.h" ] + deps = [ + "../common:internal_types", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +rtc_library("heartbeat_handler") { + deps = [ + ":context", + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../packet:bounded_io", + "../packet:chunk", + "../packet:parameter", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + "../timer", + ] + sources = [ + "heartbeat_handler.cc", + "heartbeat_handler.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("stream_reset_handler") { + deps = [ + ":context", + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:internal_types", + "../common:str_join", + "../packet:chunk", + "../packet:parameter", + "../packet:sctp_packet", + "../packet:tlv_trait", + "../public:socket", + "../public:types", + "../rx:data_tracker", + "../rx:reassembly_queue", + "../timer", + "../tx:retransmission_queue", + ] + sources = [ + "stream_reset_handler.cc", + "stream_reset_handler.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("transmission_control_block") { + deps = [ + ":context", + ":heartbeat_handler", + ":stream_reset_handler", + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + "../rx:data_tracker", + "../rx:reassembly_queue", + "../timer", + "../tx:retransmission_error_counter", + "../tx:retransmission_queue", + "../tx:retransmission_timeout", + "../tx:send_queue", + ] + sources = [ + "capabilities.h", + "transmission_control_block.cc", + "transmission_control_block.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("dcsctp_socket") { + deps = [ + ":context", + ":heartbeat_handler", + ":stream_reset_handler", + ":transmission_control_block", + "../../../api:array_view", + "../../../api:refcountedbase", + "../../../api:scoped_refptr", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:internal_types", + "../packet:bounded_io", + "../packet:chunk", + "../packet:chunk_validators", + "../packet:data", + "../packet:error_cause", + "../packet:parameter", + "../packet:sctp_packet", + "../packet:tlv_trait", + "../public:socket", + "../public:types", + "../rx:data_tracker", + "../rx:reassembly_queue", + "../timer", + "../tx:retransmission_error_counter", + "../tx:retransmission_queue", + "../tx:retransmission_timeout", + "../tx:rr_send_queue", + "../tx:send_queue", + ] + sources = [ + "callback_deferrer.h", + "dcsctp_socket.cc", + "dcsctp_socket.h", + "state_cookie.cc", + "state_cookie.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +if (rtc_include_tests) { + rtc_source_set("mock_callbacks") { + testonly = true + sources = [ "mock_dcsctp_socket_callbacks.h" ] + deps = [ + "../../../api:array_view", + "../../../rtc_base:logging", + "../../../rtc_base:rtc_base_approved", + "../../../test:test_support", + "../public:socket", + "../public:types", + "../timer", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + } + + rtc_source_set("mock_context") { + testonly = true + sources = [ "mock_context.h" ] + deps = [ + ":context", + ":mock_callbacks", + "../../../test:test_support", + "../common:internal_types", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + } + + rtc_library("dcsctp_socket_unittests") { + testonly = true + + deps = [ + ":dcsctp_socket", + ":heartbeat_handler", + ":mock_callbacks", + ":mock_context", + ":stream_reset_handler", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../rtc_base:rtc_base_approved", + "../../../test:test_support", + "../common:internal_types", + "../packet:chunk", + "../packet:error_cause", + "../packet:parameter", + "../packet:sctp_packet", + "../packet:tlv_trait", + "../public:socket", + "../public:types", + "../public:utils", + "../rx:data_tracker", + "../rx:reassembly_queue", + "../testing:data_generator", + "../testing:testing_macros", + "../timer", + "../tx:mock_send_queue", + "../tx:retransmission_queue", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + sources = [ + "dcsctp_socket_test.cc", + "heartbeat_handler_test.cc", + "state_cookie_test.cc", + "stream_reset_handler_test.cc", + ] + } +} diff --git a/net/dcsctp/socket/callback_deferrer.h b/net/dcsctp/socket/callback_deferrer.h new file mode 100644 index 0000000000..197cf434af --- /dev/null +++ b/net/dcsctp/socket/callback_deferrer.h @@ -0,0 +1,184 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_CALLBACK_DEFERRER_H_ +#define NET_DCSCTP_SOCKET_CALLBACK_DEFERRER_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "api/ref_counted_base.h" +#include "api/scoped_refptr.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "rtc_base/ref_counted_object.h" + +namespace dcsctp { + +// Defers callbacks until they can be safely triggered. +// +// There are a lot of callbacks from the dcSCTP library to the client, +// such as when messages are received or streams are closed. When the client +// receives these callbacks, the client is expected to be able to call into the +// library - from within the callback. For example, sending a reply message when +// a certain SCTP message has been received, or to reconnect when the connection +// was closed for any reason. This means that the dcSCTP library must always be +// in a consistent and stable state when these callbacks are delivered, and to +// ensure that's the case, callbacks are not immediately delivered from where +// they originate, but instead queued (deferred) by this class. At the end of +// any public API method that may result in callbacks, they are triggered and +// then delivered. +// +// There are a number of exceptions, which is clearly annotated in the API. +class CallbackDeferrer : public DcSctpSocketCallbacks { + public: + explicit CallbackDeferrer(DcSctpSocketCallbacks& underlying) + : underlying_(underlying) {} + + void TriggerDeferred() { + // Need to swap here. The client may call into the library from within a + // callback, and that might result in adding new callbacks to this instance, + // and the vector can't be modified while iterated on. + std::vector> deferred; + deferred.swap(deferred_); + + for (auto& cb : deferred) { + cb(underlying_); + } + } + + void SendPacket(rtc::ArrayView data) override { + // Will not be deferred - call directly. + underlying_.SendPacket(data); + } + + std::unique_ptr CreateTimeout() override { + // Will not be deferred - call directly. + return underlying_.CreateTimeout(); + } + + TimeMs TimeMillis() override { + // Will not be deferred - call directly. + return underlying_.TimeMillis(); + } + + uint32_t GetRandomInt(uint32_t low, uint32_t high) override { + // Will not be deferred - call directly. + return underlying_.GetRandomInt(low, high); + } + + void OnMessageReceived(DcSctpMessage message) override { + deferred_.emplace_back( + [deliverer = MessageDeliverer(std::move(message))]( + DcSctpSocketCallbacks& cb) mutable { deliverer.Deliver(cb); }); + } + + void OnError(ErrorKind error, absl::string_view message) override { + deferred_.emplace_back( + [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { + cb.OnError(error, message); + }); + } + + void OnAborted(ErrorKind error, absl::string_view message) override { + deferred_.emplace_back( + [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { + cb.OnAborted(error, message); + }); + } + + void OnConnected() override { + deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnConnected(); }); + } + + void OnClosed() override { + deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnClosed(); }); + } + + void OnConnectionRestarted() override { + deferred_.emplace_back( + [](DcSctpSocketCallbacks& cb) { cb.OnConnectionRestarted(); }); + } + + void OnStreamsResetFailed(rtc::ArrayView outgoing_streams, + absl::string_view reason) override { + deferred_.emplace_back( + [streams = std::vector(outgoing_streams.begin(), + outgoing_streams.end()), + reason = std::string(reason)](DcSctpSocketCallbacks& cb) { + cb.OnStreamsResetFailed(streams, reason); + }); + } + + void OnStreamsResetPerformed( + rtc::ArrayView outgoing_streams) override { + deferred_.emplace_back( + [streams = std::vector(outgoing_streams.begin(), + outgoing_streams.end())]( + DcSctpSocketCallbacks& cb) { + cb.OnStreamsResetPerformed(streams); + }); + } + + void OnIncomingStreamsReset( + rtc::ArrayView incoming_streams) override { + deferred_.emplace_back( + [streams = std::vector(incoming_streams.begin(), + incoming_streams.end())]( + DcSctpSocketCallbacks& cb) { cb.OnIncomingStreamsReset(streams); }); + } + + void OnBufferedAmountLow(StreamID stream_id) override { + deferred_.emplace_back([stream_id](DcSctpSocketCallbacks& cb) { + cb.OnBufferedAmountLow(stream_id); + }); + } + + void OnTotalBufferedAmountLow() override { + deferred_.emplace_back( + [](DcSctpSocketCallbacks& cb) { cb.OnTotalBufferedAmountLow(); }); + } + + private: + // A wrapper around the move-only DcSctpMessage, to let it be captured in a + // lambda. + class MessageDeliverer { + public: + explicit MessageDeliverer(DcSctpMessage&& message) + : state_(rtc::make_ref_counted(std::move(message))) {} + + void Deliver(DcSctpSocketCallbacks& c) { + // Really ensure that it's only called once. + RTC_DCHECK(!state_->has_delivered); + state_->has_delivered = true; + c.OnMessageReceived(std::move(state_->message)); + } + + private: + struct State : public rtc::RefCountInterface { + explicit State(DcSctpMessage&& m) + : has_delivered(false), message(std::move(m)) {} + bool has_delivered; + DcSctpMessage message; + }; + rtc::scoped_refptr state_; + }; + + DcSctpSocketCallbacks& underlying_; + std::vector> deferred_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_CALLBACK_DEFERRER_H_ diff --git a/net/dcsctp/socket/capabilities.h b/net/dcsctp/socket/capabilities.h new file mode 100644 index 0000000000..c6d3692b2d --- /dev/null +++ b/net/dcsctp/socket/capabilities.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_CAPABILITIES_H_ +#define NET_DCSCTP_SOCKET_CAPABILITIES_H_ + +namespace dcsctp { +// Indicates what the association supports, meaning that both parties +// support it and that feature can be used. +struct Capabilities { + // RFC3758 Partial Reliability Extension + bool partial_reliability = false; + // RFC8260 Stream Schedulers and User Message Interleaving + bool message_interleaving = false; + // RFC6525 Stream Reconfiguration + bool reconfig = false; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_CAPABILITIES_H_ diff --git a/net/dcsctp/socket/context.h b/net/dcsctp/socket/context.h new file mode 100644 index 0000000000..eca5b9e4fb --- /dev/null +++ b/net/dcsctp/socket/context.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_CONTEXT_H_ +#define NET_DCSCTP_SOCKET_CONTEXT_H_ + +#include + +#include "absl/strings/string_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// A set of helper methods used by handlers to e.g. send packets. +// +// Implemented by the TransmissionControlBlock. +class Context { + public: + virtual ~Context() = default; + + // Indicates if a connection has been established. + virtual bool is_connection_established() const = 0; + + // Returns this side's initial TSN value. + virtual TSN my_initial_tsn() const = 0; + + // Returns the peer's initial TSN value. + virtual TSN peer_initial_tsn() const = 0; + + // Returns the socket callbacks. + virtual DcSctpSocketCallbacks& callbacks() const = 0; + + // Observes a measured RTT value, in milliseconds. + virtual void ObserveRTT(DurationMs rtt_ms) = 0; + + // Returns the current Retransmission Timeout (rto) value, in milliseconds. + virtual DurationMs current_rto() const = 0; + + // Increments the transmission error counter, given a human readable reason. + virtual bool IncrementTxErrorCounter(absl::string_view reason) = 0; + + // Clears the transmission error counter. + virtual void ClearTxErrorCounter() = 0; + + // Returns true if there have been too many retransmission errors. + virtual bool HasTooManyTxErrors() const = 0; + + // Returns a PacketBuilder, filled in with the correct verification tag. + virtual SctpPacket::Builder PacketBuilder() const = 0; + + // Builds the packet from `builder` and sends it. + virtual void Send(SctpPacket::Builder& builder) = 0; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_CONTEXT_H_ diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc new file mode 100644 index 0000000000..71bc98c70d --- /dev/null +++ b/net/dcsctp/socket/dcsctp_socket.cc @@ -0,0 +1,1550 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/dcsctp_socket.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/abort_chunk.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/chunk/error_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/init_ack_chunk.h" +#include "net/dcsctp/packet/chunk/init_chunk.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_ack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_complete_chunk.h" +#include "net/dcsctp/packet/chunk_validators.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/error_cause/no_user_data_cause.h" +#include "net/dcsctp/packet/error_cause/out_of_resource_error_cause.h" +#include "net/dcsctp/packet/error_cause/protocol_violation_cause.h" +#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h" +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/state_cookie_parameter.h" +#include "net/dcsctp/packet/parameter/supported_extensions_parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/packet_observer.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/callback_deferrer.h" +#include "net/dcsctp/socket/capabilities.h" +#include "net/dcsctp/socket/heartbeat_handler.h" +#include "net/dcsctp/socket/state_cookie.h" +#include "net/dcsctp/socket/stream_reset_handler.h" +#include "net/dcsctp/socket/transmission_control_block.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" +#include "rtc_base/strings/string_format.h" + +namespace dcsctp { +namespace { + +// https://tools.ietf.org/html/rfc4960#section-5.1 +constexpr uint32_t kMinVerificationTag = 1; +constexpr uint32_t kMaxVerificationTag = std::numeric_limits::max(); + +// https://tools.ietf.org/html/rfc4960#section-3.3.2 +constexpr uint32_t kMinInitialTsn = 0; +constexpr uint32_t kMaxInitialTsn = std::numeric_limits::max(); + +Capabilities GetCapabilities(const DcSctpOptions& options, + const Parameters& parameters) { + Capabilities capabilities; + absl::optional supported_extensions = + parameters.get(); + + if (options.enable_partial_reliability) { + capabilities.partial_reliability = + parameters.get().has_value(); + if (supported_extensions.has_value()) { + capabilities.partial_reliability |= + supported_extensions->supports(ForwardTsnChunk::kType); + } + } + + if (options.enable_message_interleaving && supported_extensions.has_value()) { + capabilities.message_interleaving = + supported_extensions->supports(IDataChunk::kType) && + supported_extensions->supports(IForwardTsnChunk::kType); + } + if (supported_extensions.has_value() && + supported_extensions->supports(ReConfigChunk::kType)) { + capabilities.reconfig = true; + } + return capabilities; +} + +void AddCapabilityParameters(const DcSctpOptions& options, + Parameters::Builder& builder) { + std::vector chunk_types = {ReConfigChunk::kType}; + + if (options.enable_partial_reliability) { + builder.Add(ForwardTsnSupportedParameter()); + chunk_types.push_back(ForwardTsnChunk::kType); + } + if (options.enable_message_interleaving) { + chunk_types.push_back(IDataChunk::kType); + chunk_types.push_back(IForwardTsnChunk::kType); + } + builder.Add(SupportedExtensionsParameter(std::move(chunk_types))); +} + +TieTag MakeTieTag(DcSctpSocketCallbacks& cb) { + uint32_t tie_tag_upper = + cb.GetRandomInt(0, std::numeric_limits::max()); + uint32_t tie_tag_lower = + cb.GetRandomInt(1, std::numeric_limits::max()); + return TieTag(static_cast(tie_tag_upper) << 32 | + static_cast(tie_tag_lower)); +} + +} // namespace + +DcSctpSocket::DcSctpSocket(absl::string_view log_prefix, + DcSctpSocketCallbacks& callbacks, + std::unique_ptr packet_observer, + const DcSctpOptions& options) + : log_prefix_(std::string(log_prefix) + ": "), + packet_observer_(std::move(packet_observer)), + options_(options), + callbacks_(callbacks), + timer_manager_([this]() { return callbacks_.CreateTimeout(); }), + t1_init_(timer_manager_.CreateTimer( + "t1-init", + [this]() { return OnInitTimerExpiry(); }, + TimerOptions(options.t1_init_timeout, + TimerBackoffAlgorithm::kExponential, + options.max_init_retransmits))), + t1_cookie_(timer_manager_.CreateTimer( + "t1-cookie", + [this]() { return OnCookieTimerExpiry(); }, + TimerOptions(options.t1_cookie_timeout, + TimerBackoffAlgorithm::kExponential, + options.max_init_retransmits))), + t2_shutdown_(timer_manager_.CreateTimer( + "t2-shutdown", + [this]() { return OnShutdownTimerExpiry(); }, + TimerOptions(options.t2_shutdown_timeout, + TimerBackoffAlgorithm::kExponential, + options.max_retransmissions))), + send_queue_( + log_prefix_, + options_.max_send_buffer_size, + [this](StreamID stream_id) { + callbacks_.OnBufferedAmountLow(stream_id); + }, + options_.total_buffered_amount_low_threshold, + [this]() { callbacks_.OnTotalBufferedAmountLow(); }) {} + +std::string DcSctpSocket::log_prefix() const { + return log_prefix_ + "[" + std::string(ToString(state_)) + "] "; +} + +bool DcSctpSocket::IsConsistent() const { + switch (state_) { + case State::kClosed: + return (tcb_ == nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kCookieWait: + return (tcb_ == nullptr && t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kCookieEchoed: + return (tcb_ != nullptr && !t1_init_->is_running() && + t1_cookie_->is_running() && !t2_shutdown_->is_running() && + tcb_->has_cookie_echo_chunk()); + case State::kEstablished: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kShutdownPending: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kShutdownSent: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && t2_shutdown_->is_running()); + case State::kShutdownReceived: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kShutdownAckSent: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && t2_shutdown_->is_running()); + } +} + +constexpr absl::string_view DcSctpSocket::ToString(DcSctpSocket::State state) { + switch (state) { + case DcSctpSocket::State::kClosed: + return "CLOSED"; + case DcSctpSocket::State::kCookieWait: + return "COOKIE_WAIT"; + case DcSctpSocket::State::kCookieEchoed: + return "COOKIE_ECHOED"; + case DcSctpSocket::State::kEstablished: + return "ESTABLISHED"; + case DcSctpSocket::State::kShutdownPending: + return "SHUTDOWN_PENDING"; + case DcSctpSocket::State::kShutdownSent: + return "SHUTDOWN_SENT"; + case DcSctpSocket::State::kShutdownReceived: + return "SHUTDOWN_RECEIVED"; + case DcSctpSocket::State::kShutdownAckSent: + return "SHUTDOWN_ACK_SENT"; + } +} + +void DcSctpSocket::SetState(State state, absl::string_view reason) { + if (state_ != state) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Socket state changed from " + << ToString(state_) << " to " << ToString(state) + << " due to " << reason; + state_ = state; + } +} + +void DcSctpSocket::SendInit() { + Parameters::Builder params_builder; + AddCapabilityParameters(options_, params_builder); + InitChunk init(/*initiate_tag=*/connect_params_.verification_tag, + /*a_rwnd=*/options_.max_receiver_window_buffer_size, + options_.announced_maximum_outgoing_streams, + options_.announced_maximum_incoming_streams, + connect_params_.initial_tsn, params_builder.Build()); + SctpPacket::Builder b(VerificationTag(0), options_); + b.Add(init); + SendPacket(b); +} + +void DcSctpSocket::MakeConnectionParameters() { + VerificationTag new_verification_tag( + callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag)); + TSN initial_tsn(callbacks_.GetRandomInt(kMinInitialTsn, kMaxInitialTsn)); + connect_params_.initial_tsn = initial_tsn; + connect_params_.verification_tag = new_verification_tag; +} + +void DcSctpSocket::Connect() { + if (state_ == State::kClosed) { + MakeConnectionParameters(); + RTC_DLOG(LS_INFO) + << log_prefix() + << rtc::StringFormat( + "Connecting. my_verification_tag=%08x, my_initial_tsn=%u", + *connect_params_.verification_tag, *connect_params_.initial_tsn); + SendInit(); + t1_init_->Start(); + SetState(State::kCookieWait, "Connect called"); + } else { + RTC_DLOG(LS_WARNING) << log_prefix() + << "Called Connect on a socket that is not closed"; + } + RTC_DCHECK(IsConsistent()); + callbacks_.TriggerDeferred(); +} + +void DcSctpSocket::Shutdown() { + if (tcb_ != nullptr) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Upon receipt of the SHUTDOWN primitive from its upper layer, the + // endpoint enters the SHUTDOWN-PENDING state and remains there until all + // outstanding data has been acknowledged by its peer." + + // TODO(webrtc:12739): Remove this check, as it just hides the problem that + // the socket can transition from ShutdownSent to ShutdownPending, or + // ShutdownAckSent to ShutdownPending which is illegal. + if (state_ != State::kShutdownSent && state_ != State::kShutdownAckSent) { + SetState(State::kShutdownPending, "Shutdown called"); + t1_init_->Stop(); + t1_cookie_->Stop(); + MaybeSendShutdownOrAck(); + } + } else { + // Connection closed before even starting to connect, or during the initial + // connection phase. There is no outstanding data, so the socket can just + // be closed (stopping any connection timers, if any), as this is the + // client's intention, by calling Shutdown. + InternalClose(ErrorKind::kNoError, ""); + } + RTC_DCHECK(IsConsistent()); + callbacks_.TriggerDeferred(); +} + +void DcSctpSocket::Close() { + if (state_ != State::kClosed) { + if (tcb_ != nullptr) { + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(AbortChunk(/*filled_in_verification_tag=*/true, + Parameters::Builder() + .Add(UserInitiatedAbortCause("Close called")) + .Build())); + SendPacket(b); + } + InternalClose(ErrorKind::kNoError, ""); + } else { + RTC_DLOG(LS_INFO) << log_prefix() << "Called Close on a closed socket"; + } + RTC_DCHECK(IsConsistent()); + callbacks_.TriggerDeferred(); +} + +void DcSctpSocket::CloseConnectionBecauseOfTooManyTransmissionErrors() { + SendPacket(tcb_->PacketBuilder().Add(AbortChunk( + true, Parameters::Builder() + .Add(UserInitiatedAbortCause("Too many retransmissions")) + .Build()))); + InternalClose(ErrorKind::kTooManyRetries, "Too many retransmissions"); +} + +void DcSctpSocket::InternalClose(ErrorKind error, absl::string_view message) { + if (state_ != State::kClosed) { + t1_init_->Stop(); + t1_cookie_->Stop(); + t2_shutdown_->Stop(); + tcb_ = nullptr; + + if (error == ErrorKind::kNoError) { + callbacks_.OnClosed(); + } else { + callbacks_.OnAborted(error, message); + } + SetState(State::kClosed, message); + } + // This method's purpose is to abort/close and make it consistent by ensuring + // that e.g. all timers really are stopped. + RTC_DCHECK(IsConsistent()); +} + +SendStatus DcSctpSocket::Send(DcSctpMessage message, + const SendOptions& send_options) { + if (message.payload().empty()) { + callbacks_.OnError(ErrorKind::kProtocolViolation, + "Unable to send empty message"); + return SendStatus::kErrorMessageEmpty; + } + if (message.payload().size() > options_.max_message_size) { + callbacks_.OnError(ErrorKind::kProtocolViolation, + "Unable to send too large message"); + return SendStatus::kErrorMessageTooLarge; + } + if (state_ == State::kShutdownPending || state_ == State::kShutdownSent || + state_ == State::kShutdownReceived || state_ == State::kShutdownAckSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "An endpoint should reject any new data request from its upper layer + // if it is in the SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED, or + // SHUTDOWN-ACK-SENT state." + callbacks_.OnError(ErrorKind::kWrongSequence, + "Unable to send message as the socket is shutting down"); + return SendStatus::kErrorShuttingDown; + } + if (send_queue_.IsFull()) { + callbacks_.OnError(ErrorKind::kResourceExhaustion, + "Unable to send message as the send queue is full"); + return SendStatus::kErrorResourceExhaustion; + } + + TimeMs now = callbacks_.TimeMillis(); + send_queue_.Add(now, std::move(message), send_options); + if (tcb_ != nullptr) { + tcb_->SendBufferedPackets(now); + } + + RTC_DCHECK(IsConsistent()); + callbacks_.TriggerDeferred(); + return SendStatus::kSuccess; +} + +ResetStreamsStatus DcSctpSocket::ResetStreams( + rtc::ArrayView outgoing_streams) { + if (tcb_ == nullptr) { + callbacks_.OnError(ErrorKind::kWrongSequence, + "Can't reset streams as the socket is not connected"); + return ResetStreamsStatus::kNotConnected; + } + if (!tcb_->capabilities().reconfig) { + callbacks_.OnError(ErrorKind::kUnsupportedOperation, + "Can't reset streams as the peer doesn't support it"); + return ResetStreamsStatus::kNotSupported; + } + + tcb_->stream_reset_handler().ResetStreams(outgoing_streams); + absl::optional reconfig = + tcb_->stream_reset_handler().MakeStreamResetRequest(); + if (reconfig.has_value()) { + SctpPacket::Builder builder = tcb_->PacketBuilder(); + builder.Add(*reconfig); + SendPacket(builder); + } + + RTC_DCHECK(IsConsistent()); + callbacks_.TriggerDeferred(); + return ResetStreamsStatus::kPerformed; +} + +SocketState DcSctpSocket::state() const { + switch (state_) { + case State::kClosed: + return SocketState::kClosed; + case State::kCookieWait: + ABSL_FALLTHROUGH_INTENDED; + case State::kCookieEchoed: + return SocketState::kConnecting; + case State::kEstablished: + return SocketState::kConnected; + case State::kShutdownPending: + ABSL_FALLTHROUGH_INTENDED; + case State::kShutdownSent: + ABSL_FALLTHROUGH_INTENDED; + case State::kShutdownReceived: + ABSL_FALLTHROUGH_INTENDED; + case State::kShutdownAckSent: + return SocketState::kShuttingDown; + } +} + +void DcSctpSocket::SetMaxMessageSize(size_t max_message_size) { + options_.max_message_size = max_message_size; +} + +size_t DcSctpSocket::buffered_amount(StreamID stream_id) const { + return send_queue_.buffered_amount(stream_id); +} + +size_t DcSctpSocket::buffered_amount_low_threshold(StreamID stream_id) const { + return send_queue_.buffered_amount_low_threshold(stream_id); +} + +void DcSctpSocket::SetBufferedAmountLowThreshold(StreamID stream_id, + size_t bytes) { + send_queue_.SetBufferedAmountLowThreshold(stream_id, bytes); +} + +void DcSctpSocket::MaybeSendShutdownOnPacketReceived(const SctpPacket& packet) { + if (state_ == State::kShutdownSent) { + bool has_data_chunk = + std::find_if(packet.descriptors().begin(), packet.descriptors().end(), + [](const SctpPacket::ChunkDescriptor& descriptor) { + return descriptor.type == DataChunk::kType; + }) != packet.descriptors().end(); + if (has_data_chunk) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "While in the SHUTDOWN-SENT state, the SHUTDOWN sender MUST immediately + // respond to each received packet containing one or more DATA chunks with + // a SHUTDOWN chunk and restart the T2-shutdown timer."" + SendShutdown(); + t2_shutdown_->set_duration(tcb_->current_rto()); + t2_shutdown_->Start(); + } + } +} + +bool DcSctpSocket::ValidatePacket(const SctpPacket& packet) { + const CommonHeader& header = packet.common_header(); + VerificationTag my_verification_tag = + tcb_ != nullptr ? tcb_->my_verification_tag() : VerificationTag(0); + + if (header.verification_tag == VerificationTag(0)) { + if (packet.descriptors().size() == 1 && + packet.descriptors()[0].type == InitChunk::kType) { + // https://tools.ietf.org/html/rfc4960#section-8.5.1 + // "When an endpoint receives an SCTP packet with the Verification Tag + // set to 0, it should verify that the packet contains only an INIT chunk. + // Otherwise, the receiver MUST silently discard the packet."" + return true; + } + callbacks_.OnError( + ErrorKind::kParseFailed, + "Only a single INIT chunk can be present in packets sent on " + "verification_tag = 0"); + return false; + } + + if (packet.descriptors().size() == 1 && + packet.descriptors()[0].type == AbortChunk::kType) { + // https://tools.ietf.org/html/rfc4960#section-8.5.1 + // "The receiver of an ABORT MUST accept the packet if the Verification + // Tag field of the packet matches its own tag and the T bit is not set OR + // if it is set to its peer's tag and the T bit is set in the Chunk Flags. + // Otherwise, the receiver MUST silently discard the packet and take no + // further action." + bool t_bit = (packet.descriptors()[0].flags & 0x01) != 0; + if (t_bit && tcb_ == nullptr) { + // Can't verify the tag - assume it's okey. + return true; + } + if ((!t_bit && header.verification_tag == my_verification_tag) || + (t_bit && header.verification_tag == tcb_->peer_verification_tag())) { + return true; + } + callbacks_.OnError(ErrorKind::kParseFailed, + "ABORT chunk verification tag was wrong"); + return false; + } + + if (packet.descriptors()[0].type == InitAckChunk::kType) { + if (header.verification_tag == connect_params_.verification_tag) { + return true; + } + callbacks_.OnError( + ErrorKind::kParseFailed, + rtc::StringFormat( + "Packet has invalid verification tag: %08x, expected %08x", + *header.verification_tag, *connect_params_.verification_tag)); + return false; + } + + if (packet.descriptors()[0].type == CookieEchoChunk::kType) { + // Handled in chunk handler (due to RFC 4960, section 5.2.4). + return true; + } + + if (packet.descriptors().size() == 1 && + packet.descriptors()[0].type == ShutdownCompleteChunk::kType) { + // https://tools.ietf.org/html/rfc4960#section-8.5.1 + // "The receiver of a SHUTDOWN COMPLETE shall accept the packet if the + // Verification Tag field of the packet matches its own tag and the T bit is + // not set OR if it is set to its peer's tag and the T bit is set in the + // Chunk Flags. Otherwise, the receiver MUST silently discard the packet + // and take no further action." + bool t_bit = (packet.descriptors()[0].flags & 0x01) != 0; + if (t_bit && tcb_ == nullptr) { + // Can't verify the tag - assume it's okey. + return true; + } + if ((!t_bit && header.verification_tag == my_verification_tag) || + (t_bit && header.verification_tag == tcb_->peer_verification_tag())) { + return true; + } + callbacks_.OnError(ErrorKind::kParseFailed, + "SHUTDOWN_COMPLETE chunk verification tag was wrong"); + return false; + } + + // https://tools.ietf.org/html/rfc4960#section-8.5 + // "When receiving an SCTP packet, the endpoint MUST ensure that the value + // in the Verification Tag field of the received SCTP packet matches its own + // tag. If the received Verification Tag value does not match the receiver's + // own tag value, the receiver shall silently discard the packet and shall not + // process it any further..." + if (header.verification_tag == my_verification_tag) { + return true; + } + + callbacks_.OnError( + ErrorKind::kParseFailed, + rtc::StringFormat( + "Packet has invalid verification tag: %08x, expected %08x", + *header.verification_tag, *my_verification_tag)); + return false; +} + +void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) { + timer_manager_.HandleTimeout(timeout_id); + + if (tcb_ != nullptr && tcb_->HasTooManyTxErrors()) { + // Tearing down the TCB has to be done outside the handlers. + CloseConnectionBecauseOfTooManyTransmissionErrors(); + } + + RTC_DCHECK(IsConsistent()); + callbacks_.TriggerDeferred(); +} + +void DcSctpSocket::ReceivePacket(rtc::ArrayView data) { + if (packet_observer_ != nullptr) { + packet_observer_->OnReceivedPacket(callbacks_.TimeMillis(), data); + } + + absl::optional packet = + SctpPacket::Parse(data, options_.disable_checksum_verification); + if (!packet.has_value()) { + // https://tools.ietf.org/html/rfc4960#section-6.8 + // "The default procedure for handling invalid SCTP packets is to + // silently discard them." + callbacks_.OnError(ErrorKind::kParseFailed, + "Failed to parse received SCTP packet"); + RTC_DCHECK(IsConsistent()); + callbacks_.TriggerDeferred(); + return; + } + + if (RTC_DLOG_IS_ON) { + for (const auto& descriptor : packet->descriptors()) { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received " + << DebugConvertChunkToString(descriptor.data); + } + } + + if (!ValidatePacket(*packet)) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Packet failed verification tag check - dropping"; + RTC_DCHECK(IsConsistent()); + callbacks_.TriggerDeferred(); + return; + } + + MaybeSendShutdownOnPacketReceived(*packet); + + for (const auto& descriptor : packet->descriptors()) { + if (!Dispatch(packet->common_header(), descriptor)) { + break; + } + } + + if (tcb_ != nullptr) { + tcb_->data_tracker().ObservePacketEnd(); + tcb_->MaybeSendSack(); + } + + RTC_DCHECK(IsConsistent()); + callbacks_.TriggerDeferred(); +} + +void DcSctpSocket::DebugPrintOutgoing(rtc::ArrayView payload) { + auto packet = SctpPacket::Parse(payload); + RTC_DCHECK(packet.has_value()); + + for (const auto& desc : packet->descriptors()) { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Sent " + << DebugConvertChunkToString(desc.data); + } +} + +bool DcSctpSocket::Dispatch(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + switch (descriptor.type) { + case DataChunk::kType: + HandleData(header, descriptor); + break; + case InitChunk::kType: + HandleInit(header, descriptor); + break; + case InitAckChunk::kType: + HandleInitAck(header, descriptor); + break; + case SackChunk::kType: + HandleSack(header, descriptor); + break; + case HeartbeatRequestChunk::kType: + HandleHeartbeatRequest(header, descriptor); + break; + case HeartbeatAckChunk::kType: + HandleHeartbeatAck(header, descriptor); + break; + case AbortChunk::kType: + HandleAbort(header, descriptor); + break; + case ErrorChunk::kType: + HandleError(header, descriptor); + break; + case CookieEchoChunk::kType: + HandleCookieEcho(header, descriptor); + break; + case CookieAckChunk::kType: + HandleCookieAck(header, descriptor); + break; + case ShutdownChunk::kType: + HandleShutdown(header, descriptor); + break; + case ShutdownAckChunk::kType: + HandleShutdownAck(header, descriptor); + break; + case ShutdownCompleteChunk::kType: + HandleShutdownComplete(header, descriptor); + break; + case ReConfigChunk::kType: + HandleReconfig(header, descriptor); + break; + case ForwardTsnChunk::kType: + HandleForwardTsn(header, descriptor); + break; + case IDataChunk::kType: + HandleIData(header, descriptor); + break; + case IForwardTsnChunk::kType: + HandleForwardTsn(header, descriptor); + break; + default: + return HandleUnrecognizedChunk(descriptor); + } + return true; +} + +bool DcSctpSocket::HandleUnrecognizedChunk( + const SctpPacket::ChunkDescriptor& descriptor) { + bool report_as_error = (descriptor.type & 0x40) != 0; + bool continue_processing = (descriptor.type & 0x80) != 0; + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received unknown chunk: " + << static_cast(descriptor.type); + if (report_as_error) { + rtc::StringBuilder sb; + sb << "Received unknown chunk of type: " + << static_cast(descriptor.type) << " with report-error bit set"; + callbacks_.OnError(ErrorKind::kParseFailed, sb.str()); + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << "Unknown chunk, with type indicating it should be reported."; + + // https://tools.ietf.org/html/rfc4960#section-3.2 + // "... report in an ERROR chunk using the 'Unrecognized Chunk Type' + // cause." + if (tcb_ != nullptr) { + // Need TCB - this chunk must be sent with a correct verification tag. + SendPacket(tcb_->PacketBuilder().Add( + ErrorChunk(Parameters::Builder() + .Add(UnrecognizedChunkTypeCause(std::vector( + descriptor.data.begin(), descriptor.data.end()))) + .Build()))); + } + } + if (!continue_processing) { + // https://tools.ietf.org/html/rfc4960#section-3.2 + // "Stop processing this SCTP packet and discard it, do not process any + // further chunks within it." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Unknown chunk, with type indicating not to " + "process any further chunks"; + } + + return continue_processing; +} + +absl::optional DcSctpSocket::OnInitTimerExpiry() { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Timer " << t1_init_->name() + << " has expired: " << t1_init_->expiration_count() + << "/" << t1_init_->options().max_restarts; + RTC_DCHECK(state_ == State::kCookieWait); + + if (t1_init_->is_running()) { + SendInit(); + } else { + InternalClose(ErrorKind::kTooManyRetries, "No INIT_ACK received"); + } + RTC_DCHECK(IsConsistent()); + return absl::nullopt; +} + +absl::optional DcSctpSocket::OnCookieTimerExpiry() { + // https://tools.ietf.org/html/rfc4960#section-4 + // "If the T1-cookie timer expires, the endpoint MUST retransmit COOKIE + // ECHO and restart the T1-cookie timer without changing state. This MUST + // be repeated up to 'Max.Init.Retransmits' times. After that, the endpoint + // MUST abort the initialization process and report the error to the SCTP + // user." + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Timer " << t1_cookie_->name() + << " has expired: " << t1_cookie_->expiration_count() + << "/" << t1_cookie_->options().max_restarts; + + RTC_DCHECK(state_ == State::kCookieEchoed); + + if (t1_cookie_->is_running()) { + tcb_->SendBufferedPackets(callbacks_.TimeMillis()); + } else { + InternalClose(ErrorKind::kTooManyRetries, "No COOKIE_ACK received"); + } + + RTC_DCHECK(IsConsistent()); + return absl::nullopt; +} + +absl::optional DcSctpSocket::OnShutdownTimerExpiry() { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Timer " << t2_shutdown_->name() + << " has expired: " << t2_shutdown_->expiration_count() + << "/" << t2_shutdown_->options().max_restarts; + + if (!t2_shutdown_->is_running()) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "An endpoint should limit the number of retransmissions of the SHUTDOWN + // chunk to the protocol parameter 'Association.Max.Retrans'. If this + // threshold is exceeded, the endpoint should destroy the TCB..." + + SendPacket(tcb_->PacketBuilder().Add( + AbortChunk(true, Parameters::Builder() + .Add(UserInitiatedAbortCause( + "Too many retransmissions of SHUTDOWN")) + .Build()))); + + InternalClose(ErrorKind::kTooManyRetries, "No SHUTDOWN_ACK received"); + RTC_DCHECK(IsConsistent()); + return absl::nullopt; + } + + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If the timer expires, the endpoint must resend the SHUTDOWN with the + // updated last sequential TSN received from its peer." + SendShutdown(); + RTC_DCHECK(IsConsistent()); + return tcb_->current_rto(); +} + +void DcSctpSocket::SendPacket(SctpPacket::Builder& builder) { + if (builder.empty()) { + return; + } + + std::vector payload = builder.Build(); + + if (RTC_DLOG_IS_ON) { + DebugPrintOutgoing(payload); + } + + // The heartbeat interval timer is restarted for every sent packet, to + // fire when the outgoing channel is inactive. + if (tcb_ != nullptr) { + tcb_->heartbeat_handler().RestartTimer(); + } + + if (packet_observer_ != nullptr) { + packet_observer_->OnSentPacket(callbacks_.TimeMillis(), payload); + } + callbacks_.SendPacket(payload); +} + +bool DcSctpSocket::ValidateHasTCB() { + if (tcb_ != nullptr) { + return true; + } + + callbacks_.OnError( + ErrorKind::kNotConnected, + "Received unexpected commands on socket that is not connected"); + return false; +} + +void DcSctpSocket::ReportFailedToParseChunk(int chunk_type) { + rtc::StringBuilder sb; + sb << "Failed to parse chunk of type: " << chunk_type; + callbacks_.OnError(ErrorKind::kParseFailed, sb.str()); +} + +void DcSctpSocket::HandleData(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional chunk = DataChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + HandleDataCommon(*chunk); + } +} + +void DcSctpSocket::HandleIData(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional chunk = IDataChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + HandleDataCommon(*chunk); + } +} + +void DcSctpSocket::HandleDataCommon(AnyDataChunk& chunk) { + TSN tsn = chunk.tsn(); + AnyDataChunk::ImmediateAckFlag immediate_ack = chunk.options().immediate_ack; + Data data = std::move(chunk).extract(); + + if (data.payload.empty()) { + // Empty DATA chunks are illegal. + SendPacket(tcb_->PacketBuilder().Add( + ErrorChunk(Parameters::Builder().Add(NoUserDataCause(tsn)).Build()))); + callbacks_.OnError(ErrorKind::kProtocolViolation, + "Received DATA chunk with no user data"); + return; + } + + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Handle DATA, queue_size=" + << tcb_->reassembly_queue().queued_bytes() + << ", water_mark=" + << tcb_->reassembly_queue().watermark_bytes() + << ", full=" << tcb_->reassembly_queue().is_full() + << ", above=" + << tcb_->reassembly_queue().is_above_watermark(); + + if (tcb_->reassembly_queue().is_full()) { + // If the reassembly queue is full, there is nothing that can be done. The + // specification only allows dropping gap-ack-blocks, and that's not + // likely to help as the socket has been trying to fill gaps since the + // watermark was reached. + SendPacket(tcb_->PacketBuilder().Add(AbortChunk( + true, Parameters::Builder().Add(OutOfResourceErrorCause()).Build()))); + InternalClose(ErrorKind::kResourceExhaustion, + "Reassembly Queue is exhausted"); + return; + } + + if (tcb_->reassembly_queue().is_above_watermark()) { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Is above high watermark"; + // If the reassembly queue is above its high watermark, only accept data + // chunks that increase its cumulative ack tsn in an attempt to fill gaps + // to deliver messages. + if (!tcb_->data_tracker().will_increase_cum_ack_tsn(tsn)) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Rejected data because of exceeding watermark"; + tcb_->data_tracker().ForceImmediateSack(); + return; + } + } + + if (!tcb_->data_tracker().IsTSNValid(tsn)) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Rejected data because of failing TSN validity"; + return; + } + + tcb_->data_tracker().Observe(tsn, immediate_ack); + tcb_->reassembly_queue().MaybeResetStreamsDeferred( + tcb_->data_tracker().last_cumulative_acked_tsn()); + tcb_->reassembly_queue().Add(tsn, std::move(data)); + DeliverReassembledMessages(); +} + +void DcSctpSocket::HandleInit(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional chunk = InitChunk::Parse(descriptor.data); + if (!ValidateParseSuccess(chunk)) { + return; + } + + if (chunk->initiate_tag() == VerificationTag(0) || + chunk->nbr_outbound_streams() == 0 || chunk->nbr_inbound_streams() == 0) { + // https://tools.ietf.org/html/rfc4960#section-3.3.2 + // "If the value of the Initiate Tag in a received INIT chunk is found + // to be 0, the receiver MUST treat it as an error and close the + // association by transmitting an ABORT." + + // "A receiver of an INIT with the OS value set to 0 SHOULD abort the + // association." + + // "A receiver of an INIT with the MIS value of 0 SHOULD abort the + // association." + + SendPacket(SctpPacket::Builder(VerificationTag(0), options_) + .Add(AbortChunk( + /*filled_in_verification_tag=*/false, + Parameters::Builder() + .Add(ProtocolViolationCause("INIT malformed")) + .Build()))); + InternalClose(ErrorKind::kProtocolViolation, "Received invalid INIT"); + return; + } + + if (state_ == State::kShutdownAckSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If an endpoint is in the SHUTDOWN-ACK-SENT state and receives an + // INIT chunk (e.g., if the SHUTDOWN COMPLETE was lost) with source and + // destination transport addresses (either in the IP addresses or in the + // INIT chunk) that belong to this association, it should discard the INIT + // chunk and retransmit the SHUTDOWN ACK chunk." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received Init indicating lost ShutdownComplete"; + SendShutdownAck(); + return; + } + + TieTag tie_tag(0); + if (state_ == State::kClosed) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received Init in closed state (normal)"; + + MakeConnectionParameters(); + } else if (state_ == State::kCookieWait || state_ == State::kCookieEchoed) { + // https://tools.ietf.org/html/rfc4960#section-5.2.1 + // "This usually indicates an initialization collision, i.e., each + // endpoint is attempting, at about the same time, to establish an + // association with the other endpoint. Upon receipt of an INIT in the + // COOKIE-WAIT state, an endpoint MUST respond with an INIT ACK using the + // same parameters it sent in its original INIT chunk (including its + // Initiate Tag, unchanged). When responding, the endpoint MUST send the + // INIT ACK back to the same address that the original INIT (sent by this + // endpoint) was sent." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received Init indicating simultaneous connections"; + } else { + RTC_DCHECK(tcb_ != nullptr); + // https://tools.ietf.org/html/rfc4960#section-5.2.2 + // "The outbound SCTP packet containing this INIT ACK MUST carry a + // Verification Tag value equal to the Initiate Tag found in the + // unexpected INIT. And the INIT ACK MUST contain a new Initiate Tag + // (randomly generated; see Section 5.3.1). Other parameters for the + // endpoint SHOULD be copied from the existing parameters of the + // association (e.g., number of outbound streams) into the INIT ACK and + // cookie." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received Init indicating restarted connection"; + // Create a new verification tag - different from the previous one. + for (int tries = 0; tries < 10; ++tries) { + connect_params_.verification_tag = VerificationTag( + callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag)); + if (connect_params_.verification_tag != tcb_->my_verification_tag()) { + break; + } + } + + // Make the initial TSN make a large jump, so that there is no overlap + // with the old and new association. + connect_params_.initial_tsn = + TSN(*tcb_->retransmission_queue().next_tsn() + 1000000); + tie_tag = tcb_->tie_tag(); + } + + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << rtc::StringFormat( + "Proceeding with connection. my_verification_tag=%08x, " + "my_initial_tsn=%u, peer_verification_tag=%08x, " + "peer_initial_tsn=%u", + *connect_params_.verification_tag, *connect_params_.initial_tsn, + *chunk->initiate_tag(), *chunk->initial_tsn()); + + Capabilities capabilities = GetCapabilities(options_, chunk->parameters()); + + SctpPacket::Builder b(chunk->initiate_tag(), options_); + Parameters::Builder params_builder = + Parameters::Builder().Add(StateCookieParameter( + StateCookie(chunk->initiate_tag(), chunk->initial_tsn(), + chunk->a_rwnd(), tie_tag, capabilities) + .Serialize())); + AddCapabilityParameters(options_, params_builder); + + InitAckChunk init_ack(/*initiate_tag=*/connect_params_.verification_tag, + options_.max_receiver_window_buffer_size, + options_.announced_maximum_outgoing_streams, + options_.announced_maximum_incoming_streams, + connect_params_.initial_tsn, params_builder.Build()); + b.Add(init_ack); + SendPacket(b); +} + +void DcSctpSocket::HandleInitAck( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional chunk = InitAckChunk::Parse(descriptor.data); + if (!ValidateParseSuccess(chunk)) { + return; + } + + if (state_ != State::kCookieWait) { + // https://tools.ietf.org/html/rfc4960#section-5.2.3 + // "If an INIT ACK is received by an endpoint in any state other than + // the COOKIE-WAIT state, the endpoint should discard the INIT ACK chunk." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received INIT_ACK in unexpected state"; + return; + } + + auto cookie = chunk->parameters().get(); + if (!cookie.has_value()) { + SendPacket(SctpPacket::Builder(connect_params_.verification_tag, options_) + .Add(AbortChunk( + /*filled_in_verification_tag=*/false, + Parameters::Builder() + .Add(ProtocolViolationCause("INIT-ACK malformed")) + .Build()))); + InternalClose(ErrorKind::kProtocolViolation, + "InitAck chunk doesn't contain a cookie"); + return; + } + Capabilities capabilities = GetCapabilities(options_, chunk->parameters()); + t1_init_->Stop(); + + tcb_ = std::make_unique( + timer_manager_, log_prefix_, options_, capabilities, callbacks_, + send_queue_, connect_params_.verification_tag, + connect_params_.initial_tsn, chunk->initiate_tag(), chunk->initial_tsn(), + chunk->a_rwnd(), MakeTieTag(callbacks_), + [this]() { return state_ == State::kEstablished; }, + [this](SctpPacket::Builder& builder) { return SendPacket(builder); }); + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Created peer TCB: " << tcb_->ToString(); + + SetState(State::kCookieEchoed, "INIT_ACK received"); + + // The connection isn't fully established just yet. + tcb_->SetCookieEchoChunk(CookieEchoChunk(cookie->data())); + tcb_->SendBufferedPackets(callbacks_.TimeMillis()); + t1_cookie_->Start(); +} + +void DcSctpSocket::HandleCookieEcho( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional chunk = + CookieEchoChunk::Parse(descriptor.data); + if (!ValidateParseSuccess(chunk)) { + return; + } + + absl::optional cookie = + StateCookie::Deserialize(chunk->cookie()); + if (!cookie.has_value()) { + callbacks_.OnError(ErrorKind::kParseFailed, "Failed to parse state cookie"); + return; + } + + if (tcb_ != nullptr) { + if (!HandleCookieEchoWithTCB(header, *cookie)) { + return; + } + } else { + if (header.verification_tag != connect_params_.verification_tag) { + callbacks_.OnError( + ErrorKind::kParseFailed, + rtc::StringFormat( + "Received CookieEcho with invalid verification tag: %08x, " + "expected %08x", + *header.verification_tag, *connect_params_.verification_tag)); + return; + } + } + + // The init timer can be running on simultaneous connections. + t1_init_->Stop(); + t1_cookie_->Stop(); + if (state_ != State::kEstablished) { + if (tcb_ != nullptr) { + tcb_->ClearCookieEchoChunk(); + } + SetState(State::kEstablished, "COOKIE_ECHO received"); + callbacks_.OnConnected(); + } + + if (tcb_ == nullptr) { + tcb_ = std::make_unique( + timer_manager_, log_prefix_, options_, cookie->capabilities(), + callbacks_, send_queue_, connect_params_.verification_tag, + connect_params_.initial_tsn, cookie->initiate_tag(), + cookie->initial_tsn(), cookie->a_rwnd(), MakeTieTag(callbacks_), + [this]() { return state_ == State::kEstablished; }, + [this](SctpPacket::Builder& builder) { return SendPacket(builder); }); + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Created peer TCB: " << tcb_->ToString(); + } + + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(CookieAckChunk()); + + // https://tools.ietf.org/html/rfc4960#section-5.1 + // "A COOKIE ACK chunk may be bundled with any pending DATA chunks (and/or + // SACK chunks), but the COOKIE ACK chunk MUST be the first chunk in the + // packet." + tcb_->SendBufferedPackets(b, callbacks_.TimeMillis()); +} + +bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header, + const StateCookie& cookie) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Handling CookieEchoChunk with TCB. local_tag=" + << *tcb_->my_verification_tag() + << ", peer_tag=" << *header.verification_tag + << ", tcb_tag=" << *tcb_->peer_verification_tag() + << ", cookie_tag=" << *cookie.initiate_tag() + << ", local_tie_tag=" << *tcb_->tie_tag() + << ", peer_tie_tag=" << *cookie.tie_tag(); + // https://tools.ietf.org/html/rfc4960#section-5.2.4 + // "Handle a COOKIE ECHO when a TCB Exists" + if (header.verification_tag != tcb_->my_verification_tag() && + tcb_->peer_verification_tag() != cookie.initiate_tag() && + cookie.tie_tag() == tcb_->tie_tag()) { + // "A) In this case, the peer may have restarted." + if (state_ == State::kShutdownAckSent) { + // "If the endpoint is in the SHUTDOWN-ACK-SENT state and recognizes + // that the peer has restarted ... it MUST NOT set up a new association + // but instead resend the SHUTDOWN ACK and send an ERROR chunk with a + // "Cookie Received While Shutting Down" error cause to its peer." + SctpPacket::Builder b(cookie.initiate_tag(), options_); + b.Add(ShutdownAckChunk()); + b.Add(ErrorChunk(Parameters::Builder() + .Add(CookieReceivedWhileShuttingDownCause()) + .Build())); + SendPacket(b); + callbacks_.OnError(ErrorKind::kWrongSequence, + "Received COOKIE-ECHO while shutting down"); + return false; + } + + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received COOKIE-ECHO indicating a restarted peer"; + + // If a message was partly sent, and the peer restarted, resend it in + // full by resetting the send queue. + send_queue_.Reset(); + tcb_ = nullptr; + callbacks_.OnConnectionRestarted(); + } else if (header.verification_tag == tcb_->my_verification_tag() && + tcb_->peer_verification_tag() != cookie.initiate_tag()) { + // TODO(boivie): Handle the peer_tag == 0? + // "B) In this case, both sides may be attempting to start an + // association at about the same time, but the peer endpoint started its + // INIT after responding to the local endpoint's INIT." + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << "Received COOKIE-ECHO indicating simultaneous connections"; + tcb_ = nullptr; + } else if (header.verification_tag != tcb_->my_verification_tag() && + tcb_->peer_verification_tag() == cookie.initiate_tag() && + cookie.tie_tag() == TieTag(0)) { + // "C) In this case, the local endpoint's cookie has arrived late. + // Before it arrived, the local endpoint sent an INIT and received an + // INIT ACK and finally sent a COOKIE ECHO with the peer's same tag but + // a new tag of its own. The cookie should be silently discarded. The + // endpoint SHOULD NOT change states and should leave any timers + // running." + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << "Received COOKIE-ECHO indicating a late COOKIE-ECHO. Discarding"; + return false; + } else if (header.verification_tag == tcb_->my_verification_tag() && + tcb_->peer_verification_tag() == cookie.initiate_tag()) { + // "D) When both local and remote tags match, the endpoint should enter + // the ESTABLISHED state, if it is in the COOKIE-ECHOED state. It + // should stop any cookie timer that may be running and send a COOKIE + // ACK." + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << "Received duplicate COOKIE-ECHO, probably because of peer not " + "receiving COOKIE-ACK and retransmitting COOKIE-ECHO. Continuing."; + } + return true; +} + +void DcSctpSocket::HandleCookieAck( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional chunk = CookieAckChunk::Parse(descriptor.data); + if (!ValidateParseSuccess(chunk)) { + return; + } + + if (state_ != State::kCookieEchoed) { + // https://tools.ietf.org/html/rfc4960#section-5.2.5 + // "At any state other than COOKIE-ECHOED, an endpoint should silently + // discard a received COOKIE ACK chunk." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received COOKIE_ACK not in COOKIE_ECHOED state"; + return; + } + + // RFC 4960, Errata ID: 4400 + t1_cookie_->Stop(); + tcb_->ClearCookieEchoChunk(); + SetState(State::kEstablished, "COOKIE_ACK received"); + tcb_->SendBufferedPackets(callbacks_.TimeMillis()); + callbacks_.OnConnected(); +} + +void DcSctpSocket::DeliverReassembledMessages() { + if (tcb_->reassembly_queue().HasMessages()) { + for (auto& message : tcb_->reassembly_queue().FlushMessages()) { + callbacks_.OnMessageReceived(std::move(message)); + } + } +} + +void DcSctpSocket::HandleSack(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional chunk = SackChunk::Parse(descriptor.data); + + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + TimeMs now = callbacks_.TimeMillis(); + SackChunk sack = ChunkValidators::Clean(*std::move(chunk)); + + if (tcb_->retransmission_queue().HandleSack(now, sack)) { + MaybeSendShutdownOrAck(); + // Receiving an ACK will decrease outstanding bytes (maybe now below + // cwnd?) or indicate packet loss that may result in sending FORWARD-TSN. + tcb_->SendBufferedPackets(now); + } else { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Dropping out-of-order SACK with TSN " + << *sack.cumulative_tsn_ack(); + } + } +} + +void DcSctpSocket::HandleHeartbeatRequest( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional chunk = + HeartbeatRequestChunk::Parse(descriptor.data); + + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + tcb_->heartbeat_handler().HandleHeartbeatRequest(*std::move(chunk)); + } +} + +void DcSctpSocket::HandleHeartbeatAck( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional chunk = + HeartbeatAckChunk::Parse(descriptor.data); + + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + tcb_->heartbeat_handler().HandleHeartbeatAck(*std::move(chunk)); + } +} + +void DcSctpSocket::HandleAbort(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional chunk = AbortChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk)) { + std::string error_string = ErrorCausesToString(chunk->error_causes()); + if (tcb_ == nullptr) { + // https://tools.ietf.org/html/rfc4960#section-3.3.7 + // "If an endpoint receives an ABORT with a format error or no TCB is + // found, it MUST silently discard it." + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received ABORT (" << error_string + << ") on a connection with no TCB. Ignoring"; + return; + } + + RTC_DLOG(LS_WARNING) << log_prefix() << "Received ABORT (" << error_string + << ") - closing connection."; + InternalClose(ErrorKind::kPeerReported, error_string); + } +} + +void DcSctpSocket::HandleError(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional chunk = ErrorChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk)) { + std::string error_string = ErrorCausesToString(chunk->error_causes()); + if (tcb_ == nullptr) { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received ERROR (" << error_string + << ") on a connection with no TCB. Ignoring"; + return; + } + + RTC_DLOG(LS_WARNING) << log_prefix() << "Received ERROR: " << error_string; + callbacks_.OnError(ErrorKind::kPeerReported, + "Peer reported error: " + error_string); + } +} + +void DcSctpSocket::HandleReconfig( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional chunk = ReConfigChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + tcb_->stream_reset_handler().HandleReConfig(*std::move(chunk)); + } +} + +void DcSctpSocket::HandleShutdown( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + if (!ValidateParseSuccess(ShutdownChunk::Parse(descriptor.data))) { + return; + } + + if (state_ == State::kClosed) { + return; + } else if (state_ == State::kCookieWait || state_ == State::kCookieEchoed) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If a SHUTDOWN is received in the COOKIE-WAIT or COOKIE ECHOED state, + // the SHUTDOWN chunk SHOULD be silently discarded." + } else if (state_ == State::kShutdownSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If an endpoint is in the SHUTDOWN-SENT state and receives a + // SHUTDOWN chunk from its peer, the endpoint shall respond immediately + // with a SHUTDOWN ACK to its peer, and move into the SHUTDOWN-ACK-SENT + // state restarting its T2-shutdown timer." + SendShutdownAck(); + SetState(State::kShutdownAckSent, "SHUTDOWN received"); + } else if (state_ == State::kShutdownAckSent) { + // TODO(webrtc:12739): This condition should be removed and handled by the + // next (state_ != State::kShutdownReceived). + return; + } else if (state_ != State::kShutdownReceived) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received SHUTDOWN - shutting down the socket"; + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Upon reception of the SHUTDOWN, the peer endpoint shall enter the + // SHUTDOWN-RECEIVED state, stop accepting new data from its SCTP user, + // and verify, by checking the Cumulative TSN Ack field of the chunk, that + // all its outstanding DATA chunks have been received by the SHUTDOWN + // sender." + SetState(State::kShutdownReceived, "SHUTDOWN received"); + MaybeSendShutdownOrAck(); + } +} + +void DcSctpSocket::HandleShutdownAck( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + if (!ValidateParseSuccess(ShutdownAckChunk::Parse(descriptor.data))) { + return; + } + + if (state_ == State::kShutdownSent || state_ == State::kShutdownAckSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Upon the receipt of the SHUTDOWN ACK, the SHUTDOWN sender shall stop + // the T2-shutdown timer, send a SHUTDOWN COMPLETE chunk to its peer, and + // remove all record of the association." + + // "If an endpoint is in the SHUTDOWN-ACK-SENT state and receives a + // SHUTDOWN ACK, it shall stop the T2-shutdown timer, send a SHUTDOWN + // COMPLETE chunk to its peer, and remove all record of the association." + + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(ShutdownCompleteChunk(/*tag_reflected=*/false)); + SendPacket(b); + InternalClose(ErrorKind::kNoError, ""); + } else { + // https://tools.ietf.org/html/rfc4960#section-8.5.1 + // "If the receiver is in COOKIE-ECHOED or COOKIE-WAIT state + // the procedures in Section 8.4 SHOULD be followed; in other words, it + // should be treated as an Out Of The Blue packet." + + // https://tools.ietf.org/html/rfc4960#section-8.4 + // "If the packet contains a SHUTDOWN ACK chunk, the receiver + // should respond to the sender of the OOTB packet with a SHUTDOWN + // COMPLETE. When sending the SHUTDOWN COMPLETE, the receiver of the OOTB + // packet must fill in the Verification Tag field of the outbound packet + // with the Verification Tag received in the SHUTDOWN ACK and set the T + // bit in the Chunk Flags to indicate that the Verification Tag is + // reflected." + + SctpPacket::Builder b(header.verification_tag, options_); + b.Add(ShutdownCompleteChunk(/*tag_reflected=*/true)); + SendPacket(b); + } +} + +void DcSctpSocket::HandleShutdownComplete( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + if (!ValidateParseSuccess(ShutdownCompleteChunk::Parse(descriptor.data))) { + return; + } + + if (state_ == State::kShutdownAckSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Upon reception of the SHUTDOWN COMPLETE chunk, the endpoint will + // verify that it is in the SHUTDOWN-ACK-SENT state; if it is not, the + // chunk should be discarded. If the endpoint is in the SHUTDOWN-ACK-SENT + // state, the endpoint should stop the T2-shutdown timer and remove all + // knowledge of the association (and thus the association enters the + // CLOSED state)." + InternalClose(ErrorKind::kNoError, ""); + } +} + +void DcSctpSocket::HandleForwardTsn( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional chunk = + ForwardTsnChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + HandleForwardTsnCommon(*chunk); + } +} + +void DcSctpSocket::HandleIForwardTsn( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional chunk = + IForwardTsnChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + HandleForwardTsnCommon(*chunk); + } +} + +void DcSctpSocket::HandleForwardTsnCommon(const AnyForwardTsnChunk& chunk) { + if (!tcb_->capabilities().partial_reliability) { + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(AbortChunk(/*filled_in_verification_tag=*/true, + Parameters::Builder() + .Add(ProtocolViolationCause( + "I-FORWARD-TSN received, but not indicated " + "during connection establishment")) + .Build())); + SendPacket(b); + + callbacks_.OnError(ErrorKind::kProtocolViolation, + "Received a FORWARD_TSN without announced peer support"); + return; + } + tcb_->data_tracker().HandleForwardTsn(chunk.new_cumulative_tsn()); + tcb_->reassembly_queue().Handle(chunk); + // A forward TSN - for ordered streams - may allow messages to be + // delivered. + DeliverReassembledMessages(); + + // Processing a FORWARD_TSN might result in sending a SACK. + tcb_->MaybeSendSack(); +} + +void DcSctpSocket::MaybeSendShutdownOrAck() { + if (tcb_->retransmission_queue().outstanding_bytes() != 0) { + return; + } + + if (state_ == State::kShutdownPending) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Once all its outstanding data has been acknowledged, the endpoint + // shall send a SHUTDOWN chunk to its peer including in the Cumulative TSN + // Ack field the last sequential TSN it has received from the peer. It + // shall then start the T2-shutdown timer and enter the SHUTDOWN-SENT + // state."" + + SendShutdown(); + t2_shutdown_->set_duration(tcb_->current_rto()); + t2_shutdown_->Start(); + SetState(State::kShutdownSent, "No more outstanding data"); + } else if (state_ == State::kShutdownReceived) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If the receiver of the SHUTDOWN has no more outstanding DATA + // chunks, the SHUTDOWN receiver MUST send a SHUTDOWN ACK and start a + // T2-shutdown timer of its own, entering the SHUTDOWN-ACK-SENT state. If + // the timer expires, the endpoint must resend the SHUTDOWN ACK." + + SendShutdownAck(); + SetState(State::kShutdownAckSent, "No more outstanding data"); + } +} + +void DcSctpSocket::SendShutdown() { + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(ShutdownChunk(tcb_->data_tracker().last_cumulative_acked_tsn())); + SendPacket(b); +} + +void DcSctpSocket::SendShutdownAck() { + SendPacket(tcb_->PacketBuilder().Add(ShutdownAckChunk())); + t2_shutdown_->set_duration(tcb_->current_rto()); + t2_shutdown_->Start(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/socket/dcsctp_socket.h b/net/dcsctp/socket/dcsctp_socket.h new file mode 100644 index 0000000000..32e89b50d1 --- /dev/null +++ b/net/dcsctp/socket/dcsctp_socket.h @@ -0,0 +1,272 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_DCSCTP_SOCKET_H_ +#define NET_DCSCTP_SOCKET_DCSCTP_SOCKET_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/abort_chunk.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/chunk/error_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/init_ack_chunk.h" +#include "net/dcsctp/packet/chunk/init_chunk.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_ack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_complete_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/packet_observer.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/callback_deferrer.h" +#include "net/dcsctp/socket/state_cookie.h" +#include "net/dcsctp/socket/transmission_control_block.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_error_counter.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "net/dcsctp/tx/retransmission_timeout.h" +#include "net/dcsctp/tx/rr_send_queue.h" + +namespace dcsctp { + +// DcSctpSocket represents a single SCTP socket, to be used over DTLS. +// +// Every dcSCTP is completely isolated from any other socket. +// +// This class manages all packet and chunk dispatching and mainly handles the +// connection sequences (connect, close, shutdown, etc) as well as managing +// the Transmission Control Block (tcb). +// +// This class is thread-compatible. +class DcSctpSocket : public DcSctpSocketInterface { + public: + // Instantiates a DcSctpSocket, which interacts with the world through the + // `callbacks` interface and is configured using `options`. + // + // For debugging, `log_prefix` will prefix all debug logs, and a + // `packet_observer` can be attached to e.g. dump sent and received packets. + DcSctpSocket(absl::string_view log_prefix, + DcSctpSocketCallbacks& callbacks, + std::unique_ptr packet_observer, + const DcSctpOptions& options); + + DcSctpSocket(const DcSctpSocket&) = delete; + DcSctpSocket& operator=(const DcSctpSocket&) = delete; + + // Implementation of `DcSctpSocketInterface`. + void ReceivePacket(rtc::ArrayView data) override; + void HandleTimeout(TimeoutID timeout_id) override; + void Connect() override; + void Shutdown() override; + void Close() override; + SendStatus Send(DcSctpMessage message, + const SendOptions& send_options) override; + ResetStreamsStatus ResetStreams( + rtc::ArrayView outgoing_streams) override; + SocketState state() const override; + const DcSctpOptions& options() const override { return options_; } + void SetMaxMessageSize(size_t max_message_size) override; + size_t buffered_amount(StreamID stream_id) const override; + size_t buffered_amount_low_threshold(StreamID stream_id) const override; + void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override; + // Returns this socket's verification tag, or zero if not yet connected. + VerificationTag verification_tag() const { + return tcb_ != nullptr ? tcb_->my_verification_tag() : VerificationTag(0); + } + + private: + // Parameter proposals valid during the connect phase. + struct ConnectParameters { + TSN initial_tsn = TSN(0); + VerificationTag verification_tag = VerificationTag(0); + }; + + // Detailed state (separate from SocketState, which is the public state). + enum class State { + kClosed, + kCookieWait, + // TCB valid in these: + kCookieEchoed, + kEstablished, + kShutdownPending, + kShutdownSent, + kShutdownReceived, + kShutdownAckSent, + }; + + // Returns the log prefix used for debug logging. + std::string log_prefix() const; + + bool IsConsistent() const; + static constexpr absl::string_view ToString(DcSctpSocket::State state); + + // Changes the socket state, given a `reason` (for debugging/logging). + void SetState(State state, absl::string_view reason); + // Fills in `connect_params` with random verification tag and initial TSN. + void MakeConnectionParameters(); + // Closes the association. Note that the TCB will not be valid past this call. + void InternalClose(ErrorKind error, absl::string_view message); + // Closes the association, because of too many retransmission errors. + void CloseConnectionBecauseOfTooManyTransmissionErrors(); + // Timer expiration handlers + absl::optional OnInitTimerExpiry(); + absl::optional OnCookieTimerExpiry(); + absl::optional OnShutdownTimerExpiry(); + // Builds the packet from `builder` and sends it (through callbacks). + void SendPacket(SctpPacket::Builder& builder); + // Sends SHUTDOWN or SHUTDOWN-ACK if the socket is shutting down and if all + // outstanding data has been acknowledged. + void MaybeSendShutdownOrAck(); + // If the socket is shutting down, responds SHUTDOWN to any incoming DATA. + void MaybeSendShutdownOnPacketReceived(const SctpPacket& packet); + // Sends a INIT chunk. + void SendInit(); + // Sends a SHUTDOWN chunk. + void SendShutdown(); + // Sends a SHUTDOWN-ACK chunk. + void SendShutdownAck(); + // Validates the SCTP packet, as a whole - not the validity of individual + // chunks within it, as that's done in the different chunk handlers. + bool ValidatePacket(const SctpPacket& packet); + // Parses `payload`, which is a serialized packet that is just going to be + // sent and prints all chunks. + void DebugPrintOutgoing(rtc::ArrayView payload); + // Called whenever there may be reassembled messages, and delivers those. + void DeliverReassembledMessages(); + // Returns true if there is a TCB, and false otherwise (and reports an error). + bool ValidateHasTCB(); + + // Returns true if the parsing of a chunk of type `T` succeeded. If it didn't, + // it reports an error and returns false. + template + bool ValidateParseSuccess(const absl::optional& c) { + if (c.has_value()) { + return true; + } + + ReportFailedToParseChunk(T::kType); + return false; + } + + // Reports failing to have parsed a chunk with the provided `chunk_type`. + void ReportFailedToParseChunk(int chunk_type); + // Called when unknown chunks are received. May report an error. + bool HandleUnrecognizedChunk(const SctpPacket::ChunkDescriptor& descriptor); + + // Will dispatch more specific chunk handlers. + bool Dispatch(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming DATA chunks. + void HandleData(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming I-DATA chunks. + void HandleIData(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Common handler for DATA and I-DATA chunks. + void HandleDataCommon(AnyDataChunk& chunk); + // Handles incoming INIT chunks. + void HandleInit(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming INIT-ACK chunks. + void HandleInitAck(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming SACK chunks. + void HandleSack(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming HEARTBEAT chunks. + void HandleHeartbeatRequest(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming HEARTBEAT-ACK chunks. + void HandleHeartbeatAck(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming ABORT chunks. + void HandleAbort(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming ERROR chunks. + void HandleError(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming COOKIE-ECHO chunks. + void HandleCookieEcho(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles receiving COOKIE-ECHO when there already is a TCB. The return value + // indicates if the processing should continue. + bool HandleCookieEchoWithTCB(const CommonHeader& header, + const StateCookie& cookie); + // Handles incoming COOKIE-ACK chunks. + void HandleCookieAck(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming SHUTDOWN chunks. + void HandleShutdown(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming SHUTDOWN-ACK chunks. + void HandleShutdownAck(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming FORWARD-TSN chunks. + void HandleForwardTsn(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming I-FORWARD-TSN chunks. + void HandleIForwardTsn(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming RE-CONFIG chunks. + void HandleReconfig(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Common handled for FORWARD-TSN/I-FORWARD-TSN. + void HandleForwardTsnCommon(const AnyForwardTsnChunk& chunk); + // Handles incoming SHUTDOWN-COMPLETE chunks + void HandleShutdownComplete(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + + const std::string log_prefix_; + const std::unique_ptr packet_observer_; + DcSctpOptions options_; + + // Enqueues callbacks and dispatches them just before returning to the caller. + CallbackDeferrer callbacks_; + + TimerManager timer_manager_; + const std::unique_ptr t1_init_; + const std::unique_ptr t1_cookie_; + const std::unique_ptr t2_shutdown_; + + // The actual SendQueue implementation. As data can be sent on a socket before + // the connection is established, this component is not in the TCB. + RRSendQueue send_queue_; + + // Contains verification tag and initial TSN between having sent the INIT + // until the connection is established (there is no TCB at this point). + ConnectParameters connect_params_; + // The socket state. + State state_ = State::kClosed; + // If the connection is established, contains a transmission control block. + std::unique_ptr tcb_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_DCSCTP_SOCKET_H_ diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc new file mode 100644 index 0000000000..7ca3d9b399 --- /dev/null +++ b/net/dcsctp/socket/dcsctp_socket_test.cc @@ -0,0 +1,1612 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/dcsctp_socket.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/chunk/error_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/init_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/text_pcap_packet_observer.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +ABSL_FLAG(bool, dcsctp_capture_packets, false, "Print packet capture."); + +namespace dcsctp { +namespace { +using ::testing::_; +using ::testing::AllOf; +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::SizeIs; + +constexpr SendOptions kSendOptions; +constexpr size_t kLargeMessageSize = DcSctpOptions::kMaxSafeMTUSize * 20; +static constexpr size_t kSmallMessageSize = 10; + +MATCHER_P(HasDataChunkWithStreamId, stream_id, "") { + absl::optional packet = SctpPacket::Parse(arg); + if (!packet.has_value()) { + *result_listener << "data didn't parse as an SctpPacket"; + return false; + } + + if (packet->descriptors()[0].type != DataChunk::kType) { + *result_listener << "the first chunk in the packet is not a data chunk"; + return false; + } + + absl::optional dc = + DataChunk::Parse(packet->descriptors()[0].data); + if (!dc.has_value()) { + *result_listener << "The first chunk didn't parse as a data chunk"; + return false; + } + + if (dc->stream_id() != stream_id) { + *result_listener << "the stream_id is " << *dc->stream_id(); + return false; + } + + return true; +} + +MATCHER_P(HasDataChunkWithPPID, ppid, "") { + absl::optional packet = SctpPacket::Parse(arg); + if (!packet.has_value()) { + *result_listener << "data didn't parse as an SctpPacket"; + return false; + } + + if (packet->descriptors()[0].type != DataChunk::kType) { + *result_listener << "the first chunk in the packet is not a data chunk"; + return false; + } + + absl::optional dc = + DataChunk::Parse(packet->descriptors()[0].data); + if (!dc.has_value()) { + *result_listener << "The first chunk didn't parse as a data chunk"; + return false; + } + + if (dc->ppid() != ppid) { + *result_listener << "the ppid is " << *dc->ppid(); + return false; + } + + return true; +} + +MATCHER_P(HasDataChunkWithSsn, ssn, "") { + absl::optional packet = SctpPacket::Parse(arg); + if (!packet.has_value()) { + *result_listener << "data didn't parse as an SctpPacket"; + return false; + } + + if (packet->descriptors()[0].type != DataChunk::kType) { + *result_listener << "the first chunk in the packet is not a data chunk"; + return false; + } + + absl::optional dc = + DataChunk::Parse(packet->descriptors()[0].data); + if (!dc.has_value()) { + *result_listener << "The first chunk didn't parse as a data chunk"; + return false; + } + + if (dc->ssn() != ssn) { + *result_listener << "the ssn is " << *dc->ssn(); + return false; + } + + return true; +} + +MATCHER_P(HasDataChunkWithMid, mid, "") { + absl::optional packet = SctpPacket::Parse(arg); + if (!packet.has_value()) { + *result_listener << "data didn't parse as an SctpPacket"; + return false; + } + + if (packet->descriptors()[0].type != IDataChunk::kType) { + *result_listener << "the first chunk in the packet is not an i-data chunk"; + return false; + } + + absl::optional dc = + IDataChunk::Parse(packet->descriptors()[0].data); + if (!dc.has_value()) { + *result_listener << "The first chunk didn't parse as an i-data chunk"; + return false; + } + + if (dc->message_id() != mid) { + *result_listener << "the mid is " << *dc->message_id(); + return false; + } + + return true; +} + +MATCHER_P(HasSackWithCumAckTsn, tsn, "") { + absl::optional packet = SctpPacket::Parse(arg); + if (!packet.has_value()) { + *result_listener << "data didn't parse as an SctpPacket"; + return false; + } + + if (packet->descriptors()[0].type != SackChunk::kType) { + *result_listener << "the first chunk in the packet is not a data chunk"; + return false; + } + + absl::optional sc = + SackChunk::Parse(packet->descriptors()[0].data); + if (!sc.has_value()) { + *result_listener << "The first chunk didn't parse as a data chunk"; + return false; + } + + if (sc->cumulative_tsn_ack() != tsn) { + *result_listener << "the cum_ack_tsn is " << *sc->cumulative_tsn_ack(); + return false; + } + + return true; +} + +MATCHER(HasSackWithNoGapAckBlocks, "") { + absl::optional packet = SctpPacket::Parse(arg); + if (!packet.has_value()) { + *result_listener << "data didn't parse as an SctpPacket"; + return false; + } + + if (packet->descriptors()[0].type != SackChunk::kType) { + *result_listener << "the first chunk in the packet is not a data chunk"; + return false; + } + + absl::optional sc = + SackChunk::Parse(packet->descriptors()[0].data); + if (!sc.has_value()) { + *result_listener << "The first chunk didn't parse as a data chunk"; + return false; + } + + if (!sc->gap_ack_blocks().empty()) { + *result_listener << "there are gap ack blocks"; + return false; + } + + return true; +} + +TSN AddTo(TSN tsn, int delta) { + return TSN(*tsn + delta); +} + +DcSctpOptions MakeOptionsForTest(bool enable_message_interleaving) { + DcSctpOptions options; + // To make the interval more predictable in tests. + options.heartbeat_interval_include_rtt = false; + options.enable_message_interleaving = enable_message_interleaving; + return options; +} + +std::unique_ptr GetPacketObserver(absl::string_view name) { + if (absl::GetFlag(FLAGS_dcsctp_capture_packets)) { + return std::make_unique(name); + } + return nullptr; +} + +class DcSctpSocketTest : public testing::Test { + protected: + explicit DcSctpSocketTest(bool enable_message_interleaving = false) + : options_(MakeOptionsForTest(enable_message_interleaving)), + cb_a_("A"), + cb_z_("Z"), + sock_a_("A", cb_a_, GetPacketObserver("A"), options_), + sock_z_("Z", cb_z_, GetPacketObserver("Z"), options_) {} + + void AdvanceTime(DurationMs duration) { + cb_a_.AdvanceTime(duration); + cb_z_.AdvanceTime(duration); + } + + static void ExchangeMessages(DcSctpSocket& sock_a, + MockDcSctpSocketCallbacks& cb_a, + DcSctpSocket& sock_z, + MockDcSctpSocketCallbacks& cb_z) { + bool delivered_packet = false; + do { + delivered_packet = false; + std::vector packet_from_a = cb_a.ConsumeSentPacket(); + if (!packet_from_a.empty()) { + delivered_packet = true; + sock_z.ReceivePacket(std::move(packet_from_a)); + } + std::vector packet_from_z = cb_z.ConsumeSentPacket(); + if (!packet_from_z.empty()) { + delivered_packet = true; + sock_a.ReceivePacket(std::move(packet_from_z)); + } + } while (delivered_packet); + } + + void RunTimers(MockDcSctpSocketCallbacks& cb, DcSctpSocket& socket) { + for (;;) { + absl::optional timeout_id = cb.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + socket.HandleTimeout(*timeout_id); + } + } + + void RunTimers() { + RunTimers(cb_a_, sock_a_); + RunTimers(cb_z_, sock_z_); + } + + // Calls Connect() on `sock_a_` and make the connection established. + void ConnectSockets() { + EXPECT_CALL(cb_a_, OnConnected).Times(1); + EXPECT_CALL(cb_z_, OnConnected).Times(1); + + sock_a_.Connect(); + // Z reads INIT, INIT_ACK, COOKIE_ECHO, COOKIE_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + EXPECT_EQ(sock_a_.state(), SocketState::kConnected); + EXPECT_EQ(sock_z_.state(), SocketState::kConnected); + } + + const DcSctpOptions options_; + testing::NiceMock cb_a_; + testing::NiceMock cb_z_; + DcSctpSocket sock_a_; + DcSctpSocket sock_z_; +}; + +TEST_F(DcSctpSocketTest, EstablishConnection) { + EXPECT_CALL(cb_a_, OnConnected).Times(1); + EXPECT_CALL(cb_z_, OnConnected).Times(1); + EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(0); + EXPECT_CALL(cb_z_, OnConnectionRestarted).Times(0); + + sock_a_.Connect(); + // Z reads INIT, produces INIT_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + // Z reads COOKIE_ECHO, produces COOKIE_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // A reads COOKIE_ACK. + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + EXPECT_EQ(sock_a_.state(), SocketState::kConnected); + EXPECT_EQ(sock_z_.state(), SocketState::kConnected); +} + +TEST_F(DcSctpSocketTest, EstablishConnectionWithSetupCollision) { + EXPECT_CALL(cb_a_, OnConnected).Times(1); + EXPECT_CALL(cb_z_, OnConnected).Times(1); + EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(0); + EXPECT_CALL(cb_z_, OnConnectionRestarted).Times(0); + sock_a_.Connect(); + sock_z_.Connect(); + + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + + EXPECT_EQ(sock_a_.state(), SocketState::kConnected); + EXPECT_EQ(sock_z_.state(), SocketState::kConnected); +} + +TEST_F(DcSctpSocketTest, ShuttingDownWhileEstablishingConnection) { + EXPECT_CALL(cb_a_, OnConnected).Times(0); + EXPECT_CALL(cb_z_, OnConnected).Times(1); + sock_a_.Connect(); + + // Z reads INIT, produces INIT_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + // Z reads COOKIE_ECHO, produces COOKIE_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // Drop COOKIE_ACK, just to more easily verify shutdown protocol. + cb_z_.ConsumeSentPacket(); + + // As Socket A has received INIT_ACK, it has a TCB and is connected, while + // Socket Z needs to receive COOKIE_ECHO to get there. Socket A still has + // timers running at this point. + EXPECT_EQ(sock_a_.state(), SocketState::kConnecting); + EXPECT_EQ(sock_z_.state(), SocketState::kConnected); + + // Socket A is now shut down, which should make it stop those timers. + sock_a_.Shutdown(); + + EXPECT_CALL(cb_a_, OnClosed).Times(1); + EXPECT_CALL(cb_z_, OnClosed).Times(1); + + // Z reads SHUTDOWN, produces SHUTDOWN_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + // Z reads SHUTDOWN_COMPLETE. + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + + EXPECT_TRUE(cb_a_.ConsumeSentPacket().empty()); + EXPECT_TRUE(cb_z_.ConsumeSentPacket().empty()); + + EXPECT_EQ(sock_a_.state(), SocketState::kClosed); + EXPECT_EQ(sock_z_.state(), SocketState::kClosed); +} + +TEST_F(DcSctpSocketTest, EstablishSimultaneousConnection) { + EXPECT_CALL(cb_a_, OnConnected).Times(1); + EXPECT_CALL(cb_z_, OnConnected).Times(1); + EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(0); + EXPECT_CALL(cb_z_, OnConnectionRestarted).Times(0); + sock_a_.Connect(); + + // INIT isn't received by Z, as it wasn't ready yet. + cb_a_.ConsumeSentPacket(); + + sock_z_.Connect(); + + // A reads INIT, produces INIT_ACK + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + // Z reads INIT_ACK, sends COOKIE_ECHO + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + + // A reads COOKIE_ECHO - establishes connection. + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + EXPECT_EQ(sock_a_.state(), SocketState::kConnected); + + // Proceed with the remaining packets. + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + + EXPECT_EQ(sock_a_.state(), SocketState::kConnected); + EXPECT_EQ(sock_z_.state(), SocketState::kConnected); +} + +TEST_F(DcSctpSocketTest, EstablishConnectionLostCookieAck) { + EXPECT_CALL(cb_a_, OnConnected).Times(1); + EXPECT_CALL(cb_z_, OnConnected).Times(1); + EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(0); + EXPECT_CALL(cb_z_, OnConnectionRestarted).Times(0); + + sock_a_.Connect(); + // Z reads INIT, produces INIT_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + // Z reads COOKIE_ECHO, produces COOKIE_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // COOKIE_ACK is lost. + cb_z_.ConsumeSentPacket(); + + EXPECT_EQ(sock_a_.state(), SocketState::kConnecting); + EXPECT_EQ(sock_z_.state(), SocketState::kConnected); + + // This will make A re-send the COOKIE_ECHO + AdvanceTime(DurationMs(options_.t1_cookie_timeout)); + RunTimers(); + + // Z reads COOKIE_ECHO, produces COOKIE_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // A reads COOKIE_ACK. + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + EXPECT_EQ(sock_a_.state(), SocketState::kConnected); + EXPECT_EQ(sock_z_.state(), SocketState::kConnected); +} + +TEST_F(DcSctpSocketTest, ResendInitAndEstablishConnection) { + sock_a_.Connect(); + // INIT is never received by Z. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, + SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + EXPECT_EQ(init_packet.descriptors()[0].type, InitChunk::kType); + + AdvanceTime(options_.t1_init_timeout); + RunTimers(); + + // Z reads INIT, produces INIT_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + // Z reads COOKIE_ECHO, produces COOKIE_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // A reads COOKIE_ACK. + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + EXPECT_EQ(sock_a_.state(), SocketState::kConnected); + EXPECT_EQ(sock_z_.state(), SocketState::kConnected); +} + +TEST_F(DcSctpSocketTest, ResendingInitTooManyTimesAborts) { + sock_a_.Connect(); + + // INIT is never received by Z. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, + SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + EXPECT_EQ(init_packet.descriptors()[0].type, InitChunk::kType); + + for (int i = 0; i < options_.max_init_retransmits; ++i) { + AdvanceTime(options_.t1_init_timeout * (1 << i)); + RunTimers(); + + // INIT is resent + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket resent_init_packet, + SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + EXPECT_EQ(resent_init_packet.descriptors()[0].type, InitChunk::kType); + } + + // Another timeout, after the max init retransmits. + AdvanceTime(options_.t1_init_timeout * (1 << options_.max_init_retransmits)); + EXPECT_CALL(cb_a_, OnAborted).Times(1); + RunTimers(); + + EXPECT_EQ(sock_a_.state(), SocketState::kClosed); +} + +TEST_F(DcSctpSocketTest, ResendCookieEchoAndEstablishConnection) { + sock_a_.Connect(); + + // Z reads INIT, produces INIT_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + // COOKIE_ECHO is never received by Z. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, + SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + EXPECT_EQ(init_packet.descriptors()[0].type, CookieEchoChunk::kType); + + AdvanceTime(options_.t1_init_timeout); + RunTimers(); + + // Z reads COOKIE_ECHO, produces COOKIE_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // A reads COOKIE_ACK. + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + EXPECT_EQ(sock_a_.state(), SocketState::kConnected); + EXPECT_EQ(sock_z_.state(), SocketState::kConnected); +} + +TEST_F(DcSctpSocketTest, ResendingCookieEchoTooManyTimesAborts) { + sock_a_.Connect(); + + // Z reads INIT, produces INIT_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + // COOKIE_ECHO is never received by Z. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, + SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + EXPECT_EQ(init_packet.descriptors()[0].type, CookieEchoChunk::kType); + + for (int i = 0; i < options_.max_init_retransmits; ++i) { + AdvanceTime(options_.t1_cookie_timeout * (1 << i)); + RunTimers(); + + // COOKIE_ECHO is resent + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket resent_init_packet, + SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + EXPECT_EQ(resent_init_packet.descriptors()[0].type, CookieEchoChunk::kType); + } + + // Another timeout, after the max init retransmits. + AdvanceTime(options_.t1_cookie_timeout * + (1 << options_.max_init_retransmits)); + EXPECT_CALL(cb_a_, OnAborted).Times(1); + RunTimers(); + + EXPECT_EQ(sock_a_.state(), SocketState::kClosed); +} + +TEST_F(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) { + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector(kLargeMessageSize)), + kSendOptions); + sock_a_.Connect(); + + // Z reads INIT, produces INIT_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + // COOKIE_ECHO is never received by Z. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket cookie_echo_packet1, + SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + EXPECT_THAT(cookie_echo_packet1.descriptors(), SizeIs(2)); + EXPECT_EQ(cookie_echo_packet1.descriptors()[0].type, CookieEchoChunk::kType); + EXPECT_EQ(cookie_echo_packet1.descriptors()[1].type, DataChunk::kType); + + EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty()); + + // There are DATA chunks in the sent packet (that was lost), which means that + // the T3-RTX timer is running, but as the socket is in kCookieEcho state, it + // will be T1-COOKIE that drives retransmissions, so when the T3-RTX expires, + // nothing should be retransmitted. + ASSERT_TRUE(options_.rto_initial < options_.t1_cookie_timeout); + AdvanceTime(options_.rto_initial); + RunTimers(); + EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty()); + + // When T1-COOKIE expires, both the COOKIE-ECHO and DATA should be present. + AdvanceTime(options_.t1_cookie_timeout - options_.rto_initial); + RunTimers(); + + // And this COOKIE-ECHO and DATA is also lost - never received by Z. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket cookie_echo_packet2, + SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + EXPECT_THAT(cookie_echo_packet2.descriptors(), SizeIs(2)); + EXPECT_EQ(cookie_echo_packet2.descriptors()[0].type, CookieEchoChunk::kType); + EXPECT_EQ(cookie_echo_packet2.descriptors()[1].type, DataChunk::kType); + + EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty()); + + // COOKIE_ECHO has exponential backoff. + AdvanceTime(options_.t1_cookie_timeout * 2); + RunTimers(); + + // Z reads COOKIE_ECHO, produces COOKIE_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // A reads COOKIE_ACK. + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + EXPECT_EQ(sock_a_.state(), SocketState::kConnected); + EXPECT_EQ(sock_z_.state(), SocketState::kConnected); + + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + EXPECT_THAT(cb_z_.ConsumeReceivedMessage()->payload(), + SizeIs(kLargeMessageSize)); +} + +TEST_F(DcSctpSocketTest, ShutdownConnection) { + ConnectSockets(); + + RTC_LOG(LS_INFO) << "Shutting down"; + + sock_a_.Shutdown(); + // Z reads SHUTDOWN, produces SHUTDOWN_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + // Z reads SHUTDOWN_COMPLETE. + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + + EXPECT_EQ(sock_a_.state(), SocketState::kClosed); + EXPECT_EQ(sock_z_.state(), SocketState::kClosed); +} + +TEST_F(DcSctpSocketTest, ShutdownTimerExpiresTooManyTimeClosesConnection) { + ConnectSockets(); + + sock_a_.Shutdown(); + // Drop first SHUTDOWN packet. + cb_a_.ConsumeSentPacket(); + + EXPECT_EQ(sock_a_.state(), SocketState::kShuttingDown); + + for (int i = 0; i < options_.max_retransmissions; ++i) { + AdvanceTime(DurationMs(options_.rto_initial * (1 << i))); + RunTimers(); + + // Dropping every shutdown chunk. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + EXPECT_EQ(packet.descriptors()[0].type, ShutdownChunk::kType); + EXPECT_TRUE(cb_a_.ConsumeSentPacket().empty()); + } + // The last expiry, makes it abort the connection. + AdvanceTime(options_.rto_initial * (1 << options_.max_retransmissions)); + EXPECT_CALL(cb_a_, OnAborted).Times(1); + RunTimers(); + + EXPECT_EQ(sock_a_.state(), SocketState::kClosed); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + EXPECT_EQ(packet.descriptors()[0].type, AbortChunk::kType); + EXPECT_TRUE(cb_a_.ConsumeSentPacket().empty()); +} + +TEST_F(DcSctpSocketTest, EstablishConnectionWhileSendingData) { + sock_a_.Connect(); + + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + + // Z reads INIT, produces INIT_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // // A reads INIT_ACK, produces COOKIE_ECHO + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + // // Z reads COOKIE_ECHO, produces COOKIE_ACK + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // // A reads COOKIE_ACK. + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + EXPECT_EQ(sock_a_.state(), SocketState::kConnected); + EXPECT_EQ(sock_z_.state(), SocketState::kConnected); + + absl::optional msg = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); +} + +TEST_F(DcSctpSocketTest, SendMessageAfterEstablished) { + ConnectSockets(); + + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + + absl::optional msg = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); +} + +TEST_F(DcSctpSocketTest, TimeoutResendsPacket) { + ConnectSockets(); + + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + cb_a_.ConsumeSentPacket(); + + RTC_LOG(LS_INFO) << "Advancing time"; + AdvanceTime(options_.rto_initial); + RunTimers(); + + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + + absl::optional msg = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); +} + +TEST_F(DcSctpSocketTest, SendALotOfBytesMissedSecondPacket) { + ConnectSockets(); + + std::vector payload(kLargeMessageSize); + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + + // First DATA + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // Second DATA (lost) + cb_a_.ConsumeSentPacket(); + + // Retransmit and handle the rest + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + + absl::optional msg = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload)); +} + +TEST_F(DcSctpSocketTest, SendingHeartbeatAnswersWithAck) { + ConnectSockets(); + + // Inject a HEARTBEAT chunk + SctpPacket::Builder b(sock_a_.verification_tag(), DcSctpOptions()); + uint8_t info[] = {1, 2, 3, 4}; + Parameters::Builder params_builder; + params_builder.Add(HeartbeatInfoParameter(info)); + b.Add(HeartbeatRequestChunk(params_builder.Build())); + sock_a_.ReceivePacket(b.Build()); + + // HEARTBEAT_ACK is sent as a reply. Capture it. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket ack_packet, + SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + ASSERT_THAT(ack_packet.descriptors(), SizeIs(1)); + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatAckChunk ack, + HeartbeatAckChunk::Parse(ack_packet.descriptors()[0].data)); + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, ack.info()); + EXPECT_THAT(info_param.info(), ElementsAre(1, 2, 3, 4)); +} + +TEST_F(DcSctpSocketTest, ExpectHeartbeatToBeSent) { + ConnectSockets(); + + EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty()); + + AdvanceTime(options_.heartbeat_interval); + RunTimers(); + + std::vector hb_packet_raw = cb_a_.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet, + SctpPacket::Parse(hb_packet_raw)); + ASSERT_THAT(hb_packet.descriptors(), SizeIs(1)); + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatRequestChunk hb, + HeartbeatRequestChunk::Parse(hb_packet.descriptors()[0].data)); + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, hb.info()); + + // The info is a single 64-bit number. + EXPECT_THAT(hb.info()->info(), SizeIs(8)); + + // Feed it to Sock-z and expect a HEARTBEAT_ACK that will be propagated back. + sock_z_.ReceivePacket(hb_packet_raw); + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); +} + +TEST_F(DcSctpSocketTest, CloseConnectionAfterTooManyLostHeartbeats) { + ConnectSockets(); + + EXPECT_THAT(cb_a_.ConsumeSentPacket(), testing::IsEmpty()); + // Force-close socket Z so that it doesn't interfere from now on. + sock_z_.Close(); + + DurationMs time_to_next_hearbeat = options_.heartbeat_interval; + + for (int i = 0; i < options_.max_retransmissions; ++i) { + RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending..."; + AdvanceTime(time_to_next_hearbeat); + RunTimers(); + + // Dropping every heartbeat. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet, + SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + EXPECT_EQ(hb_packet.descriptors()[0].type, HeartbeatRequestChunk::kType); + + RTC_LOG(LS_INFO) << "Letting the heartbeat expire."; + AdvanceTime(DurationMs(1000)); + RunTimers(); + + time_to_next_hearbeat = options_.heartbeat_interval - DurationMs(1000); + } + + RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending..."; + AdvanceTime(time_to_next_hearbeat); + RunTimers(); + + // Last heartbeat + EXPECT_THAT(cb_a_.ConsumeSentPacket(), Not(IsEmpty())); + + EXPECT_CALL(cb_a_, OnAborted).Times(1); + // Should suffice as exceeding RTO + AdvanceTime(DurationMs(1000)); + RunTimers(); +} + +TEST_F(DcSctpSocketTest, RecoversAfterASuccessfulAck) { + ConnectSockets(); + + EXPECT_THAT(cb_a_.ConsumeSentPacket(), testing::IsEmpty()); + // Force-close socket Z so that it doesn't interfere from now on. + sock_z_.Close(); + + DurationMs time_to_next_hearbeat = options_.heartbeat_interval; + + for (int i = 0; i < options_.max_retransmissions; ++i) { + AdvanceTime(time_to_next_hearbeat); + RunTimers(); + + // Dropping every heartbeat. + cb_a_.ConsumeSentPacket(); + + RTC_LOG(LS_INFO) << "Letting the heartbeat expire."; + AdvanceTime(DurationMs(1000)); + RunTimers(); + + time_to_next_hearbeat = options_.heartbeat_interval - DurationMs(1000); + } + + RTC_LOG(LS_INFO) << "Getting the last heartbeat - and acking it"; + AdvanceTime(time_to_next_hearbeat); + RunTimers(); + + std::vector hb_packet_raw = cb_a_.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet, + SctpPacket::Parse(hb_packet_raw)); + ASSERT_THAT(hb_packet.descriptors(), SizeIs(1)); + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatRequestChunk hb, + HeartbeatRequestChunk::Parse(hb_packet.descriptors()[0].data)); + + SctpPacket::Builder b(sock_a_.verification_tag(), options_); + b.Add(HeartbeatAckChunk(std::move(hb).extract_parameters())); + sock_a_.ReceivePacket(b.Build()); + + // Should suffice as exceeding RTO - which will not fire. + EXPECT_CALL(cb_a_, OnAborted).Times(0); + AdvanceTime(DurationMs(1000)); + RunTimers(); + EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty()); + + // Verify that we get new heartbeats again. + RTC_LOG(LS_INFO) << "Expecting a new heartbeat"; + AdvanceTime(time_to_next_hearbeat); + RunTimers(); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket another_packet, + SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + EXPECT_EQ(another_packet.descriptors()[0].type, HeartbeatRequestChunk::kType); +} + +TEST_F(DcSctpSocketTest, ResetStream) { + ConnectSockets(); + + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), {}); + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + + absl::optional msg = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + + // Handle SACK + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + // Reset the outgoing stream. This will directly send a RE-CONFIG. + sock_a_.ResetStreams(std::vector({StreamID(1)})); + + // Receiving the packet will trigger a callback, indicating that A has + // reset its stream. It will also send a RE-CONFIG with a response. + EXPECT_CALL(cb_z_, OnIncomingStreamsReset).Times(1); + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + + // Receiving a response will trigger a callback. Streams are now reset. + EXPECT_CALL(cb_a_, OnStreamsResetPerformed).Times(1); + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); +} + +TEST_F(DcSctpSocketTest, ResetStreamWillMakeChunksStartAtZeroSsn) { + ConnectSockets(); + + std::vector payload(options_.mtu - 100); + + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + auto packet1 = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet1, HasDataChunkWithSsn(SSN(0))); + sock_z_.ReceivePacket(packet1); + + auto packet2 = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet2, HasDataChunkWithSsn(SSN(1))); + sock_z_.ReceivePacket(packet2); + + // Handle SACK + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + absl::optional msg1 = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->stream_id(), StreamID(1)); + + absl::optional msg2 = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg2.has_value()); + EXPECT_EQ(msg2->stream_id(), StreamID(1)); + + // Reset the outgoing stream. This will directly send a RE-CONFIG. + sock_a_.ResetStreams(std::vector({StreamID(1)})); + // RE-CONFIG, req + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // RE-CONFIG, resp + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + auto packet3 = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet3, HasDataChunkWithSsn(SSN(0))); + sock_z_.ReceivePacket(packet3); + + auto packet4 = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet4, HasDataChunkWithSsn(SSN(1))); + sock_z_.ReceivePacket(packet4); + + // Handle SACK + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); +} + +TEST_F(DcSctpSocketTest, ResetStreamWillOnlyResetTheRequestedStreams) { + ConnectSockets(); + + std::vector payload(options_.mtu - 100); + + // Send two ordered messages on SID 1 + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + auto packet1 = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet1, HasDataChunkWithStreamId(StreamID(1))); + EXPECT_THAT(packet1, HasDataChunkWithSsn(SSN(0))); + sock_z_.ReceivePacket(packet1); + + auto packet2 = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet1, HasDataChunkWithStreamId(StreamID(1))); + EXPECT_THAT(packet2, HasDataChunkWithSsn(SSN(1))); + sock_z_.ReceivePacket(packet2); + + // Handle SACK + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + // Do the same, for SID 3 + sock_a_.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); + sock_a_.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); + auto packet3 = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet3, HasDataChunkWithStreamId(StreamID(3))); + EXPECT_THAT(packet3, HasDataChunkWithSsn(SSN(0))); + sock_z_.ReceivePacket(packet3); + auto packet4 = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet4, HasDataChunkWithStreamId(StreamID(3))); + EXPECT_THAT(packet4, HasDataChunkWithSsn(SSN(1))); + sock_z_.ReceivePacket(packet4); + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + // Receive all messages. + absl::optional msg1 = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->stream_id(), StreamID(1)); + + absl::optional msg2 = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg2.has_value()); + EXPECT_EQ(msg2->stream_id(), StreamID(1)); + + absl::optional msg3 = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg3.has_value()); + EXPECT_EQ(msg3->stream_id(), StreamID(3)); + + absl::optional msg4 = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg4.has_value()); + EXPECT_EQ(msg4->stream_id(), StreamID(3)); + + // Reset SID 1. This will directly send a RE-CONFIG. + sock_a_.ResetStreams(std::vector({StreamID(3)})); + // RE-CONFIG, req + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // RE-CONFIG, resp + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + // Send a message on SID 1 and 3 - SID 1 should not be reset, but 3 should. + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + sock_a_.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); + + auto packet5 = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet5, HasDataChunkWithStreamId(StreamID(1))); + EXPECT_THAT(packet5, HasDataChunkWithSsn(SSN(2))); // Unchanged. + sock_z_.ReceivePacket(packet5); + + auto packet6 = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet6, HasDataChunkWithStreamId(StreamID(3))); + EXPECT_THAT(packet6, HasDataChunkWithSsn(SSN(0))); // Reset. + sock_z_.ReceivePacket(packet6); + + // Handle SACK + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); +} + +TEST_F(DcSctpSocketTest, OnePeerReconnects) { + ConnectSockets(); + + EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(1); + // Let's be evil here - reconnect while a fragmented packet was about to be + // sent. The receiving side should get it in full. + std::vector payload(kLargeMessageSize); + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + + // First DATA + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + + // Create a new association, z2 - and don't use z anymore. + testing::NiceMock cb_z2("Z2"); + DcSctpSocket sock_z2("Z2", cb_z2, nullptr, options_); + + sock_z2.Connect(); + + // Retransmit and handle the rest. As there will be some chunks in-flight that + // have the wrong verification tag, those will yield errors. + ExchangeMessages(sock_a_, cb_a_, sock_z2, cb_z2); + + absl::optional msg = cb_z2.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload)); +} + +TEST_F(DcSctpSocketTest, SendMessageWithLimitedRtx) { + ConnectSockets(); + + SendOptions send_options; + send_options.max_retransmissions = 0; + std::vector payload(options_.mtu - 100); + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options); + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options); + + // First DATA + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + // Second DATA (lost) + cb_a_.ConsumeSentPacket(); + // Third DATA + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + + // Handle SACK for first DATA + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + // Handle delayed SACK for third DATA + AdvanceTime(options_.delayed_ack_max_timeout); + RunTimers(); + + // Handle SACK for second DATA + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + // Now the missing data chunk will be marked as nacked, but it might still be + // in-flight and the reported gap could be due to out-of-order delivery. So + // the RetransmissionQueue will not mark it as "to be retransmitted" until + // after the t3-rtx timer has expired. + AdvanceTime(options_.rto_initial); + RunTimers(); + + // The chunk will be marked as retransmitted, and then as abandoned, which + // will trigger a FORWARD-TSN to be sent. + + // FORWARD-TSN (third) + sock_z_.ReceivePacket(cb_a_.ConsumeSentPacket()); + + // Which will trigger a SACK + sock_a_.ReceivePacket(cb_z_.ConsumeSentPacket()); + + absl::optional msg1 = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->ppid(), PPID(51)); + + absl::optional msg2 = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg2.has_value()); + EXPECT_EQ(msg2->ppid(), PPID(53)); + + absl::optional msg3 = cb_z_.ConsumeReceivedMessage(); + EXPECT_FALSE(msg3.has_value()); +} + +TEST_F(DcSctpSocketTest, SendManyFragmentedMessagesWithLimitedRtx) { + ConnectSockets(); + + SendOptions send_options; + send_options.unordered = IsUnordered(true); + send_options.max_retransmissions = 0; + std::vector payload(options_.mtu * 2 - 100 /* margin */); + // Sending first message + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); + // Sending second message + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options); + // Sending third message + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options); + // Sending fourth message + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(54), payload), send_options); + + // First DATA, first fragment + std::vector packet = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(51))); + sock_z_.ReceivePacket(std::move(packet)); + + // First DATA, second fragment (lost) + packet = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(51))); + + // Second DATA, first fragment + packet = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(52))); + sock_z_.ReceivePacket(std::move(packet)); + + // Second DATA, second fragment (lost) + packet = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(52))); + EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); + + // Third DATA, first fragment + packet = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(53))); + EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); + sock_z_.ReceivePacket(std::move(packet)); + + // Third DATA, second fragment (lost) + packet = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(53))); + EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); + + // Fourth DATA, first fragment + packet = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(54))); + EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); + sock_z_.ReceivePacket(std::move(packet)); + + // Fourth DATA, second fragment + packet = cb_a_.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(54))); + EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); + sock_z_.ReceivePacket(std::move(packet)); + + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + + // Let the RTX timer expire, and exchange FORWARD-TSN/SACKs + AdvanceTime(options_.rto_initial); + RunTimers(); + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + + absl::optional msg1 = cb_z_.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->ppid(), PPID(54)); + + ASSERT_FALSE(cb_z_.ConsumeReceivedMessage().has_value()); +} + +struct FakeChunkConfig : ChunkConfig { + static constexpr int kType = 0x49; + static constexpr size_t kHeaderSize = 4; + static constexpr int kVariableLengthAlignment = 0; +}; + +class FakeChunk : public Chunk, public TLVTrait { + public: + FakeChunk() {} + + FakeChunk(FakeChunk&& other) = default; + FakeChunk& operator=(FakeChunk&& other) = default; + + void SerializeTo(std::vector& out) const override { + AllocateTLV(out); + } + std::string ToString() const override { return "FAKE"; } +}; + +TEST_F(DcSctpSocketTest, ReceivingUnknownChunkRespondsWithError) { + ConnectSockets(); + + // Inject a FAKE chunk + SctpPacket::Builder b(sock_a_.verification_tag(), DcSctpOptions()); + b.Add(FakeChunk()); + sock_a_.ReceivePacket(b.Build()); + + // ERROR is sent as a reply. Capture it. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket reply_packet, + SctpPacket::Parse(cb_a_.ConsumeSentPacket())); + ASSERT_THAT(reply_packet.descriptors(), SizeIs(1)); + ASSERT_HAS_VALUE_AND_ASSIGN( + ErrorChunk error, ErrorChunk::Parse(reply_packet.descriptors()[0].data)); + ASSERT_HAS_VALUE_AND_ASSIGN( + UnrecognizedChunkTypeCause cause, + error.error_causes().get()); + EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(0x49, 0x00, 0x00, 0x04)); +} + +TEST_F(DcSctpSocketTest, ReceivingErrorChunkReportsAsCallback) { + ConnectSockets(); + + // Inject a ERROR chunk + SctpPacket::Builder b(sock_a_.verification_tag(), DcSctpOptions()); + b.Add( + ErrorChunk(Parameters::Builder() + .Add(UnrecognizedChunkTypeCause({0x49, 0x00, 0x00, 0x04})) + .Build())); + + EXPECT_CALL(cb_a_, OnError(ErrorKind::kPeerReported, + HasSubstr("Unrecognized Chunk Type"))); + sock_a_.ReceivePacket(b.Build()); +} + +TEST_F(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) { + // Create a new association, z2 - and don't use z anymore. + testing::NiceMock cb_z2("Z2"); + DcSctpOptions options = options_; + options.max_receiver_window_buffer_size = 100; + DcSctpSocket sock_z2("Z2", cb_z2, nullptr, options); + + EXPECT_CALL(cb_z2, OnClosed).Times(0); + EXPECT_CALL(cb_z2, OnAborted).Times(0); + + sock_a_.Connect(); + std::vector init_data = cb_a_.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, + SctpPacket::Parse(init_data)); + ASSERT_HAS_VALUE_AND_ASSIGN( + InitChunk init_chunk, + InitChunk::Parse(init_packet.descriptors()[0].data)); + sock_z2.ReceivePacket(init_data); + sock_a_.ReceivePacket(cb_z2.ConsumeSentPacket()); + sock_z2.ReceivePacket(cb_a_.ConsumeSentPacket()); + sock_a_.ReceivePacket(cb_z2.ConsumeSentPacket()); + + // Fill up Z2 to the high watermark limit. + TSN tsn = init_chunk.initial_tsn(); + AnyDataChunk::Options opts; + opts.is_beginning = Data::IsBeginning(true); + sock_z2.ReceivePacket( + SctpPacket::Builder(sock_z2.verification_tag(), options) + .Add(DataChunk(tsn, StreamID(1), SSN(0), PPID(53), + std::vector( + 100 * ReassemblyQueue::kHighWatermarkLimit + 1), + opts)) + .Build()); + + // First DATA will always trigger a SACK. It's not interesting. + EXPECT_THAT(cb_z2.ConsumeSentPacket(), + AllOf(HasSackWithCumAckTsn(tsn), HasSackWithNoGapAckBlocks())); + + // This DATA should be accepted - it's advancing cum ack tsn. + sock_z2.ReceivePacket(SctpPacket::Builder(sock_z2.verification_tag(), options) + .Add(DataChunk(AddTo(tsn, 1), StreamID(1), SSN(0), + PPID(53), std::vector(1), + /*options=*/{})) + .Build()); + + // The receiver might have moved into delayed ack mode. + cb_z2.AdvanceTime(options.rto_initial); + RunTimers(cb_z2, sock_z2); + + EXPECT_THAT( + cb_z2.ConsumeSentPacket(), + AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks())); + + // This DATA will not be accepted - it's not advancing cum ack tsn. + sock_z2.ReceivePacket(SctpPacket::Builder(sock_z2.verification_tag(), options) + .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), + PPID(53), std::vector(1), + /*options=*/{})) + .Build()); + + // Sack will be sent in IMMEDIATE mode when this is happening. + EXPECT_THAT( + cb_z2.ConsumeSentPacket(), + AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks())); + + // This DATA will not be accepted either. + sock_z2.ReceivePacket(SctpPacket::Builder(sock_z2.verification_tag(), options) + .Add(DataChunk(AddTo(tsn, 4), StreamID(1), SSN(0), + PPID(53), std::vector(1), + /*options=*/{})) + .Build()); + + // Sack will be sent in IMMEDIATE mode when this is happening. + EXPECT_THAT( + cb_z2.ConsumeSentPacket(), + AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks())); + + // This DATA should be accepted, and it fills the reassembly queue. + sock_z2.ReceivePacket( + SctpPacket::Builder(sock_z2.verification_tag(), options) + .Add(DataChunk(AddTo(tsn, 2), StreamID(1), SSN(0), PPID(53), + std::vector(kSmallMessageSize), + /*options=*/{})) + .Build()); + + // The receiver might have moved into delayed ack mode. + cb_z2.AdvanceTime(options.rto_initial); + RunTimers(cb_z2, sock_z2); + + EXPECT_THAT( + cb_z2.ConsumeSentPacket(), + AllOf(HasSackWithCumAckTsn(AddTo(tsn, 2)), HasSackWithNoGapAckBlocks())); + + EXPECT_CALL(cb_z2, OnAborted(ErrorKind::kResourceExhaustion, _)); + EXPECT_CALL(cb_z2, OnClosed).Times(0); + + // This DATA will make the connection close. It's too full now. + sock_z2.ReceivePacket( + SctpPacket::Builder(sock_z2.verification_tag(), options) + .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53), + std::vector(kSmallMessageSize), + /*options=*/{})) + .Build()); +} + +TEST_F(DcSctpSocketTest, SetMaxMessageSize) { + sock_a_.SetMaxMessageSize(42u); + EXPECT_EQ(sock_a_.options().max_message_size, 42u); +} + +TEST_F(DcSctpSocketTest, SendsMessagesWithLowLifetime) { + ConnectSockets(); + + // Mock that the time always goes forward. + TimeMs now(0); + EXPECT_CALL(cb_a_, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + EXPECT_CALL(cb_z_, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + + // Queue a few small messages with low lifetime, both ordered and unordered, + // and validate that all are delivered. + static constexpr int kIterations = 100; + for (int i = 0; i < kIterations; ++i) { + SendOptions send_options; + send_options.unordered = IsUnordered((i % 2) == 0); + send_options.lifetime = DurationMs(i % 3); // 0, 1, 2 ms + + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options); + } + + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + + for (int i = 0; i < kIterations; ++i) { + EXPECT_TRUE(cb_z_.ConsumeReceivedMessage().has_value()); + } + + EXPECT_FALSE(cb_z_.ConsumeReceivedMessage().has_value()); + + // Validate that the sockets really make the time move forward. + EXPECT_GE(*now, kIterations * 2); +} + +TEST_F(DcSctpSocketTest, DiscardsMessagesWithLowLifetimeIfMustBuffer) { + ConnectSockets(); + + SendOptions lifetime_0; + lifetime_0.unordered = IsUnordered(true); + lifetime_0.lifetime = DurationMs(0); + + SendOptions lifetime_1; + lifetime_1.unordered = IsUnordered(true); + lifetime_1.lifetime = DurationMs(1); + + // Mock that the time always goes forward. + TimeMs now(0); + EXPECT_CALL(cb_a_, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + EXPECT_CALL(cb_z_, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + + // Fill up the send buffer with a large message. + std::vector payload(kLargeMessageSize); + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + + // And queue a few small messages with lifetime=0 or 1 ms - can't be sent. + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), lifetime_0); + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), {4, 5, 6}), lifetime_1); + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), {7, 8, 9}), lifetime_0); + + // Handle all that was sent until congestion window got full. + for (;;) { + std::vector packet_from_a = cb_a_.ConsumeSentPacket(); + if (packet_from_a.empty()) { + break; + } + sock_z_.ReceivePacket(std::move(packet_from_a)); + } + + // Shouldn't be enough to send that large message. + EXPECT_FALSE(cb_z_.ConsumeReceivedMessage().has_value()); + + // Exchange the rest of the messages, with the time ever increasing. + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + + // The large message should be delivered. It was sent reliably. + ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage m1, cb_z_.ConsumeReceivedMessage()); + EXPECT_EQ(m1.stream_id(), StreamID(1)); + EXPECT_THAT(m1.payload(), SizeIs(kLargeMessageSize)); + + // But none of the smaller messages. + EXPECT_FALSE(cb_z_.ConsumeReceivedMessage().has_value()); +} + +TEST_F(DcSctpSocketTest, HasReasonableBufferedAmountValues) { + ConnectSockets(); + + EXPECT_EQ(sock_a_.buffered_amount(StreamID(1)), 0u); + + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector(kSmallMessageSize)), + kSendOptions); + // Sending a small message will directly send it as a single packet, so + // nothing is left in the queue. + EXPECT_EQ(sock_a_.buffered_amount(StreamID(1)), 0u); + + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector(kLargeMessageSize)), + kSendOptions); + + // Sending a message will directly start sending a few packets, so the + // buffered amount is not the full message size. + EXPECT_GT(sock_a_.buffered_amount(StreamID(1)), 0u); + EXPECT_LT(sock_a_.buffered_amount(StreamID(1)), kLargeMessageSize); +} + +TEST_F(DcSctpSocketTest, HasDefaultOnBufferedAmountLowValueZero) { + EXPECT_EQ(sock_a_.buffered_amount_low_threshold(StreamID(1)), 0u); +} + +TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountLowWithDefaultValueZero) { + EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); + ConnectSockets(); + + EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))); + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector(kSmallMessageSize)), + kSendOptions); + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); +} + +TEST_F(DcSctpSocketTest, DoesntTriggerOnBufferedAmountLowIfBelowThreshold) { + static constexpr size_t kMessageSize = 1000; + static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 10; + + sock_a_.SetBufferedAmountLowThreshold(StreamID(1), + kBufferedAmountLowThreshold); + EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); + ConnectSockets(); + + EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(0); + sock_a_.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), + kSendOptions); + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + + sock_a_.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), + kSendOptions); + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); +} + +TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountMultipleTimes) { + static constexpr size_t kMessageSize = 1000; + static constexpr size_t kBufferedAmountLowThreshold = kMessageSize / 2; + + sock_a_.SetBufferedAmountLowThreshold(StreamID(1), + kBufferedAmountLowThreshold); + EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); + ConnectSockets(); + + EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(3); + EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(2))).Times(2); + sock_a_.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), + kSendOptions); + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + + sock_a_.Send( + DcSctpMessage(StreamID(2), PPID(53), std::vector(kMessageSize)), + kSendOptions); + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + + sock_a_.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), + kSendOptions); + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + + sock_a_.Send( + DcSctpMessage(StreamID(2), PPID(53), std::vector(kMessageSize)), + kSendOptions); + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); + + sock_a_.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), + kSendOptions); + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); +} + +TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) { + static constexpr size_t kMessageSize = 1000; + static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 1.5; + + sock_a_.SetBufferedAmountLowThreshold(StreamID(1), + kBufferedAmountLowThreshold); + EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); + ConnectSockets(); + + EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0); + + // Add a few messages to fill up the congestion window. When that is full, + // messages will start to be fully buffered. + while (sock_a_.buffered_amount(StreamID(1)) == 0) { + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector(kMessageSize)), + kSendOptions); + } + size_t initial_buffered = sock_a_.buffered_amount(StreamID(1)); + ASSERT_GE(initial_buffered, 0u); + ASSERT_LT(initial_buffered, kMessageSize); + + // Up to kMessageSize (which is below the threshold) + sock_a_.Send( + DcSctpMessage(StreamID(1), PPID(53), + std::vector(kMessageSize - initial_buffered)), + kSendOptions); + EXPECT_EQ(sock_a_.buffered_amount(StreamID(1)), kMessageSize); + + // Up to 2*kMessageSize (which is above the threshold) + sock_a_.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector(kMessageSize)), + kSendOptions); + EXPECT_EQ(sock_a_.buffered_amount(StreamID(1)), 2 * kMessageSize); + + // Start ACKing packets, which will empty the send queue, and trigger the + // callback. + EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(1); + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); +} + +TEST_F(DcSctpSocketTest, DoesntTriggerOnTotalBufferAmountLowWhenBelow) { + ConnectSockets(); + + EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(0); + + sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector(kLargeMessageSize)), + kSendOptions); + + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); +} + +TEST_F(DcSctpSocketTest, TriggersOnTotalBufferAmountLowWhenCrossingThreshold) { + ConnectSockets(); + + EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(0); + + // Fill up the send queue completely. + for (;;) { + if (sock_a_.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector(kLargeMessageSize)), + kSendOptions) == SendStatus::kErrorResourceExhaustion) { + break; + } + } + + EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(1); + ExchangeMessages(sock_a_, cb_a_, sock_z_, cb_z_); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/socket/heartbeat_handler.cc b/net/dcsctp/socket/heartbeat_handler.cc new file mode 100644 index 0000000000..78616d1033 --- /dev/null +++ b/net/dcsctp/socket/heartbeat_handler.cc @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/heartbeat_handler.h" + +#include + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/timer/timer.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +// This is stored (in serialized form) as HeartbeatInfoParameter sent in +// HeartbeatRequestChunk and received back in HeartbeatAckChunk. It should be +// well understood that this data may be modified by the peer, so it can't +// be trusted. +// +// It currently only stores a timestamp, in millisecond precision, to allow for +// RTT measurements. If that would be manipulated by the peer, it would just +// result in incorrect RTT measurements, which isn't an issue. +class HeartbeatInfo { + public: + static constexpr size_t kBufferSize = sizeof(uint64_t); + static_assert(kBufferSize == 8, "Unexpected buffer size"); + + explicit HeartbeatInfo(TimeMs created_at) : created_at_(created_at) {} + + std::vector Serialize() { + uint32_t high_bits = static_cast(*created_at_ >> 32); + uint32_t low_bits = static_cast(*created_at_); + + std::vector data(kBufferSize); + BoundedByteWriter writer(data); + writer.Store32<0>(high_bits); + writer.Store32<4>(low_bits); + return data; + } + + static absl::optional Deserialize( + rtc::ArrayView data) { + if (data.size() != kBufferSize) { + RTC_LOG(LS_WARNING) << "Invalid heartbeat info: " << data.size() + << " bytes"; + return absl::nullopt; + } + + BoundedByteReader reader(data); + uint32_t high_bits = reader.Load32<0>(); + uint32_t low_bits = reader.Load32<4>(); + + uint64_t created_at = static_cast(high_bits) << 32 | low_bits; + return HeartbeatInfo(TimeMs(created_at)); + } + + TimeMs created_at() const { return created_at_; } + + private: + const TimeMs created_at_; +}; + +HeartbeatHandler::HeartbeatHandler(absl::string_view log_prefix, + const DcSctpOptions& options, + Context* context, + TimerManager* timer_manager) + : log_prefix_(std::string(log_prefix) + "heartbeat: "), + ctx_(context), + timer_manager_(timer_manager), + interval_duration_(options.heartbeat_interval), + interval_duration_should_include_rtt_( + options.heartbeat_interval_include_rtt), + interval_timer_(timer_manager_->CreateTimer( + "heartbeat-interval", + [this]() { return OnIntervalTimerExpiry(); }, + TimerOptions(interval_duration_, TimerBackoffAlgorithm::kFixed))), + timeout_timer_(timer_manager_->CreateTimer( + "heartbeat-timeout", + [this]() { return OnTimeoutTimerExpiry(); }, + TimerOptions(options.rto_initial, + TimerBackoffAlgorithm::kExponential, + /*max_restarts=*/0))) { + // The interval timer must always be running as long as the association is up. + RestartTimer(); +} + +void HeartbeatHandler::RestartTimer() { + if (interval_duration_ == DurationMs(0)) { + // Heartbeating has been disabled. + return; + } + + if (interval_duration_should_include_rtt_) { + // The RTT should be used, but it's not easy accessible. The RTO will + // suffice. + interval_timer_->set_duration(interval_duration_ + ctx_->current_rto()); + } else { + interval_timer_->set_duration(interval_duration_); + } + + interval_timer_->Start(); +} + +void HeartbeatHandler::HandleHeartbeatRequest(HeartbeatRequestChunk chunk) { + // https://tools.ietf.org/html/rfc4960#section-8.3 + // "The receiver of the HEARTBEAT should immediately respond with a + // HEARTBEAT ACK that contains the Heartbeat Information TLV, together with + // any other received TLVs, copied unchanged from the received HEARTBEAT + // chunk." + ctx_->Send(ctx_->PacketBuilder().Add( + HeartbeatAckChunk(std::move(chunk).extract_parameters()))); +} + +void HeartbeatHandler::HandleHeartbeatAck(HeartbeatAckChunk chunk) { + timeout_timer_->Stop(); + absl::optional info_param = chunk.info(); + if (!info_param.has_value()) { + ctx_->callbacks().OnError( + ErrorKind::kParseFailed, + "Failed to parse HEARTBEAT-ACK; No Heartbeat Info parameter"); + return; + } + absl::optional info = + HeartbeatInfo::Deserialize(info_param->info()); + if (!info.has_value()) { + ctx_->callbacks().OnError(ErrorKind::kParseFailed, + "Failed to parse HEARTBEAT-ACK; Failed to " + "deserialized Heartbeat info parameter"); + return; + } + + DurationMs duration(*ctx_->callbacks().TimeMillis() - *info->created_at()); + + ctx_->ObserveRTT(duration); + + // https://tools.ietf.org/html/rfc4960#section-8.1 + // "The counter shall be reset each time ... a HEARTBEAT ACK is received from + // the peer endpoint." + ctx_->ClearTxErrorCounter(); +} + +absl::optional HeartbeatHandler::OnIntervalTimerExpiry() { + if (ctx_->is_connection_established()) { + HeartbeatInfo info(ctx_->callbacks().TimeMillis()); + timeout_timer_->set_duration(ctx_->current_rto()); + timeout_timer_->Start(); + RTC_DLOG(LS_INFO) << log_prefix_ << "Sending HEARTBEAT with timeout " + << *timeout_timer_->duration(); + + Parameters parameters = Parameters::Builder() + .Add(HeartbeatInfoParameter(info.Serialize())) + .Build(); + + ctx_->Send(ctx_->PacketBuilder().Add( + HeartbeatRequestChunk(std::move(parameters)))); + } else { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ + << "Will not send HEARTBEAT when connection not established"; + } + return absl::nullopt; +} + +absl::optional HeartbeatHandler::OnTimeoutTimerExpiry() { + // Note that the timeout timer is not restarted. It will be started again when + // the interval timer expires. + RTC_DCHECK(!timeout_timer_->is_running()); + ctx_->IncrementTxErrorCounter("HEARTBEAT timeout"); + return absl::nullopt; +} +} // namespace dcsctp diff --git a/net/dcsctp/socket/heartbeat_handler.h b/net/dcsctp/socket/heartbeat_handler.h new file mode 100644 index 0000000000..14c3109534 --- /dev/null +++ b/net/dcsctp/socket/heartbeat_handler.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_ +#define NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_ + +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/timer/timer.h" + +namespace dcsctp { + +// HeartbeatHandler handles all logic around sending heartbeats and receiving +// the responses, as well as receiving incoming heartbeat requests. +// +// Heartbeats are sent on idle connections to ensure that the connection is +// still healthy and to measure the RTT. If a number of heartbeats time out, +// the connection will eventually be closed. +class HeartbeatHandler { + public: + HeartbeatHandler(absl::string_view log_prefix, + const DcSctpOptions& options, + Context* context, + TimerManager* timer_manager); + + // Called when the heartbeat interval timer should be restarted. This is + // generally done every time data is sent, which makes the timer expire when + // the connection is idle. + void RestartTimer(); + + // Called on received HeartbeatRequestChunk chunks. + void HandleHeartbeatRequest(HeartbeatRequestChunk chunk); + + // Called on received HeartbeatRequestChunk chunks. + void HandleHeartbeatAck(HeartbeatAckChunk chunk); + + private: + absl::optional OnIntervalTimerExpiry(); + absl::optional OnTimeoutTimerExpiry(); + + const std::string log_prefix_; + Context* ctx_; + TimerManager* timer_manager_; + // The time for a connection to be idle before a heartbeat is sent. + const DurationMs interval_duration_; + // Adding RTT to the duration will add some jitter, which is good in + // production, but less good in unit tests, which is why it can be disabled. + const bool interval_duration_should_include_rtt_; + const std::unique_ptr interval_timer_; + const std::unique_ptr timeout_timer_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_ diff --git a/net/dcsctp/socket/heartbeat_handler_test.cc b/net/dcsctp/socket/heartbeat_handler_test.cc new file mode 100644 index 0000000000..2c5df9fd92 --- /dev/null +++ b/net/dcsctp/socket/heartbeat_handler_test.cc @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/heartbeat_handler.h" + +#include +#include +#include + +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/socket/mock_context.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::SizeIs; + +constexpr DurationMs kHeartbeatInterval = DurationMs(30'000); + +DcSctpOptions MakeOptions(DurationMs heartbeat_interval) { + DcSctpOptions options; + options.heartbeat_interval_include_rtt = false; + options.heartbeat_interval = heartbeat_interval; + return options; +} + +class HeartbeatHandlerTestBase : public testing::Test { + protected: + explicit HeartbeatHandlerTestBase(DurationMs heartbeat_interval) + : options_(MakeOptions(heartbeat_interval)), + context_(&callbacks_), + timer_manager_([this]() { return callbacks_.CreateTimeout(); }), + handler_("log: ", options_, &context_, &timer_manager_) {} + + void AdvanceTime(DurationMs duration) { + callbacks_.AdvanceTime(duration); + for (;;) { + absl::optional timeout_id = callbacks_.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + timer_manager_.HandleTimeout(*timeout_id); + } + } + + const DcSctpOptions options_; + NiceMock callbacks_; + NiceMock context_; + TimerManager timer_manager_; + HeartbeatHandler handler_; +}; + +class HeartbeatHandlerTest : public HeartbeatHandlerTestBase { + protected: + HeartbeatHandlerTest() : HeartbeatHandlerTestBase(kHeartbeatInterval) {} +}; + +class DisabledHeartbeatHandlerTest : public HeartbeatHandlerTestBase { + protected: + DisabledHeartbeatHandlerTest() : HeartbeatHandlerTestBase(DurationMs(0)) {} +}; + +TEST_F(HeartbeatHandlerTest, HasRunningHeartbeatIntervalTimer) { + AdvanceTime(options_.heartbeat_interval); + + // Validate that a heartbeat request was sent. + std::vector payload = callbacks_.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload)); + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatRequestChunk request, + HeartbeatRequestChunk::Parse(packet.descriptors()[0].data)); + + EXPECT_TRUE(request.info().has_value()); +} + +TEST_F(HeartbeatHandlerTest, RepliesToHeartbeatRequests) { + uint8_t info_data[] = {1, 2, 3, 4, 5}; + HeartbeatRequestChunk request( + Parameters::Builder().Add(HeartbeatInfoParameter(info_data)).Build()); + + handler_.HandleHeartbeatRequest(std::move(request)); + + std::vector payload = callbacks_.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload)); + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatAckChunk response, + HeartbeatAckChunk::Parse(packet.descriptors()[0].data)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatInfoParameter param, + response.parameters().get()); + + EXPECT_THAT(param.info(), ElementsAre(1, 2, 3, 4, 5)); +} + +TEST_F(HeartbeatHandlerTest, SendsHeartbeatRequestsOnIdleChannel) { + AdvanceTime(options_.heartbeat_interval); + + // Grab the request, and make a response. + std::vector payload = callbacks_.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload)); + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatRequestChunk req, + HeartbeatRequestChunk::Parse(packet.descriptors()[0].data)); + + HeartbeatAckChunk ack(std::move(req).extract_parameters()); + + // Respond a while later. This RTT will be measured by the handler + constexpr DurationMs rtt(313); + + EXPECT_CALL(context_, ObserveRTT(rtt)).Times(1); + + callbacks_.AdvanceTime(rtt); + handler_.HandleHeartbeatAck(std::move(ack)); +} + +TEST_F(HeartbeatHandlerTest, IncreasesErrorIfNotAckedInTime) { + DurationMs rto(105); + EXPECT_CALL(context_, current_rto).WillOnce(Return(rto)); + AdvanceTime(options_.heartbeat_interval); + + // Validate that a request was sent. + EXPECT_THAT(callbacks_.ConsumeSentPacket(), Not(IsEmpty())); + + EXPECT_CALL(context_, IncrementTxErrorCounter).Times(1); + AdvanceTime(rto); +} + +TEST_F(DisabledHeartbeatHandlerTest, IsReallyDisabled) { + AdvanceTime(options_.heartbeat_interval); + + // Validate that a request was NOT sent. + EXPECT_THAT(callbacks_.ConsumeSentPacket(), IsEmpty()); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/socket/mock_context.h b/net/dcsctp/socket/mock_context.h new file mode 100644 index 0000000000..d86b99a20d --- /dev/null +++ b/net/dcsctp/socket/mock_context.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_MOCK_CONTEXT_H_ +#define NET_DCSCTP_SOCKET_MOCK_CONTEXT_H_ + +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "test/gmock.h" + +namespace dcsctp { + +class MockContext : public Context { + public: + static constexpr TSN MyInitialTsn() { return TSN(990); } + static constexpr TSN PeerInitialTsn() { return TSN(10); } + static constexpr VerificationTag PeerVerificationTag() { + return VerificationTag(0x01234567); + } + + explicit MockContext(MockDcSctpSocketCallbacks* callbacks) + : callbacks_(*callbacks) { + ON_CALL(*this, is_connection_established) + .WillByDefault(testing::Return(true)); + ON_CALL(*this, my_initial_tsn) + .WillByDefault(testing::Return(MyInitialTsn())); + ON_CALL(*this, peer_initial_tsn) + .WillByDefault(testing::Return(PeerInitialTsn())); + ON_CALL(*this, callbacks).WillByDefault(testing::ReturnRef(callbacks_)); + ON_CALL(*this, current_rto).WillByDefault(testing::Return(DurationMs(123))); + ON_CALL(*this, Send).WillByDefault([this](SctpPacket::Builder& builder) { + callbacks_.SendPacket(builder.Build()); + }); + } + + MOCK_METHOD(bool, is_connection_established, (), (const, override)); + MOCK_METHOD(TSN, my_initial_tsn, (), (const, override)); + MOCK_METHOD(TSN, peer_initial_tsn, (), (const, override)); + MOCK_METHOD(DcSctpSocketCallbacks&, callbacks, (), (const, override)); + + MOCK_METHOD(void, ObserveRTT, (DurationMs rtt_ms), (override)); + MOCK_METHOD(DurationMs, current_rto, (), (const, override)); + MOCK_METHOD(bool, + IncrementTxErrorCounter, + (absl::string_view reason), + (override)); + MOCK_METHOD(void, ClearTxErrorCounter, (), (override)); + MOCK_METHOD(bool, HasTooManyTxErrors, (), (const, override)); + SctpPacket::Builder PacketBuilder() const override { + return SctpPacket::Builder(PeerVerificationTag(), options_); + } + MOCK_METHOD(void, Send, (SctpPacket::Builder & builder), (override)); + + DcSctpOptions options_; + MockDcSctpSocketCallbacks& callbacks_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_MOCK_CONTEXT_H_ diff --git a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h new file mode 100644 index 0000000000..bcf1bde5b8 --- /dev/null +++ b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_MOCK_DCSCTP_SOCKET_CALLBACKS_H_ +#define NET_DCSCTP_SOCKET_MOCK_DCSCTP_SOCKET_CALLBACKS_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/timeout.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/timer/fake_timeout.h" +#include "rtc_base/logging.h" +#include "rtc_base/random.h" +#include "test/gmock.h" + +namespace dcsctp { + +namespace internal { +// It can be argued if a mocked random number generator should be deterministic +// or if it should be have as a "real" random number generator. In this +// implementation, each instantiation of `MockDcSctpSocketCallbacks` will have +// their `GetRandomInt` return different sequences, but each instantiation will +// always generate the same sequence of random numbers. This to make it easier +// to compare logs from tests, but still to let e.g. two different sockets (used +// in the same test) get different random numbers, so that they don't start e.g. +// on the same sequence number. While that isn't an issue in the protocol, it +// just makes debugging harder as the two sockets would look exactly the same. +// +// In a real implementation of `DcSctpSocketCallbacks` the random number +// generator backing `GetRandomInt` should be seeded externally and correctly. +inline int GetUniqueSeed() { + static int seed = 0; + return ++seed; +} +} // namespace internal + +class MockDcSctpSocketCallbacks : public DcSctpSocketCallbacks { + public: + explicit MockDcSctpSocketCallbacks(absl::string_view name = "") + : log_prefix_(name.empty() ? "" : std::string(name) + ": "), + random_(internal::GetUniqueSeed()), + timeout_manager_([this]() { return now_; }) { + ON_CALL(*this, SendPacket) + .WillByDefault([this](rtc::ArrayView data) { + sent_packets_.emplace_back( + std::vector(data.begin(), data.end())); + }); + ON_CALL(*this, OnMessageReceived) + .WillByDefault([this](DcSctpMessage message) { + received_messages_.emplace_back(std::move(message)); + }); + + ON_CALL(*this, OnError) + .WillByDefault([this](ErrorKind error, absl::string_view message) { + RTC_LOG(LS_WARNING) + << log_prefix_ << "Socket error: " << ToString(error) << "; " + << message; + }); + ON_CALL(*this, OnAborted) + .WillByDefault([this](ErrorKind error, absl::string_view message) { + RTC_LOG(LS_WARNING) + << log_prefix_ << "Socket abort: " << ToString(error) << "; " + << message; + }); + ON_CALL(*this, TimeMillis).WillByDefault([this]() { return now_; }); + } + MOCK_METHOD(void, + SendPacket, + (rtc::ArrayView data), + (override)); + + std::unique_ptr CreateTimeout() override { + return timeout_manager_.CreateTimeout(); + } + + MOCK_METHOD(TimeMs, TimeMillis, (), (override)); + uint32_t GetRandomInt(uint32_t low, uint32_t high) override { + return random_.Rand(low, high); + } + + MOCK_METHOD(void, OnMessageReceived, (DcSctpMessage message), (override)); + MOCK_METHOD(void, + OnError, + (ErrorKind error, absl::string_view message), + (override)); + MOCK_METHOD(void, + OnAborted, + (ErrorKind error, absl::string_view message), + (override)); + MOCK_METHOD(void, OnConnected, (), (override)); + MOCK_METHOD(void, OnClosed, (), (override)); + MOCK_METHOD(void, OnConnectionRestarted, (), (override)); + MOCK_METHOD(void, + OnStreamsResetFailed, + (rtc::ArrayView outgoing_streams, + absl::string_view reason), + (override)); + MOCK_METHOD(void, + OnStreamsResetPerformed, + (rtc::ArrayView outgoing_streams), + (override)); + MOCK_METHOD(void, + OnIncomingStreamsReset, + (rtc::ArrayView incoming_streams), + (override)); + MOCK_METHOD(void, OnBufferedAmountLow, (StreamID stream_id), (override)); + MOCK_METHOD(void, OnTotalBufferedAmountLow, (), (override)); + + bool HasPacket() const { return !sent_packets_.empty(); } + + std::vector ConsumeSentPacket() { + if (sent_packets_.empty()) { + return {}; + } + std::vector ret = std::move(sent_packets_.front()); + sent_packets_.pop_front(); + return ret; + } + absl::optional ConsumeReceivedMessage() { + if (received_messages_.empty()) { + return absl::nullopt; + } + DcSctpMessage ret = std::move(received_messages_.front()); + received_messages_.pop_front(); + return ret; + } + + void AdvanceTime(DurationMs duration_ms) { now_ = now_ + duration_ms; } + void SetTime(TimeMs now) { now_ = now; } + + absl::optional GetNextExpiredTimeout() { + return timeout_manager_.GetNextExpiredTimeout(); + } + + private: + const std::string log_prefix_; + TimeMs now_ = TimeMs(0); + webrtc::Random random_; + FakeTimeoutManager timeout_manager_; + std::deque> sent_packets_; + std::deque received_messages_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_MOCK_DCSCTP_SOCKET_CALLBACKS_H_ diff --git a/net/dcsctp/socket/state_cookie.cc b/net/dcsctp/socket/state_cookie.cc new file mode 100644 index 0000000000..7d04cbb0d7 --- /dev/null +++ b/net/dcsctp/socket/state_cookie.cc @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/state_cookie.h" + +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/socket/capabilities.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +// Magic values, which the state cookie is prefixed with. +constexpr uint32_t kMagic1 = 1684230979; +constexpr uint32_t kMagic2 = 1414541360; +constexpr size_t StateCookie::kCookieSize; + +std::vector StateCookie::Serialize() { + std::vector cookie; + cookie.resize(kCookieSize); + BoundedByteWriter buffer(cookie); + buffer.Store32<0>(kMagic1); + buffer.Store32<4>(kMagic2); + buffer.Store32<8>(*initiate_tag_); + buffer.Store32<12>(*initial_tsn_); + buffer.Store32<16>(a_rwnd_); + buffer.Store32<20>(static_cast(*tie_tag_ >> 32)); + buffer.Store32<24>(static_cast(*tie_tag_)); + buffer.Store8<28>(capabilities_.partial_reliability); + buffer.Store8<29>(capabilities_.message_interleaving); + buffer.Store8<30>(capabilities_.reconfig); + return cookie; +} + +absl::optional StateCookie::Deserialize( + rtc::ArrayView cookie) { + if (cookie.size() != kCookieSize) { + RTC_DLOG(LS_WARNING) << "Invalid state cookie: " << cookie.size() + << " bytes"; + return absl::nullopt; + } + + BoundedByteReader buffer(cookie); + uint32_t magic1 = buffer.Load32<0>(); + uint32_t magic2 = buffer.Load32<4>(); + if (magic1 != kMagic1 || magic2 != kMagic2) { + RTC_DLOG(LS_WARNING) << "Invalid state cookie; wrong magic"; + return absl::nullopt; + } + + VerificationTag verification_tag(buffer.Load32<8>()); + TSN initial_tsn(buffer.Load32<12>()); + uint32_t a_rwnd = buffer.Load32<16>(); + uint32_t tie_tag_upper = buffer.Load32<20>(); + uint32_t tie_tag_lower = buffer.Load32<24>(); + TieTag tie_tag(static_cast(tie_tag_upper) << 32 | + static_cast(tie_tag_lower)); + Capabilities capabilities; + capabilities.partial_reliability = buffer.Load8<28>() != 0; + capabilities.message_interleaving = buffer.Load8<29>() != 0; + capabilities.reconfig = buffer.Load8<30>() != 0; + + return StateCookie(verification_tag, initial_tsn, a_rwnd, tie_tag, + capabilities); +} + +} // namespace dcsctp diff --git a/net/dcsctp/socket/state_cookie.h b/net/dcsctp/socket/state_cookie.h new file mode 100644 index 0000000000..df4b801397 --- /dev/null +++ b/net/dcsctp/socket/state_cookie.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_STATE_COOKIE_H_ +#define NET_DCSCTP_SOCKET_STATE_COOKIE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/socket/capabilities.h" + +namespace dcsctp { + +// This is serialized as a state cookie and put in INIT_ACK. The client then +// responds with this in COOKIE_ECHO. +// +// NOTE: Expect that the client will modify it to try to exploit the library. +// Do not trust anything in it; no pointers or anything like that. +class StateCookie { + public: + static constexpr size_t kCookieSize = 31; + + StateCookie(VerificationTag initiate_tag, + TSN initial_tsn, + uint32_t a_rwnd, + TieTag tie_tag, + Capabilities capabilities) + : initiate_tag_(initiate_tag), + initial_tsn_(initial_tsn), + a_rwnd_(a_rwnd), + tie_tag_(tie_tag), + capabilities_(capabilities) {} + + // Returns a serialized version of this cookie. + std::vector Serialize(); + + // Deserializes the cookie, and returns absl::nullopt if that failed. + static absl::optional Deserialize( + rtc::ArrayView cookie); + + VerificationTag initiate_tag() const { return initiate_tag_; } + TSN initial_tsn() const { return initial_tsn_; } + uint32_t a_rwnd() const { return a_rwnd_; } + TieTag tie_tag() const { return tie_tag_; } + const Capabilities& capabilities() const { return capabilities_; } + + private: + const VerificationTag initiate_tag_; + const TSN initial_tsn_; + const uint32_t a_rwnd_; + const TieTag tie_tag_; + const Capabilities capabilities_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_STATE_COOKIE_H_ diff --git a/net/dcsctp/socket/state_cookie_test.cc b/net/dcsctp/socket/state_cookie_test.cc new file mode 100644 index 0000000000..eab41a7a56 --- /dev/null +++ b/net/dcsctp/socket/state_cookie_test.cc @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/state_cookie.h" + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::SizeIs; + +TEST(StateCookieTest, SerializeAndDeserialize) { + Capabilities capabilities = {/*partial_reliability=*/true, + /*message_interleaving=*/false, + /*reconfig=*/true}; + StateCookie cookie(VerificationTag(123), TSN(456), + /*a_rwnd=*/789, TieTag(101112), capabilities); + std::vector serialized = cookie.Serialize(); + EXPECT_THAT(serialized, SizeIs(StateCookie::kCookieSize)); + ASSERT_HAS_VALUE_AND_ASSIGN(StateCookie deserialized, + StateCookie::Deserialize(serialized)); + EXPECT_EQ(deserialized.initiate_tag(), VerificationTag(123)); + EXPECT_EQ(deserialized.initial_tsn(), TSN(456)); + EXPECT_EQ(deserialized.a_rwnd(), 789u); + EXPECT_EQ(deserialized.tie_tag(), TieTag(101112)); + EXPECT_TRUE(deserialized.capabilities().partial_reliability); + EXPECT_FALSE(deserialized.capabilities().message_interleaving); + EXPECT_TRUE(deserialized.capabilities().reconfig); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/socket/stream_reset_handler.cc b/net/dcsctp/socket/stream_reset_handler.cc new file mode 100644 index 0000000000..a1f57e6b2b --- /dev/null +++ b/net/dcsctp/socket/stream_reset_handler.cc @@ -0,0 +1,347 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/stream_reset_handler.h" + +#include +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h" +#include "net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h" +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "rtc_base/logging.h" + +namespace dcsctp { +namespace { +using ResponseResult = ReconfigurationResponseParameter::Result; + +bool DescriptorsAre(const std::vector& c, + uint16_t e1, + uint16_t e2) { + return (c[0].type == e1 && c[1].type == e2) || + (c[0].type == e2 && c[1].type == e1); +} + +} // namespace + +bool StreamResetHandler::Validate(const ReConfigChunk& chunk) { + const Parameters& parameters = chunk.parameters(); + + // https://tools.ietf.org/html/rfc6525#section-3.1 + // "Note that each RE-CONFIG chunk holds at least one parameter + // and at most two parameters. Only the following combinations are allowed:" + std::vector descriptors = parameters.descriptors(); + if (descriptors.size() == 1) { + if ((descriptors[0].type == OutgoingSSNResetRequestParameter::kType) || + (descriptors[0].type == IncomingSSNResetRequestParameter::kType) || + (descriptors[0].type == SSNTSNResetRequestParameter::kType) || + (descriptors[0].type == AddOutgoingStreamsRequestParameter::kType) || + (descriptors[0].type == AddIncomingStreamsRequestParameter::kType) || + (descriptors[0].type == ReconfigurationResponseParameter::kType)) { + return true; + } + } else if (descriptors.size() == 2) { + if (DescriptorsAre(descriptors, OutgoingSSNResetRequestParameter::kType, + IncomingSSNResetRequestParameter::kType) || + DescriptorsAre(descriptors, AddOutgoingStreamsRequestParameter::kType, + AddIncomingStreamsRequestParameter::kType) || + DescriptorsAre(descriptors, ReconfigurationResponseParameter::kType, + OutgoingSSNResetRequestParameter::kType) || + DescriptorsAre(descriptors, ReconfigurationResponseParameter::kType, + ReconfigurationResponseParameter::kType)) { + return true; + } + } + + RTC_LOG(LS_WARNING) << "Invalid set of RE-CONFIG parameters"; + return false; +} + +absl::optional> +StreamResetHandler::Process(const ReConfigChunk& chunk) { + if (!Validate(chunk)) { + return absl::nullopt; + } + + std::vector responses; + + for (const ParameterDescriptor& desc : chunk.parameters().descriptors()) { + switch (desc.type) { + case OutgoingSSNResetRequestParameter::kType: + HandleResetOutgoing(desc, responses); + break; + + case IncomingSSNResetRequestParameter::kType: + HandleResetIncoming(desc, responses); + break; + + case ReconfigurationResponseParameter::kType: + HandleResponse(desc); + break; + } + } + + return responses; +} + +void StreamResetHandler::HandleReConfig(ReConfigChunk chunk) { + absl::optional> responses = + Process(chunk); + + if (!responses.has_value()) { + ctx_->callbacks().OnError(ErrorKind::kParseFailed, + "Failed to parse RE-CONFIG command"); + return; + } + + if (!responses->empty()) { + SctpPacket::Builder b = ctx_->PacketBuilder(); + Parameters::Builder params_builder; + for (const auto& response : *responses) { + params_builder.Add(response); + } + b.Add(ReConfigChunk(params_builder.Build())); + ctx_->Send(b); + } +} + +bool StreamResetHandler::ValidateReqSeqNbr( + ReconfigRequestSN req_seq_nbr, + std::vector& responses) { + if (req_seq_nbr == last_processed_req_seq_nbr_) { + // This has already been performed previously. + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "req=" << *req_seq_nbr + << " already processed"; + responses.push_back(ReconfigurationResponseParameter( + req_seq_nbr, ResponseResult::kSuccessNothingToDo)); + return false; + } + + if (req_seq_nbr != ReconfigRequestSN(*last_processed_req_seq_nbr_ + 1)) { + // Too old, too new, from wrong association etc. + // This is expected to happen when handing over a RTCPeerConnection from one + // server to another. The client will notice this and may decide to close + // old data channels, which may be sent to the wrong (or both) servers + // during a handover. + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "req=" << *req_seq_nbr + << " bad seq_nbr"; + responses.push_back(ReconfigurationResponseParameter( + req_seq_nbr, ResponseResult::kErrorBadSequenceNumber)); + return false; + } + + return true; +} + +void StreamResetHandler::HandleResetOutgoing( + const ParameterDescriptor& descriptor, + std::vector& responses) { + absl::optional req = + OutgoingSSNResetRequestParameter::Parse(descriptor.data); + if (!req.has_value()) { + ctx_->callbacks().OnError(ErrorKind::kParseFailed, + "Failed to parse Outgoing Reset command"); + return; + } + + if (ValidateReqSeqNbr(req->request_sequence_number(), responses)) { + ResponseResult result; + + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "Reset outgoing streams with req_seq_nbr=" + << *req->request_sequence_number(); + + result = reassembly_queue_->ResetStreams( + *req, data_tracker_->last_cumulative_acked_tsn()); + if (result == ResponseResult::kSuccessPerformed) { + last_processed_req_seq_nbr_ = req->request_sequence_number(); + ctx_->callbacks().OnIncomingStreamsReset(req->stream_ids()); + } + responses.push_back(ReconfigurationResponseParameter( + req->request_sequence_number(), result)); + } +} + +void StreamResetHandler::HandleResetIncoming( + const ParameterDescriptor& descriptor, + std::vector& responses) { + absl::optional req = + IncomingSSNResetRequestParameter::Parse(descriptor.data); + if (!req.has_value()) { + ctx_->callbacks().OnError(ErrorKind::kParseFailed, + "Failed to parse Incoming Reset command"); + return; + } + if (ValidateReqSeqNbr(req->request_sequence_number(), responses)) { + responses.push_back(ReconfigurationResponseParameter( + req->request_sequence_number(), ResponseResult::kSuccessNothingToDo)); + last_processed_req_seq_nbr_ = req->request_sequence_number(); + } +} + +void StreamResetHandler::HandleResponse(const ParameterDescriptor& descriptor) { + absl::optional resp = + ReconfigurationResponseParameter::Parse(descriptor.data); + if (!resp.has_value()) { + ctx_->callbacks().OnError( + ErrorKind::kParseFailed, + "Failed to parse Reconfiguration Response command"); + return; + } + + if (current_request_.has_value() && current_request_->has_been_sent() && + resp->response_sequence_number() == current_request_->req_seq_nbr()) { + reconfig_timer_->Stop(); + + switch (resp->result()) { + case ResponseResult::kSuccessNothingToDo: + case ResponseResult::kSuccessPerformed: + RTC_DLOG(LS_VERBOSE) + << log_prefix_ << "Reset stream success, req_seq_nbr=" + << *current_request_->req_seq_nbr() << ", streams=" + << StrJoin(current_request_->streams(), ",", + [](rtc::StringBuilder& sb, StreamID stream_id) { + sb << *stream_id; + }); + ctx_->callbacks().OnStreamsResetPerformed(current_request_->streams()); + current_request_ = absl::nullopt; + retransmission_queue_->CommitResetStreams(); + break; + case ResponseResult::kInProgress: + RTC_DLOG(LS_VERBOSE) + << log_prefix_ << "Reset stream still pending, req_seq_nbr=" + << *current_request_->req_seq_nbr() << ", streams=" + << StrJoin(current_request_->streams(), ",", + [](rtc::StringBuilder& sb, StreamID stream_id) { + sb << *stream_id; + }); + // Force this request to be sent again, but with new req_seq_nbr. + current_request_->PrepareRetransmission(); + reconfig_timer_->set_duration(ctx_->current_rto()); + reconfig_timer_->Start(); + break; + case ResponseResult::kErrorRequestAlreadyInProgress: + case ResponseResult::kDenied: + case ResponseResult::kErrorWrongSSN: + case ResponseResult::kErrorBadSequenceNumber: + RTC_DLOG(LS_WARNING) + << log_prefix_ << "Reset stream error=" << ToString(resp->result()) + << ", req_seq_nbr=" << *current_request_->req_seq_nbr() + << ", streams=" + << StrJoin(current_request_->streams(), ",", + [](rtc::StringBuilder& sb, StreamID stream_id) { + sb << *stream_id; + }); + ctx_->callbacks().OnStreamsResetFailed(current_request_->streams(), + ToString(resp->result())); + current_request_ = absl::nullopt; + retransmission_queue_->RollbackResetStreams(); + break; + } + } +} + +absl::optional StreamResetHandler::MakeStreamResetRequest() { + // Only send stream resets if there are streams to reset, and no current + // ongoing request (there can only be one at a time), and if the stream + // can be reset. + if (streams_to_reset_.empty() || current_request_.has_value() || + !retransmission_queue_->CanResetStreams()) { + return absl::nullopt; + } + + std::vector streams_to_reset(streams_to_reset_.begin(), + streams_to_reset_.end()); + current_request_.emplace(TSN(*retransmission_queue_->next_tsn() - 1), + std::move(streams_to_reset)); + streams_to_reset_.clear(); + reconfig_timer_->set_duration(ctx_->current_rto()); + reconfig_timer_->Start(); + return MakeReconfigChunk(); +} + +ReConfigChunk StreamResetHandler::MakeReconfigChunk() { + // The req_seq_nbr will be empty if the request has never been sent before, + // or if it was sent, but the sender responded "in progress", and then the + // req_seq_nbr will be cleared to re-send with a new number. But if the + // request is re-sent due to timeout (reconfig-timer expiring), the same + // req_seq_nbr will be used. + RTC_DCHECK(current_request_.has_value()); + + if (!current_request_->has_been_sent()) { + current_request_->PrepareToSend(next_outgoing_req_seq_nbr_); + next_outgoing_req_seq_nbr_ = + ReconfigRequestSN(*next_outgoing_req_seq_nbr_ + 1); + } + + Parameters::Builder params_builder = + Parameters::Builder().Add(OutgoingSSNResetRequestParameter( + current_request_->req_seq_nbr(), current_request_->req_seq_nbr(), + current_request_->sender_last_assigned_tsn(), + current_request_->streams())); + + return ReConfigChunk(params_builder.Build()); +} + +void StreamResetHandler::ResetStreams( + rtc::ArrayView outgoing_streams) { + // Enqueue streams to be reset - as this may be called multiple times + // while a request is already in progress (and there can only be one). + for (StreamID stream_id : outgoing_streams) { + streams_to_reset_.insert(stream_id); + } + if (current_request_.has_value()) { + // Already an ongoing request - will need to wait for it to finish as + // there can only be one in-flight ReConfig chunk with requests at any + // time. + } else { + retransmission_queue_->PrepareResetStreams(std::vector( + streams_to_reset_.begin(), streams_to_reset_.end())); + } +} + +absl::optional StreamResetHandler::OnReconfigTimerExpiry() { + if (current_request_->has_been_sent()) { + // There is an outstanding request, which timed out while waiting for a + // response. + if (!ctx_->IncrementTxErrorCounter("RECONFIG timeout")) { + // Timed out. The connection will close after processing the timers. + return absl::nullopt; + } + } else { + // There is no outstanding request, but there is a prepared one. This means + // that the receiver has previously responded "in progress", which resulted + // in retrying the request (but with a new req_seq_nbr) after a while. + } + + ctx_->Send(ctx_->PacketBuilder().Add(MakeReconfigChunk())); + return ctx_->current_rto(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/socket/stream_reset_handler.h b/net/dcsctp/socket/stream_reset_handler.h new file mode 100644 index 0000000000..dc0ee5e8cc --- /dev/null +++ b/net/dcsctp/socket/stream_reset_handler.h @@ -0,0 +1,222 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_STREAM_RESET_HANDLER_H_ +#define NET_DCSCTP_SOCKET_STREAM_RESET_HANDLER_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_queue.h" + +namespace dcsctp { + +// StreamResetHandler handles sending outgoing stream reset requests (to close +// an SCTP stream, which translates to closing a data channel). +// +// It also handles incoming "outgoing stream reset requests", when the peer +// wants to close its data channel. +// +// Resetting streams is an asynchronous operation where the client will request +// a request a stream to be reset, but then it might not be performed exactly at +// this point. First, the sender might need to discard all messages that have +// been enqueued for this stream, or it may select to wait until all have been +// sent. At least, it must wait for the currently sending fragmented message to +// be fully sent, because a stream can't be reset while having received half a +// message. In the stream reset request, the "sender's last assigned TSN" is +// provided, which is simply the TSN for which the receiver should've received +// all messages before this value, before the stream can be reset. Since +// fragments can get lost or sent out-of-order, the receiver of a request may +// not have received all the data just yet, and then it will respond to the +// sender: "In progress". In other words, try again. The sender will then need +// to start a timer and try the very same request again (but with a new sequence +// number) until the receiver successfully performs the operation. +// +// All this can take some time, and may be driven by timers, so the client will +// ultimately be notified using callbacks. +// +// In this implementation, when a stream is reset, the queued but not-yet-sent +// messages will be discarded, but that may change in the future. RFC8831 allows +// both behaviors. +class StreamResetHandler { + public: + StreamResetHandler(absl::string_view log_prefix, + Context* context, + TimerManager* timer_manager, + DataTracker* data_tracker, + ReassemblyQueue* reassembly_queue, + RetransmissionQueue* retransmission_queue) + : log_prefix_(std::string(log_prefix) + "reset: "), + ctx_(context), + data_tracker_(data_tracker), + reassembly_queue_(reassembly_queue), + retransmission_queue_(retransmission_queue), + reconfig_timer_(timer_manager->CreateTimer( + "re-config", + [this]() { return OnReconfigTimerExpiry(); }, + TimerOptions(DurationMs(0)))), + next_outgoing_req_seq_nbr_(ReconfigRequestSN(*ctx_->my_initial_tsn())), + last_processed_req_seq_nbr_( + ReconfigRequestSN(*ctx_->peer_initial_tsn() - 1)) {} + + // Initiates reset of the provided streams. While there can only be one + // ongoing stream reset request at any time, this method can be called at any + // time and also multiple times. It will enqueue requests that can't be + // directly fulfilled, and will asynchronously process them when any ongoing + // request has completed. + void ResetStreams(rtc::ArrayView outgoing_streams); + + // Creates a Reset Streams request that must be sent if returned. Will start + // the reconfig timer. Will return absl::nullopt if there is no need to + // create a request (no streams to reset) or if there already is an ongoing + // stream reset request that hasn't completed yet. + absl::optional MakeStreamResetRequest(); + + // Called when handling and incoming RE-CONFIG chunk. + void HandleReConfig(ReConfigChunk chunk); + + private: + // Represents a stream request operation. There can only be one ongoing at + // any time, and a sent request may either succeed, fail or result in the + // receiver signaling that it can't process it right now, and then it will be + // retried. + class CurrentRequest { + public: + CurrentRequest(TSN sender_last_assigned_tsn, std::vector streams) + : req_seq_nbr_(absl::nullopt), + sender_last_assigned_tsn_(sender_last_assigned_tsn), + streams_(std::move(streams)) {} + + // Returns the current request sequence number, if this request has been + // sent (check `has_been_sent` first). Will return 0 if the request is just + // prepared (or scheduled for retransmission) but not yet sent. + ReconfigRequestSN req_seq_nbr() const { + return req_seq_nbr_.value_or(ReconfigRequestSN(0)); + } + + // The sender's last assigned TSN, from the retransmission queue. The + // receiver uses this to know when all data up to this TSN has been + // received, to know when to safely reset the stream. + TSN sender_last_assigned_tsn() const { return sender_last_assigned_tsn_; } + + // The streams that are to be reset. + const std::vector& streams() const { return streams_; } + + // If this request has been sent yet. If not, then it's either because it + // has only been prepared and not yet sent, or because the received couldn't + // apply the request, and then the exact same request will be retried, but + // with a new sequence number. + bool has_been_sent() const { return req_seq_nbr_.has_value(); } + + // If the receiver can't apply the request yet (and answered "In Progress"), + // this will be called to prepare the request to be retransmitted at a later + // time. + void PrepareRetransmission() { req_seq_nbr_ = absl::nullopt; } + + // If the request hasn't been sent yet, this assigns it a request number. + void PrepareToSend(ReconfigRequestSN new_req_seq_nbr) { + req_seq_nbr_ = new_req_seq_nbr; + } + + private: + // If this is set, this request has been sent. If it's not set, the request + // has been prepared, but has not yet been sent. This is typically used when + // the peer responded "in progress" and the same request (but a different + // request number) must be sent again. + absl::optional req_seq_nbr_; + // The sender's (that's us) last assigned TSN, from the retransmission + // queue. + TSN sender_last_assigned_tsn_; + // The streams that are to be reset in this request. + const std::vector streams_; + }; + + // Called to validate an incoming RE-CONFIG chunk. + bool Validate(const ReConfigChunk& chunk); + + // Processes a stream stream reconfiguration chunk and may either return + // absl::nullopt (on protocol errors), or a list of responses - either 0, 1 + // or 2. + absl::optional> Process( + const ReConfigChunk& chunk); + + // Creates the actual RE-CONFIG chunk. A request (which set `current_request`) + // must have been created prior. + ReConfigChunk MakeReconfigChunk(); + + // Called to validate the `req_seq_nbr`, that it's the next in sequence. If it + // fails to validate, and returns false, it will also add a response to + // `responses`. + bool ValidateReqSeqNbr( + ReconfigRequestSN req_seq_nbr, + std::vector& responses); + + // Called when this socket receives an outgoing stream reset request. It might + // either be performed straight away, or have to be deferred, and the result + // of that will be put in `responses`. + void HandleResetOutgoing( + const ParameterDescriptor& descriptor, + std::vector& responses); + + // Called when this socket receives an incoming stream reset request. This + // isn't really supported, but a successful response is put in `responses`. + void HandleResetIncoming( + const ParameterDescriptor& descriptor, + std::vector& responses); + + // Called when receiving a response to an outgoing stream reset request. It + // will either commit the stream resetting, if the operation was successful, + // or will schedule a retry if it was deferred. And if it failed, the + // operation will be rolled back. + void HandleResponse(const ParameterDescriptor& descriptor); + + // Expiration handler for the Reconfig timer. + absl::optional OnReconfigTimerExpiry(); + + const std::string log_prefix_; + Context* ctx_; + DataTracker* data_tracker_; + ReassemblyQueue* reassembly_queue_; + RetransmissionQueue* retransmission_queue_; + const std::unique_ptr reconfig_timer_; + + // Outgoing streams that have been requested to be reset, but hasn't yet + // been included in an outgoing request. + std::unordered_set streams_to_reset_; + + // The next sequence number for outgoing stream requests. + ReconfigRequestSN next_outgoing_req_seq_nbr_; + + // The current stream request operation. + absl::optional current_request_; + + // For incoming requests - last processed request sequence number. + ReconfigRequestSN last_processed_req_seq_nbr_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_STREAM_RESET_HANDLER_H_ diff --git a/net/dcsctp/socket/stream_reset_handler_test.cc b/net/dcsctp/socket/stream_reset_handler_test.cc new file mode 100644 index 0000000000..a8e96fbf20 --- /dev/null +++ b/net/dcsctp/socket/stream_reset_handler_test.cc @@ -0,0 +1,550 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/stream_reset_handler.h" + +#include +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/mock_context.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "net/dcsctp/testing/data_generator.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/mock_send_queue.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::_; +using ::testing::IsEmpty; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; +using ResponseResult = ReconfigurationResponseParameter::Result; + +constexpr TSN kMyInitialTsn = MockContext::MyInitialTsn(); +constexpr ReconfigRequestSN kMyInitialReqSn = ReconfigRequestSN(*kMyInitialTsn); +constexpr TSN kPeerInitialTsn = MockContext::PeerInitialTsn(); +constexpr ReconfigRequestSN kPeerInitialReqSn = + ReconfigRequestSN(*kPeerInitialTsn); +constexpr uint32_t kArwnd = 131072; +constexpr DurationMs kRto = DurationMs(250); + +constexpr std::array kShortPayload = {1, 2, 3, 4}; + +MATCHER_P3(SctpMessageIs, stream_id, ppid, expected_payload, "") { + if (arg.stream_id() != stream_id) { + *result_listener << "the stream_id is " << *arg.stream_id(); + return false; + } + + if (arg.ppid() != ppid) { + *result_listener << "the ppid is " << *arg.ppid(); + return false; + } + + if (std::vector(arg.payload().begin(), arg.payload().end()) != + std::vector(expected_payload.begin(), expected_payload.end())) { + *result_listener << "the payload is wrong"; + return false; + } + return true; +} + +TSN AddTo(TSN tsn, int delta) { + return TSN(*tsn + delta); +} + +ReconfigRequestSN AddTo(ReconfigRequestSN req_sn, int delta) { + return ReconfigRequestSN(*req_sn + delta); +} + +class StreamResetHandlerTest : public testing::Test { + protected: + StreamResetHandlerTest() + : ctx_(&callbacks_), + timer_manager_([this]() { return callbacks_.CreateTimeout(); }), + delayed_ack_timer_(timer_manager_.CreateTimer( + "test/delayed_ack", + []() { return absl::nullopt; }, + TimerOptions(DurationMs(0)))), + t3_rtx_timer_(timer_manager_.CreateTimer( + "test/t3_rtx", + []() { return absl::nullopt; }, + TimerOptions(DurationMs(0)))), + buf_("log: ", delayed_ack_timer_.get(), kPeerInitialTsn), + reasm_("log: ", kPeerInitialTsn, kArwnd), + retransmission_queue_( + "", + kMyInitialTsn, + kArwnd, + producer_, + [](DurationMs rtt_ms) {}, + []() {}, + *t3_rtx_timer_, + /*options=*/{}), + handler_("log: ", + &ctx_, + &timer_manager_, + &buf_, + &reasm_, + &retransmission_queue_) { + EXPECT_CALL(ctx_, current_rto).WillRepeatedly(Return(kRto)); + } + + void AdvanceTime(DurationMs duration) { + callbacks_.AdvanceTime(kRto); + for (;;) { + absl::optional timeout_id = callbacks_.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + timer_manager_.HandleTimeout(*timeout_id); + } + } + + // Handles the passed in RE-CONFIG `chunk` and returns the responses + // that are sent in the response RE-CONFIG. + std::vector HandleAndCatchResponse( + ReConfigChunk chunk) { + handler_.HandleReConfig(std::move(chunk)); + + std::vector payload = callbacks_.ConsumeSentPacket(); + if (payload.empty()) { + EXPECT_TRUE(false); + return {}; + } + + std::vector responses; + absl::optional p = SctpPacket::Parse(payload); + if (!p.has_value()) { + EXPECT_TRUE(false); + return {}; + } + if (p->descriptors().size() != 1) { + EXPECT_TRUE(false); + return {}; + } + absl::optional response_chunk = + ReConfigChunk::Parse(p->descriptors()[0].data); + if (!response_chunk.has_value()) { + EXPECT_TRUE(false); + return {}; + } + for (const auto& desc : response_chunk->parameters().descriptors()) { + if (desc.type == ReconfigurationResponseParameter::kType) { + absl::optional response = + ReconfigurationResponseParameter::Parse(desc.data); + if (!response.has_value()) { + EXPECT_TRUE(false); + return {}; + } + responses.emplace_back(*std::move(response)); + } + } + return responses; + } + + DataGenerator gen_; + NiceMock callbacks_; + NiceMock ctx_; + NiceMock producer_; + TimerManager timer_manager_; + std::unique_ptr delayed_ack_timer_; + std::unique_ptr t3_rtx_timer_; + DataTracker buf_; + ReassemblyQueue reasm_; + RetransmissionQueue retransmission_queue_; + StreamResetHandler handler_; +}; + +TEST_F(StreamResetHandlerTest, ChunkWithNoParametersReturnsError) { + EXPECT_CALL(callbacks_, SendPacket).Times(0); + EXPECT_CALL(callbacks_, OnError).Times(1); + handler_.HandleReConfig(ReConfigChunk(Parameters())); +} + +TEST_F(StreamResetHandlerTest, ChunkWithInvalidParametersReturnsError) { + Parameters::Builder builder; + // Two OutgoingSSNResetRequestParameter in a RE-CONFIG is not valid. + builder.Add(OutgoingSSNResetRequestParameter(ReconfigRequestSN(1), + ReconfigRequestSN(10), + kPeerInitialTsn, {StreamID(1)})); + builder.Add(OutgoingSSNResetRequestParameter(ReconfigRequestSN(2), + ReconfigRequestSN(10), + kPeerInitialTsn, {StreamID(2)})); + + EXPECT_CALL(callbacks_, SendPacket).Times(0); + EXPECT_CALL(callbacks_, OnError).Times(1); + handler_.HandleReConfig(ReConfigChunk(builder.Build())); +} + +TEST_F(StreamResetHandlerTest, FailToDeliverWithoutResettingStream) { + reasm_.Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE")); + reasm_.Add(AddTo(kPeerInitialTsn, 1), gen_.Ordered({1, 2, 3, 4}, "BE")); + + buf_.Observe(kPeerInitialTsn); + buf_.Observe(AddTo(kPeerInitialTsn, 1)); + EXPECT_THAT(reasm_.FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(53), kShortPayload), + SctpMessageIs(StreamID(1), PPID(53), kShortPayload))); + + gen_.ResetStream(); + reasm_.Add(AddTo(kPeerInitialTsn, 2), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_THAT(reasm_.FlushMessages(), IsEmpty()); +} + +TEST_F(StreamResetHandlerTest, ResetStreamsNotDeferred) { + reasm_.Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE")); + reasm_.Add(AddTo(kPeerInitialTsn, 1), gen_.Ordered({1, 2, 3, 4}, "BE")); + + buf_.Observe(kPeerInitialTsn); + buf_.Observe(AddTo(kPeerInitialTsn, 1)); + EXPECT_THAT(reasm_.FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(53), kShortPayload), + SctpMessageIs(StreamID(1), PPID(53), kShortPayload))); + + Parameters::Builder builder; + builder.Add(OutgoingSSNResetRequestParameter( + kPeerInitialReqSn, ReconfigRequestSN(3), AddTo(kPeerInitialTsn, 1), + {StreamID(1)})); + + std::vector responses = + HandleAndCatchResponse(ReConfigChunk(builder.Build())); + EXPECT_THAT(responses, SizeIs(1)); + EXPECT_EQ(responses[0].result(), ResponseResult::kSuccessPerformed); + + gen_.ResetStream(); + reasm_.Add(AddTo(kPeerInitialTsn, 2), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_THAT(reasm_.FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(53), kShortPayload))); +} + +TEST_F(StreamResetHandlerTest, ResetStreamsDeferred) { + DataGeneratorOptions opts; + opts.message_id = MID(0); + reasm_.Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + + opts.message_id = MID(1); + reasm_.Add(AddTo(kPeerInitialTsn, 1), gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + + buf_.Observe(kPeerInitialTsn); + buf_.Observe(AddTo(kPeerInitialTsn, 1)); + EXPECT_THAT(reasm_.FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(53), kShortPayload), + SctpMessageIs(StreamID(1), PPID(53), kShortPayload))); + + Parameters::Builder builder; + builder.Add(OutgoingSSNResetRequestParameter( + kPeerInitialReqSn, ReconfigRequestSN(3), AddTo(kPeerInitialTsn, 3), + {StreamID(1)})); + + std::vector responses = + HandleAndCatchResponse(ReConfigChunk(builder.Build())); + EXPECT_THAT(responses, SizeIs(1)); + EXPECT_EQ(responses[0].result(), ResponseResult::kInProgress); + + opts.message_id = MID(1); + opts.ppid = PPID(5); + reasm_.Add(AddTo(kPeerInitialTsn, 5), gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + reasm_.MaybeResetStreamsDeferred(AddTo(kPeerInitialTsn, 1)); + + opts.message_id = MID(0); + opts.ppid = PPID(4); + reasm_.Add(AddTo(kPeerInitialTsn, 4), gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + reasm_.MaybeResetStreamsDeferred(AddTo(kPeerInitialTsn, 1)); + + opts.message_id = MID(3); + opts.ppid = PPID(3); + reasm_.Add(AddTo(kPeerInitialTsn, 3), gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + reasm_.MaybeResetStreamsDeferred(AddTo(kPeerInitialTsn, 1)); + + opts.message_id = MID(2); + opts.ppid = PPID(2); + reasm_.Add(AddTo(kPeerInitialTsn, 2), gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + reasm_.MaybeResetStreamsDeferred(AddTo(kPeerInitialTsn, 5)); + + EXPECT_THAT( + reasm_.FlushMessages(), + UnorderedElementsAre(SctpMessageIs(StreamID(1), PPID(2), kShortPayload), + SctpMessageIs(StreamID(1), PPID(3), kShortPayload), + SctpMessageIs(StreamID(1), PPID(4), kShortPayload), + SctpMessageIs(StreamID(1), PPID(5), kShortPayload))); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingRequestDirectly) { + EXPECT_CALL(producer_, PrepareResetStreams).Times(1); + handler_.ResetStreams(std::vector({StreamID(42)})); + + EXPECT_CALL(producer_, CanResetStreams()).WillOnce(Return(true)); + absl::optional reconfig = handler_.MakeStreamResetRequest(); + ASSERT_TRUE(reconfig.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig->parameters().get()); + + EXPECT_EQ(req.request_sequence_number(), kMyInitialReqSn); + EXPECT_EQ(req.sender_last_assigned_tsn(), + TSN(*retransmission_queue_.next_tsn() - 1)); + EXPECT_THAT(req.stream_ids(), UnorderedElementsAre(StreamID(42))); +} + +TEST_F(StreamResetHandlerTest, ResetMultipleStreamsInOneRequest) { + EXPECT_CALL(producer_, PrepareResetStreams).Times(3); + handler_.ResetStreams(std::vector({StreamID(42)})); + handler_.ResetStreams( + std::vector({StreamID(43), StreamID(44), StreamID(41)})); + handler_.ResetStreams(std::vector({StreamID(42), StreamID(40)})); + + EXPECT_CALL(producer_, CanResetStreams()).WillOnce(Return(true)); + absl::optional reconfig = handler_.MakeStreamResetRequest(); + ASSERT_TRUE(reconfig.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig->parameters().get()); + + EXPECT_EQ(req.request_sequence_number(), kMyInitialReqSn); + EXPECT_EQ(req.sender_last_assigned_tsn(), + TSN(*retransmission_queue_.next_tsn() - 1)); + EXPECT_THAT(req.stream_ids(), + UnorderedElementsAre(StreamID(40), StreamID(41), StreamID(42), + StreamID(43), StreamID(44))); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingRequestDeferred) { + EXPECT_CALL(producer_, PrepareResetStreams).Times(1); + handler_.ResetStreams(std::vector({StreamID(42)})); + + EXPECT_CALL(producer_, CanResetStreams()) + .WillOnce(Return(false)) + .WillOnce(Return(false)) + .WillOnce(Return(true)); + + EXPECT_FALSE(handler_.MakeStreamResetRequest().has_value()); + EXPECT_FALSE(handler_.MakeStreamResetRequest().has_value()); + EXPECT_TRUE(handler_.MakeStreamResetRequest().has_value()); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingResettingOnPositiveResponse) { + EXPECT_CALL(producer_, PrepareResetStreams).Times(1); + handler_.ResetStreams(std::vector({StreamID(42)})); + + EXPECT_CALL(producer_, CanResetStreams()).WillOnce(Return(true)); + + absl::optional reconfig = handler_.MakeStreamResetRequest(); + ASSERT_TRUE(reconfig.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig->parameters().get()); + + Parameters::Builder builder; + builder.Add(ReconfigurationResponseParameter( + req.request_sequence_number(), ResponseResult::kSuccessPerformed)); + ReConfigChunk response_reconfig(builder.Build()); + + EXPECT_CALL(producer_, CommitResetStreams()).Times(1); + EXPECT_CALL(producer_, RollbackResetStreams()).Times(0); + + // Processing a response shouldn't result in sending anything. + EXPECT_CALL(callbacks_, OnError).Times(0); + EXPECT_CALL(callbacks_, SendPacket).Times(0); + handler_.HandleReConfig(std::move(response_reconfig)); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingResetRollbackOnError) { + EXPECT_CALL(producer_, PrepareResetStreams).Times(1); + handler_.ResetStreams(std::vector({StreamID(42)})); + + EXPECT_CALL(producer_, CanResetStreams()).WillOnce(Return(true)); + + absl::optional reconfig = handler_.MakeStreamResetRequest(); + ASSERT_TRUE(reconfig.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig->parameters().get()); + + Parameters::Builder builder; + builder.Add(ReconfigurationResponseParameter( + req.request_sequence_number(), ResponseResult::kErrorBadSequenceNumber)); + ReConfigChunk response_reconfig(builder.Build()); + + EXPECT_CALL(producer_, CommitResetStreams()).Times(0); + EXPECT_CALL(producer_, RollbackResetStreams()).Times(1); + + // Only requests should result in sending responses. + EXPECT_CALL(callbacks_, OnError).Times(0); + EXPECT_CALL(callbacks_, SendPacket).Times(0); + handler_.HandleReConfig(std::move(response_reconfig)); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingResetRetransmitOnInProgress) { + static constexpr StreamID kStreamToReset = StreamID(42); + + EXPECT_CALL(producer_, PrepareResetStreams).Times(1); + handler_.ResetStreams(std::vector({kStreamToReset})); + + EXPECT_CALL(producer_, CanResetStreams()).WillOnce(Return(true)); + + absl::optional reconfig1 = handler_.MakeStreamResetRequest(); + ASSERT_TRUE(reconfig1.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req1, + reconfig1->parameters().get()); + + // Simulate that the peer responded "In Progress". + Parameters::Builder builder; + builder.Add(ReconfigurationResponseParameter(req1.request_sequence_number(), + ResponseResult::kInProgress)); + ReConfigChunk response_reconfig(builder.Build()); + + EXPECT_CALL(producer_, CommitResetStreams()).Times(0); + EXPECT_CALL(producer_, RollbackResetStreams()).Times(0); + + // Processing a response shouldn't result in sending anything. + EXPECT_CALL(callbacks_, OnError).Times(0); + EXPECT_CALL(callbacks_, SendPacket).Times(0); + handler_.HandleReConfig(std::move(response_reconfig)); + + // Let some time pass, so that the reconfig timer expires, and retries the + // same request. + EXPECT_CALL(callbacks_, SendPacket).Times(1); + AdvanceTime(kRto); + + std::vector payload = callbacks_.ConsumeSentPacket(); + ASSERT_FALSE(payload.empty()); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload)); + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + ASSERT_HAS_VALUE_AND_ASSIGN( + ReConfigChunk reconfig2, + ReConfigChunk::Parse(packet.descriptors()[0].data)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req2, + reconfig2.parameters().get()); + + EXPECT_EQ(req2.request_sequence_number(), + AddTo(req1.request_sequence_number(), 1)); + EXPECT_THAT(req2.stream_ids(), UnorderedElementsAre(kStreamToReset)); +} + +TEST_F(StreamResetHandlerTest, ResetWhileRequestIsSentWillQueue) { + EXPECT_CALL(producer_, PrepareResetStreams).Times(1); + handler_.ResetStreams(std::vector({StreamID(42)})); + + EXPECT_CALL(producer_, CanResetStreams()).WillOnce(Return(true)); + absl::optional reconfig1 = handler_.MakeStreamResetRequest(); + ASSERT_TRUE(reconfig1.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req1, + reconfig1->parameters().get()); + EXPECT_EQ(req1.request_sequence_number(), kMyInitialReqSn); + EXPECT_EQ(req1.sender_last_assigned_tsn(), + AddTo(retransmission_queue_.next_tsn(), -1)); + EXPECT_THAT(req1.stream_ids(), UnorderedElementsAre(StreamID(42))); + + // Streams reset while the request is in-flight will be queued. + StreamID stream_ids[] = {StreamID(41), StreamID(43)}; + handler_.ResetStreams(stream_ids); + EXPECT_EQ(handler_.MakeStreamResetRequest(), absl::nullopt); + + Parameters::Builder builder; + builder.Add(ReconfigurationResponseParameter( + req1.request_sequence_number(), ResponseResult::kSuccessPerformed)); + ReConfigChunk response_reconfig(builder.Build()); + + EXPECT_CALL(producer_, CommitResetStreams()).Times(1); + EXPECT_CALL(producer_, RollbackResetStreams()).Times(0); + + // Processing a response shouldn't result in sending anything. + EXPECT_CALL(callbacks_, OnError).Times(0); + EXPECT_CALL(callbacks_, SendPacket).Times(0); + handler_.HandleReConfig(std::move(response_reconfig)); + + // Response has been processed. A new request can be sent. + EXPECT_CALL(producer_, CanResetStreams()).WillOnce(Return(true)); + absl::optional reconfig2 = handler_.MakeStreamResetRequest(); + ASSERT_TRUE(reconfig2.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req2, + reconfig2->parameters().get()); + EXPECT_EQ(req2.request_sequence_number(), AddTo(kMyInitialReqSn, 1)); + EXPECT_EQ(req2.sender_last_assigned_tsn(), + TSN(*retransmission_queue_.next_tsn() - 1)); + EXPECT_THAT(req2.stream_ids(), + UnorderedElementsAre(StreamID(41), StreamID(43))); +} + +TEST_F(StreamResetHandlerTest, SendIncomingResetJustReturnsNothingPerformed) { + Parameters::Builder builder; + builder.Add( + IncomingSSNResetRequestParameter(kPeerInitialReqSn, {StreamID(1)})); + + std::vector responses = + HandleAndCatchResponse(ReConfigChunk(builder.Build())); + ASSERT_THAT(responses, SizeIs(1)); + EXPECT_THAT(responses[0].response_sequence_number(), kPeerInitialReqSn); + EXPECT_THAT(responses[0].result(), ResponseResult::kSuccessNothingToDo); +} + +TEST_F(StreamResetHandlerTest, SendSameRequestTwiceReturnsNothingToDo) { + reasm_.Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE")); + reasm_.Add(AddTo(kPeerInitialTsn, 1), gen_.Ordered({1, 2, 3, 4}, "BE")); + + buf_.Observe(kPeerInitialTsn); + buf_.Observe(AddTo(kPeerInitialTsn, 1)); + EXPECT_THAT(reasm_.FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(53), kShortPayload), + SctpMessageIs(StreamID(1), PPID(53), kShortPayload))); + + Parameters::Builder builder1; + builder1.Add(OutgoingSSNResetRequestParameter( + kPeerInitialReqSn, ReconfigRequestSN(3), AddTo(kPeerInitialTsn, 1), + {StreamID(1)})); + + std::vector responses1 = + HandleAndCatchResponse(ReConfigChunk(builder1.Build())); + EXPECT_THAT(responses1, SizeIs(1)); + EXPECT_EQ(responses1[0].result(), ResponseResult::kSuccessPerformed); + + Parameters::Builder builder2; + builder2.Add(OutgoingSSNResetRequestParameter( + kPeerInitialReqSn, ReconfigRequestSN(3), AddTo(kPeerInitialTsn, 1), + {StreamID(1)})); + + std::vector responses2 = + HandleAndCatchResponse(ReConfigChunk(builder2.Build())); + EXPECT_THAT(responses2, SizeIs(1)); + EXPECT_EQ(responses2[0].result(), ResponseResult::kSuccessNothingToDo); +} +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/socket/transmission_control_block.cc b/net/dcsctp/socket/transmission_control_block.cc new file mode 100644 index 0000000000..4fde40cee9 --- /dev/null +++ b/net/dcsctp/socket/transmission_control_block.cc @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/transmission_control_block.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/capabilities.h" +#include "net/dcsctp/socket/stream_reset_handler.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "net/dcsctp/tx/retransmission_timeout.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +void TransmissionControlBlock::ObserveRTT(DurationMs rtt) { + DurationMs prev_rto = rto_.rto(); + rto_.ObserveRTT(rtt); + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "new rtt=" << *rtt + << ", srtt=" << *rto_.srtt() << ", rto=" << *rto_.rto() + << " (" << *prev_rto << ")"; + t3_rtx_->set_duration(rto_.rto()); + + DurationMs delayed_ack_tmo = + std::min(rto_.rto() * 0.5, options_.delayed_ack_max_timeout); + delayed_ack_timer_->set_duration(delayed_ack_tmo); +} + +absl::optional TransmissionControlBlock::OnRtxTimerExpiry() { + TimeMs now = callbacks_.TimeMillis(); + RTC_DLOG(LS_INFO) << log_prefix_ << "Timer " << t3_rtx_->name() + << " has expired"; + if (cookie_echo_chunk_.has_value()) { + // In the COOKIE_ECHO state, let the T1-COOKIE timer trigger + // retransmissions, to avoid having two timers doing that. + RTC_DLOG(LS_VERBOSE) << "Not retransmitting as T1-cookie is active."; + } else { + if (IncrementTxErrorCounter("t3-rtx expired")) { + retransmission_queue_.HandleT3RtxTimerExpiry(); + SendBufferedPackets(now); + } + } + return absl::nullopt; +} + +absl::optional TransmissionControlBlock::OnDelayedAckTimerExpiry() { + data_tracker_.HandleDelayedAckTimerExpiry(); + MaybeSendSack(); + return absl::nullopt; +} + +void TransmissionControlBlock::MaybeSendSack() { + if (data_tracker_.ShouldSendAck(/*also_if_delayed=*/false)) { + SctpPacket::Builder builder = PacketBuilder(); + builder.Add( + data_tracker_.CreateSelectiveAck(reassembly_queue_.remaining_bytes())); + Send(builder); + } +} + +void TransmissionControlBlock::SendBufferedPackets(SctpPacket::Builder& builder, + TimeMs now) { + for (int packet_idx = 0;; ++packet_idx) { + // Only add control chunks to the first packet that is sent, if sending + // multiple packets in one go (as allowed by the congestion window). + if (packet_idx == 0) { + if (cookie_echo_chunk_.has_value()) { + // https://tools.ietf.org/html/rfc4960#section-5.1 + // "The COOKIE ECHO chunk can be bundled with any pending outbound DATA + // chunks, but it MUST be the first chunk in the packet..." + RTC_DCHECK(builder.empty()); + builder.Add(*cookie_echo_chunk_); + } + + // https://tools.ietf.org/html/rfc4960#section-6 + // "Before an endpoint transmits a DATA chunk, if any received DATA + // chunks have not been acknowledged (e.g., due to delayed ack), the + // sender should create a SACK and bundle it with the outbound DATA chunk, + // as long as the size of the final SCTP packet does not exceed the + // current MTU." + if (data_tracker_.ShouldSendAck(/*also_if_delayed=*/true)) { + builder.Add(data_tracker_.CreateSelectiveAck( + reassembly_queue_.remaining_bytes())); + } + if (retransmission_queue_.ShouldSendForwardTsn(now)) { + if (capabilities_.message_interleaving) { + builder.Add(retransmission_queue_.CreateIForwardTsn()); + } else { + builder.Add(retransmission_queue_.CreateForwardTsn()); + } + } + absl::optional reconfig = + stream_reset_handler_.MakeStreamResetRequest(); + if (reconfig.has_value()) { + builder.Add(*reconfig); + } + } + + auto chunks = + retransmission_queue_.GetChunksToSend(now, builder.bytes_remaining()); + for (auto& elem : chunks) { + TSN tsn = elem.first; + Data data = std::move(elem.second); + if (capabilities_.message_interleaving) { + builder.Add(IDataChunk(tsn, std::move(data), false)); + } else { + builder.Add(DataChunk(tsn, std::move(data), false)); + } + } + if (builder.empty()) { + break; + } + Send(builder); + + if (cookie_echo_chunk_.has_value()) { + // https://tools.ietf.org/html/rfc4960#section-5.1 + // "... until the COOKIE ACK is returned the sender MUST NOT send any + // other packets to the peer." + break; + } + } +} + +std::string TransmissionControlBlock::ToString() const { + rtc::StringBuilder sb; + + sb.AppendFormat( + "verification_tag=%08x, last_cumulative_ack=%u, capabilities=", + *peer_verification_tag_, *data_tracker_.last_cumulative_acked_tsn()); + + if (capabilities_.partial_reliability) { + sb << "PR,"; + } + if (capabilities_.message_interleaving) { + sb << "IL,"; + } + if (capabilities_.reconfig) { + sb << "Reconfig,"; + } + + return sb.Release(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/socket/transmission_control_block.h b/net/dcsctp/socket/transmission_control_block.h new file mode 100644 index 0000000000..172f7c0c08 --- /dev/null +++ b/net/dcsctp/socket/transmission_control_block.h @@ -0,0 +1,220 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_TRANSMISSION_CONTROL_BLOCK_H_ +#define NET_DCSCTP_SOCKET_TRANSMISSION_CONTROL_BLOCK_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/capabilities.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/socket/heartbeat_handler.h" +#include "net/dcsctp/socket/stream_reset_handler.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_error_counter.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "net/dcsctp/tx/retransmission_timeout.h" +#include "net/dcsctp/tx/send_queue.h" + +namespace dcsctp { + +// The TransmissionControlBlock (TCB) represents an open connection to a peer, +// and holds all the resources for that. If the connection is e.g. shutdown, +// closed or restarted, this object will be deleted and/or replaced. +class TransmissionControlBlock : public Context { + public: + TransmissionControlBlock(TimerManager& timer_manager, + absl::string_view log_prefix, + const DcSctpOptions& options, + const Capabilities& capabilities, + DcSctpSocketCallbacks& callbacks, + SendQueue& send_queue, + VerificationTag my_verification_tag, + TSN my_initial_tsn, + VerificationTag peer_verification_tag, + TSN peer_initial_tsn, + size_t a_rwnd, + TieTag tie_tag, + std::function is_connection_established, + std::function send_fn) + : log_prefix_(log_prefix), + options_(options), + timer_manager_(timer_manager), + capabilities_(capabilities), + callbacks_(callbacks), + t3_rtx_(timer_manager_.CreateTimer( + "t3-rtx", + [this]() { return OnRtxTimerExpiry(); }, + TimerOptions(options.rto_initial))), + delayed_ack_timer_(timer_manager_.CreateTimer( + "delayed-ack", + [this]() { return OnDelayedAckTimerExpiry(); }, + TimerOptions(options.delayed_ack_max_timeout, + TimerBackoffAlgorithm::kExponential, + /*max_restarts=*/0))), + my_verification_tag_(my_verification_tag), + my_initial_tsn_(my_initial_tsn), + peer_verification_tag_(peer_verification_tag), + peer_initial_tsn_(peer_initial_tsn), + tie_tag_(tie_tag), + is_connection_established_(std::move(is_connection_established)), + send_fn_(std::move(send_fn)), + rto_(options), + tx_error_counter_(log_prefix, options), + data_tracker_(log_prefix, delayed_ack_timer_.get(), peer_initial_tsn), + reassembly_queue_(log_prefix, + peer_initial_tsn, + options.max_receiver_window_buffer_size), + retransmission_queue_( + log_prefix, + my_initial_tsn, + a_rwnd, + send_queue, + [this](DurationMs rtt) { return ObserveRTT(rtt); }, + [this]() { tx_error_counter_.Clear(); }, + *t3_rtx_, + options, + capabilities.partial_reliability, + capabilities.message_interleaving), + stream_reset_handler_(log_prefix, + this, + &timer_manager, + &data_tracker_, + &reassembly_queue_, + &retransmission_queue_), + heartbeat_handler_(log_prefix, options, this, &timer_manager_) {} + + // Implementation of `Context`. + bool is_connection_established() const override { + return is_connection_established_(); + } + TSN my_initial_tsn() const override { return my_initial_tsn_; } + TSN peer_initial_tsn() const override { return peer_initial_tsn_; } + DcSctpSocketCallbacks& callbacks() const override { return callbacks_; } + void ObserveRTT(DurationMs rtt) override; + DurationMs current_rto() const override { return rto_.rto(); } + bool IncrementTxErrorCounter(absl::string_view reason) override { + return tx_error_counter_.Increment(reason); + } + void ClearTxErrorCounter() override { tx_error_counter_.Clear(); } + SctpPacket::Builder PacketBuilder() const override { + return SctpPacket::Builder(peer_verification_tag_, options_); + } + bool HasTooManyTxErrors() const override { + return tx_error_counter_.IsExhausted(); + } + void Send(SctpPacket::Builder& builder) override { send_fn_(builder); } + + // Other accessors + DataTracker& data_tracker() { return data_tracker_; } + ReassemblyQueue& reassembly_queue() { return reassembly_queue_; } + RetransmissionQueue& retransmission_queue() { return retransmission_queue_; } + StreamResetHandler& stream_reset_handler() { return stream_reset_handler_; } + HeartbeatHandler& heartbeat_handler() { return heartbeat_handler_; } + + // Returns this socket's verification tag, set in all packet headers. + VerificationTag my_verification_tag() const { return my_verification_tag_; } + // Returns the peer's verification tag, which should be in received packets. + VerificationTag peer_verification_tag() const { + return peer_verification_tag_; + } + // All negotiated supported capabilities. + const Capabilities& capabilities() const { return capabilities_; } + // A 64-bit tie-tag, used to e.g. detect reconnections. + TieTag tie_tag() const { return tie_tag_; } + + // Sends a SACK, if there is a need to. + void MaybeSendSack(); + + // Will be set while the socket is in kCookieEcho state. In this state, there + // can only be a single packet outstanding, and it must contain the COOKIE + // ECHO chunk as the first chunk in that packet, until the COOKIE ACK has been + // received, which will make the socket call `ClearCookieEchoChunk`. + void SetCookieEchoChunk(CookieEchoChunk chunk) { + cookie_echo_chunk_ = std::move(chunk); + } + + // Called when the COOKIE ACK chunk has been received, to allow further + // packets to be sent. + void ClearCookieEchoChunk() { cookie_echo_chunk_ = absl::nullopt; } + + bool has_cookie_echo_chunk() const { return cookie_echo_chunk_.has_value(); } + + // Fills `builder` (which may already be filled with control chunks) with + // other control and data chunks, and sends packets as much as can be + // allowed by the congestion control algorithm. + void SendBufferedPackets(SctpPacket::Builder& builder, TimeMs now); + + // As above, but without passing in a builder. If `cookie_echo_chunk_` is + // present, then only one packet will be sent, with this chunk as the first + // chunk. + void SendBufferedPackets(TimeMs now) { + SctpPacket::Builder builder(peer_verification_tag_, options_); + SendBufferedPackets(builder, now); + } + + // Returns a textual representation of this object, for logging. + std::string ToString() const; + + private: + // Will be called when the retransmission timer (t3-rtx) expires. + absl::optional OnRtxTimerExpiry(); + // Will be called when the delayed ack timer expires. + absl::optional OnDelayedAckTimerExpiry(); + + const std::string log_prefix_; + const DcSctpOptions options_; + TimerManager& timer_manager_; + // Negotiated capabilities that both peers support. + const Capabilities capabilities_; + DcSctpSocketCallbacks& callbacks_; + // The data retransmission timer, called t3-rtx in SCTP. + const std::unique_ptr t3_rtx_; + // Delayed ack timer, which triggers when acks should be sent (when delayed). + const std::unique_ptr delayed_ack_timer_; + const VerificationTag my_verification_tag_; + const TSN my_initial_tsn_; + const VerificationTag peer_verification_tag_; + const TSN peer_initial_tsn_; + // Nonce, used to detect reconnections. + const TieTag tie_tag_; + const std::function is_connection_established_; + const std::function send_fn_; + + RetransmissionTimeout rto_; + RetransmissionErrorCounter tx_error_counter_; + DataTracker data_tracker_; + ReassemblyQueue reassembly_queue_; + RetransmissionQueue retransmission_queue_; + StreamResetHandler stream_reset_handler_; + HeartbeatHandler heartbeat_handler_; + + // Only valid when the socket state == State::kCookieEchoed. In this state, + // the socket must wait for COOKIE ACK to continue sending any packets (not + // including a COOKIE ECHO). So if `cookie_echo_chunk_` is present, the + // SendBufferedChunks will always only just send one packet, with this chunk + // as the first chunk in the packet. + absl::optional cookie_echo_chunk_ = absl::nullopt; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_TRANSMISSION_CONTROL_BLOCK_H_ diff --git a/net/dcsctp/testing/BUILD.gn b/net/dcsctp/testing/BUILD.gn new file mode 100644 index 0000000000..5367ef8c6f --- /dev/null +++ b/net/dcsctp/testing/BUILD.gn @@ -0,0 +1,35 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_source_set("testing_macros") { + testonly = true + sources = [ "testing_macros.h" ] +} + +rtc_library("data_generator") { + testonly = true + deps = [ + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:internal_types", + "../packet:data", + "../public:types", + ] + sources = [ + "data_generator.cc", + "data_generator.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} diff --git a/net/dcsctp/testing/data_generator.cc b/net/dcsctp/testing/data_generator.cc new file mode 100644 index 0000000000..e4f9f91384 --- /dev/null +++ b/net/dcsctp/testing/data_generator.cc @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/testing/data_generator.h" + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { +constexpr PPID kPpid = PPID(53); + +Data DataGenerator::Ordered(std::vector payload, + absl::string_view flags, + const DataGeneratorOptions opts) { + Data::IsBeginning is_beginning(flags.find('B') != std::string::npos); + Data::IsEnd is_end(flags.find('E') != std::string::npos); + + if (is_beginning) { + fsn_ = FSN(0); + } else { + fsn_ = FSN(*fsn_ + 1); + } + MID message_id = opts.message_id.value_or(message_id_); + Data ret = Data(opts.stream_id, SSN(static_cast(*message_id)), + message_id, fsn_, opts.ppid, std::move(payload), is_beginning, + is_end, IsUnordered(false)); + + if (is_end) { + message_id_ = MID(*message_id + 1); + } + return ret; +} + +Data DataGenerator::Unordered(std::vector payload, + absl::string_view flags, + const DataGeneratorOptions opts) { + Data::IsBeginning is_beginning(flags.find('B') != std::string::npos); + Data::IsEnd is_end(flags.find('E') != std::string::npos); + + if (is_beginning) { + fsn_ = FSN(0); + } else { + fsn_ = FSN(*fsn_ + 1); + } + MID message_id = opts.message_id.value_or(message_id_); + Data ret = Data(opts.stream_id, SSN(0), message_id, fsn_, kPpid, + std::move(payload), is_beginning, is_end, IsUnordered(true)); + if (is_end) { + message_id_ = MID(*message_id + 1); + } + return ret; +} +} // namespace dcsctp diff --git a/net/dcsctp/testing/data_generator.h b/net/dcsctp/testing/data_generator.h new file mode 100644 index 0000000000..859450b1c3 --- /dev/null +++ b/net/dcsctp/testing/data_generator.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TESTING_DATA_GENERATOR_H_ +#define NET_DCSCTP_TESTING_DATA_GENERATOR_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/data.h" + +namespace dcsctp { + +struct DataGeneratorOptions { + StreamID stream_id = StreamID(1); + absl::optional message_id = absl::nullopt; + PPID ppid = PPID(53); +}; + +// Generates Data with correct sequence numbers, and used only in unit tests. +class DataGenerator { + public: + explicit DataGenerator(MID start_message_id = MID(0)) + : message_id_(start_message_id) {} + + // Generates ordered "data" with the provided `payload` and flags, which can + // contain "B" for setting the "is_beginning" flag, and/or "E" for setting the + // "is_end" flag. + Data Ordered(std::vector payload, + absl::string_view flags = "", + const DataGeneratorOptions opts = {}); + + // Generates unordered "data" with the provided `payload` and flags, which can + // contain "B" for setting the "is_beginning" flag, and/or "E" for setting the + // "is_end" flag. + Data Unordered(std::vector payload, + absl::string_view flags = "", + const DataGeneratorOptions opts = {}); + + // Resets the Message ID identifier - simulating a "stream reset". + void ResetStream() { message_id_ = MID(0); } + + private: + MID message_id_; + FSN fsn_ = FSN(0); +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TESTING_DATA_GENERATOR_H_ diff --git a/net/dcsctp/testing/testing_macros.h b/net/dcsctp/testing/testing_macros.h new file mode 100644 index 0000000000..5cbdfffdce --- /dev/null +++ b/net/dcsctp/testing/testing_macros.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TESTING_TESTING_MACROS_H_ +#define NET_DCSCTP_TESTING_TESTING_MACROS_H_ + +#include + +namespace dcsctp { + +#define DCSCTP_CONCAT_INNER_(x, y) x##y +#define DCSCTP_CONCAT_(x, y) DCSCTP_CONCAT_INNER_(x, y) + +// Similar to ASSERT_OK_AND_ASSIGN, this works with an absl::optional<> instead +// of an absl::StatusOr<>. +#define ASSERT_HAS_VALUE_AND_ASSIGN(lhs, rexpr) \ + auto DCSCTP_CONCAT_(tmp_opt_val__, __LINE__) = rexpr; \ + ASSERT_TRUE(DCSCTP_CONCAT_(tmp_opt_val__, __LINE__).has_value()); \ + lhs = *std::move(DCSCTP_CONCAT_(tmp_opt_val__, __LINE__)); + +} // namespace dcsctp + +#endif // NET_DCSCTP_TESTING_TESTING_MACROS_H_ diff --git a/net/dcsctp/timer/BUILD.gn b/net/dcsctp/timer/BUILD.gn new file mode 100644 index 0000000000..a0ba5b030e --- /dev/null +++ b/net/dcsctp/timer/BUILD.gn @@ -0,0 +1,73 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_library("timer") { + deps = [ + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../public:socket", + "../public:strong_alias", + "../public:types", + ] + sources = [ + "fake_timeout.h", + "timer.cc", + "timer.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("task_queue_timeout") { + deps = [ + "../../../api:array_view", + "../../../api/task_queue:task_queue", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../../../rtc_base/task_utils:pending_task_safety_flag", + "../../../rtc_base/task_utils:to_queued_task", + "../public:socket", + "../public:strong_alias", + "../public:types", + ] + sources = [ + "task_queue_timeout.cc", + "task_queue_timeout.h", + ] +} + +if (rtc_include_tests) { + rtc_library("dcsctp_timer_unittests") { + testonly = true + + defines = [] + deps = [ + ":task_queue_timeout", + ":timer", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../rtc_base:rtc_base_approved", + "../../../test:test_support", + "../../../test/time_controller:time_controller", + "../public:socket", + ] + sources = [ + "task_queue_timeout_test.cc", + "timer_test.cc", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + } +} diff --git a/net/dcsctp/timer/fake_timeout.h b/net/dcsctp/timer/fake_timeout.h new file mode 100644 index 0000000000..927e6b2808 --- /dev/null +++ b/net/dcsctp/timer/fake_timeout.h @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TIMER_FAKE_TIMEOUT_H_ +#define NET_DCSCTP_TIMER_FAKE_TIMEOUT_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "net/dcsctp/public/timeout.h" +#include "rtc_base/checks.h" + +namespace dcsctp { + +// A timeout used in tests. +class FakeTimeout : public Timeout { + public: + explicit FakeTimeout(std::function get_time, + std::function on_delete) + : get_time_(std::move(get_time)), on_delete_(std::move(on_delete)) {} + + ~FakeTimeout() override { on_delete_(this); } + + void Start(DurationMs duration_ms, TimeoutID timeout_id) override { + RTC_DCHECK(expiry_ == TimeMs::InfiniteFuture()); + timeout_id_ = timeout_id; + expiry_ = get_time_() + duration_ms; + } + void Stop() override { + RTC_DCHECK(expiry_ != TimeMs::InfiniteFuture()); + expiry_ = TimeMs::InfiniteFuture(); + } + + bool EvaluateHasExpired(TimeMs now) { + if (now >= expiry_) { + expiry_ = TimeMs::InfiniteFuture(); + return true; + } + return false; + } + + TimeoutID timeout_id() const { return timeout_id_; } + + private: + const std::function get_time_; + const std::function on_delete_; + + TimeoutID timeout_id_ = TimeoutID(0); + TimeMs expiry_ = TimeMs::InfiniteFuture(); +}; + +class FakeTimeoutManager { + public: + // The `get_time` function must return the current time, relative to any + // epoch. + explicit FakeTimeoutManager(std::function get_time) + : get_time_(std::move(get_time)) {} + + std::unique_ptr CreateTimeout() { + auto timer = std::make_unique( + get_time_, [this](FakeTimeout* timer) { timers_.erase(timer); }); + timers_.insert(timer.get()); + return timer; + } + + // NOTE: This can't return a vector, as calling EvaluateHasExpired requires + // calling socket->HandleTimeout directly afterwards, as the owning Timer + // still believes it's running, and it needs to be updated to set + // Timer::is_running_ to false before you operate on the Timer or Timeout + // again. + absl::optional GetNextExpiredTimeout() { + TimeMs now = get_time_(); + std::vector expired_timers; + for (auto& timer : timers_) { + if (timer->EvaluateHasExpired(now)) { + return timer->timeout_id(); + } + } + return absl::nullopt; + } + + private: + const std::function get_time_; + std::unordered_set timers_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_TIMER_FAKE_TIMEOUT_H_ diff --git a/net/dcsctp/timer/task_queue_timeout.cc b/net/dcsctp/timer/task_queue_timeout.cc new file mode 100644 index 0000000000..6d3054eeb8 --- /dev/null +++ b/net/dcsctp/timer/task_queue_timeout.cc @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/timer/task_queue_timeout.h" + +#include "rtc_base/logging.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/task_utils/to_queued_task.h" + +namespace dcsctp { + +TaskQueueTimeoutFactory::TaskQueueTimeout::TaskQueueTimeout( + TaskQueueTimeoutFactory& parent) + : parent_(parent), + pending_task_safety_flag_(webrtc::PendingTaskSafetyFlag::Create()) {} + +TaskQueueTimeoutFactory::TaskQueueTimeout::~TaskQueueTimeout() { + RTC_DCHECK_RUN_ON(&parent_.thread_checker_); + pending_task_safety_flag_->SetNotAlive(); +} + +void TaskQueueTimeoutFactory::TaskQueueTimeout::Start(DurationMs duration_ms, + TimeoutID timeout_id) { + RTC_DCHECK_RUN_ON(&parent_.thread_checker_); + RTC_DCHECK(timeout_expiration_ == TimeMs::InfiniteFuture()); + timeout_expiration_ = parent_.get_time_() + duration_ms; + timeout_id_ = timeout_id; + + if (timeout_expiration_ >= posted_task_expiration_) { + // There is already a running task, and it's scheduled to expire sooner than + // the new expiration time. Don't do anything; The `timeout_expiration_` has + // already been updated and if the delayed task _does_ expire and the timer + // hasn't been stopped, that will be noticed in the timeout handler, and the + // task will be re-scheduled. Most timers are stopped before they expire. + return; + } + + if (posted_task_expiration_ != TimeMs::InfiniteFuture()) { + RTC_DLOG(LS_VERBOSE) << "New timeout duration is less than scheduled - " + "ghosting old delayed task."; + // There is already a scheduled delayed task, but its expiration time is + // further away than the new expiration, so it can't be used. It will be + // "killed" by replacing the safety flag. This is not expected to happen + // especially often; Mainly when a timer did exponential backoff and + // later recovered. + pending_task_safety_flag_->SetNotAlive(); + pending_task_safety_flag_ = webrtc::PendingTaskSafetyFlag::Create(); + } + + posted_task_expiration_ = timeout_expiration_; + parent_.task_queue_.PostDelayedTask( + webrtc::ToQueuedTask( + pending_task_safety_flag_, + [timeout_id, this]() { + RTC_DLOG(LS_VERBOSE) << "Timout expired: " << timeout_id.value(); + RTC_DCHECK_RUN_ON(&parent_.thread_checker_); + RTC_DCHECK(posted_task_expiration_ != TimeMs::InfiniteFuture()); + posted_task_expiration_ = TimeMs::InfiniteFuture(); + + if (timeout_expiration_ == TimeMs::InfiniteFuture()) { + // The timeout was stopped before it expired. Very common. + } else { + // Note that the timeout might have been restarted, which updated + // `timeout_expiration_` but left the scheduled task running. So + // if it's not quite time to trigger the timeout yet, schedule a + // new delayed task with what's remaining and retry at that point + // in time. + DurationMs remaining = timeout_expiration_ - parent_.get_time_(); + timeout_expiration_ = TimeMs::InfiniteFuture(); + if (*remaining > 0) { + Start(remaining, timeout_id_); + } else { + // It has actually triggered. + RTC_DLOG(LS_VERBOSE) + << "Timout triggered: " << timeout_id.value(); + parent_.on_expired_(timeout_id_); + } + } + }), + duration_ms.value()); +} + +void TaskQueueTimeoutFactory::TaskQueueTimeout::Stop() { + // As the TaskQueue doesn't support deleting a posted task, just mark the + // timeout as not running. + RTC_DCHECK_RUN_ON(&parent_.thread_checker_); + timeout_expiration_ = TimeMs::InfiniteFuture(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/timer/task_queue_timeout.h b/net/dcsctp/timer/task_queue_timeout.h new file mode 100644 index 0000000000..e8d12df592 --- /dev/null +++ b/net/dcsctp/timer/task_queue_timeout.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TIMER_TASK_QUEUE_TIMEOUT_H_ +#define NET_DCSCTP_TIMER_TASK_QUEUE_TIMEOUT_H_ + +#include +#include + +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/public/timeout.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" + +namespace dcsctp { + +// The TaskQueueTimeoutFactory creates `Timeout` instances, which schedules +// itself to be triggered on the provided `task_queue`, which may be a thread, +// an actual TaskQueue or something else which supports posting a delayed task. +// +// Note that each `DcSctpSocket` must have its own `TaskQueueTimeoutFactory`, +// as the `TimeoutID` are not unique among sockets. +// +// This class must outlive any created Timeout that it has created. Note that +// the `DcSctpSocket` will ensure that all Timeouts are deleted when the socket +// is destructed, so this means that this class must outlive the `DcSctpSocket`. +// +// This class, and the timeouts created it, are not thread safe. +class TaskQueueTimeoutFactory { + public: + // The `get_time` function must return the current time, relative to any + // epoch. Whenever a timeout expires, the `on_expired` callback will be + // triggered, and then the client should provided `timeout_id` to + // `DcSctpSocketInterface::HandleTimeout`. + TaskQueueTimeoutFactory(webrtc::TaskQueueBase& task_queue, + std::function get_time, + std::function on_expired) + : task_queue_(task_queue), + get_time_(std::move(get_time)), + on_expired_(std::move(on_expired)) {} + + // Creates an implementation of `Timeout`. + std::unique_ptr CreateTimeout() { + return std::make_unique(*this); + } + + private: + class TaskQueueTimeout : public Timeout { + public: + explicit TaskQueueTimeout(TaskQueueTimeoutFactory& parent); + ~TaskQueueTimeout(); + + void Start(DurationMs duration_ms, TimeoutID timeout_id) override; + void Stop() override; + + private: + TaskQueueTimeoutFactory& parent_; + // A safety flag to ensure that posted tasks to the task queue don't + // reference these object when they go out of scope. Note that this safety + // flag will be re-created if the scheduled-but-not-yet-expired task is not + // to be run. This happens when there is a posted delayed task with an + // expiration time _further away_ than what is now the expected expiration + // time. In this scenario, a new delayed task has to be posted with a + // shorter duration and the old task has to be forgotten. + rtc::scoped_refptr pending_task_safety_flag_; + // The time when the posted delayed task is set to expire. Will be set to + // the infinite future if there is no such task running. + TimeMs posted_task_expiration_ = TimeMs::InfiniteFuture(); + // The time when the timeout expires. It will be set to the infinite future + // if the timeout is not running/not started. + TimeMs timeout_expiration_ = TimeMs::InfiniteFuture(); + // The current timeout ID that will be reported when expired. + TimeoutID timeout_id_ = TimeoutID(0); + }; + + RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker thread_checker_; + webrtc::TaskQueueBase& task_queue_; + const std::function get_time_; + const std::function on_expired_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TIMER_TASK_QUEUE_TIMEOUT_H_ diff --git a/net/dcsctp/timer/task_queue_timeout_test.cc b/net/dcsctp/timer/task_queue_timeout_test.cc new file mode 100644 index 0000000000..9d3846953b --- /dev/null +++ b/net/dcsctp/timer/task_queue_timeout_test.cc @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/timer/task_queue_timeout.h" + +#include + +#include "rtc_base/gunit.h" +#include "test/gmock.h" +#include "test/time_controller/simulated_time_controller.h" + +namespace dcsctp { +namespace { +using ::testing::MockFunction; + +class TaskQueueTimeoutTest : public testing::Test { + protected: + TaskQueueTimeoutTest() + : time_controller_(webrtc::Timestamp::Millis(1234)), + task_queue_(time_controller_.GetMainThread()), + factory_( + *task_queue_, + [this]() { + return TimeMs(time_controller_.GetClock()->CurrentTime().ms()); + }, + on_expired_.AsStdFunction()) {} + + void AdvanceTime(DurationMs duration) { + time_controller_.AdvanceTime(webrtc::TimeDelta::Millis(*duration)); + } + + MockFunction on_expired_; + webrtc::GlobalSimulatedTimeController time_controller_; + + rtc::Thread* task_queue_; + TaskQueueTimeoutFactory factory_; +}; + +TEST_F(TaskQueueTimeoutTest, StartPostsDelayedTask) { + std::unique_ptr timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(999)); + + EXPECT_CALL(on_expired_, Call(TimeoutID(1))); + AdvanceTime(DurationMs(1)); +} + +TEST_F(TaskQueueTimeoutTest, StopBeforeExpiringDoesntTrigger) { + std::unique_ptr timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(999)); + + timeout->Stop(); + + AdvanceTime(DurationMs(1)); + AdvanceTime(DurationMs(1000)); +} + +TEST_F(TaskQueueTimeoutTest, RestartPrologingTimeoutDuration) { + std::unique_ptr timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(500)); + + timeout->Restart(DurationMs(1000), TimeoutID(2)); + + AdvanceTime(DurationMs(999)); + + EXPECT_CALL(on_expired_, Call(TimeoutID(2))); + AdvanceTime(DurationMs(1)); +} + +TEST_F(TaskQueueTimeoutTest, RestartWithShorterDurationExpiresWhenExpected) { + std::unique_ptr timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(500)); + + timeout->Restart(DurationMs(200), TimeoutID(2)); + + AdvanceTime(DurationMs(199)); + + EXPECT_CALL(on_expired_, Call(TimeoutID(2))); + AdvanceTime(DurationMs(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(1000)); +} + +TEST_F(TaskQueueTimeoutTest, KilledBeforeExpired) { + std::unique_ptr timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(500)); + + timeout = nullptr; + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(1000)); +} +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/timer/timer.cc b/net/dcsctp/timer/timer.cc new file mode 100644 index 0000000000..593d639fa7 --- /dev/null +++ b/net/dcsctp/timer/timer.cc @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/timer/timer.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "net/dcsctp/public/timeout.h" +#include "rtc_base/checks.h" + +namespace dcsctp { +namespace { +TimeoutID MakeTimeoutId(TimerID timer_id, TimerGeneration generation) { + return TimeoutID(static_cast(*timer_id) << 32 | *generation); +} + +DurationMs GetBackoffDuration(TimerBackoffAlgorithm algorithm, + DurationMs base_duration, + int expiration_count) { + switch (algorithm) { + case TimerBackoffAlgorithm::kFixed: + return base_duration; + case TimerBackoffAlgorithm::kExponential: { + int32_t duration_ms = *base_duration; + + while (expiration_count > 0 && duration_ms < *Timer::kMaxTimerDuration) { + duration_ms *= 2; + --expiration_count; + } + + return DurationMs(std::min(duration_ms, *Timer::kMaxTimerDuration)); + } + } +} +} // namespace + +constexpr DurationMs Timer::kMaxTimerDuration; + +Timer::Timer(TimerID id, + absl::string_view name, + OnExpired on_expired, + UnregisterHandler unregister_handler, + std::unique_ptr timeout, + const TimerOptions& options) + : id_(id), + name_(name), + options_(options), + on_expired_(std::move(on_expired)), + unregister_handler_(std::move(unregister_handler)), + timeout_(std::move(timeout)), + duration_(options.duration) {} + +Timer::~Timer() { + Stop(); + unregister_handler_(); +} + +void Timer::Start() { + expiration_count_ = 0; + if (!is_running()) { + is_running_ = true; + generation_ = TimerGeneration(*generation_ + 1); + timeout_->Start(duration_, MakeTimeoutId(id_, generation_)); + } else { + // Timer was running - stop and restart it, to make it expire in `duration_` + // from now. + generation_ = TimerGeneration(*generation_ + 1); + timeout_->Restart(duration_, MakeTimeoutId(id_, generation_)); + } +} + +void Timer::Stop() { + if (is_running()) { + timeout_->Stop(); + expiration_count_ = 0; + is_running_ = false; + } +} + +void Timer::Trigger(TimerGeneration generation) { + if (is_running_ && generation == generation_) { + ++expiration_count_; + is_running_ = false; + if (options_.max_restarts < 0 || + expiration_count_ <= options_.max_restarts) { + // The timer should still be running after this triggers. Start a new + // timer. Note that it might be very quickly restarted again, if the + // `on_expired_` callback returns a new duration. + is_running_ = true; + DurationMs duration = GetBackoffDuration(options_.backoff_algorithm, + duration_, expiration_count_); + generation_ = TimerGeneration(*generation_ + 1); + timeout_->Start(duration, MakeTimeoutId(id_, generation_)); + } + + absl::optional new_duration = on_expired_(); + if (new_duration.has_value() && new_duration != duration_) { + duration_ = new_duration.value(); + if (is_running_) { + // Restart it with new duration. + timeout_->Stop(); + + DurationMs duration = GetBackoffDuration(options_.backoff_algorithm, + duration_, expiration_count_); + generation_ = TimerGeneration(*generation_ + 1); + timeout_->Start(duration, MakeTimeoutId(id_, generation_)); + } + } + } +} + +void TimerManager::HandleTimeout(TimeoutID timeout_id) { + TimerID timer_id(*timeout_id >> 32); + TimerGeneration generation(*timeout_id); + auto it = timers_.find(timer_id); + if (it != timers_.end()) { + it->second->Trigger(generation); + } +} + +std::unique_ptr TimerManager::CreateTimer(absl::string_view name, + Timer::OnExpired on_expired, + const TimerOptions& options) { + next_id_ = TimerID(*next_id_ + 1); + TimerID id = next_id_; + // This would overflow after 4 billion timers created, which in SCTP would be + // after 800 million reconnections on a single socket. Ensure this will never + // happen. + RTC_CHECK_NE(*id, std::numeric_limits::max()); + auto timer = absl::WrapUnique(new Timer( + id, name, std::move(on_expired), [this, id]() { timers_.erase(id); }, + create_timeout_(), options)); + timers_[id] = timer.get(); + return timer; +} + +} // namespace dcsctp diff --git a/net/dcsctp/timer/timer.h b/net/dcsctp/timer/timer.h new file mode 100644 index 0000000000..bf923ea4ca --- /dev/null +++ b/net/dcsctp/timer/timer.h @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TIMER_TIMER_H_ +#define NET_DCSCTP_TIMER_TIMER_H_ + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "net/dcsctp/public/strong_alias.h" +#include "net/dcsctp/public/timeout.h" + +namespace dcsctp { + +using TimerID = StrongAlias; +using TimerGeneration = StrongAlias; + +enum class TimerBackoffAlgorithm { + // The base duration will be used for any restart. + kFixed, + // An exponential backoff is used for restarts, with a 2x multiplier, meaning + // that every restart will use a duration that is twice as long as the + // previous. + kExponential, +}; + +struct TimerOptions { + explicit TimerOptions(DurationMs duration) + : TimerOptions(duration, TimerBackoffAlgorithm::kExponential) {} + TimerOptions(DurationMs duration, TimerBackoffAlgorithm backoff_algorithm) + : TimerOptions(duration, backoff_algorithm, -1) {} + TimerOptions(DurationMs duration, + TimerBackoffAlgorithm backoff_algorithm, + int max_restarts) + : duration(duration), + backoff_algorithm(backoff_algorithm), + max_restarts(max_restarts) {} + + // The initial timer duration. Can be overridden with `set_duration`. + const DurationMs duration; + // If the duration should be increased (using exponential backoff) when it is + // restarted. If not set, the same duration will be used. + const TimerBackoffAlgorithm backoff_algorithm; + // The maximum number of times that the timer will be automatically restarted. + const int max_restarts; +}; + +// A high-level timer (in contrast to the low-level `Timeout` class). +// +// Timers are started and can be stopped or restarted. When a timer expires, +// the provided `on_expired` callback will be triggered. A timer is +// automatically restarted, as long as the number of restarts is below the +// configurable `max_restarts` parameter. The `is_running` property can be +// queried to know if it's still running after having expired. +// +// When a timer is restarted, it will use a configurable `backoff_algorithm` to +// possibly adjust the duration of the next expiry. It is also possible to +// return a new base duration (which is the duration before it's adjusted by the +// backoff algorithm). +class Timer { + public: + // The maximum timer duration - one day. + static constexpr DurationMs kMaxTimerDuration = DurationMs(24 * 3600 * 1000); + + // When expired, the timer handler can optionally return a new duration which + // will be set as `duration` and used as base duration when the timer is + // restarted and as input to the backoff algorithm. + using OnExpired = std::function()>; + + // TimerManager will have pointers to these instances, so they must not move. + Timer(const Timer&) = delete; + Timer& operator=(const Timer&) = delete; + + ~Timer(); + + // Starts the timer if it's stopped or restarts the timer if it's already + // running. The `expiration_count` will be reset. + void Start(); + + // Stops the timer. This can also be called when the timer is already stopped. + // The `expiration_count` will be reset. + void Stop(); + + // Sets the base duration. The actual timer duration may be larger depending + // on the backoff algorithm. + void set_duration(DurationMs duration) { + duration_ = std::min(duration, kMaxTimerDuration); + } + + // Retrieves the base duration. The actual timer duration may be larger + // depending on the backoff algorithm. + DurationMs duration() const { return duration_; } + + // Returns the number of times the timer has expired. + int expiration_count() const { return expiration_count_; } + + // Returns the timer's options. + const TimerOptions& options() const { return options_; } + + // Returns the name of the timer. + absl::string_view name() const { return name_; } + + // Indicates if this timer is currently running. + bool is_running() const { return is_running_; } + + private: + friend class TimerManager; + using UnregisterHandler = std::function; + Timer(TimerID id, + absl::string_view name, + OnExpired on_expired, + UnregisterHandler unregister, + std::unique_ptr timeout, + const TimerOptions& options); + + // Called by TimerManager. Will trigger the callback and increment + // `expiration_count`. The timer will automatically be restarted at the + // duration as decided by the backoff algorithm, unless the + // `TimerOptions::max_restarts` has been reached and then it will be stopped + // and `is_running()` will return false. + void Trigger(TimerGeneration generation); + + const TimerID id_; + const std::string name_; + const TimerOptions options_; + const OnExpired on_expired_; + const UnregisterHandler unregister_handler_; + const std::unique_ptr timeout_; + + DurationMs duration_; + + // Increased on each start, and is matched on Trigger, to avoid races. And by + // race, meaning that a timeout - which may be evaluated/expired on a + // different thread while this thread has stopped that timer already. Note + // that the entire socket is not thread-safe, so `TimerManager::HandleTimeout` + // is never executed concurrently with any timer starting/stopping. + // + // This will wrap around after 4 billion timer restarts, and if it wraps + // around, it would just trigger _this_ timer in advance (but it's hard to + // restart it 4 billion times within its duration). + TimerGeneration generation_ = TimerGeneration(0); + bool is_running_ = false; + // Incremented each time time has expired and reset when stopped or restarted. + int expiration_count_ = 0; +}; + +// Creates and manages timers. +class TimerManager { + public: + explicit TimerManager( + std::function()> create_timeout) + : create_timeout_(std::move(create_timeout)) {} + + // Creates a timer with name `name` that will expire (when started) after + // `options.duration` and call `on_expired`. There are more `options` that + // affects the behavior. Note that timers are created initially stopped. + std::unique_ptr CreateTimer(absl::string_view name, + Timer::OnExpired on_expired, + const TimerOptions& options); + + void HandleTimeout(TimeoutID timeout_id); + + private: + const std::function()> create_timeout_; + std::unordered_map timers_; + TimerID next_id_ = TimerID(0); +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_TIMER_TIMER_H_ diff --git a/net/dcsctp/timer/timer_test.cc b/net/dcsctp/timer/timer_test.cc new file mode 100644 index 0000000000..a403bb6b4b --- /dev/null +++ b/net/dcsctp/timer/timer_test.cc @@ -0,0 +1,390 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/timer/timer.h" + +#include + +#include "absl/types/optional.h" +#include "net/dcsctp/public/timeout.h" +#include "net/dcsctp/timer/fake_timeout.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::Return; + +class TimerTest : public testing::Test { + protected: + TimerTest() + : timeout_manager_([this]() { return now_; }), + manager_([this]() { return timeout_manager_.CreateTimeout(); }) { + ON_CALL(on_expired_, Call).WillByDefault(Return(absl::nullopt)); + } + + void AdvanceTimeAndRunTimers(DurationMs duration) { + now_ = now_ + duration; + + for (;;) { + absl::optional timeout_id = + timeout_manager_.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + manager_.HandleTimeout(*timeout_id); + } + } + + TimeMs now_ = TimeMs(0); + FakeTimeoutManager timeout_manager_; + TimerManager manager_; + testing::MockFunction()> on_expired_; +}; + +TEST_F(TimerTest, TimerIsInitiallyStopped) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kFixed)); + + EXPECT_FALSE(t1->is_running()); +} + +TEST_F(TimerTest, TimerExpiresAtGivenTime) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kFixed)); + + EXPECT_CALL(on_expired_, Call).Times(0); + t1->Start(); + EXPECT_TRUE(t1->is_running()); + + AdvanceTimeAndRunTimers(DurationMs(4000)); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); +} + +TEST_F(TimerTest, TimerReschedulesAfterExpiredWithFixedBackoff) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kFixed)); + + EXPECT_CALL(on_expired_, Call).Times(0); + t1->Start(); + EXPECT_EQ(t1->expiration_count(), 0); + + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Fire first time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_TRUE(t1->is_running()); + EXPECT_EQ(t1->expiration_count(), 1); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Second time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_TRUE(t1->is_running()); + EXPECT_EQ(t1->expiration_count(), 2); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Third time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_TRUE(t1->is_running()); + EXPECT_EQ(t1->expiration_count(), 3); +} + +TEST_F(TimerTest, TimerWithNoRestarts) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kFixed, + /*max_restart=*/0)); + + EXPECT_CALL(on_expired_, Call).Times(0); + t1->Start(); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Fire first time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + + EXPECT_FALSE(t1->is_running()); + + // Second time - shouldn't fire + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(5000)); + EXPECT_FALSE(t1->is_running()); +} + +TEST_F(TimerTest, TimerWithOneRestart) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kFixed, + /*max_restart=*/1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + t1->Start(); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Fire first time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_TRUE(t1->is_running()); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Second time - max restart limit reached. + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_FALSE(t1->is_running()); + + // Third time - should not fire. + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(5000)); + EXPECT_FALSE(t1->is_running()); +} + +TEST_F(TimerTest, TimerWithTwoRestart) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kFixed, + /*max_restart=*/2)); + + EXPECT_CALL(on_expired_, Call).Times(0); + t1->Start(); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Fire first time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_TRUE(t1->is_running()); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Second time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_TRUE(t1->is_running()); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Third time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_FALSE(t1->is_running()); +} + +TEST_F(TimerTest, TimerWithExponentialBackoff) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kExponential)); + + t1->Start(); + + // Fire first time at 5 seconds + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(5000)); + + // Second time at 5*2^1 = 10 seconds later. + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(9000)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + + // Third time at 5*2^2 = 20 seconds later. + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(19000)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + + // Fourth time at 5*2^3 = 40 seconds later. + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(39000)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); +} + +TEST_F(TimerTest, StartTimerWillStopAndStart) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kExponential)); + + t1->Start(); + + AdvanceTimeAndRunTimers(DurationMs(3000)); + + t1->Start(); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(2000)); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(3000)); +} + +TEST_F(TimerTest, ExpirationCounterWillResetIfStopped) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kExponential)); + + t1->Start(); + + // Fire first time at 5 seconds + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(5000)); + EXPECT_EQ(t1->expiration_count(), 1); + + // Second time at 5*2^1 = 10 seconds later. + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(9000)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_EQ(t1->expiration_count(), 2); + + t1->Start(); + EXPECT_EQ(t1->expiration_count(), 0); + + // Third time at 5*2^0 = 5 seconds later. + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4000)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_EQ(t1->expiration_count(), 1); +} + +TEST_F(TimerTest, StopTimerWillMakeItNotExpire) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kExponential)); + + t1->Start(); + EXPECT_TRUE(t1->is_running()); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4000)); + t1->Stop(); + EXPECT_FALSE(t1->is_running()); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(1000)); +} + +TEST_F(TimerTest, ReturningNewDurationWhenExpired) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kFixed)); + + EXPECT_CALL(on_expired_, Call).Times(0); + t1->Start(); + EXPECT_EQ(t1->duration(), DurationMs(5000)); + + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Fire first time + EXPECT_CALL(on_expired_, Call).WillOnce(Return(DurationMs(2000))); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_EQ(t1->duration(), DurationMs(2000)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(1000)); + + // Second time + EXPECT_CALL(on_expired_, Call).WillOnce(Return(DurationMs(10000))); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_EQ(t1->duration(), DurationMs(10000)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(9000)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); +} + +TEST_F(TimerTest, TimersHaveMaximumDuration) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(1000), TimerBackoffAlgorithm::kExponential)); + + t1->set_duration(DurationMs(2 * *Timer::kMaxTimerDuration)); + EXPECT_EQ(t1->duration(), Timer::kMaxTimerDuration); +} + +TEST_F(TimerTest, TimersHaveMaximumBackoffDuration) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(1000), TimerBackoffAlgorithm::kExponential)); + + t1->Start(); + + int max_exponent = static_cast(log2(*Timer::kMaxTimerDuration / 1000)); + for (int i = 0; i < max_exponent; ++i) { + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000 * (1 << i))); + } + + // Reached the maximum duration. + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration); +} + +TEST_F(TimerTest, TimerCanBeStartedFromWithinExpirationHandler) { + std::unique_ptr t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(1000), TimerBackoffAlgorithm::kFixed)); + + t1->Start(); + + // Start a timer, but don't return any new duration in callback. + EXPECT_CALL(on_expired_, Call).WillOnce([&]() { + EXPECT_TRUE(t1->is_running()); + t1->set_duration(DurationMs(5000)); + t1->Start(); + return absl::nullopt; + }); + AdvanceTimeAndRunTimers(DurationMs(1000)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4999)); + + // Start a timer, and return any new duration in callback. + EXPECT_CALL(on_expired_, Call).WillOnce([&]() { + EXPECT_TRUE(t1->is_running()); + t1->set_duration(DurationMs(5000)); + t1->Start(); + return absl::make_optional(DurationMs(8000)); + }); + AdvanceTimeAndRunTimers(DurationMs(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(7999)); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1)); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/tx/BUILD.gn b/net/dcsctp/tx/BUILD.gn new file mode 100644 index 0000000000..2f0b27afc6 --- /dev/null +++ b/net/dcsctp/tx/BUILD.gn @@ -0,0 +1,141 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_source_set("send_queue") { + deps = [ + "../../../api:array_view", + "../common:internal_types", + "../packet:chunk", + "../packet:data", + "../public:types", + ] + sources = [ "send_queue.h" ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] +} + +rtc_library("rr_send_queue") { + deps = [ + ":send_queue", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:pair_hash", + "../packet:data", + "../public:socket", + "../public:types", + ] + sources = [ + "rr_send_queue.cc", + "rr_send_queue.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("retransmission_error_counter") { + deps = [ + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../public:types", + ] + sources = [ + "retransmission_error_counter.cc", + "retransmission_error_counter.h", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +rtc_library("retransmission_timeout") { + deps = [ + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../public:types", + ] + sources = [ + "retransmission_timeout.cc", + "retransmission_timeout.h", + ] +} + +rtc_library("retransmission_queue") { + deps = [ + ":retransmission_timeout", + ":send_queue", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:rtc_base_approved", + "../common:math", + "../common:pair_hash", + "../common:sequence_numbers", + "../common:str_join", + "../packet:chunk", + "../packet:data", + "../public:types", + "../timer", + ] + sources = [ + "retransmission_queue.cc", + "retransmission_queue.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +if (rtc_include_tests) { + rtc_source_set("mock_send_queue") { + testonly = true + deps = [ + ":send_queue", + "../../../api:array_view", + "../../../test:test_support", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + sources = [ "mock_send_queue.h" ] + } + + rtc_library("dcsctp_tx_unittests") { + testonly = true + + deps = [ + ":mock_send_queue", + ":retransmission_error_counter", + ":retransmission_queue", + ":retransmission_timeout", + ":rr_send_queue", + ":send_queue", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../rtc_base:rtc_base_approved", + "../../../test:test_support", + "../packet:chunk", + "../packet:data", + "../public:socket", + "../public:types", + "../testing:data_generator", + "../testing:testing_macros", + "../timer", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + sources = [ + "retransmission_error_counter_test.cc", + "retransmission_queue_test.cc", + "retransmission_timeout_test.cc", + "rr_send_queue_test.cc", + ] + } +} diff --git a/net/dcsctp/tx/mock_send_queue.h b/net/dcsctp/tx/mock_send_queue.h new file mode 100644 index 0000000000..0cf64583ae --- /dev/null +++ b/net/dcsctp/tx/mock_send_queue.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_MOCK_SEND_QUEUE_H_ +#define NET_DCSCTP_TX_MOCK_SEND_QUEUE_H_ + +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/tx/send_queue.h" +#include "test/gmock.h" + +namespace dcsctp { + +class MockSendQueue : public SendQueue { + public: + MockSendQueue() { + ON_CALL(*this, Produce).WillByDefault([](TimeMs now, size_t max_size) { + return absl::nullopt; + }); + } + + MOCK_METHOD(absl::optional, + Produce, + (TimeMs now, size_t max_size), + (override)); + MOCK_METHOD(bool, + Discard, + (IsUnordered unordered, StreamID stream_id, MID message_id), + (override)); + MOCK_METHOD(void, + PrepareResetStreams, + (rtc::ArrayView streams), + (override)); + MOCK_METHOD(bool, CanResetStreams, (), (const, override)); + MOCK_METHOD(void, CommitResetStreams, (), (override)); + MOCK_METHOD(void, RollbackResetStreams, (), (override)); + MOCK_METHOD(void, Reset, (), (override)); + MOCK_METHOD(size_t, buffered_amount, (StreamID stream_id), (const, override)); + MOCK_METHOD(size_t, total_buffered_amount, (), (const, override)); + MOCK_METHOD(size_t, + buffered_amount_low_threshold, + (StreamID stream_id), + (const, override)); + MOCK_METHOD(void, + SetBufferedAmountLowThreshold, + (StreamID stream_id, size_t bytes), + (override)); +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_MOCK_SEND_QUEUE_H_ diff --git a/net/dcsctp/tx/retransmission_error_counter.cc b/net/dcsctp/tx/retransmission_error_counter.cc new file mode 100644 index 0000000000..111b6efe96 --- /dev/null +++ b/net/dcsctp/tx/retransmission_error_counter.cc @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_error_counter.h" + +#include "absl/strings/string_view.h" +#include "rtc_base/logging.h" + +namespace dcsctp { +bool RetransmissionErrorCounter::Increment(absl::string_view reason) { + ++counter_; + if (counter_ > limit_) { + RTC_DLOG(LS_INFO) << log_prefix_ << reason + << ", too many retransmissions, counter=" << counter_; + return false; + } + + RTC_DLOG(LS_VERBOSE) << log_prefix_ << reason << ", new counter=" << counter_ + << ", max=" << limit_; + return true; +} + +void RetransmissionErrorCounter::Clear() { + if (counter_ > 0) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "recovered from counter=" << counter_; + counter_ = 0; + } +} + +} // namespace dcsctp diff --git a/net/dcsctp/tx/retransmission_error_counter.h b/net/dcsctp/tx/retransmission_error_counter.h new file mode 100644 index 0000000000..bb8d1f754d --- /dev/null +++ b/net/dcsctp/tx/retransmission_error_counter.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_RETRANSMISSION_ERROR_COUNTER_H_ +#define NET_DCSCTP_TX_RETRANSMISSION_ERROR_COUNTER_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "net/dcsctp/public/dcsctp_options.h" + +namespace dcsctp { + +// The RetransmissionErrorCounter is a simple counter with a limit, and when +// the limit is exceeded, the counter is exhausted and the connection will +// be closed. It's incremented on retransmission errors, such as the T3-RTX +// timer expiring, but also missing heartbeats and stream reset requests. +class RetransmissionErrorCounter { + public: + RetransmissionErrorCounter(absl::string_view log_prefix, + const DcSctpOptions& options) + : log_prefix_(std::string(log_prefix) + "rtx-errors: "), + limit_(options.max_retransmissions) {} + + // Increments the retransmission timer. If the maximum error count has been + // reached, `false` will be returned. + bool Increment(absl::string_view reason); + bool IsExhausted() const { return counter_ > limit_; } + + // Clears the retransmission errors. + void Clear(); + + // Returns its current value + int value() const { return counter_; } + + private: + const std::string log_prefix_; + const int limit_; + int counter_ = 0; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_RETRANSMISSION_ERROR_COUNTER_H_ diff --git a/net/dcsctp/tx/retransmission_error_counter_test.cc b/net/dcsctp/tx/retransmission_error_counter_test.cc new file mode 100644 index 0000000000..61ee82926d --- /dev/null +++ b/net/dcsctp/tx/retransmission_error_counter_test.cc @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_error_counter.h" + +#include "net/dcsctp/public/dcsctp_options.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(RetransmissionErrorCounterTest, HasInitialValue) { + DcSctpOptions options; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_EQ(counter.value(), 0); +} + +TEST(RetransmissionErrorCounterTest, ReturnsFalseAtMaximumValue) { + DcSctpOptions options; + options.max_retransmissions = 5; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_TRUE(counter.Increment("test")); // 1 + EXPECT_TRUE(counter.Increment("test")); // 2 + EXPECT_TRUE(counter.Increment("test")); // 3 + EXPECT_TRUE(counter.Increment("test")); // 4 + EXPECT_TRUE(counter.Increment("test")); // 5 + EXPECT_FALSE(counter.Increment("test")); // Too many retransmissions +} + +TEST(RetransmissionErrorCounterTest, CanHandleZeroRetransmission) { + DcSctpOptions options; + options.max_retransmissions = 0; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_FALSE(counter.Increment("test")); // One is too many. +} + +TEST(RetransmissionErrorCounterTest, IsExhaustedAtMaximum) { + DcSctpOptions options; + options.max_retransmissions = 3; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_TRUE(counter.Increment("test")); // 1 + EXPECT_FALSE(counter.IsExhausted()); + EXPECT_TRUE(counter.Increment("test")); // 2 + EXPECT_FALSE(counter.IsExhausted()); + EXPECT_TRUE(counter.Increment("test")); // 3 + EXPECT_FALSE(counter.IsExhausted()); + EXPECT_FALSE(counter.Increment("test")); // Too many retransmissions + EXPECT_TRUE(counter.IsExhausted()); + EXPECT_FALSE(counter.Increment("test")); // One after too many + EXPECT_TRUE(counter.IsExhausted()); +} + +TEST(RetransmissionErrorCounterTest, ClearingCounter) { + DcSctpOptions options; + options.max_retransmissions = 3; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_TRUE(counter.Increment("test")); // 1 + EXPECT_TRUE(counter.Increment("test")); // 2 + counter.Clear(); + EXPECT_TRUE(counter.Increment("test")); // 1 + EXPECT_TRUE(counter.Increment("test")); // 2 + EXPECT_TRUE(counter.Increment("test")); // 3 + EXPECT_FALSE(counter.IsExhausted()); + EXPECT_FALSE(counter.Increment("test")); // Too many retransmissions + EXPECT_TRUE(counter.IsExhausted()); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/tx/retransmission_queue.cc b/net/dcsctp/tx/retransmission_queue.cc new file mode 100644 index 0000000000..51bb65a30c --- /dev/null +++ b/net/dcsctp/tx/retransmission_queue.cc @@ -0,0 +1,909 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_queue.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/common/pair_hash.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { +namespace { + +// The number of times a packet must be NACKed before it's retransmitted. +// See https://tools.ietf.org/html/rfc4960#section-7.2.4 +constexpr size_t kNumberOfNacksForRetransmission = 3; +} // namespace + +RetransmissionQueue::RetransmissionQueue( + absl::string_view log_prefix, + TSN initial_tsn, + size_t a_rwnd, + SendQueue& send_queue, + std::function on_new_rtt, + std::function on_clear_retransmission_counter, + Timer& t3_rtx, + const DcSctpOptions& options, + bool supports_partial_reliability, + bool use_message_interleaving) + : options_(options), + partial_reliability_(supports_partial_reliability), + log_prefix_(std::string(log_prefix) + "tx: "), + data_chunk_header_size_(use_message_interleaving + ? IDataChunk::kHeaderSize + : DataChunk::kHeaderSize), + on_new_rtt_(std::move(on_new_rtt)), + on_clear_retransmission_counter_( + std::move(on_clear_retransmission_counter)), + t3_rtx_(t3_rtx), + cwnd_(options_.cwnd_mtus_initial * options_.mtu), + rwnd_(a_rwnd), + // https://tools.ietf.org/html/rfc4960#section-7.2.1 + // "The initial value of ssthresh MAY be arbitrarily high (for + // example, implementations MAY use the size of the receiver advertised + // window)."" + ssthresh_(rwnd_), + next_tsn_(tsn_unwrapper_.Unwrap(initial_tsn)), + last_cumulative_tsn_ack_(tsn_unwrapper_.Unwrap(TSN(*initial_tsn - 1))), + send_queue_(send_queue) {} + +bool RetransmissionQueue::IsConsistent() const { + size_t actual_outstanding_bytes = 0; + + std::set actual_to_be_retransmitted; + for (const auto& elem : outstanding_data_) { + if (elem.second.is_outstanding()) { + actual_outstanding_bytes += GetSerializedChunkSize(elem.second.data()); + } + + if (elem.second.should_be_retransmitted()) { + actual_to_be_retransmitted.insert(elem.first); + } + } + + return actual_outstanding_bytes == outstanding_bytes_ && + actual_to_be_retransmitted == to_be_retransmitted_; +} + +// Returns how large a chunk will be, serialized, carrying the data +size_t RetransmissionQueue::GetSerializedChunkSize(const Data& data) const { + return RoundUpTo4(data_chunk_header_size_ + data.size()); +} + +void RetransmissionQueue::RemoveAcked(UnwrappedTSN cumulative_tsn_ack, + AckInfo& ack_info) { + auto first_unacked = outstanding_data_.upper_bound(cumulative_tsn_ack); + + for (auto it = outstanding_data_.begin(); it != first_unacked; ++it) { + ack_info.bytes_acked_by_cumulative_tsn_ack += it->second.data().size(); + ack_info.acked_tsns.push_back(it->first.Wrap()); + if (it->second.is_outstanding()) { + outstanding_bytes_ -= GetSerializedChunkSize(it->second.data()); + } else if (it->second.should_be_retransmitted()) { + to_be_retransmitted_.erase(it->first); + } + } + + outstanding_data_.erase(outstanding_data_.begin(), first_unacked); +} + +void RetransmissionQueue::AckGapBlocks( + UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView gap_ack_blocks, + AckInfo& ack_info) { + // Mark all non-gaps as ACKED (but they can't be removed) as (from RFC) + // "SCTP considers the information carried in the Gap Ack Blocks in the + // SACK chunk as advisory.". Note that when NR-SACK is supported, this can be + // handled differently. + + for (auto& block : gap_ack_blocks) { + auto start = outstanding_data_.lower_bound( + UnwrappedTSN::AddTo(cumulative_tsn_ack, block.start)); + auto end = outstanding_data_.upper_bound( + UnwrappedTSN::AddTo(cumulative_tsn_ack, block.end)); + for (auto iter = start; iter != end; ++iter) { + if (!iter->second.is_acked()) { + ack_info.bytes_acked_by_new_gap_ack_blocks += + iter->second.data().size(); + if (iter->second.is_outstanding()) { + outstanding_bytes_ -= GetSerializedChunkSize(iter->second.data()); + } + if (iter->second.should_be_retransmitted()) { + to_be_retransmitted_.erase(iter->first); + } + iter->second.Ack(); + ack_info.highest_tsn_acked = + std::max(ack_info.highest_tsn_acked, iter->first); + ack_info.acked_tsns.push_back(iter->first.Wrap()); + } + } + } +} + +void RetransmissionQueue::NackBetweenAckBlocks( + UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView gap_ack_blocks, + AckInfo& ack_info) { + // Mark everything between the blocks as NACKED/TO_BE_RETRANSMITTED. + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "Mark the DATA chunk(s) with three miss indications for retransmission." + // "For each incoming SACK, miss indications are incremented only for + // missing TSNs prior to the highest TSN newly acknowledged in the SACK." + // + // What this means is that only when there is a increasing stream of data + // received and there are new packets seen (since last time), packets that are + // in-flight and between gaps should be nacked. This means that SCTP relies on + // the T3-RTX-timer to re-send packets otherwise. + UnwrappedTSN max_tsn_to_nack = ack_info.highest_tsn_acked; + if (is_in_fast_recovery() && cumulative_tsn_ack > last_cumulative_tsn_ack_) { + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "If an endpoint is in Fast Recovery and a SACK arrives that advances + // the Cumulative TSN Ack Point, the miss indications are incremented for + // all TSNs reported missing in the SACK." + max_tsn_to_nack = UnwrappedTSN::AddTo( + cumulative_tsn_ack, + gap_ack_blocks.empty() ? 0 : gap_ack_blocks.rbegin()->end); + } + + UnwrappedTSN prev_block_last_acked = cumulative_tsn_ack; + for (auto& block : gap_ack_blocks) { + UnwrappedTSN cur_block_first_acked = + UnwrappedTSN::AddTo(cumulative_tsn_ack, block.start); + for (auto iter = outstanding_data_.upper_bound(prev_block_last_acked); + iter != outstanding_data_.lower_bound(cur_block_first_acked); ++iter) { + if (iter->first <= max_tsn_to_nack) { + ack_info.has_packet_loss = + NackItem(iter->first, iter->second, /*retransmit_now=*/false); + } + } + prev_block_last_acked = UnwrappedTSN::AddTo(cumulative_tsn_ack, block.end); + } + + // Note that packets are not NACKED which are above the highest gap-ack-block + // (or above the cumulative ack TSN if no gap-ack-blocks) as only packets + // up until the highest_tsn_acked (see above) should be considered when + // NACKing. +} + +void RetransmissionQueue::MaybeExitFastRecovery( + UnwrappedTSN cumulative_tsn_ack) { + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "When a SACK acknowledges all TSNs up to and including this [fast + // recovery] exit point, Fast Recovery is exited." + if (fast_recovery_exit_tsn_.has_value() && + cumulative_tsn_ack >= *fast_recovery_exit_tsn_) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "exit_point=" << *fast_recovery_exit_tsn_->Wrap() + << " reached - exiting fast recovery"; + fast_recovery_exit_tsn_ = absl::nullopt; + } +} + +void RetransmissionQueue::HandleIncreasedCumulativeTsnAck( + size_t outstanding_bytes, + size_t total_bytes_acked) { + // Allow some margin for classifying as fully utilized, due to e.g. that too + // small packets (less than kMinimumFragmentedPayload) are not sent + + // overhead. + bool is_fully_utilized = outstanding_bytes + options_.mtu >= cwnd_; + size_t old_cwnd = cwnd_; + if (phase() == CongestionAlgorithmPhase::kSlowStart) { + if (is_fully_utilized && !is_in_fast_recovery()) { + // https://tools.ietf.org/html/rfc4960#section-7.2.1 + // "Only when these three conditions are met can the cwnd be + // increased; otherwise, the cwnd MUST not be increased. If these + // conditions are met, then cwnd MUST be increased by, at most, the + // lesser of 1) the total size of the previously outstanding DATA + // chunk(s) acknowledged, and 2) the destination's path MTU." + if (options_.slow_start_tcp_style) { + cwnd_ += std::min(total_bytes_acked, cwnd_); + } else { + cwnd_ += std::min(total_bytes_acked, options_.mtu); + } + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "SS increase cwnd=" << cwnd_ + << " (" << old_cwnd << ")"; + } + } else if (phase() == CongestionAlgorithmPhase::kCongestionAvoidance) { + // https://tools.ietf.org/html/rfc4960#section-7.2.2 + // "Whenever cwnd is greater than ssthresh, upon each SACK arrival + // that advances the Cumulative TSN Ack Point, increase + // partial_bytes_acked by the total number of bytes of all new chunks + // acknowledged in that SACK including chunks acknowledged by the new + // Cumulative TSN Ack and by Gap Ack Blocks." + size_t old_pba = partial_bytes_acked_; + partial_bytes_acked_ += total_bytes_acked; + + if (partial_bytes_acked_ >= cwnd_ && is_fully_utilized) { + // https://tools.ietf.org/html/rfc4960#section-7.2.2 + // "When partial_bytes_acked is equal to or greater than cwnd and + // before the arrival of the SACK the sender had cwnd or more bytes of + // data outstanding (i.e., before arrival of the SACK, flightsize was + // greater than or equal to cwnd), increase cwnd by MTU, and reset + // partial_bytes_acked to (partial_bytes_acked - cwnd)." + cwnd_ += options_.mtu; + partial_bytes_acked_ -= cwnd_; + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "CA increase cwnd=" << cwnd_ + << " (" << old_cwnd << ") ssthresh=" << ssthresh_ + << ", pba=" << partial_bytes_acked_ << " (" + << old_pba << ")"; + } else { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "CA unchanged cwnd=" << cwnd_ + << " (" << old_cwnd << ") ssthresh=" << ssthresh_ + << ", pba=" << partial_bytes_acked_ << " (" + << old_pba << ")"; + } + } +} + +void RetransmissionQueue::HandlePacketLoss(UnwrappedTSN highest_tsn_acked) { + if (!is_in_fast_recovery()) { + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "If not in Fast Recovery, adjust the ssthresh and cwnd of the + // destination address(es) to which the missing DATA chunks were last + // sent, according to the formula described in Section 7.2.3." + size_t old_cwnd = cwnd_; + size_t old_pba = partial_bytes_acked_; + ssthresh_ = std::max(cwnd_ / 2, options_.cwnd_mtus_min * options_.mtu); + cwnd_ = ssthresh_; + partial_bytes_acked_ = 0; + + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "packet loss detected (not fast recovery). cwnd=" + << cwnd_ << " (" << old_cwnd + << "), ssthresh=" << ssthresh_ + << ", pba=" << partial_bytes_acked_ << " (" << old_pba + << ")"; + + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "If not in Fast Recovery, enter Fast Recovery and mark the highest + // outstanding TSN as the Fast Recovery exit point." + fast_recovery_exit_tsn_ = outstanding_data_.empty() + ? last_cumulative_tsn_ack_ + : outstanding_data_.rbegin()->first; + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "fast recovery initiated with exit_point=" + << *fast_recovery_exit_tsn_->Wrap(); + } else { + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "While in Fast Recovery, the ssthresh and cwnd SHOULD NOT change for + // any destinations due to a subsequent Fast Recovery event (i.e., one + // SHOULD NOT reduce the cwnd further due to a subsequent Fast Retransmit)." + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "packet loss detected (fast recovery). No changes."; + } +} + +void RetransmissionQueue::UpdateReceiverWindow(uint32_t a_rwnd) { + rwnd_ = outstanding_bytes_ >= a_rwnd ? 0 : a_rwnd - outstanding_bytes_; +} + +void RetransmissionQueue::StartT3RtxTimerIfOutstandingData() { + // Note: Can't use `outstanding_bytes()` as that one doesn't count chunks to + // be retransmitted. + if (outstanding_data_.empty()) { + // https://tools.ietf.org/html/rfc4960#section-6.3.2 + // "Whenever all outstanding data sent to an address have been + // acknowledged, turn off the T3-rtx timer of that address. + // Note: Already stopped in `StopT3RtxTimerOnIncreasedCumulativeTsnAck`." + } else { + // https://tools.ietf.org/html/rfc4960#section-6.3.2 + // "Whenever a SACK is received that acknowledges the DATA chunk + // with the earliest outstanding TSN for that address, restart the T3-rtx + // timer for that address with its current RTO (if there is still + // outstanding data on that address)." + // "Whenever a SACK is received missing a TSN that was previously + // acknowledged via a Gap Ack Block, start the T3-rtx for the destination + // address to which the DATA chunk was originally transmitted if it is not + // already running." + if (!t3_rtx_.is_running()) { + t3_rtx_.Start(); + } + } +} + +bool RetransmissionQueue::IsSackValid(const SackChunk& sack) const { + // https://tools.ietf.org/html/rfc4960#section-6.2.1 + // "If Cumulative TSN Ack is less than the Cumulative TSN Ack Point, + // then drop the SACK. Since Cumulative TSN Ack is monotonically increasing, + // a SACK whose Cumulative TSN Ack is less than the Cumulative TSN Ack Point + // indicates an out-of- order SACK." + // + // Note: Important not to drop SACKs with identical TSN to that previously + // received, as the gap ack blocks or dup tsn fields may have changed. + UnwrappedTSN cumulative_tsn_ack = + tsn_unwrapper_.PeekUnwrap(sack.cumulative_tsn_ack()); + if (cumulative_tsn_ack < last_cumulative_tsn_ack_) { + // https://tools.ietf.org/html/rfc4960#section-6.2.1 + // "If Cumulative TSN Ack is less than the Cumulative TSN Ack Point, + // then drop the SACK. Since Cumulative TSN Ack is monotonically + // increasing, a SACK whose Cumulative TSN Ack is less than the Cumulative + // TSN Ack Point indicates an out-of- order SACK." + return false; + } else if (outstanding_data_.empty() && + cumulative_tsn_ack > last_cumulative_tsn_ack_) { + // No in-flight data and cum-tsn-ack above what was last ACKed - not valid. + return false; + } else if (!outstanding_data_.empty() && + cumulative_tsn_ack > outstanding_data_.rbegin()->first) { + // There is in-flight data, but the cum-tsn-ack is beyond that - not valid. + return false; + } + return true; +} + +bool RetransmissionQueue::HandleSack(TimeMs now, const SackChunk& sack) { + if (!IsSackValid(sack)) { + return false; + } + + size_t old_outstanding_bytes = outstanding_bytes_; + size_t old_rwnd = rwnd_; + UnwrappedTSN cumulative_tsn_ack = + tsn_unwrapper_.Unwrap(sack.cumulative_tsn_ack()); + + if (sack.gap_ack_blocks().empty()) { + UpdateRTT(now, cumulative_tsn_ack); + } + + AckInfo ack_info(cumulative_tsn_ack); + // Erase all items up to cumulative_tsn_ack. + RemoveAcked(cumulative_tsn_ack, ack_info); + + // ACK packets reported in the gap ack blocks + AckGapBlocks(cumulative_tsn_ack, sack.gap_ack_blocks(), ack_info); + + // NACK and possibly mark for retransmit chunks that weren't acked. + NackBetweenAckBlocks(cumulative_tsn_ack, sack.gap_ack_blocks(), ack_info); + + // Update of outstanding_data_ is now done. Congestion control remains. + UpdateReceiverWindow(sack.a_rwnd()); + + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Received SACK. Acked TSN: " + << StrJoin(ack_info.acked_tsns, ",", + [](rtc::StringBuilder& sb, TSN tsn) { + sb << *tsn; + }) + << ", cum_tsn_ack=" << *cumulative_tsn_ack.Wrap() << " (" + << *last_cumulative_tsn_ack_.Wrap() + << "), outstanding_bytes=" << outstanding_bytes_ << " (" + << old_outstanding_bytes << "), rwnd=" << rwnd_ << " (" + << old_rwnd << ")"; + + MaybeExitFastRecovery(cumulative_tsn_ack); + + if (cumulative_tsn_ack > last_cumulative_tsn_ack_) { + // https://tools.ietf.org/html/rfc4960#section-6.3.2 + // "Whenever a SACK is received that acknowledges the DATA chunk + // with the earliest outstanding TSN for that address, restart the T3-rtx + // timer for that address with its current RTO (if there is still + // outstanding data on that address)." + // Note: It may be started again in a bit further down. + t3_rtx_.Stop(); + + HandleIncreasedCumulativeTsnAck( + old_outstanding_bytes, ack_info.bytes_acked_by_cumulative_tsn_ack + + ack_info.bytes_acked_by_new_gap_ack_blocks); + } + + if (ack_info.has_packet_loss) { + is_in_fast_retransmit_ = true; + HandlePacketLoss(ack_info.highest_tsn_acked); + } + + // https://tools.ietf.org/html/rfc4960#section-8.2 + // "When an outstanding TSN is acknowledged [...] the endpoint shall clear + // the error counter ..." + if (ack_info.bytes_acked_by_cumulative_tsn_ack > 0 || + ack_info.bytes_acked_by_new_gap_ack_blocks > 0) { + on_clear_retransmission_counter_(); + } + + last_cumulative_tsn_ack_ = cumulative_tsn_ack; + StartT3RtxTimerIfOutstandingData(); + RTC_DCHECK(IsConsistent()); + return true; +} + +void RetransmissionQueue::UpdateRTT(TimeMs now, + UnwrappedTSN cumulative_tsn_ack) { + // RTT updating is flawed in SCTP, as explained in e.g. Pedersen J, Griwodz C, + // Halvorsen P (2006) Considerations of SCTP retransmission delays for thin + // streams. + // Due to delayed acknowledgement, the SACK may be sent much later which + // increases the calculated RTT. + // TODO(boivie): Consider occasionally sending DATA chunks with I-bit set and + // use only those packets for measurement. + + auto it = outstanding_data_.find(cumulative_tsn_ack); + if (it != outstanding_data_.end()) { + if (!it->second.has_been_retransmitted()) { + // https://tools.ietf.org/html/rfc4960#section-6.3.1 + // "Karn's algorithm: RTT measurements MUST NOT be made using + // packets that were retransmitted (and thus for which it is ambiguous + // whether the reply was for the first instance of the chunk or for a + // later instance)" + DurationMs rtt = now - it->second.time_sent(); + on_new_rtt_(rtt); + } + } +} + +void RetransmissionQueue::HandleT3RtxTimerExpiry() { + size_t old_cwnd = cwnd_; + size_t old_outstanding_bytes = outstanding_bytes_; + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "For the destination address for which the timer expires, adjust + // its ssthresh with rules defined in Section 7.2.3 and set the cwnd <- MTU." + ssthresh_ = std::max(cwnd_ / 2, 4 * options_.mtu); + cwnd_ = 1 * options_.mtu; + + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "For the destination address for which the timer expires, set RTO + // <- RTO * 2 ("back off the timer"). The maximum value discussed in rule C7 + // above (RTO.max) may be used to provide an upper bound to this doubling + // operation." + + // Already done by the Timer implementation. + + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "Determine how many of the earliest (i.e., lowest TSN) outstanding + // DATA chunks for the address for which the T3-rtx has expired will fit into + // a single packet" + + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "Note: Any DATA chunks that were sent to the address for which the + // T3-rtx timer expired but did not fit in one MTU (rule E3 above) should be + // marked for retransmission and sent as soon as cwnd allows (normally, when a + // SACK arrives)." + for (auto& elem : outstanding_data_) { + UnwrappedTSN tsn = elem.first; + TxData& item = elem.second; + if (!item.is_acked()) { + NackItem(tsn, item, /*retransmit_now=*/true); + } + } + + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "Start the retransmission timer T3-rtx on the destination address + // to which the retransmission is sent, if rule R1 above indicates to do so." + + // Already done by the Timer implementation. + + RTC_DLOG(LS_INFO) << log_prefix_ << "t3-rtx expired. new cwnd=" << cwnd_ + << " (" << old_cwnd << "), ssthresh=" << ssthresh_ + << ", outstanding_bytes " << outstanding_bytes_ << " (" + << old_outstanding_bytes << ")"; + RTC_DCHECK(IsConsistent()); +} + +bool RetransmissionQueue::NackItem(UnwrappedTSN tsn, + TxData& item, + bool retransmit_now) { + if (item.is_outstanding()) { + outstanding_bytes_ -= GetSerializedChunkSize(item.data()); + } + + switch (item.Nack(retransmit_now)) { + case TxData::NackAction::kNothing: + return false; + case TxData::NackAction::kRetransmit: + to_be_retransmitted_.insert(tsn); + RTC_DLOG(LS_VERBOSE) << log_prefix_ << *tsn.Wrap() + << " marked for retransmission"; + break; + case TxData::NackAction::kAbandon: + AbandonAllFor(item); + break; + } + return true; +} + +std::vector> +RetransmissionQueue::GetChunksToBeRetransmitted(size_t max_size) { + std::vector> result; + + for (auto it = to_be_retransmitted_.begin(); + it != to_be_retransmitted_.end();) { + UnwrappedTSN tsn = *it; + auto elem = outstanding_data_.find(tsn); + RTC_DCHECK(elem != outstanding_data_.end()); + TxData& item = elem->second; + RTC_DCHECK(item.should_be_retransmitted()); + RTC_DCHECK(!item.is_outstanding()); + RTC_DCHECK(!item.is_abandoned()); + RTC_DCHECK(!item.is_acked()); + + size_t serialized_size = GetSerializedChunkSize(item.data()); + if (serialized_size <= max_size) { + item.Retransmit(); + result.emplace_back(tsn.Wrap(), item.data().Clone()); + max_size -= serialized_size; + outstanding_bytes_ += serialized_size; + it = to_be_retransmitted_.erase(it); + } else { + ++it; + } + // No point in continuing if the packet is full. + if (max_size <= data_chunk_header_size_) { + break; + } + } + + return result; +} + +std::vector> RetransmissionQueue::GetChunksToSend( + TimeMs now, + size_t bytes_remaining_in_packet) { + // Chunks are always padded to even divisible by four. + RTC_DCHECK(IsDivisibleBy4(bytes_remaining_in_packet)); + + std::vector> to_be_sent; + size_t old_outstanding_bytes = outstanding_bytes_; + size_t old_rwnd = rwnd_; + if (is_in_fast_retransmit()) { + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "Determine how many of the earliest (i.e., lowest TSN) DATA chunks + // marked for retransmission will fit into a single packet ... Retransmit + // those K DATA chunks in a single packet. When a Fast Retransmit is being + // performed, the sender SHOULD ignore the value of cwnd and SHOULD NOT + // delay retransmission for this single packet." + is_in_fast_retransmit_ = false; + to_be_sent = GetChunksToBeRetransmitted(bytes_remaining_in_packet); + size_t to_be_sent_bytes = absl::c_accumulate( + to_be_sent, 0, [&](size_t r, const std::pair& d) { + return r + GetSerializedChunkSize(d.second); + }); + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "fast-retransmit: sending " + << to_be_sent.size() << " chunks, " << to_be_sent_bytes + << " bytes"; + } else { + // Normal sending. Calculate the bandwidth budget (how many bytes that is + // allowed to be sent), and fill that up first with chunks that are + // scheduled to be retransmitted. If there is still budget, send new chunks + // (which will have their TSN assigned here.) + size_t remaining_cwnd_bytes = + outstanding_bytes_ >= cwnd_ ? 0 : cwnd_ - outstanding_bytes_; + size_t max_bytes = RoundDownTo4(std::min( + std::min(bytes_remaining_in_packet, rwnd()), remaining_cwnd_bytes)); + + to_be_sent = GetChunksToBeRetransmitted(max_bytes); + max_bytes -= absl::c_accumulate( + to_be_sent, 0, [&](size_t r, const std::pair& d) { + return r + GetSerializedChunkSize(d.second); + }); + + while (max_bytes > data_chunk_header_size_) { + RTC_DCHECK(IsDivisibleBy4(max_bytes)); + absl::optional chunk_opt = + send_queue_.Produce(now, max_bytes - data_chunk_header_size_); + if (!chunk_opt.has_value()) { + break; + } + + UnwrappedTSN tsn = next_tsn_; + next_tsn_.Increment(); + + // All chunks are always padded to be even divisible by 4. + size_t chunk_size = GetSerializedChunkSize(chunk_opt->data); + max_bytes -= chunk_size; + outstanding_bytes_ += chunk_size; + rwnd_ -= chunk_size; + auto item_it = + outstanding_data_ + .emplace(tsn, + RetransmissionQueue::TxData( + chunk_opt->data.Clone(), + partial_reliability_ ? chunk_opt->max_retransmissions + : absl::nullopt, + now, + partial_reliability_ ? chunk_opt->expires_at + : absl::nullopt)) + .first; + + if (item_it->second.has_expired(now)) { + // No need to send it - it was expired when it was in the send + // queue. + RTC_DLOG(LS_VERBOSE) + << log_prefix_ << "Marking freshly produced chunk " + << *item_it->first.Wrap() << " and message " + << *item_it->second.data().message_id << " as expired"; + AbandonAllFor(item_it->second); + } else { + to_be_sent.emplace_back(tsn.Wrap(), std::move(chunk_opt->data)); + } + } + } + + if (!to_be_sent.empty()) { + // https://tools.ietf.org/html/rfc4960#section-6.3.2 + // "Every time a DATA chunk is sent to any address (including a + // retransmission), if the T3-rtx timer of that address is not running, + // start it running so that it will expire after the RTO of that address." + if (!t3_rtx_.is_running()) { + t3_rtx_.Start(); + } + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Sending TSN " + << StrJoin(to_be_sent, ",", + [&](rtc::StringBuilder& sb, + const std::pair& c) { + sb << *c.first; + }) + << " - " + << absl::c_accumulate( + to_be_sent, 0, + [&](size_t r, const std::pair& d) { + return r + GetSerializedChunkSize(d.second); + }) + << " bytes. outstanding_bytes=" << outstanding_bytes_ + << " (" << old_outstanding_bytes << "), cwnd=" << cwnd_ + << ", rwnd=" << rwnd_ << " (" << old_rwnd << ")"; + } + RTC_DCHECK(IsConsistent()); + return to_be_sent; +} + +std::vector> +RetransmissionQueue::GetChunkStatesForTesting() const { + std::vector> states; + states.emplace_back(last_cumulative_tsn_ack_.Wrap(), State::kAcked); + for (const auto& elem : outstanding_data_) { + State state; + if (elem.second.is_abandoned()) { + state = State::kAbandoned; + } else if (elem.second.should_be_retransmitted()) { + state = State::kToBeRetransmitted; + } else if (elem.second.is_acked()) { + state = State::kAcked; + } else if (elem.second.is_outstanding()) { + state = State::kInFlight; + } else { + state = State::kNacked; + } + + states.emplace_back(elem.first.Wrap(), state); + } + return states; +} + +bool RetransmissionQueue::ShouldSendForwardTsn(TimeMs now) { + if (!partial_reliability_) { + return false; + } + ExpireOutstandingChunks(now); + if (!outstanding_data_.empty()) { + auto it = outstanding_data_.begin(); + return it->first == last_cumulative_tsn_ack_.next_value() && + it->second.is_abandoned(); + } + RTC_DCHECK(IsConsistent()); + return false; +} + +void RetransmissionQueue::TxData::Ack() { + ack_state_ = AckState::kAcked; + should_be_retransmitted_ = false; +} + +RetransmissionQueue::TxData::NackAction RetransmissionQueue::TxData::Nack( + bool retransmit_now) { + ack_state_ = AckState::kNacked; + ++nack_count_; + if ((retransmit_now || nack_count_ >= kNumberOfNacksForRetransmission) && + !is_abandoned_) { + // Nacked enough times - it's considered lost. + if (!max_retransmissions_.has_value() || + num_retransmissions_ < max_retransmissions_) { + should_be_retransmitted_ = true; + return NackAction::kRetransmit; + } + Abandon(); + return NackAction::kAbandon; + } + return NackAction::kNothing; +} + +void RetransmissionQueue::TxData::Retransmit() { + ack_state_ = AckState::kUnacked; + should_be_retransmitted_ = false; + + nack_count_ = 0; + ++num_retransmissions_; +} + +void RetransmissionQueue::TxData::Abandon() { + is_abandoned_ = true; + should_be_retransmitted_ = false; +} + +bool RetransmissionQueue::TxData::has_expired(TimeMs now) const { + return expires_at_.has_value() && *expires_at_ <= now; +} + +void RetransmissionQueue::ExpireOutstandingChunks(TimeMs now) { + for (const auto& elem : outstanding_data_) { + UnwrappedTSN tsn = elem.first; + const TxData& item = elem.second; + + // Chunks that are nacked can be expired. Care should be taken not to expire + // unacked (in-flight) chunks as they might have been received, but the SACK + // is either delayed or in-flight and may be received later. + if (item.is_abandoned()) { + // Already abandoned. + } else if (item.is_nacked() && item.has_expired(now)) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Marking nacked chunk " + << *tsn.Wrap() << " and message " + << *item.data().message_id << " as expired"; + AbandonAllFor(item); + } else { + // A non-expired chunk. No need to iterate any further. + break; + } + } +} + +void RetransmissionQueue::AbandonAllFor( + const RetransmissionQueue::TxData& item) { + // Erase all remaining chunks from the producer, if any. + if (send_queue_.Discard(item.data().is_unordered, item.data().stream_id, + item.data().message_id)) { + // There were remaining chunks to be produced for this message. Since the + // receiver may have already received all chunks (up till now) for this + // message, we can't just FORWARD-TSN to the last fragment in this + // (abandoned) message and start sending a new message, as the receiver will + // then see a new message before the end of the previous one was seen (or + // skipped over). So create a new fragment, representing the end, that the + // received will never see as it is abandoned immediately and used as cum + // TSN in the sent FORWARD-TSN. + UnwrappedTSN tsn = next_tsn_; + next_tsn_.Increment(); + Data message_end(item.data().stream_id, item.data().ssn, + item.data().message_id, item.data().fsn, item.data().ppid, + std::vector(), Data::IsBeginning(false), + Data::IsEnd(true), item.data().is_unordered); + TxData& added_item = + outstanding_data_ + .emplace(tsn, RetransmissionQueue::TxData(std::move(message_end), + absl::nullopt, TimeMs(0), + absl::nullopt)) + .first->second; + // The added chunk shouldn't be included in `outstanding_bytes`, so set it + // as acked. + added_item.Ack(); + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "Adding unsent end placeholder for message at tsn=" + << *tsn.Wrap(); + } + for (auto& elem : outstanding_data_) { + UnwrappedTSN tsn = elem.first; + TxData& other = elem.second; + + if (!other.is_abandoned() && + other.data().stream_id == item.data().stream_id && + other.data().is_unordered == item.data().is_unordered && + other.data().message_id == item.data().message_id) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Marking chunk " << *tsn.Wrap() + << " as abandoned"; + if (other.should_be_retransmitted()) { + to_be_retransmitted_.erase(tsn); + } + other.Abandon(); + } + } +} + +ForwardTsnChunk RetransmissionQueue::CreateForwardTsn() const { + std::unordered_map + skipped_per_ordered_stream; + UnwrappedTSN new_cumulative_ack = last_cumulative_tsn_ack_; + + for (const auto& elem : outstanding_data_) { + UnwrappedTSN tsn = elem.first; + const TxData& item = elem.second; + + if ((tsn != new_cumulative_ack.next_value()) || !item.is_abandoned()) { + break; + } + new_cumulative_ack = tsn; + if (!item.data().is_unordered && + item.data().ssn > skipped_per_ordered_stream[item.data().stream_id]) { + skipped_per_ordered_stream[item.data().stream_id] = item.data().ssn; + } + } + + std::vector skipped_streams; + skipped_streams.reserve(skipped_per_ordered_stream.size()); + for (const auto& elem : skipped_per_ordered_stream) { + skipped_streams.emplace_back(elem.first, elem.second); + } + return ForwardTsnChunk(new_cumulative_ack.Wrap(), std::move(skipped_streams)); +} + +IForwardTsnChunk RetransmissionQueue::CreateIForwardTsn() const { + std::unordered_map, MID, UnorderedStreamHash> + skipped_per_stream; + UnwrappedTSN new_cumulative_ack = last_cumulative_tsn_ack_; + + for (const auto& elem : outstanding_data_) { + UnwrappedTSN tsn = elem.first; + const TxData& item = elem.second; + + if ((tsn != new_cumulative_ack.next_value()) || !item.is_abandoned()) { + break; + } + new_cumulative_ack = tsn; + std::pair stream_id = + std::make_pair(item.data().is_unordered, item.data().stream_id); + + if (item.data().message_id > skipped_per_stream[stream_id]) { + skipped_per_stream[stream_id] = item.data().message_id; + } + } + + std::vector skipped_streams; + skipped_streams.reserve(skipped_per_stream.size()); + for (const auto& elem : skipped_per_stream) { + const std::pair& stream = elem.first; + MID message_id = elem.second; + skipped_streams.emplace_back(stream.first, stream.second, message_id); + } + + return IForwardTsnChunk(new_cumulative_ack.Wrap(), + std::move(skipped_streams)); +} + +void RetransmissionQueue::PrepareResetStreams( + rtc::ArrayView streams) { + // TODO(boivie): These calls are now only affecting the send queue. The + // packet buffer can also change behavior - for example draining the chunk + // producer and eagerly assign TSNs so that an "Outgoing SSN Reset Request" + // can be sent quickly, with a known `sender_last_assigned_tsn`. + send_queue_.PrepareResetStreams(streams); +} +bool RetransmissionQueue::CanResetStreams() const { + return send_queue_.CanResetStreams(); +} +void RetransmissionQueue::CommitResetStreams() { + send_queue_.CommitResetStreams(); +} +void RetransmissionQueue::RollbackResetStreams() { + send_queue_.RollbackResetStreams(); +} + +} // namespace dcsctp diff --git a/net/dcsctp/tx/retransmission_queue.h b/net/dcsctp/tx/retransmission_queue.h new file mode 100644 index 0000000000..c5a6a04db8 --- /dev/null +++ b/net/dcsctp/tx/retransmission_queue.h @@ -0,0 +1,387 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_RETRANSMISSION_QUEUE_H_ +#define NET_DCSCTP_TX_RETRANSMISSION_QUEUE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_timeout.h" +#include "net/dcsctp/tx/send_queue.h" + +namespace dcsctp { + +// The RetransmissionQueue manages all DATA/I-DATA chunks that are in-flight and +// schedules them to be retransmitted if necessary. Chunks are retransmitted +// when they have been lost for a number of consecutive SACKs, or when the +// retransmission timer, `t3_rtx` expires. +// +// As congestion control is tightly connected with the state of transmitted +// packets, that's also managed here to limit the amount of data that is +// in-flight (sent, but not yet acknowledged). +class RetransmissionQueue { + public: + static constexpr size_t kMinimumFragmentedPayload = 10; + // State for DATA chunks (message fragments) in the queue - used in tests. + enum class State { + // The chunk has been sent but not received yet (from the sender's point of + // view, as no SACK has been received yet that reference this chunk). + kInFlight, + // A SACK has been received which explicitly marked this chunk as missing - + // it's now NACKED and may be retransmitted if NACKED enough times. + kNacked, + // A chunk that will be retransmitted when possible. + kToBeRetransmitted, + // A SACK has been received which explicitly marked this chunk as received. + kAcked, + // A chunk whose message has expired or has been retransmitted too many + // times (RFC3758). It will not be retransmitted anymore. + kAbandoned, + }; + + // Creates a RetransmissionQueue which will send data using `initial_tsn` as + // the first TSN to use for sent fragments. It will poll data from + // `send_queue` and call `on_send_queue_empty` when it is empty. When + // SACKs are received, it will estimate the RTT, and call `on_new_rtt`. When + // an outstanding chunk has been ACKed, it will call + // `on_clear_retransmission_counter` and will also use `t3_rtx`, which is the + // SCTP retransmission timer to manage retransmissions. + RetransmissionQueue(absl::string_view log_prefix, + TSN initial_tsn, + size_t a_rwnd, + SendQueue& send_queue, + std::function on_new_rtt, + std::function on_clear_retransmission_counter, + Timer& t3_rtx, + const DcSctpOptions& options, + bool supports_partial_reliability = true, + bool use_message_interleaving = false); + + // Handles a received SACK. Returns true if the `sack` was processed and + // false if it was discarded due to received out-of-order and not relevant. + bool HandleSack(TimeMs now, const SackChunk& sack); + + // Handles an expired retransmission timer. + void HandleT3RtxTimerExpiry(); + + // Returns a list of chunks to send that would fit in one SCTP packet with + // `bytes_remaining_in_packet` bytes available. This may be further limited by + // the congestion control windows. Note that `ShouldSendForwardTSN` must be + // called prior to this method, to abandon expired chunks, as this method will + // not expire any chunks. + std::vector> GetChunksToSend( + TimeMs now, + size_t bytes_remaining_in_packet); + + // Returns the internal state of all queued chunks. This is only used in + // unit-tests. + std::vector> GetChunkStatesForTesting() const; + + // Returns the next TSN that will be allocated for sent DATA chunks. + TSN next_tsn() const { return next_tsn_.Wrap(); } + + // Returns the size of the congestion window, in bytes. This is the number of + // bytes that may be in-flight. + size_t cwnd() const { return cwnd_; } + + // Overrides the current congestion window size. + void set_cwnd(size_t cwnd) { cwnd_ = cwnd; } + + // Returns the current receiver window size. + size_t rwnd() const { return rwnd_; } + + // Returns the number of bytes of packets that are in-flight. + size_t outstanding_bytes() const { return outstanding_bytes_; } + + // Given the current time `now`, it will evaluate if there are chunks that + // have expired and that need to be discarded. It returns true if a + // FORWARD-TSN should be sent. + bool ShouldSendForwardTsn(TimeMs now); + + // Creates a FORWARD-TSN chunk. + ForwardTsnChunk CreateForwardTsn() const; + + // Creates an I-FORWARD-TSN chunk. + IForwardTsnChunk CreateIForwardTsn() const; + + // See the SendQueue for a longer description of these methods related + // to stream resetting. + void PrepareResetStreams(rtc::ArrayView streams); + bool CanResetStreams() const; + void CommitResetStreams(); + void RollbackResetStreams(); + + private: + enum class CongestionAlgorithmPhase { + kSlowStart, + kCongestionAvoidance, + }; + + // A fragmented message's DATA chunk while in the retransmission queue, and + // its associated metadata. + class TxData { + public: + enum class NackAction { + kNothing, + kRetransmit, + kAbandon, + }; + + explicit TxData(Data data, + absl::optional max_retransmissions, + TimeMs time_sent, + absl::optional expires_at) + : max_retransmissions_(max_retransmissions), + time_sent_(time_sent), + expires_at_(expires_at), + data_(std::move(data)) {} + + TimeMs time_sent() const { return time_sent_; } + + const Data& data() const { return data_; } + + // Acks an item. + void Ack(); + + // Nacks an item. If it has been nacked enough times, or if `retransmit_now` + // is set, it might be marked for retransmission. If the item has reached + // its max retransmission value, it will instead be abandoned. The action + // performed is indicated as return value. + NackAction Nack(bool retransmit_now = false); + + // Prepares the item to be retransmitted. Sets it as outstanding and + // clears all nack counters. + void Retransmit(); + + // Marks this item as abandoned. + void Abandon(); + + bool is_outstanding() const { return ack_state_ == AckState::kUnacked; } + bool is_acked() const { return ack_state_ == AckState::kAcked; } + bool is_nacked() const { return ack_state_ == AckState::kNacked; } + bool is_abandoned() const { return is_abandoned_; } + + // Indicates if this chunk should be retransmitted. + bool should_be_retransmitted() const { return should_be_retransmitted_; } + // Indicates if this chunk has ever been retransmitted. + bool has_been_retransmitted() const { return num_retransmissions_ > 0; } + + // Given the current time, and the current state of this DATA chunk, it will + // indicate if it has expired (SCTP Partial Reliability Extension). + bool has_expired(TimeMs now) const; + + private: + enum class AckState { + kUnacked, + kAcked, + kNacked, + }; + // Indicates the presence of this chunk, if it's in flight (Unacked), has + // been received (Acked) or is lost (Nacked). + AckState ack_state_ = AckState::kUnacked; + // Indicates if this chunk has been abandoned, which is a terminal state. + bool is_abandoned_ = false; + // Indicates if this chunk should be retransmitted. + bool should_be_retransmitted_ = false; + + // The number of times the DATA chunk has been nacked (by having received a + // SACK which doesn't include it). Will be cleared on retransmissions. + size_t nack_count_ = 0; + // The number of times the DATA chunk has been retransmitted. + size_t num_retransmissions_ = 0; + // If the message was sent with a maximum number of retransmissions, this is + // set to that number. The value zero (0) means that it will never be + // retransmitted. + const absl::optional max_retransmissions_; + // When the packet was sent, and placed in this queue. + const TimeMs time_sent_; + // If the message was sent with an expiration time, this is set. + const absl::optional expires_at_; + // The actual data to send/retransmit. + Data data_; + }; + + // Contains variables scoped to a processing of an incoming SACK. + struct AckInfo { + explicit AckInfo(UnwrappedTSN cumulative_tsn_ack) + : highest_tsn_acked(cumulative_tsn_ack) {} + + // All TSNs that have been acked (for the first time) in this SACK. + std::vector acked_tsns; + + // Bytes acked by increasing cumulative_tsn_ack in this SACK + size_t bytes_acked_by_cumulative_tsn_ack = 0; + + // Bytes acked by gap blocks in this SACK. + size_t bytes_acked_by_new_gap_ack_blocks = 0; + + // Indicates if this SACK indicates that packet loss has occurred. Just + // because a packet is missing in the SACK doesn't necessarily mean that + // there is packet loss as that packet might be in-flight and received + // out-of-order. But when it has been reported missing consecutive times, it + // will eventually be considered "lost" and this will be set. + bool has_packet_loss = false; + + // Highest TSN Newly Acknowledged, an SCTP variable. + UnwrappedTSN highest_tsn_acked; + }; + + bool IsConsistent() const; + + // Returns how large a chunk will be, serialized, carrying the data + size_t GetSerializedChunkSize(const Data& data) const; + + // Indicates if the congestion control algorithm is in "fast recovery". + bool is_in_fast_recovery() const { + return fast_recovery_exit_tsn_.has_value(); + } + + // Indicates if the congestion control algorithm is in "fast retransmit". + bool is_in_fast_retransmit() const { return is_in_fast_retransmit_; } + + // Indicates if the provided SACK is valid given what has previously been + // received. If it returns false, the SACK is most likely a duplicate of + // something already seen, so this returning false doesn't necessarily mean + // that the SACK is illegal. + bool IsSackValid(const SackChunk& sack) const; + + // Given a `cumulative_tsn_ack` from an incoming SACK, will remove those items + // in the retransmission queue up until this value and will update `ack_info` + // by setting `bytes_acked_by_cumulative_tsn_ack` and `acked_tsns`. + void RemoveAcked(UnwrappedTSN cumulative_tsn_ack, AckInfo& ack_info); + + // Helper method to nack an item and perform the correct operations given the + // action indicated when nacking an item (e.g. retransmitting or abandoning). + // The return value indicate if an action was performed, meaning that packet + // loss was detected and acted upon. + bool NackItem(UnwrappedTSN cumulative_tsn_ack, + TxData& item, + bool retransmit_now); + + // Will mark the chunks covered by the `gap_ack_blocks` from an incoming SACK + // as "acked" and update `ack_info` by adding new TSNs to `added_tsns`. + void AckGapBlocks(UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView gap_ack_blocks, + AckInfo& ack_info); + + // Mark chunks reported as "missing", as "nacked" or "to be retransmitted" + // depending how many times this has happened. Only packets up until + // `ack_info.highest_tsn_acked` (highest TSN newly acknowledged) are + // nacked/retransmitted. The method will set `ack_info.has_packet_loss`. + void NackBetweenAckBlocks( + UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView gap_ack_blocks, + AckInfo& ack_info); + + // When a SACK chunk is received, this method will be called which _may_ call + // into the `RetransmissionTimeout` to update the RTO. + void UpdateRTT(TimeMs now, UnwrappedTSN cumulative_tsn_ack); + + // If the congestion control is in "fast recovery mode", this may be exited + // now. + void MaybeExitFastRecovery(UnwrappedTSN cumulative_tsn_ack); + + // If chunks have been ACKed, stop the retransmission timer. + void StopT3RtxTimerOnIncreasedCumulativeTsnAck( + UnwrappedTSN cumulative_tsn_ack); + + // Update the congestion control algorithm given as the cumulative ack TSN + // value has increased, as reported in an incoming SACK chunk. + void HandleIncreasedCumulativeTsnAck(size_t outstanding_bytes, + size_t total_bytes_acked); + // Update the congestion control algorithm, given as packet loss has been + // detected, as reported in an incoming SACK chunk. + void HandlePacketLoss(UnwrappedTSN highest_tsn_acked); + // Update the view of the receiver window size. + void UpdateReceiverWindow(uint32_t a_rwnd); + // Given `max_size` of space left in a packet, which chunks can be added to + // it? + std::vector> GetChunksToBeRetransmitted(size_t max_size); + // If there is data sent and not ACKED, ensure that the retransmission timer + // is running. + void StartT3RtxTimerIfOutstandingData(); + + // Given the current time `now_ms`, expire and abandon outstanding (sent at + // least once) chunks that have a limited lifetime. + void ExpireOutstandingChunks(TimeMs now); + // Given that a message fragment, `item` has been abandoned, abandon all other + // fragments that share the same message - both never-before-sent fragments + // that are still in the SendQueue and outstanding chunks. + void AbandonAllFor(const RetransmissionQueue::TxData& item); + + // Returns the current congestion control algorithm phase. + CongestionAlgorithmPhase phase() const { + return (cwnd_ <= ssthresh_) + ? CongestionAlgorithmPhase::kSlowStart + : CongestionAlgorithmPhase::kCongestionAvoidance; + } + + const DcSctpOptions options_; + // If the peer supports RFC3758 - SCTP Partial Reliability Extension. + const bool partial_reliability_; + const std::string log_prefix_; + // The size of the data chunk (DATA/I-DATA) header that is used. + const size_t data_chunk_header_size_; + // Called when a new RTT measurement has been done + const std::function on_new_rtt_; + // Called when a SACK has been seen that cleared the retransmission counter. + const std::function on_clear_retransmission_counter_; + // The retransmission counter. + Timer& t3_rtx_; + // Unwraps TSNs + UnwrappedTSN::Unwrapper tsn_unwrapper_; + + // Congestion Window. Number of bytes that may be in-flight (sent, not acked). + size_t cwnd_; + // Receive Window. Number of bytes available in the receiver's RX buffer. + size_t rwnd_; + // Slow Start Threshold. See RFC4960. + size_t ssthresh_; + // Partial Bytes Acked. See RFC4960. + size_t partial_bytes_acked_ = 0; + // If set, fast recovery is enabled until this TSN has been cumulative + // acked. + absl::optional fast_recovery_exit_tsn_ = absl::nullopt; + // Indicates if the congestion algorithm is in fast retransmit. + bool is_in_fast_retransmit_ = false; + + // Next TSN to used. + UnwrappedTSN next_tsn_; + // The last cumulative TSN ack number + UnwrappedTSN last_cumulative_tsn_ack_; + // The send queue. + SendQueue& send_queue_; + // All the outstanding data chunks that are in-flight and that have not been + // cumulative acked. Note that it also contains chunks that have been acked in + // gap ack blocks. + std::map outstanding_data_; + // Data chunks that are to be retransmitted. + std::set to_be_retransmitted_; + // The number of bytes that are in-flight (sent but not yet acked or nacked). + size_t outstanding_bytes_ = 0; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_RETRANSMISSION_QUEUE_H_ diff --git a/net/dcsctp/tx/retransmission_queue_test.cc b/net/dcsctp/tx/retransmission_queue_test.cc new file mode 100644 index 0000000000..4aa76d66e5 --- /dev/null +++ b/net/dcsctp/tx/retransmission_queue_test.cc @@ -0,0 +1,1182 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_queue.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/testing/data_generator.h" +#include "net/dcsctp/timer/fake_timeout.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/mock_send_queue.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::MockFunction; +using State = ::dcsctp::RetransmissionQueue::State; +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::NiceMock; +using ::testing::Pair; +using ::testing::Return; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +constexpr uint32_t kArwnd = 100000; +constexpr uint32_t kMaxMtu = 1191; + +class RetransmissionQueueTest : public testing::Test { + protected: + RetransmissionQueueTest() + : gen_(MID(42)), + timeout_manager_([this]() { return now_; }), + timer_manager_([this]() { return timeout_manager_.CreateTimeout(); }), + timer_(timer_manager_.CreateTimer( + "test/t3_rtx", + []() { return absl::nullopt; }, + TimerOptions(DurationMs(0)))) {} + + std::function CreateChunk() { + return [this](TimeMs now, size_t max_size) { + return SendQueue::DataToSend(gen_.Ordered({1, 2, 3, 4}, "BE")); + }; + } + + std::vector GetSentPacketTSNs(RetransmissionQueue& queue) { + std::vector tsns; + for (const auto& elem : queue.GetChunksToSend(now_, 10000)) { + tsns.push_back(elem.first); + } + return tsns; + } + + RetransmissionQueue CreateQueue(bool supports_partial_reliability = true, + bool use_message_interleaving = false) { + DcSctpOptions options; + options.mtu = kMaxMtu; + return RetransmissionQueue( + "", TSN(10), kArwnd, producer_, on_rtt_.AsStdFunction(), + on_clear_retransmission_counter_.AsStdFunction(), *timer_, options, + supports_partial_reliability, use_message_interleaving); + } + + DataGenerator gen_; + TimeMs now_ = TimeMs(0); + FakeTimeoutManager timeout_manager_; + TimerManager timer_manager_; + NiceMock> on_rtt_; + NiceMock> on_clear_retransmission_counter_; + NiceMock producer_; + std::unique_ptr timer_; +}; + +TEST_F(RetransmissionQueueTest, InitialAckedPrevTsn) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, SendOneChunk) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(10))); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, SendOneChunkAndAck) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(10))); + + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, SendThreeChunksAndAckTwo) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12))); + + queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, AckWithGapBlocksFromRFC4960Section334) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 5)}, + {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kNacked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kNacked), // + Pair(TSN(17), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, ResendPacketsWhenNackedThreeTimes) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + // Send more chunks, but leave some as gaps to force retransmission after + // three NACKs. + + // Send 18 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(18))); + + // Ack 12, 14-15, 17-18 + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 6)}, + {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kNacked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kNacked), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked))); + + // Send 19 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(19))); + + // Ack 12, 14-15, 17-19 + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 7)}, + {})); + + // Send 20 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(20))); + + // Ack 12, 14-15, 17-20 + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 8)}, + {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kToBeRetransmitted), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kToBeRetransmitted), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked), // + Pair(TSN(19), State::kAcked), // + Pair(TSN(20), State::kAcked))); + + // This will trigger "fast retransmit" mode and only chunks 13 and 16 will be + // resent right now. The send queue will not even be queried. + EXPECT_CALL(producer_, Produce).Times(0); + + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(13), TSN(16))); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kInFlight), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked), // + Pair(TSN(19), State::kAcked), // + Pair(TSN(20), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, CanOnlyProduceTwoPacketsButWantsToSendThree) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered({1, 2, 3, 4}, "BE")); + }) + .WillOnce([this](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered({1, 2, 3, 4}, "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _))); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, RetransmitsOnT3Expiry) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered({1, 2, 3, 4}, "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + + // Will force chunks to be retransmitted + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted))); + + std::vector> chunks_to_rtx = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_rtx, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, LimitedRetransmissionOnlyWithRfc3758Support) { + RetransmissionQueue queue = + CreateQueue(/*supports_partial_reliability=*/false); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = 0; + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + + // Will force chunks to be retransmitted + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(0); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); +} // namespace dcsctp + +TEST_F(RetransmissionQueueTest, LimitsRetransmissionsAsUdp) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = 0; + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + + // Will force chunks to be retransmitted + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(1); + + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned))); + + std::vector> chunks_to_rtx = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_rtx, testing::IsEmpty()); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned))); +} + +TEST_F(RetransmissionQueueTest, LimitsRetransmissionsToThreeSends) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = 3; + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(0); + + // Retransmission 1 + queue.HandleT3RtxTimerExpiry(); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), SizeIs(1)); + + // Retransmission 2 + queue.HandleT3RtxTimerExpiry(); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), SizeIs(1)); + + // Retransmission 3 + queue.HandleT3RtxTimerExpiry(); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), SizeIs(1)); + + // Retransmission 4 - not allowed. + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(1); + queue.HandleT3RtxTimerExpiry(); + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), IsEmpty()); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned))); +} + +TEST_F(RetransmissionQueueTest, RetransmitsWhenSendBufferIsFullT3Expiry) { + RetransmissionQueue queue = CreateQueue(); + static constexpr size_t kCwnd = 1200; + queue.set_cwnd(kCwnd); + EXPECT_EQ(queue.cwnd(), kCwnd); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + + std::vector payload(1000); + EXPECT_CALL(producer_, Produce) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 1500); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + EXPECT_EQ(queue.outstanding_bytes(), payload.size() + DataChunk::kHeaderSize); + + // Will force chunks to be retransmitted + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted))); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + + std::vector> chunks_to_rtx = + queue.GetChunksToSend(now_, 1500); + EXPECT_THAT(chunks_to_rtx, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + EXPECT_EQ(queue.outstanding_bytes(), payload.size() + DataChunk::kHeaderSize); +} + +TEST_F(RetransmissionQueueTest, ProducesValidForwardTsn) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B")); + dts.max_retransmissions = 0; + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({5, 6, 7, 8}, "")); + dts.max_retransmissions = 0; + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({9, 10, 11, 12}, "")); + dts.max_retransmissions = 0; + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + // Send and ack first chunk (TSN 10) + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight))); + + // Chunk 10 is acked, but the remaining are lost + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(true)); + + queue.HandleT3RtxTimerExpiry(); + + // NOTE: The TSN=13 represents the end fragment. + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned), // + Pair(TSN(13), State::kAbandoned))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + ForwardTsnChunk forward_tsn = queue.CreateForwardTsn(); + EXPECT_EQ(forward_tsn.new_cumulative_tsn(), TSN(13)); + EXPECT_THAT(forward_tsn.skipped_streams(), + UnorderedElementsAre( + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(42)))); +} + +TEST_F(RetransmissionQueueTest, ProducesValidForwardTsnWhenFullySent) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B")); + dts.max_retransmissions = 0; + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({5, 6, 7, 8}, "")); + dts.max_retransmissions = 0; + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({9, 10, 11, 12}, "E")); + dts.max_retransmissions = 0; + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + // Send and ack first chunk (TSN 10) + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight))); + + // Chunk 10 is acked, but the remaining are lost + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + ForwardTsnChunk forward_tsn = queue.CreateForwardTsn(); + EXPECT_EQ(forward_tsn.new_cumulative_tsn(), TSN(12)); + EXPECT_THAT(forward_tsn.skipped_streams(), + UnorderedElementsAre( + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(42)))); +} + +TEST_F(RetransmissionQueueTest, ProducesValidIForwardTsn) { + RetransmissionQueue queue = CreateQueue(/*use_message_interleaving=*/true); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + DataGeneratorOptions opts; + opts.stream_id = StreamID(1); + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B", opts)); + dts.max_retransmissions = 0; + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + DataGeneratorOptions opts; + opts.stream_id = StreamID(2); + SendQueue::DataToSend dts(gen_.Unordered({1, 2, 3, 4}, "B", opts)); + dts.max_retransmissions = 0; + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + DataGeneratorOptions opts; + opts.stream_id = StreamID(3); + SendQueue::DataToSend dts(gen_.Ordered({9, 10, 11, 12}, "B", opts)); + dts.max_retransmissions = 0; + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + DataGeneratorOptions opts; + opts.stream_id = StreamID(4); + SendQueue::DataToSend dts(gen_.Ordered({13, 14, 15, 16}, "B", opts)); + dts.max_retransmissions = 0; + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _), Pair(TSN(13), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight))); + + // Chunk 13 is acked, but the remaining are lost + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(4, 4)}, {})); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kNacked), // + Pair(TSN(12), State::kNacked), // + Pair(TSN(13), State::kAcked))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(true)); + EXPECT_CALL(producer_, Discard(IsUnordered(true), StreamID(2), MID(42))) + .WillOnce(Return(true)); + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(3), MID(42))) + .WillOnce(Return(true)); + + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned), // + Pair(TSN(13), State::kAcked), + // Representing end fragments of stream 1-3 + Pair(TSN(14), State::kAbandoned), // + Pair(TSN(15), State::kAbandoned), // + Pair(TSN(16), State::kAbandoned))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + IForwardTsnChunk forward_tsn1 = queue.CreateIForwardTsn(); + EXPECT_EQ(forward_tsn1.new_cumulative_tsn(), TSN(12)); + EXPECT_THAT( + forward_tsn1.skipped_streams(), + UnorderedElementsAre(IForwardTsnChunk::SkippedStream( + IsUnordered(false), StreamID(1), MID(42)), + IForwardTsnChunk::SkippedStream( + IsUnordered(true), StreamID(2), MID(42)), + IForwardTsnChunk::SkippedStream( + IsUnordered(false), StreamID(3), MID(42)))); + + // When TSN 13 is acked, the placeholder "end fragments" must be skipped as + // well. + + // A receiver is more likely to ack TSN 13, but do it incrementally. + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, {}, {})); + + EXPECT_CALL(producer_, Discard).Times(0); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + queue.HandleSack(now_, SackChunk(TSN(13), kArwnd, {}, {})); + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAbandoned), // + Pair(TSN(15), State::kAbandoned), // + Pair(TSN(16), State::kAbandoned))); + + IForwardTsnChunk forward_tsn2 = queue.CreateIForwardTsn(); + EXPECT_EQ(forward_tsn2.new_cumulative_tsn(), TSN(16)); + EXPECT_THAT( + forward_tsn2.skipped_streams(), + UnorderedElementsAre(IForwardTsnChunk::SkippedStream( + IsUnordered(false), StreamID(1), MID(42)), + IForwardTsnChunk::SkippedStream( + IsUnordered(true), StreamID(2), MID(42)), + IForwardTsnChunk::SkippedStream( + IsUnordered(false), StreamID(3), MID(42)))); +} + +TEST_F(RetransmissionQueueTest, MeasureRTT) { + RetransmissionQueue queue = CreateQueue(/*use_message_interleaving=*/true); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B")); + dts.max_retransmissions = 0; + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + + now_ = now_ + DurationMs(123); + + EXPECT_CALL(on_rtt_, Call(DurationMs(123))).Times(1); + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); +} + +TEST_F(RetransmissionQueueTest, ValidateCumTsnAtRest) { + RetransmissionQueue queue = CreateQueue(/*use_message_interleaving=*/true); + + EXPECT_FALSE(queue.HandleSack(now_, SackChunk(TSN(8), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(9), kArwnd, {}, {}))); + EXPECT_FALSE(queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {}))); +} + +TEST_F(RetransmissionQueueTest, ValidateCumTsnAckOnInflightData) { + RetransmissionQueue queue = CreateQueue(); + + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + EXPECT_FALSE(queue.HandleSack(now_, SackChunk(TSN(8), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(9), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(13), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(14), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(15), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(16), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(17), kArwnd, {}, {}))); + EXPECT_FALSE(queue.HandleSack(now_, SackChunk(TSN(18), kArwnd, {}, {}))); +} + +TEST_F(RetransmissionQueueTest, HandleGapAckBlocksMatchingNoInflightData) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + // Ack 9, 20-25. This is an invalid SACK, but should still be handled. + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(11, 16)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight), // + Pair(TSN(14), State::kInFlight), // + Pair(TSN(15), State::kInFlight), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, HandleInvalidGapAckBlocks) { + RetransmissionQueue queue = CreateQueue(); + + // Nothing produced - nothing in retransmission queue + + // Ack 9, 12-13 + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(3, 4)}, {})); + + // Gap ack blocks are just ignore. + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, GapAckBlocksDoNotMoveCumTsnAck) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + // Ack 9, 10-14. This is actually an invalid ACK as the first gap can't be + // adjacent to the cum-tsn-ack, but it's not strictly forbidden. However, the + // cum-tsn-ack should not move, as the gap-ack-blocks are just advisory. + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(1, 5)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kInFlight), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, StaysWithinAvailableSize) { + RetransmissionQueue queue = CreateQueue(); + + // See SctpPacketTest::ReturnsCorrectSpaceAvailableToStayWithinMTU for the + // magic numbers in this test. + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t size) { + EXPECT_EQ(size, 1176 - DataChunk::kHeaderSize); + + std::vector payload(183); + return SendQueue::DataToSend(gen_.Ordered(payload, "BE")); + }) + .WillOnce([this](TimeMs, size_t size) { + EXPECT_EQ(size, 976 - DataChunk::kHeaderSize); + + std::vector payload(957); + return SendQueue::DataToSend(gen_.Ordered(payload, "BE")); + }); + + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 1188 - 12); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _))); +} + +TEST_F(RetransmissionQueueTest, AccountsNackedAbandonedChunksAsNotOutstanding) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B")); + dts.max_retransmissions = 0; + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({5, 6, 7, 8}, "")); + dts.max_retransmissions = 0; + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({9, 10, 11, 12}, "")); + dts.max_retransmissions = 0; + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + // Send and ack first chunk (TSN 10) + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight))); + EXPECT_EQ(queue.outstanding_bytes(), (16 + 4) * 3u); + + // Mark the message as lost. + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(1); + queue.HandleT3RtxTimerExpiry(); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned))); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + + // Now ACK those, one at a time. + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + + queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {})); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, {}, {})); + EXPECT_EQ(queue.outstanding_bytes(), 0u); +} + +TEST_F(RetransmissionQueueTest, ExpireFromSendQueueWhenPartiallySent) { + RetransmissionQueue queue = CreateQueue(); + DataGeneratorOptions options; + options.stream_id = StreamID(17); + options.message_id = MID(42); + TimeMs test_start = now_; + EXPECT_CALL(producer_, Produce) + .WillOnce([&](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B", options)); + dts.expires_at = TimeMs(test_start + DurationMs(10)); + return dts; + }) + .WillOnce([&](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({5, 6, 7, 8}, "", options)); + dts.expires_at = TimeMs(test_start + DurationMs(10)); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 24); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(17), MID(42))) + .WillOnce(Return(true)); + now_ += DurationMs(100); + + EXPECT_THAT(queue.GetChunksToSend(now_, 24), IsEmpty()); + + EXPECT_THAT( + queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // Initial TSN + Pair(TSN(10), State::kAbandoned), // Produced + Pair(TSN(11), State::kAbandoned), // Produced and expired + Pair(TSN(12), State::kAbandoned))); // Placeholder end +} + +TEST_F(RetransmissionQueueTest, LimitsRetransmissionsOnlyWhenNackedThreeTimes) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = 0; + return dts; + }) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _), Pair(TSN(13), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(0); + + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, 2)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, 3)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kInFlight))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, 4)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); +} + +TEST_F(RetransmissionQueueTest, AbandonsRtxLimit2WhenNackedNineTimes) { + // This is a fairly long test. + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = 2; + return dts; + }) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + std::vector> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, + ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), Pair(TSN(12), _), + Pair(TSN(13), _), Pair(TSN(14), _), Pair(TSN(15), _), + Pair(TSN(16), _), Pair(TSN(17), _), Pair(TSN(18), _), + Pair(TSN(19), _))); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight), // + Pair(TSN(14), State::kInFlight), // + Pair(TSN(15), State::kInFlight), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kInFlight), // + Pair(TSN(18), State::kInFlight), // + Pair(TSN(19), State::kInFlight))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(0); + + // Ack TSN [11 to 13] - three nacks for TSN(10), which will retransmit it. + for (int tsn = 11; tsn <= 13; ++tsn) { + queue.HandleSack( + now_, + SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, (tsn - 9))}, {})); + } + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kInFlight), // + Pair(TSN(15), State::kInFlight), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kInFlight), // + Pair(TSN(18), State::kInFlight), // + Pair(TSN(19), State::kInFlight))); + + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), ElementsAre(Pair(TSN(10), _))); + + // Ack TSN [14 to 16] - three more nacks - second and last retransmission. + for (int tsn = 14; tsn <= 16; ++tsn) { + queue.HandleSack( + now_, + SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, (tsn - 9))}, {})); + } + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kAcked), // + Pair(TSN(17), State::kInFlight), // + Pair(TSN(18), State::kInFlight), // + Pair(TSN(19), State::kInFlight))); + + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), ElementsAre(Pair(TSN(10), _))); + + // Ack TSN [17 to 18] + for (int tsn = 17; tsn <= 18; ++tsn) { + queue.HandleSack( + now_, + SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, (tsn - 9))}, {})); + } + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kAcked), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked), // + Pair(TSN(19), State::kInFlight))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + // Ack TSN 19 - three more nacks for TSN 10, no more retransmissions. + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, 10)}, {})); + + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), IsEmpty()); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kAcked), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked), // + Pair(TSN(19), State::kAcked))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); +} // namespace + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/tx/retransmission_timeout.cc b/net/dcsctp/tx/retransmission_timeout.cc new file mode 100644 index 0000000000..7d545a07d0 --- /dev/null +++ b/net/dcsctp/tx/retransmission_timeout.cc @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_timeout.h" + +#include +#include + +#include "net/dcsctp/public/dcsctp_options.h" + +namespace dcsctp { +namespace { +// https://tools.ietf.org/html/rfc4960#section-15 +constexpr double kRtoAlpha = 0.125; +constexpr double kRtoBeta = 0.25; +} // namespace + +RetransmissionTimeout::RetransmissionTimeout(const DcSctpOptions& options) + : min_rto_(*options.rto_min), + max_rto_(*options.rto_max), + max_rtt_(*options.rtt_max), + rto_(*options.rto_initial) {} + +void RetransmissionTimeout::ObserveRTT(DurationMs measured_rtt) { + double rtt = *measured_rtt; + + // Unrealistic values will be skipped. If a wrongly measured (or otherwise + // corrupt) value was processed, it could change the state in a way that would + // take a very long time to recover. + if (rtt < 0.0 || rtt > max_rtt_) { + return; + } + + if (first_measurement_) { + // https://tools.ietf.org/html/rfc4960#section-6.3.1 + // "When the first RTT measurement R is made, set + // SRTT <- R, + // RTTVAR <- R/2, and + // RTO <- SRTT + 4 * RTTVAR." + srtt_ = rtt; + rttvar_ = rtt * 0.5; + rto_ = srtt_ + 4 * rttvar_; + first_measurement_ = false; + } else { + // https://tools.ietf.org/html/rfc4960#section-6.3.1 + // "When a new RTT measurement R' is made, set + // RTTVAR <- (1 - RTO.Beta) * RTTVAR + RTO.Beta * |SRTT - R'| + // SRTT <- (1 - RTO.Alpha) * SRTT + RTO.Alpha * R' + // RTO <- SRTT + 4 * RTTVAR." + rttvar_ = (1 - kRtoBeta) * rttvar_ + kRtoBeta * std::abs(srtt_ - rtt); + srtt_ = (1 - kRtoAlpha) * srtt_ + kRtoAlpha * rtt; + rto_ = srtt_ + 4 * rttvar_; + } + + // If the RTO becomes smaller or equal to RTT, expiration timers will be + // scheduled at the same time as packets are expected. Only happens in + // extremely stable RTTs, i.e. in simulations. + rto_ = std::fmax(rto_, rtt + 1); + + // Clamp RTO between min and max. + rto_ = std::fmin(std::fmax(rto_, min_rto_), max_rto_); +} +} // namespace dcsctp diff --git a/net/dcsctp/tx/retransmission_timeout.h b/net/dcsctp/tx/retransmission_timeout.h new file mode 100644 index 0000000000..0fac33e59c --- /dev/null +++ b/net/dcsctp/tx/retransmission_timeout.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_RETRANSMISSION_TIMEOUT_H_ +#define NET_DCSCTP_TX_RETRANSMISSION_TIMEOUT_H_ + +#include +#include + +#include "net/dcsctp/public/dcsctp_options.h" + +namespace dcsctp { + +// Manages updating of the Retransmission Timeout (RTO) SCTP variable, which is +// used directly as the base timeout for T3-RTX and for other timers, such as +// delayed ack. +// +// When a round-trip-time (RTT) is calculated (outside this class), `Observe` +// is called, which calculates the retransmission timeout (RTO) value. The RTO +// value will become larger if the RTT is high and/or the RTT values are varying +// a lot, which is an indicator of a bad connection. +class RetransmissionTimeout { + public: + explicit RetransmissionTimeout(const DcSctpOptions& options); + + // To be called when a RTT has been measured, to update the RTO value. + void ObserveRTT(DurationMs measured_rtt); + + // Returns the Retransmission Timeout (RTO) value, in milliseconds. + DurationMs rto() const { return DurationMs(rto_); } + + // Returns the smoothed RTT value, in milliseconds. + DurationMs srtt() const { return DurationMs(srtt_); } + + private: + // Note that all intermediate state calculation is done in the floating point + // domain, to maintain precision. + const double min_rto_; + const double max_rto_; + const double max_rtt_; + // If this is the first measurement + bool first_measurement_ = true; + // Smoothed Round-Trip Time + double srtt_ = 0.0; + // Round-Trip Time Variation + double rttvar_ = 0.0; + // Retransmission Timeout + double rto_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_RETRANSMISSION_TIMEOUT_H_ diff --git a/net/dcsctp/tx/retransmission_timeout_test.cc b/net/dcsctp/tx/retransmission_timeout_test.cc new file mode 100644 index 0000000000..3b2e3399fe --- /dev/null +++ b/net/dcsctp/tx/retransmission_timeout_test.cc @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_timeout.h" + +#include "net/dcsctp/public/dcsctp_options.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +constexpr DurationMs kMaxRtt = DurationMs(8'000); +constexpr DurationMs kInitialRto = DurationMs(200); +constexpr DurationMs kMaxRto = DurationMs(800); +constexpr DurationMs kMinRto = DurationMs(120); + +DcSctpOptions MakeOptions() { + DcSctpOptions options; + options.rtt_max = kMaxRtt; + options.rto_initial = kInitialRto; + options.rto_max = kMaxRto; + options.rto_min = kMinRto; + return options; +} + +TEST(RetransmissionTimeoutTest, HasValidInitialRto) { + RetransmissionTimeout rto_(MakeOptions()); + EXPECT_EQ(rto_.rto(), kInitialRto); +} + +TEST(RetransmissionTimeoutTest, NegativeValuesDoNotAffectRTO) { + RetransmissionTimeout rto_(MakeOptions()); + // Initial negative value + rto_.ObserveRTT(DurationMs(-10)); + EXPECT_EQ(rto_.rto(), kInitialRto); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + // Subsequent negative value + rto_.ObserveRTT(DurationMs(-10)); + EXPECT_EQ(*rto_.rto(), 372); +} + +TEST(RetransmissionTimeoutTest, TooLargeValuesDoNotAffectRTO) { + RetransmissionTimeout rto_(MakeOptions()); + // Initial too large value + rto_.ObserveRTT(kMaxRtt + DurationMs(100)); + EXPECT_EQ(rto_.rto(), kInitialRto); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + // Subsequent too large value + rto_.ObserveRTT(kMaxRtt + DurationMs(100)); + EXPECT_EQ(*rto_.rto(), 372); +} + +TEST(RetransmissionTimeoutTest, WillNeverGoBelowMinimumRto) { + RetransmissionTimeout rto_(MakeOptions()); + for (int i = 0; i < 1000; ++i) { + rto_.ObserveRTT(DurationMs(1)); + } + EXPECT_GE(rto_.rto(), kMinRto); +} + +TEST(RetransmissionTimeoutTest, WillNeverGoAboveMaximumRto) { + RetransmissionTimeout rto_(MakeOptions()); + for (int i = 0; i < 1000; ++i) { + rto_.ObserveRTT(kMaxRtt - DurationMs(1)); + // Adding jitter, which would make it RTO be well above RTT. + rto_.ObserveRTT(kMaxRtt - DurationMs(100)); + } + EXPECT_LE(rto_.rto(), kMaxRto); +} + +TEST(RetransmissionTimeoutTest, CalculatesRtoForStableRtt) { + RetransmissionTimeout rto_(MakeOptions()); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + rto_.ObserveRTT(DurationMs(128)); + EXPECT_EQ(*rto_.rto(), 314); + rto_.ObserveRTT(DurationMs(123)); + EXPECT_EQ(*rto_.rto(), 268); + rto_.ObserveRTT(DurationMs(125)); + EXPECT_EQ(*rto_.rto(), 233); + rto_.ObserveRTT(DurationMs(127)); + EXPECT_EQ(*rto_.rto(), 208); +} + +TEST(RetransmissionTimeoutTest, CalculatesRtoForUnstableRtt) { + RetransmissionTimeout rto_(MakeOptions()); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + rto_.ObserveRTT(DurationMs(402)); + EXPECT_EQ(*rto_.rto(), 622); + rto_.ObserveRTT(DurationMs(728)); + EXPECT_EQ(*rto_.rto(), 800); + rto_.ObserveRTT(DurationMs(89)); + EXPECT_EQ(*rto_.rto(), 800); + rto_.ObserveRTT(DurationMs(126)); + EXPECT_EQ(*rto_.rto(), 800); +} + +TEST(RetransmissionTimeoutTest, WillStabilizeAfterAWhile) { + RetransmissionTimeout rto_(MakeOptions()); + rto_.ObserveRTT(DurationMs(124)); + rto_.ObserveRTT(DurationMs(402)); + rto_.ObserveRTT(DurationMs(728)); + rto_.ObserveRTT(DurationMs(89)); + rto_.ObserveRTT(DurationMs(126)); + EXPECT_EQ(*rto_.rto(), 800); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 800); + rto_.ObserveRTT(DurationMs(122)); + EXPECT_EQ(*rto_.rto(), 709); + rto_.ObserveRTT(DurationMs(123)); + EXPECT_EQ(*rto_.rto(), 630); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 561); + rto_.ObserveRTT(DurationMs(122)); + EXPECT_EQ(*rto_.rto(), 504); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 453); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 409); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 339); +} + +TEST(RetransmissionTimeoutTest, WillAlwaysStayAboveRTT) { + // In simulations, it's quite common to have a very stable RTT, and having an + // RTO at the same value will cause issues as expiry timers will be scheduled + // to be expire exactly when a packet is supposed to arrive. The RTO must be + // larger than the RTT. In non-simulated environments, this is a non-issue as + // any jitter will increase the RTO. + RetransmissionTimeout rto_(MakeOptions()); + + for (int i = 0; i < 100; ++i) { + rto_.ObserveRTT(DurationMs(124)); + } + EXPECT_GT(*rto_.rto(), 124); +} + +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/tx/rr_send_queue.cc b/net/dcsctp/tx/rr_send_queue.cc new file mode 100644 index 0000000000..254214e554 --- /dev/null +++ b/net/dcsctp/tx/rr_send_queue.cc @@ -0,0 +1,436 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/rr_send_queue.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +bool RRSendQueue::OutgoingStream::HasDataToSend(TimeMs now) { + while (!items_.empty()) { + RRSendQueue::OutgoingStream::Item& item = items_.front(); + if (item.message_id.has_value()) { + // Already partially sent messages can always continue to be sent. + return true; + } + + // Message has expired. Remove it and inspect the next one. + if (item.expires_at.has_value() && *item.expires_at <= now) { + buffered_amount_.Decrease(item.remaining_size); + total_buffered_amount_.Decrease(item.remaining_size); + items_.pop_front(); + RTC_DCHECK(IsConsistent()); + continue; + } + + if (is_paused_) { + // The stream has paused (and there is no partially sent message). + return false; + } + return true; + } + return false; +} + +bool RRSendQueue::IsConsistent() const { + size_t total_buffered_amount = 0; + for (const auto& stream_entry : streams_) { + total_buffered_amount += stream_entry.second.buffered_amount().value(); + } + + if (previous_message_has_ended_) { + auto it = streams_.find(current_stream_id_); + if (it != streams_.end() && it->second.has_partially_sent_message()) { + RTC_DLOG(LS_ERROR) + << "Previous message has ended, but still partial message in stream"; + return false; + } + } else { + auto it = streams_.find(current_stream_id_); + if (it == streams_.end() || !it->second.has_partially_sent_message()) { + RTC_DLOG(LS_ERROR) + << "Previous message has NOT ended, but there is no partial message"; + return false; + } + } + + return total_buffered_amount == total_buffered_amount_.value(); +} + +bool RRSendQueue::OutgoingStream::IsConsistent() const { + size_t bytes = 0; + for (const auto& item : items_) { + bytes += item.remaining_size; + } + return bytes == buffered_amount_.value(); +} + +void RRSendQueue::ThresholdWatcher::Decrease(size_t bytes) { + RTC_DCHECK(bytes <= value_); + size_t old_value = value_; + value_ -= bytes; + + if (old_value > low_threshold_ && value_ <= low_threshold_) { + on_threshold_reached_(); + } +} + +void RRSendQueue::ThresholdWatcher::SetLowThreshold(size_t low_threshold) { + // Betting on https://github.com/w3c/webrtc-pc/issues/2654 being accepted. + if (low_threshold_ < value_ && low_threshold >= value_) { + on_threshold_reached_(); + } + low_threshold_ = low_threshold; +} + +void RRSendQueue::OutgoingStream::Add(DcSctpMessage message, + absl::optional expires_at, + const SendOptions& send_options) { + buffered_amount_.Increase(message.payload().size()); + total_buffered_amount_.Increase(message.payload().size()); + items_.emplace_back(std::move(message), expires_at, send_options); + + RTC_DCHECK(IsConsistent()); +} + +absl::optional RRSendQueue::OutgoingStream::Produce( + TimeMs now, + size_t max_size) { + RTC_DCHECK(!items_.empty()); + + Item* item = &items_.front(); + DcSctpMessage& message = item->message; + + if (item->remaining_size > max_size && max_size < kMinimumFragmentedPayload) { + RTC_DCHECK(IsConsistent()); + return absl::nullopt; + } + + // Allocate Message ID and SSN when the first fragment is sent. + if (!item->message_id.has_value()) { + MID& mid = + item->send_options.unordered ? next_unordered_mid_ : next_ordered_mid_; + item->message_id = mid; + mid = MID(*mid + 1); + } + if (!item->send_options.unordered && !item->ssn.has_value()) { + item->ssn = next_ssn_; + next_ssn_ = SSN(*next_ssn_ + 1); + } + + // Grab the next `max_size` fragment from this message and calculate flags. + rtc::ArrayView chunk_payload = + item->message.payload().subview(item->remaining_offset, max_size); + rtc::ArrayView message_payload = message.payload(); + Data::IsBeginning is_beginning(chunk_payload.data() == + message_payload.data()); + Data::IsEnd is_end((chunk_payload.data() + chunk_payload.size()) == + (message_payload.data() + message_payload.size())); + + StreamID stream_id = message.stream_id(); + PPID ppid = message.ppid(); + + // Zero-copy the payload if the message fits in a single chunk. + std::vector payload = + is_beginning && is_end + ? std::move(message).ReleasePayload() + : std::vector(chunk_payload.begin(), chunk_payload.end()); + + FSN fsn(item->current_fsn); + item->current_fsn = FSN(*item->current_fsn + 1); + buffered_amount_.Decrease(payload.size()); + total_buffered_amount_.Decrease(payload.size()); + + SendQueue::DataToSend chunk(Data(stream_id, item->ssn.value_or(SSN(0)), + item->message_id.value(), fsn, ppid, + std::move(payload), is_beginning, is_end, + item->send_options.unordered)); + chunk.max_retransmissions = item->send_options.max_retransmissions; + chunk.expires_at = item->expires_at; + + if (is_end) { + // The entire message has been sent, and its last data copied to `chunk`, so + // it can safely be discarded. + items_.pop_front(); + } else { + item->remaining_offset += chunk_payload.size(); + item->remaining_size -= chunk_payload.size(); + RTC_DCHECK(item->remaining_offset + item->remaining_size == + item->message.payload().size()); + RTC_DCHECK(item->remaining_size > 0); + } + RTC_DCHECK(IsConsistent()); + return chunk; +} + +bool RRSendQueue::OutgoingStream::Discard(IsUnordered unordered, + MID message_id) { + bool result = false; + if (!items_.empty()) { + Item& item = items_.front(); + if (item.send_options.unordered == unordered && + item.message_id.has_value() && *item.message_id == message_id) { + buffered_amount_.Decrease(item.remaining_size); + total_buffered_amount_.Decrease(item.remaining_size); + items_.pop_front(); + // As the item still existed, it had unsent data. + result = true; + } + } + RTC_DCHECK(IsConsistent()); + return result; +} + +void RRSendQueue::OutgoingStream::Pause() { + is_paused_ = true; + + // A stream is paused when it's about to be reset. In this implementation, + // it will throw away all non-partially send messages. This is subject to + // change. It will however not discard any partially sent messages - only + // whole messages. Partially delivered messages (at the time of receiving a + // Stream Reset command) will always deliver all the fragments before + // actually resetting the stream. + for (auto it = items_.begin(); it != items_.end();) { + if (it->remaining_offset == 0) { + buffered_amount_.Decrease(it->remaining_size); + total_buffered_amount_.Decrease(it->remaining_size); + it = items_.erase(it); + } else { + ++it; + } + } + RTC_DCHECK(IsConsistent()); +} + +void RRSendQueue::OutgoingStream::Reset() { + if (!items_.empty()) { + // If this message has been partially sent, reset it so that it will be + // re-sent. + auto& item = items_.front(); + buffered_amount_.Increase(item.message.payload().size() - + item.remaining_size); + total_buffered_amount_.Increase(item.message.payload().size() - + item.remaining_size); + item.remaining_offset = 0; + item.remaining_size = item.message.payload().size(); + item.message_id = absl::nullopt; + item.ssn = absl::nullopt; + item.current_fsn = FSN(0); + } + is_paused_ = false; + next_ordered_mid_ = MID(0); + next_unordered_mid_ = MID(0); + next_ssn_ = SSN(0); + RTC_DCHECK(IsConsistent()); +} + +bool RRSendQueue::OutgoingStream::has_partially_sent_message() const { + if (items_.empty()) { + return false; + } + return items_.front().message_id.has_value(); +} + +void RRSendQueue::Add(TimeMs now, + DcSctpMessage message, + const SendOptions& send_options) { + RTC_DCHECK(!message.payload().empty()); + // Any limited lifetime should start counting from now - when the message + // has been added to the queue. + absl::optional expires_at = absl::nullopt; + if (send_options.lifetime.has_value()) { + // `expires_at` is the time when it expires. Which is slightly larger than + // the message's lifetime, as the message is alive during its entire + // lifetime (which may be zero). + expires_at = now + *send_options.lifetime + DurationMs(1); + } + GetOrCreateStreamInfo(message.stream_id()) + .Add(std::move(message), expires_at, send_options); + RTC_DCHECK(IsConsistent()); +} + +bool RRSendQueue::IsFull() const { + return total_buffered_amount() >= buffer_size_; +} + +bool RRSendQueue::IsEmpty() const { + return total_buffered_amount() == 0; +} + +std::map::iterator +RRSendQueue::GetNextStream(TimeMs now) { + auto start_it = streams_.lower_bound(StreamID(*current_stream_id_ + 1)); + + for (auto it = start_it; it != streams_.end(); ++it) { + if (it->second.HasDataToSend(now)) { + current_stream_id_ = it->first; + return it; + } + } + + for (auto it = streams_.begin(); it != start_it; ++it) { + if (it->second.HasDataToSend(now)) { + current_stream_id_ = it->first; + return it; + } + } + return streams_.end(); +} + +absl::optional RRSendQueue::Produce(TimeMs now, + size_t max_size) { + std::map::iterator stream_it; + + if (previous_message_has_ended_) { + // Previous message has ended. Round-robin to a different stream, if there + // even is one with data to send. + stream_it = GetNextStream(now); + if (stream_it == streams_.end()) { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ + << "There is no stream with data; Can't produce any data."; + return absl::nullopt; + } + } else { + // The previous message has not ended; Continue from the current stream. + stream_it = streams_.find(current_stream_id_); + RTC_DCHECK(stream_it != streams_.end()); + } + + absl::optional data = stream_it->second.Produce(now, max_size); + if (data.has_value()) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Producing DATA, type=" + << (data->data.is_unordered ? "unordered" : "ordered") + << "::" + << (*data->data.is_beginning && *data->data.is_end + ? "complete" + : *data->data.is_beginning + ? "first" + : *data->data.is_end ? "last" : "middle") + << ", stream_id=" << *stream_it->first + << ", ppid=" << *data->data.ppid + << ", length=" << data->data.payload.size(); + + previous_message_has_ended_ = *data->data.is_end; + } + + RTC_DCHECK(IsConsistent()); + return data; +} + +bool RRSendQueue::Discard(IsUnordered unordered, + StreamID stream_id, + MID message_id) { + bool has_discarded = + GetOrCreateStreamInfo(stream_id).Discard(unordered, message_id); + if (has_discarded) { + // Only partially sent messages are discarded, so if a message was + // discarded, then it was the currently sent message. + previous_message_has_ended_ = true; + } + + return has_discarded; +} + +void RRSendQueue::PrepareResetStreams(rtc::ArrayView streams) { + for (StreamID stream_id : streams) { + GetOrCreateStreamInfo(stream_id).Pause(); + } + RTC_DCHECK(IsConsistent()); +} + +bool RRSendQueue::CanResetStreams() const { + // Streams can be reset if those streams that are paused don't have any + // messages that are partially sent. + for (auto& stream : streams_) { + if (stream.second.is_paused() && + stream.second.has_partially_sent_message()) { + return false; + } + } + return true; +} + +void RRSendQueue::CommitResetStreams() { + for (auto& stream_entry : streams_) { + if (stream_entry.second.is_paused()) { + stream_entry.second.Reset(); + } + } + RTC_DCHECK(IsConsistent()); +} + +void RRSendQueue::RollbackResetStreams() { + for (auto& stream_entry : streams_) { + stream_entry.second.Resume(); + } + RTC_DCHECK(IsConsistent()); +} + +void RRSendQueue::Reset() { + // Recalculate buffered amount, as partially sent messages may have been put + // fully back in the queue. + for (auto& stream_entry : streams_) { + OutgoingStream& stream = stream_entry.second; + stream.Reset(); + } + previous_message_has_ended_ = true; +} + +size_t RRSendQueue::buffered_amount(StreamID stream_id) const { + auto it = streams_.find(stream_id); + if (it == streams_.end()) { + return 0; + } + return it->second.buffered_amount().value(); +} + +size_t RRSendQueue::buffered_amount_low_threshold(StreamID stream_id) const { + auto it = streams_.find(stream_id); + if (it == streams_.end()) { + return 0; + } + return it->second.buffered_amount().low_threshold(); +} + +void RRSendQueue::SetBufferedAmountLowThreshold(StreamID stream_id, + size_t bytes) { + GetOrCreateStreamInfo(stream_id).buffered_amount().SetLowThreshold(bytes); +} + +RRSendQueue::OutgoingStream& RRSendQueue::GetOrCreateStreamInfo( + StreamID stream_id) { + auto it = streams_.find(stream_id); + if (it != streams_.end()) { + return it->second; + } + + return streams_ + .emplace(stream_id, + OutgoingStream( + [this, stream_id]() { on_buffered_amount_low_(stream_id); }, + total_buffered_amount_)) + .first->second; +} +} // namespace dcsctp diff --git a/net/dcsctp/tx/rr_send_queue.h b/net/dcsctp/tx/rr_send_queue.h new file mode 100644 index 0000000000..3ec45af17d --- /dev/null +++ b/net/dcsctp/tx/rr_send_queue.h @@ -0,0 +1,238 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_RR_SEND_QUEUE_H_ +#define NET_DCSCTP_TX_RR_SEND_QUEUE_H_ + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/pair_hash.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/tx/send_queue.h" + +namespace dcsctp { + +// The Round Robin SendQueue holds all messages that the client wants to send, +// but that haven't yet been split into chunks and fully sent on the wire. +// +// As defined in https://datatracker.ietf.org/doc/html/rfc8260#section-3.2, +// it will cycle to send messages from different streams. It will send all +// fragments from one message before continuing with a different message on +// possibly a different stream, until support for message interleaving has been +// implemented. +// +// As messages can be (requested to be) sent before the connection is properly +// established, this send queue is always present - even for closed connections. +class RRSendQueue : public SendQueue { + public: + // How small a data chunk's payload may be, if having to fragment a message. + static constexpr size_t kMinimumFragmentedPayload = 10; + + RRSendQueue(absl::string_view log_prefix, + size_t buffer_size, + std::function on_buffered_amount_low, + size_t total_buffered_amount_low_threshold, + std::function on_total_buffered_amount_low) + : log_prefix_(std::string(log_prefix) + "fcfs: "), + buffer_size_(buffer_size), + on_buffered_amount_low_(std::move(on_buffered_amount_low)), + total_buffered_amount_(std::move(on_total_buffered_amount_low)) { + total_buffered_amount_.SetLowThreshold(total_buffered_amount_low_threshold); + } + + // Indicates if the buffer is full. Note that it's up to the caller to ensure + // that the buffer is not full prior to adding new items to it. + bool IsFull() const; + // Indicates if the buffer is empty. + bool IsEmpty() const; + + // Adds the message to be sent using the `send_options` provided. The current + // time should be in `now`. Note that it's the responsibility of the caller to + // ensure that the buffer is not full (by calling `IsFull`) before adding + // messages to it. + void Add(TimeMs now, + DcSctpMessage message, + const SendOptions& send_options = {}); + + // Implementation of `SendQueue`. + absl::optional Produce(TimeMs now, size_t max_size) override; + bool Discard(IsUnordered unordered, + StreamID stream_id, + MID message_id) override; + void PrepareResetStreams(rtc::ArrayView streams) override; + bool CanResetStreams() const override; + void CommitResetStreams() override; + void RollbackResetStreams() override; + void Reset() override; + size_t buffered_amount(StreamID stream_id) const override; + size_t total_buffered_amount() const override { + return total_buffered_amount_.value(); + } + size_t buffered_amount_low_threshold(StreamID stream_id) const override; + void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override; + + private: + // Represents a value and a "low threshold" that when the value reaches or + // goes under the "low threshold", will trigger `on_threshold_reached` + // callback. + class ThresholdWatcher { + public: + explicit ThresholdWatcher(std::function on_threshold_reached) + : on_threshold_reached_(std::move(on_threshold_reached)) {} + // Increases the value. + void Increase(size_t bytes) { value_ += bytes; } + // Decreases the value and triggers `on_threshold_reached` if it's at or + // below `low_threshold()`. + void Decrease(size_t bytes); + + size_t value() const { return value_; } + size_t low_threshold() const { return low_threshold_; } + void SetLowThreshold(size_t low_threshold); + + private: + const std::function on_threshold_reached_; + size_t value_ = 0; + size_t low_threshold_ = 0; + }; + + // Per-stream information. + class OutgoingStream { + public: + explicit OutgoingStream(std::function on_buffered_amount_low, + ThresholdWatcher& total_buffered_amount) + : buffered_amount_(std::move(on_buffered_amount_low)), + total_buffered_amount_(total_buffered_amount) {} + + // Enqueues a message to this stream. + void Add(DcSctpMessage message, + absl::optional expires_at, + const SendOptions& send_options); + + // Possibly produces a data chunk to send. + absl::optional Produce(TimeMs now, size_t max_size); + + const ThresholdWatcher& buffered_amount() const { return buffered_amount_; } + ThresholdWatcher& buffered_amount() { return buffered_amount_; } + + // Discards a partially sent message, see `SendQueue::Discard`. + bool Discard(IsUnordered unordered, MID message_id); + + // Pauses this stream, which is used before resetting it. + void Pause(); + + // Resumes a paused stream. + void Resume() { is_paused_ = false; } + + bool is_paused() const { return is_paused_; } + + // Resets this stream, meaning MIDs and SSNs are set to zero. + void Reset(); + + // Indicates if this stream has a partially sent message in it. + bool has_partially_sent_message() const; + + // Indicates if the stream has data to send. It will also try to remove any + // expired non-partially sent message. + bool HasDataToSend(TimeMs now); + + private: + // An enqueued message and metadata. + struct Item { + explicit Item(DcSctpMessage msg, + absl::optional expires_at, + const SendOptions& send_options) + : message(std::move(msg)), + expires_at(expires_at), + send_options(send_options), + remaining_offset(0), + remaining_size(message.payload().size()) {} + DcSctpMessage message; + absl::optional expires_at; + SendOptions send_options; + // The remaining payload (offset and size) to be sent, when it has been + // fragmented. + size_t remaining_offset; + size_t remaining_size; + // If set, an allocated Message ID and SSN. Will be allocated when the + // first fragment is sent. + absl::optional message_id = absl::nullopt; + absl::optional ssn = absl::nullopt; + // The current Fragment Sequence Number, incremented for each fragment. + FSN current_fsn = FSN(0); + }; + + bool IsConsistent() const; + + // Streams are pause when they are about to be reset. + bool is_paused_ = false; + // MIDs are different for unordered and ordered messages sent on a stream. + MID next_unordered_mid_ = MID(0); + MID next_ordered_mid_ = MID(0); + + SSN next_ssn_ = SSN(0); + // Enqueued messages, and metadata. + std::deque items_; + + // The current amount of buffered data. + ThresholdWatcher buffered_amount_; + + // Reference to the total buffered amount, which is updated directly by each + // stream. + ThresholdWatcher& total_buffered_amount_; + }; + + bool IsConsistent() const; + OutgoingStream& GetOrCreateStreamInfo(StreamID stream_id); + absl::optional Produce( + std::map::iterator it, + TimeMs now, + size_t max_size); + + // Return the next stream, in round-robin fashion. + std::map::iterator GetNextStream(TimeMs now); + + const std::string log_prefix_; + const size_t buffer_size_; + + // Called when the buffered amount is below what has been set using + // `SetBufferedAmountLowThreshold`. + const std::function on_buffered_amount_low_; + + // Called when the total buffered amount is below what has been set using + // `SetTotalBufferedAmountLowThreshold`. + const std::function on_total_buffered_amount_low_; + + // The total amount of buffer data, for all streams. + ThresholdWatcher total_buffered_amount_; + + // Indicates if the previous fragment sent was the end of a message. For + // non-interleaved sending, this means that the next message may come from a + // different stream. If not true, the next fragment must be produced from the + // same stream as last time. + bool previous_message_has_ended_ = true; + + // The current stream to send chunks from. Modified by `GetNextStream`. + StreamID current_stream_id_ = StreamID(0); + + // All streams, and messages added to those. + std::map streams_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_RR_SEND_QUEUE_H_ diff --git a/net/dcsctp/tx/rr_send_queue_test.cc b/net/dcsctp/tx/rr_send_queue_test.cc new file mode 100644 index 0000000000..425027762d --- /dev/null +++ b/net/dcsctp/tx/rr_send_queue_test.cc @@ -0,0 +1,783 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/rr_send_queue.h" + +#include +#include +#include + +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::SizeIs; + +constexpr TimeMs kNow = TimeMs(0); +constexpr StreamID kStreamID(1); +constexpr PPID kPPID(53); +constexpr size_t kMaxQueueSize = 1000; +constexpr size_t kBufferedAmountLowThreshold = 500; +constexpr size_t kOneFragmentPacketSize = 100; +constexpr size_t kTwoFragmentPacketSize = 101; + +class RRSendQueueTest : public testing::Test { + protected: + RRSendQueueTest() + : buf_("log: ", + kMaxQueueSize, + on_buffered_amount_low_.AsStdFunction(), + kBufferedAmountLowThreshold, + on_total_buffered_amount_low_.AsStdFunction()) {} + + const DcSctpOptions options_; + testing::NiceMock> + on_buffered_amount_low_; + testing::NiceMock> + on_total_buffered_amount_low_; + RRSendQueue buf_; +}; + +TEST_F(RRSendQueueTest, EmptyBuffer) { + EXPECT_TRUE(buf_.IsEmpty()); + EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); + EXPECT_FALSE(buf_.IsFull()); +} + +TEST_F(RRSendQueueTest, AddAndGetSingleChunk) { + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, {1, 2, 4, 5, 6})); + + EXPECT_FALSE(buf_.IsEmpty()); + EXPECT_FALSE(buf_.IsFull()); + absl::optional chunk_opt = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_opt.has_value()); + EXPECT_TRUE(chunk_opt->data.is_beginning); + EXPECT_TRUE(chunk_opt->data.is_end); +} + +TEST_F(RRSendQueueTest, CarveOutBeginningMiddleAndEnd) { + std::vector payload(60); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + absl::optional chunk_beg = + buf_.Produce(kNow, /*max_size=*/20); + ASSERT_TRUE(chunk_beg.has_value()); + EXPECT_TRUE(chunk_beg->data.is_beginning); + EXPECT_FALSE(chunk_beg->data.is_end); + + absl::optional chunk_mid = + buf_.Produce(kNow, /*max_size=*/20); + ASSERT_TRUE(chunk_mid.has_value()); + EXPECT_FALSE(chunk_mid->data.is_beginning); + EXPECT_FALSE(chunk_mid->data.is_end); + + absl::optional chunk_end = + buf_.Produce(kNow, /*max_size=*/20); + ASSERT_TRUE(chunk_end.has_value()); + EXPECT_FALSE(chunk_end->data.is_beginning); + EXPECT_TRUE(chunk_end->data.is_end); + + EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); +} + +TEST_F(RRSendQueueTest, GetChunksFromTwoMessages) { + std::vector payload(60); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(3), PPID(54), payload)); + + absl::optional chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(chunk_one->data.ppid, kPPID); + EXPECT_TRUE(chunk_one->data.is_beginning); + EXPECT_TRUE(chunk_one->data.is_end); + + absl::optional chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.stream_id, StreamID(3)); + EXPECT_EQ(chunk_two->data.ppid, PPID(54)); + EXPECT_TRUE(chunk_two->data.is_beginning); + EXPECT_TRUE(chunk_two->data.is_end); +} + +TEST_F(RRSendQueueTest, BufferBecomesFullAndEmptied) { + std::vector payload(600); + EXPECT_FALSE(buf_.IsFull()); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + EXPECT_FALSE(buf_.IsFull()); + buf_.Add(kNow, DcSctpMessage(StreamID(3), PPID(54), payload)); + EXPECT_TRUE(buf_.IsFull()); + // However, it's still possible to add messages. It's a soft limit, and it + // might be necessary to forcefully add messages due to e.g. external + // fragmentation. + buf_.Add(kNow, DcSctpMessage(StreamID(5), PPID(55), payload)); + EXPECT_TRUE(buf_.IsFull()); + + absl::optional chunk_one = buf_.Produce(kNow, 1000); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(chunk_one->data.ppid, kPPID); + + EXPECT_TRUE(buf_.IsFull()); + + absl::optional chunk_two = buf_.Produce(kNow, 1000); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.stream_id, StreamID(3)); + EXPECT_EQ(chunk_two->data.ppid, PPID(54)); + + EXPECT_FALSE(buf_.IsFull()); + EXPECT_FALSE(buf_.IsEmpty()); + + absl::optional chunk_three = buf_.Produce(kNow, 1000); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_EQ(chunk_three->data.stream_id, StreamID(5)); + EXPECT_EQ(chunk_three->data.ppid, PPID(55)); + + EXPECT_FALSE(buf_.IsFull()); + EXPECT_TRUE(buf_.IsEmpty()); +} + +TEST_F(RRSendQueueTest, WillNotSendTooSmallPacket) { + std::vector payload(RRSendQueue::kMinimumFragmentedPayload + 1); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + // Wouldn't fit enough payload (wouldn't want to fragment) + EXPECT_FALSE( + buf_.Produce(kNow, + /*max_size=*/RRSendQueue::kMinimumFragmentedPayload - 1) + .has_value()); + + // Minimum fragment + absl::optional chunk_one = + buf_.Produce(kNow, + /*max_size=*/RRSendQueue::kMinimumFragmentedPayload); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(chunk_one->data.ppid, kPPID); + + // There is only one byte remaining - it can be fetched as it doesn't require + // additional fragmentation. + absl::optional chunk_two = + buf_.Produce(kNow, /*max_size=*/1); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.stream_id, kStreamID); + EXPECT_EQ(chunk_two->data.ppid, kPPID); + + EXPECT_TRUE(buf_.IsEmpty()); +} + +TEST_F(RRSendQueueTest, DefaultsToOrderedSend) { + std::vector payload(20); + + // Default is ordered + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + absl::optional chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_FALSE(chunk_one->data.is_unordered); + + // Explicitly unordered. + SendOptions opts; + opts.unordered = IsUnordered(true); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload), opts); + absl::optional chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_TRUE(chunk_two->data.is_unordered); +} + +TEST_F(RRSendQueueTest, ProduceWithLifetimeExpiry) { + std::vector payload(20); + + // Default is no expiry + TimeMs now = kNow; + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload)); + now += DurationMs(1000000); + ASSERT_TRUE(buf_.Produce(now, kOneFragmentPacketSize)); + + SendOptions expires_2_seconds; + expires_2_seconds.lifetime = DurationMs(2000); + + // Add and consume within lifetime + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); + now += DurationMs(2000); + ASSERT_TRUE(buf_.Produce(now, kOneFragmentPacketSize)); + + // Add and consume just outside lifetime + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); + now += DurationMs(2001); + ASSERT_FALSE(buf_.Produce(now, kOneFragmentPacketSize)); + + // A long time after expiry + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); + now += DurationMs(1000000); + ASSERT_FALSE(buf_.Produce(now, kOneFragmentPacketSize)); + + // Expire one message, but produce the second that is not expired. + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); + + SendOptions expires_4_seconds; + expires_4_seconds.lifetime = DurationMs(4000); + + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_4_seconds); + now += DurationMs(2001); + + ASSERT_TRUE(buf_.Produce(now, kOneFragmentPacketSize)); + ASSERT_FALSE(buf_.Produce(now, kOneFragmentPacketSize)); +} + +TEST_F(RRSendQueueTest, DiscardPartialPackets) { + std::vector payload(120); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(2), PPID(54), payload)); + + absl::optional chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_FALSE(chunk_one->data.is_end); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + buf_.Discard(IsUnordered(false), chunk_one->data.stream_id, + chunk_one->data.message_id); + + absl::optional chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_FALSE(chunk_two->data.is_end); + EXPECT_EQ(chunk_two->data.stream_id, StreamID(2)); + + absl::optional chunk_three = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_TRUE(chunk_three->data.is_end); + EXPECT_EQ(chunk_three->data.stream_id, StreamID(2)); + ASSERT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize)); + + // Calling it again shouldn't cause issues. + buf_.Discard(IsUnordered(false), chunk_one->data.stream_id, + chunk_one->data.message_id); + ASSERT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize)); +} + +TEST_F(RRSendQueueTest, PrepareResetStreamsDiscardsStream) { + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, {1, 2, 3})); + buf_.Add(kNow, DcSctpMessage(StreamID(2), PPID(54), {1, 2, 3, 4, 5})); + EXPECT_EQ(buf_.total_buffered_amount(), 8u); + + buf_.PrepareResetStreams(std::vector({StreamID(1)})); + EXPECT_EQ(buf_.total_buffered_amount(), 5u); + buf_.CommitResetStreams(); + buf_.PrepareResetStreams(std::vector({StreamID(2)})); + EXPECT_EQ(buf_.total_buffered_amount(), 0u); +} + +TEST_F(RRSendQueueTest, PrepareResetStreamsNotPartialPackets) { + std::vector payload(120); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + absl::optional chunk_one = buf_.Produce(kNow, 50); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(buf_.total_buffered_amount(), 2 * payload.size() - 50); + + StreamID stream_ids[] = {StreamID(1)}; + buf_.PrepareResetStreams(stream_ids); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size() - 50); +} + +TEST_F(RRSendQueueTest, EnqueuedItemsArePausedDuringStreamReset) { + std::vector payload(50); + + buf_.PrepareResetStreams(std::vector({StreamID(1)})); + EXPECT_EQ(buf_.total_buffered_amount(), 0u); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size()); + + EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); + buf_.CommitResetStreams(); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size()); + + absl::optional chunk_one = buf_.Produce(kNow, 50); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(buf_.total_buffered_amount(), 0u); +} + +TEST_F(RRSendQueueTest, CommittingResetsSSN) { + std::vector payload(50); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + absl::optional chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.ssn, SSN(0)); + + absl::optional chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.ssn, SSN(1)); + + StreamID stream_ids[] = {StreamID(1)}; + buf_.PrepareResetStreams(stream_ids); + + // Buffered + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + EXPECT_TRUE(buf_.CanResetStreams()); + buf_.CommitResetStreams(); + + absl::optional chunk_three = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_EQ(chunk_three->data.ssn, SSN(0)); +} + +TEST_F(RRSendQueueTest, CommittingResetsSSNForPausedStreamsOnly) { + std::vector payload(50); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, payload)); + + absl::optional chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, StreamID(1)); + EXPECT_EQ(chunk_one->data.ssn, SSN(0)); + + absl::optional chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.stream_id, StreamID(3)); + EXPECT_EQ(chunk_two->data.ssn, SSN(0)); + + StreamID stream_ids[] = {StreamID(3)}; + buf_.PrepareResetStreams(stream_ids); + + // Send two more messages - SID 3 will buffer, SID 1 will send. + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, payload)); + + EXPECT_TRUE(buf_.CanResetStreams()); + buf_.CommitResetStreams(); + + absl::optional chunk_three = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_EQ(chunk_three->data.stream_id, StreamID(1)); + EXPECT_EQ(chunk_three->data.ssn, SSN(1)); + + absl::optional chunk_four = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_four.has_value()); + EXPECT_EQ(chunk_four->data.stream_id, StreamID(3)); + EXPECT_EQ(chunk_four->data.ssn, SSN(0)); +} + +TEST_F(RRSendQueueTest, RollBackResumesSSN) { + std::vector payload(50); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + absl::optional chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.ssn, SSN(0)); + + absl::optional chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.ssn, SSN(1)); + + buf_.PrepareResetStreams(std::vector({StreamID(1)})); + + // Buffered + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + EXPECT_TRUE(buf_.CanResetStreams()); + buf_.RollbackResetStreams(); + + absl::optional chunk_three = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_EQ(chunk_three->data.ssn, SSN(2)); +} + +TEST_F(RRSendQueueTest, ReturnsFragmentsForOneMessageBeforeMovingToNext) { + std::vector payload(200); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, payload)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(2)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk4.data.stream_id, StreamID(2)); +} + +TEST_F(RRSendQueueTest, ReturnsAlsoSmallFragmentsBeforeMovingToNext) { + std::vector payload(kTwoFragmentPacketSize); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, payload)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(kOneFragmentPacketSize)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, + SizeIs(kTwoFragmentPacketSize - kOneFragmentPacketSize)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(2)); + EXPECT_THAT(chunk3.data.payload, SizeIs(kOneFragmentPacketSize)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk4.data.stream_id, StreamID(2)); + EXPECT_THAT(chunk4.data.payload, + SizeIs(kTwoFragmentPacketSize - kOneFragmentPacketSize)); +} + +TEST_F(RRSendQueueTest, WillCycleInRoundRobinFashionBetweenStreams) { + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector(1))); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector(2))); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector(3))); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector(4))); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, std::vector(5))); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, std::vector(6))); + buf_.Add(kNow, DcSctpMessage(StreamID(4), kPPID, std::vector(7))); + buf_.Add(kNow, DcSctpMessage(StreamID(4), kPPID, std::vector(8))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(2)); + EXPECT_THAT(chunk2.data.payload, SizeIs(3)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(3)); + EXPECT_THAT(chunk3.data.payload, SizeIs(5)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk4.data.stream_id, StreamID(4)); + EXPECT_THAT(chunk4.data.payload, SizeIs(7)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk5, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk5.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk5.data.payload, SizeIs(2)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk6, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk6.data.stream_id, StreamID(2)); + EXPECT_THAT(chunk6.data.payload, SizeIs(4)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk7, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk7.data.stream_id, StreamID(3)); + EXPECT_THAT(chunk7.data.payload, SizeIs(6)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk8, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk8.data.stream_id, StreamID(4)); + EXPECT_THAT(chunk8.data.payload, SizeIs(8)); +} + +TEST_F(RRSendQueueTest, DoesntTriggerOnBufferedAmountLowWhenSetToZero) { + EXPECT_CALL(on_buffered_amount_low_, Call).Times(0); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 0u); +} + +TEST_F(RRSendQueueTest, TriggersOnBufferedAmountAtZeroLowWhenSent) { + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector(1))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 1u); + + EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(1)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 0u); +} + +TEST_F(RRSendQueueTest, WillRetriggerOnBufferedAmountLowIfAddingMore) { + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector(1))); + + EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(1)); + + EXPECT_CALL(on_buffered_amount_low_, Call).Times(0); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector(1))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 1u); + + // Should now trigger again, as buffer_amount went above the threshold. + EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1))); + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(1)); +} + +TEST_F(RRSendQueueTest, OnlyTriggersWhenTransitioningFromAboveToBelowOrEqual) { + buf_.SetBufferedAmountLowThreshold(StreamID(1), 1000); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector(10))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 10u); + + EXPECT_CALL(on_buffered_amount_low_, Call).Times(0); + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(10)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 0u); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector(20))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 20u); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(20)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 0u); +} + +TEST_F(RRSendQueueTest, WillTriggerOnBufferedAmountLowSetAboveZero) { + EXPECT_CALL(on_buffered_amount_low_, Call).Times(0); + + buf_.SetBufferedAmountLowThreshold(StreamID(1), 700); + + std::vector payload(1000); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(kOneFragmentPacketSize)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 900u); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(kOneFragmentPacketSize)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 800u); + + EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk3.data.payload, SizeIs(kOneFragmentPacketSize)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 700u); + + // Doesn't trigger when reducing even further. + EXPECT_CALL(on_buffered_amount_low_, Call).Times(0); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk3.data.payload, SizeIs(kOneFragmentPacketSize)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 600u); +} + +TEST_F(RRSendQueueTest, WillRetriggerOnBufferedAmountLowSetAboveZero) { + EXPECT_CALL(on_buffered_amount_low_, Call).Times(0); + + buf_.SetBufferedAmountLowThreshold(StreamID(1), 700); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector(1000))); + + EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1))); + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, 400)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(400)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 600u); + + EXPECT_CALL(on_buffered_amount_low_, Call).Times(0); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector(200))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 800u); + + // Will trigger again, as it went above the limit. + EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1))); + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, 200)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(200)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 600u); +} + +TEST_F(RRSendQueueTest, TriggersOnBufferedAmountLowOnThresholdChanged) { + EXPECT_CALL(on_buffered_amount_low_, Call).Times(0); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector(100))); + + // Modifying the threshold, still under buffered_amount, should not trigger. + buf_.SetBufferedAmountLowThreshold(StreamID(1), 50); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 99); + + // When the threshold reaches buffered_amount, it will trigger. + EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1))); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 100); + + // But not when it's set low again. + EXPECT_CALL(on_buffered_amount_low_, Call).Times(0); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 50); + + // But it will trigger when it overshoots. + EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1))); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 150); + + // But not when it's set low again. + EXPECT_CALL(on_buffered_amount_low_, Call).Times(0); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 0); +} + +TEST_F(RRSendQueueTest, + OnTotalBufferedAmountLowDoesNotTriggerOnBufferFillingUp) { + EXPECT_CALL(on_total_buffered_amount_low_, Call).Times(0); + std::vector payload(kBufferedAmountLowThreshold - 1); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size()); + + // Will not trigger if going above but never below. + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, + std::vector(kOneFragmentPacketSize))); +} + +TEST_F(RRSendQueueTest, TriggersOnTotalBufferedAmountLowWhenCrossing) { + EXPECT_CALL(on_total_buffered_amount_low_, Call).Times(0); + std::vector payload(kBufferedAmountLowThreshold); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size()); + + // Reaches it. + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, std::vector(1))); + + // Drain it a bit - will trigger. + EXPECT_CALL(on_total_buffered_amount_low_, Call).Times(1); + absl::optional chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); +} + +TEST_F(RRSendQueueTest, WillStayInAStreamAsLongAsThatMessageIsSending) { + buf_.Add(kNow, DcSctpMessage(StreamID(5), kPPID, std::vector(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(5)); + EXPECT_THAT(chunk1.data.payload, SizeIs(1)); + + // Next, it should pick a different stream. + + buf_.Add(kNow, + DcSctpMessage(StreamID(1), kPPID, + std::vector(kOneFragmentPacketSize * 2))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(kOneFragmentPacketSize)); + + // It should still stay on the Stream1 now, even if might be tempted to switch + // to this stream, as it's the stream following 5. + buf_.Add(kNow, DcSctpMessage(StreamID(6), kPPID, std::vector(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk3.data.payload, SizeIs(kOneFragmentPacketSize)); + + // After stream id 1 is complete, it's time to do stream 6. + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk4.data.stream_id, StreamID(6)); + EXPECT_THAT(chunk4.data.payload, SizeIs(1)); + + EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); +} + +TEST_F(RRSendQueueTest, WillStayInStreamWhenOnlySmallFragmentRemaining) { + buf_.Add(kNow, + DcSctpMessage(StreamID(5), kPPID, + std::vector(kOneFragmentPacketSize * 2))); + buf_.Add(kNow, DcSctpMessage(StreamID(6), kPPID, std::vector(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(5)); + EXPECT_THAT(chunk1.data.payload, SizeIs(kOneFragmentPacketSize)); + + // Now assume that there will be a lot of previous chunks that need to be + // retransmitted, which fills up the next packet and there is little space + // left in the packet for new chunks. What it should NOT do right now is to + // try to send a message from StreamID 6. And it should not try to send a very + // small fragment from StreamID 5 either. So just skip this one. + EXPECT_FALSE(buf_.Produce(kNow, 8).has_value()); + + // When the next produce request comes with a large buffer to fill, continue + // sending from StreamID 5. + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(5)); + EXPECT_THAT(chunk2.data.payload, SizeIs(kOneFragmentPacketSize)); + + // Lastly, produce a message on StreamID 6. + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(6)); + EXPECT_THAT(chunk3.data.payload, SizeIs(1)); + + EXPECT_FALSE(buf_.Produce(kNow, 8).has_value()); +} +} // namespace +} // namespace dcsctp diff --git a/net/dcsctp/tx/send_queue.h b/net/dcsctp/tx/send_queue.h new file mode 100644 index 0000000000..877dbdda59 --- /dev/null +++ b/net/dcsctp/tx/send_queue.h @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_SEND_QUEUE_H_ +#define NET_DCSCTP_TX_SEND_QUEUE_H_ + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +class SendQueue { + public: + // Container for a data chunk that is produced by the SendQueue + struct DataToSend { + explicit DataToSend(Data data) : data(std::move(data)) {} + // The data to send, including all parameters. + Data data; + + // Partial reliability - RFC3758 + absl::optional max_retransmissions; + absl::optional expires_at; + }; + + virtual ~SendQueue() = default; + + // TODO(boivie): This interface is obviously missing an "Add" function, but + // that is postponed a bit until the story around how to model message + // prioritization, which is important for any advanced stream scheduler, is + // further clarified. + + // Produce a chunk to be sent. + // + // `max_size` refers to how many payload bytes that may be produced, not + // including any headers. + virtual absl::optional Produce(TimeMs now, size_t max_size) = 0; + + // Discards a partially sent message identified by the parameters `unordered`, + // `stream_id` and `message_id`. The `message_id` comes from the returned + // information when having called `Produce`. A partially sent message means + // that it has had at least one fragment of it returned when `Produce` was + // called prior to calling this method). + // + // This is used when a message has been found to be expired (by the partial + // reliability extension), and the retransmission queue will signal the + // receiver that any partially received message fragments should be skipped. + // This means that any remaining fragments in the Send Queue must be removed + // as well so that they are not sent. + // + // This function returns true if this message had unsent fragments still in + // the queue that were discarded, and false if there were no such fragments. + virtual bool Discard(IsUnordered unordered, + StreamID stream_id, + MID message_id) = 0; + + // Prepares the streams to be reset. This is used to close a WebRTC data + // channel and will be signaled to the other side. + // + // Concretely, it discards all whole (not partly sent) messages in the given + // streams and pauses those streams so that future added messages aren't + // produced until `ResumeStreams` is called. + // + // TODO(boivie): Investigate if it really should discard any message at all. + // RFC8831 only mentions that "[RFC6525] also guarantees that all the messages + // are delivered (or abandoned) before the stream is reset." + // + // This method can be called multiple times to add more streams to be + // reset, and paused while they are resetting. This is the first part of the + // two-phase commit protocol to reset streams, where the caller completes the + // procedure by either calling `CommitResetStreams` or `RollbackResetStreams`. + virtual void PrepareResetStreams(rtc::ArrayView streams) = 0; + + // Returns true if all non-discarded messages during `PrepareResetStreams` + // (which are those that was partially sent before that method was called) + // have been sent. + virtual bool CanResetStreams() const = 0; + + // Called to commit to reset the streams provided to `PrepareResetStreams`. + // It will reset the stream sequence numbers (SSNs) and message identifiers + // (MIDs) and resume the paused streams. + virtual void CommitResetStreams() = 0; + + // Called to abort the resetting of streams provided to `PrepareResetStreams`. + // Will resume the paused streams without resetting the stream sequence + // numbers (SSNs) or message identifiers (MIDs). Note that the non-partial + // messages that were discarded when calling `PrepareResetStreams` will not be + // recovered, to better match the intention from the sender to "close the + // channel". + virtual void RollbackResetStreams() = 0; + + // Resets all message identifier counters (MID, SSN) and makes all partially + // messages be ready to be re-sent in full. This is used when the peer has + // been detected to have restarted and is used to try to minimize the amount + // of data loss. However, data loss cannot be completely guaranteed when a + // peer restarts. + virtual void Reset() = 0; + + // Returns the amount of buffered data. This doesn't include packets that are + // e.g. inflight. + virtual size_t buffered_amount(StreamID stream_id) const = 0; + + // Returns the total amount of buffer data, for all streams. + virtual size_t total_buffered_amount() const = 0; + + // Returns the limit for the `OnBufferedAmountLow` event. Default value is 0. + virtual size_t buffered_amount_low_threshold(StreamID stream_id) const = 0; + + // Sets a limit for the `OnBufferedAmountLow` event. + virtual void SetBufferedAmountLowThreshold(StreamID stream_id, + size_t bytes) = 0; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_SEND_QUEUE_H_ diff --git a/p2p/BUILD.gn b/p2p/BUILD.gn index d4330ef94a..244bc39092 100644 --- a/p2p/BUILD.gn +++ b/p2p/BUILD.gn @@ -45,8 +45,6 @@ rtc_library("rtc_p2p") { "base/ice_credentials_iterator.h", "base/ice_transport_internal.cc", "base/ice_transport_internal.h", - "base/mdns_message.cc", - "base/mdns_message.h", "base/p2p_constants.cc", "base/p2p_constants.h", "base/p2p_transport_channel.cc", @@ -86,26 +84,38 @@ rtc_library("rtc_p2p") { ] deps = [ + "../api:array_view", + "../api:async_dns_resolver", "../api:libjingle_peerconnection_api", "../api:packet_socket_factory", "../api:rtc_error", "../api:scoped_refptr", + "../api:sequence_checker", "../api/crypto:options", "../api/rtc_event_log", + "../api/task_queue", "../api/transport:enums", "../api/transport:stun_types", "../logging:ice_log", "../rtc_base", + "../rtc_base:async_resolver_interface", + "../rtc_base:async_socket", + "../rtc_base:callback_list", "../rtc_base:checks", + "../rtc_base:ip_address", + "../rtc_base:net_helpers", + "../rtc_base:network_constants", "../rtc_base:rtc_numerics", + "../rtc_base:socket", + "../rtc_base:socket_address", + "../rtc_base:socket_server", + "../rtc_base:threading", "../rtc_base/experiments:field_trial_parser", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/system:no_unique_address", # Needed by pseudo_tcp, which should move to a separate target. "../rtc_base:safe_minmax", "../rtc_base:weak_ptr", - "../rtc_base/memory:fifo_buffer", "../rtc_base/network:sent_packet", "../rtc_base/synchronization:mutex", "../rtc_base/system:rtc_export", @@ -134,6 +144,8 @@ if (rtc_include_tests) { "../api:libjingle_peerconnection_api", "../rtc_base", "../rtc_base:rtc_base_approved", + "../rtc_base/task_utils:pending_task_safety_flag", + "../rtc_base/task_utils:to_queued_task", ] absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container", @@ -148,6 +160,8 @@ if (rtc_include_tests) { deps = [ ":rtc_p2p", "../rtc_base", + "../rtc_base:net_helpers", + "../rtc_base:threading", ] } @@ -170,12 +184,18 @@ if (rtc_include_tests) { ":rtc_p2p", "../api:libjingle_peerconnection_api", "../api:packet_socket_factory", + "../api:sequence_checker", "../api/crypto:options", "../api/transport:stun_types", "../rtc_base", + "../rtc_base:async_resolver_interface", + "../rtc_base:async_socket", "../rtc_base:gunit_helpers", "../rtc_base:rtc_base_approved", "../rtc_base:rtc_base_tests_utils", + "../rtc_base:socket_address", + "../rtc_base:socket_server", + "../rtc_base:threading", "../rtc_base/third_party/sigslot", "../test:test_support", ] @@ -193,7 +213,6 @@ if (rtc_include_tests) { "base/basic_async_resolver_factory_unittest.cc", "base/dtls_transport_unittest.cc", "base/ice_credentials_iterator_unittest.cc", - "base/mdns_message_unittest.cc", "base/p2p_transport_channel_unittest.cc", "base/port_allocator_unittest.cc", "base/port_unittest.cc", @@ -216,20 +235,29 @@ if (rtc_include_tests) { ":p2p_test_utils", ":rtc_p2p", "../api:libjingle_peerconnection_api", + "../api:mock_async_dns_resolver", "../api:packet_socket_factory", "../api:scoped_refptr", "../api/transport:stun_types", "../api/units:time_delta", "../rtc_base", + "../rtc_base:async_socket", "../rtc_base:checks", "../rtc_base:gunit_helpers", + "../rtc_base:ip_address", + "../rtc_base:net_helpers", + "../rtc_base:network_constants", "../rtc_base:rtc_base_approved", "../rtc_base:rtc_base_tests_utils", + "../rtc_base:socket", + "../rtc_base:socket_address", "../rtc_base:testclient", + "../rtc_base:threading", "../rtc_base/network:sent_packet", "../rtc_base/third_party/sigslot", "../system_wrappers:metrics", "../test:field_trial", + "../test:rtc_expect_death", "../test:test_support", "//testing/gtest", ] @@ -252,13 +280,20 @@ rtc_library("p2p_server_utils") { deps = [ ":rtc_p2p", "../api:packet_socket_factory", + "../api:sequence_checker", "../api/transport:stun_types", "../rtc_base", "../rtc_base:checks", "../rtc_base:rtc_base_tests_utils", + "../rtc_base:socket_address", + "../rtc_base:threading", + "../rtc_base/task_utils:to_queued_task", "../rtc_base/third_party/sigslot", ] - absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/memory", + ] } rtc_library("libstunprober") { @@ -271,10 +306,17 @@ rtc_library("libstunprober") { deps = [ ":rtc_p2p", "../api:packet_socket_factory", + "../api:sequence_checker", "../api/transport:stun_types", "../rtc_base", + "../rtc_base:async_resolver_interface", "../rtc_base:checks", + "../rtc_base:ip_address", + "../rtc_base:socket_address", + "../rtc_base:threading", "../rtc_base/system:rtc_export", + "../rtc_base/task_utils:pending_task_safety_flag", + "../rtc_base/task_utils:to_queued_task", ] } @@ -290,6 +332,7 @@ if (rtc_include_tests) { "../rtc_base", "../rtc_base:checks", "../rtc_base:gunit_helpers", + "../rtc_base:ip_address", "../rtc_base:rtc_base_tests_utils", "../test:test_support", "//testing/gtest", diff --git a/p2p/base/basic_async_resolver_factory.cc b/p2p/base/basic_async_resolver_factory.cc index 9d8266eaf9..7f26a981ee 100644 --- a/p2p/base/basic_async_resolver_factory.cc +++ b/p2p/base/basic_async_resolver_factory.cc @@ -10,7 +10,13 @@ #include "p2p/base/basic_async_resolver_factory.h" -#include "rtc_base/net_helpers.h" +#include +#include + +#include "absl/memory/memory.h" +#include "api/async_dns_resolver.h" +#include "rtc_base/async_resolver.h" +#include "rtc_base/logging.h" namespace webrtc { @@ -18,4 +24,113 @@ rtc::AsyncResolverInterface* BasicAsyncResolverFactory::Create() { return new rtc::AsyncResolver(); } +class WrappingAsyncDnsResolver; + +class WrappingAsyncDnsResolverResult : public AsyncDnsResolverResult { + public: + explicit WrappingAsyncDnsResolverResult(WrappingAsyncDnsResolver* owner) + : owner_(owner) {} + ~WrappingAsyncDnsResolverResult() {} + + // Note: Inline declaration not possible, since it refers to + // WrappingAsyncDnsResolver. + bool GetResolvedAddress(int family, rtc::SocketAddress* addr) const override; + int GetError() const override; + + private: + WrappingAsyncDnsResolver* const owner_; +}; + +class WrappingAsyncDnsResolver : public AsyncDnsResolverInterface, + public sigslot::has_slots<> { + public: + explicit WrappingAsyncDnsResolver(rtc::AsyncResolverInterface* wrapped) + : wrapped_(absl::WrapUnique(wrapped)), result_(this) {} + + ~WrappingAsyncDnsResolver() override { + // Workaround to get around the fact that sigslot-using objects can't be + // destroyed from within their callback: Alert class users early. + // TODO(bugs.webrtc.org/12651): Delete this class once the sigslot users are + // gone. + RTC_CHECK(!within_resolve_result_); + wrapped_.release()->Destroy(false); + } + + void Start(const rtc::SocketAddress& addr, + std::function callback) override { + RTC_DCHECK_RUN_ON(&sequence_checker_); + RTC_DCHECK_EQ(State::kNotStarted, state_); + state_ = State::kStarted; + callback_ = callback; + wrapped_->SignalDone.connect(this, + &WrappingAsyncDnsResolver::OnResolveResult); + wrapped_->Start(addr); + } + + const AsyncDnsResolverResult& result() const override { + RTC_DCHECK_RUN_ON(&sequence_checker_); + RTC_DCHECK_EQ(State::kResolved, state_); + return result_; + } + + private: + enum class State { kNotStarted, kStarted, kResolved }; + + friend class WrappingAsyncDnsResolverResult; + // For use by WrappingAsyncDnsResolverResult + rtc::AsyncResolverInterface* wrapped() const { + RTC_DCHECK_RUN_ON(&sequence_checker_); + return wrapped_.get(); + } + + void OnResolveResult(rtc::AsyncResolverInterface* ref) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + RTC_DCHECK(state_ == State::kStarted); + RTC_DCHECK_EQ(ref, wrapped_.get()); + state_ = State::kResolved; + within_resolve_result_ = true; + callback_(); + within_resolve_result_ = false; + } + + // The class variables need to be accessed on a single thread. + SequenceChecker sequence_checker_; + std::function callback_ RTC_GUARDED_BY(sequence_checker_); + std::unique_ptr wrapped_ + RTC_GUARDED_BY(sequence_checker_); + State state_ RTC_GUARDED_BY(sequence_checker_) = State::kNotStarted; + WrappingAsyncDnsResolverResult result_ RTC_GUARDED_BY(sequence_checker_); + bool within_resolve_result_ RTC_GUARDED_BY(sequence_checker_) = false; +}; + +bool WrappingAsyncDnsResolverResult::GetResolvedAddress( + int family, + rtc::SocketAddress* addr) const { + if (!owner_->wrapped()) { + return false; + } + return owner_->wrapped()->GetResolvedAddress(family, addr); +} + +int WrappingAsyncDnsResolverResult::GetError() const { + if (!owner_->wrapped()) { + return -1; // FIXME: Find a code that makes sense. + } + return owner_->wrapped()->GetError(); +} + +std::unique_ptr +WrappingAsyncDnsResolverFactory::Create() { + return std::make_unique(wrapped_factory_->Create()); +} + +std::unique_ptr +WrappingAsyncDnsResolverFactory::CreateAndResolve( + const rtc::SocketAddress& addr, + std::function callback) { + std::unique_ptr resolver = Create(); + resolver->Start(addr, callback); + return resolver; +} + } // namespace webrtc diff --git a/p2p/base/basic_async_resolver_factory.h b/p2p/base/basic_async_resolver_factory.h index c4661b448b..c988913068 100644 --- a/p2p/base/basic_async_resolver_factory.h +++ b/p2p/base/basic_async_resolver_factory.h @@ -11,16 +11,47 @@ #ifndef P2P_BASE_BASIC_ASYNC_RESOLVER_FACTORY_H_ #define P2P_BASE_BASIC_ASYNC_RESOLVER_FACTORY_H_ +#include +#include +#include + +#include "api/async_dns_resolver.h" #include "api/async_resolver_factory.h" #include "rtc_base/async_resolver_interface.h" namespace webrtc { -class BasicAsyncResolverFactory : public AsyncResolverFactory { +class BasicAsyncResolverFactory final : public AsyncResolverFactory { public: rtc::AsyncResolverInterface* Create() override; }; +// This class wraps a factory using the older webrtc::AsyncResolverFactory API, +// and produces webrtc::AsyncDnsResolver objects that contain an +// rtc::AsyncResolver object. +class WrappingAsyncDnsResolverFactory final + : public AsyncDnsResolverFactoryInterface { + public: + explicit WrappingAsyncDnsResolverFactory( + std::unique_ptr wrapped_factory) + : owned_factory_(std::move(wrapped_factory)), + wrapped_factory_(owned_factory_.get()) {} + + explicit WrappingAsyncDnsResolverFactory( + AsyncResolverFactory* non_owned_factory) + : wrapped_factory_(non_owned_factory) {} + + std::unique_ptr CreateAndResolve( + const rtc::SocketAddress& addr, + std::function callback) override; + + std::unique_ptr Create() override; + + private: + const std::unique_ptr owned_factory_; + AsyncResolverFactory* const wrapped_factory_; +}; + } // namespace webrtc #endif // P2P_BASE_BASIC_ASYNC_RESOLVER_FACTORY_H_ diff --git a/p2p/base/basic_async_resolver_factory_unittest.cc b/p2p/base/basic_async_resolver_factory_unittest.cc index 8242146bae..6706f50d61 100644 --- a/p2p/base/basic_async_resolver_factory_unittest.cc +++ b/p2p/base/basic_async_resolver_factory_unittest.cc @@ -10,10 +10,15 @@ #include "p2p/base/basic_async_resolver_factory.h" +#include "api/test/mock_async_dns_resolver.h" +#include "p2p/base/mock_async_resolver.h" +#include "rtc_base/async_resolver.h" #include "rtc_base/gunit.h" #include "rtc_base/socket_address.h" #include "rtc_base/third_party/sigslot/sigslot.h" +#include "test/gmock.h" #include "test/gtest.h" +#include "test/testsupport/rtc_expect_death.h" namespace webrtc { @@ -47,4 +52,66 @@ TEST_F(BasicAsyncResolverFactoryTest, TestCreate) { TestCreate(); } +TEST(WrappingAsyncDnsResolverFactoryTest, TestCreateAndResolve) { + WrappingAsyncDnsResolverFactory factory( + std::make_unique()); + + std::unique_ptr resolver(factory.Create()); + ASSERT_TRUE(resolver); + + bool address_resolved = false; + rtc::SocketAddress address("", 0); + resolver->Start(address, [&address_resolved]() { address_resolved = true; }); + ASSERT_TRUE_WAIT(address_resolved, 10000 /*ms*/); + resolver.reset(); +} + +TEST(WrappingAsyncDnsResolverFactoryTest, WrapOtherResolver) { + BasicAsyncResolverFactory non_owned_factory; + WrappingAsyncDnsResolverFactory factory(&non_owned_factory); + std::unique_ptr resolver(factory.Create()); + ASSERT_TRUE(resolver); + + bool address_resolved = false; + rtc::SocketAddress address("", 0); + resolver->Start(address, [&address_resolved]() { address_resolved = true; }); + ASSERT_TRUE_WAIT(address_resolved, 10000 /*ms*/); + resolver.reset(); +} + +#if GTEST_HAS_DEATH_TEST && defined(WEBRTC_LINUX) +// Tests that the prohibition against deleting the resolver from the callback +// is enforced. This is required by the use of sigslot in the wrapped resolver. +// Checking the error message fails on a number of platforms, so run this +// test only on the platforms where it works. +void CallResolver(WrappingAsyncDnsResolverFactory& factory) { + rtc::SocketAddress address("", 0); + std::unique_ptr resolver(factory.Create()); + resolver->Start(address, [&resolver]() { resolver.reset(); }); + WAIT(!resolver.get(), 10000 /*ms*/); +} + +TEST(WrappingAsyncDnsResolverFactoryDeathTest, DestroyResolverInCallback) { + // This test requires the main thread to be wrapped. So we defeat the + // workaround in test/test_main_lib.cc by explicitly wrapping the main + // thread here. + auto thread = rtc::Thread::CreateWithSocketServer(); + thread->WrapCurrent(); + // TODO(bugs.webrtc.org/12652): Rewrite as death test in loop style when it + // works. + WrappingAsyncDnsResolverFactory factory( + std::make_unique()); + + // Since EXPECT_DEATH is thread sensitive, and the resolver creates a thread, + // we wrap the whole creation section in EXPECT_DEATH. + RTC_EXPECT_DEATH(CallResolver(factory), + "Check failed: !within_resolve_result_"); + // If we get here, we have to unwrap the thread. + thread->Quit(); + thread->Run(); + thread->UnwrapCurrent(); + thread = nullptr; +} +#endif + } // namespace webrtc diff --git a/p2p/base/basic_packet_socket_factory.cc b/p2p/base/basic_packet_socket_factory.cc index ebc11bbcf7..232e58b546 100644 --- a/p2p/base/basic_packet_socket_factory.cc +++ b/p2p/base/basic_packet_socket_factory.cc @@ -15,6 +15,7 @@ #include #include "p2p/base/async_stun_tcp_socket.h" +#include "rtc_base/async_resolver.h" #include "rtc_base/async_tcp_socket.h" #include "rtc_base/async_udp_socket.h" #include "rtc_base/checks.h" diff --git a/p2p/base/connection.cc b/p2p/base/connection.cc index 8adfeb418d..0aa2bcbeff 100644 --- a/p2p/base/connection.cc +++ b/p2p/base/connection.cc @@ -480,6 +480,7 @@ void Connection::OnReadPacket(const char* data, // If this is a STUN response, then update the writable bit. // Log at LS_INFO if we receive a ping on an unwritable connection. rtc::LoggingSeverity sev = (!writable() ? rtc::LS_INFO : rtc::LS_VERBOSE); + msg->ValidateMessageIntegrity(remote_candidate().password()); switch (msg->type()) { case STUN_BINDING_REQUEST: RTC_LOG_V(sev) << ToString() << ": Received " @@ -505,8 +506,7 @@ void Connection::OnReadPacket(const char* data, // id's match. case STUN_BINDING_RESPONSE: case STUN_BINDING_ERROR_RESPONSE: - if (msg->ValidateMessageIntegrity(data, size, - remote_candidate().password())) { + if (msg->IntegrityOk()) { requests_.CheckResponse(msg.get()); } // Otherwise silently discard the response message. @@ -523,8 +523,7 @@ void Connection::OnReadPacket(const char* data, break; case GOOG_PING_RESPONSE: case GOOG_PING_ERROR_RESPONSE: - if (msg->ValidateMessageIntegrity32(data, size, - remote_candidate().password())) { + if (msg->IntegrityOk()) { requests_.CheckResponse(msg.get()); } break; diff --git a/p2p/base/default_ice_transport_factory.cc b/p2p/base/default_ice_transport_factory.cc index f4b182efdf..0a7175cfd8 100644 --- a/p2p/base/default_ice_transport_factory.cc +++ b/p2p/base/default_ice_transport_factory.cc @@ -44,10 +44,10 @@ DefaultIceTransportFactory::CreateIceTransport( int component, IceTransportInit init) { BasicIceControllerFactory factory; - return new rtc::RefCountedObject( - std::make_unique( + return rtc::make_ref_counted( + cricket::P2PTransportChannel::Create( transport_name, component, init.port_allocator(), - init.async_resolver_factory(), init.event_log(), &factory)); + init.async_dns_resolver_factory(), init.event_log(), &factory)); } } // namespace webrtc diff --git a/p2p/base/default_ice_transport_factory.h b/p2p/base/default_ice_transport_factory.h index 4834c9ada7..e46680d480 100644 --- a/p2p/base/default_ice_transport_factory.h +++ b/p2p/base/default_ice_transport_factory.h @@ -36,7 +36,7 @@ class DefaultIceTransport : public IceTransportInterface { } private: - const rtc::ThreadChecker thread_checker_{}; + const SequenceChecker thread_checker_{}; std::unique_ptr internal_ RTC_GUARDED_BY(thread_checker_); }; diff --git a/p2p/base/dtls_transport.cc b/p2p/base/dtls_transport.cc index 52fe5c65a2..76b94a8d79 100644 --- a/p2p/base/dtls_transport.cc +++ b/p2p/base/dtls_transport.cc @@ -15,6 +15,7 @@ #include #include "absl/memory/memory.h" +#include "api/dtls_transport_interface.h" #include "api/rtc_event_log/rtc_event_log.h" #include "logging/rtc_event_log/events/rtc_event_dtls_transport_state.h" #include "logging/rtc_event_log/events/rtc_event_dtls_writable_state.h" @@ -134,14 +135,13 @@ void StreamInterfaceChannel::Close() { DtlsTransport::DtlsTransport(IceTransportInternal* ice_transport, const webrtc::CryptoOptions& crypto_options, - webrtc::RtcEventLog* event_log) - : transport_name_(ice_transport->transport_name()), - component_(ice_transport->component()), + webrtc::RtcEventLog* event_log, + rtc::SSLProtocolVersion max_version) + : component_(ice_transport->component()), ice_transport_(ice_transport), downward_(NULL), srtp_ciphers_(crypto_options.GetSupportedDtlsSrtpCryptoSuites()), - ssl_max_version_(rtc::SSL_PROTOCOL_DTLS_12), - crypto_options_(crypto_options), + ssl_max_version_(max_version), event_log_(event_log) { RTC_DCHECK(ice_transport_); ConnectToIceTransport(); @@ -149,16 +149,12 @@ DtlsTransport::DtlsTransport(IceTransportInternal* ice_transport, DtlsTransport::~DtlsTransport() = default; -const webrtc::CryptoOptions& DtlsTransport::crypto_options() const { - return crypto_options_; -} - -DtlsTransportState DtlsTransport::dtls_state() const { +webrtc::DtlsTransportState DtlsTransport::dtls_state() const { return dtls_state_; } const std::string& DtlsTransport::transport_name() const { - return transport_name_; + return ice_transport_->transport_name(); } int DtlsTransport::component() const { @@ -199,17 +195,6 @@ rtc::scoped_refptr DtlsTransport::GetLocalCertificate() return local_certificate_; } -bool DtlsTransport::SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) { - if (dtls_active_) { - RTC_LOG(LS_ERROR) << "Not changing max. protocol version " - "while DTLS is negotiating"; - return false; - } - - ssl_max_version_ = version; - return true; -} - bool DtlsTransport::SetDtlsRole(rtc::SSLRole role) { if (dtls_) { RTC_DCHECK(dtls_role_); @@ -234,7 +219,7 @@ bool DtlsTransport::GetDtlsRole(rtc::SSLRole* role) const { } bool DtlsTransport::GetSslCipherSuite(int* cipher) { - if (dtls_state() != DTLS_TRANSPORT_CONNECTED) { + if (dtls_state() != webrtc::DtlsTransportState::kConnected) { return false; } @@ -292,7 +277,7 @@ bool DtlsTransport::SetRemoteFingerprint(const std::string& digest_alg, remote_fingerprint_value_.size(), &err)) { RTC_LOG(LS_ERROR) << ToString() << ": Couldn't set DTLS certificate digest."; - set_dtls_state(DTLS_TRANSPORT_FAILED); + set_dtls_state(webrtc::DtlsTransportState::kFailed); // If the error is "verification failed", don't return false, because // this means the fingerprint was formatted correctly but didn't match // the certificate from the DTLS handshake. Thus the DTLS state should go @@ -306,12 +291,12 @@ bool DtlsTransport::SetRemoteFingerprint(const std::string& digest_alg, // create a new one, resetting our state. if (dtls_ && fingerprint_changing) { dtls_.reset(nullptr); - set_dtls_state(DTLS_TRANSPORT_NEW); + set_dtls_state(webrtc::DtlsTransportState::kNew); set_writable(false); } if (!SetupDtls()) { - set_dtls_state(DTLS_TRANSPORT_FAILED); + set_dtls_state(webrtc::DtlsTransportState::kFailed); return false; } @@ -389,7 +374,7 @@ bool DtlsTransport::SetupDtls() { } bool DtlsTransport::GetSrtpCryptoSuite(int* cipher) { - if (dtls_state() != DTLS_TRANSPORT_CONNECTED) { + if (dtls_state() != webrtc::DtlsTransportState::kConnected) { return false; } @@ -397,7 +382,7 @@ bool DtlsTransport::GetSrtpCryptoSuite(int* cipher) { } bool DtlsTransport::GetSslVersionBytes(int* version) const { - if (dtls_state() != DTLS_TRANSPORT_CONNECTED) { + if (dtls_state() != webrtc::DtlsTransportState::kConnected) { return false; } @@ -415,14 +400,14 @@ int DtlsTransport::SendPacket(const char* data, } switch (dtls_state()) { - case DTLS_TRANSPORT_NEW: + case webrtc::DtlsTransportState::kNew: // Can't send data until the connection is active. // TODO(ekr@rtfm.com): assert here if dtls_ is NULL? return -1; - case DTLS_TRANSPORT_CONNECTING: + case webrtc::DtlsTransportState::kConnecting: // Can't send data until the connection is active. return -1; - case DTLS_TRANSPORT_CONNECTED: + case webrtc::DtlsTransportState::kConnected: if (flags & PF_SRTP_BYPASS) { RTC_DCHECK(!srtp_ciphers_.empty()); if (!IsRtpPacket(data, size)) { @@ -435,17 +420,17 @@ int DtlsTransport::SendPacket(const char* data, ? static_cast(size) : -1; } - case DTLS_TRANSPORT_FAILED: + case webrtc::DtlsTransportState::kFailed: // Can't send anything when we're failed. - RTC_LOG(LS_ERROR) - << ToString() - << ": Couldn't send packet due to DTLS_TRANSPORT_FAILED."; + RTC_LOG(LS_ERROR) << ToString() + << ": Couldn't send packet due to " + "webrtc::DtlsTransportState::kFailed."; return -1; - case DTLS_TRANSPORT_CLOSED: + case webrtc::DtlsTransportState::kClosed: // Can't send anything when we're closed. - RTC_LOG(LS_ERROR) - << ToString() - << ": Couldn't send packet due to DTLS_TRANSPORT_CLOSED."; + RTC_LOG(LS_ERROR) << ToString() + << ": Couldn't send packet due to " + "webrtc::DtlsTransportState::kClosed."; return -1; default: RTC_NOTREACHED(); @@ -524,27 +509,30 @@ void DtlsTransport::OnWritableState(rtc::PacketTransportInternal* transport) { } switch (dtls_state()) { - case DTLS_TRANSPORT_NEW: + case webrtc::DtlsTransportState::kNew: MaybeStartDtls(); break; - case DTLS_TRANSPORT_CONNECTED: + case webrtc::DtlsTransportState::kConnected: // Note: SignalWritableState fired by set_writable. set_writable(ice_transport_->writable()); break; - case DTLS_TRANSPORT_CONNECTING: + case webrtc::DtlsTransportState::kConnecting: // Do nothing. break; - case DTLS_TRANSPORT_FAILED: + case webrtc::DtlsTransportState::kFailed: // Should not happen. Do nothing. - RTC_LOG(LS_ERROR) - << ToString() - << ": OnWritableState() called in state DTLS_TRANSPORT_FAILED."; + RTC_LOG(LS_ERROR) << ToString() + << ": OnWritableState() called in state " + "webrtc::DtlsTransportState::kFailed."; break; - case DTLS_TRANSPORT_CLOSED: + case webrtc::DtlsTransportState::kClosed: // Should not happen. Do nothing. - RTC_LOG(LS_ERROR) - << ToString() - << ": OnWritableState() called in state DTLS_TRANSPORT_CLOSED."; + RTC_LOG(LS_ERROR) << ToString() + << ": OnWritableState() called in state " + "webrtc::DtlsTransportState::kClosed."; + break; + case webrtc::DtlsTransportState::kNumValues: + RTC_NOTREACHED(); break; } } @@ -556,7 +544,7 @@ void DtlsTransport::OnReceivingState(rtc::PacketTransportInternal* transport) { << ": ice_transport " "receiving state changed to " << ice_transport_->receiving(); - if (!dtls_active_ || dtls_state() == DTLS_TRANSPORT_CONNECTED) { + if (!dtls_active_ || dtls_state() == webrtc::DtlsTransportState::kConnected) { // Note: SignalReceivingState fired by set_receiving. set_receiving(ice_transport_->receiving()); } @@ -578,7 +566,7 @@ void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport, } switch (dtls_state()) { - case DTLS_TRANSPORT_NEW: + case webrtc::DtlsTransportState::kNew: if (dtls_) { RTC_LOG(LS_INFO) << ToString() << ": Packet received before DTLS started."; @@ -607,8 +595,8 @@ void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport, } break; - case DTLS_TRANSPORT_CONNECTING: - case DTLS_TRANSPORT_CONNECTED: + case webrtc::DtlsTransportState::kConnecting: + case webrtc::DtlsTransportState::kConnected: // We should only get DTLS or SRTP packets; STUN's already been demuxed. // Is this potentially a DTLS packet? if (IsDtlsPacket(data, size)) { @@ -618,7 +606,7 @@ void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport, } } else { // Not a DTLS packet; our handshake should be complete by now. - if (dtls_state() != DTLS_TRANSPORT_CONNECTED) { + if (dtls_state() != webrtc::DtlsTransportState::kConnected) { RTC_LOG(LS_ERROR) << ToString() << ": Received non-DTLS packet before DTLS " "complete."; @@ -639,8 +627,9 @@ void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport, SignalReadPacket(this, data, size, packet_time_us, PF_SRTP_BYPASS); } break; - case DTLS_TRANSPORT_FAILED: - case DTLS_TRANSPORT_CLOSED: + case webrtc::DtlsTransportState::kFailed: + case webrtc::DtlsTransportState::kClosed: + case webrtc::DtlsTransportState::kNumValues: // This shouldn't be happening. Drop the packet. break; } @@ -668,7 +657,7 @@ void DtlsTransport::OnDtlsEvent(rtc::StreamInterface* dtls, int sig, int err) { if (dtls_->GetState() == rtc::SS_OPEN) { // The check for OPEN shouldn't be necessary but let's make // sure we don't accidentally frob the state if it's closed. - set_dtls_state(DTLS_TRANSPORT_CONNECTED); + set_dtls_state(webrtc::DtlsTransportState::kConnected); set_writable(true); } } @@ -687,7 +676,7 @@ void DtlsTransport::OnDtlsEvent(rtc::StreamInterface* dtls, int sig, int err) { // Remote peer shut down the association with no error. RTC_LOG(LS_INFO) << ToString() << ": DTLS transport closed by remote"; set_writable(false); - set_dtls_state(DTLS_TRANSPORT_CLOSED); + set_dtls_state(webrtc::DtlsTransportState::kClosed); SignalClosed(this); } else if (ret == rtc::SR_ERROR) { // Remote peer shut down the association with an error. @@ -696,7 +685,7 @@ void DtlsTransport::OnDtlsEvent(rtc::StreamInterface* dtls, int sig, int err) { << ": Closed by remote with DTLS transport error, code=" << read_error; set_writable(false); - set_dtls_state(DTLS_TRANSPORT_FAILED); + set_dtls_state(webrtc::DtlsTransportState::kFailed); SignalClosed(this); } } while (ret == rtc::SR_SUCCESS); @@ -706,10 +695,10 @@ void DtlsTransport::OnDtlsEvent(rtc::StreamInterface* dtls, int sig, int err) { set_writable(false); if (!err) { RTC_LOG(LS_INFO) << ToString() << ": DTLS transport closed"; - set_dtls_state(DTLS_TRANSPORT_CLOSED); + set_dtls_state(webrtc::DtlsTransportState::kClosed); } else { RTC_LOG(LS_INFO) << ToString() << ": DTLS transport error, code=" << err; - set_dtls_state(DTLS_TRANSPORT_FAILED); + set_dtls_state(webrtc::DtlsTransportState::kFailed); } } } @@ -733,11 +722,11 @@ void DtlsTransport::MaybeStartDtls() { // configuration and therefore are our fault. RTC_NOTREACHED() << "StartSSL failed."; RTC_LOG(LS_ERROR) << ToString() << ": Couldn't start DTLS handshake"; - set_dtls_state(DTLS_TRANSPORT_FAILED); + set_dtls_state(webrtc::DtlsTransportState::kFailed); return; } RTC_LOG(LS_INFO) << ToString() << ": DtlsTransport: Started DTLS handshake"; - set_dtls_state(DTLS_TRANSPORT_CONNECTING); + set_dtls_state(webrtc::DtlsTransportState::kConnecting); // Now that the handshake has started, we can process a cached ClientHello // (if one exists). if (cached_client_hello_.size()) { @@ -805,22 +794,23 @@ void DtlsTransport::set_writable(bool writable) { SignalWritableState(this); } -void DtlsTransport::set_dtls_state(DtlsTransportState state) { +void DtlsTransport::set_dtls_state(webrtc::DtlsTransportState state) { if (dtls_state_ == state) { return; } if (event_log_) { - event_log_->Log(std::make_unique( - ConvertDtlsTransportState(state))); + event_log_->Log( + std::make_unique(state)); } - RTC_LOG(LS_VERBOSE) << ToString() << ": set_dtls_state from:" << dtls_state_ - << " to " << state; + RTC_LOG(LS_VERBOSE) << ToString() << ": set_dtls_state from:" + << static_cast(dtls_state_) << " to " + << static_cast(state); dtls_state_ = state; - SignalDtlsState(this, state); + SendDtlsState(this, state); } void DtlsTransport::OnDtlsHandshakeError(rtc::SSLHandshakeError error) { - SignalDtlsHandshakeError(error); + SendDtlsHandshakeError(error); } void DtlsTransport::ConfigureHandshakeTimeout() { diff --git a/p2p/base/dtls_transport.h b/p2p/base/dtls_transport.h index 5c8a721d03..0296a742c0 100644 --- a/p2p/base/dtls_transport.h +++ b/p2p/base/dtls_transport.h @@ -16,6 +16,8 @@ #include #include "api/crypto/crypto_options.h" +#include "api/dtls_transport_interface.h" +#include "api/sequence_checker.h" #include "p2p/base/dtls_transport_internal.h" #include "p2p/base/ice_transport_internal.h" #include "rtc_base/buffer.h" @@ -24,9 +26,7 @@ #include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/stream.h" #include "rtc_base/strings/string_builder.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" -#include "rtc_base/thread_checker.h" namespace rtc { class PacketTransportInternal; @@ -102,14 +102,15 @@ class DtlsTransport : public DtlsTransportInternal { // // |event_log| is an optional RtcEventLog for logging state changes. It should // outlive the DtlsTransport. - explicit DtlsTransport(IceTransportInternal* ice_transport, - const webrtc::CryptoOptions& crypto_options, - webrtc::RtcEventLog* event_log); + DtlsTransport( + IceTransportInternal* ice_transport, + const webrtc::CryptoOptions& crypto_options, + webrtc::RtcEventLog* event_log, + rtc::SSLProtocolVersion max_version = rtc::SSL_PROTOCOL_DTLS_12); ~DtlsTransport() override; - const webrtc::CryptoOptions& crypto_options() const override; - DtlsTransportState dtls_state() const override; + webrtc::DtlsTransportState dtls_state() const override; const std::string& transport_name() const override; int component() const override; @@ -143,8 +144,6 @@ class DtlsTransport : public DtlsTransportInternal { bool GetOption(rtc::Socket::Option opt, int* value) override; - bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) override; - // Find out which TLS version was negotiated bool GetSslVersionBytes(int* version) const override; // Find out which DTLS-SRTP cipher was negotiated @@ -192,7 +191,7 @@ class DtlsTransport : public DtlsTransportInternal { const absl::string_view RECEIVING_ABBREV[2] = {"_", "R"}; const absl::string_view WRITABLE_ABBREV[2] = {"_", "W"}; rtc::StringBuilder sb; - sb << "DtlsTransport[" << transport_name_ << "|" << component_ << "|" + sb << "DtlsTransport[" << transport_name() << "|" << component_ << "|" << RECEIVING_ABBREV[receiving()] << WRITABLE_ABBREV[writable()] << "]"; return sb.Release(); } @@ -221,24 +220,22 @@ class DtlsTransport : public DtlsTransportInternal { void set_receiving(bool receiving); void set_writable(bool writable); // Sets the DTLS state, signaling if necessary. - void set_dtls_state(DtlsTransportState state); + void set_dtls_state(webrtc::DtlsTransportState state); - rtc::ThreadChecker thread_checker_; + webrtc::SequenceChecker thread_checker_; - std::string transport_name_; - int component_; - DtlsTransportState dtls_state_ = DTLS_TRANSPORT_NEW; + const int component_; + webrtc::DtlsTransportState dtls_state_ = webrtc::DtlsTransportState::kNew; // Underlying ice_transport, not owned by this class. - IceTransportInternal* ice_transport_; + IceTransportInternal* const ice_transport_; std::unique_ptr dtls_; // The DTLS stream StreamInterfaceChannel* downward_; // Wrapper for ice_transport_, owned by dtls_. - std::vector srtp_ciphers_; // SRTP ciphers to use with DTLS. + const std::vector srtp_ciphers_; // SRTP ciphers to use with DTLS. bool dtls_active_ = false; rtc::scoped_refptr local_certificate_; absl::optional dtls_role_; - rtc::SSLProtocolVersion ssl_max_version_; - webrtc::CryptoOptions crypto_options_; + const rtc::SSLProtocolVersion ssl_max_version_; rtc::Buffer remote_fingerprint_value_; std::string remote_fingerprint_algorithm_; diff --git a/p2p/base/dtls_transport_factory.h b/p2p/base/dtls_transport_factory.h index 9ad78a7cc2..7c4a24adc8 100644 --- a/p2p/base/dtls_transport_factory.h +++ b/p2p/base/dtls_transport_factory.h @@ -31,7 +31,8 @@ class DtlsTransportFactory { virtual std::unique_ptr CreateDtlsTransport( IceTransportInternal* ice, - const webrtc::CryptoOptions& crypto_options) = 0; + const webrtc::CryptoOptions& crypto_options, + rtc::SSLProtocolVersion max_version) = 0; }; } // namespace cricket diff --git a/p2p/base/dtls_transport_internal.cc b/p2p/base/dtls_transport_internal.cc index dd23b1baa7..6997dbc702 100644 --- a/p2p/base/dtls_transport_internal.cc +++ b/p2p/base/dtls_transport_internal.cc @@ -16,22 +16,4 @@ DtlsTransportInternal::DtlsTransportInternal() = default; DtlsTransportInternal::~DtlsTransportInternal() = default; -webrtc::DtlsTransportState ConvertDtlsTransportState( - cricket::DtlsTransportState cricket_state) { - switch (cricket_state) { - case DtlsTransportState::DTLS_TRANSPORT_NEW: - return webrtc::DtlsTransportState::kNew; - case DtlsTransportState::DTLS_TRANSPORT_CONNECTING: - return webrtc::DtlsTransportState::kConnecting; - case DtlsTransportState::DTLS_TRANSPORT_CONNECTED: - return webrtc::DtlsTransportState::kConnected; - case DtlsTransportState::DTLS_TRANSPORT_CLOSED: - return webrtc::DtlsTransportState::kClosed; - case DtlsTransportState::DTLS_TRANSPORT_FAILED: - return webrtc::DtlsTransportState::kFailed; - } - RTC_NOTREACHED(); - return webrtc::DtlsTransportState::kNew; -} - } // namespace cricket diff --git a/p2p/base/dtls_transport_internal.h b/p2p/base/dtls_transport_internal.h index 4c35d7371f..0b26a7fd7a 100644 --- a/p2p/base/dtls_transport_internal.h +++ b/p2p/base/dtls_transport_internal.h @@ -16,36 +16,22 @@ #include #include +#include +#include "absl/base/attributes.h" #include "api/crypto/crypto_options.h" #include "api/dtls_transport_interface.h" #include "api/scoped_refptr.h" #include "p2p/base/ice_transport_internal.h" #include "p2p/base/packet_transport_internal.h" +#include "rtc_base/callback_list.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/ssl_certificate.h" #include "rtc_base/ssl_fingerprint.h" #include "rtc_base/ssl_stream_adapter.h" -#include "rtc_base/third_party/sigslot/sigslot.h" namespace cricket { -enum DtlsTransportState { - // Haven't started negotiating. - DTLS_TRANSPORT_NEW = 0, - // Have started negotiating. - DTLS_TRANSPORT_CONNECTING, - // Negotiated, and has a secure connection. - DTLS_TRANSPORT_CONNECTED, - // Transport is closed. - DTLS_TRANSPORT_CLOSED, - // Failed due to some error in the handshake process. - DTLS_TRANSPORT_FAILED, -}; - -webrtc::DtlsTransportState ConvertDtlsTransportState( - cricket::DtlsTransportState cricket_state); - enum PacketFlags { PF_NORMAL = 0x00, // A normal packet. PF_SRTP_BYPASS = 0x01, // An encrypted SRTP packet; bypass any additional @@ -62,9 +48,7 @@ class DtlsTransportInternal : public rtc::PacketTransportInternal { public: ~DtlsTransportInternal() override; - virtual const webrtc::CryptoOptions& crypto_options() const = 0; - - virtual DtlsTransportState dtls_state() const = 0; + virtual webrtc::DtlsTransportState dtls_state() const = 0; virtual int component() const = 0; @@ -107,21 +91,55 @@ class DtlsTransportInternal : public rtc::PacketTransportInternal { const uint8_t* digest, size_t digest_len) = 0; - virtual bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) = 0; + ABSL_DEPRECATED("Set the max version via construction.") + bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) { + return true; + } // Expose the underneath IceTransport. virtual IceTransportInternal* ice_transport() = 0; - sigslot::signal2 SignalDtlsState; + // F: void(DtlsTransportInternal*, const webrtc::DtlsTransportState) + template + void SubscribeDtlsTransportState(F&& callback) { + dtls_transport_state_callback_list_.AddReceiver(std::forward(callback)); + } + + template + void SubscribeDtlsTransportState(const void* id, F&& callback) { + dtls_transport_state_callback_list_.AddReceiver(id, + std::forward(callback)); + } + // Unsubscribe the subscription with given id. + void UnsubscribeDtlsTransportState(const void* id) { + dtls_transport_state_callback_list_.RemoveReceivers(id); + } + + void SendDtlsState(DtlsTransportInternal* transport, + webrtc::DtlsTransportState state) { + dtls_transport_state_callback_list_.Send(transport, state); + } // Emitted whenever the Dtls handshake failed on some transport channel. - sigslot::signal1 SignalDtlsHandshakeError; + // F: void(rtc::SSLHandshakeError) + template + void SubscribeDtlsHandshakeError(F&& callback) { + dtls_handshake_error_callback_list_.AddReceiver(std::forward(callback)); + } + + void SendDtlsHandshakeError(rtc::SSLHandshakeError error) { + dtls_handshake_error_callback_list_.Send(error); + } protected: DtlsTransportInternal(); private: RTC_DISALLOW_COPY_AND_ASSIGN(DtlsTransportInternal); + webrtc::CallbackList + dtls_handshake_error_callback_list_; + webrtc::CallbackList + dtls_transport_state_callback_list_; }; } // namespace cricket diff --git a/p2p/base/dtls_transport_unittest.cc b/p2p/base/dtls_transport_unittest.cc index 6822e55be7..f01566d263 100644 --- a/p2p/base/dtls_transport_unittest.cc +++ b/p2p/base/dtls_transport_unittest.cc @@ -15,6 +15,7 @@ #include #include +#include "api/dtls_transport_interface.h" #include "p2p/base/fake_ice_transport.h" #include "p2p/base/packet_transport_internal.h" #include "rtc_base/checks.h" @@ -86,10 +87,9 @@ class DtlsTestClient : public sigslot::has_slots<> { fake_ice_transport_->SignalReadPacket.connect( this, &DtlsTestClient::OnFakeIceTransportReadPacket); - dtls_transport_ = std::make_unique(fake_ice_transport_.get(), - webrtc::CryptoOptions(), - /*event_log=*/nullptr); - dtls_transport_->SetSslMaxProtocolVersion(ssl_max_version_); + dtls_transport_ = std::make_unique( + fake_ice_transport_.get(), webrtc::CryptoOptions(), + /*event_log=*/nullptr, ssl_max_version_); // Note: Certificate may be null here if testing passthrough. dtls_transport_->SetLocalCertificate(certificate_); dtls_transport_->SignalWritableState.connect( @@ -669,18 +669,19 @@ class DtlsEventOrderingTest // Sanity check that the handshake hasn't already finished. EXPECT_FALSE(client1_.dtls_transport()->IsDtlsConnected() || client1_.dtls_transport()->dtls_state() == - DTLS_TRANSPORT_FAILED); + webrtc::DtlsTransportState::kFailed); EXPECT_TRUE_SIMULATED_WAIT( client1_.dtls_transport()->IsDtlsConnected() || client1_.dtls_transport()->dtls_state() == - DTLS_TRANSPORT_FAILED, + webrtc::DtlsTransportState::kFailed, kTimeout, fake_clock_); break; } } - DtlsTransportState expected_final_state = - valid_fingerprint ? DTLS_TRANSPORT_CONNECTED : DTLS_TRANSPORT_FAILED; + webrtc::DtlsTransportState expected_final_state = + valid_fingerprint ? webrtc::DtlsTransportState::kConnected + : webrtc::DtlsTransportState::kFailed; EXPECT_EQ_SIMULATED_WAIT(expected_final_state, client1_.dtls_transport()->dtls_state(), kTimeout, fake_clock_); diff --git a/p2p/base/fake_dtls_transport.h b/p2p/base/fake_dtls_transport.h index 7061ea4b3e..e02755c68f 100644 --- a/p2p/base/fake_dtls_transport.h +++ b/p2p/base/fake_dtls_transport.h @@ -17,6 +17,7 @@ #include #include "api/crypto/crypto_options.h" +#include "api/dtls_transport_interface.h" #include "p2p/base/dtls_transport_internal.h" #include "p2p/base/fake_ice_transport.h" #include "rtc_base/fake_ssl_identity.h" @@ -55,9 +56,15 @@ class FakeDtlsTransport : public DtlsTransportInternal { // If this constructor is called, a new fake ICE transport will be created, // and this FakeDtlsTransport will take the ownership. - explicit FakeDtlsTransport(const std::string& name, int component) + FakeDtlsTransport(const std::string& name, int component) : FakeDtlsTransport(std::make_unique(name, component)) { } + FakeDtlsTransport(const std::string& name, + int component, + rtc::Thread* network_thread) + : FakeDtlsTransport(std::make_unique(name, + component, + network_thread)) {} ~FakeDtlsTransport() override { if (dest_ && dest_->dest_ == this) { @@ -83,9 +90,9 @@ class FakeDtlsTransport : public DtlsTransportInternal { ice_transport_->SetReceiving(receiving); set_receiving(receiving); } - void SetDtlsState(DtlsTransportState state) { + void SetDtlsState(webrtc::DtlsTransportState state) { dtls_state_ = state; - SignalDtlsState(this, dtls_state_); + SendDtlsState(this, dtls_state_); } // Simulates the two DTLS transports connecting to each other. @@ -115,7 +122,7 @@ class FakeDtlsTransport : public DtlsTransportInternal { if (!dtls_role_) { dtls_role_ = std::move(rtc::SSL_CLIENT); } - SetDtlsState(DTLS_TRANSPORT_CONNECTED); + SetDtlsState(webrtc::DtlsTransportState::kConnected); ice_transport_->SetDestination( static_cast(dest->ice_transport()), asymmetric); } else { @@ -127,7 +134,7 @@ class FakeDtlsTransport : public DtlsTransportInternal { } // Fake DtlsTransportInternal implementation. - DtlsTransportState dtls_state() const override { return dtls_state_; } + webrtc::DtlsTransportState dtls_state() const override { return dtls_state_; } const std::string& transport_name() const override { return transport_name_; } int component() const override { return component_; } const rtc::SSLFingerprint& dtls_fingerprint() const { @@ -140,9 +147,6 @@ class FakeDtlsTransport : public DtlsTransportInternal { rtc::SSLFingerprint(alg, rtc::MakeArrayView(digest, digest_len)); return true; } - bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) override { - return true; - } bool SetDtlsRole(rtc::SSLRole role) override { dtls_role_ = std::move(role); return true; @@ -154,12 +158,6 @@ class FakeDtlsTransport : public DtlsTransportInternal { *role = *dtls_role_; return true; } - const webrtc::CryptoOptions& crypto_options() const override { - return crypto_options_; - } - void SetCryptoOptions(const webrtc::CryptoOptions& crypto_options) { - crypto_options_ = crypto_options; - } bool SetLocalCertificate( const rtc::scoped_refptr& certificate) override { do_dtls_ = true; @@ -297,9 +295,8 @@ class FakeDtlsTransport : public DtlsTransportInternal { absl::optional dtls_role_; int crypto_suite_ = rtc::SRTP_AES128_CM_SHA1_80; absl::optional ssl_cipher_suite_; - webrtc::CryptoOptions crypto_options_; - DtlsTransportState dtls_state_ = DTLS_TRANSPORT_NEW; + webrtc::DtlsTransportState dtls_state_ = webrtc::DtlsTransportState::kNew; bool receiving_ = false; bool writable_ = false; diff --git a/p2p/base/fake_ice_transport.h b/p2p/base/fake_ice_transport.h index edc5730440..f8be8a9835 100644 --- a/p2p/base/fake_ice_transport.h +++ b/p2p/base/fake_ice_transport.h @@ -20,11 +20,15 @@ #include "absl/types/optional.h" #include "api/ice_transport_interface.h" #include "p2p/base/ice_transport_internal.h" -#include "rtc_base/async_invoker.h" #include "rtc_base/copy_on_write_buffer.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/task_utils/to_queued_task.h" namespace cricket { +// All methods must be called on the network thread (which is either the thread +// calling the constructor, or the separate thread explicitly passed to the +// constructor). class FakeIceTransport : public IceTransportInternal { public: explicit FakeIceTransport(const std::string& name, @@ -34,6 +38,8 @@ class FakeIceTransport : public IceTransportInternal { component_(component), network_thread_(network_thread ? network_thread : rtc::Thread::Current()) {} + // Must be called either on the network thread, or after the network thread + // has been shut down. ~FakeIceTransport() override { if (dest_ && dest_->dest_ == this) { dest_->dest_ = nullptr; @@ -42,18 +48,31 @@ class FakeIceTransport : public IceTransportInternal { // If async, will send packets by "Post"-ing to message queue instead of // synchronously "Send"-ing. - void SetAsync(bool async) { async_ = async; } - void SetAsyncDelay(int delay_ms) { async_delay_ms_ = delay_ms; } + void SetAsync(bool async) { + RTC_DCHECK_RUN_ON(network_thread_); + async_ = async; + } + void SetAsyncDelay(int delay_ms) { + RTC_DCHECK_RUN_ON(network_thread_); + async_delay_ms_ = delay_ms; + } // SetWritable, SetReceiving and SetDestination are the main methods that can // be used for testing, to simulate connectivity or lack thereof. - void SetWritable(bool writable) { set_writable(writable); } - void SetReceiving(bool receiving) { set_receiving(receiving); } + void SetWritable(bool writable) { + RTC_DCHECK_RUN_ON(network_thread_); + set_writable(writable); + } + void SetReceiving(bool receiving) { + RTC_DCHECK_RUN_ON(network_thread_); + set_receiving(receiving); + } // Simulates the two transports connecting to each other. // If |asymmetric| is true this method only affects this FakeIceTransport. // If false, it affects |dest| as well. void SetDestination(FakeIceTransport* dest, bool asymmetric = false) { + RTC_DCHECK_RUN_ON(network_thread_); if (dest == dest_) { return; } @@ -75,12 +94,14 @@ class FakeIceTransport : public IceTransportInternal { void SetTransportState(webrtc::IceTransportState state, IceTransportState legacy_state) { + RTC_DCHECK_RUN_ON(network_thread_); transport_state_ = state; legacy_transport_state_ = legacy_state; SignalIceTransportStateChanged(this); } void SetConnectionCount(size_t connection_count) { + RTC_DCHECK_RUN_ON(network_thread_); size_t old_connection_count = connection_count_; connection_count_ = connection_count; if (connection_count) { @@ -94,6 +115,7 @@ class FakeIceTransport : public IceTransportInternal { } void SetCandidatesGatheringComplete() { + RTC_DCHECK_RUN_ON(network_thread_); if (gathering_state_ != kIceGatheringComplete) { gathering_state_ = kIceGatheringComplete; SignalGatheringState(this); @@ -102,16 +124,29 @@ class FakeIceTransport : public IceTransportInternal { // Convenience functions for accessing ICE config and other things. int receiving_timeout() const { + RTC_DCHECK_RUN_ON(network_thread_); return ice_config_.receiving_timeout_or_default(); } - bool gather_continually() const { return ice_config_.gather_continually(); } - const Candidates& remote_candidates() const { return remote_candidates_; } + bool gather_continually() const { + RTC_DCHECK_RUN_ON(network_thread_); + return ice_config_.gather_continually(); + } + const Candidates& remote_candidates() const { + RTC_DCHECK_RUN_ON(network_thread_); + return remote_candidates_; + } // Fake IceTransportInternal implementation. const std::string& transport_name() const override { return name_; } int component() const override { return component_; } - uint64_t IceTiebreaker() const { return tiebreaker_; } - IceMode remote_ice_mode() const { return remote_ice_mode_; } + uint64_t IceTiebreaker() const { + RTC_DCHECK_RUN_ON(network_thread_); + return tiebreaker_; + } + IceMode remote_ice_mode() const { + RTC_DCHECK_RUN_ON(network_thread_); + return remote_ice_mode_; + } const std::string& ice_ufrag() const { return ice_parameters_.ufrag; } const std::string& ice_pwd() const { return ice_parameters_.pwd; } const std::string& remote_ice_ufrag() const { @@ -126,6 +161,7 @@ class FakeIceTransport : public IceTransportInternal { } IceTransportState GetState() const override { + RTC_DCHECK_RUN_ON(network_thread_); if (legacy_transport_state_) { return *legacy_transport_state_; } @@ -143,6 +179,7 @@ class FakeIceTransport : public IceTransportInternal { } webrtc::IceTransportState GetIceTransportState() const override { + RTC_DCHECK_RUN_ON(network_thread_); if (transport_state_) { return *transport_state_; } @@ -159,21 +196,34 @@ class FakeIceTransport : public IceTransportInternal { return webrtc::IceTransportState::kConnected; } - void SetIceRole(IceRole role) override { role_ = role; } - IceRole GetIceRole() const override { return role_; } + void SetIceRole(IceRole role) override { + RTC_DCHECK_RUN_ON(network_thread_); + role_ = role; + } + IceRole GetIceRole() const override { + RTC_DCHECK_RUN_ON(network_thread_); + return role_; + } void SetIceTiebreaker(uint64_t tiebreaker) override { + RTC_DCHECK_RUN_ON(network_thread_); tiebreaker_ = tiebreaker; } void SetIceParameters(const IceParameters& ice_params) override { + RTC_DCHECK_RUN_ON(network_thread_); ice_parameters_ = ice_params; } void SetRemoteIceParameters(const IceParameters& params) override { + RTC_DCHECK_RUN_ON(network_thread_); remote_ice_parameters_ = params; } - void SetRemoteIceMode(IceMode mode) override { remote_ice_mode_ = mode; } + void SetRemoteIceMode(IceMode mode) override { + RTC_DCHECK_RUN_ON(network_thread_); + remote_ice_mode_ = mode; + } void MaybeStartGathering() override { + RTC_DCHECK_RUN_ON(network_thread_); if (gathering_state_ == kIceGatheringNew) { gathering_state_ = kIceGatheringGathering; SignalGatheringState(this); @@ -181,15 +231,21 @@ class FakeIceTransport : public IceTransportInternal { } IceGatheringState gathering_state() const override { + RTC_DCHECK_RUN_ON(network_thread_); return gathering_state_; } - void SetIceConfig(const IceConfig& config) override { ice_config_ = config; } + void SetIceConfig(const IceConfig& config) override { + RTC_DCHECK_RUN_ON(network_thread_); + ice_config_ = config; + } void AddRemoteCandidate(const Candidate& candidate) override { + RTC_DCHECK_RUN_ON(network_thread_); remote_candidates_.push_back(candidate); } void RemoveRemoteCandidate(const Candidate& candidate) override { + RTC_DCHECK_RUN_ON(network_thread_); auto it = absl::c_find(remote_candidates_, candidate); if (it == remote_candidates_.end()) { RTC_LOG(LS_INFO) << "Trying to remove a candidate which doesn't exist."; @@ -199,7 +255,10 @@ class FakeIceTransport : public IceTransportInternal { remote_candidates_.erase(it); } - void RemoveAllRemoteCandidates() override { remote_candidates_.clear(); } + void RemoveAllRemoteCandidates() override { + RTC_DCHECK_RUN_ON(network_thread_); + remote_candidates_.clear(); + } bool GetStats(IceTransportStats* ice_transport_stats) override { CandidateStats candidate_stats; @@ -220,17 +279,25 @@ class FakeIceTransport : public IceTransportInternal { } // Fake PacketTransportInternal implementation. - bool writable() const override { return writable_; } - bool receiving() const override { return receiving_; } + bool writable() const override { + RTC_DCHECK_RUN_ON(network_thread_); + return writable_; + } + bool receiving() const override { + RTC_DCHECK_RUN_ON(network_thread_); + return receiving_; + } // If combine is enabled, every two consecutive packets to be sent with // "SendPacket" will be combined into one outgoing packet. void combine_outgoing_packets(bool combine) { + RTC_DCHECK_RUN_ON(network_thread_); combine_outgoing_packets_ = combine; } int SendPacket(const char* data, size_t len, const rtc::PacketOptions& options, int flags) override { + RTC_DCHECK_RUN_ON(network_thread_); if (!dest_) { return -1; } @@ -239,9 +306,12 @@ class FakeIceTransport : public IceTransportInternal { if (!combine_outgoing_packets_ || send_packet_.size() > len) { rtc::CopyOnWriteBuffer packet(std::move(send_packet_)); if (async_) { - invoker_.AsyncInvokeDelayed( - RTC_FROM_HERE, rtc::Thread::Current(), - rtc::Bind(&FakeIceTransport::SendPacketInternal, this, packet), + network_thread_->PostDelayedTask( + ToQueuedTask(task_safety_.flag(), + [this, packet] { + RTC_DCHECK_RUN_ON(network_thread_); + FakeIceTransport::SendPacketInternal(packet); + }), async_delay_ms_); } else { SendPacketInternal(packet); @@ -253,10 +323,12 @@ class FakeIceTransport : public IceTransportInternal { } int SetOption(rtc::Socket::Option opt, int value) override { + RTC_DCHECK_RUN_ON(network_thread_); socket_options_[opt] = value; return true; } bool GetOption(rtc::Socket::Option opt, int* value) override { + RTC_DCHECK_RUN_ON(network_thread_); auto it = socket_options_.find(opt); if (it != socket_options_.end()) { *value = it->second; @@ -268,19 +340,27 @@ class FakeIceTransport : public IceTransportInternal { int GetError() override { return 0; } - rtc::CopyOnWriteBuffer last_sent_packet() { return last_sent_packet_; } + rtc::CopyOnWriteBuffer last_sent_packet() { + RTC_DCHECK_RUN_ON(network_thread_); + return last_sent_packet_; + } absl::optional network_route() const override { + RTC_DCHECK_RUN_ON(network_thread_); return network_route_; } void SetNetworkRoute(absl::optional network_route) { + RTC_DCHECK_RUN_ON(network_thread_); network_route_ = network_route; - network_thread_->Invoke( - RTC_FROM_HERE, [this] { SignalNetworkRouteChanged(network_route_); }); + network_thread_->Invoke(RTC_FROM_HERE, [this] { + RTC_DCHECK_RUN_ON(network_thread_); + SignalNetworkRouteChanged(network_route_); + }); } private: - void set_writable(bool writable) { + void set_writable(bool writable) + RTC_EXCLUSIVE_LOCKS_REQUIRED(network_thread_) { if (writable_ == writable) { return; } @@ -292,7 +372,8 @@ class FakeIceTransport : public IceTransportInternal { SignalWritableState(this); } - void set_receiving(bool receiving) { + void set_receiving(bool receiving) + RTC_EXCLUSIVE_LOCKS_REQUIRED(network_thread_) { if (receiving_ == receiving) { return; } @@ -300,7 +381,8 @@ class FakeIceTransport : public IceTransportInternal { SignalReceivingState(this); } - void SendPacketInternal(const rtc::CopyOnWriteBuffer& packet) { + void SendPacketInternal(const rtc::CopyOnWriteBuffer& packet) + RTC_EXCLUSIVE_LOCKS_REQUIRED(network_thread_) { if (dest_) { last_sent_packet_ = packet; dest_->SignalReadPacket(dest_, packet.data(), packet.size(), @@ -308,32 +390,37 @@ class FakeIceTransport : public IceTransportInternal { } } - rtc::AsyncInvoker invoker_; - std::string name_; - int component_; - FakeIceTransport* dest_ = nullptr; - bool async_ = false; - int async_delay_ms_ = 0; - Candidates remote_candidates_; - IceConfig ice_config_; - IceRole role_ = ICEROLE_UNKNOWN; - uint64_t tiebreaker_ = 0; - IceParameters ice_parameters_; - IceParameters remote_ice_parameters_; - IceMode remote_ice_mode_ = ICEMODE_FULL; - size_t connection_count_ = 0; - absl::optional transport_state_; - absl::optional legacy_transport_state_; - IceGatheringState gathering_state_ = kIceGatheringNew; - bool had_connection_ = false; - bool writable_ = false; - bool receiving_ = false; - bool combine_outgoing_packets_ = false; - rtc::CopyOnWriteBuffer send_packet_; - absl::optional network_route_; - std::map socket_options_; - rtc::CopyOnWriteBuffer last_sent_packet_; + const std::string name_; + const int component_; + FakeIceTransport* dest_ RTC_GUARDED_BY(network_thread_) = nullptr; + bool async_ RTC_GUARDED_BY(network_thread_) = false; + int async_delay_ms_ RTC_GUARDED_BY(network_thread_) = 0; + Candidates remote_candidates_ RTC_GUARDED_BY(network_thread_); + IceConfig ice_config_ RTC_GUARDED_BY(network_thread_); + IceRole role_ RTC_GUARDED_BY(network_thread_) = ICEROLE_UNKNOWN; + uint64_t tiebreaker_ RTC_GUARDED_BY(network_thread_) = 0; + IceParameters ice_parameters_ RTC_GUARDED_BY(network_thread_); + IceParameters remote_ice_parameters_ RTC_GUARDED_BY(network_thread_); + IceMode remote_ice_mode_ RTC_GUARDED_BY(network_thread_) = ICEMODE_FULL; + size_t connection_count_ RTC_GUARDED_BY(network_thread_) = 0; + absl::optional transport_state_ + RTC_GUARDED_BY(network_thread_); + absl::optional legacy_transport_state_ + RTC_GUARDED_BY(network_thread_); + IceGatheringState gathering_state_ RTC_GUARDED_BY(network_thread_) = + kIceGatheringNew; + bool had_connection_ RTC_GUARDED_BY(network_thread_) = false; + bool writable_ RTC_GUARDED_BY(network_thread_) = false; + bool receiving_ RTC_GUARDED_BY(network_thread_) = false; + bool combine_outgoing_packets_ RTC_GUARDED_BY(network_thread_) = false; + rtc::CopyOnWriteBuffer send_packet_ RTC_GUARDED_BY(network_thread_); + absl::optional network_route_ + RTC_GUARDED_BY(network_thread_); + std::map socket_options_ + RTC_GUARDED_BY(network_thread_); + rtc::CopyOnWriteBuffer last_sent_packet_ RTC_GUARDED_BY(network_thread_); rtc::Thread* const network_thread_; + webrtc::ScopedTaskSafetyDetached task_safety_; }; class FakeIceTransportWrapper : public webrtc::IceTransportInterface { diff --git a/p2p/base/fake_packet_transport.h b/p2p/base/fake_packet_transport.h index a5e2abb7d6..b69c9b5208 100644 --- a/p2p/base/fake_packet_transport.h +++ b/p2p/base/fake_packet_transport.h @@ -15,7 +15,6 @@ #include #include "p2p/base/packet_transport_internal.h" -#include "rtc_base/async_invoker.h" #include "rtc_base/copy_on_write_buffer.h" namespace rtc { @@ -31,11 +30,6 @@ class FakePacketTransport : public PacketTransportInternal { } } - // If async, will send packets by "Post"-ing to message queue instead of - // synchronously "Send"-ing. - void SetAsync(bool async) { async_ = async; } - void SetAsyncDelay(int delay_ms) { async_delay_ms_ = delay_ms; } - // SetWritable, SetReceiving and SetDestination are the main methods that can // be used for testing, to simulate connectivity or lack thereof. void SetWritable(bool writable) { set_writable(writable); } @@ -70,14 +64,8 @@ class FakePacketTransport : public PacketTransportInternal { return -1; } CopyOnWriteBuffer packet(data, len); - if (async_) { - invoker_.AsyncInvokeDelayed( - RTC_FROM_HERE, Thread::Current(), - Bind(&FakePacketTransport::SendPacketInternal, this, packet), - async_delay_ms_); - } else { - SendPacketInternal(packet); - } + SendPacketInternal(packet); + SentPacket sent_packet(options.packet_id, TimeMillis()); SignalSentPacket(this, sent_packet); return static_cast(len); @@ -139,11 +127,8 @@ class FakePacketTransport : public PacketTransportInternal { } CopyOnWriteBuffer last_sent_packet_; - AsyncInvoker invoker_; std::string transport_name_; FakePacketTransport* dest_ = nullptr; - bool async_ = false; - int async_delay_ms_ = 0; bool writable_ = false; bool receiving_ = false; diff --git a/p2p/base/fake_port_allocator.h b/p2p/base/fake_port_allocator.h index 266bb7956b..efe9a53a16 100644 --- a/p2p/base/fake_port_allocator.h +++ b/p2p/base/fake_port_allocator.h @@ -18,7 +18,6 @@ #include "p2p/base/basic_packet_socket_factory.h" #include "p2p/base/port_allocator.h" #include "p2p/base/udp_port.h" -#include "rtc_base/bind.h" #include "rtc_base/net_helpers.h" #include "rtc_base/thread.h" @@ -119,8 +118,8 @@ class FakePortAllocatorSession : public PortAllocatorSession { username(), password(), std::string(), false)); RTC_DCHECK(port_); - port_->SignalDestroyed.connect( - this, &FakePortAllocatorSession::OnPortDestroyed); + port_->SubscribePortDestroyed( + [this](PortInterface* port) { OnPortDestroyed(port); }); AddPort(port_.get()); } ++port_config_count_; @@ -222,9 +221,7 @@ class FakePortAllocator : public cricket::PortAllocator { Initialize(); return; } - network_thread_->Invoke(RTC_FROM_HERE, - rtc::Bind(&PortAllocator::Initialize, - static_cast(this))); + network_thread_->Invoke(RTC_FROM_HERE, [this] { Initialize(); }); } void SetNetworkIgnoreMask(int network_ignore_mask) override {} @@ -241,10 +238,19 @@ class FakePortAllocator : public cricket::PortAllocator { bool initialized() const { return initialized_; } + // For testing: Manipulate MdnsObfuscationEnabled() + bool MdnsObfuscationEnabled() const override { + return mdns_obfuscation_enabled_; + } + void SetMdnsObfuscationEnabledForTesting(bool enabled) { + mdns_obfuscation_enabled_ = enabled; + } + private: rtc::Thread* network_thread_; rtc::PacketSocketFactory* factory_; std::unique_ptr owned_factory_; + bool mdns_obfuscation_enabled_ = false; }; } // namespace cricket diff --git a/p2p/base/ice_controller_interface.h b/p2p/base/ice_controller_interface.h index d5dc29e782..0e77d1dd00 100644 --- a/p2p/base/ice_controller_interface.h +++ b/p2p/base/ice_controller_interface.h @@ -87,7 +87,9 @@ class IceControllerInterface { // This represents the result of a call to SelectConnectionToPing. struct PingResult { PingResult(const Connection* conn, int _recheck_delay_ms) - : connection(conn), recheck_delay_ms(_recheck_delay_ms) {} + : connection(conn ? absl::optional(conn) + : absl::nullopt), + recheck_delay_ms(_recheck_delay_ms) {} // Connection that we should (optionally) ping. const absl::optional connection; diff --git a/p2p/base/ice_transport_internal.cc b/p2p/base/ice_transport_internal.cc index 1d5b6e7403..104a95b5af 100644 --- a/p2p/base/ice_transport_internal.cc +++ b/p2p/base/ice_transport_internal.cc @@ -14,6 +14,50 @@ namespace cricket { +using webrtc::RTCError; +using webrtc::RTCErrorType; + +RTCError VerifyCandidate(const Candidate& cand) { + // No address zero. + if (cand.address().IsNil() || cand.address().IsAnyIP()) { + return RTCError(RTCErrorType::INVALID_PARAMETER, + "candidate has address of zero"); + } + + // Disallow all ports below 1024, except for 80 and 443 on public addresses. + int port = cand.address().port(); + if (cand.protocol() == cricket::TCP_PROTOCOL_NAME && + (cand.tcptype() == cricket::TCPTYPE_ACTIVE_STR || port == 0)) { + // Expected for active-only candidates per + // http://tools.ietf.org/html/rfc6544#section-4.5 so no error. + // Libjingle clients emit port 0, in "active" mode. + return RTCError::OK(); + } + if (port < 1024) { + if ((port != 80) && (port != 443)) { + return RTCError(RTCErrorType::INVALID_PARAMETER, + "candidate has port below 1024, but not 80 or 443"); + } + + if (cand.address().IsPrivateIP()) { + return RTCError( + RTCErrorType::INVALID_PARAMETER, + "candidate has port of 80 or 443 with private IP address"); + } + } + + return RTCError::OK(); +} + +RTCError VerifyCandidates(const Candidates& candidates) { + for (const Candidate& candidate : candidates) { + RTCError error = VerifyCandidate(candidate); + if (!error.ok()) + return error; + } + return RTCError::OK(); +} + IceConfig::IceConfig() = default; IceConfig::IceConfig(int receiving_timeout_ms, diff --git a/p2p/base/ice_transport_internal.h b/p2p/base/ice_transport_internal.h index b735a1a742..b3eb2dc9e2 100644 --- a/p2p/base/ice_transport_internal.h +++ b/p2p/base/ice_transport_internal.h @@ -18,6 +18,7 @@ #include "absl/types/optional.h" #include "api/candidate.h" +#include "api/rtc_error.h" #include "api/transport/enums.h" #include "p2p/base/connection.h" #include "p2p/base/packet_transport_internal.h" @@ -74,6 +75,17 @@ enum class NominationMode { // The details are described in P2PTransportChannel. }; +// Utility method that checks if various required Candidate fields are filled in +// and contain valid values. If conditions are not met, an RTCError with the +// appropriated error number and description is returned. If the configuration +// is valid RTCError::OK() is returned. +webrtc::RTCError VerifyCandidate(const Candidate& cand); + +// Runs through a list of cricket::Candidate instances and calls VerifyCandidate +// for each one, stopping on the first error encounted and returning that error +// value if so. On success returns RTCError::OK(). +webrtc::RTCError VerifyCandidates(const Candidates& candidates); + // Information about ICE configuration. // TODO(deadbeef): Use absl::optional to represent unset values, instead of // -1. diff --git a/p2p/base/mdns_message.cc b/p2p/base/mdns_message.cc deleted file mode 100644 index 1aa996c4a8..0000000000 --- a/p2p/base/mdns_message.cc +++ /dev/null @@ -1,396 +0,0 @@ -/* - * Copyright 2018 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "p2p/base/mdns_message.h" - -#include "rtc_base/logging.h" -#include "rtc_base/net_helpers.h" -#include "rtc_base/string_encode.h" - -namespace webrtc { - -namespace { -// RFC 1035, Section 4.1.1. -// -// QR bit. -constexpr uint16_t kMdnsFlagMaskQueryOrResponse = 0x8000; -// AA bit. -constexpr uint16_t kMdnsFlagMaskAuthoritative = 0x0400; -// RFC 1035, Section 4.1.2, QCLASS and RFC 6762, Section 18.12, repurposing of -// top bit of QCLASS as the unicast response bit. -constexpr uint16_t kMdnsQClassMaskUnicastResponse = 0x8000; -constexpr size_t kMdnsHeaderSizeBytes = 12; - -bool ReadDomainName(MessageBufferReader* buf, std::string* name) { - size_t name_start_pos = buf->CurrentOffset(); - uint8_t label_length; - if (!buf->ReadUInt8(&label_length)) { - return false; - } - // RFC 1035, Section 4.1.4. - // - // If the first two bits of the length octet are ones, the name is compressed - // and the rest six bits with the next octet denotes its position in the - // message by the offset from the start of the message. - auto is_pointer = [](uint8_t octet) { - return (octet & 0x80) && (octet & 0x40); - }; - while (label_length && !is_pointer(label_length)) { - // RFC 1035, Section 2.3.1, labels are restricted to 63 octets or less. - if (label_length > 63) { - return false; - } - std::string label; - if (!buf->ReadString(&label, label_length)) { - return false; - } - (*name) += label + "."; - if (!buf->ReadUInt8(&label_length)) { - return false; - } - } - if (is_pointer(label_length)) { - uint8_t next_octet; - if (!buf->ReadUInt8(&next_octet)) { - return false; - } - size_t pos_jump_to = ((label_length & 0x3f) << 8) | next_octet; - // A legitimate pointer only refers to a prior occurrence of the same name, - // and we should only move strictly backward to a prior name field after the - // header. - if (pos_jump_to >= name_start_pos || pos_jump_to < kMdnsHeaderSizeBytes) { - return false; - } - MessageBufferReader new_buf(buf->MessageData(), buf->MessageLength()); - if (!new_buf.Consume(pos_jump_to)) { - return false; - } - return ReadDomainName(&new_buf, name); - } - return true; -} - -void WriteDomainName(rtc::ByteBufferWriter* buf, const std::string& name) { - std::vector labels; - rtc::tokenize(name, '.', &labels); - for (const auto& label : labels) { - buf->WriteUInt8(label.length()); - buf->WriteString(label); - } - buf->WriteUInt8(0); -} - -} // namespace - -void MdnsHeader::SetQueryOrResponse(bool is_query) { - if (is_query) { - flags &= ~kMdnsFlagMaskQueryOrResponse; - } else { - flags |= kMdnsFlagMaskQueryOrResponse; - } -} - -void MdnsHeader::SetAuthoritative(bool is_authoritative) { - if (is_authoritative) { - flags |= kMdnsFlagMaskAuthoritative; - } else { - flags &= ~kMdnsFlagMaskAuthoritative; - } -} - -bool MdnsHeader::IsAuthoritative() const { - return flags & kMdnsFlagMaskAuthoritative; -} - -bool MdnsHeader::Read(MessageBufferReader* buf) { - if (!buf->ReadUInt16(&id) || !buf->ReadUInt16(&flags) || - !buf->ReadUInt16(&qdcount) || !buf->ReadUInt16(&ancount) || - !buf->ReadUInt16(&nscount) || !buf->ReadUInt16(&arcount)) { - RTC_LOG(LS_ERROR) << "Invalid mDNS header."; - return false; - } - return true; -} - -void MdnsHeader::Write(rtc::ByteBufferWriter* buf) const { - buf->WriteUInt16(id); - buf->WriteUInt16(flags); - buf->WriteUInt16(qdcount); - buf->WriteUInt16(ancount); - buf->WriteUInt16(nscount); - buf->WriteUInt16(arcount); -} - -bool MdnsHeader::IsQuery() const { - return !(flags & kMdnsFlagMaskQueryOrResponse); -} - -MdnsSectionEntry::MdnsSectionEntry() = default; -MdnsSectionEntry::~MdnsSectionEntry() = default; -MdnsSectionEntry::MdnsSectionEntry(const MdnsSectionEntry& other) = default; - -void MdnsSectionEntry::SetType(SectionEntryType type) { - switch (type) { - case SectionEntryType::kA: - type_ = 1; - return; - case SectionEntryType::kAAAA: - type_ = 28; - return; - default: - RTC_NOTREACHED(); - } -} - -SectionEntryType MdnsSectionEntry::GetType() const { - switch (type_) { - case 1: - return SectionEntryType::kA; - case 28: - return SectionEntryType::kAAAA; - default: - return SectionEntryType::kUnsupported; - } -} - -void MdnsSectionEntry::SetClass(SectionEntryClass cls) { - switch (cls) { - case SectionEntryClass::kIN: - class_ = 1; - return; - default: - RTC_NOTREACHED(); - } -} - -SectionEntryClass MdnsSectionEntry::GetClass() const { - switch (class_) { - case 1: - return SectionEntryClass::kIN; - default: - return SectionEntryClass::kUnsupported; - } -} - -MdnsQuestion::MdnsQuestion() = default; -MdnsQuestion::MdnsQuestion(const MdnsQuestion& other) = default; -MdnsQuestion::~MdnsQuestion() = default; - -bool MdnsQuestion::Read(MessageBufferReader* buf) { - if (!ReadDomainName(buf, &name_)) { - RTC_LOG(LS_ERROR) << "Invalid name."; - return false; - } - if (!buf->ReadUInt16(&type_) || !buf->ReadUInt16(&class_)) { - RTC_LOG(LS_ERROR) << "Invalid type and class."; - return false; - } - return true; -} - -bool MdnsQuestion::Write(rtc::ByteBufferWriter* buf) const { - WriteDomainName(buf, name_); - buf->WriteUInt16(type_); - buf->WriteUInt16(class_); - return true; -} - -void MdnsQuestion::SetUnicastResponse(bool should_unicast) { - if (should_unicast) { - class_ |= kMdnsQClassMaskUnicastResponse; - } else { - class_ &= ~kMdnsQClassMaskUnicastResponse; - } -} - -bool MdnsQuestion::ShouldUnicastResponse() const { - return class_ & kMdnsQClassMaskUnicastResponse; -} - -MdnsResourceRecord::MdnsResourceRecord() = default; -MdnsResourceRecord::MdnsResourceRecord(const MdnsResourceRecord& other) = - default; -MdnsResourceRecord::~MdnsResourceRecord() = default; - -bool MdnsResourceRecord::Read(MessageBufferReader* buf) { - if (!ReadDomainName(buf, &name_)) { - return false; - } - if (!buf->ReadUInt16(&type_) || !buf->ReadUInt16(&class_) || - !buf->ReadUInt32(&ttl_seconds_) || !buf->ReadUInt16(&rdlength_)) { - return false; - } - - switch (GetType()) { - case SectionEntryType::kA: - return ReadARData(buf); - case SectionEntryType::kAAAA: - return ReadQuadARData(buf); - case SectionEntryType::kUnsupported: - return false; - default: - RTC_NOTREACHED(); - } - return false; -} -bool MdnsResourceRecord::ReadARData(MessageBufferReader* buf) { - // A RDATA contains a 32-bit IPv4 address. - return buf->ReadString(&rdata_, 4); -} - -bool MdnsResourceRecord::ReadQuadARData(MessageBufferReader* buf) { - // AAAA RDATA contains a 128-bit IPv6 address. - return buf->ReadString(&rdata_, 16); -} - -bool MdnsResourceRecord::Write(rtc::ByteBufferWriter* buf) const { - WriteDomainName(buf, name_); - buf->WriteUInt16(type_); - buf->WriteUInt16(class_); - buf->WriteUInt32(ttl_seconds_); - buf->WriteUInt16(rdlength_); - switch (GetType()) { - case SectionEntryType::kA: - WriteARData(buf); - return true; - case SectionEntryType::kAAAA: - WriteQuadARData(buf); - return true; - case SectionEntryType::kUnsupported: - return false; - default: - RTC_NOTREACHED(); - } - return true; -} - -void MdnsResourceRecord::WriteARData(rtc::ByteBufferWriter* buf) const { - buf->WriteString(rdata_); -} - -void MdnsResourceRecord::WriteQuadARData(rtc::ByteBufferWriter* buf) const { - buf->WriteString(rdata_); -} - -bool MdnsResourceRecord::SetIPAddressInRecordData( - const rtc::IPAddress& address) { - int af = address.family(); - if (af != AF_INET && af != AF_INET6) { - return false; - } - char out[16] = {0}; - if (!rtc::inet_pton(af, address.ToString().c_str(), out)) { - return false; - } - rdlength_ = (af == AF_INET) ? 4 : 16; - rdata_ = std::string(out, rdlength_); - return true; -} - -bool MdnsResourceRecord::GetIPAddressFromRecordData( - rtc::IPAddress* address) const { - if (GetType() != SectionEntryType::kA && - GetType() != SectionEntryType::kAAAA) { - return false; - } - if (rdata_.size() != 4 && rdata_.size() != 16) { - return false; - } - char out[INET6_ADDRSTRLEN] = {0}; - int af = (GetType() == SectionEntryType::kA) ? AF_INET : AF_INET6; - if (!rtc::inet_ntop(af, rdata_.data(), out, sizeof(out))) { - return false; - } - return rtc::IPFromString(std::string(out), address); -} - -MdnsMessage::MdnsMessage() = default; -MdnsMessage::~MdnsMessage() = default; - -bool MdnsMessage::Read(MessageBufferReader* buf) { - RTC_DCHECK_EQ(0u, buf->CurrentOffset()); - if (!header_.Read(buf)) { - return false; - } - - auto read_question = [&buf](std::vector* section, - uint16_t count) { - section->resize(count); - for (auto& question : (*section)) { - if (!question.Read(buf)) { - return false; - } - } - return true; - }; - auto read_rr = [&buf](std::vector* section, - uint16_t count) { - section->resize(count); - for (auto& rr : (*section)) { - if (!rr.Read(buf)) { - return false; - } - } - return true; - }; - - if (!read_question(&question_section_, header_.qdcount) || - !read_rr(&answer_section_, header_.ancount) || - !read_rr(&authority_section_, header_.nscount) || - !read_rr(&additional_section_, header_.arcount)) { - return false; - } - return true; -} - -bool MdnsMessage::Write(rtc::ByteBufferWriter* buf) const { - header_.Write(buf); - - auto write_rr = [&buf](const std::vector& section) { - for (const auto& rr : section) { - if (!rr.Write(buf)) { - return false; - } - } - return true; - }; - - for (const auto& question : question_section_) { - if (!question.Write(buf)) { - return false; - } - } - if (!write_rr(answer_section_) || !write_rr(authority_section_) || - !write_rr(additional_section_)) { - return false; - } - - return true; -} - -bool MdnsMessage::ShouldUnicastResponse() const { - bool should_unicast = false; - for (const auto& question : question_section_) { - should_unicast |= question.ShouldUnicastResponse(); - } - return should_unicast; -} - -void MdnsMessage::AddQuestion(const MdnsQuestion& question) { - question_section_.push_back(question); - header_.qdcount = question_section_.size(); -} - -void MdnsMessage::AddAnswerRecord(const MdnsResourceRecord& answer) { - answer_section_.push_back(answer); - header_.ancount = answer_section_.size(); -} - -} // namespace webrtc diff --git a/p2p/base/mdns_message.h b/p2p/base/mdns_message.h deleted file mode 100644 index 79be5219e4..0000000000 --- a/p2p/base/mdns_message.h +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Copyright 2018 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef P2P_BASE_MDNS_MESSAGE_H_ -#define P2P_BASE_MDNS_MESSAGE_H_ - -// This file contains classes to read and write mDNSs message defined in RFC -// 6762 and RFC 1025 (DNS messages). Note that it is recommended by RFC 6762 to -// use the name compression scheme defined in RFC 1035 whenever possible. We -// currently only implement the capability of reading compressed names in mDNS -// messages in MdnsMessage::Read(); however, the MdnsMessage::Write() does not -// support name compression yet. -// -// Fuzzer tests (test/fuzzers/mdns_parser_fuzzer.cc) MUST always be performed -// after changes made to this file. - -#include - -#include -#include - -#include "rtc_base/byte_buffer.h" -#include "rtc_base/ip_address.h" -#include "rtc_base/message_buffer_reader.h" - -namespace webrtc { - -// We use "section entry" to denote either a question or a resource record. -// -// RFC 1035 Section 3.2.2. -enum class SectionEntryType { - kA, - kAAAA, - // Only the above types are processed in the current implementation. - kUnsupported, -}; - -// RFC 1035 Section 3.2.4. -enum class SectionEntryClass { - kIN, - kUnsupported, -}; - -// RFC 1035, Section 4.1.1. -class MdnsHeader final { - public: - bool Read(MessageBufferReader* buf); - void Write(rtc::ByteBufferWriter* buf) const; - - void SetQueryOrResponse(bool is_query); - bool IsQuery() const; - void SetAuthoritative(bool is_authoritative); - bool IsAuthoritative() const; - - uint16_t id = 0; - uint16_t flags = 0; - // Number of entries in the question section. - uint16_t qdcount = 0; - // Number of resource records in the answer section. - uint16_t ancount = 0; - // Number of name server resource records in the authority records section. - uint16_t nscount = 0; - // Number of resource records in the additional records section. - uint16_t arcount = 0; -}; - -// Entries in each section after the header share a common structure. Note that -// this is not a concept defined in RFC 1035. -class MdnsSectionEntry { - public: - MdnsSectionEntry(); - MdnsSectionEntry(const MdnsSectionEntry& other); - virtual ~MdnsSectionEntry(); - virtual bool Read(MessageBufferReader* buf) = 0; - virtual bool Write(rtc::ByteBufferWriter* buf) const = 0; - - void SetName(const std::string& name) { name_ = name; } - // Returns the fully qualified domain name in the section entry, i.e., QNAME - // in a question or NAME in a resource record. - std::string GetName() const { return name_; } - - void SetType(SectionEntryType type); - SectionEntryType GetType() const; - void SetClass(SectionEntryClass cls); - SectionEntryClass GetClass() const; - - protected: - std::string name_; // Fully qualified domain name. - uint16_t type_ = 0; - uint16_t class_ = 0; -}; - -// RFC 1035, Section 4.1.2. -class MdnsQuestion final : public MdnsSectionEntry { - public: - MdnsQuestion(); - MdnsQuestion(const MdnsQuestion& other); - ~MdnsQuestion() override; - - bool Read(MessageBufferReader* buf) override; - bool Write(rtc::ByteBufferWriter* buf) const override; - - void SetUnicastResponse(bool should_unicast); - bool ShouldUnicastResponse() const; -}; - -// RFC 1035, Section 4.1.3. -class MdnsResourceRecord final : public MdnsSectionEntry { - public: - MdnsResourceRecord(); - MdnsResourceRecord(const MdnsResourceRecord& other); - ~MdnsResourceRecord() override; - - bool Read(MessageBufferReader* buf) override; - bool Write(rtc::ByteBufferWriter* buf) const override; - - void SetTtlSeconds(uint32_t ttl_seconds) { ttl_seconds_ = ttl_seconds; } - uint32_t GetTtlSeconds() const { return ttl_seconds_; } - // Returns true if |address| is in the address family AF_INET or AF_INET6 and - // |address| has a valid IPv4 or IPv6 address; false otherwise. - bool SetIPAddressInRecordData(const rtc::IPAddress& address); - // Returns true if the record is of type A or AAAA and the record has a valid - // IPv4 or IPv6 address; false otherwise. Stores the valid IP in |address|. - bool GetIPAddressFromRecordData(rtc::IPAddress* address) const; - - private: - // The list of methods reading and writing rdata can grow as we support more - // types of rdata. - bool ReadARData(MessageBufferReader* buf); - void WriteARData(rtc::ByteBufferWriter* buf) const; - - bool ReadQuadARData(MessageBufferReader* buf); - void WriteQuadARData(rtc::ByteBufferWriter* buf) const; - - uint32_t ttl_seconds_ = 0; - uint16_t rdlength_ = 0; - std::string rdata_; -}; - -class MdnsMessage final { - public: - // RFC 1035, Section 4.1. - enum class Section { kQuestion, kAnswer, kAuthority, kAdditional }; - - MdnsMessage(); - ~MdnsMessage(); - // Reads the mDNS message in |buf| and populates the corresponding fields in - // MdnsMessage. - bool Read(MessageBufferReader* buf); - // Write an mDNS message to |buf| based on the fields in MdnsMessage. - // - // TODO(qingsi): Implement name compression when writing mDNS messages. - bool Write(rtc::ByteBufferWriter* buf) const; - - void SetId(uint16_t id) { header_.id = id; } - uint16_t GetId() const { return header_.id; } - - void SetQueryOrResponse(bool is_query) { - header_.SetQueryOrResponse(is_query); - } - bool IsQuery() const { return header_.IsQuery(); } - - void SetAuthoritative(bool is_authoritative) { - header_.SetAuthoritative(is_authoritative); - } - bool IsAuthoritative() const { return header_.IsAuthoritative(); } - - // Returns true if the message is a query and the unicast response is - // preferred. False otherwise. - bool ShouldUnicastResponse() const; - - void AddQuestion(const MdnsQuestion& question); - // TODO(qingsi): Implement AddXRecord for name server and additional records. - void AddAnswerRecord(const MdnsResourceRecord& answer); - - const std::vector& question_section() const { - return question_section_; - } - const std::vector& answer_section() const { - return answer_section_; - } - const std::vector& authority_section() const { - return authority_section_; - } - const std::vector& additional_section() const { - return additional_section_; - } - - private: - MdnsHeader header_; - std::vector question_section_; - std::vector answer_section_; - std::vector authority_section_; - std::vector additional_section_; -}; - -} // namespace webrtc - -#endif // P2P_BASE_MDNS_MESSAGE_H_ diff --git a/p2p/base/mdns_message_unittest.cc b/p2p/base/mdns_message_unittest.cc deleted file mode 100644 index 2f1f74d8e3..0000000000 --- a/p2p/base/mdns_message_unittest.cc +++ /dev/null @@ -1,571 +0,0 @@ -/* - * Copyright 2018 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "p2p/base/mdns_message.h" - -#include -#include -#include - -#include "rtc_base/byte_buffer.h" -#include "rtc_base/gunit.h" -#include "rtc_base/ip_address.h" -#include "rtc_base/socket_address.h" -#include "test/gmock.h" - -#define ReadMdnsMessage(X, Y) ReadMdnsMessageTestCase(X, Y, sizeof(Y)) -#define WriteMdnsMessageAndCompare(X, Y) \ - WriteMdnsMessageAndCompareWithTestCast(X, Y, sizeof(Y)) - -using ::testing::ElementsAre; -using ::testing::Pair; -using ::testing::UnorderedElementsAre; - -namespace webrtc { - -namespace { - -const uint8_t kSingleQuestionForIPv4AddrWithUnicastResponse[] = { - 0x12, 0x34, // ID - 0x00, 0x00, // flags - 0x00, 0x01, // number of questions - 0x00, 0x00, // number of answer rr - 0x00, 0x00, // number of name server rr - 0x00, 0x00, // number of additional rr - 0x06, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, // webrtc - 0x03, 0x6f, 0x72, 0x67, // org - 0x00, // null label - 0x00, 0x01, // type A Record - 0x80, 0x01, // class IN, unicast response -}; - -const uint8_t kTwoQuestionsForIPv4AndIPv6AddrWithMulticastResponse[] = { - 0x12, 0x34, // ID - 0x00, 0x00, // flags - 0x00, 0x02, // number of questions - 0x00, 0x00, // number of answer rr - 0x00, 0x00, // number of name server rr - 0x00, 0x00, // number of additional rr - 0x07, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x34, // webrtc4 - 0x03, 0x6f, 0x72, 0x67, // org - 0x00, // null label - 0x00, 0x01, // type A Record - 0x00, 0x01, // class IN, multicast response - 0x07, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x36, // webrtc6 - 0x03, 0x6f, 0x72, 0x67, // org - 0x00, // null label - 0x00, 0x1C, // type AAAA Record - 0x00, 0x01, // class IN, multicast response -}; - -const uint8_t - kTwoQuestionsForIPv4AndIPv6AddrWithMulticastResponseAndNameCompression[] = { - 0x12, 0x34, // ID - 0x00, 0x00, // flags - 0x00, 0x02, // number of questions - 0x00, 0x00, // number of answer rr - 0x00, 0x00, // number of name server rr - 0x00, 0x00, // number of additional rr - 0x03, 0x77, 0x77, 0x77, // www - 0x06, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, // webrtc - 0x03, 0x6f, 0x72, 0x67, // org - 0x00, // null label - 0x00, 0x01, // type A Record - 0x00, 0x01, // class IN, multicast response - 0x04, 0x6d, 0x64, 0x6e, 0x73, // mdns - 0xc0, 0x10, // offset 16, webrtc.org. - 0x00, 0x1C, // type AAAA Record - 0x00, 0x01, // class IN, multicast response -}; - -const uint8_t kThreeQuestionsWithTwoPointersToTheSameNameSuffix[] = { - 0x12, 0x34, // ID - 0x00, 0x00, // flags - 0x00, 0x03, // number of questions - 0x00, 0x00, // number of answer rr - 0x00, 0x00, // number of name server rr - 0x00, 0x00, // number of additional rr - 0x03, 0x77, 0x77, 0x77, // www - 0x06, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, // webrtc - 0x03, 0x6f, 0x72, 0x67, // org - 0x00, // null label - 0x00, 0x01, // type A Record - 0x00, 0x01, // class IN, multicast response - 0x04, 0x6d, 0x64, 0x6e, 0x73, // mdns - 0xc0, 0x10, // offset 16, webrtc.org. - 0x00, 0x1C, // type AAAA Record - 0x00, 0x01, // class IN, multicast response - 0xc0, 0x10, // offset 16, webrtc.org. - 0x00, 0x01, // type A Record - 0x00, 0x01, // class IN, multicast response -}; - -const uint8_t kThreeQuestionsWithPointerToNameSuffixContainingAnotherPointer[] = - { - 0x12, 0x34, // ID - 0x00, 0x00, // flags - 0x00, 0x03, // number of questions - 0x00, 0x00, // number of answer rr - 0x00, 0x00, // number of name server rr - 0x00, 0x00, // number of additional rr - 0x03, 0x77, 0x77, 0x77, // www - 0x06, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, // webrtc - 0x03, 0x6f, 0x72, 0x67, // org - 0x00, // null label - 0x00, 0x01, // type A Record - 0x00, 0x01, // class IN, multicast response - 0x04, 0x6d, 0x64, 0x6e, 0x73, // mdns - 0xc0, 0x10, // offset 16, webrtc.org. - 0x00, 0x1C, // type AAAA Record - 0x00, 0x01, // class IN, multicast response - 0x03, 0x77, 0x77, 0x77, // www - 0xc0, 0x20, // offset 32, mdns.webrtc.org. - 0x00, 0x01, // type A Record - 0x00, 0x01, // class IN, multicast response -}; - -const uint8_t kCorruptedQuestionWithNameCompression1[] = { - 0x12, 0x34, // ID - 0x84, 0x00, // flags - 0x00, 0x01, // number of questions - 0x00, 0x00, // number of answer rr - 0x00, 0x00, // number of name server rr - 0x00, 0x00, // number of additional rr - 0xc0, 0x0c, // offset 12, - 0x00, 0x01, // type A Record - 0x00, 0x01, // class IN -}; - -const uint8_t kCorruptedQuestionWithNameCompression2[] = { - 0x12, 0x34, // ID - 0x84, 0x00, // flags - 0x00, 0x01, // number of questions - 0x00, 0x00, // number of answer rr - 0x00, 0x00, // number of name server rr - 0x00, 0x00, // number of additional rr - 0x01, 0x77, // w - 0xc0, 0x0c, // offset 12, - 0x00, 0x01, // type A Record - 0x00, 0x01, // class IN -}; - -const uint8_t kSingleAuthoritativeAnswerWithIPv4Addr[] = { - 0x12, 0x34, // ID - 0x84, 0x00, // flags - 0x00, 0x00, // number of questions - 0x00, 0x01, // number of answer rr - 0x00, 0x00, // number of name server rr - 0x00, 0x00, // number of additional rr - 0x06, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, // webrtc - 0x03, 0x6f, 0x72, 0x67, // org - 0x00, // null label - 0x00, 0x01, // type A Record - 0x00, 0x01, // class IN - 0x00, 0x00, 0x00, 0x78, // TTL, 120 seconds - 0x00, 0x04, // rdlength, 32 bits - 0xC0, 0xA8, 0x00, 0x01, // 192.168.0.1 -}; - -const uint8_t kTwoAuthoritativeAnswersWithIPv4AndIPv6Addr[] = { - 0x12, 0x34, // ID - 0x84, 0x00, // flags - 0x00, 0x00, // number of questions - 0x00, 0x02, // number of answer rr - 0x00, 0x00, // number of name server rr - 0x00, 0x00, // number of additional rr - 0x07, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x34, // webrtc4 - 0x03, 0x6f, 0x72, 0x67, // org - 0x00, // null label - 0x00, 0x01, // type A Record - 0x00, 0x01, // class IN - 0x00, 0x00, 0x00, 0x3c, // TTL, 60 seconds - 0x00, 0x04, // rdlength, 32 bits - 0xC0, 0xA8, 0x00, 0x01, // 192.168.0.1 - 0x07, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x36, // webrtc6 - 0x03, 0x6f, 0x72, 0x67, // org - 0x00, // null label - 0x00, 0x1C, // type AAAA Record - 0x00, 0x01, // class IN - 0x00, 0x00, 0x00, 0x78, // TTL, 120 seconds - 0x00, 0x10, // rdlength, 128 bits - 0xfd, 0x12, 0x34, 0x56, 0x78, 0x9a, 0x00, 0x01, // fd12:3456:789a:1::1 - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, -}; - -const uint8_t kTwoAuthoritativeAnswersWithIPv4AndIPv6AddrWithNameCompression[] = - { - 0x12, 0x34, // ID - 0x84, 0x00, // flags - 0x00, 0x00, // number of questions - 0x00, 0x02, // number of answer rr - 0x00, 0x00, // number of name server rr - 0x00, 0x00, // number of additional rr - 0x03, 0x77, 0x77, 0x77, // www - 0x06, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, // webrtc - 0x03, 0x6f, 0x72, 0x67, // org - 0x00, // null label - 0x00, 0x01, // type A Record - 0x00, 0x01, // class IN - 0x00, 0x00, 0x00, 0x3c, // TTL, 60 seconds - 0x00, 0x04, // rdlength, 32 bits - 0xc0, 0xA8, 0x00, 0x01, // 192.168.0.1 - 0xc0, 0x10, // offset 16, webrtc.org. - 0x00, 0x1C, // type AAAA Record - 0x00, 0x01, // class IN - 0x00, 0x00, 0x00, 0x78, // TTL, 120 seconds - 0x00, 0x10, // rdlength, 128 bits - 0xfd, 0x12, 0x34, 0x56, 0x78, 0x9a, 0x00, 0x01, // fd12:3456:789a:1::1 - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, -}; - -const uint8_t kCorruptedAnswerWithNameCompression1[] = { - 0x12, 0x34, // ID - 0x84, 0x00, // flags - 0x00, 0x00, // number of questions - 0x00, 0x01, // number of answer rr - 0x00, 0x00, // number of name server rr - 0x00, 0x00, // number of additional rr - 0xc0, 0x0c, // offset 12, - 0x00, 0x01, // type A Record - 0x00, 0x01, // class IN - 0x00, 0x00, 0x00, 0x3c, // TTL, 60 seconds - 0x00, 0x04, // rdlength, 32 bits - 0xc0, 0xA8, 0x00, 0x01, // 192.168.0.1 -}; - -const uint8_t kCorruptedAnswerWithNameCompression2[] = { - 0x12, 0x34, // ID - 0x84, 0x00, // flags - 0x00, 0x00, // number of questions - 0x00, 0x01, // number of answer rr - 0x00, 0x00, // number of name server rr - 0x00, 0x00, // number of additional rr - 0x01, 0x77, // w - 0xc0, 0x0c, // offset 12, - 0x00, 0x01, // type A Record - 0x00, 0x01, // class IN - 0x00, 0x00, 0x00, 0x3c, // TTL, 60 seconds - 0x00, 0x04, // rdlength, 32 bits - 0xc0, 0xA8, 0x00, 0x01, // 192.168.0.1 -}; - -bool ReadMdnsMessageTestCase(MdnsMessage* msg, - const uint8_t* testcase, - size_t size) { - MessageBufferReader buf(reinterpret_cast(testcase), size); - return msg->Read(&buf); -} - -void WriteMdnsMessageAndCompareWithTestCast(MdnsMessage* msg, - const uint8_t* testcase, - size_t size) { - rtc::ByteBufferWriter out; - EXPECT_TRUE(msg->Write(&out)); - EXPECT_EQ(size, out.Length()); - int len = static_cast(out.Length()); - rtc::ByteBufferReader read_buf(out); - std::string bytes; - read_buf.ReadString(&bytes, len); - std::string testcase_bytes(reinterpret_cast(testcase), size); - EXPECT_EQ(testcase_bytes, bytes); -} - -bool GetQueriedNames(MdnsMessage* msg, std::set* names) { - if (!msg->IsQuery() || msg->question_section().empty()) { - return false; - } - for (const auto& question : msg->question_section()) { - names->insert(question.GetName()); - } - return true; -} - -bool GetResolution(MdnsMessage* msg, - std::map* names) { - if (msg->IsQuery() || msg->answer_section().empty()) { - return false; - } - for (const auto& answer : msg->answer_section()) { - rtc::IPAddress resolved_addr; - if (!answer.GetIPAddressFromRecordData(&resolved_addr)) { - return false; - } - (*names)[answer.GetName()] = resolved_addr; - } - return true; -} - -} // namespace - -TEST(MdnsMessageTest, ReadSingleQuestionForIPv4Address) { - MdnsMessage msg; - ASSERT_TRUE( - ReadMdnsMessage(&msg, kSingleQuestionForIPv4AddrWithUnicastResponse)); - EXPECT_TRUE(msg.IsQuery()); - EXPECT_EQ(0x1234, msg.GetId()); - ASSERT_EQ(1u, msg.question_section().size()); - EXPECT_EQ(0u, msg.answer_section().size()); - EXPECT_EQ(0u, msg.authority_section().size()); - EXPECT_EQ(0u, msg.additional_section().size()); - EXPECT_TRUE(msg.ShouldUnicastResponse()); - - const auto& question = msg.question_section()[0]; - EXPECT_EQ(SectionEntryType::kA, question.GetType()); - - std::set queried_names; - EXPECT_TRUE(GetQueriedNames(&msg, &queried_names)); - EXPECT_THAT(queried_names, ElementsAre("webrtc.org.")); -} - -TEST(MdnsMessageTest, ReadTwoQuestionsForIPv4AndIPv6Addr) { - MdnsMessage msg; - ASSERT_TRUE(ReadMdnsMessage( - &msg, kTwoQuestionsForIPv4AndIPv6AddrWithMulticastResponse)); - EXPECT_TRUE(msg.IsQuery()); - EXPECT_EQ(0x1234, msg.GetId()); - ASSERT_EQ(2u, msg.question_section().size()); - EXPECT_EQ(0u, msg.answer_section().size()); - EXPECT_EQ(0u, msg.authority_section().size()); - EXPECT_EQ(0u, msg.additional_section().size()); - - const auto& question1 = msg.question_section()[0]; - const auto& question2 = msg.question_section()[1]; - EXPECT_EQ(SectionEntryType::kA, question1.GetType()); - EXPECT_EQ(SectionEntryType::kAAAA, question2.GetType()); - - std::set queried_names; - EXPECT_TRUE(GetQueriedNames(&msg, &queried_names)); - EXPECT_THAT(queried_names, - UnorderedElementsAre("webrtc4.org.", "webrtc6.org.")); -} - -TEST(MdnsMessageTest, ReadTwoQuestionsForIPv4AndIPv6AddrWithNameCompression) { - MdnsMessage msg; - ASSERT_TRUE(ReadMdnsMessage( - &msg, - kTwoQuestionsForIPv4AndIPv6AddrWithMulticastResponseAndNameCompression)); - - ASSERT_EQ(2u, msg.question_section().size()); - const auto& question1 = msg.question_section()[0]; - const auto& question2 = msg.question_section()[1]; - EXPECT_EQ(SectionEntryType::kA, question1.GetType()); - EXPECT_EQ(SectionEntryType::kAAAA, question2.GetType()); - - std::set queried_names; - EXPECT_TRUE(GetQueriedNames(&msg, &queried_names)); - EXPECT_THAT(queried_names, - UnorderedElementsAre("www.webrtc.org.", "mdns.webrtc.org.")); -} - -TEST(MdnsMessageTest, ReadThreeQuestionsWithTwoPointersToTheSameNameSuffix) { - MdnsMessage msg; - ASSERT_TRUE( - ReadMdnsMessage(&msg, kThreeQuestionsWithTwoPointersToTheSameNameSuffix)); - - ASSERT_EQ(3u, msg.question_section().size()); - const auto& question1 = msg.question_section()[0]; - const auto& question2 = msg.question_section()[1]; - const auto& question3 = msg.question_section()[2]; - EXPECT_EQ(SectionEntryType::kA, question1.GetType()); - EXPECT_EQ(SectionEntryType::kAAAA, question2.GetType()); - EXPECT_EQ(SectionEntryType::kA, question3.GetType()); - - std::set queried_names; - EXPECT_TRUE(GetQueriedNames(&msg, &queried_names)); - EXPECT_THAT(queried_names, - UnorderedElementsAre("www.webrtc.org.", "mdns.webrtc.org.", - "webrtc.org.")); -} - -TEST(MdnsMessageTest, - ReadThreeQuestionsWithPointerToNameSuffixContainingAnotherPointer) { - MdnsMessage msg; - ASSERT_TRUE(ReadMdnsMessage( - &msg, kThreeQuestionsWithPointerToNameSuffixContainingAnotherPointer)); - - ASSERT_EQ(3u, msg.question_section().size()); - const auto& question1 = msg.question_section()[0]; - const auto& question2 = msg.question_section()[1]; - const auto& question3 = msg.question_section()[2]; - EXPECT_EQ(SectionEntryType::kA, question1.GetType()); - EXPECT_EQ(SectionEntryType::kAAAA, question2.GetType()); - EXPECT_EQ(SectionEntryType::kA, question3.GetType()); - - std::set queried_names; - EXPECT_TRUE(GetQueriedNames(&msg, &queried_names)); - EXPECT_THAT(queried_names, - UnorderedElementsAre("www.webrtc.org.", "mdns.webrtc.org.", - "www.mdns.webrtc.org.")); -} - -TEST(MdnsMessageTest, - ReadQuestionWithCorruptedPointerInNameCompressionShouldFail) { - MdnsMessage msg; - EXPECT_FALSE(ReadMdnsMessage(&msg, kCorruptedQuestionWithNameCompression1)); - EXPECT_FALSE(ReadMdnsMessage(&msg, kCorruptedQuestionWithNameCompression2)); -} - -TEST(MdnsMessageTest, ReadSingleAnswerForIPv4Addr) { - MdnsMessage msg; - ASSERT_TRUE(ReadMdnsMessage(&msg, kSingleAuthoritativeAnswerWithIPv4Addr)); - EXPECT_FALSE(msg.IsQuery()); - EXPECT_TRUE(msg.IsAuthoritative()); - EXPECT_EQ(0x1234, msg.GetId()); - EXPECT_EQ(0u, msg.question_section().size()); - ASSERT_EQ(1u, msg.answer_section().size()); - EXPECT_EQ(0u, msg.authority_section().size()); - EXPECT_EQ(0u, msg.additional_section().size()); - - const auto& answer = msg.answer_section()[0]; - EXPECT_EQ(SectionEntryType::kA, answer.GetType()); - EXPECT_EQ(120u, answer.GetTtlSeconds()); - - std::map resolution; - EXPECT_TRUE(GetResolution(&msg, &resolution)); - rtc::IPAddress expected_addr(rtc::SocketAddress("192.168.0.1", 0).ipaddr()); - EXPECT_THAT(resolution, ElementsAre(Pair("webrtc.org.", expected_addr))); -} - -TEST(MdnsMessageTest, ReadTwoAnswersForIPv4AndIPv6Addr) { - MdnsMessage msg; - ASSERT_TRUE( - ReadMdnsMessage(&msg, kTwoAuthoritativeAnswersWithIPv4AndIPv6Addr)); - EXPECT_FALSE(msg.IsQuery()); - EXPECT_TRUE(msg.IsAuthoritative()); - EXPECT_EQ(0x1234, msg.GetId()); - EXPECT_EQ(0u, msg.question_section().size()); - ASSERT_EQ(2u, msg.answer_section().size()); - EXPECT_EQ(0u, msg.authority_section().size()); - EXPECT_EQ(0u, msg.additional_section().size()); - - const auto& answer1 = msg.answer_section()[0]; - const auto& answer2 = msg.answer_section()[1]; - EXPECT_EQ(SectionEntryType::kA, answer1.GetType()); - EXPECT_EQ(SectionEntryType::kAAAA, answer2.GetType()); - EXPECT_EQ(60u, answer1.GetTtlSeconds()); - EXPECT_EQ(120u, answer2.GetTtlSeconds()); - - std::map resolution; - EXPECT_TRUE(GetResolution(&msg, &resolution)); - rtc::IPAddress expected_addr_ipv4( - rtc::SocketAddress("192.168.0.1", 0).ipaddr()); - rtc::IPAddress expected_addr_ipv6( - rtc::SocketAddress("fd12:3456:789a:1::1", 0).ipaddr()); - EXPECT_THAT(resolution, - UnorderedElementsAre(Pair("webrtc4.org.", expected_addr_ipv4), - Pair("webrtc6.org.", expected_addr_ipv6))); -} - -TEST(MdnsMessageTest, ReadTwoAnswersForIPv4AndIPv6AddrWithNameCompression) { - MdnsMessage msg; - ASSERT_TRUE(ReadMdnsMessage( - &msg, kTwoAuthoritativeAnswersWithIPv4AndIPv6AddrWithNameCompression)); - - std::map resolution; - EXPECT_TRUE(GetResolution(&msg, &resolution)); - rtc::IPAddress expected_addr_ipv4( - rtc::SocketAddress("192.168.0.1", 0).ipaddr()); - rtc::IPAddress expected_addr_ipv6( - rtc::SocketAddress("fd12:3456:789a:1::1", 0).ipaddr()); - EXPECT_THAT(resolution, - UnorderedElementsAre(Pair("www.webrtc.org.", expected_addr_ipv4), - Pair("webrtc.org.", expected_addr_ipv6))); -} - -TEST(MdnsMessageTest, - ReadAnswerWithCorruptedPointerInNameCompressionShouldFail) { - MdnsMessage msg; - EXPECT_FALSE(ReadMdnsMessage(&msg, kCorruptedAnswerWithNameCompression1)); - EXPECT_FALSE(ReadMdnsMessage(&msg, kCorruptedAnswerWithNameCompression2)); -} - -TEST(MdnsMessageTest, WriteSingleQuestionForIPv4Addr) { - MdnsMessage msg; - msg.SetId(0x1234); - msg.SetQueryOrResponse(true); - - MdnsQuestion question; - question.SetName("webrtc.org."); - question.SetType(SectionEntryType::kA); - question.SetClass(SectionEntryClass::kIN); - question.SetUnicastResponse(true); - msg.AddQuestion(question); - - WriteMdnsMessageAndCompare(&msg, - kSingleQuestionForIPv4AddrWithUnicastResponse); -} - -TEST(MdnsMessageTest, WriteTwoQuestionsForIPv4AndIPv6Addr) { - MdnsMessage msg; - msg.SetId(0x1234); - msg.SetQueryOrResponse(true); - - MdnsQuestion question1; - question1.SetName("webrtc4.org."); - question1.SetType(SectionEntryType::kA); - question1.SetClass(SectionEntryClass::kIN); - msg.AddQuestion(question1); - - MdnsQuestion question2; - question2.SetName("webrtc6.org."); - question2.SetType(SectionEntryType::kAAAA); - question2.SetClass(SectionEntryClass::kIN); - msg.AddQuestion(question2); - - WriteMdnsMessageAndCompare( - &msg, kTwoQuestionsForIPv4AndIPv6AddrWithMulticastResponse); -} - -TEST(MdnsMessageTest, WriteSingleAnswerToIPv4Addr) { - MdnsMessage msg; - msg.SetId(0x1234); - msg.SetQueryOrResponse(false); - msg.SetAuthoritative(true); - - MdnsResourceRecord answer; - answer.SetName("webrtc.org."); - answer.SetType(SectionEntryType::kA); - answer.SetClass(SectionEntryClass::kIN); - EXPECT_TRUE(answer.SetIPAddressInRecordData( - rtc::SocketAddress("192.168.0.1", 0).ipaddr())); - answer.SetTtlSeconds(120); - msg.AddAnswerRecord(answer); - - WriteMdnsMessageAndCompare(&msg, kSingleAuthoritativeAnswerWithIPv4Addr); -} - -TEST(MdnsMessageTest, WriteTwoAnswersToIPv4AndIPv6Addr) { - MdnsMessage msg; - msg.SetId(0x1234); - msg.SetQueryOrResponse(false); - msg.SetAuthoritative(true); - - MdnsResourceRecord answer1; - answer1.SetName("webrtc4.org."); - answer1.SetType(SectionEntryType::kA); - answer1.SetClass(SectionEntryClass::kIN); - answer1.SetIPAddressInRecordData( - rtc::SocketAddress("192.168.0.1", 0).ipaddr()); - answer1.SetTtlSeconds(60); - msg.AddAnswerRecord(answer1); - - MdnsResourceRecord answer2; - answer2.SetName("webrtc6.org."); - answer2.SetType(SectionEntryType::kAAAA); - answer2.SetClass(SectionEntryClass::kIN); - answer2.SetIPAddressInRecordData( - rtc::SocketAddress("fd12:3456:789a:1::1", 0).ipaddr()); - answer2.SetTtlSeconds(120); - msg.AddAnswerRecord(answer2); - - WriteMdnsMessageAndCompare(&msg, kTwoAuthoritativeAnswersWithIPv4AndIPv6Addr); -} - -} // namespace webrtc diff --git a/p2p/base/p2p_transport_channel.cc b/p2p/base/p2p_transport_channel.cc index f511fb915a..836721c151 100644 --- a/p2p/base/p2p_transport_channel.cc +++ b/p2p/base/p2p_transport_channel.cc @@ -10,28 +10,40 @@ #include "p2p/base/p2p_transport_channel.h" -#include +#include +#include + +#include +#include #include #include #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "absl/strings/match.h" +#include "api/async_dns_resolver.h" #include "api/candidate.h" +#include "api/task_queue/queued_task.h" #include "logging/rtc_event_log/ice_logger.h" +#include "p2p/base/basic_async_resolver_factory.h" #include "p2p/base/basic_ice_controller.h" -#include "p2p/base/candidate_pair_interface.h" #include "p2p/base/connection.h" +#include "p2p/base/connection_info.h" #include "p2p/base/port.h" #include "rtc_base/checks.h" #include "rtc_base/crc32.h" #include "rtc_base/experiments/struct_parameters_parser.h" +#include "rtc_base/ip_address.h" #include "rtc_base/logging.h" #include "rtc_base/net_helper.h" -#include "rtc_base/net_helpers.h" +#include "rtc_base/network.h" +#include "rtc_base/network_constants.h" #include "rtc_base/string_encode.h" #include "rtc_base/task_utils/to_queued_task.h" +#include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/time_utils.h" +#include "rtc_base/trace_event.h" #include "system_wrappers/include/field_trial.h" #include "system_wrappers/include/metrics.h" @@ -122,26 +134,50 @@ bool IceCredentialsChanged(const std::string& old_ufrag, return (old_ufrag != new_ufrag) || (old_pwd != new_pwd); } +// static +std::unique_ptr P2PTransportChannel::Create( + const std::string& transport_name, + int component, + PortAllocator* allocator, + webrtc::AsyncDnsResolverFactoryInterface* async_dns_resolver_factory, + webrtc::RtcEventLog* event_log, + IceControllerFactoryInterface* ice_controller_factory) { + return absl::WrapUnique(new P2PTransportChannel( + transport_name, component, allocator, async_dns_resolver_factory, + /* owned_dns_resolver_factory= */ nullptr, event_log, + ice_controller_factory)); +} + P2PTransportChannel::P2PTransportChannel(const std::string& transport_name, int component, PortAllocator* allocator) : P2PTransportChannel(transport_name, component, allocator, - nullptr, - nullptr) {} + /* async_dns_resolver_factory= */ nullptr, + /* owned_dns_resolver_factory= */ nullptr, + /* event_log= */ nullptr, + /* ice_controller_factory= */ nullptr) {} +// Private constructor, called from Create() P2PTransportChannel::P2PTransportChannel( const std::string& transport_name, int component, PortAllocator* allocator, - webrtc::AsyncResolverFactory* async_resolver_factory, + webrtc::AsyncDnsResolverFactoryInterface* async_dns_resolver_factory, + std::unique_ptr + owned_dns_resolver_factory, webrtc::RtcEventLog* event_log, IceControllerFactoryInterface* ice_controller_factory) : transport_name_(transport_name), component_(component), allocator_(allocator), - async_resolver_factory_(async_resolver_factory), + // If owned_dns_resolver_factory is given, async_dns_resolver_factory is + // ignored. + async_dns_resolver_factory_(owned_dns_resolver_factory + ? owned_dns_resolver_factory.get() + : async_dns_resolver_factory), + owned_dns_resolver_factory_(std::move(owned_dns_resolver_factory)), network_thread_(rtc::Thread::Current()), incoming_only_(false), error_(0), @@ -158,6 +194,7 @@ P2PTransportChannel::P2PTransportChannel( true /* presume_writable_when_fully_relayed */, REGATHER_ON_FAILED_NETWORKS_INTERVAL, RECEIVING_SWITCHING_DELAY) { + TRACE_EVENT0("webrtc", "P2PTransportChannel::P2PTransportChannel"); RTC_DCHECK(allocator_ != nullptr); weak_ping_interval_ = GetWeakPingIntervalInFieldTrial(); // Validate IceConfig even for mostly built-in constant default values in case @@ -192,15 +229,32 @@ P2PTransportChannel::P2PTransportChannel( } } +// Public constructor, exposed for backwards compatibility. +// Deprecated. +P2PTransportChannel::P2PTransportChannel( + const std::string& transport_name, + int component, + PortAllocator* allocator, + webrtc::AsyncResolverFactory* async_resolver_factory, + webrtc::RtcEventLog* event_log, + IceControllerFactoryInterface* ice_controller_factory) + : P2PTransportChannel( + transport_name, + component, + allocator, + nullptr, + std::make_unique( + async_resolver_factory), + event_log, + ice_controller_factory) {} + P2PTransportChannel::~P2PTransportChannel() { + TRACE_EVENT0("webrtc", "P2PTransportChannel::~P2PTransportChannel"); RTC_DCHECK_RUN_ON(network_thread_); std::vector copy(connections().begin(), connections().end()); for (Connection* con : copy) { con->Destroy(); } - for (auto& p : resolvers_) { - p.resolver_->Destroy(false); - } resolvers_.clear(); } @@ -903,7 +957,8 @@ void P2PTransportChannel::OnPortReady(PortAllocatorSession* session, ports_.push_back(port); port->SignalUnknownAddress.connect(this, &P2PTransportChannel::OnUnknownAddress); - port->SignalDestroyed.connect(this, &P2PTransportChannel::OnPortDestroyed); + port->SubscribePortDestroyed( + [this](PortInterface* port) { OnPortDestroyed(port); }); port->SignalRoleConflict.connect(this, &P2PTransportChannel::OnRoleConflict); port->SignalSentPacket.connect(this, &P2PTransportChannel::OnSentPacket); @@ -1163,16 +1218,17 @@ void P2PTransportChannel::OnNominated(Connection* conn) { void P2PTransportChannel::ResolveHostnameCandidate(const Candidate& candidate) { RTC_DCHECK_RUN_ON(network_thread_); - if (!async_resolver_factory_) { + if (!async_dns_resolver_factory_) { RTC_LOG(LS_WARNING) << "Dropping ICE candidate with hostname address " "(no AsyncResolverFactory)"; return; } - rtc::AsyncResolverInterface* resolver = async_resolver_factory_->Create(); - resolvers_.emplace_back(candidate, resolver); - resolver->SignalDone.connect(this, &P2PTransportChannel::OnCandidateResolved); - resolver->Start(candidate.address()); + auto resolver = async_dns_resolver_factory_->Create(); + auto resptr = resolver.get(); + resolvers_.emplace_back(candidate, std::move(resolver)); + resptr->Start(candidate.address(), + [this, resptr]() { OnCandidateResolved(resptr); }); RTC_LOG(LS_INFO) << "Asynchronously resolving ICE candidate hostname " << candidate.address().HostAsSensitiveURIString(); } @@ -1227,38 +1283,44 @@ void P2PTransportChannel::AddRemoteCandidate(const Candidate& candidate) { P2PTransportChannel::CandidateAndResolver::CandidateAndResolver( const Candidate& candidate, - rtc::AsyncResolverInterface* resolver) - : candidate_(candidate), resolver_(resolver) {} + std::unique_ptr&& resolver) + : candidate_(candidate), resolver_(std::move(resolver)) {} P2PTransportChannel::CandidateAndResolver::~CandidateAndResolver() {} void P2PTransportChannel::OnCandidateResolved( - rtc::AsyncResolverInterface* resolver) { + webrtc::AsyncDnsResolverInterface* resolver) { RTC_DCHECK_RUN_ON(network_thread_); auto p = absl::c_find_if(resolvers_, [resolver](const CandidateAndResolver& cr) { - return cr.resolver_ == resolver; + return cr.resolver_.get() == resolver; }); if (p == resolvers_.end()) { - RTC_LOG(LS_ERROR) << "Unexpected AsyncResolver signal"; + RTC_LOG(LS_ERROR) << "Unexpected AsyncDnsResolver return"; RTC_NOTREACHED(); return; } Candidate candidate = p->candidate_; - resolvers_.erase(p); - AddRemoteCandidateWithResolver(candidate, resolver); + AddRemoteCandidateWithResult(candidate, resolver->result()); + // Now we can delete the resolver. + // TODO(bugs.webrtc.org/12651): Replace the stuff below with + // resolvers_.erase(p); + std::unique_ptr to_delete = + std::move(p->resolver_); + // Delay the actual deletion of the resolver until the lambda executes. network_thread_->PostTask( - ToQueuedTask([resolver]() { resolver->Destroy(false); })); + ToQueuedTask([delete_this = std::move(to_delete)] {})); + resolvers_.erase(p); } -void P2PTransportChannel::AddRemoteCandidateWithResolver( +void P2PTransportChannel::AddRemoteCandidateWithResult( Candidate candidate, - rtc::AsyncResolverInterface* resolver) { + const webrtc::AsyncDnsResolverResult& result) { RTC_DCHECK_RUN_ON(network_thread_); - if (resolver->GetError()) { + if (result.GetError()) { RTC_LOG(LS_WARNING) << "Failed to resolve ICE candidate hostname " << candidate.address().HostAsSensitiveURIString() - << " with error " << resolver->GetError(); + << " with error " << result.GetError(); return; } @@ -1266,9 +1328,8 @@ void P2PTransportChannel::AddRemoteCandidateWithResolver( // Prefer IPv6 to IPv4 if we have it (see RFC 5245 Section 15.1). // TODO(zstein): This won't work if we only have IPv4 locally but receive an // AAAA DNS record. - bool have_address = - resolver->GetResolvedAddress(AF_INET6, &resolved_address) || - resolver->GetResolvedAddress(AF_INET, &resolved_address); + bool have_address = result.GetResolvedAddress(AF_INET6, &resolved_address) || + result.GetResolvedAddress(AF_INET, &resolved_address); if (!have_address) { RTC_LOG(LS_INFO) << "ICE candidate hostname " << candidate.address().HostAsSensitiveURIString() diff --git a/p2p/base/p2p_transport_channel.h b/p2p/base/p2p_transport_channel.h index 1e93942fe9..462aa105b1 100644 --- a/p2p/base/p2p_transport_channel.h +++ b/p2p/base/p2p_transport_channel.h @@ -20,6 +20,9 @@ #ifndef P2P_BASE_P2P_TRANSPORT_CHANNEL_H_ #define P2P_BASE_P2P_TRANSPORT_CHANNEL_H_ +#include +#include + #include #include #include @@ -27,26 +30,43 @@ #include #include +#include "absl/base/attributes.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/async_dns_resolver.h" #include "api/async_resolver_factory.h" #include "api/candidate.h" #include "api/rtc_error.h" +#include "api/sequence_checker.h" +#include "api/transport/enums.h" +#include "api/transport/stun.h" #include "logging/rtc_event_log/events/rtc_event_ice_candidate_pair_config.h" #include "logging/rtc_event_log/ice_logger.h" #include "p2p/base/candidate_pair_interface.h" +#include "p2p/base/connection.h" #include "p2p/base/ice_controller_factory_interface.h" #include "p2p/base/ice_controller_interface.h" #include "p2p/base/ice_transport_internal.h" #include "p2p/base/p2p_constants.h" #include "p2p/base/p2p_transport_channel_ice_field_trials.h" +#include "p2p/base/port.h" #include "p2p/base/port_allocator.h" #include "p2p/base/port_interface.h" #include "p2p/base/regathering_controller.h" +#include "p2p/base/transport_description.h" #include "rtc_base/async_packet_socket.h" +#include "rtc_base/checks.h" #include "rtc_base/constructor_magic.h" +#include "rtc_base/dscp.h" +#include "rtc_base/network/sent_packet.h" +#include "rtc_base/network_route.h" +#include "rtc_base/socket.h" +#include "rtc_base/socket_address.h" #include "rtc_base/strings/string_builder.h" #include "rtc_base/system/rtc_export.h" #include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/third_party/sigslot/sigslot.h" +#include "rtc_base/thread.h" #include "rtc_base/thread_annotations.h" namespace webrtc { @@ -82,11 +102,19 @@ class RemoteCandidate : public Candidate { // two P2P clients connected to each other. class RTC_EXPORT P2PTransportChannel : public IceTransportInternal { public: + static std::unique_ptr Create( + const std::string& transport_name, + int component, + PortAllocator* allocator, + webrtc::AsyncDnsResolverFactoryInterface* async_dns_resolver_factory, + webrtc::RtcEventLog* event_log = nullptr, + IceControllerFactoryInterface* ice_controller_factory = nullptr); // For testing only. - // TODO(zstein): Remove once AsyncResolverFactory is required. + // TODO(zstein): Remove once AsyncDnsResolverFactory is required. P2PTransportChannel(const std::string& transport_name, int component, PortAllocator* allocator); + ABSL_DEPRECATED("bugs.webrtc.org/12598") P2PTransportChannel( const std::string& transport_name, int component, @@ -209,6 +237,18 @@ class RTC_EXPORT P2PTransportChannel : public IceTransportInternal { } private: + P2PTransportChannel( + const std::string& transport_name, + int component, + PortAllocator* allocator, + // DNS resolver factory + webrtc::AsyncDnsResolverFactoryInterface* async_dns_resolver_factory, + // If the P2PTransportChannel has to delete the DNS resolver factory + // on release, this pointer is set. + std::unique_ptr + owned_dns_resolver_factory, + webrtc::RtcEventLog* event_log = nullptr, + IceControllerFactoryInterface* ice_controller_factory = nullptr); bool IsGettingPorts() { RTC_DCHECK_RUN_ON(network_thread_); return allocator_session()->IsGettingPorts(); @@ -363,8 +403,10 @@ class RTC_EXPORT P2PTransportChannel : public IceTransportInternal { std::string transport_name_ RTC_GUARDED_BY(network_thread_); int component_ RTC_GUARDED_BY(network_thread_); PortAllocator* allocator_ RTC_GUARDED_BY(network_thread_); - webrtc::AsyncResolverFactory* async_resolver_factory_ + webrtc::AsyncDnsResolverFactoryInterface* const async_dns_resolver_factory_ RTC_GUARDED_BY(network_thread_); + const std::unique_ptr + owned_dns_resolver_factory_; rtc::Thread* const network_thread_; bool incoming_only_ RTC_GUARDED_BY(network_thread_); int error_ RTC_GUARDED_BY(network_thread_); @@ -426,17 +468,23 @@ class RTC_EXPORT P2PTransportChannel : public IceTransportInternal { RTC_GUARDED_BY(network_thread_); struct CandidateAndResolver final { - CandidateAndResolver(const Candidate& candidate, - rtc::AsyncResolverInterface* resolver); + CandidateAndResolver( + const Candidate& candidate, + std::unique_ptr&& resolver); ~CandidateAndResolver(); + // Moveable, but not copyable. + CandidateAndResolver(CandidateAndResolver&&) = default; + CandidateAndResolver& operator=(CandidateAndResolver&&) = default; + Candidate candidate_; - rtc::AsyncResolverInterface* resolver_; + std::unique_ptr resolver_; }; std::vector resolvers_ RTC_GUARDED_BY(network_thread_); void FinishAddingRemoteCandidate(const Candidate& new_remote_candidate); - void OnCandidateResolved(rtc::AsyncResolverInterface* resolver); - void AddRemoteCandidateWithResolver(Candidate candidate, - rtc::AsyncResolverInterface* resolver); + void OnCandidateResolved(webrtc::AsyncDnsResolverInterface* resolver); + void AddRemoteCandidateWithResult( + Candidate candidate, + const webrtc::AsyncDnsResolverResult& result); // Number of times the selected_connection_ has been modified. uint32_t selected_candidate_pair_changes_ = 0; diff --git a/p2p/base/p2p_transport_channel_unittest.cc b/p2p/base/p2p_transport_channel_unittest.cc index 3ea9ca72ae..b217a74859 100644 --- a/p2p/base/p2p_transport_channel_unittest.cc +++ b/p2p/base/p2p_transport_channel_unittest.cc @@ -14,6 +14,7 @@ #include #include +#include "api/test/mock_async_dns_resolver.h" #include "p2p/base/basic_ice_controller.h" #include "p2p/base/connection.h" #include "p2p/base/fake_port_allocator.h" @@ -51,9 +52,12 @@ using ::testing::Assign; using ::testing::Contains; using ::testing::DoAll; using ::testing::InSequence; +using ::testing::InvokeArgument; using ::testing::InvokeWithoutArgs; using ::testing::NiceMock; using ::testing::Return; +using ::testing::ReturnRef; +using ::testing::SaveArg; using ::testing::SetArgPointee; using ::testing::SizeIs; @@ -187,6 +191,51 @@ class MockIceControllerFactory : public cricket::IceControllerFactoryInterface { MOCK_METHOD(void, RecordIceControllerCreated, ()); }; +// An one-shot resolver factory with default return arguments. +// Resolution is immediate, always succeeds, and returns nonsense. +class ResolverFactoryFixture : public webrtc::MockAsyncDnsResolverFactory { + public: + ResolverFactoryFixture() { + mock_async_dns_resolver_ = std::make_unique(); + ON_CALL(*mock_async_dns_resolver_, Start(_, _)) + .WillByDefault(InvokeArgument<1>()); + EXPECT_CALL(*mock_async_dns_resolver_, result()) + .WillOnce(ReturnRef(mock_async_dns_resolver_result_)); + + // A default action for GetResolvedAddress. Will be overruled + // by SetAddressToReturn. + ON_CALL(mock_async_dns_resolver_result_, GetResolvedAddress(_, _)) + .WillByDefault(Return(true)); + + EXPECT_CALL(mock_async_dns_resolver_result_, GetError()) + .WillOnce(Return(0)); + EXPECT_CALL(*this, Create()).WillOnce([this]() { + return std::move(mock_async_dns_resolver_); + }); + } + + void SetAddressToReturn(rtc::SocketAddress address_to_return) { + EXPECT_CALL(mock_async_dns_resolver_result_, GetResolvedAddress(_, _)) + .WillOnce(DoAll(SetArgPointee<1>(address_to_return), Return(true))); + } + void DelayResolution() { + // This function must be called before Create(). + ASSERT_TRUE(!!mock_async_dns_resolver_); + EXPECT_CALL(*mock_async_dns_resolver_, Start(_, _)) + .WillOnce(SaveArg<1>(&saved_callback_)); + } + void FireDelayedResolution() { + // This function must be called after Create(). + ASSERT_TRUE(saved_callback_); + saved_callback_(); + } + + private: + std::unique_ptr mock_async_dns_resolver_; + webrtc::MockAsyncDnsResolverResult mock_async_dns_resolver_result_; + std::function saved_callback_; +}; + } // namespace namespace cricket { @@ -345,7 +394,7 @@ class P2PTransportChannelTestBase : public ::testing::Test, rtc::FakeNetworkManager network_manager_; std::unique_ptr allocator_; - webrtc::AsyncResolverFactory* async_resolver_factory_; + webrtc::AsyncDnsResolverFactoryInterface* async_dns_resolver_factory_; ChannelData cd1_; ChannelData cd2_; IceRole role_; @@ -378,10 +427,10 @@ class P2PTransportChannelTestBase : public ::testing::Test, IceParamsWithRenomination(kIceParams[0], renomination); IceParameters ice_ep2_cd1_ch = IceParamsWithRenomination(kIceParams[1], renomination); - ep1_.cd1_.ch_.reset(CreateChannel(0, ICE_CANDIDATE_COMPONENT_DEFAULT, - ice_ep1_cd1_ch, ice_ep2_cd1_ch)); - ep2_.cd1_.ch_.reset(CreateChannel(1, ICE_CANDIDATE_COMPONENT_DEFAULT, - ice_ep2_cd1_ch, ice_ep1_cd1_ch)); + ep1_.cd1_.ch_ = CreateChannel(0, ICE_CANDIDATE_COMPONENT_DEFAULT, + ice_ep1_cd1_ch, ice_ep2_cd1_ch); + ep2_.cd1_.ch_ = CreateChannel(1, ICE_CANDIDATE_COMPONENT_DEFAULT, + ice_ep2_cd1_ch, ice_ep1_cd1_ch); ep1_.cd1_.ch_->SetIceConfig(ep1_config); ep2_.cd1_.ch_->SetIceConfig(ep2_config); ep1_.cd1_.ch_->MaybeStartGathering(); @@ -397,13 +446,14 @@ class P2PTransportChannelTestBase : public ::testing::Test, CreateChannels(default_config, default_config, false); } - P2PTransportChannel* CreateChannel(int endpoint, - int component, - const IceParameters& local_ice, - const IceParameters& remote_ice) { - P2PTransportChannel* channel = new P2PTransportChannel( + std::unique_ptr CreateChannel( + int endpoint, + int component, + const IceParameters& local_ice, + const IceParameters& remote_ice) { + auto channel = P2PTransportChannel::Create( "test content name", component, GetAllocator(endpoint), - GetEndpoint(endpoint)->async_resolver_factory_); + GetEndpoint(endpoint)->async_dns_resolver_factory_); channel->SignalReadyToSend.connect( this, &P2PTransportChannelTestBase::OnReadyToSend); channel->SignalCandidateGathered.connect( @@ -2079,8 +2129,8 @@ TEST_F(P2PTransportChannelTest, TurnToTurnPresumedWritable) { kDefaultPortAllocatorFlags); // Only configure one channel so we can control when the remote candidate // is added. - GetEndpoint(0)->cd1_.ch_.reset(CreateChannel( - 0, ICE_CANDIDATE_COMPONENT_DEFAULT, kIceParams[0], kIceParams[1])); + GetEndpoint(0)->cd1_.ch_ = CreateChannel(0, ICE_CANDIDATE_COMPONENT_DEFAULT, + kIceParams[0], kIceParams[1]); IceConfig config; config.presume_writable_when_fully_relayed = true; ep1_ch1()->SetIceConfig(config); @@ -2128,10 +2178,10 @@ TEST_F(P2PTransportChannelTest, TurnToPrflxPresumedWritable) { test_turn_server()->set_enable_permission_checks(false); IceConfig config; config.presume_writable_when_fully_relayed = true; - GetEndpoint(0)->cd1_.ch_.reset(CreateChannel( - 0, ICE_CANDIDATE_COMPONENT_DEFAULT, kIceParams[0], kIceParams[1])); - GetEndpoint(1)->cd1_.ch_.reset(CreateChannel( - 1, ICE_CANDIDATE_COMPONENT_DEFAULT, kIceParams[1], kIceParams[0])); + GetEndpoint(0)->cd1_.ch_ = CreateChannel(0, ICE_CANDIDATE_COMPONENT_DEFAULT, + kIceParams[0], kIceParams[1]); + GetEndpoint(1)->cd1_.ch_ = CreateChannel(1, ICE_CANDIDATE_COMPONENT_DEFAULT, + kIceParams[1], kIceParams[0]); ep1_ch1()->SetIceConfig(config); ep2_ch1()->SetIceConfig(config); // Don't signal candidates from channel 2, so that channel 1 sees the TURN @@ -2167,10 +2217,10 @@ TEST_F(P2PTransportChannelTest, PresumedWritablePreferredOverUnreliable) { kDefaultPortAllocatorFlags); IceConfig config; config.presume_writable_when_fully_relayed = true; - GetEndpoint(0)->cd1_.ch_.reset(CreateChannel( - 0, ICE_CANDIDATE_COMPONENT_DEFAULT, kIceParams[0], kIceParams[1])); - GetEndpoint(1)->cd1_.ch_.reset(CreateChannel( - 1, ICE_CANDIDATE_COMPONENT_DEFAULT, kIceParams[1], kIceParams[0])); + GetEndpoint(0)->cd1_.ch_ = CreateChannel(0, ICE_CANDIDATE_COMPONENT_DEFAULT, + kIceParams[0], kIceParams[1]); + GetEndpoint(1)->cd1_.ch_ = CreateChannel(1, ICE_CANDIDATE_COMPONENT_DEFAULT, + kIceParams[1], kIceParams[0]); ep1_ch1()->SetIceConfig(config); ep2_ch1()->SetIceConfig(config); ep1_ch1()->MaybeStartGathering(); @@ -2205,8 +2255,8 @@ TEST_F(P2PTransportChannelTest, SignalReadyToSendWithPresumedWritable) { kDefaultPortAllocatorFlags); // Only test one endpoint, so we can ensure the connection doesn't receive a // binding response and advance beyond being "presumed" writable. - GetEndpoint(0)->cd1_.ch_.reset(CreateChannel( - 0, ICE_CANDIDATE_COMPONENT_DEFAULT, kIceParams[0], kIceParams[1])); + GetEndpoint(0)->cd1_.ch_ = CreateChannel(0, ICE_CANDIDATE_COMPONENT_DEFAULT, + kIceParams[0], kIceParams[1]); IceConfig config; config.presume_writable_when_fully_relayed = true; ep1_ch1()->SetIceConfig(config); @@ -2258,10 +2308,10 @@ TEST_F(P2PTransportChannelTest, // to configure the server to accept packets from an address we haven't // explicitly installed permission for. test_turn_server()->set_enable_permission_checks(false); - GetEndpoint(0)->cd1_.ch_.reset(CreateChannel( - 0, ICE_CANDIDATE_COMPONENT_DEFAULT, kIceParams[0], kIceParams[1])); - GetEndpoint(1)->cd1_.ch_.reset(CreateChannel( - 1, ICE_CANDIDATE_COMPONENT_DEFAULT, kIceParams[1], kIceParams[0])); + GetEndpoint(0)->cd1_.ch_ = CreateChannel(0, ICE_CANDIDATE_COMPONENT_DEFAULT, + kIceParams[0], kIceParams[1]); + GetEndpoint(1)->cd1_.ch_ = CreateChannel(1, ICE_CANDIDATE_COMPONENT_DEFAULT, + kIceParams[1], kIceParams[0]); // Don't signal candidates from channel 2, so that channel 1 sees the TURN // candidate as peer reflexive. PauseCandidates(1); @@ -2887,6 +2937,53 @@ TEST_F(P2PTransportChannelMultihomedTest, TestPingBackupConnectionRate) { DestroyChannels(); } +// Test that the connection is pinged at a rate no faster than +// what was configured when stable and writable. +TEST_F(P2PTransportChannelMultihomedTest, TestStableWritableRate) { + AddAddress(0, kPublicAddrs[0]); + // Adding alternate address will make sure |kPublicAddrs| has the higher + // priority than others. This is due to FakeNetwork::AddInterface method. + AddAddress(1, kAlternateAddrs[1]); + AddAddress(1, kPublicAddrs[1]); + + // Use only local ports for simplicity. + SetAllocatorFlags(0, kOnlyLocalPorts); + SetAllocatorFlags(1, kOnlyLocalPorts); + + // Create channels and let them go writable, as usual. + CreateChannels(); + EXPECT_TRUE_WAIT_MARGIN(CheckConnected(ep1_ch1(), ep2_ch1()), 1000, 1000); + // Set a value larger than the default value of 2500 ms + int ping_interval_ms = 3456; + IceConfig config = CreateIceConfig(2 * ping_interval_ms, GATHER_ONCE); + config.stable_writable_connection_ping_interval = ping_interval_ms; + ep2_ch1()->SetIceConfig(config); + // After the state becomes COMPLETED and is stable and writable, the + // connection will be pinged once every |ping_interval_ms| milliseconds. + ASSERT_TRUE_WAIT(ep2_ch1()->GetState() == IceTransportState::STATE_COMPLETED, + 1000); + auto connections = ep2_ch1()->connections(); + ASSERT_EQ(2U, connections.size()); + Connection* conn = connections[0]; + EXPECT_TRUE_WAIT(conn->writable(), kMediumTimeout); + + int64_t last_ping_response_ms; + // Burn through some pings so the connection is stable. + for (int i = 0; i < 5; i++) { + last_ping_response_ms = conn->last_ping_response_received(); + EXPECT_TRUE_WAIT( + last_ping_response_ms < conn->last_ping_response_received(), + kDefaultTimeout); + } + EXPECT_TRUE(conn->stable(last_ping_response_ms)) << "Connection not stable"; + int time_elapsed = + conn->last_ping_response_received() - last_ping_response_ms; + RTC_LOG(LS_INFO) << "Time elapsed: " << time_elapsed; + EXPECT_GE(time_elapsed, ping_interval_ms); + + DestroyChannels(); +} + TEST_F(P2PTransportChannelMultihomedTest, TestGetState) { rtc::ScopedFakeClock clock; AddAddress(0, kAlternateAddrs[0]); @@ -4834,31 +4931,18 @@ TEST_F(P2PTransportChannelMostLikelyToWorkFirstTest, TestTcpTurn) { // when the address is a hostname. The destruction should happen even // if the channel is not destroyed. TEST(P2PTransportChannelResolverTest, HostnameCandidateIsResolved) { - rtc::MockAsyncResolver mock_async_resolver; - EXPECT_CALL(mock_async_resolver, GetError()).WillOnce(Return(0)); - EXPECT_CALL(mock_async_resolver, GetResolvedAddress(_, _)) - .WillOnce(Return(true)); - // Destroy is called asynchronously after the address is resolved, - // so we need a variable to wait on. - bool destroy_called = false; - EXPECT_CALL(mock_async_resolver, Destroy(_)) - .WillOnce(Assign(&destroy_called, true)); - webrtc::MockAsyncResolverFactory mock_async_resolver_factory; - EXPECT_CALL(mock_async_resolver_factory, Create()) - .WillOnce(Return(&mock_async_resolver)); - + ResolverFactoryFixture resolver_fixture; FakePortAllocator allocator(rtc::Thread::Current(), nullptr); - P2PTransportChannel channel("tn", 0, &allocator, - &mock_async_resolver_factory); + auto channel = + P2PTransportChannel::Create("tn", 0, &allocator, &resolver_fixture); Candidate hostname_candidate; SocketAddress hostname_address("fake.test", 1000); hostname_candidate.set_address(hostname_address); - channel.AddRemoteCandidate(hostname_candidate); + channel->AddRemoteCandidate(hostname_candidate); - ASSERT_EQ_WAIT(1u, channel.remote_candidates().size(), kDefaultTimeout); - const RemoteCandidate& candidate = channel.remote_candidates()[0]; + ASSERT_EQ_WAIT(1u, channel->remote_candidates().size(), kDefaultTimeout); + const RemoteCandidate& candidate = channel->remote_candidates()[0]; EXPECT_FALSE(candidate.address().IsUnresolvedIP()); - WAIT(destroy_called, kShortTimeout); } // Test that if we signal a hostname candidate after the remote endpoint @@ -4867,11 +4951,6 @@ TEST(P2PTransportChannelResolverTest, HostnameCandidateIsResolved) { // done. TEST_F(P2PTransportChannelTest, PeerReflexiveCandidateBeforeSignalingWithMdnsName) { - rtc::MockAsyncResolver mock_async_resolver; - webrtc::MockAsyncResolverFactory mock_async_resolver_factory; - EXPECT_CALL(mock_async_resolver_factory, Create()) - .WillOnce(Return(&mock_async_resolver)); - // ep1 and ep2 will only gather host candidates with addresses // kPublicAddrs[0] and kPublicAddrs[1], respectively. ConfigureEndpoints(OPEN, OPEN, kOnlyLocalPorts, kOnlyLocalPorts); @@ -4879,7 +4958,9 @@ TEST_F(P2PTransportChannelTest, set_remote_ice_parameter_source(FROM_SETICEPARAMETERS); GetEndpoint(0)->network_manager_.set_mdns_responder( std::make_unique(rtc::Thread::Current())); - GetEndpoint(1)->async_resolver_factory_ = &mock_async_resolver_factory; + + ResolverFactoryFixture resolver_fixture; + GetEndpoint(1)->async_dns_resolver_factory_ = &resolver_fixture; CreateChannels(); // Pause sending candidates from both endpoints until we find out what port // number is assgined to ep1's host candidate. @@ -4894,6 +4975,7 @@ TEST_F(P2PTransportChannelTest, // This is the underlying private IP address of the same candidate at ep1. const auto local_address = rtc::SocketAddress( kPublicAddrs[0].ipaddr(), local_candidate.address().port()); + // Let ep2 signal its candidate to ep1. ep1 should form a candidate // pair and start to ping. After receiving the ping, ep2 discovers a prflx // remote candidate and form a candidate pair as well. @@ -4909,19 +4991,7 @@ TEST_F(P2PTransportChannelTest, EXPECT_EQ(kIceUfrag[0], selected_connection->remote_candidate().username()); EXPECT_EQ(kIcePwd[0], selected_connection->remote_candidate().password()); // Set expectation before ep1 signals a hostname candidate. - { - InSequence sequencer; - EXPECT_CALL(mock_async_resolver, Start(_)); - EXPECT_CALL(mock_async_resolver, GetError()).WillOnce(Return(0)); - // Let the mock resolver of ep2 receives the correct resolution. - EXPECT_CALL(mock_async_resolver, GetResolvedAddress(_, _)) - .WillOnce(DoAll(SetArgPointee<1>(local_address), Return(true))); - } - // Destroy is called asynchronously after the address is resolved, - // so we need a variable to wait on. - bool destroy_called = false; - EXPECT_CALL(mock_async_resolver, Destroy(_)) - .WillOnce(Assign(&destroy_called, true)); + resolver_fixture.SetAddressToReturn(local_address); ResumeCandidates(0); // Verify ep2's selected connection is updated to use the 'local' candidate. EXPECT_EQ_WAIT(LOCAL_PORT_TYPE, @@ -4929,7 +4999,6 @@ TEST_F(P2PTransportChannelTest, kMediumTimeout); EXPECT_EQ(selected_connection, ep2_ch1()->selected_connection()); - WAIT(destroy_called, kShortTimeout); DestroyChannels(); } @@ -4939,13 +5008,9 @@ TEST_F(P2PTransportChannelTest, // address after the resolution completes. TEST_F(P2PTransportChannelTest, PeerReflexiveCandidateDuringResolvingHostCandidateWithMdnsName) { - auto mock_async_resolver = new NiceMock(); - ON_CALL(*mock_async_resolver, Destroy).WillByDefault([mock_async_resolver] { - delete mock_async_resolver; - }); - webrtc::MockAsyncResolverFactory mock_async_resolver_factory; - EXPECT_CALL(mock_async_resolver_factory, Create()) - .WillOnce(Return(mock_async_resolver)); + ResolverFactoryFixture resolver_fixture; + // Prevent resolution until triggered by FireDelayedResolution. + resolver_fixture.DelayResolution(); // ep1 and ep2 will only gather host candidates with addresses // kPublicAddrs[0] and kPublicAddrs[1], respectively. @@ -4954,12 +5019,13 @@ TEST_F(P2PTransportChannelTest, set_remote_ice_parameter_source(FROM_SETICEPARAMETERS); GetEndpoint(0)->network_manager_.set_mdns_responder( std::make_unique(rtc::Thread::Current())); - GetEndpoint(1)->async_resolver_factory_ = &mock_async_resolver_factory; + GetEndpoint(1)->async_dns_resolver_factory_ = &resolver_fixture; CreateChannels(); // Pause sending candidates from both endpoints until we find out what port // number is assgined to ep1's host candidate. PauseCandidates(0); PauseCandidates(1); + ASSERT_EQ_WAIT(1u, GetEndpoint(0)->saved_candidates_.size(), kMediumTimeout); ASSERT_EQ(1u, GetEndpoint(0)->saved_candidates_[0]->candidates.size()); const auto& local_candidate = @@ -4969,24 +5035,16 @@ TEST_F(P2PTransportChannelTest, // This is the underlying private IP address of the same candidate at ep1. const auto local_address = rtc::SocketAddress( kPublicAddrs[0].ipaddr(), local_candidate.address().port()); - bool mock_async_resolver_started = false; - // Not signaling done yet, and only make sure we are in the process of - // resolution. - EXPECT_CALL(*mock_async_resolver, Start(_)) - .WillOnce(InvokeWithoutArgs([&mock_async_resolver_started]() { - mock_async_resolver_started = true; - })); // Let ep1 signal its hostname candidate to ep2. ResumeCandidates(0); - ASSERT_TRUE_WAIT(mock_async_resolver_started, kMediumTimeout); // Now that ep2 is in the process of resolving the hostname candidate signaled // by ep1. Let ep2 signal its host candidate with an IP address to ep1, so // that ep1 can form a candidate pair, select it and start to ping ep2. ResumeCandidates(1); ASSERT_TRUE_WAIT(ep1_ch1()->selected_connection() != nullptr, kMediumTimeout); // Let the mock resolver of ep2 receives the correct resolution. - EXPECT_CALL(*mock_async_resolver, GetResolvedAddress(_, _)) - .WillOnce(DoAll(SetArgPointee<1>(local_address), Return(true))); + resolver_fixture.SetAddressToReturn(local_address); + // Upon receiving a ping from ep1, ep2 adds a prflx candidate from the // unknown address and establishes a connection. // @@ -4997,7 +5055,9 @@ TEST_F(P2PTransportChannelTest, ep2_ch1()->selected_connection()->remote_candidate().type()); // ep2 should also be able resolve the hostname candidate. The resolved remote // host candidate should be merged with the prflx remote candidate. - mock_async_resolver->SignalDone(mock_async_resolver); + + resolver_fixture.FireDelayedResolution(); + EXPECT_EQ_WAIT(LOCAL_PORT_TYPE, ep2_ch1()->selected_connection()->remote_candidate().type(), kMediumTimeout); @@ -5010,10 +5070,7 @@ TEST_F(P2PTransportChannelTest, // which is obfuscated by an mDNS name, and if the peer can complete the name // resolution with the correct IP address, we can have a p2p connection. TEST_F(P2PTransportChannelTest, CanConnectWithHostCandidateWithMdnsName) { - NiceMock mock_async_resolver; - webrtc::MockAsyncResolverFactory mock_async_resolver_factory; - EXPECT_CALL(mock_async_resolver_factory, Create()) - .WillOnce(Return(&mock_async_resolver)); + ResolverFactoryFixture resolver_fixture; // ep1 and ep2 will only gather host candidates with addresses // kPublicAddrs[0] and kPublicAddrs[1], respectively. @@ -5022,7 +5079,7 @@ TEST_F(P2PTransportChannelTest, CanConnectWithHostCandidateWithMdnsName) { set_remote_ice_parameter_source(FROM_SETICEPARAMETERS); GetEndpoint(0)->network_manager_.set_mdns_responder( std::make_unique(rtc::Thread::Current())); - GetEndpoint(1)->async_resolver_factory_ = &mock_async_resolver_factory; + GetEndpoint(1)->async_dns_resolver_factory_ = &resolver_fixture; CreateChannels(); // Pause sending candidates from both endpoints until we find out what port // number is assgined to ep1's host candidate. @@ -5039,8 +5096,7 @@ TEST_F(P2PTransportChannelTest, CanConnectWithHostCandidateWithMdnsName) { rtc::SocketAddress resolved_address_ep1(local_candidate_ep1.address()); resolved_address_ep1.SetResolvedIP(kPublicAddrs[0].ipaddr()); - EXPECT_CALL(mock_async_resolver, GetResolvedAddress(_, _)) - .WillOnce(DoAll(SetArgPointee<1>(resolved_address_ep1), Return(true))); + resolver_fixture.SetAddressToReturn(resolved_address_ep1); // Let ep1 signal its hostname candidate to ep2. ResumeCandidates(0); @@ -5064,10 +5120,7 @@ TEST_F(P2PTransportChannelTest, CanConnectWithHostCandidateWithMdnsName) { // this remote host candidate in stats. TEST_F(P2PTransportChannelTest, CandidatesSanitizedInStatsWhenMdnsObfuscationEnabled) { - NiceMock mock_async_resolver; - webrtc::MockAsyncResolverFactory mock_async_resolver_factory; - EXPECT_CALL(mock_async_resolver_factory, Create()) - .WillOnce(Return(&mock_async_resolver)); + ResolverFactoryFixture resolver_fixture; // ep1 and ep2 will gather host candidates with addresses // kPublicAddrs[0] and kPublicAddrs[1], respectively. ep1 also gathers a srflx @@ -5079,7 +5132,7 @@ TEST_F(P2PTransportChannelTest, set_remote_ice_parameter_source(FROM_SETICEPARAMETERS); GetEndpoint(0)->network_manager_.set_mdns_responder( std::make_unique(rtc::Thread::Current())); - GetEndpoint(1)->async_resolver_factory_ = &mock_async_resolver_factory; + GetEndpoint(1)->async_dns_resolver_factory_ = &resolver_fixture; CreateChannels(); // Pause sending candidates from both endpoints until we find out what port // number is assigned to ep1's host candidate. @@ -5097,9 +5150,7 @@ TEST_F(P2PTransportChannelTest, // and let the mock resolver of ep2 receive the correct resolution. rtc::SocketAddress resolved_address_ep1(local_candidate_ep1.address()); resolved_address_ep1.SetResolvedIP(kPublicAddrs[0].ipaddr()); - EXPECT_CALL(mock_async_resolver, GetResolvedAddress(_, _)) - .WillOnce( - DoAll(SetArgPointee<1>(resolved_address_ep1), Return(true))); + resolver_fixture.SetAddressToReturn(resolved_address_ep1); break; } } @@ -5248,10 +5299,7 @@ TEST_F(P2PTransportChannelTest, // when it is queried via GetSelectedCandidatePair. TEST_F(P2PTransportChannelTest, SelectedCandidatePairSanitizedWhenMdnsObfuscationEnabled) { - NiceMock mock_async_resolver; - webrtc::MockAsyncResolverFactory mock_async_resolver_factory; - EXPECT_CALL(mock_async_resolver_factory, Create()) - .WillOnce(Return(&mock_async_resolver)); + ResolverFactoryFixture resolver_fixture; // ep1 and ep2 will gather host candidates with addresses // kPublicAddrs[0] and kPublicAddrs[1], respectively. @@ -5260,7 +5308,7 @@ TEST_F(P2PTransportChannelTest, set_remote_ice_parameter_source(FROM_SETICEPARAMETERS); GetEndpoint(0)->network_manager_.set_mdns_responder( std::make_unique(rtc::Thread::Current())); - GetEndpoint(1)->async_resolver_factory_ = &mock_async_resolver_factory; + GetEndpoint(1)->async_dns_resolver_factory_ = &resolver_fixture; CreateChannels(); // Pause sending candidates from both endpoints until we find out what port // number is assigned to ep1's host candidate. @@ -5275,8 +5323,8 @@ TEST_F(P2PTransportChannelTest, // and let the mock resolver of ep2 receive the correct resolution. rtc::SocketAddress resolved_address_ep1(local_candidate_ep1.address()); resolved_address_ep1.SetResolvedIP(kPublicAddrs[0].ipaddr()); - EXPECT_CALL(mock_async_resolver, GetResolvedAddress(_, _)) - .WillOnce(DoAll(SetArgPointee<1>(resolved_address_ep1), Return(true))); + resolver_fixture.SetAddressToReturn(resolved_address_ep1); + ResumeCandidates(0); ResumeCandidates(1); @@ -5305,8 +5353,8 @@ TEST_F(P2PTransportChannelTest, // We use one endpoint to test the behavior of adding remote candidates, and // this endpoint only gathers relay candidates. ConfigureEndpoints(OPEN, OPEN, kOnlyRelayPorts, kDefaultPortAllocatorFlags); - GetEndpoint(0)->cd1_.ch_.reset(CreateChannel( - 0, ICE_CANDIDATE_COMPONENT_DEFAULT, kIceParams[0], kIceParams[1])); + GetEndpoint(0)->cd1_.ch_ = CreateChannel(0, ICE_CANDIDATE_COMPONENT_DEFAULT, + kIceParams[0], kIceParams[1]); IceConfig config; // Start gathering and we should have only a single relay port. ep1_ch1()->SetIceConfig(config); @@ -5869,21 +5917,21 @@ class ForgetLearnedStateControllerFactory TEST_F(P2PTransportChannelPingTest, TestForgetLearnedState) { ForgetLearnedStateControllerFactory factory; FakePortAllocator pa(rtc::Thread::Current(), nullptr); - P2PTransportChannel ch("ping sufficiently", 1, &pa, nullptr, nullptr, - &factory); - PrepareChannel(&ch); - ch.MaybeStartGathering(); - ch.AddRemoteCandidate(CreateUdpCandidate(LOCAL_PORT_TYPE, "1.1.1.1", 1, 1)); - ch.AddRemoteCandidate(CreateUdpCandidate(LOCAL_PORT_TYPE, "2.2.2.2", 2, 2)); - - Connection* conn1 = WaitForConnectionTo(&ch, "1.1.1.1", 1); - Connection* conn2 = WaitForConnectionTo(&ch, "2.2.2.2", 2); + auto ch = P2PTransportChannel::Create("ping sufficiently", 1, &pa, nullptr, + nullptr, &factory); + PrepareChannel(ch.get()); + ch->MaybeStartGathering(); + ch->AddRemoteCandidate(CreateUdpCandidate(LOCAL_PORT_TYPE, "1.1.1.1", 1, 1)); + ch->AddRemoteCandidate(CreateUdpCandidate(LOCAL_PORT_TYPE, "2.2.2.2", 2, 2)); + + Connection* conn1 = WaitForConnectionTo(ch.get(), "1.1.1.1", 1); + Connection* conn2 = WaitForConnectionTo(ch.get(), "2.2.2.2", 2); ASSERT_TRUE(conn1 != nullptr); ASSERT_TRUE(conn2 != nullptr); // Wait for conn1 to be selected. conn1->ReceivedPingResponse(LOW_RTT, "id"); - EXPECT_EQ_WAIT(conn1, ch.selected_connection(), kMediumTimeout); + EXPECT_EQ_WAIT(conn1, ch->selected_connection(), kMediumTimeout); conn2->ReceivedPingResponse(LOW_RTT, "id"); EXPECT_TRUE(conn2->writable()); @@ -5904,23 +5952,23 @@ TEST_F(P2PTransportChannelTest, DisableDnsLookupsWithTransportPolicyRelay) { auto* ep1 = GetEndpoint(0); ep1->allocator_->SetCandidateFilter(CF_RELAY); - rtc::MockAsyncResolver mock_async_resolver; - webrtc::MockAsyncResolverFactory mock_async_resolver_factory; + std::unique_ptr mock_async_resolver = + std::make_unique(); + // This test expects resolution to not be started. + EXPECT_CALL(*mock_async_resolver, Start(_, _)).Times(0); + + webrtc::MockAsyncDnsResolverFactory mock_async_resolver_factory; ON_CALL(mock_async_resolver_factory, Create()) - .WillByDefault(Return(&mock_async_resolver)); - ep1->async_resolver_factory_ = &mock_async_resolver_factory; + .WillByDefault( + [&mock_async_resolver]() { return std::move(mock_async_resolver); }); - bool lookup_started = false; - ON_CALL(mock_async_resolver, Start(_)) - .WillByDefault(Assign(&lookup_started, true)); + ep1->async_dns_resolver_factory_ = &mock_async_resolver_factory; CreateChannels(); ep1_ch1()->AddRemoteCandidate( CreateUdpCandidate(LOCAL_PORT_TYPE, "hostname.test", 1, 100)); - EXPECT_FALSE(lookup_started); - DestroyChannels(); } @@ -5930,23 +5978,23 @@ TEST_F(P2PTransportChannelTest, DisableDnsLookupsWithTransportPolicyNone) { auto* ep1 = GetEndpoint(0); ep1->allocator_->SetCandidateFilter(CF_NONE); - rtc::MockAsyncResolver mock_async_resolver; - webrtc::MockAsyncResolverFactory mock_async_resolver_factory; + std::unique_ptr mock_async_resolver = + std::make_unique(); + // This test expects resolution to not be started. + EXPECT_CALL(*mock_async_resolver, Start(_, _)).Times(0); + + webrtc::MockAsyncDnsResolverFactory mock_async_resolver_factory; ON_CALL(mock_async_resolver_factory, Create()) - .WillByDefault(Return(&mock_async_resolver)); - ep1->async_resolver_factory_ = &mock_async_resolver_factory; + .WillByDefault( + [&mock_async_resolver]() { return std::move(mock_async_resolver); }); - bool lookup_started = false; - ON_CALL(mock_async_resolver, Start(_)) - .WillByDefault(Assign(&lookup_started, true)); + ep1->async_dns_resolver_factory_ = &mock_async_resolver_factory; CreateChannels(); ep1_ch1()->AddRemoteCandidate( CreateUdpCandidate(LOCAL_PORT_TYPE, "hostname.test", 1, 100)); - EXPECT_FALSE(lookup_started); - DestroyChannels(); } @@ -5956,18 +6004,19 @@ TEST_F(P2PTransportChannelTest, EnableDnsLookupsWithTransportPolicyNoHost) { auto* ep1 = GetEndpoint(0); ep1->allocator_->SetCandidateFilter(CF_ALL & ~CF_HOST); - rtc::MockAsyncResolver mock_async_resolver; - webrtc::MockAsyncResolverFactory mock_async_resolver_factory; - EXPECT_CALL(mock_async_resolver_factory, Create()) - .WillOnce(Return(&mock_async_resolver)); - EXPECT_CALL(mock_async_resolver, Destroy(_)); - - ep1->async_resolver_factory_ = &mock_async_resolver_factory; - + std::unique_ptr mock_async_resolver = + std::make_unique(); bool lookup_started = false; - EXPECT_CALL(mock_async_resolver, Start(_)) + EXPECT_CALL(*mock_async_resolver, Start(_, _)) .WillOnce(Assign(&lookup_started, true)); + webrtc::MockAsyncDnsResolverFactory mock_async_resolver_factory; + EXPECT_CALL(mock_async_resolver_factory, Create()) + .WillOnce( + [&mock_async_resolver]() { return std::move(mock_async_resolver); }); + + ep1->async_dns_resolver_factory_ = &mock_async_resolver_factory; + CreateChannels(); ep1_ch1()->AddRemoteCandidate( diff --git a/p2p/base/port.cc b/p2p/base/port.cc index 7b54c11cb8..a03a0d6a66 100644 --- a/p2p/base/port.cc +++ b/p2p/base/port.cc @@ -33,6 +33,7 @@ #include "rtc_base/string_utils.h" #include "rtc_base/strings/string_builder.h" #include "rtc_base/third_party/base64/base64.h" +#include "rtc_base/trace_event.h" #include "system_wrappers/include/field_trial.h" namespace { @@ -104,16 +105,6 @@ std::string Port::ComputeFoundation(const std::string& type, return rtc::ToString(rtc::ComputeCrc32(sb.Release())); } -CandidateStats::CandidateStats() = default; - -CandidateStats::CandidateStats(const CandidateStats&) = default; - -CandidateStats::CandidateStats(Candidate candidate) { - this->candidate = candidate; -} - -CandidateStats::~CandidateStats() = default; - Port::Port(rtc::Thread* thread, const std::string& type, rtc::PacketSocketFactory* factory, @@ -137,6 +128,7 @@ Port::Port(rtc::Thread* thread, tiebreaker_(0), shared_socket_(true), weak_factory_(this) { + RTC_DCHECK(factory_ != NULL); Construct(); } @@ -493,7 +485,8 @@ bool Port::GetStunMessage(const char* data, } // If ICE, and the MESSAGE-INTEGRITY is bad, fail with a 401 Unauthorized - if (!stun_msg->ValidateMessageIntegrity(data, size, password_)) { + if (stun_msg->ValidateMessageIntegrity(password_) != + StunMessage::IntegrityStatus::kIntegrityOk) { RTC_LOG(LS_ERROR) << ToString() << ": Received " << StunMethodToString(stun_msg->type()) << " with bad M-I from " << addr.ToSensitiveString() @@ -559,7 +552,8 @@ bool Port::GetStunMessage(const char* data, // No stun attributes will be verified, if it's stun indication message. // Returning from end of the this method. } else if (stun_msg->type() == GOOG_PING_REQUEST) { - if (!stun_msg->ValidateMessageIntegrity32(data, size, password_)) { + if (stun_msg->ValidateMessageIntegrity(password_) != + StunMessage::IntegrityStatus::kIntegrityOk) { RTC_LOG(LS_ERROR) << ToString() << ": Received " << StunMethodToString(stun_msg->type()) << " with bad M-I from " << addr.ToSensitiveString() @@ -833,6 +827,7 @@ void Port::Prune() { // Call to stop any currently pending operations from running. void Port::CancelPendingTasks() { + TRACE_EVENT0("webrtc", "Port::CancelPendingTasks"); RTC_DCHECK_RUN_ON(thread_); thread_->Clear(this); } @@ -849,6 +844,14 @@ void Port::OnMessage(rtc::Message* pmsg) { } } +void Port::SubscribePortDestroyed( + std::function callback) { + port_destroyed_callback_list_.AddReceiver(callback); +} + +void Port::SendPortDestroyed(Port* port) { + port_destroyed_callback_list_.Send(port); +} void Port::OnNetworkTypeChanged(const rtc::Network* network) { RTC_DCHECK(network == network_); @@ -913,7 +916,7 @@ void Port::OnConnectionDestroyed(Connection* conn) { void Port::Destroy() { RTC_DCHECK(connections_.empty()); RTC_LOG(LS_INFO) << ToString() << ": Port deleted"; - SignalDestroyed(this); + SendPortDestroyed(this); delete this; } diff --git a/p2p/base/port.h b/p2p/base/port.h index 43196e5c03..2c18f1adeb 100644 --- a/p2p/base/port.h +++ b/p2p/base/port.h @@ -33,6 +33,7 @@ #include "p2p/base/port_interface.h" #include "p2p/base/stun_request.h" #include "rtc_base/async_packet_socket.h" +#include "rtc_base/callback_list.h" #include "rtc_base/checks.h" #include "rtc_base/net_helper.h" #include "rtc_base/network.h" @@ -98,14 +99,24 @@ class StunStats { // Stats that we can return about a candidate. class CandidateStats { public: - CandidateStats(); - explicit CandidateStats(Candidate candidate); - CandidateStats(const CandidateStats&); - ~CandidateStats(); + CandidateStats() = default; + CandidateStats(const CandidateStats&) = default; + CandidateStats(CandidateStats&&) = default; + CandidateStats(Candidate candidate, + absl::optional stats = absl::nullopt) + : candidate_(std::move(candidate)), stun_stats_(std::move(stats)) {} + ~CandidateStats() = default; - Candidate candidate; + CandidateStats& operator=(const CandidateStats& other) = default; + + const Candidate& candidate() const { return candidate_; } + + const absl::optional& stun_stats() const { return stun_stats_; } + + private: + Candidate candidate_; // STUN port stats if this candidate is a STUN candidate. - absl::optional stun_stats; + absl::optional stun_stats_; }; typedef std::vector CandidateStatsList; @@ -217,9 +228,6 @@ class Port : public PortInterface, // The factory used to create the sockets of this port. rtc::PacketSocketFactory* socket_factory() const { return factory_; } - void set_socket_factory(rtc::PacketSocketFactory* factory) { - factory_ = factory; - } // For debugging purposes. const std::string& content_name() const { return content_name_; } @@ -269,6 +277,9 @@ class Port : public PortInterface, // connection. sigslot::signal1 SignalPortError; + void SubscribePortDestroyed( + std::function callback) override; + void SendPortDestroyed(Port* port); // Returns a map containing all of the connections of this port, keyed by the // remote address. typedef std::map AddressMap; @@ -441,7 +452,7 @@ class Port : public PortInterface, void OnNetworkTypeChanged(const rtc::Network* network); rtc::Thread* const thread_; - rtc::PacketSocketFactory* factory_; + rtc::PacketSocketFactory* const factory_; std::string type_; bool send_retransmit_count_attribute_; rtc::Network* network_; @@ -487,6 +498,7 @@ class Port : public PortInterface, bool is_final); friend class Connection; + webrtc::CallbackList port_destroyed_callback_list_; }; } // namespace cricket diff --git a/p2p/base/port_allocator.cc b/p2p/base/port_allocator.cc index b13896c4bc..d8ff637e2c 100644 --- a/p2p/base/port_allocator.cc +++ b/p2p/base/port_allocator.cc @@ -317,7 +317,8 @@ Candidate PortAllocator::SanitizeCandidate(const Candidate& c) const { // For a local host candidate, we need to conceal its IP address candidate if // the mDNS obfuscation is enabled. bool use_hostname_address = - c.type() == LOCAL_PORT_TYPE && MdnsObfuscationEnabled(); + (c.type() == LOCAL_PORT_TYPE || c.type() == PRFLX_PORT_TYPE) && + MdnsObfuscationEnabled(); // If adapter enumeration is disabled or host candidates are disabled, // clear the raddr of STUN candidates to avoid local address leakage. bool filter_stun_related_address = diff --git a/p2p/base/port_allocator.h b/p2p/base/port_allocator.h index 4bbe56c0b5..33a23484f2 100644 --- a/p2p/base/port_allocator.h +++ b/p2p/base/port_allocator.h @@ -16,6 +16,7 @@ #include #include +#include "api/sequence_checker.h" #include "api/transport/enums.h" #include "p2p/base/port.h" #include "p2p/base/port_interface.h" @@ -25,7 +26,6 @@ #include "rtc_base/system/rtc_export.h" #include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/thread.h" -#include "rtc_base/thread_checker.h" namespace webrtc { class TurnCustomizer; @@ -638,7 +638,7 @@ class RTC_EXPORT PortAllocator : public sigslot::has_slots<> { bool allow_tcp_listen_; uint32_t candidate_filter_; std::string origin_; - rtc::ThreadChecker thread_checker_; + webrtc::SequenceChecker thread_checker_; private: ServerAddresses stun_servers_; diff --git a/p2p/base/port_allocator_unittest.cc b/p2p/base/port_allocator_unittest.cc index 70946a3d81..cbac5cccaf 100644 --- a/p2p/base/port_allocator_unittest.cc +++ b/p2p/base/port_allocator_unittest.cc @@ -305,3 +305,56 @@ TEST_F(PortAllocatorTest, RestrictIceCredentialsChange) { credentials[0].pwd)); allocator_->DiscardCandidatePool(); } + +// Constants for testing candidates +const char kIpv4Address[] = "12.34.56.78"; +const char kIpv4AddressWithPort[] = "12.34.56.78:443"; + +TEST_F(PortAllocatorTest, SanitizeEmptyCandidateDefaultConfig) { + cricket::Candidate input; + cricket::Candidate output = allocator_->SanitizeCandidate(input); + EXPECT_EQ("", output.address().ipaddr().ToString()); +} + +TEST_F(PortAllocatorTest, SanitizeIpv4CandidateDefaultConfig) { + cricket::Candidate input(1, "udp", rtc::SocketAddress(kIpv4Address, 443), 1, + "username", "password", cricket::LOCAL_PORT_TYPE, 1, + "foundation", 1, 1); + cricket::Candidate output = allocator_->SanitizeCandidate(input); + EXPECT_EQ(kIpv4AddressWithPort, output.address().ToString()); + EXPECT_EQ(kIpv4Address, output.address().ipaddr().ToString()); +} + +TEST_F(PortAllocatorTest, SanitizeIpv4CandidateMdnsObfuscationEnabled) { + allocator_->SetMdnsObfuscationEnabledForTesting(true); + cricket::Candidate input(1, "udp", rtc::SocketAddress(kIpv4Address, 443), 1, + "username", "password", cricket::LOCAL_PORT_TYPE, 1, + "foundation", 1, 1); + cricket::Candidate output = allocator_->SanitizeCandidate(input); + EXPECT_NE(kIpv4AddressWithPort, output.address().ToString()); + EXPECT_EQ("", output.address().ipaddr().ToString()); +} + +TEST_F(PortAllocatorTest, SanitizePrflxCandidateMdnsObfuscationEnabled) { + allocator_->SetMdnsObfuscationEnabledForTesting(true); + // Create the candidate from an IP literal. This populates the hostname. + cricket::Candidate input(1, "udp", rtc::SocketAddress(kIpv4Address, 443), 1, + "username", "password", cricket::PRFLX_PORT_TYPE, 1, + "foundation", 1, 1); + cricket::Candidate output = allocator_->SanitizeCandidate(input); + EXPECT_NE(kIpv4AddressWithPort, output.address().ToString()); + EXPECT_EQ("", output.address().ipaddr().ToString()); +} + +TEST_F(PortAllocatorTest, SanitizeIpv4NonLiteralMdnsObfuscationEnabled) { + // Create the candidate with an empty hostname. + allocator_->SetMdnsObfuscationEnabledForTesting(true); + rtc::IPAddress ip; + EXPECT_TRUE(IPFromString(kIpv4Address, &ip)); + cricket::Candidate input(1, "udp", rtc::SocketAddress(ip, 443), 1, "username", + "password", cricket::LOCAL_PORT_TYPE, 1, + "foundation", 1, 1); + cricket::Candidate output = allocator_->SanitizeCandidate(input); + EXPECT_NE(kIpv4AddressWithPort, output.address().ToString()); + EXPECT_EQ("", output.address().ipaddr().ToString()); +} diff --git a/p2p/base/port_interface.h b/p2p/base/port_interface.h index 39eae18a0d..73c8e36c78 100644 --- a/p2p/base/port_interface.h +++ b/p2p/base/port_interface.h @@ -12,12 +12,14 @@ #define P2P_BASE_PORT_INTERFACE_H_ #include +#include #include #include "absl/types/optional.h" #include "api/candidate.h" #include "p2p/base/transport_description.h" #include "rtc_base/async_packet_socket.h" +#include "rtc_base/callback_list.h" #include "rtc_base/socket_address.h" namespace rtc { @@ -112,7 +114,8 @@ class PortInterface { // Signaled when this port decides to delete itself because it no longer has // any usefulness. - sigslot::signal1 SignalDestroyed; + virtual void SubscribePortDestroyed( + std::function callback) = 0; // Signaled when Port discovers ice role conflict with the peer. sigslot::signal1 SignalRoleConflict; diff --git a/p2p/base/port_unittest.cc b/p2p/base/port_unittest.cc index 0bb378992b..293a8d1f8b 100644 --- a/p2p/base/port_unittest.cc +++ b/p2p/base/port_unittest.cc @@ -270,7 +270,8 @@ class TestChannel : public sigslot::has_slots<> { explicit TestChannel(std::unique_ptr p1) : port_(std::move(p1)) { port_->SignalPortComplete.connect(this, &TestChannel::OnPortComplete); port_->SignalUnknownAddress.connect(this, &TestChannel::OnUnknownAddress); - port_->SignalDestroyed.connect(this, &TestChannel::OnSrcPortDestroyed); + port_->SubscribePortDestroyed( + [this](PortInterface* port) { OnSrcPortDestroyed(port); }); } int complete_count() { return complete_count_; } @@ -777,7 +778,8 @@ class PortTest : public ::testing::Test, public sigslot::has_slots<> { bool role_conflict() const { return role_conflict_; } void ConnectToSignalDestroyed(PortInterface* port) { - port->SignalDestroyed.connect(this, &PortTest::OnDestroyed); + port->SubscribePortDestroyed( + [this](PortInterface* port) { OnDestroyed(port); }); } void OnDestroyed(PortInterface* port) { ++ports_destroyed_; } @@ -1724,9 +1726,8 @@ TEST_F(PortTest, TestSendStunMessage) { EXPECT_EQ(kDefaultPrflxPriority, priority_attr->value()); EXPECT_EQ("rfrag:lfrag", username_attr->GetString()); EXPECT_TRUE(msg->GetByteString(STUN_ATTR_MESSAGE_INTEGRITY) != NULL); - EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( - lport->last_stun_buf()->data(), lport->last_stun_buf()->size(), - "rpass")); + EXPECT_EQ(StunMessage::IntegrityStatus::kIntegrityOk, + msg->ValidateMessageIntegrity("rpass")); const StunUInt64Attribute* ice_controlling_attr = msg->GetUInt64(STUN_ATTR_ICE_CONTROLLING); ASSERT_TRUE(ice_controlling_attr != NULL); @@ -1765,9 +1766,8 @@ TEST_F(PortTest, TestSendStunMessage) { ASSERT_TRUE(addr_attr != NULL); EXPECT_EQ(lport->Candidates()[0].address(), addr_attr->GetAddress()); EXPECT_TRUE(msg->GetByteString(STUN_ATTR_MESSAGE_INTEGRITY) != NULL); - EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( - rport->last_stun_buf()->data(), rport->last_stun_buf()->size(), - "rpass")); + EXPECT_EQ(StunMessage::IntegrityStatus::kIntegrityOk, + msg->ValidateMessageIntegrity("rpass")); EXPECT_TRUE(msg->GetUInt32(STUN_ATTR_FINGERPRINT) != NULL); EXPECT_TRUE(StunMessage::ValidateFingerprint( lport->last_stun_buf()->data(), lport->last_stun_buf()->size())); @@ -1796,9 +1796,8 @@ TEST_F(PortTest, TestSendStunMessage) { EXPECT_EQ(STUN_ERROR_SERVER_ERROR, error_attr->code()); EXPECT_EQ(std::string(STUN_ERROR_REASON_SERVER_ERROR), error_attr->reason()); EXPECT_TRUE(msg->GetByteString(STUN_ATTR_MESSAGE_INTEGRITY) != NULL); - EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( - rport->last_stun_buf()->data(), rport->last_stun_buf()->size(), - "rpass")); + EXPECT_EQ(StunMessage::IntegrityStatus::kIntegrityOk, + msg->ValidateMessageIntegrity("rpass")); EXPECT_TRUE(msg->GetUInt32(STUN_ATTR_FINGERPRINT) != NULL); EXPECT_TRUE(StunMessage::ValidateFingerprint( lport->last_stun_buf()->data(), lport->last_stun_buf()->size())); diff --git a/p2p/base/stun_port.cc b/p2p/base/stun_port.cc index 4e1a1f6a97..7b1a2a83a2 100644 --- a/p2p/base/stun_port.cc +++ b/p2p/base/stun_port.cc @@ -17,11 +17,11 @@ #include "p2p/base/connection.h" #include "p2p/base/p2p_constants.h" #include "p2p/base/port_allocator.h" +#include "rtc_base/async_resolver_interface.h" #include "rtc_base/checks.h" #include "rtc_base/helpers.h" #include "rtc_base/ip_address.h" #include "rtc_base/logging.h" -#include "rtc_base/net_helpers.h" #include "rtc_base/strings/string_builder.h" namespace cricket { @@ -306,7 +306,9 @@ int UDPPort::SendTo(const void* data, if (send_error_count_ < kSendErrorLogLimit) { ++send_error_count_; RTC_LOG(LS_ERROR) << ToString() << ": UDP send of " << size - << " bytes failed with error " << error_; + << " bytes to host " << addr.ToSensitiveString() << " (" + << addr.ToResolvedSensitiveString() + << ") failed with error " << error_; } } else { send_error_count_ = 0; @@ -593,7 +595,11 @@ void UDPPort::OnSendPacket(const void* data, size_t size, StunRequest* req) { options.info_signaled_after_sent.packet_type = rtc::PacketType::kStunMessage; CopyPortInformationToPacketInfo(&options.info_signaled_after_sent); if (socket_->SendTo(data, size, sreq->server_addr(), options) < 0) { - RTC_LOG_ERR_EX(LERROR, socket_->GetError()) << "sendto"; + RTC_LOG_ERR_EX(LERROR, socket_->GetError()) + << "UDP send of " << size << " bytes to host " + << sreq->server_addr().ToSensitiveString() << " (" + << sreq->server_addr().ToResolvedSensitiveString() + << ") failed with error " << error_; } stats_.stun_binding_requests_sent++; } diff --git a/p2p/base/stun_request.cc b/p2p/base/stun_request.cc index 44376ced95..2870dcdfc5 100644 --- a/p2p/base/stun_request.cc +++ b/p2p/base/stun_request.cc @@ -120,6 +120,18 @@ bool StunRequestManager::CheckResponse(StunMessage* msg) { } StunRequest* request = iter->second; + + // Now that we know the request, we can see if the response is + // integrity-protected or not. + // For some tests, the message integrity is not set in the request. + // Complain, and then don't check. + bool skip_integrity_checking = false; + if (request->msg()->integrity() == StunMessage::IntegrityStatus::kNotSet) { + skip_integrity_checking = true; + } else { + msg->ValidateMessageIntegrity(request->msg()->password()); + } + if (!msg->GetNonComprehendedAttributes().empty()) { // If a response contains unknown comprehension-required attributes, it's // simply discarded and the transaction is considered failed. See RFC5389 @@ -129,6 +141,9 @@ bool StunRequestManager::CheckResponse(StunMessage* msg) { delete request; return false; } else if (msg->type() == GetStunSuccessResponseType(request->type())) { + if (!msg->IntegrityOk() && !skip_integrity_checking) { + return false; + } request->OnResponse(msg); } else if (msg->type() == GetStunErrorResponseType(request->type())) { request->OnErrorResponse(msg); diff --git a/p2p/base/test_turn_server.h b/p2p/base/test_turn_server.h index d438a83301..ecd934861b 100644 --- a/p2p/base/test_turn_server.h +++ b/p2p/base/test_turn_server.h @@ -14,6 +14,7 @@ #include #include +#include "api/sequence_checker.h" #include "api/transport/stun.h" #include "p2p/base/basic_packet_socket_factory.h" #include "p2p/base/turn_server.h" @@ -21,7 +22,6 @@ #include "rtc_base/ssl_adapter.h" #include "rtc_base/ssl_identity.h" #include "rtc_base/thread.h" -#include "rtc_base/thread_checker.h" namespace cricket { @@ -147,7 +147,7 @@ class TestTurnServer : public TurnAuthInterface { TurnServer server_; rtc::Thread* thread_; - rtc::ThreadChecker thread_checker_; + webrtc::SequenceChecker thread_checker_; }; } // namespace cricket diff --git a/p2p/base/turn_port.cc b/p2p/base/turn_port.cc index 4d39f207b4..33925d43e7 100644 --- a/p2p/base/turn_port.cc +++ b/p2p/base/turn_port.cc @@ -28,6 +28,7 @@ #include "rtc_base/net_helpers.h" #include "rtc_base/socket_address.h" #include "rtc_base/strings/string_builder.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "system_wrappers/include/field_trial.h" namespace cricket { @@ -346,6 +347,15 @@ void TurnPort::PrepareAddress() { server_address_.address.SetPort(TURN_DEFAULT_PORT); } + if (!AllowedTurnPort(server_address_.address.port())) { + // This can only happen after a 300 ALTERNATE SERVER, since the port can't + // be created with a disallowed port number. + RTC_LOG(LS_ERROR) << "Attempt to start allocation with disallowed port# " + << server_address_.address.port(); + OnAllocateError(STUN_ERROR_SERVER_ERROR, + "Attempt to start allocation to a disallowed port"); + return; + } if (server_address_.address.IsUnresolvedIP()) { ResolveTurnAddress(server_address_.address); } else { @@ -715,16 +725,6 @@ bool TurnPort::HandleIncomingPacket(rtc::AsyncPacketSocket* socket, return false; } - // This must be a response for one of our requests. - // Check success responses, but not errors, for MESSAGE-INTEGRITY. - if (IsStunSuccessResponseType(msg_type) && - !StunMessage::ValidateMessageIntegrity(data, size, hash())) { - RTC_LOG(LS_WARNING) << ToString() - << ": Received TURN message with invalid " - "message integrity, msg_type: " - << msg_type; - return true; - } request_manager_.CheckResponse(data, size); return true; @@ -943,6 +943,21 @@ rtc::DiffServCodePoint TurnPort::StunDscpValue() const { return stun_dscp_value_; } +// static +bool TurnPort::AllowedTurnPort(int port) { + // Port 53, 80 and 443 are used for existing deployments. + // Ports above 1024 are assumed to be OK to use. + if (port == 53 || port == 80 || port == 443 || port >= 1024) { + return true; + } + // Allow any port if relevant field trial is set. This allows disabling the + // check. + if (webrtc::field_trial::IsEnabled("WebRTC-Turn-AllowSystemPorts")) { + return true; + } + return false; +} + void TurnPort::OnMessage(rtc::Message* message) { switch (message->message_id) { case MSG_ALLOCATE_ERROR: @@ -1274,10 +1289,12 @@ void TurnPort::ScheduleEntryDestruction(TurnEntry* entry) { RTC_DCHECK(!entry->destruction_timestamp().has_value()); int64_t timestamp = rtc::TimeMillis(); entry->set_destruction_timestamp(timestamp); - invoker_.AsyncInvokeDelayed( - RTC_FROM_HERE, thread(), - rtc::Bind(&TurnPort::DestroyEntryIfNotCancelled, this, entry, timestamp), - TURN_PERMISSION_TIMEOUT); + thread()->PostDelayedTask(ToQueuedTask(task_safety_.flag(), + [this, entry, timestamp] { + DestroyEntryIfNotCancelled( + entry, timestamp); + }), + TURN_PERMISSION_TIMEOUT); } bool TurnPort::SetEntryChannelId(const rtc::SocketAddress& address, diff --git a/p2p/base/turn_port.h b/p2p/base/turn_port.h index a9ec434194..55dbda5ece 100644 --- a/p2p/base/turn_port.h +++ b/p2p/base/turn_port.h @@ -23,9 +23,10 @@ #include "absl/memory/memory.h" #include "p2p/base/port.h" #include "p2p/client/basic_port_allocator.h" -#include "rtc_base/async_invoker.h" #include "rtc_base/async_packet_socket.h" +#include "rtc_base/async_resolver_interface.h" #include "rtc_base/ssl_certificate.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" namespace webrtc { class TurnCustomizer; @@ -65,6 +66,14 @@ class TurnPort : public Port { webrtc::TurnCustomizer* customizer) { // Do basic parameter validation. if (credentials.username.size() > kMaxTurnUsernameLength) { + RTC_LOG(LS_ERROR) << "Attempt to use TURN with a too long username " + << "of length " << credentials.username.size(); + return nullptr; + } + // Do not connect to low-numbered ports. The default STUN port is 3478. + if (!AllowedTurnPort(server_address.address.port())) { + RTC_LOG(LS_ERROR) << "Attempt to use TURN to connect to port " + << server_address.address.port(); return nullptr; } // Using `new` to access a non-public constructor. @@ -110,6 +119,14 @@ class TurnPort : public Port { rtc::SSLCertificateVerifier* tls_cert_verifier = nullptr) { // Do basic parameter validation. if (credentials.username.size() > kMaxTurnUsernameLength) { + RTC_LOG(LS_ERROR) << "Attempt to use TURN with a too long username " + << "of length " << credentials.username.size(); + return nullptr; + } + // Do not connect to low-numbered ports. The default STUN port is 3478. + if (!AllowedTurnPort(server_address.address.port())) { + RTC_LOG(LS_ERROR) << "Attempt to use TURN to connect to port " + << server_address.address.port(); return nullptr; } // Using `new` to access a non-public constructor. @@ -210,9 +227,6 @@ class TurnPort : public Port { rtc::AsyncPacketSocket* socket() const { return socket_; } - // For testing only. - rtc::AsyncInvoker* invoker() { return &invoker_; } - // Signal with resolved server address. // Parameters are port, server address and resolved server address. // This signal will be sent only if server address is resolved successfully. @@ -295,6 +309,7 @@ class TurnPort : public Port { typedef std::map SocketOptionsMap; typedef std::set AttemptedServerSet; + static bool AllowedTurnPort(int port); void OnMessage(rtc::Message* pmsg) override; bool CreateTurnClientSocket(); @@ -397,8 +412,6 @@ class TurnPort : public Port { // The number of retries made due to allocate mismatch error. size_t allocate_mismatch_retries_; - rtc::AsyncInvoker invoker_; - // Optional TurnCustomizer that can modify outgoing messages. Once set, this // must outlive the TurnPort's lifetime. webrtc::TurnCustomizer* turn_customizer_ = nullptr; @@ -411,6 +424,8 @@ class TurnPort : public Port { // to be more easy to work with. std::string turn_logging_id_; + webrtc::ScopedTaskSafety task_safety_; + friend class TurnEntry; friend class TurnAllocateRequest; friend class TurnRefreshRequest; diff --git a/p2p/base/turn_port_unittest.cc b/p2p/base/turn_port_unittest.cc index e5f614e2d6..6d396ad520 100644 --- a/p2p/base/turn_port_unittest.cc +++ b/p2p/base/turn_port_unittest.cc @@ -41,6 +41,7 @@ #include "rtc_base/thread.h" #include "rtc_base/time_utils.h" #include "rtc_base/virtual_socket_server.h" +#include "test/field_trial.h" #include "test/gtest.h" using rtc::SocketAddress; @@ -58,6 +59,15 @@ static const SocketAddress kTurnTcpIntAddr("99.99.99.4", static const SocketAddress kTurnUdpExtAddr("99.99.99.5", 0); static const SocketAddress kTurnAlternateIntAddr("99.99.99.6", cricket::TURN_SERVER_PORT); +// Port for redirecting to a TCP Web server. Should not work. +static const SocketAddress kTurnDangerousAddr("99.99.99.7", 81); +// Port 53 (the DNS port); should work. +static const SocketAddress kTurnPort53Addr("99.99.99.7", 53); +// Port 80 (the HTTP port); should work. +static const SocketAddress kTurnPort80Addr("99.99.99.7", 80); +// Port 443 (the HTTPS port); should work. +static const SocketAddress kTurnPort443Addr("99.99.99.7", 443); +// The default TURN server port. static const SocketAddress kTurnIntAddr("99.99.99.7", cricket::TURN_SERVER_PORT); static const SocketAddress kTurnIPv6IntAddr( @@ -94,6 +104,15 @@ static const cricket::ProtocolAddress kTurnTlsProtoAddr(kTurnTcpIntAddr, cricket::PROTO_TLS); static const cricket::ProtocolAddress kTurnUdpIPv6ProtoAddr(kTurnUdpIPv6IntAddr, cricket::PROTO_UDP); +static const cricket::ProtocolAddress kTurnDangerousProtoAddr( + kTurnDangerousAddr, + cricket::PROTO_TCP); +static const cricket::ProtocolAddress kTurnPort53ProtoAddr(kTurnPort53Addr, + cricket::PROTO_TCP); +static const cricket::ProtocolAddress kTurnPort80ProtoAddr(kTurnPort80Addr, + cricket::PROTO_TCP); +static const cricket::ProtocolAddress kTurnPort443ProtoAddr(kTurnPort443Addr, + cricket::PROTO_TCP); static const unsigned int MSG_TESTFINISH = 0; @@ -335,8 +354,8 @@ class TurnPortTest : public ::testing::Test, this, &TurnPortTest::OnTurnRefreshResult); turn_port_->SignalTurnPortClosed.connect(this, &TurnPortTest::OnTurnPortClosed); - turn_port_->SignalDestroyed.connect(this, - &TurnPortTest::OnTurnPortDestroyed); + turn_port_->SubscribePortDestroyed( + [this](PortInterface* port) { OnTurnPortDestroyed(port); }); } void CreateUdpPort() { CreateUdpPort(kLocalAddr2); } @@ -615,6 +634,11 @@ class TurnPortTest : public ::testing::Test, Port::ORIGIN_MESSAGE); Connection* conn2 = turn_port_->CreateConnection(udp_port_->Candidates()[0], Port::ORIGIN_MESSAGE); + + // Increased to 10 minutes, to ensure that the TurnEntry times out before + // the TurnPort. + turn_port_->set_timeout_delay(10 * 60 * 1000); + ASSERT_TRUE(conn2 != NULL); ASSERT_TRUE_SIMULATED_WAIT(turn_create_permission_success_, kSimulatedRtt, fake_clock_); @@ -631,11 +655,11 @@ class TurnPortTest : public ::testing::Test, EXPECT_TRUE_SIMULATED_WAIT(turn_unknown_address_, kSimulatedRtt, fake_clock_); - // Flush all requests in the invoker to destroy the TurnEntry. + // Wait for TurnEntry to expire. Timeout is 5 minutes. // Expect that it still processes an incoming ping and signals the // unknown address. turn_unknown_address_ = false; - turn_port_->invoker()->Flush(rtc::Thread::Current()); + fake_clock_.AdvanceTime(webrtc::TimeDelta::Seconds(5 * 60)); conn1->Ping(0); EXPECT_TRUE_SIMULATED_WAIT(turn_unknown_address_, kSimulatedRtt, fake_clock_); @@ -1785,4 +1809,58 @@ TEST_F(TurnPortTest, TestOverlongUsername) { CreateTurnPort(overlong_username, kTurnPassword, kTurnTlsProtoAddr)); } +TEST_F(TurnPortTest, TestTurnDangerousServer) { + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnDangerousProtoAddr); + ASSERT_FALSE(turn_port_); +} + +TEST_F(TurnPortTest, TestTurnDangerousServerPermits53) { + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnPort53ProtoAddr); + ASSERT_TRUE(turn_port_); +} + +TEST_F(TurnPortTest, TestTurnDangerousServerPermits80) { + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnPort80ProtoAddr); + ASSERT_TRUE(turn_port_); +} + +TEST_F(TurnPortTest, TestTurnDangerousServerPermits443) { + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnPort443ProtoAddr); + ASSERT_TRUE(turn_port_); +} + +TEST_F(TurnPortTest, TestTurnDangerousAlternateServer) { + const ProtocolType protocol_type = PROTO_TCP; + std::vector redirect_addresses; + redirect_addresses.push_back(kTurnDangerousAddr); + + TestTurnRedirector redirector(redirect_addresses); + + turn_server_.AddInternalSocket(kTurnIntAddr, protocol_type); + turn_server_.AddInternalSocket(kTurnDangerousAddr, protocol_type); + turn_server_.set_redirect_hook(&redirector); + CreateTurnPort(kTurnUsername, kTurnPassword, + ProtocolAddress(kTurnIntAddr, protocol_type)); + + // Retrieve the address before we run the state machine. + const SocketAddress old_addr = turn_port_->server_address().address; + + turn_port_->PrepareAddress(); + // This should result in an error event. + EXPECT_TRUE_SIMULATED_WAIT(error_event_.error_code != 0, + TimeToGetAlternateTurnCandidate(protocol_type), + fake_clock_); + // but should NOT result in the port turning ready, and no candidates + // should be gathered. + EXPECT_FALSE(turn_ready_); + ASSERT_EQ(0U, turn_port_->Candidates().size()); +} + +TEST_F(TurnPortTest, TestTurnDangerousServerAllowedWithFieldTrial) { + webrtc::test::ScopedFieldTrials override_field_trials( + "WebRTC-Turn-AllowSystemPorts/Enabled/"); + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnDangerousProtoAddr); + ASSERT_TRUE(turn_port_); +} + } // namespace cricket diff --git a/p2p/base/turn_server.cc b/p2p/base/turn_server.cc index 17a49e403d..53f283bc96 100644 --- a/p2p/base/turn_server.cc +++ b/p2p/base/turn_server.cc @@ -15,10 +15,10 @@ #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "api/packet_socket_factory.h" #include "api/transport/stun.h" #include "p2p/base/async_stun_tcp_socket.h" -#include "rtc_base/bind.h" #include "rtc_base/byte_buffer.h" #include "rtc_base/checks.h" #include "rtc_base/helpers.h" @@ -26,6 +26,7 @@ #include "rtc_base/message_digest.h" #include "rtc_base/socket_adapters.h" #include "rtc_base/strings/string_builder.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/thread.h" namespace cricket { @@ -129,7 +130,7 @@ TurnServer::TurnServer(rtc::Thread* thread) enable_otu_nonce_(false) {} TurnServer::~TurnServer() { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); for (InternalSocketMap::iterator it = server_sockets_.begin(); it != server_sockets_.end(); ++it) { rtc::AsyncPacketSocket* socket = it->first; @@ -145,7 +146,7 @@ TurnServer::~TurnServer() { void TurnServer::AddInternalSocket(rtc::AsyncPacketSocket* socket, ProtocolType proto) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); RTC_DCHECK(server_sockets_.end() == server_sockets_.find(socket)); server_sockets_[socket] = proto; socket->SignalReadPacket.connect(this, &TurnServer::OnInternalPacket); @@ -153,7 +154,7 @@ void TurnServer::AddInternalSocket(rtc::AsyncPacketSocket* socket, void TurnServer::AddInternalServerSocket(rtc::AsyncSocket* socket, ProtocolType proto) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); RTC_DCHECK(server_listen_sockets_.end() == server_listen_sockets_.find(socket)); server_listen_sockets_[socket] = proto; @@ -163,20 +164,19 @@ void TurnServer::AddInternalServerSocket(rtc::AsyncSocket* socket, void TurnServer::SetExternalSocketFactory( rtc::PacketSocketFactory* factory, const rtc::SocketAddress& external_addr) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); external_socket_factory_.reset(factory); external_addr_ = external_addr; } void TurnServer::OnNewInternalConnection(rtc::AsyncSocket* socket) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); RTC_DCHECK(server_listen_sockets_.find(socket) != server_listen_sockets_.end()); AcceptConnection(socket); } void TurnServer::AcceptConnection(rtc::AsyncSocket* server_socket) { - RTC_DCHECK(thread_checker_.IsCurrent()); // Check if someone is trying to connect to us. rtc::SocketAddress accept_addr; rtc::AsyncSocket* accepted_socket = server_socket->Accept(&accept_addr); @@ -193,7 +193,7 @@ void TurnServer::AcceptConnection(rtc::AsyncSocket* server_socket) { void TurnServer::OnInternalSocketClose(rtc::AsyncPacketSocket* socket, int err) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); DestroyInternalSocket(socket); } @@ -202,7 +202,7 @@ void TurnServer::OnInternalPacket(rtc::AsyncPacketSocket* socket, size_t size, const rtc::SocketAddress& addr, const int64_t& /* packet_time_us */) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); // Fail if the packet is too small to even contain a channel header. if (size < TURN_CHANNEL_HEADER_SIZE) { return; @@ -229,7 +229,6 @@ void TurnServer::OnInternalPacket(rtc::AsyncPacketSocket* socket, void TurnServer::HandleStunMessage(TurnServerConnection* conn, const char* data, size_t size) { - RTC_DCHECK(thread_checker_.IsCurrent()); TurnMessage msg; rtc::ByteBufferReader buf(data, size); if (!msg.Read(&buf) || (buf.Length() > 0)) { @@ -295,7 +294,6 @@ void TurnServer::HandleStunMessage(TurnServerConnection* conn, } bool TurnServer::GetKey(const StunMessage* msg, std::string* key) { - RTC_DCHECK(thread_checker_.IsCurrent()); const StunByteStringAttribute* username_attr = msg->GetByteString(STUN_ATTR_USERNAME); if (!username_attr) { @@ -307,11 +305,10 @@ bool TurnServer::GetKey(const StunMessage* msg, std::string* key) { } bool TurnServer::CheckAuthorization(TurnServerConnection* conn, - const StunMessage* msg, + StunMessage* msg, const char* data, size_t size, const std::string& key) { - RTC_DCHECK(thread_checker_.IsCurrent()); // RFC 5389, 10.2.2. RTC_DCHECK(IsStunRequestType(msg->type())); const StunByteStringAttribute* mi_attr = @@ -323,14 +320,14 @@ bool TurnServer::CheckAuthorization(TurnServerConnection* conn, const StunByteStringAttribute* nonce_attr = msg->GetByteString(STUN_ATTR_NONCE); - // Fail if no M-I. + // Fail if no MESSAGE_INTEGRITY. if (!mi_attr) { SendErrorResponseWithRealmAndNonce(conn, msg, STUN_ERROR_UNAUTHORIZED, STUN_ERROR_REASON_UNAUTHORIZED); return false; } - // Fail if there is M-I but no username, nonce, or realm. + // Fail if there is MESSAGE_INTEGRITY but no username, nonce, or realm. if (!username_attr || !realm_attr || !nonce_attr) { SendErrorResponse(conn, msg, STUN_ERROR_BAD_REQUEST, STUN_ERROR_REASON_BAD_REQUEST); @@ -344,9 +341,9 @@ bool TurnServer::CheckAuthorization(TurnServerConnection* conn, return false; } - // Fail if bad username or M-I. - // We need |data| and |size| for the call to ValidateMessageIntegrity. - if (key.empty() || !StunMessage::ValidateMessageIntegrity(data, size, key)) { + // Fail if bad MESSAGE_INTEGRITY. + if (key.empty() || msg->ValidateMessageIntegrity(key) != + StunMessage::IntegrityStatus::kIntegrityOk) { SendErrorResponseWithRealmAndNonce(conn, msg, STUN_ERROR_UNAUTHORIZED, STUN_ERROR_REASON_UNAUTHORIZED); return false; @@ -370,7 +367,6 @@ bool TurnServer::CheckAuthorization(TurnServerConnection* conn, void TurnServer::HandleBindingRequest(TurnServerConnection* conn, const StunMessage* req) { - RTC_DCHECK(thread_checker_.IsCurrent()); StunMessage response; InitResponse(req, &response); @@ -385,7 +381,6 @@ void TurnServer::HandleBindingRequest(TurnServerConnection* conn, void TurnServer::HandleAllocateRequest(TurnServerConnection* conn, const TurnMessage* msg, const std::string& key) { - RTC_DCHECK(thread_checker_.IsCurrent()); // Check the parameters in the request. const StunUInt32Attribute* transport_attr = msg->GetUInt32(STUN_ATTR_REQUESTED_TRANSPORT); @@ -415,7 +410,6 @@ void TurnServer::HandleAllocateRequest(TurnServerConnection* conn, } std::string TurnServer::GenerateNonce(int64_t now) const { - RTC_DCHECK(thread_checker_.IsCurrent()); // Generate a nonce of the form hex(now + HMAC-MD5(nonce_key_, now)) std::string input(reinterpret_cast(&now), sizeof(now)); std::string nonce = rtc::hex_encode(input.c_str(), input.size()); @@ -426,7 +420,6 @@ std::string TurnServer::GenerateNonce(int64_t now) const { } bool TurnServer::ValidateNonce(const std::string& nonce) const { - RTC_DCHECK(thread_checker_.IsCurrent()); // Check the size. if (nonce.size() != kNonceSize) { return false; @@ -453,7 +446,6 @@ bool TurnServer::ValidateNonce(const std::string& nonce) const { } TurnServerAllocation* TurnServer::FindAllocation(TurnServerConnection* conn) { - RTC_DCHECK(thread_checker_.IsCurrent()); AllocationMap::const_iterator it = allocations_.find(*conn); return (it != allocations_.end()) ? it->second.get() : nullptr; } @@ -461,7 +453,6 @@ TurnServerAllocation* TurnServer::FindAllocation(TurnServerConnection* conn) { TurnServerAllocation* TurnServer::CreateAllocation(TurnServerConnection* conn, int proto, const std::string& key) { - RTC_DCHECK(thread_checker_.IsCurrent()); rtc::AsyncPacketSocket* external_socket = (external_socket_factory_) ? external_socket_factory_->CreateUdpSocket(external_addr_, 0, 0) @@ -482,7 +473,7 @@ void TurnServer::SendErrorResponse(TurnServerConnection* conn, const StunMessage* req, int code, const std::string& reason) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); TurnMessage resp; InitErrorResponse(req, code, reason, &resp); RTC_LOG(LS_INFO) << "Sending error response, type=" << resp.type() @@ -494,7 +485,6 @@ void TurnServer::SendErrorResponseWithRealmAndNonce(TurnServerConnection* conn, const StunMessage* msg, int code, const std::string& reason) { - RTC_DCHECK(thread_checker_.IsCurrent()); TurnMessage resp; InitErrorResponse(msg, code, reason, &resp); @@ -514,7 +504,6 @@ void TurnServer::SendErrorResponseWithAlternateServer( TurnServerConnection* conn, const StunMessage* msg, const rtc::SocketAddress& addr) { - RTC_DCHECK(thread_checker_.IsCurrent()); TurnMessage resp; InitErrorResponse(msg, STUN_ERROR_TRY_ALTERNATE, STUN_ERROR_REASON_TRY_ALTERNATE_SERVER, &resp); @@ -524,7 +513,7 @@ void TurnServer::SendErrorResponseWithAlternateServer( } void TurnServer::SendStun(TurnServerConnection* conn, StunMessage* msg) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); rtc::ByteBufferWriter buf; // Add a SOFTWARE attribute if one is set. if (!software_.empty()) { @@ -537,13 +526,12 @@ void TurnServer::SendStun(TurnServerConnection* conn, StunMessage* msg) { void TurnServer::Send(TurnServerConnection* conn, const rtc::ByteBufferWriter& buf) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); rtc::PacketOptions options; conn->socket()->SendTo(buf.Data(), buf.Length(), conn->src(), options); } void TurnServer::OnAllocationDestroyed(TurnServerAllocation* allocation) { - RTC_DCHECK(thread_checker_.IsCurrent()); // Removing the internal socket if the connection is not udp. rtc::AsyncPacketSocket* socket = allocation->conn()->socket(); InternalSocketMap::iterator iter = server_sockets_.find(socket); @@ -563,27 +551,21 @@ void TurnServer::OnAllocationDestroyed(TurnServerAllocation* allocation) { } void TurnServer::DestroyInternalSocket(rtc::AsyncPacketSocket* socket) { - RTC_DCHECK(thread_checker_.IsCurrent()); InternalSocketMap::iterator iter = server_sockets_.find(socket); if (iter != server_sockets_.end()) { rtc::AsyncPacketSocket* socket = iter->first; socket->SignalReadPacket.disconnect(this); server_sockets_.erase(iter); + std::unique_ptr socket_to_delete = + absl::WrapUnique(socket); // We must destroy the socket async to avoid invalidating the sigslot // callback list iterator inside a sigslot callback. (In other words, // deleting an object from within a callback from that object). - sockets_to_delete_.push_back( - std::unique_ptr(socket)); - invoker_.AsyncInvoke(RTC_FROM_HERE, rtc::Thread::Current(), - rtc::Bind(&TurnServer::FreeSockets, this)); + thread_->PostTask(webrtc::ToQueuedTask( + [socket_to_delete = std::move(socket_to_delete)] {})); } } -void TurnServer::FreeSockets() { - RTC_DCHECK(thread_checker_.IsCurrent()); - sockets_to_delete_.clear(); -} - TurnServerConnection::TurnServerConnection(const rtc::SocketAddress& src, ProtocolType proto, rtc::AsyncPacketSocket* socket) diff --git a/p2p/base/turn_server.h b/p2p/base/turn_server.h index ca856448b3..f90c3dac0d 100644 --- a/p2p/base/turn_server.h +++ b/p2p/base/turn_server.h @@ -19,13 +19,12 @@ #include #include +#include "api/sequence_checker.h" #include "p2p/base/port_interface.h" -#include "rtc_base/async_invoker.h" #include "rtc_base/async_packet_socket.h" #include "rtc_base/socket_address.h" #include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/thread.h" -#include "rtc_base/thread_checker.h" namespace rtc { class ByteBufferWriter; @@ -129,8 +128,8 @@ class TurnServerAllocation : public rtc::MessageHandlerAutoCleanup, void OnChannelDestroyed(Channel* channel); void OnMessage(rtc::Message* msg) override; - TurnServer* server_; - rtc::Thread* thread_; + TurnServer* const server_; + rtc::Thread* const thread_; TurnServerConnection conn_; std::unique_ptr external_socket_; std::string key_; @@ -183,53 +182,53 @@ class TurnServer : public sigslot::has_slots<> { // Gets/sets the realm value to use for the server. const std::string& realm() const { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); return realm_; } void set_realm(const std::string& realm) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); realm_ = realm; } // Gets/sets the value for the SOFTWARE attribute for TURN messages. const std::string& software() const { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); return software_; } void set_software(const std::string& software) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); software_ = software; } const AllocationMap& allocations() const { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); return allocations_; } // Sets the authentication callback; does not take ownership. void set_auth_hook(TurnAuthInterface* auth_hook) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); auth_hook_ = auth_hook; } void set_redirect_hook(TurnRedirectInterface* redirect_hook) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); redirect_hook_ = redirect_hook; } void set_enable_otu_nonce(bool enable) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); enable_otu_nonce_ = enable; } // If set to true, reject CreatePermission requests to RFC1918 addresses. void set_reject_private_addresses(bool filter) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); reject_private_addresses_ = filter; } void set_enable_permission_checks(bool enable) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); enable_permission_checks_ = enable; } @@ -244,18 +243,22 @@ class TurnServer : public sigslot::has_slots<> { const rtc::SocketAddress& address); // For testing only. std::string SetTimestampForNextNonce(int64_t timestamp) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); ts_for_next_nonce_ = timestamp; return GenerateNonce(timestamp); } void SetStunMessageObserver(std::unique_ptr observer) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); stun_message_observer_ = std::move(observer); } private: - std::string GenerateNonce(int64_t now) const; + // All private member functions and variables should have access restricted to + // thread_. But compile-time annotations are missing for members access from + // TurnServerAllocation (via friend declaration), and the On* methods, which + // are called via sigslot. + std::string GenerateNonce(int64_t now) const RTC_RUN_ON(thread_); void OnInternalPacket(rtc::AsyncPacketSocket* socket, const char* data, size_t size, @@ -265,29 +268,32 @@ class TurnServer : public sigslot::has_slots<> { void OnNewInternalConnection(rtc::AsyncSocket* socket); // Accept connections on this server socket. - void AcceptConnection(rtc::AsyncSocket* server_socket); + void AcceptConnection(rtc::AsyncSocket* server_socket) RTC_RUN_ON(thread_); void OnInternalSocketClose(rtc::AsyncPacketSocket* socket, int err); void HandleStunMessage(TurnServerConnection* conn, const char* data, - size_t size); - void HandleBindingRequest(TurnServerConnection* conn, const StunMessage* msg); + size_t size) RTC_RUN_ON(thread_); + void HandleBindingRequest(TurnServerConnection* conn, const StunMessage* msg) + RTC_RUN_ON(thread_); void HandleAllocateRequest(TurnServerConnection* conn, const TurnMessage* msg, - const std::string& key); + const std::string& key) RTC_RUN_ON(thread_); - bool GetKey(const StunMessage* msg, std::string* key); + bool GetKey(const StunMessage* msg, std::string* key) RTC_RUN_ON(thread_); bool CheckAuthorization(TurnServerConnection* conn, - const StunMessage* msg, + StunMessage* msg, const char* data, size_t size, - const std::string& key); - bool ValidateNonce(const std::string& nonce) const; + const std::string& key) RTC_RUN_ON(thread_); + bool ValidateNonce(const std::string& nonce) const RTC_RUN_ON(thread_); - TurnServerAllocation* FindAllocation(TurnServerConnection* conn); + TurnServerAllocation* FindAllocation(TurnServerConnection* conn) + RTC_RUN_ON(thread_); TurnServerAllocation* CreateAllocation(TurnServerConnection* conn, int proto, - const std::string& key); + const std::string& key) + RTC_RUN_ON(thread_); void SendErrorResponse(TurnServerConnection* conn, const StunMessage* req, @@ -297,55 +303,53 @@ class TurnServer : public sigslot::has_slots<> { void SendErrorResponseWithRealmAndNonce(TurnServerConnection* conn, const StunMessage* req, int code, - const std::string& reason); + const std::string& reason) + RTC_RUN_ON(thread_); void SendErrorResponseWithAlternateServer(TurnServerConnection* conn, const StunMessage* req, - const rtc::SocketAddress& addr); + const rtc::SocketAddress& addr) + RTC_RUN_ON(thread_); void SendStun(TurnServerConnection* conn, StunMessage* msg); void Send(TurnServerConnection* conn, const rtc::ByteBufferWriter& buf); - void OnAllocationDestroyed(TurnServerAllocation* allocation); - void DestroyInternalSocket(rtc::AsyncPacketSocket* socket); - - // Just clears |sockets_to_delete_|; called asynchronously. - void FreeSockets(); + void OnAllocationDestroyed(TurnServerAllocation* allocation) + RTC_RUN_ON(thread_); + void DestroyInternalSocket(rtc::AsyncPacketSocket* socket) + RTC_RUN_ON(thread_); typedef std::map InternalSocketMap; typedef std::map ServerSocketMap; - rtc::Thread* thread_; - rtc::ThreadChecker thread_checker_; - std::string nonce_key_; - std::string realm_; - std::string software_; - TurnAuthInterface* auth_hook_; - TurnRedirectInterface* redirect_hook_; + rtc::Thread* const thread_; + const std::string nonce_key_; + std::string realm_ RTC_GUARDED_BY(thread_); + std::string software_ RTC_GUARDED_BY(thread_); + TurnAuthInterface* auth_hook_ RTC_GUARDED_BY(thread_); + TurnRedirectInterface* redirect_hook_ RTC_GUARDED_BY(thread_); // otu - one-time-use. Server will respond with 438 if it's // sees the same nonce in next transaction. - bool enable_otu_nonce_; + bool enable_otu_nonce_ RTC_GUARDED_BY(thread_); bool reject_private_addresses_ = false; // Check for permission when receiving an external packet. bool enable_permission_checks_ = true; - InternalSocketMap server_sockets_; - ServerSocketMap server_listen_sockets_; - // Used when we need to delete a socket asynchronously. - std::vector> sockets_to_delete_; - std::unique_ptr external_socket_factory_; - rtc::SocketAddress external_addr_; - - AllocationMap allocations_; + InternalSocketMap server_sockets_ RTC_GUARDED_BY(thread_); + ServerSocketMap server_listen_sockets_ RTC_GUARDED_BY(thread_); + std::unique_ptr external_socket_factory_ + RTC_GUARDED_BY(thread_); + rtc::SocketAddress external_addr_ RTC_GUARDED_BY(thread_); - rtc::AsyncInvoker invoker_; + AllocationMap allocations_ RTC_GUARDED_BY(thread_); // For testing only. If this is non-zero, the next NONCE will be generated // from this value, and it will be reset to 0 after generating the NONCE. - int64_t ts_for_next_nonce_ = 0; + int64_t ts_for_next_nonce_ RTC_GUARDED_BY(thread_) = 0; // For testing only. Used to observe STUN messages received. - std::unique_ptr stun_message_observer_; + std::unique_ptr stun_message_observer_ + RTC_GUARDED_BY(thread_); friend class TurnServerAllocation; }; diff --git a/p2p/client/basic_port_allocator.cc b/p2p/client/basic_port_allocator.cc index bb640d9498..1d38a4c19f 100644 --- a/p2p/client/basic_port_allocator.cc +++ b/p2p/client/basic_port_allocator.cc @@ -12,12 +12,14 @@ #include #include +#include #include #include #include #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "p2p/base/basic_packet_socket_factory.h" #include "p2p/base/port.h" #include "p2p/base/stun_port.h" @@ -27,6 +29,8 @@ #include "rtc_base/checks.h" #include "rtc_base/helpers.h" #include "rtc_base/logging.h" +#include "rtc_base/task_utils/to_queued_task.h" +#include "rtc_base/trace_event.h" #include "system_wrappers/include/field_trial.h" #include "system_wrappers/include/metrics.h" @@ -35,15 +39,6 @@ using rtc::CreateRandomId; namespace cricket { namespace { -enum { - MSG_CONFIG_START, - MSG_CONFIG_READY, - MSG_ALLOCATE, - MSG_ALLOCATION_PHASE, - MSG_SEQUENCEOBJECTS_CREATED, - MSG_CONFIG_STOP, -}; - const int PHASE_UDP = 0; const int PHASE_RELAY = 1; const int PHASE_TCP = 2; @@ -268,16 +263,18 @@ BasicPortAllocatorSession::BasicPortAllocatorSession( network_manager_started_(false), allocation_sequences_created_(false), turn_port_prune_policy_(allocator->turn_port_prune_policy()) { + TRACE_EVENT0("webrtc", + "BasicPortAllocatorSession::BasicPortAllocatorSession"); allocator_->network_manager()->SignalNetworksChanged.connect( this, &BasicPortAllocatorSession::OnNetworksChanged); allocator_->network_manager()->StartUpdating(); } BasicPortAllocatorSession::~BasicPortAllocatorSession() { + TRACE_EVENT0("webrtc", + "BasicPortAllocatorSession::~BasicPortAllocatorSession"); RTC_DCHECK_RUN_ON(network_thread_); allocator_->network_manager()->StopUpdating(); - if (network_thread_ != NULL) - network_thread_->Clear(this); for (uint32_t i = 0; i < sequences_.size(); ++i) { // AllocationSequence should clear it's map entry for turn ports before @@ -289,8 +286,7 @@ BasicPortAllocatorSession::~BasicPortAllocatorSession() { for (it = ports_.begin(); it != ports_.end(); it++) delete it->port(); - for (uint32_t i = 0; i < configs_.size(); ++i) - delete configs_[i]; + configs_.clear(); for (uint32_t i = 0; i < sequences_.size(); ++i) delete sequences_[i]; @@ -370,7 +366,8 @@ void BasicPortAllocatorSession::StartGettingPorts() { socket_factory_ = owned_socket_factory_.get(); } - network_thread_->Post(RTC_FROM_HERE, this, MSG_CONFIG_START); + network_thread_->PostTask(webrtc::ToQueuedTask( + network_safety_, [this] { GetPortConfigurations(); })); RTC_LOG(LS_INFO) << "Start getting ports with turn_port_prune_policy " << turn_port_prune_policy_; @@ -386,11 +383,12 @@ void BasicPortAllocatorSession::StopGettingPorts() { void BasicPortAllocatorSession::ClearGettingPorts() { RTC_DCHECK_RUN_ON(network_thread_); - network_thread_->Clear(this, MSG_ALLOCATE); + ++allocation_epoch_; for (uint32_t i = 0; i < sequences_.size(); ++i) { sequences_[i]->Stop(); } - network_thread_->Post(RTC_FROM_HERE, this, MSG_CONFIG_STOP); + network_thread_->PostTask( + webrtc::ToQueuedTask(network_safety_, [this] { OnConfigStop(); })); state_ = SessionState::CLEARED; } @@ -489,8 +487,10 @@ void BasicPortAllocatorSession::GetCandidateStatsFromReadyPorts( for (auto* port : ports) { auto candidates = port->Candidates(); for (const auto& candidate : candidates) { - CandidateStats candidate_stats(allocator_->SanitizeCandidate(candidate)); - port->GetStunStats(&candidate_stats.stun_stats); + absl::optional stun_stats; + port->GetStunStats(&stun_stats); + CandidateStats candidate_stats(allocator_->SanitizeCandidate(candidate), + std::move(stun_stats)); candidate_stats_list->push_back(std::move(candidate_stats)); } } @@ -574,28 +574,6 @@ bool BasicPortAllocatorSession::CandidatesAllocationDone() const { ports_, [](const PortData& port) { return port.inprogress(); }); } -void BasicPortAllocatorSession::OnMessage(rtc::Message* message) { - switch (message->message_id) { - case MSG_CONFIG_START: - GetPortConfigurations(); - break; - case MSG_CONFIG_READY: - OnConfigReady(static_cast(message->pdata)); - break; - case MSG_ALLOCATE: - OnAllocate(); - break; - case MSG_SEQUENCEOBJECTS_CREATED: - OnAllocationSequenceObjectsCreated(); - break; - case MSG_CONFIG_STOP: - OnConfigStop(); - break; - default: - RTC_NOTREACHED(); - } -} - void BasicPortAllocatorSession::UpdateIceParametersInternal() { RTC_DCHECK_RUN_ON(network_thread_); for (PortData& port : ports_) { @@ -607,26 +585,35 @@ void BasicPortAllocatorSession::UpdateIceParametersInternal() { void BasicPortAllocatorSession::GetPortConfigurations() { RTC_DCHECK_RUN_ON(network_thread_); - PortConfiguration* config = - new PortConfiguration(allocator_->stun_servers(), username(), password()); + auto config = std::make_unique(allocator_->stun_servers(), + username(), password()); for (const RelayServerConfig& turn_server : allocator_->turn_servers()) { config->AddRelay(turn_server); } - ConfigReady(config); + ConfigReady(std::move(config)); } void BasicPortAllocatorSession::ConfigReady(PortConfiguration* config) { RTC_DCHECK_RUN_ON(network_thread_); - network_thread_->Post(RTC_FROM_HERE, this, MSG_CONFIG_READY, config); + ConfigReady(absl::WrapUnique(config)); +} + +void BasicPortAllocatorSession::ConfigReady( + std::unique_ptr config) { + RTC_DCHECK_RUN_ON(network_thread_); + network_thread_->PostTask(webrtc::ToQueuedTask( + network_safety_, [this, config = std::move(config)]() mutable { + OnConfigReady(std::move(config)); + })); } // Adds a configuration to the list. -void BasicPortAllocatorSession::OnConfigReady(PortConfiguration* config) { +void BasicPortAllocatorSession::OnConfigReady( + std::unique_ptr config) { RTC_DCHECK_RUN_ON(network_thread_); - if (config) { - configs_.push_back(config); - } + if (config) + configs_.push_back(std::move(config)); AllocatePorts(); } @@ -664,11 +651,16 @@ void BasicPortAllocatorSession::OnConfigStop() { void BasicPortAllocatorSession::AllocatePorts() { RTC_DCHECK_RUN_ON(network_thread_); - network_thread_->Post(RTC_FROM_HERE, this, MSG_ALLOCATE); + network_thread_->PostTask(webrtc::ToQueuedTask( + network_safety_, [this, allocation_epoch = allocation_epoch_] { + OnAllocate(allocation_epoch); + })); } -void BasicPortAllocatorSession::OnAllocate() { +void BasicPortAllocatorSession::OnAllocate(int allocation_epoch) { RTC_DCHECK_RUN_ON(network_thread_); + if (allocation_epoch != allocation_epoch_) + return; if (network_manager_started_ && !IsStopped()) { bool disable_equivalent_phases = true; @@ -774,7 +766,8 @@ void BasicPortAllocatorSession::DoAllocate(bool disable_equivalent) { done_signal_needed = true; } else { RTC_LOG(LS_INFO) << "Allocate ports on " << networks.size() << " networks"; - PortConfiguration* config = configs_.empty() ? nullptr : configs_.back(); + PortConfiguration* config = + configs_.empty() ? nullptr : configs_.back().get(); for (uint32_t i = 0; i < networks.size(); ++i) { uint32_t sequence_flags = flags(); if ((sequence_flags & DISABLE_ALL_PHASES) == DISABLE_ALL_PHASES) { @@ -814,9 +807,11 @@ void BasicPortAllocatorSession::DoAllocate(bool disable_equivalent) { } AllocationSequence* sequence = - new AllocationSequence(this, networks[i], config, sequence_flags); - sequence->SignalPortAllocationComplete.connect( - this, &BasicPortAllocatorSession::OnPortAllocationComplete); + new AllocationSequence(this, networks[i], config, sequence_flags, + [this, safety_flag = network_safety_.flag()] { + if (safety_flag->alive()) + OnPortAllocationComplete(); + }); sequence->Init(); sequence->Start(); sequences_.push_back(sequence); @@ -824,7 +819,8 @@ void BasicPortAllocatorSession::DoAllocate(bool disable_equivalent) { } } if (done_signal_needed) { - network_thread_->Post(RTC_FROM_HERE, this, MSG_SEQUENCEOBJECTS_CREATED); + network_thread_->PostTask(webrtc::ToQueuedTask( + network_safety_, [this] { OnAllocationSequenceObjectsCreated(); })); } } @@ -900,8 +896,9 @@ void BasicPortAllocatorSession::AddAllocatedPort(Port* port, this, &BasicPortAllocatorSession::OnCandidateError); port->SignalPortComplete.connect(this, &BasicPortAllocatorSession::OnPortComplete); - port->SignalDestroyed.connect(this, - &BasicPortAllocatorSession::OnPortDestroyed); + port->SubscribePortDestroyed( + [this](PortInterface* port) { OnPortDestroyed(port); }); + port->SignalPortError.connect(this, &BasicPortAllocatorSession::OnPortError); RTC_LOG(LS_INFO) << port->ToString() << ": Added port to allocator"; @@ -1127,8 +1124,7 @@ bool BasicPortAllocatorSession::CandidatePairable(const Candidate& c, !host_candidates_disabled); } -void BasicPortAllocatorSession::OnPortAllocationComplete( - AllocationSequence* seq) { +void BasicPortAllocatorSession::OnPortAllocationComplete() { RTC_DCHECK_RUN_ON(network_thread_); // Send candidate allocation complete signal if all ports are done. MaybeSignalCandidatesAllocationDone(); @@ -1219,10 +1215,12 @@ void BasicPortAllocatorSession::PrunePortsAndRemoveCandidates( // AllocationSequence -AllocationSequence::AllocationSequence(BasicPortAllocatorSession* session, - rtc::Network* network, - PortConfiguration* config, - uint32_t flags) +AllocationSequence::AllocationSequence( + BasicPortAllocatorSession* session, + rtc::Network* network, + PortConfiguration* config, + uint32_t flags, + std::function port_allocation_complete_callback) : session_(session), network_(network), config_(config), @@ -1230,7 +1228,9 @@ AllocationSequence::AllocationSequence(BasicPortAllocatorSession* session, flags_(flags), udp_socket_(), udp_port_(NULL), - phase_(0) {} + phase_(0), + port_allocation_complete_callback_( + std::move(port_allocation_complete_callback)) {} void AllocationSequence::Init() { if (IsFlagSet(PORTALLOCATOR_ENABLE_SHARED_SOCKET)) { @@ -1247,6 +1247,7 @@ void AllocationSequence::Init() { } void AllocationSequence::Clear() { + TRACE_EVENT0("webrtc", "AllocationSequence::Clear"); udp_port_ = NULL; relay_ports_.clear(); } @@ -1258,10 +1259,6 @@ void AllocationSequence::OnNetworkFailed() { Stop(); } -AllocationSequence::~AllocationSequence() { - session_->network_thread()->Clear(this); -} - void AllocationSequence::DisableEquivalentPhases(rtc::Network* network, PortConfiguration* config, uint32_t* flags) { @@ -1336,7 +1333,9 @@ void AllocationSequence::DisableEquivalentPhases(rtc::Network* network, void AllocationSequence::Start() { state_ = kRunning; - session_->network_thread()->Post(RTC_FROM_HERE, this, MSG_ALLOCATION_PHASE); + + session_->network_thread()->PostTask(webrtc::ToQueuedTask( + safety_, [this, epoch = epoch_] { Process(epoch); })); // Take a snapshot of the best IP, so that when DisableEquivalentPhases is // called next time, we enable all phases if the best IP has since changed. previous_best_ip_ = network_->GetBestIP(); @@ -1346,16 +1345,18 @@ void AllocationSequence::Stop() { // If the port is completed, don't set it to stopped. if (state_ == kRunning) { state_ = kStopped; - session_->network_thread()->Clear(this, MSG_ALLOCATION_PHASE); + // Cause further Process calls in the previous epoch to be ignored. + ++epoch_; } } -void AllocationSequence::OnMessage(rtc::Message* msg) { +void AllocationSequence::Process(int epoch) { RTC_DCHECK(rtc::Thread::Current() == session_->network_thread()); - RTC_DCHECK(msg->message_id == MSG_ALLOCATION_PHASE); - const char* const PHASE_NAMES[kNumPhases] = {"Udp", "Relay", "Tcp"}; + if (epoch != epoch_) + return; + // Perform all of the phases in the current step. RTC_LOG(LS_INFO) << network_->ToString() << ": Allocation Phase=" << PHASE_NAMES[phase_]; @@ -1381,14 +1382,16 @@ void AllocationSequence::OnMessage(rtc::Message* msg) { if (state() == kRunning) { ++phase_; - session_->network_thread()->PostDelayed(RTC_FROM_HERE, - session_->allocator()->step_delay(), - this, MSG_ALLOCATION_PHASE); + session_->network_thread()->PostDelayedTask( + webrtc::ToQueuedTask(safety_, + [this, epoch = epoch_] { Process(epoch); }), + session_->allocator()->step_delay()); } else { - // If all phases in AllocationSequence are completed, no allocation - // steps needed further. Canceling pending signal. - session_->network_thread()->Clear(this, MSG_ALLOCATION_PHASE); - SignalPortAllocationComplete(this); + // No allocation steps needed further if all phases in AllocationSequence + // are completed. Cause further Process calls in the previous epoch to be + // ignored. + ++epoch_; + port_allocation_complete_callback_(); } } @@ -1423,7 +1426,8 @@ void AllocationSequence::CreateUDPPorts() { // UDPPort. if (IsFlagSet(PORTALLOCATOR_ENABLE_SHARED_SOCKET)) { udp_port_ = port.get(); - port->SignalDestroyed.connect(this, &AllocationSequence::OnPortDestroyed); + port->SubscribePortDestroyed( + [this](PortInterface* port) { OnPortDestroyed(port); }); // If STUN is not disabled, setting stun server address to port. if (!IsFlagSet(PORTALLOCATOR_DISABLE_STUN)) { @@ -1561,8 +1565,10 @@ void AllocationSequence::CreateTurnPort(const RelayServerConfig& config) { relay_ports_.push_back(port.get()); // Listen to the port destroyed signal, to allow AllocationSequence to - // remove entrt from it's map. - port->SignalDestroyed.connect(this, &AllocationSequence::OnPortDestroyed); + // remove the entry from it's map. + port->SubscribePortDestroyed( + [this](PortInterface* port) { OnPortDestroyed(port); }); + } else { port = session_->allocator()->relay_port_factory()->Create( args, session_->allocator()->min_port(), @@ -1653,8 +1659,6 @@ PortConfiguration::PortConfiguration(const ServerAddresses& stun_servers, webrtc::field_trial::IsDisabled("WebRTC-UseTurnServerAsStunServer"); } -PortConfiguration::~PortConfiguration() = default; - ServerAddresses PortConfiguration::StunServers() { if (!stun_address.IsNil() && stun_servers.find(stun_address) == stun_servers.end()) { diff --git a/p2p/client/basic_port_allocator.h b/p2p/client/basic_port_allocator.h index b27016a1dc..77aceb1e9c 100644 --- a/p2p/client/basic_port_allocator.h +++ b/p2p/client/basic_port_allocator.h @@ -22,7 +22,9 @@ #include "rtc_base/checks.h" #include "rtc_base/network.h" #include "rtc_base/system/rtc_export.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/thread.h" +#include "rtc_base/thread_annotations.h" namespace cricket { @@ -106,8 +108,9 @@ enum class SessionState { // process will be started. }; -class RTC_EXPORT BasicPortAllocatorSession : public PortAllocatorSession, - public rtc::MessageHandler { +// This class is thread-compatible and assumes it's created, operated upon and +// destroyed on the network thread. +class RTC_EXPORT BasicPortAllocatorSession : public PortAllocatorSession { public: BasicPortAllocatorSession(BasicPortAllocator* allocator, const std::string& content_name, @@ -155,10 +158,11 @@ class RTC_EXPORT BasicPortAllocatorSession : public PortAllocatorSession, // Adds a port configuration that is now ready. Once we have one for each // network (or a timeout occurs), we will start allocating ports. - virtual void ConfigReady(PortConfiguration* config); - - // MessageHandler. Can be overriden if message IDs do not conflict. - void OnMessage(rtc::Message* message) override; + void ConfigReady(std::unique_ptr config); + // TODO(bugs.webrtc.org/12840) Remove once unused in downstream projects. + ABSL_DEPRECATED( + "Use ConfigReady(std::unique_ptr) instead!") + void ConfigReady(PortConfiguration* config); private: class PortData { @@ -213,10 +217,10 @@ class RTC_EXPORT BasicPortAllocatorSession : public PortAllocatorSession, State state_ = STATE_INPROGRESS; }; - void OnConfigReady(PortConfiguration* config); + void OnConfigReady(std::unique_ptr config); void OnConfigStop(); void AllocatePorts(); - void OnAllocate(); + void OnAllocate(int allocation_epoch); void DoAllocate(bool disable_equivalent_phases); void OnNetworksChanged(); void OnAllocationSequenceObjectsCreated(); @@ -233,7 +237,7 @@ class RTC_EXPORT BasicPortAllocatorSession : public PortAllocatorSession, void OnProtocolEnabled(AllocationSequence* seq, ProtocolType proto); void OnPortDestroyed(PortInterface* port); void MaybeSignalCandidatesAllocationDone(); - void OnPortAllocationComplete(AllocationSequence* seq); + void OnPortAllocationComplete(); PortData* FindPort(Port* port); std::vector GetNetworks(); std::vector GetFailedNetworks(); @@ -266,7 +270,7 @@ class RTC_EXPORT BasicPortAllocatorSession : public PortAllocatorSession, bool allocation_started_; bool network_manager_started_; bool allocation_sequences_created_; - std::vector configs_; + std::vector> configs_; std::vector sequences_; std::vector ports_; std::vector candidate_error_events_; @@ -274,13 +278,15 @@ class RTC_EXPORT BasicPortAllocatorSession : public PortAllocatorSession, // Policy on how to prune turn ports, taken from the port allocator. webrtc::PortPrunePolicy turn_port_prune_policy_; SessionState state_ = SessionState::CLEARED; + int allocation_epoch_ RTC_GUARDED_BY(network_thread_) = 0; + webrtc::ScopedTaskSafety network_safety_; friend class AllocationSequence; }; // Records configuration information useful in creating ports. // TODO(deadbeef): Rename "relay" to "turn_server" in this struct. -struct RTC_EXPORT PortConfiguration : public rtc::MessageData { +struct RTC_EXPORT PortConfiguration { // TODO(jiayl): remove |stun_address| when Chrome is updated. rtc::SocketAddress stun_address; ServerAddresses stun_servers; @@ -300,8 +306,6 @@ struct RTC_EXPORT PortConfiguration : public rtc::MessageData { const std::string& username, const std::string& password); - ~PortConfiguration() override; - // Returns addresses of both the explicitly configured STUN servers, // and TURN servers that should be used as STUN servers. ServerAddresses StunServers(); @@ -323,8 +327,8 @@ class TurnPort; // Performs the allocation of ports, in a sequenced (timed) manner, for a given // network and IP address. -class AllocationSequence : public rtc::MessageHandler, - public sigslot::has_slots<> { +// This class is thread-compatible. +class AllocationSequence : public sigslot::has_slots<> { public: enum State { kInit, // Initial state. @@ -334,11 +338,18 @@ class AllocationSequence : public rtc::MessageHandler, // kInit --> kRunning --> {kCompleted|kStopped} }; + // |port_allocation_complete_callback| is called when AllocationSequence is + // done with allocating ports. This signal is useful when port allocation + // fails which doesn't result in any candidates. Using this signal + // BasicPortAllocatorSession can send its candidate discovery conclusion + // signal. Without this signal, BasicPortAllocatorSession doesn't have any + // event to trigger signal. This can also be achieved by starting a timer in + // BPAS, but this is less deterministic. AllocationSequence(BasicPortAllocatorSession* session, rtc::Network* network, PortConfiguration* config, - uint32_t flags); - ~AllocationSequence() override; + uint32_t flags, + std::function port_allocation_complete_callback); void Init(); void Clear(); void OnNetworkFailed(); @@ -360,17 +371,6 @@ class AllocationSequence : public rtc::MessageHandler, void Start(); void Stop(); - // MessageHandler - void OnMessage(rtc::Message* msg) override; - - // Signal from AllocationSequence, when it's done with allocating ports. - // This signal is useful, when port allocation fails which doesn't result - // in any candidates. Using this signal BasicPortAllocatorSession can send - // its candidate discovery conclusion signal. Without this signal, - // BasicPortAllocatorSession doesn't have any event to trigger signal. This - // can also be achieved by starting timer in BPAS. - sigslot::signal1 SignalPortAllocationComplete; - protected: // For testing. void CreateTurnPort(const RelayServerConfig& config); @@ -378,6 +378,7 @@ class AllocationSequence : public rtc::MessageHandler, private: typedef std::vector ProtocolList; + void Process(int epoch); bool IsFlagSet(uint32_t flag) { return ((flags_ & flag) != 0); } void CreateUDPPorts(); void CreateTCPPorts(); @@ -406,6 +407,12 @@ class AllocationSequence : public rtc::MessageHandler, UDPPort* udp_port_; std::vector relay_ports_; int phase_; + std::function port_allocation_complete_callback_; + // This counter is sampled and passed together with tasks when tasks are + // posted. If the sampled counter doesn't match |epoch_| on reception, the + // posted task is ignored. + int epoch_ = 0; + webrtc::ScopedTaskSafety safety_; }; } // namespace cricket diff --git a/p2p/g3doc/ice.md b/p2p/g3doc/ice.md new file mode 100644 index 0000000000..be81ff9e22 --- /dev/null +++ b/p2p/g3doc/ice.md @@ -0,0 +1,102 @@ +# ICE + + + + +## Overview + +ICE ([link](https://developer.mozilla.org/en-US/docs/Glossary/ICE)) provides +unreliable packet transport between two clients (p2p) or between a client and a +server. + +This documentation provides an overview of how ICE is implemented, i.e how the +following classes interact. + +* [`cricket::IceTransportInternal`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/ice_transport_internal.h;l=225;drc=8cb97062880b0e0a78f9d578370a01aced81a13f) - + is the interface that does ICE (manage ports, candidates, connections to + send/receive packets). The interface is implemented by + [`cricket::P2PTransportChannel`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/p2p_transport_channel.h;l=103;drc=0ccfbd2de7bc3b237a0f8c30f48666c97b9e5523). + +* [`cricket::PortInterface`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/port_interface.h;l=47;drc=c3a486c41e682cce943f2b20fe987c9421d4b631) + Represents a local communication mechanism that can be used to create + connections to similar mechanisms of the other client. There are 4 + implementations of `cricket::PortInterface` + [`cricket::UDPPort`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/stun_port.h;l=33;drc=a4d873786f10eedd72de25ad0d94ad7c53c1f68a), + [`cricket::StunPort`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/stun_port.h;l=265;drc=a4d873786f10eedd72de25ad0d94ad7c53c1f68a), + [`cricket::TcpPort`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/tcp_port.h;l=33;drc=7a284e1614a38286477ed2334ecbdde78e87b79c) + and + [`cricket::TurnPort`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/turn_port.h;l=44;drc=ffb7603b6025fbd6e79f360d293ab49092bded54). + The ports share lots of functionality in a base class, + [`cricket::Port`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/port.h;l=187;drc=3ba7beba29c4e542c4a9bffcc5a47d5e911865be). + +* [`cricket::Candidate`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/candidate.h;l=30;drc=10542f21c8e4e2d60b136fab45338f2b1e132dde) + represents an address discovered by a `cricket::Port`. A candidate can be + local (i.e discovered by a local port) or remote. Remote candidates are + transported using signaling, i.e outside of webrtc. There are 4 types of + candidates: `local`, `stun`, `prflx` or `relay` + ([standard](https://developer.mozilla.org/en-US/docs/Web/API/RTCIceCandidateType)) + +* [`cricket::Connection`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/connection.h) + provides the management of a `cricket::CandidatePair`, i.e for sending data + between two candidates. It sends STUN Binding requests (aka STUN pings) to + verify that packets can traverse back and forth and keep connections alive + (both that NAT binding is kept, and that the remote peer still wants the + connection to remain open). + +* `cricket::P2PTransportChannel` uses an + [`cricket::PortAllocator`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/port_allocator.h;l=335;drc=9438fb3fff97c803d1ead34c0e4f223db168526f) + to create ports and discover local candidates. The `cricket::PortAllocator` + is implemented by + [`cricket::BasicPortAllocator`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/client/basic_port_allocator.h;l=29;drc=e27f3dea8293884701283a54f90f8a429ea99505). + +* `cricket::P2PTransportChannel` uses an + [`cricket::IceControllerInterface`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/ice_controller_interface.h;l=73;drc=9438fb3fff97c803d1ead34c0e4f223db168526f) + to manage a set of connections. The `cricket::IceControllerInterface` + decides which `cricket::Connection` to send data on. + +## Connection establishment + +This section describes a normal sequence of interactions to establish ice state +completed +[ link ](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/ice_transport_internal.h;l=208;drc=9438fb3fff97c803d1ead34c0e4f223db168526f) +([ standard ](https://developer.mozilla.org/en-US/docs/Web/API/RTCPeerConnection/iceConnectionState)) + +All of these steps are invoked by interactions with `PeerConnection`. + +1. [`P2PTransportChannel::MaybeStartGathering`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/p2p_transport_channel.cc;l=864;drc=0ccfbd2de7bc3b237a0f8c30f48666c97b9e5523) + This function is invoked as part of `PeerConnection::SetLocalDescription`. + `P2PTransportChannel` will use the `cricket::PortAllocator` to create a + `cricket::PortAllocatorSession`. The `cricket::PortAllocatorSession` will + create local ports as configured, and the ports will start gathering + candidates. + +2. [`IceTransportInternal::SignalCandidateGathered`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/ice_transport_internal.h;l=293;drc=8cb97062880b0e0a78f9d578370a01aced81a13f) + When a port finds a local candidate, it will be added to a list on + `cricket::P2PTransportChannel` and signaled to application using + `IceTransportInternal::SignalCandidateGathered`. A p2p application can then + send them to peer using favorite transport mechanism whereas a client-server + application will do nothing. + +3. [`P2PTransportChannel::AddRemoteCandidate`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/p2p_transport_channel.cc;l=1233;drc=0ccfbd2de7bc3b237a0f8c30f48666c97b9e5523) + When the application get a remote candidate, it can add it using + `PeerConnection::AddRemoteCandidate` (after + `PeerConnection::SetRemoteDescription` has been called!), this will trickle + down to `P2PTransportChannel::AddRemoteCandidate`. `P2PTransportChannel` + will combine the remote candidate with all compatible local candidates to + form new `cricket::Connection`(s). Candidates are compatible if it is + possible to send/receive data (e.g ipv4 can only send to ipv4, tcp can only + connect to tcp etc...) The newly formed `cricket::Connection`(s) will be + added to the `cricket::IceController` that will decide which + `cricket::Connection` to send STUN ping on. + +4. [`P2PTransportChannel::SignalCandidatePairChanged`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/ice_transport_internal.h;l=310;drc=8cb97062880b0e0a78f9d578370a01aced81a13f) + When a remote connection replies to a STUN ping, `cricket::IceController` + will instruct `P2PTransportChannel` to use the connection. This is signalled + up the stack using `P2PTransportChannel::SignalCandidatePairChanged`. Note + that `cricket::IceController` will continue to send STUN pings on the + selected connection, as well as other connections. + +5. [`P2PTransportChannel::SignalIceTransportStateChanged`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/ice_transport_internal.h;l=323;drc=8cb97062880b0e0a78f9d578370a01aced81a13f) + The initial selection of a connection makes `P2PTransportChannel` signal up + stack that state has changed, which may make [`cricket::DtlsTransportInternal`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/dtls_transport_internal.h;l=63;drc=653bab6790ac92c513b7cf4cd3ad59039c589a95) + initiate a DTLS handshake (depending on the DTLS role). diff --git a/p2p/stunprober/stun_prober.cc b/p2p/stunprober/stun_prober.cc index f37f24994a..d85d5f27ea 100644 --- a/p2p/stunprober/stun_prober.cc +++ b/p2p/stunprober/stun_prober.cc @@ -20,11 +20,11 @@ #include "api/transport/stun.h" #include "rtc_base/async_packet_socket.h" #include "rtc_base/async_resolver_interface.h" -#include "rtc_base/bind.h" #include "rtc_base/checks.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/helpers.h" #include "rtc_base/logging.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/thread.h" #include "rtc_base/time_utils.h" @@ -104,7 +104,7 @@ class StunProber::Requester : public sigslot::has_slots<> { int16_t num_request_sent_ = 0; int16_t num_response_received_ = 0; - rtc::ThreadChecker& thread_checker_; + webrtc::SequenceChecker& thread_checker_; RTC_DISALLOW_COPY_AND_ASSIGN(Requester); }; @@ -262,6 +262,7 @@ StunProber::StunProber(rtc::PacketSocketFactory* socket_factory, networks_(networks) {} StunProber::~StunProber() { + RTC_DCHECK(thread_checker_.IsCurrent()); for (auto* req : requesters_) { if (req) { delete req; @@ -358,9 +359,8 @@ void StunProber::OnServerResolved(rtc::AsyncResolverInterface* resolver) { // Deletion of AsyncResolverInterface can't be done in OnResolveResult which // handles SignalDone. - invoker_.AsyncInvoke( - RTC_FROM_HERE, thread_, - rtc::Bind(&rtc::AsyncResolverInterface::Destroy, resolver, false)); + thread_->PostTask( + webrtc::ToQueuedTask([resolver] { resolver->Destroy(false); })); servers_.pop_back(); if (servers_.size()) { @@ -453,13 +453,14 @@ int StunProber::get_wake_up_interval_ms() { } void StunProber::MaybeScheduleStunRequests() { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(thread_); int64_t now = rtc::TimeMillis(); if (Done()) { - invoker_.AsyncInvokeDelayed( - RTC_FROM_HERE, thread_, - rtc::Bind(&StunProber::ReportOnFinished, this, SUCCESS), timeout_ms_); + thread_->PostDelayedTask( + webrtc::ToQueuedTask(task_safety_.flag(), + [this] { ReportOnFinished(SUCCESS); }), + timeout_ms_); return; } if (should_send_next_request(now)) { @@ -469,9 +470,9 @@ void StunProber::MaybeScheduleStunRequests() { } next_request_time_ms_ = now + interval_ms_; } - invoker_.AsyncInvokeDelayed( - RTC_FROM_HERE, thread_, - rtc::Bind(&StunProber::MaybeScheduleStunRequests, this), + thread_->PostDelayedTask( + webrtc::ToQueuedTask(task_safety_.flag(), + [this] { MaybeScheduleStunRequests(); }), get_wake_up_interval_ms()); } diff --git a/p2p/stunprober/stun_prober.h b/p2p/stunprober/stun_prober.h index a739a6c98b..43d84ff806 100644 --- a/p2p/stunprober/stun_prober.h +++ b/p2p/stunprober/stun_prober.h @@ -15,16 +15,15 @@ #include #include -#include "rtc_base/async_invoker.h" +#include "api/sequence_checker.h" #include "rtc_base/byte_buffer.h" -#include "rtc_base/callback.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/ip_address.h" #include "rtc_base/network.h" #include "rtc_base/socket_address.h" #include "rtc_base/system/rtc_export.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/thread.h" -#include "rtc_base/thread_checker.h" namespace rtc { class AsyncPacketSocket; @@ -40,7 +39,7 @@ class StunProber; static const int kMaxUdpBufferSize = 1200; -typedef rtc::Callback2 AsyncCallback; +typedef std::function AsyncCallback; enum NatType { NATTYPE_INVALID, @@ -227,15 +226,13 @@ class RTC_EXPORT StunProber : public sigslot::has_slots<> { // The set of STUN probe sockets and their state. std::vector requesters_; - rtc::ThreadChecker thread_checker_; + webrtc::SequenceChecker thread_checker_; // Temporary storage for created sockets. std::vector sockets_; // This tracks how many of the sockets are ready. size_t total_ready_sockets_ = 0; - rtc::AsyncInvoker invoker_; - Observer* observer_ = nullptr; // TODO(guoweis): Remove this once all dependencies move away from // AsyncCallback. @@ -243,6 +240,8 @@ class RTC_EXPORT StunProber : public sigslot::has_slots<> { rtc::NetworkManager::NetworkList networks_; + webrtc::ScopedTaskSafety task_safety_; + RTC_DISALLOW_COPY_AND_ASSIGN(StunProber); }; diff --git a/pc/BUILD.gn b/pc/BUILD.gn index 143ce25c74..460462e54a 100644 --- a/pc/BUILD.gn +++ b/pc/BUILD.gn @@ -19,10 +19,24 @@ group("pc") { config("rtc_pc_config") { defines = [] if (rtc_enable_sctp) { - defines += [ "HAVE_SCTP" ] + defines += [ "WEBRTC_HAVE_SCTP" ] } } +rtc_library("proxy") { + sources = [ + "proxy.cc", + "proxy.h", + ] + deps = [ + "../api:scoped_refptr", + "../api/task_queue", + "../rtc_base:rtc_base_approved", + "../rtc_base:threading", + "../rtc_base/system:rtc_export", + ] +} + rtc_library("rtc_pc_base") { visibility = [ "*" ] defines = [] @@ -32,8 +46,6 @@ rtc_library("rtc_pc_base") { "channel_interface.h", "channel_manager.cc", "channel_manager.h", - "composite_rtp_transport.cc", - "composite_rtp_transport.h", "dtls_srtp_transport.cc", "dtls_srtp_transport.h", "dtls_transport.cc", @@ -44,14 +56,22 @@ rtc_library("rtc_pc_base") { "ice_transport.h", "jsep_transport.cc", "jsep_transport.h", + "jsep_transport_collection.cc", + "jsep_transport_collection.h", "jsep_transport_controller.cc", "jsep_transport_controller.h", "media_session.cc", "media_session.h", + "media_stream_proxy.h", + "media_stream_track_proxy.h", + "peer_connection_factory_proxy.h", + "peer_connection_proxy.h", "rtcp_mux_filter.cc", "rtcp_mux_filter.h", "rtp_media_utils.cc", "rtp_media_utils.h", + "rtp_receiver_proxy.h", + "rtp_sender_proxy.h", "rtp_transport.cc", "rtp_transport.h", "rtp_transport_internal.h", @@ -61,10 +81,6 @@ rtc_library("rtc_pc_base") { "sctp_transport.h", "sctp_utils.cc", "sctp_utils.h", - "session_description.cc", - "session_description.h", - "simulcast_description.cc", - "simulcast_description.h", "srtp_filter.cc", "srtp_filter.h", "srtp_session.cc", @@ -74,51 +90,71 @@ rtc_library("rtc_pc_base") { "transport_stats.cc", "transport_stats.h", "used_ids.h", + "video_track_source_proxy.cc", + "video_track_source_proxy.h", ] deps = [ ":media_protocol_names", + ":proxy", + ":session_description", + ":simulcast_description", "../api:array_view", + "../api:async_dns_resolver", "../api:audio_options_api", "../api:call_api", "../api:function_view", "../api:ice_transport_factory", "../api:libjingle_peerconnection_api", + "../api:media_stream_interface", + "../api:packet_socket_factory", "../api:priority", "../api:rtc_error", "../api:rtp_headers", "../api:rtp_parameters", "../api:rtp_parameters", + "../api:rtp_transceiver_direction", "../api:scoped_refptr", + "../api:sequence_checker", "../api/crypto:options", "../api/rtc_event_log", + "../api/task_queue", "../api/transport:datagram_transport_interface", + "../api/transport:enums", + "../api/transport:sctp_transport_factory_interface", "../api/video:builtin_video_bitrate_allocator_factory", + "../api/video:video_bitrate_allocator_factory", "../api/video:video_frame", "../api/video:video_rtp_headers", + "../api/video_codecs:video_codecs_api", "../call:call_interfaces", "../call:rtp_interfaces", "../call:rtp_receiver", "../common_video", "../common_video:common_video", "../logging:ice_log", - "../media:rtc_data", - "../media:rtc_h264_profile_id", + "../media:rtc_data_sctp_transport_internal", "../media:rtc_media_base", "../media:rtc_media_config", + "../media:rtc_sdp_video_format_utils", "../modules/rtp_rtcp:rtp_rtcp", "../modules/rtp_rtcp:rtp_rtcp_format", "../p2p:rtc_p2p", "../rtc_base", "../rtc_base:callback_list", "../rtc_base:checks", - "../rtc_base:deprecation", "../rtc_base:rtc_task_queue", + "../rtc_base:socket", + "../rtc_base:socket_address", "../rtc_base:stringutils", + "../rtc_base:threading", + "../rtc_base/network:sent_packet", "../rtc_base/synchronization:mutex", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/system:file_wrapper", + "../rtc_base/system:no_unique_address", "../rtc_base/system:rtc_export", + "../rtc_base/task_utils:pending_task_safety_flag", + "../rtc_base/task_utils:to_queued_task", "../rtc_base/third_party/base64", "../rtc_base/third_party/sigslot", "../system_wrappers:field_trial", @@ -139,6 +175,43 @@ rtc_library("rtc_pc_base") { public_configs = [ ":rtc_pc_config" ] } +rtc_source_set("session_description") { + visibility = [ "*" ] + sources = [ + "session_description.cc", + "session_description.h", + ] + deps = [ + ":media_protocol_names", + ":simulcast_description", + "../api:libjingle_peerconnection_api", + "../api:rtp_parameters", + "../api:rtp_transceiver_direction", + "../media:rtc_media_base", + "../p2p:rtc_p2p", + "../rtc_base:checks", + "../rtc_base:socket_address", + "../rtc_base/system:rtc_export", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/memory:memory", + ] +} + +rtc_source_set("simulcast_description") { + visibility = [ "*" ] + sources = [ + "simulcast_description.cc", + "simulcast_description.h", + ] + deps = [ + "../rtc_base:checks", + "../rtc_base:socket_address", + "../rtc_base/system:rtc_export", + ] +} + rtc_source_set("rtc_pc") { visibility = [ "*" ] allow_poison = [ "audio_codecs" ] # TODO(bugs.webrtc.org/8396): Remove. @@ -180,8 +253,6 @@ rtc_library("peerconnection") { "rtc_stats_collector.h", "rtc_stats_traversal.cc", "rtc_stats_traversal.h", - "rtp_data_channel.cc", - "rtp_data_channel.h", "sctp_data_channel.cc", "sctp_data_channel.h", "sdp_offer_answer.cc", # TODO: Make separate target when not circular @@ -207,11 +278,10 @@ rtc_library("peerconnection") { ":connection_context", ":dtmf_sender", ":jitter_buffer_delay", - ":jitter_buffer_delay_interface", - ":jitter_buffer_delay_proxy", ":media_protocol_names", ":media_stream", ":peer_connection_message_handler", + ":proxy", ":remote_audio_source", ":rtc_pc_base", ":rtp_parameters_conversion", @@ -220,6 +290,8 @@ rtc_library("peerconnection") { ":rtp_transceiver", ":rtp_transmission_manager", ":sdp_state_provider", + ":session_description", + ":simulcast_description", ":stats_collector_interface", ":transceiver_list", ":usage_pattern", @@ -227,6 +299,7 @@ rtc_library("peerconnection") { ":video_track", ":video_track_source", "../api:array_view", + "../api:async_dns_resolver", "../api:audio_options_api", "../api:call_api", "../api:callfactory_api", @@ -245,7 +318,9 @@ rtc_library("peerconnection") { "../api:rtp_parameters", "../api:rtp_transceiver_direction", "../api:scoped_refptr", + "../api:sequence_checker", "../api/adaptation:resource_adaptation_api", + "../api/audio_codecs:audio_codecs_api", "../api/crypto:frame_decryptor_interface", "../api/crypto:options", "../api/neteq:neteq_api", @@ -266,28 +341,34 @@ rtc_library("peerconnection") { "../api/video:video_rtp_headers", "../api/video_codecs:video_codecs_api", "../call:call_interfaces", + "../call:rtp_interfaces", + "../call:rtp_sender", "../common_video", "../logging:ice_log", - "../media:rtc_data", + "../media:rtc_data_sctp_transport_internal", "../media:rtc_media_base", "../media:rtc_media_config", + "../modules/audio_processing:audio_processing_statistics", "../modules/rtp_rtcp:rtp_rtcp_format", "../p2p:rtc_p2p", "../rtc_base", "../rtc_base:callback_list", "../rtc_base:checks", - "../rtc_base:deprecation", + "../rtc_base:ip_address", + "../rtc_base:network_constants", "../rtc_base:rtc_base_approved", "../rtc_base:rtc_operations_chain", "../rtc_base:safe_minmax", + "../rtc_base:socket_address", + "../rtc_base:threading", "../rtc_base:weak_ptr", "../rtc_base/experiments:field_trial_parser", "../rtc_base/network:sent_packet", "../rtc_base/synchronization:mutex", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/system:file_wrapper", "../rtc_base/system:no_unique_address", "../rtc_base/system:rtc_export", + "../rtc_base/system:unused", "../rtc_base/task_utils:pending_task_safety_flag", "../rtc_base/task_utils:to_queued_task", "../rtc_base/third_party/base64", @@ -314,16 +395,20 @@ rtc_library("connection_context") { "../api:callfactory_api", "../api:libjingle_peerconnection_api", "../api:media_stream_interface", + "../api:refcountedbase", "../api:scoped_refptr", + "../api:sequence_checker", "../api/neteq:neteq_api", "../api/transport:field_trial_based_config", "../api/transport:sctp_transport_factory_interface", "../api/transport:webrtc_key_value_config", - "../media:rtc_data", + "../media:rtc_data_sctp_transport_factory", "../media:rtc_media_base", "../p2p:rtc_p2p", "../rtc_base", "../rtc_base:checks", + "../rtc_base:threading", + "../rtc_base/task_utils:to_queued_task", ] } @@ -337,8 +422,11 @@ rtc_library("peer_connection_message_handler") { "../api:libjingle_peerconnection_api", "../api:media_stream_interface", "../api:rtc_error", + "../api:scoped_refptr", + "../api:sequence_checker", "../rtc_base", - "../rtc_base/synchronization:sequence_checker", + "../rtc_base:checks", + "../rtc_base:threading", ] } @@ -360,14 +448,29 @@ rtc_library("rtp_transceiver") { "rtp_transceiver.h", ] deps = [ + ":proxy", ":rtc_pc_base", ":rtp_parameters_conversion", ":rtp_receiver", ":rtp_sender", + ":session_description", + "../api:array_view", "../api:libjingle_peerconnection_api", + "../api:rtc_error", "../api:rtp_parameters", + "../api:rtp_transceiver_direction", + "../api:scoped_refptr", + "../api:sequence_checker", + "../api/task_queue", + "../media:rtc_media_base", "../rtc_base:checks", "../rtc_base:logging", + "../rtc_base:macromagic", + "../rtc_base:refcount", + "../rtc_base:threading", + "../rtc_base/task_utils:pending_task_safety_flag", + "../rtc_base/task_utils:to_queued_task", + "../rtc_base/third_party/sigslot", ] absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container", @@ -397,9 +500,12 @@ rtc_library("rtp_transmission_manager") { "../api:rtp_parameters", "../api:rtp_transceiver_direction", "../api:scoped_refptr", + "../api:sequence_checker", "../media:rtc_media_base", "../rtc_base", "../rtc_base:checks", + "../rtc_base:threading", + "../rtc_base:weak_ptr", "../rtc_base/third_party/sigslot", ] absl_deps = [ @@ -414,7 +520,18 @@ rtc_library("transceiver_list") { "transceiver_list.cc", "transceiver_list.h", ] - deps = [ ":rtp_transceiver" ] + deps = [ + ":rtp_transceiver", + "../api:libjingle_peerconnection_api", + "../api:rtc_error", + "../api:rtp_parameters", + "../api:scoped_refptr", + "../api:sequence_checker", + "../rtc_base:checks", + "../rtc_base:macromagic", + "../rtc_base/system:no_unique_address", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } rtc_library("rtp_receiver") { @@ -424,6 +541,7 @@ rtc_library("rtp_receiver") { ] deps = [ ":media_stream", + ":rtc_pc_base", ":video_track_source", "../api:libjingle_peerconnection_api", "../api:media_stream_interface", @@ -436,6 +554,7 @@ rtc_library("rtp_receiver") { "../rtc_base:logging", "../rtc_base:rtc_base", "../rtc_base:rtc_base_approved", + "../rtc_base:threading", ] absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container", @@ -452,20 +571,26 @@ rtc_library("audio_rtp_receiver") { deps = [ ":audio_track", ":jitter_buffer_delay", - ":jitter_buffer_delay_interface", - ":jitter_buffer_delay_proxy", ":media_stream", ":remote_audio_source", + ":rtc_pc_base", ":rtp_receiver", + "../api:frame_transformer_interface", "../api:libjingle_peerconnection_api", "../api:media_stream_interface", "../api:rtp_parameters", "../api:scoped_refptr", + "../api:sequence_checker", "../api/crypto:frame_decryptor_interface", + "../api/transport/rtp:rtp_source", "../media:rtc_media_base", "../rtc_base", "../rtc_base:checks", "../rtc_base:refcount", + "../rtc_base:threading", + "../rtc_base/system:no_unique_address", + "../rtc_base/task_utils:pending_task_safety_flag", + "../rtc_base/task_utils:to_queued_task", ] absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container", @@ -481,9 +606,8 @@ rtc_library("video_rtp_receiver") { ] deps = [ ":jitter_buffer_delay", - ":jitter_buffer_delay_interface", - ":jitter_buffer_delay_proxy", ":media_stream", + ":rtc_pc_base", ":rtp_receiver", ":video_rtp_track_source", ":video_track", @@ -492,12 +616,17 @@ rtc_library("video_rtp_receiver") { "../api:media_stream_interface", "../api:rtp_parameters", "../api:scoped_refptr", + "../api:sequence_checker", "../api/crypto:frame_decryptor_interface", + "../api/transport/rtp:rtp_source", + "../api/video:recordable_encoded_frame", "../api/video:video_frame", "../media:rtc_media_base", "../rtc_base", "../rtc_base:checks", "../rtc_base:rtc_base_approved", + "../rtc_base:threading", + "../rtc_base/system:no_unique_address", ] absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container", @@ -513,8 +642,12 @@ rtc_library("video_rtp_track_source") { ] deps = [ ":video_track_source", + "../api:sequence_checker", + "../api/video:recordable_encoded_frame", + "../api/video:video_frame", "../media:rtc_media_base", "../rtc_base", + "../rtc_base:checks", "../rtc_base/synchronization:mutex", "../rtc_base/system:no_unique_address", ] @@ -528,9 +661,9 @@ rtc_library("audio_track") { deps = [ "../api:media_stream_interface", "../api:scoped_refptr", + "../api:sequence_checker", "../rtc_base:checks", "../rtc_base:refcount", - "../rtc_base:thread_checker", ] } @@ -542,25 +675,14 @@ rtc_library("video_track") { deps = [ "../api:media_stream_interface", "../api:scoped_refptr", + "../api:sequence_checker", "../api/video:video_frame", "../media:rtc_media_base", "../rtc_base", "../rtc_base:checks", "../rtc_base:refcount", "../rtc_base:rtc_base_approved", - ] -} - -rtc_source_set("jitter_buffer_delay_interface") { - sources = [ "jitter_buffer_delay_interface.h" ] - deps = [ - "../media:rtc_media_base", - "../rtc_base:refcount", - ] - absl_deps = [ - "//third_party/abseil-cpp/absl/algorithm:container", - "//third_party/abseil-cpp/absl/strings", - "//third_party/abseil-cpp/absl/types:optional", + "../rtc_base:threading", ] } @@ -572,33 +694,19 @@ rtc_source_set("sdp_state_provider") { ] } -rtc_source_set("jitter_buffer_delay_proxy") { - sources = [ "jitter_buffer_delay_proxy.h" ] - deps = [ - ":jitter_buffer_delay_interface", - "../api:libjingle_peerconnection_api", - "../media:rtc_media_base", - ] -} - rtc_library("jitter_buffer_delay") { sources = [ "jitter_buffer_delay.cc", "jitter_buffer_delay.h", ] deps = [ - ":jitter_buffer_delay_interface", - "../media:rtc_media_base", - "../rtc_base", + "../api:sequence_checker", "../rtc_base:checks", - "../rtc_base:refcount", + "../rtc_base:safe_conversions", "../rtc_base:safe_minmax", + "../rtc_base/system:no_unique_address", ] - absl_deps = [ - "//third_party/abseil-cpp/absl/algorithm:container", - "//third_party/abseil-cpp/absl/strings", - "//third_party/abseil-cpp/absl/types:optional", - ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } rtc_library("remote_audio_source") { @@ -611,12 +719,15 @@ rtc_library("remote_audio_source") { "../api:call_api", "../api:media_stream_interface", "../api:scoped_refptr", + "../api:sequence_checker", + "../media:rtc_media_base", "../rtc_base", "../rtc_base:checks", "../rtc_base:logging", "../rtc_base:rtc_base_approved", "../rtc_base:safe_conversions", "../rtc_base:stringutils", + "../rtc_base:threading", "../rtc_base/synchronization:mutex", ] absl_deps = [ @@ -635,12 +746,20 @@ rtc_library("rtp_sender") { ":dtmf_sender", ":stats_collector_interface", "../api:audio_options_api", + "../api:frame_transformer_interface", "../api:libjingle_peerconnection_api", "../api:media_stream_interface", + "../api:priority", + "../api:rtc_error", + "../api:rtp_parameters", + "../api:scoped_refptr", + "../api/crypto:frame_encryptor_interface", "../media:rtc_media_base", "../rtc_base:checks", "../rtc_base:rtc_base", + "../rtc_base:threading", "../rtc_base/synchronization:mutex", + "../rtc_base/third_party/sigslot", ] absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container", @@ -656,6 +775,7 @@ rtc_library("rtp_parameters_conversion") { ] deps = [ ":rtc_pc_base", + ":session_description", "../api:array_view", "../api:libjingle_peerconnection_api", "../api:rtc_error", @@ -677,9 +797,15 @@ rtc_library("dtmf_sender") { "dtmf_sender.h", ] deps = [ + ":proxy", "../api:libjingle_peerconnection_api", + "../api:scoped_refptr", "../rtc_base:checks", "../rtc_base:rtc_base", + "../rtc_base:threading", + "../rtc_base/task_utils:pending_task_safety_flag", + "../rtc_base/task_utils:to_queued_task", + "../rtc_base/third_party/sigslot", ] absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container", @@ -715,12 +841,15 @@ rtc_library("video_track_source") { ] deps = [ "../api:media_stream_interface", + "../api:sequence_checker", + "../api/video:recordable_encoded_frame", "../api/video:video_frame", "../media:rtc_media_base", "../rtc_base:checks", "../rtc_base:rtc_base_approved", "../rtc_base/system:rtc_export", ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } rtc_source_set("stats_collector_interface") { @@ -739,14 +868,13 @@ rtc_source_set("libjingle_peerconnection") { ] } -if (rtc_include_tests) { +if (rtc_include_tests && !build_with_chromium) { rtc_test("rtc_pc_unittests") { testonly = true sources = [ "channel_manager_unittest.cc", "channel_unittest.cc", - "composite_rtp_transport_test.cc", "dtls_srtp_transport_unittest.cc", "dtls_transport_unittest.cc", "ice_transport_unittest.cc", @@ -778,6 +906,7 @@ if (rtc_include_tests) { ":peerconnection", ":rtc_pc", ":rtc_pc_base", + ":session_description", ":video_rtp_receiver", "../api:array_view", "../api:audio_options_api", @@ -790,7 +919,7 @@ if (rtc_include_tests) { "../api/video/test:mock_recordable_encoded_frame", "../call:rtp_interfaces", "../call:rtp_receiver", - "../media:rtc_data", + "../media:rtc_data_sctp_transport_internal", "../media:rtc_media_base", "../media:rtc_media_tests_utils", "../modules/rtp_rtcp:rtp_rtcp_format", @@ -803,6 +932,9 @@ if (rtc_include_tests) { "../rtc_base:gunit_helpers", "../rtc_base:rtc_base_approved", "../rtc_base:rtc_base_tests_utils", + "../rtc_base:threading", + "../rtc_base/task_utils:pending_task_safety_flag", + "../rtc_base/task_utils:to_queued_task", "../rtc_base/third_party/sigslot", "../system_wrappers:metrics", "../test:field_trial", @@ -851,6 +983,8 @@ if (rtc_include_tests) { "../rtc_base:checks", "../rtc_base:gunit_helpers", "../rtc_base:rtc_base_tests_utils", + "../rtc_base:socket_address", + "../rtc_base:threading", "../system_wrappers", "../test:perf_test", "../test:test_support", @@ -881,89 +1015,10 @@ if (rtc_include_tests) { ] } - rtc_library("pc_test_utils") { - testonly = true - sources = [ - "test/fake_audio_capture_module.cc", - "test/fake_audio_capture_module.h", - "test/fake_data_channel_provider.h", - "test/fake_peer_connection_base.h", - "test/fake_peer_connection_for_stats.h", - "test/fake_periodic_video_source.h", - "test/fake_periodic_video_track_source.h", - "test/fake_rtc_certificate_generator.h", - "test/fake_video_track_renderer.h", - "test/fake_video_track_source.h", - "test/frame_generator_capturer_video_track_source.h", - "test/mock_channel_interface.h", - "test/mock_data_channel.h", - "test/mock_delayable.h", - "test/mock_peer_connection_observers.h", - "test/mock_rtp_receiver_internal.h", - "test/mock_rtp_sender_internal.h", - "test/peer_connection_test_wrapper.cc", - "test/peer_connection_test_wrapper.h", - "test/rtc_stats_obtainer.h", - "test/test_sdp_strings.h", - ] - - deps = [ - ":jitter_buffer_delay", - ":jitter_buffer_delay_interface", - ":libjingle_peerconnection", - ":peerconnection", - ":rtc_pc_base", - ":rtp_receiver", - ":rtp_sender", - ":video_track_source", - "../api:audio_options_api", - "../api:create_frame_generator", - "../api:create_peerconnection_factory", - "../api:libjingle_peerconnection_api", - "../api:media_stream_interface", - "../api:rtc_error", - "../api:rtc_stats_api", - "../api:scoped_refptr", - "../api/audio:audio_mixer_api", - "../api/audio_codecs:audio_codecs_api", - "../api/task_queue", - "../api/task_queue:default_task_queue_factory", - "../api/video:builtin_video_bitrate_allocator_factory", - "../api/video:video_frame", - "../api/video:video_rtp_headers", - "../api/video_codecs:builtin_video_decoder_factory", - "../api/video_codecs:builtin_video_encoder_factory", - "../api/video_codecs:video_codecs_api", - "../call:call_interfaces", - "../media:rtc_data", - "../media:rtc_media", - "../media:rtc_media_base", - "../media:rtc_media_tests_utils", - "../modules/audio_device", - "../modules/audio_processing", - "../modules/audio_processing:api", - "../p2p:fake_port_allocator", - "../p2p:p2p_test_utils", - "../p2p:rtc_p2p", - "../rtc_base", - "../rtc_base:checks", - "../rtc_base:gunit_helpers", - "../rtc_base:rtc_base_approved", - "../rtc_base:rtc_task_queue", - "../rtc_base:task_queue_for_test", - "../rtc_base/synchronization:mutex", - "../rtc_base/synchronization:sequence_checker", - "../rtc_base/task_utils:repeating_task", - "../rtc_base/third_party/sigslot", - "../test:test_support", - "../test:video_test_common", - ] - absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] - } - rtc_test("peerconnection_unittests") { testonly = true sources = [ + "data_channel_integrationtest.cc", "data_channel_unittest.cc", "dtmf_sender_unittest.cc", "ice_server_parsing_unittest.cc", @@ -1008,24 +1063,22 @@ if (rtc_include_tests) { "webrtc_sdp_unittest.cc", ] - if (rtc_enable_sctp) { - defines = [ "HAVE_SCTP" ] - } - deps = [ ":audio_rtp_receiver", ":audio_track", ":dtmf_sender", + ":integration_test_helpers", ":jitter_buffer_delay", - ":jitter_buffer_delay_interface", ":media_stream", ":peerconnection", + ":proxy", ":remote_audio_source", ":rtc_pc_base", ":rtp_parameters_conversion", ":rtp_receiver", ":rtp_sender", ":rtp_transceiver", + ":session_description", ":usage_pattern", ":video_rtp_receiver", ":video_rtp_track_source", @@ -1041,7 +1094,9 @@ if (rtc_include_tests) { "../api:libjingle_peerconnection_api", "../api:media_stream_interface", "../api:mock_rtp", + "../api:packet_socket_factory", "../api:rtc_error", + "../api:rtp_transceiver_direction", "../api:scoped_refptr", "../api/audio:audio_mixer_api", "../api/crypto:frame_decryptor_interface", @@ -1049,13 +1104,17 @@ if (rtc_include_tests) { "../api/crypto:options", "../api/rtc_event_log", "../api/rtc_event_log:rtc_event_log_factory", + "../api/task_queue", "../api/task_queue:default_task_queue_factory", "../api/transport:field_trial_based_config", + "../api/transport:webrtc_key_value_config", "../api/transport/rtp:rtp_source", "../api/units:time_delta", "../api/video:builtin_video_bitrate_allocator_factory", + "../api/video:video_rtp_headers", "../call/adaptation:resource_adaptation_test_utilities", "../logging:fake_rtc_event_log", + "../media:rtc_data_sctp_transport_internal", "../media:rtc_media_config", "../media:rtc_media_engine_defaults", "../modules/audio_device:audio_device_api", @@ -1064,17 +1123,23 @@ if (rtc_include_tests) { "../modules/rtp_rtcp:rtp_rtcp_format", "../p2p:fake_ice_transport", "../p2p:fake_port_allocator", + "../p2p:p2p_server_utils", "../rtc_base:checks", "../rtc_base:gunit_helpers", + "../rtc_base:ip_address", "../rtc_base:rtc_base_tests_utils", "../rtc_base:rtc_json", + "../rtc_base:socket_address", + "../rtc_base:threading", "../rtc_base/synchronization:mutex", "../rtc_base/third_party/base64", "../rtc_base/third_party/sigslot", + "../system_wrappers:field_trial", "../system_wrappers:metrics", "../test:field_trial", "../test:fileutils", "../test:rtp_test_utils", + "../test:test_common", "../test/pc/sctp:fake_sctp_transport", "./scenario_tests:pc_scenario_tests", "//third_party/abseil-cpp/absl/algorithm:container", @@ -1105,8 +1170,6 @@ if (rtc_include_tests) { "../api/video_codecs:video_codecs_api", "../call:call_interfaces", "../media:rtc_audio_video", - "../media:rtc_data", # TODO(phoglund): AFAIK only used for one sctp - # constant. "../media:rtc_media_base", "../media:rtc_media_tests_utils", "../modules/audio_processing", @@ -1154,4 +1217,187 @@ if (rtc_include_tests) { ] } } + + rtc_library("integration_test_helpers") { + testonly = true + sources = [ + "test/integration_test_helpers.cc", + "test/integration_test_helpers.h", + ] + deps = [ + ":audio_rtp_receiver", + ":audio_track", + ":dtmf_sender", + ":jitter_buffer_delay", + ":media_stream", + ":pc_test_utils", + ":peerconnection", + ":remote_audio_source", + ":rtc_pc_base", + ":rtp_parameters_conversion", + ":rtp_receiver", + ":rtp_sender", + ":rtp_transceiver", + ":session_description", + ":usage_pattern", + ":video_rtp_receiver", + ":video_rtp_track_source", + ":video_track", + ":video_track_source", + "../api:array_view", + "../api:audio_options_api", + "../api:callfactory_api", + "../api:create_peerconnection_factory", + "../api:fake_frame_decryptor", + "../api:fake_frame_encryptor", + "../api:function_view", + "../api:libjingle_logging_api", + "../api:libjingle_peerconnection_api", + "../api:media_stream_interface", + "../api:mock_rtp", + "../api:packet_socket_factory", + "../api:rtc_error", + "../api:rtc_stats_api", + "../api:rtp_parameters", + "../api:rtp_transceiver_direction", + "../api:scoped_refptr", + "../api/audio:audio_mixer_api", + "../api/crypto:frame_decryptor_interface", + "../api/crypto:frame_encryptor_interface", + "../api/crypto:options", + "../api/rtc_event_log", + "../api/rtc_event_log:rtc_event_log_factory", + "../api/task_queue", + "../api/task_queue:default_task_queue_factory", + "../api/transport:field_trial_based_config", + "../api/transport:webrtc_key_value_config", + "../api/transport/rtp:rtp_source", + "../api/units:time_delta", + "../api/video:builtin_video_bitrate_allocator_factory", + "../api/video:video_rtp_headers", + "../api/video_codecs:video_codecs_api", + "../call:call_interfaces", + "../call/adaptation:resource_adaptation_test_utilities", + "../logging:fake_rtc_event_log", + "../media:rtc_audio_video", + "../media:rtc_media_base", + "../media:rtc_media_config", + "../media:rtc_media_engine_defaults", + "../media:rtc_media_tests_utils", + "../modules/audio_device:audio_device_api", + "../modules/audio_processing:api", + "../modules/audio_processing:audio_processing_statistics", + "../modules/audio_processing:audioproc_test_utils", + "../modules/rtp_rtcp:rtp_rtcp_format", + "../p2p:fake_ice_transport", + "../p2p:fake_port_allocator", + "../p2p:p2p_server_utils", + "../p2p:p2p_test_utils", + "../p2p:rtc_p2p", + "../rtc_base", + "../rtc_base:checks", + "../rtc_base:gunit_helpers", + "../rtc_base:ip_address", + "../rtc_base:rtc_base_tests_utils", + "../rtc_base:rtc_json", + "../rtc_base:socket_address", + "../rtc_base:threading", + "../rtc_base:timeutils", + "../rtc_base/synchronization:mutex", + "../rtc_base/task_utils:pending_task_safety_flag", + "../rtc_base/task_utils:to_queued_task", + "../rtc_base/third_party/base64", + "../rtc_base/third_party/sigslot", + "../system_wrappers:metrics", + "../test:field_trial", + "../test:fileutils", + "../test:rtp_test_utils", + "../test:test_support", + "../test/pc/sctp:fake_sctp_transport", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + } + + rtc_library("pc_test_utils") { + testonly = true + sources = [ + "test/fake_audio_capture_module.cc", + "test/fake_audio_capture_module.h", + "test/fake_data_channel_provider.h", + "test/fake_peer_connection_base.h", + "test/fake_peer_connection_for_stats.h", + "test/fake_periodic_video_source.h", + "test/fake_periodic_video_track_source.h", + "test/fake_rtc_certificate_generator.h", + "test/fake_video_track_renderer.h", + "test/fake_video_track_source.h", + "test/frame_generator_capturer_video_track_source.h", + "test/mock_channel_interface.h", + "test/mock_data_channel.h", + "test/mock_peer_connection_observers.h", + "test/mock_rtp_receiver_internal.h", + "test/mock_rtp_sender_internal.h", + "test/peer_connection_test_wrapper.cc", + "test/peer_connection_test_wrapper.h", + "test/rtc_stats_obtainer.h", + "test/test_sdp_strings.h", + ] + + deps = [ + ":jitter_buffer_delay", + ":libjingle_peerconnection", + ":peerconnection", + ":rtc_pc_base", + ":rtp_receiver", + ":rtp_sender", + ":video_track_source", + "../api:audio_options_api", + "../api:create_frame_generator", + "../api:create_peerconnection_factory", + "../api:libjingle_peerconnection_api", + "../api:media_stream_interface", + "../api:rtc_error", + "../api:rtc_stats_api", + "../api:scoped_refptr", + "../api:sequence_checker", + "../api/audio:audio_mixer_api", + "../api/audio_codecs:audio_codecs_api", + "../api/task_queue", + "../api/task_queue:default_task_queue_factory", + "../api/video:builtin_video_bitrate_allocator_factory", + "../api/video:video_frame", + "../api/video:video_rtp_headers", + "../api/video_codecs:builtin_video_decoder_factory", + "../api/video_codecs:builtin_video_encoder_factory", + "../api/video_codecs:video_codecs_api", + "../call:call_interfaces", + "../media:rtc_media", + "../media:rtc_media_base", + "../media:rtc_media_tests_utils", + "../modules/audio_device", + "../modules/audio_processing", + "../modules/audio_processing:api", + "../p2p:fake_port_allocator", + "../p2p:p2p_test_utils", + "../p2p:rtc_p2p", + "../rtc_base", + "../rtc_base:checks", + "../rtc_base:gunit_helpers", + "../rtc_base:rtc_base_approved", + "../rtc_base:rtc_task_queue", + "../rtc_base:task_queue_for_test", + "../rtc_base:threading", + "../rtc_base/synchronization:mutex", + "../rtc_base/task_utils:repeating_task", + "../rtc_base/third_party/sigslot", + "../test:test_support", + "../test:video_test_common", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + } } diff --git a/pc/audio_rtp_receiver.cc b/pc/audio_rtp_receiver.cc index 8ff685d8e2..4efab24d15 100644 --- a/pc/audio_rtp_receiver.cc +++ b/pc/audio_rtp_receiver.cc @@ -15,42 +15,43 @@ #include #include -#include "api/media_stream_proxy.h" -#include "api/media_stream_track_proxy.h" +#include "api/sequence_checker.h" #include "pc/audio_track.h" -#include "pc/jitter_buffer_delay.h" -#include "pc/jitter_buffer_delay_proxy.h" -#include "pc/media_stream.h" +#include "pc/media_stream_track_proxy.h" #include "rtc_base/checks.h" #include "rtc_base/location.h" #include "rtc_base/logging.h" -#include "rtc_base/trace_event.h" +#include "rtc_base/task_utils/to_queued_task.h" namespace webrtc { AudioRtpReceiver::AudioRtpReceiver(rtc::Thread* worker_thread, std::string receiver_id, - std::vector stream_ids) + std::vector stream_ids, + bool is_unified_plan) : AudioRtpReceiver(worker_thread, receiver_id, - CreateStreamsFromIds(std::move(stream_ids))) {} + CreateStreamsFromIds(std::move(stream_ids)), + is_unified_plan) {} AudioRtpReceiver::AudioRtpReceiver( rtc::Thread* worker_thread, const std::string& receiver_id, - const std::vector>& streams) + const std::vector>& streams, + bool is_unified_plan) : worker_thread_(worker_thread), id_(receiver_id), - source_(new rtc::RefCountedObject(worker_thread)), + source_(rtc::make_ref_counted( + worker_thread, + is_unified_plan + ? RemoteAudioSource::OnAudioChannelGoneAction::kSurvive + : RemoteAudioSource::OnAudioChannelGoneAction::kEnd)), track_(AudioTrackProxyWithInternal::Create( rtc::Thread::Current(), AudioTrack::Create(receiver_id, source_))), cached_track_enabled_(track_->enabled()), attachment_id_(GenerateUniqueId()), - delay_(JitterBufferDelayProxy::Create( - rtc::Thread::Current(), - worker_thread_, - new rtc::RefCountedObject(worker_thread))) { + worker_thread_safety_(PendingTaskSafetyFlag::CreateDetachedInactive()) { RTC_DCHECK(worker_thread_); RTC_DCHECK(track_->GetSource()->remote()); track_->RegisterObserver(this); @@ -59,139 +60,188 @@ AudioRtpReceiver::AudioRtpReceiver( } AudioRtpReceiver::~AudioRtpReceiver() { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); + RTC_DCHECK(stopped_); + RTC_DCHECK(!media_channel_); + track_->GetSource()->UnregisterAudioObserver(this); track_->UnregisterObserver(this); - Stop(); } void AudioRtpReceiver::OnChanged() { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); if (cached_track_enabled_ != track_->enabled()) { cached_track_enabled_ = track_->enabled(); - Reconfigure(); + worker_thread_->PostTask(ToQueuedTask( + worker_thread_safety_, + [this, enabled = cached_track_enabled_, volume = cached_volume_]() { + RTC_DCHECK_RUN_ON(worker_thread_); + Reconfigure(enabled, volume); + })); } } -bool AudioRtpReceiver::SetOutputVolume(double volume) { +// RTC_RUN_ON(worker_thread_) +void AudioRtpReceiver::SetOutputVolume_w(double volume) { RTC_DCHECK_GE(volume, 0.0); RTC_DCHECK_LE(volume, 10.0); - RTC_DCHECK(media_channel_); - RTC_DCHECK(!stopped_); - return worker_thread_->Invoke(RTC_FROM_HERE, [&] { - return ssrc_ ? media_channel_->SetOutputVolume(*ssrc_, volume) - : media_channel_->SetDefaultOutputVolume(volume); - }); + ssrc_ ? media_channel_->SetOutputVolume(*ssrc_, volume) + : media_channel_->SetDefaultOutputVolume(volume); } void AudioRtpReceiver::OnSetVolume(double volume) { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); RTC_DCHECK_GE(volume, 0); RTC_DCHECK_LE(volume, 10); - cached_volume_ = volume; - if (!media_channel_ || stopped_) { - RTC_LOG(LS_ERROR) - << "AudioRtpReceiver::OnSetVolume: No audio channel exists."; + if (stopped_) return; - } + + cached_volume_ = volume; + // When the track is disabled, the volume of the source, which is the // corresponding WebRtc Voice Engine channel will be 0. So we do not allow // setting the volume to the source when the track is disabled. - if (!stopped_ && track_->enabled()) { - if (!SetOutputVolume(cached_volume_)) { - RTC_NOTREACHED(); - } + if (track_->enabled()) { + worker_thread_->PostTask( + ToQueuedTask(worker_thread_safety_, [this, volume = cached_volume_]() { + RTC_DCHECK_RUN_ON(worker_thread_); + SetOutputVolume_w(volume); + })); } } +rtc::scoped_refptr AudioRtpReceiver::dtls_transport() + const { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); + return dtls_transport_; +} + std::vector AudioRtpReceiver::stream_ids() const { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); std::vector stream_ids(streams_.size()); for (size_t i = 0; i < streams_.size(); ++i) stream_ids[i] = streams_[i]->id(); return stream_ids; } +std::vector> +AudioRtpReceiver::streams() const { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); + return streams_; +} + RtpParameters AudioRtpReceiver::GetParameters() const { - if (!media_channel_ || stopped_) { + RTC_DCHECK_RUN_ON(worker_thread_); + if (!media_channel_) return RtpParameters(); - } - return worker_thread_->Invoke(RTC_FROM_HERE, [&] { - return ssrc_ ? media_channel_->GetRtpReceiveParameters(*ssrc_) - : media_channel_->GetDefaultRtpReceiveParameters(); - }); + return ssrc_ ? media_channel_->GetRtpReceiveParameters(*ssrc_) + : media_channel_->GetDefaultRtpReceiveParameters(); } void AudioRtpReceiver::SetFrameDecryptor( rtc::scoped_refptr frame_decryptor) { + RTC_DCHECK_RUN_ON(worker_thread_); frame_decryptor_ = std::move(frame_decryptor); // Special Case: Set the frame decryptor to any value on any existing channel. - if (media_channel_ && ssrc_.has_value() && !stopped_) { - worker_thread_->Invoke(RTC_FROM_HERE, [&] { - media_channel_->SetFrameDecryptor(*ssrc_, frame_decryptor_); - }); + if (media_channel_ && ssrc_) { + media_channel_->SetFrameDecryptor(*ssrc_, frame_decryptor_); } } rtc::scoped_refptr AudioRtpReceiver::GetFrameDecryptor() const { + RTC_DCHECK_RUN_ON(worker_thread_); return frame_decryptor_; } void AudioRtpReceiver::Stop() { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); // TODO(deadbeef): Need to do more here to fully stop receiving packets. - if (stopped_) { - return; - } - if (media_channel_) { - // Allow that SetOutputVolume fail. This is the normal case when the - // underlying media channel has already been deleted. - SetOutputVolume(0.0); + if (!stopped_) { + source_->SetState(MediaSourceInterface::kEnded); + stopped_ = true; } - stopped_ = true; + + worker_thread_->Invoke(RTC_FROM_HERE, [&]() { + RTC_DCHECK_RUN_ON(worker_thread_); + if (media_channel_) + SetOutputVolume_w(0.0); + SetMediaChannel_w(nullptr); + }); } void AudioRtpReceiver::StopAndEndTrack() { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); Stop(); track_->internal()->set_ended(); } void AudioRtpReceiver::RestartMediaChannel(absl::optional ssrc) { - RTC_DCHECK(media_channel_); - if (!stopped_ && ssrc_ == ssrc) { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); + bool ok = worker_thread_->Invoke( + RTC_FROM_HERE, [&, enabled = cached_track_enabled_, + volume = cached_volume_, was_stopped = stopped_]() { + RTC_DCHECK_RUN_ON(worker_thread_); + if (!media_channel_) { + RTC_DCHECK(was_stopped); + return false; // Can't restart. + } + + if (!was_stopped && ssrc_ == ssrc) { + // Already running with that ssrc. + RTC_DCHECK(worker_thread_safety_->alive()); + return true; + } + + if (!was_stopped) { + source_->Stop(media_channel_, ssrc_); + } + + ssrc_ = std::move(ssrc); + source_->Start(media_channel_, ssrc_); + if (ssrc_) { + media_channel_->SetBaseMinimumPlayoutDelayMs(*ssrc_, delay_.GetMs()); + } + + Reconfigure(enabled, volume); + return true; + }); + + if (!ok) return; - } - if (!stopped_) { - source_->Stop(media_channel_, ssrc_); - delay_->OnStop(); - } - ssrc_ = ssrc; stopped_ = false; - source_->Start(media_channel_, ssrc); - delay_->OnStart(media_channel_, ssrc.value_or(0)); - Reconfigure(); } void AudioRtpReceiver::SetupMediaChannel(uint32_t ssrc) { - if (!media_channel_) { - RTC_LOG(LS_ERROR) - << "AudioRtpReceiver::SetupMediaChannel: No audio channel exists."; - return; - } + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); RestartMediaChannel(ssrc); } void AudioRtpReceiver::SetupUnsignaledMediaChannel() { - if (!media_channel_) { - RTC_LOG(LS_ERROR) << "AudioRtpReceiver::SetupUnsignaledMediaChannel: No " - "audio channel exists."; - } + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); RestartMediaChannel(absl::nullopt); } +uint32_t AudioRtpReceiver::ssrc() const { + RTC_DCHECK_RUN_ON(worker_thread_); + return ssrc_.value_or(0); +} + void AudioRtpReceiver::set_stream_ids(std::vector stream_ids) { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); SetStreams(CreateStreamsFromIds(std::move(stream_ids))); } +void AudioRtpReceiver::set_transport( + rtc::scoped_refptr dtls_transport) { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); + dtls_transport_ = std::move(dtls_transport); +} + void AudioRtpReceiver::SetStreams( const std::vector>& streams) { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); // Remove remote track from any streams that are going away. for (const auto& existing_stream : streams_) { bool removed = true; @@ -224,51 +274,42 @@ void AudioRtpReceiver::SetStreams( } std::vector AudioRtpReceiver::GetSources() const { - if (!media_channel_ || !ssrc_ || stopped_) { + RTC_DCHECK_RUN_ON(worker_thread_); + if (!media_channel_ || !ssrc_) { return {}; } - return worker_thread_->Invoke>( - RTC_FROM_HERE, [&] { return media_channel_->GetSources(*ssrc_); }); + return media_channel_->GetSources(*ssrc_); } void AudioRtpReceiver::SetDepacketizerToDecoderFrameTransformer( rtc::scoped_refptr frame_transformer) { - worker_thread_->Invoke( - RTC_FROM_HERE, [this, frame_transformer = std::move(frame_transformer)] { - RTC_DCHECK_RUN_ON(worker_thread_); - frame_transformer_ = frame_transformer; - if (media_channel_ && ssrc_.has_value() && !stopped_) { - media_channel_->SetDepacketizerToDecoderFrameTransformer( - *ssrc_, frame_transformer); - } - }); + RTC_DCHECK_RUN_ON(worker_thread_); + if (media_channel_) { + media_channel_->SetDepacketizerToDecoderFrameTransformer(ssrc_.value_or(0), + frame_transformer); + } + frame_transformer_ = std::move(frame_transformer); } -void AudioRtpReceiver::Reconfigure() { - if (!media_channel_ || stopped_) { - RTC_LOG(LS_ERROR) - << "AudioRtpReceiver::Reconfigure: No audio channel exists."; - return; - } - if (!SetOutputVolume(track_->enabled() ? cached_volume_ : 0)) { - RTC_NOTREACHED(); +// RTC_RUN_ON(worker_thread_) +void AudioRtpReceiver::Reconfigure(bool track_enabled, double volume) { + RTC_DCHECK(media_channel_); + + SetOutputVolume_w(track_enabled ? volume : 0); + + if (ssrc_ && frame_decryptor_) { + // Reattach the frame decryptor if we were reconfigured. + media_channel_->SetFrameDecryptor(*ssrc_, frame_decryptor_); } - // Reattach the frame decryptor if we were reconfigured. - MaybeAttachFrameDecryptorToMediaChannel( - ssrc_, worker_thread_, frame_decryptor_, media_channel_, stopped_); - - if (media_channel_ && ssrc_.has_value() && !stopped_) { - worker_thread_->Invoke(RTC_FROM_HERE, [this] { - RTC_DCHECK_RUN_ON(worker_thread_); - if (!frame_transformer_) - return; - media_channel_->SetDepacketizerToDecoderFrameTransformer( - *ssrc_, frame_transformer_); - }); + + if (frame_transformer_) { + media_channel_->SetDepacketizerToDecoderFrameTransformer( + ssrc_.value_or(0), frame_transformer_); } } void AudioRtpReceiver::SetObserver(RtpReceiverObserverInterface* observer) { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); observer_ = observer; // Deliver any notifications the observer may have missed by being set late. if (received_first_packet_ && observer_) { @@ -278,16 +319,35 @@ void AudioRtpReceiver::SetObserver(RtpReceiverObserverInterface* observer) { void AudioRtpReceiver::SetJitterBufferMinimumDelay( absl::optional delay_seconds) { - delay_->Set(delay_seconds); + RTC_DCHECK_RUN_ON(worker_thread_); + delay_.Set(delay_seconds); + if (media_channel_ && ssrc_) + media_channel_->SetBaseMinimumPlayoutDelayMs(*ssrc_, delay_.GetMs()); } void AudioRtpReceiver::SetMediaChannel(cricket::MediaChannel* media_channel) { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); RTC_DCHECK(media_channel == nullptr || media_channel->media_type() == media_type()); + + if (stopped_ && !media_channel) + return; + + worker_thread_->Invoke(RTC_FROM_HERE, [&] { + RTC_DCHECK_RUN_ON(worker_thread_); + SetMediaChannel_w(media_channel); + }); +} + +// RTC_RUN_ON(worker_thread_) +void AudioRtpReceiver::SetMediaChannel_w(cricket::MediaChannel* media_channel) { + media_channel ? worker_thread_safety_->SetAlive() + : worker_thread_safety_->SetNotAlive(); media_channel_ = static_cast(media_channel); } void AudioRtpReceiver::NotifyFirstPacketReceived() { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); if (observer_) { observer_->OnFirstPacketReceived(media_type()); } diff --git a/pc/audio_rtp_receiver.h b/pc/audio_rtp_receiver.h index f4b821068e..c3468721d8 100644 --- a/pc/audio_rtp_receiver.h +++ b/pc/audio_rtp_receiver.h @@ -18,33 +18,43 @@ #include "absl/types/optional.h" #include "api/crypto/frame_decryptor_interface.h" +#include "api/dtls_transport_interface.h" +#include "api/frame_transformer_interface.h" #include "api/media_stream_interface.h" -#include "api/media_stream_track_proxy.h" #include "api/media_types.h" #include "api/rtp_parameters.h" +#include "api/rtp_receiver_interface.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" +#include "api/transport/rtp/rtp_source.h" #include "media/base/media_channel.h" #include "pc/audio_track.h" -#include "pc/jitter_buffer_delay_interface.h" +#include "pc/jitter_buffer_delay.h" +#include "pc/media_stream_track_proxy.h" #include "pc/remote_audio_source.h" #include "pc/rtp_receiver.h" #include "rtc_base/ref_counted_object.h" +#include "rtc_base/system/no_unique_address.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/thread.h" +#include "rtc_base/thread_annotations.h" namespace webrtc { class AudioRtpReceiver : public ObserverInterface, public AudioSourceInterface::AudioObserver, - public rtc::RefCountedObject { + public RtpReceiverInternal { public: AudioRtpReceiver(rtc::Thread* worker_thread, std::string receiver_id, - std::vector stream_ids); + std::vector stream_ids, + bool is_unified_plan); // TODO(https://crbug.com/webrtc/9480): Remove this when streams() is removed. AudioRtpReceiver( rtc::Thread* worker_thread, const std::string& receiver_id, - const std::vector>& streams); + const std::vector>& streams, + bool is_unified_plan); virtual ~AudioRtpReceiver(); // ObserverInterface implementation @@ -53,22 +63,16 @@ class AudioRtpReceiver : public ObserverInterface, // AudioSourceInterface::AudioObserver implementation void OnSetVolume(double volume) override; - rtc::scoped_refptr audio_track() const { - return track_.get(); - } + rtc::scoped_refptr audio_track() const { return track_; } // RtpReceiverInterface implementation rtc::scoped_refptr track() const override { - return track_.get(); - } - rtc::scoped_refptr dtls_transport() const override { - return dtls_transport_; + return track_; } + rtc::scoped_refptr dtls_transport() const override; std::vector stream_ids() const override; std::vector> streams() - const override { - return streams_; - } + const override; cricket::MediaType media_type() const override { return cricket::MEDIA_TYPE_AUDIO; @@ -89,13 +93,11 @@ class AudioRtpReceiver : public ObserverInterface, void StopAndEndTrack() override; void SetupMediaChannel(uint32_t ssrc) override; void SetupUnsignaledMediaChannel() override; - uint32_t ssrc() const override { return ssrc_.value_or(0); } + uint32_t ssrc() const override; void NotifyFirstPacketReceived() override; void set_stream_ids(std::vector stream_ids) override; void set_transport( - rtc::scoped_refptr dtls_transport) override { - dtls_transport_ = dtls_transport; - } + rtc::scoped_refptr dtls_transport) override; void SetStreams(const std::vector>& streams) override; void SetObserver(RtpReceiverObserverInterface* observer) override; @@ -113,29 +115,40 @@ class AudioRtpReceiver : public ObserverInterface, private: void RestartMediaChannel(absl::optional ssrc); - void Reconfigure(); - bool SetOutputVolume(double volume); + void Reconfigure(bool track_enabled, double volume) + RTC_RUN_ON(worker_thread_); + void SetOutputVolume_w(double volume) RTC_RUN_ON(worker_thread_); + void SetMediaChannel_w(cricket::MediaChannel* media_channel) + RTC_RUN_ON(worker_thread_); + RTC_NO_UNIQUE_ADDRESS SequenceChecker signaling_thread_checker_; rtc::Thread* const worker_thread_; const std::string id_; const rtc::scoped_refptr source_; const rtc::scoped_refptr> track_; - cricket::VoiceMediaChannel* media_channel_ = nullptr; - absl::optional ssrc_; - std::vector> streams_; - bool cached_track_enabled_; - double cached_volume_ = 1; - bool stopped_ = true; - RtpReceiverObserverInterface* observer_ = nullptr; - bool received_first_packet_ = false; - int attachment_id_ = 0; - rtc::scoped_refptr frame_decryptor_; - rtc::scoped_refptr dtls_transport_; - // Allows to thread safely change playout delay. Handles caching cases if + cricket::VoiceMediaChannel* media_channel_ RTC_GUARDED_BY(worker_thread_) = + nullptr; + absl::optional ssrc_ RTC_GUARDED_BY(worker_thread_); + std::vector> streams_ + RTC_GUARDED_BY(&signaling_thread_checker_); + bool cached_track_enabled_ RTC_GUARDED_BY(&signaling_thread_checker_); + double cached_volume_ RTC_GUARDED_BY(&signaling_thread_checker_) = 1.0; + bool stopped_ RTC_GUARDED_BY(&signaling_thread_checker_) = true; + RtpReceiverObserverInterface* observer_ + RTC_GUARDED_BY(&signaling_thread_checker_) = nullptr; + bool received_first_packet_ RTC_GUARDED_BY(&signaling_thread_checker_) = + false; + const int attachment_id_; + rtc::scoped_refptr frame_decryptor_ + RTC_GUARDED_BY(worker_thread_); + rtc::scoped_refptr dtls_transport_ + RTC_GUARDED_BY(&signaling_thread_checker_); + // Stores and updates the playout delay. Handles caching cases if // |SetJitterBufferMinimumDelay| is called before start. - rtc::scoped_refptr delay_; + JitterBufferDelay delay_ RTC_GUARDED_BY(worker_thread_); rtc::scoped_refptr frame_transformer_ RTC_GUARDED_BY(worker_thread_); + const rtc::scoped_refptr worker_thread_safety_; }; } // namespace webrtc diff --git a/pc/audio_track.cc b/pc/audio_track.cc index 4f4c6b4757..be087f693b 100644 --- a/pc/audio_track.cc +++ b/pc/audio_track.cc @@ -19,7 +19,7 @@ namespace webrtc { rtc::scoped_refptr AudioTrack::Create( const std::string& id, const rtc::scoped_refptr& source) { - return new rtc::RefCountedObject(id, source); + return rtc::make_ref_counted(id, source); } AudioTrack::AudioTrack(const std::string& label, @@ -32,7 +32,7 @@ AudioTrack::AudioTrack(const std::string& label, } AudioTrack::~AudioTrack() { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&thread_checker_); set_state(MediaStreamTrackInterface::kEnded); if (audio_source_) audio_source_->UnregisterObserver(this); @@ -43,24 +43,24 @@ std::string AudioTrack::kind() const { } AudioSourceInterface* AudioTrack::GetSource() const { - RTC_DCHECK(thread_checker_.IsCurrent()); + // Callable from any thread. return audio_source_.get(); } void AudioTrack::AddSink(AudioTrackSinkInterface* sink) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&thread_checker_); if (audio_source_) audio_source_->AddSink(sink); } void AudioTrack::RemoveSink(AudioTrackSinkInterface* sink) { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&thread_checker_); if (audio_source_) audio_source_->RemoveSink(sink); } void AudioTrack::OnChanged() { - RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK_RUN_ON(&thread_checker_); if (audio_source_->state() == MediaSourceInterface::kEnded) { set_state(kEnded); } else { diff --git a/pc/audio_track.h b/pc/audio_track.h index 8cff79e8b9..8a705cf8fb 100644 --- a/pc/audio_track.h +++ b/pc/audio_track.h @@ -16,7 +16,7 @@ #include "api/media_stream_interface.h" #include "api/media_stream_track.h" #include "api/scoped_refptr.h" -#include "rtc_base/thread_checker.h" +#include "api/sequence_checker.h" namespace webrtc { @@ -41,19 +41,19 @@ class AudioTrack : public MediaStreamTrack, // MediaStreamTrack implementation. std::string kind() const override; - private: // AudioTrackInterface implementation. AudioSourceInterface* GetSource() const override; void AddSink(AudioTrackSinkInterface* sink) override; void RemoveSink(AudioTrackSinkInterface* sink) override; + private: // ObserverInterface implementation. void OnChanged() override; private: const rtc::scoped_refptr audio_source_; - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; }; } // namespace webrtc diff --git a/pc/channel.cc b/pc/channel.cc index 69b2ca1676..8630703be1 100644 --- a/pc/channel.cc +++ b/pc/channel.cc @@ -10,41 +10,39 @@ #include "pc/channel.h" +#include +#include #include +#include #include #include "absl/algorithm/container.h" -#include "absl/memory/memory.h" -#include "api/call/audio_sink.h" -#include "media/base/media_constants.h" +#include "absl/strings/string_view.h" +#include "api/rtp_parameters.h" +#include "api/sequence_checker.h" +#include "api/task_queue/queued_task.h" +#include "media/base/codec.h" +#include "media/base/rid_description.h" #include "media/base/rtp_utils.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" -#include "p2p/base/packet_transport_internal.h" -#include "pc/channel_manager.h" #include "pc/rtp_media_utils.h" -#include "rtc_base/bind.h" -#include "rtc_base/byte_order.h" #include "rtc_base/checks.h" #include "rtc_base/copy_on_write_buffer.h" -#include "rtc_base/dscp.h" #include "rtc_base/logging.h" #include "rtc_base/network_route.h" #include "rtc_base/strings/string_builder.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/trace_event.h" namespace cricket { -using rtc::Bind; -using rtc::UniqueRandomIdGenerator; -using webrtc::SdpType; - namespace { -struct SendPacketMessageData : public rtc::MessageData { - rtc::CopyOnWriteBuffer packet; - rtc::PacketOptions options; -}; +using ::rtc::UniqueRandomIdGenerator; +using ::webrtc::PendingTaskSafetyFlag; +using ::webrtc::SdpType; +using ::webrtc::ToQueuedTask; // Finds a stream based on target's Primary SSRC or RIDs. // This struct is used in BaseChannel::UpdateLocalStreams_w. @@ -81,14 +79,6 @@ struct StreamFinder { } // namespace -enum { - MSG_SEND_RTP_PACKET = 1, - MSG_SEND_RTCP_PACKET, - MSG_READYTOSENDDATA, - MSG_DATARECEIVED, - MSG_FIRSTPACKETRECEIVED, -}; - static void SafeSetError(const std::string& message, std::string* error_desc) { if (error_desc) { *error_desc = message; @@ -135,6 +125,7 @@ BaseChannel::BaseChannel(rtc::Thread* worker_thread, : worker_thread_(worker_thread), network_thread_(network_thread), signaling_thread_(signaling_thread), + alive_(PendingTaskSafetyFlag::Create()), content_name_(content_name), srtp_required_(srtp_required), crypto_options_(crypto_options), @@ -151,8 +142,7 @@ BaseChannel::~BaseChannel() { RTC_DCHECK_RUN_ON(worker_thread_); // Eats any outstanding messages or packets. - worker_thread_->Clear(&invoker_); - worker_thread_->Clear(this); + alive_->SetNotAlive(); // The media channel is destroyed at the end of the destructor, since it // is a std::unique_ptr. The transport channel (rtp_transport) must outlive // the media channel. @@ -169,11 +159,16 @@ std::string BaseChannel::ToString() const { } bool BaseChannel::ConnectToRtpTransport() { - RTC_DCHECK_RUN_ON(network_thread()); RTC_DCHECK(rtp_transport_); - // TODO(bugs.webrtc.org/12230): This accesses demuxer_criteria_ on the - // networking thread. - if (!rtp_transport_->RegisterRtpDemuxerSink(demuxer_criteria_, this)) { + RTC_DCHECK(media_channel()); + + // We don't need to call OnDemuxerCriteriaUpdatePending/Complete because + // there's no previous criteria to worry about. + bool result = rtp_transport_->RegisterRtpDemuxerSink(demuxer_criteria_, this); + if (result) { + previous_demuxer_criteria_ = demuxer_criteria_; + } else { + previous_demuxer_criteria_ = {}; RTC_LOG(LS_ERROR) << "Failed to set up demuxing for " << ToString(); return false; } @@ -189,8 +184,8 @@ bool BaseChannel::ConnectToRtpTransport() { } void BaseChannel::DisconnectFromRtpTransport() { - RTC_DCHECK_RUN_ON(network_thread()); RTC_DCHECK(rtp_transport_); + RTC_DCHECK(media_channel()); rtp_transport_->UnregisterRtpDemuxerSink(this); rtp_transport_->SignalReadyToSend.disconnect(this); rtp_transport_->SignalNetworkRouteChanged.disconnect(this); @@ -201,39 +196,31 @@ void BaseChannel::DisconnectFromRtpTransport() { void BaseChannel::Init_w(webrtc::RtpTransportInternal* rtp_transport) { RTC_DCHECK_RUN_ON(worker_thread()); - network_thread_->Invoke( - RTC_FROM_HERE, [this, rtp_transport] { SetRtpTransport(rtp_transport); }); - - // Both RTP and RTCP channels should be set, we can call SetInterface on - // the media channel and it can set network options. - media_channel_->SetInterface(this); + network_thread_->Invoke(RTC_FROM_HERE, [this, rtp_transport] { + SetRtpTransport(rtp_transport); + // Both RTP and RTCP channels should be set, we can call SetInterface on + // the media channel and it can set network options. + media_channel_->SetInterface(this); + }); } void BaseChannel::Deinit() { RTC_DCHECK_RUN_ON(worker_thread()); - media_channel_->SetInterface(/*iface=*/nullptr); // Packets arrive on the network thread, processing packets calls virtual // functions, so need to stop this process in Deinit that is called in // derived classes destructor. network_thread_->Invoke(RTC_FROM_HERE, [&] { RTC_DCHECK_RUN_ON(network_thread()); - FlushRtcpMessages_n(); + media_channel_->SetInterface(/*iface=*/nullptr); if (rtp_transport_) { DisconnectFromRtpTransport(); } - // Clear pending read packets/messages. - network_thread_->Clear(&invoker_); - network_thread_->Clear(this); }); } bool BaseChannel::SetRtpTransport(webrtc::RtpTransportInternal* rtp_transport) { - if (!network_thread_->IsCurrent()) { - return network_thread_->Invoke(RTC_FROM_HERE, [this, rtp_transport] { - return SetRtpTransport(rtp_transport); - }); - } + TRACE_EVENT0("webrtc", "BaseChannel::SetRtpTransport"); RTC_DCHECK_RUN_ON(network_thread()); if (rtp_transport == rtp_transport_) { return true; @@ -268,78 +255,59 @@ bool BaseChannel::SetRtpTransport(webrtc::RtpTransportInternal* rtp_transport) { return true; } -bool BaseChannel::Enable(bool enable) { - worker_thread_->Invoke( - RTC_FROM_HERE, - Bind(enable ? &BaseChannel::EnableMedia_w : &BaseChannel::DisableMedia_w, - this)); - return true; +void BaseChannel::Enable(bool enable) { + RTC_DCHECK_RUN_ON(signaling_thread()); + + if (enable == enabled_s_) + return; + + enabled_s_ = enable; + + worker_thread_->PostTask(ToQueuedTask(alive_, [this, enable] { + RTC_DCHECK_RUN_ON(worker_thread()); + // Sanity check to make sure that enabled_ and enabled_s_ + // stay in sync. + RTC_DCHECK_NE(enabled_, enable); + if (enable) { + EnableMedia_w(); + } else { + DisableMedia_w(); + } + })); } bool BaseChannel::SetLocalContent(const MediaContentDescription* content, SdpType type, std::string* error_desc) { + RTC_DCHECK_RUN_ON(worker_thread()); TRACE_EVENT0("webrtc", "BaseChannel::SetLocalContent"); - return InvokeOnWorker( - RTC_FROM_HERE, - Bind(&BaseChannel::SetLocalContent_w, this, content, type, error_desc)); + return SetLocalContent_w(content, type, error_desc); } bool BaseChannel::SetRemoteContent(const MediaContentDescription* content, SdpType type, std::string* error_desc) { + RTC_DCHECK_RUN_ON(worker_thread()); TRACE_EVENT0("webrtc", "BaseChannel::SetRemoteContent"); - return InvokeOnWorker( - RTC_FROM_HERE, - Bind(&BaseChannel::SetRemoteContent_w, this, content, type, error_desc)); + return SetRemoteContent_w(content, type, error_desc); } -void BaseChannel::SetPayloadTypeDemuxingEnabled(bool enabled) { +bool BaseChannel::SetPayloadTypeDemuxingEnabled(bool enabled) { + RTC_DCHECK_RUN_ON(worker_thread()); TRACE_EVENT0("webrtc", "BaseChannel::SetPayloadTypeDemuxingEnabled"); - InvokeOnWorker( - RTC_FROM_HERE, - Bind(&BaseChannel::SetPayloadTypeDemuxingEnabled_w, this, enabled)); -} - -bool BaseChannel::UpdateRtpTransport(std::string* error_desc) { - return network_thread_->Invoke(RTC_FROM_HERE, [this, error_desc] { - RTC_DCHECK_RUN_ON(network_thread()); - RTC_DCHECK(rtp_transport_); - // TODO(bugs.webrtc.org/12230): This accesses demuxer_criteria_ on the - // networking thread. - if (!rtp_transport_->RegisterRtpDemuxerSink(demuxer_criteria_, this)) { - RTC_LOG(LS_ERROR) << "Failed to set up demuxing for " << ToString(); - rtc::StringBuilder desc; - desc << "Failed to set up demuxing for m-section with mid='" - << content_name() << "'."; - SafeSetError(desc.str(), error_desc); - return false; - } - // NOTE: This doesn't take the BUNDLE case in account meaning the RTP header - // extension maps are not merged when BUNDLE is enabled. This is fine - // because the ID for MID should be consistent among all the RTP transports, - // and that's all RtpTransport uses this map for. - // - // TODO(deadbeef): Move this call to JsepTransport, there is no reason - // BaseChannel needs to be involved here. - if (media_type() != cricket::MEDIA_TYPE_DATA) { - rtp_transport_->UpdateRtpHeaderExtensionMap( - receive_rtp_header_extensions_); - } - return true; - }); + return SetPayloadTypeDemuxingEnabled_w(enabled); } bool BaseChannel::IsReadyToReceiveMedia_w() const { // Receive data if we are enabled and have local content, - return enabled() && + return enabled_ && webrtc::RtpTransceiverDirectionHasRecv(local_content_direction_); } bool BaseChannel::IsReadyToSendMedia_w() const { // Send outgoing data if we are enabled, have local and remote content, // and we have had some form of connectivity. - return enabled() && + return enabled_ && webrtc::RtpTransceiverDirectionHasRecv(remote_content_direction_) && webrtc::RtpTransceiverDirectionHasSend(local_content_direction_) && was_ever_writable(); @@ -358,13 +326,7 @@ bool BaseChannel::SendRtcp(rtc::CopyOnWriteBuffer* packet, int BaseChannel::SetOption(SocketType type, rtc::Socket::Option opt, int value) { - return network_thread_->Invoke( - RTC_FROM_HERE, Bind(&BaseChannel::SetOption_n, this, type, opt, value)); -} - -int BaseChannel::SetOption_n(SocketType type, - rtc::Socket::Option opt, - int value) { + RTC_DCHECK_RUN_ON(network_thread()); RTC_DCHECK(rtp_transport_); switch (type) { case ST_RTP: @@ -390,7 +352,7 @@ void BaseChannel::OnWritableState(bool writable) { void BaseChannel::OnNetworkRouteChanged( absl::optional network_route) { - RTC_LOG(LS_INFO) << "Network route for " << ToString() << " was changed."; + RTC_LOG(LS_INFO) << "Network route changed for " << ToString(); RTC_DCHECK_RUN_ON(network_thread()); rtc::NetworkRoute new_route; @@ -401,34 +363,25 @@ void BaseChannel::OnNetworkRouteChanged( // use the same transport name and MediaChannel::OnNetworkRouteChanged cannot // work correctly. Intentionally leave it broken to simplify the code and // encourage the users to stop using non-muxing RTCP. - invoker_.AsyncInvoke(RTC_FROM_HERE, worker_thread_, [=] { - RTC_DCHECK_RUN_ON(worker_thread()); - media_channel_->OnNetworkRouteChanged(transport_name_, new_route); - }); -} - -sigslot::signal1& BaseChannel::SignalFirstPacketReceived() { - RTC_DCHECK_RUN_ON(signaling_thread_); - return SignalFirstPacketReceived_; + media_channel_->OnNetworkRouteChanged(transport_name_, new_route); } -sigslot::signal1& BaseChannel::SignalSentPacket() { - // TODO(bugs.webrtc.org/11994): Uncomment this check once callers have been - // fixed to access this variable from the correct thread. - // RTC_DCHECK_RUN_ON(worker_thread_); - return SignalSentPacket_; +void BaseChannel::SetFirstPacketReceivedCallback( + std::function callback) { + RTC_DCHECK_RUN_ON(network_thread()); + RTC_DCHECK(!on_first_packet_received_ || !callback); + on_first_packet_received_ = std::move(callback); } void BaseChannel::OnTransportReadyToSend(bool ready) { - invoker_.AsyncInvoke(RTC_FROM_HERE, worker_thread_, [=] { - RTC_DCHECK_RUN_ON(worker_thread()); - media_channel_->OnReadyToSend(ready); - }); + RTC_DCHECK_RUN_ON(network_thread()); + media_channel_->OnReadyToSend(ready); } bool BaseChannel::SendPacket(bool rtcp, rtc::CopyOnWriteBuffer* packet, const rtc::PacketOptions& options) { + RTC_DCHECK_RUN_ON(network_thread()); // Until all the code is migrated to use RtpPacketType instead of bool. RtpPacketType packet_type = rtcp ? RtpPacketType::kRtcp : RtpPacketType::kRtp; // SendPacket gets called from MediaEngine, on a pacer or an encoder thread. @@ -438,16 +391,6 @@ bool BaseChannel::SendPacket(bool rtcp, // SRTP and the inner workings of the transport channels. // The only downside is that we can't return a proper failure code if // needed. Since UDP is unreliable anyway, this should be a non-issue. - if (!network_thread_->IsCurrent()) { - // Avoid a copy by transferring the ownership of the packet data. - int message_id = rtcp ? MSG_SEND_RTCP_PACKET : MSG_SEND_RTP_PACKET; - SendPacketMessageData* data = new SendPacketMessageData; - data->packet = std::move(*packet); - data->options = options; - network_thread_->Post(RTC_FROM_HERE, this, message_id, data); - return true; - } - RTC_DCHECK_RUN_ON(network_thread()); TRACE_EVENT0("webrtc", "BaseChannel::SendPacket"); @@ -495,16 +438,11 @@ bool BaseChannel::SendPacket(bool rtcp, } void BaseChannel::OnRtpPacket(const webrtc::RtpPacketReceived& parsed_packet) { - // Take packet time from the |parsed_packet|. - // RtpPacketReceived.arrival_time_ms = (timestamp_us + 500) / 1000; - int64_t packet_time_us = -1; - if (parsed_packet.arrival_time_ms() > 0) { - packet_time_us = parsed_packet.arrival_time_ms() * 1000; - } + RTC_DCHECK_RUN_ON(network_thread()); - if (!has_received_packet_) { - has_received_packet_ = true; - signaling_thread()->Post(RTC_FROM_HERE, this, MSG_FIRSTPACKETRECEIVED); + if (on_first_packet_received_) { + on_first_packet_received_(); + on_first_packet_received_ = nullptr; } if (!srtp_active() && srtp_required_) { @@ -525,17 +463,50 @@ void BaseChannel::OnRtpPacket(const webrtc::RtpPacketReceived& parsed_packet) { return; } - auto packet_buffer = parsed_packet.Buffer(); + webrtc::Timestamp packet_time = parsed_packet.arrival_time(); + media_channel_->OnPacketReceived( + parsed_packet.Buffer(), + packet_time.IsMinusInfinity() ? -1 : packet_time.us()); +} + +void BaseChannel::UpdateRtpHeaderExtensionMap( + const RtpHeaderExtensions& header_extensions) { + // Update the header extension map on network thread in case there is data + // race. + // + // NOTE: This doesn't take the BUNDLE case in account meaning the RTP header + // extension maps are not merged when BUNDLE is enabled. This is fine because + // the ID for MID should be consistent among all the RTP transports. + network_thread_->Invoke(RTC_FROM_HERE, [this, &header_extensions] { + RTC_DCHECK_RUN_ON(network_thread()); + rtp_transport_->UpdateRtpHeaderExtensionMap(header_extensions); + }); +} - invoker_.AsyncInvoke( - RTC_FROM_HERE, worker_thread_, [this, packet_buffer, packet_time_us] { - RTC_DCHECK_RUN_ON(worker_thread()); - media_channel_->OnPacketReceived(packet_buffer, packet_time_us); +bool BaseChannel::RegisterRtpDemuxerSink_w() { + if (demuxer_criteria_ == previous_demuxer_criteria_) { + return true; + } + media_channel_->OnDemuxerCriteriaUpdatePending(); + // Copy demuxer criteria, since they're a worker-thread variable + // and we want to pass them to the network thread + return network_thread_->Invoke( + RTC_FROM_HERE, [this, demuxer_criteria = demuxer_criteria_] { + RTC_DCHECK_RUN_ON(network_thread()); + RTC_DCHECK(rtp_transport_); + bool result = + rtp_transport_->RegisterRtpDemuxerSink(demuxer_criteria, this); + if (result) { + previous_demuxer_criteria_ = demuxer_criteria; + } else { + previous_demuxer_criteria_ = {}; + } + media_channel_->OnDemuxerCriteriaUpdateComplete(); + return result; }); } void BaseChannel::EnableMedia_w() { - RTC_DCHECK(worker_thread_ == rtc::Thread::Current()); if (enabled_) return; @@ -545,7 +516,6 @@ void BaseChannel::EnableMedia_w() { } void BaseChannel::DisableMedia_w() { - RTC_DCHECK(worker_thread_ == rtc::Thread::Current()); if (!enabled_) return; @@ -555,6 +525,7 @@ void BaseChannel::DisableMedia_w() { } void BaseChannel::UpdateWritableState_n() { + TRACE_EVENT0("webrtc", "BaseChannel::UpdateWritableState_n"); if (rtp_transport_->IsWritable(/*rtcp=*/true) && rtp_transport_->IsWritable(/*rtcp=*/false)) { ChannelWritable_n(); @@ -564,26 +535,27 @@ void BaseChannel::UpdateWritableState_n() { } void BaseChannel::ChannelWritable_n() { + TRACE_EVENT0("webrtc", "BaseChannel::ChannelWritable_n"); if (writable_) { return; } writable_ = true; - RTC_LOG(LS_INFO) << "Channel writable (" << ToString() << ")" << (was_ever_writable_n_ ? "" : " for the first time"); - // We only have to do this AsyncInvoke once, when first transitioning to + // We only have to do this PostTask once, when first transitioning to // writable. if (!was_ever_writable_n_) { - invoker_.AsyncInvoke(RTC_FROM_HERE, worker_thread_, [this] { + worker_thread_->PostTask(ToQueuedTask(alive_, [this] { RTC_DCHECK_RUN_ON(worker_thread()); was_ever_writable_ = true; UpdateMediaSendRecvState_w(); - }); + })); } was_ever_writable_n_ = true; } void BaseChannel::ChannelNotWritable_n() { + TRACE_EVENT0("webrtc", "BaseChannel::ChannelNotWritable_n"); if (!writable_) { return; } @@ -600,13 +572,12 @@ bool BaseChannel::RemoveRecvStream_w(uint32_t ssrc) { } void BaseChannel::ResetUnsignaledRecvStream_w() { - RTC_DCHECK(worker_thread() == rtc::Thread::Current()); media_channel()->ResetUnsignaledRecvStream(); } -void BaseChannel::SetPayloadTypeDemuxingEnabled_w(bool enabled) { +bool BaseChannel::SetPayloadTypeDemuxingEnabled_w(bool enabled) { if (enabled == payload_type_demuxing_enabled_) { - return; + return true; } payload_type_demuxing_enabled_ = enabled; if (!enabled) { @@ -617,10 +588,21 @@ void BaseChannel::SetPayloadTypeDemuxingEnabled_w(bool enabled) { // there is no straightforward way to identify those streams. media_channel()->ResetUnsignaledRecvStream(); demuxer_criteria_.payload_types.clear(); + if (!RegisterRtpDemuxerSink_w()) { + RTC_LOG(LS_ERROR) << "Failed to disable payload type demuxing for " + << ToString(); + return false; + } } else if (!payload_types_.empty()) { demuxer_criteria_.payload_types.insert(payload_types_.begin(), payload_types_.end()); + if (!RegisterRtpDemuxerSink_w()) { + RTC_LOG(LS_ERROR) << "Failed to enable payload type demuxing for " + << ToString(); + return false; + } } + return true; } bool BaseChannel::UpdateLocalStreams_w(const std::vector& streams, @@ -761,47 +743,21 @@ bool BaseChannel::UpdateRemoteStreams_w( demuxer_criteria_.ssrcs.insert(new_stream.ssrcs.begin(), new_stream.ssrcs.end()); } + // Re-register the sink to update the receiving ssrcs. + if (!RegisterRtpDemuxerSink_w()) { + RTC_LOG(LS_ERROR) << "Failed to set up demuxing for " << ToString(); + ret = false; + } remote_streams_ = streams; return ret; } -RtpHeaderExtensions BaseChannel::GetFilteredRtpHeaderExtensions( +RtpHeaderExtensions BaseChannel::GetDeduplicatedRtpHeaderExtensions( const RtpHeaderExtensions& extensions) { - if (crypto_options_.srtp.enable_encrypted_rtp_header_extensions) { - RtpHeaderExtensions filtered; - absl::c_copy_if(extensions, std::back_inserter(filtered), - [](const webrtc::RtpExtension& extension) { - return !extension.encrypt; - }); - return filtered; - } - - return webrtc::RtpExtension::FilterDuplicateNonEncrypted(extensions); -} - -void BaseChannel::SetReceiveExtensions(const RtpHeaderExtensions& extensions) { - receive_rtp_header_extensions_ = extensions; -} - -void BaseChannel::OnMessage(rtc::Message* pmsg) { - TRACE_EVENT0("webrtc", "BaseChannel::OnMessage"); - switch (pmsg->message_id) { - case MSG_SEND_RTP_PACKET: - case MSG_SEND_RTCP_PACKET: { - RTC_DCHECK_RUN_ON(network_thread()); - SendPacketMessageData* data = - static_cast(pmsg->pdata); - bool rtcp = pmsg->message_id == MSG_SEND_RTCP_PACKET; - SendPacket(rtcp, &data->packet, data->options); - delete data; - break; - } - case MSG_FIRSTPACKETRECEIVED: { - RTC_DCHECK_RUN_ON(signaling_thread_); - SignalFirstPacketReceived_(this); - break; - } - } + return webrtc::RtpExtension::DeduplicateHeaderExtensions( + extensions, crypto_options_.srtp.enable_encrypted_rtp_header_extensions + ? webrtc::RtpExtension::kPreferEncryptedExtension + : webrtc::RtpExtension::kDiscardEncryptedExtension); } void BaseChannel::MaybeAddHandledPayloadType(int payload_type) { @@ -818,37 +774,9 @@ void BaseChannel::ClearHandledPayloadTypes() { payload_types_.clear(); } -void BaseChannel::FlushRtcpMessages_n() { - // Flush all remaining RTCP messages. This should only be called in - // destructor. - rtc::MessageList rtcp_messages; - network_thread_->Clear(this, MSG_SEND_RTCP_PACKET, &rtcp_messages); - for (const auto& message : rtcp_messages) { - network_thread_->Send(RTC_FROM_HERE, this, MSG_SEND_RTCP_PACKET, - message.pdata); - } -} - void BaseChannel::SignalSentPacket_n(const rtc::SentPacket& sent_packet) { - invoker_.AsyncInvoke(RTC_FROM_HERE, worker_thread_, - [this, sent_packet] { - RTC_DCHECK_RUN_ON(worker_thread()); - SignalSentPacket()(sent_packet); - }); -} - -void BaseChannel::SetNegotiatedHeaderExtensions_w( - const RtpHeaderExtensions& extensions) { - TRACE_EVENT0("webrtc", __func__); - RTC_DCHECK_RUN_ON(worker_thread()); - webrtc::MutexLock lock(&negotiated_header_extensions_lock_); - negotiated_header_extensions_ = extensions; -} - -RtpHeaderExtensions BaseChannel::GetNegotiatedRtpHeaderExtensions() const { - RTC_DCHECK_RUN_ON(signaling_thread()); - webrtc::MutexLock lock(&negotiated_header_extensions_lock_); - return negotiated_header_extensions_; + RTC_DCHECK_RUN_ON(network_thread()); + media_channel()->OnPacketSent(sent_packet); } VoiceChannel::VoiceChannel(rtc::Thread* worker_thread, @@ -875,10 +803,6 @@ VoiceChannel::~VoiceChannel() { Deinit(); } -void VoiceChannel::Init_w(webrtc::RtpTransportInternal* rtp_transport) { - BaseChannel::Init_w(rtp_transport); -} - void VoiceChannel::UpdateMediaSendRecvState_w() { // Render incoming data if we're the active call, and we have the local // content. We receive data on the default channel and multiplexed streams. @@ -902,26 +826,19 @@ bool VoiceChannel::SetLocalContent_w(const MediaContentDescription* content, RTC_DCHECK_RUN_ON(worker_thread()); RTC_LOG(LS_INFO) << "Setting local voice description for " << ToString(); - RTC_DCHECK(content); - if (!content) { - SafeSetError("Can't find audio content in local description.", error_desc); - return false; - } - - const AudioContentDescription* audio = content->as_audio(); - - if (type == SdpType::kAnswer) - SetNegotiatedHeaderExtensions_w(audio->rtp_header_extensions()); - RtpHeaderExtensions rtp_header_extensions = - GetFilteredRtpHeaderExtensions(audio->rtp_header_extensions()); - SetReceiveExtensions(rtp_header_extensions); - media_channel()->SetExtmapAllowMixed(audio->extmap_allow_mixed()); + GetDeduplicatedRtpHeaderExtensions(content->rtp_header_extensions()); + // TODO(tommi): There's a hop to the network thread here. + // some of the below is also network thread related. + UpdateRtpHeaderExtensionMap(rtp_header_extensions); + media_channel()->SetExtmapAllowMixed(content->extmap_allow_mixed()); AudioRecvParameters recv_params = last_recv_params_; RtpParametersFromMediaDescription( - audio, rtp_header_extensions, - webrtc::RtpTransceiverDirectionHasRecv(audio->direction()), &recv_params); + content->as_audio(), rtp_header_extensions, + webrtc::RtpTransceiverDirectionHasRecv(content->direction()), + &recv_params); + if (!media_channel()->SetRecvParameters(recv_params)) { SafeSetError( "Failed to set local audio description recv parameters for m-section " @@ -931,10 +848,15 @@ bool VoiceChannel::SetLocalContent_w(const MediaContentDescription* content, return false; } - if (webrtc::RtpTransceiverDirectionHasRecv(audio->direction())) { - for (const AudioCodec& codec : audio->codecs()) { + if (webrtc::RtpTransceiverDirectionHasRecv(content->direction())) { + for (const AudioCodec& codec : content->as_audio()->codecs()) { MaybeAddHandledPayloadType(codec.id); } + // Need to re-register the sink to update the handled payload. + if (!RegisterRtpDemuxerSink_w()) { + RTC_LOG(LS_ERROR) << "Failed to set up audio demuxing for " << ToString(); + return false; + } } last_recv_params_ = recv_params; @@ -943,7 +865,7 @@ bool VoiceChannel::SetLocalContent_w(const MediaContentDescription* content, // only give it to the media channel once we have a remote // description too (without a remote description, we won't be able // to send them anyway). - if (!UpdateLocalStreams_w(audio->streams(), type, error_desc)) { + if (!UpdateLocalStreams_w(content->as_audio()->streams(), type, error_desc)) { SafeSetError( "Failed to set local audio description streams for m-section with " "mid='" + @@ -964,19 +886,10 @@ bool VoiceChannel::SetRemoteContent_w(const MediaContentDescription* content, RTC_DCHECK_RUN_ON(worker_thread()); RTC_LOG(LS_INFO) << "Setting remote voice description for " << ToString(); - RTC_DCHECK(content); - if (!content) { - SafeSetError("Can't find audio content in remote description.", error_desc); - return false; - } - const AudioContentDescription* audio = content->as_audio(); - if (type == SdpType::kAnswer) - SetNegotiatedHeaderExtensions_w(audio->rtp_header_extensions()); - RtpHeaderExtensions rtp_header_extensions = - GetFilteredRtpHeaderExtensions(audio->rtp_header_extensions()); + GetDeduplicatedRtpHeaderExtensions(audio->rtp_header_extensions()); AudioSendParameters send_params = last_send_params_; RtpSendParametersFromMediaDescription( @@ -1000,6 +913,10 @@ bool VoiceChannel::SetRemoteContent_w(const MediaContentDescription* content, "disable payload type demuxing for " << ToString(); ClearHandledPayloadTypes(); + if (!RegisterRtpDemuxerSink_w()) { + RTC_LOG(LS_ERROR) << "Failed to update audio demuxing for " << ToString(); + return false; + } } // TODO(pthatcher): Move remote streams into AudioRecvParameters, @@ -1059,8 +976,9 @@ void VideoChannel::UpdateMediaSendRecvState_w() { } void VideoChannel::FillBitrateInfo(BandwidthEstimationInfo* bwe_info) { - InvokeOnWorker(RTC_FROM_HERE, Bind(&VideoMediaChannel::FillBitrateInfo, - media_channel(), bwe_info)); + RTC_DCHECK_RUN_ON(worker_thread()); + VideoMediaChannel* mc = media_channel(); + mc->FillBitrateInfo(bwe_info); } bool VideoChannel::SetLocalContent_w(const MediaContentDescription* content, @@ -1070,26 +988,17 @@ bool VideoChannel::SetLocalContent_w(const MediaContentDescription* content, RTC_DCHECK_RUN_ON(worker_thread()); RTC_LOG(LS_INFO) << "Setting local video description for " << ToString(); - RTC_DCHECK(content); - if (!content) { - SafeSetError("Can't find video content in local description.", error_desc); - return false; - } - - const VideoContentDescription* video = content->as_video(); - - if (type == SdpType::kAnswer) - SetNegotiatedHeaderExtensions_w(video->rtp_header_extensions()); - RtpHeaderExtensions rtp_header_extensions = - GetFilteredRtpHeaderExtensions(video->rtp_header_extensions()); - SetReceiveExtensions(rtp_header_extensions); - media_channel()->SetExtmapAllowMixed(video->extmap_allow_mixed()); + GetDeduplicatedRtpHeaderExtensions(content->rtp_header_extensions()); + UpdateRtpHeaderExtensionMap(rtp_header_extensions); + media_channel()->SetExtmapAllowMixed(content->extmap_allow_mixed()); VideoRecvParameters recv_params = last_recv_params_; + RtpParametersFromMediaDescription( - video, rtp_header_extensions, - webrtc::RtpTransceiverDirectionHasRecv(video->direction()), &recv_params); + content->as_video(), rtp_header_extensions, + webrtc::RtpTransceiverDirectionHasRecv(content->direction()), + &recv_params); VideoSendParameters send_params = last_send_params_; @@ -1122,10 +1031,15 @@ bool VideoChannel::SetLocalContent_w(const MediaContentDescription* content, return false; } - if (webrtc::RtpTransceiverDirectionHasRecv(video->direction())) { - for (const VideoCodec& codec : video->codecs()) { + if (webrtc::RtpTransceiverDirectionHasRecv(content->direction())) { + for (const VideoCodec& codec : content->as_video()->codecs()) { MaybeAddHandledPayloadType(codec.id); } + // Need to re-register the sink to update the handled payload. + if (!RegisterRtpDemuxerSink_w()) { + RTC_LOG(LS_ERROR) << "Failed to set up video demuxing for " << ToString(); + return false; + } } last_recv_params_ = recv_params; @@ -1144,7 +1058,7 @@ bool VideoChannel::SetLocalContent_w(const MediaContentDescription* content, // only give it to the media channel once we have a remote // description too (without a remote description, we won't be able // to send them anyway). - if (!UpdateLocalStreams_w(video->streams(), type, error_desc)) { + if (!UpdateLocalStreams_w(content->as_video()->streams(), type, error_desc)) { SafeSetError( "Failed to set local video description streams for m-section with " "mid='" + @@ -1165,19 +1079,10 @@ bool VideoChannel::SetRemoteContent_w(const MediaContentDescription* content, RTC_DCHECK_RUN_ON(worker_thread()); RTC_LOG(LS_INFO) << "Setting remote video description for " << ToString(); - RTC_DCHECK(content); - if (!content) { - SafeSetError("Can't find video content in remote description.", error_desc); - return false; - } - const VideoContentDescription* video = content->as_video(); - if (type == SdpType::kAnswer) - SetNegotiatedHeaderExtensions_w(video->rtp_header_extensions()); - RtpHeaderExtensions rtp_header_extensions = - GetFilteredRtpHeaderExtensions(video->rtp_header_extensions()); + GetDeduplicatedRtpHeaderExtensions(video->rtp_header_extensions()); VideoSendParameters send_params = last_send_params_; RtpSendParametersFromMediaDescription( @@ -1235,6 +1140,10 @@ bool VideoChannel::SetRemoteContent_w(const MediaContentDescription* content, "disable payload type demuxing for " << ToString(); ClearHandledPayloadTypes(); + if (!RegisterRtpDemuxerSink_w()) { + RTC_LOG(LS_ERROR) << "Failed to update video demuxing for " << ToString(); + return false; + } } // TODO(pthatcher): Move remote streams into VideoRecvParameters, @@ -1254,237 +1163,4 @@ bool VideoChannel::SetRemoteContent_w(const MediaContentDescription* content, return true; } -RtpDataChannel::RtpDataChannel(rtc::Thread* worker_thread, - rtc::Thread* network_thread, - rtc::Thread* signaling_thread, - std::unique_ptr media_channel, - const std::string& content_name, - bool srtp_required, - webrtc::CryptoOptions crypto_options, - UniqueRandomIdGenerator* ssrc_generator) - : BaseChannel(worker_thread, - network_thread, - signaling_thread, - std::move(media_channel), - content_name, - srtp_required, - crypto_options, - ssrc_generator) {} - -RtpDataChannel::~RtpDataChannel() { - TRACE_EVENT0("webrtc", "RtpDataChannel::~RtpDataChannel"); - // this can't be done in the base class, since it calls a virtual - DisableMedia_w(); - Deinit(); -} - -void RtpDataChannel::Init_w(webrtc::RtpTransportInternal* rtp_transport) { - BaseChannel::Init_w(rtp_transport); - media_channel()->SignalDataReceived.connect(this, - &RtpDataChannel::OnDataReceived); - media_channel()->SignalReadyToSend.connect( - this, &RtpDataChannel::OnDataChannelReadyToSend); -} - -bool RtpDataChannel::SendData(const SendDataParams& params, - const rtc::CopyOnWriteBuffer& payload, - SendDataResult* result) { - return InvokeOnWorker( - RTC_FROM_HERE, Bind(&DataMediaChannel::SendData, media_channel(), params, - payload, result)); -} - -bool RtpDataChannel::CheckDataChannelTypeFromContent( - const MediaContentDescription* content, - std::string* error_desc) { - if (!content->as_rtp_data()) { - if (content->as_sctp()) { - SafeSetError("Data channel type mismatch. Expected RTP, got SCTP.", - error_desc); - } else { - SafeSetError("Data channel is not RTP or SCTP.", error_desc); - } - return false; - } - return true; -} - -bool RtpDataChannel::SetLocalContent_w(const MediaContentDescription* content, - SdpType type, - std::string* error_desc) { - TRACE_EVENT0("webrtc", "RtpDataChannel::SetLocalContent_w"); - RTC_DCHECK_RUN_ON(worker_thread()); - RTC_LOG(LS_INFO) << "Setting local data description for " << ToString(); - - RTC_DCHECK(content); - if (!content) { - SafeSetError("Can't find data content in local description.", error_desc); - return false; - } - - if (!CheckDataChannelTypeFromContent(content, error_desc)) { - return false; - } - const RtpDataContentDescription* data = content->as_rtp_data(); - - RtpHeaderExtensions rtp_header_extensions = - GetFilteredRtpHeaderExtensions(data->rtp_header_extensions()); - - DataRecvParameters recv_params = last_recv_params_; - RtpParametersFromMediaDescription( - data, rtp_header_extensions, - webrtc::RtpTransceiverDirectionHasRecv(data->direction()), &recv_params); - if (!media_channel()->SetRecvParameters(recv_params)) { - SafeSetError( - "Failed to set remote data description recv parameters for m-section " - "with mid='" + - content_name() + "'.", - error_desc); - return false; - } - for (const DataCodec& codec : data->codecs()) { - MaybeAddHandledPayloadType(codec.id); - } - - last_recv_params_ = recv_params; - - // TODO(pthatcher): Move local streams into DataSendParameters, and - // only give it to the media channel once we have a remote - // description too (without a remote description, we won't be able - // to send them anyway). - if (!UpdateLocalStreams_w(data->streams(), type, error_desc)) { - SafeSetError( - "Failed to set local data description streams for m-section with " - "mid='" + - content_name() + "'.", - error_desc); - return false; - } - - set_local_content_direction(content->direction()); - UpdateMediaSendRecvState_w(); - return true; -} - -bool RtpDataChannel::SetRemoteContent_w(const MediaContentDescription* content, - SdpType type, - std::string* error_desc) { - TRACE_EVENT0("webrtc", "RtpDataChannel::SetRemoteContent_w"); - RTC_DCHECK_RUN_ON(worker_thread()); - RTC_LOG(LS_INFO) << "Setting remote data description for " << ToString(); - - RTC_DCHECK(content); - if (!content) { - SafeSetError("Can't find data content in remote description.", error_desc); - return false; - } - - if (!CheckDataChannelTypeFromContent(content, error_desc)) { - return false; - } - - const RtpDataContentDescription* data = content->as_rtp_data(); - - // If the remote data doesn't have codecs, it must be empty, so ignore it. - if (!data->has_codecs()) { - return true; - } - - RtpHeaderExtensions rtp_header_extensions = - GetFilteredRtpHeaderExtensions(data->rtp_header_extensions()); - - RTC_LOG(LS_INFO) << "Setting remote data description for " << ToString(); - DataSendParameters send_params = last_send_params_; - RtpSendParametersFromMediaDescription( - data, rtp_header_extensions, - webrtc::RtpTransceiverDirectionHasRecv(data->direction()), &send_params); - if (!media_channel()->SetSendParameters(send_params)) { - SafeSetError( - "Failed to set remote data description send parameters for m-section " - "with mid='" + - content_name() + "'.", - error_desc); - return false; - } - last_send_params_ = send_params; - - // TODO(pthatcher): Move remote streams into DataRecvParameters, - // and only give it to the media channel once we have a local - // description too (without a local description, we won't be able to - // recv them anyway). - if (!UpdateRemoteStreams_w(data->streams(), type, error_desc)) { - SafeSetError( - "Failed to set remote data description streams for m-section with " - "mid='" + - content_name() + "'.", - error_desc); - return false; - } - - set_remote_content_direction(content->direction()); - UpdateMediaSendRecvState_w(); - return true; -} - -void RtpDataChannel::UpdateMediaSendRecvState_w() { - // Render incoming data if we're the active call, and we have the local - // content. We receive data on the default channel and multiplexed streams. - RTC_DCHECK_RUN_ON(worker_thread()); - bool recv = IsReadyToReceiveMedia_w(); - if (!media_channel()->SetReceive(recv)) { - RTC_LOG(LS_ERROR) << "Failed to SetReceive on data channel: " << ToString(); - } - - // Send outgoing data if we're the active call, we have the remote content, - // and we have had some form of connectivity. - bool send = IsReadyToSendMedia_w(); - if (!media_channel()->SetSend(send)) { - RTC_LOG(LS_ERROR) << "Failed to SetSend on data channel: " << ToString(); - } - - // Trigger SignalReadyToSendData asynchronously. - OnDataChannelReadyToSend(send); - - RTC_LOG(LS_INFO) << "Changing data state, recv=" << recv << " send=" << send - << " for " << ToString(); -} - -void RtpDataChannel::OnMessage(rtc::Message* pmsg) { - switch (pmsg->message_id) { - case MSG_READYTOSENDDATA: { - DataChannelReadyToSendMessageData* data = - static_cast(pmsg->pdata); - ready_to_send_data_ = data->data(); - SignalReadyToSendData(ready_to_send_data_); - delete data; - break; - } - case MSG_DATARECEIVED: { - DataReceivedMessageData* data = - static_cast(pmsg->pdata); - SignalDataReceived(data->params, data->payload); - delete data; - break; - } - default: - BaseChannel::OnMessage(pmsg); - break; - } -} - -void RtpDataChannel::OnDataReceived(const ReceiveDataParams& params, - const char* data, - size_t len) { - DataReceivedMessageData* msg = new DataReceivedMessageData(params, data, len); - signaling_thread()->Post(RTC_FROM_HERE, this, MSG_DATARECEIVED, msg); -} - -void RtpDataChannel::OnDataChannelReadyToSend(bool writable) { - // This is usded for congestion control to indicate that the stream is ready - // to send by the MediaChannel, as opposed to OnReadyToSend, which indicates - // that the transport channel is ready. - signaling_thread()->Post(RTC_FROM_HERE, this, MSG_READYTOSENDDATA, - new DataChannelReadyToSendMessageData(writable)); -} - } // namespace cricket diff --git a/pc/channel.h b/pc/channel.h index bbb95d7ea1..d1dbe2cd6c 100644 --- a/pc/channel.h +++ b/pc/channel.h @@ -11,6 +11,9 @@ #ifndef PC_CHANNEL_H_ #define PC_CHANNEL_H_ +#include +#include + #include #include #include @@ -18,30 +21,48 @@ #include #include +#include "absl/types/optional.h" #include "api/call/audio_sink.h" +#include "api/crypto/crypto_options.h" #include "api/function_view.h" #include "api/jsep.h" +#include "api/media_types.h" #include "api/rtp_receiver_interface.h" +#include "api/rtp_transceiver_direction.h" +#include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "api/video/video_sink_interface.h" #include "api/video/video_source_interface.h" +#include "call/rtp_demuxer.h" #include "call/rtp_packet_sink_interface.h" #include "media/base/media_channel.h" #include "media/base/media_engine.h" #include "media/base/stream_params.h" +#include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "p2p/base/dtls_transport_internal.h" #include "p2p/base/packet_transport_internal.h" #include "pc/channel_interface.h" #include "pc/dtls_srtp_transport.h" #include "pc/media_session.h" #include "pc/rtp_transport.h" +#include "pc/rtp_transport_internal.h" +#include "pc/session_description.h" #include "pc/srtp_filter.h" #include "pc/srtp_transport.h" -#include "rtc_base/async_invoker.h" +#include "rtc_base/async_packet_socket.h" #include "rtc_base/async_udp_socket.h" +#include "rtc_base/checks.h" +#include "rtc_base/copy_on_write_buffer.h" +#include "rtc_base/location.h" #include "rtc_base/network.h" -#include "rtc_base/synchronization/sequence_checker.h" +#include "rtc_base/network/sent_packet.h" +#include "rtc_base/network_route.h" +#include "rtc_base/socket.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/third_party/sigslot/sigslot.h" +#include "rtc_base/thread.h" #include "rtc_base/thread_annotations.h" +#include "rtc_base/thread_message.h" #include "rtc_base/unique_id_generator.h" namespace webrtc { @@ -71,8 +92,10 @@ struct CryptoParams; // NetworkInterface. class BaseChannel : public ChannelInterface, - public rtc::MessageHandlerAutoCleanup, + // TODO(tommi): Remove has_slots inheritance. public sigslot::has_slots<>, + // TODO(tommi): Consider implementing these interfaces + // via composition. public MediaChannel::NetworkInterface, public webrtc::RtpPacketSinkInterface { public: @@ -101,8 +124,13 @@ class BaseChannel : public ChannelInterface, rtc::Thread* network_thread() const { return network_thread_; } const std::string& content_name() const override { return content_name_; } // TODO(deadbeef): This is redundant; remove this. - const std::string& transport_name() const override { return transport_name_; } - bool enabled() const override { return enabled_; } + const std::string& transport_name() const override { + RTC_DCHECK_RUN_ON(network_thread()); + if (rtp_transport_) + return rtp_transport_->transport_name(); + // TODO(tommi): Delete this variable. + return transport_name_; + } // This function returns true if using SRTP (DTLS-based keying or SDES). bool srtp_active() const { @@ -110,15 +138,6 @@ class BaseChannel : public ChannelInterface, return rtp_transport_ && rtp_transport_->IsSrtpActive(); } - // Version of the above that can be called from any thread. - bool SrtpActiveForTesting() const { - if (!network_thread_->IsCurrent()) { - return network_thread_->Invoke(RTC_FROM_HERE, - [this] { return srtp_active(); }); - } - RTC_DCHECK_RUN_ON(network_thread()); - return srtp_active(); - } // Set an RTP level transport which could be an RtpTransport without // encryption, an SrtpTransport for SDES or a DtlsSrtpTransport for DTLS-SRTP. // This can be called from any thread and it hops to the network thread @@ -130,18 +149,7 @@ class BaseChannel : public ChannelInterface, return rtp_transport_; } - // Version of the above that can be called from any thread. - webrtc::RtpTransportInternal* RtpTransportForTesting() const { - if (!network_thread_->IsCurrent()) { - return network_thread_->Invoke( - RTC_FROM_HERE, [this] { return rtp_transport(); }); - } - RTC_DCHECK_RUN_ON(network_thread()); - return rtp_transport(); - } - - // Channel control. Must call UpdateRtpTransport afterwards to apply any - // changes to the RtpTransport on the network thread. + // Channel control bool SetLocalContent(const MediaContentDescription* content, webrtc::SdpType type, std::string* error_desc) override; @@ -156,13 +164,9 @@ class BaseChannel : public ChannelInterface, // This method will also remove any existing streams that were bound to this // channel on the basis of payload type, since one of these streams might // actually belong to a new channel. See: crbug.com/webrtc/11477 - // - // As with SetLocalContent/SetRemoteContent, must call UpdateRtpTransport - // afterwards to apply changes to the RtpTransport on the network thread. - void SetPayloadTypeDemuxingEnabled(bool enabled) override; - bool UpdateRtpTransport(std::string* error_desc) override; + bool SetPayloadTypeDemuxingEnabled(bool enabled) override; - bool Enable(bool enable) override; + void Enable(bool enable) override; const std::vector& local_streams() const override { return local_streams_; @@ -172,28 +176,17 @@ class BaseChannel : public ChannelInterface, } // Used for latency measurements. - sigslot::signal1& SignalFirstPacketReceived() override; - - // Forward SignalSentPacket to worker thread. - sigslot::signal1& SignalSentPacket(); + void SetFirstPacketReceivedCallback(std::function callback) override; // From RtpTransport - public for testing only void OnTransportReadyToSend(bool ready); // Only public for unit tests. Otherwise, consider protected. int SetOption(SocketType type, rtc::Socket::Option o, int val) override; - int SetOption_n(SocketType type, rtc::Socket::Option o, int val) - RTC_RUN_ON(network_thread()); // RtpPacketSinkInterface overrides. void OnRtpPacket(const webrtc::RtpPacketReceived& packet) override; - // Used by the RTCStatsCollector tests to set the transport name without - // creating RtpTransports. - void set_transport_name_for_testing(const std::string& transport_name) { - transport_name_ = transport_name; - } - MediaChannel* media_channel() const override { return media_channel_.get(); } @@ -225,8 +218,6 @@ class BaseChannel : public ChannelInterface, bool IsReadyToSendMedia_w() const RTC_RUN_ON(worker_thread()); rtc::Thread* signaling_thread() const { return signaling_thread_; } - void FlushRtcpMessages_n() RTC_RUN_ON(network_thread()); - // NetworkInterface implementation, called by MediaEngine bool SendPacket(rtc::CopyOnWriteBuffer* packet, const rtc::PacketOptions& options) override; @@ -238,9 +229,6 @@ class BaseChannel : public ChannelInterface, void OnNetworkRouteChanged(absl::optional network_route); - bool PacketIsRtcp(const rtc::PacketTransportInternal* transport, - const char* data, - size_t len); bool SendPacket(bool rtcp, rtc::CopyOnWriteBuffer* packet, const rtc::PacketOptions& options); @@ -258,7 +246,7 @@ class BaseChannel : public ChannelInterface, bool AddRecvStream_w(const StreamParams& sp) RTC_RUN_ON(worker_thread()); bool RemoveRecvStream_w(uint32_t ssrc) RTC_RUN_ON(worker_thread()); void ResetUnsignaledRecvStream_w() RTC_RUN_ON(worker_thread()); - void SetPayloadTypeDemuxingEnabled_w(bool enabled) + bool SetPayloadTypeDemuxingEnabled_w(bool enabled) RTC_RUN_ON(worker_thread()); bool AddSendStream_w(const StreamParams& sp) RTC_RUN_ON(worker_thread()); bool RemoveSendStream_w(uint32_t ssrc) RTC_RUN_ON(worker_thread()); @@ -266,7 +254,7 @@ class BaseChannel : public ChannelInterface, // Should be called whenever the conditions for // IsReadyToReceiveMedia/IsReadyToSendMedia are satisfied (or unsatisfied). // Updates the send/recv state of the media channel. - virtual void UpdateMediaSendRecvState_w() = 0; + virtual void UpdateMediaSendRecvState_w() RTC_RUN_ON(worker_thread()) = 0; bool UpdateLocalStreams_w(const std::vector& streams, webrtc::SdpType type, @@ -278,64 +266,54 @@ class BaseChannel : public ChannelInterface, RTC_RUN_ON(worker_thread()); virtual bool SetLocalContent_w(const MediaContentDescription* content, webrtc::SdpType type, - std::string* error_desc) = 0; + std::string* error_desc) + RTC_RUN_ON(worker_thread()) = 0; virtual bool SetRemoteContent_w(const MediaContentDescription* content, webrtc::SdpType type, - std::string* error_desc) = 0; - // Return a list of RTP header extensions with the non-encrypted extensions - // removed depending on the current crypto_options_ and only if both the - // non-encrypted and encrypted extension is present for the same URI. - RtpHeaderExtensions GetFilteredRtpHeaderExtensions( + std::string* error_desc) + RTC_RUN_ON(worker_thread()) = 0; + + // Returns a list of RTP header extensions where any extension URI is unique. + // Encrypted extensions will be either preferred or discarded, depending on + // the current crypto_options_. + RtpHeaderExtensions GetDeduplicatedRtpHeaderExtensions( const RtpHeaderExtensions& extensions); - // Set a list of RTP extensions we should prepare to receive on the next - // UpdateRtpTransport call. - void SetReceiveExtensions(const RtpHeaderExtensions& extensions); - - // From MessageHandler - void OnMessage(rtc::Message* pmsg) override; - - // Helper function template for invoking methods on the worker thread. - template - T InvokeOnWorker(const rtc::Location& posted_from, - rtc::FunctionView functor) { - return worker_thread_->Invoke(posted_from, functor); - } // Add |payload_type| to |demuxer_criteria_| if payload type demuxing is // enabled. void MaybeAddHandledPayloadType(int payload_type) RTC_RUN_ON(worker_thread()); void ClearHandledPayloadTypes() RTC_RUN_ON(worker_thread()); - // Return description of media channel to facilitate logging - std::string ToString() const; - void SetNegotiatedHeaderExtensions_w(const RtpHeaderExtensions& extensions); + void UpdateRtpHeaderExtensionMap( + const RtpHeaderExtensions& header_extensions); - // ChannelInterface overrides - RtpHeaderExtensions GetNegotiatedRtpHeaderExtensions() const override; + bool RegisterRtpDemuxerSink_w() RTC_RUN_ON(worker_thread()); - bool has_received_packet_ = false; + // Return description of media channel to facilitate logging + std::string ToString() const; private: - bool ConnectToRtpTransport(); - void DisconnectFromRtpTransport(); - void SignalSentPacket_n(const rtc::SentPacket& sent_packet) - RTC_RUN_ON(network_thread()); + bool ConnectToRtpTransport() RTC_RUN_ON(network_thread()); + void DisconnectFromRtpTransport() RTC_RUN_ON(network_thread()); + void SignalSentPacket_n(const rtc::SentPacket& sent_packet); rtc::Thread* const worker_thread_; rtc::Thread* const network_thread_; rtc::Thread* const signaling_thread_; - rtc::AsyncInvoker invoker_; - sigslot::signal1 SignalFirstPacketReceived_ - RTC_GUARDED_BY(signaling_thread_); - sigslot::signal1 SignalSentPacket_ - RTC_GUARDED_BY(worker_thread_); + rtc::scoped_refptr alive_; const std::string content_name_; + std::function on_first_packet_received_ + RTC_GUARDED_BY(network_thread()); + // Won't be set when using raw packet transports. SDP-specific thing. // TODO(bugs.webrtc.org/12230): Written on network thread, read on // worker thread (at least). + // TODO(tommi): Remove this variable and instead use rtp_transport_ to + // return the transport name. This variable is currently required for + // "for_test" methods. std::string transport_name_; webrtc::RtpTransportInternal* rtp_transport_ @@ -349,6 +327,24 @@ class BaseChannel : public ChannelInterface, bool was_ever_writable_n_ RTC_GUARDED_BY(network_thread()) = false; bool was_ever_writable_ RTC_GUARDED_BY(worker_thread()) = false; const bool srtp_required_ = true; + + // TODO(tommi): This field shouldn't be necessary. It's a copy of + // PeerConnection::GetCryptoOptions(), which is const state. It's also only + // used to filter header extensions when calling + // `rtp_transport_->UpdateRtpHeaderExtensionMap()` when the local/remote + // content description is updated. Since the transport is actually owned + // by the transport controller that also gets updated whenever the content + // description changes, it seems we have two paths into the transports, along + // with several thread hops via various classes (such as the Channel classes) + // that only serve as additional layers and store duplicate state. The Jsep* + // family of classes already apply session description updates on the network + // thread every time it changes. + // For the Channel classes, we should be able to get rid of: + // * crypto_options (and fewer construction parameters)_ + // * UpdateRtpHeaderExtensionMap + // * GetFilteredRtpHeaderExtensions + // * Blocking thread hop to the network thread for every call to set + // local/remote content is updated. const webrtc::CryptoOptions crypto_options_; // MediaChannel related members that should be accessed from the worker @@ -357,7 +353,8 @@ class BaseChannel : public ChannelInterface, // Currently the |enabled_| flag is accessed from the signaling thread as // well, but it can be changed only when signaling thread does a synchronous // call to the worker thread, so it should be safe. - bool enabled_ = false; + bool enabled_ RTC_GUARDED_BY(worker_thread()) = false; + bool enabled_s_ RTC_GUARDED_BY(signaling_thread()) = false; bool payload_type_demuxing_enabled_ RTC_GUARDED_BY(worker_thread()) = true; std::vector local_streams_ RTC_GUARDED_BY(worker_thread()); std::vector remote_streams_ RTC_GUARDED_BY(worker_thread()); @@ -371,23 +368,17 @@ class BaseChannel : public ChannelInterface, // Cached list of payload types, used if payload type demuxing is re-enabled. std::set payload_types_ RTC_GUARDED_BY(worker_thread()); - // TODO(bugs.webrtc.org/12239): These two variables are modified on the worker - // thread, accessed on the network thread in UpdateRtpTransport. + // TODO(bugs.webrtc.org/12239): Modified on worker thread, accessed + // on network thread in RegisterRtpDemuxerSink_n (called from Init_w) webrtc::RtpDemuxerCriteria demuxer_criteria_; - RtpHeaderExtensions receive_rtp_header_extensions_; + // Accessed on the worker thread, modified on the network thread from + // RegisterRtpDemuxerSink_w's Invoke. + webrtc::RtpDemuxerCriteria previous_demuxer_criteria_; // This generator is used to generate SSRCs for local streams. // This is needed in cases where SSRCs are not negotiated or set explicitly // like in Simulcast. // This object is not owned by the channel so it must outlive it. rtc::UniqueRandomIdGenerator* const ssrc_generator_; - - // |negotiated_header_extensions_| is read on the signaling thread, but - // written on the worker thread while being sync-invoked from the signal - // thread in SdpOfferAnswerHandler::PushdownMediaDescription(). Hence the lock - // isn't strictly needed, but it's anyway placed here for future safeness. - mutable webrtc::Mutex negotiated_header_extensions_lock_; - RtpHeaderExtensions negotiated_header_extensions_ - RTC_GUARDED_BY(negotiated_header_extensions_lock_); }; // VoiceChannel is a specialization that adds support for early media, DTMF, @@ -412,7 +403,6 @@ class VoiceChannel : public BaseChannel { cricket::MediaType media_type() const override { return cricket::MEDIA_TYPE_AUDIO; } - void Init_w(webrtc::RtpTransportInternal* rtp_transport) override; private: // overrides from BaseChannel @@ -474,104 +464,6 @@ class VideoChannel : public BaseChannel { VideoRecvParameters last_recv_params_; }; -// RtpDataChannel is a specialization for data. -class RtpDataChannel : public BaseChannel { - public: - RtpDataChannel(rtc::Thread* worker_thread, - rtc::Thread* network_thread, - rtc::Thread* signaling_thread, - std::unique_ptr channel, - const std::string& content_name, - bool srtp_required, - webrtc::CryptoOptions crypto_options, - rtc::UniqueRandomIdGenerator* ssrc_generator); - ~RtpDataChannel(); - // TODO(zhihuang): Remove this once the RtpTransport can be shared between - // BaseChannels. - void Init_w(DtlsTransportInternal* rtp_dtls_transport, - DtlsTransportInternal* rtcp_dtls_transport, - rtc::PacketTransportInternal* rtp_packet_transport, - rtc::PacketTransportInternal* rtcp_packet_transport); - void Init_w(webrtc::RtpTransportInternal* rtp_transport) override; - - virtual bool SendData(const SendDataParams& params, - const rtc::CopyOnWriteBuffer& payload, - SendDataResult* result); - - // Should be called on the signaling thread only. - bool ready_to_send_data() const { return ready_to_send_data_; } - - sigslot::signal2 - SignalDataReceived; - // Signal for notifying when the channel becomes ready to send data. - // That occurs when the channel is enabled, the transport is writable, - // both local and remote descriptions are set, and the channel is unblocked. - sigslot::signal1 SignalReadyToSendData; - cricket::MediaType media_type() const override { - return cricket::MEDIA_TYPE_DATA; - } - - protected: - // downcasts a MediaChannel. - DataMediaChannel* media_channel() const override { - return static_cast(BaseChannel::media_channel()); - } - - private: - struct SendDataMessageData : public rtc::MessageData { - SendDataMessageData(const SendDataParams& params, - const rtc::CopyOnWriteBuffer* payload, - SendDataResult* result) - : params(params), payload(payload), result(result), succeeded(false) {} - - const SendDataParams& params; - const rtc::CopyOnWriteBuffer* payload; - SendDataResult* result; - bool succeeded; - }; - - struct DataReceivedMessageData : public rtc::MessageData { - // We copy the data because the data will become invalid after we - // handle DataMediaChannel::SignalDataReceived but before we fire - // SignalDataReceived. - DataReceivedMessageData(const ReceiveDataParams& params, - const char* data, - size_t len) - : params(params), payload(data, len) {} - const ReceiveDataParams params; - const rtc::CopyOnWriteBuffer payload; - }; - - typedef rtc::TypedMessageData DataChannelReadyToSendMessageData; - - // overrides from BaseChannel - // Checks that data channel type is RTP. - bool CheckDataChannelTypeFromContent(const MediaContentDescription* content, - std::string* error_desc); - bool SetLocalContent_w(const MediaContentDescription* content, - webrtc::SdpType type, - std::string* error_desc) override; - bool SetRemoteContent_w(const MediaContentDescription* content, - webrtc::SdpType type, - std::string* error_desc) override; - void UpdateMediaSendRecvState_w() override; - - void OnMessage(rtc::Message* pmsg) override; - void OnDataReceived(const ReceiveDataParams& params, - const char* data, - size_t len); - void OnDataChannelReadyToSend(bool writable); - - bool ready_to_send_data_ = false; - - // Last DataSendParameters sent down to the media_channel() via - // SetSendParameters. - DataSendParameters last_send_params_; - // Last DataRecvParameters sent down to the media_channel() via - // SetRecvParameters. - DataRecvParameters last_recv_params_; -}; - } // namespace cricket #endif // PC_CHANNEL_H_ diff --git a/pc/channel_interface.h b/pc/channel_interface.h index 1937c8f9f6..3b71f0f8b5 100644 --- a/pc/channel_interface.h +++ b/pc/channel_interface.h @@ -37,13 +37,12 @@ class ChannelInterface { virtual const std::string& content_name() const = 0; - virtual bool enabled() const = 0; - // Enables or disables this channel - virtual bool Enable(bool enable) = 0; + virtual void Enable(bool enable) = 0; // Used for latency measurements. - virtual sigslot::signal1& SignalFirstPacketReceived() = 0; + virtual void SetFirstPacketReceivedCallback( + std::function callback) = 0; // Channel control virtual bool SetLocalContent(const MediaContentDescription* content, @@ -52,8 +51,7 @@ class ChannelInterface { virtual bool SetRemoteContent(const MediaContentDescription* content, webrtc::SdpType type, std::string* error_desc) = 0; - virtual void SetPayloadTypeDemuxingEnabled(bool enabled) = 0; - virtual bool UpdateRtpTransport(std::string* error_desc) = 0; + virtual bool SetPayloadTypeDemuxingEnabled(bool enabled) = 0; // Access to the local and remote streams that were set on the channel. virtual const std::vector& local_streams() const = 0; @@ -66,9 +64,6 @@ class ChannelInterface { // * A DtlsSrtpTransport for DTLS-SRTP. virtual bool SetRtpTransport(webrtc::RtpTransportInternal* rtp_transport) = 0; - // Returns the last negotiated header extensions. - virtual RtpHeaderExtensions GetNegotiatedRtpHeaderExtensions() const = 0; - protected: virtual ~ChannelInterface() = default; }; diff --git a/pc/channel_manager.cc b/pc/channel_manager.cc index 9d5adcad42..b58830b215 100644 --- a/pc/channel_manager.cc +++ b/pc/channel_manager.cc @@ -10,57 +10,54 @@ #include "pc/channel_manager.h" +#include #include #include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" +#include "api/sequence_checker.h" #include "media/base/media_constants.h" #include "rtc_base/checks.h" #include "rtc_base/location.h" #include "rtc_base/logging.h" -#include "rtc_base/thread_checker.h" #include "rtc_base/trace_event.h" namespace cricket { +// static +std::unique_ptr ChannelManager::Create( + std::unique_ptr media_engine, + bool enable_rtx, + rtc::Thread* worker_thread, + rtc::Thread* network_thread) { + RTC_DCHECK_RUN_ON(worker_thread); + RTC_DCHECK(network_thread); + RTC_DCHECK(worker_thread); + + if (media_engine) + media_engine->Init(); + + return absl::WrapUnique(new ChannelManager( + std::move(media_engine), enable_rtx, worker_thread, network_thread)); +} + ChannelManager::ChannelManager( std::unique_ptr media_engine, - std::unique_ptr data_engine, + bool enable_rtx, rtc::Thread* worker_thread, rtc::Thread* network_thread) : media_engine_(std::move(media_engine)), - data_engine_(std::move(data_engine)), - main_thread_(rtc::Thread::Current()), worker_thread_(worker_thread), - network_thread_(network_thread) { - RTC_DCHECK(data_engine_); + network_thread_(network_thread), + enable_rtx_(enable_rtx) { RTC_DCHECK(worker_thread_); RTC_DCHECK(network_thread_); + RTC_DCHECK_RUN_ON(worker_thread_); } ChannelManager::~ChannelManager() { - if (initialized_) { - Terminate(); - } - // The media engine needs to be deleted on the worker thread for thread safe - // destruction, - worker_thread_->Invoke(RTC_FROM_HERE, [&] { media_engine_.reset(); }); -} - -bool ChannelManager::SetVideoRtxEnabled(bool enable) { - // To be safe, this call is only allowed before initialization. Apps like - // Flute only have a singleton ChannelManager and we don't want this flag to - // be toggled between calls or when there's concurrent calls. We expect apps - // to enable this at startup and retain that setting for the lifetime of the - // app. - if (!initialized_) { - enable_rtx_ = enable; - return true; - } else { - RTC_LOG(LS_WARNING) << "Cannot toggle rtx after initialization!"; - return false; - } + RTC_DCHECK_RUN_ON(worker_thread_); } void ChannelManager::GetSupportedAudioSendCodecs( @@ -113,34 +110,6 @@ void ChannelManager::GetSupportedVideoReceiveCodecs( } } -void ChannelManager::GetSupportedDataCodecs( - std::vector* codecs) const { - *codecs = data_engine_->data_codecs(); -} - -bool ChannelManager::Init() { - RTC_DCHECK(!initialized_); - if (initialized_) { - return false; - } - RTC_DCHECK(network_thread_); - RTC_DCHECK(worker_thread_); - if (!network_thread_->IsCurrent()) { - // Do not allow invoking calls to other threads on the network thread. - network_thread_->Invoke( - RTC_FROM_HERE, [&] { network_thread_->DisallowBlockingCalls(); }); - } - - if (media_engine_) { - initialized_ = worker_thread_->Invoke( - RTC_FROM_HERE, [&] { return media_engine_->Init(); }); - RTC_DCHECK(initialized_); - } else { - initialized_ = true; - } - return initialized_; -} - RtpHeaderExtensions ChannelManager::GetDefaultEnabledAudioRtpHeaderExtensions() const { if (!media_engine_) @@ -169,23 +138,9 @@ ChannelManager::GetSupportedVideoRtpHeaderExtensions() const { return media_engine_->video().GetRtpHeaderExtensions(); } -void ChannelManager::Terminate() { - RTC_DCHECK(initialized_); - if (!initialized_) { - return; - } - // Need to destroy the channels on the worker thread. - worker_thread_->Invoke(RTC_FROM_HERE, [&] { - video_channels_.clear(); - voice_channels_.clear(); - data_channels_.clear(); - }); - initialized_ = false; -} - VoiceChannel* ChannelManager::CreateVoiceChannel( webrtc::Call* call, - const cricket::MediaConfig& media_config, + const MediaConfig& media_config, webrtc::RtpTransportInternal* rtp_transport, rtc::Thread* signaling_thread, const std::string& content_name, @@ -193,6 +148,8 @@ VoiceChannel* ChannelManager::CreateVoiceChannel( const webrtc::CryptoOptions& crypto_options, rtc::UniqueRandomIdGenerator* ssrc_generator, const AudioOptions& options) { + RTC_DCHECK(call); + RTC_DCHECK(media_engine_); // TODO(bugs.webrtc.org/11992): Remove this workaround after updates in // PeerConnection and add the expectation that we're already on the right // thread. @@ -205,11 +162,6 @@ VoiceChannel* ChannelManager::CreateVoiceChannel( } RTC_DCHECK_RUN_ON(worker_thread_); - RTC_DCHECK(initialized_); - RTC_DCHECK(call); - if (!media_engine_) { - return nullptr; - } VoiceMediaChannel* media_channel = media_engine_->voice().CreateMediaChannel( call, media_config, options, crypto_options); @@ -231,32 +183,25 @@ VoiceChannel* ChannelManager::CreateVoiceChannel( void ChannelManager::DestroyVoiceChannel(VoiceChannel* voice_channel) { TRACE_EVENT0("webrtc", "ChannelManager::DestroyVoiceChannel"); - if (!voice_channel) { - return; - } + RTC_DCHECK(voice_channel); + if (!worker_thread_->IsCurrent()) { worker_thread_->Invoke(RTC_FROM_HERE, [&] { DestroyVoiceChannel(voice_channel); }); return; } - RTC_DCHECK(initialized_); - - auto it = absl::c_find_if(voice_channels_, - [&](const std::unique_ptr& p) { - return p.get() == voice_channel; - }); - RTC_DCHECK(it != voice_channels_.end()); - if (it == voice_channels_.end()) { - return; - } + RTC_DCHECK_RUN_ON(worker_thread_); - voice_channels_.erase(it); + voice_channels_.erase(absl::c_find_if( + voice_channels_, [&](const std::unique_ptr& p) { + return p.get() == voice_channel; + })); } VideoChannel* ChannelManager::CreateVideoChannel( webrtc::Call* call, - const cricket::MediaConfig& media_config, + const MediaConfig& media_config, webrtc::RtpTransportInternal* rtp_transport, rtc::Thread* signaling_thread, const std::string& content_name, @@ -265,6 +210,8 @@ VideoChannel* ChannelManager::CreateVideoChannel( rtc::UniqueRandomIdGenerator* ssrc_generator, const VideoOptions& options, webrtc::VideoBitrateAllocatorFactory* video_bitrate_allocator_factory) { + RTC_DCHECK(call); + RTC_DCHECK(media_engine_); // TODO(bugs.webrtc.org/11992): Remove this workaround after updates in // PeerConnection and add the expectation that we're already on the right // thread. @@ -278,11 +225,6 @@ VideoChannel* ChannelManager::CreateVideoChannel( } RTC_DCHECK_RUN_ON(worker_thread_); - RTC_DCHECK(initialized_); - RTC_DCHECK(call); - if (!media_engine_) { - return nullptr; - } VideoMediaChannel* media_channel = media_engine_->video().CreateMediaChannel( call, media_config, options, crypto_options, @@ -305,101 +247,30 @@ VideoChannel* ChannelManager::CreateVideoChannel( void ChannelManager::DestroyVideoChannel(VideoChannel* video_channel) { TRACE_EVENT0("webrtc", "ChannelManager::DestroyVideoChannel"); - if (!video_channel) { - return; - } + RTC_DCHECK(video_channel); + if (!worker_thread_->IsCurrent()) { worker_thread_->Invoke(RTC_FROM_HERE, [&] { DestroyVideoChannel(video_channel); }); return; } + RTC_DCHECK_RUN_ON(worker_thread_); - RTC_DCHECK(initialized_); - - auto it = absl::c_find_if(video_channels_, - [&](const std::unique_ptr& p) { - return p.get() == video_channel; - }); - RTC_DCHECK(it != video_channels_.end()); - if (it == video_channels_.end()) { - return; - } - - video_channels_.erase(it); -} - -RtpDataChannel* ChannelManager::CreateRtpDataChannel( - const cricket::MediaConfig& media_config, - webrtc::RtpTransportInternal* rtp_transport, - rtc::Thread* signaling_thread, - const std::string& content_name, - bool srtp_required, - const webrtc::CryptoOptions& crypto_options, - rtc::UniqueRandomIdGenerator* ssrc_generator) { - if (!worker_thread_->IsCurrent()) { - return worker_thread_->Invoke(RTC_FROM_HERE, [&] { - return CreateRtpDataChannel(media_config, rtp_transport, signaling_thread, - content_name, srtp_required, crypto_options, - ssrc_generator); - }); - } - - // This is ok to alloc from a thread other than the worker thread. - RTC_DCHECK(initialized_); - DataMediaChannel* media_channel = data_engine_->CreateChannel(media_config); - if (!media_channel) { - RTC_LOG(LS_WARNING) << "Failed to create RTP data channel."; - return nullptr; - } - - auto data_channel = std::make_unique( - worker_thread_, network_thread_, signaling_thread, - absl::WrapUnique(media_channel), content_name, srtp_required, - crypto_options, ssrc_generator); - - // Media Transports are not supported with Rtp Data Channel. - data_channel->Init_w(rtp_transport); - - RtpDataChannel* data_channel_ptr = data_channel.get(); - data_channels_.push_back(std::move(data_channel)); - return data_channel_ptr; -} - -void ChannelManager::DestroyRtpDataChannel(RtpDataChannel* data_channel) { - TRACE_EVENT0("webrtc", "ChannelManager::DestroyRtpDataChannel"); - if (!data_channel) { - return; - } - if (!worker_thread_->IsCurrent()) { - worker_thread_->Invoke( - RTC_FROM_HERE, [&] { return DestroyRtpDataChannel(data_channel); }); - return; - } - - RTC_DCHECK(initialized_); - - auto it = absl::c_find_if(data_channels_, - [&](const std::unique_ptr& p) { - return p.get() == data_channel; - }); - RTC_DCHECK(it != data_channels_.end()); - if (it == data_channels_.end()) { - return; - } - - data_channels_.erase(it); + video_channels_.erase(absl::c_find_if( + video_channels_, [&](const std::unique_ptr& p) { + return p.get() == video_channel; + })); } bool ChannelManager::StartAecDump(webrtc::FileWrapper file, int64_t max_size_bytes) { - return worker_thread_->Invoke(RTC_FROM_HERE, [&] { - return media_engine_->voice().StartAecDump(std::move(file), max_size_bytes); - }); + RTC_DCHECK_RUN_ON(worker_thread_); + return media_engine_->voice().StartAecDump(std::move(file), max_size_bytes); } void ChannelManager::StopAecDump() { - worker_thread_->Invoke(RTC_FROM_HERE, - [&] { media_engine_->voice().StopAecDump(); }); + RTC_DCHECK_RUN_ON(worker_thread_); + media_engine_->voice().StopAecDump(); } } // namespace cricket diff --git a/pc/channel_manager.h b/pc/channel_manager.h index ba2c260099..43fa27935f 100644 --- a/pc/channel_manager.h +++ b/pc/channel_manager.h @@ -19,6 +19,8 @@ #include "api/audio_options.h" #include "api/crypto/crypto_options.h" +#include "api/rtp_parameters.h" +#include "api/video/video_bitrate_allocator_factory.h" #include "call/call.h" #include "media/base/codec.h" #include "media/base/media_channel.h" @@ -29,6 +31,7 @@ #include "pc/session_description.h" #include "rtc_base/system/file_wrapper.h" #include "rtc_base/thread.h" +#include "rtc_base/unique_id_generator.h" namespace cricket { @@ -42,32 +45,20 @@ namespace cricket { // using device manager. class ChannelManager final { public: - // Construct a ChannelManager with the specified media engine and data engine. - ChannelManager(std::unique_ptr media_engine, - std::unique_ptr data_engine, - rtc::Thread* worker_thread, - rtc::Thread* network_thread); + // Returns an initialized instance of ChannelManager. + // If media_engine is non-nullptr, then the returned ChannelManager instance + // will own that reference and media engine initialization + static std::unique_ptr Create( + std::unique_ptr media_engine, + bool enable_rtx, + rtc::Thread* worker_thread, + rtc::Thread* network_thread); + + ChannelManager() = delete; ~ChannelManager(); - // Accessors for the worker thread, allowing it to be set after construction, - // but before Init. set_worker_thread will return false if called after Init. rtc::Thread* worker_thread() const { return worker_thread_; } - bool set_worker_thread(rtc::Thread* thread) { - if (initialized_) { - return false; - } - worker_thread_ = thread; - return true; - } rtc::Thread* network_thread() const { return network_thread_; } - bool set_network_thread(rtc::Thread* thread) { - if (initialized_) { - return false; - } - network_thread_ = thread; - return true; - } - MediaEngineInterface* media_engine() { return media_engine_.get(); } // Retrieves the list of supported audio & video codec types. @@ -76,7 +67,6 @@ class ChannelManager final { void GetSupportedAudioReceiveCodecs(std::vector* codecs) const; void GetSupportedVideoSendCodecs(std::vector* codecs) const; void GetSupportedVideoReceiveCodecs(std::vector* codecs) const; - void GetSupportedDataCodecs(std::vector* codecs) const; RtpHeaderExtensions GetDefaultEnabledAudioRtpHeaderExtensions() const; std::vector GetSupportedAudioRtpHeaderExtensions() const; @@ -84,20 +74,13 @@ class ChannelManager final { std::vector GetSupportedVideoRtpHeaderExtensions() const; - // Indicates whether the media engine is started. - bool initialized() const { return initialized_; } - // Starts up the media engine. - bool Init(); - // Shuts down the media engine. - void Terminate(); - // The operations below all occur on the worker thread. // ChannelManager retains ownership of the created channels, so clients should // call the appropriate Destroy*Channel method when done. // Creates a voice channel, to be associated with the specified session. VoiceChannel* CreateVoiceChannel(webrtc::Call* call, - const cricket::MediaConfig& media_config, + const MediaConfig& media_config, webrtc::RtpTransportInternal* rtp_transport, rtc::Thread* signaling_thread, const std::string& content_name, @@ -113,7 +96,7 @@ class ChannelManager final { // Version of the above that takes PacketTransportInternal. VideoChannel* CreateVideoChannel( webrtc::Call* call, - const cricket::MediaConfig& media_config, + const MediaConfig& media_config, webrtc::RtpTransportInternal* rtp_transport, rtc::Thread* signaling_thread, const std::string& content_name, @@ -125,32 +108,6 @@ class ChannelManager final { // Destroys a video channel created by CreateVideoChannel. void DestroyVideoChannel(VideoChannel* video_channel); - RtpDataChannel* CreateRtpDataChannel( - const cricket::MediaConfig& media_config, - webrtc::RtpTransportInternal* rtp_transport, - rtc::Thread* signaling_thread, - const std::string& content_name, - bool srtp_required, - const webrtc::CryptoOptions& crypto_options, - rtc::UniqueRandomIdGenerator* ssrc_generator); - // Destroys a data channel created by CreateRtpDataChannel. - void DestroyRtpDataChannel(RtpDataChannel* data_channel); - - // Indicates whether any channels exist. - bool has_channels() const { - return (!voice_channels_.empty() || !video_channels_.empty() || - !data_channels_.empty()); - } - - // RTX will be enabled/disabled in engines that support it. The supporting - // engines will start offering an RTX codec. Must be called before Init(). - bool SetVideoRtxEnabled(bool enable); - - // Starts/stops the local microphone and enables polling of the input level. - bool capturing() const { return capturing_; } - - // The operations below occur on the main thread. - // Starts AEC dump using existing file, with a specified maximum file size in // bytes. When the limit is reached, logging will stop and the file will be // closed. If max_size_bytes is set to <= 0, no limit will be used. @@ -160,20 +117,22 @@ class ChannelManager final { void StopAecDump(); private: - std::unique_ptr media_engine_; // Nullable. - std::unique_ptr data_engine_; // Non-null. - bool initialized_ = false; - rtc::Thread* main_thread_; - rtc::Thread* worker_thread_; - rtc::Thread* network_thread_; + ChannelManager(std::unique_ptr media_engine, + bool enable_rtx, + rtc::Thread* worker_thread, + rtc::Thread* network_thread); + + const std::unique_ptr media_engine_; // Nullable. + rtc::Thread* const worker_thread_; + rtc::Thread* const network_thread_; // Vector contents are non-null. - std::vector> voice_channels_; - std::vector> video_channels_; - std::vector> data_channels_; + std::vector> voice_channels_ + RTC_GUARDED_BY(worker_thread_); + std::vector> video_channels_ + RTC_GUARDED_BY(worker_thread_); - bool enable_rtx_ = false; - bool capturing_ = false; + const bool enable_rtx_; }; } // namespace cricket diff --git a/pc/channel_manager_unittest.cc b/pc/channel_manager_unittest.cc index 610d7979ab..88de1f6a48 100644 --- a/pc/channel_manager_unittest.cc +++ b/pc/channel_manager_unittest.cc @@ -26,11 +26,9 @@ #include "rtc_base/thread.h" #include "test/gtest.h" +namespace cricket { namespace { const bool kDefaultSrtpRequired = true; -} - -namespace cricket { static const AudioCodec kAudioCodecs[] = { AudioCodec(97, "voice", 1, 2, 3), @@ -43,36 +41,33 @@ static const VideoCodec kVideoCodecs[] = { VideoCodec(96, "rtx"), }; +std::unique_ptr CreateFakeMediaEngine() { + auto fme = std::make_unique(); + fme->SetAudioCodecs(MAKE_VECTOR(kAudioCodecs)); + fme->SetVideoCodecs(MAKE_VECTOR(kVideoCodecs)); + return fme; +} + +} // namespace + class ChannelManagerTest : public ::testing::Test { protected: ChannelManagerTest() : network_(rtc::Thread::CreateWithSocketServer()), - worker_(rtc::Thread::Create()), + worker_(rtc::Thread::Current()), video_bitrate_allocator_factory_( webrtc::CreateBuiltinVideoBitrateAllocatorFactory()), - fme_(new cricket::FakeMediaEngine()), - fdme_(new cricket::FakeDataEngine()), - cm_(new cricket::ChannelManager( - std::unique_ptr(fme_), - std::unique_ptr(fdme_), - rtc::Thread::Current(), - rtc::Thread::Current())), - fake_call_() { - fme_->SetAudioCodecs(MAKE_VECTOR(kAudioCodecs)); - fme_->SetVideoCodecs(MAKE_VECTOR(kVideoCodecs)); - } - - std::unique_ptr CreateDtlsSrtpTransport() { - rtp_dtls_transport_ = std::make_unique( - "fake_dtls_transport", cricket::ICE_CANDIDATE_COMPONENT_RTP); - auto dtls_srtp_transport = std::make_unique( - /*rtcp_mux_required=*/true); - dtls_srtp_transport->SetDtlsTransports(rtp_dtls_transport_.get(), - /*rtcp_dtls_transport=*/nullptr); - return dtls_srtp_transport; + cm_(cricket::ChannelManager::Create(CreateFakeMediaEngine(), + false, + worker_, + network_.get())), + fake_call_(worker_, network_.get()) { + network_->SetName("Network", this); + network_->Start(); } void TestCreateDestroyChannels(webrtc::RtpTransportInternal* rtp_transport) { + RTC_DCHECK_RUN_ON(worker_); cricket::VoiceChannel* voice_channel = cm_->CreateVoiceChannel( &fake_call_, cricket::MediaConfig(), rtp_transport, rtc::Thread::Current(), cricket::CN_AUDIO, kDefaultSrtpRequired, @@ -84,59 +79,19 @@ class ChannelManagerTest : public ::testing::Test { webrtc::CryptoOptions(), &ssrc_generator_, VideoOptions(), video_bitrate_allocator_factory_.get()); EXPECT_TRUE(video_channel != nullptr); - cricket::RtpDataChannel* rtp_data_channel = cm_->CreateRtpDataChannel( - cricket::MediaConfig(), rtp_transport, rtc::Thread::Current(), - cricket::CN_DATA, kDefaultSrtpRequired, webrtc::CryptoOptions(), - &ssrc_generator_); - EXPECT_TRUE(rtp_data_channel != nullptr); cm_->DestroyVideoChannel(video_channel); cm_->DestroyVoiceChannel(voice_channel); - cm_->DestroyRtpDataChannel(rtp_data_channel); - cm_->Terminate(); } - std::unique_ptr rtp_dtls_transport_; std::unique_ptr network_; - std::unique_ptr worker_; + rtc::Thread* const worker_; std::unique_ptr video_bitrate_allocator_factory_; - // |fme_| and |fdme_| are actually owned by |cm_|. - cricket::FakeMediaEngine* fme_; - cricket::FakeDataEngine* fdme_; std::unique_ptr cm_; cricket::FakeCall fake_call_; rtc::UniqueRandomIdGenerator ssrc_generator_; }; -// Test that we startup/shutdown properly. -TEST_F(ChannelManagerTest, StartupShutdown) { - EXPECT_FALSE(cm_->initialized()); - EXPECT_EQ(rtc::Thread::Current(), cm_->worker_thread()); - EXPECT_TRUE(cm_->Init()); - EXPECT_TRUE(cm_->initialized()); - cm_->Terminate(); - EXPECT_FALSE(cm_->initialized()); -} - -// Test that we startup/shutdown properly with a worker thread. -TEST_F(ChannelManagerTest, StartupShutdownOnThread) { - network_->Start(); - worker_->Start(); - EXPECT_FALSE(cm_->initialized()); - EXPECT_EQ(rtc::Thread::Current(), cm_->worker_thread()); - EXPECT_TRUE(cm_->set_network_thread(network_.get())); - EXPECT_EQ(network_.get(), cm_->network_thread()); - EXPECT_TRUE(cm_->set_worker_thread(worker_.get())); - EXPECT_EQ(worker_.get(), cm_->worker_thread()); - EXPECT_TRUE(cm_->Init()); - EXPECT_TRUE(cm_->initialized()); - // Setting the network or worker thread while initialized should fail. - EXPECT_FALSE(cm_->set_network_thread(rtc::Thread::Current())); - EXPECT_FALSE(cm_->set_worker_thread(rtc::Thread::Current())); - cm_->Terminate(); - EXPECT_FALSE(cm_->initialized()); -} - TEST_F(ChannelManagerTest, SetVideoRtxEnabled) { std::vector send_codecs; std::vector recv_codecs; @@ -149,47 +104,34 @@ TEST_F(ChannelManagerTest, SetVideoRtxEnabled) { EXPECT_FALSE(ContainsMatchingCodec(recv_codecs, rtx_codec)); // Enable and check. - EXPECT_TRUE(cm_->SetVideoRtxEnabled(true)); + cm_ = cricket::ChannelManager::Create(CreateFakeMediaEngine(), + true, worker_, network_.get()); cm_->GetSupportedVideoSendCodecs(&send_codecs); EXPECT_TRUE(ContainsMatchingCodec(send_codecs, rtx_codec)); cm_->GetSupportedVideoSendCodecs(&recv_codecs); EXPECT_TRUE(ContainsMatchingCodec(recv_codecs, rtx_codec)); // Disable and check. - EXPECT_TRUE(cm_->SetVideoRtxEnabled(false)); + cm_ = cricket::ChannelManager::Create(CreateFakeMediaEngine(), + false, worker_, network_.get()); cm_->GetSupportedVideoSendCodecs(&send_codecs); EXPECT_FALSE(ContainsMatchingCodec(send_codecs, rtx_codec)); cm_->GetSupportedVideoSendCodecs(&recv_codecs); EXPECT_FALSE(ContainsMatchingCodec(recv_codecs, rtx_codec)); - - // Cannot toggle rtx after initialization. - EXPECT_TRUE(cm_->Init()); - EXPECT_FALSE(cm_->SetVideoRtxEnabled(true)); - EXPECT_FALSE(cm_->SetVideoRtxEnabled(false)); - - // Can set again after terminate. - cm_->Terminate(); - EXPECT_TRUE(cm_->SetVideoRtxEnabled(true)); - cm_->GetSupportedVideoSendCodecs(&send_codecs); - EXPECT_TRUE(ContainsMatchingCodec(send_codecs, rtx_codec)); - cm_->GetSupportedVideoSendCodecs(&recv_codecs); - EXPECT_TRUE(ContainsMatchingCodec(recv_codecs, rtx_codec)); } TEST_F(ChannelManagerTest, CreateDestroyChannels) { - EXPECT_TRUE(cm_->Init()); - auto rtp_transport = CreateDtlsSrtpTransport(); - TestCreateDestroyChannels(rtp_transport.get()); -} - -TEST_F(ChannelManagerTest, CreateDestroyChannelsOnThread) { - network_->Start(); - worker_->Start(); - EXPECT_TRUE(cm_->set_worker_thread(worker_.get())); - EXPECT_TRUE(cm_->set_network_thread(network_.get())); - EXPECT_TRUE(cm_->Init()); - auto rtp_transport = CreateDtlsSrtpTransport(); - TestCreateDestroyChannels(rtp_transport.get()); + auto rtp_dtls_transport = std::make_unique( + "fake_dtls_transport", cricket::ICE_CANDIDATE_COMPONENT_RTP, + network_.get()); + auto dtls_srtp_transport = std::make_unique( + /*rtcp_mux_required=*/true); + network_->Invoke( + RTC_FROM_HERE, [&rtp_dtls_transport, &dtls_srtp_transport] { + dtls_srtp_transport->SetDtlsTransports(rtp_dtls_transport.get(), + /*rtcp_dtls_transport=*/nullptr); + }); + TestCreateDestroyChannels(dtls_srtp_transport.get()); } } // namespace cricket diff --git a/pc/channel_unittest.cc b/pc/channel_unittest.cc index fb62b08df5..581f6de7ac 100644 --- a/pc/channel_unittest.cc +++ b/pc/channel_unittest.cc @@ -35,6 +35,8 @@ #include "rtc_base/checks.h" #include "rtc_base/rtc_certificate.h" #include "rtc_base/ssl_identity.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "test/gmock.h" #include "test/gtest.h" @@ -52,7 +54,6 @@ const cricket::AudioCodec kPcmaCodec(8, "PCMA", 64000, 8000, 1); const cricket::AudioCodec kIsacCodec(103, "ISAC", 40000, 16000, 1); const cricket::VideoCodec kH264Codec(97, "H264"); const cricket::VideoCodec kH264SvcCodec(99, "H264-SVC"); -const cricket::DataCodec kGoogleDataCodec(101, "google-data"); const uint32_t kSsrc1 = 0x1111; const uint32_t kSsrc2 = 0x2222; const uint32_t kSsrc3 = 0x3333; @@ -93,14 +94,7 @@ class VideoTraits : public Traits {}; -class DataTraits : public Traits {}; - -// Base class for Voice/Video/RtpDataChannel tests +// Base class for Voice/Video tests template class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { public: @@ -127,19 +121,30 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { network_thread_keeper_->SetName("Network", nullptr); network_thread_ = network_thread_keeper_.get(); } + RTC_DCHECK(network_thread_); + } + + ~ChannelTest() { + if (network_thread_) { + network_thread_->Invoke( + RTC_FROM_HERE, [this]() { network_thread_safety_->SetNotAlive(); }); + } } void CreateChannels(int flags1, int flags2) { CreateChannels(std::make_unique( - nullptr, typename T::Options()), + nullptr, typename T::Options(), network_thread_), std::make_unique( - nullptr, typename T::Options()), + nullptr, typename T::Options(), network_thread_), flags1, flags2); } void CreateChannels(std::unique_ptr ch1, std::unique_ptr ch2, int flags1, int flags2) { + RTC_DCHECK(!channel1_); + RTC_DCHECK(!channel2_); + // Network thread is started in CreateChannels, to allow the test to // configure a fake clock before any threads are spawned and attempt to // access the time. @@ -151,8 +156,6 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { // channels. RTC_DCHECK_EQ(flags1 & RAW_PACKET_TRANSPORT, flags2 & RAW_PACKET_TRANSPORT); rtc::Thread* worker_thread = rtc::Thread::Current(); - media_channel1_ = ch1.get(); - media_channel2_ = ch2.get(); rtc::PacketTransportInternal* rtp1 = nullptr; rtc::PacketTransportInternal* rtcp1 = nullptr; rtc::PacketTransportInternal* rtp2 = nullptr; @@ -170,11 +173,12 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { } else { // Confirmed to work with KT_RSA and KT_ECDSA. fake_rtp_dtls_transport1_.reset(new cricket::FakeDtlsTransport( - "channel1", cricket::ICE_CANDIDATE_COMPONENT_RTP)); + "channel1", cricket::ICE_CANDIDATE_COMPONENT_RTP, network_thread_)); rtp1 = fake_rtp_dtls_transport1_.get(); if (!(flags1 & RTCP_MUX)) { fake_rtcp_dtls_transport1_.reset(new cricket::FakeDtlsTransport( - "channel1", cricket::ICE_CANDIDATE_COMPONENT_RTCP)); + "channel1", cricket::ICE_CANDIDATE_COMPONENT_RTCP, + network_thread_)); rtcp1 = fake_rtcp_dtls_transport1_.get(); } if (flags1 & DTLS) { @@ -199,11 +203,12 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { } else { // Confirmed to work with KT_RSA and KT_ECDSA. fake_rtp_dtls_transport2_.reset(new cricket::FakeDtlsTransport( - "channel2", cricket::ICE_CANDIDATE_COMPONENT_RTP)); + "channel2", cricket::ICE_CANDIDATE_COMPONENT_RTP, network_thread_)); rtp2 = fake_rtp_dtls_transport2_.get(); if (!(flags2 & RTCP_MUX)) { fake_rtcp_dtls_transport2_.reset(new cricket::FakeDtlsTransport( - "channel2", cricket::ICE_CANDIDATE_COMPONENT_RTCP)); + "channel2", cricket::ICE_CANDIDATE_COMPONENT_RTCP, + network_thread_)); rtcp2 = fake_rtcp_dtls_transport2_.get(); } if (flags2 & DTLS) { @@ -284,10 +289,14 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { auto rtp_transport = std::make_unique( rtcp_packet_transport == nullptr); - rtp_transport->SetRtpPacketTransport(rtp_packet_transport); - if (rtcp_packet_transport) { - rtp_transport->SetRtcpPacketTransport(rtcp_packet_transport); - } + network_thread_->Invoke( + RTC_FROM_HERE, + [&rtp_transport, rtp_packet_transport, rtcp_packet_transport] { + rtp_transport->SetRtpPacketTransport(rtp_packet_transport); + if (rtcp_packet_transport) { + rtp_transport->SetRtcpPacketTransport(rtcp_packet_transport); + } + }); return rtp_transport; } @@ -297,8 +306,12 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { auto dtls_srtp_transport = std::make_unique( rtcp_dtls_transport == nullptr); - dtls_srtp_transport->SetDtlsTransports(rtp_dtls_transport, - rtcp_dtls_transport); + network_thread_->Invoke( + RTC_FROM_HERE, + [&dtls_srtp_transport, rtp_dtls_transport, rtcp_dtls_transport] { + dtls_srtp_transport->SetDtlsTransports(rtp_dtls_transport, + rtcp_dtls_transport); + }); return dtls_srtp_transport; } @@ -331,18 +344,16 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { bool SendInitiate() { bool result = channel1_->SetLocalContent(&local_media_content1_, - SdpType::kOffer, NULL) && - channel1_->UpdateRtpTransport(nullptr); + SdpType::kOffer, NULL); if (result) { channel1_->Enable(true); + FlushCurrentThread(); result = channel2_->SetRemoteContent(&remote_media_content1_, - SdpType::kOffer, NULL) && - channel2_->UpdateRtpTransport(nullptr); + SdpType::kOffer, NULL); if (result) { ConnectFakeTransports(); result = channel2_->SetLocalContent(&local_media_content2_, - SdpType::kAnswer, NULL) && - channel2_->UpdateRtpTransport(nullptr); + SdpType::kAnswer, NULL); } } return result; @@ -350,33 +361,29 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { bool SendAccept() { channel2_->Enable(true); + FlushCurrentThread(); return channel1_->SetRemoteContent(&remote_media_content2_, - SdpType::kAnswer, NULL) && - channel1_->UpdateRtpTransport(nullptr); + SdpType::kAnswer, NULL); } bool SendOffer() { bool result = channel1_->SetLocalContent(&local_media_content1_, - SdpType::kOffer, NULL) && - channel1_->UpdateRtpTransport(nullptr); + SdpType::kOffer, NULL); if (result) { channel1_->Enable(true); result = channel2_->SetRemoteContent(&remote_media_content1_, - SdpType::kOffer, NULL) && - channel2_->UpdateRtpTransport(nullptr); + SdpType::kOffer, NULL); } return result; } bool SendProvisionalAnswer() { bool result = channel2_->SetLocalContent(&local_media_content2_, - SdpType::kPrAnswer, NULL) && - channel2_->UpdateRtpTransport(nullptr); + SdpType::kPrAnswer, NULL); if (result) { channel2_->Enable(true); result = channel1_->SetRemoteContent(&remote_media_content2_, - SdpType::kPrAnswer, NULL) && - channel1_->UpdateRtpTransport(nullptr); + SdpType::kPrAnswer, NULL); ConnectFakeTransports(); } return result; @@ -384,64 +391,59 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { bool SendFinalAnswer() { bool result = channel2_->SetLocalContent(&local_media_content2_, - SdpType::kAnswer, NULL) && - channel2_->UpdateRtpTransport(nullptr); + SdpType::kAnswer, NULL); if (result) result = channel1_->SetRemoteContent(&remote_media_content2_, - SdpType::kAnswer, NULL) && - channel1_->UpdateRtpTransport(nullptr); + SdpType::kAnswer, NULL); return result; } - bool Terminate() { - channel1_.reset(); - channel2_.reset(); - fake_rtp_dtls_transport1_.reset(); - fake_rtcp_dtls_transport1_.reset(); - fake_rtp_dtls_transport2_.reset(); - fake_rtcp_dtls_transport2_.reset(); - fake_rtp_packet_transport1_.reset(); - fake_rtcp_packet_transport1_.reset(); - fake_rtp_packet_transport2_.reset(); - fake_rtcp_packet_transport2_.reset(); - if (network_thread_keeper_) { - network_thread_keeper_.reset(); - } - return true; + void SendRtp(typename T::MediaChannel* media_channel, rtc::Buffer data) { + network_thread_->PostTask(webrtc::ToQueuedTask( + network_thread_safety_, [media_channel, data = std::move(data)]() { + media_channel->SendRtp(data.data(), data.size(), + rtc::PacketOptions()); + })); } void SendRtp1() { - media_channel1_->SendRtp(rtp_packet_.data(), rtp_packet_.size(), - rtc::PacketOptions()); + SendRtp1(rtc::Buffer(rtp_packet_.data(), rtp_packet_.size())); + } + + void SendRtp1(rtc::Buffer data) { + SendRtp(media_channel1(), std::move(data)); } + void SendRtp2() { - media_channel2_->SendRtp(rtp_packet_.data(), rtp_packet_.size(), - rtc::PacketOptions()); + SendRtp2(rtc::Buffer(rtp_packet_.data(), rtp_packet_.size())); + } + + void SendRtp2(rtc::Buffer data) { + SendRtp(media_channel2(), std::move(data)); } + // Methods to send custom data. void SendCustomRtp1(uint32_t ssrc, int sequence_number, int pl_type = -1) { - rtc::Buffer data = CreateRtpData(ssrc, sequence_number, pl_type); - media_channel1_->SendRtp(data.data(), data.size(), rtc::PacketOptions()); + SendRtp1(CreateRtpData(ssrc, sequence_number, pl_type)); } void SendCustomRtp2(uint32_t ssrc, int sequence_number, int pl_type = -1) { - rtc::Buffer data = CreateRtpData(ssrc, sequence_number, pl_type); - media_channel2_->SendRtp(data.data(), data.size(), rtc::PacketOptions()); + SendRtp2(CreateRtpData(ssrc, sequence_number, pl_type)); } bool CheckRtp1() { - return media_channel1_->CheckRtp(rtp_packet_.data(), rtp_packet_.size()); + return media_channel1()->CheckRtp(rtp_packet_.data(), rtp_packet_.size()); } bool CheckRtp2() { - return media_channel2_->CheckRtp(rtp_packet_.data(), rtp_packet_.size()); + return media_channel2()->CheckRtp(rtp_packet_.data(), rtp_packet_.size()); } // Methods to check custom data. bool CheckCustomRtp1(uint32_t ssrc, int sequence_number, int pl_type = -1) { rtc::Buffer data = CreateRtpData(ssrc, sequence_number, pl_type); - return media_channel1_->CheckRtp(data.data(), data.size()); + return media_channel1()->CheckRtp(data.data(), data.size()); } bool CheckCustomRtp2(uint32_t ssrc, int sequence_number, int pl_type = -1) { rtc::Buffer data = CreateRtpData(ssrc, sequence_number, pl_type); - return media_channel2_->CheckRtp(data.data(), data.size()); + return media_channel2()->CheckRtp(data.data(), data.size()); } rtc::Buffer CreateRtpData(uint32_t ssrc, int sequence_number, int pl_type) { rtc::Buffer data(rtp_packet_.data(), rtp_packet_.size()); @@ -454,8 +456,8 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { return data; } - bool CheckNoRtp1() { return media_channel1_->CheckNoRtp(); } - bool CheckNoRtp2() { return media_channel2_->CheckNoRtp(); } + bool CheckNoRtp1() { return media_channel1()->CheckNoRtp(); } + bool CheckNoRtp2() { return media_channel2()->CheckNoRtp(); } void CreateContent(int flags, const cricket::AudioCodec& audio_codec, @@ -510,19 +512,38 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { // Base implementation. } + // Utility method that calls BaseChannel::srtp_active() on the network thread + // and returns the result. The |srtp_active()| state is maintained on the + // network thread, which callers need to factor in. + bool IsSrtpActive(std::unique_ptr& channel) { + RTC_DCHECK(channel.get()); + return network_thread_->Invoke( + RTC_FROM_HERE, [&] { return channel->srtp_active(); }); + } + + // Returns true iff the transport is set for a channel and rtcp_mux_enabled() + // returns true. + bool IsRtcpMuxEnabled(std::unique_ptr& channel) { + RTC_DCHECK(channel.get()); + return network_thread_->Invoke(RTC_FROM_HERE, [&] { + return channel->rtp_transport() && + channel->rtp_transport()->rtcp_mux_enabled(); + }); + } + // Tests that can be used by derived classes. // Basic sanity check. void TestInit() { CreateChannels(0, 0); - EXPECT_FALSE(channel1_->SrtpActiveForTesting()); - EXPECT_FALSE(media_channel1_->sending()); + EXPECT_FALSE(IsSrtpActive(channel1_)); + EXPECT_FALSE(media_channel1()->sending()); if (verify_playout_) { - EXPECT_FALSE(media_channel1_->playout()); + EXPECT_FALSE(media_channel1()->playout()); } - EXPECT_TRUE(media_channel1_->codecs().empty()); - EXPECT_TRUE(media_channel1_->recv_streams().empty()); - EXPECT_TRUE(media_channel1_->rtp_packets().empty()); + EXPECT_TRUE(media_channel1()->codecs().empty()); + EXPECT_TRUE(media_channel1()->recv_streams().empty()); + EXPECT_TRUE(media_channel1()->rtp_packets().empty()); } // Test that SetLocalContent and SetRemoteContent properly configure @@ -532,11 +553,11 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { typename T::Content content; CreateContent(0, kPcmuCodec, kH264Codec, &content); EXPECT_TRUE(channel1_->SetLocalContent(&content, SdpType::kOffer, NULL)); - EXPECT_EQ(0U, media_channel1_->codecs().size()); + EXPECT_EQ(0U, media_channel1()->codecs().size()); EXPECT_TRUE(channel1_->SetRemoteContent(&content, SdpType::kAnswer, NULL)); - ASSERT_EQ(1U, media_channel1_->codecs().size()); + ASSERT_EQ(1U, media_channel1()->codecs().size()); EXPECT_TRUE( - CodecMatches(content.codecs()[0], media_channel1_->codecs()[0])); + CodecMatches(content.codecs()[0], media_channel1()->codecs()[0])); } // Test that SetLocalContent and SetRemoteContent properly configure @@ -553,7 +574,7 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { EXPECT_TRUE(channel1_->SetLocalContent(&content, SdpType::kOffer, NULL)); content.set_extmap_allow_mixed_enum(answer_enum); EXPECT_TRUE(channel1_->SetRemoteContent(&content, SdpType::kAnswer, NULL)); - EXPECT_EQ(answer, media_channel1_->ExtmapAllowMixed()); + EXPECT_EQ(answer, media_channel1()->ExtmapAllowMixed()); } void TestSetContentsExtmapAllowMixedCallee(bool offer, bool answer) { // For a callee, SetRemoteContent() is called first with an offer and next @@ -567,7 +588,7 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { EXPECT_TRUE(channel1_->SetRemoteContent(&content, SdpType::kOffer, NULL)); content.set_extmap_allow_mixed_enum(answer_enum); EXPECT_TRUE(channel1_->SetLocalContent(&content, SdpType::kAnswer, NULL)); - EXPECT_EQ(answer, media_channel1_->ExtmapAllowMixed()); + EXPECT_EQ(answer, media_channel1()->ExtmapAllowMixed()); } // Test that SetLocalContent and SetRemoteContent properly deals @@ -577,11 +598,11 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { typename T::Content content; EXPECT_TRUE(channel1_->SetLocalContent(&content, SdpType::kOffer, NULL)); CreateContent(0, kPcmuCodec, kH264Codec, &content); - EXPECT_EQ(0U, media_channel1_->codecs().size()); + EXPECT_EQ(0U, media_channel1()->codecs().size()); EXPECT_TRUE(channel1_->SetRemoteContent(&content, SdpType::kAnswer, NULL)); - ASSERT_EQ(1U, media_channel1_->codecs().size()); + ASSERT_EQ(1U, media_channel1()->codecs().size()); EXPECT_TRUE( - CodecMatches(content.codecs()[0], media_channel1_->codecs()[0])); + CodecMatches(content.codecs()[0], media_channel1()->codecs()[0])); } // Test that SetLocalContent and SetRemoteContent properly set RTCP @@ -622,25 +643,21 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { CreateContent(0, kPcmuCodec, kH264Codec, &content1); content1.AddStream(stream1); EXPECT_TRUE(channel1_->SetLocalContent(&content1, SdpType::kOffer, NULL)); - EXPECT_TRUE(channel1_->UpdateRtpTransport(nullptr)); - EXPECT_TRUE(channel1_->Enable(true)); - EXPECT_EQ(1u, media_channel1_->send_streams().size()); + channel1_->Enable(true); + EXPECT_EQ(1u, media_channel1()->send_streams().size()); EXPECT_TRUE(channel2_->SetRemoteContent(&content1, SdpType::kOffer, NULL)); - EXPECT_TRUE(channel2_->UpdateRtpTransport(nullptr)); - EXPECT_EQ(1u, media_channel2_->recv_streams().size()); + EXPECT_EQ(1u, media_channel2()->recv_streams().size()); ConnectFakeTransports(); // Channel 2 do not send anything. typename T::Content content2; CreateContent(0, kPcmuCodec, kH264Codec, &content2); EXPECT_TRUE(channel1_->SetRemoteContent(&content2, SdpType::kAnswer, NULL)); - EXPECT_TRUE(channel1_->UpdateRtpTransport(nullptr)); - EXPECT_EQ(0u, media_channel1_->recv_streams().size()); + EXPECT_EQ(0u, media_channel1()->recv_streams().size()); EXPECT_TRUE(channel2_->SetLocalContent(&content2, SdpType::kAnswer, NULL)); - EXPECT_TRUE(channel2_->UpdateRtpTransport(nullptr)); - EXPECT_TRUE(channel2_->Enable(true)); - EXPECT_EQ(0u, media_channel2_->send_streams().size()); + channel2_->Enable(true); + EXPECT_EQ(0u, media_channel2()->send_streams().size()); SendCustomRtp1(kSsrc1, 0); WaitForThreads(); @@ -651,25 +668,21 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { CreateContent(0, kPcmuCodec, kH264Codec, &content3); content3.AddStream(stream2); EXPECT_TRUE(channel2_->SetLocalContent(&content3, SdpType::kOffer, NULL)); - EXPECT_TRUE(channel2_->UpdateRtpTransport(nullptr)); - ASSERT_EQ(1u, media_channel2_->send_streams().size()); - EXPECT_EQ(stream2, media_channel2_->send_streams()[0]); + ASSERT_EQ(1u, media_channel2()->send_streams().size()); + EXPECT_EQ(stream2, media_channel2()->send_streams()[0]); EXPECT_TRUE(channel1_->SetRemoteContent(&content3, SdpType::kOffer, NULL)); - EXPECT_TRUE(channel1_->UpdateRtpTransport(nullptr)); - ASSERT_EQ(1u, media_channel1_->recv_streams().size()); - EXPECT_EQ(stream2, media_channel1_->recv_streams()[0]); + ASSERT_EQ(1u, media_channel1()->recv_streams().size()); + EXPECT_EQ(stream2, media_channel1()->recv_streams()[0]); // Channel 1 replies but stop sending stream1. typename T::Content content4; CreateContent(0, kPcmuCodec, kH264Codec, &content4); EXPECT_TRUE(channel1_->SetLocalContent(&content4, SdpType::kAnswer, NULL)); - EXPECT_TRUE(channel1_->UpdateRtpTransport(nullptr)); - EXPECT_EQ(0u, media_channel1_->send_streams().size()); + EXPECT_EQ(0u, media_channel1()->send_streams().size()); EXPECT_TRUE(channel2_->SetRemoteContent(&content4, SdpType::kAnswer, NULL)); - EXPECT_TRUE(channel2_->UpdateRtpTransport(nullptr)); - EXPECT_EQ(0u, media_channel2_->recv_streams().size()); + EXPECT_EQ(0u, media_channel2()->recv_streams().size()); SendCustomRtp2(kSsrc2, 0); WaitForThreads(); @@ -680,56 +693,58 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { void TestPlayoutAndSendingStates() { CreateChannels(0, 0); if (verify_playout_) { - EXPECT_FALSE(media_channel1_->playout()); + EXPECT_FALSE(media_channel1()->playout()); } - EXPECT_FALSE(media_channel1_->sending()); + EXPECT_FALSE(media_channel1()->sending()); if (verify_playout_) { - EXPECT_FALSE(media_channel2_->playout()); + EXPECT_FALSE(media_channel2()->playout()); } - EXPECT_FALSE(media_channel2_->sending()); - EXPECT_TRUE(channel1_->Enable(true)); + EXPECT_FALSE(media_channel2()->sending()); + channel1_->Enable(true); + FlushCurrentThread(); if (verify_playout_) { - EXPECT_FALSE(media_channel1_->playout()); + EXPECT_FALSE(media_channel1()->playout()); } - EXPECT_FALSE(media_channel1_->sending()); + EXPECT_FALSE(media_channel1()->sending()); EXPECT_TRUE(channel1_->SetLocalContent(&local_media_content1_, SdpType::kOffer, NULL)); if (verify_playout_) { - EXPECT_TRUE(media_channel1_->playout()); + EXPECT_TRUE(media_channel1()->playout()); } - EXPECT_FALSE(media_channel1_->sending()); + EXPECT_FALSE(media_channel1()->sending()); EXPECT_TRUE(channel2_->SetRemoteContent(&local_media_content1_, SdpType::kOffer, NULL)); if (verify_playout_) { - EXPECT_FALSE(media_channel2_->playout()); + EXPECT_FALSE(media_channel2()->playout()); } - EXPECT_FALSE(media_channel2_->sending()); + EXPECT_FALSE(media_channel2()->sending()); EXPECT_TRUE(channel2_->SetLocalContent(&local_media_content2_, SdpType::kAnswer, NULL)); if (verify_playout_) { - EXPECT_FALSE(media_channel2_->playout()); + EXPECT_FALSE(media_channel2()->playout()); } - EXPECT_FALSE(media_channel2_->sending()); + EXPECT_FALSE(media_channel2()->sending()); ConnectFakeTransports(); if (verify_playout_) { - EXPECT_TRUE(media_channel1_->playout()); + EXPECT_TRUE(media_channel1()->playout()); } - EXPECT_FALSE(media_channel1_->sending()); + EXPECT_FALSE(media_channel1()->sending()); if (verify_playout_) { - EXPECT_FALSE(media_channel2_->playout()); + EXPECT_FALSE(media_channel2()->playout()); } - EXPECT_FALSE(media_channel2_->sending()); - EXPECT_TRUE(channel2_->Enable(true)); + EXPECT_FALSE(media_channel2()->sending()); + channel2_->Enable(true); + FlushCurrentThread(); if (verify_playout_) { - EXPECT_TRUE(media_channel2_->playout()); + EXPECT_TRUE(media_channel2()->playout()); } - EXPECT_TRUE(media_channel2_->sending()); + EXPECT_TRUE(media_channel2()->sending()); EXPECT_TRUE(channel1_->SetRemoteContent(&local_media_content2_, SdpType::kAnswer, NULL)); if (verify_playout_) { - EXPECT_TRUE(media_channel1_->playout()); + EXPECT_TRUE(media_channel1()->playout()); } - EXPECT_TRUE(media_channel1_->sending()); + EXPECT_TRUE(media_channel1()->sending()); } // Test that changing the MediaContentDirection in the local and remote @@ -743,16 +758,17 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { // Set |content2| to be InActive. content2.set_direction(RtpTransceiverDirection::kInactive); - EXPECT_TRUE(channel1_->Enable(true)); - EXPECT_TRUE(channel2_->Enable(true)); + channel1_->Enable(true); + channel2_->Enable(true); + FlushCurrentThread(); if (verify_playout_) { - EXPECT_FALSE(media_channel1_->playout()); + EXPECT_FALSE(media_channel1()->playout()); } - EXPECT_FALSE(media_channel1_->sending()); + EXPECT_FALSE(media_channel1()->sending()); if (verify_playout_) { - EXPECT_FALSE(media_channel2_->playout()); + EXPECT_FALSE(media_channel2()->playout()); } - EXPECT_FALSE(media_channel2_->sending()); + EXPECT_FALSE(media_channel2()->sending()); EXPECT_TRUE(channel1_->SetLocalContent(&content1, SdpType::kOffer, NULL)); EXPECT_TRUE(channel2_->SetRemoteContent(&content1, SdpType::kOffer, NULL)); @@ -763,13 +779,13 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { ConnectFakeTransports(); if (verify_playout_) { - EXPECT_TRUE(media_channel1_->playout()); + EXPECT_TRUE(media_channel1()->playout()); } - EXPECT_FALSE(media_channel1_->sending()); // remote InActive + EXPECT_FALSE(media_channel1()->sending()); // remote InActive if (verify_playout_) { - EXPECT_FALSE(media_channel2_->playout()); // local InActive + EXPECT_FALSE(media_channel2()->playout()); // local InActive } - EXPECT_FALSE(media_channel2_->sending()); // local InActive + EXPECT_FALSE(media_channel2()->sending()); // local InActive // Update |content2| to be RecvOnly. content2.set_direction(RtpTransceiverDirection::kRecvOnly); @@ -779,13 +795,13 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { channel1_->SetRemoteContent(&content2, SdpType::kPrAnswer, NULL)); if (verify_playout_) { - EXPECT_TRUE(media_channel1_->playout()); + EXPECT_TRUE(media_channel1()->playout()); } - EXPECT_TRUE(media_channel1_->sending()); + EXPECT_TRUE(media_channel1()->sending()); if (verify_playout_) { - EXPECT_TRUE(media_channel2_->playout()); // local RecvOnly + EXPECT_TRUE(media_channel2()->playout()); // local RecvOnly } - EXPECT_FALSE(media_channel2_->sending()); // local RecvOnly + EXPECT_FALSE(media_channel2()->sending()); // local RecvOnly // Update |content2| to be SendRecv. content2.set_direction(RtpTransceiverDirection::kSendRecv); @@ -793,13 +809,13 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { EXPECT_TRUE(channel1_->SetRemoteContent(&content2, SdpType::kAnswer, NULL)); if (verify_playout_) { - EXPECT_TRUE(media_channel1_->playout()); + EXPECT_TRUE(media_channel1()->playout()); } - EXPECT_TRUE(media_channel1_->sending()); + EXPECT_TRUE(media_channel1()->sending()); if (verify_playout_) { - EXPECT_TRUE(media_channel2_->playout()); + EXPECT_TRUE(media_channel2()->playout()); } - EXPECT_TRUE(media_channel2_->sending()); + EXPECT_TRUE(media_channel2()->sending()); } // Tests that when the transport channel signals a candidate pair change @@ -865,43 +881,21 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { // Test setting up a call. void TestCallSetup() { CreateChannels(0, 0); - EXPECT_FALSE(channel1_->SrtpActiveForTesting()); + EXPECT_FALSE(IsSrtpActive(channel1_)); EXPECT_TRUE(SendInitiate()); if (verify_playout_) { - EXPECT_TRUE(media_channel1_->playout()); + EXPECT_TRUE(media_channel1()->playout()); } - EXPECT_FALSE(media_channel1_->sending()); + EXPECT_FALSE(media_channel1()->sending()); EXPECT_TRUE(SendAccept()); - EXPECT_FALSE(channel1_->SrtpActiveForTesting()); - EXPECT_TRUE(media_channel1_->sending()); - EXPECT_EQ(1U, media_channel1_->codecs().size()); + EXPECT_FALSE(IsSrtpActive(channel1_)); + EXPECT_TRUE(media_channel1()->sending()); + EXPECT_EQ(1U, media_channel1()->codecs().size()); if (verify_playout_) { - EXPECT_TRUE(media_channel2_->playout()); + EXPECT_TRUE(media_channel2()->playout()); } - EXPECT_TRUE(media_channel2_->sending()); - EXPECT_EQ(1U, media_channel2_->codecs().size()); - } - - // Test that we don't crash if packets are sent during call teardown - // when RTCP mux is enabled. This is a regression test against a specific - // race condition that would only occur when a RTCP packet was sent during - // teardown of a channel on which RTCP mux was enabled. - void TestCallTeardownRtcpMux() { - class LastWordMediaChannel : public T::MediaChannel { - public: - LastWordMediaChannel() : T::MediaChannel(NULL, typename T::Options()) {} - ~LastWordMediaChannel() { - T::MediaChannel::SendRtp(kPcmuFrame, sizeof(kPcmuFrame), - rtc::PacketOptions()); - T::MediaChannel::SendRtcp(kRtcpReport, sizeof(kRtcpReport)); - } - }; - CreateChannels(std::make_unique(), - std::make_unique(), RTCP_MUX, - RTCP_MUX); - EXPECT_TRUE(SendInitiate()); - EXPECT_TRUE(SendAccept()); - EXPECT_TRUE(Terminate()); + EXPECT_TRUE(media_channel2()->sending()); + EXPECT_EQ(1U, media_channel2()->codecs().size()); } // Send voice RTP data to the other side and ensure it gets there. @@ -909,8 +903,8 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { CreateChannels(RTCP_MUX, RTCP_MUX); EXPECT_TRUE(SendInitiate()); EXPECT_TRUE(SendAccept()); - EXPECT_TRUE(channel1_->RtpTransportForTesting()->rtcp_mux_enabled()); - EXPECT_TRUE(channel2_->RtpTransportForTesting()->rtcp_mux_enabled()); + EXPECT_TRUE(IsRtcpMuxEnabled(channel1_)); + EXPECT_TRUE(IsRtcpMuxEnabled(channel2_)); SendRtp1(); SendRtp2(); WaitForThreads(); @@ -933,13 +927,13 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { void SendDtlsSrtpToDtlsSrtp(int flags1, int flags2) { CreateChannels(flags1 | DTLS, flags2 | DTLS); - EXPECT_FALSE(channel1_->SrtpActiveForTesting()); - EXPECT_FALSE(channel2_->SrtpActiveForTesting()); + EXPECT_FALSE(IsSrtpActive(channel1_)); + EXPECT_FALSE(IsSrtpActive(channel2_)); EXPECT_TRUE(SendInitiate()); WaitForThreads(); EXPECT_TRUE(SendAccept()); - EXPECT_TRUE(channel1_->SrtpActiveForTesting()); - EXPECT_TRUE(channel2_->SrtpActiveForTesting()); + EXPECT_TRUE(IsSrtpActive(channel1_)); + EXPECT_TRUE(IsSrtpActive(channel2_)); SendRtp1(); SendRtp2(); WaitForThreads(); @@ -957,10 +951,10 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { CreateChannels(SSRC_MUX | RTCP_MUX | DTLS, SSRC_MUX | RTCP_MUX | DTLS); EXPECT_TRUE(SendOffer()); EXPECT_TRUE(SendProvisionalAnswer()); - EXPECT_TRUE(channel1_->SrtpActiveForTesting()); - EXPECT_TRUE(channel2_->SrtpActiveForTesting()); - EXPECT_TRUE(channel1_->RtpTransportForTesting()->rtcp_mux_enabled()); - EXPECT_TRUE(channel2_->RtpTransportForTesting()->rtcp_mux_enabled()); + EXPECT_TRUE(IsSrtpActive(channel1_)); + EXPECT_TRUE(IsSrtpActive(channel2_)); + EXPECT_TRUE(IsRtcpMuxEnabled(channel1_)); + EXPECT_TRUE(IsRtcpMuxEnabled(channel2_)); WaitForThreads(); // Wait for 'sending' flag go through network thread. SendCustomRtp1(kSsrc1, ++sequence_number1_1); WaitForThreads(); @@ -973,8 +967,8 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { // Complete call setup and ensure everything is still OK. EXPECT_TRUE(SendFinalAnswer()); - EXPECT_TRUE(channel1_->SrtpActiveForTesting()); - EXPECT_TRUE(channel2_->SrtpActiveForTesting()); + EXPECT_TRUE(IsSrtpActive(channel1_)); + EXPECT_TRUE(IsSrtpActive(channel2_)); SendCustomRtp1(kSsrc1, ++sequence_number1_1); SendCustomRtp2(kSsrc2, ++sequence_number2_2); WaitForThreads(); @@ -1003,8 +997,8 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { CreateChannels(RTCP_MUX, RTCP_MUX); EXPECT_TRUE(SendInitiate()); EXPECT_TRUE(SendAccept()); - EXPECT_TRUE(channel1_->RtpTransportForTesting()->rtcp_mux_enabled()); - EXPECT_TRUE(channel2_->RtpTransportForTesting()->rtcp_mux_enabled()); + EXPECT_TRUE(IsRtcpMuxEnabled(channel1_)); + EXPECT_TRUE(IsRtcpMuxEnabled(channel2_)); SendRtp1(); SendRtp2(); WaitForThreads(); @@ -1027,7 +1021,7 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { network_thread_->Invoke(RTC_FROM_HERE, [this] { fake_rtp_dtls_transport1_->SetWritable(true); }); - EXPECT_TRUE(media_channel1_->sending()); + EXPECT_TRUE(media_channel1()->sending()); SendRtp1(); SendRtp2(); WaitForThreads(); @@ -1041,7 +1035,7 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { bool asymmetric = true; fake_rtp_dtls_transport1_->SetDestination(nullptr, asymmetric); }); - EXPECT_TRUE(media_channel1_->sending()); + EXPECT_TRUE(media_channel1()->sending()); // Should fail also. SendRtp1(); @@ -1057,7 +1051,7 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { fake_rtp_dtls_transport1_->SetDestination(fake_rtp_dtls_transport2_.get(), asymmetric); }); - EXPECT_TRUE(media_channel1_->sending()); + EXPECT_TRUE(media_channel1()->sending()); SendRtp1(); SendRtp2(); WaitForThreads(); @@ -1110,17 +1104,17 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { std::unique_ptr content( CreateMediaContentWithStream(1)); - media_channel1_->set_fail_set_recv_codecs(true); + media_channel1()->set_fail_set_recv_codecs(true); EXPECT_FALSE( channel1_->SetLocalContent(content.get(), SdpType::kOffer, &err)); EXPECT_FALSE( channel1_->SetLocalContent(content.get(), SdpType::kAnswer, &err)); - media_channel1_->set_fail_set_send_codecs(true); + media_channel1()->set_fail_set_send_codecs(true); EXPECT_FALSE( channel1_->SetRemoteContent(content.get(), SdpType::kOffer, &err)); - media_channel1_->set_fail_set_send_codecs(true); + media_channel1()->set_fail_set_send_codecs(true); EXPECT_FALSE( channel1_->SetRemoteContent(content.get(), SdpType::kAnswer, &err)); } @@ -1133,14 +1127,14 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { CreateMediaContentWithStream(1)); EXPECT_TRUE( channel1_->SetLocalContent(content1.get(), SdpType::kOffer, &err)); - EXPECT_TRUE(media_channel1_->HasSendStream(1)); + EXPECT_TRUE(media_channel1()->HasSendStream(1)); std::unique_ptr content2( CreateMediaContentWithStream(2)); EXPECT_TRUE( channel1_->SetLocalContent(content2.get(), SdpType::kOffer, &err)); - EXPECT_FALSE(media_channel1_->HasSendStream(1)); - EXPECT_TRUE(media_channel1_->HasSendStream(2)); + EXPECT_FALSE(media_channel1()->HasSendStream(1)); + EXPECT_TRUE(media_channel1()->HasSendStream(2)); } void TestReceiveTwoOffers() { @@ -1151,14 +1145,14 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { CreateMediaContentWithStream(1)); EXPECT_TRUE( channel1_->SetRemoteContent(content1.get(), SdpType::kOffer, &err)); - EXPECT_TRUE(media_channel1_->HasRecvStream(1)); + EXPECT_TRUE(media_channel1()->HasRecvStream(1)); std::unique_ptr content2( CreateMediaContentWithStream(2)); EXPECT_TRUE( channel1_->SetRemoteContent(content2.get(), SdpType::kOffer, &err)); - EXPECT_FALSE(media_channel1_->HasRecvStream(1)); - EXPECT_TRUE(media_channel1_->HasRecvStream(2)); + EXPECT_FALSE(media_channel1()->HasRecvStream(1)); + EXPECT_TRUE(media_channel1()->HasRecvStream(2)); } void TestSendPrAnswer() { @@ -1170,24 +1164,24 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { CreateMediaContentWithStream(1)); EXPECT_TRUE( channel1_->SetRemoteContent(content1.get(), SdpType::kOffer, &err)); - EXPECT_TRUE(media_channel1_->HasRecvStream(1)); + EXPECT_TRUE(media_channel1()->HasRecvStream(1)); // Send PR answer std::unique_ptr content2( CreateMediaContentWithStream(2)); EXPECT_TRUE( channel1_->SetLocalContent(content2.get(), SdpType::kPrAnswer, &err)); - EXPECT_TRUE(media_channel1_->HasRecvStream(1)); - EXPECT_TRUE(media_channel1_->HasSendStream(2)); + EXPECT_TRUE(media_channel1()->HasRecvStream(1)); + EXPECT_TRUE(media_channel1()->HasSendStream(2)); // Send answer std::unique_ptr content3( CreateMediaContentWithStream(3)); EXPECT_TRUE( channel1_->SetLocalContent(content3.get(), SdpType::kAnswer, &err)); - EXPECT_TRUE(media_channel1_->HasRecvStream(1)); - EXPECT_FALSE(media_channel1_->HasSendStream(2)); - EXPECT_TRUE(media_channel1_->HasSendStream(3)); + EXPECT_TRUE(media_channel1()->HasRecvStream(1)); + EXPECT_FALSE(media_channel1()->HasSendStream(2)); + EXPECT_TRUE(media_channel1()->HasSendStream(3)); } void TestReceivePrAnswer() { @@ -1199,37 +1193,39 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { CreateMediaContentWithStream(1)); EXPECT_TRUE( channel1_->SetLocalContent(content1.get(), SdpType::kOffer, &err)); - EXPECT_TRUE(media_channel1_->HasSendStream(1)); + EXPECT_TRUE(media_channel1()->HasSendStream(1)); // Receive PR answer std::unique_ptr content2( CreateMediaContentWithStream(2)); EXPECT_TRUE( channel1_->SetRemoteContent(content2.get(), SdpType::kPrAnswer, &err)); - EXPECT_TRUE(media_channel1_->HasSendStream(1)); - EXPECT_TRUE(media_channel1_->HasRecvStream(2)); + EXPECT_TRUE(media_channel1()->HasSendStream(1)); + EXPECT_TRUE(media_channel1()->HasRecvStream(2)); // Receive answer std::unique_ptr content3( CreateMediaContentWithStream(3)); EXPECT_TRUE( channel1_->SetRemoteContent(content3.get(), SdpType::kAnswer, &err)); - EXPECT_TRUE(media_channel1_->HasSendStream(1)); - EXPECT_FALSE(media_channel1_->HasRecvStream(2)); - EXPECT_TRUE(media_channel1_->HasRecvStream(3)); + EXPECT_TRUE(media_channel1()->HasSendStream(1)); + EXPECT_FALSE(media_channel1()->HasRecvStream(2)); + EXPECT_TRUE(media_channel1()->HasRecvStream(3)); } void TestOnTransportReadyToSend() { CreateChannels(0, 0); - EXPECT_FALSE(media_channel1_->ready_to_send()); + EXPECT_FALSE(media_channel1()->ready_to_send()); - channel1_->OnTransportReadyToSend(true); + network_thread_->PostTask( + RTC_FROM_HERE, [this] { channel1_->OnTransportReadyToSend(true); }); WaitForThreads(); - EXPECT_TRUE(media_channel1_->ready_to_send()); + EXPECT_TRUE(media_channel1()->ready_to_send()); - channel1_->OnTransportReadyToSend(false); + network_thread_->PostTask( + RTC_FROM_HERE, [this] { channel1_->OnTransportReadyToSend(false); }); WaitForThreads(); - EXPECT_FALSE(media_channel1_->ready_to_send()); + EXPECT_FALSE(media_channel1()->ready_to_send()); } bool SetRemoteContentWithBitrateLimit(int remote_limit) { @@ -1257,8 +1253,8 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { CreateChannels(0, 0); EXPECT_TRUE(channel1_->SetLocalContent(&local_media_content1_, SdpType::kOffer, NULL)); - EXPECT_EQ(media_channel1_->max_bps(), -1); - VerifyMaxBitrate(media_channel1_->GetRtpSendParameters(kSsrc1), + EXPECT_EQ(media_channel1()->max_bps(), -1); + VerifyMaxBitrate(media_channel1()->GetRtpSendParameters(kSsrc1), absl::nullopt); } @@ -1275,22 +1271,27 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { CreateChannels(DTLS, DTLS); - channel1_->SetOption(cricket::BaseChannel::ST_RTP, - rtc::Socket::Option::OPT_SNDBUF, kSndBufSize); - channel2_->SetOption(cricket::BaseChannel::ST_RTP, - rtc::Socket::Option::OPT_RCVBUF, kRcvBufSize); - new_rtp_transport_ = CreateDtlsSrtpTransport( fake_rtp_dtls_transport2_.get(), fake_rtcp_dtls_transport2_.get()); - channel1_->SetRtpTransport(new_rtp_transport_.get()); - int option_val; - ASSERT_TRUE(fake_rtp_dtls_transport2_->GetOption( - rtc::Socket::Option::OPT_SNDBUF, &option_val)); - EXPECT_EQ(kSndBufSize, option_val); - ASSERT_TRUE(fake_rtp_dtls_transport2_->GetOption( - rtc::Socket::Option::OPT_RCVBUF, &option_val)); - EXPECT_EQ(kRcvBufSize, option_val); + bool rcv_success, send_success; + int rcv_buf, send_buf; + network_thread_->Invoke(RTC_FROM_HERE, [&] { + channel1_->SetOption(cricket::BaseChannel::ST_RTP, + rtc::Socket::Option::OPT_SNDBUF, kSndBufSize); + channel2_->SetOption(cricket::BaseChannel::ST_RTP, + rtc::Socket::Option::OPT_RCVBUF, kRcvBufSize); + channel1_->SetRtpTransport(new_rtp_transport_.get()); + send_success = fake_rtp_dtls_transport2_->GetOption( + rtc::Socket::Option::OPT_SNDBUF, &send_buf); + rcv_success = fake_rtp_dtls_transport2_->GetOption( + rtc::Socket::Option::OPT_RCVBUF, &rcv_buf); + }); + + ASSERT_TRUE(send_success); + EXPECT_EQ(kSndBufSize, send_buf); + ASSERT_TRUE(rcv_success); + EXPECT_EQ(kRcvBufSize, rcv_buf); } void CreateSimulcastContent(const std::vector& rids, @@ -1354,6 +1355,9 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { thread->ProcessMessages(0); } } + static void FlushCurrentThread() { + rtc::Thread::Current()->ProcessMessages(0); + } void WaitForThreads(rtc::ArrayView threads) { // |threads| and current thread post packets to network thread. for (rtc::Thread* thread : threads) { @@ -1369,9 +1373,24 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { // Worker thread = current Thread process received messages. ProcessThreadQueue(rtc::Thread::Current()); } + + typename T::MediaChannel* media_channel1() { + RTC_DCHECK(channel1_); + RTC_DCHECK(channel1_->media_channel()); + return static_cast(channel1_->media_channel()); + } + + typename T::MediaChannel* media_channel2() { + RTC_DCHECK(channel2_); + RTC_DCHECK(channel2_->media_channel()); + return static_cast(channel2_->media_channel()); + } + // TODO(pbos): Remove playout from all media channels and let renderers mute // themselves. const bool verify_playout_; + rtc::scoped_refptr network_thread_safety_ = + webrtc::PendingTaskSafetyFlag::CreateDetached(); std::unique_ptr network_thread_keeper_; rtc::Thread* network_thread_; std::unique_ptr fake_rtp_dtls_transport1_; @@ -1386,9 +1405,6 @@ class ChannelTest : public ::testing::Test, public sigslot::has_slots<> { std::unique_ptr rtp_transport2_; std::unique_ptr new_rtp_transport_; cricket::FakeMediaEngine media_engine_; - // The media channels are owned by the voice channel objects below. - typename T::MediaChannel* media_channel1_ = nullptr; - typename T::MediaChannel* media_channel2_ = nullptr; std::unique_ptr channel1_; std::unique_ptr channel2_; typename T::Content local_media_content1_; @@ -1548,8 +1564,8 @@ class VideoChannelDoubleThreadTest : public ChannelTest { TEST_F(VoiceChannelSingleThreadTest, TestInit) { Base::TestInit(); - EXPECT_FALSE(media_channel1_->IsStreamMuted(0)); - EXPECT_TRUE(media_channel1_->dtmf_info_queue().empty()); + EXPECT_FALSE(media_channel1()->IsStreamMuted(0)); + EXPECT_TRUE(media_channel1()->dtmf_info_queue().empty()); } TEST_F(VoiceChannelSingleThreadTest, TestDeinit) { @@ -1610,10 +1626,6 @@ TEST_F(VoiceChannelSingleThreadTest, TestCallSetup) { Base::TestCallSetup(); } -TEST_F(VoiceChannelSingleThreadTest, TestCallTeardownRtcpMux) { - Base::TestCallTeardownRtcpMux(); -} - TEST_F(VoiceChannelSingleThreadTest, SendRtpToRtp) { Base::SendRtpToRtp(); } @@ -1689,8 +1701,8 @@ TEST_F(VoiceChannelSingleThreadTest, SocketOptionsMergedOnSetTransport) { // VoiceChannelDoubleThreadTest TEST_F(VoiceChannelDoubleThreadTest, TestInit) { Base::TestInit(); - EXPECT_FALSE(media_channel1_->IsStreamMuted(0)); - EXPECT_TRUE(media_channel1_->dtmf_info_queue().empty()); + EXPECT_FALSE(media_channel1()->IsStreamMuted(0)); + EXPECT_TRUE(media_channel1()->dtmf_info_queue().empty()); } TEST_F(VoiceChannelDoubleThreadTest, TestDeinit) { @@ -1751,10 +1763,6 @@ TEST_F(VoiceChannelDoubleThreadTest, TestCallSetup) { Base::TestCallSetup(); } -TEST_F(VoiceChannelDoubleThreadTest, TestCallTeardownRtcpMux) { - Base::TestCallTeardownRtcpMux(); -} - TEST_F(VoiceChannelDoubleThreadTest, SendRtpToRtp) { Base::SendRtpToRtp(); } @@ -1890,10 +1898,6 @@ TEST_F(VideoChannelSingleThreadTest, TestCallSetup) { Base::TestCallSetup(); } -TEST_F(VideoChannelSingleThreadTest, TestCallTeardownRtcpMux) { - Base::TestCallTeardownRtcpMux(); -} - TEST_F(VideoChannelSingleThreadTest, SendRtpToRtp) { Base::SendRtpToRtp(); } @@ -1980,12 +1984,12 @@ TEST_F(VideoChannelSingleThreadTest, TestSetLocalOfferWithPacketization) { CreateChannels(0, 0); EXPECT_TRUE(channel1_->SetLocalContent(&video, SdpType::kOffer, NULL)); - EXPECT_THAT(media_channel1_->send_codecs(), testing::IsEmpty()); - ASSERT_THAT(media_channel1_->recv_codecs(), testing::SizeIs(2)); - EXPECT_TRUE(media_channel1_->recv_codecs()[0].Matches(kVp8Codec)); - EXPECT_EQ(media_channel1_->recv_codecs()[0].packetization, absl::nullopt); - EXPECT_TRUE(media_channel1_->recv_codecs()[1].Matches(vp9_codec)); - EXPECT_EQ(media_channel1_->recv_codecs()[1].packetization, + EXPECT_THAT(media_channel1()->send_codecs(), testing::IsEmpty()); + ASSERT_THAT(media_channel1()->recv_codecs(), testing::SizeIs(2)); + EXPECT_TRUE(media_channel1()->recv_codecs()[0].Matches(kVp8Codec)); + EXPECT_EQ(media_channel1()->recv_codecs()[0].packetization, absl::nullopt); + EXPECT_TRUE(media_channel1()->recv_codecs()[1].Matches(vp9_codec)); + EXPECT_EQ(media_channel1()->recv_codecs()[1].packetization, cricket::kPacketizationParamRaw); } @@ -1999,12 +2003,12 @@ TEST_F(VideoChannelSingleThreadTest, TestSetRemoteOfferWithPacketization) { CreateChannels(0, 0); EXPECT_TRUE(channel1_->SetRemoteContent(&video, SdpType::kOffer, NULL)); - EXPECT_THAT(media_channel1_->recv_codecs(), testing::IsEmpty()); - ASSERT_THAT(media_channel1_->send_codecs(), testing::SizeIs(2)); - EXPECT_TRUE(media_channel1_->send_codecs()[0].Matches(kVp8Codec)); - EXPECT_EQ(media_channel1_->send_codecs()[0].packetization, absl::nullopt); - EXPECT_TRUE(media_channel1_->send_codecs()[1].Matches(vp9_codec)); - EXPECT_EQ(media_channel1_->send_codecs()[1].packetization, + EXPECT_THAT(media_channel1()->recv_codecs(), testing::IsEmpty()); + ASSERT_THAT(media_channel1()->send_codecs(), testing::SizeIs(2)); + EXPECT_TRUE(media_channel1()->send_codecs()[0].Matches(kVp8Codec)); + EXPECT_EQ(media_channel1()->send_codecs()[0].packetization, absl::nullopt); + EXPECT_TRUE(media_channel1()->send_codecs()[1].Matches(vp9_codec)); + EXPECT_EQ(media_channel1()->send_codecs()[1].packetization, cricket::kPacketizationParamRaw); } @@ -2019,17 +2023,17 @@ TEST_F(VideoChannelSingleThreadTest, TestSetAnswerWithPacketization) { EXPECT_TRUE(channel1_->SetLocalContent(&video, SdpType::kOffer, NULL)); EXPECT_TRUE(channel1_->SetRemoteContent(&video, SdpType::kAnswer, NULL)); - ASSERT_THAT(media_channel1_->recv_codecs(), testing::SizeIs(2)); - EXPECT_TRUE(media_channel1_->recv_codecs()[0].Matches(kVp8Codec)); - EXPECT_EQ(media_channel1_->recv_codecs()[0].packetization, absl::nullopt); - EXPECT_TRUE(media_channel1_->recv_codecs()[1].Matches(vp9_codec)); - EXPECT_EQ(media_channel1_->recv_codecs()[1].packetization, + ASSERT_THAT(media_channel1()->recv_codecs(), testing::SizeIs(2)); + EXPECT_TRUE(media_channel1()->recv_codecs()[0].Matches(kVp8Codec)); + EXPECT_EQ(media_channel1()->recv_codecs()[0].packetization, absl::nullopt); + EXPECT_TRUE(media_channel1()->recv_codecs()[1].Matches(vp9_codec)); + EXPECT_EQ(media_channel1()->recv_codecs()[1].packetization, cricket::kPacketizationParamRaw); - EXPECT_THAT(media_channel1_->send_codecs(), testing::SizeIs(2)); - EXPECT_TRUE(media_channel1_->send_codecs()[0].Matches(kVp8Codec)); - EXPECT_EQ(media_channel1_->send_codecs()[0].packetization, absl::nullopt); - EXPECT_TRUE(media_channel1_->send_codecs()[1].Matches(vp9_codec)); - EXPECT_EQ(media_channel1_->send_codecs()[1].packetization, + EXPECT_THAT(media_channel1()->send_codecs(), testing::SizeIs(2)); + EXPECT_TRUE(media_channel1()->send_codecs()[0].Matches(kVp8Codec)); + EXPECT_EQ(media_channel1()->send_codecs()[0].packetization, absl::nullopt); + EXPECT_TRUE(media_channel1()->send_codecs()[1].Matches(vp9_codec)); + EXPECT_EQ(media_channel1()->send_codecs()[1].packetization, cricket::kPacketizationParamRaw); } @@ -2047,10 +2051,10 @@ TEST_F(VideoChannelSingleThreadTest, TestSetLocalAnswerWithoutPacketization) { EXPECT_TRUE( channel1_->SetRemoteContent(&remote_video, SdpType::kOffer, NULL)); EXPECT_TRUE(channel1_->SetLocalContent(&local_video, SdpType::kAnswer, NULL)); - ASSERT_THAT(media_channel1_->recv_codecs(), testing::SizeIs(1)); - EXPECT_EQ(media_channel1_->recv_codecs()[0].packetization, absl::nullopt); - ASSERT_THAT(media_channel1_->send_codecs(), testing::SizeIs(1)); - EXPECT_EQ(media_channel1_->send_codecs()[0].packetization, absl::nullopt); + ASSERT_THAT(media_channel1()->recv_codecs(), testing::SizeIs(1)); + EXPECT_EQ(media_channel1()->recv_codecs()[0].packetization, absl::nullopt); + ASSERT_THAT(media_channel1()->send_codecs(), testing::SizeIs(1)); + EXPECT_EQ(media_channel1()->send_codecs()[0].packetization, absl::nullopt); } TEST_F(VideoChannelSingleThreadTest, TestSetRemoteAnswerWithoutPacketization) { @@ -2067,10 +2071,10 @@ TEST_F(VideoChannelSingleThreadTest, TestSetRemoteAnswerWithoutPacketization) { EXPECT_TRUE(channel1_->SetLocalContent(&local_video, SdpType::kOffer, NULL)); EXPECT_TRUE( channel1_->SetRemoteContent(&remote_video, SdpType::kAnswer, NULL)); - ASSERT_THAT(media_channel1_->recv_codecs(), testing::SizeIs(1)); - EXPECT_EQ(media_channel1_->recv_codecs()[0].packetization, absl::nullopt); - ASSERT_THAT(media_channel1_->send_codecs(), testing::SizeIs(1)); - EXPECT_EQ(media_channel1_->send_codecs()[0].packetization, absl::nullopt); + ASSERT_THAT(media_channel1()->recv_codecs(), testing::SizeIs(1)); + EXPECT_EQ(media_channel1()->recv_codecs()[0].packetization, absl::nullopt); + ASSERT_THAT(media_channel1()->send_codecs(), testing::SizeIs(1)); + EXPECT_EQ(media_channel1()->send_codecs()[0].packetization, absl::nullopt); } TEST_F(VideoChannelSingleThreadTest, @@ -2089,10 +2093,10 @@ TEST_F(VideoChannelSingleThreadTest, EXPECT_TRUE(channel1_->SetLocalContent(&local_video, SdpType::kOffer, NULL)); EXPECT_FALSE( channel1_->SetRemoteContent(&remote_video, SdpType::kAnswer, NULL)); - ASSERT_THAT(media_channel1_->recv_codecs(), testing::SizeIs(1)); - EXPECT_EQ(media_channel1_->recv_codecs()[0].packetization, + ASSERT_THAT(media_channel1()->recv_codecs(), testing::SizeIs(1)); + EXPECT_EQ(media_channel1()->recv_codecs()[0].packetization, cricket::kPacketizationParamRaw); - EXPECT_THAT(media_channel1_->send_codecs(), testing::IsEmpty()); + EXPECT_THAT(media_channel1()->send_codecs(), testing::IsEmpty()); } TEST_F(VideoChannelSingleThreadTest, @@ -2111,9 +2115,9 @@ TEST_F(VideoChannelSingleThreadTest, channel1_->SetRemoteContent(&remote_video, SdpType::kOffer, NULL)); EXPECT_FALSE( channel1_->SetLocalContent(&local_video, SdpType::kAnswer, NULL)); - EXPECT_THAT(media_channel1_->recv_codecs(), testing::IsEmpty()); - ASSERT_THAT(media_channel1_->send_codecs(), testing::SizeIs(1)); - EXPECT_EQ(media_channel1_->send_codecs()[0].packetization, absl::nullopt); + EXPECT_THAT(media_channel1()->recv_codecs(), testing::IsEmpty()); + ASSERT_THAT(media_channel1()->send_codecs(), testing::SizeIs(1)); + EXPECT_EQ(media_channel1()->send_codecs()[0].packetization, absl::nullopt); } // VideoChannelDoubleThreadTest @@ -2179,10 +2183,6 @@ TEST_F(VideoChannelDoubleThreadTest, TestCallSetup) { Base::TestCallSetup(); } -TEST_F(VideoChannelDoubleThreadTest, TestCallTeardownRtcpMux) { - Base::TestCallTeardownRtcpMux(); -} - TEST_F(VideoChannelDoubleThreadTest, SendRtpToRtp) { Base::SendRtpToRtp(); } @@ -2255,220 +2255,5 @@ TEST_F(VideoChannelDoubleThreadTest, SocketOptionsMergedOnSetTransport) { Base::SocketOptionsMergedOnSetTransport(); } -// RtpDataChannelSingleThreadTest -class RtpDataChannelSingleThreadTest : public ChannelTest { - public: - typedef ChannelTest Base; - RtpDataChannelSingleThreadTest() - : Base(true, kDataPacket, kRtcpReport, NetworkIsWorker::Yes) {} -}; - -// RtpDataChannelDoubleThreadTest -class RtpDataChannelDoubleThreadTest : public ChannelTest { - public: - typedef ChannelTest Base; - RtpDataChannelDoubleThreadTest() - : Base(true, kDataPacket, kRtcpReport, NetworkIsWorker::No) {} -}; - -// Override to avoid engine channel parameter. -template <> -std::unique_ptr ChannelTest::CreateChannel( - rtc::Thread* worker_thread, - rtc::Thread* network_thread, - std::unique_ptr ch, - webrtc::RtpTransportInternal* rtp_transport, - int flags) { - rtc::Thread* signaling_thread = rtc::Thread::Current(); - auto channel = std::make_unique( - worker_thread, network_thread, signaling_thread, std::move(ch), - cricket::CN_DATA, (flags & DTLS) != 0, webrtc::CryptoOptions(), - &ssrc_generator_); - channel->Init_w(rtp_transport); - return channel; -} - -template <> -void ChannelTest::CreateContent( - int flags, - const cricket::AudioCodec& audio_codec, - const cricket::VideoCodec& video_codec, - cricket::RtpDataContentDescription* data) { - data->AddCodec(kGoogleDataCodec); - data->set_rtcp_mux((flags & RTCP_MUX) != 0); -} - -template <> -void ChannelTest::CopyContent( - const cricket::RtpDataContentDescription& source, - cricket::RtpDataContentDescription* data) { - *data = source; -} - -template <> -bool ChannelTest::CodecMatches(const cricket::DataCodec& c1, - const cricket::DataCodec& c2) { - return c1.name == c2.name; -} - -template <> -void ChannelTest::AddLegacyStreamInContent( - uint32_t ssrc, - int flags, - cricket::RtpDataContentDescription* data) { - data->AddLegacyStream(ssrc); -} - -TEST_F(RtpDataChannelSingleThreadTest, TestInit) { - Base::TestInit(); - EXPECT_FALSE(media_channel1_->IsStreamMuted(0)); -} - -TEST_F(RtpDataChannelSingleThreadTest, TestDeinit) { - Base::TestDeinit(); -} - -TEST_F(RtpDataChannelSingleThreadTest, TestSetContents) { - Base::TestSetContents(); -} - -TEST_F(RtpDataChannelSingleThreadTest, TestSetContentsNullOffer) { - Base::TestSetContentsNullOffer(); -} - -TEST_F(RtpDataChannelSingleThreadTest, TestSetContentsRtcpMux) { - Base::TestSetContentsRtcpMux(); -} - -TEST_F(RtpDataChannelSingleThreadTest, TestChangeStreamParamsInContent) { - Base::TestChangeStreamParamsInContent(); -} - -TEST_F(RtpDataChannelSingleThreadTest, TestPlayoutAndSendingStates) { - Base::TestPlayoutAndSendingStates(); -} - -TEST_F(RtpDataChannelSingleThreadTest, TestMediaContentDirection) { - Base::TestMediaContentDirection(); -} - -TEST_F(RtpDataChannelSingleThreadTest, TestCallSetup) { - Base::TestCallSetup(); -} - -TEST_F(RtpDataChannelSingleThreadTest, TestCallTeardownRtcpMux) { - Base::TestCallTeardownRtcpMux(); -} - -TEST_F(RtpDataChannelSingleThreadTest, TestOnTransportReadyToSend) { - Base::TestOnTransportReadyToSend(); -} - -TEST_F(RtpDataChannelSingleThreadTest, SendRtpToRtp) { - Base::SendRtpToRtp(); -} - -TEST_F(RtpDataChannelSingleThreadTest, SendRtpToRtpOnThread) { - Base::SendRtpToRtpOnThread(); -} - -TEST_F(RtpDataChannelSingleThreadTest, SendWithWritabilityLoss) { - Base::SendWithWritabilityLoss(); -} - -TEST_F(RtpDataChannelSingleThreadTest, SocketOptionsMergedOnSetTransport) { - Base::SocketOptionsMergedOnSetTransport(); -} - -TEST_F(RtpDataChannelSingleThreadTest, TestSendData) { - CreateChannels(0, 0); - EXPECT_TRUE(SendInitiate()); - EXPECT_TRUE(SendAccept()); - - cricket::SendDataParams params; - params.ssrc = 42; - unsigned char data[] = {'f', 'o', 'o'}; - rtc::CopyOnWriteBuffer payload(data, 3); - cricket::SendDataResult result; - ASSERT_TRUE(media_channel1_->SendData(params, payload, &result)); - EXPECT_EQ(params.ssrc, media_channel1_->last_sent_data_params().ssrc); - EXPECT_EQ("foo", media_channel1_->last_sent_data()); -} - -TEST_F(RtpDataChannelDoubleThreadTest, TestInit) { - Base::TestInit(); - EXPECT_FALSE(media_channel1_->IsStreamMuted(0)); -} - -TEST_F(RtpDataChannelDoubleThreadTest, TestDeinit) { - Base::TestDeinit(); -} - -TEST_F(RtpDataChannelDoubleThreadTest, TestSetContents) { - Base::TestSetContents(); -} - -TEST_F(RtpDataChannelDoubleThreadTest, TestSetContentsNullOffer) { - Base::TestSetContentsNullOffer(); -} - -TEST_F(RtpDataChannelDoubleThreadTest, TestSetContentsRtcpMux) { - Base::TestSetContentsRtcpMux(); -} - -TEST_F(RtpDataChannelDoubleThreadTest, TestChangeStreamParamsInContent) { - Base::TestChangeStreamParamsInContent(); -} - -TEST_F(RtpDataChannelDoubleThreadTest, TestPlayoutAndSendingStates) { - Base::TestPlayoutAndSendingStates(); -} - -TEST_F(RtpDataChannelDoubleThreadTest, TestMediaContentDirection) { - Base::TestMediaContentDirection(); -} - -TEST_F(RtpDataChannelDoubleThreadTest, TestCallSetup) { - Base::TestCallSetup(); -} - -TEST_F(RtpDataChannelDoubleThreadTest, TestCallTeardownRtcpMux) { - Base::TestCallTeardownRtcpMux(); -} - -TEST_F(RtpDataChannelDoubleThreadTest, TestOnTransportReadyToSend) { - Base::TestOnTransportReadyToSend(); -} - -TEST_F(RtpDataChannelDoubleThreadTest, SendRtpToRtp) { - Base::SendRtpToRtp(); -} - -TEST_F(RtpDataChannelDoubleThreadTest, SendRtpToRtpOnThread) { - Base::SendRtpToRtpOnThread(); -} - -TEST_F(RtpDataChannelDoubleThreadTest, SendWithWritabilityLoss) { - Base::SendWithWritabilityLoss(); -} - -TEST_F(RtpDataChannelDoubleThreadTest, SocketOptionsMergedOnSetTransport) { - Base::SocketOptionsMergedOnSetTransport(); -} - -TEST_F(RtpDataChannelDoubleThreadTest, TestSendData) { - CreateChannels(0, 0); - EXPECT_TRUE(SendInitiate()); - EXPECT_TRUE(SendAccept()); - - cricket::SendDataParams params; - params.ssrc = 42; - unsigned char data[] = {'f', 'o', 'o'}; - rtc::CopyOnWriteBuffer payload(data, 3); - cricket::SendDataResult result; - ASSERT_TRUE(media_channel1_->SendData(params, payload, &result)); - EXPECT_EQ(params.ssrc, media_channel1_->last_sent_data_params().ssrc); - EXPECT_EQ("foo", media_channel1_->last_sent_data()); -} // TODO(pthatcher): TestSetReceiver? diff --git a/pc/composite_rtp_transport.cc b/pc/composite_rtp_transport.cc deleted file mode 100644 index 641d1d0fab..0000000000 --- a/pc/composite_rtp_transport.cc +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Copyright 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "pc/composite_rtp_transport.h" - -#include -#include - -#include "absl/memory/memory.h" -#include "p2p/base/packet_transport_internal.h" - -namespace webrtc { - -CompositeRtpTransport::CompositeRtpTransport( - std::vector transports) - : transports_(std::move(transports)) { - RTC_DCHECK(!transports_.empty()) << "Cannot have an empty composite"; - std::vector rtp_transports; - std::vector rtcp_transports; - for (RtpTransportInternal* transport : transports_) { - RTC_DCHECK_EQ(transport->rtcp_mux_enabled(), rtcp_mux_enabled()) - << "Either all or none of the transports in a composite must enable " - "rtcp mux"; - RTC_DCHECK_EQ(transport->transport_name(), transport_name()) - << "All transports in a composite must have the same transport name"; - - transport->SignalNetworkRouteChanged.connect( - this, &CompositeRtpTransport::OnNetworkRouteChanged); - transport->SignalRtcpPacketReceived.connect( - this, &CompositeRtpTransport::OnRtcpPacketReceived); - } -} - -void CompositeRtpTransport::SetSendTransport( - RtpTransportInternal* send_transport) { - if (send_transport_ == send_transport) { - return; - } - - RTC_DCHECK(absl::c_linear_search(transports_, send_transport)) - << "Cannot set a send transport that isn't part of the composite"; - - if (send_transport_) { - send_transport_->SignalReadyToSend.disconnect(this); - send_transport_->SignalWritableState.disconnect(this); - send_transport_->SignalSentPacket.disconnect(this); - } - - send_transport_ = send_transport; - send_transport_->SignalReadyToSend.connect( - this, &CompositeRtpTransport::OnReadyToSend); - send_transport_->SignalWritableState.connect( - this, &CompositeRtpTransport::OnWritableState); - send_transport_->SignalSentPacket.connect( - this, &CompositeRtpTransport::OnSentPacket); - - SignalWritableState(send_transport_->IsWritable(/*rtcp=*/true) && - send_transport_->IsWritable(/*rtcp=*/false)); - if (send_transport_->IsReadyToSend()) { - SignalReadyToSend(true); - } -} - -void CompositeRtpTransport::RemoveTransport(RtpTransportInternal* transport) { - RTC_DCHECK(transport != send_transport_) << "Cannot remove send transport"; - - auto it = absl::c_find(transports_, transport); - if (it == transports_.end()) { - return; - } - - transport->SignalNetworkRouteChanged.disconnect(this); - transport->SignalRtcpPacketReceived.disconnect(this); - for (auto sink : rtp_demuxer_sinks_) { - transport->UnregisterRtpDemuxerSink(sink); - } - - transports_.erase(it); -} - -const std::string& CompositeRtpTransport::transport_name() const { - return transports_.front()->transport_name(); -} - -int CompositeRtpTransport::SetRtpOption(rtc::Socket::Option opt, int value) { - int result = 0; - for (auto transport : transports_) { - result |= transport->SetRtpOption(opt, value); - } - return result; -} - -int CompositeRtpTransport::SetRtcpOption(rtc::Socket::Option opt, int value) { - int result = 0; - for (auto transport : transports_) { - result |= transport->SetRtcpOption(opt, value); - } - return result; -} - -bool CompositeRtpTransport::rtcp_mux_enabled() const { - return transports_.front()->rtcp_mux_enabled(); -} - -void CompositeRtpTransport::SetRtcpMuxEnabled(bool enabled) { - for (auto transport : transports_) { - transport->SetRtcpMuxEnabled(enabled); - } -} - -bool CompositeRtpTransport::IsReadyToSend() const { - return send_transport_ && send_transport_->IsReadyToSend(); -} - -bool CompositeRtpTransport::IsWritable(bool rtcp) const { - return send_transport_ && send_transport_->IsWritable(rtcp); -} - -bool CompositeRtpTransport::SendRtpPacket(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options, - int flags) { - if (!send_transport_) { - return false; - } - return send_transport_->SendRtpPacket(packet, options, flags); -} - -bool CompositeRtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options, - int flags) { - if (!send_transport_) { - return false; - } - return send_transport_->SendRtcpPacket(packet, options, flags); -} - -void CompositeRtpTransport::UpdateRtpHeaderExtensionMap( - const cricket::RtpHeaderExtensions& header_extensions) { - for (RtpTransportInternal* transport : transports_) { - transport->UpdateRtpHeaderExtensionMap(header_extensions); - } -} - -bool CompositeRtpTransport::IsSrtpActive() const { - bool active = true; - for (RtpTransportInternal* transport : transports_) { - active &= transport->IsSrtpActive(); - } - return active; -} - -bool CompositeRtpTransport::RegisterRtpDemuxerSink( - const RtpDemuxerCriteria& criteria, - RtpPacketSinkInterface* sink) { - for (RtpTransportInternal* transport : transports_) { - transport->RegisterRtpDemuxerSink(criteria, sink); - } - rtp_demuxer_sinks_.insert(sink); - return true; -} - -bool CompositeRtpTransport::UnregisterRtpDemuxerSink( - RtpPacketSinkInterface* sink) { - for (RtpTransportInternal* transport : transports_) { - transport->UnregisterRtpDemuxerSink(sink); - } - rtp_demuxer_sinks_.erase(sink); - return true; -} - -void CompositeRtpTransport::OnNetworkRouteChanged( - absl::optional route) { - SignalNetworkRouteChanged(route); -} - -void CompositeRtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, - int64_t packet_time_us) { - SignalRtcpPacketReceived(packet, packet_time_us); -} - -void CompositeRtpTransport::OnWritableState(bool writable) { - SignalWritableState(writable); -} - -void CompositeRtpTransport::OnReadyToSend(bool ready_to_send) { - SignalReadyToSend(ready_to_send); -} - -void CompositeRtpTransport::OnSentPacket(const rtc::SentPacket& packet) { - SignalSentPacket(packet); -} - -} // namespace webrtc diff --git a/pc/composite_rtp_transport.h b/pc/composite_rtp_transport.h deleted file mode 100644 index 35f9382571..0000000000 --- a/pc/composite_rtp_transport.h +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef PC_COMPOSITE_RTP_TRANSPORT_H_ -#define PC_COMPOSITE_RTP_TRANSPORT_H_ - -#include -#include -#include -#include - -#include "call/rtp_demuxer.h" -#include "call/rtp_packet_sink_interface.h" -#include "pc/rtp_transport_internal.h" -#include "pc/session_description.h" -#include "rtc_base/async_packet_socket.h" -#include "rtc_base/copy_on_write_buffer.h" - -namespace webrtc { - -// Composite RTP transport capable of receiving from multiple sub-transports. -// -// CompositeRtpTransport is receive-only until the caller explicitly chooses -// which transport will be used to send and calls |SetSendTransport|. This -// choice must be made as part of the SDP negotiation process, based on receipt -// of a provisional answer. |CompositeRtpTransport| does not become writable or -// ready to send until |SetSendTransport| is called. -// -// When a full answer is received, the user should replace the composite -// transport with the single, chosen RTP transport, then delete the composite -// and all non-chosen transports. -class CompositeRtpTransport : public RtpTransportInternal { - public: - // Constructs a composite out of the given |transports|. |transports| must - // not be empty. All |transports| must outlive the composite. - explicit CompositeRtpTransport(std::vector transports); - - // Sets which transport will be used for sending packets. Once called, - // |IsReadyToSend|, |IsWritable|, and the associated signals will reflect the - // state of |send_tranpsort|. - void SetSendTransport(RtpTransportInternal* send_transport); - - // Removes |transport| from the composite. No-op if |transport| is null or - // not found in the composite. Removing a transport disconnects all signals - // and RTP demux sinks from that transport. The send transport may not be - // removed. - void RemoveTransport(RtpTransportInternal* transport); - - // All transports within a composite must have the same name. - const std::string& transport_name() const override; - - int SetRtpOption(rtc::Socket::Option opt, int value) override; - int SetRtcpOption(rtc::Socket::Option opt, int value) override; - - // All transports within a composite must either enable or disable RTCP mux. - bool rtcp_mux_enabled() const override; - - // Enables or disables RTCP mux for all component transports. - void SetRtcpMuxEnabled(bool enabled) override; - - // The composite is ready to send if |send_transport_| is set and ready to - // send. - bool IsReadyToSend() const override; - - // The composite is writable if |send_transport_| is set and writable. - bool IsWritable(bool rtcp) const override; - - // Sends an RTP packet. May only be called after |send_transport_| is set. - bool SendRtpPacket(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options, - int flags) override; - - // Sends an RTCP packet. May only be called after |send_transport_| is set. - bool SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, - const rtc::PacketOptions& options, - int flags) override; - - // Updates the mapping of RTP header extensions for all component transports. - void UpdateRtpHeaderExtensionMap( - const cricket::RtpHeaderExtensions& header_extensions) override; - - // SRTP is only active for a composite if it is active for all component - // transports. - bool IsSrtpActive() const override; - - // Registers an RTP demux sink with all component transports. - bool RegisterRtpDemuxerSink(const RtpDemuxerCriteria& criteria, - RtpPacketSinkInterface* sink) override; - bool UnregisterRtpDemuxerSink(RtpPacketSinkInterface* sink) override; - - private: - // Receive-side signals. - void OnNetworkRouteChanged(absl::optional route); - void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet, - int64_t packet_time_us); - - // Send-side signals. - void OnWritableState(bool writable); - void OnReadyToSend(bool ready_to_send); - void OnSentPacket(const rtc::SentPacket& packet); - - std::vector transports_; - RtpTransportInternal* send_transport_ = nullptr; - - // Record of registered RTP demuxer sinks. Used to unregister sinks when a - // transport is removed. - std::set rtp_demuxer_sinks_; -}; - -} // namespace webrtc - -#endif // PC_COMPOSITE_RTP_TRANSPORT_H_ diff --git a/pc/composite_rtp_transport_test.cc b/pc/composite_rtp_transport_test.cc deleted file mode 100644 index fee8c215b2..0000000000 --- a/pc/composite_rtp_transport_test.cc +++ /dev/null @@ -1,389 +0,0 @@ -/* - * Copyright 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "pc/composite_rtp_transport.h" - -#include - -#include "modules/rtp_rtcp/source/rtp_packet_received.h" -#include "p2p/base/fake_packet_transport.h" -#include "pc/rtp_transport.h" -#include "test/gtest.h" - -namespace webrtc { -namespace { - -constexpr char kTransportName[] = "test-transport"; -constexpr char kRtcpTransportName[] = "test-transport-rtcp"; -constexpr uint8_t kRtpPayloadType = 100; - -constexpr uint8_t kRtcpPacket[] = {0x80, 73, 0, 0}; -constexpr uint8_t kRtpPacket[] = {0x80, 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - -class CompositeRtpTransportTest : public ::testing::Test, - public sigslot::has_slots<>, - public RtpPacketSinkInterface { - public: - CompositeRtpTransportTest() - : packet_transport_1_( - std::make_unique(kTransportName)), - packet_transport_2_( - std::make_unique(kTransportName)), - rtcp_transport_1_( - std::make_unique(kRtcpTransportName)), - rtcp_transport_2_( - std::make_unique(kRtcpTransportName)) {} - - void SetupRtpTransports(bool rtcp_mux) { - transport_1_ = std::make_unique(rtcp_mux); - transport_2_ = std::make_unique(rtcp_mux); - - transport_1_->SetRtpPacketTransport(packet_transport_1_.get()); - transport_2_->SetRtpPacketTransport(packet_transport_2_.get()); - if (!rtcp_mux) { - transport_1_->SetRtcpPacketTransport(rtcp_transport_1_.get()); - transport_2_->SetRtcpPacketTransport(rtcp_transport_2_.get()); - } - - composite_ = std::make_unique( - std::vector{transport_1_.get(), - transport_2_.get()}); - - composite_->SignalReadyToSend.connect( - this, &CompositeRtpTransportTest::OnReadyToSend); - composite_->SignalWritableState.connect( - this, &CompositeRtpTransportTest::OnWritableState); - composite_->SignalSentPacket.connect( - this, &CompositeRtpTransportTest::OnSentPacket); - composite_->SignalNetworkRouteChanged.connect( - this, &CompositeRtpTransportTest::OnNetworkRouteChanged); - composite_->SignalRtcpPacketReceived.connect( - this, &CompositeRtpTransportTest::OnRtcpPacketReceived); - - RtpDemuxerCriteria criteria; - criteria.payload_types.insert(kRtpPayloadType); - composite_->RegisterRtpDemuxerSink(criteria, this); - } - - void TearDown() override { composite_->UnregisterRtpDemuxerSink(this); } - - void OnReadyToSend(bool ready) { ++ready_to_send_count_; } - void OnWritableState(bool writable) { ++writable_state_count_; } - void OnSentPacket(const rtc::SentPacket& packet) { ++sent_packet_count_; } - void OnNetworkRouteChanged(absl::optional route) { - ++network_route_count_; - last_network_route_ = route; - } - void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* buffer, - int64_t packet_time_us) { - ++rtcp_packet_count_; - last_packet_ = *buffer; - } - void OnRtpPacket(const RtpPacketReceived& packet) { - ++rtp_packet_count_; - last_packet_ = packet.Buffer(); - } - - protected: - std::unique_ptr packet_transport_1_; - std::unique_ptr packet_transport_2_; - std::unique_ptr rtcp_transport_1_; - std::unique_ptr rtcp_transport_2_; - std::unique_ptr transport_1_; - std::unique_ptr transport_2_; - std::unique_ptr composite_; - - int ready_to_send_count_ = 0; - int writable_state_count_ = 0; - int sent_packet_count_ = 0; - int network_route_count_ = 0; - int rtcp_packet_count_ = 0; - int rtp_packet_count_ = 0; - - absl::optional last_network_route_; - rtc::CopyOnWriteBuffer last_packet_; -}; - -TEST_F(CompositeRtpTransportTest, EnableRtcpMux) { - SetupRtpTransports(/*rtcp_mux=*/false); - EXPECT_FALSE(composite_->rtcp_mux_enabled()); - EXPECT_FALSE(transport_1_->rtcp_mux_enabled()); - EXPECT_FALSE(transport_2_->rtcp_mux_enabled()); - - composite_->SetRtcpMuxEnabled(true); - EXPECT_TRUE(composite_->rtcp_mux_enabled()); - EXPECT_TRUE(transport_1_->rtcp_mux_enabled()); - EXPECT_TRUE(transport_2_->rtcp_mux_enabled()); -} - -TEST_F(CompositeRtpTransportTest, DisableRtcpMux) { - SetupRtpTransports(/*rtcp_mux=*/true); - EXPECT_TRUE(composite_->rtcp_mux_enabled()); - EXPECT_TRUE(transport_1_->rtcp_mux_enabled()); - EXPECT_TRUE(transport_2_->rtcp_mux_enabled()); - - // If the component transports didn't have an RTCP transport before, they need - // to be set independently before disabling RTCP mux. There's no other sane - // way to do this, as the interface only allows sending a single RTCP - // transport, and we need one for each component. - transport_1_->SetRtcpPacketTransport(rtcp_transport_1_.get()); - transport_2_->SetRtcpPacketTransport(rtcp_transport_2_.get()); - - composite_->SetRtcpMuxEnabled(false); - EXPECT_FALSE(composite_->rtcp_mux_enabled()); - EXPECT_FALSE(transport_1_->rtcp_mux_enabled()); - EXPECT_FALSE(transport_2_->rtcp_mux_enabled()); -} - -TEST_F(CompositeRtpTransportTest, SetRtpOption) { - SetupRtpTransports(/*rtcp_mux=*/true); - EXPECT_EQ(0, composite_->SetRtpOption(rtc::Socket::OPT_DSCP, 2)); - - int value = 0; - EXPECT_TRUE(packet_transport_1_->GetOption(rtc::Socket::OPT_DSCP, &value)); - EXPECT_EQ(value, 2); - - EXPECT_TRUE(packet_transport_2_->GetOption(rtc::Socket::OPT_DSCP, &value)); - EXPECT_EQ(value, 2); -} - -TEST_F(CompositeRtpTransportTest, SetRtcpOption) { - SetupRtpTransports(/*rtcp_mux=*/false); - EXPECT_EQ(0, composite_->SetRtcpOption(rtc::Socket::OPT_DSCP, 2)); - - int value = 0; - EXPECT_TRUE(rtcp_transport_1_->GetOption(rtc::Socket::OPT_DSCP, &value)); - EXPECT_EQ(value, 2); - - EXPECT_TRUE(rtcp_transport_2_->GetOption(rtc::Socket::OPT_DSCP, &value)); - EXPECT_EQ(value, 2); -} - -TEST_F(CompositeRtpTransportTest, NeverWritableWithoutSendTransport) { - SetupRtpTransports(/*rtcp_mux=*/true); - - packet_transport_1_->SetWritable(true); - packet_transport_2_->SetWritable(true); - - EXPECT_FALSE(composite_->IsWritable(false)); - EXPECT_FALSE(composite_->IsWritable(true)); - EXPECT_FALSE(composite_->IsReadyToSend()); - EXPECT_EQ(0, ready_to_send_count_); - EXPECT_EQ(0, writable_state_count_); -} - -TEST_F(CompositeRtpTransportTest, WritableWhenSendTransportBecomesWritable) { - SetupRtpTransports(/*rtcp_mux=*/true); - - composite_->SetSendTransport(transport_1_.get()); - - EXPECT_FALSE(composite_->IsWritable(false)); - EXPECT_FALSE(composite_->IsWritable(true)); - EXPECT_FALSE(composite_->IsReadyToSend()); - EXPECT_EQ(0, ready_to_send_count_); - EXPECT_EQ(1, writable_state_count_); - - packet_transport_2_->SetWritable(true); - - EXPECT_FALSE(composite_->IsWritable(false)); - EXPECT_FALSE(composite_->IsWritable(true)); - EXPECT_FALSE(composite_->IsReadyToSend()); - EXPECT_EQ(0, ready_to_send_count_); - EXPECT_EQ(1, writable_state_count_); - - packet_transport_1_->SetWritable(true); - - EXPECT_TRUE(composite_->IsWritable(false)); - EXPECT_TRUE(composite_->IsWritable(true)); - EXPECT_TRUE(composite_->IsReadyToSend()); - EXPECT_EQ(1, ready_to_send_count_); - EXPECT_EQ(2, writable_state_count_); -} - -TEST_F(CompositeRtpTransportTest, SendTransportAlreadyWritable) { - SetupRtpTransports(/*rtcp_mux=*/true); - packet_transport_1_->SetWritable(true); - - composite_->SetSendTransport(transport_1_.get()); - - EXPECT_TRUE(composite_->IsWritable(false)); - EXPECT_TRUE(composite_->IsWritable(true)); - EXPECT_TRUE(composite_->IsReadyToSend()); - EXPECT_EQ(1, ready_to_send_count_); - EXPECT_EQ(1, writable_state_count_); -} - -TEST_F(CompositeRtpTransportTest, IsSrtpActive) { - SetupRtpTransports(/*rtcp_mux=*/true); - EXPECT_FALSE(composite_->IsSrtpActive()); -} - -TEST_F(CompositeRtpTransportTest, NetworkRouteChange) { - SetupRtpTransports(/*rtcp_mux=*/true); - - rtc::NetworkRoute route; - route.local = rtc::RouteEndpoint::CreateWithNetworkId(7); - packet_transport_1_->SetNetworkRoute(route); - - EXPECT_EQ(1, network_route_count_); - EXPECT_EQ(7, last_network_route_->local.network_id()); - - route.local = rtc::RouteEndpoint::CreateWithNetworkId(8); - packet_transport_2_->SetNetworkRoute(route); - - EXPECT_EQ(2, network_route_count_); - EXPECT_EQ(8, last_network_route_->local.network_id()); -} - -TEST_F(CompositeRtpTransportTest, RemoveTransport) { - SetupRtpTransports(/*rtcp_mux=*/true); - - composite_->RemoveTransport(transport_1_.get()); - - // Check that signals are disconnected. - rtc::NetworkRoute route; - route.local = rtc::RouteEndpoint::CreateWithNetworkId(7); - packet_transport_1_->SetNetworkRoute(route); - - EXPECT_EQ(0, network_route_count_); -} - -TEST_F(CompositeRtpTransportTest, SendRtcpBeforeSendTransportSet) { - SetupRtpTransports(/*rtcp_mux=*/true); - - rtc::FakePacketTransport remote("remote"); - remote.SetDestination(packet_transport_1_.get(), false); - - rtc::CopyOnWriteBuffer packet(kRtcpPacket); - EXPECT_FALSE(composite_->SendRtcpPacket(&packet, rtc::PacketOptions(), 0)); - EXPECT_EQ(0, sent_packet_count_); -} - -TEST_F(CompositeRtpTransportTest, SendRtcpOn1) { - SetupRtpTransports(/*rtcp_mux=*/true); - - rtc::FakePacketTransport remote("remote"); - remote.SetDestination(packet_transport_1_.get(), false); - composite_->SetSendTransport(transport_1_.get()); - - rtc::CopyOnWriteBuffer packet(kRtcpPacket); - EXPECT_TRUE(composite_->SendRtcpPacket(&packet, rtc::PacketOptions(), 0)); - EXPECT_EQ(1, sent_packet_count_); - EXPECT_EQ(packet, *packet_transport_1_->last_sent_packet()); -} - -TEST_F(CompositeRtpTransportTest, SendRtcpOn2) { - SetupRtpTransports(/*rtcp_mux=*/true); - - rtc::FakePacketTransport remote("remote"); - remote.SetDestination(packet_transport_2_.get(), false); - composite_->SetSendTransport(transport_2_.get()); - - rtc::CopyOnWriteBuffer packet(kRtcpPacket); - EXPECT_TRUE(composite_->SendRtcpPacket(&packet, rtc::PacketOptions(), 0)); - EXPECT_EQ(1, sent_packet_count_); - EXPECT_EQ(packet, *packet_transport_2_->last_sent_packet()); -} - -TEST_F(CompositeRtpTransportTest, SendRtpBeforeSendTransportSet) { - SetupRtpTransports(/*rtcp_mux=*/true); - - rtc::FakePacketTransport remote("remote"); - remote.SetDestination(packet_transport_1_.get(), false); - - rtc::CopyOnWriteBuffer packet(kRtpPacket); - EXPECT_FALSE(composite_->SendRtpPacket(&packet, rtc::PacketOptions(), 0)); - EXPECT_EQ(0, sent_packet_count_); -} - -TEST_F(CompositeRtpTransportTest, SendRtpOn1) { - SetupRtpTransports(/*rtcp_mux=*/true); - - rtc::FakePacketTransport remote("remote"); - remote.SetDestination(packet_transport_1_.get(), false); - composite_->SetSendTransport(transport_1_.get()); - - rtc::CopyOnWriteBuffer packet(kRtpPacket); - EXPECT_TRUE(composite_->SendRtpPacket(&packet, rtc::PacketOptions(), 0)); - EXPECT_EQ(1, sent_packet_count_); - EXPECT_EQ(packet, *packet_transport_1_->last_sent_packet()); -} - -TEST_F(CompositeRtpTransportTest, SendRtpOn2) { - SetupRtpTransports(/*rtcp_mux=*/true); - - rtc::FakePacketTransport remote("remote"); - remote.SetDestination(packet_transport_2_.get(), false); - composite_->SetSendTransport(transport_2_.get()); - - rtc::CopyOnWriteBuffer packet(kRtpPacket); - EXPECT_TRUE(composite_->SendRtpPacket(&packet, rtc::PacketOptions(), 0)); - EXPECT_EQ(1, sent_packet_count_); - EXPECT_EQ(packet, *packet_transport_2_->last_sent_packet()); -} - -TEST_F(CompositeRtpTransportTest, ReceiveRtcpFrom1) { - SetupRtpTransports(/*rtcp_mux=*/true); - - rtc::FakePacketTransport remote("remote"); - remote.SetDestination(packet_transport_1_.get(), false); - - rtc::CopyOnWriteBuffer packet(kRtcpPacket); - remote.SendPacket(packet.cdata(), packet.size(), rtc::PacketOptions(), - 0); - - EXPECT_EQ(1, rtcp_packet_count_); - EXPECT_EQ(packet, last_packet_); -} - -TEST_F(CompositeRtpTransportTest, ReceiveRtcpFrom2) { - SetupRtpTransports(/*rtcp_mux=*/true); - - rtc::FakePacketTransport remote("remote"); - remote.SetDestination(packet_transport_2_.get(), false); - - rtc::CopyOnWriteBuffer packet(kRtcpPacket); - remote.SendPacket(packet.cdata(), packet.size(), rtc::PacketOptions(), - 0); - - EXPECT_EQ(1, rtcp_packet_count_); - EXPECT_EQ(packet, last_packet_); -} - -TEST_F(CompositeRtpTransportTest, ReceiveRtpFrom1) { - SetupRtpTransports(/*rtcp_mux=*/true); - - rtc::FakePacketTransport remote("remote"); - remote.SetDestination(packet_transport_1_.get(), false); - - rtc::CopyOnWriteBuffer packet(kRtpPacket); - remote.SendPacket(packet.cdata(), packet.size(), rtc::PacketOptions(), - 0); - - EXPECT_EQ(1, rtp_packet_count_); - EXPECT_EQ(packet, last_packet_); -} - -TEST_F(CompositeRtpTransportTest, ReceiveRtpFrom2) { - SetupRtpTransports(/*rtcp_mux=*/true); - - rtc::FakePacketTransport remote("remote"); - remote.SetDestination(packet_transport_2_.get(), false); - - rtc::CopyOnWriteBuffer packet(kRtpPacket); - remote.SendPacket(packet.cdata(), packet.size(), rtc::PacketOptions(), - 0); - - EXPECT_EQ(1, rtp_packet_count_); - EXPECT_EQ(packet, last_packet_); -} - -} // namespace -} // namespace webrtc diff --git a/pc/connection_context.cc b/pc/connection_context.cc index 727fbd6542..1bb7908f5c 100644 --- a/pc/connection_context.cc +++ b/pc/connection_context.cc @@ -15,9 +15,9 @@ #include #include "api/transport/field_trial_based_config.h" -#include "media/base/rtp_data_engine.h" +#include "media/sctp/sctp_transport_factory.h" #include "rtc_base/helpers.h" -#include "rtc_base/ref_counted_object.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/time_utils.h" namespace webrtc { @@ -63,7 +63,7 @@ std::unique_ptr MaybeCreateSctpFactory( if (factory) { return factory; } -#ifdef HAVE_SCTP +#ifdef WEBRTC_HAVE_SCTP return std::make_unique(network_thread); #else return nullptr; @@ -75,11 +75,7 @@ std::unique_ptr MaybeCreateSctpFactory( // Static rtc::scoped_refptr ConnectionContext::Create( PeerConnectionFactoryDependencies* dependencies) { - auto context = new rtc::RefCountedObject(dependencies); - if (!context->channel_manager_->Init()) { - return nullptr; - } - return context; + return new ConnectionContext(dependencies); } ConnectionContext::ConnectionContext( @@ -97,7 +93,6 @@ ConnectionContext::ConnectionContext( network_monitor_factory_( std::move(dependencies->network_monitor_factory)), call_factory_(std::move(dependencies->call_factory)), - media_engine_(std::move(dependencies->media_engine)), sctp_factory_( MaybeCreateSctpFactory(std::move(dependencies->sctp_factory), network_thread())), @@ -107,7 +102,16 @@ ConnectionContext::ConnectionContext( signaling_thread_->AllowInvokesToThread(worker_thread_); signaling_thread_->AllowInvokesToThread(network_thread_); worker_thread_->AllowInvokesToThread(network_thread_); - network_thread_->DisallowAllInvokes(); + if (network_thread_->IsCurrent()) { + // TODO(https://crbug.com/webrtc/12802) switch to DisallowAllInvokes + network_thread_->AllowInvokesToThread(network_thread_); + } else { + network_thread_->PostTask(ToQueuedTask([thread = network_thread_] { + thread->DisallowBlockingCalls(); + // TODO(https://crbug.com/webrtc/12802) switch to DisallowAllInvokes + thread->AllowInvokesToThread(thread); + })); + } RTC_DCHECK_RUN_ON(signaling_thread_); rtc::InitRandom(rtc::Time32()); @@ -120,16 +124,26 @@ ConnectionContext::ConnectionContext( default_socket_factory_ = std::make_unique(network_thread()); - channel_manager_ = std::make_unique( - std::move(media_engine_), std::make_unique(), - worker_thread(), network_thread()); - - channel_manager_->SetVideoRtxEnabled(true); + worker_thread_->Invoke(RTC_FROM_HERE, [&]() { + channel_manager_ = cricket::ChannelManager::Create( + std::move(dependencies->media_engine), + /*enable_rtx=*/true, worker_thread(), network_thread()); + }); + + // Set warning levels on the threads, to give warnings when response + // may be slower than is expected of the thread. + // Since some of the threads may be the same, start with the least + // restrictive limits and end with the least permissive ones. + // This will give warnings for all cases. + signaling_thread_->SetDispatchWarningMs(100); + worker_thread_->SetDispatchWarningMs(30); + network_thread_->SetDispatchWarningMs(10); } ConnectionContext::~ConnectionContext() { RTC_DCHECK_RUN_ON(signaling_thread_); - channel_manager_.reset(nullptr); + worker_thread_->Invoke(RTC_FROM_HERE, + [&]() { channel_manager_.reset(nullptr); }); // Make sure |worker_thread()| and |signaling_thread()| outlive // |default_socket_factory_| and |default_network_manager_|. diff --git a/pc/connection_context.h b/pc/connection_context.h index 02d08a191e..8fad13c10c 100644 --- a/pc/connection_context.h +++ b/pc/connection_context.h @@ -17,19 +17,18 @@ #include "api/call/call_factory_interface.h" #include "api/media_stream_interface.h" #include "api/peer_connection_interface.h" +#include "api/ref_counted_base.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "api/transport/sctp_transport_factory_interface.h" #include "api/transport/webrtc_key_value_config.h" #include "media/base/media_engine.h" -#include "media/sctp/sctp_transport_internal.h" #include "p2p/base/basic_packet_socket_factory.h" #include "pc/channel_manager.h" #include "rtc_base/checks.h" #include "rtc_base/network.h" #include "rtc_base/network_monitor_factory.h" -#include "rtc_base/ref_count.h" #include "rtc_base/rtc_certificate_generator.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/thread.h" #include "rtc_base/thread_annotations.h" @@ -48,7 +47,8 @@ class RtcEventLog; // interferes with the operation of other PeerConnections. // // This class must be created and destroyed on the signaling thread. -class ConnectionContext : public rtc::RefCountInterface { +class ConnectionContext final + : public rtc::RefCountedNonVirtual { public: // Creates a ConnectionContext. May return null if initialization fails. // The Dependencies class allows simple management of all new dependencies @@ -62,7 +62,6 @@ class ConnectionContext : public rtc::RefCountInterface { // Functions called from PeerConnection and friends SctpTransportFactoryInterface* sctp_transport_factory() const { - RTC_DCHECK_RUN_ON(signaling_thread_); return sctp_factory_.get(); } @@ -94,7 +93,8 @@ class ConnectionContext : public rtc::RefCountInterface { protected: explicit ConnectionContext(PeerConnectionFactoryDependencies* dependencies); - virtual ~ConnectionContext(); + friend class rtc::RefCountedNonVirtual; + ~ConnectionContext(); private: // The following three variables are used to communicate between the @@ -121,10 +121,7 @@ class ConnectionContext : public rtc::RefCountInterface { std::unique_ptr default_socket_factory_ RTC_GUARDED_BY(signaling_thread_); - std::unique_ptr media_engine_ - RTC_GUARDED_BY(signaling_thread_); - std::unique_ptr const sctp_factory_ - RTC_GUARDED_BY(signaling_thread_); + std::unique_ptr const sctp_factory_; // Accessed both on signaling thread and worker thread. std::unique_ptr const trials_; }; diff --git a/pc/data_channel_controller.cc b/pc/data_channel_controller.cc index 9fabe13cc7..7a6fd3c168 100644 --- a/pc/data_channel_controller.cc +++ b/pc/data_channel_controller.cc @@ -10,57 +10,36 @@ #include "pc/data_channel_controller.h" +#include #include +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "api/peer_connection_interface.h" +#include "api/rtc_error.h" #include "pc/peer_connection.h" #include "pc/sctp_utils.h" +#include "rtc_base/location.h" +#include "rtc_base/logging.h" +#include "rtc_base/task_utils/to_queued_task.h" namespace webrtc { bool DataChannelController::HasDataChannels() const { RTC_DCHECK_RUN_ON(signaling_thread()); - return !rtp_data_channels_.empty() || !sctp_data_channels_.empty(); + return !sctp_data_channels_.empty(); } -bool DataChannelController::SendData(const cricket::SendDataParams& params, +bool DataChannelController::SendData(int sid, + const SendDataParams& params, const rtc::CopyOnWriteBuffer& payload, cricket::SendDataResult* result) { if (data_channel_transport()) - return DataChannelSendData(params, payload, result); - if (rtp_data_channel()) - return rtp_data_channel()->SendData(params, payload, result); + return DataChannelSendData(sid, params, payload, result); RTC_LOG(LS_ERROR) << "SendData called before transport is ready"; return false; } -bool DataChannelController::ConnectDataChannel( - RtpDataChannel* webrtc_data_channel) { - RTC_DCHECK_RUN_ON(signaling_thread()); - if (!rtp_data_channel()) { - // Don't log an error here, because DataChannels are expected to call - // ConnectDataChannel in this state. It's the only way to initially tell - // whether or not the underlying transport is ready. - return false; - } - rtp_data_channel()->SignalReadyToSendData.connect( - webrtc_data_channel, &RtpDataChannel::OnChannelReady); - rtp_data_channel()->SignalDataReceived.connect( - webrtc_data_channel, &RtpDataChannel::OnDataReceived); - return true; -} - -void DataChannelController::DisconnectDataChannel( - RtpDataChannel* webrtc_data_channel) { - RTC_DCHECK_RUN_ON(signaling_thread()); - if (!rtp_data_channel()) { - RTC_LOG(LS_ERROR) - << "DisconnectDataChannel called when rtp_data_channel_ is NULL."; - return; - } - rtp_data_channel()->SignalReadyToSendData.disconnect(webrtc_data_channel); - rtp_data_channel()->SignalDataReceived.disconnect(webrtc_data_channel); -} - bool DataChannelController::ConnectDataChannel( SctpDataChannel* webrtc_data_channel) { RTC_DCHECK_RUN_ON(signaling_thread()); @@ -117,8 +96,7 @@ void DataChannelController::RemoveSctpDataStream(int sid) { bool DataChannelController::ReadyToSendData() const { RTC_DCHECK_RUN_ON(signaling_thread()); - return (rtp_data_channel() && rtp_data_channel()->ready_to_send_data()) || - (data_channel_transport() && data_channel_transport_ready_to_send_); + return (data_channel_transport() && data_channel_transport_ready_to_send_); } void DataChannelController::OnDataReceived( @@ -128,60 +106,70 @@ void DataChannelController::OnDataReceived( RTC_DCHECK_RUN_ON(network_thread()); cricket::ReceiveDataParams params; params.sid = channel_id; - params.type = ToCricketDataMessageType(type); - data_channel_transport_invoker_.AsyncInvoke( - RTC_FROM_HERE, signaling_thread(), [this, params, buffer] { - RTC_DCHECK_RUN_ON(signaling_thread()); - // TODO(bugs.webrtc.org/11547): The data being received should be - // delivered on the network thread. The way HandleOpenMessage_s works - // right now is that it's called for all types of buffers and operates - // as a selector function. Change this so that it's only called for - // buffers that it should be able to handle. Once we do that, we can - // deliver all other buffers on the network thread (change - // SignalDataChannelTransportReceivedData_s to - // SignalDataChannelTransportReceivedData_n). - if (!HandleOpenMessage_s(params, buffer)) { - SignalDataChannelTransportReceivedData_s(params, buffer); + params.type = type; + signaling_thread()->PostTask( + ToQueuedTask([self = weak_factory_.GetWeakPtr(), params, buffer] { + if (self) { + RTC_DCHECK_RUN_ON(self->signaling_thread()); + // TODO(bugs.webrtc.org/11547): The data being received should be + // delivered on the network thread. The way HandleOpenMessage_s works + // right now is that it's called for all types of buffers and operates + // as a selector function. Change this so that it's only called for + // buffers that it should be able to handle. Once we do that, we can + // deliver all other buffers on the network thread (change + // SignalDataChannelTransportReceivedData_s to + // SignalDataChannelTransportReceivedData_n). + if (!self->HandleOpenMessage_s(params, buffer)) { + self->SignalDataChannelTransportReceivedData_s(params, buffer); + } } - }); + })); } void DataChannelController::OnChannelClosing(int channel_id) { RTC_DCHECK_RUN_ON(network_thread()); - data_channel_transport_invoker_.AsyncInvoke( - RTC_FROM_HERE, signaling_thread(), [this, channel_id] { - RTC_DCHECK_RUN_ON(signaling_thread()); - SignalDataChannelTransportChannelClosing_s(channel_id); - }); + signaling_thread()->PostTask( + ToQueuedTask([self = weak_factory_.GetWeakPtr(), channel_id] { + if (self) { + RTC_DCHECK_RUN_ON(self->signaling_thread()); + self->SignalDataChannelTransportChannelClosing_s(channel_id); + } + })); } void DataChannelController::OnChannelClosed(int channel_id) { RTC_DCHECK_RUN_ON(network_thread()); - data_channel_transport_invoker_.AsyncInvoke( - RTC_FROM_HERE, signaling_thread(), [this, channel_id] { - RTC_DCHECK_RUN_ON(signaling_thread()); - SignalDataChannelTransportChannelClosed_s(channel_id); - }); + signaling_thread()->PostTask( + ToQueuedTask([self = weak_factory_.GetWeakPtr(), channel_id] { + if (self) { + RTC_DCHECK_RUN_ON(self->signaling_thread()); + self->SignalDataChannelTransportChannelClosed_s(channel_id); + } + })); } void DataChannelController::OnReadyToSend() { RTC_DCHECK_RUN_ON(network_thread()); - data_channel_transport_invoker_.AsyncInvoke( - RTC_FROM_HERE, signaling_thread(), [this] { - RTC_DCHECK_RUN_ON(signaling_thread()); - data_channel_transport_ready_to_send_ = true; - SignalDataChannelTransportWritable_s( - data_channel_transport_ready_to_send_); - }); + signaling_thread()->PostTask( + ToQueuedTask([self = weak_factory_.GetWeakPtr()] { + if (self) { + RTC_DCHECK_RUN_ON(self->signaling_thread()); + self->data_channel_transport_ready_to_send_ = true; + self->SignalDataChannelTransportWritable_s( + self->data_channel_transport_ready_to_send_); + } + })); } -void DataChannelController::OnTransportClosed() { +void DataChannelController::OnTransportClosed(RTCError error) { RTC_DCHECK_RUN_ON(network_thread()); - data_channel_transport_invoker_.AsyncInvoke( - RTC_FROM_HERE, signaling_thread(), [this] { - RTC_DCHECK_RUN_ON(signaling_thread()); - OnTransportChannelClosed(); - }); + signaling_thread()->PostTask( + ToQueuedTask([self = weak_factory_.GetWeakPtr(), error] { + if (self) { + RTC_DCHECK_RUN_ON(self->signaling_thread()); + self->OnTransportChannelClosed(error); + } + })); } void DataChannelController::SetupDataChannelTransport_n() { @@ -234,15 +222,15 @@ std::vector DataChannelController::GetDataChannelStats() bool DataChannelController::HandleOpenMessage_s( const cricket::ReceiveDataParams& params, const rtc::CopyOnWriteBuffer& buffer) { - if (params.type == cricket::DMT_CONTROL && IsOpenMessage(buffer)) { + if (params.type == DataMessageType::kControl && IsOpenMessage(buffer)) { // Received OPEN message; parse and signal that a new data channel should // be created. std::string label; InternalDataChannelInit config; - config.id = params.ssrc; + config.id = params.sid; if (!ParseDataChannelOpenMessage(buffer, &label, &config)) { - RTC_LOG(LS_WARNING) << "Failed to parse the OPEN message for ssrc " - << params.ssrc; + RTC_LOG(LS_WARNING) << "Failed to parse the OPEN message for sid " + << params.sid; return true; } config.open_handshake_role = InternalDataChannelInit::kAcker; @@ -274,49 +262,16 @@ DataChannelController::InternalCreateDataChannelWithProxy( if (pc_->IsClosed()) { return nullptr; } - if (data_channel_type_ == cricket::DCT_NONE) { - RTC_LOG(LS_ERROR) - << "InternalCreateDataChannel: Data is not supported in this call."; - return nullptr; - } - if (IsSctpLike(data_channel_type())) { - rtc::scoped_refptr channel = - InternalCreateSctpDataChannel(label, config); - if (channel) { - return SctpDataChannel::CreateProxy(channel); - } - } else if (data_channel_type() == cricket::DCT_RTP) { - rtc::scoped_refptr channel = - InternalCreateRtpDataChannel(label, config); - if (channel) { - return RtpDataChannel::CreateProxy(channel); - } + + rtc::scoped_refptr channel = + InternalCreateSctpDataChannel(label, config); + if (channel) { + return SctpDataChannel::CreateProxy(channel); } return nullptr; } -rtc::scoped_refptr -DataChannelController::InternalCreateRtpDataChannel( - const std::string& label, - const DataChannelInit* config) { - RTC_DCHECK_RUN_ON(signaling_thread()); - DataChannelInit new_config = config ? (*config) : DataChannelInit(); - rtc::scoped_refptr channel( - RtpDataChannel::Create(this, label, new_config, signaling_thread())); - if (!channel) { - return nullptr; - } - if (rtp_data_channels_.find(channel->label()) != rtp_data_channels_.end()) { - RTC_LOG(LS_ERROR) << "DataChannel with label " << channel->label() - << " already exists."; - return nullptr; - } - rtp_data_channels_[channel->label()] = channel; - SignalRtpDataChannelCreated_(channel.get()); - return channel; -} - rtc::scoped_refptr DataChannelController::InternalCreateSctpDataChannel( const std::string& label, @@ -384,31 +339,25 @@ void DataChannelController::OnSctpDataChannelClosed(SctpDataChannel* channel) { sctp_data_channels_to_free_.push_back(*it); sctp_data_channels_.erase(it); signaling_thread()->PostTask( - RTC_FROM_HERE, [self = weak_factory_.GetWeakPtr()] { + ToQueuedTask([self = weak_factory_.GetWeakPtr()] { if (self) { RTC_DCHECK_RUN_ON(self->signaling_thread()); self->sctp_data_channels_to_free_.clear(); } - }); + })); return; } } } -void DataChannelController::OnTransportChannelClosed() { +void DataChannelController::OnTransportChannelClosed(RTCError error) { RTC_DCHECK_RUN_ON(signaling_thread()); - // Use a temporary copy of the RTP/SCTP DataChannel list because the + // Use a temporary copy of the SCTP DataChannel list because the // DataChannel may callback to us and try to modify the list. - std::map> temp_rtp_dcs; - temp_rtp_dcs.swap(rtp_data_channels_); - for (const auto& kv : temp_rtp_dcs) { - kv.second->OnTransportChannelClosed(); - } - std::vector> temp_sctp_dcs; temp_sctp_dcs.swap(sctp_data_channels_); for (const auto& channel : temp_sctp_dcs) { - channel->OnTransportChannelClosed(); + channel->OnTransportChannelClosed(error); } } @@ -422,70 +371,6 @@ SctpDataChannel* DataChannelController::FindDataChannelBySid(int sid) const { return nullptr; } -void DataChannelController::UpdateLocalRtpDataChannels( - const cricket::StreamParamsVec& streams) { - std::vector existing_channels; - - RTC_DCHECK_RUN_ON(signaling_thread()); - // Find new and active data channels. - for (const cricket::StreamParams& params : streams) { - // |it->sync_label| is actually the data channel label. The reason is that - // we use the same naming of data channels as we do for - // MediaStreams and Tracks. - // For MediaStreams, the sync_label is the MediaStream label and the - // track label is the same as |streamid|. - const std::string& channel_label = params.first_stream_id(); - auto data_channel_it = rtp_data_channels()->find(channel_label); - if (data_channel_it == rtp_data_channels()->end()) { - RTC_LOG(LS_ERROR) << "channel label not found"; - continue; - } - // Set the SSRC the data channel should use for sending. - data_channel_it->second->SetSendSsrc(params.first_ssrc()); - existing_channels.push_back(data_channel_it->first); - } - - UpdateClosingRtpDataChannels(existing_channels, true); -} - -void DataChannelController::UpdateRemoteRtpDataChannels( - const cricket::StreamParamsVec& streams) { - RTC_DCHECK_RUN_ON(signaling_thread()); - - std::vector existing_channels; - - // Find new and active data channels. - for (const cricket::StreamParams& params : streams) { - // The data channel label is either the mslabel or the SSRC if the mslabel - // does not exist. Ex a=ssrc:444330170 mslabel:test1. - std::string label = params.first_stream_id().empty() - ? rtc::ToString(params.first_ssrc()) - : params.first_stream_id(); - auto data_channel_it = rtp_data_channels()->find(label); - if (data_channel_it == rtp_data_channels()->end()) { - // This is a new data channel. - CreateRemoteRtpDataChannel(label, params.first_ssrc()); - } else { - data_channel_it->second->SetReceiveSsrc(params.first_ssrc()); - } - existing_channels.push_back(label); - } - - UpdateClosingRtpDataChannels(existing_channels, false); -} - -cricket::DataChannelType DataChannelController::data_channel_type() const { - // TODO(bugs.webrtc.org/9987): Should be restricted to the signaling thread. - // RTC_DCHECK_RUN_ON(signaling_thread()); - return data_channel_type_; -} - -void DataChannelController::set_data_channel_type( - cricket::DataChannelType type) { - RTC_DCHECK_RUN_ON(signaling_thread()); - data_channel_type_ = type; -} - DataChannelTransportInterface* DataChannelController::data_channel_transport() const { // TODO(bugs.webrtc.org/11547): Only allow this accessor to be called on the @@ -500,58 +385,9 @@ void DataChannelController::set_data_channel_transport( data_channel_transport_ = transport; } -const std::map>* -DataChannelController::rtp_data_channels() const { - RTC_DCHECK_RUN_ON(signaling_thread()); - return &rtp_data_channels_; -} - -void DataChannelController::UpdateClosingRtpDataChannels( - const std::vector& active_channels, - bool is_local_update) { - auto it = rtp_data_channels_.begin(); - while (it != rtp_data_channels_.end()) { - RtpDataChannel* data_channel = it->second; - if (absl::c_linear_search(active_channels, data_channel->label())) { - ++it; - continue; - } - - if (is_local_update) { - data_channel->SetSendSsrc(0); - } else { - data_channel->RemotePeerRequestClose(); - } - - if (data_channel->state() == RtpDataChannel::kClosed) { - rtp_data_channels_.erase(it); - it = rtp_data_channels_.begin(); - } else { - ++it; - } - } -} - -void DataChannelController::CreateRemoteRtpDataChannel(const std::string& label, - uint32_t remote_ssrc) { - if (data_channel_type() != cricket::DCT_RTP) { - return; - } - rtc::scoped_refptr channel( - InternalCreateRtpDataChannel(label, nullptr)); - if (!channel.get()) { - RTC_LOG(LS_WARNING) << "Remote peer requested a DataChannel but" - "CreateDataChannel failed."; - return; - } - channel->SetReceiveSsrc(remote_ssrc); - rtc::scoped_refptr proxy_channel = - RtpDataChannel::CreateProxy(std::move(channel)); - pc_->Observer()->OnDataChannel(std::move(proxy_channel)); -} - bool DataChannelController::DataChannelSendData( - const cricket::SendDataParams& params, + int sid, + const SendDataParams& params, const rtc::CopyOnWriteBuffer& payload, cricket::SendDataResult* result) { // TODO(bugs.webrtc.org/11547): Expect method to be called on the network @@ -560,19 +396,9 @@ bool DataChannelController::DataChannelSendData( RTC_DCHECK_RUN_ON(signaling_thread()); RTC_DCHECK(data_channel_transport()); - SendDataParams send_params; - send_params.type = ToWebrtcDataMessageType(params.type); - send_params.ordered = params.ordered; - if (params.max_rtx_count >= 0) { - send_params.max_rtx_count = params.max_rtx_count; - } else if (params.max_rtx_ms >= 0) { - send_params.max_rtx_ms = params.max_rtx_ms; - } - RTCError error = network_thread()->Invoke( - RTC_FROM_HERE, [this, params, send_params, payload] { - return data_channel_transport()->SendData(params.sid, send_params, - payload); + RTC_FROM_HERE, [this, sid, params, payload] { + return data_channel_transport()->SendData(sid, params, payload); }); if (error.ok()) { @@ -590,13 +416,15 @@ bool DataChannelController::DataChannelSendData( void DataChannelController::NotifyDataChannelsOfTransportCreated() { RTC_DCHECK_RUN_ON(network_thread()); - data_channel_transport_invoker_.AsyncInvoke( - RTC_FROM_HERE, signaling_thread(), [this] { - RTC_DCHECK_RUN_ON(signaling_thread()); - for (const auto& channel : sctp_data_channels_) { - channel->OnTransportChannelCreated(); + signaling_thread()->PostTask( + ToQueuedTask([self = weak_factory_.GetWeakPtr()] { + if (self) { + RTC_DCHECK_RUN_ON(self->signaling_thread()); + for (const auto& channel : self->sctp_data_channels_) { + channel->OnTransportChannelCreated(); + } } - }); + })); } rtc::Thread* DataChannelController::network_thread() const { diff --git a/pc/data_channel_controller.h b/pc/data_channel_controller.h index 6759288825..7b1ff26690 100644 --- a/pc/data_channel_controller.h +++ b/pc/data_channel_controller.h @@ -11,22 +11,36 @@ #ifndef PC_DATA_CHANNEL_CONTROLLER_H_ #define PC_DATA_CHANNEL_CONTROLLER_H_ +#include + #include #include #include #include +#include "api/data_channel_interface.h" +#include "api/scoped_refptr.h" +#include "api/sequence_checker.h" +#include "api/transport/data_channel_transport_interface.h" +#include "media/base/media_channel.h" +#include "media/base/media_engine.h" +#include "media/base/stream_params.h" #include "pc/channel.h" -#include "pc/rtp_data_channel.h" +#include "pc/data_channel_utils.h" #include "pc/sctp_data_channel.h" +#include "rtc_base/checks.h" +#include "rtc_base/copy_on_write_buffer.h" +#include "rtc_base/ssl_stream_adapter.h" +#include "rtc_base/third_party/sigslot/sigslot.h" +#include "rtc_base/thread.h" +#include "rtc_base/thread_annotations.h" #include "rtc_base/weak_ptr.h" namespace webrtc { class PeerConnection; -class DataChannelController : public RtpDataChannelProviderInterface, - public SctpDataChannelProviderInterface, +class DataChannelController : public SctpDataChannelProviderInterface, public DataChannelSink { public: explicit DataChannelController(PeerConnection* pc) : pc_(pc) {} @@ -37,13 +51,12 @@ class DataChannelController : public RtpDataChannelProviderInterface, DataChannelController(DataChannelController&&) = delete; DataChannelController& operator=(DataChannelController&& other) = delete; - // Implements RtpDataChannelProviderInterface/ + // Implements // SctpDataChannelProviderInterface. - bool SendData(const cricket::SendDataParams& params, + bool SendData(int sid, + const SendDataParams& params, const rtc::CopyOnWriteBuffer& payload, cricket::SendDataResult* result) override; - bool ConnectDataChannel(RtpDataChannel* webrtc_data_channel) override; - void DisconnectDataChannel(RtpDataChannel* webrtc_data_channel) override; bool ConnectDataChannel(SctpDataChannel* webrtc_data_channel) override; void DisconnectDataChannel(SctpDataChannel* webrtc_data_channel) override; void AddSctpDataStream(int sid) override; @@ -57,7 +70,7 @@ class DataChannelController : public RtpDataChannelProviderInterface, void OnChannelClosing(int channel_id) override; void OnChannelClosed(int channel_id) override; void OnReadyToSend() override; - void OnTransportClosed() override; + void OnTransportClosed(RTCError error) override; // Called from PeerConnection::SetupDataChannelTransport_n void SetupDataChannelTransport_n(); @@ -88,46 +101,21 @@ class DataChannelController : public RtpDataChannelProviderInterface, RTC_DCHECK_RUN_ON(signaling_thread()); return !sctp_data_channels_.empty(); } - bool HasRtpDataChannels() const { - RTC_DCHECK_RUN_ON(signaling_thread()); - return !rtp_data_channels_.empty(); - } - - void UpdateLocalRtpDataChannels(const cricket::StreamParamsVec& streams); - void UpdateRemoteRtpDataChannels(const cricket::StreamParamsVec& streams); // Accessors - cricket::DataChannelType data_channel_type() const; - void set_data_channel_type(cricket::DataChannelType type); - cricket::RtpDataChannel* rtp_data_channel() const { - return rtp_data_channel_; - } - void set_rtp_data_channel(cricket::RtpDataChannel* channel) { - rtp_data_channel_ = channel; - } DataChannelTransportInterface* data_channel_transport() const; void set_data_channel_transport(DataChannelTransportInterface* transport); - const std::map>* - rtp_data_channels() const; - sigslot::signal1& SignalRtpDataChannelCreated() { - RTC_DCHECK_RUN_ON(signaling_thread()); - return SignalRtpDataChannelCreated_; - } sigslot::signal1& SignalSctpDataChannelCreated() { RTC_DCHECK_RUN_ON(signaling_thread()); return SignalSctpDataChannelCreated_; } // Called when the transport for the data channels is closed or destroyed. - void OnTransportChannelClosed(); + void OnTransportChannelClosed(RTCError error); void OnSctpDataChannelClosed(SctpDataChannel* channel); private: - rtc::scoped_refptr InternalCreateRtpDataChannel( - const std::string& label, - const DataChannelInit* config) /* RTC_RUN_ON(signaling_thread()) */; - rtc::scoped_refptr InternalCreateSctpDataChannel( const std::string& label, const InternalDataChannelInit* @@ -143,16 +131,9 @@ class DataChannelController : public RtpDataChannelProviderInterface, const InternalDataChannelInit& config) RTC_RUN_ON(signaling_thread()); - void CreateRemoteRtpDataChannel(const std::string& label, - uint32_t remote_ssrc) - RTC_RUN_ON(signaling_thread()); - - void UpdateClosingRtpDataChannels( - const std::vector& active_channels, - bool is_local_update) RTC_RUN_ON(signaling_thread()); - // Called from SendData when data_channel_transport() is true. - bool DataChannelSendData(const cricket::SendDataParams& params, + bool DataChannelSendData(int sid, + const SendDataParams& params, const rtc::CopyOnWriteBuffer& payload, cricket::SendDataResult* result); @@ -163,17 +144,6 @@ class DataChannelController : public RtpDataChannelProviderInterface, rtc::Thread* network_thread() const; rtc::Thread* signaling_thread() const; - // Specifies which kind of data channel is allowed. This is controlled - // by the chrome command-line flag and constraints: - // 1. If chrome command-line switch 'enable-sctp-data-channels' is enabled, - // constraint kEnableDtlsSrtp is true, and constaint kEnableRtpDataChannels is - // not set or false, SCTP is allowed (DCT_SCTP); - // 2. If constraint kEnableRtpDataChannels is true, RTP is allowed (DCT_RTP); - // 3. If both 1&2 are false, data channel is not allowed (DCT_NONE). - cricket::DataChannelType data_channel_type_ = - cricket::DCT_NONE; // TODO(bugs.webrtc.org/9987): Accessed on both - // signaling and network thread. - // Plugin transport used for data channels. Pointer may be accessed and // checked from any thread, but the object may only be touched on the // network thread. @@ -185,22 +155,12 @@ class DataChannelController : public RtpDataChannelProviderInterface, bool data_channel_transport_ready_to_send_ RTC_GUARDED_BY(signaling_thread()) = false; - // |rtp_data_channel_| is used if in RTP data channel mode, - // |data_channel_transport_| when using SCTP. - cricket::RtpDataChannel* rtp_data_channel_ = nullptr; - // TODO(bugs.webrtc.org/9987): Accessed on both - // signaling and some other thread. - SctpSidAllocator sid_allocator_ /* RTC_GUARDED_BY(signaling_thread()) */; std::vector> sctp_data_channels_ RTC_GUARDED_BY(signaling_thread()); std::vector> sctp_data_channels_to_free_ RTC_GUARDED_BY(signaling_thread()); - // Map of label -> DataChannel - std::map> rtp_data_channels_ - RTC_GUARDED_BY(signaling_thread()); - // Signals from |data_channel_transport_|. These are invoked on the // signaling thread. // TODO(bugs.webrtc.org/11547): These '_s' signals likely all belong on the @@ -216,18 +176,13 @@ class DataChannelController : public RtpDataChannelProviderInterface, sigslot::signal1 SignalDataChannelTransportChannelClosed_s RTC_GUARDED_BY(signaling_thread()); - sigslot::signal1 SignalRtpDataChannelCreated_ - RTC_GUARDED_BY(signaling_thread()); sigslot::signal1 SignalSctpDataChannelCreated_ RTC_GUARDED_BY(signaling_thread()); - // Used from the network thread to invoke data channel transport signals on - // the signaling thread. - rtc::AsyncInvoker data_channel_transport_invoker_ - RTC_GUARDED_BY(network_thread()); - // Owning PeerConnection. PeerConnection* const pc_; + // The weak pointers must be dereferenced and invalidated on the signalling + // thread only. rtc::WeakPtrFactory weak_factory_{this}; }; diff --git a/pc/data_channel_integrationtest.cc b/pc/data_channel_integrationtest.cc new file mode 100644 index 0000000000..47ea74a4b2 --- /dev/null +++ b/pc/data_channel_integrationtest.cc @@ -0,0 +1,845 @@ +/* + * Copyright 2012 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "api/data_channel_interface.h" +#include "api/dtmf_sender_interface.h" +#include "api/peer_connection_interface.h" +#include "api/scoped_refptr.h" +#include "api/units/time_delta.h" +#include "pc/test/integration_test_helpers.h" +#include "pc/test/mock_peer_connection_observers.h" +#include "rtc_base/fake_clock.h" +#include "rtc_base/gunit.h" +#include "rtc_base/ref_counted_object.h" +#include "rtc_base/virtual_socket_server.h" +#include "system_wrappers/include/field_trial.h" +#include "test/gtest.h" + +namespace webrtc { + +namespace { + +// All tests in this file require SCTP support. +#ifdef WEBRTC_HAVE_SCTP + +class DataChannelIntegrationTest : public PeerConnectionIntegrationBaseTest, + public ::testing::WithParamInterface< + std::tuple> { + protected: + DataChannelIntegrationTest() + : PeerConnectionIntegrationBaseTest(std::get<0>(GetParam()), + std::get<1>(GetParam())) {} +}; + +// Fake clock must be set before threads are started to prevent race on +// Set/GetClockForTesting(). +// To achieve that, multiple inheritance is used as a mixin pattern +// where order of construction is finely controlled. +// This also ensures peerconnection is closed before switching back to non-fake +// clock, avoiding other races and DCHECK failures such as in rtp_sender.cc. +class FakeClockForTest : public rtc::ScopedFakeClock { + protected: + FakeClockForTest() { + // Some things use a time of "0" as a special value, so we need to start out + // the fake clock at a nonzero time. + // TODO(deadbeef): Fix this. + AdvanceTime(webrtc::TimeDelta::Seconds(1)); + } + + // Explicit handle. + ScopedFakeClock& FakeClock() { return *this; } +}; + +class DataChannelIntegrationTestPlanB + : public PeerConnectionIntegrationBaseTest { + protected: + DataChannelIntegrationTestPlanB() + : PeerConnectionIntegrationBaseTest(SdpSemantics::kPlanB) {} +}; + +class DataChannelIntegrationTestUnifiedPlan + : public PeerConnectionIntegrationBaseTest { + protected: + DataChannelIntegrationTestUnifiedPlan() + : PeerConnectionIntegrationBaseTest(SdpSemantics::kUnifiedPlan) {} +}; + +// This test causes a PeerConnection to enter Disconnected state, and +// sends data on a DataChannel while disconnected. +// The data should be surfaced when the connection reestablishes. +TEST_P(DataChannelIntegrationTest, DataChannelWhileDisconnected) { + CreatePeerConnectionWrappers(); + ConnectFakeSignaling(); + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer(), kDefaultTimeout); + std::string data1 = "hello first"; + caller()->data_channel()->Send(DataBuffer(data1)); + EXPECT_EQ_WAIT(data1, callee()->data_observer()->last_message(), + kDefaultTimeout); + // Cause a network outage + virtual_socket_server()->set_drop_probability(1.0); + EXPECT_EQ_WAIT(PeerConnectionInterface::kIceConnectionDisconnected, + caller()->standardized_ice_connection_state(), + kDefaultTimeout); + std::string data2 = "hello second"; + caller()->data_channel()->Send(DataBuffer(data2)); + // Remove the network outage. The connection should reestablish. + virtual_socket_server()->set_drop_probability(0.0); + EXPECT_EQ_WAIT(data2, callee()->data_observer()->last_message(), + kDefaultTimeout); +} + +// This test causes a PeerConnection to enter Disconnected state, +// sends data on a DataChannel while disconnected, and then triggers +// an ICE restart. +// The data should be surfaced when the connection reestablishes. +TEST_P(DataChannelIntegrationTest, DataChannelWhileDisconnectedIceRestart) { + CreatePeerConnectionWrappers(); + ConnectFakeSignaling(); + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer(), kDefaultTimeout); + std::string data1 = "hello first"; + caller()->data_channel()->Send(DataBuffer(data1)); + EXPECT_EQ_WAIT(data1, callee()->data_observer()->last_message(), + kDefaultTimeout); + // Cause a network outage + virtual_socket_server()->set_drop_probability(1.0); + ASSERT_EQ_WAIT(PeerConnectionInterface::kIceConnectionDisconnected, + caller()->standardized_ice_connection_state(), + kDefaultTimeout); + std::string data2 = "hello second"; + caller()->data_channel()->Send(DataBuffer(data2)); + + // Trigger an ICE restart. The signaling channel is not affected by + // the network outage. + caller()->SetOfferAnswerOptions(IceRestartOfferAnswerOptions()); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + // Remove the network outage. The connection should reestablish. + virtual_socket_server()->set_drop_probability(0.0); + EXPECT_EQ_WAIT(data2, callee()->data_observer()->last_message(), + kDefaultTimeout); +} + +// This test sets up a call between two parties with audio, video and an SCTP +// data channel. +TEST_P(DataChannelIntegrationTest, EndToEndCallWithSctpDataChannel) { + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + // Expect that data channel created on caller side will show up for callee as + // well. + caller()->CreateDataChannel(); + caller()->AddAudioVideoTracks(); + callee()->AddAudioVideoTracks(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + // Ensure the existence of the SCTP data channel didn't impede audio/video. + MediaExpectations media_expectations; + media_expectations.ExpectBidirectionalAudioAndVideo(); + ASSERT_TRUE(ExpectNewFrames(media_expectations)); + // Caller data channel should already exist (it created one). Callee data + // channel may not exist yet, since negotiation happens in-band, not in SDP. + ASSERT_NE(nullptr, caller()->data_channel()); + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); + EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + + // Ensure data can be sent in both directions. + std::string data = "hello world"; + caller()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(data, callee()->data_observer()->last_message(), + kDefaultTimeout); + callee()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(data, caller()->data_observer()->last_message(), + kDefaultTimeout); +} + +// This test sets up a call between two parties with an SCTP +// data channel only, and sends messages of various sizes. +TEST_P(DataChannelIntegrationTest, + EndToEndCallWithSctpDataChannelVariousSizes) { + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + // Expect that data channel created on caller side will show up for callee as + // well. + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + // Caller data channel should already exist (it created one). Callee data + // channel may not exist yet, since negotiation happens in-band, not in SDP. + ASSERT_NE(nullptr, caller()->data_channel()); + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); + EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + + for (int message_size = 1; message_size < 100000; message_size *= 2) { + std::string data(message_size, 'a'); + caller()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(data, callee()->data_observer()->last_message(), + kDefaultTimeout); + callee()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(data, caller()->data_observer()->last_message(), + kDefaultTimeout); + } + // Specifically probe the area around the MTU size. + for (int message_size = 1100; message_size < 1300; message_size += 1) { + std::string data(message_size, 'a'); + caller()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(data, callee()->data_observer()->last_message(), + kDefaultTimeout); + callee()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(data, caller()->data_observer()->last_message(), + kDefaultTimeout); + } +} + +// This test sets up a call between two parties with an SCTP +// data channel only, and sends empty messages +TEST_P(DataChannelIntegrationTest, + EndToEndCallWithSctpDataChannelEmptyMessages) { + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + // Expect that data channel created on caller side will show up for callee as + // well. + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + // Caller data channel should already exist (it created one). Callee data + // channel may not exist yet, since negotiation happens in-band, not in SDP. + ASSERT_NE(nullptr, caller()->data_channel()); + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); + EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + + // Ensure data can be sent in both directions. + // Sending empty string data + std::string data = ""; + caller()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(1u, callee()->data_observer()->received_message_count(), + kDefaultTimeout); + EXPECT_TRUE(callee()->data_observer()->last_message().empty()); + EXPECT_FALSE(callee()->data_observer()->messages().back().binary); + callee()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(1u, caller()->data_observer()->received_message_count(), + kDefaultTimeout); + EXPECT_TRUE(caller()->data_observer()->last_message().empty()); + EXPECT_FALSE(caller()->data_observer()->messages().back().binary); + + // Sending empty binary data + rtc::CopyOnWriteBuffer empty_buffer; + caller()->data_channel()->Send(DataBuffer(empty_buffer, true)); + EXPECT_EQ_WAIT(2u, callee()->data_observer()->received_message_count(), + kDefaultTimeout); + EXPECT_TRUE(callee()->data_observer()->last_message().empty()); + EXPECT_TRUE(callee()->data_observer()->messages().back().binary); + callee()->data_channel()->Send(DataBuffer(empty_buffer, true)); + EXPECT_EQ_WAIT(2u, caller()->data_observer()->received_message_count(), + kDefaultTimeout); + EXPECT_TRUE(caller()->data_observer()->last_message().empty()); + EXPECT_TRUE(caller()->data_observer()->messages().back().binary); +} + +TEST_P(DataChannelIntegrationTest, + EndToEndCallWithSctpDataChannelLowestSafeMtu) { + // The lowest payload size limit that's tested and found safe for this + // application. Note that this is not the safe limit under all conditions; + // in particular, the default is not the largest DTLS signature, and + // this test does not use TURN. + const size_t kLowestSafePayloadSizeLimit = 1225; + + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + // Expect that data channel created on caller side will show up for callee as + // well. + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + // Caller data channel should already exist (it created one). Callee data + // channel may not exist yet, since negotiation happens in-band, not in SDP. + ASSERT_NE(nullptr, caller()->data_channel()); + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); + EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + + virtual_socket_server()->set_max_udp_payload(kLowestSafePayloadSizeLimit); + for (int message_size = 1140; message_size < 1240; message_size += 1) { + std::string data(message_size, 'a'); + caller()->data_channel()->Send(DataBuffer(data)); + ASSERT_EQ_WAIT(data, callee()->data_observer()->last_message(), + kDefaultTimeout); + callee()->data_channel()->Send(DataBuffer(data)); + ASSERT_EQ_WAIT(data, caller()->data_observer()->last_message(), + kDefaultTimeout); + } +} + +// This test verifies that lowering the MTU of the connection will cause +// the datachannel to not transmit reliably. +// The purpose of this test is to ensure that we know how a too-small MTU +// error manifests itself. +TEST_P(DataChannelIntegrationTest, EndToEndCallWithSctpDataChannelHarmfulMtu) { + // The lowest payload size limit that's tested and found safe for this + // application in this configuration (see test above). + const size_t kLowestSafePayloadSizeLimit = 1225; + // The size of the smallest message that fails to be delivered. + const size_t kMessageSizeThatIsNotDelivered = 1157; + + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_NE(nullptr, caller()->data_channel()); + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); + EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + + virtual_socket_server()->set_max_udp_payload(kLowestSafePayloadSizeLimit - 1); + // Probe for an undelivered or slowly delivered message. The exact + // size limit seems to be dependent on the message history, so make the + // code easily able to find the current value. + bool failure_seen = false; + for (size_t message_size = 1110; message_size < 1400; message_size++) { + const size_t message_count = + callee()->data_observer()->received_message_count(); + const std::string data(message_size, 'a'); + caller()->data_channel()->Send(DataBuffer(data)); + // Wait a very short time for the message to be delivered. + // Note: Waiting only 10 ms is too short for Windows bots; they will + // flakily fail at a random frame. + WAIT(callee()->data_observer()->received_message_count() > message_count, + 100); + if (callee()->data_observer()->received_message_count() == message_count) { + ASSERT_EQ(kMessageSizeThatIsNotDelivered, message_size); + failure_seen = true; + break; + } + } + ASSERT_TRUE(failure_seen); +} + +// Ensure that when the callee closes an SCTP data channel, the closing +// procedure results in the data channel being closed for the caller as well. +TEST_P(DataChannelIntegrationTest, CalleeClosesSctpDataChannel) { + // Same procedure as above test. + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + caller()->CreateDataChannel(); + caller()->AddAudioVideoTracks(); + callee()->AddAudioVideoTracks(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_NE(nullptr, caller()->data_channel()); + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + ASSERT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + + // Close the data channel on the callee side, and wait for it to reach the + // "closed" state on both sides. + callee()->data_channel()->Close(); + + DataChannelInterface::DataState expected_states[] = { + DataChannelInterface::DataState::kConnecting, + DataChannelInterface::DataState::kOpen, + DataChannelInterface::DataState::kClosing, + DataChannelInterface::DataState::kClosed}; + + EXPECT_EQ_WAIT(DataChannelInterface::DataState::kClosed, + caller()->data_observer()->state(), kDefaultTimeout); + EXPECT_THAT(caller()->data_observer()->states(), + ::testing::ElementsAreArray(expected_states)); + + EXPECT_EQ_WAIT(DataChannelInterface::DataState::kClosed, + callee()->data_observer()->state(), kDefaultTimeout); + EXPECT_THAT(callee()->data_observer()->states(), + ::testing::ElementsAreArray(expected_states)); +} + +TEST_P(DataChannelIntegrationTest, SctpDataChannelConfigSentToOtherSide) { + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + webrtc::DataChannelInit init; + init.id = 53; + init.maxRetransmits = 52; + caller()->CreateDataChannel("data-channel", &init); + caller()->AddAudioVideoTracks(); + callee()->AddAudioVideoTracks(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + // Since "negotiated" is false, the "id" parameter should be ignored. + EXPECT_NE(init.id, callee()->data_channel()->id()); + EXPECT_EQ("data-channel", callee()->data_channel()->label()); + EXPECT_EQ(init.maxRetransmits, callee()->data_channel()->maxRetransmits()); + EXPECT_FALSE(callee()->data_channel()->negotiated()); +} + +// Test usrsctp's ability to process unordered data stream, where data actually +// arrives out of order using simulated delays. Previously there have been some +// bugs in this area. +TEST_P(DataChannelIntegrationTest, StressTestUnorderedSctpDataChannel) { + // Introduce random network delays. + // Otherwise it's not a true "unordered" test. + virtual_socket_server()->set_delay_mean(20); + virtual_socket_server()->set_delay_stddev(5); + virtual_socket_server()->UpdateDelayDistribution(); + // Normal procedure, but with unordered data channel config. + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + webrtc::DataChannelInit init; + init.ordered = false; + caller()->CreateDataChannel(&init); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_NE(nullptr, caller()->data_channel()); + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + ASSERT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + + static constexpr int kNumMessages = 100; + // Deliberately chosen to be larger than the MTU so messages get fragmented. + static constexpr size_t kMaxMessageSize = 4096; + // Create and send random messages. + std::vector sent_messages; + for (int i = 0; i < kNumMessages; ++i) { + size_t length = + (rand() % kMaxMessageSize) + 1; // NOLINT (rand_r instead of rand) + std::string message; + ASSERT_TRUE(rtc::CreateRandomString(length, &message)); + caller()->data_channel()->Send(DataBuffer(message)); + callee()->data_channel()->Send(DataBuffer(message)); + sent_messages.push_back(message); + } + + // Wait for all messages to be received. + EXPECT_EQ_WAIT(rtc::checked_cast(kNumMessages), + caller()->data_observer()->received_message_count(), + kDefaultTimeout); + EXPECT_EQ_WAIT(rtc::checked_cast(kNumMessages), + callee()->data_observer()->received_message_count(), + kDefaultTimeout); + + // Sort and compare to make sure none of the messages were corrupted. + std::vector caller_received_messages; + absl::c_transform(caller()->data_observer()->messages(), + std::back_inserter(caller_received_messages), + [](const auto& a) { return a.data; }); + + std::vector callee_received_messages; + absl::c_transform(callee()->data_observer()->messages(), + std::back_inserter(callee_received_messages), + [](const auto& a) { return a.data; }); + + absl::c_sort(sent_messages); + absl::c_sort(caller_received_messages); + absl::c_sort(callee_received_messages); + EXPECT_EQ(sent_messages, caller_received_messages); + EXPECT_EQ(sent_messages, callee_received_messages); +} + +// This test sets up a call between two parties with audio, and video. When +// audio and video are setup and flowing, an SCTP data channel is negotiated. +TEST_P(DataChannelIntegrationTest, AddSctpDataChannelInSubsequentOffer) { + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + // Do initial offer/answer with audio/video. + caller()->AddAudioVideoTracks(); + callee()->AddAudioVideoTracks(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + // Create data channel and do new offer and answer. + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + // Caller data channel should already exist (it created one). Callee data + // channel may not exist yet, since negotiation happens in-band, not in SDP. + ASSERT_NE(nullptr, caller()->data_channel()); + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); + EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + // Ensure data can be sent in both directions. + std::string data = "hello world"; + caller()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(data, callee()->data_observer()->last_message(), + kDefaultTimeout); + callee()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(data, caller()->data_observer()->last_message(), + kDefaultTimeout); +} + +// Set up a connection initially just using SCTP data channels, later upgrading +// to audio/video, ensuring frames are received end-to-end. Effectively the +// inverse of the test above. +// This was broken in M57; see https://crbug.com/711243 +TEST_P(DataChannelIntegrationTest, SctpDataChannelToAudioVideoUpgrade) { + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + // Do initial offer/answer with just data channel. + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + // Wait until data can be sent over the data channel. + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + ASSERT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + + // Do subsequent offer/answer with two-way audio and video. Audio and video + // should end up bundled on the DTLS/ICE transport already used for data. + caller()->AddAudioVideoTracks(); + callee()->AddAudioVideoTracks(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + MediaExpectations media_expectations; + media_expectations.ExpectBidirectionalAudioAndVideo(); + ASSERT_TRUE(ExpectNewFrames(media_expectations)); +} + +static void MakeSpecCompliantSctpOffer(cricket::SessionDescription* desc) { + cricket::SctpDataContentDescription* dcd_offer = + GetFirstSctpDataContentDescription(desc); + // See https://crbug.com/webrtc/11211 - this function is a no-op + ASSERT_TRUE(dcd_offer); + dcd_offer->set_use_sctpmap(false); + dcd_offer->set_protocol("UDP/DTLS/SCTP"); +} + +// Test that the data channel works when a spec-compliant SCTP m= section is +// offered (using "a=sctp-port" instead of "a=sctpmap", and using +// "UDP/DTLS/SCTP" as the protocol). +TEST_P(DataChannelIntegrationTest, + DataChannelWorksWhenSpecCompliantSctpOfferReceived) { + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + caller()->CreateDataChannel(); + caller()->SetGeneratedSdpMunger(MakeSpecCompliantSctpOffer); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); + EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); + EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + + // Ensure data can be sent in both directions. + std::string data = "hello world"; + caller()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(data, callee()->data_observer()->last_message(), + kDefaultTimeout); + callee()->data_channel()->Send(DataBuffer(data)); + EXPECT_EQ_WAIT(data, caller()->data_observer()->last_message(), + kDefaultTimeout); +} + +// Test that after closing PeerConnections, they stop sending any packets (ICE, +// DTLS, RTP...). +TEST_P(DataChannelIntegrationTest, ClosingConnectionStopsPacketFlow) { + // Set up audio/video/data, wait for some frames to be received. + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + caller()->AddAudioVideoTracks(); + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + MediaExpectations media_expectations; + media_expectations.CalleeExpectsSomeAudioAndVideo(); + ASSERT_TRUE(ExpectNewFrames(media_expectations)); + // Close PeerConnections. + ClosePeerConnections(); + // Pump messages for a second, and ensure no new packets end up sent. + uint32_t sent_packets_a = virtual_socket_server()->sent_packets(); + WAIT(false, 1000); + uint32_t sent_packets_b = virtual_socket_server()->sent_packets(); + EXPECT_EQ(sent_packets_a, sent_packets_b); +} + +// Test that transport stats are generated by the RTCStatsCollector for a +// connection that only involves data channels. This is a regression test for +// crbug.com/826972. +TEST_P(DataChannelIntegrationTest, + TransportStatsReportedForDataChannelOnlyConnection) { + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + caller()->CreateDataChannel(); + + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_channel(), kDefaultTimeout); + + auto caller_report = caller()->NewGetStats(); + EXPECT_EQ(1u, caller_report->GetStatsOfType().size()); + auto callee_report = callee()->NewGetStats(); + EXPECT_EQ(1u, callee_report->GetStatsOfType().size()); +} + +TEST_P(DataChannelIntegrationTest, QueuedPacketsGetDeliveredInReliableMode) { + CreatePeerConnectionWrappers(); + ConnectFakeSignaling(); + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_channel(), kDefaultTimeout); + + caller()->data_channel()->Send(DataBuffer("hello first")); + ASSERT_EQ_WAIT(1u, callee()->data_observer()->received_message_count(), + kDefaultTimeout); + // Cause a temporary network outage + virtual_socket_server()->set_drop_probability(1.0); + for (int i = 1; i <= 10; i++) { + caller()->data_channel()->Send(DataBuffer("Sent while blocked")); + } + // Nothing should be delivered during outage. Short wait. + EXPECT_EQ_WAIT(1u, callee()->data_observer()->received_message_count(), 10); + // Reverse outage + virtual_socket_server()->set_drop_probability(0.0); + // All packets should be delivered. + EXPECT_EQ_WAIT(11u, callee()->data_observer()->received_message_count(), + kDefaultTimeout); +} + +TEST_P(DataChannelIntegrationTest, QueuedPacketsGetDroppedInUnreliableMode) { + CreatePeerConnectionWrappers(); + ConnectFakeSignaling(); + DataChannelInit init; + init.maxRetransmits = 0; + init.ordered = false; + caller()->CreateDataChannel(&init); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_channel(), kDefaultTimeout); + caller()->data_channel()->Send(DataBuffer("hello first")); + ASSERT_EQ_WAIT(1u, callee()->data_observer()->received_message_count(), + kDefaultTimeout); + // Cause a temporary network outage + virtual_socket_server()->set_drop_probability(1.0); + // Send a few packets. Note that all get dropped only when all packets + // fit into the receiver receive window/congestion window, so that they + // actually get sent. + for (int i = 1; i <= 10; i++) { + caller()->data_channel()->Send(DataBuffer("Sent while blocked")); + } + // Nothing should be delivered during outage. + // We do a short wait to verify that delivery count is still 1. + WAIT(false, 10); + EXPECT_EQ(1u, callee()->data_observer()->received_message_count()); + // Reverse the network outage. + virtual_socket_server()->set_drop_probability(0.0); + // Send a new packet, and wait for it to be delivered. + caller()->data_channel()->Send(DataBuffer("After block")); + EXPECT_EQ_WAIT("After block", callee()->data_observer()->last_message(), + kDefaultTimeout); + // Some messages should be lost, but first and last message should have + // been delivered. + // First, check that the protocol guarantee is preserved. + EXPECT_GT(11u, callee()->data_observer()->received_message_count()); + EXPECT_LE(2u, callee()->data_observer()->received_message_count()); + // Then, check that observed behavior (lose all messages) has not changed + EXPECT_EQ(2u, callee()->data_observer()->received_message_count()); +} + +TEST_P(DataChannelIntegrationTest, + QueuedPacketsGetDroppedInLifetimeLimitedMode) { + CreatePeerConnectionWrappers(); + ConnectFakeSignaling(); + DataChannelInit init; + init.maxRetransmitTime = 1; + init.ordered = false; + caller()->CreateDataChannel(&init); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_channel(), kDefaultTimeout); + caller()->data_channel()->Send(DataBuffer("hello first")); + ASSERT_EQ_WAIT(1u, callee()->data_observer()->received_message_count(), + kDefaultTimeout); + // Cause a temporary network outage + virtual_socket_server()->set_drop_probability(1.0); + for (int i = 1; i <= 200; i++) { + caller()->data_channel()->Send(DataBuffer("Sent while blocked")); + } + // Nothing should be delivered during outage. + // We do a short wait to verify that delivery count is still 1, + // and to make sure max packet lifetime (which is in ms) is exceeded. + WAIT(false, 10); + EXPECT_EQ(1u, callee()->data_observer()->received_message_count()); + // Reverse the network outage. + virtual_socket_server()->set_drop_probability(0.0); + // Send a new packet, and wait for it to be delivered. + caller()->data_channel()->Send(DataBuffer("After block")); + EXPECT_EQ_WAIT("After block", callee()->data_observer()->last_message(), + kDefaultTimeout); + // Some messages should be lost, but first and last message should have + // been delivered. + // First, check that the protocol guarantee is preserved. + EXPECT_GT(202u, callee()->data_observer()->received_message_count()); + EXPECT_LE(2u, callee()->data_observer()->received_message_count()); + // Then, check that observed behavior (lose some messages) has not changed + if (webrtc::field_trial::IsEnabled("WebRTC-DataChannel-Dcsctp")) { + // DcSctp loses all messages. This is correct. + EXPECT_EQ(2u, callee()->data_observer()->received_message_count()); + } else { + // Usrsctp loses some messages, but keeps messages not attempted. + // THIS IS THE WRONG BEHAVIOR. According to discussion in + // https://github.com/sctplab/usrsctp/issues/584, all these packets + // should be discarded. + // TODO(bugs.webrtc.org/12731): Fix this. + EXPECT_EQ(90u, callee()->data_observer()->received_message_count()); + } +} + +TEST_P(DataChannelIntegrationTest, + SomeQueuedPacketsGetDroppedInMaxRetransmitsMode) { + CreatePeerConnectionWrappers(); + ConnectFakeSignaling(); + DataChannelInit init; + init.maxRetransmits = 0; + init.ordered = false; + caller()->CreateDataChannel(&init); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_channel(), kDefaultTimeout); + caller()->data_channel()->Send(DataBuffer("hello first")); + ASSERT_EQ_WAIT(1u, callee()->data_observer()->received_message_count(), + kDefaultTimeout); + // Cause a temporary network outage + virtual_socket_server()->set_drop_probability(1.0); + // Fill the buffer until queued data starts to build + size_t packet_counter = 0; + while (caller()->data_channel()->buffered_amount() < 1 && + packet_counter < 10000) { + packet_counter++; + caller()->data_channel()->Send(DataBuffer("Sent while blocked")); + } + if (caller()->data_channel()->buffered_amount()) { + RTC_LOG(LS_INFO) << "Buffered data after " << packet_counter << " packets"; + } else { + RTC_LOG(LS_INFO) << "No buffered data after " << packet_counter + << " packets"; + } + // Nothing should be delivered during outage. + // We do a short wait to verify that delivery count is still 1. + WAIT(false, 10); + EXPECT_EQ(1u, callee()->data_observer()->received_message_count()); + // Reverse the network outage. + virtual_socket_server()->set_drop_probability(0.0); + // Send a new packet, and wait for it to be delivered. + caller()->data_channel()->Send(DataBuffer("After block")); + EXPECT_EQ_WAIT("After block", callee()->data_observer()->last_message(), + kDefaultTimeout); + // Some messages should be lost, but first and last message should have + // been delivered. + // Due to the fact that retransmissions are only counted when the packet + // goes on the wire, NOT when they are stalled in queue due to + // congestion, we expect some of the packets to be delivered, because + // congestion prevented them from being sent. + // Citation: https://tools.ietf.org/html/rfc7496#section-3.1 + + // First, check that the protocol guarantee is preserved. + EXPECT_GT(packet_counter, + callee()->data_observer()->received_message_count()); + EXPECT_LE(2u, callee()->data_observer()->received_message_count()); + // Then, check that observed behavior (lose between 100 and 200 messages) + // has not changed. + // Usrsctp behavior is different on Android (177) and other platforms (122). + // Dcsctp loses 432 packets. + EXPECT_GT(2 + packet_counter - 100, + callee()->data_observer()->received_message_count()); + EXPECT_LT(2 + packet_counter - 500, + callee()->data_observer()->received_message_count()); +} + +INSTANTIATE_TEST_SUITE_P( + DataChannelIntegrationTest, + DataChannelIntegrationTest, + Combine(Values(SdpSemantics::kPlanB, SdpSemantics::kUnifiedPlan), + Values("WebRTC-DataChannel-Dcsctp/Enabled/", + "WebRTC-DataChannel-Dcsctp/Disabled/"))); + +TEST_F(DataChannelIntegrationTestUnifiedPlan, + EndToEndCallWithBundledSctpDataChannel) { + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + caller()->CreateDataChannel(); + caller()->AddAudioVideoTracks(); + callee()->AddAudioVideoTracks(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_TRUE_WAIT(caller()->pc()->GetSctpTransport(), kDefaultTimeout); + ASSERT_EQ_WAIT(SctpTransportState::kConnected, + caller()->pc()->GetSctpTransport()->Information().state(), + kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_channel(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); +} + +TEST_F(DataChannelIntegrationTestUnifiedPlan, + EndToEndCallWithDataChannelOnlyConnects) { + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_channel(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + ASSERT_TRUE(caller()->data_observer()->IsOpen()); +} + +TEST_F(DataChannelIntegrationTestUnifiedPlan, DataChannelClosesWhenClosed) { + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + caller()->data_channel()->Close(); + ASSERT_TRUE_WAIT(!callee()->data_observer()->IsOpen(), kDefaultTimeout); +} + +TEST_F(DataChannelIntegrationTestUnifiedPlan, + DataChannelClosesWhenClosedReverse) { + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + callee()->data_channel()->Close(); + ASSERT_TRUE_WAIT(!caller()->data_observer()->IsOpen(), kDefaultTimeout); +} + +TEST_F(DataChannelIntegrationTestUnifiedPlan, + DataChannelClosesWhenPeerConnectionClosed) { + ASSERT_TRUE(CreatePeerConnectionWrappers()); + ConnectFakeSignaling(); + caller()->CreateDataChannel(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer(), kDefaultTimeout); + ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); + caller()->pc()->Close(); + ASSERT_TRUE_WAIT(!callee()->data_observer()->IsOpen(), kDefaultTimeout); +} + +#endif // WEBRTC_HAVE_SCTP + +} // namespace + +} // namespace webrtc diff --git a/pc/data_channel_unittest.cc b/pc/data_channel_unittest.cc index 7601c80b08..770892cbe1 100644 --- a/pc/data_channel_unittest.cc +++ b/pc/data_channel_unittest.cc @@ -13,6 +13,7 @@ #include #include +#include "media/sctp/sctp_transport_internal.h" #include "pc/sctp_data_channel.h" #include "pc/sctp_utils.h" #include "pc/test/fake_data_channel_provider.h" @@ -286,9 +287,9 @@ TEST_F(SctpDataChannelTest, OpenMessageSent) { SetChannelReady(); EXPECT_GE(webrtc_data_channel_->id(), 0); - EXPECT_EQ(cricket::DMT_CONTROL, provider_->last_send_data_params().type); - EXPECT_EQ(provider_->last_send_data_params().ssrc, - static_cast(webrtc_data_channel_->id())); + EXPECT_EQ(webrtc::DataMessageType::kControl, + provider_->last_send_data_params().type); + EXPECT_EQ(provider_->last_sid(), webrtc_data_channel_->id()); } TEST_F(SctpDataChannelTest, QueuedOpenMessageSent) { @@ -296,9 +297,9 @@ TEST_F(SctpDataChannelTest, QueuedOpenMessageSent) { SetChannelReady(); provider_->set_send_blocked(false); - EXPECT_EQ(cricket::DMT_CONTROL, provider_->last_send_data_params().type); - EXPECT_EQ(provider_->last_send_data_params().ssrc, - static_cast(webrtc_data_channel_->id())); + EXPECT_EQ(webrtc::DataMessageType::kControl, + provider_->last_send_data_params().type); + EXPECT_EQ(provider_->last_sid(), webrtc_data_channel_->id()); } // Tests that the DataChannel created after transport gets ready can enter OPEN @@ -334,8 +335,8 @@ TEST_F(SctpDataChannelTest, SendUnorderedAfterReceivesOpenAck) { // Emulates receiving an OPEN_ACK message. cricket::ReceiveDataParams params; - params.ssrc = init.id; - params.type = cricket::DMT_CONTROL; + params.sid = init.id; + params.type = webrtc::DataMessageType::kControl; rtc::CopyOnWriteBuffer payload; webrtc::WriteDataChannelOpenAckMessage(&payload); dc->OnDataReceived(params, payload); @@ -360,8 +361,8 @@ TEST_F(SctpDataChannelTest, SendUnorderedAfterReceiveData) { // Emulates receiving a DATA message. cricket::ReceiveDataParams params; - params.ssrc = init.id; - params.type = cricket::DMT_TEXT; + params.sid = init.id; + params.type = webrtc::DataMessageType::kText; webrtc::DataBuffer buffer("data"); dc->OnDataReceived(params, buffer.data); @@ -382,7 +383,8 @@ TEST_F(SctpDataChannelTest, OpenWaitsForOpenMesssage) { provider_->set_send_blocked(false); EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kOpen, webrtc_data_channel_->state(), 1000); - EXPECT_EQ(cricket::DMT_CONTROL, provider_->last_send_data_params().type); + EXPECT_EQ(webrtc::DataMessageType::kControl, + provider_->last_send_data_params().type); } // Tests that close first makes sure all queued data gets sent. @@ -403,42 +405,43 @@ TEST_F(SctpDataChannelTest, QueuedCloseFlushes) { EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kClosed, webrtc_data_channel_->state(), 1000); EXPECT_TRUE(webrtc_data_channel_->error().ok()); - EXPECT_EQ(cricket::DMT_TEXT, provider_->last_send_data_params().type); + EXPECT_EQ(webrtc::DataMessageType::kText, + provider_->last_send_data_params().type); } -// Tests that messages are sent with the right ssrc. -TEST_F(SctpDataChannelTest, SendDataSsrc) { +// Tests that messages are sent with the right id. +TEST_F(SctpDataChannelTest, SendDataId) { webrtc_data_channel_->SetSctpSid(1); SetChannelReady(); webrtc::DataBuffer buffer("data"); EXPECT_TRUE(webrtc_data_channel_->Send(buffer)); - EXPECT_EQ(1U, provider_->last_send_data_params().ssrc); + EXPECT_EQ(1, provider_->last_sid()); } -// Tests that the incoming messages with wrong ssrcs are rejected. -TEST_F(SctpDataChannelTest, ReceiveDataWithInvalidSsrc) { +// Tests that the incoming messages with wrong ids are rejected. +TEST_F(SctpDataChannelTest, ReceiveDataWithInvalidId) { webrtc_data_channel_->SetSctpSid(1); SetChannelReady(); AddObserver(); cricket::ReceiveDataParams params; - params.ssrc = 0; + params.sid = 0; webrtc::DataBuffer buffer("abcd"); webrtc_data_channel_->OnDataReceived(params, buffer.data); EXPECT_EQ(0U, observer_->messages_received()); } -// Tests that the incoming messages with right ssrcs are acceted. -TEST_F(SctpDataChannelTest, ReceiveDataWithValidSsrc) { +// Tests that the incoming messages with right ids are accepted. +TEST_F(SctpDataChannelTest, ReceiveDataWithValidId) { webrtc_data_channel_->SetSctpSid(1); SetChannelReady(); AddObserver(); cricket::ReceiveDataParams params; - params.ssrc = 1; + params.sid = 1; webrtc::DataBuffer buffer("abcd"); webrtc_data_channel_->OnDataReceived(params, buffer.data); @@ -459,7 +462,7 @@ TEST_F(SctpDataChannelTest, NoMsgSentIfNegotiatedAndNotFromOpenMsg) { rtc::Thread::Current(), rtc::Thread::Current()); EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kOpen, dc->state(), 1000); - EXPECT_EQ(0U, provider_->last_send_data_params().ssrc); + EXPECT_EQ(0, provider_->last_sid()); } // Tests that DataChannel::messages_received() and DataChannel::bytes_received() @@ -477,7 +480,7 @@ TEST_F(SctpDataChannelTest, VerifyMessagesAndBytesReceived) { webrtc_data_channel_->SetSctpSid(1); cricket::ReceiveDataParams params; - params.ssrc = 1; + params.sid = 1; // Default values. EXPECT_EQ(0U, webrtc_data_channel_->messages_received()); @@ -524,9 +527,9 @@ TEST_F(SctpDataChannelTest, OpenAckSentIfCreatedFromOpenMessage) { EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kOpen, dc->state(), 1000); - EXPECT_EQ(static_cast(config.id), - provider_->last_send_data_params().ssrc); - EXPECT_EQ(cricket::DMT_CONTROL, provider_->last_send_data_params().type); + EXPECT_EQ(config.id, provider_->last_sid()); + EXPECT_EQ(webrtc::DataMessageType::kControl, + provider_->last_send_data_params().type); } // Tests the OPEN_ACK role assigned by InternalDataChannelInit. @@ -584,7 +587,7 @@ TEST_F(SctpDataChannelTest, ClosedWhenReceivedBufferFull) { memset(buffer.MutableData(), 0, buffer.size()); cricket::ReceiveDataParams params; - params.ssrc = 0; + params.sid = 0; // Receiving data without having an observer will overflow the buffer. for (size_t i = 0; i < 16 * 1024 + 1; ++i) { @@ -633,7 +636,9 @@ TEST_F(SctpDataChannelTest, TransportDestroyedWhileDataBuffered) { // Tell the data channel that its transport is being destroyed. // It should then stop using the transport (allowing us to delete it) and // transition to the "closed" state. - webrtc_data_channel_->OnTransportChannelClosed(); + webrtc::RTCError error(webrtc::RTCErrorType::OPERATION_ERROR_WITH_DATA, ""); + error.set_error_detail(webrtc::RTCErrorDetailType::SCTP_FAILURE); + webrtc_data_channel_->OnTransportChannelClosed(error); provider_.reset(nullptr); EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kClosed, webrtc_data_channel_->state(), kDefaultTimeout); @@ -644,6 +649,31 @@ TEST_F(SctpDataChannelTest, TransportDestroyedWhileDataBuffered) { webrtc_data_channel_->error().error_detail()); } +TEST_F(SctpDataChannelTest, TransportGotErrorCode) { + SetChannelReady(); + + // Tell the data channel that its transport is being destroyed with an + // error code. + // It should then report that error code. + webrtc::RTCError error(webrtc::RTCErrorType::OPERATION_ERROR_WITH_DATA, + "Transport channel closed"); + error.set_error_detail(webrtc::RTCErrorDetailType::SCTP_FAILURE); + error.set_sctp_cause_code( + static_cast(cricket::SctpErrorCauseCode::kProtocolViolation)); + webrtc_data_channel_->OnTransportChannelClosed(error); + provider_.reset(nullptr); + EXPECT_EQ_WAIT(webrtc::DataChannelInterface::kClosed, + webrtc_data_channel_->state(), kDefaultTimeout); + EXPECT_FALSE(webrtc_data_channel_->error().ok()); + EXPECT_EQ(webrtc::RTCErrorType::OPERATION_ERROR_WITH_DATA, + webrtc_data_channel_->error().type()); + EXPECT_EQ(webrtc::RTCErrorDetailType::SCTP_FAILURE, + webrtc_data_channel_->error().error_detail()); + EXPECT_EQ( + static_cast(cricket::SctpErrorCauseCode::kProtocolViolation), + webrtc_data_channel_->error().sctp_cause_code()); +} + class SctpSidAllocatorTest : public ::testing::Test { protected: SctpSidAllocator allocator_; diff --git a/pc/data_channel_utils.cc b/pc/data_channel_utils.cc index 51d6af941f..a772241c3e 100644 --- a/pc/data_channel_utils.cc +++ b/pc/data_channel_utils.cc @@ -10,6 +10,10 @@ #include "pc/data_channel_utils.h" +#include + +#include "rtc_base/checks.h" + namespace webrtc { bool PacketQueue::Empty() const { @@ -47,8 +51,4 @@ void PacketQueue::Swap(PacketQueue* other) { other->packets_.swap(packets_); } -bool IsSctpLike(cricket::DataChannelType type) { - return type == cricket::DCT_SCTP; -} - } // namespace webrtc diff --git a/pc/data_channel_utils.h b/pc/data_channel_utils.h index 13c6620cd8..85cacdb563 100644 --- a/pc/data_channel_utils.h +++ b/pc/data_channel_utils.h @@ -11,6 +11,8 @@ #ifndef PC_DATA_CHANNEL_UTILS_H_ #define PC_DATA_CHANNEL_UTILS_H_ +#include +#include #include #include #include @@ -55,8 +57,6 @@ struct DataChannelStats { uint64_t bytes_received; }; -bool IsSctpLike(cricket::DataChannelType type); - } // namespace webrtc #endif // PC_DATA_CHANNEL_UTILS_H_ diff --git a/pc/dtls_srtp_transport.cc b/pc/dtls_srtp_transport.cc index dacbcb411d..ac091c6131 100644 --- a/pc/dtls_srtp_transport.cc +++ b/pc/dtls_srtp_transport.cc @@ -15,6 +15,7 @@ #include #include +#include "api/dtls_transport_interface.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/ssl_stream_adapter.h" @@ -114,10 +115,9 @@ bool DtlsSrtpTransport::IsDtlsConnected() { auto rtcp_dtls_transport = rtcp_mux_enabled() ? nullptr : rtcp_dtls_transport_; return (rtp_dtls_transport_ && - rtp_dtls_transport_->dtls_state() == - cricket::DTLS_TRANSPORT_CONNECTED && + rtp_dtls_transport_->dtls_state() == DtlsTransportState::kConnected && (!rtcp_dtls_transport || rtcp_dtls_transport->dtls_state() == - cricket::DTLS_TRANSPORT_CONNECTED)); + DtlsTransportState::kConnected)); } bool DtlsSrtpTransport::IsDtlsWritable() { @@ -166,7 +166,6 @@ void DtlsSrtpTransport::SetupRtpDtlsSrtp() { static_cast(send_key.size()), send_extension_ids, selected_crypto_suite, &recv_key[0], static_cast(recv_key.size()), recv_extension_ids)) { - SignalDtlsSrtpSetupFailure(this, /*rtcp=*/false); RTC_LOG(LS_WARNING) << "DTLS-SRTP key installation for RTP failed"; } } @@ -198,7 +197,6 @@ void DtlsSrtpTransport::SetupRtcpDtlsSrtp() { selected_crypto_suite, &rtcp_recv_key[0], static_cast(rtcp_recv_key.size()), recv_extension_ids)) { - SignalDtlsSrtpSetupFailure(this, /*rtcp=*/true); RTC_LOG(LS_WARNING) << "DTLS-SRTP key installation for RTCP failed"; } } @@ -277,14 +275,16 @@ void DtlsSrtpTransport::SetDtlsTransport( } if (*old_dtls_transport) { - (*old_dtls_transport)->SignalDtlsState.disconnect(this); + (*old_dtls_transport)->UnsubscribeDtlsTransportState(this); } *old_dtls_transport = new_dtls_transport; if (new_dtls_transport) { - new_dtls_transport->SignalDtlsState.connect( - this, &DtlsSrtpTransport::OnDtlsState); + new_dtls_transport->SubscribeDtlsTransportState( + this, + [this](cricket::DtlsTransportInternal* transport, + DtlsTransportState state) { OnDtlsState(transport, state); }); } } @@ -299,13 +299,15 @@ void DtlsSrtpTransport::SetRtcpDtlsTransport( } void DtlsSrtpTransport::OnDtlsState(cricket::DtlsTransportInternal* transport, - cricket::DtlsTransportState state) { + DtlsTransportState state) { RTC_DCHECK(transport == rtp_dtls_transport_ || transport == rtcp_dtls_transport_); - SignalDtlsStateChange(); + if (on_dtls_state_change_) { + on_dtls_state_change_(); + } - if (state != cricket::DTLS_TRANSPORT_CONNECTED) { + if (state != DtlsTransportState::kConnected) { ResetParams(); return; } @@ -318,4 +320,8 @@ void DtlsSrtpTransport::OnWritableState( MaybeSetupDtlsSrtp(); } +void DtlsSrtpTransport::SetOnDtlsStateChange( + std::function callback) { + on_dtls_state_change_ = std::move(callback); +} } // namespace webrtc diff --git a/pc/dtls_srtp_transport.h b/pc/dtls_srtp_transport.h index c63a3ca5dd..9c52dcf809 100644 --- a/pc/dtls_srtp_transport.h +++ b/pc/dtls_srtp_transport.h @@ -11,10 +11,12 @@ #ifndef PC_DTLS_SRTP_TRANSPORT_H_ #define PC_DTLS_SRTP_TRANSPORT_H_ +#include #include #include "absl/types/optional.h" #include "api/crypto_params.h" +#include "api/dtls_transport_interface.h" #include "api/rtc_error.h" #include "p2p/base/dtls_transport_internal.h" #include "p2p/base/packet_transport_internal.h" @@ -45,8 +47,7 @@ class DtlsSrtpTransport : public SrtpTransport { void UpdateRecvEncryptedHeaderExtensionIds( const std::vector& recv_extension_ids); - sigslot::signal SignalDtlsSrtpSetupFailure; - sigslot::signal<> SignalDtlsStateChange; + void SetOnDtlsStateChange(std::function callback); RTCError SetSrtpSendKey(const cricket::CryptoParams& params) override { return RTCError(RTCErrorType::UNSUPPORTED_OPERATION, @@ -82,7 +83,7 @@ class DtlsSrtpTransport : public SrtpTransport { cricket::DtlsTransportInternal* rtcp_dtls_transport); void OnDtlsState(cricket::DtlsTransportInternal* dtls_transport, - cricket::DtlsTransportState state); + DtlsTransportState state); // Override the SrtpTransport::OnWritableState. void OnWritableState(rtc::PacketTransportInternal* packet_transport) override; @@ -96,6 +97,7 @@ class DtlsSrtpTransport : public SrtpTransport { absl::optional> recv_extension_ids_; bool active_reset_srtp_params_ = false; + std::function on_dtls_state_change_; }; } // namespace webrtc diff --git a/pc/dtls_transport.cc b/pc/dtls_transport.cc index 550ede790d..074f44e22b 100644 --- a/pc/dtls_transport.cc +++ b/pc/dtls_transport.cc @@ -12,41 +12,31 @@ #include +#include "absl/types/optional.h" +#include "api/dtls_transport_interface.h" +#include "api/sequence_checker.h" #include "pc/ice_transport.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/ref_counted_object.h" +#include "rtc_base/ssl_certificate.h" namespace webrtc { -namespace { - -DtlsTransportState TranslateState(cricket::DtlsTransportState internal_state) { - switch (internal_state) { - case cricket::DTLS_TRANSPORT_NEW: - return DtlsTransportState::kNew; - case cricket::DTLS_TRANSPORT_CONNECTING: - return DtlsTransportState::kConnecting; - case cricket::DTLS_TRANSPORT_CONNECTED: - return DtlsTransportState::kConnected; - case cricket::DTLS_TRANSPORT_CLOSED: - return DtlsTransportState::kClosed; - case cricket::DTLS_TRANSPORT_FAILED: - return DtlsTransportState::kFailed; - } - RTC_CHECK_NOTREACHED(); -} - -} // namespace - // Implementation of DtlsTransportInterface DtlsTransport::DtlsTransport( std::unique_ptr internal) : owner_thread_(rtc::Thread::Current()), info_(DtlsTransportState::kNew), internal_dtls_transport_(std::move(internal)), - ice_transport_(new rtc::RefCountedObject( + ice_transport_(rtc::make_ref_counted( internal_dtls_transport_->ice_transport())) { RTC_DCHECK(internal_dtls_transport_.get()); - internal_dtls_transport_->SignalDtlsState.connect( - this, &DtlsTransport::OnInternalDtlsState); + internal_dtls_transport_->SubscribeDtlsTransportState( + [this](cricket::DtlsTransportInternal* transport, + DtlsTransportState state) { + OnInternalDtlsState(transport, state); + }); UpdateInformation(); } @@ -81,7 +71,7 @@ void DtlsTransport::Clear() { RTC_DCHECK_RUN_ON(owner_thread_); RTC_DCHECK(internal()); bool must_send_event = - (internal()->dtls_state() != cricket::DTLS_TRANSPORT_CLOSED); + (internal()->dtls_state() != DtlsTransportState::kClosed); // The destructor of cricket::DtlsTransportInternal calls back // into DtlsTransport, so we can't hold the lock while releasing. std::unique_ptr transport_to_release; @@ -98,7 +88,7 @@ void DtlsTransport::Clear() { void DtlsTransport::OnInternalDtlsState( cricket::DtlsTransportInternal* transport, - cricket::DtlsTransportState state) { + DtlsTransportState state) { RTC_DCHECK_RUN_ON(owner_thread_); RTC_DCHECK(transport == internal()); RTC_DCHECK(state == internal()->dtls_state()); @@ -113,7 +103,7 @@ void DtlsTransport::UpdateInformation() { MutexLock lock(&lock_); if (internal_dtls_transport_) { if (internal_dtls_transport_->dtls_state() == - cricket::DTLS_TRANSPORT_CONNECTED) { + DtlsTransportState::kConnected) { bool success = true; int ssl_cipher_suite; int tls_version; @@ -123,20 +113,19 @@ void DtlsTransport::UpdateInformation() { success &= internal_dtls_transport_->GetSrtpCryptoSuite(&srtp_cipher); if (success) { info_ = DtlsTransportInformation( - TranslateState(internal_dtls_transport_->dtls_state()), tls_version, + internal_dtls_transport_->dtls_state(), tls_version, ssl_cipher_suite, srtp_cipher, internal_dtls_transport_->GetRemoteSSLCertChain()); } else { RTC_LOG(LS_ERROR) << "DtlsTransport in connected state has incomplete " "TLS information"; info_ = DtlsTransportInformation( - TranslateState(internal_dtls_transport_->dtls_state()), - absl::nullopt, absl::nullopt, absl::nullopt, + internal_dtls_transport_->dtls_state(), absl::nullopt, + absl::nullopt, absl::nullopt, internal_dtls_transport_->GetRemoteSSLCertChain()); } } else { - info_ = DtlsTransportInformation( - TranslateState(internal_dtls_transport_->dtls_state())); + info_ = DtlsTransportInformation(internal_dtls_transport_->dtls_state()); } } else { info_ = DtlsTransportInformation(DtlsTransportState::kClosed); diff --git a/pc/dtls_transport.h b/pc/dtls_transport.h index ff8108ca90..cca4cc980a 100644 --- a/pc/dtls_transport.h +++ b/pc/dtls_transport.h @@ -17,7 +17,11 @@ #include "api/ice_transport_interface.h" #include "api/scoped_refptr.h" #include "p2p/base/dtls_transport.h" +#include "p2p/base/dtls_transport_internal.h" +#include "pc/ice_transport.h" #include "rtc_base/synchronization/mutex.h" +#include "rtc_base/thread.h" +#include "rtc_base/thread_annotations.h" namespace webrtc { @@ -25,8 +29,7 @@ class IceTransportWithPointer; // This implementation wraps a cricket::DtlsTransport, and takes // ownership of it. -class DtlsTransport : public DtlsTransportInterface, - public sigslot::has_slots<> { +class DtlsTransport : public DtlsTransportInterface { public: // This object must be constructed and updated on a consistent thread, // the same thread as the one the cricket::DtlsTransportInternal object @@ -57,7 +60,7 @@ class DtlsTransport : public DtlsTransportInterface, private: void OnInternalDtlsState(cricket::DtlsTransportInternal* transport, - cricket::DtlsTransportState state); + DtlsTransportState state); void UpdateInformation(); DtlsTransportObserverInterface* observer_ = nullptr; diff --git a/pc/dtls_transport_unittest.cc b/pc/dtls_transport_unittest.cc index a3f0a7ce8b..f80d99b05e 100644 --- a/pc/dtls_transport_unittest.cc +++ b/pc/dtls_transport_unittest.cc @@ -63,7 +63,7 @@ class DtlsTransportTest : public ::testing::Test { } cricket_transport->SetSslCipherSuite(kNonsenseCipherSuite); transport_ = - new rtc::RefCountedObject(std::move(cricket_transport)); + rtc::make_ref_counted(std::move(cricket_transport)); } void CompleteDtlsHandshake() { @@ -86,8 +86,8 @@ class DtlsTransportTest : public ::testing::Test { TEST_F(DtlsTransportTest, CreateClearDelete) { auto cricket_transport = std::make_unique( "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP); - rtc::scoped_refptr webrtc_transport = - new rtc::RefCountedObject(std::move(cricket_transport)); + auto webrtc_transport = + rtc::make_ref_counted(std::move(cricket_transport)); ASSERT_TRUE(webrtc_transport->internal()); ASSERT_EQ(DtlsTransportState::kNew, webrtc_transport->Information().state()); webrtc_transport->Clear(); diff --git a/pc/dtmf_sender.cc b/pc/dtmf_sender.cc index 10378028c8..67c3fac134 100644 --- a/pc/dtmf_sender.cc +++ b/pc/dtmf_sender.cc @@ -18,6 +18,7 @@ #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/ref_counted_object.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/thread.h" namespace webrtc { @@ -64,9 +65,7 @@ rtc::scoped_refptr DtmfSender::Create( if (!signaling_thread) { return nullptr; } - rtc::scoped_refptr dtmf_sender( - new rtc::RefCountedObject(signaling_thread, provider)); - return dtmf_sender; + return rtc::make_ref_counted(signaling_thread, provider); } DtmfSender::DtmfSender(rtc::Thread* signaling_thread, @@ -86,19 +85,22 @@ DtmfSender::DtmfSender(rtc::Thread* signaling_thread, } DtmfSender::~DtmfSender() { + RTC_DCHECK_RUN_ON(signaling_thread_); StopSending(); } void DtmfSender::RegisterObserver(DtmfSenderObserverInterface* observer) { + RTC_DCHECK_RUN_ON(signaling_thread_); observer_ = observer; } void DtmfSender::UnregisterObserver() { + RTC_DCHECK_RUN_ON(signaling_thread_); observer_ = nullptr; } bool DtmfSender::CanInsertDtmf() { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); if (!provider_) { return false; } @@ -109,7 +111,7 @@ bool DtmfSender::InsertDtmf(const std::string& tones, int duration, int inter_tone_gap, int comma_delay) { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); if (duration > kDtmfMaxDurationMs || duration < kDtmfMinDurationMs || inter_tone_gap < kDtmfMinGapMs || comma_delay < kDtmfMinGapMs) { @@ -132,38 +134,49 @@ bool DtmfSender::InsertDtmf(const std::string& tones, duration_ = duration; inter_tone_gap_ = inter_tone_gap; comma_delay_ = comma_delay; - // Clear the previous queue. - dtmf_driver_.Clear(); - // Kick off a new DTMF task queue. + + // Cancel any remaining tasks for previous tones. + if (safety_flag_) { + safety_flag_->SetNotAlive(); + } + safety_flag_ = PendingTaskSafetyFlag::Create(); + // Kick off a new DTMF task. QueueInsertDtmf(RTC_FROM_HERE, 1 /*ms*/); return true; } std::string DtmfSender::tones() const { + RTC_DCHECK_RUN_ON(signaling_thread_); return tones_; } int DtmfSender::duration() const { + RTC_DCHECK_RUN_ON(signaling_thread_); return duration_; } int DtmfSender::inter_tone_gap() const { + RTC_DCHECK_RUN_ON(signaling_thread_); return inter_tone_gap_; } int DtmfSender::comma_delay() const { + RTC_DCHECK_RUN_ON(signaling_thread_); return comma_delay_; } void DtmfSender::QueueInsertDtmf(const rtc::Location& posted_from, uint32_t delay_ms) { - dtmf_driver_.AsyncInvokeDelayed( - posted_from, signaling_thread_, [this] { DoInsertDtmf(); }, delay_ms); + signaling_thread_->PostDelayedTask( + ToQueuedTask(safety_flag_, + [this] { + RTC_DCHECK_RUN_ON(signaling_thread_); + DoInsertDtmf(); + }), + delay_ms); } void DtmfSender::DoInsertDtmf() { - RTC_DCHECK(signaling_thread_->IsCurrent()); - // Get the first DTMF tone from the tone buffer. Unrecognized characters will // be ignored and skipped. size_t first_tone_pos = tones_.find_first_of(kDtmfValidTones); @@ -222,13 +235,17 @@ void DtmfSender::DoInsertDtmf() { } void DtmfSender::OnProviderDestroyed() { + RTC_DCHECK_RUN_ON(signaling_thread_); + RTC_LOG(LS_INFO) << "The Dtmf provider is deleted. Clear the sending queue."; StopSending(); provider_ = nullptr; } void DtmfSender::StopSending() { - dtmf_driver_.Clear(); + if (safety_flag_) { + safety_flag_->SetNotAlive(); + } } } // namespace webrtc diff --git a/pc/dtmf_sender.h b/pc/dtmf_sender.h index e332a7ef58..b64b50e09c 100644 --- a/pc/dtmf_sender.h +++ b/pc/dtmf_sender.h @@ -11,13 +11,18 @@ #ifndef PC_DTMF_SENDER_H_ #define PC_DTMF_SENDER_H_ +#include + #include #include "api/dtmf_sender_interface.h" -#include "api/proxy.h" -#include "rtc_base/async_invoker.h" +#include "api/scoped_refptr.h" +#include "pc/proxy.h" #include "rtc_base/constructor_magic.h" +#include "rtc_base/location.h" #include "rtc_base/ref_count.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/thread.h" // DtmfSender is the native implementation of the RTCDTMFSender defined by @@ -70,32 +75,34 @@ class DtmfSender : public DtmfSenderInterface, public sigslot::has_slots<> { private: DtmfSender(); - void QueueInsertDtmf(const rtc::Location& posted_from, uint32_t delay_ms); + void QueueInsertDtmf(const rtc::Location& posted_from, uint32_t delay_ms) + RTC_RUN_ON(signaling_thread_); // The DTMF sending task. - void DoInsertDtmf(); + void DoInsertDtmf() RTC_RUN_ON(signaling_thread_); void OnProviderDestroyed(); - void StopSending(); + void StopSending() RTC_RUN_ON(signaling_thread_); - DtmfSenderObserverInterface* observer_; + DtmfSenderObserverInterface* observer_ RTC_GUARDED_BY(signaling_thread_); rtc::Thread* signaling_thread_; - DtmfProviderInterface* provider_; - std::string tones_; - int duration_; - int inter_tone_gap_; - int comma_delay_; - // Invoker for running delayed tasks which feed the DTMF provider one tone at - // a time. - rtc::AsyncInvoker dtmf_driver_; + DtmfProviderInterface* provider_ RTC_GUARDED_BY(signaling_thread_); + std::string tones_ RTC_GUARDED_BY(signaling_thread_); + int duration_ RTC_GUARDED_BY(signaling_thread_); + int inter_tone_gap_ RTC_GUARDED_BY(signaling_thread_); + int comma_delay_ RTC_GUARDED_BY(signaling_thread_); + + // For cancelling the tasks which feed the DTMF provider one tone at a time. + rtc::scoped_refptr safety_flag_ RTC_GUARDED_BY( + signaling_thread_) RTC_PT_GUARDED_BY(signaling_thread_) = nullptr; RTC_DISALLOW_COPY_AND_ASSIGN(DtmfSender); }; // Define proxy for DtmfSenderInterface. -BEGIN_SIGNALING_PROXY_MAP(DtmfSender) -PROXY_SIGNALING_THREAD_DESTRUCTOR() +BEGIN_PRIMARY_PROXY_MAP(DtmfSender) +PROXY_PRIMARY_THREAD_DESTRUCTOR() PROXY_METHOD1(void, RegisterObserver, DtmfSenderObserverInterface*) PROXY_METHOD0(void, UnregisterObserver) PROXY_METHOD0(bool, CanInsertDtmf) @@ -104,7 +111,7 @@ PROXY_CONSTMETHOD0(std::string, tones) PROXY_CONSTMETHOD0(int, duration) PROXY_CONSTMETHOD0(int, inter_tone_gap) PROXY_CONSTMETHOD0(int, comma_delay) -END_PROXY_MAP() +END_PROXY_MAP(DtmfSender) // Get DTMF code from the DTMF event character. bool GetDtmfCode(char tone, int* code); diff --git a/pc/dtmf_sender_unittest.cc b/pc/dtmf_sender_unittest.cc index f7f229a887..261cbd0303 100644 --- a/pc/dtmf_sender_unittest.cc +++ b/pc/dtmf_sender_unittest.cc @@ -18,7 +18,6 @@ #include "rtc_base/fake_clock.h" #include "rtc_base/gunit.h" -#include "rtc_base/ref_counted_object.h" #include "rtc_base/time_utils.h" #include "test/gtest.h" @@ -118,8 +117,7 @@ class FakeDtmfProvider : public DtmfProviderInterface { class DtmfSenderTest : public ::testing::Test { protected: DtmfSenderTest() - : observer_(new rtc::RefCountedObject()), - provider_(new FakeDtmfProvider()) { + : observer_(new FakeDtmfObserver()), provider_(new FakeDtmfProvider()) { provider_->SetCanInsertDtmf(true); dtmf_ = DtmfSender::Create(rtc::Thread::Current(), provider_.get()); dtmf_->RegisterObserver(observer_.get()); diff --git a/pc/g3doc/dtls_transport.md b/pc/g3doc/dtls_transport.md new file mode 100644 index 0000000000..65206dff5d --- /dev/null +++ b/pc/g3doc/dtls_transport.md @@ -0,0 +1,53 @@ + + + +## Overview + +WebRTC uses DTLS in two ways: + +* to negotiate keys for SRTP encryption using + [DTLS-SRTP](https://www.rfc-editor.org/info/rfc5763) +* as a transport for SCTP which is used by the Datachannel API + +The W3C WebRTC API represents this as the +[DtlsTransport](https://w3c.github.io/webrtc-pc/#rtcdtlstransport-interface). + +The DTLS handshake happens after the ICE transport becomes writable and has +found a valid pair. It results in a set of keys being derived for DTLS-SRTP as +well as a fingerprint of the remote certificate which is compared to the one +given in the SDP `a=fingerprint:` line. + +This documentation provides an overview of how DTLS is implemented, i.e how the +following classes interact. + +## webrtc::DtlsTransport + +The [`webrtc::DtlsTransport`][1] class is a wrapper around the +`cricket::DtlsTransportInternal` and allows registering observers implementing +the `webrtc::DtlsTransportObserverInterface`. The +[`webrtc::DtlsTransportObserverInterface`][2] will provide updates to the +observers, passing around a snapshot of the transports state such as the +connection state, the remote certificate(s) and the SRTP ciphers as +[`DtlsTransportInformation`][3]. + +## cricket::DtlsTransportInternal + +The [`cricket::DtlsTransportInternal`][4] class is an interface. Its +implementation is [`cricket::DtlsTransport`][5]. The `cricket::DtlsTransport` +sends and receives network packets via an ICE transport. It also demultiplexes +DTLS packets and SRTP packets according to the scheme described in +[RFC 5764](https://tools.ietf.org/html/rfc5764#section-5.1.2). + +## webrtc::DtlsSrtpTranport + +The [`webrtc::DtlsSrtpTransport`][6] class is responsіble for extracting the +SRTP keys after the DTLS handshake as well as protection and unprotection of +SRTP packets via its [`cricket::SrtpSession`][7]. + +[1]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/pc/dtls_transport.h;l=32;drc=6a55e7307b78edb50f94a1ff1ef8393d58218369 +[2]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/dtls_transport_interface.h;l=76;drc=34437d5660a80393d631657329ef74c6538be25a +[3]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/dtls_transport_interface.h;l=41;drc=34437d5660a80393d631657329ef74c6538be25a +[4]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/dtls_transport_internal.h;l=63;drc=34437d5660a80393d631657329ef74c6538be25a +[5]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/p2p/base/dtls_transport.h;l=94;drc=653bab6790ac92c513b7cf4cd3ad59039c589a95 +[6]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/pc/dtls_srtp_transport.h;l=31;drc=c32f00ea9ddf3267257fe6b45d4d79c6f6bcb829 +[7]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/pc/srtp_session.h;l=33;drc=be66d95ab7f9428028806bbf66cb83800bda9241 diff --git a/pc/g3doc/peer_connection.md b/pc/g3doc/peer_connection.md new file mode 100644 index 0000000000..1eae135991 --- /dev/null +++ b/pc/g3doc/peer_connection.md @@ -0,0 +1,59 @@ + + + +# PeerConnection and friends + +The PeerConnection is the C++-level implementation of the Javascript +object "RTCPeerConnection" from the +[WEBRTC specification](https://w3c.github.io/webrtc-pc/). + +Like many objects in WebRTC, the PeerConnection is used via a factory and an +observer: + + * PeerConnectionFactory, which is created via a static Create method and takes + a PeerConnectionFactoryDependencies structure listing such things as + non-default threads and factories for use by all PeerConnections using + the same factory. (Using more than one factory should be avoided, since + it takes more resources.) + * PeerConnection itself, which is created by the method called + PeerConnectionFactory::CreatePeerConnectionOrError, and takes a + PeerConnectionInterface::RTCConfiguration argument, as well as + a PeerConnectionDependencies (even more factories, plus other stuff). + * PeerConnectionObserver (a member of PeerConnectionDependencies), which + contains the functions that will be called on events in the PeerConnection + +These types are visible in the API. + +## Internal structure of PeerConnection and friends + +The PeerConnection is, to a large extent, a "God object" - most things +that are done in WebRTC require a PeerConnection. + +Internally, it is divided into several objects, each with its own +responsibilities, all of which are owned by the PeerConnection and live +as long as the PeerConnection: + + * SdpOfferAnswerHandler takes care of negotiating configurations with + a remote peer, using SDP-formatted descriptions. + * RtpTransmissionManager takes care of the lists of RtpSenders, + RtpReceivers and RtpTransceivers that form the heart of the transmission + service. + * DataChannelController takes care of managing the PeerConnection's + DataChannels and its SctpTransport. + * JsepTransportController takes care of configuring the details of senders + and receivers. + * Call does management of overall call state. + * RtcStatsCollector (and its obsolete sibling, StatsCollector) collects + statistics from all the objects comprising the PeerConnection when + requested. + +There are a number of other smaller objects that are also owned by +the PeerConnection, but it would take too much space to describe them +all here; please consult the .h files. + +PeerConnectionFactory owns an object called ConnectionContext, and a +reference to this is passed to each PeerConnection. It is referenced +via an rtc::scoped_refptr, which means that it is guaranteed to be +alive as long as either the factory or one of the PeerConnections +is using it. + diff --git a/pc/g3doc/rtp.md b/pc/g3doc/rtp.md new file mode 100644 index 0000000000..38c1702ad3 --- /dev/null +++ b/pc/g3doc/rtp.md @@ -0,0 +1,56 @@ + + + +# RTP in WebRTC + +WebRTC uses the RTP protocol described in +[RFC3550](https://datatracker.ietf.org/doc/html/rfc3550) for transporting audio +and video. Media is encrypted using [SRTP](./srtp.md). + +## Allocation of payload types + +RTP packets have a payload type field that describes which media codec can be +used to handle a packet. For some (older) codecs like PCMU the payload type is +assigned statically as described in +[RFC3551](https://datatracker.ietf.org/doc/html/rfc3551). For others, it is +assigned dynamically through the SDP. **Note:** there are no guarantees on the +stability of a payload type assignment. + +For this allocation, the range from 96 to 127 is used. When this range is +exhausted, the allocation falls back to the range from 35 to 63 as permitted by +[section 5.1 of RFC3550][1]. Note that older versions of WebRTC failed to +recognize payload types in the lower range. Newer codecs (such as flexfec-03 and +AV1) will by default be allocated in that range. + +Payload types in the range 64 to 95 are not used to avoid confusion with RTCP as +described in [RFC5761](https://datatracker.ietf.org/doc/html/rfc5761). + +## Allocation of audio payload types + +Audio payload types are assigned from a table by the [PayloadTypeMapper][2] +class. New audio codecs should be allocated in the lower dynamic range [35,63], +starting at 63, to reduce collisions with payload types + +## Allocation of video payload types + +Video payload types are allocated by the +[GetPayloadTypesAndDefaultCodecs method][3]. The set of codecs depends on the +platform, in particular for H264 codecs and their different profiles. Payload +numbers are assigned ascending from 96 for video codecs and their +[associated retransmission format](https://datatracker.ietf.org/doc/html/rfc4588). +Some codecs like flexfec-03 and AV1 are assigned to the lower range [35,63] for +reasons explained above. When the upper range [96,127] is exhausted, payload +types are assigned to the lower range [35,63], starting at 35. + +## Handling of payload type collisions + +Due to the requirement that payload types must be uniquely identifiable when +using [BUNDLE](https://datatracker.ietf.org/doc/html/rfc8829) collisions between +the assignments of the audio and video payload types may arise. These are +resolved by the [UsedPayloadTypes][4] class which will reassign payload type +numbers descending from 127. + +[1]: https://datatracker.ietf.org/doc/html/rfc3550#section-5.1 +[2]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/media/engine/payload_type_mapper.cc;l=25;drc=4f26a3c7e8e20e0e0ca4ca67a6ebdf3f5543dc3f +[3]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/media/engine/webrtc_video_engine.cc;l=119;drc=b412efdb780c86e6530493afa403783d14985347 +[4]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/pc/used_ids.h;l=94;drc=b412efdb780c86e6530493afa403783d14985347 diff --git a/pc/g3doc/sctp_transport.md b/pc/g3doc/sctp_transport.md new file mode 100644 index 0000000000..254e264b0b --- /dev/null +++ b/pc/g3doc/sctp_transport.md @@ -0,0 +1,44 @@ + + + + +# SctpTransport + +## webrtc::SctpTransport + +The [`webrtc::SctpTransport`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/pc/sctp_transport.h;l=33?q=class%20webrtc::SctpTransport) class encapsulates an SCTP association, and exposes a +few properties of this association to the WebRTC user (such as Chrome). + +The SctpTransport is used to support Datachannels, as described in the [WebRTC +specification for the Peer-to-peer Data +API](https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api). + +The public interface ([`webrtc::SctpTransportInterface`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/sctp_transport_interface.h?q=webrtc::SctpTransportInterface)) exposes an observer +interface where the user can define a callback to be called whenever the state +of an SctpTransport changes; this callback is called on the network thread (as +set during PeerConnectionFactory initialization). + +The implementation of this object lives in pc/sctp_transport.{h,cc}, and is +basically a wrapper around a `cricket::SctpTransportInternal`, hiding its +implementation details and APIs that shoudldn't be accessed from the user. + +The `webrtc::SctpTransport` is a ref counted object; it should be regarded +as owned by the PeerConnection, and will be closed when the PeerConnection +closes, but the object itself may survive longer than the PeerConnection. + +## cricket::SctpTransportInternal + +[`cricket::SctpTransportInternal`](https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/media/sctp/sctp_transport_internal.h?q=cricket::SctpTransportInternal) owns two objects: The SCTP association object (currently +implemented by wrapping the usrsctp library) and the DTLS transport, which is +the object used to send and receive messages as emitted from or consumed by the +usrsctp library. + +It communicates state changes and events using sigslot. + +See header files for details. + + + + + + diff --git a/pc/g3doc/srtp.md b/pc/g3doc/srtp.md new file mode 100644 index 0000000000..47446157c9 --- /dev/null +++ b/pc/g3doc/srtp.md @@ -0,0 +1,72 @@ + + + +# SRTP in WebRTC + +WebRTC mandates encryption of media by means of the Secure Realtime Protocol, or +SRTP, which is described in +[RFC 3711](https://datatracker.ietf.org/doc/html/rfc3711). + +The key negotiation in WebRTC happens using DTLS-SRTP which is described in +[RFC 5764](https://datatracker.ietf.org/doc/html/rfc5764). The older +[SDES protocol](https://datatracker.ietf.org/doc/html/rfc4568) is implemented +but not enabled by default. + +Unencrypted RTP can be enabled for debugging purposes by setting the +PeerConnections [`disable_encryption`][1] option to true. + +## Supported cipher suites + +The implementation supports the following cipher suites: + +* SRTP_AES128_CM_HMAC_SHA1_80 +* SRTP_AEAD_AES_128_GCM +* SRTP_AEAD_AES_256_GCM + +The SRTP_AES128_CM_HMAC_SHA1_32 cipher suite is accepted for audio-only +connections if offered by the other side. It is not actively supported, see +[SelectCrypto][2] for details. + +The cipher suite ordering allows a non-WebRTC peer to prefer GCM cipher suites, +however they are not selected as default by two instances of the WebRTC library. + +## cricket::SrtpSession + +The [`cricket::SrtpSession`][3] is providing encryption and decryption of SRTP +packets using [`libsrtp`](https://github.com/cisco/libsrtp). Keys will be +provided by `SrtpTransport` or `DtlsSrtpTransport` in the [`SetSend`][4] and +[`SetRecv`][5] methods. + +Encryption and decryption happens in-place in the [`ProtectRtp`][6], +[`ProtectRtcp`][7], [`UnprotectRtp`][8] and [`UnprotectRtcp`][9] methods. The +`SrtpSession` class also takes care of initializing and deinitializing `libsrtp` +by keeping track of how many instances are being used. + +## webrtc::SrtpTransport and webrtc::DtlsSrtpTransport + +The [`webrtc::SrtpTransport`][10] class is controlling the `SrtpSession` +instances for RTP and RTCP. When +[rtcp-mux](https://datatracker.ietf.org/doc/html/rfc5761) is used, the +`SrtpSession` for RTCP is not needed. + +[`webrtc:DtlsSrtpTransport`][11] is a subclass of the `SrtpTransport` that +extracts the keying material when the DTLS handshake is done and configures it +in its base class. It will also become writable only once the DTLS handshake is +done. + +## cricket::SrtpFilter + +The [`cricket::SrtpFilter`][12] class is used to negotiate SDES. + +[1]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/api/peer_connection_interface.h;l=1413;drc=f467b445631189557d44de86a77ca6a0c3e2108d +[2]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/pc/media_session.cc;l=297;drc=3ac73bd0aa5322abee98f1ff8705af64a184bf61 +[3]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/pc/srtp_session.h;l=33;drc=be66d95ab7f9428028806bbf66cb83800bda9241 +[4]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/pc/srtp_session.h;l=40;drc=be66d95ab7f9428028806bbf66cb83800bda9241 +[5]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/pc/srtp_session.h;l=51;drc=be66d95ab7f9428028806bbf66cb83800bda9241 +[6]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/pc/srtp_session.h;l=62;drc=be66d95ab7f9428028806bbf66cb83800bda9241 +[7]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/pc/srtp_session.h;l=69;drc=be66d95ab7f9428028806bbf66cb83800bda9241 +[8]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/pc/srtp_session.h;l=72;drc=be66d95ab7f9428028806bbf66cb83800bda9241 +[9]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/pc/srtp_session.h;l=73;drc=be66d95ab7f9428028806bbf66cb83800bda9241 +[10]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/pc/srtp_transport.h;l=37;drc=a4d873786f10eedd72de25ad0d94ad7c53c1f68a +[11]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/pc/dtls_srtp_transport.h;l=31;drc=2f8e0536eb97ce2131e7a74e3ca06077aa0b64b3 +[12]: https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/pc/srtp_filter.h;drc=d15a575ec3528c252419149d35977e55269d8a41 diff --git a/pc/ice_server_parsing.cc b/pc/ice_server_parsing.cc index 47641375de..0daf8e445d 100644 --- a/pc/ice_server_parsing.cc +++ b/pc/ice_server_parsing.cc @@ -12,7 +12,9 @@ #include +#include #include // For std::isdigit. +#include #include #include "p2p/base/port_interface.h" @@ -21,6 +23,7 @@ #include "rtc_base/ip_address.h" #include "rtc_base/logging.h" #include "rtc_base/socket_address.h" +#include "rtc_base/string_encode.h" namespace webrtc { diff --git a/pc/ice_transport.cc b/pc/ice_transport.cc index ccc5ecd7f2..205846755d 100644 --- a/pc/ice_transport.cc +++ b/pc/ice_transport.cc @@ -10,8 +10,7 @@ #include "pc/ice_transport.h" -#include -#include +#include "api/sequence_checker.h" namespace webrtc { diff --git a/pc/ice_transport.h b/pc/ice_transport.h index c1529de6b7..11f3de5d27 100644 --- a/pc/ice_transport.h +++ b/pc/ice_transport.h @@ -12,8 +12,10 @@ #define PC_ICE_TRANSPORT_H_ #include "api/ice_transport_interface.h" +#include "api/sequence_checker.h" +#include "rtc_base/checks.h" #include "rtc_base/thread.h" -#include "rtc_base/thread_checker.h" +#include "rtc_base/thread_annotations.h" namespace webrtc { diff --git a/pc/ice_transport_unittest.cc b/pc/ice_transport_unittest.cc index 3711a86d5d..ebb46cb5d5 100644 --- a/pc/ice_transport_unittest.cc +++ b/pc/ice_transport_unittest.cc @@ -28,9 +28,8 @@ class IceTransportTest : public ::testing::Test {}; TEST_F(IceTransportTest, CreateNonSelfDeletingTransport) { auto cricket_transport = std::make_unique("name", 0, nullptr); - rtc::scoped_refptr ice_transport = - new rtc::RefCountedObject( - cricket_transport.get()); + auto ice_transport = + rtc::make_ref_counted(cricket_transport.get()); EXPECT_EQ(ice_transport->internal(), cricket_transport.get()); ice_transport->Clear(); EXPECT_NE(ice_transport->internal(), cricket_transport.get()); diff --git a/pc/jitter_buffer_delay.cc b/pc/jitter_buffer_delay.cc index c9506b3c59..801cef7215 100644 --- a/pc/jitter_buffer_delay.cc +++ b/pc/jitter_buffer_delay.cc @@ -10,13 +10,10 @@ #include "pc/jitter_buffer_delay.h" +#include "api/sequence_checker.h" #include "rtc_base/checks.h" -#include "rtc_base/location.h" -#include "rtc_base/logging.h" #include "rtc_base/numerics/safe_conversions.h" #include "rtc_base/numerics/safe_minmax.h" -#include "rtc_base/thread.h" -#include "rtc_base/thread_checker.h" namespace { constexpr int kDefaultDelay = 0; @@ -25,43 +22,21 @@ constexpr int kMaximumDelayMs = 10000; namespace webrtc { -JitterBufferDelay::JitterBufferDelay(rtc::Thread* worker_thread) - : signaling_thread_(rtc::Thread::Current()), worker_thread_(worker_thread) { - RTC_DCHECK(worker_thread_); -} - -void JitterBufferDelay::OnStart(cricket::Delayable* media_channel, - uint32_t ssrc) { - RTC_DCHECK_RUN_ON(signaling_thread_); - - media_channel_ = media_channel; - ssrc_ = ssrc; - - // Trying to apply cached delay for the audio stream. - if (cached_delay_seconds_) { - Set(cached_delay_seconds_.value()); - } -} - -void JitterBufferDelay::OnStop() { - RTC_DCHECK_RUN_ON(signaling_thread_); - // Assume that audio stream is no longer present. - media_channel_ = nullptr; - ssrc_ = absl::nullopt; +JitterBufferDelay::JitterBufferDelay() { + worker_thread_checker_.Detach(); } void JitterBufferDelay::Set(absl::optional delay_seconds) { - RTC_DCHECK_RUN_ON(worker_thread_); - - // TODO(kuddai) propagate absl::optional deeper down as default preference. - int delay_ms = - rtc::saturated_cast(delay_seconds.value_or(kDefaultDelay) * 1000); - delay_ms = rtc::SafeClamp(delay_ms, 0, kMaximumDelayMs); - + RTC_DCHECK_RUN_ON(&worker_thread_checker_); cached_delay_seconds_ = delay_seconds; - if (media_channel_ && ssrc_) { - media_channel_->SetBaseMinimumPlayoutDelayMs(ssrc_.value(), delay_ms); - } +} + +int JitterBufferDelay::GetMs() const { + RTC_DCHECK_RUN_ON(&worker_thread_checker_); + return rtc::SafeClamp( + rtc::saturated_cast(cached_delay_seconds_.value_or(kDefaultDelay) * + 1000), + 0, kMaximumDelayMs); } } // namespace webrtc diff --git a/pc/jitter_buffer_delay.h b/pc/jitter_buffer_delay.h index 8edfc6ce20..dc10e3d2ba 100644 --- a/pc/jitter_buffer_delay.h +++ b/pc/jitter_buffer_delay.h @@ -14,36 +14,25 @@ #include #include "absl/types/optional.h" -#include "media/base/delayable.h" -#include "pc/jitter_buffer_delay_interface.h" -#include "rtc_base/thread.h" +#include "api/sequence_checker.h" +#include "rtc_base/system/no_unique_address.h" namespace webrtc { // JitterBufferDelay converts delay from seconds to milliseconds for the // underlying media channel. It also handles cases when user sets delay before -// the start of media_channel by caching its request. Note, this class is not -// thread safe. Its thread safe version is defined in -// pc/jitter_buffer_delay_proxy.h -class JitterBufferDelay : public JitterBufferDelayInterface { +// the start of media_channel by caching its request. +class JitterBufferDelay { public: - // Must be called on signaling thread. - explicit JitterBufferDelay(rtc::Thread* worker_thread); + JitterBufferDelay(); - void OnStart(cricket::Delayable* media_channel, uint32_t ssrc) override; - - void OnStop() override; - - void Set(absl::optional delay_seconds) override; + void Set(absl::optional delay_seconds); + int GetMs() const; private: - // Throughout webrtc source, sometimes it is also called as |main_thread_|. - rtc::Thread* const signaling_thread_; - rtc::Thread* const worker_thread_; - // Media channel and ssrc together uniqely identify audio stream. - cricket::Delayable* media_channel_ = nullptr; - absl::optional ssrc_; - absl::optional cached_delay_seconds_; + RTC_NO_UNIQUE_ADDRESS SequenceChecker worker_thread_checker_; + absl::optional cached_delay_seconds_ + RTC_GUARDED_BY(&worker_thread_checker_); }; } // namespace webrtc diff --git a/pc/jitter_buffer_delay_interface.h b/pc/jitter_buffer_delay_interface.h deleted file mode 100644 index f2132d318d..0000000000 --- a/pc/jitter_buffer_delay_interface.h +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef PC_JITTER_BUFFER_DELAY_INTERFACE_H_ -#define PC_JITTER_BUFFER_DELAY_INTERFACE_H_ - -#include - -#include "absl/types/optional.h" -#include "media/base/delayable.h" -#include "rtc_base/ref_count.h" - -namespace webrtc { - -// JitterBufferDelay delivers user's queries to the underlying media channel. It -// can describe either video or audio delay for receiving stream. "Interface" -// suffix in the interface name is required to be compatible with api/proxy.cc -class JitterBufferDelayInterface : public rtc::RefCountInterface { - public: - // OnStart allows to uniqely identify to which receiving stream playout - // delay must correpond through |media_channel| and |ssrc| pair. - virtual void OnStart(cricket::Delayable* media_channel, uint32_t ssrc) = 0; - - // Indicates that underlying receiving stream is stopped. - virtual void OnStop() = 0; - - virtual void Set(absl::optional delay_seconds) = 0; -}; - -} // namespace webrtc - -#endif // PC_JITTER_BUFFER_DELAY_INTERFACE_H_ diff --git a/pc/jitter_buffer_delay_proxy.h b/pc/jitter_buffer_delay_proxy.h deleted file mode 100644 index b3380fd258..0000000000 --- a/pc/jitter_buffer_delay_proxy.h +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef PC_JITTER_BUFFER_DELAY_PROXY_H_ -#define PC_JITTER_BUFFER_DELAY_PROXY_H_ - -#include - -#include "api/proxy.h" -#include "media/base/delayable.h" -#include "pc/jitter_buffer_delay_interface.h" - -namespace webrtc { - -BEGIN_PROXY_MAP(JitterBufferDelay) -PROXY_SIGNALING_THREAD_DESTRUCTOR() -PROXY_METHOD2(void, OnStart, cricket::Delayable*, uint32_t) -PROXY_METHOD0(void, OnStop) -PROXY_WORKER_METHOD1(void, Set, absl::optional) -END_PROXY_MAP() - -} // namespace webrtc - -#endif // PC_JITTER_BUFFER_DELAY_PROXY_H_ diff --git a/pc/jitter_buffer_delay_unittest.cc b/pc/jitter_buffer_delay_unittest.cc index 7edd09acd2..b00075ceb5 100644 --- a/pc/jitter_buffer_delay_unittest.cc +++ b/pc/jitter_buffer_delay_unittest.cc @@ -13,79 +13,47 @@ #include #include "absl/types/optional.h" -#include "api/scoped_refptr.h" -#include "pc/test/mock_delayable.h" -#include "rtc_base/ref_counted_object.h" -#include "rtc_base/thread.h" -#include "test/gmock.h" #include "test/gtest.h" -using ::testing::Return; - -namespace { -constexpr int kSsrc = 1234; -} // namespace - namespace webrtc { class JitterBufferDelayTest : public ::testing::Test { public: - JitterBufferDelayTest() - : delay_(new rtc::RefCountedObject( - rtc::Thread::Current())) {} + JitterBufferDelayTest() {} protected: - rtc::scoped_refptr delay_; - MockDelayable delayable_; + JitterBufferDelay delay_; }; TEST_F(JitterBufferDelayTest, Set) { - delay_->OnStart(&delayable_, kSsrc); - - EXPECT_CALL(delayable_, SetBaseMinimumPlayoutDelayMs(kSsrc, 3000)) - .WillOnce(Return(true)); - // Delay in seconds. - delay_->Set(3.0); + delay_.Set(3.0); + EXPECT_EQ(delay_.GetMs(), 3000); } -TEST_F(JitterBufferDelayTest, Caching) { - // Check that value is cached before start. - delay_->Set(4.0); - - // Check that cached value applied on the start. - EXPECT_CALL(delayable_, SetBaseMinimumPlayoutDelayMs(kSsrc, 4000)) - .WillOnce(Return(true)); - delay_->OnStart(&delayable_, kSsrc); +TEST_F(JitterBufferDelayTest, DefaultValue) { + EXPECT_EQ(delay_.GetMs(), 0); // Default value is 0ms. } TEST_F(JitterBufferDelayTest, Clamping) { - delay_->OnStart(&delayable_, kSsrc); - // In current Jitter Buffer implementation (Audio or Video) maximum supported // value is 10000 milliseconds. - EXPECT_CALL(delayable_, SetBaseMinimumPlayoutDelayMs(kSsrc, 10000)) - .WillOnce(Return(true)); - delay_->Set(10.5); + delay_.Set(10.5); + EXPECT_EQ(delay_.GetMs(), 10000); // Test int overflow. - EXPECT_CALL(delayable_, SetBaseMinimumPlayoutDelayMs(kSsrc, 10000)) - .WillOnce(Return(true)); - delay_->Set(21474836470.0); + delay_.Set(21474836470.0); + EXPECT_EQ(delay_.GetMs(), 10000); - EXPECT_CALL(delayable_, SetBaseMinimumPlayoutDelayMs(kSsrc, 0)) - .WillOnce(Return(true)); - delay_->Set(-21474836470.0); + delay_.Set(-21474836470.0); + EXPECT_EQ(delay_.GetMs(), 0); // Boundary value in seconds to milliseconds conversion. - EXPECT_CALL(delayable_, SetBaseMinimumPlayoutDelayMs(kSsrc, 0)) - .WillOnce(Return(true)); - delay_->Set(0.0009); - - EXPECT_CALL(delayable_, SetBaseMinimumPlayoutDelayMs(kSsrc, 0)) - .WillOnce(Return(true)); + delay_.Set(0.0009); + EXPECT_EQ(delay_.GetMs(), 0); - delay_->Set(-2.0); + delay_.Set(-2.0); + EXPECT_EQ(delay_.GetMs(), 0); } } // namespace webrtc diff --git a/pc/jsep_transport.cc b/pc/jsep_transport.cc index 2f7615ab3b..e72088885f 100644 --- a/pc/jsep_transport.cc +++ b/pc/jsep_transport.cc @@ -14,7 +14,6 @@ #include #include -#include #include // for std::pair #include "api/array_view.h" @@ -25,7 +24,9 @@ #include "rtc_base/checks.h" #include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/logging.h" +#include "rtc_base/ref_counted_object.h" #include "rtc_base/strings/string_builder.h" +#include "rtc_base/trace_event.h" using webrtc::SdpType; @@ -77,7 +78,6 @@ JsepTransport::JsepTransport( std::unique_ptr unencrypted_rtp_transport, std::unique_ptr sdes_transport, std::unique_ptr dtls_srtp_transport, - std::unique_ptr datagram_rtp_transport, std::unique_ptr rtp_dtls_transport, std::unique_ptr rtcp_dtls_transport, std::unique_ptr sctp_transport) @@ -89,23 +89,23 @@ JsepTransport::JsepTransport( unencrypted_rtp_transport_(std::move(unencrypted_rtp_transport)), sdes_transport_(std::move(sdes_transport)), dtls_srtp_transport_(std::move(dtls_srtp_transport)), - rtp_dtls_transport_( - rtp_dtls_transport ? new rtc::RefCountedObject( - std::move(rtp_dtls_transport)) - : nullptr), - rtcp_dtls_transport_( - rtcp_dtls_transport - ? new rtc::RefCountedObject( - std::move(rtcp_dtls_transport)) - : nullptr), + rtp_dtls_transport_(rtp_dtls_transport + ? rtc::make_ref_counted( + std::move(rtp_dtls_transport)) + : nullptr), + rtcp_dtls_transport_(rtcp_dtls_transport + ? rtc::make_ref_counted( + std::move(rtcp_dtls_transport)) + : nullptr), sctp_data_channel_transport_( sctp_transport ? std::make_unique( sctp_transport.get()) : nullptr), sctp_transport_(sctp_transport - ? new rtc::RefCountedObject( + ? rtc::make_ref_counted( std::move(sctp_transport)) : nullptr) { + TRACE_EVENT0("webrtc", "JsepTransport::JsepTransport"); RTC_DCHECK(ice_transport_); RTC_DCHECK(rtp_dtls_transport_); // |rtcp_ice_transport_| must be present iff |rtcp_dtls_transport_| is @@ -128,15 +128,10 @@ JsepTransport::JsepTransport( if (sctp_transport_) { sctp_transport_->SetDtlsTransport(rtp_dtls_transport_); } - - if (datagram_rtp_transport_ && default_rtp_transport()) { - composite_rtp_transport_ = std::make_unique( - std::vector{ - datagram_rtp_transport_.get(), default_rtp_transport()}); - } } JsepTransport::~JsepTransport() { + TRACE_EVENT0("webrtc", "JsepTransport::~JsepTransport"); if (sctp_transport_) { sctp_transport_->Clear(); } @@ -155,7 +150,7 @@ webrtc::RTCError JsepTransport::SetLocalJsepTransportDescription( const JsepTransportDescription& jsep_description, SdpType type) { webrtc::RTCError error; - + TRACE_EVENT0("webrtc", "JsepTransport::SetLocalJsepTransportDescription"); RTC_DCHECK_RUN_ON(network_thread_); IceParameters ice_parameters = @@ -175,23 +170,20 @@ webrtc::RTCError JsepTransport::SetLocalJsepTransportDescription( } // If doing SDES, setup the SDES crypto parameters. - { - rtc::CritScope scope(&accessor_lock_); - if (sdes_transport_) { - RTC_DCHECK(!unencrypted_rtp_transport_); - RTC_DCHECK(!dtls_srtp_transport_); - if (!SetSdes(jsep_description.cryptos, - jsep_description.encrypted_header_extension_ids, type, - ContentSource::CS_LOCAL)) { - return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, - "Failed to setup SDES crypto parameters."); - } - } else if (dtls_srtp_transport_) { - RTC_DCHECK(!unencrypted_rtp_transport_); - RTC_DCHECK(!sdes_transport_); - dtls_srtp_transport_->UpdateRecvEncryptedHeaderExtensionIds( - jsep_description.encrypted_header_extension_ids); + if (sdes_transport_) { + RTC_DCHECK(!unencrypted_rtp_transport_); + RTC_DCHECK(!dtls_srtp_transport_); + if (!SetSdes(jsep_description.cryptos, + jsep_description.encrypted_header_extension_ids, type, + ContentSource::CS_LOCAL)) { + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, + "Failed to setup SDES crypto parameters."); } + } else if (dtls_srtp_transport_) { + RTC_DCHECK(!unencrypted_rtp_transport_); + RTC_DCHECK(!sdes_transport_); + dtls_srtp_transport_->UpdateRecvEncryptedHeaderExtensionIds( + jsep_description.encrypted_header_extension_ids); } bool ice_restarting = local_description_ != nullptr && @@ -212,18 +204,17 @@ webrtc::RTCError JsepTransport::SetLocalJsepTransportDescription( return error; } } - { - rtc::CritScope scope(&accessor_lock_); RTC_DCHECK(rtp_dtls_transport_->internal()); rtp_dtls_transport_->internal()->ice_transport()->SetIceParameters( ice_parameters); - if (rtcp_dtls_transport_) { - RTC_DCHECK(rtcp_dtls_transport_->internal()); - rtcp_dtls_transport_->internal()->ice_transport()->SetIceParameters( - ice_parameters); + { + if (rtcp_dtls_transport_) { + RTC_DCHECK(rtcp_dtls_transport_->internal()); + rtcp_dtls_transport_->internal()->ice_transport()->SetIceParameters( + ice_parameters); + } } - } // If PRANSWER/ANSWER is set, we should decide transport protocol type. if (type == SdpType::kPrAnswer || type == SdpType::kAnswer) { error = NegotiateAndSetDtlsParameters(type); @@ -232,13 +223,11 @@ webrtc::RTCError JsepTransport::SetLocalJsepTransportDescription( local_description_.reset(); return error; } - { - rtc::CritScope scope(&accessor_lock_); - if (needs_ice_restart_ && ice_restarting) { - needs_ice_restart_ = false; - RTC_LOG(LS_VERBOSE) << "needs-ice-restart flag cleared for transport " - << mid(); - } + + if (needs_ice_restart_ && ice_restarting) { + needs_ice_restart_ = false; + RTC_LOG(LS_VERBOSE) << "needs-ice-restart flag cleared for transport " + << mid(); } return webrtc::RTCError::OK(); @@ -247,6 +236,7 @@ webrtc::RTCError JsepTransport::SetLocalJsepTransportDescription( webrtc::RTCError JsepTransport::SetRemoteJsepTransportDescription( const JsepTransportDescription& jsep_description, webrtc::SdpType type) { + TRACE_EVENT0("webrtc", "JsepTransport::SetLocalJsepTransportDescription"); webrtc::RTCError error; RTC_DCHECK_RUN_ON(network_thread_); @@ -269,27 +259,24 @@ webrtc::RTCError JsepTransport::SetRemoteJsepTransportDescription( } // If doing SDES, setup the SDES crypto parameters. - { - rtc::CritScope lock(&accessor_lock_); - if (sdes_transport_) { - RTC_DCHECK(!unencrypted_rtp_transport_); - RTC_DCHECK(!dtls_srtp_transport_); - if (!SetSdes(jsep_description.cryptos, - jsep_description.encrypted_header_extension_ids, type, - ContentSource::CS_REMOTE)) { - return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, - "Failed to setup SDES crypto parameters."); - } - sdes_transport_->CacheRtpAbsSendTimeHeaderExtension( - jsep_description.rtp_abs_sendtime_extn_id); - } else if (dtls_srtp_transport_) { - RTC_DCHECK(!unencrypted_rtp_transport_); - RTC_DCHECK(!sdes_transport_); - dtls_srtp_transport_->UpdateSendEncryptedHeaderExtensionIds( - jsep_description.encrypted_header_extension_ids); - dtls_srtp_transport_->CacheRtpAbsSendTimeHeaderExtension( - jsep_description.rtp_abs_sendtime_extn_id); + if (sdes_transport_) { + RTC_DCHECK(!unencrypted_rtp_transport_); + RTC_DCHECK(!dtls_srtp_transport_); + if (!SetSdes(jsep_description.cryptos, + jsep_description.encrypted_header_extension_ids, type, + ContentSource::CS_REMOTE)) { + return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, + "Failed to setup SDES crypto parameters."); } + sdes_transport_->CacheRtpAbsSendTimeHeaderExtension( + jsep_description.rtp_abs_sendtime_extn_id); + } else if (dtls_srtp_transport_) { + RTC_DCHECK(!unencrypted_rtp_transport_); + RTC_DCHECK(!sdes_transport_); + dtls_srtp_transport_->UpdateSendEncryptedHeaderExtensionIds( + jsep_description.encrypted_header_extension_ids); + dtls_srtp_transport_->CacheRtpAbsSendTimeHeaderExtension( + jsep_description.rtp_abs_sendtime_extn_id); } remote_description_.reset(new JsepTransportDescription(jsep_description)); @@ -341,7 +328,7 @@ webrtc::RTCError JsepTransport::AddRemoteCandidates( } void JsepTransport::SetNeedsIceRestartFlag() { - rtc::CritScope scope(&accessor_lock_); + RTC_DCHECK_RUN_ON(network_thread_); if (!needs_ice_restart_) { needs_ice_restart_ = true; RTC_LOG(LS_VERBOSE) << "needs-ice-restart flag set for transport " << mid(); @@ -350,7 +337,6 @@ void JsepTransport::SetNeedsIceRestartFlag() { absl::optional JsepTransport::GetDtlsRole() const { RTC_DCHECK_RUN_ON(network_thread_); - rtc::CritScope scope(&accessor_lock_); RTC_DCHECK(rtp_dtls_transport_); RTC_DCHECK(rtp_dtls_transport_->internal()); rtc::SSLRole dtls_role; @@ -362,15 +348,18 @@ absl::optional JsepTransport::GetDtlsRole() const { } bool JsepTransport::GetStats(TransportStats* stats) { + TRACE_EVENT0("webrtc", "JsepTransport::GetStats"); RTC_DCHECK_RUN_ON(network_thread_); - rtc::CritScope scope(&accessor_lock_); stats->transport_name = mid(); stats->channel_stats.clear(); RTC_DCHECK(rtp_dtls_transport_->internal()); - bool ret = GetTransportStats(rtp_dtls_transport_->internal(), stats); + bool ret = GetTransportStats(rtp_dtls_transport_->internal(), + ICE_CANDIDATE_COMPONENT_RTP, stats); + if (rtcp_dtls_transport_) { RTC_DCHECK(rtcp_dtls_transport_->internal()); - ret &= GetTransportStats(rtcp_dtls_transport_->internal(), stats); + ret &= GetTransportStats(rtcp_dtls_transport_->internal(), + ICE_CANDIDATE_COMPONENT_RTCP, stats); } return ret; } @@ -378,6 +367,7 @@ bool JsepTransport::GetStats(TransportStats* stats) { webrtc::RTCError JsepTransport::VerifyCertificateFingerprint( const rtc::RTCCertificate* certificate, const rtc::SSLFingerprint* fingerprint) const { + TRACE_EVENT0("webrtc", "JsepTransport::VerifyCertificateFingerprint"); RTC_DCHECK_RUN_ON(network_thread_); if (!fingerprint) { return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, @@ -405,7 +395,6 @@ webrtc::RTCError JsepTransport::VerifyCertificateFingerprint( void JsepTransport::SetActiveResetSrtpParams(bool active_reset_srtp_params) { RTC_DCHECK_RUN_ON(network_thread_); - rtc::CritScope scope(&accessor_lock_); if (dtls_srtp_transport_) { RTC_LOG(INFO) << "Setting active_reset_srtp_params of DtlsSrtpTransport to: " @@ -417,6 +406,7 @@ void JsepTransport::SetActiveResetSrtpParams(bool active_reset_srtp_params) { void JsepTransport::SetRemoteIceParameters( const IceParameters& ice_parameters, IceTransportInternal* ice_transport) { + TRACE_EVENT0("webrtc", "JsepTransport::SetRemoteIceParameters"); RTC_DCHECK_RUN_ON(network_thread_); RTC_DCHECK(ice_transport); RTC_DCHECK(remote_description_); @@ -480,31 +470,22 @@ bool JsepTransport::SetRtcpMux(bool enable, } void JsepTransport::ActivateRtcpMux() { - { - // Don't hold the network_thread_ lock while calling other functions, - // since they might call other functions that call RTC_DCHECK_RUN_ON. - // TODO(https://crbug.com/webrtc/10318): Simplify when possible. - RTC_DCHECK_RUN_ON(network_thread_); - } - { - rtc::CritScope scope(&accessor_lock_); - if (unencrypted_rtp_transport_) { - RTC_DCHECK(!sdes_transport_); - RTC_DCHECK(!dtls_srtp_transport_); - unencrypted_rtp_transport_->SetRtcpPacketTransport(nullptr); - } else if (sdes_transport_) { - RTC_DCHECK(!unencrypted_rtp_transport_); - RTC_DCHECK(!dtls_srtp_transport_); - sdes_transport_->SetRtcpPacketTransport(nullptr); - } else if (dtls_srtp_transport_) { - RTC_DCHECK(dtls_srtp_transport_); - RTC_DCHECK(!unencrypted_rtp_transport_); - RTC_DCHECK(!sdes_transport_); - dtls_srtp_transport_->SetDtlsTransports(rtp_dtls_transport_locked(), - /*rtcp_dtls_transport=*/nullptr); - } - rtcp_dtls_transport_ = nullptr; // Destroy this reference. + if (unencrypted_rtp_transport_) { + RTC_DCHECK(!sdes_transport_); + RTC_DCHECK(!dtls_srtp_transport_); + unencrypted_rtp_transport_->SetRtcpPacketTransport(nullptr); + } else if (sdes_transport_) { + RTC_DCHECK(!unencrypted_rtp_transport_); + RTC_DCHECK(!dtls_srtp_transport_); + sdes_transport_->SetRtcpPacketTransport(nullptr); + } else if (dtls_srtp_transport_) { + RTC_DCHECK(dtls_srtp_transport_); + RTC_DCHECK(!unencrypted_rtp_transport_); + RTC_DCHECK(!sdes_transport_); + dtls_srtp_transport_->SetDtlsTransports(rtp_dtls_transport(), + /*rtcp_dtls_transport=*/nullptr); } + rtcp_dtls_transport_ = nullptr; // Destroy this reference. // Notify the JsepTransportController to update the aggregate states. SignalRtcpMuxActive(); } @@ -696,17 +677,12 @@ webrtc::RTCError JsepTransport::NegotiateDtlsRole( } bool JsepTransport::GetTransportStats(DtlsTransportInternal* dtls_transport, + int component, TransportStats* stats) { RTC_DCHECK_RUN_ON(network_thread_); RTC_DCHECK(dtls_transport); TransportChannelStats substats; - if (rtcp_dtls_transport_) { - substats.component = dtls_transport == rtcp_dtls_transport_->internal() - ? ICE_CANDIDATE_COMPONENT_RTCP - : ICE_CANDIDATE_COMPONENT_RTP; - } else { - substats.component = ICE_CANDIDATE_COMPONENT_RTP; - } + substats.component = component; dtls_transport->GetSslVersionBytes(&substats.ssl_version_bytes); dtls_transport->GetSrtpCryptoSuite(&substats.srtp_crypto_suite); dtls_transport->GetSslCipherSuite(&substats.ssl_cipher_suite); diff --git a/pc/jsep_transport.h b/pc/jsep_transport.h index 11c8168d9e..5e8cae0ecf 100644 --- a/pc/jsep_transport.h +++ b/pc/jsep_transport.h @@ -18,28 +18,38 @@ #include "absl/types/optional.h" #include "api/candidate.h" +#include "api/crypto_params.h" #include "api/ice_transport_interface.h" #include "api/jsep.h" +#include "api/rtc_error.h" +#include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "api/transport/data_channel_transport_interface.h" #include "media/sctp/sctp_transport_internal.h" #include "p2p/base/dtls_transport.h" +#include "p2p/base/dtls_transport_internal.h" +#include "p2p/base/ice_transport_internal.h" #include "p2p/base/p2p_constants.h" +#include "p2p/base/transport_description.h" #include "p2p/base/transport_info.h" -#include "pc/composite_rtp_transport.h" #include "pc/dtls_srtp_transport.h" #include "pc/dtls_transport.h" #include "pc/rtcp_mux_filter.h" #include "pc/rtp_transport.h" +#include "pc/rtp_transport_internal.h" #include "pc/sctp_transport.h" #include "pc/session_description.h" #include "pc/srtp_filter.h" #include "pc/srtp_transport.h" #include "pc/transport_stats.h" +#include "rtc_base/checks.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/rtc_certificate.h" +#include "rtc_base/ssl_fingerprint.h" #include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/third_party/sigslot/sigslot.h" -#include "rtc_base/thread_checker.h" +#include "rtc_base/thread.h" +#include "rtc_base/thread_annotations.h" namespace cricket { @@ -89,7 +99,6 @@ class JsepTransport : public sigslot::has_slots<> { std::unique_ptr unencrypted_rtp_transport, std::unique_ptr sdes_transport, std::unique_ptr dtls_srtp_transport, - std::unique_ptr datagram_rtp_transport, std::unique_ptr rtp_dtls_transport, std::unique_ptr rtcp_dtls_transport, std::unique_ptr sctp_transport); @@ -115,38 +124,36 @@ class JsepTransport : public sigslot::has_slots<> { webrtc::RTCError SetLocalJsepTransportDescription( const JsepTransportDescription& jsep_description, - webrtc::SdpType type) RTC_LOCKS_EXCLUDED(accessor_lock_); + webrtc::SdpType type); // Set the remote TransportDescription to be used by DTLS and ICE channels // that are part of this Transport. webrtc::RTCError SetRemoteJsepTransportDescription( const JsepTransportDescription& jsep_description, - webrtc::SdpType type) RTC_LOCKS_EXCLUDED(accessor_lock_); - webrtc::RTCError AddRemoteCandidates(const Candidates& candidates) - RTC_LOCKS_EXCLUDED(accessor_lock_); + webrtc::SdpType type); + webrtc::RTCError AddRemoteCandidates(const Candidates& candidates); // Set the "needs-ice-restart" flag as described in JSEP. After the flag is // set, offers should generate new ufrags/passwords until an ICE restart // occurs. // - // This and the below method can be called safely from any thread as long as - // SetXTransportDescription is not in progress. - void SetNeedsIceRestartFlag() RTC_LOCKS_EXCLUDED(accessor_lock_); + // This and |needs_ice_restart()| must be called on the network thread. + void SetNeedsIceRestartFlag(); + // Returns true if the ICE restart flag above was set, and no ICE restart has // occurred yet for this transport (by applying a local description with // changed ufrag/password). - bool needs_ice_restart() const RTC_LOCKS_EXCLUDED(accessor_lock_) { - rtc::CritScope scope(&accessor_lock_); + bool needs_ice_restart() const { + RTC_DCHECK_RUN_ON(network_thread_); return needs_ice_restart_; } // Returns role if negotiated, or empty absl::optional if it hasn't been // negotiated yet. - absl::optional GetDtlsRole() const - RTC_LOCKS_EXCLUDED(accessor_lock_); + absl::optional GetDtlsRole() const; // TODO(deadbeef): Make this const. See comment in transportcontroller.h. - bool GetStats(TransportStats* stats) RTC_LOCKS_EXCLUDED(accessor_lock_); + bool GetStats(TransportStats* stats); const JsepTransportDescription* local_description() const { RTC_DCHECK_RUN_ON(network_thread_); @@ -158,71 +165,61 @@ class JsepTransport : public sigslot::has_slots<> { return remote_description_.get(); } - webrtc::RtpTransportInternal* rtp_transport() const - RTC_LOCKS_EXCLUDED(accessor_lock_) { - rtc::CritScope scope(&accessor_lock_); - if (composite_rtp_transport_) { - return composite_rtp_transport_.get(); - } else if (datagram_rtp_transport_) { - return datagram_rtp_transport_.get(); - } else { - return default_rtp_transport(); + // Returns the rtp transport, if any. + webrtc::RtpTransportInternal* rtp_transport() const { + if (dtls_srtp_transport_) { + return dtls_srtp_transport_.get(); } + if (sdes_transport_) { + return sdes_transport_.get(); + } + if (unencrypted_rtp_transport_) { + return unencrypted_rtp_transport_.get(); + } + return nullptr; } - const DtlsTransportInternal* rtp_dtls_transport() const - RTC_LOCKS_EXCLUDED(accessor_lock_) { - rtc::CritScope scope(&accessor_lock_); + const DtlsTransportInternal* rtp_dtls_transport() const { if (rtp_dtls_transport_) { return rtp_dtls_transport_->internal(); - } else { - return nullptr; } + return nullptr; } - DtlsTransportInternal* rtp_dtls_transport() - RTC_LOCKS_EXCLUDED(accessor_lock_) { - rtc::CritScope scope(&accessor_lock_); - return rtp_dtls_transport_locked(); + DtlsTransportInternal* rtp_dtls_transport() { + if (rtp_dtls_transport_) { + return rtp_dtls_transport_->internal(); + } + return nullptr; } - const DtlsTransportInternal* rtcp_dtls_transport() const - RTC_LOCKS_EXCLUDED(accessor_lock_) { - rtc::CritScope scope(&accessor_lock_); + const DtlsTransportInternal* rtcp_dtls_transport() const { + RTC_DCHECK_RUN_ON(network_thread_); if (rtcp_dtls_transport_) { return rtcp_dtls_transport_->internal(); - } else { - return nullptr; } + return nullptr; } - DtlsTransportInternal* rtcp_dtls_transport() - RTC_LOCKS_EXCLUDED(accessor_lock_) { - rtc::CritScope scope(&accessor_lock_); + DtlsTransportInternal* rtcp_dtls_transport() { + RTC_DCHECK_RUN_ON(network_thread_); if (rtcp_dtls_transport_) { return rtcp_dtls_transport_->internal(); - } else { - return nullptr; } + return nullptr; } - rtc::scoped_refptr RtpDtlsTransport() - RTC_LOCKS_EXCLUDED(accessor_lock_) { - rtc::CritScope scope(&accessor_lock_); + rtc::scoped_refptr RtpDtlsTransport() { return rtp_dtls_transport_; } - rtc::scoped_refptr SctpTransport() const - RTC_LOCKS_EXCLUDED(accessor_lock_) { - rtc::CritScope scope(&accessor_lock_); + rtc::scoped_refptr SctpTransport() const { return sctp_transport_; } // TODO(bugs.webrtc.org/9719): Delete method, update callers to use // SctpTransport() instead. - webrtc::DataChannelTransportInterface* data_channel_transport() const - RTC_LOCKS_EXCLUDED(accessor_lock_) { - rtc::CritScope scope(&accessor_lock_); + webrtc::DataChannelTransportInterface* data_channel_transport() const { if (sctp_data_channel_transport_) { return sctp_data_channel_transport_.get(); } @@ -247,24 +244,14 @@ class JsepTransport : public sigslot::has_slots<> { void SetActiveResetSrtpParams(bool active_reset_srtp_params); private: - DtlsTransportInternal* rtp_dtls_transport_locked() - RTC_EXCLUSIVE_LOCKS_REQUIRED(accessor_lock_) { - if (rtp_dtls_transport_) { - return rtp_dtls_transport_->internal(); - } else { - return nullptr; - } - } - bool SetRtcpMux(bool enable, webrtc::SdpType type, ContentSource source); - void ActivateRtcpMux(); + void ActivateRtcpMux() RTC_RUN_ON(network_thread_); bool SetSdes(const std::vector& cryptos, const std::vector& encrypted_extension_ids, webrtc::SdpType type, - ContentSource source) - RTC_EXCLUSIVE_LOCKS_REQUIRED(accessor_lock_); + ContentSource source); // Negotiates and sets the DTLS parameters based on the current local and // remote transport description, such as the DTLS role to use, and whether @@ -281,8 +268,7 @@ class JsepTransport : public sigslot::has_slots<> { webrtc::SdpType local_description_type, ConnectionRole local_connection_role, ConnectionRole remote_connection_role, - absl::optional* negotiated_dtls_role) - RTC_LOCKS_EXCLUDED(accessor_lock_); + absl::optional* negotiated_dtls_role); // Pushes down the ICE parameters from the remote description. void SetRemoteIceParameters(const IceParameters& ice_parameters, @@ -295,31 +281,14 @@ class JsepTransport : public sigslot::has_slots<> { rtc::SSLFingerprint* remote_fingerprint); bool GetTransportStats(DtlsTransportInternal* dtls_transport, - TransportStats* stats) - RTC_EXCLUSIVE_LOCKS_REQUIRED(accessor_lock_); - - // Returns the default (non-datagram) rtp transport, if any. - webrtc::RtpTransportInternal* default_rtp_transport() const - RTC_EXCLUSIVE_LOCKS_REQUIRED(accessor_lock_) { - if (dtls_srtp_transport_) { - return dtls_srtp_transport_.get(); - } else if (sdes_transport_) { - return sdes_transport_.get(); - } else if (unencrypted_rtp_transport_) { - return unencrypted_rtp_transport_.get(); - } else { - return nullptr; - } - } + int component, + TransportStats* stats); // Owning thread, for safety checks const rtc::Thread* const network_thread_; - // Critical scope for fields accessed off-thread - // TODO(https://bugs.webrtc.org/10300): Stop doing this. - rtc::RecursiveCriticalSection accessor_lock_; const std::string mid_; // needs-ice-restart bit as described in JSEP. - bool needs_ice_restart_ RTC_GUARDED_BY(accessor_lock_) = false; + bool needs_ice_restart_ RTC_GUARDED_BY(network_thread_) = false; rtc::scoped_refptr local_certificate_ RTC_GUARDED_BY(network_thread_); std::unique_ptr local_description_ @@ -334,31 +303,19 @@ class JsepTransport : public sigslot::has_slots<> { // To avoid downcasting and make it type safe, keep three unique pointers for // different SRTP mode and only one of these is non-nullptr. - std::unique_ptr unencrypted_rtp_transport_ - RTC_GUARDED_BY(accessor_lock_); - std::unique_ptr sdes_transport_ - RTC_GUARDED_BY(accessor_lock_); - std::unique_ptr dtls_srtp_transport_ - RTC_GUARDED_BY(accessor_lock_); - - // If multiple RTP transports are in use, |composite_rtp_transport_| will be - // passed to callers. This is only valid for offer-only, receive-only - // scenarios, as it is not possible for the composite to correctly choose - // which transport to use for sending. - std::unique_ptr composite_rtp_transport_ - RTC_GUARDED_BY(accessor_lock_); - - rtc::scoped_refptr rtp_dtls_transport_ - RTC_GUARDED_BY(accessor_lock_); + const std::unique_ptr unencrypted_rtp_transport_; + const std::unique_ptr sdes_transport_; + const std::unique_ptr dtls_srtp_transport_; + + const rtc::scoped_refptr rtp_dtls_transport_; + // The RTCP transport is const for all usages, except that it is cleared + // when RTCP multiplexing is turned on; this happens on the network thread. rtc::scoped_refptr rtcp_dtls_transport_ - RTC_GUARDED_BY(accessor_lock_); - rtc::scoped_refptr datagram_dtls_transport_ - RTC_GUARDED_BY(accessor_lock_); + RTC_GUARDED_BY(network_thread_); - std::unique_ptr - sctp_data_channel_transport_ RTC_GUARDED_BY(accessor_lock_); - rtc::scoped_refptr sctp_transport_ - RTC_GUARDED_BY(accessor_lock_); + const std::unique_ptr + sctp_data_channel_transport_; + const rtc::scoped_refptr sctp_transport_; SrtpFilter sdes_negotiator_ RTC_GUARDED_BY(network_thread_); RtcpMuxFilter rtcp_mux_negotiator_ RTC_GUARDED_BY(network_thread_); @@ -369,9 +326,6 @@ class JsepTransport : public sigslot::has_slots<> { absl::optional> recv_extension_ids_ RTC_GUARDED_BY(network_thread_); - std::unique_ptr datagram_rtp_transport_ - RTC_GUARDED_BY(accessor_lock_); - RTC_DISALLOW_COPY_AND_ASSIGN(JsepTransport); }; diff --git a/pc/jsep_transport_collection.cc b/pc/jsep_transport_collection.cc new file mode 100644 index 0000000000..ce068d99fc --- /dev/null +++ b/pc/jsep_transport_collection.cc @@ -0,0 +1,255 @@ +/* + * Copyright 2021 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "pc/jsep_transport_collection.h" + +#include +#include +#include +#include + +#include "p2p/base/p2p_constants.h" +#include "rtc_base/logging.h" + +namespace webrtc { + +void BundleManager::Update(const cricket::SessionDescription* description) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + bundle_groups_.clear(); + for (const cricket::ContentGroup* new_bundle_group : + description->GetGroupsByName(cricket::GROUP_TYPE_BUNDLE)) { + bundle_groups_.push_back( + std::make_unique(*new_bundle_group)); + RTC_DLOG(LS_VERBOSE) << "Establishing bundle group " + << new_bundle_group->ToString(); + } + established_bundle_groups_by_mid_.clear(); + for (const auto& bundle_group : bundle_groups_) { + for (const std::string& content_name : bundle_group->content_names()) { + established_bundle_groups_by_mid_[content_name] = bundle_group.get(); + } + } +} + +const cricket::ContentGroup* BundleManager::LookupGroupByMid( + const std::string& mid) const { + auto it = established_bundle_groups_by_mid_.find(mid); + return it != established_bundle_groups_by_mid_.end() ? it->second : nullptr; +} +bool BundleManager::IsFirstMidInGroup(const std::string& mid) const { + auto group = LookupGroupByMid(mid); + if (!group) { + return true; // Unbundled MIDs are considered group leaders + } + return mid == *(group->FirstContentName()); +} + +cricket::ContentGroup* BundleManager::LookupGroupByMid(const std::string& mid) { + auto it = established_bundle_groups_by_mid_.find(mid); + return it != established_bundle_groups_by_mid_.end() ? it->second : nullptr; +} + +void BundleManager::DeleteMid(const cricket::ContentGroup* bundle_group, + const std::string& mid) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + RTC_LOG(LS_VERBOSE) << "Deleting mid " << mid << " from bundle group " + << bundle_group->ToString(); + // Remove the rejected content from the |bundle_group|. + // The const pointer arg is used to identify the group, we verify + // it before we use it to make a modification. + auto bundle_group_it = std::find_if( + bundle_groups_.begin(), bundle_groups_.end(), + [bundle_group](std::unique_ptr& group) { + return bundle_group == group.get(); + }); + RTC_DCHECK(bundle_group_it != bundle_groups_.end()); + (*bundle_group_it)->RemoveContentName(mid); + established_bundle_groups_by_mid_.erase( + established_bundle_groups_by_mid_.find(mid)); +} + +void BundleManager::DeleteGroup(const cricket::ContentGroup* bundle_group) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + RTC_DLOG(LS_VERBOSE) << "Deleting bundle group " << bundle_group->ToString(); + + auto bundle_group_it = std::find_if( + bundle_groups_.begin(), bundle_groups_.end(), + [bundle_group](std::unique_ptr& group) { + return bundle_group == group.get(); + }); + RTC_DCHECK(bundle_group_it != bundle_groups_.end()); + auto mid_list = (*bundle_group_it)->content_names(); + for (const auto& content_name : mid_list) { + DeleteMid(bundle_group, content_name); + } + bundle_groups_.erase(bundle_group_it); +} + +void JsepTransportCollection::RegisterTransport( + const std::string& mid, + std::unique_ptr transport) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + SetTransportForMid(mid, transport.get()); + jsep_transports_by_name_[mid] = std::move(transport); + RTC_DCHECK(IsConsistent()); +} + +std::vector JsepTransportCollection::Transports() { + RTC_DCHECK_RUN_ON(&sequence_checker_); + std::vector result; + for (auto& kv : jsep_transports_by_name_) { + result.push_back(kv.second.get()); + } + return result; +} + +void JsepTransportCollection::DestroyAllTransports() { + RTC_DCHECK_RUN_ON(&sequence_checker_); + for (const auto& jsep_transport : jsep_transports_by_name_) { + map_change_callback_(jsep_transport.first, nullptr); + } + jsep_transports_by_name_.clear(); + RTC_DCHECK(IsConsistent()); +} + +const cricket::JsepTransport* JsepTransportCollection::GetTransportByName( + const std::string& transport_name) const { + RTC_DCHECK_RUN_ON(&sequence_checker_); + auto it = jsep_transports_by_name_.find(transport_name); + return (it == jsep_transports_by_name_.end()) ? nullptr : it->second.get(); +} + +cricket::JsepTransport* JsepTransportCollection::GetTransportByName( + const std::string& transport_name) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + auto it = jsep_transports_by_name_.find(transport_name); + return (it == jsep_transports_by_name_.end()) ? nullptr : it->second.get(); +} + +cricket::JsepTransport* JsepTransportCollection::GetTransportForMid( + const std::string& mid) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + auto it = mid_to_transport_.find(mid); + return it == mid_to_transport_.end() ? nullptr : it->second; +} + +const cricket::JsepTransport* JsepTransportCollection::GetTransportForMid( + const std::string& mid) const { + RTC_DCHECK_RUN_ON(&sequence_checker_); + auto it = mid_to_transport_.find(mid); + return it == mid_to_transport_.end() ? nullptr : it->second; +} + +bool JsepTransportCollection::SetTransportForMid( + const std::string& mid, + cricket::JsepTransport* jsep_transport) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + RTC_DCHECK(jsep_transport); + + auto it = mid_to_transport_.find(mid); + if (it != mid_to_transport_.end() && it->second == jsep_transport) + return true; + + pending_mids_.push_back(mid); + + // The map_change_callback must be called before destroying the + // transport, because it removes references to the transport + // in the RTP demuxer. + bool result = map_change_callback_(mid, jsep_transport); + + if (it == mid_to_transport_.end()) { + mid_to_transport_.insert(std::make_pair(mid, jsep_transport)); + } else { + auto old_transport = it->second; + it->second = jsep_transport; + MaybeDestroyJsepTransport(old_transport); + } + RTC_DCHECK(IsConsistent()); + return result; +} + +void JsepTransportCollection::RemoveTransportForMid(const std::string& mid) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + RTC_DCHECK(IsConsistent()); + bool ret = map_change_callback_(mid, nullptr); + // Calling OnTransportChanged with nullptr should always succeed, since it is + // only expected to fail when adding media to a transport (not removing). + RTC_DCHECK(ret); + + auto old_transport = GetTransportForMid(mid); + if (old_transport) { + mid_to_transport_.erase(mid); + MaybeDestroyJsepTransport(old_transport); + } + RTC_DCHECK(IsConsistent()); +} + +void JsepTransportCollection::RollbackTransports() { + RTC_DCHECK_RUN_ON(&sequence_checker_); + for (auto&& mid : pending_mids_) { + RemoveTransportForMid(mid); + } + pending_mids_.clear(); +} + +void JsepTransportCollection::CommitTransports() { + RTC_DCHECK_RUN_ON(&sequence_checker_); + pending_mids_.clear(); +} + +bool JsepTransportCollection::TransportInUse( + cricket::JsepTransport* jsep_transport) const { + RTC_DCHECK_RUN_ON(&sequence_checker_); + for (const auto& kv : mid_to_transport_) { + if (kv.second == jsep_transport) { + return true; + } + } + return false; +} + +void JsepTransportCollection::MaybeDestroyJsepTransport( + cricket::JsepTransport* transport) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + // Don't destroy the JsepTransport if there are still media sections referring + // to it. + if (TransportInUse(transport)) { + return; + } + for (const auto& it : jsep_transports_by_name_) { + if (it.second.get() == transport) { + jsep_transports_by_name_.erase(it.first); + state_change_callback_(); + break; + } + } + RTC_DCHECK(IsConsistent()); +} + +bool JsepTransportCollection::IsConsistent() { + RTC_DCHECK_RUN_ON(&sequence_checker_); + for (const auto& it : jsep_transports_by_name_) { + if (!TransportInUse(it.second.get())) { + RTC_LOG(LS_ERROR) << "Transport registered with mid " << it.first + << " is not in use, transport " << it.second.get(); + return false; + } + const auto& lookup = mid_to_transport_.find(it.first); + if (lookup->second != it.second.get()) { + // Not an error, but unusual. + RTC_DLOG(LS_INFO) << "Note: Mid " << it.first << " was registered to " + << it.second.get() << " but currently maps to " + << lookup->second; + } + } + return true; +} + +} // namespace webrtc diff --git a/pc/jsep_transport_collection.h b/pc/jsep_transport_collection.h new file mode 100644 index 0000000000..0dd528d348 --- /dev/null +++ b/pc/jsep_transport_collection.h @@ -0,0 +1,145 @@ +/* + * Copyright 2021 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef PC_JSEP_TRANSPORT_COLLECTION_H_ +#define PC_JSEP_TRANSPORT_COLLECTION_H_ + +#include +#include +#include +#include +#include +#include + +#include "api/sequence_checker.h" +#include "pc/jsep_transport.h" +#include "pc/session_description.h" +#include "rtc_base/checks.h" +#include "rtc_base/system/no_unique_address.h" +#include "rtc_base/thread_annotations.h" + +namespace webrtc { + +// This class manages information about RFC 8843 BUNDLE bundles +// in SDP descriptions. + +// This is a work-in-progress. Planned steps: +// 1) Move all Bundle-related data structures from JsepTransport +// into this class. +// 2) Move all Bundle-related functions into this class. +// 3) Move remaining Bundle-related logic into this class. +// Make data members private. +// 4) Refine interface to have comprehensible semantics. +// 5) Add unit tests. +// 6) Change the logic to do what's right. +class BundleManager { + public: + BundleManager() { + // Allow constructor to be called on a different thread. + sequence_checker_.Detach(); + } + const std::vector>& bundle_groups() + const { + RTC_DCHECK_RUN_ON(&sequence_checker_); + return bundle_groups_; + } + // Lookup a bundle group by a member mid name. + const cricket::ContentGroup* LookupGroupByMid(const std::string& mid) const; + cricket::ContentGroup* LookupGroupByMid(const std::string& mid); + // Returns true if the MID is the first item of a group, or if + // the MID is not a member of a group. + bool IsFirstMidInGroup(const std::string& mid) const; + // Update the groups description. This completely replaces the group + // description with the one from the SessionDescription. + void Update(const cricket::SessionDescription* description); + // Delete a MID from the group that contains it. + void DeleteMid(const cricket::ContentGroup* bundle_group, + const std::string& mid); + // Delete a group. + void DeleteGroup(const cricket::ContentGroup* bundle_group); + + private: + RTC_NO_UNIQUE_ADDRESS SequenceChecker sequence_checker_; + std::vector> bundle_groups_ + RTC_GUARDED_BY(sequence_checker_); + std::map + established_bundle_groups_by_mid_; +}; + +// This class keeps the mapping of MIDs to transports. +// It is pulled out here because a lot of the code that deals with +// bundles end up modifying this map, and the two need to be consistent; +// the managers may merge. +class JsepTransportCollection { + public: + JsepTransportCollection(std::function + map_change_callback, + std::function state_change_callback) + : map_change_callback_(map_change_callback), + state_change_callback_(state_change_callback) { + // Allow constructor to be called on a different thread. + sequence_checker_.Detach(); + } + + void RegisterTransport(const std::string& mid, + std::unique_ptr transport); + std::vector Transports(); + void DestroyAllTransports(); + // Lookup a JsepTransport by the MID that was used to register it. + cricket::JsepTransport* GetTransportByName(const std::string& mid); + const cricket::JsepTransport* GetTransportByName( + const std::string& mid) const; + // Lookup a JsepTransport by any MID that refers to it. + cricket::JsepTransport* GetTransportForMid(const std::string& mid); + const cricket::JsepTransport* GetTransportForMid( + const std::string& mid) const; + // Set transport for a MID. This may destroy a transport if it is no + // longer in use. + bool SetTransportForMid(const std::string& mid, + cricket::JsepTransport* jsep_transport); + // Remove a transport for a MID. This may destroy a transport if it is + // no longer in use. + void RemoveTransportForMid(const std::string& mid); + // Roll back pending mid-to-transport mappings. + void RollbackTransports(); + // Commit pending mid-transport mappings (rollback is no longer possible). + void CommitTransports(); + // Returns true if any mid currently maps to this transport. + bool TransportInUse(cricket::JsepTransport* jsep_transport) const; + + private: + // Destroy a transport if it's no longer in use. + void MaybeDestroyJsepTransport(cricket::JsepTransport* transport); + + bool IsConsistent(); // For testing only: Verify internal structure. + + RTC_NO_UNIQUE_ADDRESS SequenceChecker sequence_checker_; + // This member owns the JSEP transports. + std::map> + jsep_transports_by_name_ RTC_GUARDED_BY(sequence_checker_); + + // This keeps track of the mapping between media section + // (BaseChannel/SctpTransport) and the JsepTransport underneath. + std::map mid_to_transport_ + RTC_GUARDED_BY(sequence_checker_); + // Keep track of mids that have been mapped to transports. Used for rollback. + std::vector pending_mids_ RTC_GUARDED_BY(sequence_checker_); + // Callback used to inform subscribers of altered transports. + const std::function + map_change_callback_; + // Callback used to inform subscribers of possibly altered state. + const std::function state_change_callback_; +}; + +} // namespace webrtc + +#endif // PC_JSEP_TRANSPORT_COLLECTION_H_ diff --git a/pc/jsep_transport_controller.cc b/pc/jsep_transport_controller.cc index 4999f2ab04..f0e377e048 100644 --- a/pc/jsep_transport_controller.cc +++ b/pc/jsep_transport_controller.cc @@ -10,102 +10,76 @@ #include "pc/jsep_transport_controller.h" +#include + +#include +#include #include +#include #include #include "absl/algorithm/container.h" -#include "api/ice_transport_factory.h" +#include "api/dtls_transport_interface.h" +#include "api/rtp_parameters.h" +#include "api/sequence_checker.h" +#include "api/transport/enums.h" +#include "media/sctp/sctp_transport_internal.h" +#include "p2p/base/dtls_transport.h" #include "p2p/base/ice_transport_internal.h" +#include "p2p/base/p2p_constants.h" #include "p2p/base/port.h" -#include "pc/srtp_filter.h" -#include "rtc_base/bind.h" #include "rtc_base/checks.h" +#include "rtc_base/location.h" +#include "rtc_base/logging.h" #include "rtc_base/thread.h" +#include "rtc_base/trace_event.h" using webrtc::SdpType; -namespace { - -webrtc::RTCError VerifyCandidate(const cricket::Candidate& cand) { - // No address zero. - if (cand.address().IsNil() || cand.address().IsAnyIP()) { - return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, - "candidate has address of zero"); - } - - // Disallow all ports below 1024, except for 80 and 443 on public addresses. - int port = cand.address().port(); - if (cand.protocol() == cricket::TCP_PROTOCOL_NAME && - (cand.tcptype() == cricket::TCPTYPE_ACTIVE_STR || port == 0)) { - // Expected for active-only candidates per - // http://tools.ietf.org/html/rfc6544#section-4.5 so no error. - // Libjingle clients emit port 0, in "active" mode. - return webrtc::RTCError::OK(); - } - if (port < 1024) { - if ((port != 80) && (port != 443)) { - return webrtc::RTCError( - webrtc::RTCErrorType::INVALID_PARAMETER, - "candidate has port below 1024, but not 80 or 443"); - } - - if (cand.address().IsPrivateIP()) { - return webrtc::RTCError( - webrtc::RTCErrorType::INVALID_PARAMETER, - "candidate has port of 80 or 443 with private IP address"); - } - } - - return webrtc::RTCError::OK(); -} - -webrtc::RTCError VerifyCandidates(const cricket::Candidates& candidates) { - for (const cricket::Candidate& candidate : candidates) { - webrtc::RTCError error = VerifyCandidate(candidate); - if (!error.ok()) { - return error; - } - } - return webrtc::RTCError::OK(); -} - -} // namespace - namespace webrtc { JsepTransportController::JsepTransportController( - rtc::Thread* signaling_thread, rtc::Thread* network_thread, cricket::PortAllocator* port_allocator, - AsyncResolverFactory* async_resolver_factory, + AsyncDnsResolverFactoryInterface* async_dns_resolver_factory, Config config) - : signaling_thread_(signaling_thread), - network_thread_(network_thread), + : network_thread_(network_thread), port_allocator_(port_allocator), - async_resolver_factory_(async_resolver_factory), - config_(config) { + async_dns_resolver_factory_(async_dns_resolver_factory), + transports_( + [this](const std::string& mid, cricket::JsepTransport* transport) { + return OnTransportChanged(mid, transport); + }, + [this]() { + RTC_DCHECK_RUN_ON(network_thread_); + UpdateAggregateStates_n(); + }), + config_(config), + active_reset_srtp_params_(config.active_reset_srtp_params) { // The |transport_observer| is assumed to be non-null. RTC_DCHECK(config_.transport_observer); RTC_DCHECK(config_.rtcp_handler); RTC_DCHECK(config_.ice_transport_factory); + RTC_DCHECK(config_.on_dtls_handshake_error_); } JsepTransportController::~JsepTransportController() { // Channel destructors may try to send packets, so this needs to happen on // the network thread. - network_thread_->Invoke( - RTC_FROM_HERE, - rtc::Bind(&JsepTransportController::DestroyAllJsepTransports_n, this)); + RTC_DCHECK_RUN_ON(network_thread_); + DestroyAllJsepTransports_n(); } RTCError JsepTransportController::SetLocalDescription( SdpType type, const cricket::SessionDescription* description) { + TRACE_EVENT0("webrtc", "JsepTransportController::SetLocalDescription"); if (!network_thread_->IsCurrent()) { return network_thread_->Invoke( RTC_FROM_HERE, [=] { return SetLocalDescription(type, description); }); } + RTC_DCHECK_RUN_ON(network_thread_); if (!initial_offerer_.has_value()) { initial_offerer_.emplace(type == SdpType::kOffer); if (*initial_offerer_) { @@ -120,16 +94,19 @@ RTCError JsepTransportController::SetLocalDescription( RTCError JsepTransportController::SetRemoteDescription( SdpType type, const cricket::SessionDescription* description) { + TRACE_EVENT0("webrtc", "JsepTransportController::SetRemoteDescription"); if (!network_thread_->IsCurrent()) { return network_thread_->Invoke( RTC_FROM_HERE, [=] { return SetRemoteDescription(type, description); }); } + RTC_DCHECK_RUN_ON(network_thread_); return ApplyDescription_n(/*local=*/false, type, description); } RtpTransportInternal* JsepTransportController::GetRtpTransport( const std::string& mid) const { + RTC_DCHECK_RUN_ON(network_thread_); auto jsep_transport = GetJsepTransportForMid(mid); if (!jsep_transport) { return nullptr; @@ -139,6 +116,7 @@ RtpTransportInternal* JsepTransportController::GetRtpTransport( DataChannelTransportInterface* JsepTransportController::GetDataChannelTransport( const std::string& mid) const { + RTC_DCHECK_RUN_ON(network_thread_); auto jsep_transport = GetJsepTransportForMid(mid); if (!jsep_transport) { return nullptr; @@ -148,6 +126,7 @@ DataChannelTransportInterface* JsepTransportController::GetDataChannelTransport( cricket::DtlsTransportInternal* JsepTransportController::GetDtlsTransport( const std::string& mid) { + RTC_DCHECK_RUN_ON(network_thread_); auto jsep_transport = GetJsepTransportForMid(mid); if (!jsep_transport) { return nullptr; @@ -157,6 +136,7 @@ cricket::DtlsTransportInternal* JsepTransportController::GetDtlsTransport( const cricket::DtlsTransportInternal* JsepTransportController::GetRtcpDtlsTransport(const std::string& mid) const { + RTC_DCHECK_RUN_ON(network_thread_); auto jsep_transport = GetJsepTransportForMid(mid); if (!jsep_transport) { return nullptr; @@ -166,6 +146,7 @@ JsepTransportController::GetRtcpDtlsTransport(const std::string& mid) const { rtc::scoped_refptr JsepTransportController::LookupDtlsTransportByMid(const std::string& mid) { + RTC_DCHECK_RUN_ON(network_thread_); auto jsep_transport = GetJsepTransportForMid(mid); if (!jsep_transport) { return nullptr; @@ -175,6 +156,7 @@ JsepTransportController::LookupDtlsTransportByMid(const std::string& mid) { rtc::scoped_refptr JsepTransportController::GetSctpTransport( const std::string& mid) const { + RTC_DCHECK_RUN_ON(network_thread_); auto jsep_transport = GetJsepTransportForMid(mid); if (!jsep_transport) { return nullptr; @@ -183,11 +165,7 @@ rtc::scoped_refptr JsepTransportController::GetSctpTransport( } void JsepTransportController::SetIceConfig(const cricket::IceConfig& config) { - if (!network_thread_->IsCurrent()) { - network_thread_->Invoke(RTC_FROM_HERE, [&] { SetIceConfig(config); }); - return; - } - + RTC_DCHECK_RUN_ON(network_thread_); ice_config_ = config; for (auto& dtls : GetDtlsTransports()) { dtls->ice_transport()->SetIceConfig(ice_config_); @@ -195,13 +173,16 @@ void JsepTransportController::SetIceConfig(const cricket::IceConfig& config) { } void JsepTransportController::SetNeedsIceRestartFlag() { - for (auto& kv : jsep_transports_by_name_) { - kv.second->SetNeedsIceRestartFlag(); + RTC_DCHECK_RUN_ON(network_thread_); + for (auto& transport : transports_.Transports()) { + transport->SetNeedsIceRestartFlag(); } } bool JsepTransportController::NeedsIceRestart( const std::string& transport_name) const { + RTC_DCHECK_RUN_ON(network_thread_); + const cricket::JsepTransport* transport = GetJsepTransportByName(transport_name); if (!transport) { @@ -212,11 +193,16 @@ bool JsepTransportController::NeedsIceRestart( absl::optional JsepTransportController::GetDtlsRole( const std::string& mid) const { + // TODO(tommi): Remove this hop. Currently it's called from the signaling + // thread during negotiations, potentially multiple times. + // WebRtcSessionDescriptionFactory::InternalCreateAnswer is one example. if (!network_thread_->IsCurrent()) { return network_thread_->Invoke>( RTC_FROM_HERE, [&] { return GetDtlsRole(mid); }); } + RTC_DCHECK_RUN_ON(network_thread_); + const cricket::JsepTransport* t = GetJsepTransportForMid(mid); if (!t) { return absl::optional(); @@ -231,6 +217,8 @@ bool JsepTransportController::SetLocalCertificate( RTC_FROM_HERE, [&] { return SetLocalCertificate(certificate); }); } + RTC_DCHECK_RUN_ON(network_thread_); + // Can't change a certificate, or set a null certificate. if (certificate_ || !certificate) { return false; @@ -240,8 +228,8 @@ bool JsepTransportController::SetLocalCertificate( // Set certificate for JsepTransport, which verifies it matches the // fingerprint in SDP, and DTLS transport. // Fallback from DTLS to SDES is not supported. - for (auto& kv : jsep_transports_by_name_) { - kv.second->SetLocalCertificate(certificate_); + for (auto& transport : transports_.Transports()) { + transport->SetLocalCertificate(certificate_); } for (auto& dtls : GetDtlsTransports()) { bool set_cert_success = dtls->SetLocalCertificate(certificate_); @@ -253,10 +241,7 @@ bool JsepTransportController::SetLocalCertificate( rtc::scoped_refptr JsepTransportController::GetLocalCertificate( const std::string& transport_name) const { - if (!network_thread_->IsCurrent()) { - return network_thread_->Invoke>( - RTC_FROM_HERE, [&] { return GetLocalCertificate(transport_name); }); - } + RTC_DCHECK_RUN_ON(network_thread_); const cricket::JsepTransport* t = GetJsepTransportByName(transport_name); if (!t) { @@ -268,10 +253,7 @@ JsepTransportController::GetLocalCertificate( std::unique_ptr JsepTransportController::GetRemoteSSLCertChain( const std::string& transport_name) const { - if (!network_thread_->IsCurrent()) { - return network_thread_->Invoke>( - RTC_FROM_HERE, [&] { return GetRemoteSSLCertChain(transport_name); }); - } + RTC_DCHECK_RUN_ON(network_thread_); // Get the certificate from the RTP transport's DTLS handshake. Should be // identical to the RTCP transport's, since they were given the same remote @@ -303,17 +285,8 @@ void JsepTransportController::MaybeStartGathering() { RTCError JsepTransportController::AddRemoteCandidates( const std::string& transport_name, const cricket::Candidates& candidates) { - if (!network_thread_->IsCurrent()) { - return network_thread_->Invoke(RTC_FROM_HERE, [&] { - return AddRemoteCandidates(transport_name, candidates); - }); - } - - // Verify each candidate before passing down to the transport layer. - RTCError error = VerifyCandidates(candidates); - if (!error.ok()) { - return error; - } + RTC_DCHECK_RUN_ON(network_thread_); + RTC_DCHECK(VerifyCandidates(candidates).ok()); auto jsep_transport = GetJsepTransportByName(transport_name); if (!jsep_transport) { RTC_LOG(LS_WARNING) << "Not adding candidate because the JsepTransport " @@ -330,6 +303,8 @@ RTCError JsepTransportController::RemoveRemoteCandidates( RTC_FROM_HERE, [&] { return RemoveRemoteCandidates(candidates); }); } + RTC_DCHECK_RUN_ON(network_thread_); + // Verify each candidate before passing down to the transport layer. RTCError error = VerifyCandidates(candidates); if (!error.ok()) { @@ -372,10 +347,7 @@ RTCError JsepTransportController::RemoveRemoteCandidates( bool JsepTransportController::GetStats(const std::string& transport_name, cricket::TransportStats* stats) { - if (!network_thread_->IsCurrent()) { - return network_thread_->Invoke( - RTC_FROM_HERE, [=] { return GetStats(transport_name, stats); }); - } + RTC_DCHECK_RUN_ON(network_thread_); cricket::JsepTransport* transport = GetJsepTransportByName(transport_name); if (!transport) { @@ -392,13 +364,13 @@ void JsepTransportController::SetActiveResetSrtpParams( }); return; } - + RTC_DCHECK_RUN_ON(network_thread_); RTC_LOG(INFO) << "Updating the active_reset_srtp_params for JsepTransportController: " << active_reset_srtp_params; - config_.active_reset_srtp_params = active_reset_srtp_params; - for (auto& kv : jsep_transports_by_name_) { - kv.second->SetActiveResetSrtpParams(active_reset_srtp_params); + active_reset_srtp_params_ = active_reset_srtp_params; + for (auto& transport : transports_.Transports()) { + transport->SetActiveResetSrtpParams(active_reset_srtp_params); } } @@ -408,13 +380,7 @@ void JsepTransportController::RollbackTransports() { return; } RTC_DCHECK_RUN_ON(network_thread_); - for (auto&& mid : pending_mids_) { - RemoveTransportForMid(mid); - } - for (auto&& mid : pending_mids_) { - MaybeDestroyJsepTransport(mid); - } - pending_mids_.clear(); + transports_.RollbackTransports(); } rtc::scoped_refptr @@ -425,7 +391,7 @@ JsepTransportController::CreateIceTransport(const std::string& transport_name, IceTransportInit init; init.set_port_allocator(port_allocator_); - init.set_async_resolver_factory(async_resolver_factory_); + init.set_async_dns_resolver_factory(async_dns_resolver_factory_); init.set_event_log(config_.event_log); return config_.ice_transport_factory->CreateIceTransport( transport_name, component, std::move(init)); @@ -435,20 +401,20 @@ std::unique_ptr JsepTransportController::CreateDtlsTransport( const cricket::ContentInfo& content_info, cricket::IceTransportInternal* ice) { - RTC_DCHECK(network_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(network_thread_); std::unique_ptr dtls; if (config_.dtls_transport_factory) { dtls = config_.dtls_transport_factory->CreateDtlsTransport( - ice, config_.crypto_options); + ice, config_.crypto_options, config_.ssl_max_version); } else { dtls = std::make_unique(ice, config_.crypto_options, - config_.event_log); + config_.event_log, + config_.ssl_max_version); } RTC_DCHECK(dtls); - dtls->SetSslMaxProtocolVersion(config_.ssl_max_version); dtls->ice_transport()->SetIceRole(ice_role_); dtls->ice_transport()->SetIceTiebreaker(ice_tiebreaker_); dtls->ice_transport()->SetIceConfig(ice_config_); @@ -462,8 +428,6 @@ JsepTransportController::CreateDtlsTransport( this, &JsepTransportController::OnTransportWritableState_n); dtls->SignalReceivingState.connect( this, &JsepTransportController::OnTransportReceivingState_n); - dtls->SignalDtlsHandshakeError.connect( - this, &JsepTransportController::OnDtlsHandshakeError); dtls->ice_transport()->SignalGatheringState.connect( this, &JsepTransportController::OnTransportGatheringState_n); dtls->ice_transport()->SignalCandidateGathered.connect( @@ -480,6 +444,9 @@ JsepTransportController::CreateDtlsTransport( this, &JsepTransportController::OnTransportStateChanged_n); dtls->ice_transport()->SignalCandidatePairChanged.connect( this, &JsepTransportController::OnTransportCandidatePairChanged_n); + + dtls->SubscribeDtlsHandshakeError( + [this](rtc::SSLHandshakeError error) { OnDtlsHandshakeError(error); }); return dtls; } @@ -488,7 +455,7 @@ JsepTransportController::CreateUnencryptedRtpTransport( const std::string& transport_name, rtc::PacketTransportInternal* rtp_packet_transport, rtc::PacketTransportInternal* rtcp_packet_transport) { - RTC_DCHECK(network_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(network_thread_); auto unencrypted_rtp_transport = std::make_unique(rtcp_packet_transport == nullptr); unencrypted_rtp_transport->SetRtpPacketTransport(rtp_packet_transport); @@ -503,7 +470,7 @@ JsepTransportController::CreateSdesTransport( const std::string& transport_name, cricket::DtlsTransportInternal* rtp_dtls_transport, cricket::DtlsTransportInternal* rtcp_dtls_transport) { - RTC_DCHECK(network_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(network_thread_); auto srtp_transport = std::make_unique(rtcp_dtls_transport == nullptr); RTC_DCHECK(rtp_dtls_transport); @@ -522,7 +489,7 @@ JsepTransportController::CreateDtlsSrtpTransport( const std::string& transport_name, cricket::DtlsTransportInternal* rtp_dtls_transport, cricket::DtlsTransportInternal* rtcp_dtls_transport) { - RTC_DCHECK(network_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(network_thread_); auto dtls_srtp_transport = std::make_unique( rtcp_dtls_transport == nullptr); if (config_.enable_external_auth) { @@ -531,19 +498,21 @@ JsepTransportController::CreateDtlsSrtpTransport( dtls_srtp_transport->SetDtlsTransports(rtp_dtls_transport, rtcp_dtls_transport); - dtls_srtp_transport->SetActiveResetSrtpParams( - config_.active_reset_srtp_params); - dtls_srtp_transport->SignalDtlsStateChange.connect( - this, &JsepTransportController::UpdateAggregateStates_n); + dtls_srtp_transport->SetActiveResetSrtpParams(active_reset_srtp_params_); + // Capturing this in the callback because JsepTransportController will always + // outlive the DtlsSrtpTransport. + dtls_srtp_transport->SetOnDtlsStateChange([this]() { + RTC_DCHECK_RUN_ON(this->network_thread_); + this->UpdateAggregateStates_n(); + }); return dtls_srtp_transport; } std::vector JsepTransportController::GetDtlsTransports() { + RTC_DCHECK_RUN_ON(network_thread_); std::vector dtls_transports; - for (auto it = jsep_transports_by_name_.begin(); - it != jsep_transports_by_name_.end(); ++it) { - auto jsep_transport = it->second.get(); + for (auto jsep_transport : transports_.Transports()) { RTC_DCHECK(jsep_transport); if (jsep_transport->rtp_dtls_transport()) { dtls_transports.push_back(jsep_transport->rtp_dtls_transport()); @@ -560,7 +529,7 @@ RTCError JsepTransportController::ApplyDescription_n( bool local, SdpType type, const cricket::SessionDescription* description) { - RTC_DCHECK_RUN_ON(network_thread_); + TRACE_EVENT0("webrtc", "JsepTransportController::ApplyDescription_n"); RTC_DCHECK(description); if (local) { @@ -570,21 +539,22 @@ RTCError JsepTransportController::ApplyDescription_n( } RTCError error; - error = ValidateAndMaybeUpdateBundleGroup(local, type, description); + error = ValidateAndMaybeUpdateBundleGroups(local, type, description); if (!error.ok()) { return error; } - std::vector merged_encrypted_extension_ids; - if (bundle_group_) { - merged_encrypted_extension_ids = - MergeEncryptedHeaderExtensionIdsForBundle(description); + std::map> + merged_encrypted_extension_ids_by_bundle; + if (!bundles_.bundle_groups().empty()) { + merged_encrypted_extension_ids_by_bundle = + MergeEncryptedHeaderExtensionIdsForBundles(description); } for (const cricket::ContentInfo& content_info : description->contents()) { - // Don't create transports for rejected m-lines and bundled m-lines." + // Don't create transports for rejected m-lines and bundled m-lines. if (content_info.rejected || - (IsBundled(content_info.name) && content_info.name != *bundled_mid())) { + !bundles_.IsFirstMidInGroup(content_info.name)) { continue; } error = MaybeCreateJsepTransport(local, content_info, *description); @@ -600,14 +570,22 @@ RTCError JsepTransportController::ApplyDescription_n( const cricket::TransportInfo& transport_info = description->transport_infos()[i]; if (content_info.rejected) { - HandleRejectedContent(content_info, description); + // This may cause groups to be removed from |bundles_.bundle_groups()|. + HandleRejectedContent(content_info); continue; } - if (IsBundled(content_info.name) && content_info.name != *bundled_mid()) { - if (!HandleBundledContent(content_info)) { + const cricket::ContentGroup* established_bundle_group = + bundles_.LookupGroupByMid(content_info.name); + + // For bundle members that are not BUNDLE-tagged (not first in the group), + // configure their transport to be the same as the BUNDLE-tagged transport. + if (established_bundle_group && + content_info.name != *established_bundle_group->FirstContentName()) { + if (!HandleBundledContent(content_info, *established_bundle_group)) { return RTCError(RTCErrorType::INVALID_PARAMETER, - "Failed to process the bundled m= section with mid='" + + "Failed to process the bundled m= section with " + "mid='" + content_info.name + "'."); } continue; @@ -619,8 +597,13 @@ RTCError JsepTransportController::ApplyDescription_n( } std::vector extension_ids; - if (bundled_mid() && content_info.name == *bundled_mid()) { - extension_ids = merged_encrypted_extension_ids; + // Is BUNDLE-tagged (first in the group)? + if (established_bundle_group && + content_info.name == *established_bundle_group->FirstContentName()) { + auto it = merged_encrypted_extension_ids_by_bundle.find( + established_bundle_group); + RTC_DCHECK(it != merged_encrypted_extension_ids_by_bundle.end()); + extension_ids = it->second; } else { extension_ids = GetEncryptedHeaderExtensionIds(content_info); } @@ -653,56 +636,103 @@ RTCError JsepTransportController::ApplyDescription_n( } } if (type == SdpType::kAnswer) { - pending_mids_.clear(); + transports_.CommitTransports(); } return RTCError::OK(); } -RTCError JsepTransportController::ValidateAndMaybeUpdateBundleGroup( +RTCError JsepTransportController::ValidateAndMaybeUpdateBundleGroups( bool local, SdpType type, const cricket::SessionDescription* description) { RTC_DCHECK(description); - const cricket::ContentGroup* new_bundle_group = - description->GetGroupByName(cricket::GROUP_TYPE_BUNDLE); - // The BUNDLE group containing a MID that no m= section has is invalid. - if (new_bundle_group) { + std::vector new_bundle_groups = + description->GetGroupsByName(cricket::GROUP_TYPE_BUNDLE); + // Verify |new_bundle_groups|. + std::map new_bundle_groups_by_mid; + for (const cricket::ContentGroup* new_bundle_group : new_bundle_groups) { for (const std::string& content_name : new_bundle_group->content_names()) { + // The BUNDLE group must not contain a MID that is a member of a different + // BUNDLE group, or that contains the same MID multiple times. + if (new_bundle_groups_by_mid.find(content_name) != + new_bundle_groups_by_mid.end()) { + return RTCError(RTCErrorType::INVALID_PARAMETER, + "A BUNDLE group contains a MID='" + content_name + + "' that is already in a BUNDLE group."); + } + new_bundle_groups_by_mid.insert( + std::make_pair(content_name, new_bundle_group)); + // The BUNDLE group must not contain a MID that no m= section has. if (!description->GetContentByName(content_name)) { return RTCError(RTCErrorType::INVALID_PARAMETER, - "The BUNDLE group contains MID='" + content_name + + "A BUNDLE group contains a MID='" + content_name + "' matching no m= section."); } } } if (type == SdpType::kAnswer) { - const cricket::ContentGroup* offered_bundle_group = - local ? remote_desc_->GetGroupByName(cricket::GROUP_TYPE_BUNDLE) - : local_desc_->GetGroupByName(cricket::GROUP_TYPE_BUNDLE); + std::vector offered_bundle_groups = + local ? remote_desc_->GetGroupsByName(cricket::GROUP_TYPE_BUNDLE) + : local_desc_->GetGroupsByName(cricket::GROUP_TYPE_BUNDLE); + + std::map + offered_bundle_groups_by_mid; + for (const cricket::ContentGroup* offered_bundle_group : + offered_bundle_groups) { + for (const std::string& content_name : + offered_bundle_group->content_names()) { + offered_bundle_groups_by_mid[content_name] = offered_bundle_group; + } + } - if (new_bundle_group) { - // The BUNDLE group in answer should be a subset of offered group. + std::map + new_bundle_groups_by_offered_bundle_groups; + for (const cricket::ContentGroup* new_bundle_group : new_bundle_groups) { + if (!new_bundle_group->FirstContentName()) { + // Empty groups could be a subset of any group. + continue; + } + // The group in the answer (new_bundle_group) must have a corresponding + // group in the offer (original_group), because the answer groups may only + // be subsets of the offer groups. + auto it = offered_bundle_groups_by_mid.find( + *new_bundle_group->FirstContentName()); + if (it == offered_bundle_groups_by_mid.end()) { + return RTCError(RTCErrorType::INVALID_PARAMETER, + "A BUNDLE group was added in the answer that did not " + "exist in the offer."); + } + const cricket::ContentGroup* offered_bundle_group = it->second; + if (new_bundle_groups_by_offered_bundle_groups.find( + offered_bundle_group) != + new_bundle_groups_by_offered_bundle_groups.end()) { + return RTCError(RTCErrorType::INVALID_PARAMETER, + "A MID in the answer has changed group."); + } + new_bundle_groups_by_offered_bundle_groups.insert( + std::make_pair(offered_bundle_group, new_bundle_group)); for (const std::string& content_name : new_bundle_group->content_names()) { - if (!offered_bundle_group || - !offered_bundle_group->HasContentName(content_name)) { + it = offered_bundle_groups_by_mid.find(content_name); + // The BUNDLE group in answer should be a subset of offered group. + if (it == offered_bundle_groups_by_mid.end() || + it->second != offered_bundle_group) { return RTCError(RTCErrorType::INVALID_PARAMETER, - "The BUNDLE group in answer contains a MID='" + + "A BUNDLE group in answer contains a MID='" + content_name + - "' that was " - "not in the offered group."); + "' that was not in the offered group."); } } } - if (bundle_group_) { - for (const std::string& content_name : bundle_group_->content_names()) { + for (const auto& bundle_group : bundles_.bundle_groups()) { + for (const std::string& content_name : bundle_group->content_names()) { // An answer that removes m= sections from pre-negotiated BUNDLE group // without rejecting it, is invalid. - if (!new_bundle_group || - !new_bundle_group->HasContentName(content_name)) { + auto it = new_bundle_groups_by_mid.find(content_name); + if (it == new_bundle_groups_by_mid.end()) { auto* content_info = description->GetContentByName(content_name); if (!content_info || !content_info->rejected) { return RTCError(RTCErrorType::INVALID_PARAMETER, @@ -723,33 +753,35 @@ RTCError JsepTransportController::ValidateAndMaybeUpdateBundleGroup( } if (ShouldUpdateBundleGroup(type, description)) { - bundle_group_ = *new_bundle_group; + bundles_.Update(description); } - if (!bundled_mid()) { - return RTCError::OK(); - } + for (const auto& bundle_group : bundles_.bundle_groups()) { + if (!bundle_group->FirstContentName()) + continue; - auto bundled_content = description->GetContentByName(*bundled_mid()); - if (!bundled_content) { - return RTCError( - RTCErrorType::INVALID_PARAMETER, - "An m= section associated with the BUNDLE-tag doesn't exist."); - } + // The first MID in a BUNDLE group is BUNDLE-tagged. + auto bundled_content = + description->GetContentByName(*bundle_group->FirstContentName()); + if (!bundled_content) { + return RTCError( + RTCErrorType::INVALID_PARAMETER, + "An m= section associated with the BUNDLE-tag doesn't exist."); + } - // If the |bundled_content| is rejected, other contents in the bundle group - // should be rejected. - if (bundled_content->rejected) { - for (const auto& content_name : bundle_group_->content_names()) { - auto other_content = description->GetContentByName(content_name); - if (!other_content->rejected) { - return RTCError(RTCErrorType::INVALID_PARAMETER, - "The m= section with mid='" + content_name + - "' should be rejected."); + // If the |bundled_content| is rejected, other contents in the bundle group + // must also be rejected. + if (bundled_content->rejected) { + for (const auto& content_name : bundle_group->content_names()) { + auto other_content = description->GetContentByName(content_name); + if (!other_content->rejected) { + return RTCError(RTCErrorType::INVALID_PARAMETER, + "The m= section with mid='" + content_name + + "' should be rejected."); + } } } } - return RTCError::OK(); } @@ -768,68 +800,46 @@ RTCError JsepTransportController::ValidateContent( } void JsepTransportController::HandleRejectedContent( - const cricket::ContentInfo& content_info, - const cricket::SessionDescription* description) { + const cricket::ContentInfo& content_info) { // If the content is rejected, let the // BaseChannel/SctpTransport change the RtpTransport/DtlsTransport first, // then destroy the cricket::JsepTransport. - RemoveTransportForMid(content_info.name); - if (content_info.name == bundled_mid()) { - for (const auto& content_name : bundle_group_->content_names()) { - RemoveTransportForMid(content_name); + cricket::ContentGroup* bundle_group = + bundles_.LookupGroupByMid(content_info.name); + if (bundle_group && !bundle_group->content_names().empty() && + content_info.name == *bundle_group->FirstContentName()) { + // Rejecting a BUNDLE group's first mid means we are rejecting the entire + // group. + for (const auto& content_name : bundle_group->content_names()) { + transports_.RemoveTransportForMid(content_name); } - bundle_group_.reset(); - } else if (IsBundled(content_info.name)) { - // Remove the rejected content from the |bundle_group_|. - bundle_group_->RemoveContentName(content_info.name); - // Reset the bundle group if nothing left. - if (!bundle_group_->FirstContentName()) { - bundle_group_.reset(); + // Delete the BUNDLE group. + bundles_.DeleteGroup(bundle_group); + } else { + transports_.RemoveTransportForMid(content_info.name); + if (bundle_group) { + // Remove the rejected content from the |bundle_group|. + bundles_.DeleteMid(bundle_group, content_info.name); } } - MaybeDestroyJsepTransport(content_info.name); } bool JsepTransportController::HandleBundledContent( - const cricket::ContentInfo& content_info) { - auto jsep_transport = GetJsepTransportByName(*bundled_mid()); + const cricket::ContentInfo& content_info, + const cricket::ContentGroup& bundle_group) { + TRACE_EVENT0("webrtc", "JsepTransportController::HandleBundledContent"); + RTC_DCHECK(bundle_group.FirstContentName()); + auto jsep_transport = + GetJsepTransportByName(*bundle_group.FirstContentName()); RTC_DCHECK(jsep_transport); // If the content is bundled, let the // BaseChannel/SctpTransport change the RtpTransport/DtlsTransport first, // then destroy the cricket::JsepTransport. - if (SetTransportForMid(content_info.name, jsep_transport)) { - // TODO(bugs.webrtc.org/9719) For media transport this is far from ideal, - // because it means that we first create media transport and start - // connecting it, and then we destroy it. We will need to address it before - // video path is enabled. - MaybeDestroyJsepTransport(content_info.name); - return true; - } - return false; -} - -bool JsepTransportController::SetTransportForMid( - const std::string& mid, - cricket::JsepTransport* jsep_transport) { - RTC_DCHECK(jsep_transport); - if (mid_to_transport_[mid] == jsep_transport) { - return true; - } - RTC_DCHECK_RUN_ON(network_thread_); - pending_mids_.push_back(mid); - mid_to_transport_[mid] = jsep_transport; - return config_.transport_observer->OnTransportChanged( - mid, jsep_transport->rtp_transport(), jsep_transport->RtpDtlsTransport(), - jsep_transport->data_channel_transport()); -} - -void JsepTransportController::RemoveTransportForMid(const std::string& mid) { - bool ret = config_.transport_observer->OnTransportChanged(mid, nullptr, - nullptr, nullptr); - // Calling OnTransportChanged with nullptr should always succeed, since it is - // only expected to fail when adding media to a transport (not removing). - RTC_DCHECK(ret); - mid_to_transport_.erase(mid); + // TODO(bugs.webrtc.org/9719) For media transport this is far from ideal, + // because it means that we first create media transport and start + // connecting it, and then we destroy it. We will need to address it before + // video path is enabled. + return transports_.SetTransportForMid(content_info.name, jsep_transport); } cricket::JsepTransportDescription @@ -838,6 +848,8 @@ JsepTransportController::CreateJsepTransportDescription( const cricket::TransportInfo& transport_info, const std::vector& encrypted_extension_ids, int rtp_abs_sendtime_extn_id) { + TRACE_EVENT0("webrtc", + "JsepTransportController::CreateJsepTransportDescription"); const cricket::MediaContentDescription* content_desc = content_info.media_description(); RTC_DCHECK(content_desc); @@ -863,11 +875,11 @@ bool JsepTransportController::ShouldUpdateBundleGroup( } RTC_DCHECK(local_desc_ && remote_desc_); - const cricket::ContentGroup* local_bundle = - local_desc_->GetGroupByName(cricket::GROUP_TYPE_BUNDLE); - const cricket::ContentGroup* remote_bundle = - remote_desc_->GetGroupByName(cricket::GROUP_TYPE_BUNDLE); - return local_bundle && remote_bundle; + std::vector local_bundles = + local_desc_->GetGroupsByName(cricket::GROUP_TYPE_BUNDLE); + std::vector remote_bundles = + remote_desc_->GetGroupsByName(cricket::GROUP_TYPE_BUNDLE); + return !local_bundles.empty() && !remote_bundles.empty(); } std::vector JsepTransportController::GetEncryptedHeaderExtensionIds( @@ -891,26 +903,31 @@ std::vector JsepTransportController::GetEncryptedHeaderExtensionIds( return encrypted_header_extension_ids; } -std::vector -JsepTransportController::MergeEncryptedHeaderExtensionIdsForBundle( +std::map> +JsepTransportController::MergeEncryptedHeaderExtensionIdsForBundles( const cricket::SessionDescription* description) { RTC_DCHECK(description); - RTC_DCHECK(bundle_group_); - - std::vector merged_ids; + RTC_DCHECK(!bundles_.bundle_groups().empty()); + std::map> + merged_encrypted_extension_ids_by_bundle; // Union the encrypted header IDs in the group when bundle is enabled. for (const cricket::ContentInfo& content_info : description->contents()) { - if (bundle_group_->HasContentName(content_info.name)) { - std::vector extension_ids = - GetEncryptedHeaderExtensionIds(content_info); - for (int id : extension_ids) { - if (!absl::c_linear_search(merged_ids, id)) { - merged_ids.push_back(id); - } + auto group = bundles_.LookupGroupByMid(content_info.name); + if (!group) + continue; + // Get or create list of IDs for the BUNDLE group. + std::vector& merged_ids = + merged_encrypted_extension_ids_by_bundle[group]; + // Add IDs not already in the list. + std::vector extension_ids = + GetEncryptedHeaderExtensionIds(content_info); + for (int id : extension_ids) { + if (!absl::c_linear_search(merged_ids, id)) { + merged_ids.push_back(id); } } } - return merged_ids; + return merged_encrypted_extension_ids_by_bundle; } int JsepTransportController::GetRtpAbsSendTimeHeaderExtensionId( @@ -925,39 +942,37 @@ int JsepTransportController::GetRtpAbsSendTimeHeaderExtensionId( const webrtc::RtpExtension* send_time_extension = webrtc::RtpExtension::FindHeaderExtensionByUri( content_desc->rtp_header_extensions(), - webrtc::RtpExtension::kAbsSendTimeUri); + webrtc::RtpExtension::kAbsSendTimeUri, + config_.crypto_options.srtp.enable_encrypted_rtp_header_extensions + ? webrtc::RtpExtension::kPreferEncryptedExtension + : webrtc::RtpExtension::kDiscardEncryptedExtension); return send_time_extension ? send_time_extension->id : -1; } const cricket::JsepTransport* JsepTransportController::GetJsepTransportForMid( const std::string& mid) const { - auto it = mid_to_transport_.find(mid); - return it == mid_to_transport_.end() ? nullptr : it->second; + return transports_.GetTransportForMid(mid); } cricket::JsepTransport* JsepTransportController::GetJsepTransportForMid( const std::string& mid) { - auto it = mid_to_transport_.find(mid); - return it == mid_to_transport_.end() ? nullptr : it->second; + return transports_.GetTransportForMid(mid); } const cricket::JsepTransport* JsepTransportController::GetJsepTransportByName( const std::string& transport_name) const { - auto it = jsep_transports_by_name_.find(transport_name); - return (it == jsep_transports_by_name_.end()) ? nullptr : it->second.get(); + return transports_.GetTransportByName(transport_name); } cricket::JsepTransport* JsepTransportController::GetJsepTransportByName( const std::string& transport_name) { - auto it = jsep_transports_by_name_.find(transport_name); - return (it == jsep_transports_by_name_.end()) ? nullptr : it->second.get(); + return transports_.GetTransportByName(transport_name); } RTCError JsepTransportController::MaybeCreateJsepTransport( bool local, const cricket::ContentInfo& content_info, const cricket::SessionDescription& description) { - RTC_DCHECK(network_thread_->IsCurrent()); cricket::JsepTransport* transport = GetJsepTransportByName(content_info.name); if (transport) { return RTCError::OK(); @@ -981,7 +996,6 @@ RTCError JsepTransportController::MaybeCreateJsepTransport( std::unique_ptr unencrypted_rtp_transport; std::unique_ptr sdes_transport; std::unique_ptr dtls_srtp_transport; - std::unique_ptr datagram_rtp_transport; rtc::scoped_refptr rtcp_ice; if (config_.rtcp_mux_policy != @@ -1017,57 +1031,27 @@ RTCError JsepTransportController::MaybeCreateJsepTransport( std::make_unique( content_info.name, certificate_, std::move(ice), std::move(rtcp_ice), std::move(unencrypted_rtp_transport), std::move(sdes_transport), - std::move(dtls_srtp_transport), std::move(datagram_rtp_transport), - std::move(rtp_dtls_transport), std::move(rtcp_dtls_transport), - std::move(sctp_transport)); + std::move(dtls_srtp_transport), std::move(rtp_dtls_transport), + std::move(rtcp_dtls_transport), std::move(sctp_transport)); jsep_transport->rtp_transport()->SignalRtcpPacketReceived.connect( this, &JsepTransportController::OnRtcpPacketReceived_n); jsep_transport->SignalRtcpMuxActive.connect( this, &JsepTransportController::UpdateAggregateStates_n); - SetTransportForMid(content_info.name, jsep_transport.get()); - - jsep_transports_by_name_[content_info.name] = std::move(jsep_transport); + transports_.RegisterTransport(content_info.name, std::move(jsep_transport)); UpdateAggregateStates_n(); return RTCError::OK(); } -void JsepTransportController::MaybeDestroyJsepTransport( - const std::string& mid) { - auto jsep_transport = GetJsepTransportByName(mid); - if (!jsep_transport) { - return; - } - - // Don't destroy the JsepTransport if there are still media sections referring - // to it. - for (const auto& kv : mid_to_transport_) { - if (kv.second == jsep_transport) { - return; - } - } - - jsep_transports_by_name_.erase(mid); - UpdateAggregateStates_n(); -} - void JsepTransportController::DestroyAllJsepTransports_n() { - RTC_DCHECK(network_thread_->IsCurrent()); - - for (const auto& jsep_transport : jsep_transports_by_name_) { - config_.transport_observer->OnTransportChanged(jsep_transport.first, - nullptr, nullptr, nullptr); - } - - jsep_transports_by_name_.clear(); + transports_.DestroyAllTransports(); } void JsepTransportController::SetIceRole_n(cricket::IceRole ice_role) { - RTC_DCHECK(network_thread_->IsCurrent()); - ice_role_ = ice_role; - for (auto& dtls : GetDtlsTransports()) { + auto dtls_transports = GetDtlsTransports(); + for (auto& dtls : dtls_transports) { dtls->ice_transport()->SetIceRole(ice_role_); } } @@ -1122,7 +1106,6 @@ cricket::IceRole JsepTransportController::DetermineIceRole( void JsepTransportController::OnTransportWritableState_n( rtc::PacketTransportInternal* transport) { - RTC_DCHECK(network_thread_->IsCurrent()); RTC_LOG(LS_INFO) << " Transport " << transport->transport_name() << " writability changed to " << transport->writable() << "."; @@ -1131,58 +1114,44 @@ void JsepTransportController::OnTransportWritableState_n( void JsepTransportController::OnTransportReceivingState_n( rtc::PacketTransportInternal* transport) { - RTC_DCHECK(network_thread_->IsCurrent()); UpdateAggregateStates_n(); } void JsepTransportController::OnTransportGatheringState_n( cricket::IceTransportInternal* transport) { - RTC_DCHECK(network_thread_->IsCurrent()); UpdateAggregateStates_n(); } void JsepTransportController::OnTransportCandidateGathered_n( cricket::IceTransportInternal* transport, const cricket::Candidate& candidate) { - RTC_DCHECK(network_thread_->IsCurrent()); - // We should never signal peer-reflexive candidates. if (candidate.type() == cricket::PRFLX_PORT_TYPE) { RTC_NOTREACHED(); return; } - std::string transport_name = transport->transport_name(); - invoker_.AsyncInvoke( - RTC_FROM_HERE, signaling_thread_, [this, transport_name, candidate] { - SignalIceCandidatesGathered(transport_name, {candidate}); - }); + + signal_ice_candidates_gathered_.Send( + transport->transport_name(), std::vector{candidate}); } void JsepTransportController::OnTransportCandidateError_n( cricket::IceTransportInternal* transport, const cricket::IceCandidateErrorEvent& event) { - RTC_DCHECK(network_thread_->IsCurrent()); - - invoker_.AsyncInvoke(RTC_FROM_HERE, signaling_thread_, - [this, event] { SignalIceCandidateError(event); }); + signal_ice_candidate_error_.Send(event); } void JsepTransportController::OnTransportCandidatesRemoved_n( cricket::IceTransportInternal* transport, const cricket::Candidates& candidates) { - invoker_.AsyncInvoke( - RTC_FROM_HERE, signaling_thread_, - [this, candidates] { SignalIceCandidatesRemoved(candidates); }); + signal_ice_candidates_removed_.Send(candidates); } void JsepTransportController::OnTransportCandidatePairChanged_n( const cricket::CandidatePairChangeEvent& event) { - invoker_.AsyncInvoke(RTC_FROM_HERE, signaling_thread_, [this, event] { - SignalIceCandidatePairChanged(event); - }); + signal_ice_candidate_pair_changed_.Send(event); } void JsepTransportController::OnTransportRoleConflict_n( cricket::IceTransportInternal* transport) { - RTC_DCHECK(network_thread_->IsCurrent()); // Note: since the role conflict is handled entirely on the network thread, // we don't need to worry about role conflicts occurring on two ports at // once. The first one encountered should immediately reverse the role. @@ -1199,7 +1168,6 @@ void JsepTransportController::OnTransportRoleConflict_n( void JsepTransportController::OnTransportStateChanged_n( cricket::IceTransportInternal* transport) { - RTC_DCHECK(network_thread_->IsCurrent()); RTC_LOG(LS_INFO) << transport->transport_name() << " Transport " << transport->component() << " state changed. Check if state is complete."; @@ -1207,8 +1175,7 @@ void JsepTransportController::OnTransportStateChanged_n( } void JsepTransportController::UpdateAggregateStates_n() { - RTC_DCHECK(network_thread_->IsCurrent()); - + TRACE_EVENT0("webrtc", "JsepTransportController::UpdateAggregateStates_n"); auto dtls_transports = GetDtlsTransports(); cricket::IceConnectionState new_connection_state = cricket::kIceConnectionConnecting; @@ -1224,7 +1191,7 @@ void JsepTransportController::UpdateAggregateStates_n() { bool all_done_gathering = !dtls_transports.empty(); std::map ice_state_counts; - std::map dtls_state_counts; + std::map dtls_state_counts; for (const auto& dtls : dtls_transports) { any_failed = any_failed || dtls->ice_transport()->GetState() == @@ -1257,10 +1224,7 @@ void JsepTransportController::UpdateAggregateStates_n() { if (ice_connection_state_ != new_connection_state) { ice_connection_state_ = new_connection_state; - invoker_.AsyncInvoke( - RTC_FROM_HERE, signaling_thread_, [this, new_connection_state] { - SignalIceConnectionState.Send(new_connection_state); - }); + signal_ice_connection_state_.Send(new_connection_state); } // Compute the current RTCIceConnectionState as described in @@ -1316,16 +1280,11 @@ void JsepTransportController::UpdateAggregateStates_n() { new_ice_connection_state == PeerConnectionInterface::kIceConnectionCompleted) { // Ensure that we never skip over the "connected" state. - invoker_.AsyncInvoke(RTC_FROM_HERE, signaling_thread_, [this] { - SignalStandardizedIceConnectionState( - PeerConnectionInterface::kIceConnectionConnected); - }); + signal_standardized_ice_connection_state_.Send( + PeerConnectionInterface::kIceConnectionConnected); } standardized_ice_connection_state_ = new_ice_connection_state; - invoker_.AsyncInvoke( - RTC_FROM_HERE, signaling_thread_, [this, new_ice_connection_state] { - SignalStandardizedIceConnectionState(new_ice_connection_state); - }); + signal_standardized_ice_connection_state_.Send(new_ice_connection_state); } // Compute the current RTCPeerConnectionState as described in @@ -1334,16 +1293,15 @@ void JsepTransportController::UpdateAggregateStates_n() { // Note that "connecting" is only a valid state for DTLS transports while // "checking", "completed" and "disconnected" are only valid for ICE // transports. - int total_connected = total_ice_connected + - dtls_state_counts[cricket::DTLS_TRANSPORT_CONNECTED]; + int total_connected = + total_ice_connected + dtls_state_counts[DtlsTransportState::kConnected]; int total_dtls_connecting = - dtls_state_counts[cricket::DTLS_TRANSPORT_CONNECTING]; + dtls_state_counts[DtlsTransportState::kConnecting]; int total_failed = - total_ice_failed + dtls_state_counts[cricket::DTLS_TRANSPORT_FAILED]; + total_ice_failed + dtls_state_counts[DtlsTransportState::kFailed]; int total_closed = - total_ice_closed + dtls_state_counts[cricket::DTLS_TRANSPORT_CLOSED]; - int total_new = - total_ice_new + dtls_state_counts[cricket::DTLS_TRANSPORT_NEW]; + total_ice_closed + dtls_state_counts[DtlsTransportState::kClosed]; + int total_new = total_ice_new + dtls_state_counts[DtlsTransportState::kNew]; int total_transports = total_ice * 2; if (total_failed > 0) { @@ -1376,10 +1334,7 @@ void JsepTransportController::UpdateAggregateStates_n() { if (combined_connection_state_ != new_combined_state) { combined_connection_state_ = new_combined_state; - invoker_.AsyncInvoke(RTC_FROM_HERE, signaling_thread_, - [this, new_combined_state] { - SignalConnectionState(new_combined_state); - }); + signal_connection_state_.Send(new_combined_state); } // Compute the gathering state. @@ -1392,10 +1347,7 @@ void JsepTransportController::UpdateAggregateStates_n() { } if (ice_gathering_state_ != new_gathering_state) { ice_gathering_state_ = new_gathering_state; - invoker_.AsyncInvoke(RTC_FROM_HERE, signaling_thread_, - [this, new_gathering_state] { - SignalIceGatheringState(new_gathering_state); - }); + signal_ice_gathering_state_.Send(new_gathering_state); } } @@ -1408,7 +1360,24 @@ void JsepTransportController::OnRtcpPacketReceived_n( void JsepTransportController::OnDtlsHandshakeError( rtc::SSLHandshakeError error) { - SignalDtlsHandshakeError(error); + config_.on_dtls_handshake_error_(error); +} + +bool JsepTransportController::OnTransportChanged( + const std::string& mid, + cricket::JsepTransport* jsep_transport) { + if (config_.transport_observer) { + if (jsep_transport) { + return config_.transport_observer->OnTransportChanged( + mid, jsep_transport->rtp_transport(), + jsep_transport->RtpDtlsTransport(), + jsep_transport->data_channel_transport()); + } else { + return config_.transport_observer->OnTransportChanged(mid, nullptr, + nullptr, nullptr); + } + } + return false; } } // namespace webrtc diff --git a/pc/jsep_transport_controller.h b/pc/jsep_transport_controller.h index f0adeedf26..71b01bffb2 100644 --- a/pc/jsep_transport_controller.h +++ b/pc/jsep_transport_controller.h @@ -11,32 +11,63 @@ #ifndef PC_JSEP_TRANSPORT_CONTROLLER_H_ #define PC_JSEP_TRANSPORT_CONTROLLER_H_ +#include + +#include #include #include #include #include #include +#include "absl/types/optional.h" +#include "api/async_dns_resolver.h" #include "api/candidate.h" #include "api/crypto/crypto_options.h" #include "api/ice_transport_factory.h" +#include "api/ice_transport_interface.h" +#include "api/jsep.h" #include "api/peer_connection_interface.h" +#include "api/rtc_error.h" #include "api/rtc_event_log/rtc_event_log.h" +#include "api/scoped_refptr.h" +#include "api/sequence_checker.h" +#include "api/transport/data_channel_transport_interface.h" +#include "api/transport/sctp_transport_factory_interface.h" #include "media/sctp/sctp_transport_internal.h" #include "p2p/base/dtls_transport.h" #include "p2p/base/dtls_transport_factory.h" +#include "p2p/base/dtls_transport_internal.h" +#include "p2p/base/ice_transport_internal.h" #include "p2p/base/p2p_transport_channel.h" +#include "p2p/base/packet_transport_internal.h" +#include "p2p/base/port.h" +#include "p2p/base/port_allocator.h" +#include "p2p/base/transport_description.h" +#include "p2p/base/transport_info.h" #include "pc/channel.h" #include "pc/dtls_srtp_transport.h" #include "pc/dtls_transport.h" #include "pc/jsep_transport.h" +#include "pc/jsep_transport_collection.h" #include "pc/rtp_transport.h" +#include "pc/rtp_transport_internal.h" +#include "pc/sctp_transport.h" +#include "pc/session_description.h" #include "pc/srtp_transport.h" -#include "rtc_base/async_invoker.h" +#include "pc/transport_stats.h" +#include "rtc_base/callback_list.h" +#include "rtc_base/checks.h" #include "rtc_base/constructor_magic.h" +#include "rtc_base/copy_on_write_buffer.h" +#include "rtc_base/helpers.h" #include "rtc_base/ref_counted_object.h" -#include "rtc_base/callback_list.h" +#include "rtc_base/rtc_certificate.h" +#include "rtc_base/ssl_certificate.h" +#include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/third_party/sigslot/sigslot.h" +#include "rtc_base/thread.h" +#include "rtc_base/thread_annotations.h" namespace rtc { class Thread; @@ -98,20 +129,25 @@ class JsepTransportController : public sigslot::has_slots<> { std::function rtcp_handler; + // Initial value for whether DtlsTransport reset causes a reset + // of SRTP parameters. bool active_reset_srtp_params = false; RtcEventLog* event_log = nullptr; // Factory for SCTP transports. SctpTransportFactoryInterface* sctp_factory = nullptr; + std::function on_dtls_handshake_error_; }; - // The ICE related events are signaled on the |signaling_thread|. - // All the transport related methods are called on the |network_thread|. - JsepTransportController(rtc::Thread* signaling_thread, - rtc::Thread* network_thread, - cricket::PortAllocator* port_allocator, - AsyncResolverFactory* async_resolver_factory, - Config config); + // The ICE related events are fired on the |network_thread|. + // All the transport related methods are called on the |network_thread| + // and destruction of the JsepTransportController must occur on the + // |network_thread|. + JsepTransportController( + rtc::Thread* network_thread, + cricket::PortAllocator* port_allocator, + AsyncDnsResolverFactoryInterface* async_dns_resolver_factory, + Config config); virtual ~JsepTransportController(); // The main method to be called; applies a description at the transport @@ -192,57 +228,113 @@ class JsepTransportController : public sigslot::has_slots<> { // and deletes unused transports, but doesn't consider anything more complex. void RollbackTransports(); - // All of these signals are fired on the signaling thread. + // F: void(const std::string&, const std::vector&) + template + void SubscribeIceCandidateGathered(F&& callback) { + RTC_DCHECK_RUN_ON(network_thread_); + signal_ice_candidates_gathered_.AddReceiver(std::forward(callback)); + } + + // F: void(cricket::IceConnectionState) + template + void SubscribeIceConnectionState(F&& callback) { + RTC_DCHECK_RUN_ON(network_thread_); + signal_ice_connection_state_.AddReceiver(std::forward(callback)); + } + + // F: void(PeerConnectionInterface::PeerConnectionState) + template + void SubscribeConnectionState(F&& callback) { + RTC_DCHECK_RUN_ON(network_thread_); + signal_connection_state_.AddReceiver(std::forward(callback)); + } + + // F: void(PeerConnectionInterface::IceConnectionState) + template + void SubscribeStandardizedIceConnectionState(F&& callback) { + RTC_DCHECK_RUN_ON(network_thread_); + signal_standardized_ice_connection_state_.AddReceiver( + std::forward(callback)); + } + + // F: void(cricket::IceGatheringState) + template + void SubscribeIceGatheringState(F&& callback) { + RTC_DCHECK_RUN_ON(network_thread_); + signal_ice_gathering_state_.AddReceiver(std::forward(callback)); + } + + // F: void(const cricket::IceCandidateErrorEvent&) + template + void SubscribeIceCandidateError(F&& callback) { + RTC_DCHECK_RUN_ON(network_thread_); + signal_ice_candidate_error_.AddReceiver(std::forward(callback)); + } + + // F: void(const std::vector&) + template + void SubscribeIceCandidatesRemoved(F&& callback) { + RTC_DCHECK_RUN_ON(network_thread_); + signal_ice_candidates_removed_.AddReceiver(std::forward(callback)); + } + + // F: void(const cricket::CandidatePairChangeEvent&) + template + void SubscribeIceCandidatePairChanged(F&& callback) { + RTC_DCHECK_RUN_ON(network_thread_); + signal_ice_candidate_pair_changed_.AddReceiver(std::forward(callback)); + } + + private: + // All of these callbacks are fired on the network thread. // If any transport failed => failed, // Else if all completed => completed, // Else if all connected => connected, // Else => connecting - CallbackList SignalIceConnectionState; + CallbackList signal_ice_connection_state_ + RTC_GUARDED_BY(network_thread_); - sigslot::signal1 - SignalConnectionState; + CallbackList + signal_connection_state_ RTC_GUARDED_BY(network_thread_); - sigslot::signal1 - SignalStandardizedIceConnectionState; + CallbackList + signal_standardized_ice_connection_state_ RTC_GUARDED_BY(network_thread_); // If all transports done gathering => complete, // Else if any are gathering => gathering, // Else => new - sigslot::signal1 SignalIceGatheringState; + CallbackList signal_ice_gathering_state_ + RTC_GUARDED_BY(network_thread_); - // (mid, candidates) - sigslot::signal2&> - SignalIceCandidatesGathered; + // [mid, candidates] + CallbackList&> + signal_ice_candidates_gathered_ RTC_GUARDED_BY(network_thread_); - sigslot::signal1 - SignalIceCandidateError; + CallbackList + signal_ice_candidate_error_ RTC_GUARDED_BY(network_thread_); - sigslot::signal1&> - SignalIceCandidatesRemoved; + CallbackList&> + signal_ice_candidates_removed_ RTC_GUARDED_BY(network_thread_); - sigslot::signal1 - SignalIceCandidatePairChanged; + CallbackList + signal_ice_candidate_pair_changed_ RTC_GUARDED_BY(network_thread_); - sigslot::signal1 SignalDtlsHandshakeError; - - private: RTCError ApplyDescription_n(bool local, SdpType type, - const cricket::SessionDescription* description); - RTCError ValidateAndMaybeUpdateBundleGroup( + const cricket::SessionDescription* description) + RTC_RUN_ON(network_thread_); + RTCError ValidateAndMaybeUpdateBundleGroups( bool local, SdpType type, const cricket::SessionDescription* description); RTCError ValidateContent(const cricket::ContentInfo& content_info); - void HandleRejectedContent(const cricket::ContentInfo& content_info, - const cricket::SessionDescription* description); - bool HandleBundledContent(const cricket::ContentInfo& content_info); - - bool SetTransportForMid(const std::string& mid, - cricket::JsepTransport* jsep_transport); - void RemoveTransportForMid(const std::string& mid); + void HandleRejectedContent(const cricket::ContentInfo& content_info) + RTC_RUN_ON(network_thread_); + bool HandleBundledContent(const cricket::ContentInfo& content_info, + const cricket::ContentGroup& bundle_group) + RTC_RUN_ON(network_thread_); cricket::JsepTransportDescription CreateJsepTransportDescription( const cricket::ContentInfo& content_info, @@ -250,22 +342,11 @@ class JsepTransportController : public sigslot::has_slots<> { const std::vector& encrypted_extension_ids, int rtp_abs_sendtime_extn_id); - absl::optional bundled_mid() const { - absl::optional bundled_mid; - if (bundle_group_ && bundle_group_->FirstContentName()) { - bundled_mid = *(bundle_group_->FirstContentName()); - } - return bundled_mid; - } - - bool IsBundled(const std::string& mid) const { - return bundle_group_ && bundle_group_->HasContentName(mid); - } - bool ShouldUpdateBundleGroup(SdpType type, const cricket::SessionDescription* description); - std::vector MergeEncryptedHeaderExtensionIdsForBundle( + std::map> + MergeEncryptedHeaderExtensionIdsForBundles( const cricket::SessionDescription* description); std::vector GetEncryptedHeaderExtensionIds( const cricket::ContentInfo& content_info); @@ -278,15 +359,16 @@ class JsepTransportController : public sigslot::has_slots<> { // transports are bundled on (In current implementation, it is the first // content in the BUNDLE group). const cricket::JsepTransport* GetJsepTransportForMid( - const std::string& mid) const; - cricket::JsepTransport* GetJsepTransportForMid(const std::string& mid); + const std::string& mid) const RTC_RUN_ON(network_thread_); + cricket::JsepTransport* GetJsepTransportForMid(const std::string& mid) + RTC_RUN_ON(network_thread_); // Get the JsepTransport without considering the BUNDLE group. Return nullptr // if the JsepTransport is destroyed. const cricket::JsepTransport* GetJsepTransportByName( - const std::string& transport_name) const; + const std::string& transport_name) const RTC_RUN_ON(network_thread_); cricket::JsepTransport* GetJsepTransportByName( - const std::string& transport_name); + const std::string& transport_name) RTC_RUN_ON(network_thread_); // Creates jsep transport. Noop if transport is already created. // Transport is created either during SetLocalDescription (|local| == true) or @@ -295,12 +377,12 @@ class JsepTransportController : public sigslot::has_slots<> { RTCError MaybeCreateJsepTransport( bool local, const cricket::ContentInfo& content_info, - const cricket::SessionDescription& description); + const cricket::SessionDescription& description) + RTC_RUN_ON(network_thread_); - void MaybeDestroyJsepTransport(const std::string& mid); - void DestroyAllJsepTransports_n(); + void DestroyAllJsepTransports_n() RTC_RUN_ON(network_thread_); - void SetIceRole_n(cricket::IceRole ice_role); + void SetIceRole_n(cricket::IceRole ice_role) RTC_RUN_ON(network_thread_); cricket::IceRole DetermineIceRole( cricket::JsepTransport* jsep_transport, @@ -334,39 +416,44 @@ class JsepTransportController : public sigslot::has_slots<> { std::vector GetDtlsTransports(); // Handlers for signals from Transport. - void OnTransportWritableState_n(rtc::PacketTransportInternal* transport); - void OnTransportReceivingState_n(rtc::PacketTransportInternal* transport); - void OnTransportGatheringState_n(cricket::IceTransportInternal* transport); + void OnTransportWritableState_n(rtc::PacketTransportInternal* transport) + RTC_RUN_ON(network_thread_); + void OnTransportReceivingState_n(rtc::PacketTransportInternal* transport) + RTC_RUN_ON(network_thread_); + void OnTransportGatheringState_n(cricket::IceTransportInternal* transport) + RTC_RUN_ON(network_thread_); void OnTransportCandidateGathered_n(cricket::IceTransportInternal* transport, - const cricket::Candidate& candidate); - void OnTransportCandidateError_n( - cricket::IceTransportInternal* transport, - const cricket::IceCandidateErrorEvent& event); + const cricket::Candidate& candidate) + RTC_RUN_ON(network_thread_); + void OnTransportCandidateError_n(cricket::IceTransportInternal* transport, + const cricket::IceCandidateErrorEvent& event) + RTC_RUN_ON(network_thread_); void OnTransportCandidatesRemoved_n(cricket::IceTransportInternal* transport, - const cricket::Candidates& candidates); - void OnTransportRoleConflict_n(cricket::IceTransportInternal* transport); - void OnTransportStateChanged_n(cricket::IceTransportInternal* transport); + const cricket::Candidates& candidates) + RTC_RUN_ON(network_thread_); + void OnTransportRoleConflict_n(cricket::IceTransportInternal* transport) + RTC_RUN_ON(network_thread_); + void OnTransportStateChanged_n(cricket::IceTransportInternal* transport) + RTC_RUN_ON(network_thread_); void OnTransportCandidatePairChanged_n( - const cricket::CandidatePairChangeEvent& event); - void UpdateAggregateStates_n(); + const cricket::CandidatePairChangeEvent& event) + RTC_RUN_ON(network_thread_); + void UpdateAggregateStates_n() RTC_RUN_ON(network_thread_); void OnRtcpPacketReceived_n(rtc::CopyOnWriteBuffer* packet, - int64_t packet_time_us); + int64_t packet_time_us) + RTC_RUN_ON(network_thread_); void OnDtlsHandshakeError(rtc::SSLHandshakeError error); - rtc::Thread* const signaling_thread_ = nullptr; + bool OnTransportChanged(const std::string& mid, + cricket::JsepTransport* transport); + rtc::Thread* const network_thread_ = nullptr; cricket::PortAllocator* const port_allocator_ = nullptr; - AsyncResolverFactory* const async_resolver_factory_ = nullptr; - - std::map> - jsep_transports_by_name_; - // This keeps track of the mapping between media section - // (BaseChannel/SctpTransport) and the JsepTransport underneath. - std::map mid_to_transport_; - // Keep track of mids that have been mapped to transports. Used for rollback. - std::vector pending_mids_ RTC_GUARDED_BY(network_thread_); + AsyncDnsResolverFactoryInterface* const async_dns_resolver_factory_ = nullptr; + + JsepTransportCollection transports_ RTC_GUARDED_BY(network_thread_); // Aggregate states for Transports. // standardized_ice_connection_state_ is intended to replace // ice_connection_state, see bugs.webrtc.org/9308 @@ -379,19 +466,19 @@ class JsepTransportController : public sigslot::has_slots<> { PeerConnectionInterface::PeerConnectionState::kNew; cricket::IceGatheringState ice_gathering_state_ = cricket::kIceGatheringNew; - Config config_; + const Config config_; + bool active_reset_srtp_params_ RTC_GUARDED_BY(network_thread_); const cricket::SessionDescription* local_desc_ = nullptr; const cricket::SessionDescription* remote_desc_ = nullptr; absl::optional initial_offerer_; - absl::optional bundle_group_; - cricket::IceConfig ice_config_; cricket::IceRole ice_role_ = cricket::ICEROLE_CONTROLLING; uint64_t ice_tiebreaker_ = rtc::CreateRandomId64(); rtc::scoped_refptr certificate_; - rtc::AsyncInvoker invoker_; + + BundleManager bundles_; RTC_DISALLOW_COPY_AND_ASSIGN(JsepTransportController); }; diff --git a/pc/jsep_transport_controller_unittest.cc b/pc/jsep_transport_controller_unittest.cc index 40dc23e535..2b261c83c8 100644 --- a/pc/jsep_transport_controller_unittest.cc +++ b/pc/jsep_transport_controller_unittest.cc @@ -13,6 +13,7 @@ #include #include +#include "api/dtls_transport_interface.h" #include "p2p/base/dtls_transport_factory.h" #include "p2p/base/fake_dtls_transport.h" #include "p2p/base/fake_ice_transport.h" @@ -33,6 +34,8 @@ static const char kIceUfrag2[] = "u0002"; static const char kIcePwd2[] = "TESTICEPWD00000000000002"; static const char kIceUfrag3[] = "u0003"; static const char kIcePwd3[] = "TESTICEPWD00000000000003"; +static const char kIceUfrag4[] = "u0004"; +static const char kIcePwd4[] = "TESTICEPWD00000000000004"; static const char kAudioMid1[] = "audio1"; static const char kAudioMid2[] = "audio2"; static const char kVideoMid1[] = "video1"; @@ -48,7 +51,7 @@ class FakeIceTransportFactory : public webrtc::IceTransportFactory { const std::string& transport_name, int component, IceTransportInit init) override { - return new rtc::RefCountedObject( + return rtc::make_ref_counted( std::make_unique(transport_name, component)); } }; @@ -57,7 +60,8 @@ class FakeDtlsTransportFactory : public cricket::DtlsTransportFactory { public: std::unique_ptr CreateDtlsTransport( cricket::IceTransportInternal* ice, - const webrtc::CryptoOptions& crypto_options) override { + const webrtc::CryptoOptions& crypto_options, + rtc::SSLProtocolVersion max_version) override { return std::make_unique( static_cast(ice)); } @@ -74,7 +78,6 @@ class JsepTransportControllerTest : public JsepTransportController::Observer, void CreateJsepTransportController( JsepTransportController::Config config, - rtc::Thread* signaling_thread = rtc::Thread::Current(), rtc::Thread* network_thread = rtc::Thread::Current(), cricket::PortAllocator* port_allocator = nullptr) { config.transport_observer = this; @@ -82,25 +85,37 @@ class JsepTransportControllerTest : public JsepTransportController::Observer, int64_t packet_time_us) { RTC_NOTREACHED(); }; config.ice_transport_factory = fake_ice_transport_factory_.get(); config.dtls_transport_factory = fake_dtls_transport_factory_.get(); + config.on_dtls_handshake_error_ = [](rtc::SSLHandshakeError s) {}; transport_controller_ = std::make_unique( - signaling_thread, network_thread, port_allocator, - nullptr /* async_resolver_factory */, config); - ConnectTransportControllerSignals(); + network_thread, port_allocator, nullptr /* async_resolver_factory */, + config); + network_thread->Invoke(RTC_FROM_HERE, + [&] { ConnectTransportControllerSignals(); }); } void ConnectTransportControllerSignals() { - transport_controller_->SignalIceConnectionState.AddReceiver( + transport_controller_->SubscribeIceConnectionState( [this](cricket::IceConnectionState s) { JsepTransportControllerTest::OnConnectionState(s); }); - transport_controller_->SignalStandardizedIceConnectionState.connect( - this, &JsepTransportControllerTest::OnStandardizedIceConnectionState); - transport_controller_->SignalConnectionState.connect( - this, &JsepTransportControllerTest::OnCombinedConnectionState); - transport_controller_->SignalIceGatheringState.connect( - this, &JsepTransportControllerTest::OnGatheringState); - transport_controller_->SignalIceCandidatesGathered.connect( - this, &JsepTransportControllerTest::OnCandidatesGathered); + transport_controller_->SubscribeConnectionState( + [this](PeerConnectionInterface::PeerConnectionState s) { + JsepTransportControllerTest::OnCombinedConnectionState(s); + }); + transport_controller_->SubscribeStandardizedIceConnectionState( + [this](PeerConnectionInterface::IceConnectionState s) { + JsepTransportControllerTest::OnStandardizedIceConnectionState(s); + }); + transport_controller_->SubscribeIceGatheringState( + [this](cricket::IceGatheringState s) { + JsepTransportControllerTest::OnGatheringState(s); + }); + transport_controller_->SubscribeIceCandidateGathered( + [this](const std::string& transport, + const std::vector& candidates) { + JsepTransportControllerTest::OnCandidatesGathered(transport, + candidates); + }); } std::unique_ptr @@ -265,18 +280,14 @@ class JsepTransportControllerTest : public JsepTransportController::Observer, protected: void OnConnectionState(cricket::IceConnectionState state) { - if (!signaling_thread_->IsCurrent()) { - signaled_on_non_signaling_thread_ = true; - } + ice_signaled_on_thread_ = rtc::Thread::Current(); connection_state_ = state; ++connection_state_signal_count_; } void OnStandardizedIceConnectionState( PeerConnectionInterface::IceConnectionState state) { - if (!signaling_thread_->IsCurrent()) { - signaled_on_non_signaling_thread_ = true; - } + ice_signaled_on_thread_ = rtc::Thread::Current(); ice_connection_state_ = state; ++ice_connection_state_signal_count_; } @@ -285,26 +296,20 @@ class JsepTransportControllerTest : public JsepTransportController::Observer, PeerConnectionInterface::PeerConnectionState state) { RTC_LOG(LS_INFO) << "OnCombinedConnectionState: " << static_cast(state); - if (!signaling_thread_->IsCurrent()) { - signaled_on_non_signaling_thread_ = true; - } + ice_signaled_on_thread_ = rtc::Thread::Current(); combined_connection_state_ = state; ++combined_connection_state_signal_count_; } void OnGatheringState(cricket::IceGatheringState state) { - if (!signaling_thread_->IsCurrent()) { - signaled_on_non_signaling_thread_ = true; - } + ice_signaled_on_thread_ = rtc::Thread::Current(); gathering_state_ = state; ++gathering_state_signal_count_; } void OnCandidatesGathered(const std::string& transport_name, const Candidates& candidates) { - if (!signaling_thread_->IsCurrent()) { - signaled_on_non_signaling_thread_ = true; - } + ice_signaled_on_thread_ = rtc::Thread::Current(); candidates_[transport_name].insert(candidates_[transport_name].end(), candidates.begin(), candidates.end()); ++candidates_signal_count_; @@ -349,7 +354,7 @@ class JsepTransportControllerTest : public JsepTransportController::Observer, std::unique_ptr fake_ice_transport_factory_; std::unique_ptr fake_dtls_transport_factory_; rtc::Thread* const signaling_thread_ = nullptr; - bool signaled_on_non_signaling_thread_ = false; + rtc::Thread* ice_signaled_on_thread_ = nullptr; // Used to verify the SignalRtpTransportChanged/SignalDtlsTransportChanged are // signaled correctly. std::map changed_rtp_transport_by_mid_; @@ -689,8 +694,8 @@ TEST_F(JsepTransportControllerTest, combined_connection_state_, kTimeout); EXPECT_EQ(2, combined_connection_state_signal_count_); - fake_audio_dtls->SetDtlsState(cricket::DTLS_TRANSPORT_CONNECTED); - fake_video_dtls->SetDtlsState(cricket::DTLS_TRANSPORT_CONNECTED); + fake_audio_dtls->SetDtlsState(DtlsTransportState::kConnected); + fake_video_dtls->SetDtlsState(DtlsTransportState::kConnected); // Set the connection count to be 2 and the cricket::FakeIceTransport will set // the transport state to be STATE_CONNECTING. fake_video_dtls->fake_ice_transport()->SetConnectionCount(2); @@ -746,8 +751,8 @@ TEST_F(JsepTransportControllerTest, SignalConnectionStateComplete) { combined_connection_state_, kTimeout); EXPECT_EQ(2, combined_connection_state_signal_count_); - fake_audio_dtls->SetDtlsState(cricket::DTLS_TRANSPORT_CONNECTED); - fake_video_dtls->SetDtlsState(cricket::DTLS_TRANSPORT_CONNECTED); + fake_audio_dtls->SetDtlsState(DtlsTransportState::kConnected); + fake_video_dtls->SetDtlsState(DtlsTransportState::kConnected); // Set the connection count to be 1 and the cricket::FakeIceTransport will set // the transport state to be STATE_COMPLETED. fake_video_dtls->fake_ice_transport()->SetTransportState( @@ -835,7 +840,7 @@ TEST_F(JsepTransportControllerTest, fake_audio_dtls->SetWritable(true); fake_audio_dtls->fake_ice_transport()->SetCandidatesGatheringComplete(); fake_audio_dtls->fake_ice_transport()->SetConnectionCount(1); - fake_audio_dtls->SetDtlsState(cricket::DTLS_TRANSPORT_CONNECTED); + fake_audio_dtls->SetDtlsState(DtlsTransportState::kConnected); EXPECT_EQ(1, gathering_state_signal_count_); // Set the remote description and enable the bundle. @@ -872,11 +877,12 @@ TEST_F(JsepTransportControllerTest, SignalCandidatesGathered) { EXPECT_EQ(1u, candidates_[kAudioMid1].size()); } -TEST_F(JsepTransportControllerTest, IceSignalingOccursOnSignalingThread) { +TEST_F(JsepTransportControllerTest, IceSignalingOccursOnNetworkThread) { network_thread_ = rtc::Thread::CreateWithSocketServer(); network_thread_->Start(); + EXPECT_EQ(ice_signaled_on_thread_, nullptr); CreateJsepTransportController(JsepTransportController::Config(), - signaling_thread_, network_thread_.get(), + network_thread_.get(), /*port_allocator=*/nullptr); CreateLocalDescriptionAndCompleteConnectionOnNetworkThread(); @@ -892,7 +898,10 @@ TEST_F(JsepTransportControllerTest, IceSignalingOccursOnSignalingThread) { EXPECT_EQ_WAIT(1u, candidates_[kVideoMid1].size(), kTimeout); EXPECT_EQ(2, candidates_signal_count_); - EXPECT_TRUE(!signaled_on_non_signaling_thread_); + EXPECT_EQ(ice_signaled_on_thread_, network_thread_.get()); + + network_thread_->Invoke(RTC_FROM_HERE, + [&] { transport_controller_.reset(); }); } // Test that if the TransportController was created with the @@ -1093,6 +1102,512 @@ TEST_F(JsepTransportControllerTest, MultipleMediaSectionsOfSameTypeWithBundle) { ASSERT_TRUE(it2 != changed_dtls_transport_by_mid_.end()); } +TEST_F(JsepTransportControllerTest, MultipleBundleGroups) { + static const char kMid1Audio[] = "1_audio"; + static const char kMid2Video[] = "2_video"; + static const char kMid3Audio[] = "3_audio"; + static const char kMid4Video[] = "4_video"; + + CreateJsepTransportController(JsepTransportController::Config()); + cricket::ContentGroup bundle_group1(cricket::GROUP_TYPE_BUNDLE); + bundle_group1.AddContentName(kMid1Audio); + bundle_group1.AddContentName(kMid2Video); + cricket::ContentGroup bundle_group2(cricket::GROUP_TYPE_BUNDLE); + bundle_group2.AddContentName(kMid3Audio); + bundle_group2.AddContentName(kMid4Video); + + auto local_offer = std::make_unique(); + AddAudioSection(local_offer.get(), kMid1Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(local_offer.get(), kMid2Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(local_offer.get(), kMid3Audio, kIceUfrag3, kIcePwd3, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(local_offer.get(), kMid4Video, kIceUfrag4, kIcePwd4, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + local_offer->AddGroup(bundle_group1); + local_offer->AddGroup(bundle_group2); + + auto remote_answer = std::make_unique(); + AddAudioSection(remote_answer.get(), kMid1Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(remote_answer.get(), kMid2Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(remote_answer.get(), kMid3Audio, kIceUfrag3, kIcePwd3, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(remote_answer.get(), kMid4Video, kIceUfrag4, kIcePwd4, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + remote_answer->AddGroup(bundle_group1); + remote_answer->AddGroup(bundle_group2); + + EXPECT_TRUE(transport_controller_ + ->SetLocalDescription(SdpType::kOffer, local_offer.get()) + .ok()); + EXPECT_TRUE(transport_controller_ + ->SetRemoteDescription(SdpType::kAnswer, remote_answer.get()) + .ok()); + + // Verify that (kMid1Audio,kMid2Video) and (kMid3Audio,kMid4Video) form two + // distinct bundled groups. + auto mid1_transport = transport_controller_->GetRtpTransport(kMid1Audio); + auto mid2_transport = transport_controller_->GetRtpTransport(kMid2Video); + auto mid3_transport = transport_controller_->GetRtpTransport(kMid3Audio); + auto mid4_transport = transport_controller_->GetRtpTransport(kMid4Video); + EXPECT_EQ(mid1_transport, mid2_transport); + EXPECT_EQ(mid3_transport, mid4_transport); + EXPECT_NE(mid1_transport, mid3_transport); + + auto it = changed_rtp_transport_by_mid_.find(kMid1Audio); + ASSERT_TRUE(it != changed_rtp_transport_by_mid_.end()); + EXPECT_EQ(it->second, mid1_transport); + + it = changed_rtp_transport_by_mid_.find(kMid2Video); + ASSERT_TRUE(it != changed_rtp_transport_by_mid_.end()); + EXPECT_EQ(it->second, mid2_transport); + + it = changed_rtp_transport_by_mid_.find(kMid3Audio); + ASSERT_TRUE(it != changed_rtp_transport_by_mid_.end()); + EXPECT_EQ(it->second, mid3_transport); + + it = changed_rtp_transport_by_mid_.find(kMid4Video); + ASSERT_TRUE(it != changed_rtp_transport_by_mid_.end()); + EXPECT_EQ(it->second, mid4_transport); +} + +TEST_F(JsepTransportControllerTest, + MultipleBundleGroupsInOfferButOnlyASingleGroupInAnswer) { + static const char kMid1Audio[] = "1_audio"; + static const char kMid2Video[] = "2_video"; + static const char kMid3Audio[] = "3_audio"; + static const char kMid4Video[] = "4_video"; + + CreateJsepTransportController(JsepTransportController::Config()); + cricket::ContentGroup bundle_group1(cricket::GROUP_TYPE_BUNDLE); + bundle_group1.AddContentName(kMid1Audio); + bundle_group1.AddContentName(kMid2Video); + cricket::ContentGroup bundle_group2(cricket::GROUP_TYPE_BUNDLE); + bundle_group2.AddContentName(kMid3Audio); + bundle_group2.AddContentName(kMid4Video); + + auto local_offer = std::make_unique(); + AddAudioSection(local_offer.get(), kMid1Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(local_offer.get(), kMid2Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(local_offer.get(), kMid3Audio, kIceUfrag3, kIcePwd3, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(local_offer.get(), kMid4Video, kIceUfrag4, kIcePwd4, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + // The offer has both groups. + local_offer->AddGroup(bundle_group1); + local_offer->AddGroup(bundle_group2); + + auto remote_answer = std::make_unique(); + AddAudioSection(remote_answer.get(), kMid1Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(remote_answer.get(), kMid2Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(remote_answer.get(), kMid3Audio, kIceUfrag3, kIcePwd3, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(remote_answer.get(), kMid4Video, kIceUfrag4, kIcePwd4, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + // The answer only has a single group! This is what happens when talking to an + // endpoint that does not have support for multiple BUNDLE groups. + remote_answer->AddGroup(bundle_group1); + + EXPECT_TRUE(transport_controller_ + ->SetLocalDescription(SdpType::kOffer, local_offer.get()) + .ok()); + EXPECT_TRUE(transport_controller_ + ->SetRemoteDescription(SdpType::kAnswer, remote_answer.get()) + .ok()); + + // Verify that (kMid1Audio,kMid2Video) form a bundle group, but that + // kMid3Audio and kMid4Video are unbundled. + auto mid1_transport = transport_controller_->GetRtpTransport(kMid1Audio); + auto mid2_transport = transport_controller_->GetRtpTransport(kMid2Video); + auto mid3_transport = transport_controller_->GetRtpTransport(kMid3Audio); + auto mid4_transport = transport_controller_->GetRtpTransport(kMid4Video); + EXPECT_EQ(mid1_transport, mid2_transport); + EXPECT_NE(mid3_transport, mid4_transport); + EXPECT_NE(mid1_transport, mid3_transport); + EXPECT_NE(mid1_transport, mid4_transport); +} + +TEST_F(JsepTransportControllerTest, MultipleBundleGroupsIllegallyChangeGroup) { + static const char kMid1Audio[] = "1_audio"; + static const char kMid2Video[] = "2_video"; + static const char kMid3Audio[] = "3_audio"; + static const char kMid4Video[] = "4_video"; + + CreateJsepTransportController(JsepTransportController::Config()); + // Offer groups (kMid1Audio,kMid2Video) and (kMid3Audio,kMid4Video). + cricket::ContentGroup offer_bundle_group1(cricket::GROUP_TYPE_BUNDLE); + offer_bundle_group1.AddContentName(kMid1Audio); + offer_bundle_group1.AddContentName(kMid2Video); + cricket::ContentGroup offer_bundle_group2(cricket::GROUP_TYPE_BUNDLE); + offer_bundle_group2.AddContentName(kMid3Audio); + offer_bundle_group2.AddContentName(kMid4Video); + // Answer groups (kMid1Audio,kMid4Video) and (kMid3Audio,kMid2Video), i.e. the + // second group members have switched places. This should get rejected. + cricket::ContentGroup answer_bundle_group1(cricket::GROUP_TYPE_BUNDLE); + answer_bundle_group1.AddContentName(kMid1Audio); + answer_bundle_group1.AddContentName(kMid4Video); + cricket::ContentGroup answer_bundle_group2(cricket::GROUP_TYPE_BUNDLE); + answer_bundle_group2.AddContentName(kMid3Audio); + answer_bundle_group2.AddContentName(kMid2Video); + + auto local_offer = std::make_unique(); + AddAudioSection(local_offer.get(), kMid1Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(local_offer.get(), kMid2Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(local_offer.get(), kMid3Audio, kIceUfrag3, kIcePwd3, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(local_offer.get(), kMid4Video, kIceUfrag4, kIcePwd4, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + local_offer->AddGroup(offer_bundle_group1); + local_offer->AddGroup(offer_bundle_group2); + + auto remote_answer = std::make_unique(); + AddAudioSection(remote_answer.get(), kMid1Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(remote_answer.get(), kMid2Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(remote_answer.get(), kMid3Audio, kIceUfrag3, kIcePwd3, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(remote_answer.get(), kMid4Video, kIceUfrag4, kIcePwd4, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + remote_answer->AddGroup(answer_bundle_group1); + remote_answer->AddGroup(answer_bundle_group2); + + // Accept offer. + EXPECT_TRUE(transport_controller_ + ->SetLocalDescription(SdpType::kOffer, local_offer.get()) + .ok()); + // Reject answer! + EXPECT_FALSE(transport_controller_ + ->SetRemoteDescription(SdpType::kAnswer, remote_answer.get()) + .ok()); +} + +TEST_F(JsepTransportControllerTest, MultipleBundleGroupsInvalidSubsets) { + static const char kMid1Audio[] = "1_audio"; + static const char kMid2Video[] = "2_video"; + static const char kMid3Audio[] = "3_audio"; + static const char kMid4Video[] = "4_video"; + + CreateJsepTransportController(JsepTransportController::Config()); + // Offer groups (kMid1Audio,kMid2Video) and (kMid3Audio,kMid4Video). + cricket::ContentGroup offer_bundle_group1(cricket::GROUP_TYPE_BUNDLE); + offer_bundle_group1.AddContentName(kMid1Audio); + offer_bundle_group1.AddContentName(kMid2Video); + cricket::ContentGroup offer_bundle_group2(cricket::GROUP_TYPE_BUNDLE); + offer_bundle_group2.AddContentName(kMid3Audio); + offer_bundle_group2.AddContentName(kMid4Video); + // Answer groups (kMid1Audio) and (kMid2Video), i.e. the second group was + // moved from the first group. This should get rejected. + cricket::ContentGroup answer_bundle_group1(cricket::GROUP_TYPE_BUNDLE); + answer_bundle_group1.AddContentName(kMid1Audio); + cricket::ContentGroup answer_bundle_group2(cricket::GROUP_TYPE_BUNDLE); + answer_bundle_group2.AddContentName(kMid2Video); + + auto local_offer = std::make_unique(); + AddAudioSection(local_offer.get(), kMid1Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(local_offer.get(), kMid2Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(local_offer.get(), kMid3Audio, kIceUfrag3, kIcePwd3, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(local_offer.get(), kMid4Video, kIceUfrag4, kIcePwd4, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + local_offer->AddGroup(offer_bundle_group1); + local_offer->AddGroup(offer_bundle_group2); + + auto remote_answer = std::make_unique(); + AddAudioSection(remote_answer.get(), kMid1Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(remote_answer.get(), kMid2Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(remote_answer.get(), kMid3Audio, kIceUfrag3, kIcePwd3, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(remote_answer.get(), kMid4Video, kIceUfrag4, kIcePwd4, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + remote_answer->AddGroup(answer_bundle_group1); + remote_answer->AddGroup(answer_bundle_group2); + + // Accept offer. + EXPECT_TRUE(transport_controller_ + ->SetLocalDescription(SdpType::kOffer, local_offer.get()) + .ok()); + // Reject answer! + EXPECT_FALSE(transport_controller_ + ->SetRemoteDescription(SdpType::kAnswer, remote_answer.get()) + .ok()); +} + +TEST_F(JsepTransportControllerTest, MultipleBundleGroupsInvalidOverlap) { + static const char kMid1Audio[] = "1_audio"; + static const char kMid2Video[] = "2_video"; + static const char kMid3Audio[] = "3_audio"; + + CreateJsepTransportController(JsepTransportController::Config()); + // Offer groups (kMid1Audio,kMid3Audio) and (kMid2Video,kMid3Audio), i.e. + // kMid3Audio is in both groups - this is illegal. + cricket::ContentGroup offer_bundle_group1(cricket::GROUP_TYPE_BUNDLE); + offer_bundle_group1.AddContentName(kMid1Audio); + offer_bundle_group1.AddContentName(kMid3Audio); + cricket::ContentGroup offer_bundle_group2(cricket::GROUP_TYPE_BUNDLE); + offer_bundle_group2.AddContentName(kMid2Video); + offer_bundle_group2.AddContentName(kMid3Audio); + + auto offer = std::make_unique(); + AddAudioSection(offer.get(), kMid1Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(offer.get(), kMid2Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(offer.get(), kMid3Audio, kIceUfrag3, kIcePwd3, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + offer->AddGroup(offer_bundle_group1); + offer->AddGroup(offer_bundle_group2); + + // Reject offer, both if set as local or remote. + EXPECT_FALSE( + transport_controller_->SetLocalDescription(SdpType::kOffer, offer.get()) + .ok()); + EXPECT_FALSE( + transport_controller_->SetRemoteDescription(SdpType::kOffer, offer.get()) + .ok()); +} + +TEST_F(JsepTransportControllerTest, MultipleBundleGroupsUnbundleFirstMid) { + static const char kMid1Audio[] = "1_audio"; + static const char kMid2Audio[] = "2_audio"; + static const char kMid3Audio[] = "3_audio"; + static const char kMid4Video[] = "4_video"; + static const char kMid5Video[] = "5_video"; + static const char kMid6Video[] = "6_video"; + + CreateJsepTransportController(JsepTransportController::Config()); + // Offer groups (kMid1Audio,kMid2Audio,kMid3Audio) and + // (kMid4Video,kMid5Video,kMid6Video). + cricket::ContentGroup offer_bundle_group1(cricket::GROUP_TYPE_BUNDLE); + offer_bundle_group1.AddContentName(kMid1Audio); + offer_bundle_group1.AddContentName(kMid2Audio); + offer_bundle_group1.AddContentName(kMid3Audio); + cricket::ContentGroup offer_bundle_group2(cricket::GROUP_TYPE_BUNDLE); + offer_bundle_group2.AddContentName(kMid4Video); + offer_bundle_group2.AddContentName(kMid5Video); + offer_bundle_group2.AddContentName(kMid6Video); + // Answer groups (kMid2Audio,kMid3Audio) and (kMid5Video,kMid6Video), i.e. + // we've moved the first MIDs out of the groups. + cricket::ContentGroup answer_bundle_group1(cricket::GROUP_TYPE_BUNDLE); + answer_bundle_group1.AddContentName(kMid2Audio); + answer_bundle_group1.AddContentName(kMid3Audio); + cricket::ContentGroup answer_bundle_group2(cricket::GROUP_TYPE_BUNDLE); + answer_bundle_group2.AddContentName(kMid5Video); + answer_bundle_group2.AddContentName(kMid6Video); + + auto local_offer = std::make_unique(); + AddAudioSection(local_offer.get(), kMid1Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(local_offer.get(), kMid2Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(local_offer.get(), kMid3Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(local_offer.get(), kMid4Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(local_offer.get(), kMid5Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(local_offer.get(), kMid6Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + local_offer->AddGroup(offer_bundle_group1); + local_offer->AddGroup(offer_bundle_group2); + + auto remote_answer = std::make_unique(); + AddAudioSection(remote_answer.get(), kMid1Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(remote_answer.get(), kMid2Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(remote_answer.get(), kMid3Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(remote_answer.get(), kMid4Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(remote_answer.get(), kMid5Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(remote_answer.get(), kMid6Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + remote_answer->AddGroup(answer_bundle_group1); + remote_answer->AddGroup(answer_bundle_group2); + + EXPECT_TRUE(transport_controller_ + ->SetLocalDescription(SdpType::kOffer, local_offer.get()) + .ok()); + EXPECT_TRUE(transport_controller_ + ->SetRemoteDescription(SdpType::kAnswer, remote_answer.get()) + .ok()); + + auto mid1_transport = transport_controller_->GetRtpTransport(kMid1Audio); + auto mid2_transport = transport_controller_->GetRtpTransport(kMid2Audio); + auto mid3_transport = transport_controller_->GetRtpTransport(kMid3Audio); + auto mid4_transport = transport_controller_->GetRtpTransport(kMid4Video); + auto mid5_transport = transport_controller_->GetRtpTransport(kMid5Video); + auto mid6_transport = transport_controller_->GetRtpTransport(kMid6Video); + EXPECT_NE(mid1_transport, mid2_transport); + EXPECT_EQ(mid2_transport, mid3_transport); + EXPECT_NE(mid4_transport, mid5_transport); + EXPECT_EQ(mid5_transport, mid6_transport); + EXPECT_NE(mid1_transport, mid4_transport); + EXPECT_NE(mid2_transport, mid5_transport); +} + +TEST_F(JsepTransportControllerTest, MultipleBundleGroupsChangeFirstMid) { + static const char kMid1Audio[] = "1_audio"; + static const char kMid2Audio[] = "2_audio"; + static const char kMid3Audio[] = "3_audio"; + static const char kMid4Video[] = "4_video"; + static const char kMid5Video[] = "5_video"; + static const char kMid6Video[] = "6_video"; + + CreateJsepTransportController(JsepTransportController::Config()); + // Offer groups (kMid1Audio,kMid2Audio,kMid3Audio) and + // (kMid4Video,kMid5Video,kMid6Video). + cricket::ContentGroup offer_bundle_group1(cricket::GROUP_TYPE_BUNDLE); + offer_bundle_group1.AddContentName(kMid1Audio); + offer_bundle_group1.AddContentName(kMid2Audio); + offer_bundle_group1.AddContentName(kMid3Audio); + cricket::ContentGroup offer_bundle_group2(cricket::GROUP_TYPE_BUNDLE); + offer_bundle_group2.AddContentName(kMid4Video); + offer_bundle_group2.AddContentName(kMid5Video); + offer_bundle_group2.AddContentName(kMid6Video); + // Answer groups (kMid2Audio,kMid1Audio,kMid3Audio) and + // (kMid5Video,kMid6Video,kMid4Video), i.e. we've changed which MID is first + // but accept the whole group. + cricket::ContentGroup answer_bundle_group1(cricket::GROUP_TYPE_BUNDLE); + answer_bundle_group1.AddContentName(kMid2Audio); + answer_bundle_group1.AddContentName(kMid1Audio); + answer_bundle_group1.AddContentName(kMid3Audio); + cricket::ContentGroup answer_bundle_group2(cricket::GROUP_TYPE_BUNDLE); + answer_bundle_group2.AddContentName(kMid5Video); + answer_bundle_group2.AddContentName(kMid6Video); + answer_bundle_group2.AddContentName(kMid4Video); + + auto local_offer = std::make_unique(); + AddAudioSection(local_offer.get(), kMid1Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(local_offer.get(), kMid2Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(local_offer.get(), kMid3Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(local_offer.get(), kMid4Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(local_offer.get(), kMid5Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(local_offer.get(), kMid6Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + local_offer->AddGroup(offer_bundle_group1); + local_offer->AddGroup(offer_bundle_group2); + + auto remote_answer = std::make_unique(); + AddAudioSection(remote_answer.get(), kMid1Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(remote_answer.get(), kMid2Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddAudioSection(remote_answer.get(), kMid3Audio, kIceUfrag1, kIcePwd1, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(remote_answer.get(), kMid4Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(remote_answer.get(), kMid5Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + AddVideoSection(remote_answer.get(), kMid6Video, kIceUfrag2, kIcePwd2, + cricket::ICEMODE_FULL, cricket::CONNECTIONROLE_ACTPASS, + nullptr); + remote_answer->AddGroup(answer_bundle_group1); + remote_answer->AddGroup(answer_bundle_group2); + + EXPECT_TRUE(transport_controller_ + ->SetLocalDescription(SdpType::kOffer, local_offer.get()) + .ok()); + + // The fact that we accept this answer is actually a bug. If we accept the + // first MID to be in the group, we should also accept that it is the tagged + // one. + // TODO(https://crbug.com/webrtc/12699): When this issue is fixed, change this + // to EXPECT_FALSE and remove the below expectations about transports. + EXPECT_TRUE(transport_controller_ + ->SetRemoteDescription(SdpType::kAnswer, remote_answer.get()) + .ok()); + auto mid1_transport = transport_controller_->GetRtpTransport(kMid1Audio); + auto mid2_transport = transport_controller_->GetRtpTransport(kMid2Audio); + auto mid3_transport = transport_controller_->GetRtpTransport(kMid3Audio); + auto mid4_transport = transport_controller_->GetRtpTransport(kMid4Video); + auto mid5_transport = transport_controller_->GetRtpTransport(kMid5Video); + auto mid6_transport = transport_controller_->GetRtpTransport(kMid6Video); + EXPECT_NE(mid1_transport, mid4_transport); + EXPECT_EQ(mid1_transport, mid2_transport); + EXPECT_EQ(mid2_transport, mid3_transport); + EXPECT_EQ(mid4_transport, mid5_transport); + EXPECT_EQ(mid5_transport, mid6_transport); +} + // Tests that only a subset of all the m= sections are bundled. TEST_F(JsepTransportControllerTest, BundleSubsetOfMediaSections) { CreateJsepTransportController(JsepTransportController::Config()); diff --git a/pc/jsep_transport_unittest.cc b/pc/jsep_transport_unittest.cc index d8f2fff621..5f4334068a 100644 --- a/pc/jsep_transport_unittest.cc +++ b/pc/jsep_transport_unittest.cc @@ -48,8 +48,7 @@ rtc::scoped_refptr CreateIceTransport( return nullptr; } - return new rtc::RefCountedObject( - std::move(internal)); + return rtc::make_ref_counted(std::move(internal)); } class JsepTransport2Test : public ::testing::Test, public sigslot::has_slots<> { @@ -118,8 +117,7 @@ class JsepTransport2Test : public ::testing::Test, public sigslot::has_slots<> { kTransportName, /*local_certificate=*/nullptr, std::move(ice), std::move(rtcp_ice), std::move(unencrypted_rtp_transport), std::move(sdes_transport), std::move(dtls_srtp_transport), - /*datagram_rtp_transport=*/nullptr, std::move(rtp_dtls_transport), - std::move(rtcp_dtls_transport), + std::move(rtp_dtls_transport), std::move(rtcp_dtls_transport), /*sctp_transport=*/nullptr); signal_rtcp_mux_active_received_ = false; diff --git a/pc/local_audio_source.cc b/pc/local_audio_source.cc index 22ab1c39c3..3fcad50a1d 100644 --- a/pc/local_audio_source.cc +++ b/pc/local_audio_source.cc @@ -18,8 +18,7 @@ namespace webrtc { rtc::scoped_refptr LocalAudioSource::Create( const cricket::AudioOptions* audio_options) { - rtc::scoped_refptr source( - new rtc::RefCountedObject()); + auto source = rtc::make_ref_counted(); source->Initialize(audio_options); return source; } diff --git a/pc/media_protocol_names.cc b/pc/media_protocol_names.cc index 3def3f0f20..ae4fcf3391 100644 --- a/pc/media_protocol_names.cc +++ b/pc/media_protocol_names.cc @@ -10,6 +10,9 @@ #include "pc/media_protocol_names.h" +#include +#include + namespace cricket { // There are multiple variants of the RTP protocol stack, including diff --git a/pc/media_session.cc b/pc/media_session.cc index 4fd3efa521..3c73ddf535 100644 --- a/pc/media_session.cc +++ b/pc/media_session.cc @@ -10,8 +10,9 @@ #include "pc/media_session.h" +#include + #include -#include #include #include #include @@ -20,20 +21,24 @@ #include "absl/algorithm/container.h" #include "absl/strings/match.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "api/crypto_params.h" -#include "media/base/h264_profile_level_id.h" +#include "api/video_codecs/h264_profile_level_id.h" +#include "media/base/codec.h" #include "media/base/media_constants.h" +#include "media/base/sdp_video_format_utils.h" #include "media/sctp/sctp_transport_internal.h" #include "p2p/base/p2p_constants.h" #include "pc/channel_manager.h" #include "pc/media_protocol_names.h" #include "pc/rtp_media_utils.h" -#include "pc/srtp_filter.h" #include "pc/used_ids.h" #include "rtc_base/checks.h" #include "rtc_base/helpers.h" #include "rtc_base/logging.h" +#include "rtc_base/ssl_stream_adapter.h" +#include "rtc_base/string_encode.h" #include "rtc_base/third_party/base64/base64.h" #include "rtc_base/unique_id_generator.h" #include "system_wrappers/include/field_trial.h" @@ -789,10 +794,16 @@ static void NegotiateCodecs(const std::vector& local_codecs, // FindMatchingCodec shouldn't return something with no apt value. RTC_DCHECK(apt_it != theirs.params.end()); negotiated.SetParam(kCodecParamAssociatedPayloadType, apt_it->second); + + // We support parsing the declarative rtx-time parameter. + const auto rtx_time_it = theirs.params.find(kCodecParamRtxTime); + if (rtx_time_it != theirs.params.end()) { + negotiated.SetParam(kCodecParamRtxTime, rtx_time_it->second); + } } if (absl::EqualsIgnoreCase(ours.name, kH264CodecName)) { - webrtc::H264::GenerateProfileLevelIdForAnswer( - ours.params, theirs.params, &negotiated.params); + webrtc::H264GenerateProfileLevelIdForAnswer(ours.params, theirs.params, + &negotiated.params); } negotiated.id = theirs.id; negotiated.name = theirs.name; @@ -978,68 +989,6 @@ static Codecs MatchCodecPreference( return filtered_codecs; } -static bool FindByUriAndEncryption(const RtpHeaderExtensions& extensions, - const webrtc::RtpExtension& ext_to_match, - webrtc::RtpExtension* found_extension) { - auto it = absl::c_find_if( - extensions, [&ext_to_match](const webrtc::RtpExtension& extension) { - // We assume that all URIs are given in a canonical - // format. - return extension.uri == ext_to_match.uri && - extension.encrypt == ext_to_match.encrypt; - }); - if (it == extensions.end()) { - return false; - } - if (found_extension) { - *found_extension = *it; - } - return true; -} - -static bool FindByUri(const RtpHeaderExtensions& extensions, - const webrtc::RtpExtension& ext_to_match, - webrtc::RtpExtension* found_extension) { - // We assume that all URIs are given in a canonical format. - const webrtc::RtpExtension* found = - webrtc::RtpExtension::FindHeaderExtensionByUri(extensions, - ext_to_match.uri); - if (!found) { - return false; - } - if (found_extension) { - *found_extension = *found; - } - return true; -} - -static bool FindByUriWithEncryptionPreference( - const RtpHeaderExtensions& extensions, - absl::string_view uri_to_match, - bool encryption_preference, - webrtc::RtpExtension* found_extension) { - const webrtc::RtpExtension* unencrypted_extension = nullptr; - for (const webrtc::RtpExtension& extension : extensions) { - // We assume that all URIs are given in a canonical format. - if (extension.uri == uri_to_match) { - if (!encryption_preference || extension.encrypt) { - if (found_extension) { - *found_extension = extension; - } - return true; - } - unencrypted_extension = &extension; - } - } - if (unencrypted_extension) { - if (found_extension) { - *found_extension = *unencrypted_extension; - } - return true; - } - return false; -} - // Adds all extensions from |reference_extensions| to |offered_extensions| that // don't already exist in |offered_extensions| and ensure the IDs don't // collide. If an extension is added, it's also added to |regular_extensions| or @@ -1054,22 +1003,28 @@ static void MergeRtpHdrExts(const RtpHeaderExtensions& reference_extensions, RtpHeaderExtensions* encrypted_extensions, UsedRtpHeaderExtensionIds* used_ids) { for (auto reference_extension : reference_extensions) { - if (!FindByUriAndEncryption(*offered_extensions, reference_extension, - nullptr)) { - webrtc::RtpExtension existing; + if (!webrtc::RtpExtension::FindHeaderExtensionByUriAndEncryption( + *offered_extensions, reference_extension.uri, + reference_extension.encrypt)) { if (reference_extension.encrypt) { - if (FindByUriAndEncryption(*encrypted_extensions, reference_extension, - &existing)) { - offered_extensions->push_back(existing); + const webrtc::RtpExtension* existing = + webrtc::RtpExtension::FindHeaderExtensionByUriAndEncryption( + *encrypted_extensions, reference_extension.uri, + reference_extension.encrypt); + if (existing) { + offered_extensions->push_back(*existing); } else { used_ids->FindAndSetIdUsed(&reference_extension); encrypted_extensions->push_back(reference_extension); offered_extensions->push_back(reference_extension); } } else { - if (FindByUriAndEncryption(*regular_extensions, reference_extension, - &existing)) { - offered_extensions->push_back(existing); + const webrtc::RtpExtension* existing = + webrtc::RtpExtension::FindHeaderExtensionByUriAndEncryption( + *regular_extensions, reference_extension.uri, + reference_extension.encrypt); + if (existing) { + offered_extensions->push_back(*existing); } else { used_ids->FindAndSetIdUsed(&reference_extension); regular_extensions->push_back(reference_extension); @@ -1080,41 +1035,86 @@ static void MergeRtpHdrExts(const RtpHeaderExtensions& reference_extensions, } } -static void AddEncryptedVersionsOfHdrExts(RtpHeaderExtensions* extensions, - RtpHeaderExtensions* all_extensions, - UsedRtpHeaderExtensionIds* used_ids) { - RtpHeaderExtensions encrypted_extensions; - for (const webrtc::RtpExtension& extension : *extensions) { - webrtc::RtpExtension existing; - // Don't add encrypted extensions again that were already included in a - // previous offer or regular extensions that are also included as encrypted - // extensions. - if (extension.encrypt || - !webrtc::RtpExtension::IsEncryptionSupported(extension.uri) || - (FindByUriWithEncryptionPreference(*extensions, extension.uri, true, - &existing) && - existing.encrypt)) { +static void AddEncryptedVersionsOfHdrExts( + RtpHeaderExtensions* offered_extensions, + RtpHeaderExtensions* encrypted_extensions, + UsedRtpHeaderExtensionIds* used_ids) { + RtpHeaderExtensions encrypted_extensions_to_add; + for (const auto& extension : *offered_extensions) { + // Skip existing encrypted offered extension + if (extension.encrypt) { continue; } - if (FindByUri(*all_extensions, extension, &existing)) { - encrypted_extensions.push_back(existing); - } else { - webrtc::RtpExtension encrypted(extension); - encrypted.encrypt = true; - used_ids->FindAndSetIdUsed(&encrypted); - all_extensions->push_back(encrypted); - encrypted_extensions.push_back(encrypted); + // Skip if we cannot encrypt the extension + if (!webrtc::RtpExtension::IsEncryptionSupported(extension.uri)) { + continue; + } + + // Skip if an encrypted extension with that URI already exists in the + // offered extensions. + const bool have_encrypted_extension = + webrtc::RtpExtension::FindHeaderExtensionByUriAndEncryption( + *offered_extensions, extension.uri, true); + if (have_encrypted_extension) { + continue; } + + // Determine if a shared encrypted extension with that URI already exists. + const webrtc::RtpExtension* shared_encrypted_extension = + webrtc::RtpExtension::FindHeaderExtensionByUriAndEncryption( + *encrypted_extensions, extension.uri, true); + if (shared_encrypted_extension) { + // Re-use the shared encrypted extension + encrypted_extensions_to_add.push_back(*shared_encrypted_extension); + continue; + } + + // None exists. Create a new shared encrypted extension from the + // non-encrypted one. + webrtc::RtpExtension new_encrypted_extension(extension); + new_encrypted_extension.encrypt = true; + used_ids->FindAndSetIdUsed(&new_encrypted_extension); + encrypted_extensions->push_back(new_encrypted_extension); + encrypted_extensions_to_add.push_back(new_encrypted_extension); } - extensions->insert(extensions->end(), encrypted_extensions.begin(), - encrypted_extensions.end()); + + // Append the additional encrypted extensions to be offered + offered_extensions->insert(offered_extensions->end(), + encrypted_extensions_to_add.begin(), + encrypted_extensions_to_add.end()); +} + +// Mostly identical to RtpExtension::FindHeaderExtensionByUri but discards any +// encrypted extensions that this implementation cannot encrypt. +static const webrtc::RtpExtension* FindHeaderExtensionByUriDiscardUnsupported( + const std::vector& extensions, + absl::string_view uri, + webrtc::RtpExtension::Filter filter) { + // Note: While it's technically possible to decrypt extensions that we don't + // encrypt, the symmetric API of libsrtp does not allow us to supply + // different IDs for encryption/decryption of header extensions depending on + // whether the packet is inbound or outbound. Thereby, we are limited to + // what we can send in encrypted form. + if (!webrtc::RtpExtension::IsEncryptionSupported(uri)) { + // If there's no encryption support and we only want encrypted extensions, + // there's no point in continuing the search here. + if (filter == webrtc::RtpExtension::kRequireEncryptedExtension) { + return nullptr; + } + + // Instruct to only return non-encrypted extensions + filter = webrtc::RtpExtension::Filter::kDiscardEncryptedExtension; + } + + return webrtc::RtpExtension::FindHeaderExtensionByUri(extensions, uri, + filter); } static void NegotiateRtpHeaderExtensions( const RtpHeaderExtensions& local_extensions, const RtpHeaderExtensions& offered_extensions, - bool enable_encrypted_rtp_header_extensions, + webrtc::RtpExtension::Filter filter, RtpHeaderExtensions* negotiated_extensions) { // TransportSequenceNumberV2 is not offered by default. The special logic for // the TransportSequenceNumber extensions works as follows: @@ -1123,9 +1123,9 @@ static void NegotiateRtpHeaderExtensions( // V1 and V2 V2 regardless of local_extensions. // V2 V2 regardless of local_extensions. const webrtc::RtpExtension* transport_sequence_number_v2_offer = - webrtc::RtpExtension::FindHeaderExtensionByUri( + FindHeaderExtensionByUriDiscardUnsupported( offered_extensions, - webrtc::RtpExtension::kTransportSequenceNumberV2Uri); + webrtc::RtpExtension::kTransportSequenceNumberV2Uri, filter); bool frame_descriptor_in_local = false; bool dependency_descriptor_in_local = false; @@ -1138,10 +1138,10 @@ static void NegotiateRtpHeaderExtensions( dependency_descriptor_in_local = true; else if (ours.uri == webrtc::RtpExtension::kAbsoluteCaptureTimeUri) abs_capture_time_in_local = true; - webrtc::RtpExtension theirs; - if (FindByUriWithEncryptionPreference( - offered_extensions, ours.uri, - enable_encrypted_rtp_header_extensions, &theirs)) { + const webrtc::RtpExtension* theirs = + FindHeaderExtensionByUriDiscardUnsupported(offered_extensions, ours.uri, + filter); + if (theirs) { if (transport_sequence_number_v2_offer && ours.uri == webrtc::RtpExtension::kTransportSequenceNumberUri) { // Don't respond to @@ -1151,7 +1151,7 @@ static void NegotiateRtpHeaderExtensions( continue; } else { // We respond with their RTP header extension id. - negotiated_extensions->push_back(theirs); + negotiated_extensions->push_back(*theirs); } } } @@ -1163,28 +1163,35 @@ static void NegotiateRtpHeaderExtensions( // Frame descriptors support. If the extension is not present locally, but is // in the offer, we add it to the list. - webrtc::RtpExtension theirs; - if (!dependency_descriptor_in_local && - FindByUriWithEncryptionPreference( - offered_extensions, webrtc::RtpExtension::kDependencyDescriptorUri, - enable_encrypted_rtp_header_extensions, &theirs)) { - negotiated_extensions->push_back(theirs); - } - if (!frame_descriptor_in_local && - FindByUriWithEncryptionPreference( - offered_extensions, - webrtc::RtpExtension::kGenericFrameDescriptorUri00, - enable_encrypted_rtp_header_extensions, &theirs)) { - negotiated_extensions->push_back(theirs); + if (!dependency_descriptor_in_local) { + const webrtc::RtpExtension* theirs = + FindHeaderExtensionByUriDiscardUnsupported( + offered_extensions, webrtc::RtpExtension::kDependencyDescriptorUri, + filter); + if (theirs) { + negotiated_extensions->push_back(*theirs); + } + } + if (!frame_descriptor_in_local) { + const webrtc::RtpExtension* theirs = + FindHeaderExtensionByUriDiscardUnsupported( + offered_extensions, + webrtc::RtpExtension::kGenericFrameDescriptorUri00, filter); + if (theirs) { + negotiated_extensions->push_back(*theirs); + } } // Absolute capture time support. If the extension is not present locally, but // is in the offer, we add it to the list. - if (!abs_capture_time_in_local && - FindByUriWithEncryptionPreference( - offered_extensions, webrtc::RtpExtension::kAbsoluteCaptureTimeUri, - enable_encrypted_rtp_header_extensions, &theirs)) { - negotiated_extensions->push_back(theirs); + if (!abs_capture_time_in_local) { + const webrtc::RtpExtension* theirs = + FindHeaderExtensionByUriDiscardUnsupported( + offered_extensions, webrtc::RtpExtension::kAbsoluteCaptureTimeUri, + filter); + if (theirs) { + negotiated_extensions->push_back(*theirs); + } } } @@ -1239,10 +1246,14 @@ static bool CreateMediaContentAnswer( bool bundle_enabled, MediaContentDescription* answer) { answer->set_extmap_allow_mixed_enum(offer->extmap_allow_mixed_enum()); + const webrtc::RtpExtension::Filter extensions_filter = + enable_encrypted_rtp_header_extensions + ? webrtc::RtpExtension::Filter::kPreferEncryptedExtension + : webrtc::RtpExtension::Filter::kDiscardEncryptedExtension; RtpHeaderExtensions negotiated_rtp_extensions; - NegotiateRtpHeaderExtensions( - local_rtp_extensions, offer->rtp_header_extensions(), - enable_encrypted_rtp_header_extensions, &negotiated_rtp_extensions); + NegotiateRtpHeaderExtensions(local_rtp_extensions, + offer->rtp_header_extensions(), + extensions_filter, &negotiated_rtp_extensions); answer->set_rtp_header_extensions(negotiated_rtp_extensions); answer->set_rtcp_mux(session_options.rtcp_mux_enabled && offer->rtcp_mux()); @@ -1370,14 +1381,6 @@ void MediaDescriptionOptions::AddVideoSender( num_sim_layers); } -void MediaDescriptionOptions::AddRtpDataChannel(const std::string& track_id, - const std::string& stream_id) { - RTC_DCHECK(type == MEDIA_TYPE_DATA); - // TODO(steveanton): Is it the case that RtpDataChannel will never have more - // than one stream? - AddSenderInternal(track_id, {stream_id}, {}, SimulcastLayerList(), 1); -} - void MediaDescriptionOptions::AddSenderInternal( const std::string& track_id, const std::vector& stream_ids, @@ -1418,7 +1421,6 @@ MediaSessionDescriptionFactory::MediaSessionDescriptionFactory( channel_manager->GetSupportedAudioReceiveCodecs(&audio_recv_codecs_); channel_manager->GetSupportedVideoSendCodecs(&video_send_codecs_); channel_manager->GetSupportedVideoReceiveCodecs(&video_recv_codecs_); - channel_manager->GetSupportedDataCodecs(&rtp_data_codecs_); ComputeAudioCodecsIntersectionAndUnion(); ComputeVideoCodecsIntersectionAndUnion(); } @@ -1511,16 +1513,8 @@ std::unique_ptr MediaSessionDescriptionFactory::CreateOffer( AudioCodecs offer_audio_codecs; VideoCodecs offer_video_codecs; - RtpDataCodecs offer_rtp_data_codecs; - GetCodecsForOffer( - current_active_contents, &offer_audio_codecs, &offer_video_codecs, - session_options.data_channel_type == DataChannelType::DCT_SCTP - ? nullptr - : &offer_rtp_data_codecs); - if (!session_options.vad_enabled) { - // If application doesn't want CN codecs in offer. - StripCNCodecs(&offer_audio_codecs); - } + GetCodecsForOffer(current_active_contents, &offer_audio_codecs, + &offer_video_codecs); AudioVideoRtpHeaderExtensions extensions_with_ids = GetOfferedRtpHeaderExtensionsWithIds( current_active_contents, session_options.offer_extmap_allow_mixed, @@ -1564,8 +1558,8 @@ std::unique_ptr MediaSessionDescriptionFactory::CreateOffer( case MEDIA_TYPE_DATA: if (!AddDataContentForOffer(media_description_options, session_options, current_content, current_description, - offer_rtp_data_codecs, ¤t_streams, - offer.get(), &ice_credentials)) { + ¤t_streams, offer.get(), + &ice_credentials)) { return nullptr; } break; @@ -1663,23 +1657,26 @@ MediaSessionDescriptionFactory::CreateAnswer( // sections. AudioCodecs answer_audio_codecs; VideoCodecs answer_video_codecs; - RtpDataCodecs answer_rtp_data_codecs; GetCodecsForAnswer(current_active_contents, *offer, &answer_audio_codecs, - &answer_video_codecs, &answer_rtp_data_codecs); - - if (!session_options.vad_enabled) { - // If application doesn't want CN codecs in answer. - StripCNCodecs(&answer_audio_codecs); - } + &answer_video_codecs); auto answer = std::make_unique(); // If the offer supports BUNDLE, and we want to use it too, create a BUNDLE // group in the answer with the appropriate content names. - const ContentGroup* offer_bundle = offer->GetGroupByName(GROUP_TYPE_BUNDLE); - ContentGroup answer_bundle(GROUP_TYPE_BUNDLE); - // Transport info shared by the bundle group. - std::unique_ptr bundle_transport; + std::vector offer_bundles = + offer->GetGroupsByName(GROUP_TYPE_BUNDLE); + // There are as many answer BUNDLE groups as offer BUNDLE groups (even if + // rejected, we respond with an empty group). |offer_bundles|, + // |answer_bundles| and |bundle_transports| share the same size and indices. + std::vector answer_bundles; + std::vector> bundle_transports; + answer_bundles.reserve(offer_bundles.size()); + bundle_transports.reserve(offer_bundles.size()); + for (size_t i = 0; i < offer_bundles.size(); ++i) { + answer_bundles.emplace_back(GROUP_TYPE_BUNDLE); + bundle_transports.emplace_back(nullptr); + } answer->set_extmap_allow_mixed(offer->extmap_allow_mixed()); @@ -1694,6 +1691,18 @@ MediaSessionDescriptionFactory::CreateAnswer( RTC_DCHECK( IsMediaContentOfType(offer_content, media_description_options.type)); RTC_DCHECK(media_description_options.mid == offer_content->name); + // Get the index of the BUNDLE group that this MID belongs to, if any. + absl::optional bundle_index; + for (size_t i = 0; i < offer_bundles.size(); ++i) { + if (offer_bundles[i]->HasContentName(media_description_options.mid)) { + bundle_index = i; + break; + } + } + TransportInfo* bundle_transport = + bundle_index.has_value() ? bundle_transports[bundle_index.value()].get() + : nullptr; + const ContentInfo* current_content = nullptr; if (current_description && msection_index < current_description->contents().size()) { @@ -1706,26 +1715,25 @@ MediaSessionDescriptionFactory::CreateAnswer( case MEDIA_TYPE_AUDIO: if (!AddAudioContentForAnswer( media_description_options, session_options, offer_content, - offer, current_content, current_description, - bundle_transport.get(), answer_audio_codecs, header_extensions, - ¤t_streams, answer.get(), &ice_credentials)) { + offer, current_content, current_description, bundle_transport, + answer_audio_codecs, header_extensions, ¤t_streams, + answer.get(), &ice_credentials)) { return nullptr; } break; case MEDIA_TYPE_VIDEO: if (!AddVideoContentForAnswer( media_description_options, session_options, offer_content, - offer, current_content, current_description, - bundle_transport.get(), answer_video_codecs, header_extensions, - ¤t_streams, answer.get(), &ice_credentials)) { + offer, current_content, current_description, bundle_transport, + answer_video_codecs, header_extensions, ¤t_streams, + answer.get(), &ice_credentials)) { return nullptr; } break; case MEDIA_TYPE_DATA: if (!AddDataContentForAnswer( media_description_options, session_options, offer_content, - offer, current_content, current_description, - bundle_transport.get(), answer_rtp_data_codecs, + offer, current_content, current_description, bundle_transport, ¤t_streams, answer.get(), &ice_credentials)) { return nullptr; } @@ -1733,8 +1741,8 @@ MediaSessionDescriptionFactory::CreateAnswer( case MEDIA_TYPE_UNSUPPORTED: if (!AddUnsupportedContentForAnswer( media_description_options, session_options, offer_content, - offer, current_content, current_description, - bundle_transport.get(), answer.get(), &ice_credentials)) { + offer, current_content, current_description, bundle_transport, + answer.get(), &ice_credentials)) { return nullptr; } break; @@ -1745,37 +1753,41 @@ MediaSessionDescriptionFactory::CreateAnswer( // See if we can add the newly generated m= section to the BUNDLE group in // the answer. ContentInfo& added = answer->contents().back(); - if (!added.rejected && session_options.bundle_enabled && offer_bundle && - offer_bundle->HasContentName(added.name)) { - answer_bundle.AddContentName(added.name); - bundle_transport.reset( + if (!added.rejected && session_options.bundle_enabled && + bundle_index.has_value()) { + // The |bundle_index| is for |media_description_options.mid|. + RTC_DCHECK_EQ(media_description_options.mid, added.name); + answer_bundles[bundle_index.value()].AddContentName(added.name); + bundle_transports[bundle_index.value()].reset( new TransportInfo(*answer->GetTransportInfoByName(added.name))); } } - // If a BUNDLE group was offered, put a BUNDLE group in the answer even if - // it's empty. RFC5888 says: + // If BUNDLE group(s) were offered, put the same number of BUNDLE groups in + // the answer even if they're empty. RFC5888 says: // // A SIP entity that receives an offer that contains an "a=group" line // with semantics that are understood MUST return an answer that // contains an "a=group" line with the same semantics. - if (offer_bundle) { - answer->AddGroup(answer_bundle); - } - - if (answer_bundle.FirstContentName()) { - // Share the same ICE credentials and crypto params across all contents, - // as BUNDLE requires. - if (!UpdateTransportInfoForBundle(answer_bundle, answer.get())) { - RTC_LOG(LS_ERROR) - << "CreateAnswer failed to UpdateTransportInfoForBundle."; - return NULL; - } + if (!offer_bundles.empty()) { + for (const ContentGroup& answer_bundle : answer_bundles) { + answer->AddGroup(answer_bundle); + + if (answer_bundle.FirstContentName()) { + // Share the same ICE credentials and crypto params across all contents, + // as BUNDLE requires. + if (!UpdateTransportInfoForBundle(answer_bundle, answer.get())) { + RTC_LOG(LS_ERROR) + << "CreateAnswer failed to UpdateTransportInfoForBundle."; + return NULL; + } - if (!UpdateCryptoParamsForBundle(answer_bundle, answer.get())) { - RTC_LOG(LS_ERROR) - << "CreateAnswer failed to UpdateCryptoParamsForBundle."; - return NULL; + if (!UpdateCryptoParamsForBundle(answer_bundle, answer.get())) { + RTC_LOG(LS_ERROR) + << "CreateAnswer failed to UpdateCryptoParamsForBundle."; + return NULL; + } + } } } @@ -1889,7 +1901,6 @@ void MergeCodecsFromDescription( const std::vector& current_active_contents, AudioCodecs* audio_codecs, VideoCodecs* video_codecs, - RtpDataCodecs* rtp_data_codecs, UsedPayloadTypes* used_pltypes) { for (const ContentInfo* content : current_active_contents) { if (IsMediaContentOfType(content, MEDIA_TYPE_AUDIO)) { @@ -1900,14 +1911,6 @@ void MergeCodecsFromDescription( const VideoContentDescription* video = content->media_description()->as_video(); MergeCodecs(video->codecs(), video_codecs, used_pltypes); - } else if (IsMediaContentOfType(content, MEDIA_TYPE_DATA)) { - const RtpDataContentDescription* data = - content->media_description()->as_rtp_data(); - if (data) { - // Only relevant for RTP datachannels - MergeCodecs(data->codecs(), rtp_data_codecs, - used_pltypes); - } } } } @@ -1921,22 +1924,17 @@ void MergeCodecsFromDescription( void MediaSessionDescriptionFactory::GetCodecsForOffer( const std::vector& current_active_contents, AudioCodecs* audio_codecs, - VideoCodecs* video_codecs, - RtpDataCodecs* rtp_data_codecs) const { + VideoCodecs* video_codecs) const { // First - get all codecs from the current description if the media type // is used. Add them to |used_pltypes| so the payload type is not reused if a // new media type is added. UsedPayloadTypes used_pltypes; MergeCodecsFromDescription(current_active_contents, audio_codecs, - video_codecs, rtp_data_codecs, &used_pltypes); + video_codecs, &used_pltypes); // Add our codecs that are not in the current description. MergeCodecs(all_audio_codecs_, audio_codecs, &used_pltypes); MergeCodecs(all_video_codecs_, video_codecs, &used_pltypes); - // Only allocate a payload type for rtp datachannels when using rtp data - // channels. - if (rtp_data_codecs) - MergeCodecs(rtp_data_codecs_, rtp_data_codecs, &used_pltypes); } // Getting codecs for an answer involves these steps: @@ -1950,19 +1948,17 @@ void MediaSessionDescriptionFactory::GetCodecsForAnswer( const std::vector& current_active_contents, const SessionDescription& remote_offer, AudioCodecs* audio_codecs, - VideoCodecs* video_codecs, - RtpDataCodecs* rtp_data_codecs) const { + VideoCodecs* video_codecs) const { // First - get all codecs from the current description if the media type // is used. Add them to |used_pltypes| so the payload type is not reused if a // new media type is added. UsedPayloadTypes used_pltypes; MergeCodecsFromDescription(current_active_contents, audio_codecs, - video_codecs, rtp_data_codecs, &used_pltypes); + video_codecs, &used_pltypes); // Second - filter out codecs that we don't support at all and should ignore. AudioCodecs filtered_offered_audio_codecs; VideoCodecs filtered_offered_video_codecs; - RtpDataCodecs filtered_offered_rtp_data_codecs; for (const ContentInfo& content : remote_offer.contents()) { if (IsMediaContentOfType(&content, MEDIA_TYPE_AUDIO)) { const AudioContentDescription* audio = @@ -1988,22 +1984,6 @@ void MediaSessionDescriptionFactory::GetCodecsForAnswer( filtered_offered_video_codecs.push_back(offered_video_codec); } } - } else if (IsMediaContentOfType(&content, MEDIA_TYPE_DATA)) { - const RtpDataContentDescription* data = - content.media_description()->as_rtp_data(); - if (data) { - // RTP data. This part is inactive for SCTP data. - for (const RtpDataCodec& offered_rtp_data_codec : data->codecs()) { - if (!FindMatchingCodec( - data->codecs(), filtered_offered_rtp_data_codecs, - offered_rtp_data_codec, nullptr) && - FindMatchingCodec(data->codecs(), rtp_data_codecs_, - offered_rtp_data_codec, - nullptr)) { - filtered_offered_rtp_data_codecs.push_back(offered_rtp_data_codec); - } - } - } } } @@ -2013,8 +1993,6 @@ void MediaSessionDescriptionFactory::GetCodecsForAnswer( &used_pltypes); MergeCodecs(filtered_offered_video_codecs, video_codecs, &used_pltypes); - MergeCodecs(filtered_offered_rtp_data_codecs, rtp_data_codecs, - &used_pltypes); } MediaSessionDescriptionFactory::AudioVideoRtpHeaderExtensions @@ -2196,6 +2174,10 @@ bool MediaSessionDescriptionFactory::AddAudioContentForOffer( } } } + if (!session_options.vad_enabled) { + // If application doesn't want CN codecs in offer. + StripCNCodecs(&filtered_codecs); + } cricket::SecurePolicy sdes_policy = IsDtlsActive(current_content, current_description) ? cricket::SEC_DISABLED @@ -2323,7 +2305,7 @@ bool MediaSessionDescriptionFactory::AddVideoContentForOffer( return true; } -bool MediaSessionDescriptionFactory::AddSctpDataContentForOffer( +bool MediaSessionDescriptionFactory::AddDataContentForOffer( const MediaDescriptionOptions& media_description_options, const MediaSessionOptions& session_options, const ContentInfo* current_content, @@ -2368,73 +2350,6 @@ bool MediaSessionDescriptionFactory::AddSctpDataContentForOffer( return true; } -bool MediaSessionDescriptionFactory::AddRtpDataContentForOffer( - const MediaDescriptionOptions& media_description_options, - const MediaSessionOptions& session_options, - const ContentInfo* current_content, - const SessionDescription* current_description, - const RtpDataCodecs& rtp_data_codecs, - StreamParamsVec* current_streams, - SessionDescription* desc, - IceCredentialsIterator* ice_credentials) const { - auto data = std::make_unique(); - bool secure_transport = (transport_desc_factory_->secure() != SEC_DISABLED); - - cricket::SecurePolicy sdes_policy = - IsDtlsActive(current_content, current_description) ? cricket::SEC_DISABLED - : secure(); - std::vector crypto_suites; - GetSupportedDataSdesCryptoSuiteNames(session_options.crypto_options, - &crypto_suites); - if (!CreateMediaContentOffer(media_description_options, session_options, - rtp_data_codecs, sdes_policy, - GetCryptos(current_content), crypto_suites, - RtpHeaderExtensions(), ssrc_generator_, - current_streams, data.get())) { - return false; - } - - data->set_bandwidth(kRtpDataMaxBandwidth); - SetMediaProtocol(secure_transport, data.get()); - desc->AddContent(media_description_options.mid, MediaProtocolType::kRtp, - media_description_options.stopped, std::move(data)); - if (!AddTransportOffer(media_description_options.mid, - media_description_options.transport_options, - current_description, desc, ice_credentials)) { - return false; - } - return true; -} - -bool MediaSessionDescriptionFactory::AddDataContentForOffer( - const MediaDescriptionOptions& media_description_options, - const MediaSessionOptions& session_options, - const ContentInfo* current_content, - const SessionDescription* current_description, - const RtpDataCodecs& rtp_data_codecs, - StreamParamsVec* current_streams, - SessionDescription* desc, - IceCredentialsIterator* ice_credentials) const { - bool is_sctp = (session_options.data_channel_type == DCT_SCTP); - // If the DataChannel type is not specified, use the DataChannel type in - // the current description. - if (session_options.data_channel_type == DCT_NONE && current_content) { - RTC_CHECK(IsMediaContentOfType(current_content, MEDIA_TYPE_DATA)); - is_sctp = (current_content->media_description()->protocol() == - kMediaProtocolSctp); - } - if (is_sctp) { - return AddSctpDataContentForOffer( - media_description_options, session_options, current_content, - current_description, current_streams, desc, ice_credentials); - } else { - return AddRtpDataContentForOffer(media_description_options, session_options, - current_content, current_description, - rtp_data_codecs, current_streams, desc, - ice_credentials); - } -} - bool MediaSessionDescriptionFactory::AddUnsupportedContentForOffer( const MediaDescriptionOptions& media_description_options, const MediaSessionOptions& session_options, @@ -2538,6 +2453,10 @@ bool MediaSessionDescriptionFactory::AddAudioContentForAnswer( } } } + if (!session_options.vad_enabled) { + // If application doesn't want CN codecs in answer. + StripCNCodecs(&filtered_codecs); + } bool bundle_enabled = offer_description->HasGroup(GROUP_TYPE_BUNDLE) && session_options.bundle_enabled; @@ -2708,7 +2627,6 @@ bool MediaSessionDescriptionFactory::AddDataContentForAnswer( const ContentInfo* current_content, const SessionDescription* current_description, const TransportInfo* bundle_transport, - const RtpDataCodecs& rtp_data_codecs, StreamParamsVec* current_streams, SessionDescription* answer, IceCredentialsIterator* ice_credentials) const { @@ -2756,32 +2674,13 @@ bool MediaSessionDescriptionFactory::AddDataContentForAnswer( bool offer_uses_sctpmap = offer_data_description->use_sctpmap(); data_answer->as_sctp()->set_use_sctpmap(offer_uses_sctpmap); } else { - // RTP offer - data_answer = std::make_unique(); - - const RtpDataContentDescription* offer_data_description = - offer_content->media_description()->as_rtp_data(); - RTC_CHECK(offer_data_description); - if (!SetCodecsInAnswer(offer_data_description, rtp_data_codecs, - media_description_options, session_options, - ssrc_generator_, current_streams, - data_answer->as_rtp_data())) { - return false; - } - if (!CreateMediaContentAnswer( - offer_data_description, media_description_options, session_options, - sdes_policy, GetCryptos(current_content), RtpHeaderExtensions(), - ssrc_generator_, enable_encrypted_rtp_header_extensions_, - current_streams, bundle_enabled, data_answer.get())) { - return false; // Fails the session setup. - } + RTC_NOTREACHED() << "Non-SCTP data content found"; } bool secure = bundle_transport ? bundle_transport->description.secure() : data_transport->secure(); - bool rejected = session_options.data_channel_type == DCT_NONE || - media_description_options.stopped || + bool rejected = media_description_options.stopped || offer_content->rejected || !IsMediaProtocolSupported(MEDIA_TYPE_DATA, data_answer->protocol(), secure); @@ -2790,13 +2689,6 @@ bool MediaSessionDescriptionFactory::AddDataContentForAnswer( return false; } - if (!rejected && session_options.data_channel_type == DCT_RTP) { - data_answer->set_bandwidth(kRtpDataMaxBandwidth); - } else { - // RFC 3264 - // The answer MUST contain the same number of m-lines as the offer. - RTC_LOG(LS_INFO) << "Data is not supported in the answer."; - } answer->AddContent(media_description_options.mid, offer_content->type, rejected, std::move(data_answer)); return true; @@ -2981,12 +2873,6 @@ const VideoContentDescription* GetFirstVideoContentDescription( return desc ? desc->as_video() : nullptr; } -const RtpDataContentDescription* GetFirstRtpDataContentDescription( - const SessionDescription* sdesc) { - auto desc = GetFirstMediaContentDescription(sdesc, MEDIA_TYPE_DATA); - return desc ? desc->as_rtp_data() : nullptr; -} - const SctpDataContentDescription* GetFirstSctpDataContentDescription( const SessionDescription* sdesc) { auto desc = GetFirstMediaContentDescription(sdesc, MEDIA_TYPE_DATA); @@ -3059,12 +2945,6 @@ VideoContentDescription* GetFirstVideoContentDescription( return desc ? desc->as_video() : nullptr; } -RtpDataContentDescription* GetFirstRtpDataContentDescription( - SessionDescription* sdesc) { - auto desc = GetFirstMediaContentDescription(sdesc, MEDIA_TYPE_DATA); - return desc ? desc->as_rtp_data() : nullptr; -} - SctpDataContentDescription* GetFirstSctpDataContentDescription( SessionDescription* sdesc) { auto desc = GetFirstMediaContentDescription(sdesc, MEDIA_TYPE_DATA); diff --git a/pc/media_session.h b/pc/media_session.h index 58a31a2ab2..d4c8025bc0 100644 --- a/pc/media_session.h +++ b/pc/media_session.h @@ -18,14 +18,21 @@ #include #include +#include "api/crypto/crypto_options.h" #include "api/media_types.h" +#include "api/rtp_parameters.h" +#include "api/rtp_transceiver_direction.h" #include "media/base/media_constants.h" -#include "media/base/media_engine.h" // For DataChannelType +#include "media/base/rid_description.h" +#include "media/base/stream_params.h" #include "p2p/base/ice_credentials_iterator.h" +#include "p2p/base/transport_description.h" #include "p2p/base/transport_description_factory.h" +#include "p2p/base/transport_info.h" #include "pc/jsep_transport.h" #include "pc/media_protocol_names.h" #include "pc/session_description.h" +#include "pc/simulcast_description.h" #include "rtc_base/unique_id_generator.h" namespace cricket { @@ -65,10 +72,6 @@ struct MediaDescriptionOptions { const SimulcastLayerList& simulcast_layers, int num_sim_layers); - // Internally just uses sender_options. - void AddRtpDataChannel(const std::string& track_id, - const std::string& stream_id); - MediaType type; std::string mid; webrtc::RtpTransceiverDirection direction; @@ -102,7 +105,6 @@ struct MediaSessionOptions { bool HasMediaDescription(MediaType type) const; - DataChannelType data_channel_type = DCT_NONE; bool vad_enabled = true; // When disabled, removes all CN codecs from SDP. bool rtcp_mux_enabled = true; bool bundle_enabled = false; @@ -154,10 +156,6 @@ class MediaSessionDescriptionFactory { const VideoCodecs& recv_codecs); RtpHeaderExtensions filtered_rtp_header_extensions( RtpHeaderExtensions extensions) const; - const RtpDataCodecs& rtp_data_codecs() const { return rtp_data_codecs_; } - void set_rtp_data_codecs(const RtpDataCodecs& codecs) { - rtp_data_codecs_ = codecs; - } SecurePolicy secure() const { return secure_; } void set_secure(SecurePolicy s) { secure_ = s; } @@ -196,14 +194,12 @@ class MediaSessionDescriptionFactory { void GetCodecsForOffer( const std::vector& current_active_contents, AudioCodecs* audio_codecs, - VideoCodecs* video_codecs, - RtpDataCodecs* rtp_data_codecs) const; + VideoCodecs* video_codecs) const; void GetCodecsForAnswer( const std::vector& current_active_contents, const SessionDescription& remote_offer, AudioCodecs* audio_codecs, - VideoCodecs* video_codecs, - RtpDataCodecs* rtp_data_codecs) const; + VideoCodecs* video_codecs) const; AudioVideoRtpHeaderExtensions GetOfferedRtpHeaderExtensionsWithIds( const std::vector& current_active_contents, bool extmap_allow_mixed, @@ -253,32 +249,11 @@ class MediaSessionDescriptionFactory { SessionDescription* desc, IceCredentialsIterator* ice_credentials) const; - bool AddSctpDataContentForOffer( - const MediaDescriptionOptions& media_description_options, - const MediaSessionOptions& session_options, - const ContentInfo* current_content, - const SessionDescription* current_description, - StreamParamsVec* current_streams, - SessionDescription* desc, - IceCredentialsIterator* ice_credentials) const; - bool AddRtpDataContentForOffer( - const MediaDescriptionOptions& media_description_options, - const MediaSessionOptions& session_options, - const ContentInfo* current_content, - const SessionDescription* current_description, - const RtpDataCodecs& rtp_data_codecs, - StreamParamsVec* current_streams, - SessionDescription* desc, - IceCredentialsIterator* ice_credentials) const; - // This function calls either AddRtpDataContentForOffer or - // AddSctpDataContentForOffer depending on protocol. - // The codecs argument is ignored for SCTP. bool AddDataContentForOffer( const MediaDescriptionOptions& media_description_options, const MediaSessionOptions& session_options, const ContentInfo* current_content, const SessionDescription* current_description, - const RtpDataCodecs& rtp_data_codecs, StreamParamsVec* current_streams, SessionDescription* desc, IceCredentialsIterator* ice_credentials) const; @@ -327,7 +302,6 @@ class MediaSessionDescriptionFactory { const ContentInfo* current_content, const SessionDescription* current_description, const TransportInfo* bundle_transport, - const RtpDataCodecs& rtp_data_codecs, StreamParamsVec* current_streams, SessionDescription* answer, IceCredentialsIterator* ice_credentials) const; @@ -360,7 +334,6 @@ class MediaSessionDescriptionFactory { VideoCodecs video_sendrecv_codecs_; // Union of send and recv. VideoCodecs all_video_codecs_; - RtpDataCodecs rtp_data_codecs_; // This object is not owned by the channel so it must outlive it. rtc::UniqueRandomIdGenerator* const ssrc_generator_; bool enable_encrypted_rtp_header_extensions_ = false; @@ -390,8 +363,6 @@ const AudioContentDescription* GetFirstAudioContentDescription( const SessionDescription* sdesc); const VideoContentDescription* GetFirstVideoContentDescription( const SessionDescription* sdesc); -const RtpDataContentDescription* GetFirstRtpDataContentDescription( - const SessionDescription* sdesc); const SctpDataContentDescription* GetFirstSctpDataContentDescription( const SessionDescription* sdesc); // Non-const versions of the above functions. @@ -409,8 +380,6 @@ AudioContentDescription* GetFirstAudioContentDescription( SessionDescription* sdesc); VideoContentDescription* GetFirstVideoContentDescription( SessionDescription* sdesc); -RtpDataContentDescription* GetFirstRtpDataContentDescription( - SessionDescription* sdesc); SctpDataContentDescription* GetFirstSctpDataContentDescription( SessionDescription* sdesc); diff --git a/pc/media_session_unittest.cc b/pc/media_session_unittest.cc index d8cb1591a9..c7c07fc527 100644 --- a/pc/media_session_unittest.cc +++ b/pc/media_session_unittest.cc @@ -50,7 +50,6 @@ using cricket::CryptoParamsVec; using cricket::GetFirstAudioContent; using cricket::GetFirstAudioContentDescription; using cricket::GetFirstDataContent; -using cricket::GetFirstRtpDataContentDescription; using cricket::GetFirstVideoContent; using cricket::GetFirstVideoContentDescription; using cricket::kAutoBandwidth; @@ -65,8 +64,6 @@ using cricket::MediaSessionOptions; using cricket::MediaType; using cricket::RidDescription; using cricket::RidDirection; -using cricket::RtpDataCodec; -using cricket::RtpDataContentDescription; using cricket::SctpDataContentDescription; using cricket::SEC_DISABLED; using cricket::SEC_ENABLED; @@ -133,15 +130,6 @@ static const VideoCodec kVideoCodecs2[] = {VideoCodec(126, "H264"), static const VideoCodec kVideoCodecsAnswer[] = {VideoCodec(97, "H264")}; -static const RtpDataCodec kDataCodecs1[] = {RtpDataCodec(98, "binary-data"), - RtpDataCodec(99, "utf8-text")}; - -static const RtpDataCodec kDataCodecs2[] = {RtpDataCodec(126, "binary-data"), - RtpDataCodec(127, "utf8-text")}; - -static const RtpDataCodec kDataCodecsAnswer[] = { - RtpDataCodec(98, "binary-data"), RtpDataCodec(99, "utf8-text")}; - static const RtpExtension kAudioRtpExtension1[] = { RtpExtension("urn:ietf:params:rtp-hdrext:ssrc-audio-level", 8), RtpExtension("http://google.com/testing/audio_something", 10), @@ -151,6 +139,7 @@ static const RtpExtension kAudioRtpExtensionEncrypted1[] = { RtpExtension("urn:ietf:params:rtp-hdrext:ssrc-audio-level", 8), RtpExtension("http://google.com/testing/audio_something", 10), RtpExtension("urn:ietf:params:rtp-hdrext:ssrc-audio-level", 12, true), + RtpExtension("http://google.com/testing/audio_something", 11, true), }; static const RtpExtension kAudioRtpExtension2[] = { @@ -173,7 +162,15 @@ static const RtpExtension kAudioRtpExtension3ForEncryption[] = { static const RtpExtension kAudioRtpExtension3ForEncryptionOffer[] = { RtpExtension("http://google.com/testing/audio_something", 2), RtpExtension("urn:ietf:params:rtp-hdrext:toffset", 3), - RtpExtension("urn:ietf:params:rtp-hdrext:toffset", 14, true), + RtpExtension("http://google.com/testing/audio_something", 14, true), + RtpExtension("urn:ietf:params:rtp-hdrext:toffset", 13, true), +}; + +static const RtpExtension kVideoRtpExtension3ForEncryptionOffer[] = { + RtpExtension("http://google.com/testing/video_something", 4), + RtpExtension("urn:ietf:params:rtp-hdrext:toffset", 3), + RtpExtension("http://google.com/testing/video_something", 12, true), + RtpExtension("urn:ietf:params:rtp-hdrext:toffset", 13, true), }; static const RtpExtension kAudioRtpExtensionAnswer[] = { @@ -192,7 +189,8 @@ static const RtpExtension kVideoRtpExtension1[] = { static const RtpExtension kVideoRtpExtensionEncrypted1[] = { RtpExtension("urn:ietf:params:rtp-hdrext:toffset", 14), RtpExtension("http://google.com/testing/video_something", 13), - RtpExtension("urn:ietf:params:rtp-hdrext:toffset", 11, true), + RtpExtension("urn:ietf:params:rtp-hdrext:toffset", 9, true), + RtpExtension("http://google.com/testing/video_something", 7, true), }; static const RtpExtension kVideoRtpExtension2[] = { @@ -217,7 +215,7 @@ static const RtpExtension kVideoRtpExtensionAnswer[] = { }; static const RtpExtension kVideoRtpExtensionEncryptedAnswer[] = { - RtpExtension("urn:ietf:params:rtp-hdrext:toffset", 11, true), + RtpExtension("urn:ietf:params:rtp-hdrext:toffset", 9, true), }; static const RtpExtension kRtpExtensionTransportSequenceNumber01[] = { @@ -260,9 +258,6 @@ static const char kVideoTrack2[] = "video_2"; static const char kAudioTrack1[] = "audio_1"; static const char kAudioTrack2[] = "audio_2"; static const char kAudioTrack3[] = "audio_3"; -static const char kDataTrack1[] = "data_1"; -static const char kDataTrack2[] = "data_2"; -static const char kDataTrack3[] = "data_3"; static const char* kMediaProtocols[] = {"RTP/AVP", "RTP/SAVP", "RTP/AVPF", "RTP/SAVPF"}; @@ -344,10 +339,8 @@ static void AddAudioVideoSections(RtpTransceiverDirection direction, opts); } -static void AddDataSection(cricket::DataChannelType dct, - RtpTransceiverDirection direction, +static void AddDataSection(RtpTransceiverDirection direction, MediaSessionOptions* opts) { - opts->data_channel_type = dct; AddMediaDescriptionOptions(MEDIA_TYPE_DATA, "data", direction, kActive, opts); } @@ -369,10 +362,6 @@ static void AttachSenderToMediaDescriptionOptions( it->AddVideoSender(track_id, stream_ids, rids, simulcast_layers, num_sim_layer); break; - case MEDIA_TYPE_DATA: - RTC_CHECK(stream_ids.size() == 1U); - it->AddRtpDataChannel(track_id, stream_ids[0]); - break; default: RTC_NOTREACHED(); } @@ -437,12 +426,10 @@ class MediaSessionDescriptionFactoryTest : public ::testing::Test { MAKE_VECTOR(kAudioCodecs1)); f1_.set_video_codecs(MAKE_VECTOR(kVideoCodecs1), MAKE_VECTOR(kVideoCodecs1)); - f1_.set_rtp_data_codecs(MAKE_VECTOR(kDataCodecs1)); f2_.set_audio_codecs(MAKE_VECTOR(kAudioCodecs2), MAKE_VECTOR(kAudioCodecs2)); f2_.set_video_codecs(MAKE_VECTOR(kVideoCodecs2), MAKE_VECTOR(kVideoCodecs2)); - f2_.set_rtp_data_codecs(MAKE_VECTOR(kDataCodecs2)); tdf1_.set_certificate(rtc::RTCCertificate::Create( std::unique_ptr(new rtc::FakeSSLIdentity("id1")))); tdf2_.set_certificate(rtc::RTCCertificate::Create( @@ -604,8 +591,6 @@ class MediaSessionDescriptionFactoryTest : public ::testing::Test { f1_.set_secure(SEC_ENABLED); MediaSessionOptions options; AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &options); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, - &options); std::unique_ptr ref_desc; std::unique_ptr desc; if (offer) { @@ -862,30 +847,21 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateVideoOffer) { TEST_F(MediaSessionDescriptionFactoryTest, TestBundleOfferWithSameCodecPlType) { const VideoCodec& offered_video_codec = f2_.video_sendrecv_codecs()[0]; const AudioCodec& offered_audio_codec = f2_.audio_sendrecv_codecs()[0]; - const RtpDataCodec& offered_data_codec = f2_.rtp_data_codecs()[0]; ASSERT_EQ(offered_video_codec.id, offered_audio_codec.id); - ASSERT_EQ(offered_video_codec.id, offered_data_codec.id); MediaSessionOptions opts; AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &opts); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, &opts); opts.bundle_enabled = true; std::unique_ptr offer = f2_.CreateOffer(opts, NULL); const VideoContentDescription* vcd = GetFirstVideoContentDescription(offer.get()); const AudioContentDescription* acd = GetFirstAudioContentDescription(offer.get()); - const RtpDataContentDescription* dcd = - GetFirstRtpDataContentDescription(offer.get()); ASSERT_TRUE(NULL != vcd); ASSERT_TRUE(NULL != acd); - ASSERT_TRUE(NULL != dcd); EXPECT_NE(vcd->codecs()[0].id, acd->codecs()[0].id); - EXPECT_NE(vcd->codecs()[0].id, dcd->codecs()[0].id); - EXPECT_NE(acd->codecs()[0].id, dcd->codecs()[0].id); EXPECT_EQ(vcd->codecs()[0].name, offered_video_codec.name); EXPECT_EQ(acd->codecs()[0].name, offered_audio_codec.name); - EXPECT_EQ(dcd->codecs()[0].name, offered_data_codec.name); } // Test creating an updated offer with bundle, audio, video and data @@ -901,7 +877,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, AddMediaDescriptionOptions(MEDIA_TYPE_VIDEO, "video", RtpTransceiverDirection::kInactive, kStopped, &opts); - opts.data_channel_type = cricket::DCT_NONE; opts.bundle_enabled = true; std::unique_ptr offer = f1_.CreateOffer(opts, NULL); std::unique_ptr answer = @@ -909,8 +884,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, MediaSessionOptions updated_opts; AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &updated_opts); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, - &updated_opts); updated_opts.bundle_enabled = true; std::unique_ptr updated_offer( f1_.CreateOffer(updated_opts, answer.get())); @@ -919,58 +892,20 @@ TEST_F(MediaSessionDescriptionFactoryTest, GetFirstAudioContentDescription(updated_offer.get()); const VideoContentDescription* vcd = GetFirstVideoContentDescription(updated_offer.get()); - const RtpDataContentDescription* dcd = - GetFirstRtpDataContentDescription(updated_offer.get()); EXPECT_TRUE(NULL != vcd); EXPECT_TRUE(NULL != acd); - EXPECT_TRUE(NULL != dcd); ASSERT_CRYPTO(acd, 1U, kDefaultSrtpCryptoSuite); EXPECT_EQ(cricket::kMediaProtocolSavpf, acd->protocol()); ASSERT_CRYPTO(vcd, 1U, kDefaultSrtpCryptoSuite); EXPECT_EQ(cricket::kMediaProtocolSavpf, vcd->protocol()); - ASSERT_CRYPTO(dcd, 1U, kDefaultSrtpCryptoSuite); - EXPECT_EQ(cricket::kMediaProtocolSavpf, dcd->protocol()); -} - -// Create a RTP data offer, and ensure it matches what we expect. -TEST_F(MediaSessionDescriptionFactoryTest, TestCreateRtpDataOffer) { - MediaSessionOptions opts; - AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &opts); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, &opts); - f1_.set_secure(SEC_ENABLED); - std::unique_ptr offer = f1_.CreateOffer(opts, NULL); - ASSERT_TRUE(offer.get() != NULL); - const ContentInfo* ac = offer->GetContentByName("audio"); - const ContentInfo* dc = offer->GetContentByName("data"); - ASSERT_TRUE(ac != NULL); - ASSERT_TRUE(dc != NULL); - EXPECT_EQ(MediaProtocolType::kRtp, ac->type); - EXPECT_EQ(MediaProtocolType::kRtp, dc->type); - const AudioContentDescription* acd = ac->media_description()->as_audio(); - const RtpDataContentDescription* dcd = dc->media_description()->as_rtp_data(); - EXPECT_EQ(MEDIA_TYPE_AUDIO, acd->type()); - EXPECT_EQ(f1_.audio_sendrecv_codecs(), acd->codecs()); - EXPECT_EQ(0U, acd->first_ssrc()); // no sender is attched. - EXPECT_EQ(kAutoBandwidth, acd->bandwidth()); // default bandwidth (auto) - EXPECT_TRUE(acd->rtcp_mux()); // rtcp-mux defaults on - ASSERT_CRYPTO(acd, 1U, kDefaultSrtpCryptoSuite); - EXPECT_EQ(cricket::kMediaProtocolSavpf, acd->protocol()); - EXPECT_EQ(MEDIA_TYPE_DATA, dcd->type()); - EXPECT_EQ(f1_.rtp_data_codecs(), dcd->codecs()); - EXPECT_EQ(0U, dcd->first_ssrc()); // no sender is attached. - EXPECT_EQ(cricket::kRtpDataMaxBandwidth, - dcd->bandwidth()); // default bandwidth (auto) - EXPECT_TRUE(dcd->rtcp_mux()); // rtcp-mux defaults on - ASSERT_CRYPTO(dcd, 1U, kDefaultSrtpCryptoSuite); - EXPECT_EQ(cricket::kMediaProtocolSavpf, dcd->protocol()); } // Create an SCTP data offer with bundle without error. TEST_F(MediaSessionDescriptionFactoryTest, TestCreateSctpDataOffer) { MediaSessionOptions opts; opts.bundle_enabled = true; - AddDataSection(cricket::DCT_SCTP, RtpTransceiverDirection::kSendRecv, &opts); + AddDataSection(RtpTransceiverDirection::kSendRecv, &opts); f1_.set_secure(SEC_ENABLED); std::unique_ptr offer = f1_.CreateOffer(opts, NULL); EXPECT_TRUE(offer.get() != NULL); @@ -985,7 +920,7 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateSctpDataOffer) { TEST_F(MediaSessionDescriptionFactoryTest, TestCreateSecureSctpDataOffer) { MediaSessionOptions opts; opts.bundle_enabled = true; - AddDataSection(cricket::DCT_SCTP, RtpTransceiverDirection::kSendRecv, &opts); + AddDataSection(RtpTransceiverDirection::kSendRecv, &opts); f1_.set_secure(SEC_ENABLED); tdf1_.set_secure(SEC_ENABLED); std::unique_ptr offer = f1_.CreateOffer(opts, NULL); @@ -1001,7 +936,7 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateSecureSctpDataOffer) { TEST_F(MediaSessionDescriptionFactoryTest, TestCreateImplicitSctpDataOffer) { MediaSessionOptions opts; opts.bundle_enabled = true; - AddDataSection(cricket::DCT_SCTP, RtpTransceiverDirection::kSendRecv, &opts); + AddDataSection(RtpTransceiverDirection::kSendRecv, &opts); f1_.set_secure(SEC_ENABLED); std::unique_ptr offer1(f1_.CreateOffer(opts, NULL)); ASSERT_TRUE(offer1.get() != NULL); @@ -1009,10 +944,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateImplicitSctpDataOffer) { ASSERT_TRUE(data != NULL); ASSERT_EQ(cricket::kMediaProtocolSctp, data->media_description()->protocol()); - // Now set data_channel_type to 'none' (default) and make sure that the - // datachannel type that gets generated from the previous offer, is of the - // same type. - opts.data_channel_type = cricket::DCT_NONE; std::unique_ptr offer2( f1_.CreateOffer(opts, offer1.get())); data = offer2->GetContentByName("data"); @@ -1115,6 +1046,66 @@ TEST_F(MediaSessionDescriptionFactoryTest, ReAnswerChangedBundleOffererTagged) { EXPECT_TRUE(bundle_group->HasContentName("video")); } +TEST_F(MediaSessionDescriptionFactoryTest, + CreateAnswerForOfferWithMultipleBundleGroups) { + // Create an offer with 4 m= sections, initially without BUNDLE groups. + MediaSessionOptions opts; + opts.bundle_enabled = false; + AddMediaDescriptionOptions(MEDIA_TYPE_AUDIO, "1", + RtpTransceiverDirection::kSendRecv, kActive, + &opts); + AddMediaDescriptionOptions(MEDIA_TYPE_AUDIO, "2", + RtpTransceiverDirection::kSendRecv, kActive, + &opts); + AddMediaDescriptionOptions(MEDIA_TYPE_AUDIO, "3", + RtpTransceiverDirection::kSendRecv, kActive, + &opts); + AddMediaDescriptionOptions(MEDIA_TYPE_AUDIO, "4", + RtpTransceiverDirection::kSendRecv, kActive, + &opts); + std::unique_ptr offer = f1_.CreateOffer(opts, nullptr); + ASSERT_TRUE(offer->groups().empty()); + + // Munge the offer to have two groups. Offers like these cannot be generated + // without munging, but it is valid to receive such offers from remote + // endpoints. + cricket::ContentGroup bundle_group1(cricket::GROUP_TYPE_BUNDLE); + bundle_group1.AddContentName("1"); + bundle_group1.AddContentName("2"); + cricket::ContentGroup bundle_group2(cricket::GROUP_TYPE_BUNDLE); + bundle_group2.AddContentName("3"); + bundle_group2.AddContentName("4"); + offer->AddGroup(bundle_group1); + offer->AddGroup(bundle_group2); + + // If BUNDLE is enabled, the answer to this offer should accept both BUNDLE + // groups. + opts.bundle_enabled = true; + std::unique_ptr answer = + f2_.CreateAnswer(offer.get(), opts, nullptr); + + std::vector answer_groups = + answer->GetGroupsByName(cricket::GROUP_TYPE_BUNDLE); + ASSERT_EQ(answer_groups.size(), 2u); + EXPECT_EQ(answer_groups[0]->content_names().size(), 2u); + EXPECT_TRUE(answer_groups[0]->HasContentName("1")); + EXPECT_TRUE(answer_groups[0]->HasContentName("2")); + EXPECT_EQ(answer_groups[1]->content_names().size(), 2u); + EXPECT_TRUE(answer_groups[1]->HasContentName("3")); + EXPECT_TRUE(answer_groups[1]->HasContentName("4")); + + // If BUNDLE is disabled, the answer to this offer should reject both BUNDLE + // groups. + opts.bundle_enabled = false; + answer = f2_.CreateAnswer(offer.get(), opts, nullptr); + + answer_groups = answer->GetGroupsByName(cricket::GROUP_TYPE_BUNDLE); + // Rejected groups are still listed, but they are empty. + ASSERT_EQ(answer_groups.size(), 2u); + EXPECT_TRUE(answer_groups[0]->content_names().empty()); + EXPECT_TRUE(answer_groups[1]->content_names().empty()); +} + // Test that if the BUNDLE offerer-tagged media section is changed in a reoffer // and there is still a non-rejected media section that was in the initial // offer, then the ICE credentials do not change in the reoffer offerer-tagged @@ -1216,7 +1207,7 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateSendOnlyOffer) { // SessionDescription is preserved in the new SessionDescription. TEST_F(MediaSessionDescriptionFactoryTest, TestCreateOfferContentOrder) { MediaSessionOptions opts; - AddDataSection(cricket::DCT_SCTP, RtpTransceiverDirection::kSendRecv, &opts); + AddDataSection(RtpTransceiverDirection::kSendRecv, &opts); std::unique_ptr offer1(f1_.CreateOffer(opts, NULL)); ASSERT_TRUE(offer1.get() != NULL); @@ -1350,79 +1341,11 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateVideoAnswerGcmAnswer) { TestVideoGcmCipher(false, true); } -TEST_F(MediaSessionDescriptionFactoryTest, TestCreateDataAnswer) { - MediaSessionOptions opts = CreatePlanBMediaSessionOptions(); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, &opts); - f1_.set_secure(SEC_ENABLED); - f2_.set_secure(SEC_ENABLED); - std::unique_ptr offer = f1_.CreateOffer(opts, NULL); - ASSERT_TRUE(offer.get() != NULL); - std::unique_ptr answer = - f2_.CreateAnswer(offer.get(), opts, NULL); - const ContentInfo* ac = answer->GetContentByName("audio"); - const ContentInfo* dc = answer->GetContentByName("data"); - ASSERT_TRUE(ac != NULL); - ASSERT_TRUE(dc != NULL); - EXPECT_EQ(MediaProtocolType::kRtp, ac->type); - EXPECT_EQ(MediaProtocolType::kRtp, dc->type); - const AudioContentDescription* acd = ac->media_description()->as_audio(); - const RtpDataContentDescription* dcd = dc->media_description()->as_rtp_data(); - EXPECT_EQ(MEDIA_TYPE_AUDIO, acd->type()); - EXPECT_THAT(acd->codecs(), ElementsAreArray(kAudioCodecsAnswer)); - EXPECT_EQ(kAutoBandwidth, acd->bandwidth()); // negotiated auto bw - EXPECT_EQ(0U, acd->first_ssrc()); // no sender is attached - EXPECT_TRUE(acd->rtcp_mux()); // negotiated rtcp-mux - ASSERT_CRYPTO(acd, 1U, kDefaultSrtpCryptoSuite); - EXPECT_EQ(MEDIA_TYPE_DATA, dcd->type()); - EXPECT_THAT(dcd->codecs(), ElementsAreArray(kDataCodecsAnswer)); - EXPECT_EQ(0U, dcd->first_ssrc()); // no sender is attached - EXPECT_TRUE(dcd->rtcp_mux()); // negotiated rtcp-mux - ASSERT_CRYPTO(dcd, 1U, kDefaultSrtpCryptoSuite); - EXPECT_EQ(cricket::kMediaProtocolSavpf, dcd->protocol()); -} - -TEST_F(MediaSessionDescriptionFactoryTest, TestCreateDataAnswerGcm) { - MediaSessionOptions opts = CreatePlanBMediaSessionOptions(); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, &opts); - opts.crypto_options.srtp.enable_gcm_crypto_suites = true; - f1_.set_secure(SEC_ENABLED); - f2_.set_secure(SEC_ENABLED); - std::unique_ptr offer = f1_.CreateOffer(opts, NULL); - ASSERT_TRUE(offer.get() != NULL); - for (cricket::ContentInfo& content : offer->contents()) { - auto cryptos = content.media_description()->cryptos(); - PreferGcmCryptoParameters(&cryptos); - content.media_description()->set_cryptos(cryptos); - } - std::unique_ptr answer = - f2_.CreateAnswer(offer.get(), opts, NULL); - const ContentInfo* ac = answer->GetContentByName("audio"); - const ContentInfo* dc = answer->GetContentByName("data"); - ASSERT_TRUE(ac != NULL); - ASSERT_TRUE(dc != NULL); - EXPECT_EQ(MediaProtocolType::kRtp, ac->type); - EXPECT_EQ(MediaProtocolType::kRtp, dc->type); - const AudioContentDescription* acd = ac->media_description()->as_audio(); - const RtpDataContentDescription* dcd = dc->media_description()->as_rtp_data(); - EXPECT_EQ(MEDIA_TYPE_AUDIO, acd->type()); - EXPECT_THAT(acd->codecs(), ElementsAreArray(kAudioCodecsAnswer)); - EXPECT_EQ(kAutoBandwidth, acd->bandwidth()); // negotiated auto bw - EXPECT_EQ(0U, acd->first_ssrc()); // no sender is attached - EXPECT_TRUE(acd->rtcp_mux()); // negotiated rtcp-mux - ASSERT_CRYPTO(acd, 1U, kDefaultSrtpCryptoSuiteGcm); - EXPECT_EQ(MEDIA_TYPE_DATA, dcd->type()); - EXPECT_THAT(dcd->codecs(), ElementsAreArray(kDataCodecsAnswer)); - EXPECT_EQ(0U, dcd->first_ssrc()); // no sender is attached - EXPECT_TRUE(dcd->rtcp_mux()); // negotiated rtcp-mux - ASSERT_CRYPTO(dcd, 1U, kDefaultSrtpCryptoSuiteGcm); - EXPECT_EQ(cricket::kMediaProtocolSavpf, dcd->protocol()); -} - // The use_sctpmap flag should be set in an Sctp DataContentDescription by // default. The answer's use_sctpmap flag should match the offer's. TEST_F(MediaSessionDescriptionFactoryTest, TestCreateDataAnswerUsesSctpmap) { MediaSessionOptions opts; - AddDataSection(cricket::DCT_SCTP, RtpTransceiverDirection::kSendRecv, &opts); + AddDataSection(RtpTransceiverDirection::kSendRecv, &opts); std::unique_ptr offer = f1_.CreateOffer(opts, NULL); ASSERT_TRUE(offer.get() != NULL); ContentInfo* dc_offer = offer->GetContentByName("data"); @@ -1443,7 +1366,7 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateDataAnswerUsesSctpmap) { // The answer's use_sctpmap flag should match the offer's. TEST_F(MediaSessionDescriptionFactoryTest, TestCreateDataAnswerWithoutSctpmap) { MediaSessionOptions opts; - AddDataSection(cricket::DCT_SCTP, RtpTransceiverDirection::kSendRecv, &opts); + AddDataSection(RtpTransceiverDirection::kSendRecv, &opts); std::unique_ptr offer = f1_.CreateOffer(opts, NULL); ASSERT_TRUE(offer.get() != NULL); ContentInfo* dc_offer = offer->GetContentByName("data"); @@ -1473,7 +1396,7 @@ TEST_F(MediaSessionDescriptionFactoryTest, tdf2_.set_secure(SEC_ENABLED); MediaSessionOptions opts; - AddDataSection(cricket::DCT_SCTP, RtpTransceiverDirection::kSendRecv, &opts); + AddDataSection(RtpTransceiverDirection::kSendRecv, &opts); std::unique_ptr offer = f1_.CreateOffer(opts, nullptr); ASSERT_TRUE(offer.get() != nullptr); ContentInfo* dc_offer = offer->GetContentByName("data"); @@ -1507,7 +1430,7 @@ TEST_F(MediaSessionDescriptionFactoryTest, tdf2_.set_secure(SEC_ENABLED); MediaSessionOptions opts; - AddDataSection(cricket::DCT_SCTP, RtpTransceiverDirection::kSendRecv, &opts); + AddDataSection(RtpTransceiverDirection::kSendRecv, &opts); std::unique_ptr offer = f1_.CreateOffer(opts, nullptr); ASSERT_TRUE(offer.get() != nullptr); ContentInfo* dc_offer = offer->GetContentByName("data"); @@ -1536,7 +1459,7 @@ TEST_F(MediaSessionDescriptionFactoryTest, tdf2_.set_secure(SEC_ENABLED); MediaSessionOptions opts; - AddDataSection(cricket::DCT_SCTP, RtpTransceiverDirection::kSendRecv, &opts); + AddDataSection(RtpTransceiverDirection::kSendRecv, &opts); std::unique_ptr offer = f1_.CreateOffer(opts, nullptr); ASSERT_TRUE(offer.get() != nullptr); ContentInfo* dc_offer = offer->GetContentByName("data"); @@ -1561,7 +1484,7 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateAnswerContentOrder) { MediaSessionOptions opts; // Creates a data only offer. - AddDataSection(cricket::DCT_SCTP, RtpTransceiverDirection::kSendRecv, &opts); + AddDataSection(RtpTransceiverDirection::kSendRecv, &opts); std::unique_ptr offer1(f1_.CreateOffer(opts, NULL)); ASSERT_TRUE(offer1.get() != NULL); @@ -1621,35 +1544,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, CreateAnswerToInactiveOffer) { RtpTransceiverDirection::kInactive); } -// Test that a data content with an unknown protocol is rejected in an answer. -TEST_F(MediaSessionDescriptionFactoryTest, - CreateDataAnswerToOfferWithUnknownProtocol) { - MediaSessionOptions opts; - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, &opts); - f1_.set_secure(SEC_ENABLED); - f2_.set_secure(SEC_ENABLED); - std::unique_ptr offer = f1_.CreateOffer(opts, NULL); - ContentInfo* dc_offer = offer->GetContentByName("data"); - ASSERT_TRUE(dc_offer != NULL); - RtpDataContentDescription* dcd_offer = - dc_offer->media_description()->as_rtp_data(); - ASSERT_TRUE(dcd_offer != NULL); - // Offer must be acceptable as an RTP protocol in order to be set. - std::string protocol = "RTP/a weird unknown protocol"; - dcd_offer->set_protocol(protocol); - - std::unique_ptr answer = - f2_.CreateAnswer(offer.get(), opts, NULL); - - const ContentInfo* dc_answer = answer->GetContentByName("data"); - ASSERT_TRUE(dc_answer != NULL); - EXPECT_TRUE(dc_answer->rejected); - const RtpDataContentDescription* dcd_answer = - dc_answer->media_description()->as_rtp_data(); - ASSERT_TRUE(dcd_answer != NULL); - EXPECT_EQ(protocol, dcd_answer->protocol()); -} - // Test that the media protocol is RTP/AVPF if DTLS and SDES are disabled. TEST_F(MediaSessionDescriptionFactoryTest, AudioOfferAnswerWithCryptoDisabled) { MediaSessionOptions opts = CreatePlanBMediaSessionOptions(); @@ -2169,36 +2063,28 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateAnswerWithoutLegacyStreams) { MediaSessionOptions opts; AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &opts); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, &opts); std::unique_ptr offer = f1_.CreateOffer(opts, NULL); ASSERT_TRUE(offer.get() != NULL); std::unique_ptr answer = f2_.CreateAnswer(offer.get(), opts, NULL); const ContentInfo* ac = answer->GetContentByName("audio"); const ContentInfo* vc = answer->GetContentByName("video"); - const ContentInfo* dc = answer->GetContentByName("data"); ASSERT_TRUE(ac != NULL); ASSERT_TRUE(vc != NULL); const AudioContentDescription* acd = ac->media_description()->as_audio(); const VideoContentDescription* vcd = vc->media_description()->as_video(); - const RtpDataContentDescription* dcd = dc->media_description()->as_rtp_data(); EXPECT_FALSE(acd->has_ssrcs()); // No StreamParams. EXPECT_FALSE(vcd->has_ssrcs()); // No StreamParams. - EXPECT_FALSE(dcd->has_ssrcs()); // No StreamParams. } // Create a typical video answer, and ensure it matches what we expect. TEST_F(MediaSessionDescriptionFactoryTest, TestCreateVideoAnswerRtcpMux) { MediaSessionOptions offer_opts; AddAudioVideoSections(RtpTransceiverDirection::kSendRecv, &offer_opts); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kSendRecv, - &offer_opts); MediaSessionOptions answer_opts; AddAudioVideoSections(RtpTransceiverDirection::kSendRecv, &answer_opts); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kSendRecv, - &answer_opts); std::unique_ptr offer; std::unique_ptr answer; @@ -2209,16 +2095,12 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateVideoAnswerRtcpMux) { answer = f2_.CreateAnswer(offer.get(), answer_opts, NULL); ASSERT_TRUE(NULL != GetFirstAudioContentDescription(offer.get())); ASSERT_TRUE(NULL != GetFirstVideoContentDescription(offer.get())); - ASSERT_TRUE(NULL != GetFirstRtpDataContentDescription(offer.get())); ASSERT_TRUE(NULL != GetFirstAudioContentDescription(answer.get())); ASSERT_TRUE(NULL != GetFirstVideoContentDescription(answer.get())); - ASSERT_TRUE(NULL != GetFirstRtpDataContentDescription(answer.get())); EXPECT_TRUE(GetFirstAudioContentDescription(offer.get())->rtcp_mux()); EXPECT_TRUE(GetFirstVideoContentDescription(offer.get())->rtcp_mux()); - EXPECT_TRUE(GetFirstRtpDataContentDescription(offer.get())->rtcp_mux()); EXPECT_TRUE(GetFirstAudioContentDescription(answer.get())->rtcp_mux()); EXPECT_TRUE(GetFirstVideoContentDescription(answer.get())->rtcp_mux()); - EXPECT_TRUE(GetFirstRtpDataContentDescription(answer.get())->rtcp_mux()); offer_opts.rtcp_mux_enabled = true; answer_opts.rtcp_mux_enabled = false; @@ -2226,16 +2108,12 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateVideoAnswerRtcpMux) { answer = f2_.CreateAnswer(offer.get(), answer_opts, NULL); ASSERT_TRUE(NULL != GetFirstAudioContentDescription(offer.get())); ASSERT_TRUE(NULL != GetFirstVideoContentDescription(offer.get())); - ASSERT_TRUE(NULL != GetFirstRtpDataContentDescription(offer.get())); ASSERT_TRUE(NULL != GetFirstAudioContentDescription(answer.get())); ASSERT_TRUE(NULL != GetFirstVideoContentDescription(answer.get())); - ASSERT_TRUE(NULL != GetFirstRtpDataContentDescription(answer.get())); EXPECT_TRUE(GetFirstAudioContentDescription(offer.get())->rtcp_mux()); EXPECT_TRUE(GetFirstVideoContentDescription(offer.get())->rtcp_mux()); - EXPECT_TRUE(GetFirstRtpDataContentDescription(offer.get())->rtcp_mux()); EXPECT_FALSE(GetFirstAudioContentDescription(answer.get())->rtcp_mux()); EXPECT_FALSE(GetFirstVideoContentDescription(answer.get())->rtcp_mux()); - EXPECT_FALSE(GetFirstRtpDataContentDescription(answer.get())->rtcp_mux()); offer_opts.rtcp_mux_enabled = false; answer_opts.rtcp_mux_enabled = true; @@ -2243,16 +2121,12 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateVideoAnswerRtcpMux) { answer = f2_.CreateAnswer(offer.get(), answer_opts, NULL); ASSERT_TRUE(NULL != GetFirstAudioContentDescription(offer.get())); ASSERT_TRUE(NULL != GetFirstVideoContentDescription(offer.get())); - ASSERT_TRUE(NULL != GetFirstRtpDataContentDescription(offer.get())); ASSERT_TRUE(NULL != GetFirstAudioContentDescription(answer.get())); ASSERT_TRUE(NULL != GetFirstVideoContentDescription(answer.get())); - ASSERT_TRUE(NULL != GetFirstRtpDataContentDescription(answer.get())); EXPECT_FALSE(GetFirstAudioContentDescription(offer.get())->rtcp_mux()); EXPECT_FALSE(GetFirstVideoContentDescription(offer.get())->rtcp_mux()); - EXPECT_FALSE(GetFirstRtpDataContentDescription(offer.get())->rtcp_mux()); EXPECT_FALSE(GetFirstAudioContentDescription(answer.get())->rtcp_mux()); EXPECT_FALSE(GetFirstVideoContentDescription(answer.get())->rtcp_mux()); - EXPECT_FALSE(GetFirstRtpDataContentDescription(answer.get())->rtcp_mux()); offer_opts.rtcp_mux_enabled = false; answer_opts.rtcp_mux_enabled = false; @@ -2260,16 +2134,12 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateVideoAnswerRtcpMux) { answer = f2_.CreateAnswer(offer.get(), answer_opts, NULL); ASSERT_TRUE(NULL != GetFirstAudioContentDescription(offer.get())); ASSERT_TRUE(NULL != GetFirstVideoContentDescription(offer.get())); - ASSERT_TRUE(NULL != GetFirstRtpDataContentDescription(offer.get())); ASSERT_TRUE(NULL != GetFirstAudioContentDescription(answer.get())); ASSERT_TRUE(NULL != GetFirstVideoContentDescription(answer.get())); - ASSERT_TRUE(NULL != GetFirstRtpDataContentDescription(answer.get())); EXPECT_FALSE(GetFirstAudioContentDescription(offer.get())->rtcp_mux()); EXPECT_FALSE(GetFirstVideoContentDescription(offer.get())->rtcp_mux()); - EXPECT_FALSE(GetFirstRtpDataContentDescription(offer.get())->rtcp_mux()); EXPECT_FALSE(GetFirstAudioContentDescription(answer.get())->rtcp_mux()); EXPECT_FALSE(GetFirstVideoContentDescription(answer.get())->rtcp_mux()); - EXPECT_FALSE(GetFirstRtpDataContentDescription(answer.get())->rtcp_mux()); } // Create an audio-only answer to a video offer. @@ -2295,122 +2165,141 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateAudioAnswerToVideo) { EXPECT_TRUE(vc->rejected); } -// Create an audio-only answer to an offer with data. -TEST_F(MediaSessionDescriptionFactoryTest, TestCreateNoDataAnswerToDataOffer) { - MediaSessionOptions opts = CreatePlanBMediaSessionOptions(); - opts.data_channel_type = cricket::DCT_RTP; - AddMediaDescriptionOptions(MEDIA_TYPE_DATA, "data", - RtpTransceiverDirection::kRecvOnly, kActive, - &opts); - std::unique_ptr offer = f1_.CreateOffer(opts, NULL); - ASSERT_TRUE(offer.get() != NULL); - - opts.media_description_options[1].stopped = true; - std::unique_ptr answer = - f2_.CreateAnswer(offer.get(), opts, NULL); - const ContentInfo* ac = answer->GetContentByName("audio"); - const ContentInfo* dc = answer->GetContentByName("data"); - ASSERT_TRUE(ac != NULL); - ASSERT_TRUE(dc != NULL); - ASSERT_TRUE(dc->media_description() != NULL); - EXPECT_TRUE(dc->rejected); -} - // Create an answer that rejects the contents which are rejected in the offer. TEST_F(MediaSessionDescriptionFactoryTest, CreateAnswerToOfferWithRejectedMedia) { MediaSessionOptions opts; AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &opts); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, &opts); std::unique_ptr offer = f1_.CreateOffer(opts, NULL); ASSERT_TRUE(offer.get() != NULL); ContentInfo* ac = offer->GetContentByName("audio"); ContentInfo* vc = offer->GetContentByName("video"); - ContentInfo* dc = offer->GetContentByName("data"); ASSERT_TRUE(ac != NULL); ASSERT_TRUE(vc != NULL); - ASSERT_TRUE(dc != NULL); ac->rejected = true; vc->rejected = true; - dc->rejected = true; std::unique_ptr answer = f2_.CreateAnswer(offer.get(), opts, NULL); ac = answer->GetContentByName("audio"); vc = answer->GetContentByName("video"); - dc = answer->GetContentByName("data"); ASSERT_TRUE(ac != NULL); ASSERT_TRUE(vc != NULL); - ASSERT_TRUE(dc != NULL); EXPECT_TRUE(ac->rejected); EXPECT_TRUE(vc->rejected); - EXPECT_TRUE(dc->rejected); } TEST_F(MediaSessionDescriptionFactoryTest, - CreateAnswerSupportsMixedOneAndTwoByteHeaderExtensions) { + OfferAndAnswerDoesNotHaveMixedByteSessionAttribute) { MediaSessionOptions opts; - std::unique_ptr offer = f1_.CreateOffer(opts, NULL); - // Offer without request of mixed one- and two-byte header extensions. + std::unique_ptr offer = + f1_.CreateOffer(opts, /*current_description=*/nullptr); offer->set_extmap_allow_mixed(false); - ASSERT_TRUE(offer.get() != NULL); - std::unique_ptr answer_no_support( - f2_.CreateAnswer(offer.get(), opts, NULL)); - EXPECT_FALSE(answer_no_support->extmap_allow_mixed()); - // Offer with request of mixed one- and two-byte header extensions. + std::unique_ptr answer( + f2_.CreateAnswer(offer.get(), opts, /*current_description=*/nullptr)); + + EXPECT_FALSE(answer->extmap_allow_mixed()); +} + +TEST_F(MediaSessionDescriptionFactoryTest, + OfferAndAnswerHaveMixedByteSessionAttribute) { + MediaSessionOptions opts; + std::unique_ptr offer = + f1_.CreateOffer(opts, /*current_description=*/nullptr); offer->set_extmap_allow_mixed(true); - ASSERT_TRUE(offer.get() != NULL); + std::unique_ptr answer_support( - f2_.CreateAnswer(offer.get(), opts, NULL)); + f2_.CreateAnswer(offer.get(), opts, /*current_description=*/nullptr)); + EXPECT_TRUE(answer_support->extmap_allow_mixed()); } TEST_F(MediaSessionDescriptionFactoryTest, - CreateAnswerSupportsMixedOneAndTwoByteHeaderExtensionsOnMediaLevel) { + OfferAndAnswerDoesNotHaveMixedByteMediaAttributes) { MediaSessionOptions opts; AddAudioVideoSections(RtpTransceiverDirection::kSendRecv, &opts); - std::unique_ptr offer = f1_.CreateOffer(opts, NULL); - MediaContentDescription* video_offer = - offer->GetContentDescriptionByName("video"); - ASSERT_TRUE(video_offer); + std::unique_ptr offer = + f1_.CreateOffer(opts, /*current_description=*/nullptr); + offer->set_extmap_allow_mixed(false); MediaContentDescription* audio_offer = offer->GetContentDescriptionByName("audio"); - ASSERT_TRUE(audio_offer); + MediaContentDescription* video_offer = + offer->GetContentDescriptionByName("video"); + ASSERT_EQ(MediaContentDescription::kNo, + audio_offer->extmap_allow_mixed_enum()); + ASSERT_EQ(MediaContentDescription::kNo, + video_offer->extmap_allow_mixed_enum()); - // Explicit disable of mixed one-two byte header support in offer. - video_offer->set_extmap_allow_mixed_enum(MediaContentDescription::kNo); - audio_offer->set_extmap_allow_mixed_enum(MediaContentDescription::kNo); + std::unique_ptr answer( + f2_.CreateAnswer(offer.get(), opts, /*current_description=*/nullptr)); - ASSERT_TRUE(offer.get() != NULL); - std::unique_ptr answer_no_support( - f2_.CreateAnswer(offer.get(), opts, NULL)); - MediaContentDescription* video_answer = - answer_no_support->GetContentDescriptionByName("video"); MediaContentDescription* audio_answer = - answer_no_support->GetContentDescriptionByName("audio"); - EXPECT_EQ(MediaContentDescription::kNo, - video_answer->extmap_allow_mixed_enum()); + answer->GetContentDescriptionByName("audio"); + MediaContentDescription* video_answer = + answer->GetContentDescriptionByName("video"); EXPECT_EQ(MediaContentDescription::kNo, audio_answer->extmap_allow_mixed_enum()); + EXPECT_EQ(MediaContentDescription::kNo, + video_answer->extmap_allow_mixed_enum()); +} - // Enable mixed one-two byte header support in offer. - video_offer->set_extmap_allow_mixed_enum(MediaContentDescription::kMedia); +TEST_F(MediaSessionDescriptionFactoryTest, + OfferAndAnswerHaveSameMixedByteMediaAttributes) { + MediaSessionOptions opts; + AddAudioVideoSections(RtpTransceiverDirection::kSendRecv, &opts); + std::unique_ptr offer = + f1_.CreateOffer(opts, /*current_description=*/nullptr); + offer->set_extmap_allow_mixed(false); + MediaContentDescription* audio_offer = + offer->GetContentDescriptionByName("audio"); audio_offer->set_extmap_allow_mixed_enum(MediaContentDescription::kMedia); - ASSERT_TRUE(offer.get() != NULL); - std::unique_ptr answer_support( - f2_.CreateAnswer(offer.get(), opts, NULL)); - video_answer = answer_support->GetContentDescriptionByName("video"); - audio_answer = answer_support->GetContentDescriptionByName("audio"); + MediaContentDescription* video_offer = + offer->GetContentDescriptionByName("video"); + video_offer->set_extmap_allow_mixed_enum(MediaContentDescription::kMedia); + + std::unique_ptr answer( + f2_.CreateAnswer(offer.get(), opts, /*current_description=*/nullptr)); + + MediaContentDescription* audio_answer = + answer->GetContentDescriptionByName("audio"); + MediaContentDescription* video_answer = + answer->GetContentDescriptionByName("video"); EXPECT_EQ(MediaContentDescription::kMedia, - video_answer->extmap_allow_mixed_enum()); + audio_answer->extmap_allow_mixed_enum()); EXPECT_EQ(MediaContentDescription::kMedia, + video_answer->extmap_allow_mixed_enum()); +} + +TEST_F(MediaSessionDescriptionFactoryTest, + OfferAndAnswerHaveDifferentMixedByteMediaAttributes) { + MediaSessionOptions opts; + AddAudioVideoSections(RtpTransceiverDirection::kSendRecv, &opts); + std::unique_ptr offer = + f1_.CreateOffer(opts, /*current_description=*/nullptr); + offer->set_extmap_allow_mixed(false); + MediaContentDescription* audio_offer = + offer->GetContentDescriptionByName("audio"); + audio_offer->set_extmap_allow_mixed_enum(MediaContentDescription::kNo); + MediaContentDescription* video_offer = + offer->GetContentDescriptionByName("video"); + video_offer->set_extmap_allow_mixed_enum(MediaContentDescription::kMedia); + + std::unique_ptr answer( + f2_.CreateAnswer(offer.get(), opts, /*current_description=*/nullptr)); + + MediaContentDescription* audio_answer = + answer->GetContentDescriptionByName("audio"); + MediaContentDescription* video_answer = + answer->GetContentDescriptionByName("video"); + EXPECT_EQ(MediaContentDescription::kNo, audio_answer->extmap_allow_mixed_enum()); + EXPECT_EQ(MediaContentDescription::kMedia, + video_answer->extmap_allow_mixed_enum()); } // Create an audio and video offer with: // - one video track // - two audio tracks -// - two data tracks // and ensure it matches what we expect. Also updates the initial offer by // adding a new video track and replaces one of the audio tracks. TEST_F(MediaSessionDescriptionFactoryTest, TestCreateMultiStreamVideoOffer) { @@ -2423,25 +2312,16 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateMultiStreamVideoOffer) { AttachSenderToMediaDescriptionOptions("audio", MEDIA_TYPE_AUDIO, kAudioTrack2, {kMediaStream1}, 1, &opts); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kSendRecv, &opts); - AttachSenderToMediaDescriptionOptions("data", MEDIA_TYPE_DATA, kDataTrack1, - {kMediaStream1}, 1, &opts); - AttachSenderToMediaDescriptionOptions("data", MEDIA_TYPE_DATA, kDataTrack2, - {kMediaStream1}, 1, &opts); - f1_.set_secure(SEC_ENABLED); std::unique_ptr offer = f1_.CreateOffer(opts, NULL); ASSERT_TRUE(offer.get() != NULL); const ContentInfo* ac = offer->GetContentByName("audio"); const ContentInfo* vc = offer->GetContentByName("video"); - const ContentInfo* dc = offer->GetContentByName("data"); ASSERT_TRUE(ac != NULL); ASSERT_TRUE(vc != NULL); - ASSERT_TRUE(dc != NULL); const AudioContentDescription* acd = ac->media_description()->as_audio(); const VideoContentDescription* vcd = vc->media_description()->as_video(); - const RtpDataContentDescription* dcd = dc->media_description()->as_rtp_data(); EXPECT_EQ(MEDIA_TYPE_AUDIO, acd->type()); EXPECT_EQ(f1_.audio_sendrecv_codecs(), acd->codecs()); @@ -2470,25 +2350,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateMultiStreamVideoOffer) { EXPECT_EQ(kAutoBandwidth, vcd->bandwidth()); // default bandwidth (auto) EXPECT_TRUE(vcd->rtcp_mux()); // rtcp-mux defaults on - EXPECT_EQ(MEDIA_TYPE_DATA, dcd->type()); - EXPECT_EQ(f1_.rtp_data_codecs(), dcd->codecs()); - ASSERT_CRYPTO(dcd, 1U, kDefaultSrtpCryptoSuite); - - const StreamParamsVec& data_streams = dcd->streams(); - ASSERT_EQ(2U, data_streams.size()); - EXPECT_EQ(data_streams[0].cname, data_streams[1].cname); - EXPECT_EQ(kDataTrack1, data_streams[0].id); - ASSERT_EQ(1U, data_streams[0].ssrcs.size()); - EXPECT_NE(0U, data_streams[0].ssrcs[0]); - EXPECT_EQ(kDataTrack2, data_streams[1].id); - ASSERT_EQ(1U, data_streams[1].ssrcs.size()); - EXPECT_NE(0U, data_streams[1].ssrcs[0]); - - EXPECT_EQ(cricket::kRtpDataMaxBandwidth, - dcd->bandwidth()); // default bandwidth (auto) - EXPECT_TRUE(dcd->rtcp_mux()); // rtcp-mux defaults on - ASSERT_CRYPTO(dcd, 1U, kDefaultSrtpCryptoSuite); - // Update the offer. Add a new video track that is not synched to the // other tracks and replace audio track 2 with audio track 3. AttachSenderToMediaDescriptionOptions("video", MEDIA_TYPE_VIDEO, kVideoTrack2, @@ -2496,38 +2357,27 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateMultiStreamVideoOffer) { DetachSenderFromMediaSection("audio", kAudioTrack2, &opts); AttachSenderToMediaDescriptionOptions("audio", MEDIA_TYPE_AUDIO, kAudioTrack3, {kMediaStream1}, 1, &opts); - DetachSenderFromMediaSection("data", kDataTrack2, &opts); - AttachSenderToMediaDescriptionOptions("data", MEDIA_TYPE_DATA, kDataTrack3, - {kMediaStream1}, 1, &opts); std::unique_ptr updated_offer( f1_.CreateOffer(opts, offer.get())); ASSERT_TRUE(updated_offer.get() != NULL); ac = updated_offer->GetContentByName("audio"); vc = updated_offer->GetContentByName("video"); - dc = updated_offer->GetContentByName("data"); ASSERT_TRUE(ac != NULL); ASSERT_TRUE(vc != NULL); - ASSERT_TRUE(dc != NULL); const AudioContentDescription* updated_acd = ac->media_description()->as_audio(); const VideoContentDescription* updated_vcd = vc->media_description()->as_video(); - const RtpDataContentDescription* updated_dcd = - dc->media_description()->as_rtp_data(); EXPECT_EQ(acd->type(), updated_acd->type()); EXPECT_EQ(acd->codecs(), updated_acd->codecs()); EXPECT_EQ(vcd->type(), updated_vcd->type()); EXPECT_EQ(vcd->codecs(), updated_vcd->codecs()); - EXPECT_EQ(dcd->type(), updated_dcd->type()); - EXPECT_EQ(dcd->codecs(), updated_dcd->codecs()); ASSERT_CRYPTO(updated_acd, 1U, kDefaultSrtpCryptoSuite); EXPECT_TRUE(CompareCryptoParams(acd->cryptos(), updated_acd->cryptos())); ASSERT_CRYPTO(updated_vcd, 1U, kDefaultSrtpCryptoSuite); EXPECT_TRUE(CompareCryptoParams(vcd->cryptos(), updated_vcd->cryptos())); - ASSERT_CRYPTO(updated_dcd, 1U, kDefaultSrtpCryptoSuite); - EXPECT_TRUE(CompareCryptoParams(dcd->cryptos(), updated_dcd->cryptos())); const StreamParamsVec& updated_audio_streams = updated_acd->streams(); ASSERT_EQ(2U, updated_audio_streams.size()); @@ -2543,18 +2393,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateMultiStreamVideoOffer) { EXPECT_EQ(kVideoTrack2, updated_video_streams[1].id); // All the media streams in one PeerConnection share one RTCP CNAME. EXPECT_EQ(updated_video_streams[1].cname, updated_video_streams[0].cname); - - const StreamParamsVec& updated_data_streams = updated_dcd->streams(); - ASSERT_EQ(2U, updated_data_streams.size()); - EXPECT_EQ(data_streams[0], updated_data_streams[0]); - EXPECT_EQ(kDataTrack3, updated_data_streams[1].id); // New data track. - ASSERT_EQ(1U, updated_data_streams[1].ssrcs.size()); - EXPECT_NE(0U, updated_data_streams[1].ssrcs[0]); - EXPECT_EQ(updated_data_streams[0].cname, updated_data_streams[1].cname); - // The stream correctly got the CNAME from the MediaSessionOptions. - // The Expected RTCP CNAME is the default one as we are using the default - // MediaSessionOptions. - EXPECT_EQ(updated_data_streams[0].cname, cricket::kDefaultRtcpCname); } // Create an offer with simulcast video stream. @@ -2757,10 +2595,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateMultiStreamVideoAnswer) { AddMediaDescriptionOptions(MEDIA_TYPE_VIDEO, "video", RtpTransceiverDirection::kRecvOnly, kActive, &offer_opts); - offer_opts.data_channel_type = cricket::DCT_RTP; - AddMediaDescriptionOptions(MEDIA_TYPE_DATA, "data", - RtpTransceiverDirection::kRecvOnly, kActive, - &offer_opts); f1_.set_secure(SEC_ENABLED); f2_.set_secure(SEC_ENABLED); std::unique_ptr offer = f1_.CreateOffer(offer_opts, NULL); @@ -2779,31 +2613,18 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateMultiStreamVideoAnswer) { AttachSenderToMediaDescriptionOptions("audio", MEDIA_TYPE_AUDIO, kAudioTrack2, {kMediaStream1}, 1, &answer_opts); - AddMediaDescriptionOptions(MEDIA_TYPE_DATA, "data", - RtpTransceiverDirection::kSendRecv, kActive, - &answer_opts); - AttachSenderToMediaDescriptionOptions("data", MEDIA_TYPE_DATA, kDataTrack1, - {kMediaStream1}, 1, &answer_opts); - AttachSenderToMediaDescriptionOptions("data", MEDIA_TYPE_DATA, kDataTrack2, - {kMediaStream1}, 1, &answer_opts); - answer_opts.data_channel_type = cricket::DCT_RTP; - std::unique_ptr answer = f2_.CreateAnswer(offer.get(), answer_opts, NULL); ASSERT_TRUE(answer.get() != NULL); const ContentInfo* ac = answer->GetContentByName("audio"); const ContentInfo* vc = answer->GetContentByName("video"); - const ContentInfo* dc = answer->GetContentByName("data"); ASSERT_TRUE(ac != NULL); ASSERT_TRUE(vc != NULL); - ASSERT_TRUE(dc != NULL); const AudioContentDescription* acd = ac->media_description()->as_audio(); const VideoContentDescription* vcd = vc->media_description()->as_video(); - const RtpDataContentDescription* dcd = dc->media_description()->as_rtp_data(); ASSERT_CRYPTO(acd, 1U, kDefaultSrtpCryptoSuite); ASSERT_CRYPTO(vcd, 1U, kDefaultSrtpCryptoSuite); - ASSERT_CRYPTO(dcd, 1U, kDefaultSrtpCryptoSuite); EXPECT_EQ(MEDIA_TYPE_AUDIO, acd->type()); EXPECT_THAT(acd->codecs(), ElementsAreArray(kAudioCodecsAnswer)); @@ -2831,59 +2652,33 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateMultiStreamVideoAnswer) { EXPECT_EQ(kAutoBandwidth, vcd->bandwidth()); // default bandwidth (auto) EXPECT_TRUE(vcd->rtcp_mux()); // rtcp-mux defaults on - EXPECT_EQ(MEDIA_TYPE_DATA, dcd->type()); - EXPECT_THAT(dcd->codecs(), ElementsAreArray(kDataCodecsAnswer)); - - const StreamParamsVec& data_streams = dcd->streams(); - ASSERT_EQ(2U, data_streams.size()); - EXPECT_TRUE(data_streams[0].cname == data_streams[1].cname); - EXPECT_EQ(kDataTrack1, data_streams[0].id); - ASSERT_EQ(1U, data_streams[0].ssrcs.size()); - EXPECT_NE(0U, data_streams[0].ssrcs[0]); - EXPECT_EQ(kDataTrack2, data_streams[1].id); - ASSERT_EQ(1U, data_streams[1].ssrcs.size()); - EXPECT_NE(0U, data_streams[1].ssrcs[0]); - - EXPECT_EQ(cricket::kRtpDataMaxBandwidth, - dcd->bandwidth()); // default bandwidth (auto) - EXPECT_TRUE(dcd->rtcp_mux()); // rtcp-mux defaults on - // Update the answer. Add a new video track that is not synched to the // other tracks and remove 1 audio track. AttachSenderToMediaDescriptionOptions("video", MEDIA_TYPE_VIDEO, kVideoTrack2, {kMediaStream2}, 1, &answer_opts); DetachSenderFromMediaSection("audio", kAudioTrack2, &answer_opts); - DetachSenderFromMediaSection("data", kDataTrack2, &answer_opts); std::unique_ptr updated_answer( f2_.CreateAnswer(offer.get(), answer_opts, answer.get())); ASSERT_TRUE(updated_answer.get() != NULL); ac = updated_answer->GetContentByName("audio"); vc = updated_answer->GetContentByName("video"); - dc = updated_answer->GetContentByName("data"); ASSERT_TRUE(ac != NULL); ASSERT_TRUE(vc != NULL); - ASSERT_TRUE(dc != NULL); const AudioContentDescription* updated_acd = ac->media_description()->as_audio(); const VideoContentDescription* updated_vcd = vc->media_description()->as_video(); - const RtpDataContentDescription* updated_dcd = - dc->media_description()->as_rtp_data(); ASSERT_CRYPTO(updated_acd, 1U, kDefaultSrtpCryptoSuite); EXPECT_TRUE(CompareCryptoParams(acd->cryptos(), updated_acd->cryptos())); ASSERT_CRYPTO(updated_vcd, 1U, kDefaultSrtpCryptoSuite); EXPECT_TRUE(CompareCryptoParams(vcd->cryptos(), updated_vcd->cryptos())); - ASSERT_CRYPTO(updated_dcd, 1U, kDefaultSrtpCryptoSuite); - EXPECT_TRUE(CompareCryptoParams(dcd->cryptos(), updated_dcd->cryptos())); EXPECT_EQ(acd->type(), updated_acd->type()); EXPECT_EQ(acd->codecs(), updated_acd->codecs()); EXPECT_EQ(vcd->type(), updated_vcd->type()); EXPECT_EQ(vcd->codecs(), updated_vcd->codecs()); - EXPECT_EQ(dcd->type(), updated_dcd->type()); - EXPECT_EQ(dcd->codecs(), updated_dcd->codecs()); const StreamParamsVec& updated_audio_streams = updated_acd->streams(); ASSERT_EQ(1U, updated_audio_streams.size()); @@ -2895,10 +2690,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCreateMultiStreamVideoAnswer) { EXPECT_EQ(kVideoTrack2, updated_video_streams[1].id); // All media streams in one PeerConnection share one CNAME. EXPECT_EQ(updated_video_streams[1].cname, updated_video_streams[0].cname); - - const StreamParamsVec& updated_data_streams = updated_dcd->streams(); - ASSERT_EQ(1U, updated_data_streams.size()); - EXPECT_TRUE(data_streams[0] == updated_data_streams[0]); } // Create an updated offer after creating an answer to the original offer and @@ -3650,19 +3441,11 @@ TEST_F(MediaSessionDescriptionFactoryTest, RtpExtensionIdReusedEncrypted) { MAKE_VECTOR(kVideoRtpExtension3ForEncryption), &opts); std::unique_ptr offer = f1_.CreateOffer(opts, NULL); - // The extensions that are shared between audio and video should use the same - // id. - const RtpExtension kExpectedVideoRtpExtension[] = { - kVideoRtpExtension3ForEncryption[0], - kAudioRtpExtension3ForEncryptionOffer[1], - kAudioRtpExtension3ForEncryptionOffer[2], - }; - EXPECT_EQ( MAKE_VECTOR(kAudioRtpExtension3ForEncryptionOffer), GetFirstAudioContentDescription(offer.get())->rtp_header_extensions()); EXPECT_EQ( - MAKE_VECTOR(kExpectedVideoRtpExtension), + MAKE_VECTOR(kVideoRtpExtension3ForEncryptionOffer), GetFirstVideoContentDescription(offer.get())->rtp_header_extensions()); // Nothing should change when creating a new offer @@ -3672,7 +3455,7 @@ TEST_F(MediaSessionDescriptionFactoryTest, RtpExtensionIdReusedEncrypted) { EXPECT_EQ(MAKE_VECTOR(kAudioRtpExtension3ForEncryptionOffer), GetFirstAudioContentDescription(updated_offer.get()) ->rtp_header_extensions()); - EXPECT_EQ(MAKE_VECTOR(kExpectedVideoRtpExtension), + EXPECT_EQ(MAKE_VECTOR(kVideoRtpExtension3ForEncryptionOffer), GetFirstVideoContentDescription(updated_offer.get()) ->rtp_header_extensions()); } @@ -3742,8 +3525,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestTransportInfoOfferAudioCurrent) { TEST_F(MediaSessionDescriptionFactoryTest, TestTransportInfoOfferMultimedia) { MediaSessionOptions options; AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &options); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, - &options); TestTransportInfo(true, options, false); } @@ -3751,16 +3532,12 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestTransportInfoOfferMultimediaCurrent) { MediaSessionOptions options; AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &options); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, - &options); TestTransportInfo(true, options, true); } TEST_F(MediaSessionDescriptionFactoryTest, TestTransportInfoOfferBundle) { MediaSessionOptions options; AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &options); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, - &options); options.bundle_enabled = true; TestTransportInfo(true, options, false); } @@ -3769,8 +3546,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestTransportInfoOfferBundleCurrent) { MediaSessionOptions options; AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &options); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, - &options); options.bundle_enabled = true; TestTransportInfo(true, options, true); } @@ -3806,8 +3581,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, TEST_F(MediaSessionDescriptionFactoryTest, TestTransportInfoAnswerMultimedia) { MediaSessionOptions options; AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &options); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, - &options); TestTransportInfo(false, options, false); } @@ -3815,16 +3588,12 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestTransportInfoAnswerMultimediaCurrent) { MediaSessionOptions options; AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &options); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, - &options); TestTransportInfo(false, options, true); } TEST_F(MediaSessionDescriptionFactoryTest, TestTransportInfoAnswerBundle) { MediaSessionOptions options; AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &options); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, - &options); options.bundle_enabled = true; TestTransportInfo(false, options, false); } @@ -3833,8 +3602,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestTransportInfoAnswerBundleCurrent) { MediaSessionOptions options; AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &options); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, - &options); options.bundle_enabled = true; TestTransportInfo(false, options, true); } @@ -4024,8 +3791,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCryptoOfferDtlsButNotSdes) { tdf2_.set_secure(SEC_ENABLED); MediaSessionOptions options; AddAudioVideoSections(RtpTransceiverDirection::kRecvOnly, &options); - AddDataSection(cricket::DCT_RTP, RtpTransceiverDirection::kRecvOnly, - &options); // Generate an offer with DTLS but without SDES. std::unique_ptr offer = f1_.CreateOffer(options, NULL); @@ -4037,9 +3802,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCryptoOfferDtlsButNotSdes) { const VideoContentDescription* video_offer = GetFirstVideoContentDescription(offer.get()); ASSERT_TRUE(video_offer->cryptos().empty()); - const RtpDataContentDescription* data_offer = - GetFirstRtpDataContentDescription(offer.get()); - ASSERT_TRUE(data_offer->cryptos().empty()); const cricket::TransportDescription* audio_offer_trans_desc = offer->GetTransportDescriptionByName("audio"); @@ -4047,9 +3809,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCryptoOfferDtlsButNotSdes) { const cricket::TransportDescription* video_offer_trans_desc = offer->GetTransportDescriptionByName("video"); ASSERT_TRUE(video_offer_trans_desc->identity_fingerprint.get() != NULL); - const cricket::TransportDescription* data_offer_trans_desc = - offer->GetTransportDescriptionByName("data"); - ASSERT_TRUE(data_offer_trans_desc->identity_fingerprint.get() != NULL); // Generate an answer with DTLS. std::unique_ptr answer = @@ -4062,9 +3821,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestCryptoOfferDtlsButNotSdes) { const cricket::TransportDescription* video_answer_trans_desc = answer->GetTransportDescriptionByName("video"); EXPECT_TRUE(video_answer_trans_desc->identity_fingerprint.get() != NULL); - const cricket::TransportDescription* data_answer_trans_desc = - answer->GetTransportDescriptionByName("data"); - EXPECT_TRUE(data_answer_trans_desc->identity_fingerprint.get() != NULL); } // Verifies if vad_enabled option is set to false, CN codecs are not present in @@ -4098,7 +3854,6 @@ TEST_F(MediaSessionDescriptionFactoryTest, TestMIDsMatchesExistingOffer) { AddMediaDescriptionOptions(MEDIA_TYPE_VIDEO, "video_modified", RtpTransceiverDirection::kRecvOnly, kActive, &opts); - opts.data_channel_type = cricket::DCT_SCTP; AddMediaDescriptionOptions(MEDIA_TYPE_DATA, "data_modified", RtpTransceiverDirection::kSendRecv, kActive, &opts); @@ -4570,12 +4325,10 @@ class MediaProtocolTest : public ::testing::TestWithParam { MAKE_VECTOR(kAudioCodecs1)); f1_.set_video_codecs(MAKE_VECTOR(kVideoCodecs1), MAKE_VECTOR(kVideoCodecs1)); - f1_.set_rtp_data_codecs(MAKE_VECTOR(kDataCodecs1)); f2_.set_audio_codecs(MAKE_VECTOR(kAudioCodecs2), MAKE_VECTOR(kAudioCodecs2)); f2_.set_video_codecs(MAKE_VECTOR(kVideoCodecs2), MAKE_VECTOR(kVideoCodecs2)); - f2_.set_rtp_data_codecs(MAKE_VECTOR(kDataCodecs2)); f1_.set_secure(SEC_ENABLED); f2_.set_secure(SEC_ENABLED); tdf1_.set_certificate(rtc::RTCCertificate::Create( diff --git a/pc/media_stream.cc b/pc/media_stream.cc index 00f491b3cb..08a2a723d0 100644 --- a/pc/media_stream.cc +++ b/pc/media_stream.cc @@ -31,9 +31,7 @@ static typename V::iterator FindTrack(V* vector, const std::string& track_id) { } rtc::scoped_refptr MediaStream::Create(const std::string& id) { - rtc::RefCountedObject* stream = - new rtc::RefCountedObject(id); - return stream; + return rtc::make_ref_counted(id); } MediaStream::MediaStream(const std::string& id) : id_(id) {} diff --git a/api/media_stream_proxy.h b/pc/media_stream_proxy.h similarity index 78% rename from api/media_stream_proxy.h rename to pc/media_stream_proxy.h index 8ee33ca0ee..36069a4369 100644 --- a/api/media_stream_proxy.h +++ b/pc/media_stream_proxy.h @@ -8,20 +8,20 @@ * be found in the AUTHORS file in the root of the source tree. */ -#ifndef API_MEDIA_STREAM_PROXY_H_ -#define API_MEDIA_STREAM_PROXY_H_ +#ifndef PC_MEDIA_STREAM_PROXY_H_ +#define PC_MEDIA_STREAM_PROXY_H_ #include #include "api/media_stream_interface.h" -#include "api/proxy.h" +#include "pc/proxy.h" namespace webrtc { -// TODO(deadbeef): Move this to .cc file and out of api/. What threads methods -// are called on is an implementation detail. -BEGIN_SIGNALING_PROXY_MAP(MediaStream) -PROXY_SIGNALING_THREAD_DESTRUCTOR() +// TODO(deadbeef): Move this to a .cc file. What threads methods are called on +// is an implementation detail. +BEGIN_PRIMARY_PROXY_MAP(MediaStream) +PROXY_PRIMARY_THREAD_DESTRUCTOR() BYPASS_PROXY_CONSTMETHOD0(std::string, id) PROXY_METHOD0(AudioTrackVector, GetAudioTracks) PROXY_METHOD0(VideoTrackVector, GetVideoTracks) @@ -37,8 +37,8 @@ PROXY_METHOD1(bool, RemoveTrack, AudioTrackInterface*) PROXY_METHOD1(bool, RemoveTrack, VideoTrackInterface*) PROXY_METHOD1(void, RegisterObserver, ObserverInterface*) PROXY_METHOD1(void, UnregisterObserver, ObserverInterface*) -END_PROXY_MAP() +END_PROXY_MAP(MediaStream) } // namespace webrtc -#endif // API_MEDIA_STREAM_PROXY_H_ +#endif // PC_MEDIA_STREAM_PROXY_H_ diff --git a/api/media_stream_track_proxy.h b/pc/media_stream_track_proxy.h similarity index 57% rename from api/media_stream_track_proxy.h rename to pc/media_stream_track_proxy.h index 59dcb77244..f563137c77 100644 --- a/api/media_stream_track_proxy.h +++ b/pc/media_stream_track_proxy.h @@ -11,26 +11,25 @@ // This file includes proxy classes for tracks. The purpose is // to make sure tracks are only accessed from the signaling thread. -#ifndef API_MEDIA_STREAM_TRACK_PROXY_H_ -#define API_MEDIA_STREAM_TRACK_PROXY_H_ +#ifndef PC_MEDIA_STREAM_TRACK_PROXY_H_ +#define PC_MEDIA_STREAM_TRACK_PROXY_H_ #include #include "api/media_stream_interface.h" -#include "api/proxy.h" +#include "pc/proxy.h" namespace webrtc { -// TODO(deadbeef): Move this to .cc file and out of api/. What threads methods -// are called on is an implementation detail. - -BEGIN_SIGNALING_PROXY_MAP(AudioTrack) -PROXY_SIGNALING_THREAD_DESTRUCTOR() +// TODO(deadbeef): Move this to .cc file. What threads methods are called on is +// an implementation detail. +BEGIN_PRIMARY_PROXY_MAP(AudioTrack) +PROXY_PRIMARY_THREAD_DESTRUCTOR() BYPASS_PROXY_CONSTMETHOD0(std::string, kind) BYPASS_PROXY_CONSTMETHOD0(std::string, id) PROXY_CONSTMETHOD0(TrackState, state) PROXY_CONSTMETHOD0(bool, enabled) -PROXY_CONSTMETHOD0(AudioSourceInterface*, GetSource) +BYPASS_PROXY_CONSTMETHOD0(AudioSourceInterface*, GetSource) PROXY_METHOD1(void, AddSink, AudioTrackSinkInterface*) PROXY_METHOD1(void, RemoveSink, AudioTrackSinkInterface*) PROXY_METHOD1(bool, GetSignalLevel, int*) @@ -38,28 +37,28 @@ PROXY_METHOD0(rtc::scoped_refptr, GetAudioProcessor) PROXY_METHOD1(bool, set_enabled, bool) PROXY_METHOD1(void, RegisterObserver, ObserverInterface*) PROXY_METHOD1(void, UnregisterObserver, ObserverInterface*) -END_PROXY_MAP() +END_PROXY_MAP(AudioTrack) BEGIN_PROXY_MAP(VideoTrack) -PROXY_SIGNALING_THREAD_DESTRUCTOR() +PROXY_PRIMARY_THREAD_DESTRUCTOR() BYPASS_PROXY_CONSTMETHOD0(std::string, kind) BYPASS_PROXY_CONSTMETHOD0(std::string, id) -PROXY_CONSTMETHOD0(TrackState, state) -PROXY_CONSTMETHOD0(bool, enabled) -PROXY_METHOD1(bool, set_enabled, bool) -PROXY_CONSTMETHOD0(ContentHint, content_hint) -PROXY_METHOD1(void, set_content_hint, ContentHint) -PROXY_WORKER_METHOD2(void, - AddOrUpdateSink, - rtc::VideoSinkInterface*, - const rtc::VideoSinkWants&) -PROXY_WORKER_METHOD1(void, RemoveSink, rtc::VideoSinkInterface*) -PROXY_CONSTMETHOD0(VideoTrackSourceInterface*, GetSource) +PROXY_SECONDARY_CONSTMETHOD0(TrackState, state) +PROXY_SECONDARY_CONSTMETHOD0(bool, enabled) +PROXY_SECONDARY_METHOD1(bool, set_enabled, bool) +PROXY_SECONDARY_CONSTMETHOD0(ContentHint, content_hint) +PROXY_SECONDARY_METHOD1(void, set_content_hint, ContentHint) +PROXY_SECONDARY_METHOD2(void, + AddOrUpdateSink, + rtc::VideoSinkInterface*, + const rtc::VideoSinkWants&) +PROXY_SECONDARY_METHOD1(void, RemoveSink, rtc::VideoSinkInterface*) +BYPASS_PROXY_CONSTMETHOD0(VideoTrackSourceInterface*, GetSource) PROXY_METHOD1(void, RegisterObserver, ObserverInterface*) PROXY_METHOD1(void, UnregisterObserver, ObserverInterface*) -END_PROXY_MAP() +END_PROXY_MAP(VideoTrack) } // namespace webrtc -#endif // API_MEDIA_STREAM_TRACK_PROXY_H_ +#endif // PC_MEDIA_STREAM_TRACK_PROXY_H_ diff --git a/pc/peer_connection.cc b/pc/peer_connection.cc index 9ba7daefa1..276af1787d 100644 --- a/pc/peer_connection.cc +++ b/pc/peer_connection.cc @@ -12,6 +12,7 @@ #include #include + #include #include #include @@ -33,26 +34,27 @@ #include "media/base/rid_description.h" #include "media/base/stream_params.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" +#include "p2p/base/basic_async_resolver_factory.h" #include "p2p/base/connection.h" #include "p2p/base/connection_info.h" #include "p2p/base/dtls_transport_internal.h" #include "p2p/base/p2p_constants.h" #include "p2p/base/p2p_transport_channel.h" #include "p2p/base/transport_info.h" +#include "pc/channel.h" #include "pc/ice_server_parsing.h" #include "pc/rtp_receiver.h" #include "pc/rtp_sender.h" #include "pc/sctp_transport.h" #include "pc/simulcast_description.h" #include "pc/webrtc_session_description_factory.h" -#include "rtc_base/bind.h" #include "rtc_base/helpers.h" #include "rtc_base/ip_address.h" #include "rtc_base/location.h" #include "rtc_base/logging.h" #include "rtc_base/net_helper.h" #include "rtc_base/network_constants.h" -#include "rtc_base/callback_list.h" +#include "rtc_base/ref_counted_object.h" #include "rtc_base/socket_address.h" #include "rtc_base/string_encode.h" #include "rtc_base/task_utils/to_queued_task.h" @@ -88,7 +90,6 @@ const char kSimulcastNumberOfEncodings[] = static const int REPORT_USAGE_PATTERN_DELAY_MS = 60000; - uint32_t ConvertIceTransportTypeToCandidateFilter( PeerConnectionInterface::IceTransportsType type) { switch (type) { @@ -179,7 +180,6 @@ IceCandidatePairType GetIceCandidatePairCounter( return kIceCandidatePairMax; } - absl::optional RTCConfigurationToIceConfigOptionalInt( int rtc_configuration_parameter) { if (rtc_configuration_parameter == @@ -247,6 +247,8 @@ cricket::IceConfig ParseIceConfig( ice_config.ice_inactive_timeout = config.ice_inactive_timeout; ice_config.stun_keepalive_interval = config.stun_candidate_keepalive_interval; ice_config.network_preference = config.network_preference; + ice_config.stable_writable_connection_ping_interval = + config.stable_writable_connection_ping_interval_ms; return ice_config; } @@ -264,6 +266,20 @@ bool HasRtcpMuxEnabled(const cricket::ContentInfo* content) { return content->media_description()->rtcp_mux(); } +bool DtlsEnabled(const PeerConnectionInterface::RTCConfiguration& configuration, + const PeerConnectionFactoryInterface::Options& options, + const PeerConnectionDependencies& dependencies) { + if (options.disable_encryption) + return false; + + // Enable DTLS by default if we have an identity store or a certificate. + bool default_enabled = + (dependencies.cert_generator || !configuration.certificates.empty()); + + // The |configuration| can override the default value. + return configuration.enable_dtls_srtp.value_or(default_enabled); +} + } // namespace bool PeerConnectionInterface::RTCConfiguration::operator==( @@ -319,6 +335,7 @@ bool PeerConnectionInterface::RTCConfiguration::operator==( bool enable_implicit_rollback; absl::optional allow_codec_switching; absl::optional report_usage_pattern_delay_ms; + absl::optional stable_writable_connection_ping_interval_ms; }; static_assert(sizeof(stuff_being_tested_for_equality) == sizeof(*this), "Did you add something to RTCConfiguration and forget to " @@ -347,7 +364,6 @@ bool PeerConnectionInterface::RTCConfiguration::operator==( disable_ipv6_on_wifi == o.disable_ipv6_on_wifi && max_ipv6_networks == o.max_ipv6_networks && disable_link_local_networks == o.disable_link_local_networks && - enable_rtp_data_channel == o.enable_rtp_data_channel && screencast_min_bitrate == o.screencast_min_bitrate && combined_audio_video_bwe == o.combined_audio_video_bwe && enable_dtls_srtp == o.enable_dtls_srtp && @@ -379,7 +395,9 @@ bool PeerConnectionInterface::RTCConfiguration::operator==( turn_logging_id == o.turn_logging_id && enable_implicit_rollback == o.enable_implicit_rollback && allow_codec_switching == o.allow_codec_switching && - report_usage_pattern_delay_ms == o.report_usage_pattern_delay_ms; + report_usage_pattern_delay_ms == o.report_usage_pattern_delay_ms && + stable_writable_connection_ping_interval_ms == + o.stable_writable_connection_ping_interval_ms; } bool PeerConnectionInterface::RTCConfiguration::operator!=( @@ -421,11 +439,35 @@ RTCErrorOr> PeerConnection::Create( bool is_unified_plan = configuration.sdp_semantics == SdpSemantics::kUnifiedPlan; + bool dtls_enabled = DtlsEnabled(configuration, options, dependencies); + + // Interim code: If an AsyncResolverFactory is given, but not an + // AsyncDnsResolverFactory, wrap it in a WrappingAsyncDnsResolverFactory + // If neither is given, create a WrappingAsyncDnsResolverFactory wrapping + // a BasicAsyncResolver. + // TODO(bugs.webrtc.org/12598): Remove code once all callers pass a + // AsyncDnsResolverFactory. + if (dependencies.async_dns_resolver_factory && + dependencies.async_resolver_factory) { + RTC_LOG(LS_ERROR) + << "Attempt to set both old and new type of DNS resolver factory"; + return RTCError(RTCErrorType::INVALID_PARAMETER, + "Both old and new type of DNS resolver given"); + } + if (dependencies.async_resolver_factory) { + dependencies.async_dns_resolver_factory = + std::make_unique( + std::move(dependencies.async_resolver_factory)); + } else { + dependencies.async_dns_resolver_factory = + std::make_unique( + std::make_unique()); + } + // The PeerConnection constructor consumes some, but not all, dependencies. - rtc::scoped_refptr pc( - new rtc::RefCountedObject( - context, options, is_unified_plan, std::move(event_log), - std::move(call), dependencies)); + auto pc = rtc::make_ref_counted( + context, options, is_unified_plan, std::move(event_log), std::move(call), + dependencies, dtls_enabled); RTCError init_error = pc->Initialize(configuration, std::move(dependencies)); if (!init_error.ok()) { RTC_LOG(LS_ERROR) << "PeerConnection initialization failed"; @@ -440,21 +482,37 @@ PeerConnection::PeerConnection( bool is_unified_plan, std::unique_ptr event_log, std::unique_ptr call, - PeerConnectionDependencies& dependencies) + PeerConnectionDependencies& dependencies, + bool dtls_enabled) : context_(context), options_(options), observer_(dependencies.observer), is_unified_plan_(is_unified_plan), event_log_(std::move(event_log)), event_log_ptr_(event_log_.get()), - async_resolver_factory_(std::move(dependencies.async_resolver_factory)), + async_dns_resolver_factory_( + std::move(dependencies.async_dns_resolver_factory)), port_allocator_(std::move(dependencies.allocator)), ice_transport_factory_(std::move(dependencies.ice_transport_factory)), tls_cert_verifier_(std::move(dependencies.tls_cert_verifier)), call_(std::move(call)), call_ptr_(call_.get()), + // RFC 3264: The numeric value of the session id and version in the + // o line MUST be representable with a "64 bit signed integer". + // Due to this constraint session id |session_id_| is max limited to + // LLONG_MAX. + session_id_(rtc::ToString(rtc::CreateRandomId64() & LLONG_MAX)), + dtls_enabled_(dtls_enabled), data_channel_controller_(this), - message_handler_(signaling_thread()) {} + message_handler_(signaling_thread()), + weak_factory_(this) { + worker_thread()->Invoke(RTC_FROM_HERE, [this] { + RTC_DCHECK_RUN_ON(worker_thread()); + worker_thread_safety_ = PendingTaskSafetyFlag::Create(); + if (!call_) + worker_thread_safety_->SetNotAlive(); + }); +} PeerConnection::~PeerConnection() { TRACE_EVENT0("webrtc", "PeerConnection::~PeerConnection"); @@ -488,17 +546,22 @@ PeerConnection::~PeerConnection() { sdp_handler_->ResetSessionDescFactory(); } - transport_controller_.reset(); - // port_allocator_ lives on the network thread and should be destroyed there. + // port_allocator_ and transport_controller_ live on the network thread and + // should be destroyed there. network_thread()->Invoke(RTC_FROM_HERE, [this] { RTC_DCHECK_RUN_ON(network_thread()); + TeardownDataChannelTransport_n(); + transport_controller_.reset(); port_allocator_.reset(); + if (network_thread_safety_) + network_thread_safety_->SetNotAlive(); }); + // call_ and event_log_ must be destroyed on the worker thread. worker_thread()->Invoke(RTC_FROM_HERE, [this] { RTC_DCHECK_RUN_ON(worker_thread()); - call_safety_.reset(); + worker_thread_safety_->SetNotAlive(); call_.reset(); // The event log must outlive call (and any other object that uses it). event_log_.reset(); @@ -525,14 +588,6 @@ RTCError PeerConnection::Initialize( turn_server.turn_logging_id = configuration.turn_logging_id; } - // The port allocator lives on the network thread and should be initialized - // there. - const auto pa_result = - network_thread()->Invoke( - RTC_FROM_HERE, - rtc::Bind(&PeerConnection::InitializePortAllocator_n, this, - stun_servers, turn_servers, configuration)); - // Note if STUN or TURN servers were supplied. if (!stun_servers.empty()) { NoteUsageEvent(UsageEvent::STUN_SERVER_ADDED); @@ -541,98 +596,24 @@ RTCError PeerConnection::Initialize( NoteUsageEvent(UsageEvent::TURN_SERVER_ADDED); } - // Send information about IPv4/IPv6 status. - PeerConnectionAddressFamilyCounter address_family; - if (pa_result.enable_ipv6) { - address_family = kPeerConnection_IPv6; - } else { - address_family = kPeerConnection_IPv4; - } - RTC_HISTOGRAM_ENUMERATION("WebRTC.PeerConnection.IPMetrics", address_family, - kPeerConnectionAddressFamilyCounter_Max); - - // RFC 3264: The numeric value of the session id and version in the - // o line MUST be representable with a "64 bit signed integer". - // Due to this constraint session id |session_id_| is max limited to - // LLONG_MAX. - session_id_ = rtc::ToString(rtc::CreateRandomId64() & LLONG_MAX); - JsepTransportController::Config config; - config.redetermine_role_on_ice_restart = - configuration.redetermine_role_on_ice_restart; - config.ssl_max_version = options_.ssl_max_version; - config.disable_encryption = options_.disable_encryption; - config.bundle_policy = configuration.bundle_policy; - config.rtcp_mux_policy = configuration.rtcp_mux_policy; - // TODO(bugs.webrtc.org/9891) - Remove options_.crypto_options then remove - // this stub. - config.crypto_options = configuration.crypto_options.has_value() - ? *configuration.crypto_options - : options_.crypto_options; - config.transport_observer = this; - config.rtcp_handler = InitializeRtcpCallback(); - config.event_log = event_log_ptr_; -#if defined(ENABLE_EXTERNAL_AUTH) - config.enable_external_auth = true; -#endif - config.active_reset_srtp_params = configuration.active_reset_srtp_params; - - if (options_.disable_encryption) { - dtls_enabled_ = false; - } else { - // Enable DTLS by default if we have an identity store or a certificate. - dtls_enabled_ = - (dependencies.cert_generator || !configuration.certificates.empty()); - // |configuration| can override the default |dtls_enabled_| value. - if (configuration.enable_dtls_srtp) { - dtls_enabled_ = *(configuration.enable_dtls_srtp); - } - } - - if (configuration.enable_rtp_data_channel) { - // Enable creation of RTP data channels if the kEnableRtpDataChannels is - // set. It takes precendence over the disable_sctp_data_channels - // PeerConnectionFactoryInterface::Options. - data_channel_controller_.set_data_channel_type(cricket::DCT_RTP); - } else { - // DTLS has to be enabled to use SCTP. - if (!options_.disable_sctp_data_channels && dtls_enabled_) { - data_channel_controller_.set_data_channel_type(cricket::DCT_SCTP); - config.sctp_factory = context_->sctp_transport_factory(); - } - } - - config.ice_transport_factory = ice_transport_factory_.get(); - - transport_controller_.reset(new JsepTransportController( - signaling_thread(), network_thread(), port_allocator_.get(), - async_resolver_factory_.get(), config)); - transport_controller_->SignalStandardizedIceConnectionState.connect( - this, &PeerConnection::SetStandardizedIceConnectionState); - transport_controller_->SignalConnectionState.connect( - this, &PeerConnection::SetConnectionState); - transport_controller_->SignalIceGatheringState.connect( - this, &PeerConnection::OnTransportControllerGatheringState); - transport_controller_->SignalIceCandidatesGathered.connect( - this, &PeerConnection::OnTransportControllerCandidatesGathered); - transport_controller_->SignalIceCandidateError.connect( - this, &PeerConnection::OnTransportControllerCandidateError); - transport_controller_->SignalIceCandidatesRemoved.connect( - this, &PeerConnection::OnTransportControllerCandidatesRemoved); - transport_controller_->SignalDtlsHandshakeError.connect( - this, &PeerConnection::OnTransportControllerDtlsHandshakeError); - transport_controller_->SignalIceCandidatePairChanged.connect( - this, &PeerConnection::OnTransportControllerCandidateChanged); - - transport_controller_->SignalIceConnectionState.AddReceiver( - [this](cricket::IceConnectionState s) { - RTC_DCHECK_RUN_ON(signaling_thread()); - OnTransportControllerConnectionState(s); - }); + // Network thread initialization. + network_thread()->Invoke(RTC_FROM_HERE, [this, &stun_servers, + &turn_servers, &configuration, + &dependencies] { + RTC_DCHECK_RUN_ON(network_thread()); + network_thread_safety_ = PendingTaskSafetyFlag::Create(); + InitializePortAllocatorResult pa_result = + InitializePortAllocator_n(stun_servers, turn_servers, configuration); + // Send information about IPv4/IPv6 status. + PeerConnectionAddressFamilyCounter address_family = + pa_result.enable_ipv6 ? kPeerConnection_IPv6 : kPeerConnection_IPv4; + RTC_HISTOGRAM_ENUMERATION("WebRTC.PeerConnection.IPMetrics", address_family, + kPeerConnectionAddressFamilyCounter_Max); + InitializeTransportController_n(configuration, dependencies); + }); configuration_ = configuration; - transport_controller_->SetIceConfig(ParseIceConfig(configuration)); - stats_ = std::make_unique(this); stats_collector_ = RTCStatsCollector::Create(this); @@ -650,10 +631,12 @@ RTCError PeerConnection::Initialize( if (!IsUnifiedPlan()) { rtp_manager()->transceivers()->Add( RtpTransceiverProxyWithInternal::Create( - signaling_thread(), new RtpTransceiver(cricket::MEDIA_TYPE_AUDIO))); + signaling_thread(), + new RtpTransceiver(cricket::MEDIA_TYPE_AUDIO, channel_manager()))); rtp_manager()->transceivers()->Add( RtpTransceiverProxyWithInternal::Create( - signaling_thread(), new RtpTransceiver(cricket::MEDIA_TYPE_VIDEO))); + signaling_thread(), + new RtpTransceiver(cricket::MEDIA_TYPE_VIDEO, channel_manager()))); } int delay_ms = configuration.report_usage_pattern_delay_ms @@ -669,6 +652,127 @@ RTCError PeerConnection::Initialize( return RTCError::OK(); } +void PeerConnection::InitializeTransportController_n( + const RTCConfiguration& configuration, + const PeerConnectionDependencies& dependencies) { + JsepTransportController::Config config; + config.redetermine_role_on_ice_restart = + configuration.redetermine_role_on_ice_restart; + config.ssl_max_version = options_.ssl_max_version; + config.disable_encryption = options_.disable_encryption; + config.bundle_policy = configuration.bundle_policy; + config.rtcp_mux_policy = configuration.rtcp_mux_policy; + // TODO(bugs.webrtc.org/9891) - Remove options_.crypto_options then remove + // this stub. + config.crypto_options = configuration.crypto_options.has_value() + ? *configuration.crypto_options + : options_.crypto_options; + config.transport_observer = this; + config.rtcp_handler = InitializeRtcpCallback(); + config.event_log = event_log_ptr_; +#if defined(ENABLE_EXTERNAL_AUTH) + config.enable_external_auth = true; +#endif + config.active_reset_srtp_params = configuration.active_reset_srtp_params; + + // DTLS has to be enabled to use SCTP. + if (dtls_enabled_) { + config.sctp_factory = context_->sctp_transport_factory(); + } + + config.ice_transport_factory = ice_transport_factory_.get(); + config.on_dtls_handshake_error_ = + [weak_ptr = weak_factory_.GetWeakPtr()](rtc::SSLHandshakeError s) { + if (weak_ptr) { + weak_ptr->OnTransportControllerDtlsHandshakeError(s); + } + }; + + transport_controller_.reset( + new JsepTransportController(network_thread(), port_allocator_.get(), + async_dns_resolver_factory_.get(), config)); + + transport_controller_->SubscribeIceConnectionState( + [this](cricket::IceConnectionState s) { + RTC_DCHECK_RUN_ON(network_thread()); + if (s == cricket::kIceConnectionConnected) { + ReportTransportStats(); + } + signaling_thread()->PostTask( + ToQueuedTask(signaling_thread_safety_.flag(), [this, s]() { + RTC_DCHECK_RUN_ON(signaling_thread()); + OnTransportControllerConnectionState(s); + })); + }); + transport_controller_->SubscribeConnectionState( + [this](PeerConnectionInterface::PeerConnectionState s) { + RTC_DCHECK_RUN_ON(network_thread()); + signaling_thread()->PostTask( + ToQueuedTask(signaling_thread_safety_.flag(), [this, s]() { + RTC_DCHECK_RUN_ON(signaling_thread()); + SetConnectionState(s); + })); + }); + transport_controller_->SubscribeStandardizedIceConnectionState( + [this](PeerConnectionInterface::IceConnectionState s) { + RTC_DCHECK_RUN_ON(network_thread()); + signaling_thread()->PostTask( + ToQueuedTask(signaling_thread_safety_.flag(), [this, s]() { + RTC_DCHECK_RUN_ON(signaling_thread()); + SetStandardizedIceConnectionState(s); + })); + }); + transport_controller_->SubscribeIceGatheringState( + [this](cricket::IceGatheringState s) { + RTC_DCHECK_RUN_ON(network_thread()); + signaling_thread()->PostTask( + ToQueuedTask(signaling_thread_safety_.flag(), [this, s]() { + RTC_DCHECK_RUN_ON(signaling_thread()); + OnTransportControllerGatheringState(s); + })); + }); + transport_controller_->SubscribeIceCandidateGathered( + [this](const std::string& transport, + const std::vector& candidates) { + RTC_DCHECK_RUN_ON(network_thread()); + signaling_thread()->PostTask( + ToQueuedTask(signaling_thread_safety_.flag(), + [this, t = transport, c = candidates]() { + RTC_DCHECK_RUN_ON(signaling_thread()); + OnTransportControllerCandidatesGathered(t, c); + })); + }); + transport_controller_->SubscribeIceCandidateError( + [this](const cricket::IceCandidateErrorEvent& event) { + RTC_DCHECK_RUN_ON(network_thread()); + signaling_thread()->PostTask(ToQueuedTask( + signaling_thread_safety_.flag(), [this, event = event]() { + RTC_DCHECK_RUN_ON(signaling_thread()); + OnTransportControllerCandidateError(event); + })); + }); + transport_controller_->SubscribeIceCandidatesRemoved( + [this](const std::vector& c) { + RTC_DCHECK_RUN_ON(network_thread()); + signaling_thread()->PostTask( + ToQueuedTask(signaling_thread_safety_.flag(), [this, c = c]() { + RTC_DCHECK_RUN_ON(signaling_thread()); + OnTransportControllerCandidatesRemoved(c); + })); + }); + transport_controller_->SubscribeIceCandidatePairChanged( + [this](const cricket::CandidatePairChangeEvent& event) { + RTC_DCHECK_RUN_ON(network_thread()); + signaling_thread()->PostTask(ToQueuedTask( + signaling_thread_safety_.flag(), [this, event = event]() { + RTC_DCHECK_RUN_ON(signaling_thread()); + OnTransportControllerCandidateChanged(event); + })); + }); + + transport_controller_->SetIceConfig(ParseIceConfig(configuration)); +} + rtc::scoped_refptr PeerConnection::local_streams() { RTC_DCHECK_RUN_ON(signaling_thread()); RTC_CHECK(!IsUnifiedPlan()) << "local_streams is not available with Unified " @@ -792,6 +896,16 @@ PeerConnection::AddTransceiver( return AddTransceiver(track, RtpTransceiverInit()); } +RtpTransportInternal* PeerConnection::GetRtpTransport(const std::string& mid) { + RTC_DCHECK_RUN_ON(signaling_thread()); + return network_thread()->Invoke( + RTC_FROM_HERE, [this, &mid] { + auto rtp_transport = transport_controller_->GetRtpTransport(mid); + RTC_DCHECK(rtp_transport); + return rtp_transport; + }); +} + RTCErrorOr> PeerConnection::AddTransceiver( rtc::scoped_refptr track, @@ -883,9 +997,11 @@ PeerConnection::AddTransceiver( parameters.encodings = init.send_encodings; // Encodings are dropped from the tail if too many are provided. - if (parameters.encodings.size() > kMaxSimulcastStreams) { + size_t max_simulcast_streams = + media_type == cricket::MEDIA_TYPE_VIDEO ? kMaxSimulcastStreams : 1u; + if (parameters.encodings.size() > max_simulcast_streams) { parameters.encodings.erase( - parameters.encodings.begin() + kMaxSimulcastStreams, + parameters.encodings.begin() + max_simulcast_streams, parameters.encodings.end()); } @@ -1033,6 +1149,8 @@ bool PeerConnection::GetStats(StatsObserver* observer, return false; } + RTC_LOG_THREAD_BLOCK_COUNT(); + stats_->UpdateStats(level); // The StatsCollector is used to tell if a track is valid because it may // remember tracks that the PeerConnection previously removed. @@ -1042,6 +1160,7 @@ bool PeerConnection::GetStats(StatsObserver* observer, return false; } message_handler_.PostGetStats(observer, stats_.get(), track); + return true; } @@ -1050,6 +1169,7 @@ void PeerConnection::GetStats(RTCStatsCollectorCallback* callback) { RTC_DCHECK_RUN_ON(signaling_thread()); RTC_DCHECK(stats_collector_); RTC_DCHECK(callback); + RTC_LOG_THREAD_BLOCK_COUNT(); stats_collector_->GetStatsReport(callback); } @@ -1159,9 +1279,9 @@ absl::optional PeerConnection::can_trickle_ice_candidates() { "trickle"); } -rtc::scoped_refptr PeerConnection::CreateDataChannel( - const std::string& label, - const DataChannelInit* config) { +RTCErrorOr> +PeerConnection::CreateDataChannelOrError(const std::string& label, + const DataChannelInit* config) { RTC_DCHECK_RUN_ON(signaling_thread()); TRACE_EVENT0("webrtc", "PeerConnection::CreateDataChannel"); @@ -1171,16 +1291,18 @@ rtc::scoped_refptr PeerConnection::CreateDataChannel( if (config) { internal_config.reset(new InternalDataChannelInit(*config)); } + // TODO(bugs.webrtc.org/12796): Return a more specific error. rtc::scoped_refptr channel( data_channel_controller_.InternalCreateDataChannelWithProxy( label, internal_config.get())); if (!channel.get()) { - return nullptr; + return RTCError(RTCErrorType::INTERNAL_ERROR, + "Data channel creation failed"); } - // Trigger the onRenegotiationNeeded event for every new RTP DataChannel, or + // Trigger the onRenegotiationNeeded event for // the first SCTP DataChannel. - if (data_channel_type() == cricket::DCT_RTP || first_datachannel) { + if (first_datachannel) { sdp_handler_->UpdateNegotiationNeeded(); } NoteUsageEvent(UsageEvent::DATA_ADDED); @@ -1305,6 +1427,8 @@ RTCError PeerConnection::SetConfiguration( configuration.active_reset_srtp_params; modified_config.turn_logging_id = configuration.turn_logging_id; modified_config.allow_codec_switching = configuration.allow_codec_switching; + modified_config.stable_writable_connection_ping_interval_ms = + configuration.stable_writable_connection_ping_interval_ms; if (configuration != modified_config) { LOG_AND_RETURN_ERROR(RTCErrorType::INVALID_MODIFICATION, "Modifying the configuration in an unsupported way."); @@ -1344,36 +1468,46 @@ RTCError PeerConnection::SetConfiguration( NoteUsageEvent(UsageEvent::TURN_SERVER_ADDED); } - // In theory this shouldn't fail. - if (!network_thread()->Invoke( - RTC_FROM_HERE, - rtc::Bind(&PeerConnection::ReconfigurePortAllocator_n, this, - stun_servers, turn_servers, modified_config.type, - modified_config.ice_candidate_pool_size, - modified_config.GetTurnPortPrunePolicy(), - modified_config.turn_customizer, - modified_config.stun_candidate_keepalive_interval, - static_cast(local_description())))) { - LOG_AND_RETURN_ERROR(RTCErrorType::INTERNAL_ERROR, - "Failed to apply configuration to PortAllocator."); - } + const bool has_local_description = local_description() != nullptr; - // As described in JSEP, calling setConfiguration with new ICE servers or - // candidate policy must set a "needs-ice-restart" bit so that the next offer - // triggers an ICE restart which will pick up the changes. - if (modified_config.servers != configuration_.servers || + const bool needs_ice_restart = + modified_config.servers != configuration_.servers || NeedIceRestart( configuration_.surface_ice_candidates_on_ice_transport_type_changed, configuration_.type, modified_config.type) || modified_config.GetTurnPortPrunePolicy() != - configuration_.GetTurnPortPrunePolicy()) { - transport_controller_->SetNeedsIceRestartFlag(); - } + configuration_.GetTurnPortPrunePolicy(); + cricket::IceConfig ice_config = ParseIceConfig(modified_config); - transport_controller_->SetIceConfig(ParseIceConfig(modified_config)); + // Apply part of the configuration on the network thread. In theory this + // shouldn't fail. + if (!network_thread()->Invoke( + RTC_FROM_HERE, + [this, needs_ice_restart, &ice_config, &stun_servers, &turn_servers, + &modified_config, has_local_description] { + // As described in JSEP, calling setConfiguration with new ICE + // servers or candidate policy must set a "needs-ice-restart" bit so + // that the next offer triggers an ICE restart which will pick up + // the changes. + if (needs_ice_restart) + transport_controller_->SetNeedsIceRestartFlag(); + + transport_controller_->SetIceConfig(ice_config); + return ReconfigurePortAllocator_n( + stun_servers, turn_servers, modified_config.type, + modified_config.ice_candidate_pool_size, + modified_config.GetTurnPortPrunePolicy(), + modified_config.turn_customizer, + modified_config.stun_candidate_keepalive_interval, + has_local_description); + })) { + LOG_AND_RETURN_ERROR(RTCErrorType::INTERNAL_ERROR, + "Failed to apply configuration to PortAllocator."); + } if (configuration_.active_reset_srtp_params != modified_config.active_reset_srtp_params) { + // TODO(tommi): move to the network thread - this hides an invoke. transport_controller_->SetActiveResetSrtpParams( modified_config.active_reset_srtp_params); } @@ -1468,8 +1602,7 @@ RTCError PeerConnection::SetBitrate(const BitrateSettings& bitrate) { void PeerConnection::SetAudioPlayout(bool playout) { if (!worker_thread()->IsCurrent()) { worker_thread()->Invoke( - RTC_FROM_HERE, - rtc::Bind(&PeerConnection::SetAudioPlayout, this, playout)); + RTC_FROM_HERE, [this, playout] { SetAudioPlayout(playout); }); return; } auto audio_state = @@ -1480,8 +1613,7 @@ void PeerConnection::SetAudioPlayout(bool playout) { void PeerConnection::SetAudioRecording(bool recording) { if (!worker_thread()->IsCurrent()) { worker_thread()->Invoke( - RTC_FROM_HERE, - rtc::Bind(&PeerConnection::SetAudioRecording, this, recording)); + RTC_FROM_HERE, [this, recording] { SetAudioRecording(recording); }); return; } auto audio_state = @@ -1524,13 +1656,12 @@ bool PeerConnection::StartRtcEventLog( } void PeerConnection::StopRtcEventLog() { - worker_thread()->Invoke( - RTC_FROM_HERE, rtc::Bind(&PeerConnection::StopRtcEventLog_w, this)); + worker_thread()->Invoke(RTC_FROM_HERE, [this] { StopRtcEventLog_w(); }); } rtc::scoped_refptr PeerConnection::LookupDtlsTransportByMid(const std::string& mid) { - RTC_DCHECK_RUN_ON(signaling_thread()); + RTC_DCHECK_RUN_ON(network_thread()); return transport_controller_->LookupDtlsTransportByMid(mid); } @@ -1542,11 +1673,11 @@ PeerConnection::LookupDtlsTransportByMidInternal(const std::string& mid) { rtc::scoped_refptr PeerConnection::GetSctpTransport() const { - RTC_DCHECK_RUN_ON(signaling_thread()); - if (!sctp_mid_s_) { + RTC_DCHECK_RUN_ON(network_thread()); + if (!sctp_mid_n_) return nullptr; - } - return transport_controller_->GetSctpTransport(*sctp_mid_s_); + + return transport_controller_->GetSctpTransport(*sctp_mid_n_); } const SessionDescriptionInterface* PeerConnection::local_description() const { @@ -1587,6 +1718,8 @@ void PeerConnection::Close() { RTC_DCHECK_RUN_ON(signaling_thread()); TRACE_EVENT0("webrtc", "PeerConnection::Close"); + RTC_LOG_THREAD_BLOCK_COUNT(); + if (IsClosed()) { return; } @@ -1627,16 +1760,24 @@ void PeerConnection::Close() { // WebRTC session description factory, the session description factory would // call the transport controller. sdp_handler_->ResetSessionDescFactory(); - transport_controller_.reset(); rtp_manager_->Close(); - network_thread()->Invoke( - RTC_FROM_HERE, rtc::Bind(&cricket::PortAllocator::DiscardCandidatePool, - port_allocator_.get())); + network_thread()->Invoke(RTC_FROM_HERE, [this] { + // Data channels will already have been unset via the DestroyAllChannels() + // call above, which triggers a call to TeardownDataChannelTransport_n(). + // TODO(tommi): ^^ That's not exactly optimal since this is yet another + // blocking hop to the network thread during Close(). Further still, the + // voice/video/data channels will be cleared on the worker thread. + transport_controller_.reset(); + port_allocator_->DiscardCandidatePool(); + if (network_thread_safety_) { + network_thread_safety_->SetNotAlive(); + } + }); worker_thread()->Invoke(RTC_FROM_HERE, [this] { RTC_DCHECK_RUN_ON(worker_thread()); - call_safety_.reset(); + worker_thread_safety_->SetNotAlive(); call_.reset(); // The event log must outlive call (and any other object that uses it). event_log_.reset(); @@ -1645,6 +1786,10 @@ void PeerConnection::Close() { // The .h file says that observer can be discarded after close() returns. // Make sure this is true. observer_ = nullptr; + + // Signal shutdown to the sdp handler. This invalidates weak pointers for + // internal pending callbacks. + sdp_handler_->PrepareForShutdown(); } void PeerConnection::SetIceConnectionState(IceConnectionState new_state) { @@ -1693,6 +1838,62 @@ void PeerConnection::SetConnectionState( return; connection_state_ = new_state; Observer()->OnConnectionChange(new_state); + + if (new_state == PeerConnectionState::kConnected && !was_ever_connected_) { + was_ever_connected_ = true; + + // The first connection state change to connected happens once per + // connection which makes it a good point to report metrics. + // Record bundle-policy from configuration. Done here from + // connectionStateChange to limit to actually established connections. + BundlePolicyUsage policy = kBundlePolicyUsageMax; + switch (configuration_.bundle_policy) { + case kBundlePolicyBalanced: + policy = kBundlePolicyUsageBalanced; + break; + case kBundlePolicyMaxBundle: + policy = kBundlePolicyUsageMaxBundle; + break; + case kBundlePolicyMaxCompat: + policy = kBundlePolicyUsageMaxCompat; + break; + } + RTC_HISTOGRAM_ENUMERATION("WebRTC.PeerConnection.BundlePolicy", policy, + kBundlePolicyUsageMax); + + // Record configured ice candidate pool size depending on the + // BUNDLE policy. See + // https://w3c.github.io/webrtc-pc/#dom-rtcconfiguration-icecandidatepoolsize + // The ICE candidate pool size is an optimization and it may be desirable + // to restrict the maximum size of the pre-gathered candidates. + switch (configuration_.bundle_policy) { + case kBundlePolicyBalanced: + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.PeerConnection.CandidatePoolUsage.Balanced", + configuration_.ice_candidate_pool_size, 0, 255, 256); + break; + case kBundlePolicyMaxBundle: + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.PeerConnection.CandidatePoolUsage.MaxBundle", + configuration_.ice_candidate_pool_size, 0, 255, 256); + break; + case kBundlePolicyMaxCompat: + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.PeerConnection.CandidatePoolUsage.MaxCompat", + configuration_.ice_candidate_pool_size, 0, 255, 256); + break; + } + + // Record whether there was a local or remote provisional answer. + ProvisionalAnswerUsage pranswer = kProvisionalAnswerNotUsed; + if (local_description()->GetType() == SdpType::kPrAnswer) { + pranswer = kProvisionalAnswerLocal; + } else if (remote_description()->GetType() == SdpType::kPrAnswer) { + pranswer = kProvisionalAnswerRemote; + } + RTC_HISTOGRAM_ENUMERATION("WebRTC.PeerConnection.ProvisionalAnswer", + pranswer, kProvisionalAnswerMax); + } } void PeerConnection::OnIceGatheringChange( @@ -1752,17 +1953,18 @@ void PeerConnection::OnSelectedCandidatePairChanged( absl::optional PeerConnection::GetDataMid() const { RTC_DCHECK_RUN_ON(signaling_thread()); - switch (data_channel_type()) { - case cricket::DCT_RTP: - if (!data_channel_controller_.rtp_data_channel()) { - return absl::nullopt; - } - return data_channel_controller_.rtp_data_channel()->content_name(); - case cricket::DCT_SCTP: - return sctp_mid_s_; - default: - return absl::nullopt; - } + return sctp_mid_s_; +} + +void PeerConnection::SetSctpDataMid(const std::string& mid) { + RTC_DCHECK_RUN_ON(signaling_thread()); + sctp_mid_s_ = mid; +} + +void PeerConnection::ResetSctpDataMid() { + RTC_DCHECK_RUN_ON(signaling_thread()); + sctp_mid_s_.reset(); + sctp_transport_name_s_.clear(); } void PeerConnection::OnSctpDataChannelClosed(DataChannelInterface* channel) { @@ -1897,16 +2099,12 @@ void PeerConnection::StopRtcEventLog_w() { cricket::ChannelInterface* PeerConnection::GetChannel( const std::string& content_name) { - for (const auto& transceiver : rtp_manager()->transceivers()->List()) { + for (const auto& transceiver : rtp_manager()->transceivers()->UnsafeList()) { cricket::ChannelInterface* channel = transceiver->internal()->channel(); if (channel && channel->content_name() == content_name) { return channel; } } - if (rtp_data_channel() && - rtp_data_channel()->content_name() == content_name) { - return rtp_data_channel(); - } return nullptr; } @@ -1978,59 +2176,34 @@ std::vector PeerConnection::GetDataChannelStats() const { absl::optional PeerConnection::sctp_transport_name() const { RTC_DCHECK_RUN_ON(signaling_thread()); - if (sctp_mid_s_ && transport_controller_) { - auto dtls_transport = transport_controller_->GetDtlsTransport(*sctp_mid_s_); - if (dtls_transport) { - return dtls_transport->transport_name(); - } - return absl::optional(); - } + if (sctp_mid_s_ && transport_controller_) + return sctp_transport_name_s_; return absl::optional(); } +absl::optional PeerConnection::sctp_mid() const { + RTC_DCHECK_RUN_ON(signaling_thread()); + return sctp_mid_s_; +} + cricket::CandidateStatsList PeerConnection::GetPooledCandidateStats() const { + RTC_DCHECK_RUN_ON(network_thread()); + if (!network_thread_safety_->alive()) + return {}; cricket::CandidateStatsList candidate_states_list; - network_thread()->Invoke( - RTC_FROM_HERE, - rtc::Bind(&cricket::PortAllocator::GetCandidateStatsFromPooledSessions, - port_allocator_.get(), &candidate_states_list)); + port_allocator_->GetCandidateStatsFromPooledSessions(&candidate_states_list); return candidate_states_list; } -std::map PeerConnection::GetTransportNamesByMid() - const { - RTC_DCHECK_RUN_ON(signaling_thread()); - std::map transport_names_by_mid; - for (const auto& transceiver : rtp_manager()->transceivers()->List()) { - cricket::ChannelInterface* channel = transceiver->internal()->channel(); - if (channel) { - transport_names_by_mid[channel->content_name()] = - channel->transport_name(); - } - } - if (data_channel_controller_.rtp_data_channel()) { - transport_names_by_mid[data_channel_controller_.rtp_data_channel() - ->content_name()] = - data_channel_controller_.rtp_data_channel()->transport_name(); - } - if (data_channel_controller_.data_channel_transport()) { - absl::optional transport_name = sctp_transport_name(); - RTC_DCHECK(transport_name); - transport_names_by_mid[*sctp_mid_s_] = *transport_name; - } - return transport_names_by_mid; -} - std::map PeerConnection::GetTransportStatsByNames( const std::set& transport_names) { - if (!network_thread()->IsCurrent()) { - return network_thread() - ->Invoke>( - RTC_FROM_HERE, - [&] { return GetTransportStatsByNames(transport_names); }); - } + TRACE_EVENT0("webrtc", "PeerConnection::GetTransportStatsByNames"); RTC_DCHECK_RUN_ON(network_thread()); + if (!network_thread_safety_->alive()) + return {}; + + rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; std::map transport_stats_by_name; for (const std::string& transport_name : transport_names) { cricket::TransportStats transport_stats; @@ -2049,7 +2222,8 @@ PeerConnection::GetTransportStatsByNames( bool PeerConnection::GetLocalCertificate( const std::string& transport_name, rtc::scoped_refptr* certificate) { - if (!certificate) { + RTC_DCHECK_RUN_ON(network_thread()); + if (!network_thread_safety_->alive() || !certificate) { return false; } *certificate = transport_controller_->GetLocalCertificate(transport_name); @@ -2058,20 +2232,20 @@ bool PeerConnection::GetLocalCertificate( std::unique_ptr PeerConnection::GetRemoteSSLCertChain( const std::string& transport_name) { + RTC_DCHECK_RUN_ON(network_thread()); return transport_controller_->GetRemoteSSLCertChain(transport_name); } -cricket::DataChannelType PeerConnection::data_channel_type() const { - return data_channel_controller_.data_channel_type(); -} - bool PeerConnection::IceRestartPending(const std::string& content_name) const { RTC_DCHECK_RUN_ON(signaling_thread()); return sdp_handler_->IceRestartPending(content_name); } bool PeerConnection::NeedsIceRestart(const std::string& content_name) const { - return transport_controller_->NeedsIceRestart(content_name); + return network_thread()->Invoke(RTC_FROM_HERE, [this, &content_name] { + RTC_DCHECK_RUN_ON(network_thread()); + return transport_controller_->NeedsIceRestart(content_name); + }); } void PeerConnection::OnTransportControllerConnectionState( @@ -2111,8 +2285,8 @@ void PeerConnection::OnTransportControllerConnectionState( SetIceConnectionState(PeerConnectionInterface::kIceConnectionConnected); } SetIceConnectionState(PeerConnectionInterface::kIceConnectionCompleted); + NoteUsageEvent(UsageEvent::ICE_STATE_CONNECTED); - ReportTransportStats(); break; default: RTC_NOTREACHED(); @@ -2122,6 +2296,8 @@ void PeerConnection::OnTransportControllerConnectionState( void PeerConnection::OnTransportControllerCandidatesGathered( const std::string& transport_name, const cricket::Candidates& candidates) { + // TODO(bugs.webrtc.org/12427): Expect this to come in on the network thread + // (not signaling as it currently does), handle appropriately. int sdp_mline_index; if (!GetLocalCandidateMediaIndex(transport_name, &sdp_mline_index)) { RTC_LOG(LS_ERROR) @@ -2196,7 +2372,7 @@ bool PeerConnection::GetLocalCandidateMediaIndex( Call::Stats PeerConnection::GetCallStats() { if (!worker_thread()->IsCurrent()) { return worker_thread()->Invoke( - RTC_FROM_HERE, rtc::Bind(&PeerConnection::GetCallStats, this)); + RTC_FROM_HERE, [this] { return GetCallStats(); }); } RTC_DCHECK_RUN_ON(worker_thread()); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; @@ -2221,6 +2397,16 @@ bool PeerConnection::SetupDataChannelTransport_n(const std::string& mid) { data_channel_controller_.set_data_channel_transport(transport); data_channel_controller_.SetupDataChannelTransport_n(); sctp_mid_n_ = mid; + cricket::DtlsTransportInternal* dtls_transport = + transport_controller_->GetDtlsTransport(mid); + if (dtls_transport) { + signaling_thread()->PostTask( + ToQueuedTask(signaling_thread_safety_.flag(), + [this, name = dtls_transport->transport_name()] { + RTC_DCHECK_RUN_ON(signaling_thread()); + sctp_transport_name_s_ = std::move(name); + })); + } // Note: setting the data sink and checking initial state must be done last, // after setting up the data channel. Setting the data sink may trigger @@ -2231,34 +2417,32 @@ bool PeerConnection::SetupDataChannelTransport_n(const std::string& mid) { } void PeerConnection::TeardownDataChannelTransport_n() { - if (!sctp_mid_n_ && !data_channel_controller_.data_channel_transport()) { - return; + if (sctp_mid_n_) { + // |sctp_mid_| may still be active through an SCTP transport. If not, unset + // it. + RTC_LOG(LS_INFO) << "Tearing down data channel transport for mid=" + << *sctp_mid_n_; + sctp_mid_n_.reset(); } - RTC_LOG(LS_INFO) << "Tearing down data channel transport for mid=" - << *sctp_mid_n_; - // |sctp_mid_| may still be active through an SCTP transport. If not, unset - // it. - sctp_mid_n_.reset(); data_channel_controller_.TeardownDataChannelTransport_n(); } // Returns false if bundle is enabled and rtcp_mux is disabled. -bool PeerConnection::ValidateBundleSettings(const SessionDescription* desc) { - bool bundle_enabled = desc->HasGroup(cricket::GROUP_TYPE_BUNDLE); - if (!bundle_enabled) +bool PeerConnection::ValidateBundleSettings( + const SessionDescription* desc, + const std::map& + bundle_groups_by_mid) { + if (bundle_groups_by_mid.empty()) return true; - const cricket::ContentGroup* bundle_group = - desc->GetGroupByName(cricket::GROUP_TYPE_BUNDLE); - RTC_DCHECK(bundle_group != NULL); - const cricket::ContentInfos& contents = desc->contents(); for (cricket::ContentInfos::const_iterator citer = contents.begin(); citer != contents.end(); ++citer) { const cricket::ContentInfo* content = (&*citer); RTC_DCHECK(content != NULL); - if (bundle_group->HasContentName(content->name) && !content->rejected && + auto it = bundle_groups_by_mid.find(content->name); + if (it != bundle_groups_by_mid.end() && !content->rejected && content->type == MediaProtocolType::kRtp) { if (!HasRtcpMuxEnabled(content)) return false; @@ -2269,12 +2453,13 @@ bool PeerConnection::ValidateBundleSettings(const SessionDescription* desc) { } void PeerConnection::ReportSdpFormatReceived( - const SessionDescriptionInterface& remote_offer) { + const SessionDescriptionInterface& remote_description) { int num_audio_mlines = 0; int num_video_mlines = 0; int num_audio_tracks = 0; int num_video_tracks = 0; - for (const ContentInfo& content : remote_offer.description()->contents()) { + for (const ContentInfo& content : + remote_description.description()->contents()) { cricket::MediaType media_type = content.media_description()->type(); int num_tracks = std::max( 1, static_cast(content.media_description()->streams().size())); @@ -2294,7 +2479,7 @@ void PeerConnection::ReportSdpFormatReceived( } else if (num_audio_tracks > 0 || num_video_tracks > 0) { format = kSdpFormatReceivedSimple; } - switch (remote_offer.GetType()) { + switch (remote_description.GetType()) { case SdpType::kOffer: // Historically only offers were counted. RTC_HISTOGRAM_ENUMERATION("WebRTC.PeerConnection.SdpFormatReceived", @@ -2306,11 +2491,57 @@ void PeerConnection::ReportSdpFormatReceived( break; default: RTC_LOG(LS_ERROR) << "Can not report SdpFormatReceived for " - << SdpTypeToString(remote_offer.GetType()); + << SdpTypeToString(remote_description.GetType()); break; } } +void PeerConnection::ReportSdpBundleUsage( + const SessionDescriptionInterface& remote_description) { + RTC_DCHECK_RUN_ON(signaling_thread()); + + bool using_bundle = + remote_description.description()->HasGroup(cricket::GROUP_TYPE_BUNDLE); + int num_audio_mlines = 0; + int num_video_mlines = 0; + int num_data_mlines = 0; + for (const ContentInfo& content : + remote_description.description()->contents()) { + cricket::MediaType media_type = content.media_description()->type(); + if (media_type == cricket::MEDIA_TYPE_AUDIO) { + num_audio_mlines += 1; + } else if (media_type == cricket::MEDIA_TYPE_VIDEO) { + num_video_mlines += 1; + } else if (media_type == cricket::MEDIA_TYPE_DATA) { + num_data_mlines += 1; + } + } + bool simple = num_audio_mlines <= 1 && num_video_mlines <= 1; + BundleUsage usage = kBundleUsageMax; + if (num_audio_mlines == 0 && num_video_mlines == 0) { + if (num_data_mlines > 0) { + usage = using_bundle ? kBundleUsageBundleDatachannelOnly + : kBundleUsageNoBundleDatachannelOnly; + } else { + usage = kBundleUsageEmpty; + } + } else if (configuration_.sdp_semantics == SdpSemantics::kPlanB) { + // In plan-b, simple/complex usage will not show up in the number of + // m-lines or BUNDLE. + usage = using_bundle ? kBundleUsageBundlePlanB : kBundleUsageNoBundlePlanB; + } else { + if (simple) { + usage = + using_bundle ? kBundleUsageBundleSimple : kBundleUsageNoBundleSimple; + } else { + usage = using_bundle ? kBundleUsageBundleComplex + : kBundleUsageNoBundleComplex; + } + } + RTC_HISTOGRAM_ENUMERATION("WebRTC.PeerConnection.BundleUsage", usage, + kBundleUsageMax); +} + void PeerConnection::ReportIceCandidateCollected( const cricket::Candidate& candidate) { NoteUsageEvent(UsageEvent::CANDIDATE_COLLECTED); @@ -2330,11 +2561,70 @@ void PeerConnection::NoteUsageEvent(UsageEvent event) { usage_pattern_.NoteUsageEvent(event); } +// Asynchronously adds remote candidates on the network thread. +void PeerConnection::AddRemoteCandidate(const std::string& mid, + const cricket::Candidate& candidate) { + RTC_DCHECK_RUN_ON(signaling_thread()); + + network_thread()->PostTask(ToQueuedTask( + network_thread_safety_, [this, mid = mid, candidate = candidate] { + RTC_DCHECK_RUN_ON(network_thread()); + std::vector candidates = {candidate}; + RTCError error = + transport_controller_->AddRemoteCandidates(mid, candidates); + if (error.ok()) { + signaling_thread()->PostTask(ToQueuedTask( + signaling_thread_safety_.flag(), + [this, candidate = std::move(candidate)] { + ReportRemoteIceCandidateAdded(candidate); + // Candidates successfully submitted for checking. + if (ice_connection_state() == + PeerConnectionInterface::kIceConnectionNew || + ice_connection_state() == + PeerConnectionInterface::kIceConnectionDisconnected) { + // If state is New, then the session has just gotten its first + // remote ICE candidates, so go to Checking. If state is + // Disconnected, the session is re-using old candidates or + // receiving additional ones, so go to Checking. If state is + // Connected, stay Connected. + // TODO(bemasc): If state is Connected, and the new candidates + // are for a newly added transport, then the state actually + // _should_ move to checking. Add a way to distinguish that + // case. + SetIceConnectionState( + PeerConnectionInterface::kIceConnectionChecking); + } + // TODO(bemasc): If state is Completed, go back to Connected. + })); + } else { + RTC_LOG(LS_WARNING) << error.message(); + } + })); +} + void PeerConnection::ReportUsagePattern() const { usage_pattern_.ReportUsagePattern(observer_); } +void PeerConnection::ReportRemoteIceCandidateAdded( + const cricket::Candidate& candidate) { + RTC_DCHECK_RUN_ON(signaling_thread()); + + NoteUsageEvent(UsageEvent::REMOTE_CANDIDATE_ADDED); + + if (candidate.address().IsPrivateIP()) { + NoteUsageEvent(UsageEvent::REMOTE_PRIVATE_CANDIDATE_ADDED); + } + if (candidate.address().IsUnresolvedIP()) { + NoteUsageEvent(UsageEvent::REMOTE_MDNS_CANDIDATE_ADDED); + } + if (candidate.address().family() == AF_INET6) { + NoteUsageEvent(UsageEvent::REMOTE_IPV6_CANDIDATE_ADDED); + } +} + bool PeerConnection::SrtpRequired() const { + RTC_DCHECK_RUN_ON(signaling_thread()); return (dtls_enabled_ || sdp_handler_->webrtc_session_desc_factory()->SdesPolicy() == cricket::SEC_REQUIRED); @@ -2355,10 +2645,13 @@ void PeerConnection::OnTransportControllerGatheringState( } } +// Runs on network_thread(). void PeerConnection::ReportTransportStats() { + TRACE_EVENT0("webrtc", "PeerConnection::ReportTransportStats"); + rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; std::map> media_types_by_transport_name; - for (const auto& transceiver : rtp_manager()->transceivers()->List()) { + for (const auto& transceiver : rtp_manager()->transceivers()->UnsafeList()) { if (transceiver->internal()->channel()) { const std::string& transport_name = transceiver->internal()->channel()->transport_name(); @@ -2366,15 +2659,14 @@ void PeerConnection::ReportTransportStats() { transceiver->media_type()); } } - if (rtp_data_channel()) { - media_types_by_transport_name[rtp_data_channel()->transport_name()].insert( - cricket::MEDIA_TYPE_DATA); - } - absl::optional transport_name = sctp_transport_name(); - if (transport_name) { - media_types_by_transport_name[*transport_name].insert( - cricket::MEDIA_TYPE_DATA); + if (sctp_mid_n_) { + cricket::DtlsTransportInternal* dtls_transport = + transport_controller_->GetDtlsTransport(*sctp_mid_n_); + if (dtls_transport) { + media_types_by_transport_name[dtls_transport->transport_name()].insert( + cricket::MEDIA_TYPE_DATA); + } } for (const auto& entry : media_types_by_transport_name) { @@ -2383,12 +2675,14 @@ void PeerConnection::ReportTransportStats() { cricket::TransportStats stats; if (transport_controller_->GetStats(transport_name, &stats)) { ReportBestConnectionState(stats); - ReportNegotiatedCiphers(stats, media_types); + ReportNegotiatedCiphers(dtls_enabled_, stats, media_types); } } } + // Walk through the ConnectionInfos to gather best connection usage // for IPv4 and IPv6. +// static (no member state required) void PeerConnection::ReportBestConnectionState( const cricket::TransportStats& stats) { for (const cricket::TransportChannelStats& channel_stats : @@ -2436,10 +2730,12 @@ void PeerConnection::ReportBestConnectionState( } } +// static void PeerConnection::ReportNegotiatedCiphers( + bool dtls_enabled, const cricket::TransportStats& stats, const std::set& media_types) { - if (!dtls_enabled_ || stats.channel_stats.empty()) { + if (!dtls_enabled || stats.channel_stats.empty()) { return; } @@ -2501,12 +2797,6 @@ void PeerConnection::ReportNegotiatedCiphers( } } -void PeerConnection::OnSentPacket_w(const rtc::SentPacket& sent_packet) { - RTC_DCHECK_RUN_ON(worker_thread()); - RTC_DCHECK(call_); - call_->OnSentPacket(sent_packet); -} - bool PeerConnection::OnTransportChanged( const std::string& mid, RtpTransportInternal* rtp_transport, @@ -2518,9 +2808,19 @@ bool PeerConnection::OnTransportChanged( if (base_channel) { ret = base_channel->SetRtpTransport(rtp_transport); } + if (mid == sctp_mid_n_) { data_channel_controller_.OnTransportChanged(data_channel_transport); + if (dtls_transport) { + signaling_thread()->PostTask(ToQueuedTask( + signaling_thread_safety_.flag(), + [this, name = dtls_transport->internal()->transport_name()] { + RTC_DCHECK_RUN_ON(signaling_thread()); + sctp_transport_name_s_ = std::move(name); + })); + } } + return ret; } @@ -2530,6 +2830,23 @@ PeerConnectionObserver* PeerConnection::Observer() const { return observer_; } +void PeerConnection::StartSctpTransport(int local_port, + int remote_port, + int max_message_size) { + RTC_DCHECK_RUN_ON(signaling_thread()); + if (!sctp_mid_s_) + return; + + network_thread()->PostTask(ToQueuedTask( + network_thread_safety_, + [this, mid = *sctp_mid_s_, local_port, remote_port, max_message_size] { + rtc::scoped_refptr sctp_transport = + transport_controller()->GetSctpTransport(mid); + if (sctp_transport) + sctp_transport->Start(local_port, remote_port, max_message_size); + })); +} + CryptoOptions PeerConnection::GetCryptoOptions() { RTC_DCHECK_RUN_ON(signaling_thread()); // TODO(bugs.webrtc.org/9891) - Remove PeerConnectionFactory::CryptoOptions @@ -2563,34 +2880,11 @@ void PeerConnection::RequestUsagePatternReportForTesting() { std::function PeerConnection::InitializeRtcpCallback() { - RTC_DCHECK_RUN_ON(signaling_thread()); - - auto flag = - worker_thread()->Invoke>( - RTC_FROM_HERE, [this] { - RTC_DCHECK_RUN_ON(worker_thread()); - if (!call_) - return rtc::scoped_refptr(); - if (!call_safety_) - call_safety_.reset(new ScopedTaskSafety()); - return call_safety_->flag(); - }); - - if (!flag) - return [](const rtc::CopyOnWriteBuffer&, int64_t) {}; - - return [this, flag = std::move(flag)](const rtc::CopyOnWriteBuffer& packet, - int64_t packet_time_us) { + RTC_DCHECK_RUN_ON(network_thread()); + return [this](const rtc::CopyOnWriteBuffer& packet, int64_t packet_time_us) { RTC_DCHECK_RUN_ON(network_thread()); - // TODO(bugs.webrtc.org/11993): We should actually be delivering this call - // directly to the Call class somehow directly on the network thread and not - // incur this hop here. The DeliverPacket() method will eventually just have - // to hop back over to the network thread. - worker_thread()->PostTask(ToQueuedTask(flag, [this, packet, - packet_time_us] { - RTC_DCHECK_RUN_ON(worker_thread()); - call_->Receiver()->DeliverPacket(MediaType::ANY, packet, packet_time_us); - })); + call_ptr_->Receiver()->DeliverPacket(MediaType::ANY, packet, + packet_time_us); }; } diff --git a/pc/peer_connection.h b/pc/peer_connection.h index 8768ebb133..4476c5d8e1 100644 --- a/pc/peer_connection.h +++ b/pc/peer_connection.h @@ -12,6 +12,7 @@ #define PC_PEER_CONNECTION_H_ #include + #include #include #include @@ -22,6 +23,7 @@ #include "absl/types/optional.h" #include "api/adaptation/resource.h" +#include "api/async_dns_resolver.h" #include "api/async_resolver_factory.h" #include "api/audio_options.h" #include "api/candidate.h" @@ -43,6 +45,7 @@ #include "api/rtp_transceiver_interface.h" #include "api/scoped_refptr.h" #include "api/sctp_transport_interface.h" +#include "api/sequence_checker.h" #include "api/set_local_description_observer_interface.h" #include "api/set_remote_description_observer_interface.h" #include "api/stats/rtc_stats_collector_callback.h" @@ -69,7 +72,6 @@ #include "pc/peer_connection_internal.h" #include "pc/peer_connection_message_handler.h" #include "pc/rtc_stats_collector.h" -#include "pc/rtp_data_channel.h" #include "pc/rtp_receiver.h" #include "pc/rtp_sender.h" #include "pc/rtp_transceiver.h" @@ -86,17 +88,16 @@ #include "pc/usage_pattern.h" #include "rtc_base/checks.h" #include "rtc_base/copy_on_write_buffer.h" -#include "rtc_base/deprecation.h" #include "rtc_base/network/sent_packet.h" #include "rtc_base/rtc_certificate.h" #include "rtc_base/ssl_certificate.h" #include "rtc_base/ssl_stream_adapter.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/thread.h" #include "rtc_base/thread_annotations.h" #include "rtc_base/unique_id_generator.h" +#include "rtc_base/weak_ptr.h" namespace webrtc { @@ -166,7 +167,7 @@ class PeerConnection : public PeerConnectionInternal, std::vector> GetTransceivers() const override; - rtc::scoped_refptr CreateDataChannel( + RTCErrorOr> CreateDataChannelOrError( const std::string& label, const DataChannelInit* config) override; // WARNING: LEGACY. See peerconnectioninterface.h @@ -271,7 +272,6 @@ class PeerConnection : public PeerConnectionInternal, rtc::Thread* worker_thread() const final { return context_->worker_thread(); } std::string session_id() const override { - RTC_DCHECK_RUN_ON(signaling_thread()); return session_id_; } @@ -287,24 +287,16 @@ class PeerConnection : public PeerConnectionInternal, return rtp_manager()->transceivers()->List(); } - sigslot::signal1& SignalRtpDataChannelCreated() override { - return data_channel_controller_.SignalRtpDataChannelCreated(); - } - sigslot::signal1& SignalSctpDataChannelCreated() override { return data_channel_controller_.SignalSctpDataChannelCreated(); } - cricket::RtpDataChannel* rtp_data_channel() const override { - return data_channel_controller_.rtp_data_channel(); - } - std::vector GetDataChannelStats() const override; absl::optional sctp_transport_name() const override; + absl::optional sctp_mid() const override; cricket::CandidateStatsList GetPooledCandidateStats() const override; - std::map GetTransportNamesByMid() const override; std::map GetTransportStatsByNames( const std::set& transport_names) override; Call::Stats GetCallStats() override; @@ -324,7 +316,8 @@ class PeerConnection : public PeerConnectionInternal, PeerConnectionObserver* Observer() const; bool IsClosed() const { RTC_DCHECK_RUN_ON(signaling_thread()); - return sdp_handler_->signaling_state() == PeerConnectionInterface::kClosed; + return !sdp_handler_ || + sdp_handler_->signaling_state() == PeerConnectionInterface::kClosed; } // Get current SSL role used by SCTP's underlying transport. bool GetSctpSslRole(rtc::SSLRole* role); @@ -350,10 +343,6 @@ class PeerConnection : public PeerConnectionInternal, RTC_DCHECK_RUN_ON(signaling_thread()); return &configuration_; } - absl::optional sctp_mid() { - RTC_DCHECK_RUN_ON(signaling_thread()); - return sctp_mid_s_; - } PeerConnectionMessageHandler* message_handler() { RTC_DCHECK_RUN_ON(signaling_thread()); return &message_handler_; @@ -375,12 +364,20 @@ class PeerConnection : public PeerConnectionInternal, const PeerConnectionFactoryInterface::Options* options() const { return &options_; } - cricket::DataChannelType data_channel_type() const; void SetIceConnectionState(IceConnectionState new_state); void NoteUsageEvent(UsageEvent event); - // Report the UMA metric SdpFormatReceived for the given remote offer. - void ReportSdpFormatReceived(const SessionDescriptionInterface& remote_offer); + // Asynchronously adds a remote candidate on the network thread. + void AddRemoteCandidate(const std::string& mid, + const cricket::Candidate& candidate); + + // Report the UMA metric SdpFormatReceived for the given remote description. + void ReportSdpFormatReceived( + const SessionDescriptionInterface& remote_description); + + // Report the UMA metric BundleUsage for the given remote description. + void ReportSdpBundleUsage( + const SessionDescriptionInterface& remote_description); // Returns true if the PeerConnection is configured to use Unified Plan // semantics for creating offers/answers and setting local/remote @@ -392,21 +389,25 @@ class PeerConnection : public PeerConnectionInternal, RTC_DCHECK_RUN_ON(signaling_thread()); return is_unified_plan_; } - bool ValidateBundleSettings(const cricket::SessionDescription* desc); + bool ValidateBundleSettings( + const cricket::SessionDescription* desc, + const std::map& + bundle_groups_by_mid); - // Returns the MID for the data section associated with either the - // RtpDataChannel or SCTP data channel, if it has been set. If no data + // Returns the MID for the data section associated with the + // SCTP data channel, if it has been set. If no data // channels are configured this will return nullopt. absl::optional GetDataMid() const; - void SetSctpDataMid(const std::string& mid) { - RTC_DCHECK_RUN_ON(signaling_thread()); - sctp_mid_s_ = mid; - } - void ResetSctpDataMid() { - RTC_DCHECK_RUN_ON(signaling_thread()); - sctp_mid_s_.reset(); - } + void SetSctpDataMid(const std::string& mid); + + void ResetSctpDataMid(); + + // Asynchronously calls SctpTransport::Start() on the network thread for + // |sctp_mid()| if set. Called as part of setting the local description. + void StartSctpTransport(int local_port, + int remote_port, + int max_message_size); // Returns the CryptoOptions for this PeerConnection. This will always // return the RTCConfiguration.crypto_options if set and will only default @@ -422,23 +423,17 @@ class PeerConnection : public PeerConnectionInternal, bool fire_callback = true); // Returns rtp transport, result can not be nullptr. - RtpTransportInternal* GetRtpTransport(const std::string& mid) { - RTC_DCHECK_RUN_ON(signaling_thread()); - auto rtp_transport = transport_controller_->GetRtpTransport(mid); - RTC_DCHECK(rtp_transport); - return rtp_transport; - } + RtpTransportInternal* GetRtpTransport(const std::string& mid); // Returns true if SRTP (either using DTLS-SRTP or SDES) is required by // this session. - bool SrtpRequired() const RTC_RUN_ON(signaling_thread()); - - void OnSentPacket_w(const rtc::SentPacket& sent_packet); + bool SrtpRequired() const; bool SetupDataChannelTransport_n(const std::string& mid) RTC_RUN_ON(network_thread()); void TeardownDataChannelTransport_n() RTC_RUN_ON(network_thread()); - cricket::ChannelInterface* GetChannel(const std::string& content_name); + cricket::ChannelInterface* GetChannel(const std::string& content_name) + RTC_RUN_ON(network_thread()); // Functions made public for testing. void ReturnHistogramVeryQuicklyForTesting() { @@ -454,7 +449,8 @@ class PeerConnection : public PeerConnectionInternal, bool is_unified_plan, std::unique_ptr event_log, std::unique_ptr call, - PeerConnectionDependencies& dependencies); + PeerConnectionDependencies& dependencies, + bool dtls_enabled); ~PeerConnection() override; @@ -462,6 +458,10 @@ class PeerConnection : public PeerConnectionInternal, RTCError Initialize( const PeerConnectionInterface::RTCConfiguration& configuration, PeerConnectionDependencies dependencies); + void InitializeTransportController_n( + const RTCConfiguration& configuration, + const PeerConnectionDependencies& dependencies) + RTC_RUN_ON(network_thread()); rtc::scoped_refptr> FindTransceiverBySender(rtc::scoped_refptr sender) @@ -495,10 +495,8 @@ class PeerConnection : public PeerConnectionInternal, const cricket::CandidatePairChangeEvent& event) RTC_RUN_ON(signaling_thread()); - void OnNegotiationNeeded(); - // Returns the specified SCTP DataChannel in sctp_data_channels_, // or nullptr if not found. SctpDataChannel* FindDataChannelBySid(int sid) const @@ -569,19 +567,22 @@ class PeerConnection : public PeerConnectionInternal, // Invoked when TransportController connection completion is signaled. // Reports stats for all transports in use. - void ReportTransportStats() RTC_RUN_ON(signaling_thread()); + void ReportTransportStats() RTC_RUN_ON(network_thread()); // Gather the usage of IPv4/IPv6 as best connection. - void ReportBestConnectionState(const cricket::TransportStats& stats); + static void ReportBestConnectionState(const cricket::TransportStats& stats); - void ReportNegotiatedCiphers(const cricket::TransportStats& stats, - const std::set& media_types) - RTC_RUN_ON(signaling_thread()); + static void ReportNegotiatedCiphers( + bool dtls_enabled, + const cricket::TransportStats& stats, + const std::set& media_types); void ReportIceCandidateCollected(const cricket::Candidate& candidate) RTC_RUN_ON(signaling_thread()); void ReportUsagePattern() const RTC_RUN_ON(signaling_thread()); + void ReportRemoteIceCandidateAdded(const cricket::Candidate& candidate); + // JsepTransportController::Observer override. // // Called by |transport_controller_| when processing transport information @@ -624,10 +625,8 @@ class PeerConnection : public PeerConnectionInternal, PeerConnectionInterface::RTCConfiguration configuration_ RTC_GUARDED_BY(signaling_thread()); - // TODO(zstein): |async_resolver_factory_| can currently be nullptr if it - // is not injected. It should be required once chromium supplies it. - const std::unique_ptr async_resolver_factory_ - RTC_GUARDED_BY(signaling_thread()); + const std::unique_ptr + async_dns_resolver_factory_; std::unique_ptr port_allocator_; // TODO(bugs.webrtc.org/9987): Accessed on both // signaling and network thread. @@ -643,8 +642,9 @@ class PeerConnection : public PeerConnectionInternal, // The unique_ptr belongs to the worker thread, but the Call object manages // its own thread safety. std::unique_ptr call_ RTC_GUARDED_BY(worker_thread()); - std::unique_ptr call_safety_ - RTC_GUARDED_BY(worker_thread()); + ScopedTaskSafety signaling_thread_safety_; + rtc::scoped_refptr network_thread_safety_; + rtc::scoped_refptr worker_thread_safety_; // Points to the same thing as `call_`. Since it's const, we may read the // pointer from any thread. @@ -657,7 +657,7 @@ class PeerConnection : public PeerConnectionInternal, rtc::scoped_refptr stats_collector_ RTC_GUARDED_BY(signaling_thread()); - std::string session_id_ RTC_GUARDED_BY(signaling_thread()); + const std::string session_id_; std::unique_ptr transport_controller_; // TODO(bugs.webrtc.org/9987): Accessed on both @@ -672,12 +672,13 @@ class PeerConnection : public PeerConnectionInternal, // thread, but applied first on the networking thread via an invoke(). absl::optional sctp_mid_s_ RTC_GUARDED_BY(signaling_thread()); absl::optional sctp_mid_n_ RTC_GUARDED_BY(network_thread()); + std::string sctp_transport_name_s_ RTC_GUARDED_BY(signaling_thread()); // The machinery for handling offers and answers. Const after initialization. std::unique_ptr sdp_handler_ RTC_GUARDED_BY(signaling_thread()); - bool dtls_enabled_ RTC_GUARDED_BY(signaling_thread()) = false; + const bool dtls_enabled_; UsagePattern usage_pattern_ RTC_GUARDED_BY(signaling_thread()); bool return_histogram_very_quickly_ RTC_GUARDED_BY(signaling_thread()) = @@ -691,6 +692,12 @@ class PeerConnection : public PeerConnectionInternal, // Administration of senders, receivers and transceivers // Accessed on both signaling and network thread. Const after Initialize(). std::unique_ptr rtp_manager_; + + rtc::WeakPtrFactory weak_factory_; + + // Did the connectionState ever change to `connected`? + // Used to gather metrics only the first such state change. + bool was_ever_connected_ RTC_GUARDED_BY(signaling_thread()) = false; }; } // namespace webrtc diff --git a/pc/peer_connection_adaptation_integrationtest.cc b/pc/peer_connection_adaptation_integrationtest.cc index 71d054eb90..dfb12971b4 100644 --- a/pc/peer_connection_adaptation_integrationtest.cc +++ b/pc/peer_connection_adaptation_integrationtest.cc @@ -50,7 +50,7 @@ TrackWithPeriodicSource CreateTrackWithPeriodicSource( periodic_track_source_config.frame_interval_ms = 100; periodic_track_source_config.timestamp_offset_ms = rtc::TimeMillis(); rtc::scoped_refptr periodic_track_source = - new rtc::RefCountedObject( + rtc::make_ref_counted( periodic_track_source_config, /* remote */ false); TrackWithPeriodicSource track_with_source; track_with_source.track = @@ -83,7 +83,7 @@ class PeerConnectionAdaptationIntegrationTest : public ::testing::Test { rtc::scoped_refptr CreatePcWrapper( const char* name) { rtc::scoped_refptr pc_wrapper = - new rtc::RefCountedObject( + rtc::make_ref_counted( name, network_thread_.get(), worker_thread_.get()); PeerConnectionInterface::RTCConfiguration config; config.sdp_semantics = SdpSemantics::kUnifiedPlan; diff --git a/pc/peer_connection_bundle_unittest.cc b/pc/peer_connection_bundle_unittest.cc index c544db396f..08754c6820 100644 --- a/pc/peer_connection_bundle_unittest.cc +++ b/pc/peer_connection_bundle_unittest.cc @@ -13,7 +13,6 @@ #include "api/audio_codecs/builtin_audio_decoder_factory.h" #include "api/audio_codecs/builtin_audio_encoder_factory.h" #include "api/create_peerconnection_factory.h" -#include "api/peer_connection_proxy.h" #include "api/video_codecs/builtin_video_decoder_factory.h" #include "api/video_codecs/builtin_video_encoder_factory.h" #include "p2p/base/fake_port_allocator.h" @@ -21,6 +20,7 @@ #include "p2p/client/basic_port_allocator.h" #include "pc/media_session.h" #include "pc/peer_connection.h" +#include "pc/peer_connection_proxy.h" #include "pc/peer_connection_wrapper.h" #include "pc/sdp_utils.h" #ifdef WEBRTC_ANDROID @@ -753,11 +753,9 @@ TEST_P(PeerConnectionBundleTest, RejectDescriptionChangingBundleTag) { // This tests that removing contents from BUNDLE group and reject the whole // BUNDLE group could work. This is a regression test for // (https://bugs.chromium.org/p/chromium/issues/detail?id=827917) +#ifdef HAVE_SCTP TEST_P(PeerConnectionBundleTest, RemovingContentAndRejectBundleGroup) { RTCConfiguration config; -#ifndef HAVE_SCTP - config.enable_rtp_data_channel = true; -#endif config.bundle_policy = BundlePolicy::kBundlePolicyMaxBundle; auto caller = CreatePeerConnectionWithAudioVideo(config); caller->CreateDataChannel("dc"); @@ -782,6 +780,7 @@ TEST_P(PeerConnectionBundleTest, RemovingContentAndRejectBundleGroup) { EXPECT_TRUE(caller->SetLocalDescription(std::move(re_offer))); } +#endif // This tests that the BUNDLE group in answer should be a subset of the offered // group. @@ -887,4 +886,56 @@ TEST_F(PeerConnectionBundleTestUnifiedPlan, EXPECT_TRUE(bundle_group->content_names().empty()); } +TEST_F(PeerConnectionBundleTestUnifiedPlan, MultipleBundleGroups) { + auto caller = CreatePeerConnection(); + caller->AddAudioTrack("0_audio"); + caller->AddAudioTrack("1_audio"); + caller->AddVideoTrack("2_audio"); + caller->AddVideoTrack("3_audio"); + auto callee = CreatePeerConnection(); + + auto offer = caller->CreateOffer(RTCOfferAnswerOptions()); + // Modify the GROUP to have two BUNDLEs. We know that the MIDs will be 0,1,2,4 + // because our implementation has predictable MIDs. + offer->description()->RemoveGroupByName(cricket::GROUP_TYPE_BUNDLE); + cricket::ContentGroup bundle_group1(cricket::GROUP_TYPE_BUNDLE); + bundle_group1.AddContentName("0"); + bundle_group1.AddContentName("1"); + cricket::ContentGroup bundle_group2(cricket::GROUP_TYPE_BUNDLE); + bundle_group2.AddContentName("2"); + bundle_group2.AddContentName("3"); + offer->description()->AddGroup(bundle_group1); + offer->description()->AddGroup(bundle_group2); + + EXPECT_TRUE( + caller->SetLocalDescription(CloneSessionDescription(offer.get()))); + callee->SetRemoteDescription(std::move(offer)); + auto answer = callee->CreateAnswer(); + EXPECT_TRUE( + callee->SetLocalDescription(CloneSessionDescription(answer.get()))); + caller->SetRemoteDescription(std::move(answer)); + + // Verify bundling on sender side. + auto senders = caller->pc()->GetSenders(); + ASSERT_EQ(senders.size(), 4u); + auto sender0_transport = senders[0]->dtls_transport(); + auto sender1_transport = senders[1]->dtls_transport(); + auto sender2_transport = senders[2]->dtls_transport(); + auto sender3_transport = senders[3]->dtls_transport(); + EXPECT_EQ(sender0_transport, sender1_transport); + EXPECT_EQ(sender2_transport, sender3_transport); + EXPECT_NE(sender0_transport, sender2_transport); + + // Verify bundling on receiver side. + auto receivers = callee->pc()->GetReceivers(); + ASSERT_EQ(receivers.size(), 4u); + auto receiver0_transport = receivers[0]->dtls_transport(); + auto receiver1_transport = receivers[1]->dtls_transport(); + auto receiver2_transport = receivers[2]->dtls_transport(); + auto receiver3_transport = receivers[3]->dtls_transport(); + EXPECT_EQ(receiver0_transport, receiver1_transport); + EXPECT_EQ(receiver2_transport, receiver3_transport); + EXPECT_NE(receiver0_transport, receiver2_transport); +} + } // namespace webrtc diff --git a/pc/peer_connection_crypto_unittest.cc b/pc/peer_connection_crypto_unittest.cc index 32e8cbd74c..394203cb02 100644 --- a/pc/peer_connection_crypto_unittest.cc +++ b/pc/peer_connection_crypto_unittest.cc @@ -631,7 +631,7 @@ TEST_P(PeerConnectionCryptoDtlsCertGenTest, TestCertificateGeneration) { observers; for (size_t i = 0; i < concurrent_calls_; i++) { rtc::scoped_refptr observer = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); observers.push_back(observer); if (sdp_type_ == SdpType::kOffer) { pc->pc()->CreateOffer(observer, diff --git a/pc/peer_connection_data_channel_unittest.cc b/pc/peer_connection_data_channel_unittest.cc index 6c51f01594..2544473536 100644 --- a/pc/peer_connection_data_channel_unittest.cc +++ b/pc/peer_connection_data_channel_unittest.cc @@ -19,7 +19,6 @@ #include "api/jsep.h" #include "api/media_types.h" #include "api/peer_connection_interface.h" -#include "api/peer_connection_proxy.h" #include "api/scoped_refptr.h" #include "api/task_queue/default_task_queue_factory.h" #include "media/base/codec.h" @@ -32,6 +31,7 @@ #include "pc/media_session.h" #include "pc/peer_connection.h" #include "pc/peer_connection_factory.h" +#include "pc/peer_connection_proxy.h" #include "pc/peer_connection_wrapper.h" #include "pc/sdp_utils.h" #include "pc/session_description.h" @@ -193,28 +193,6 @@ class PeerConnectionDataChannelUnifiedPlanTest : PeerConnectionDataChannelBaseTest(SdpSemantics::kUnifiedPlan) {} }; -TEST_P(PeerConnectionDataChannelTest, - NoSctpTransportCreatedIfRtpDataChannelEnabled) { - RTCConfiguration config; - config.enable_rtp_data_channel = true; - auto caller = CreatePeerConnectionWithDataChannel(config); - - ASSERT_TRUE(caller->SetLocalDescription(caller->CreateOffer())); - EXPECT_FALSE(caller->sctp_transport_factory()->last_fake_sctp_transport()); -} - -TEST_P(PeerConnectionDataChannelTest, - RtpDataChannelCreatedEvenIfSctpAvailable) { - RTCConfiguration config; - config.enable_rtp_data_channel = true; - PeerConnectionFactoryInterface::Options options; - options.disable_sctp_data_channels = false; - auto caller = CreatePeerConnectionWithDataChannel(config, options); - - ASSERT_TRUE(caller->SetLocalDescription(caller->CreateOffer())); - EXPECT_FALSE(caller->sctp_transport_factory()->last_fake_sctp_transport()); -} - TEST_P(PeerConnectionDataChannelTest, InternalSctpTransportDeletedOnTeardown) { auto caller = CreatePeerConnectionWithDataChannel(); @@ -311,34 +289,6 @@ TEST_P(PeerConnectionDataChannelTest, EXPECT_TRUE(caller->pc()->CreateDataChannel("dc", nullptr)); } -TEST_P(PeerConnectionDataChannelTest, CreateDataChannelWithSctpDisabledFails) { - PeerConnectionFactoryInterface::Options options; - options.disable_sctp_data_channels = true; - auto caller = CreatePeerConnection(RTCConfiguration(), options); - - EXPECT_FALSE(caller->pc()->CreateDataChannel("dc", nullptr)); -} - -// Test that if a callee has SCTP disabled and receives an offer with an SCTP -// data channel, the data section is rejected and no SCTP transport is created -// on the callee. -TEST_P(PeerConnectionDataChannelTest, - DataSectionRejectedIfCalleeHasSctpDisabled) { - auto caller = CreatePeerConnectionWithDataChannel(); - PeerConnectionFactoryInterface::Options options; - options.disable_sctp_data_channels = true; - auto callee = CreatePeerConnection(RTCConfiguration(), options); - - ASSERT_TRUE(callee->SetRemoteDescription(caller->CreateOfferAndSetAsLocal())); - - EXPECT_FALSE(callee->sctp_transport_factory()->last_fake_sctp_transport()); - - auto answer = callee->CreateAnswer(); - auto* data_content = cricket::GetFirstDataContent(answer->description()); - ASSERT_TRUE(data_content); - EXPECT_TRUE(data_content->rejected); -} - TEST_P(PeerConnectionDataChannelTest, SctpPortPropagatedFromSdpToTransport) { constexpr int kNewSendPort = 9998; constexpr int kNewRecvPort = 7775; @@ -352,8 +302,9 @@ TEST_P(PeerConnectionDataChannelTest, SctpPortPropagatedFromSdpToTransport) { auto answer = callee->CreateAnswer(); ChangeSctpPortOnDescription(answer->description(), kNewRecvPort); + std::string sdp; + answer->ToString(&sdp); ASSERT_TRUE(callee->SetLocalDescription(std::move(answer))); - auto* callee_transport = callee->sctp_transport_factory()->last_fake_sctp_transport(); ASSERT_TRUE(callee_transport); @@ -392,28 +343,4 @@ INSTANTIATE_TEST_SUITE_P(PeerConnectionDataChannelTest, Values(SdpSemantics::kPlanB, SdpSemantics::kUnifiedPlan)); -TEST_F(PeerConnectionDataChannelUnifiedPlanTest, - ReOfferAfterPeerRejectsDataChannel) { - auto caller = CreatePeerConnectionWithDataChannel(); - PeerConnectionFactoryInterface::Options options; - options.disable_sctp_data_channels = true; - auto callee = CreatePeerConnection(RTCConfiguration(), options); - - ASSERT_TRUE(caller->ExchangeOfferAnswerWith(callee.get())); - - auto offer = caller->CreateOffer(); - ASSERT_TRUE(offer); - const auto& contents = offer->description()->contents(); - ASSERT_EQ(1u, contents.size()); - EXPECT_TRUE(contents[0].rejected); - - ASSERT_TRUE( - caller->SetLocalDescription(CloneSessionDescription(offer.get()))); - ASSERT_TRUE(callee->SetRemoteDescription(std::move(offer))); - - auto answer = callee->CreateAnswerAndSetAsLocal(); - ASSERT_TRUE(answer); - EXPECT_TRUE(caller->SetRemoteDescription(std::move(answer))); -} - } // namespace webrtc diff --git a/pc/peer_connection_end_to_end_unittest.cc b/pc/peer_connection_end_to_end_unittest.cc index 24ef69c111..b29371c59b 100644 --- a/pc/peer_connection_end_to_end_unittest.cc +++ b/pc/peer_connection_end_to_end_unittest.cc @@ -465,7 +465,7 @@ TEST_P(PeerConnectionEndToEndTest, CallWithCustomCodec) { EXPECT_NE(encoder_id1, encoder_id2); } -#ifdef HAVE_SCTP +#ifdef WEBRTC_HAVE_SCTP // Verifies that a DataChannel created before the negotiation can transition to // "OPEN" and transfer data. TEST_P(PeerConnectionEndToEndTest, CreateDataChannelBeforeNegotiate) { @@ -735,7 +735,7 @@ TEST_P(PeerConnectionEndToEndTest, TooManyDataChannelsOpenedBeforeConnecting) { channels[cricket::kMaxSctpStreams / 2]->state()); } -#endif // HAVE_SCTP +#endif // WEBRTC_HAVE_SCTP TEST_P(PeerConnectionEndToEndTest, CanRestartIce) { rtc::scoped_refptr real_decoder_factory = diff --git a/pc/peer_connection_factory.cc b/pc/peer_connection_factory.cc index f4f72c75f8..50755a38c7 100644 --- a/pc/peer_connection_factory.cc +++ b/pc/peer_connection_factory.cc @@ -10,9 +10,7 @@ #include "pc/peer_connection_factory.h" -#include #include -#include #include #include "absl/strings/match.h" @@ -20,29 +18,31 @@ #include "api/call/call_factory_interface.h" #include "api/fec_controller.h" #include "api/ice_transport_interface.h" -#include "api/media_stream_proxy.h" -#include "api/media_stream_track_proxy.h" #include "api/network_state_predictor.h" #include "api/packet_socket_factory.h" -#include "api/peer_connection_factory_proxy.h" -#include "api/peer_connection_proxy.h" #include "api/rtc_event_log/rtc_event_log.h" +#include "api/sequence_checker.h" #include "api/transport/bitrate_settings.h" #include "api/units/data_rate.h" #include "call/audio_state.h" +#include "call/rtp_transport_controller_send_factory.h" #include "media/base/media_engine.h" #include "p2p/base/basic_async_resolver_factory.h" #include "p2p/base/basic_packet_socket_factory.h" #include "p2p/base/default_ice_transport_factory.h" +#include "p2p/base/port_allocator.h" #include "p2p/client/basic_port_allocator.h" #include "pc/audio_track.h" #include "pc/local_audio_source.h" #include "pc/media_stream.h" +#include "pc/media_stream_proxy.h" +#include "pc/media_stream_track_proxy.h" #include "pc/peer_connection.h" +#include "pc/peer_connection_factory_proxy.h" +#include "pc/peer_connection_proxy.h" #include "pc/rtp_parameters_conversion.h" #include "pc/session_description.h" #include "pc/video_track.h" -#include "rtc_base/bind.h" #include "rtc_base/checks.h" #include "rtc_base/experiments/field_trial_parser.h" #include "rtc_base/experiments/field_trial_units.h" @@ -50,7 +50,7 @@ #include "rtc_base/logging.h" #include "rtc_base/numerics/safe_conversions.h" #include "rtc_base/ref_counted_object.h" -#include "rtc_base/synchronization/sequence_checker.h" +#include "rtc_base/rtc_certificate_generator.h" #include "rtc_base/system/file_wrapper.h" namespace webrtc { @@ -76,8 +76,8 @@ CreateModularPeerConnectionFactory( // Verify that the invocation and the initialization ended up agreeing on the // thread. RTC_DCHECK_RUN_ON(pc_factory->signaling_thread()); - return PeerConnectionFactoryProxy::Create(pc_factory->signaling_thread(), - pc_factory); + return PeerConnectionFactoryProxy::Create( + pc_factory->signaling_thread(), pc_factory->worker_thread(), pc_factory); } // Static @@ -87,8 +87,7 @@ rtc::scoped_refptr PeerConnectionFactory::Create( if (!context) { return nullptr; } - return new rtc::RefCountedObject(context, - &dependencies); + return rtc::make_ref_counted(context, &dependencies); } PeerConnectionFactory::PeerConnectionFactory( @@ -102,7 +101,11 @@ PeerConnectionFactory::PeerConnectionFactory( std::move(dependencies->network_state_predictor_factory)), injected_network_controller_factory_( std::move(dependencies->network_controller_factory)), - neteq_factory_(std::move(dependencies->neteq_factory)) {} + neteq_factory_(std::move(dependencies->neteq_factory)), + transport_controller_send_factory_( + (dependencies->transport_controller_send_factory) + ? std::move(dependencies->transport_controller_send_factory) + : std::make_unique()) {} PeerConnectionFactory::PeerConnectionFactory( PeerConnectionFactoryDependencies dependencies) @@ -141,6 +144,7 @@ RtpCapabilities PeerConnectionFactory::GetRtpSenderCapabilities( case cricket::MEDIA_TYPE_UNSUPPORTED: return RtpCapabilities(); } + RTC_DLOG(LS_ERROR) << "Got unexpected MediaType " << kind; RTC_CHECK_NOTREACHED(); } @@ -167,6 +171,7 @@ RtpCapabilities PeerConnectionFactory::GetRtpReceiverCapabilities( case cricket::MEDIA_TYPE_UNSUPPORTED: return RtpCapabilities(); } + RTC_DLOG(LS_ERROR) << "Got unexpected MediaType " << kind; RTC_CHECK_NOTREACHED(); } @@ -179,42 +184,15 @@ PeerConnectionFactory::CreateAudioSource(const cricket::AudioOptions& options) { } bool PeerConnectionFactory::StartAecDump(FILE* file, int64_t max_size_bytes) { - RTC_DCHECK(signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread()); return channel_manager()->StartAecDump(FileWrapper(file), max_size_bytes); } void PeerConnectionFactory::StopAecDump() { - RTC_DCHECK(signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(worker_thread()); channel_manager()->StopAecDump(); } -rtc::scoped_refptr -PeerConnectionFactory::CreatePeerConnection( - const PeerConnectionInterface::RTCConfiguration& configuration, - std::unique_ptr allocator, - std::unique_ptr cert_generator, - PeerConnectionObserver* observer) { - // Convert the legacy API into the new dependency structure. - PeerConnectionDependencies dependencies(observer); - dependencies.allocator = std::move(allocator); - dependencies.cert_generator = std::move(cert_generator); - // Pass that into the new API. - return CreatePeerConnection(configuration, std::move(dependencies)); -} - -rtc::scoped_refptr -PeerConnectionFactory::CreatePeerConnection( - const PeerConnectionInterface::RTCConfiguration& configuration, - PeerConnectionDependencies dependencies) { - auto result = - CreatePeerConnectionOrError(configuration, std::move(dependencies)); - if (result.ok()) { - return result.MoveValue(); - } else { - return nullptr; - } -} - RTCErrorOr> PeerConnectionFactory::CreatePeerConnectionOrError( const PeerConnectionInterface::RTCConfiguration& configuration, @@ -256,12 +234,11 @@ PeerConnectionFactory::CreatePeerConnectionOrError( std::unique_ptr event_log = worker_thread()->Invoke>( - RTC_FROM_HERE, - rtc::Bind(&PeerConnectionFactory::CreateRtcEventLog_w, this)); + RTC_FROM_HERE, [this] { return CreateRtcEventLog_w(); }); std::unique_ptr call = worker_thread()->Invoke>( RTC_FROM_HERE, - rtc::Bind(&PeerConnectionFactory::CreateCall_w, this, event_log.get())); + [this, &event_log] { return CreateCall_w(event_log.get()); }); auto result = PeerConnection::Create(context_, options_, std::move(event_log), std::move(call), configuration, @@ -269,8 +246,15 @@ PeerConnectionFactory::CreatePeerConnectionOrError( if (!result.ok()) { return result.MoveError(); } + // We configure the proxy with a pointer to the network thread for methods + // that need to be invoked there rather than on the signaling thread. + // Internally, the proxy object has a member variable named |worker_thread_| + // which will point to the network thread (and not the factory's + // worker_thread()). All such methods have thread checks though, so the code + // should still be clear (outside of macro expansion). rtc::scoped_refptr result_proxy = - PeerConnectionProxy::Create(signaling_thread(), result.MoveValue()); + PeerConnectionProxy::Create(signaling_thread(), network_thread(), + result.MoveValue()); return result_proxy; } @@ -317,7 +301,7 @@ std::unique_ptr PeerConnectionFactory::CreateCall_w( RtcEventLog* event_log) { RTC_DCHECK_RUN_ON(worker_thread()); - webrtc::Call::Config call_config(event_log); + webrtc::Call::Config call_config(event_log, network_thread()); if (!channel_manager()->media_engine() || !context_->call_factory()) { return nullptr; } @@ -355,7 +339,8 @@ std::unique_ptr PeerConnectionFactory::CreateCall_w( } call_config.trials = &trials(); - + call_config.rtp_transport_controller_send_factory = + transport_controller_send_factory_.get(); return std::unique_ptr( context_->call_factory()->CreateCall(call_config)); } diff --git a/pc/peer_connection_factory.h b/pc/peer_connection_factory.h index 9c4a2b0526..4946ec6ea2 100644 --- a/pc/peer_connection_factory.h +++ b/pc/peer_connection_factory.h @@ -14,6 +14,7 @@ #include #include + #include #include @@ -30,17 +31,20 @@ #include "api/rtc_event_log/rtc_event_log_factory_interface.h" #include "api/rtp_parameters.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_factory.h" #include "api/transport/network_control.h" #include "api/transport/sctp_transport_factory_interface.h" #include "api/transport/webrtc_key_value_config.h" #include "call/call.h" -#include "media/sctp/sctp_transport_internal.h" +#include "call/rtp_transport_controller_send_factory_interface.h" #include "p2p/base/port_allocator.h" #include "pc/channel_manager.h" #include "pc/connection_context.h" +#include "rtc_base/checks.h" #include "rtc_base/rtc_certificate_generator.h" #include "rtc_base/thread.h" +#include "rtc_base/thread_annotations.h" namespace rtc { class BasicNetworkManager; @@ -63,16 +67,6 @@ class PeerConnectionFactory : public PeerConnectionFactoryInterface { void SetOptions(const Options& options) override; - rtc::scoped_refptr CreatePeerConnection( - const PeerConnectionInterface::RTCConfiguration& configuration, - std::unique_ptr allocator, - std::unique_ptr cert_generator, - PeerConnectionObserver* observer) override; - - rtc::scoped_refptr CreatePeerConnection( - const PeerConnectionInterface::RTCConfiguration& configuration, - PeerConnectionDependencies dependencies) override; - RTCErrorOr> CreatePeerConnectionOrError( const PeerConnectionInterface::RTCConfiguration& configuration, @@ -113,6 +107,8 @@ class PeerConnectionFactory : public PeerConnectionFactoryInterface { return context_->signaling_thread(); } + rtc::Thread* worker_thread() const { return context_->worker_thread(); } + const Options& options() const { RTC_DCHECK_RUN_ON(signaling_thread()); return options_; @@ -133,7 +129,6 @@ class PeerConnectionFactory : public PeerConnectionFactoryInterface { virtual ~PeerConnectionFactory(); private: - rtc::Thread* worker_thread() const { return context_->worker_thread(); } rtc::Thread* network_thread() const { return context_->network_thread(); } bool IsTrialEnabled(absl::string_view key) const; @@ -155,6 +150,8 @@ class PeerConnectionFactory : public PeerConnectionFactoryInterface { std::unique_ptr injected_network_controller_factory_; std::unique_ptr neteq_factory_; + const std::unique_ptr + transport_controller_send_factory_; }; } // namespace webrtc diff --git a/api/peer_connection_factory_proxy.h b/pc/peer_connection_factory_proxy.h similarity index 61% rename from api/peer_connection_factory_proxy.h rename to pc/peer_connection_factory_proxy.h index be098e34d8..59e373db7b 100644 --- a/api/peer_connection_factory_proxy.h +++ b/pc/peer_connection_factory_proxy.h @@ -8,34 +8,23 @@ * be found in the AUTHORS file in the root of the source tree. */ -#ifndef API_PEER_CONNECTION_FACTORY_PROXY_H_ -#define API_PEER_CONNECTION_FACTORY_PROXY_H_ +#ifndef PC_PEER_CONNECTION_FACTORY_PROXY_H_ +#define PC_PEER_CONNECTION_FACTORY_PROXY_H_ #include #include #include #include "api/peer_connection_interface.h" -#include "api/proxy.h" -#include "rtc_base/bind.h" +#include "pc/proxy.h" namespace webrtc { -// TODO(deadbeef): Move this to .cc file and out of api/. What threads methods -// are called on is an implementation detail. -BEGIN_SIGNALING_PROXY_MAP(PeerConnectionFactory) -PROXY_SIGNALING_THREAD_DESTRUCTOR() +// TODO(deadbeef): Move this to .cc file. What threads methods are called on is +// an implementation detail. +BEGIN_PROXY_MAP(PeerConnectionFactory) +PROXY_PRIMARY_THREAD_DESTRUCTOR() PROXY_METHOD1(void, SetOptions, const Options&) -PROXY_METHOD4(rtc::scoped_refptr, - CreatePeerConnection, - const PeerConnectionInterface::RTCConfiguration&, - std::unique_ptr, - std::unique_ptr, - PeerConnectionObserver*) -PROXY_METHOD2(rtc::scoped_refptr, - CreatePeerConnection, - const PeerConnectionInterface::RTCConfiguration&, - PeerConnectionDependencies) PROXY_METHOD2(RTCErrorOr>, CreatePeerConnectionOrError, const PeerConnectionInterface::RTCConfiguration&, @@ -60,10 +49,10 @@ PROXY_METHOD2(rtc::scoped_refptr, CreateAudioTrack, const std::string&, AudioSourceInterface*) -PROXY_METHOD2(bool, StartAecDump, FILE*, int64_t) -PROXY_METHOD0(void, StopAecDump) -END_PROXY_MAP() +PROXY_SECONDARY_METHOD2(bool, StartAecDump, FILE*, int64_t) +PROXY_SECONDARY_METHOD0(void, StopAecDump) +END_PROXY_MAP(PeerConnectionFactory) } // namespace webrtc -#endif // API_PEER_CONNECTION_FACTORY_PROXY_H_ +#endif // PC_PEER_CONNECTION_FACTORY_PROXY_H_ diff --git a/pc/peer_connection_histogram_unittest.cc b/pc/peer_connection_histogram_unittest.cc index 39b9a73a46..fa46ce9802 100644 --- a/pc/peer_connection_histogram_unittest.cc +++ b/pc/peer_connection_histogram_unittest.cc @@ -19,7 +19,6 @@ #include "api/jsep.h" #include "api/jsep_session_description.h" #include "api/peer_connection_interface.h" -#include "api/peer_connection_proxy.h" #include "api/rtc_error.h" #include "api/scoped_refptr.h" #include "api/task_queue/default_task_queue_factory.h" @@ -29,6 +28,7 @@ #include "p2p/client/basic_port_allocator.h" #include "pc/peer_connection.h" #include "pc/peer_connection_factory.h" +#include "pc/peer_connection_proxy.h" #include "pc/peer_connection_wrapper.h" #include "pc/sdp_utils.h" #include "pc/test/mock_peer_connection_observers.h" @@ -497,7 +497,7 @@ TEST_F(PeerConnectionUsageHistogramTest, FingerprintWithMdnsCallee) { expected_fingerprint_callee)); } -#ifdef HAVE_SCTP +#ifdef WEBRTC_HAVE_SCTP TEST_F(PeerConnectionUsageHistogramTest, FingerprintDataOnly) { auto caller = CreatePeerConnection(); auto callee = CreatePeerConnection(); @@ -521,7 +521,7 @@ TEST_F(PeerConnectionUsageHistogramTest, FingerprintDataOnly) { expected_fingerprint | static_cast(UsageEvent::PRIVATE_CANDIDATE_COLLECTED)) == 2); } -#endif // HAVE_SCTP +#endif // WEBRTC_HAVE_SCTP #endif // WEBRTC_ANDROID TEST_F(PeerConnectionUsageHistogramTest, FingerprintStunTurn) { @@ -628,7 +628,7 @@ TEST_F(PeerConnectionUsageHistogramTest, FingerprintWithPrivateIpv6Callee) { } #ifndef WEBRTC_ANDROID -#ifdef HAVE_SCTP +#ifdef WEBRTC_HAVE_SCTP // Test that the usage pattern bits for adding remote (private IPv6) candidates // are set when the remote candidates are retrieved from the Offer SDP instead // of trickled ICE messages. diff --git a/pc/peer_connection_ice_unittest.cc b/pc/peer_connection_ice_unittest.cc index 8c1a764398..7971547ffa 100644 --- a/pc/peer_connection_ice_unittest.cc +++ b/pc/peer_connection_ice_unittest.cc @@ -23,10 +23,10 @@ #include "api/audio_codecs/builtin_audio_decoder_factory.h" #include "api/audio_codecs/builtin_audio_encoder_factory.h" #include "api/create_peerconnection_factory.h" -#include "api/peer_connection_proxy.h" #include "api/uma_metrics.h" #include "api/video_codecs/builtin_video_decoder_factory.h" #include "api/video_codecs/builtin_video_encoder_factory.h" +#include "pc/peer_connection_proxy.h" #include "pc/test/fake_audio_capture_module.h" #include "pc/test/mock_peer_connection_observers.h" #include "rtc_base/fake_network.h" @@ -497,6 +497,24 @@ TEST_P(PeerConnectionIceTest, DuplicateIceCandidateIgnoredWhenAdded) { EXPECT_EQ(1u, caller->GetIceCandidatesFromRemoteDescription().size()); } +// TODO(tommi): Re-enable after updating RTCPeerConnection-blockedPorts.html in +// Chromium (the test needs setRemoteDescription to succeed for an invalid +// candidate). +TEST_P(PeerConnectionIceTest, DISABLED_ErrorOnInvalidRemoteIceCandidateAdded) { + auto caller = CreatePeerConnectionWithAudioVideo(); + auto callee = CreatePeerConnectionWithAudioVideo(); + ASSERT_TRUE(callee->SetRemoteDescription(caller->CreateOfferAndSetAsLocal())); + // Add a candidate to the remote description with a candidate that has an + // invalid address (port number == 2). + auto answer = callee->CreateAnswerAndSetAsLocal(); + cricket::Candidate bad_candidate = + CreateLocalUdpCandidate(SocketAddress("2.2.2.2", 2)); + RTC_LOG(LS_INFO) << "Bad candidate: " << bad_candidate.ToString(); + AddCandidateToFirstTransport(&bad_candidate, answer.get()); + // Now the call to SetRemoteDescription should fail. + EXPECT_FALSE(caller->SetRemoteDescription(std::move(answer))); +} + TEST_P(PeerConnectionIceTest, CannotRemoveIceCandidatesWhenPeerConnectionClosed) { const SocketAddress kCalleeAddress("1.1.1.1", 1111); @@ -750,8 +768,8 @@ TEST_P(PeerConnectionIceTest, ASSERT_TRUE(callee->SetRemoteDescription(caller->CreateOfferAndSetAsLocal())); // Chain an operation that will block AddIceCandidate() from executing. - rtc::scoped_refptr answer_observer( - new rtc::RefCountedObject()); + auto answer_observer = + rtc::make_ref_counted(); callee->pc()->CreateAnswer(answer_observer, RTCOfferAnswerOptions()); auto jsep_candidate = @@ -798,8 +816,8 @@ TEST_P(PeerConnectionIceTest, ASSERT_TRUE(callee->SetRemoteDescription(caller->CreateOfferAndSetAsLocal())); // Chain an operation that will block AddIceCandidate() from executing. - rtc::scoped_refptr answer_observer( - new rtc::RefCountedObject()); + auto answer_observer = + rtc::make_ref_counted(); callee->pc()->CreateAnswer(answer_observer, RTCOfferAnswerOptions()); auto jsep_candidate = @@ -1388,6 +1406,36 @@ TEST_F(PeerConnectionIceConfigTest, SetStunCandidateKeepaliveInterval) { EXPECT_EQ(actual_stun_keepalive_interval.value_or(-1), 321); } +TEST_F(PeerConnectionIceConfigTest, SetStableWritableConnectionInterval) { + RTCConfiguration config; + config.stable_writable_connection_ping_interval_ms = 3500; + CreatePeerConnection(config); + EXPECT_TRUE(pc_->SetConfiguration(config).ok()); + EXPECT_EQ(pc_->GetConfiguration().stable_writable_connection_ping_interval_ms, + config.stable_writable_connection_ping_interval_ms); +} + +TEST_F(PeerConnectionIceConfigTest, + SetStableWritableConnectionInterval_FailsValidation) { + RTCConfiguration config; + CreatePeerConnection(config); + ASSERT_TRUE(pc_->SetConfiguration(config).ok()); + config.stable_writable_connection_ping_interval_ms = 5000; + config.ice_check_interval_strong_connectivity = 7500; + EXPECT_FALSE(pc_->SetConfiguration(config).ok()); +} + +TEST_F(PeerConnectionIceConfigTest, + SetStableWritableConnectionInterval_DefaultValue_FailsValidation) { + RTCConfiguration config; + CreatePeerConnection(config); + ASSERT_TRUE(pc_->SetConfiguration(config).ok()); + config.ice_check_interval_strong_connectivity = 2500; + EXPECT_TRUE(pc_->SetConfiguration(config).ok()); + config.ice_check_interval_strong_connectivity = 2501; + EXPECT_FALSE(pc_->SetConfiguration(config).ok()); +} + TEST_P(PeerConnectionIceTest, IceCredentialsCreateOffer) { RTCConfiguration config; config.ice_candidate_pool_size = 1; @@ -1434,4 +1482,24 @@ TEST_P(PeerConnectionIceTest, CloseDoesNotTransitionGatheringStateToComplete) { pc->pc()->ice_gathering_state()); } +TEST_P(PeerConnectionIceTest, PrefersMidOverMLineIndex) { + const SocketAddress kCalleeAddress("1.1.1.1", 1111); + + auto caller = CreatePeerConnectionWithAudioVideo(); + auto callee = CreatePeerConnectionWithAudioVideo(); + + ASSERT_TRUE(callee->SetRemoteDescription(caller->CreateOfferAndSetAsLocal())); + ASSERT_TRUE( + caller->SetRemoteDescription(callee->CreateAnswerAndSetAsLocal())); + + // |candidate.transport_name()| is empty. + cricket::Candidate candidate = CreateLocalUdpCandidate(kCalleeAddress); + auto* audio_content = cricket::GetFirstAudioContent( + caller->pc()->local_description()->description()); + std::unique_ptr ice_candidate = + CreateIceCandidate(audio_content->name, 65535, candidate); + EXPECT_TRUE(caller->pc()->AddIceCandidate(ice_candidate.get())); + EXPECT_TRUE(caller->pc()->RemoveIceCandidates({candidate})); +} + } // namespace webrtc diff --git a/pc/peer_connection_integrationtest.cc b/pc/peer_connection_integrationtest.cc index 32bfd1aff6..dfceacd777 100644 --- a/pc/peer_connection_integrationtest.cc +++ b/pc/peer_connection_integrationtest.cc @@ -8,1771 +8,88 @@ * be found in the AUTHORS file in the root of the source tree. */ -// Disable for TSan v2, see -// https://code.google.com/p/webrtc/issues/detail?id=1205 for details. -#if !defined(THREAD_SANITIZER) - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "api/media_stream_interface.h" -#include "api/peer_connection_interface.h" -#include "api/peer_connection_proxy.h" -#include "api/rtc_event_log/rtc_event_log_factory.h" -#include "api/rtp_receiver_interface.h" -#include "api/task_queue/default_task_queue_factory.h" -#include "api/transport/field_trial_based_config.h" -#include "api/uma_metrics.h" -#include "api/video_codecs/sdp_video_format.h" -#include "call/call.h" -#include "logging/rtc_event_log/fake_rtc_event_log_factory.h" -#include "media/engine/fake_webrtc_video_engine.h" -#include "media/engine/webrtc_media_engine.h" -#include "media/engine/webrtc_media_engine_defaults.h" -#include "modules/audio_processing/test/audio_processing_builder_for_testing.h" -#include "p2p/base/fake_ice_transport.h" -#include "p2p/base/mock_async_resolver.h" -#include "p2p/base/p2p_constants.h" -#include "p2p/base/port_interface.h" -#include "p2p/base/test_stun_server.h" -#include "p2p/base/test_turn_customizer.h" -#include "p2p/base/test_turn_server.h" -#include "p2p/client/basic_port_allocator.h" -#include "pc/dtmf_sender.h" -#include "pc/local_audio_source.h" -#include "pc/media_session.h" -#include "pc/peer_connection.h" -#include "pc/peer_connection_factory.h" -#include "pc/rtp_media_utils.h" -#include "pc/session_description.h" -#include "pc/test/fake_audio_capture_module.h" -#include "pc/test/fake_periodic_video_track_source.h" -#include "pc/test/fake_rtc_certificate_generator.h" -#include "pc/test/fake_video_track_renderer.h" -#include "pc/test/mock_peer_connection_observers.h" -#include "rtc_base/fake_clock.h" -#include "rtc_base/fake_mdns_responder.h" -#include "rtc_base/fake_network.h" -#include "rtc_base/firewall_socket_server.h" -#include "rtc_base/gunit.h" -#include "rtc_base/numerics/safe_conversions.h" -#include "rtc_base/test_certificate_verifier.h" -#include "rtc_base/time_utils.h" -#include "rtc_base/virtual_socket_server.h" -#include "system_wrappers/include/metrics.h" -#include "test/field_trial.h" -#include "test/gmock.h" - -namespace webrtc { -namespace { - -using ::cricket::ContentInfo; -using ::cricket::StreamParams; -using ::rtc::SocketAddress; -using ::testing::_; -using ::testing::Combine; -using ::testing::Contains; -using ::testing::DoAll; -using ::testing::ElementsAre; -using ::testing::NiceMock; -using ::testing::Return; -using ::testing::SetArgPointee; -using ::testing::UnorderedElementsAreArray; -using ::testing::Values; -using RTCConfiguration = PeerConnectionInterface::RTCConfiguration; - -static const int kDefaultTimeout = 10000; -static const int kMaxWaitForStatsMs = 3000; -static const int kMaxWaitForActivationMs = 5000; -static const int kMaxWaitForFramesMs = 10000; -// Default number of audio/video frames to wait for before considering a test -// successful. -static const int kDefaultExpectedAudioFrameCount = 3; -static const int kDefaultExpectedVideoFrameCount = 3; - -static const char kDataChannelLabel[] = "data_channel"; - -// SRTP cipher name negotiated by the tests. This must be updated if the -// default changes. -static const int kDefaultSrtpCryptoSuite = rtc::SRTP_AES128_CM_SHA1_80; -static const int kDefaultSrtpCryptoSuiteGcm = rtc::SRTP_AEAD_AES_256_GCM; - -static const SocketAddress kDefaultLocalAddress("192.168.1.1", 0); - -// Helper function for constructing offer/answer options to initiate an ICE -// restart. -PeerConnectionInterface::RTCOfferAnswerOptions IceRestartOfferAnswerOptions() { - PeerConnectionInterface::RTCOfferAnswerOptions options; - options.ice_restart = true; - return options; -} - -// Remove all stream information (SSRCs, track IDs, etc.) and "msid-semantic" -// attribute from received SDP, simulating a legacy endpoint. -void RemoveSsrcsAndMsids(cricket::SessionDescription* desc) { - for (ContentInfo& content : desc->contents()) { - content.media_description()->mutable_streams().clear(); - } - desc->set_msid_supported(false); - desc->set_msid_signaling(0); -} - -// Removes all stream information besides the stream ids, simulating an -// endpoint that only signals a=msid lines to convey stream_ids. -void RemoveSsrcsAndKeepMsids(cricket::SessionDescription* desc) { - for (ContentInfo& content : desc->contents()) { - std::string track_id; - std::vector stream_ids; - if (!content.media_description()->streams().empty()) { - const StreamParams& first_stream = - content.media_description()->streams()[0]; - track_id = first_stream.id; - stream_ids = first_stream.stream_ids(); - } - content.media_description()->mutable_streams().clear(); - StreamParams new_stream; - new_stream.id = track_id; - new_stream.set_stream_ids(stream_ids); - content.media_description()->AddStream(new_stream); - } -} - -int FindFirstMediaStatsIndexByKind( - const std::string& kind, - const std::vector& - media_stats_vec) { - for (size_t i = 0; i < media_stats_vec.size(); i++) { - if (media_stats_vec[i]->kind.ValueToString() == kind) { - return i; - } - } - return -1; -} - -class SignalingMessageReceiver { - public: - virtual void ReceiveSdpMessage(SdpType type, const std::string& msg) = 0; - virtual void ReceiveIceMessage(const std::string& sdp_mid, - int sdp_mline_index, - const std::string& msg) = 0; - - protected: - SignalingMessageReceiver() {} - virtual ~SignalingMessageReceiver() {} -}; - -class MockRtpReceiverObserver : public webrtc::RtpReceiverObserverInterface { - public: - explicit MockRtpReceiverObserver(cricket::MediaType media_type) - : expected_media_type_(media_type) {} - - void OnFirstPacketReceived(cricket::MediaType media_type) override { - ASSERT_EQ(expected_media_type_, media_type); - first_packet_received_ = true; - } - - bool first_packet_received() const { return first_packet_received_; } - - virtual ~MockRtpReceiverObserver() {} - - private: - bool first_packet_received_ = false; - cricket::MediaType expected_media_type_; -}; - -// Helper class that wraps a peer connection, observes it, and can accept -// signaling messages from another wrapper. -// -// Uses a fake network, fake A/V capture, and optionally fake -// encoders/decoders, though they aren't used by default since they don't -// advertise support of any codecs. -// TODO(steveanton): See how this could become a subclass of -// PeerConnectionWrapper defined in peerconnectionwrapper.h. -class PeerConnectionWrapper : public webrtc::PeerConnectionObserver, - public SignalingMessageReceiver { - public: - // Different factory methods for convenience. - // TODO(deadbeef): Could use the pattern of: - // - // PeerConnectionWrapper = - // WrapperBuilder.WithConfig(...).WithOptions(...).build(); - // - // To reduce some code duplication. - static PeerConnectionWrapper* CreateWithDtlsIdentityStore( - const std::string& debug_name, - std::unique_ptr cert_generator, - rtc::Thread* network_thread, - rtc::Thread* worker_thread) { - PeerConnectionWrapper* client(new PeerConnectionWrapper(debug_name)); - webrtc::PeerConnectionDependencies dependencies(nullptr); - dependencies.cert_generator = std::move(cert_generator); - if (!client->Init(nullptr, nullptr, std::move(dependencies), network_thread, - worker_thread, nullptr, - /*reset_encoder_factory=*/false, - /*reset_decoder_factory=*/false)) { - delete client; - return nullptr; - } - return client; - } - - webrtc::PeerConnectionFactoryInterface* pc_factory() const { - return peer_connection_factory_.get(); - } - - webrtc::PeerConnectionInterface* pc() const { return peer_connection_.get(); } - - // If a signaling message receiver is set (via ConnectFakeSignaling), this - // will set the whole offer/answer exchange in motion. Just need to wait for - // the signaling state to reach "stable". - void CreateAndSetAndSignalOffer() { - auto offer = CreateOfferAndWait(); - ASSERT_NE(nullptr, offer); - EXPECT_TRUE(SetLocalDescriptionAndSendSdpMessage(std::move(offer))); - } - - // Sets the options to be used when CreateAndSetAndSignalOffer is called, or - // when a remote offer is received (via fake signaling) and an answer is - // generated. By default, uses default options. - void SetOfferAnswerOptions( - const PeerConnectionInterface::RTCOfferAnswerOptions& options) { - offer_answer_options_ = options; - } - - // Set a callback to be invoked when SDP is received via the fake signaling - // channel, which provides an opportunity to munge (modify) the SDP. This is - // used to test SDP being applied that a PeerConnection would normally not - // generate, but a non-JSEP endpoint might. - void SetReceivedSdpMunger( - std::function munger) { - received_sdp_munger_ = std::move(munger); - } - - // Similar to the above, but this is run on SDP immediately after it's - // generated. - void SetGeneratedSdpMunger( - std::function munger) { - generated_sdp_munger_ = std::move(munger); - } - - // Set a callback to be invoked when a remote offer is received via the fake - // signaling channel. This provides an opportunity to change the - // PeerConnection state before an answer is created and sent to the caller. - void SetRemoteOfferHandler(std::function handler) { - remote_offer_handler_ = std::move(handler); - } - - void SetRemoteAsyncResolver(rtc::MockAsyncResolver* resolver) { - remote_async_resolver_ = resolver; - } - - // Every ICE connection state in order that has been seen by the observer. - std::vector - ice_connection_state_history() const { - return ice_connection_state_history_; - } - void clear_ice_connection_state_history() { - ice_connection_state_history_.clear(); - } - - // Every standardized ICE connection state in order that has been seen by the - // observer. - std::vector - standardized_ice_connection_state_history() const { - return standardized_ice_connection_state_history_; - } - - // Every PeerConnection state in order that has been seen by the observer. - std::vector - peer_connection_state_history() const { - return peer_connection_state_history_; - } - - // Every ICE gathering state in order that has been seen by the observer. - std::vector - ice_gathering_state_history() const { - return ice_gathering_state_history_; - } - std::vector - ice_candidate_pair_change_history() const { - return ice_candidate_pair_change_history_; - } - - // Every PeerConnection signaling state in order that has been seen by the - // observer. - std::vector - peer_connection_signaling_state_history() const { - return peer_connection_signaling_state_history_; - } - - void AddAudioVideoTracks() { - AddAudioTrack(); - AddVideoTrack(); - } - - rtc::scoped_refptr AddAudioTrack() { - return AddTrack(CreateLocalAudioTrack()); - } - - rtc::scoped_refptr AddVideoTrack() { - return AddTrack(CreateLocalVideoTrack()); - } - - rtc::scoped_refptr CreateLocalAudioTrack() { - cricket::AudioOptions options; - // Disable highpass filter so that we can get all the test audio frames. - options.highpass_filter = false; - rtc::scoped_refptr source = - peer_connection_factory_->CreateAudioSource(options); - // TODO(perkj): Test audio source when it is implemented. Currently audio - // always use the default input. - return peer_connection_factory_->CreateAudioTrack(rtc::CreateRandomUuid(), - source); - } - - rtc::scoped_refptr CreateLocalVideoTrack() { - webrtc::FakePeriodicVideoSource::Config config; - config.timestamp_offset_ms = rtc::TimeMillis(); - return CreateLocalVideoTrackInternal(config); - } - - rtc::scoped_refptr - CreateLocalVideoTrackWithConfig( - webrtc::FakePeriodicVideoSource::Config config) { - return CreateLocalVideoTrackInternal(config); - } - - rtc::scoped_refptr - CreateLocalVideoTrackWithRotation(webrtc::VideoRotation rotation) { - webrtc::FakePeriodicVideoSource::Config config; - config.rotation = rotation; - config.timestamp_offset_ms = rtc::TimeMillis(); - return CreateLocalVideoTrackInternal(config); - } - - rtc::scoped_refptr AddTrack( - rtc::scoped_refptr track, - const std::vector& stream_ids = {}) { - auto result = pc()->AddTrack(track, stream_ids); - EXPECT_EQ(RTCErrorType::NONE, result.error().type()); - return result.MoveValue(); - } - - std::vector> GetReceiversOfType( - cricket::MediaType media_type) { - std::vector> receivers; - for (const auto& receiver : pc()->GetReceivers()) { - if (receiver->media_type() == media_type) { - receivers.push_back(receiver); - } - } - return receivers; - } - - rtc::scoped_refptr GetFirstTransceiverOfType( - cricket::MediaType media_type) { - for (auto transceiver : pc()->GetTransceivers()) { - if (transceiver->receiver()->media_type() == media_type) { - return transceiver; - } - } - return nullptr; - } - - bool SignalingStateStable() { - return pc()->signaling_state() == webrtc::PeerConnectionInterface::kStable; - } - - void CreateDataChannel() { CreateDataChannel(nullptr); } - - void CreateDataChannel(const webrtc::DataChannelInit* init) { - CreateDataChannel(kDataChannelLabel, init); - } - - void CreateDataChannel(const std::string& label, - const webrtc::DataChannelInit* init) { - data_channel_ = pc()->CreateDataChannel(label, init); - ASSERT_TRUE(data_channel_.get() != nullptr); - data_observer_.reset(new MockDataChannelObserver(data_channel_)); - } - - DataChannelInterface* data_channel() { return data_channel_; } - const MockDataChannelObserver* data_observer() const { - return data_observer_.get(); - } - - int audio_frames_received() const { - return fake_audio_capture_module_->frames_received(); - } - - // Takes minimum of video frames received for each track. - // - // Can be used like: - // EXPECT_GE(expected_frames, min_video_frames_received_per_track()); - // - // To ensure that all video tracks received at least a certain number of - // frames. - int min_video_frames_received_per_track() const { - int min_frames = INT_MAX; - if (fake_video_renderers_.empty()) { - return 0; - } - - for (const auto& pair : fake_video_renderers_) { - min_frames = std::min(min_frames, pair.second->num_rendered_frames()); - } - return min_frames; - } - - // Returns a MockStatsObserver in a state after stats gathering finished, - // which can be used to access the gathered stats. - rtc::scoped_refptr OldGetStatsForTrack( - webrtc::MediaStreamTrackInterface* track) { - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); - EXPECT_TRUE(peer_connection_->GetStats( - observer, nullptr, PeerConnectionInterface::kStatsOutputLevelStandard)); - EXPECT_TRUE_WAIT(observer->called(), kDefaultTimeout); - return observer; - } - - // Version that doesn't take a track "filter", and gathers all stats. - rtc::scoped_refptr OldGetStats() { - return OldGetStatsForTrack(nullptr); - } - - // Synchronously gets stats and returns them. If it times out, fails the test - // and returns null. - rtc::scoped_refptr NewGetStats() { - rtc::scoped_refptr callback( - new rtc::RefCountedObject()); - peer_connection_->GetStats(callback); - EXPECT_TRUE_WAIT(callback->called(), kDefaultTimeout); - return callback->report(); - } - - int rendered_width() { - EXPECT_FALSE(fake_video_renderers_.empty()); - return fake_video_renderers_.empty() - ? 0 - : fake_video_renderers_.begin()->second->width(); - } - - int rendered_height() { - EXPECT_FALSE(fake_video_renderers_.empty()); - return fake_video_renderers_.empty() - ? 0 - : fake_video_renderers_.begin()->second->height(); - } - - double rendered_aspect_ratio() { - if (rendered_height() == 0) { - return 0.0; - } - return static_cast(rendered_width()) / rendered_height(); - } - - webrtc::VideoRotation rendered_rotation() { - EXPECT_FALSE(fake_video_renderers_.empty()); - return fake_video_renderers_.empty() - ? webrtc::kVideoRotation_0 - : fake_video_renderers_.begin()->second->rotation(); - } - - int local_rendered_width() { - return local_video_renderer_ ? local_video_renderer_->width() : 0; - } - - int local_rendered_height() { - return local_video_renderer_ ? local_video_renderer_->height() : 0; - } - - double local_rendered_aspect_ratio() { - if (local_rendered_height() == 0) { - return 0.0; - } - return static_cast(local_rendered_width()) / - local_rendered_height(); - } - - size_t number_of_remote_streams() { - if (!pc()) { - return 0; - } - return pc()->remote_streams()->count(); - } - - StreamCollectionInterface* remote_streams() const { - if (!pc()) { - ADD_FAILURE(); - return nullptr; - } - return pc()->remote_streams(); - } - - StreamCollectionInterface* local_streams() { - if (!pc()) { - ADD_FAILURE(); - return nullptr; - } - return pc()->local_streams(); - } - - webrtc::PeerConnectionInterface::SignalingState signaling_state() { - return pc()->signaling_state(); - } - - webrtc::PeerConnectionInterface::IceConnectionState ice_connection_state() { - return pc()->ice_connection_state(); - } - - webrtc::PeerConnectionInterface::IceConnectionState - standardized_ice_connection_state() { - return pc()->standardized_ice_connection_state(); - } - - webrtc::PeerConnectionInterface::IceGatheringState ice_gathering_state() { - return pc()->ice_gathering_state(); - } - - // Returns a MockRtpReceiverObserver for each RtpReceiver returned by - // GetReceivers. They're updated automatically when a remote offer/answer - // from the fake signaling channel is applied, or when - // ResetRtpReceiverObservers below is called. - const std::vector>& - rtp_receiver_observers() { - return rtp_receiver_observers_; - } - - void ResetRtpReceiverObservers() { - rtp_receiver_observers_.clear(); - for (const rtc::scoped_refptr& receiver : - pc()->GetReceivers()) { - std::unique_ptr observer( - new MockRtpReceiverObserver(receiver->media_type())); - receiver->SetObserver(observer.get()); - rtp_receiver_observers_.push_back(std::move(observer)); - } - } - - rtc::FakeNetworkManager* network_manager() const { - return fake_network_manager_.get(); - } - cricket::PortAllocator* port_allocator() const { return port_allocator_; } - - webrtc::FakeRtcEventLogFactory* event_log_factory() const { - return event_log_factory_; - } - - const cricket::Candidate& last_candidate_gathered() const { - return last_candidate_gathered_; - } - const cricket::IceCandidateErrorEvent& error_event() const { - return error_event_; - } - - // Sets the mDNS responder for the owned fake network manager and keeps a - // reference to the responder. - void SetMdnsResponder( - std::unique_ptr mdns_responder) { - RTC_DCHECK(mdns_responder != nullptr); - mdns_responder_ = mdns_responder.get(); - network_manager()->set_mdns_responder(std::move(mdns_responder)); - } - - // Returns null on failure. - std::unique_ptr CreateOfferAndWait() { - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); - pc()->CreateOffer(observer, offer_answer_options_); - return WaitForDescriptionFromObserver(observer); - } - bool Rollback() { - return SetRemoteDescription( - webrtc::CreateSessionDescription(SdpType::kRollback, "")); - } - - // Functions for querying stats. - void StartWatchingDelayStats() { - // Get the baseline numbers for audio_packets and audio_delay. - auto received_stats = NewGetStats(); - auto track_stats = - received_stats->GetStatsOfType()[0]; - ASSERT_TRUE(track_stats->relative_packet_arrival_delay.is_defined()); - auto rtp_stats = - received_stats->GetStatsOfType()[0]; - ASSERT_TRUE(rtp_stats->packets_received.is_defined()); - ASSERT_TRUE(rtp_stats->track_id.is_defined()); - audio_track_stats_id_ = track_stats->id(); - ASSERT_TRUE(received_stats->Get(audio_track_stats_id_)); - rtp_stats_id_ = rtp_stats->id(); - ASSERT_EQ(audio_track_stats_id_, *rtp_stats->track_id); - audio_packets_stat_ = *rtp_stats->packets_received; - audio_delay_stat_ = *track_stats->relative_packet_arrival_delay; - } - - void UpdateDelayStats(std::string tag, int desc_size) { - auto report = NewGetStats(); - auto track_stats = - report->GetAs(audio_track_stats_id_); - ASSERT_TRUE(track_stats); - auto rtp_stats = - report->GetAs(rtp_stats_id_); - ASSERT_TRUE(rtp_stats); - auto delta_packets = *rtp_stats->packets_received - audio_packets_stat_; - auto delta_rpad = - *track_stats->relative_packet_arrival_delay - audio_delay_stat_; - auto recent_delay = delta_packets > 0 ? delta_rpad / delta_packets : -1; - // An average relative packet arrival delay over the renegotiation of - // > 100 ms indicates that something is dramatically wrong, and will impact - // quality for sure. - ASSERT_GT(0.1, recent_delay) << tag << " size " << desc_size; - // Increment trailing counters - audio_packets_stat_ = *rtp_stats->packets_received; - audio_delay_stat_ = *track_stats->relative_packet_arrival_delay; - } - - private: - explicit PeerConnectionWrapper(const std::string& debug_name) - : debug_name_(debug_name) {} - - bool Init( - const PeerConnectionFactory::Options* options, - const PeerConnectionInterface::RTCConfiguration* config, - webrtc::PeerConnectionDependencies dependencies, - rtc::Thread* network_thread, - rtc::Thread* worker_thread, - std::unique_ptr event_log_factory, - bool reset_encoder_factory, - bool reset_decoder_factory) { - // There's an error in this test code if Init ends up being called twice. - RTC_DCHECK(!peer_connection_); - RTC_DCHECK(!peer_connection_factory_); - - fake_network_manager_.reset(new rtc::FakeNetworkManager()); - fake_network_manager_->AddInterface(kDefaultLocalAddress); - - std::unique_ptr port_allocator( - new cricket::BasicPortAllocator(fake_network_manager_.get())); - port_allocator_ = port_allocator.get(); - fake_audio_capture_module_ = FakeAudioCaptureModule::Create(); - if (!fake_audio_capture_module_) { - return false; - } - rtc::Thread* const signaling_thread = rtc::Thread::Current(); - - webrtc::PeerConnectionFactoryDependencies pc_factory_dependencies; - pc_factory_dependencies.network_thread = network_thread; - pc_factory_dependencies.worker_thread = worker_thread; - pc_factory_dependencies.signaling_thread = signaling_thread; - pc_factory_dependencies.task_queue_factory = - webrtc::CreateDefaultTaskQueueFactory(); - pc_factory_dependencies.trials = std::make_unique(); - cricket::MediaEngineDependencies media_deps; - media_deps.task_queue_factory = - pc_factory_dependencies.task_queue_factory.get(); - media_deps.adm = fake_audio_capture_module_; - webrtc::SetMediaEngineDefaults(&media_deps); - - if (reset_encoder_factory) { - media_deps.video_encoder_factory.reset(); - } - if (reset_decoder_factory) { - media_deps.video_decoder_factory.reset(); - } - - if (!media_deps.audio_processing) { - // If the standard Creation method for APM returns a null pointer, instead - // use the builder for testing to create an APM object. - media_deps.audio_processing = AudioProcessingBuilderForTesting().Create(); - } - - media_deps.trials = pc_factory_dependencies.trials.get(); - - pc_factory_dependencies.media_engine = - cricket::CreateMediaEngine(std::move(media_deps)); - pc_factory_dependencies.call_factory = webrtc::CreateCallFactory(); - if (event_log_factory) { - event_log_factory_ = event_log_factory.get(); - pc_factory_dependencies.event_log_factory = std::move(event_log_factory); - } else { - pc_factory_dependencies.event_log_factory = - std::make_unique( - pc_factory_dependencies.task_queue_factory.get()); - } - peer_connection_factory_ = webrtc::CreateModularPeerConnectionFactory( - std::move(pc_factory_dependencies)); - - if (!peer_connection_factory_) { - return false; - } - if (options) { - peer_connection_factory_->SetOptions(*options); - } - if (config) { - sdp_semantics_ = config->sdp_semantics; - } - - dependencies.allocator = std::move(port_allocator); - peer_connection_ = CreatePeerConnection(config, std::move(dependencies)); - return peer_connection_.get() != nullptr; - } - - rtc::scoped_refptr CreatePeerConnection( - const PeerConnectionInterface::RTCConfiguration* config, - webrtc::PeerConnectionDependencies dependencies) { - PeerConnectionInterface::RTCConfiguration modified_config; - // If |config| is null, this will result in a default configuration being - // used. - if (config) { - modified_config = *config; - } - // Disable resolution adaptation; we don't want it interfering with the - // test results. - // TODO(deadbeef): Do something more robust. Since we're testing for aspect - // ratios and not specific resolutions, is this even necessary? - modified_config.set_cpu_adaptation(false); - - dependencies.observer = this; - return peer_connection_factory_->CreatePeerConnection( - modified_config, std::move(dependencies)); - } - - void set_signaling_message_receiver( - SignalingMessageReceiver* signaling_message_receiver) { - signaling_message_receiver_ = signaling_message_receiver; - } - - void set_signaling_delay_ms(int delay_ms) { signaling_delay_ms_ = delay_ms; } - - void set_signal_ice_candidates(bool signal) { - signal_ice_candidates_ = signal; - } - - rtc::scoped_refptr CreateLocalVideoTrackInternal( - webrtc::FakePeriodicVideoSource::Config config) { - // Set max frame rate to 10fps to reduce the risk of test flakiness. - // TODO(deadbeef): Do something more robust. - config.frame_interval_ms = 100; - - video_track_sources_.emplace_back( - new rtc::RefCountedObject( - config, false /* remote */)); - rtc::scoped_refptr track( - peer_connection_factory_->CreateVideoTrack( - rtc::CreateRandomUuid(), video_track_sources_.back())); - if (!local_video_renderer_) { - local_video_renderer_.reset(new webrtc::FakeVideoTrackRenderer(track)); - } - return track; - } - - void HandleIncomingOffer(const std::string& msg) { - RTC_LOG(LS_INFO) << debug_name_ << ": HandleIncomingOffer"; - std::unique_ptr desc = - webrtc::CreateSessionDescription(SdpType::kOffer, msg); - if (received_sdp_munger_) { - received_sdp_munger_(desc->description()); - } - - EXPECT_TRUE(SetRemoteDescription(std::move(desc))); - // Setting a remote description may have changed the number of receivers, - // so reset the receiver observers. - ResetRtpReceiverObservers(); - if (remote_offer_handler_) { - remote_offer_handler_(); - } - auto answer = CreateAnswer(); - ASSERT_NE(nullptr, answer); - EXPECT_TRUE(SetLocalDescriptionAndSendSdpMessage(std::move(answer))); - } - - void HandleIncomingAnswer(const std::string& msg) { - RTC_LOG(LS_INFO) << debug_name_ << ": HandleIncomingAnswer"; - std::unique_ptr desc = - webrtc::CreateSessionDescription(SdpType::kAnswer, msg); - if (received_sdp_munger_) { - received_sdp_munger_(desc->description()); - } - - EXPECT_TRUE(SetRemoteDescription(std::move(desc))); - // Set the RtpReceiverObserver after receivers are created. - ResetRtpReceiverObservers(); - } - - // Returns null on failure. - std::unique_ptr CreateAnswer() { - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); - pc()->CreateAnswer(observer, offer_answer_options_); - return WaitForDescriptionFromObserver(observer); - } - - std::unique_ptr WaitForDescriptionFromObserver( - MockCreateSessionDescriptionObserver* observer) { - EXPECT_EQ_WAIT(true, observer->called(), kDefaultTimeout); - if (!observer->result()) { - return nullptr; - } - auto description = observer->MoveDescription(); - if (generated_sdp_munger_) { - generated_sdp_munger_(description->description()); - } - return description; - } - - // Setting the local description and sending the SDP message over the fake - // signaling channel are combined into the same method because the SDP - // message needs to be sent as soon as SetLocalDescription finishes, without - // waiting for the observer to be called. This ensures that ICE candidates - // don't outrace the description. - bool SetLocalDescriptionAndSendSdpMessage( - std::unique_ptr desc) { - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); - RTC_LOG(LS_INFO) << debug_name_ << ": SetLocalDescriptionAndSendSdpMessage"; - SdpType type = desc->GetType(); - std::string sdp; - EXPECT_TRUE(desc->ToString(&sdp)); - RTC_LOG(LS_INFO) << debug_name_ << ": local SDP contents=\n" << sdp; - pc()->SetLocalDescription(observer, desc.release()); - RemoveUnusedVideoRenderers(); - // As mentioned above, we need to send the message immediately after - // SetLocalDescription. - SendSdpMessage(type, sdp); - EXPECT_TRUE_WAIT(observer->called(), kDefaultTimeout); - return true; - } - - bool SetRemoteDescription(std::unique_ptr desc) { - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); - RTC_LOG(LS_INFO) << debug_name_ << ": SetRemoteDescription"; - pc()->SetRemoteDescription(observer, desc.release()); - RemoveUnusedVideoRenderers(); - EXPECT_TRUE_WAIT(observer->called(), kDefaultTimeout); - return observer->result(); - } - - // This is a work around to remove unused fake_video_renderers from - // transceivers that have either stopped or are no longer receiving. - void RemoveUnusedVideoRenderers() { - if (sdp_semantics_ != SdpSemantics::kUnifiedPlan) { - return; - } - auto transceivers = pc()->GetTransceivers(); - std::set active_renderers; - for (auto& transceiver : transceivers) { - // Note - we don't check for direction here. This function is called - // before direction is set, and in that case, we should not remove - // the renderer. - if (transceiver->receiver()->media_type() == cricket::MEDIA_TYPE_VIDEO) { - active_renderers.insert(transceiver->receiver()->track()->id()); - } - } - for (auto it = fake_video_renderers_.begin(); - it != fake_video_renderers_.end();) { - // Remove fake video renderers belonging to any non-active transceivers. - if (!active_renderers.count(it->first)) { - it = fake_video_renderers_.erase(it); - } else { - it++; - } - } - } - - // Simulate sending a blob of SDP with delay |signaling_delay_ms_| (0 by - // default). - void SendSdpMessage(SdpType type, const std::string& msg) { - if (signaling_delay_ms_ == 0) { - RelaySdpMessageIfReceiverExists(type, msg); - } else { - invoker_.AsyncInvokeDelayed( - RTC_FROM_HERE, rtc::Thread::Current(), - rtc::Bind(&PeerConnectionWrapper::RelaySdpMessageIfReceiverExists, - this, type, msg), - signaling_delay_ms_); - } - } - - void RelaySdpMessageIfReceiverExists(SdpType type, const std::string& msg) { - if (signaling_message_receiver_) { - signaling_message_receiver_->ReceiveSdpMessage(type, msg); - } - } - - // Simulate trickling an ICE candidate with delay |signaling_delay_ms_| (0 by - // default). - void SendIceMessage(const std::string& sdp_mid, - int sdp_mline_index, - const std::string& msg) { - if (signaling_delay_ms_ == 0) { - RelayIceMessageIfReceiverExists(sdp_mid, sdp_mline_index, msg); - } else { - invoker_.AsyncInvokeDelayed( - RTC_FROM_HERE, rtc::Thread::Current(), - rtc::Bind(&PeerConnectionWrapper::RelayIceMessageIfReceiverExists, - this, sdp_mid, sdp_mline_index, msg), - signaling_delay_ms_); - } - } - - void RelayIceMessageIfReceiverExists(const std::string& sdp_mid, - int sdp_mline_index, - const std::string& msg) { - if (signaling_message_receiver_) { - signaling_message_receiver_->ReceiveIceMessage(sdp_mid, sdp_mline_index, - msg); - } - } - - // SignalingMessageReceiver callbacks. - void ReceiveSdpMessage(SdpType type, const std::string& msg) override { - if (type == SdpType::kOffer) { - HandleIncomingOffer(msg); - } else { - HandleIncomingAnswer(msg); - } - } - - void ReceiveIceMessage(const std::string& sdp_mid, - int sdp_mline_index, - const std::string& msg) override { - RTC_LOG(LS_INFO) << debug_name_ << ": ReceiveIceMessage"; - std::unique_ptr candidate( - webrtc::CreateIceCandidate(sdp_mid, sdp_mline_index, msg, nullptr)); - EXPECT_TRUE(pc()->AddIceCandidate(candidate.get())); - } - - // PeerConnectionObserver callbacks. - void OnSignalingChange( - webrtc::PeerConnectionInterface::SignalingState new_state) override { - EXPECT_EQ(pc()->signaling_state(), new_state); - peer_connection_signaling_state_history_.push_back(new_state); - } - void OnAddTrack(rtc::scoped_refptr receiver, - const std::vector>& - streams) override { - if (receiver->media_type() == cricket::MEDIA_TYPE_VIDEO) { - rtc::scoped_refptr video_track( - static_cast(receiver->track().get())); - ASSERT_TRUE(fake_video_renderers_.find(video_track->id()) == - fake_video_renderers_.end()); - fake_video_renderers_[video_track->id()] = - std::make_unique(video_track); - } - } - void OnRemoveTrack( - rtc::scoped_refptr receiver) override { - if (receiver->media_type() == cricket::MEDIA_TYPE_VIDEO) { - auto it = fake_video_renderers_.find(receiver->track()->id()); - if (it != fake_video_renderers_.end()) { - fake_video_renderers_.erase(it); - } else { - RTC_LOG(LS_ERROR) << "OnRemoveTrack called for non-active renderer"; - } - } - } - void OnRenegotiationNeeded() override {} - void OnIceConnectionChange( - webrtc::PeerConnectionInterface::IceConnectionState new_state) override { - EXPECT_EQ(pc()->ice_connection_state(), new_state); - ice_connection_state_history_.push_back(new_state); - } - void OnStandardizedIceConnectionChange( - webrtc::PeerConnectionInterface::IceConnectionState new_state) override { - standardized_ice_connection_state_history_.push_back(new_state); - } - void OnConnectionChange( - webrtc::PeerConnectionInterface::PeerConnectionState new_state) override { - peer_connection_state_history_.push_back(new_state); - } - - void OnIceGatheringChange( - webrtc::PeerConnectionInterface::IceGatheringState new_state) override { - EXPECT_EQ(pc()->ice_gathering_state(), new_state); - ice_gathering_state_history_.push_back(new_state); - } - - void OnIceSelectedCandidatePairChanged( - const cricket::CandidatePairChangeEvent& event) { - ice_candidate_pair_change_history_.push_back(event); - } - - void OnIceCandidate(const webrtc::IceCandidateInterface* candidate) override { - RTC_LOG(LS_INFO) << debug_name_ << ": OnIceCandidate"; - - if (remote_async_resolver_) { - const auto& local_candidate = candidate->candidate(); - if (local_candidate.address().IsUnresolvedIP()) { - RTC_DCHECK(local_candidate.type() == cricket::LOCAL_PORT_TYPE); - rtc::SocketAddress resolved_addr(local_candidate.address()); - const auto resolved_ip = mdns_responder_->GetMappedAddressForName( - local_candidate.address().hostname()); - RTC_DCHECK(!resolved_ip.IsNil()); - resolved_addr.SetResolvedIP(resolved_ip); - EXPECT_CALL(*remote_async_resolver_, GetResolvedAddress(_, _)) - .WillOnce(DoAll(SetArgPointee<1>(resolved_addr), Return(true))); - EXPECT_CALL(*remote_async_resolver_, Destroy(_)); - } - } - - std::string ice_sdp; - EXPECT_TRUE(candidate->ToString(&ice_sdp)); - if (signaling_message_receiver_ == nullptr || !signal_ice_candidates_) { - // Remote party may be deleted. - return; - } - SendIceMessage(candidate->sdp_mid(), candidate->sdp_mline_index(), ice_sdp); - last_candidate_gathered_ = candidate->candidate(); - } - void OnIceCandidateError(const std::string& address, - int port, - const std::string& url, - int error_code, - const std::string& error_text) override { - error_event_ = cricket::IceCandidateErrorEvent(address, port, url, - error_code, error_text); - } - void OnDataChannel( - rtc::scoped_refptr data_channel) override { - RTC_LOG(LS_INFO) << debug_name_ << ": OnDataChannel"; - data_channel_ = data_channel; - data_observer_.reset(new MockDataChannelObserver(data_channel)); - } - - std::string debug_name_; - - std::unique_ptr fake_network_manager_; - // Reference to the mDNS responder owned by |fake_network_manager_| after set. - webrtc::FakeMdnsResponder* mdns_responder_ = nullptr; - - rtc::scoped_refptr peer_connection_; - rtc::scoped_refptr - peer_connection_factory_; - - cricket::PortAllocator* port_allocator_; - // Needed to keep track of number of frames sent. - rtc::scoped_refptr fake_audio_capture_module_; - // Needed to keep track of number of frames received. - std::map> - fake_video_renderers_; - // Needed to ensure frames aren't received for removed tracks. - std::vector> - removed_fake_video_renderers_; - - // For remote peer communication. - SignalingMessageReceiver* signaling_message_receiver_ = nullptr; - int signaling_delay_ms_ = 0; - bool signal_ice_candidates_ = true; - cricket::Candidate last_candidate_gathered_; - cricket::IceCandidateErrorEvent error_event_; - - // Store references to the video sources we've created, so that we can stop - // them, if required. - std::vector> - video_track_sources_; - // |local_video_renderer_| attached to the first created local video track. - std::unique_ptr local_video_renderer_; - - SdpSemantics sdp_semantics_; - PeerConnectionInterface::RTCOfferAnswerOptions offer_answer_options_; - std::function received_sdp_munger_; - std::function generated_sdp_munger_; - std::function remote_offer_handler_; - rtc::MockAsyncResolver* remote_async_resolver_ = nullptr; - rtc::scoped_refptr data_channel_; - std::unique_ptr data_observer_; - - std::vector> rtp_receiver_observers_; - - std::vector - ice_connection_state_history_; - std::vector - standardized_ice_connection_state_history_; - std::vector - peer_connection_state_history_; - std::vector - ice_gathering_state_history_; - std::vector - ice_candidate_pair_change_history_; - std::vector - peer_connection_signaling_state_history_; - webrtc::FakeRtcEventLogFactory* event_log_factory_; - - // Variables for tracking delay stats on an audio track - int audio_packets_stat_ = 0; - double audio_delay_stat_ = 0.0; - std::string rtp_stats_id_; - std::string audio_track_stats_id_; - - rtc::AsyncInvoker invoker_; - - friend class PeerConnectionIntegrationBaseTest; -}; - -class MockRtcEventLogOutput : public webrtc::RtcEventLogOutput { - public: - virtual ~MockRtcEventLogOutput() = default; - MOCK_METHOD(bool, IsActive, (), (const, override)); - MOCK_METHOD(bool, Write, (const std::string&), (override)); -}; - -// This helper object is used for both specifying how many audio/video frames -// are expected to be received for a caller/callee. It provides helper functions -// to specify these expectations. The object initially starts in a state of no -// expectations. -class MediaExpectations { - public: - enum ExpectFrames { - kExpectSomeFrames, - kExpectNoFrames, - kNoExpectation, - }; - - void ExpectBidirectionalAudioAndVideo() { - ExpectBidirectionalAudio(); - ExpectBidirectionalVideo(); - } - - void ExpectBidirectionalAudio() { - CallerExpectsSomeAudio(); - CalleeExpectsSomeAudio(); - } - - void ExpectNoAudio() { - CallerExpectsNoAudio(); - CalleeExpectsNoAudio(); - } - - void ExpectBidirectionalVideo() { - CallerExpectsSomeVideo(); - CalleeExpectsSomeVideo(); - } - - void ExpectNoVideo() { - CallerExpectsNoVideo(); - CalleeExpectsNoVideo(); - } - - void CallerExpectsSomeAudioAndVideo() { - CallerExpectsSomeAudio(); - CallerExpectsSomeVideo(); - } - - void CalleeExpectsSomeAudioAndVideo() { - CalleeExpectsSomeAudio(); - CalleeExpectsSomeVideo(); - } - - // Caller's audio functions. - void CallerExpectsSomeAudio( - int expected_audio_frames = kDefaultExpectedAudioFrameCount) { - caller_audio_expectation_ = kExpectSomeFrames; - caller_audio_frames_expected_ = expected_audio_frames; - } - - void CallerExpectsNoAudio() { - caller_audio_expectation_ = kExpectNoFrames; - caller_audio_frames_expected_ = 0; - } - - // Caller's video functions. - void CallerExpectsSomeVideo( - int expected_video_frames = kDefaultExpectedVideoFrameCount) { - caller_video_expectation_ = kExpectSomeFrames; - caller_video_frames_expected_ = expected_video_frames; - } - - void CallerExpectsNoVideo() { - caller_video_expectation_ = kExpectNoFrames; - caller_video_frames_expected_ = 0; - } - - // Callee's audio functions. - void CalleeExpectsSomeAudio( - int expected_audio_frames = kDefaultExpectedAudioFrameCount) { - callee_audio_expectation_ = kExpectSomeFrames; - callee_audio_frames_expected_ = expected_audio_frames; - } - - void CalleeExpectsNoAudio() { - callee_audio_expectation_ = kExpectNoFrames; - callee_audio_frames_expected_ = 0; - } - - // Callee's video functions. - void CalleeExpectsSomeVideo( - int expected_video_frames = kDefaultExpectedVideoFrameCount) { - callee_video_expectation_ = kExpectSomeFrames; - callee_video_frames_expected_ = expected_video_frames; - } - - void CalleeExpectsNoVideo() { - callee_video_expectation_ = kExpectNoFrames; - callee_video_frames_expected_ = 0; - } - - ExpectFrames caller_audio_expectation_ = kNoExpectation; - ExpectFrames caller_video_expectation_ = kNoExpectation; - ExpectFrames callee_audio_expectation_ = kNoExpectation; - ExpectFrames callee_video_expectation_ = kNoExpectation; - int caller_audio_frames_expected_ = 0; - int caller_video_frames_expected_ = 0; - int callee_audio_frames_expected_ = 0; - int callee_video_frames_expected_ = 0; -}; - -class MockIceTransport : public webrtc::IceTransportInterface { - public: - MockIceTransport(const std::string& name, int component) - : internal_(std::make_unique( - name, - component, - nullptr /* network_thread */)) {} - ~MockIceTransport() = default; - cricket::IceTransportInternal* internal() { return internal_.get(); } - - private: - std::unique_ptr internal_; -}; - -class MockIceTransportFactory : public IceTransportFactory { - public: - ~MockIceTransportFactory() override = default; - rtc::scoped_refptr CreateIceTransport( - const std::string& transport_name, - int component, - IceTransportInit init) { - RecordIceTransportCreated(); - return new rtc::RefCountedObject(transport_name, - component); - } - MOCK_METHOD(void, RecordIceTransportCreated, ()); -}; - -// Tests two PeerConnections connecting to each other end-to-end, using a -// virtual network, fake A/V capture and fake encoder/decoders. The -// PeerConnections share the threads/socket servers, but use separate versions -// of everything else (including "PeerConnectionFactory"s). -class PeerConnectionIntegrationBaseTest : public ::testing::Test { - public: - explicit PeerConnectionIntegrationBaseTest(SdpSemantics sdp_semantics) - : sdp_semantics_(sdp_semantics), - ss_(new rtc::VirtualSocketServer()), - fss_(new rtc::FirewallSocketServer(ss_.get())), - network_thread_(new rtc::Thread(fss_.get())), - worker_thread_(rtc::Thread::Create()) { - network_thread_->SetName("PCNetworkThread", this); - worker_thread_->SetName("PCWorkerThread", this); - RTC_CHECK(network_thread_->Start()); - RTC_CHECK(worker_thread_->Start()); - webrtc::metrics::Reset(); - } - - ~PeerConnectionIntegrationBaseTest() { - // The PeerConnections should be deleted before the TurnCustomizers. - // A TurnPort is created with a raw pointer to a TurnCustomizer. The - // TurnPort has the same lifetime as the PeerConnection, so it's expected - // that the TurnCustomizer outlives the life of the PeerConnection or else - // when Send() is called it will hit a seg fault. - if (caller_) { - caller_->set_signaling_message_receiver(nullptr); - delete SetCallerPcWrapperAndReturnCurrent(nullptr); - } - if (callee_) { - callee_->set_signaling_message_receiver(nullptr); - delete SetCalleePcWrapperAndReturnCurrent(nullptr); - } - - // If turn servers were created for the test they need to be destroyed on - // the network thread. - network_thread()->Invoke(RTC_FROM_HERE, [this] { - turn_servers_.clear(); - turn_customizers_.clear(); - }); - } - - bool SignalingStateStable() { - return caller_->SignalingStateStable() && callee_->SignalingStateStable(); - } - - bool DtlsConnected() { - // TODO(deadbeef): kIceConnectionConnected currently means both ICE and DTLS - // are connected. This is an important distinction. Once we have separate - // ICE and DTLS state, this check needs to use the DTLS state. - return (callee()->ice_connection_state() == - webrtc::PeerConnectionInterface::kIceConnectionConnected || - callee()->ice_connection_state() == - webrtc::PeerConnectionInterface::kIceConnectionCompleted) && - (caller()->ice_connection_state() == - webrtc::PeerConnectionInterface::kIceConnectionConnected || - caller()->ice_connection_state() == - webrtc::PeerConnectionInterface::kIceConnectionCompleted); - } - - // When |event_log_factory| is null, the default implementation of the event - // log factory will be used. - std::unique_ptr CreatePeerConnectionWrapper( - const std::string& debug_name, - const PeerConnectionFactory::Options* options, - const RTCConfiguration* config, - webrtc::PeerConnectionDependencies dependencies, - std::unique_ptr event_log_factory, - bool reset_encoder_factory, - bool reset_decoder_factory) { - RTCConfiguration modified_config; - if (config) { - modified_config = *config; - } - modified_config.sdp_semantics = sdp_semantics_; - if (!dependencies.cert_generator) { - dependencies.cert_generator = - std::make_unique(); - } - std::unique_ptr client( - new PeerConnectionWrapper(debug_name)); - - if (!client->Init(options, &modified_config, std::move(dependencies), - network_thread_.get(), worker_thread_.get(), - std::move(event_log_factory), reset_encoder_factory, - reset_decoder_factory)) { - return nullptr; - } - return client; - } - - std::unique_ptr - CreatePeerConnectionWrapperWithFakeRtcEventLog( - const std::string& debug_name, - const PeerConnectionFactory::Options* options, - const RTCConfiguration* config, - webrtc::PeerConnectionDependencies dependencies) { - std::unique_ptr event_log_factory( - new webrtc::FakeRtcEventLogFactory(rtc::Thread::Current())); - return CreatePeerConnectionWrapper(debug_name, options, config, - std::move(dependencies), - std::move(event_log_factory), - /*reset_encoder_factory=*/false, - /*reset_decoder_factory=*/false); - } - - bool CreatePeerConnectionWrappers() { - return CreatePeerConnectionWrappersWithConfig( - PeerConnectionInterface::RTCConfiguration(), - PeerConnectionInterface::RTCConfiguration()); - } - - bool CreatePeerConnectionWrappersWithSdpSemantics( - SdpSemantics caller_semantics, - SdpSemantics callee_semantics) { - // Can't specify the sdp_semantics in the passed-in configuration since it - // will be overwritten by CreatePeerConnectionWrapper with whatever is - // stored in sdp_semantics_. So get around this by modifying the instance - // variable before calling CreatePeerConnectionWrapper for the caller and - // callee PeerConnections. - SdpSemantics original_semantics = sdp_semantics_; - sdp_semantics_ = caller_semantics; - caller_ = CreatePeerConnectionWrapper( - "Caller", nullptr, nullptr, webrtc::PeerConnectionDependencies(nullptr), - nullptr, - /*reset_encoder_factory=*/false, - /*reset_decoder_factory=*/false); - sdp_semantics_ = callee_semantics; - callee_ = CreatePeerConnectionWrapper( - "Callee", nullptr, nullptr, webrtc::PeerConnectionDependencies(nullptr), - nullptr, - /*reset_encoder_factory=*/false, - /*reset_decoder_factory=*/false); - sdp_semantics_ = original_semantics; - return caller_ && callee_; - } - - bool CreatePeerConnectionWrappersWithConfig( - const PeerConnectionInterface::RTCConfiguration& caller_config, - const PeerConnectionInterface::RTCConfiguration& callee_config) { - caller_ = CreatePeerConnectionWrapper( - "Caller", nullptr, &caller_config, - webrtc::PeerConnectionDependencies(nullptr), nullptr, - /*reset_encoder_factory=*/false, - /*reset_decoder_factory=*/false); - callee_ = CreatePeerConnectionWrapper( - "Callee", nullptr, &callee_config, - webrtc::PeerConnectionDependencies(nullptr), nullptr, - /*reset_encoder_factory=*/false, - /*reset_decoder_factory=*/false); - return caller_ && callee_; - } - - bool CreatePeerConnectionWrappersWithConfigAndDeps( - const PeerConnectionInterface::RTCConfiguration& caller_config, - webrtc::PeerConnectionDependencies caller_dependencies, - const PeerConnectionInterface::RTCConfiguration& callee_config, - webrtc::PeerConnectionDependencies callee_dependencies) { - caller_ = - CreatePeerConnectionWrapper("Caller", nullptr, &caller_config, - std::move(caller_dependencies), nullptr, - /*reset_encoder_factory=*/false, - /*reset_decoder_factory=*/false); - callee_ = - CreatePeerConnectionWrapper("Callee", nullptr, &callee_config, - std::move(callee_dependencies), nullptr, - /*reset_encoder_factory=*/false, - /*reset_decoder_factory=*/false); - return caller_ && callee_; - } - - bool CreatePeerConnectionWrappersWithOptions( - const PeerConnectionFactory::Options& caller_options, - const PeerConnectionFactory::Options& callee_options) { - caller_ = CreatePeerConnectionWrapper( - "Caller", &caller_options, nullptr, - webrtc::PeerConnectionDependencies(nullptr), nullptr, - /*reset_encoder_factory=*/false, - /*reset_decoder_factory=*/false); - callee_ = CreatePeerConnectionWrapper( - "Callee", &callee_options, nullptr, - webrtc::PeerConnectionDependencies(nullptr), nullptr, - /*reset_encoder_factory=*/false, - /*reset_decoder_factory=*/false); - return caller_ && callee_; - } - - bool CreatePeerConnectionWrappersWithFakeRtcEventLog() { - PeerConnectionInterface::RTCConfiguration default_config; - caller_ = CreatePeerConnectionWrapperWithFakeRtcEventLog( - "Caller", nullptr, &default_config, - webrtc::PeerConnectionDependencies(nullptr)); - callee_ = CreatePeerConnectionWrapperWithFakeRtcEventLog( - "Callee", nullptr, &default_config, - webrtc::PeerConnectionDependencies(nullptr)); - return caller_ && callee_; - } - - std::unique_ptr - CreatePeerConnectionWrapperWithAlternateKey() { - std::unique_ptr cert_generator( - new FakeRTCCertificateGenerator()); - cert_generator->use_alternate_key(); - - webrtc::PeerConnectionDependencies dependencies(nullptr); - dependencies.cert_generator = std::move(cert_generator); - return CreatePeerConnectionWrapper("New Peer", nullptr, nullptr, - std::move(dependencies), nullptr, - /*reset_encoder_factory=*/false, - /*reset_decoder_factory=*/false); - } - - bool CreateOneDirectionalPeerConnectionWrappers(bool caller_to_callee) { - caller_ = CreatePeerConnectionWrapper( - "Caller", nullptr, nullptr, webrtc::PeerConnectionDependencies(nullptr), - nullptr, - /*reset_encoder_factory=*/!caller_to_callee, - /*reset_decoder_factory=*/caller_to_callee); - callee_ = CreatePeerConnectionWrapper( - "Callee", nullptr, nullptr, webrtc::PeerConnectionDependencies(nullptr), - nullptr, - /*reset_encoder_factory=*/caller_to_callee, - /*reset_decoder_factory=*/!caller_to_callee); - return caller_ && callee_; - } - - cricket::TestTurnServer* CreateTurnServer( - rtc::SocketAddress internal_address, - rtc::SocketAddress external_address, - cricket::ProtocolType type = cricket::ProtocolType::PROTO_UDP, - const std::string& common_name = "test turn server") { - rtc::Thread* thread = network_thread(); - std::unique_ptr turn_server = - network_thread()->Invoke>( - RTC_FROM_HERE, - [thread, internal_address, external_address, type, common_name] { - return std::make_unique( - thread, internal_address, external_address, type, - /*ignore_bad_certs=*/true, common_name); - }); - turn_servers_.push_back(std::move(turn_server)); - // Interactions with the turn server should be done on the network thread. - return turn_servers_.back().get(); - } - - cricket::TestTurnCustomizer* CreateTurnCustomizer() { - std::unique_ptr turn_customizer = - network_thread()->Invoke>( - RTC_FROM_HERE, - [] { return std::make_unique(); }); - turn_customizers_.push_back(std::move(turn_customizer)); - // Interactions with the turn customizer should be done on the network - // thread. - return turn_customizers_.back().get(); - } - - // Checks that the function counters for a TestTurnCustomizer are greater than - // 0. - void ExpectTurnCustomizerCountersIncremented( - cricket::TestTurnCustomizer* turn_customizer) { - unsigned int allow_channel_data_counter = - network_thread()->Invoke( - RTC_FROM_HERE, [turn_customizer] { - return turn_customizer->allow_channel_data_cnt_; - }); - EXPECT_GT(allow_channel_data_counter, 0u); - unsigned int modify_counter = network_thread()->Invoke( - RTC_FROM_HERE, - [turn_customizer] { return turn_customizer->modify_cnt_; }); - EXPECT_GT(modify_counter, 0u); - } - - // Once called, SDP blobs and ICE candidates will be automatically signaled - // between PeerConnections. - void ConnectFakeSignaling() { - caller_->set_signaling_message_receiver(callee_.get()); - callee_->set_signaling_message_receiver(caller_.get()); - } - - // Once called, SDP blobs will be automatically signaled between - // PeerConnections. Note that ICE candidates will not be signaled unless they - // are in the exchanged SDP blobs. - void ConnectFakeSignalingForSdpOnly() { - ConnectFakeSignaling(); - SetSignalIceCandidates(false); - } - - void SetSignalingDelayMs(int delay_ms) { - caller_->set_signaling_delay_ms(delay_ms); - callee_->set_signaling_delay_ms(delay_ms); - } - - void SetSignalIceCandidates(bool signal) { - caller_->set_signal_ice_candidates(signal); - callee_->set_signal_ice_candidates(signal); - } - - // Messages may get lost on the unreliable DataChannel, so we send multiple - // times to avoid test flakiness. - void SendRtpDataWithRetries(webrtc::DataChannelInterface* dc, - const std::string& data, - int retries) { - for (int i = 0; i < retries; ++i) { - dc->Send(DataBuffer(data)); - } - } - - rtc::Thread* network_thread() { return network_thread_.get(); } - - rtc::VirtualSocketServer* virtual_socket_server() { return ss_.get(); } - - PeerConnectionWrapper* caller() { return caller_.get(); } - - // Set the |caller_| to the |wrapper| passed in and return the - // original |caller_|. - PeerConnectionWrapper* SetCallerPcWrapperAndReturnCurrent( - PeerConnectionWrapper* wrapper) { - PeerConnectionWrapper* old = caller_.release(); - caller_.reset(wrapper); - return old; - } - - PeerConnectionWrapper* callee() { return callee_.get(); } - - // Set the |callee_| to the |wrapper| passed in and return the - // original |callee_|. - PeerConnectionWrapper* SetCalleePcWrapperAndReturnCurrent( - PeerConnectionWrapper* wrapper) { - PeerConnectionWrapper* old = callee_.release(); - callee_.reset(wrapper); - return old; - } - - void SetPortAllocatorFlags(uint32_t caller_flags, uint32_t callee_flags) { - network_thread()->Invoke( - RTC_FROM_HERE, rtc::Bind(&cricket::PortAllocator::set_flags, - caller()->port_allocator(), caller_flags)); - network_thread()->Invoke( - RTC_FROM_HERE, rtc::Bind(&cricket::PortAllocator::set_flags, - callee()->port_allocator(), callee_flags)); - } - - rtc::FirewallSocketServer* firewall() const { return fss_.get(); } - - // Expects the provided number of new frames to be received within - // kMaxWaitForFramesMs. The new expected frames are specified in - // |media_expectations|. Returns false if any of the expectations were - // not met. - bool ExpectNewFrames(const MediaExpectations& media_expectations) { - // Make sure there are no bogus tracks confusing the issue. - caller()->RemoveUnusedVideoRenderers(); - callee()->RemoveUnusedVideoRenderers(); - // First initialize the expected frame counts based upon the current - // frame count. - int total_caller_audio_frames_expected = caller()->audio_frames_received(); - if (media_expectations.caller_audio_expectation_ == - MediaExpectations::kExpectSomeFrames) { - total_caller_audio_frames_expected += - media_expectations.caller_audio_frames_expected_; - } - int total_caller_video_frames_expected = - caller()->min_video_frames_received_per_track(); - if (media_expectations.caller_video_expectation_ == - MediaExpectations::kExpectSomeFrames) { - total_caller_video_frames_expected += - media_expectations.caller_video_frames_expected_; - } - int total_callee_audio_frames_expected = callee()->audio_frames_received(); - if (media_expectations.callee_audio_expectation_ == - MediaExpectations::kExpectSomeFrames) { - total_callee_audio_frames_expected += - media_expectations.callee_audio_frames_expected_; - } - int total_callee_video_frames_expected = - callee()->min_video_frames_received_per_track(); - if (media_expectations.callee_video_expectation_ == - MediaExpectations::kExpectSomeFrames) { - total_callee_video_frames_expected += - media_expectations.callee_video_frames_expected_; - } - - // Wait for the expected frames. - EXPECT_TRUE_WAIT(caller()->audio_frames_received() >= - total_caller_audio_frames_expected && - caller()->min_video_frames_received_per_track() >= - total_caller_video_frames_expected && - callee()->audio_frames_received() >= - total_callee_audio_frames_expected && - callee()->min_video_frames_received_per_track() >= - total_callee_video_frames_expected, - kMaxWaitForFramesMs); - bool expectations_correct = - caller()->audio_frames_received() >= - total_caller_audio_frames_expected && - caller()->min_video_frames_received_per_track() >= - total_caller_video_frames_expected && - callee()->audio_frames_received() >= - total_callee_audio_frames_expected && - callee()->min_video_frames_received_per_track() >= - total_callee_video_frames_expected; - - // After the combined wait, print out a more detailed message upon - // failure. - EXPECT_GE(caller()->audio_frames_received(), - total_caller_audio_frames_expected); - EXPECT_GE(caller()->min_video_frames_received_per_track(), - total_caller_video_frames_expected); - EXPECT_GE(callee()->audio_frames_received(), - total_callee_audio_frames_expected); - EXPECT_GE(callee()->min_video_frames_received_per_track(), - total_callee_video_frames_expected); - - // We want to make sure nothing unexpected was received. - if (media_expectations.caller_audio_expectation_ == - MediaExpectations::kExpectNoFrames) { - EXPECT_EQ(caller()->audio_frames_received(), - total_caller_audio_frames_expected); - if (caller()->audio_frames_received() != - total_caller_audio_frames_expected) { - expectations_correct = false; - } - } - if (media_expectations.caller_video_expectation_ == - MediaExpectations::kExpectNoFrames) { - EXPECT_EQ(caller()->min_video_frames_received_per_track(), - total_caller_video_frames_expected); - if (caller()->min_video_frames_received_per_track() != - total_caller_video_frames_expected) { - expectations_correct = false; - } - } - if (media_expectations.callee_audio_expectation_ == - MediaExpectations::kExpectNoFrames) { - EXPECT_EQ(callee()->audio_frames_received(), - total_callee_audio_frames_expected); - if (callee()->audio_frames_received() != - total_callee_audio_frames_expected) { - expectations_correct = false; - } - } - if (media_expectations.callee_video_expectation_ == - MediaExpectations::kExpectNoFrames) { - EXPECT_EQ(callee()->min_video_frames_received_per_track(), - total_callee_video_frames_expected); - if (callee()->min_video_frames_received_per_track() != - total_callee_video_frames_expected) { - expectations_correct = false; - } - } - return expectations_correct; - } +#include - void ClosePeerConnections() { - caller()->pc()->Close(); - callee()->pc()->Close(); - } - - void TestNegotiatedCipherSuite( - const PeerConnectionFactory::Options& caller_options, - const PeerConnectionFactory::Options& callee_options, - int expected_cipher_suite) { - ASSERT_TRUE(CreatePeerConnectionWrappersWithOptions(caller_options, - callee_options)); - ConnectFakeSignaling(); - caller()->AddAudioVideoTracks(); - callee()->AddAudioVideoTracks(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(DtlsConnected(), kDefaultTimeout); - EXPECT_EQ_WAIT(rtc::SrtpCryptoSuiteToName(expected_cipher_suite), - caller()->OldGetStats()->SrtpCipher(), kDefaultTimeout); - // TODO(bugs.webrtc.org/9456): Fix it. - EXPECT_METRIC_EQ(1, webrtc::metrics::NumEvents( - "WebRTC.PeerConnection.SrtpCryptoSuite.Audio", - expected_cipher_suite)); - } +#include +#include +#include +#include +#include +#include - void TestGcmNegotiationUsesCipherSuite(bool local_gcm_enabled, - bool remote_gcm_enabled, - bool aes_ctr_enabled, - int expected_cipher_suite) { - PeerConnectionFactory::Options caller_options; - caller_options.crypto_options.srtp.enable_gcm_crypto_suites = - local_gcm_enabled; - caller_options.crypto_options.srtp.enable_aes128_sha1_80_crypto_cipher = - aes_ctr_enabled; - PeerConnectionFactory::Options callee_options; - callee_options.crypto_options.srtp.enable_gcm_crypto_suites = - remote_gcm_enabled; - callee_options.crypto_options.srtp.enable_aes128_sha1_80_crypto_cipher = - aes_ctr_enabled; - TestNegotiatedCipherSuite(caller_options, callee_options, - expected_cipher_suite); - } +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "api/async_resolver_factory.h" +#include "api/candidate.h" +#include "api/crypto/crypto_options.h" +#include "api/dtmf_sender_interface.h" +#include "api/ice_transport_interface.h" +#include "api/jsep.h" +#include "api/media_stream_interface.h" +#include "api/media_types.h" +#include "api/peer_connection_interface.h" +#include "api/rtc_error.h" +#include "api/rtc_event_log/rtc_event.h" +#include "api/rtc_event_log/rtc_event_log.h" +#include "api/rtc_event_log_output.h" +#include "api/rtp_parameters.h" +#include "api/rtp_receiver_interface.h" +#include "api/rtp_sender_interface.h" +#include "api/rtp_transceiver_direction.h" +#include "api/rtp_transceiver_interface.h" +#include "api/scoped_refptr.h" +#include "api/stats/rtc_stats.h" +#include "api/stats/rtc_stats_report.h" +#include "api/stats/rtcstats_objects.h" +#include "api/transport/rtp/rtp_source.h" +#include "api/uma_metrics.h" +#include "api/units/time_delta.h" +#include "api/video/video_rotation.h" +#include "logging/rtc_event_log/fake_rtc_event_log.h" +#include "logging/rtc_event_log/fake_rtc_event_log_factory.h" +#include "media/base/codec.h" +#include "media/base/media_constants.h" +#include "media/base/stream_params.h" +#include "p2p/base/mock_async_resolver.h" +#include "p2p/base/port.h" +#include "p2p/base/port_allocator.h" +#include "p2p/base/port_interface.h" +#include "p2p/base/stun_server.h" +#include "p2p/base/test_stun_server.h" +#include "p2p/base/test_turn_customizer.h" +#include "p2p/base/test_turn_server.h" +#include "p2p/base/transport_description.h" +#include "p2p/base/transport_info.h" +#include "pc/media_session.h" +#include "pc/peer_connection.h" +#include "pc/peer_connection_factory.h" +#include "pc/session_description.h" +#include "pc/test/fake_periodic_video_source.h" +#include "pc/test/integration_test_helpers.h" +#include "pc/test/mock_peer_connection_observers.h" +#include "rtc_base/fake_clock.h" +#include "rtc_base/fake_mdns_responder.h" +#include "rtc_base/fake_network.h" +#include "rtc_base/firewall_socket_server.h" +#include "rtc_base/gunit.h" +#include "rtc_base/helpers.h" +#include "rtc_base/location.h" +#include "rtc_base/logging.h" +#include "rtc_base/ref_counted_object.h" +#include "rtc_base/socket_address.h" +#include "rtc_base/ssl_certificate.h" +#include "rtc_base/ssl_fingerprint.h" +#include "rtc_base/ssl_identity.h" +#include "rtc_base/ssl_stream_adapter.h" +#include "rtc_base/test_certificate_verifier.h" +#include "rtc_base/thread.h" +#include "rtc_base/time_utils.h" +#include "rtc_base/virtual_socket_server.h" +#include "system_wrappers/include/metrics.h" - protected: - SdpSemantics sdp_semantics_; +namespace webrtc { - private: - // |ss_| is used by |network_thread_| so it must be destroyed later. - std::unique_ptr ss_; - std::unique_ptr fss_; - // |network_thread_| and |worker_thread_| are used by both - // |caller_| and |callee_| so they must be destroyed - // later. - std::unique_ptr network_thread_; - std::unique_ptr worker_thread_; - // The turn servers and turn customizers should be accessed & deleted on the - // network thread to avoid a race with the socket read/write that occurs - // on the network thread. - std::vector> turn_servers_; - std::vector> turn_customizers_; - std::unique_ptr caller_; - std::unique_ptr callee_; -}; +namespace { class PeerConnectionIntegrationTest : public PeerConnectionIntegrationBaseTest, @@ -1888,8 +205,8 @@ class DummyDtmfObserver : public DtmfSenderObserverInterface { // Assumes |sender| already has an audio track added and the offer/answer // exchange is done. -void TestDtmfFromSenderToReceiver(PeerConnectionWrapper* sender, - PeerConnectionWrapper* receiver) { +void TestDtmfFromSenderToReceiver(PeerConnectionIntegrationWrapper* sender, + PeerConnectionIntegrationWrapper* receiver) { // We should be able to get a DTMF sender from the local sender. rtc::scoped_refptr dtmf_sender = sender->pc()->GetSenders().at(0)->GetDtmfSender(); @@ -2302,7 +619,7 @@ TEST_P(PeerConnectionIntegrationTest, CallTransferredForCallee) { // Keep the original peer around which will still send packets to the // receiving client. These SRTP packets will be dropped. - std::unique_ptr original_peer( + std::unique_ptr original_peer( SetCallerPcWrapperAndReturnCurrent( CreatePeerConnectionWrapperWithAlternateKey().release())); // TODO(deadbeef): Why do we call Close here? That goes against the comment @@ -2331,7 +648,7 @@ TEST_P(PeerConnectionIntegrationTest, CallTransferredForCaller) { // Keep the original peer around which will still send packets to the // receiving client. These SRTP packets will be dropped. - std::unique_ptr original_peer( + std::unique_ptr original_peer( SetCalleePcWrapperAndReturnCurrent( CreatePeerConnectionWrapperWithAlternateKey().release())); // TODO(deadbeef): Why do we call Close here? That goes against the comment @@ -3515,424 +1832,6 @@ TEST_P(PeerConnectionIntegrationTest, EndToEndCallWithGcmCipher) { ASSERT_TRUE(ExpectNewFrames(media_expectations)); } -// This test sets up a call between two parties with audio, video and an RTP -// data channel. -TEST_P(PeerConnectionIntegrationTest, EndToEndCallWithRtpDataChannel) { - PeerConnectionInterface::RTCConfiguration rtc_config; - rtc_config.enable_rtp_data_channel = true; - rtc_config.enable_dtls_srtp = false; - ASSERT_TRUE(CreatePeerConnectionWrappersWithConfig(rtc_config, rtc_config)); - ConnectFakeSignaling(); - // Expect that data channel created on caller side will show up for callee as - // well. - caller()->CreateDataChannel(); - caller()->AddAudioVideoTracks(); - callee()->AddAudioVideoTracks(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - // Ensure the existence of the RTP data channel didn't impede audio/video. - MediaExpectations media_expectations; - media_expectations.ExpectBidirectionalAudioAndVideo(); - ASSERT_TRUE(ExpectNewFrames(media_expectations)); - ASSERT_NE(nullptr, caller()->data_channel()); - ASSERT_NE(nullptr, callee()->data_channel()); - EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); - EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); - - // Ensure data can be sent in both directions. - std::string data = "hello world"; - SendRtpDataWithRetries(caller()->data_channel(), data, 5); - EXPECT_EQ_WAIT(data, callee()->data_observer()->last_message(), - kDefaultTimeout); - SendRtpDataWithRetries(callee()->data_channel(), data, 5); - EXPECT_EQ_WAIT(data, caller()->data_observer()->last_message(), - kDefaultTimeout); -} - -TEST_P(PeerConnectionIntegrationTest, RtpDataChannelWorksAfterRollback) { - PeerConnectionInterface::RTCConfiguration rtc_config; - rtc_config.enable_rtp_data_channel = true; - rtc_config.enable_dtls_srtp = false; - ASSERT_TRUE(CreatePeerConnectionWrappersWithConfig(rtc_config, rtc_config)); - ConnectFakeSignaling(); - auto data_channel = caller()->pc()->CreateDataChannel("label_1", nullptr); - ASSERT_TRUE(data_channel.get() != nullptr); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - - caller()->CreateDataChannel("label_2", nullptr); - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); - caller()->pc()->SetLocalDescription(observer, - caller()->CreateOfferAndWait().release()); - EXPECT_TRUE_WAIT(observer->called(), kDefaultTimeout); - caller()->Rollback(); - - std::string data = "hello world"; - SendRtpDataWithRetries(data_channel, data, 5); - EXPECT_EQ_WAIT(data, callee()->data_observer()->last_message(), - kDefaultTimeout); -} - -// Ensure that an RTP data channel is signaled as closed for the caller when -// the callee rejects it in a subsequent offer. -TEST_P(PeerConnectionIntegrationTest, - RtpDataChannelSignaledClosedInCalleeOffer) { - // Same procedure as above test. - PeerConnectionInterface::RTCConfiguration rtc_config; - rtc_config.enable_rtp_data_channel = true; - rtc_config.enable_dtls_srtp = false; - ASSERT_TRUE(CreatePeerConnectionWrappersWithConfig(rtc_config, rtc_config)); - ConnectFakeSignaling(); - caller()->CreateDataChannel(); - caller()->AddAudioVideoTracks(); - callee()->AddAudioVideoTracks(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - ASSERT_NE(nullptr, caller()->data_channel()); - ASSERT_NE(nullptr, callee()->data_channel()); - ASSERT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); - - // Close the data channel on the callee, and do an updated offer/answer. - callee()->data_channel()->Close(); - callee()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - EXPECT_FALSE(caller()->data_observer()->IsOpen()); - EXPECT_FALSE(callee()->data_observer()->IsOpen()); -} - -// Tests that data is buffered in an RTP data channel until an observer is -// registered for it. -// -// NOTE: RTP data channels can receive data before the underlying -// transport has detected that a channel is writable and thus data can be -// received before the data channel state changes to open. That is hard to test -// but the same buffering is expected to be used in that case. -// -// Use fake clock and simulated network delay so that we predictably can wait -// until an SCTP message has been delivered without "sleep()"ing. -TEST_P(PeerConnectionIntegrationTestWithFakeClock, - DataBufferedUntilRtpDataChannelObserverRegistered) { - virtual_socket_server()->set_delay_mean(5); // 5 ms per hop. - virtual_socket_server()->UpdateDelayDistribution(); - - PeerConnectionInterface::RTCConfiguration rtc_config; - rtc_config.enable_rtp_data_channel = true; - rtc_config.enable_dtls_srtp = false; - ASSERT_TRUE(CreatePeerConnectionWrappersWithConfig(rtc_config, rtc_config)); - ConnectFakeSignaling(); - caller()->CreateDataChannel(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE(caller()->data_channel() != nullptr); - ASSERT_TRUE_SIMULATED_WAIT(callee()->data_channel() != nullptr, - kDefaultTimeout, FakeClock()); - ASSERT_TRUE_SIMULATED_WAIT(caller()->data_observer()->IsOpen(), - kDefaultTimeout, FakeClock()); - ASSERT_EQ_SIMULATED_WAIT(DataChannelInterface::kOpen, - callee()->data_channel()->state(), kDefaultTimeout, - FakeClock()); - - // Unregister the observer which is normally automatically registered. - callee()->data_channel()->UnregisterObserver(); - // Send data and advance fake clock until it should have been received. - std::string data = "hello world"; - caller()->data_channel()->Send(DataBuffer(data)); - SIMULATED_WAIT(false, 50, FakeClock()); - - // Attach data channel and expect data to be received immediately. Note that - // EXPECT_EQ_WAIT is used, such that the simulated clock is not advanced any - // further, but data can be received even if the callback is asynchronous. - MockDataChannelObserver new_observer(callee()->data_channel()); - EXPECT_EQ_SIMULATED_WAIT(data, new_observer.last_message(), kDefaultTimeout, - FakeClock()); -} - -// This test sets up a call between two parties with audio, video and but only -// the caller client supports RTP data channels. -TEST_P(PeerConnectionIntegrationTest, RtpDataChannelsRejectedByCallee) { - PeerConnectionInterface::RTCConfiguration rtc_config_1; - rtc_config_1.enable_rtp_data_channel = true; - // Must disable DTLS to make negotiation succeed. - rtc_config_1.enable_dtls_srtp = false; - PeerConnectionInterface::RTCConfiguration rtc_config_2; - rtc_config_2.enable_dtls_srtp = false; - rtc_config_2.enable_dtls_srtp = false; - ASSERT_TRUE( - CreatePeerConnectionWrappersWithConfig(rtc_config_1, rtc_config_2)); - ConnectFakeSignaling(); - caller()->CreateDataChannel(); - ASSERT_TRUE(caller()->data_channel() != nullptr); - caller()->AddAudioVideoTracks(); - callee()->AddAudioVideoTracks(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - // The caller should still have a data channel, but it should be closed, and - // one should ever have been created for the callee. - EXPECT_TRUE(caller()->data_channel() != nullptr); - EXPECT_FALSE(caller()->data_observer()->IsOpen()); - EXPECT_EQ(nullptr, callee()->data_channel()); -} - -// This test sets up a call between two parties with audio, and video. When -// audio and video is setup and flowing, an RTP data channel is negotiated. -TEST_P(PeerConnectionIntegrationTest, AddRtpDataChannelInSubsequentOffer) { - PeerConnectionInterface::RTCConfiguration rtc_config; - rtc_config.enable_rtp_data_channel = true; - rtc_config.enable_dtls_srtp = false; - ASSERT_TRUE(CreatePeerConnectionWrappersWithConfig(rtc_config, rtc_config)); - ConnectFakeSignaling(); - // Do initial offer/answer with audio/video. - caller()->AddAudioVideoTracks(); - callee()->AddAudioVideoTracks(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - // Create data channel and do new offer and answer. - caller()->CreateDataChannel(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - ASSERT_NE(nullptr, caller()->data_channel()); - ASSERT_NE(nullptr, callee()->data_channel()); - EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); - EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); - // Ensure data can be sent in both directions. - std::string data = "hello world"; - SendRtpDataWithRetries(caller()->data_channel(), data, 5); - EXPECT_EQ_WAIT(data, callee()->data_observer()->last_message(), - kDefaultTimeout); - SendRtpDataWithRetries(callee()->data_channel(), data, 5); - EXPECT_EQ_WAIT(data, caller()->data_observer()->last_message(), - kDefaultTimeout); -} - -#ifdef HAVE_SCTP - -// This test sets up a call between two parties with audio, video and an SCTP -// data channel. -TEST_P(PeerConnectionIntegrationTest, EndToEndCallWithSctpDataChannel) { - ASSERT_TRUE(CreatePeerConnectionWrappers()); - ConnectFakeSignaling(); - // Expect that data channel created on caller side will show up for callee as - // well. - caller()->CreateDataChannel(); - caller()->AddAudioVideoTracks(); - callee()->AddAudioVideoTracks(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - // Ensure the existence of the SCTP data channel didn't impede audio/video. - MediaExpectations media_expectations; - media_expectations.ExpectBidirectionalAudioAndVideo(); - ASSERT_TRUE(ExpectNewFrames(media_expectations)); - // Caller data channel should already exist (it created one). Callee data - // channel may not exist yet, since negotiation happens in-band, not in SDP. - ASSERT_NE(nullptr, caller()->data_channel()); - ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); - EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); - EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); - - // Ensure data can be sent in both directions. - std::string data = "hello world"; - caller()->data_channel()->Send(DataBuffer(data)); - EXPECT_EQ_WAIT(data, callee()->data_observer()->last_message(), - kDefaultTimeout); - callee()->data_channel()->Send(DataBuffer(data)); - EXPECT_EQ_WAIT(data, caller()->data_observer()->last_message(), - kDefaultTimeout); -} - -// Ensure that when the callee closes an SCTP data channel, the closing -// procedure results in the data channel being closed for the caller as well. -TEST_P(PeerConnectionIntegrationTest, CalleeClosesSctpDataChannel) { - // Same procedure as above test. - ASSERT_TRUE(CreatePeerConnectionWrappers()); - ConnectFakeSignaling(); - caller()->CreateDataChannel(); - caller()->AddAudioVideoTracks(); - callee()->AddAudioVideoTracks(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - ASSERT_NE(nullptr, caller()->data_channel()); - ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); - ASSERT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); - - // Close the data channel on the callee side, and wait for it to reach the - // "closed" state on both sides. - callee()->data_channel()->Close(); - EXPECT_TRUE_WAIT(!caller()->data_observer()->IsOpen(), kDefaultTimeout); - EXPECT_TRUE_WAIT(!callee()->data_observer()->IsOpen(), kDefaultTimeout); -} - -TEST_P(PeerConnectionIntegrationTest, SctpDataChannelConfigSentToOtherSide) { - ASSERT_TRUE(CreatePeerConnectionWrappers()); - ConnectFakeSignaling(); - webrtc::DataChannelInit init; - init.id = 53; - init.maxRetransmits = 52; - caller()->CreateDataChannel("data-channel", &init); - caller()->AddAudioVideoTracks(); - callee()->AddAudioVideoTracks(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); - // Since "negotiated" is false, the "id" parameter should be ignored. - EXPECT_NE(init.id, callee()->data_channel()->id()); - EXPECT_EQ("data-channel", callee()->data_channel()->label()); - EXPECT_EQ(init.maxRetransmits, callee()->data_channel()->maxRetransmits()); - EXPECT_FALSE(callee()->data_channel()->negotiated()); -} - -// Test usrsctp's ability to process unordered data stream, where data actually -// arrives out of order using simulated delays. Previously there have been some -// bugs in this area. -TEST_P(PeerConnectionIntegrationTest, StressTestUnorderedSctpDataChannel) { - // Introduce random network delays. - // Otherwise it's not a true "unordered" test. - virtual_socket_server()->set_delay_mean(20); - virtual_socket_server()->set_delay_stddev(5); - virtual_socket_server()->UpdateDelayDistribution(); - // Normal procedure, but with unordered data channel config. - ASSERT_TRUE(CreatePeerConnectionWrappers()); - ConnectFakeSignaling(); - webrtc::DataChannelInit init; - init.ordered = false; - caller()->CreateDataChannel(&init); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - ASSERT_NE(nullptr, caller()->data_channel()); - ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); - ASSERT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); - - static constexpr int kNumMessages = 100; - // Deliberately chosen to be larger than the MTU so messages get fragmented. - static constexpr size_t kMaxMessageSize = 4096; - // Create and send random messages. - std::vector sent_messages; - for (int i = 0; i < kNumMessages; ++i) { - size_t length = - (rand() % kMaxMessageSize) + 1; // NOLINT (rand_r instead of rand) - std::string message; - ASSERT_TRUE(rtc::CreateRandomString(length, &message)); - caller()->data_channel()->Send(DataBuffer(message)); - callee()->data_channel()->Send(DataBuffer(message)); - sent_messages.push_back(message); - } - - // Wait for all messages to be received. - EXPECT_EQ_WAIT(rtc::checked_cast(kNumMessages), - caller()->data_observer()->received_message_count(), - kDefaultTimeout); - EXPECT_EQ_WAIT(rtc::checked_cast(kNumMessages), - callee()->data_observer()->received_message_count(), - kDefaultTimeout); - - // Sort and compare to make sure none of the messages were corrupted. - std::vector caller_received_messages = - caller()->data_observer()->messages(); - std::vector callee_received_messages = - callee()->data_observer()->messages(); - absl::c_sort(sent_messages); - absl::c_sort(caller_received_messages); - absl::c_sort(callee_received_messages); - EXPECT_EQ(sent_messages, caller_received_messages); - EXPECT_EQ(sent_messages, callee_received_messages); -} - -// This test sets up a call between two parties with audio, and video. When -// audio and video are setup and flowing, an SCTP data channel is negotiated. -TEST_P(PeerConnectionIntegrationTest, AddSctpDataChannelInSubsequentOffer) { - ASSERT_TRUE(CreatePeerConnectionWrappers()); - ConnectFakeSignaling(); - // Do initial offer/answer with audio/video. - caller()->AddAudioVideoTracks(); - callee()->AddAudioVideoTracks(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - // Create data channel and do new offer and answer. - caller()->CreateDataChannel(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - // Caller data channel should already exist (it created one). Callee data - // channel may not exist yet, since negotiation happens in-band, not in SDP. - ASSERT_NE(nullptr, caller()->data_channel()); - ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); - EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); - EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); - // Ensure data can be sent in both directions. - std::string data = "hello world"; - caller()->data_channel()->Send(DataBuffer(data)); - EXPECT_EQ_WAIT(data, callee()->data_observer()->last_message(), - kDefaultTimeout); - callee()->data_channel()->Send(DataBuffer(data)); - EXPECT_EQ_WAIT(data, caller()->data_observer()->last_message(), - kDefaultTimeout); -} - -// Set up a connection initially just using SCTP data channels, later upgrading -// to audio/video, ensuring frames are received end-to-end. Effectively the -// inverse of the test above. -// This was broken in M57; see https://crbug.com/711243 -TEST_P(PeerConnectionIntegrationTest, SctpDataChannelToAudioVideoUpgrade) { - ASSERT_TRUE(CreatePeerConnectionWrappers()); - ConnectFakeSignaling(); - // Do initial offer/answer with just data channel. - caller()->CreateDataChannel(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - // Wait until data can be sent over the data channel. - ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); - ASSERT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); - - // Do subsequent offer/answer with two-way audio and video. Audio and video - // should end up bundled on the DTLS/ICE transport already used for data. - caller()->AddAudioVideoTracks(); - callee()->AddAudioVideoTracks(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - MediaExpectations media_expectations; - media_expectations.ExpectBidirectionalAudioAndVideo(); - ASSERT_TRUE(ExpectNewFrames(media_expectations)); -} - -static void MakeSpecCompliantSctpOffer(cricket::SessionDescription* desc) { - cricket::SctpDataContentDescription* dcd_offer = - GetFirstSctpDataContentDescription(desc); - // See https://crbug.com/webrtc/11211 - this function is a no-op - ASSERT_TRUE(dcd_offer); - dcd_offer->set_use_sctpmap(false); - dcd_offer->set_protocol("UDP/DTLS/SCTP"); -} - -// Test that the data channel works when a spec-compliant SCTP m= section is -// offered (using "a=sctp-port" instead of "a=sctpmap", and using -// "UDP/DTLS/SCTP" as the protocol). -TEST_P(PeerConnectionIntegrationTest, - DataChannelWorksWhenSpecCompliantSctpOfferReceived) { - ASSERT_TRUE(CreatePeerConnectionWrappers()); - ConnectFakeSignaling(); - caller()->CreateDataChannel(); - caller()->SetGeneratedSdpMunger(MakeSpecCompliantSctpOffer); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_channel() != nullptr, kDefaultTimeout); - EXPECT_TRUE_WAIT(caller()->data_observer()->IsOpen(), kDefaultTimeout); - EXPECT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); - - // Ensure data can be sent in both directions. - std::string data = "hello world"; - caller()->data_channel()->Send(DataBuffer(data)); - EXPECT_EQ_WAIT(data, callee()->data_observer()->last_message(), - kDefaultTimeout); - callee()->data_channel()->Send(DataBuffer(data)); - EXPECT_EQ_WAIT(data, caller()->data_observer()->last_message(), - kDefaultTimeout); -} - -#endif // HAVE_SCTP - // Test that the ICE connection and gathering states eventually reach // "complete". TEST_P(PeerConnectionIntegrationTest, IceStatesReachCompletion) { @@ -3958,14 +1857,25 @@ TEST_P(PeerConnectionIntegrationTest, IceStatesReachCompletion) { callee()->ice_connection_state(), kDefaultTimeout); } +#if !defined(THREAD_SANITIZER) +// This test provokes TSAN errors. See bugs.webrtc.org/3608 + constexpr int kOnlyLocalPorts = cricket::PORTALLOCATOR_DISABLE_STUN | cricket::PORTALLOCATOR_DISABLE_RELAY | cricket::PORTALLOCATOR_DISABLE_TCP; // Use a mock resolver to resolve the hostname back to the original IP on both // sides and check that the ICE connection connects. +// TODO(bugs.webrtc.org/12590): Flaky on Windows and on Linux MSAN. +#if defined(WEBRTC_WIN) || defined(WEBRTC_LINUX) +#define MAYBE_IceStatesReachCompletionWithRemoteHostname \ + DISABLED_IceStatesReachCompletionWithRemoteHostname +#else +#define MAYBE_IceStatesReachCompletionWithRemoteHostname \ + IceStatesReachCompletionWithRemoteHostname +#endif TEST_P(PeerConnectionIntegrationTest, - IceStatesReachCompletionWithRemoteHostname) { + MAYBE_IceStatesReachCompletionWithRemoteHostname) { auto caller_resolver_factory = std::make_unique>(); auto callee_resolver_factory = @@ -4018,6 +1928,8 @@ TEST_P(PeerConnectionIntegrationTest, webrtc::kIceCandidatePairHostNameHostName)); } +#endif // !defined(THREAD_SANITIZER) + // Test that firewalling the ICE connection causes the clients to identify the // disconnected state and then removing the firewall causes them to reconnect. class PeerConnectionIntegrationIceStatesTest @@ -4086,6 +1998,9 @@ class PeerConnectionIntegrationIceStatesTestWithFakeClock : public FakeClockForTest, public PeerConnectionIntegrationIceStatesTest {}; +#if !defined(THREAD_SANITIZER) +// This test provokes TSAN errors. bugs.webrtc.org/11282 + // Tests that the PeerConnection goes through all the ICE gathering/connection // states over the duration of the call. This includes Disconnected and Failed // states, induced by putting a firewall between the peers and waiting for them @@ -4212,9 +2127,17 @@ TEST_P(PeerConnectionIntegrationIceStatesTestWithFakeClock, kConsentTimeout, FakeClock()); } +#endif // !defined(THREAD_SANITIZER) + // Tests that the best connection is set to the appropriate IPv4/IPv6 connection // and that the statistics in the metric observers are updated correctly. -TEST_P(PeerConnectionIntegrationIceStatesTest, VerifyBestConnection) { +// TODO(bugs.webrtc.org/12591): Flaky on Windows. +#if defined(WEBRTC_WIN) +#define MAYBE_VerifyBestConnection DISABLED_VerifyBestConnection +#else +#define MAYBE_VerifyBestConnection VerifyBestConnection +#endif +TEST_P(PeerConnectionIntegrationIceStatesTest, MAYBE_VerifyBestConnection) { ASSERT_TRUE(CreatePeerConnectionWrappers()); ConnectFakeSignaling(); SetPortAllocatorFlags(); @@ -4388,8 +2311,16 @@ TEST_P(PeerConnectionIntegrationTest, EndToEndCallWithIceRenomination) { // With a max bundle policy and RTCP muxing, adding a new media description to // the connection should not affect ICE at all because the new media will use // the existing connection. +// TODO(bugs.webrtc.org/12538): Fails on tsan. +#if defined(THREAD_SANITIZER) +#define MAYBE_AddMediaToConnectedBundleDoesNotRestartIce \ + DISABLED_AddMediaToConnectedBundleDoesNotRestartIce +#else +#define MAYBE_AddMediaToConnectedBundleDoesNotRestartIce \ + AddMediaToConnectedBundleDoesNotRestartIce +#endif TEST_P(PeerConnectionIntegrationTest, - AddMediaToConnectedBundleDoesNotRestartIce) { + MAYBE_AddMediaToConnectedBundleDoesNotRestartIce) { PeerConnectionInterface::RTCConfiguration config; config.bundle_policy = PeerConnectionInterface::kBundlePolicyMaxBundle; config.rtcp_mux_policy = PeerConnectionInterface::kRtcpMuxPolicyRequire; @@ -4568,6 +2499,9 @@ TEST_F(PeerConnectionIntegrationTestPlanB, CanSendRemoteVideoTrack) { ASSERT_TRUE(ExpectNewFrames(media_expectations)); } +#if !defined(THREAD_SANITIZER) +// This test provokes TSAN errors. bugs.webrtc.org/11282 + // Test that we achieve the expected end-to-end connection time, using a // fake clock and simulated latency on the media and signaling paths. // We use a TURN<->TURN connection because this is usually the quickest to @@ -4658,6 +2592,8 @@ TEST_P(PeerConnectionIntegrationTestWithFakeClock, ClosePeerConnections(); } +#endif // !defined(THREAD_SANITIZER) + // Verify that a TurnCustomizer passed in through RTCConfiguration // is actually used by the underlying TURN candidate pair. // Note that turnport_unittest.cc contains more detailed, lower-level tests. @@ -4897,8 +2833,7 @@ TEST_P(PeerConnectionIntegrationTest, IceTransportFactoryUsedForConnections) { /*reset_decoder_factory=*/false); ASSERT_TRUE(wrapper); wrapper->CreateDataChannel(); - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); + auto observer = rtc::make_ref_counted(); wrapper->pc()->SetLocalDescription(observer, wrapper->CreateOfferAndWait().release()); } @@ -5080,6 +3015,9 @@ TEST_P(PeerConnectionIntegrationTest, MediaFlowsWhenCandidatesSetOnlyInSdp) { ASSERT_TRUE(ExpectNewFrames(media_expectations)); } +#if !defined(THREAD_SANITIZER) +// These tests provokes TSAN errors. See bugs.webrtc.org/11305. + // Test that SetAudioPlayout can be used to disable audio playout from the // start, then later enable it. This may be useful, for example, if the caller // needs to play a local ringtone until some event occurs, after which it @@ -5111,7 +3049,7 @@ TEST_P(PeerConnectionIntegrationTest, DisableAndEnableAudioPlayout) { ASSERT_TRUE(ExpectNewFrames(media_expectations)); } -double GetAudioEnergyStat(PeerConnectionWrapper* pc) { +double GetAudioEnergyStat(PeerConnectionIntegrationWrapper* pc) { auto report = pc->NewGetStats(); auto track_stats_list = report->GetStatsOfType(); @@ -5150,6 +3088,8 @@ TEST_P(PeerConnectionIntegrationTest, EXPECT_TRUE_WAIT(GetAudioEnergyStat(caller()) > 0, kMaxWaitForFramesMs); } +#endif // !defined(THREAD_SANITIZER) + // Test that SetAudioRecording can be used to disable audio recording from the // start, then later enable it. This may be useful, for example, if the caller // wants to ensure that no audio resources are active before a certain state @@ -5181,51 +3121,6 @@ TEST_P(PeerConnectionIntegrationTest, DisableAndEnableAudioRecording) { ASSERT_TRUE(ExpectNewFrames(media_expectations)); } -// Test that after closing PeerConnections, they stop sending any packets (ICE, -// DTLS, RTP...). -TEST_P(PeerConnectionIntegrationTest, ClosingConnectionStopsPacketFlow) { - // Set up audio/video/data, wait for some frames to be received. - ASSERT_TRUE(CreatePeerConnectionWrappers()); - ConnectFakeSignaling(); - caller()->AddAudioVideoTracks(); -#ifdef HAVE_SCTP - caller()->CreateDataChannel(); -#endif - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - MediaExpectations media_expectations; - media_expectations.CalleeExpectsSomeAudioAndVideo(); - ASSERT_TRUE(ExpectNewFrames(media_expectations)); - // Close PeerConnections. - ClosePeerConnections(); - // Pump messages for a second, and ensure no new packets end up sent. - uint32_t sent_packets_a = virtual_socket_server()->sent_packets(); - WAIT(false, 1000); - uint32_t sent_packets_b = virtual_socket_server()->sent_packets(); - EXPECT_EQ(sent_packets_a, sent_packets_b); -} - -// Test that transport stats are generated by the RTCStatsCollector for a -// connection that only involves data channels. This is a regression test for -// crbug.com/826972. -#ifdef HAVE_SCTP -TEST_P(PeerConnectionIntegrationTest, - TransportStatsReportedForDataChannelOnlyConnection) { - ASSERT_TRUE(CreatePeerConnectionWrappers()); - ConnectFakeSignaling(); - caller()->CreateDataChannel(); - - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_channel(), kDefaultTimeout); - - auto caller_report = caller()->NewGetStats(); - EXPECT_EQ(1u, caller_report->GetStatsOfType().size()); - auto callee_report = callee()->NewGetStats(); - EXPECT_EQ(1u, callee_report->GetStatsOfType().size()); -} -#endif // HAVE_SCTP - TEST_P(PeerConnectionIntegrationTest, IceEventsGeneratedAndLoggedInRtcEventLog) { ASSERT_TRUE(CreatePeerConnectionWrappersWithFakeRtcEventLog()); @@ -5238,11 +3133,9 @@ TEST_P(PeerConnectionIntegrationTest, ASSERT_NE(nullptr, caller()->event_log_factory()); ASSERT_NE(nullptr, callee()->event_log_factory()); webrtc::FakeRtcEventLog* caller_event_log = - static_cast( - caller()->event_log_factory()->last_log_created()); + caller()->event_log_factory()->last_log_created(); webrtc::FakeRtcEventLog* callee_event_log = - static_cast( - callee()->event_log_factory()->last_log_created()); + callee()->event_log_factory()->last_log_created(); ASSERT_NE(nullptr, caller_event_log); ASSERT_NE(nullptr, callee_event_log); int caller_ice_config_count = caller_event_log->GetEventCount( @@ -5425,8 +3318,7 @@ TEST_F(PeerConnectionIntegrationTestUnifiedPlan, SetSignalIceCandidates(false); // Workaround candidate outrace sdp. caller()->AddVideoTrack(); callee()->AddVideoTrack(); - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); + auto observer = rtc::make_ref_counted(); callee()->pc()->SetLocalDescription(observer, callee()->CreateOfferAndWait().release()); EXPECT_TRUE_WAIT(observer->called(), kDefaultTimeout); @@ -5443,15 +3335,15 @@ TEST_F(PeerConnectionIntegrationTestUnifiedPlan, ASSERT_TRUE(CreatePeerConnectionWrappersWithConfig(config, config)); - rtc::scoped_refptr sld_observer( - new rtc::RefCountedObject()); + auto sld_observer = + rtc::make_ref_counted(); callee()->pc()->SetLocalDescription(sld_observer, callee()->CreateOfferAndWait().release()); EXPECT_TRUE_WAIT(sld_observer->called(), kDefaultTimeout); EXPECT_EQ(sld_observer->error(), ""); - rtc::scoped_refptr srd_observer( - new rtc::RefCountedObject()); + auto srd_observer = + rtc::make_ref_counted(); callee()->pc()->SetRemoteDescription( srd_observer, caller()->CreateOfferAndWait().release()); EXPECT_TRUE_WAIT(srd_observer->called(), kDefaultTimeout); @@ -5520,7 +3412,8 @@ TEST_F(PeerConnectionIntegrationTestUnifiedPlan, // Add more tracks until we get close to having issues. // Issues have been seen at: // - 32 tracks on android_arm64_rel and android_arm_dbg bots - while (current_size < 16) { + // - 16 tracks on android_arm_dbg (flaky) + while (current_size < 8) { // Double the number of tracks for (int i = 0; i < current_size; i++) { caller()->pc()->AddTransceiver(cricket::MEDIA_TYPE_AUDIO); @@ -5555,7 +3448,8 @@ TEST_F(PeerConnectionIntegrationTestUnifiedPlan, // - 96 on a Linux workstation // - 64 at win_x86_more_configs and win_x64_msvc_dbg // - 32 on android_arm64_rel and linux_dbg bots - while (current_size < 16) { + // - 16 on Android 64 (Nexus 5x) + while (current_size < 8) { // Double the number of tracks for (int i = 0; i < current_size; i++) { caller()->pc()->AddTransceiver(cricket::MEDIA_TYPE_VIDEO); @@ -5597,7 +3491,7 @@ TEST_F(PeerConnectionIntegrationTestUnifiedPlan, int current_size = caller()->pc()->GetTransceivers().size(); // Add more tracks until we get close to having issues. // Making this number very large makes the test very slow. - while (current_size < 32) { + while (current_size < 16) { // Double the number of tracks for (int i = 0; i < current_size; i++) { caller()->pc()->AddTransceiver(cricket::MEDIA_TYPE_VIDEO); @@ -5638,7 +3532,7 @@ class PeerConnectionIntegrationInteropTest protected: // Setting the SdpSemantics for the base test to kDefault does not matter // because we specify not to use the test semantics when creating - // PeerConnectionWrappers. + // PeerConnectionIntegrationWrappers. PeerConnectionIntegrationInteropTest() : PeerConnectionIntegrationBaseTest(SdpSemantics::kPlanB), caller_semantics_(std::get<0>(GetParam())), @@ -5909,77 +3803,6 @@ TEST_F(PeerConnectionIntegrationTestUnifiedPlan, callee_track->state()); } -#ifdef HAVE_SCTP - -TEST_F(PeerConnectionIntegrationTestUnifiedPlan, - EndToEndCallWithBundledSctpDataChannel) { - ASSERT_TRUE(CreatePeerConnectionWrappers()); - ConnectFakeSignaling(); - caller()->CreateDataChannel(); - caller()->AddAudioVideoTracks(); - callee()->AddAudioVideoTracks(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - ASSERT_EQ_WAIT(SctpTransportState::kConnected, - caller()->pc()->GetSctpTransport()->Information().state(), - kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_channel(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); -} - -TEST_F(PeerConnectionIntegrationTestUnifiedPlan, - EndToEndCallWithDataChannelOnlyConnects) { - ASSERT_TRUE(CreatePeerConnectionWrappers()); - ConnectFakeSignaling(); - caller()->CreateDataChannel(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_channel(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); - ASSERT_TRUE(caller()->data_observer()->IsOpen()); -} - -TEST_F(PeerConnectionIntegrationTestUnifiedPlan, DataChannelClosesWhenClosed) { - ASSERT_TRUE(CreatePeerConnectionWrappers()); - ConnectFakeSignaling(); - caller()->CreateDataChannel(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_observer(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); - caller()->data_channel()->Close(); - ASSERT_TRUE_WAIT(!callee()->data_observer()->IsOpen(), kDefaultTimeout); -} - -TEST_F(PeerConnectionIntegrationTestUnifiedPlan, - DataChannelClosesWhenClosedReverse) { - ASSERT_TRUE(CreatePeerConnectionWrappers()); - ConnectFakeSignaling(); - caller()->CreateDataChannel(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_observer(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); - callee()->data_channel()->Close(); - ASSERT_TRUE_WAIT(!caller()->data_observer()->IsOpen(), kDefaultTimeout); -} - -TEST_F(PeerConnectionIntegrationTestUnifiedPlan, - DataChannelClosesWhenPeerConnectionClosed) { - ASSERT_TRUE(CreatePeerConnectionWrappers()); - ConnectFakeSignaling(); - caller()->CreateDataChannel(); - caller()->CreateAndSetAndSignalOffer(); - ASSERT_TRUE_WAIT(SignalingStateStable(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_observer(), kDefaultTimeout); - ASSERT_TRUE_WAIT(callee()->data_observer()->IsOpen(), kDefaultTimeout); - caller()->pc()->Close(); - ASSERT_TRUE_WAIT(!callee()->data_observer()->IsOpen(), kDefaultTimeout); -} - -#endif // HAVE_SCTP - } // namespace -} // namespace webrtc -#endif // if !defined(THREAD_SANITIZER) +} // namespace webrtc diff --git a/pc/peer_connection_interface_unittest.cc b/pc/peer_connection_interface_unittest.cc index abedf48688..fcea842b22 100644 --- a/pc/peer_connection_interface_unittest.cc +++ b/pc/peer_connection_interface_unittest.cc @@ -661,7 +661,7 @@ class PeerConnectionFactoryForTest : public webrtc::PeerConnectionFactory { dependencies.event_log_factory = std::make_unique( dependencies.task_queue_factory.get()); - return new rtc::RefCountedObject( + return rtc::make_ref_counted( std::move(dependencies)); } @@ -683,7 +683,7 @@ class PeerConnectionInterfaceBaseTest : public ::testing::Test { #endif } - virtual void SetUp() { + void SetUp() override { // Use fake audio capture module since we're only testing the interface // level, and using a real one could make tests flaky when run in parallel. fake_audio_capture_module_ = FakeAudioCaptureModule::Create(); @@ -701,6 +701,11 @@ class PeerConnectionInterfaceBaseTest : public ::testing::Test { PeerConnectionFactoryForTest::CreatePeerConnectionFactoryForTest(); } + void TearDown() override { + if (pc_) + pc_->Close(); + } + void CreatePeerConnection() { CreatePeerConnection(PeerConnectionInterface::RTCConfiguration()); } @@ -734,6 +739,10 @@ class PeerConnectionInterfaceBaseTest : public ::testing::Test { } void CreatePeerConnection(const RTCConfiguration& config) { + if (pc_) { + pc_->Close(); + pc_ = nullptr; + } std::unique_ptr port_allocator( new cricket::FakePortAllocator(rtc::Thread::Current(), nullptr)); port_allocator_ = port_allocator.get(); @@ -870,8 +879,8 @@ class PeerConnectionInterfaceBaseTest : public ::testing::Test { bool DoCreateOfferAnswer(std::unique_ptr* desc, const RTCOfferAnswerOptions* options, bool offer) { - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); + auto observer = + rtc::make_ref_counted(); if (offer) { pc_->CreateOffer(observer, options ? *options : RTCOfferAnswerOptions()); } else { @@ -895,8 +904,7 @@ class PeerConnectionInterfaceBaseTest : public ::testing::Test { bool DoSetSessionDescription( std::unique_ptr desc, bool local) { - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); + auto observer = rtc::make_ref_counted(); if (local) { pc_->SetLocalDescription(observer, desc.release()); } else { @@ -922,8 +930,7 @@ class PeerConnectionInterfaceBaseTest : public ::testing::Test { // It does not verify the values in the StatReports since a RTCP packet might // be required. bool DoGetStats(MediaStreamTrackInterface* track) { - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); + auto observer = rtc::make_ref_counted(); if (!pc_->GetStats(observer, track, PeerConnectionInterface::kStatsOutputLevelStandard)) return false; @@ -933,8 +940,8 @@ class PeerConnectionInterfaceBaseTest : public ::testing::Test { // Call the standards-compliant GetStats function. bool DoGetRTCStats() { - rtc::scoped_refptr callback( - new rtc::RefCountedObject()); + auto callback = + rtc::make_ref_counted(); pc_->GetStats(callback); EXPECT_TRUE_WAIT(callback->called(), kTimeout); return callback->called(); @@ -1189,8 +1196,8 @@ class PeerConnectionInterfaceBaseTest : public ::testing::Test { std::unique_ptr CreateOfferWithOptions( const RTCOfferAnswerOptions& offer_answer_options) { RTC_DCHECK(pc_); - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); + auto observer = + rtc::make_ref_counted(); pc_->CreateOffer(observer, offer_answer_options); EXPECT_EQ_WAIT(true, observer->called(), kTimeout); return observer->MoveDescription(); @@ -1892,179 +1899,6 @@ TEST_P(PeerConnectionInterfaceTest, GetRTCStatsBeforeAndAfterCalling) { EXPECT_TRUE(DoGetRTCStats()); } -// This test setup two RTP data channels in loop back. -TEST_P(PeerConnectionInterfaceTest, TestDataChannel) { - RTCConfiguration config; - config.enable_rtp_data_channel = true; - config.enable_dtls_srtp = false; - CreatePeerConnection(config); - rtc::scoped_refptr data1 = - pc_->CreateDataChannel("test1", NULL); - rtc::scoped_refptr data2 = - pc_->CreateDataChannel("test2", NULL); - ASSERT_TRUE(data1 != NULL); - std::unique_ptr observer1( - new MockDataChannelObserver(data1)); - std::unique_ptr observer2( - new MockDataChannelObserver(data2)); - - EXPECT_EQ(DataChannelInterface::kConnecting, data1->state()); - EXPECT_EQ(DataChannelInterface::kConnecting, data2->state()); - std::string data_to_send1 = "testing testing"; - std::string data_to_send2 = "testing something else"; - EXPECT_FALSE(data1->Send(DataBuffer(data_to_send1))); - - CreateOfferReceiveAnswer(); - EXPECT_TRUE_WAIT(observer1->IsOpen(), kTimeout); - EXPECT_TRUE_WAIT(observer2->IsOpen(), kTimeout); - - EXPECT_EQ(DataChannelInterface::kOpen, data1->state()); - EXPECT_EQ(DataChannelInterface::kOpen, data2->state()); - EXPECT_TRUE(data1->Send(DataBuffer(data_to_send1))); - EXPECT_TRUE(data2->Send(DataBuffer(data_to_send2))); - - EXPECT_EQ_WAIT(data_to_send1, observer1->last_message(), kTimeout); - EXPECT_EQ_WAIT(data_to_send2, observer2->last_message(), kTimeout); - - data1->Close(); - EXPECT_EQ(DataChannelInterface::kClosing, data1->state()); - CreateOfferReceiveAnswer(); - EXPECT_FALSE(observer1->IsOpen()); - EXPECT_EQ(DataChannelInterface::kClosed, data1->state()); - EXPECT_TRUE(observer2->IsOpen()); - - data_to_send2 = "testing something else again"; - EXPECT_TRUE(data2->Send(DataBuffer(data_to_send2))); - - EXPECT_EQ_WAIT(data_to_send2, observer2->last_message(), kTimeout); -} - -// This test verifies that sendnig binary data over RTP data channels should -// fail. -TEST_P(PeerConnectionInterfaceTest, TestSendBinaryOnRtpDataChannel) { - RTCConfiguration config; - config.enable_rtp_data_channel = true; - config.enable_dtls_srtp = false; - CreatePeerConnection(config); - rtc::scoped_refptr data1 = - pc_->CreateDataChannel("test1", NULL); - rtc::scoped_refptr data2 = - pc_->CreateDataChannel("test2", NULL); - ASSERT_TRUE(data1 != NULL); - std::unique_ptr observer1( - new MockDataChannelObserver(data1)); - std::unique_ptr observer2( - new MockDataChannelObserver(data2)); - - EXPECT_EQ(DataChannelInterface::kConnecting, data1->state()); - EXPECT_EQ(DataChannelInterface::kConnecting, data2->state()); - - CreateOfferReceiveAnswer(); - EXPECT_TRUE_WAIT(observer1->IsOpen(), kTimeout); - EXPECT_TRUE_WAIT(observer2->IsOpen(), kTimeout); - - EXPECT_EQ(DataChannelInterface::kOpen, data1->state()); - EXPECT_EQ(DataChannelInterface::kOpen, data2->state()); - - rtc::CopyOnWriteBuffer buffer("test", 4); - EXPECT_FALSE(data1->Send(DataBuffer(buffer, true))); -} - -// This test setup a RTP data channels in loop back and test that a channel is -// opened even if the remote end answer with a zero SSRC. -TEST_P(PeerConnectionInterfaceTest, TestSendOnlyDataChannel) { - RTCConfiguration config; - config.enable_rtp_data_channel = true; - config.enable_dtls_srtp = false; - CreatePeerConnection(config); - rtc::scoped_refptr data1 = - pc_->CreateDataChannel("test1", NULL); - std::unique_ptr observer1( - new MockDataChannelObserver(data1)); - - CreateOfferReceiveAnswerWithoutSsrc(); - - EXPECT_TRUE_WAIT(observer1->IsOpen(), kTimeout); - - data1->Close(); - EXPECT_EQ(DataChannelInterface::kClosing, data1->state()); - CreateOfferReceiveAnswerWithoutSsrc(); - EXPECT_EQ(DataChannelInterface::kClosed, data1->state()); - EXPECT_FALSE(observer1->IsOpen()); -} - -// This test that if a data channel is added in an answer a receive only channel -// channel is created. -TEST_P(PeerConnectionInterfaceTest, TestReceiveOnlyDataChannel) { - RTCConfiguration config; - config.enable_rtp_data_channel = true; - config.enable_dtls_srtp = false; - - CreatePeerConnection(config); - - std::string offer_label = "offer_channel"; - rtc::scoped_refptr offer_channel = - pc_->CreateDataChannel(offer_label, NULL); - - CreateOfferAsLocalDescription(); - - // Replace the data channel label in the offer and apply it as an answer. - std::string receive_label = "answer_channel"; - std::string sdp; - EXPECT_TRUE(pc_->local_description()->ToString(&sdp)); - absl::StrReplaceAll({{offer_label, receive_label}}, &sdp); - CreateAnswerAsRemoteDescription(sdp); - - // Verify that a new incoming data channel has been created and that - // it is open but can't we written to. - ASSERT_TRUE(observer_.last_datachannel_ != NULL); - DataChannelInterface* received_channel = observer_.last_datachannel_; - EXPECT_EQ(DataChannelInterface::kConnecting, received_channel->state()); - EXPECT_EQ(receive_label, received_channel->label()); - EXPECT_FALSE(received_channel->Send(DataBuffer("something"))); - - // Verify that the channel we initially offered has been rejected. - EXPECT_EQ(DataChannelInterface::kClosed, offer_channel->state()); - - // Do another offer / answer exchange and verify that the data channel is - // opened. - CreateOfferReceiveAnswer(); - EXPECT_EQ_WAIT(DataChannelInterface::kOpen, received_channel->state(), - kTimeout); -} - -// This test that no data channel is returned if a reliable channel is -// requested. -// TODO(perkj): Remove this test once reliable channels are implemented. -TEST_P(PeerConnectionInterfaceTest, CreateReliableRtpDataChannelShouldFail) { - RTCConfiguration rtc_config; - rtc_config.enable_rtp_data_channel = true; - CreatePeerConnection(rtc_config); - - std::string label = "test"; - webrtc::DataChannelInit config; - config.reliable = true; - rtc::scoped_refptr channel = - pc_->CreateDataChannel(label, &config); - EXPECT_TRUE(channel == NULL); -} - -// Verifies that duplicated label is not allowed for RTP data channel. -TEST_P(PeerConnectionInterfaceTest, RtpDuplicatedLabelNotAllowed) { - RTCConfiguration config; - config.enable_rtp_data_channel = true; - CreatePeerConnection(config); - - std::string label = "test"; - rtc::scoped_refptr channel = - pc_->CreateDataChannel(label, nullptr); - EXPECT_NE(channel, nullptr); - - rtc::scoped_refptr dup_channel = - pc_->CreateDataChannel(label, nullptr); - EXPECT_EQ(dup_channel, nullptr); -} - // This tests that a SCTP data channel is returned using different // DataChannelInit configurations. TEST_P(PeerConnectionInterfaceTest, CreateSctpDataChannel) { @@ -2182,80 +2016,8 @@ TEST_P(PeerConnectionInterfaceTest, SctpDuplicatedLabelAllowed) { EXPECT_NE(dup_channel, nullptr); } -// This test verifies that OnRenegotiationNeeded is fired for every new RTP -// DataChannel. -TEST_P(PeerConnectionInterfaceTest, RenegotiationNeededForNewRtpDataChannel) { - RTCConfiguration rtc_config; - rtc_config.enable_rtp_data_channel = true; - rtc_config.enable_dtls_srtp = false; - CreatePeerConnection(rtc_config); - - rtc::scoped_refptr dc1 = - pc_->CreateDataChannel("test1", NULL); - EXPECT_TRUE(observer_.renegotiation_needed_); - observer_.renegotiation_needed_ = false; - - CreateOfferReceiveAnswer(); - - rtc::scoped_refptr dc2 = - pc_->CreateDataChannel("test2", NULL); - EXPECT_EQ(observer_.renegotiation_needed_, - GetParam() == SdpSemantics::kPlanB); -} - -// This test that a data channel closes when a PeerConnection is deleted/closed. -TEST_P(PeerConnectionInterfaceTest, DataChannelCloseWhenPeerConnectionClose) { - RTCConfiguration rtc_config; - rtc_config.enable_rtp_data_channel = true; - rtc_config.enable_dtls_srtp = false; - CreatePeerConnection(rtc_config); - - rtc::scoped_refptr data1 = - pc_->CreateDataChannel("test1", NULL); - rtc::scoped_refptr data2 = - pc_->CreateDataChannel("test2", NULL); - ASSERT_TRUE(data1 != NULL); - std::unique_ptr observer1( - new MockDataChannelObserver(data1)); - std::unique_ptr observer2( - new MockDataChannelObserver(data2)); - CreateOfferReceiveAnswer(); - EXPECT_TRUE_WAIT(observer1->IsOpen(), kTimeout); - EXPECT_TRUE_WAIT(observer2->IsOpen(), kTimeout); - - ReleasePeerConnection(); - EXPECT_EQ(DataChannelInterface::kClosed, data1->state()); - EXPECT_EQ(DataChannelInterface::kClosed, data2->state()); -} - -// This tests that RTP data channels can be rejected in an answer. -TEST_P(PeerConnectionInterfaceTest, TestRejectRtpDataChannelInAnswer) { - RTCConfiguration rtc_config; - rtc_config.enable_rtp_data_channel = true; - rtc_config.enable_dtls_srtp = false; - CreatePeerConnection(rtc_config); - - rtc::scoped_refptr offer_channel( - pc_->CreateDataChannel("offer_channel", NULL)); - - CreateOfferAsLocalDescription(); - - // Create an answer where the m-line for data channels are rejected. - std::string sdp; - EXPECT_TRUE(pc_->local_description()->ToString(&sdp)); - std::unique_ptr answer( - webrtc::CreateSessionDescription(SdpType::kAnswer, sdp)); - ASSERT_TRUE(answer); - cricket::ContentInfo* data_info = - cricket::GetFirstDataContent(answer->description()); - data_info->rejected = true; - - DoSetRemoteDescription(std::move(answer)); - EXPECT_EQ(DataChannelInterface::kClosed, offer_channel->state()); -} - -#ifdef HAVE_SCTP +#ifdef WEBRTC_HAVE_SCTP // This tests that SCTP data channels can be rejected in an answer. TEST_P(PeerConnectionInterfaceTest, TestRejectSctpDataChannelInAnswer) #else @@ -2310,7 +2072,7 @@ TEST_P(PeerConnectionInterfaceTest, ReceiveFireFoxOffer) { cricket::GetFirstVideoContent(pc_->local_description()->description()); ASSERT_TRUE(content != NULL); EXPECT_FALSE(content->rejected); -#ifdef HAVE_SCTP +#ifdef WEBRTC_HAVE_SCTP content = cricket::GetFirstDataContent(pc_->local_description()->description()); ASSERT_TRUE(content != NULL); @@ -3593,12 +3355,12 @@ TEST_F(PeerConnectionInterfaceTestPlanB, // Test that negotiation can succeed with a data channel only, and with the max // bundle policy. Previously there was a bug that prevented this. -#ifdef HAVE_SCTP +#ifdef WEBRTC_HAVE_SCTP TEST_P(PeerConnectionInterfaceTest, DataChannelOnlyOfferWithMaxBundlePolicy) { #else TEST_P(PeerConnectionInterfaceTest, DISABLED_DataChannelOnlyOfferWithMaxBundlePolicy) { -#endif // HAVE_SCTP +#endif // WEBRTC_HAVE_SCTP PeerConnectionInterface::RTCConfiguration config; config.bundle_policy = PeerConnectionInterface::kBundlePolicyMaxBundle; CreatePeerConnection(config); @@ -3900,17 +3662,17 @@ TEST_P(PeerConnectionInterfaceTest, TEST_P(PeerConnectionInterfaceTest, ExtmapAllowMixedIsConfigurable) { RTCConfiguration config; - // Default behavior is false. + // Default behavior is true. CreatePeerConnection(config); std::unique_ptr offer; ASSERT_TRUE(DoCreateOffer(&offer, nullptr)); - EXPECT_FALSE(offer->description()->extmap_allow_mixed()); - // Possible to set to true. - config.offer_extmap_allow_mixed = true; + EXPECT_TRUE(offer->description()->extmap_allow_mixed()); + // Possible to set to false. + config.offer_extmap_allow_mixed = false; CreatePeerConnection(config); - offer.release(); + offer = nullptr; ASSERT_TRUE(DoCreateOffer(&offer, nullptr)); - EXPECT_TRUE(offer->description()->extmap_allow_mixed()); + EXPECT_FALSE(offer->description()->extmap_allow_mixed()); } INSTANTIATE_TEST_SUITE_P(PeerConnectionInterfaceTest, diff --git a/pc/peer_connection_internal.h b/pc/peer_connection_internal.h index 029febab2d..6f97612914 100644 --- a/pc/peer_connection_internal.h +++ b/pc/peer_connection_internal.h @@ -19,7 +19,6 @@ #include "api/peer_connection_interface.h" #include "call/call.h" -#include "pc/rtp_data_channel.h" #include "pc/rtp_transceiver.h" #include "pc/sctp_data_channel.h" @@ -41,13 +40,9 @@ class PeerConnectionInternal : public PeerConnectionInterface { rtc::scoped_refptr>> GetTransceiversInternal() const = 0; - virtual sigslot::signal1& SignalRtpDataChannelCreated() = 0; virtual sigslot::signal1& SignalSctpDataChannelCreated() = 0; - // Only valid when using deprecated RTP data channels. - virtual cricket::RtpDataChannel* rtp_data_channel() const = 0; - // Call on the network thread to fetch stats for all the data channels. // TODO(tommi): Make pure virtual after downstream updates. virtual std::vector GetDataChannelStats() const { @@ -55,14 +50,13 @@ class PeerConnectionInternal : public PeerConnectionInterface { } virtual absl::optional sctp_transport_name() const = 0; + virtual absl::optional sctp_mid() const = 0; virtual cricket::CandidateStatsList GetPooledCandidateStats() const = 0; - // Returns a map from MID to transport name for all active media sections. - virtual std::map GetTransportNamesByMid() const = 0; - // Returns a map from transport name to transport stats for all given // transport names. + // Must be called on the network thread. virtual std::map GetTransportStatsByNames(const std::set& transport_names) = 0; diff --git a/pc/peer_connection_jsep_unittest.cc b/pc/peer_connection_jsep_unittest.cc index c3e093617b..4713068a15 100644 --- a/pc/peer_connection_jsep_unittest.cc +++ b/pc/peer_connection_jsep_unittest.cc @@ -1915,6 +1915,68 @@ TEST_F(PeerConnectionJsepTest, RollbackRestoresMid) { EXPECT_TRUE(callee->SetLocalDescription(std::move(offer))); } +TEST_F(PeerConnectionJsepTest, RollbackRestoresInitSendEncodings) { + auto caller = CreatePeerConnection(); + RtpTransceiverInit init; + init.direction = RtpTransceiverDirection::kSendRecv; + RtpEncodingParameters encoding; + encoding.rid = "hi"; + init.send_encodings.push_back(encoding); + encoding.rid = "mid"; + init.send_encodings.push_back(encoding); + encoding.rid = "lo"; + init.send_encodings.push_back(encoding); + caller->AddTransceiver(cricket::MEDIA_TYPE_VIDEO, init); + auto encodings = + caller->pc()->GetTransceivers()[0]->sender()->init_send_encodings(); + EXPECT_TRUE(caller->SetLocalDescription(caller->CreateOffer())); + EXPECT_NE(caller->pc()->GetTransceivers()[0]->sender()->init_send_encodings(), + encodings); + EXPECT_TRUE(caller->SetLocalDescription(caller->CreateRollback())); + EXPECT_EQ(caller->pc()->GetTransceivers()[0]->sender()->init_send_encodings(), + encodings); +} + +TEST_F(PeerConnectionJsepTest, RollbackDoesNotAffectSendEncodings) { + auto caller = CreatePeerConnection(); + auto callee = CreatePeerConnection(); + RtpTransceiverInit init; + init.direction = RtpTransceiverDirection::kSendOnly; + RtpEncodingParameters encoding; + encoding.rid = "hi"; + init.send_encodings.push_back(encoding); + encoding.rid = "mid"; + init.send_encodings.push_back(encoding); + encoding.rid = "lo"; + init.send_encodings.push_back(encoding); + caller->AddTransceiver(cricket::MEDIA_TYPE_VIDEO, init); + callee->AddTransceiver(cricket::MEDIA_TYPE_VIDEO); + callee->SetRemoteDescription(caller->CreateOfferAndSetAsLocal()); + caller->SetRemoteDescription(callee->CreateAnswerAndSetAsLocal()); + auto params = caller->pc()->GetTransceivers()[0]->sender()->GetParameters(); + EXPECT_TRUE(params.encodings[0].active); + params.encodings[0].active = false; + caller->pc()->GetTransceivers()[0]->sender()->SetParameters(params); + auto offer = caller->CreateOffer(); + std::string offer_string; + EXPECT_TRUE(offer.get()->ToString(&offer_string)); + std::string simulcast_line = + offer_string.substr(offer_string.find("a=simulcast")); + EXPECT_FALSE(simulcast_line.empty()); + EXPECT_TRUE(caller->SetLocalDescription(std::move(offer))); + EXPECT_TRUE(caller->SetLocalDescription(caller->CreateRollback())); + EXPECT_FALSE(caller->pc() + ->GetTransceivers()[0] + ->sender() + ->GetParameters() + .encodings[0] + .active); + offer = caller->CreateOffer(); + EXPECT_TRUE(offer.get()->ToString(&offer_string)); + EXPECT_EQ(offer_string.substr(offer_string.find("a=simulcast")), + simulcast_line); +} + TEST_F(PeerConnectionJsepTest, RollbackRestoresMidAndRemovesTransceiver) { auto callee = CreatePeerConnection(); callee->AddVideoTrack("a"); @@ -2204,16 +2266,4 @@ TEST_F(PeerConnectionJsepTest, EXPECT_TRUE(callee->CreateOfferAndSetAsLocal()); } -TEST_F(PeerConnectionJsepTest, RollbackRtpDataChannel) { - RTCConfiguration config; - config.sdp_semantics = SdpSemantics::kUnifiedPlan; - config.enable_rtp_data_channel = true; - auto pc = CreatePeerConnection(config); - pc->CreateDataChannel("dummy"); - auto offer = pc->CreateOffer(); - EXPECT_TRUE(pc->CreateOfferAndSetAsLocal()); - EXPECT_TRUE(pc->SetRemoteDescription(pc->CreateRollback())); - EXPECT_TRUE(pc->SetLocalDescription(std::move(offer))); -} - } // namespace webrtc diff --git a/pc/peer_connection_media_unittest.cc b/pc/peer_connection_media_unittest.cc index f078144d4f..d5d0b926b7 100644 --- a/pc/peer_connection_media_unittest.cc +++ b/pc/peer_connection_media_unittest.cc @@ -848,8 +848,9 @@ bool HasAnyComfortNoiseCodecs(const cricket::SessionDescription* desc) { TEST_P(PeerConnectionMediaTest, CreateOfferWithNoVoiceActivityDetectionIncludesNoComfortNoiseCodecs) { - auto caller = CreatePeerConnectionWithAudioVideo(); - AddComfortNoiseCodecsToSend(caller->media_engine()); + auto fake_engine = std::make_unique(); + AddComfortNoiseCodecsToSend(fake_engine.get()); + auto caller = CreatePeerConnectionWithAudioVideo(std::move(fake_engine)); RTCOfferAnswerOptions options; options.voice_activity_detection = false; @@ -859,11 +860,47 @@ TEST_P(PeerConnectionMediaTest, } TEST_P(PeerConnectionMediaTest, - CreateAnswerWithNoVoiceActivityDetectionIncludesNoComfortNoiseCodecs) { + CreateOfferWithVoiceActivityDetectionIncludesComfortNoiseCodecs) { + auto fake_engine = std::make_unique(); + AddComfortNoiseCodecsToSend(fake_engine.get()); + auto caller = CreatePeerConnectionWithAudioVideo(std::move(fake_engine)); + + RTCOfferAnswerOptions options; + options.voice_activity_detection = true; + auto offer = caller->CreateOffer(options); + + EXPECT_TRUE(HasAnyComfortNoiseCodecs(offer->description())); +} + +TEST_P(PeerConnectionMediaTest, + CreateAnswerWithVoiceActivityDetectionIncludesNoComfortNoiseCodecs) { auto caller = CreatePeerConnectionWithAudioVideo(); - AddComfortNoiseCodecsToSend(caller->media_engine()); - auto callee = CreatePeerConnectionWithAudioVideo(); - AddComfortNoiseCodecsToSend(callee->media_engine()); + + auto callee_fake_engine = std::make_unique(); + AddComfortNoiseCodecsToSend(callee_fake_engine.get()); + auto callee = + CreatePeerConnectionWithAudioVideo(std::move(callee_fake_engine)); + + ASSERT_TRUE(callee->SetRemoteDescription(caller->CreateOfferAndSetAsLocal())); + + RTCOfferAnswerOptions options; + options.voice_activity_detection = true; + auto answer = callee->CreateAnswer(options); + + EXPECT_FALSE(HasAnyComfortNoiseCodecs(answer->description())); +} + +TEST_P(PeerConnectionMediaTest, + CreateAnswerWithNoVoiceActivityDetectionIncludesNoComfortNoiseCodecs) { + auto caller_fake_engine = std::make_unique(); + AddComfortNoiseCodecsToSend(caller_fake_engine.get()); + auto caller = + CreatePeerConnectionWithAudioVideo(std::move(caller_fake_engine)); + + auto callee_fake_engine = std::make_unique(); + AddComfortNoiseCodecsToSend(callee_fake_engine.get()); + auto callee = + CreatePeerConnectionWithAudioVideo(std::move(callee_fake_engine)); ASSERT_TRUE(callee->SetRemoteDescription(caller->CreateOfferAndSetAsLocal())); @@ -1736,6 +1773,26 @@ TEST_F(PeerConnectionMediaTestUnifiedPlan, EXPECT_TRUE(CompareCodecs(video_codecs_vpx_reverse, recv_codecs)); } +TEST_F(PeerConnectionMediaTestUnifiedPlan, + SetCodecPreferencesVoiceActivityDetection) { + auto fake_engine = std::make_unique(); + AddComfortNoiseCodecsToSend(fake_engine.get()); + auto caller = CreatePeerConnectionWithAudio(std::move(fake_engine)); + + RTCOfferAnswerOptions options; + auto offer = caller->CreateOffer(options); + EXPECT_TRUE(HasAnyComfortNoiseCodecs(offer->description())); + + auto transceiver = caller->pc()->GetTransceivers().front(); + auto capabilities = caller->pc_factory()->GetRtpSenderCapabilities( + cricket::MediaType::MEDIA_TYPE_AUDIO); + EXPECT_TRUE(transceiver->SetCodecPreferences(capabilities.codecs).ok()); + + options.voice_activity_detection = false; + offer = caller->CreateOffer(options); + EXPECT_FALSE(HasAnyComfortNoiseCodecs(offer->description())); +} + INSTANTIATE_TEST_SUITE_P(PeerConnectionMediaTest, PeerConnectionMediaTest, Values(SdpSemantics::kPlanB, diff --git a/pc/peer_connection_message_handler.cc b/pc/peer_connection_message_handler.cc index b3ffcf888d..4b7913d678 100644 --- a/pc/peer_connection_message_handler.cc +++ b/pc/peer_connection_message_handler.cc @@ -15,8 +15,12 @@ #include "api/jsep.h" #include "api/media_stream_interface.h" #include "api/peer_connection_interface.h" +#include "api/scoped_refptr.h" +#include "api/sequence_checker.h" +#include "api/stats_types.h" #include "pc/stats_collector_interface.h" -#include "rtc_base/synchronization/sequence_checker.h" +#include "rtc_base/checks.h" +#include "rtc_base/location.h" namespace webrtc { diff --git a/pc/peer_connection_message_handler.h b/pc/peer_connection_message_handler.h index 027fbea6c3..c19f5a4e50 100644 --- a/pc/peer_connection_message_handler.h +++ b/pc/peer_connection_message_handler.h @@ -11,10 +11,17 @@ #ifndef PC_PEER_CONNECTION_MESSAGE_HANDLER_H_ #define PC_PEER_CONNECTION_MESSAGE_HANDLER_H_ +#include + +#include "api/jsep.h" +#include "api/media_stream_interface.h" +#include "api/peer_connection_interface.h" #include "api/rtc_error.h" #include "api/stats_types.h" +#include "pc/stats_collector_interface.h" #include "rtc_base/message_handler.h" #include "rtc_base/thread.h" +#include "rtc_base/thread_message.h" namespace webrtc { diff --git a/api/peer_connection_proxy.h b/pc/peer_connection_proxy.h similarity index 83% rename from api/peer_connection_proxy.h rename to pc/peer_connection_proxy.h index 2d4cb5cad0..7601c9d053 100644 --- a/api/peer_connection_proxy.h +++ b/pc/peer_connection_proxy.h @@ -8,22 +8,25 @@ * be found in the AUTHORS file in the root of the source tree. */ -#ifndef API_PEER_CONNECTION_PROXY_H_ -#define API_PEER_CONNECTION_PROXY_H_ +#ifndef PC_PEER_CONNECTION_PROXY_H_ +#define PC_PEER_CONNECTION_PROXY_H_ #include #include #include #include "api/peer_connection_interface.h" -#include "api/proxy.h" +#include "pc/proxy.h" namespace webrtc { -// TODO(deadbeef): Move this to .cc file and out of api/. What threads methods -// are called on is an implementation detail. -BEGIN_SIGNALING_PROXY_MAP(PeerConnection) -PROXY_SIGNALING_THREAD_DESTRUCTOR() +// PeerConnection proxy objects will be constructed with two thread pointers, +// signaling and network. The proxy macros don't have 'network' specific macros +// and support for a secondary thread is provided via 'SECONDARY' macros. +// TODO(deadbeef): Move this to .cc file. What threads methods are called on is +// an implementation detail. +BEGIN_PROXY_MAP(PeerConnection) +PROXY_PRIMARY_THREAD_DESTRUCTOR() PROXY_METHOD0(rtc::scoped_refptr, local_streams) PROXY_METHOD0(rtc::scoped_refptr, remote_streams) PROXY_METHOD1(bool, AddStream, MediaStreamInterface*) @@ -73,8 +76,8 @@ PROXY_METHOD2(void, rtc::scoped_refptr, rtc::scoped_refptr) PROXY_METHOD0(void, ClearStatsCache) -PROXY_METHOD2(rtc::scoped_refptr, - CreateDataChannel, +PROXY_METHOD2(RTCErrorOr>, + CreateDataChannelOrError, const std::string&, const DataChannelInit*) PROXY_CONSTMETHOD0(const SessionDescriptionInterface*, local_description) @@ -130,10 +133,15 @@ PROXY_METHOD1(bool, RemoveIceCandidates, const std::vector&) PROXY_METHOD1(RTCError, SetBitrate, const BitrateSettings&) PROXY_METHOD1(void, SetAudioPlayout, bool) PROXY_METHOD1(void, SetAudioRecording, bool) -PROXY_METHOD1(rtc::scoped_refptr, - LookupDtlsTransportByMid, - const std::string&) -PROXY_CONSTMETHOD0(rtc::scoped_refptr, GetSctpTransport) +// This method will be invoked on the network thread. See +// PeerConnectionFactory::CreatePeerConnectionOrError for more details. +PROXY_SECONDARY_METHOD1(rtc::scoped_refptr, + LookupDtlsTransportByMid, + const std::string&) +// This method will be invoked on the network thread. See +// PeerConnectionFactory::CreatePeerConnectionOrError for more details. +PROXY_SECONDARY_CONSTMETHOD0(rtc::scoped_refptr, + GetSctpTransport) PROXY_METHOD0(SignalingState, signaling_state) PROXY_METHOD0(IceConnectionState, ice_connection_state) PROXY_METHOD0(IceConnectionState, standardized_ice_connection_state) @@ -149,8 +157,8 @@ PROXY_METHOD1(bool, StartRtcEventLog, std::unique_ptr) PROXY_METHOD0(void, StopRtcEventLog) PROXY_METHOD0(void, Close) BYPASS_PROXY_CONSTMETHOD0(rtc::Thread*, signaling_thread) -END_PROXY_MAP() +END_PROXY_MAP(PeerConnection) } // namespace webrtc -#endif // API_PEER_CONNECTION_PROXY_H_ +#endif // PC_PEER_CONNECTION_PROXY_H_ diff --git a/pc/peer_connection_rampup_tests.cc b/pc/peer_connection_rampup_tests.cc index cf3b0a27f5..d50d488125 100644 --- a/pc/peer_connection_rampup_tests.cc +++ b/pc/peer_connection_rampup_tests.cc @@ -120,7 +120,7 @@ class PeerConnectionWrapperForRampUpTest : public PeerConnectionWrapper { FrameGeneratorCapturerVideoTrackSource::Config config, Clock* clock) { video_track_sources_.emplace_back( - new rtc::RefCountedObject( + rtc::make_ref_counted( config, clock, /*is_screencast=*/false)); video_track_sources_.back()->Start(); return rtc::scoped_refptr( @@ -192,14 +192,14 @@ class PeerConnectionRampUpTest : public ::testing::Test { dependencies.tls_cert_verifier = std::make_unique(); - auto pc = - pc_factory_->CreatePeerConnection(config, std::move(dependencies)); - if (!pc) { + auto result = pc_factory_->CreatePeerConnectionOrError( + config, std::move(dependencies)); + if (!result.ok()) { return nullptr; } return std::make_unique( - pc_factory_, pc, std::move(observer)); + pc_factory_, result.MoveValue(), std::move(observer)); } void SetupOneWayCall() { diff --git a/pc/peer_connection_rtp_unittest.cc b/pc/peer_connection_rtp_unittest.cc index 4d6da66943..2822854a2d 100644 --- a/pc/peer_connection_rtp_unittest.cc +++ b/pc/peer_connection_rtp_unittest.cc @@ -779,6 +779,56 @@ TEST_F(PeerConnectionRtpTestUnifiedPlan, UnsignaledSsrcCreatesReceiverStreams) { EXPECT_EQ(receivers[0]->streams()[0]->id(), kStreamId1); EXPECT_EQ(receivers[0]->streams()[1]->id(), kStreamId2); } +TEST_F(PeerConnectionRtpTestUnifiedPlan, TracksDoNotEndWhenSsrcChanges) { + constexpr uint32_t kFirstMungedSsrc = 1337u; + + auto caller = CreatePeerConnection(); + auto callee = CreatePeerConnection(); + + // Caller offers to receive audio and video. + RtpTransceiverInit init; + init.direction = RtpTransceiverDirection::kRecvOnly; + caller->AddTransceiver(cricket::MEDIA_TYPE_AUDIO, init); + caller->AddTransceiver(cricket::MEDIA_TYPE_VIDEO, init); + + // Callee wants to send audio and video tracks. + callee->AddTrack(callee->CreateAudioTrack("audio_track"), {}); + callee->AddTrack(callee->CreateVideoTrack("video_track"), {}); + + // Do inittial offer/answer exchange. + ASSERT_TRUE(callee->SetRemoteDescription(caller->CreateOfferAndSetAsLocal())); + ASSERT_TRUE( + caller->SetRemoteDescription(callee->CreateAnswerAndSetAsLocal())); + ASSERT_EQ(caller->observer()->add_track_events_.size(), 2u); + ASSERT_EQ(caller->pc()->GetReceivers().size(), 2u); + + // Do a follow-up offer/answer exchange where the SSRCs are modified. + ASSERT_TRUE(callee->SetRemoteDescription(caller->CreateOfferAndSetAsLocal())); + auto answer = callee->CreateAnswer(); + auto& contents = answer->description()->contents(); + ASSERT_TRUE(!contents.empty()); + for (size_t i = 0; i < contents.size(); ++i) { + auto& mutable_streams = contents[i].media_description()->mutable_streams(); + ASSERT_EQ(mutable_streams.size(), 1u); + mutable_streams[0].ssrcs = {kFirstMungedSsrc + static_cast(i)}; + } + ASSERT_TRUE( + callee->SetLocalDescription(CloneSessionDescription(answer.get()))); + ASSERT_TRUE( + caller->SetRemoteDescription(CloneSessionDescription(answer.get()))); + + // No furher track events should fire because we never changed direction, only + // SSRCs. + ASSERT_EQ(caller->observer()->add_track_events_.size(), 2u); + // We should have the same number of receivers as before. + auto receivers = caller->pc()->GetReceivers(); + ASSERT_EQ(receivers.size(), 2u); + // The tracks are still alive. + EXPECT_EQ(receivers[0]->track()->state(), + MediaStreamTrackInterface::TrackState::kLive); + EXPECT_EQ(receivers[1]->track()->state(), + MediaStreamTrackInterface::TrackState::kLive); +} // Tests that with Unified Plan if the the stream id changes for a track when // when setting a new remote description, that the media stream is updated @@ -869,7 +919,7 @@ TEST_P(PeerConnectionRtpTest, auto callee = CreatePeerConnection(); rtc::scoped_refptr observer = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); auto offer = caller->CreateOfferAndSetAsLocal(); callee->pc()->SetRemoteDescription(observer, offer.release()); @@ -1844,7 +1894,7 @@ TEST_F(PeerConnectionMsidSignalingTest, PureUnifiedPlanToUs) { class SdpFormatReceivedTest : public PeerConnectionRtpTestUnifiedPlan {}; -#ifdef HAVE_SCTP +#ifdef WEBRTC_HAVE_SCTP TEST_F(SdpFormatReceivedTest, DataChannelOnlyIsReportedAsNoTracks) { auto caller = CreatePeerConnectionWithUnifiedPlan(); caller->CreateDataChannel("dc"); @@ -1856,7 +1906,7 @@ TEST_F(SdpFormatReceivedTest, DataChannelOnlyIsReportedAsNoTracks) { metrics::Samples("WebRTC.PeerConnection.SdpFormatReceived"), ElementsAre(Pair(kSdpFormatReceivedNoTracks, 1))); } -#endif // HAVE_SCTP +#endif // WEBRTC_HAVE_SCTP TEST_F(SdpFormatReceivedTest, SimpleUnifiedPlanIsReportedAsSimple) { auto caller = CreatePeerConnectionWithUnifiedPlan(); diff --git a/pc/peer_connection_signaling_unittest.cc b/pc/peer_connection_signaling_unittest.cc index 605a1338c6..1c94570ec7 100644 --- a/pc/peer_connection_signaling_unittest.cc +++ b/pc/peer_connection_signaling_unittest.cc @@ -11,6 +11,7 @@ // This file contains tests that check the PeerConnection's signaling state // machine, as well as tests that check basic, media-agnostic aspects of SDP. +#include #include #include @@ -18,10 +19,10 @@ #include "api/audio_codecs/builtin_audio_encoder_factory.h" #include "api/create_peerconnection_factory.h" #include "api/jsep_session_description.h" -#include "api/peer_connection_proxy.h" #include "api/video_codecs/builtin_video_decoder_factory.h" #include "api/video_codecs/builtin_video_encoder_factory.h" #include "pc/peer_connection.h" +#include "pc/peer_connection_proxy.h" #include "pc/peer_connection_wrapper.h" #include "pc/sdp_utils.h" #include "pc/webrtc_sdp.h" @@ -537,8 +538,7 @@ TEST_P(PeerConnectionSignalingTest, CreateOffersAndShutdown) { rtc::scoped_refptr observers[100]; for (auto& observer : observers) { - observer = - new rtc::RefCountedObject(); + observer = rtc::make_ref_counted(); caller->pc()->CreateOffer(observer, options); } @@ -559,8 +559,7 @@ TEST_P(PeerConnectionSignalingTest, CreateOffersAndShutdown) { // the WebRtcSessionDescriptionFactory is responsible for it. TEST_P(PeerConnectionSignalingTest, CloseCreateOfferAndShutdown) { auto caller = CreatePeerConnection(); - rtc::scoped_refptr observer = - new rtc::RefCountedObject(); + auto observer = rtc::make_ref_counted(); caller->pc()->Close(); caller->pc()->CreateOffer(observer, RTCOfferAnswerOptions()); caller.reset(nullptr); @@ -687,8 +686,8 @@ TEST_P(PeerConnectionSignalingTest, CreateOfferBlocksSetRemoteDescription) { auto offer = caller->CreateOffer(RTCOfferAnswerOptions()); EXPECT_EQ(0u, callee->pc()->GetReceivers().size()); - rtc::scoped_refptr offer_observer( - new rtc::RefCountedObject()); + auto offer_observer = + rtc::make_ref_counted(); // Synchronously invoke CreateOffer() and SetRemoteDescription(). The // SetRemoteDescription() operation should be chained to be executed // asynchronously, when CreateOffer() completes. @@ -901,6 +900,137 @@ TEST_P(PeerConnectionSignalingTest, UnsupportedContentType) { EXPECT_TRUE(caller->SetLocalDescription(std::move(offer))); } +TEST_P(PeerConnectionSignalingTest, ReceiveFlexFec) { + auto caller = CreatePeerConnection(); + + std::string sdp = + "v=0\r\n" + "o=- 8403615332048243445 2 IN IP4 127.0.0.1\r\n" + "s=-\r\n" + "t=0 0\r\n" + "a=group:BUNDLE 0\r\n" + "m=video 9 UDP/TLS/RTP/SAVPF 102 122\r\n" + "c=IN IP4 0.0.0.0\r\n" + "a=rtcp:9 IN IP4 0.0.0.0\r\n" + "a=ice-ufrag:IZeV\r\n" + "a=ice-pwd:uaZhQD4rYM/Tta2qWBT1Bbt4\r\n" + "a=ice-options:trickle\r\n" + "a=fingerprint:sha-256 " + "D8:6C:3D:FA:23:E2:2C:63:11:2D:D0:86:BE:C4:D0:65:F9:42:F7:1C:06:04:27:E6:" + "1C:2C:74:01:8D:50:67:23\r\n" + "a=setup:actpass\r\n" + "a=mid:0\r\n" + "a=sendrecv\r\n" + "a=msid:stream track\r\n" + "a=rtcp-mux\r\n" + "a=rtcp-rsize\r\n" + "a=rtpmap:102 VP8/90000\r\n" + "a=rtcp-fb:102 goog-remb\r\n" + "a=rtcp-fb:102 transport-cc\r\n" + "a=rtcp-fb:102 ccm fir\r\n" + "a=rtcp-fb:102 nack\r\n" + "a=rtcp-fb:102 nack pli\r\n" + "a=rtpmap:122 flexfec-03/90000\r\n" + "a=fmtp:122 repair-window=10000000\r\n" + "a=ssrc-group:FEC-FR 1224551896 1953032773\r\n" + "a=ssrc:1224551896 cname:/exJcmhSLpyu9FgV\r\n" + "a=ssrc:1953032773 cname:/exJcmhSLpyu9FgV\r\n"; + std::unique_ptr remote_description = + webrtc::CreateSessionDescription(SdpType::kOffer, sdp, nullptr); + + EXPECT_TRUE(caller->SetRemoteDescription(std::move(remote_description))); + + auto answer = caller->CreateAnswer(); + ASSERT_EQ(answer->description()->contents().size(), 1u); + ASSERT_NE( + answer->description()->contents()[0].media_description()->as_video(), + nullptr); + auto codecs = answer->description() + ->contents()[0] + .media_description() + ->as_video() + ->codecs(); + ASSERT_EQ(codecs.size(), 2u); + EXPECT_EQ(codecs[1].name, "flexfec-03"); + + EXPECT_TRUE(caller->SetLocalDescription(std::move(answer))); +} + +TEST_P(PeerConnectionSignalingTest, ReceiveFlexFecReoffer) { + auto caller = CreatePeerConnection(); + + std::string sdp = + "v=0\r\n" + "o=- 8403615332048243445 2 IN IP4 127.0.0.1\r\n" + "s=-\r\n" + "t=0 0\r\n" + "a=group:BUNDLE 0\r\n" + "m=video 9 UDP/TLS/RTP/SAVPF 102 35\r\n" + "c=IN IP4 0.0.0.0\r\n" + "a=rtcp:9 IN IP4 0.0.0.0\r\n" + "a=ice-ufrag:IZeV\r\n" + "a=ice-pwd:uaZhQD4rYM/Tta2qWBT1Bbt4\r\n" + "a=ice-options:trickle\r\n" + "a=fingerprint:sha-256 " + "D8:6C:3D:FA:23:E2:2C:63:11:2D:D0:86:BE:C4:D0:65:F9:42:F7:1C:06:04:27:E6:" + "1C:2C:74:01:8D:50:67:23\r\n" + "a=setup:actpass\r\n" + "a=mid:0\r\n" + "a=sendrecv\r\n" + "a=msid:stream track\r\n" + "a=rtcp-mux\r\n" + "a=rtcp-rsize\r\n" + "a=rtpmap:102 VP8/90000\r\n" + "a=rtcp-fb:102 goog-remb\r\n" + "a=rtcp-fb:102 transport-cc\r\n" + "a=rtcp-fb:102 ccm fir\r\n" + "a=rtcp-fb:102 nack\r\n" + "a=rtcp-fb:102 nack pli\r\n" + "a=rtpmap:35 flexfec-03/90000\r\n" + "a=fmtp:35 repair-window=10000000\r\n" + "a=ssrc-group:FEC-FR 1224551896 1953032773\r\n" + "a=ssrc:1224551896 cname:/exJcmhSLpyu9FgV\r\n" + "a=ssrc:1953032773 cname:/exJcmhSLpyu9FgV\r\n"; + std::unique_ptr remote_description = + webrtc::CreateSessionDescription(SdpType::kOffer, sdp, nullptr); + + EXPECT_TRUE(caller->SetRemoteDescription(std::move(remote_description))); + + auto answer = caller->CreateAnswer(); + ASSERT_EQ(answer->description()->contents().size(), 1u); + ASSERT_NE( + answer->description()->contents()[0].media_description()->as_video(), + nullptr); + auto codecs = answer->description() + ->contents()[0] + .media_description() + ->as_video() + ->codecs(); + ASSERT_EQ(codecs.size(), 2u); + EXPECT_EQ(codecs[1].name, "flexfec-03"); + EXPECT_EQ(codecs[1].id, 35); + + EXPECT_TRUE(caller->SetLocalDescription(std::move(answer))); + + // This generates a collision for AV1 which needs to be remapped. + auto offer = caller->CreateOffer(RTCOfferAnswerOptions()); + auto offer_codecs = offer->description() + ->contents()[0] + .media_description() + ->as_video() + ->codecs(); + auto flexfec_it = std::find_if( + offer_codecs.begin(), offer_codecs.end(), + [](const cricket::Codec& codec) { return codec.name == "flexfec-03"; }); + ASSERT_EQ(flexfec_it->id, 35); + auto av1_it = std::find_if( + offer_codecs.begin(), offer_codecs.end(), + [](const cricket::Codec& codec) { return codec.name == "AV1X"; }); + if (av1_it != offer_codecs.end()) { + ASSERT_NE(av1_it->id, 35); + } +} + INSTANTIATE_TEST_SUITE_P(PeerConnectionSignalingTest, PeerConnectionSignalingTest, Values(SdpSemantics::kPlanB, @@ -929,7 +1059,7 @@ TEST_F(PeerConnectionSignalingUnifiedPlanTest, // waiting for it would not ensure synchronicity. RTC_DCHECK(!caller->pc()->GetTransceivers()[0]->mid().has_value()); caller->pc()->SetLocalDescription( - new rtc::RefCountedObject(), + rtc::make_ref_counted(), offer.release()); EXPECT_TRUE(caller->pc()->GetTransceivers()[0]->mid().has_value()); } @@ -957,9 +1087,8 @@ TEST_F(PeerConnectionSignalingUnifiedPlanTest, // This offer will cause transceiver mids to get assigned. auto offer = caller->CreateOffer(RTCOfferAnswerOptions()); - rtc::scoped_refptr - offer_observer(new rtc::RefCountedObject< - ExecuteFunctionOnCreateSessionDescriptionObserver>( + auto offer_observer = + rtc::make_ref_counted( [pc = caller->pc()](SessionDescriptionInterface* desc) { // By not waiting for the observer's callback we can verify that the // operation executed immediately. @@ -968,7 +1097,7 @@ TEST_F(PeerConnectionSignalingUnifiedPlanTest, new rtc::RefCountedObject(), desc); EXPECT_TRUE(pc->GetTransceivers()[0]->mid().has_value()); - })); + }); caller->pc()->CreateOffer(offer_observer, RTCOfferAnswerOptions()); EXPECT_TRUE_WAIT(offer_observer->was_called(), kWaitTimeout); } @@ -1055,8 +1184,7 @@ TEST_F(PeerConnectionSignalingUnifiedPlanTest, caller->AddTransceiver(cricket::MEDIA_TYPE_AUDIO, RtpTransceiverInit()); EXPECT_TRUE(caller->observer()->has_negotiation_needed_event()); - rtc::scoped_refptr observer = - new rtc::RefCountedObject(); + auto observer = rtc::make_ref_counted(); caller->pc()->CreateOffer(observer, RTCOfferAnswerOptions()); // For this test to work, the operation has to be pending, i.e. the observer // has not yet been invoked. diff --git a/pc/peer_connection_simulcast_unittest.cc b/pc/peer_connection_simulcast_unittest.cc index 8822a980f7..31385754b7 100644 --- a/pc/peer_connection_simulcast_unittest.cc +++ b/pc/peer_connection_simulcast_unittest.cc @@ -157,9 +157,10 @@ class PeerConnectionSimulcastTests : public ::testing::Test { rtc::scoped_refptr AddTransceiver( PeerConnectionWrapper* pc, - const std::vector& layers) { + const std::vector& layers, + cricket::MediaType media_type = cricket::MEDIA_TYPE_VIDEO) { auto init = CreateTransceiverInit(layers); - return pc->AddTransceiver(cricket::MEDIA_TYPE_VIDEO, init); + return pc->AddTransceiver(media_type, init); } SimulcastDescription RemoveSimulcast(SessionDescriptionInterface* sd) { @@ -556,6 +557,25 @@ TEST_F(PeerConnectionSimulcastTests, NegotiationDoesNotHaveRidExtension) { ValidateTransceiverParameters(transceiver, expected_layers); } +TEST_F(PeerConnectionSimulcastTests, SimulcastAudioRejected) { + auto local = CreatePeerConnectionWrapper(); + auto remote = CreatePeerConnectionWrapper(); + auto layers = CreateLayers({"1", "2", "3", "4"}, true); + auto transceiver = + AddTransceiver(local.get(), layers, cricket::MEDIA_TYPE_AUDIO); + // Should only have the first layer. + auto parameters = transceiver->sender()->GetParameters(); + EXPECT_EQ(1u, parameters.encodings.size()); + EXPECT_THAT(parameters.encodings, + ElementsAre(Field("rid", &RtpEncodingParameters::rid, Eq("")))); + ExchangeOfferAnswer(local.get(), remote.get(), {}); + // Still have a single layer after negotiation + parameters = transceiver->sender()->GetParameters(); + EXPECT_EQ(1u, parameters.encodings.size()); + EXPECT_THAT(parameters.encodings, + ElementsAre(Field("rid", &RtpEncodingParameters::rid, Eq("")))); +} + #if RTC_METRICS_ENABLED // // Checks the logged metrics when simulcast is not used. diff --git a/pc/peer_connection_wrapper.cc b/pc/peer_connection_wrapper.cc index 328f5795e2..3b4d28f0d9 100644 --- a/pc/peer_connection_wrapper.cc +++ b/pc/peer_connection_wrapper.cc @@ -48,7 +48,10 @@ PeerConnectionWrapper::PeerConnectionWrapper( observer_->SetPeerConnectionInterface(pc_.get()); } -PeerConnectionWrapper::~PeerConnectionWrapper() = default; +PeerConnectionWrapper::~PeerConnectionWrapper() { + if (pc_) + pc_->Close(); +} PeerConnectionFactoryInterface* PeerConnectionWrapper::pc_factory() { return pc_factory_.get(); @@ -133,8 +136,7 @@ PeerConnectionWrapper::CreateRollback() { std::unique_ptr PeerConnectionWrapper::CreateSdp( rtc::FunctionView fn, std::string* error_out) { - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); + auto observer = rtc::make_ref_counted(); fn(observer); EXPECT_EQ_WAIT(true, observer->called(), kDefaultTimeout); if (error_out && !observer->result()) { @@ -179,8 +181,7 @@ bool PeerConnectionWrapper::SetRemoteDescription( bool PeerConnectionWrapper::SetSdp( rtc::FunctionView fn, std::string* error_out) { - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); + auto observer = rtc::make_ref_counted(); fn(observer); EXPECT_EQ_WAIT(true, observer->called(), kDefaultTimeout); if (error_out && !observer->result()) { @@ -305,7 +306,14 @@ rtc::scoped_refptr PeerConnectionWrapper::AddVideoTrack( rtc::scoped_refptr PeerConnectionWrapper::CreateDataChannel(const std::string& label) { - return pc()->CreateDataChannel(label, nullptr); + auto result = pc()->CreateDataChannelOrError(label, nullptr); + if (!result.ok()) { + RTC_LOG(LS_ERROR) << "CreateDataChannel failed: " + << ToString(result.error().type()) << " " + << result.error().message(); + return nullptr; + } + return result.MoveValue(); } PeerConnectionInterface::SignalingState @@ -323,8 +331,7 @@ bool PeerConnectionWrapper::IsIceConnected() { rtc::scoped_refptr PeerConnectionWrapper::GetStats() { - rtc::scoped_refptr callback( - new rtc::RefCountedObject()); + auto callback = rtc::make_ref_counted(); pc()->GetStats(callback); EXPECT_TRUE_WAIT(callback->called(), kDefaultTimeout); return callback->report(); diff --git a/pc/proxy.cc b/pc/proxy.cc new file mode 100644 index 0000000000..5f4e0b8832 --- /dev/null +++ b/pc/proxy.cc @@ -0,0 +1,25 @@ +/* + * Copyright 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "pc/proxy.h" + +#include "rtc_base/trace_event.h" + +namespace webrtc { +namespace proxy_internal { +ScopedTrace::ScopedTrace(const char* class_and_method_name) + : class_and_method_name_(class_and_method_name) { + TRACE_EVENT_BEGIN0("webrtc", class_and_method_name_); +} +ScopedTrace::~ScopedTrace() { + TRACE_EVENT_END0("webrtc", class_and_method_name_); +} +} // namespace proxy_internal +} // namespace webrtc diff --git a/api/proxy.h b/pc/proxy.h similarity index 63% rename from api/proxy.h rename to pc/proxy.h index 05f7414bc0..565ae80175 100644 --- a/api/proxy.h +++ b/pc/proxy.h @@ -12,6 +12,13 @@ // PeerConnection classes. // TODO(deadbeef): Move this to pc/; this is part of the implementation. +// The proxied objects are initialized with either one or two thread +// objects that operations can be proxied to: The primary and secondary +// threads. +// In common usage, the primary thread will be the PeerConnection's +// signaling thread, and the secondary thread will be either the +// PeerConnection's worker thread or the PeerConnection's network thread. + // // Example usage: // @@ -29,28 +36,28 @@ // }; // // BEGIN_PROXY_MAP(Test) -// PROXY_SIGNALING_THREAD_DESTRUCTOR() +// PROXY_PRIMARY_THREAD_DESTRUCTOR() // PROXY_METHOD0(std::string, FooA) // PROXY_CONSTMETHOD1(std::string, FooB, arg1) -// PROXY_WORKER_METHOD1(std::string, FooC, arg1) +// PROXY_SECONDARY_METHOD1(std::string, FooC, arg1) // END_PROXY_MAP() // -// Where the destructor and first two methods are invoked on the signaling -// thread, and the third is invoked on the worker thread. +// Where the destructor and first two methods are invoked on the primary +// thread, and the third is invoked on the secondary thread. // // The proxy can be created using // // TestProxy::Create(Thread* signaling_thread, Thread* worker_thread, // TestInterface*). // -// The variant defined with BEGIN_SIGNALING_PROXY_MAP is unaware of -// the worker thread, and invokes all methods on the signaling thread. +// The variant defined with BEGIN_PRIMARY_PROXY_MAP is unaware of +// the secondary thread, and invokes all methods on the primary thread. // // The variant defined with BEGIN_OWNED_PROXY_MAP does not use // refcounting, and instead just takes ownership of the object being proxied. -#ifndef API_PROXY_H_ -#define API_PROXY_H_ +#ifndef PC_PROXY_H_ +#define PC_PROXY_H_ #include #include @@ -64,14 +71,31 @@ #include "rtc_base/event.h" #include "rtc_base/message_handler.h" #include "rtc_base/ref_counted_object.h" +#include "rtc_base/string_utils.h" #include "rtc_base/system/rtc_export.h" #include "rtc_base/thread.h" +#if !defined(RTC_DISABLE_PROXY_TRACE_EVENTS) && !defined(WEBRTC_CHROMIUM_BUILD) +#define RTC_DISABLE_PROXY_TRACE_EVENTS +#endif + namespace rtc { class Location; } namespace webrtc { +namespace proxy_internal { + +// Class for tracing the lifetime of MethodCall::Marshal. +class ScopedTrace { + public: + explicit ScopedTrace(const char* class_and_method_name); + ~ScopedTrace(); + + private: + const char* const class_and_method_name_; +}; +} // namespace proxy_internal template class ReturnType { @@ -174,6 +198,9 @@ class ConstMethodCall : public QueuedTask { rtc::Event event_; }; +#define PROXY_STRINGIZE_IMPL(x) #x +#define PROXY_STRINGIZE(x) PROXY_STRINGIZE_IMPL(x) + // Helper macros to reduce code duplication. #define PROXY_MAP_BOILERPLATE(c) \ template \ @@ -182,6 +209,7 @@ class ConstMethodCall : public QueuedTask { template \ class c##ProxyWithInternal : public c##Interface { \ protected: \ + static constexpr char proxy_name_[] = #c "Proxy"; \ typedef c##Interface C; \ \ public: \ @@ -191,29 +219,31 @@ class ConstMethodCall : public QueuedTask { // clang-format off // clang-format would put the semicolon alone, // leading to a presubmit error (cpplint.py) -#define END_PROXY_MAP() \ - }; +#define END_PROXY_MAP(c) \ + }; \ + template \ + constexpr char c##ProxyWithInternal::proxy_name_[]; // clang-format on -#define SIGNALING_PROXY_MAP_BOILERPLATE(c) \ +#define PRIMARY_PROXY_MAP_BOILERPLATE(c) \ + protected: \ + c##ProxyWithInternal(rtc::Thread* primary_thread, INTERNAL_CLASS* c) \ + : primary_thread_(primary_thread), c_(c) {} \ + \ + private: \ + mutable rtc::Thread* primary_thread_; + +#define SECONDARY_PROXY_MAP_BOILERPLATE(c) \ protected: \ - c##ProxyWithInternal(rtc::Thread* signaling_thread, INTERNAL_CLASS* c) \ - : signaling_thread_(signaling_thread), c_(c) {} \ + c##ProxyWithInternal(rtc::Thread* primary_thread, \ + rtc::Thread* secondary_thread, INTERNAL_CLASS* c) \ + : primary_thread_(primary_thread), \ + secondary_thread_(secondary_thread), \ + c_(c) {} \ \ private: \ - mutable rtc::Thread* signaling_thread_; - -#define WORKER_PROXY_MAP_BOILERPLATE(c) \ - protected: \ - c##ProxyWithInternal(rtc::Thread* signaling_thread, \ - rtc::Thread* worker_thread, INTERNAL_CLASS* c) \ - : signaling_thread_(signaling_thread), \ - worker_thread_(worker_thread), \ - c_(c) {} \ - \ - private: \ - mutable rtc::Thread* signaling_thread_; \ - mutable rtc::Thread* worker_thread_; + mutable rtc::Thread* primary_thread_; \ + mutable rtc::Thread* secondary_thread_; // Note that the destructor is protected so that the proxy can only be // destroyed via RefCountInterface. @@ -246,172 +276,198 @@ class ConstMethodCall : public QueuedTask { void DestroyInternal() { delete c_; } \ INTERNAL_CLASS* c_; -#define BEGIN_SIGNALING_PROXY_MAP(c) \ +#define BEGIN_PRIMARY_PROXY_MAP(c) \ + PROXY_MAP_BOILERPLATE(c) \ + PRIMARY_PROXY_MAP_BOILERPLATE(c) \ + REFCOUNTED_PROXY_MAP_BOILERPLATE(c) \ + public: \ + static rtc::scoped_refptr Create( \ + rtc::Thread* primary_thread, INTERNAL_CLASS* c) { \ + return rtc::make_ref_counted(primary_thread, c); \ + } + +#define BEGIN_PROXY_MAP(c) \ PROXY_MAP_BOILERPLATE(c) \ - SIGNALING_PROXY_MAP_BOILERPLATE(c) \ + SECONDARY_PROXY_MAP_BOILERPLATE(c) \ REFCOUNTED_PROXY_MAP_BOILERPLATE(c) \ public: \ static rtc::scoped_refptr Create( \ - rtc::Thread* signaling_thread, INTERNAL_CLASS* c) { \ - return new rtc::RefCountedObject(signaling_thread, \ - c); \ - } - -#define BEGIN_PROXY_MAP(c) \ - PROXY_MAP_BOILERPLATE(c) \ - WORKER_PROXY_MAP_BOILERPLATE(c) \ - REFCOUNTED_PROXY_MAP_BOILERPLATE(c) \ - public: \ - static rtc::scoped_refptr Create( \ - rtc::Thread* signaling_thread, rtc::Thread* worker_thread, \ - INTERNAL_CLASS* c) { \ - return new rtc::RefCountedObject(signaling_thread, \ - worker_thread, c); \ + rtc::Thread* primary_thread, rtc::Thread* secondary_thread, \ + INTERNAL_CLASS* c) { \ + return rtc::make_ref_counted(primary_thread, \ + secondary_thread, c); \ } #define BEGIN_OWNED_PROXY_MAP(c) \ PROXY_MAP_BOILERPLATE(c) \ - WORKER_PROXY_MAP_BOILERPLATE(c) \ + SECONDARY_PROXY_MAP_BOILERPLATE(c) \ OWNED_PROXY_MAP_BOILERPLATE(c) \ public: \ static std::unique_ptr Create( \ - rtc::Thread* signaling_thread, rtc::Thread* worker_thread, \ + rtc::Thread* primary_thread, rtc::Thread* secondary_thread, \ std::unique_ptr c) { \ return std::unique_ptr(new c##ProxyWithInternal( \ - signaling_thread, worker_thread, c.release())); \ + primary_thread, secondary_thread, c.release())); \ } -#define PROXY_SIGNALING_THREAD_DESTRUCTOR() \ - private: \ - rtc::Thread* destructor_thread() const { return signaling_thread_; } \ - \ +#define PROXY_PRIMARY_THREAD_DESTRUCTOR() \ + private: \ + rtc::Thread* destructor_thread() const { return primary_thread_; } \ + \ public: // NOLINTNEXTLINE -#define PROXY_WORKER_THREAD_DESTRUCTOR() \ - private: \ - rtc::Thread* destructor_thread() const { return worker_thread_; } \ - \ +#define PROXY_SECONDARY_THREAD_DESTRUCTOR() \ + private: \ + rtc::Thread* destructor_thread() const { return secondary_thread_; } \ + \ public: // NOLINTNEXTLINE -#define PROXY_METHOD0(r, method) \ - r method() override { \ - MethodCall call(c_, &C::method); \ - return call.Marshal(RTC_FROM_HERE, signaling_thread_); \ +#if defined(RTC_DISABLE_PROXY_TRACE_EVENTS) +#define TRACE_BOILERPLATE(method) \ + do { \ + } while (0) +#else // if defined(RTC_DISABLE_PROXY_TRACE_EVENTS) +#define TRACE_BOILERPLATE(method) \ + static constexpr auto class_and_method_name = \ + rtc::MakeCompileTimeString(proxy_name_) \ + .Concat(rtc::MakeCompileTimeString("::")) \ + .Concat(rtc::MakeCompileTimeString(#method)); \ + proxy_internal::ScopedTrace scoped_trace(class_and_method_name.string) + +#endif // if defined(RTC_DISABLE_PROXY_TRACE_EVENTS) + +#define PROXY_METHOD0(r, method) \ + r method() override { \ + TRACE_BOILERPLATE(method); \ + MethodCall call(c_, &C::method); \ + return call.Marshal(RTC_FROM_HERE, primary_thread_); \ } -#define PROXY_CONSTMETHOD0(r, method) \ - r method() const override { \ - ConstMethodCall call(c_, &C::method); \ - return call.Marshal(RTC_FROM_HERE, signaling_thread_); \ +#define PROXY_CONSTMETHOD0(r, method) \ + r method() const override { \ + TRACE_BOILERPLATE(method); \ + ConstMethodCall call(c_, &C::method); \ + return call.Marshal(RTC_FROM_HERE, primary_thread_); \ } #define PROXY_METHOD1(r, method, t1) \ r method(t1 a1) override { \ + TRACE_BOILERPLATE(method); \ MethodCall call(c_, &C::method, std::move(a1)); \ - return call.Marshal(RTC_FROM_HERE, signaling_thread_); \ + return call.Marshal(RTC_FROM_HERE, primary_thread_); \ } #define PROXY_CONSTMETHOD1(r, method, t1) \ r method(t1 a1) const override { \ + TRACE_BOILERPLATE(method); \ ConstMethodCall call(c_, &C::method, std::move(a1)); \ - return call.Marshal(RTC_FROM_HERE, signaling_thread_); \ + return call.Marshal(RTC_FROM_HERE, primary_thread_); \ } #define PROXY_METHOD2(r, method, t1, t2) \ r method(t1 a1, t2 a2) override { \ + TRACE_BOILERPLATE(method); \ MethodCall call(c_, &C::method, std::move(a1), \ std::move(a2)); \ - return call.Marshal(RTC_FROM_HERE, signaling_thread_); \ + return call.Marshal(RTC_FROM_HERE, primary_thread_); \ } #define PROXY_METHOD3(r, method, t1, t2, t3) \ r method(t1 a1, t2 a2, t3 a3) override { \ + TRACE_BOILERPLATE(method); \ MethodCall call(c_, &C::method, std::move(a1), \ std::move(a2), std::move(a3)); \ - return call.Marshal(RTC_FROM_HERE, signaling_thread_); \ + return call.Marshal(RTC_FROM_HERE, primary_thread_); \ } #define PROXY_METHOD4(r, method, t1, t2, t3, t4) \ r method(t1 a1, t2 a2, t3 a3, t4 a4) override { \ + TRACE_BOILERPLATE(method); \ MethodCall call(c_, &C::method, std::move(a1), \ std::move(a2), std::move(a3), \ std::move(a4)); \ - return call.Marshal(RTC_FROM_HERE, signaling_thread_); \ + return call.Marshal(RTC_FROM_HERE, primary_thread_); \ } #define PROXY_METHOD5(r, method, t1, t2, t3, t4, t5) \ r method(t1 a1, t2 a2, t3 a3, t4 a4, t5 a5) override { \ + TRACE_BOILERPLATE(method); \ MethodCall call(c_, &C::method, std::move(a1), \ std::move(a2), std::move(a3), \ std::move(a4), std::move(a5)); \ - return call.Marshal(RTC_FROM_HERE, signaling_thread_); \ + return call.Marshal(RTC_FROM_HERE, primary_thread_); \ } -// Define methods which should be invoked on the worker thread. -#define PROXY_WORKER_METHOD0(r, method) \ - r method() override { \ - MethodCall call(c_, &C::method); \ - return call.Marshal(RTC_FROM_HERE, worker_thread_); \ +// Define methods which should be invoked on the secondary thread. +#define PROXY_SECONDARY_METHOD0(r, method) \ + r method() override { \ + TRACE_BOILERPLATE(method); \ + MethodCall call(c_, &C::method); \ + return call.Marshal(RTC_FROM_HERE, secondary_thread_); \ } -#define PROXY_WORKER_CONSTMETHOD0(r, method) \ - r method() const override { \ - ConstMethodCall call(c_, &C::method); \ - return call.Marshal(RTC_FROM_HERE, worker_thread_); \ +#define PROXY_SECONDARY_CONSTMETHOD0(r, method) \ + r method() const override { \ + TRACE_BOILERPLATE(method); \ + ConstMethodCall call(c_, &C::method); \ + return call.Marshal(RTC_FROM_HERE, secondary_thread_); \ } -#define PROXY_WORKER_METHOD1(r, method, t1) \ +#define PROXY_SECONDARY_METHOD1(r, method, t1) \ r method(t1 a1) override { \ + TRACE_BOILERPLATE(method); \ MethodCall call(c_, &C::method, std::move(a1)); \ - return call.Marshal(RTC_FROM_HERE, worker_thread_); \ + return call.Marshal(RTC_FROM_HERE, secondary_thread_); \ } -#define PROXY_WORKER_CONSTMETHOD1(r, method, t1) \ +#define PROXY_SECONDARY_CONSTMETHOD1(r, method, t1) \ r method(t1 a1) const override { \ + TRACE_BOILERPLATE(method); \ ConstMethodCall call(c_, &C::method, std::move(a1)); \ - return call.Marshal(RTC_FROM_HERE, worker_thread_); \ + return call.Marshal(RTC_FROM_HERE, secondary_thread_); \ } -#define PROXY_WORKER_METHOD2(r, method, t1, t2) \ +#define PROXY_SECONDARY_METHOD2(r, method, t1, t2) \ r method(t1 a1, t2 a2) override { \ + TRACE_BOILERPLATE(method); \ MethodCall call(c_, &C::method, std::move(a1), \ std::move(a2)); \ - return call.Marshal(RTC_FROM_HERE, worker_thread_); \ + return call.Marshal(RTC_FROM_HERE, secondary_thread_); \ } -#define PROXY_WORKER_CONSTMETHOD2(r, method, t1, t2) \ +#define PROXY_SECONDARY_CONSTMETHOD2(r, method, t1, t2) \ r method(t1 a1, t2 a2) const override { \ + TRACE_BOILERPLATE(method); \ ConstMethodCall call(c_, &C::method, std::move(a1), \ std::move(a2)); \ - return call.Marshal(RTC_FROM_HERE, worker_thread_); \ + return call.Marshal(RTC_FROM_HERE, secondary_thread_); \ } -#define PROXY_WORKER_METHOD3(r, method, t1, t2, t3) \ +#define PROXY_SECONDARY_METHOD3(r, method, t1, t2, t3) \ r method(t1 a1, t2 a2, t3 a3) override { \ + TRACE_BOILERPLATE(method); \ MethodCall call(c_, &C::method, std::move(a1), \ std::move(a2), std::move(a3)); \ - return call.Marshal(RTC_FROM_HERE, worker_thread_); \ + return call.Marshal(RTC_FROM_HERE, secondary_thread_); \ } -#define PROXY_WORKER_CONSTMETHOD3(r, method, t1, t2) \ +#define PROXY_SECONDARY_CONSTMETHOD3(r, method, t1, t2) \ r method(t1 a1, t2 a2, t3 a3) const override { \ + TRACE_BOILERPLATE(method); \ ConstMethodCall call(c_, &C::method, std::move(a1), \ std::move(a2), std::move(a3)); \ - return call.Marshal(RTC_FROM_HERE, worker_thread_); \ + return call.Marshal(RTC_FROM_HERE, secondary_thread_); \ } // For use when returning purely const state (set during construction). // Use with caution. This method should only be used when the return value will // always be the same. -#define BYPASS_PROXY_CONSTMETHOD0(r, method) \ - r method() const override { \ - static_assert( \ - std::is_same::value || !std::is_pointer::value, \ - "Type is a pointer"); \ - static_assert(!std::is_reference::value, "Type is a reference"); \ - return c_->method(); \ +#define BYPASS_PROXY_CONSTMETHOD0(r, method) \ + r method() const override { \ + TRACE_BOILERPLATE(method); \ + return c_->method(); \ } } // namespace webrtc -#endif // API_PROXY_H_ +#endif // PC_PROXY_H_ diff --git a/pc/proxy_unittest.cc b/pc/proxy_unittest.cc index 500828a03e..ef3d97eddc 100644 --- a/pc/proxy_unittest.cc +++ b/pc/proxy_unittest.cc @@ -8,7 +8,7 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include "api/proxy.h" +#include "pc/proxy.h" #include #include @@ -43,7 +43,7 @@ class FakeInterface : public rtc::RefCountInterface { class Fake : public FakeInterface { public: static rtc::scoped_refptr Create() { - return new rtc::RefCountedObject(); + return rtc::make_ref_counted(); } // Used to verify destructor is called on the correct thread. MOCK_METHOD(void, Destroy, ()); @@ -64,27 +64,27 @@ class Fake : public FakeInterface { // Proxies for the test interface. BEGIN_PROXY_MAP(Fake) -PROXY_WORKER_THREAD_DESTRUCTOR() +PROXY_SECONDARY_THREAD_DESTRUCTOR() PROXY_METHOD0(void, VoidMethod0) PROXY_METHOD0(std::string, Method0) PROXY_CONSTMETHOD0(std::string, ConstMethod0) -PROXY_WORKER_METHOD1(std::string, Method1, std::string) +PROXY_SECONDARY_METHOD1(std::string, Method1, std::string) PROXY_CONSTMETHOD1(std::string, ConstMethod1, std::string) -PROXY_WORKER_METHOD2(std::string, Method2, std::string, std::string) -END_PROXY_MAP() +PROXY_SECONDARY_METHOD2(std::string, Method2, std::string, std::string) +END_PROXY_MAP(Fake) // Preprocessor hack to get a proxy class a name different than FakeProxy. #define FakeProxy FakeSignalingProxy #define FakeProxyWithInternal FakeSignalingProxyWithInternal -BEGIN_SIGNALING_PROXY_MAP(Fake) -PROXY_SIGNALING_THREAD_DESTRUCTOR() +BEGIN_PRIMARY_PROXY_MAP(Fake) +PROXY_PRIMARY_THREAD_DESTRUCTOR() PROXY_METHOD0(void, VoidMethod0) PROXY_METHOD0(std::string, Method0) PROXY_CONSTMETHOD0(std::string, ConstMethod0) PROXY_METHOD1(std::string, Method1, std::string) PROXY_CONSTMETHOD1(std::string, ConstMethod1, std::string) PROXY_METHOD2(std::string, Method2, std::string, std::string) -END_PROXY_MAP() +END_PROXY_MAP(Fake) #undef FakeProxy class SignalingProxyTest : public ::testing::Test { @@ -270,9 +270,9 @@ class Foo : public FooInterface { }; BEGIN_OWNED_PROXY_MAP(Foo) -PROXY_SIGNALING_THREAD_DESTRUCTOR() +PROXY_PRIMARY_THREAD_DESTRUCTOR() PROXY_METHOD0(void, Bar) -END_PROXY_MAP() +END_PROXY_MAP(Foo) class OwnedProxyTest : public ::testing::Test { public: diff --git a/pc/remote_audio_source.cc b/pc/remote_audio_source.cc index 8ae0612541..dc890e737c 100644 --- a/pc/remote_audio_source.cc +++ b/pc/remote_audio_source.cc @@ -13,17 +13,15 @@ #include #include -#include #include "absl/algorithm/container.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "rtc_base/checks.h" #include "rtc_base/location.h" #include "rtc_base/logging.h" -#include "rtc_base/numerics/safe_conversions.h" #include "rtc_base/strings/string_format.h" #include "rtc_base/thread.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -51,54 +49,63 @@ class RemoteAudioSource::AudioDataProxy : public AudioSinkInterface { const rtc::scoped_refptr source_; }; -RemoteAudioSource::RemoteAudioSource(rtc::Thread* worker_thread) +RemoteAudioSource::RemoteAudioSource( + rtc::Thread* worker_thread, + OnAudioChannelGoneAction on_audio_channel_gone_action) : main_thread_(rtc::Thread::Current()), worker_thread_(worker_thread), + on_audio_channel_gone_action_(on_audio_channel_gone_action), state_(MediaSourceInterface::kLive) { RTC_DCHECK(main_thread_); RTC_DCHECK(worker_thread_); } RemoteAudioSource::~RemoteAudioSource() { - RTC_DCHECK(main_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(main_thread_); RTC_DCHECK(audio_observers_.empty()); - RTC_DCHECK(sinks_.empty()); + if (!sinks_.empty()) { + RTC_LOG(LS_WARNING) + << "RemoteAudioSource destroyed while sinks_ is non-empty."; + } } void RemoteAudioSource::Start(cricket::VoiceMediaChannel* media_channel, absl::optional ssrc) { - RTC_DCHECK_RUN_ON(main_thread_); - RTC_DCHECK(media_channel); + RTC_DCHECK_RUN_ON(worker_thread_); // Register for callbacks immediately before AddSink so that we always get // notified when a channel goes out of scope (signaled when "AudioDataProxy" // is destroyed). - worker_thread_->Invoke(RTC_FROM_HERE, [&] { - ssrc ? media_channel->SetRawAudioSink( - *ssrc, std::make_unique(this)) - : media_channel->SetDefaultRawAudioSink( - std::make_unique(this)); - }); + RTC_DCHECK(media_channel); + ssrc ? media_channel->SetRawAudioSink(*ssrc, + std::make_unique(this)) + : media_channel->SetDefaultRawAudioSink( + std::make_unique(this)); } void RemoteAudioSource::Stop(cricket::VoiceMediaChannel* media_channel, absl::optional ssrc) { - RTC_DCHECK_RUN_ON(main_thread_); + RTC_DCHECK_RUN_ON(worker_thread_); RTC_DCHECK(media_channel); + ssrc ? media_channel->SetRawAudioSink(*ssrc, nullptr) + : media_channel->SetDefaultRawAudioSink(nullptr); +} - worker_thread_->Invoke(RTC_FROM_HERE, [&] { - ssrc ? media_channel->SetRawAudioSink(*ssrc, nullptr) - : media_channel->SetDefaultRawAudioSink(nullptr); - }); +void RemoteAudioSource::SetState(SourceState new_state) { + RTC_DCHECK_RUN_ON(main_thread_); + if (state_ != new_state) { + state_ = new_state; + FireOnChanged(); + } } MediaSourceInterface::SourceState RemoteAudioSource::state() const { - RTC_DCHECK(main_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(main_thread_); return state_; } bool RemoteAudioSource::remote() const { - RTC_DCHECK(main_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(main_thread_); return true; } @@ -124,7 +131,7 @@ void RemoteAudioSource::UnregisterAudioObserver(AudioObserver* observer) { } void RemoteAudioSource::AddSink(AudioTrackSinkInterface* sink) { - RTC_DCHECK(main_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(main_thread_); RTC_DCHECK(sink); if (state_ != MediaSourceInterface::kLive) { @@ -138,7 +145,7 @@ void RemoteAudioSource::AddSink(AudioTrackSinkInterface* sink) { } void RemoteAudioSource::RemoveSink(AudioTrackSinkInterface* sink) { - RTC_DCHECK(main_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(main_thread_); RTC_DCHECK(sink); MutexLock lock(&sink_lock_); @@ -158,6 +165,9 @@ void RemoteAudioSource::OnData(const AudioSinkInterface::Data& audio) { } void RemoteAudioSource::OnAudioChannelGone() { + if (on_audio_channel_gone_action_ != OnAudioChannelGoneAction::kEnd) { + return; + } // Called when the audio channel is deleted. It may be the worker thread // in libjingle or may be a different worker thread. // This object needs to live long enough for the cleanup logic in OnMessage to @@ -170,10 +180,9 @@ void RemoteAudioSource::OnAudioChannelGone() { } void RemoteAudioSource::OnMessage(rtc::Message* msg) { - RTC_DCHECK(main_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(main_thread_); sinks_.clear(); - state_ = MediaSourceInterface::kEnded; - FireOnChanged(); + SetState(MediaSourceInterface::kEnded); // Will possibly delete this RemoteAudioSource since it is reference counted // in the message. delete msg->pdata; diff --git a/pc/remote_audio_source.h b/pc/remote_audio_source.h index 9ec09165cf..2eae073272 100644 --- a/pc/remote_audio_source.h +++ b/pc/remote_audio_source.h @@ -11,15 +11,21 @@ #ifndef PC_REMOTE_AUDIO_SOURCE_H_ #define PC_REMOTE_AUDIO_SOURCE_H_ +#include + #include #include #include "absl/types/optional.h" #include "api/call/audio_sink.h" +#include "api/media_stream_interface.h" #include "api/notifier.h" +#include "media/base/media_channel.h" #include "pc/channel.h" #include "rtc_base/message_handler.h" #include "rtc_base/synchronization/mutex.h" +#include "rtc_base/thread.h" +#include "rtc_base/thread_message.h" namespace rtc { struct Message; @@ -34,7 +40,21 @@ namespace webrtc { class RemoteAudioSource : public Notifier, rtc::MessageHandler { public: - explicit RemoteAudioSource(rtc::Thread* worker_thread); + // In Unified Plan, receivers map to m= sections and their tracks and sources + // survive SSRCs being reconfigured. The life cycle of the remote audio source + // is associated with the life cycle of the m= section, and thus even if an + // audio channel is destroyed the RemoteAudioSource should kSurvive. + // + // In Plan B however, remote audio sources map 1:1 with an SSRCs and if an + // audio channel is destroyed, the RemoteAudioSource should kEnd. + enum class OnAudioChannelGoneAction { + kSurvive, + kEnd, + }; + + explicit RemoteAudioSource( + rtc::Thread* worker_thread, + OnAudioChannelGoneAction on_audio_channel_gone_action); // Register and unregister remote audio source with the underlying media // engine. @@ -42,6 +62,7 @@ class RemoteAudioSource : public Notifier, absl::optional ssrc); void Stop(cricket::VoiceMediaChannel* media_channel, absl::optional ssrc); + void SetState(SourceState new_state); // MediaSourceInterface implementation. MediaSourceInterface::SourceState state() const override; @@ -61,6 +82,7 @@ class RemoteAudioSource : public Notifier, private: // These are callbacks from the media engine. class AudioDataProxy; + void OnData(const AudioSinkInterface::Data& audio); void OnAudioChannelGone(); @@ -68,6 +90,7 @@ class RemoteAudioSource : public Notifier, rtc::Thread* const main_thread_; rtc::Thread* const worker_thread_; + const OnAudioChannelGoneAction on_audio_channel_gone_action_; std::list audio_observers_; Mutex sink_lock_; std::list sinks_; diff --git a/pc/rtc_stats_collector.cc b/pc/rtc_stats_collector.cc index 529200894d..6599d0ef49 100644 --- a/pc/rtc_stats_collector.cc +++ b/pc/rtc_stats_collector.cc @@ -10,23 +10,52 @@ #include "pc/rtc_stats_collector.h" +#include + +#include +#include #include #include #include #include #include +#include "api/array_view.h" #include "api/candidate.h" #include "api/media_stream_interface.h" -#include "api/peer_connection_interface.h" +#include "api/rtp_parameters.h" +#include "api/rtp_receiver_interface.h" +#include "api/rtp_sender_interface.h" +#include "api/sequence_checker.h" +#include "api/stats/rtc_stats.h" +#include "api/stats/rtcstats_objects.h" +#include "api/task_queue/queued_task.h" #include "api/video/video_content_type.h" +#include "common_video/include/quality_limitation_reason.h" #include "media/base/media_channel.h" +#include "modules/audio_processing/include/audio_processing_statistics.h" +#include "modules/rtp_rtcp/include/report_block_data.h" +#include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" +#include "p2p/base/connection_info.h" +#include "p2p/base/dtls_transport_internal.h" +#include "p2p/base/ice_transport_internal.h" #include "p2p/base/p2p_constants.h" #include "p2p/base/port.h" -#include "pc/peer_connection.h" +#include "pc/channel.h" +#include "pc/channel_interface.h" +#include "pc/data_channel_utils.h" #include "pc/rtc_stats_traversal.h" #include "pc/webrtc_sdp.h" #include "rtc_base/checks.h" +#include "rtc_base/ip_address.h" +#include "rtc_base/location.h" +#include "rtc_base/logging.h" +#include "rtc_base/network_constants.h" +#include "rtc_base/ref_counted_object.h" +#include "rtc_base/rtc_certificate.h" +#include "rtc_base/socket_address.h" +#include "rtc_base/ssl_stream_adapter.h" +#include "rtc_base/string_encode.h" #include "rtc_base/strings/string_builder.h" #include "rtc_base/time_utils.h" #include "rtc_base/trace_event.h" @@ -80,17 +109,23 @@ std::string RTCTransportStatsIDFromTransportChannel( return sb.str(); } -std::string RTCInboundRTPStreamStatsIDFromSSRC(bool audio, uint32_t ssrc) { +std::string RTCInboundRTPStreamStatsIDFromSSRC(cricket::MediaType media_type, + uint32_t ssrc) { char buf[1024]; rtc::SimpleStringBuilder sb(buf); - sb << "RTCInboundRTP" << (audio ? "Audio" : "Video") << "Stream_" << ssrc; + sb << "RTCInboundRTP" + << (media_type == cricket::MEDIA_TYPE_AUDIO ? "Audio" : "Video") + << "Stream_" << ssrc; return sb.str(); } -std::string RTCOutboundRTPStreamStatsIDFromSSRC(bool audio, uint32_t ssrc) { +std::string RTCOutboundRTPStreamStatsIDFromSSRC(cricket::MediaType media_type, + uint32_t ssrc) { char buf[1024]; rtc::SimpleStringBuilder sb(buf); - sb << "RTCOutboundRTP" << (audio ? "Audio" : "Video") << "Stream_" << ssrc; + sb << "RTCOutboundRTP" + << (media_type == cricket::MEDIA_TYPE_AUDIO ? "Audio" : "Video") + << "Stream_" << ssrc; return sb.str(); } @@ -105,6 +140,17 @@ std::string RTCRemoteInboundRtpStreamStatsIdFromSourceSsrc( return sb.str(); } +std::string RTCRemoteOutboundRTPStreamStatsIDFromSSRC( + cricket::MediaType media_type, + uint32_t source_ssrc) { + char buf[1024]; + rtc::SimpleStringBuilder sb(buf); + sb << "RTCRemoteOutboundRTP" + << (media_type == cricket::MEDIA_TYPE_AUDIO ? "Audio" : "Video") + << "Stream_" << source_ssrc; + return sb.str(); +} + std::string RTCMediaSourceStatsIDFromKindAndAttachment( cricket::MediaType media_type, int attachment_id) { @@ -163,20 +209,20 @@ const char* IceCandidatePairStateToRTCStatsIceCandidatePairState( } const char* DtlsTransportStateToRTCDtlsTransportState( - cricket::DtlsTransportState state) { + DtlsTransportState state) { switch (state) { - case cricket::DTLS_TRANSPORT_NEW: + case DtlsTransportState::kNew: return RTCDtlsTransportState::kNew; - case cricket::DTLS_TRANSPORT_CONNECTING: + case DtlsTransportState::kConnecting: return RTCDtlsTransportState::kConnecting; - case cricket::DTLS_TRANSPORT_CONNECTED: + case DtlsTransportState::kConnected: return RTCDtlsTransportState::kConnected; - case cricket::DTLS_TRANSPORT_CLOSED: + case DtlsTransportState::kClosed: return RTCDtlsTransportState::kClosed; - case cricket::DTLS_TRANSPORT_FAILED: + case DtlsTransportState::kFailed: return RTCDtlsTransportState::kFailed; default: - RTC_NOTREACHED(); + RTC_CHECK_NOTREACHED(); return nullptr; } } @@ -219,6 +265,17 @@ const char* QualityLimitationReasonToRTCQualityLimitationReason( RTC_CHECK_NOTREACHED(); } +std::map +QualityLimitationDurationToRTCQualityLimitationDuration( + std::map durations_ms) { + std::map result; + for (const auto& elem : durations_ms) { + result[QualityLimitationReasonToRTCQualityLimitationReason(elem.first)] = + elem.second; + } + return result; +} + double DoubleAudioLevelFromIntAudioLevel(int audio_level) { RTC_DCHECK_GE(audio_level, 0); RTC_DCHECK_LE(audio_level, 32767); @@ -268,8 +325,6 @@ void SetInboundRTPStreamStatsFromMediaReceiverInfo( RTCInboundRTPStreamStats* inbound_stats) { RTC_DCHECK(inbound_stats); inbound_stats->ssrc = media_receiver_info.ssrc(); - // TODO(hbos): Support the remote case. https://crbug.com/657855 - inbound_stats->is_remote = false; inbound_stats->packets_received = static_cast(media_receiver_info.packets_rcvd); inbound_stats->bytes_received = @@ -278,26 +333,33 @@ void SetInboundRTPStreamStatsFromMediaReceiverInfo( static_cast(media_receiver_info.header_and_padding_bytes_rcvd); inbound_stats->packets_lost = static_cast(media_receiver_info.packets_lost); + inbound_stats->jitter_buffer_delay = + media_receiver_info.jitter_buffer_delay_seconds; + inbound_stats->jitter_buffer_emitted_count = + media_receiver_info.jitter_buffer_emitted_count; + if (media_receiver_info.nacks_sent) { + inbound_stats->nack_count = *media_receiver_info.nacks_sent; + } } -void SetInboundRTPStreamStatsFromVoiceReceiverInfo( - const std::string& mid, +std::unique_ptr CreateInboundAudioStreamStats( const cricket::VoiceReceiverInfo& voice_receiver_info, - RTCInboundRTPStreamStats* inbound_audio) { + const std::string& mid, + int64_t timestamp_us) { + auto inbound_audio = std::make_unique( + /*id=*/RTCInboundRTPStreamStatsIDFromSSRC(cricket::MEDIA_TYPE_AUDIO, + voice_receiver_info.ssrc()), + timestamp_us); SetInboundRTPStreamStatsFromMediaReceiverInfo(voice_receiver_info, - inbound_audio); + inbound_audio.get()); inbound_audio->media_type = "audio"; inbound_audio->kind = "audio"; if (voice_receiver_info.codec_payload_type) { inbound_audio->codec_id = RTCCodecStatsIDFromMidDirectionAndPayload( - mid, true, *voice_receiver_info.codec_payload_type); + mid, /*inbound=*/true, *voice_receiver_info.codec_payload_type); } inbound_audio->jitter = static_cast(voice_receiver_info.jitter_ms) / rtc::kNumMillisecsPerSec; - inbound_audio->jitter_buffer_delay = - voice_receiver_info.jitter_buffer_delay_seconds; - inbound_audio->jitter_buffer_emitted_count = - voice_receiver_info.jitter_buffer_emitted_count; inbound_audio->total_samples_received = voice_receiver_info.total_samples_received; inbound_audio->concealed_samples = voice_receiver_info.concealed_samples; @@ -318,12 +380,11 @@ void SetInboundRTPStreamStatsFromVoiceReceiverInfo( // |fir_count|, |pli_count| and |sli_count| are only valid for video and are // purposefully left undefined for audio. if (voice_receiver_info.last_packet_received_timestamp_ms) { - inbound_audio->last_packet_received_timestamp = - static_cast( - *voice_receiver_info.last_packet_received_timestamp_ms) / - rtc::kNumMillisecsPerSec; + inbound_audio->last_packet_received_timestamp = static_cast( + *voice_receiver_info.last_packet_received_timestamp_ms); } if (voice_receiver_info.estimated_playout_ntp_timestamp_ms) { + // TODO(bugs.webrtc.org/10529): Fix time origin. inbound_audio->estimated_playout_timestamp = static_cast( *voice_receiver_info.estimated_playout_ntp_timestamp_ms); } @@ -331,6 +392,51 @@ void SetInboundRTPStreamStatsFromVoiceReceiverInfo( voice_receiver_info.fec_packets_received; inbound_audio->fec_packets_discarded = voice_receiver_info.fec_packets_discarded; + return inbound_audio; +} + +std::unique_ptr +CreateRemoteOutboundAudioStreamStats( + const cricket::VoiceReceiverInfo& voice_receiver_info, + const std::string& mid, + const std::string& inbound_audio_id, + const std::string& transport_id) { + if (!voice_receiver_info.last_sender_report_timestamp_ms.has_value()) { + // Cannot create `RTCRemoteOutboundRtpStreamStats` when the RTCP SR arrival + // timestamp is not available - i.e., until the first sender report is + // received. + return nullptr; + } + RTC_DCHECK_GT(voice_receiver_info.sender_reports_reports_count, 0); + + // Create. + auto stats = std::make_unique( + /*id=*/RTCRemoteOutboundRTPStreamStatsIDFromSSRC( + cricket::MEDIA_TYPE_AUDIO, voice_receiver_info.ssrc()), + /*timestamp_us=*/rtc::kNumMicrosecsPerMillisec * + voice_receiver_info.last_sender_report_timestamp_ms.value()); + + // Populate. + // - RTCRtpStreamStats. + stats->ssrc = voice_receiver_info.ssrc(); + stats->kind = "audio"; + stats->transport_id = transport_id; + stats->codec_id = RTCCodecStatsIDFromMidDirectionAndPayload( + mid, + /*inbound=*/true, // Remote-outbound same as local-inbound. + *voice_receiver_info.codec_payload_type); + // - RTCSentRtpStreamStats. + stats->packets_sent = voice_receiver_info.sender_reports_packets_sent; + stats->bytes_sent = voice_receiver_info.sender_reports_bytes_sent; + // - RTCRemoteOutboundRtpStreamStats. + stats->local_id = inbound_audio_id; + RTC_DCHECK( + voice_receiver_info.last_sender_report_remote_timestamp_ms.has_value()); + stats->remote_timestamp = static_cast( + voice_receiver_info.last_sender_report_remote_timestamp_ms.value()); + stats->reports_sent = voice_receiver_info.sender_reports_reports_count; + + return stats; } void SetInboundRTPStreamStatsFromVideoReceiverInfo( @@ -343,14 +449,14 @@ void SetInboundRTPStreamStatsFromVideoReceiverInfo( inbound_video->kind = "video"; if (video_receiver_info.codec_payload_type) { inbound_video->codec_id = RTCCodecStatsIDFromMidDirectionAndPayload( - mid, true, *video_receiver_info.codec_payload_type); + mid, /*inbound=*/true, *video_receiver_info.codec_payload_type); } + inbound_video->jitter = static_cast(video_receiver_info.jitter_ms) / + rtc::kNumMillisecsPerSec; inbound_video->fir_count = static_cast(video_receiver_info.firs_sent); inbound_video->pli_count = static_cast(video_receiver_info.plis_sent); - inbound_video->nack_count = - static_cast(video_receiver_info.nacks_sent); inbound_video->frames_received = video_receiver_info.frames_received; inbound_video->frames_decoded = video_receiver_info.frames_decoded; inbound_video->frames_dropped = video_receiver_info.frames_dropped; @@ -376,17 +482,16 @@ void SetInboundRTPStreamStatsFromVideoReceiverInfo( inbound_video->total_squared_inter_frame_delay = video_receiver_info.total_squared_inter_frame_delay; if (video_receiver_info.last_packet_received_timestamp_ms) { - inbound_video->last_packet_received_timestamp = - static_cast( - *video_receiver_info.last_packet_received_timestamp_ms) / - rtc::kNumMillisecsPerSec; + inbound_video->last_packet_received_timestamp = static_cast( + *video_receiver_info.last_packet_received_timestamp_ms); } if (video_receiver_info.estimated_playout_ntp_timestamp_ms) { + // TODO(bugs.webrtc.org/10529): Fix time origin if needed. inbound_video->estimated_playout_timestamp = static_cast( *video_receiver_info.estimated_playout_ntp_timestamp_ms); } - // TODO(https://crbug.com/webrtc/10529): When info's |content_info| is - // optional, support the "unspecified" value. + // TODO(bugs.webrtc.org/10529): When info's |content_info| is optional + // support the "unspecified" value. if (video_receiver_info.content_type == VideoContentType::SCREENSHARE) inbound_video->content_type = RTCContentType::kScreenshare; if (!video_receiver_info.decoder_implementation_name.empty()) { @@ -401,8 +506,6 @@ void SetOutboundRTPStreamStatsFromMediaSenderInfo( RTCOutboundRTPStreamStats* outbound_stats) { RTC_DCHECK(outbound_stats); outbound_stats->ssrc = media_sender_info.ssrc(); - // TODO(hbos): Support the remote case. https://crbug.com/657856 - outbound_stats->is_remote = false; outbound_stats->packets_sent = static_cast(media_sender_info.packets_sent); outbound_stats->retransmitted_packets_sent = @@ -413,6 +516,7 @@ void SetOutboundRTPStreamStatsFromMediaSenderInfo( static_cast(media_sender_info.header_and_padding_bytes_sent); outbound_stats->retransmitted_bytes_sent = media_sender_info.retransmitted_bytes_sent; + outbound_stats->nack_count = media_sender_info.nacks_rcvd; } void SetOutboundRTPStreamStatsFromVoiceSenderInfo( @@ -425,7 +529,7 @@ void SetOutboundRTPStreamStatsFromVoiceSenderInfo( outbound_audio->kind = "audio"; if (voice_sender_info.codec_payload_type) { outbound_audio->codec_id = RTCCodecStatsIDFromMidDirectionAndPayload( - mid, false, *voice_sender_info.codec_payload_type); + mid, /*inbound=*/false, *voice_sender_info.codec_payload_type); } // |fir_count|, |pli_count| and |sli_count| are only valid for video and are // purposefully left undefined for audio. @@ -441,14 +545,12 @@ void SetOutboundRTPStreamStatsFromVideoSenderInfo( outbound_video->kind = "video"; if (video_sender_info.codec_payload_type) { outbound_video->codec_id = RTCCodecStatsIDFromMidDirectionAndPayload( - mid, false, *video_sender_info.codec_payload_type); + mid, /*inbound=*/false, *video_sender_info.codec_payload_type); } outbound_video->fir_count = static_cast(video_sender_info.firs_rcvd); outbound_video->pli_count = static_cast(video_sender_info.plis_rcvd); - outbound_video->nack_count = - static_cast(video_sender_info.nacks_rcvd); if (video_sender_info.qp_sum) outbound_video->qp_sum = *video_sender_info.qp_sum; outbound_video->frames_encoded = video_sender_info.frames_encoded; @@ -477,6 +579,9 @@ void SetOutboundRTPStreamStatsFromVideoSenderInfo( outbound_video->quality_limitation_reason = QualityLimitationReasonToRTCQualityLimitationReason( video_sender_info.quality_limitation_reason); + outbound_video->quality_limitation_durations = + QualityLimitationDurationToRTCQualityLimitationDuration( + video_sender_info.quality_limitation_durations_ms); outbound_video->quality_limitation_resolution_changes = video_sender_info.quality_limitation_resolution_changes; // TODO(https://crbug.com/webrtc/10529): When info's |content_info| is @@ -510,12 +615,19 @@ ProduceRemoteInboundRtpStreamStatsFromReportBlockData( remote_inbound->kind = media_type == cricket::MEDIA_TYPE_AUDIO ? "audio" : "video"; remote_inbound->packets_lost = report_block.packets_lost; + remote_inbound->fraction_lost = + static_cast(report_block.fraction_lost) / (1 << 8); remote_inbound->round_trip_time = static_cast(report_block_data.last_rtt_ms()) / rtc::kNumMillisecsPerSec; + remote_inbound->total_round_trip_time = + static_cast(report_block_data.sum_rtt_ms()) / + rtc::kNumMillisecsPerSec; + remote_inbound->round_trip_time_measurements = + report_block_data.num_rtts(); - std::string local_id = RTCOutboundRTPStreamStatsIDFromSSRC( - media_type == cricket::MEDIA_TYPE_AUDIO, report_block.source_ssrc); + std::string local_id = + RTCOutboundRTPStreamStatsIDFromSSRC(media_type, report_block.source_ssrc); // Look up local stat from |outbound_rtps| where the pointers are non-const. auto local_id_it = outbound_rtps.find(local_id); if (local_id_it != outbound_rtps.end()) { @@ -616,6 +728,7 @@ const std::string& ProduceIceCandidateStats(int64_t timestamp_us, RTC_DCHECK_EQ(rtc::ADAPTER_TYPE_UNKNOWN, candidate.network_type()); } candidate_stats->ip = candidate.address().ipaddr().ToString(); + candidate_stats->address = candidate.address().ipaddr().ToString(); candidate_stats->port = static_cast(candidate.address().port()); candidate_stats->protocol = candidate.protocol(); candidate_stats->candidate_type = @@ -630,10 +743,22 @@ const std::string& ProduceIceCandidateStats(int64_t timestamp_us, return stats->id(); } +template +void SetAudioProcessingStats(StatsType* stats, + const AudioProcessingStats& apm_stats) { + if (apm_stats.echo_return_loss) { + stats->echo_return_loss = *apm_stats.echo_return_loss; + } + if (apm_stats.echo_return_loss_enhancement) { + stats->echo_return_loss_enhancement = + *apm_stats.echo_return_loss_enhancement; + } +} + std::unique_ptr ProduceMediaStreamTrackStatsFromVoiceSenderInfo( int64_t timestamp_us, - const AudioTrackInterface& audio_track, + AudioTrackInterface& audio_track, const cricket::VoiceSenderInfo& voice_sender_info, int attachment_id) { std::unique_ptr audio_track_stats( @@ -648,13 +773,17 @@ ProduceMediaStreamTrackStatsFromVoiceSenderInfo( attachment_id); audio_track_stats->remote_source = false; audio_track_stats->detached = false; - if (voice_sender_info.apm_statistics.echo_return_loss) { - audio_track_stats->echo_return_loss = - *voice_sender_info.apm_statistics.echo_return_loss; - } - if (voice_sender_info.apm_statistics.echo_return_loss_enhancement) { - audio_track_stats->echo_return_loss_enhancement = - *voice_sender_info.apm_statistics.echo_return_loss_enhancement; + // Audio processor may be attached to either the track or the send + // stream, so look in both places. + SetAudioProcessingStats(audio_track_stats.get(), + voice_sender_info.apm_statistics); + auto audio_processor(audio_track.GetAudioProcessor()); + if (audio_processor.get()) { + // The |has_remote_tracks| argument is obsolete; makes no difference if it's + // set to true or false. + AudioProcessorInterface::AudioProcessorStatistics ap_stats = + audio_processor->GetStats(/*has_remote_tracks=*/false); + SetAudioProcessingStats(audio_track_stats.get(), ap_stats.apm_statistics); } return audio_track_stats; } @@ -998,8 +1127,7 @@ RTCStatsCollector::RequestInfo::RequestInfo( rtc::scoped_refptr RTCStatsCollector::Create( PeerConnectionInternal* pc, int64_t cache_lifetime_us) { - return rtc::scoped_refptr( - new rtc::RefCountedObject(pc, cache_lifetime_us)); + return rtc::make_ref_counted(pc, cache_lifetime_us); } RTCStatsCollector::RTCStatsCollector(PeerConnectionInternal* pc, @@ -1019,8 +1147,6 @@ RTCStatsCollector::RTCStatsCollector(PeerConnectionInternal* pc, RTC_DCHECK(worker_thread_); RTC_DCHECK(network_thread_); RTC_DCHECK_GE(cache_lifetime_us_, 0); - pc_->SignalRtpDataChannelCreated().connect( - this, &RTCStatsCollector::OnRtpDataChannelCreated); pc_->SignalSctpDataChannelCreated().connect( this, &RTCStatsCollector::OnSctpDataChannelCreated); } @@ -1048,7 +1174,7 @@ void RTCStatsCollector::GetStatsReport( void RTCStatsCollector::GetStatsReportInternal( RTCStatsCollector::RequestInfo request) { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); requests_.push_back(std::move(request)); // "Now" using a monotonically increasing timer. @@ -1060,9 +1186,30 @@ void RTCStatsCollector::GetStatsReportInternal( // reentrancy problems. std::vector requests; requests.swap(requests_); - signaling_thread_->PostTask( - RTC_FROM_HERE, rtc::Bind(&RTCStatsCollector::DeliverCachedReport, this, - cached_report_, std::move(requests))); + + // Task subclass to take ownership of the requests. + // TODO(nisse): Delete when we can use C++14, and do lambda capture with + // std::move. + class DeliveryTask : public QueuedTask { + public: + DeliveryTask(rtc::scoped_refptr collector, + rtc::scoped_refptr cached_report, + std::vector requests) + : collector_(collector), + cached_report_(cached_report), + requests_(std::move(requests)) {} + bool Run() override { + collector_->DeliverCachedReport(cached_report_, std::move(requests_)); + return true; + } + + private: + rtc::scoped_refptr collector_; + rtc::scoped_refptr cached_report_; + std::vector requests_; + }; + signaling_thread_->PostTask(std::make_unique( + this, cached_report_, std::move(requests))); } else if (!num_pending_partial_reports_) { // Only start gathering stats if we're not already gathering stats. In the // case of already gathering stats, |callback_| will be invoked when there @@ -1079,30 +1226,30 @@ void RTCStatsCollector::GetStatsReportInternal( // Prepare |transceiver_stats_infos_| and |call_stats_| for use in // |ProducePartialResultsOnNetworkThread| and // |ProducePartialResultsOnSignalingThread|. - PrepareTransceiverStatsInfosAndCallStats_s_w(); - // Prepare |transport_names_| for use in - // |ProducePartialResultsOnNetworkThread|. - transport_names_ = PrepareTransportNames_s(); - + PrepareTransceiverStatsInfosAndCallStats_s_w_n(); // Don't touch |network_report_| on the signaling thread until // ProducePartialResultsOnNetworkThread() has signaled the // |network_report_event_|. network_report_event_.Reset(); + rtc::scoped_refptr collector(this); network_thread_->PostTask( RTC_FROM_HERE, - rtc::Bind(&RTCStatsCollector::ProducePartialResultsOnNetworkThread, - this, timestamp_us)); + [collector, sctp_transport_name = pc_->sctp_transport_name(), + timestamp_us]() mutable { + collector->ProducePartialResultsOnNetworkThread( + timestamp_us, std::move(sctp_transport_name)); + }); ProducePartialResultsOnSignalingThread(timestamp_us); } } void RTCStatsCollector::ClearCachedStatsReport() { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); cached_report_ = nullptr; } void RTCStatsCollector::WaitForPendingRequest() { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); // If a request is pending, blocks until the |network_report_event_| is // signaled and then delivers the result. Otherwise this is a NO-OP. MergeNetworkReport_s(); @@ -1110,7 +1257,7 @@ void RTCStatsCollector::WaitForPendingRequest() { void RTCStatsCollector::ProducePartialResultsOnSignalingThread( int64_t timestamp_us) { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; partial_report_ = RTCStatsReport::Create(timestamp_us); @@ -1129,7 +1276,7 @@ void RTCStatsCollector::ProducePartialResultsOnSignalingThread( void RTCStatsCollector::ProducePartialResultsOnSignalingThreadImpl( int64_t timestamp_us, RTCStatsReport* partial_report) { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; ProduceDataChannelStats_s(timestamp_us, partial_report); @@ -1140,16 +1287,29 @@ void RTCStatsCollector::ProducePartialResultsOnSignalingThreadImpl( } void RTCStatsCollector::ProducePartialResultsOnNetworkThread( - int64_t timestamp_us) { - RTC_DCHECK(network_thread_->IsCurrent()); + int64_t timestamp_us, + absl::optional sctp_transport_name) { + TRACE_EVENT0("webrtc", + "RTCStatsCollector::ProducePartialResultsOnNetworkThread"); + RTC_DCHECK_RUN_ON(network_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; // Touching |network_report_| on this thread is safe by this method because // |network_report_event_| is reset before this method is invoked. network_report_ = RTCStatsReport::Create(timestamp_us); + std::set transport_names; + if (sctp_transport_name) { + transport_names.emplace(std::move(*sctp_transport_name)); + } + + for (const auto& info : transceiver_stats_infos_) { + if (info.transport_name) + transport_names.insert(*info.transport_name); + } + std::map transport_stats_by_name = - pc_->GetTransportStatsByNames(transport_names_); + pc_->GetTransportStatsByNames(transport_names); std::map transport_cert_stats = PrepareTransportCertificateStats_n(transport_stats_by_name); @@ -1160,8 +1320,9 @@ void RTCStatsCollector::ProducePartialResultsOnNetworkThread( // Signal that it is now safe to touch |network_report_| on the signaling // thread, and post a task to merge it into the final results. network_report_event_.Set(); + rtc::scoped_refptr collector(this); signaling_thread_->PostTask( - RTC_FROM_HERE, rtc::Bind(&RTCStatsCollector::MergeNetworkReport_s, this)); + RTC_FROM_HERE, [collector] { collector->MergeNetworkReport_s(); }); } void RTCStatsCollector::ProducePartialResultsOnNetworkThreadImpl( @@ -1170,7 +1331,7 @@ void RTCStatsCollector::ProducePartialResultsOnNetworkThreadImpl( transport_stats_by_name, const std::map& transport_cert_stats, RTCStatsReport* partial_report) { - RTC_DCHECK(network_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(network_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; ProduceCertificateStats_n(timestamp_us, transport_cert_stats, partial_report); @@ -1184,7 +1345,7 @@ void RTCStatsCollector::ProducePartialResultsOnNetworkThreadImpl( } void RTCStatsCollector::MergeNetworkReport_s() { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); // The |network_report_event_| must be signaled for it to be safe to touch // |network_report_|. This is normally not blocking, but if // WaitForPendingRequest() is called while a request is pending, we might have @@ -1227,7 +1388,7 @@ void RTCStatsCollector::MergeNetworkReport_s() { void RTCStatsCollector::DeliverCachedReport( rtc::scoped_refptr cached_report, std::vector requests) { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); RTC_DCHECK(!requests.empty()); RTC_DCHECK(cached_report); @@ -1258,7 +1419,7 @@ void RTCStatsCollector::ProduceCertificateStats_n( int64_t timestamp_us, const std::map& transport_cert_stats, RTCStatsReport* report) const { - RTC_DCHECK(network_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(network_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; for (const auto& transport_cert_stats_pair : transport_cert_stats) { @@ -1277,7 +1438,7 @@ void RTCStatsCollector::ProduceCodecStats_n( int64_t timestamp_us, const std::vector& transceiver_stats_infos, RTCStatsReport* report) const { - RTC_DCHECK(network_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(network_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; for (const auto& stats : transceiver_stats_infos) { @@ -1349,7 +1510,7 @@ void RTCStatsCollector::ProduceIceCandidateAndPairStats_n( transport_stats_by_name, const Call::Stats& call_stats, RTCStatsReport* report) const { - RTC_DCHECK(network_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(network_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; for (const auto& entry : transport_stats_by_name) { @@ -1431,7 +1592,7 @@ void RTCStatsCollector::ProduceIceCandidateAndPairStats_n( void RTCStatsCollector::ProduceMediaStreamStats_s( int64_t timestamp_us, RTCStatsReport* report) const { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; std::map> track_ids; @@ -1468,7 +1629,7 @@ void RTCStatsCollector::ProduceMediaStreamStats_s( void RTCStatsCollector::ProduceMediaStreamTrackStats_s( int64_t timestamp_us, RTCStatsReport* report) const { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; for (const RtpTransceiverStatsInfo& stats : transceiver_stats_infos_) { @@ -1491,7 +1652,7 @@ void RTCStatsCollector::ProduceMediaStreamTrackStats_s( void RTCStatsCollector::ProduceMediaSourceStats_s( int64_t timestamp_us, RTCStatsReport* report) const { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; for (const RtpTransceiverStatsInfo& transceiver_stats_info : @@ -1512,6 +1673,8 @@ void RTCStatsCollector::ProduceMediaSourceStats_s( // create separate media source stats objects on a per-attachment basis. std::unique_ptr media_source_stats; if (track->kind() == MediaStreamTrackInterface::kAudioKind) { + AudioTrackInterface* audio_track = + static_cast(track.get()); auto audio_source_stats = std::make_unique( RTCMediaSourceStatsIDFromKindAndAttachment( cricket::MEDIA_TYPE_AUDIO, sender_internal->AttachmentId()), @@ -1532,8 +1695,21 @@ void RTCStatsCollector::ProduceMediaSourceStats_s( voice_sender_info->total_input_energy; audio_source_stats->total_samples_duration = voice_sender_info->total_input_duration; + SetAudioProcessingStats(audio_source_stats.get(), + voice_sender_info->apm_statistics); } } + // Audio processor may be attached to either the track or the send + // stream, so look in both places. + auto audio_processor(audio_track->GetAudioProcessor()); + if (audio_processor.get()) { + // The |has_remote_tracks| argument is obsolete; makes no difference + // if it's set to true or false. + AudioProcessorInterface::AudioProcessorStatistics ap_stats = + audio_processor->GetStats(/*has_remote_tracks=*/false); + SetAudioProcessingStats(audio_source_stats.get(), + ap_stats.apm_statistics); + } media_source_stats = std::move(audio_source_stats); } else { RTC_DCHECK_EQ(MediaStreamTrackInterface::kVideoKind, track->kind()); @@ -1560,6 +1736,7 @@ void RTCStatsCollector::ProduceMediaSourceStats_s( if (video_sender_info) { video_source_stats->frames_per_second = video_sender_info->framerate_input; + video_source_stats->frames = video_sender_info->frames; } } media_source_stats = std::move(video_source_stats); @@ -1574,7 +1751,7 @@ void RTCStatsCollector::ProduceMediaSourceStats_s( void RTCStatsCollector::ProducePeerConnectionStats_s( int64_t timestamp_us, RTCStatsReport* report) const { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; std::unique_ptr stats( @@ -1588,7 +1765,7 @@ void RTCStatsCollector::ProduceRTPStreamStats_n( int64_t timestamp_us, const std::vector& transceiver_stats_infos, RTCStatsReport* report) const { - RTC_DCHECK(network_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(network_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; for (const RtpTransceiverStatsInfo& stats : transceiver_stats_infos) { @@ -1606,7 +1783,7 @@ void RTCStatsCollector::ProduceAudioRTPStreamStats_n( int64_t timestamp_us, const RtpTransceiverStatsInfo& stats, RTCStatsReport* report) const { - RTC_DCHECK(network_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(network_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; if (!stats.mid || !stats.transport_name) { @@ -1618,16 +1795,16 @@ void RTCStatsCollector::ProduceAudioRTPStreamStats_n( std::string mid = *stats.mid; std::string transport_id = RTCTransportStatsIDFromTransportChannel( *stats.transport_name, cricket::ICE_CANDIDATE_COMPONENT_RTP); - // Inbound + // Inbound and remote-outbound. + // The remote-outbound stats are based on RTCP sender reports sent from the + // remote endpoint providing metrics about the remote outbound streams. for (const cricket::VoiceReceiverInfo& voice_receiver_info : track_media_info_map.voice_media_info()->receivers) { if (!voice_receiver_info.connected()) continue; - auto inbound_audio = std::make_unique( - RTCInboundRTPStreamStatsIDFromSSRC(true, voice_receiver_info.ssrc()), - timestamp_us); - SetInboundRTPStreamStatsFromVoiceReceiverInfo(mid, voice_receiver_info, - inbound_audio.get()); + // Inbound. + auto inbound_audio = + CreateInboundAudioStreamStats(voice_receiver_info, mid, timestamp_us); // TODO(hta): This lookup should look for the sender, not the track. rtc::scoped_refptr audio_track = track_media_info_map.GetAudioTrack(voice_receiver_info); @@ -1638,16 +1815,27 @@ void RTCStatsCollector::ProduceAudioRTPStreamStats_n( track_media_info_map.GetAttachmentIdByTrack(audio_track).value()); } inbound_audio->transport_id = transport_id; + // Remote-outbound. + auto remote_outbound_audio = CreateRemoteOutboundAudioStreamStats( + voice_receiver_info, mid, inbound_audio->id(), transport_id); + // Add stats. + if (remote_outbound_audio) { + // When the remote outbound stats are available, the remote ID for the + // local inbound stats is set. + inbound_audio->remote_id = remote_outbound_audio->id(); + report->AddStats(std::move(remote_outbound_audio)); + } report->AddStats(std::move(inbound_audio)); } - // Outbound + // Outbound. std::map audio_outbound_rtps; for (const cricket::VoiceSenderInfo& voice_sender_info : track_media_info_map.voice_media_info()->senders) { if (!voice_sender_info.connected()) continue; auto outbound_audio = std::make_unique( - RTCOutboundRTPStreamStatsIDFromSSRC(true, voice_sender_info.ssrc()), + RTCOutboundRTPStreamStatsIDFromSSRC(cricket::MEDIA_TYPE_AUDIO, + voice_sender_info.ssrc()), timestamp_us); SetOutboundRTPStreamStatsFromVoiceSenderInfo(mid, voice_sender_info, outbound_audio.get()); @@ -1668,7 +1856,7 @@ void RTCStatsCollector::ProduceAudioRTPStreamStats_n( std::make_pair(outbound_audio->id(), outbound_audio.get())); report->AddStats(std::move(outbound_audio)); } - // Remote-inbound + // Remote-inbound. // These are Report Block-based, information sent from the remote endpoint, // providing metrics about our Outbound streams. We take advantage of the fact // that RTCOutboundRtpStreamStats, RTCCodecStats and RTCTransport have already @@ -1687,7 +1875,7 @@ void RTCStatsCollector::ProduceVideoRTPStreamStats_n( int64_t timestamp_us, const RtpTransceiverStatsInfo& stats, RTCStatsReport* report) const { - RTC_DCHECK(network_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(network_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; if (!stats.mid || !stats.transport_name) { @@ -1705,7 +1893,8 @@ void RTCStatsCollector::ProduceVideoRTPStreamStats_n( if (!video_receiver_info.connected()) continue; auto inbound_video = std::make_unique( - RTCInboundRTPStreamStatsIDFromSSRC(false, video_receiver_info.ssrc()), + RTCInboundRTPStreamStatsIDFromSSRC(cricket::MEDIA_TYPE_VIDEO, + video_receiver_info.ssrc()), timestamp_us); SetInboundRTPStreamStatsFromVideoReceiverInfo(mid, video_receiver_info, inbound_video.get()); @@ -1719,6 +1908,7 @@ void RTCStatsCollector::ProduceVideoRTPStreamStats_n( } inbound_video->transport_id = transport_id; report->AddStats(std::move(inbound_video)); + // TODO(crbug.com/webrtc/12529): Add remote-outbound stats. } // Outbound std::map video_outbound_rtps; @@ -1727,7 +1917,8 @@ void RTCStatsCollector::ProduceVideoRTPStreamStats_n( if (!video_sender_info.connected()) continue; auto outbound_video = std::make_unique( - RTCOutboundRTPStreamStatsIDFromSSRC(false, video_sender_info.ssrc()), + RTCOutboundRTPStreamStatsIDFromSSRC(cricket::MEDIA_TYPE_VIDEO, + video_sender_info.ssrc()), timestamp_us); SetOutboundRTPStreamStatsFromVideoSenderInfo(mid, video_sender_info, outbound_video.get()); @@ -1769,7 +1960,7 @@ void RTCStatsCollector::ProduceTransportStats_n( transport_stats_by_name, const std::map& transport_cert_stats, RTCStatsReport* report) const { - RTC_DCHECK(network_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(network_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; for (const auto& entry : transport_stats_by_name) { @@ -1867,7 +2058,7 @@ std::map RTCStatsCollector::PrepareTransportCertificateStats_n( const std::map& transport_stats_by_name) const { - RTC_DCHECK(network_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(network_thread_); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; std::map transport_cert_stats; @@ -1893,8 +2084,8 @@ RTCStatsCollector::PrepareTransportCertificateStats_n( return transport_cert_stats; } -void RTCStatsCollector::PrepareTransceiverStatsInfosAndCallStats_s_w() { - RTC_DCHECK(signaling_thread_->IsCurrent()); +void RTCStatsCollector::PrepareTransceiverStatsInfosAndCallStats_s_w_n() { + RTC_DCHECK_RUN_ON(signaling_thread_); transceiver_stats_infos_.clear(); // These are used to invoke GetStats for all the media channels together in @@ -1906,20 +2097,26 @@ void RTCStatsCollector::PrepareTransceiverStatsInfosAndCallStats_s_w() { std::unique_ptr> video_stats; - { + auto transceivers = pc_->GetTransceiversInternal(); + + // TODO(tommi): See if we can avoid synchronously blocking the signaling + // thread while we do this (or avoid the Invoke at all). + network_thread_->Invoke(RTC_FROM_HERE, [this, &transceivers, + &voice_stats, &video_stats] { rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; - for (const auto& transceiver : pc_->GetTransceiversInternal()) { + for (const auto& transceiver_proxy : transceivers) { + RtpTransceiver* transceiver = transceiver_proxy->internal(); cricket::MediaType media_type = transceiver->media_type(); // Prepare stats entry. The TrackMediaInfoMap will be filled in after the // stats have been fetched on the worker thread. transceiver_stats_infos_.emplace_back(); RtpTransceiverStatsInfo& stats = transceiver_stats_infos_.back(); - stats.transceiver = transceiver->internal(); + stats.transceiver = transceiver; stats.media_type = media_type; - cricket::ChannelInterface* channel = transceiver->internal()->channel(); + cricket::ChannelInterface* channel = transceiver->channel(); if (!channel) { // The remaining fields require a BaseChannel. continue; @@ -1944,7 +2141,7 @@ void RTCStatsCollector::PrepareTransceiverStatsInfosAndCallStats_s_w() { RTC_NOTREACHED(); } } - } + }); // We jump to the worker thread and call GetStats() on each media channel as // well as GetCallStats(). At the same time we construct the @@ -2003,38 +2200,13 @@ void RTCStatsCollector::PrepareTransceiverStatsInfosAndCallStats_s_w() { }); } -std::set RTCStatsCollector::PrepareTransportNames_s() const { - RTC_DCHECK(signaling_thread_->IsCurrent()); - rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; - - std::set transport_names; - for (const auto& transceiver : pc_->GetTransceiversInternal()) { - if (transceiver->internal()->channel()) { - transport_names.insert( - transceiver->internal()->channel()->transport_name()); - } - } - if (pc_->rtp_data_channel()) { - transport_names.insert(pc_->rtp_data_channel()->transport_name()); - } - if (pc_->sctp_transport_name()) { - transport_names.insert(*pc_->sctp_transport_name()); - } - return transport_names; -} - -void RTCStatsCollector::OnRtpDataChannelCreated(RtpDataChannel* channel) { - channel->SignalOpened.connect(this, &RTCStatsCollector::OnDataChannelOpened); - channel->SignalClosed.connect(this, &RTCStatsCollector::OnDataChannelClosed); -} - void RTCStatsCollector::OnSctpDataChannelCreated(SctpDataChannel* channel) { channel->SignalOpened.connect(this, &RTCStatsCollector::OnDataChannelOpened); channel->SignalClosed.connect(this, &RTCStatsCollector::OnDataChannelClosed); } void RTCStatsCollector::OnDataChannelOpened(DataChannelInterface* channel) { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); bool result = internal_record_.opened_data_channels .insert(reinterpret_cast(channel)) .second; @@ -2043,7 +2215,7 @@ void RTCStatsCollector::OnDataChannelOpened(DataChannelInterface* channel) { } void RTCStatsCollector::OnDataChannelClosed(DataChannelInterface* channel) { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); // Only channels that have been fully opened (and have increased the // |data_channels_opened_| counter) increase the closed counter. if (internal_record_.opened_data_channels.erase( diff --git a/pc/rtc_stats_collector.h b/pc/rtc_stats_collector.h index 35576e91d8..5f13f54d26 100644 --- a/pc/rtc_stats_collector.h +++ b/pc/rtc_stats_collector.h @@ -11,6 +11,7 @@ #ifndef PC_RTC_STATS_COLLECTOR_H_ #define PC_RTC_STATS_COLLECTOR_H_ +#include #include #include #include @@ -18,6 +19,8 @@ #include #include "absl/types/optional.h" +#include "api/data_channel_interface.h" +#include "api/media_types.h" #include "api/scoped_refptr.h" #include "api/stats/rtc_stats_collector_callback.h" #include "api/stats/rtc_stats_report.h" @@ -26,11 +29,19 @@ #include "media/base/media_channel.h" #include "pc/data_channel_utils.h" #include "pc/peer_connection_internal.h" +#include "pc/rtp_receiver.h" +#include "pc/rtp_sender.h" +#include "pc/rtp_transceiver.h" +#include "pc/sctp_data_channel.h" #include "pc/track_media_info_map.h" +#include "pc/transport_stats.h" +#include "rtc_base/checks.h" #include "rtc_base/event.h" #include "rtc_base/ref_count.h" +#include "rtc_base/ssl_certificate.h" #include "rtc_base/ssl_identity.h" #include "rtc_base/third_party/sigslot/sigslot.h" +#include "rtc_base/thread.h" #include "rtc_base/time_utils.h" namespace webrtc { @@ -42,7 +53,7 @@ class RtpReceiverInternal; // Stats are gathered on the signaling, worker and network threads // asynchronously. The callback is invoked on the signaling thread. Resulting // reports are cached for |cache_lifetime_| ms. -class RTCStatsCollector : public virtual rtc::RefCountInterface, +class RTCStatsCollector : public rtc::RefCountInterface, public sigslot::has_slots<> { public: static rtc::scoped_refptr Create( @@ -216,18 +227,18 @@ class RTCStatsCollector : public virtual rtc::RefCountInterface, const std::map& transport_stats_by_name) const; // The results are stored in |transceiver_stats_infos_| and |call_stats_|. - void PrepareTransceiverStatsInfosAndCallStats_s_w(); - std::set PrepareTransportNames_s() const; + void PrepareTransceiverStatsInfosAndCallStats_s_w_n(); // Stats gathering on a particular thread. void ProducePartialResultsOnSignalingThread(int64_t timestamp_us); - void ProducePartialResultsOnNetworkThread(int64_t timestamp_us); + void ProducePartialResultsOnNetworkThread( + int64_t timestamp_us, + absl::optional sctp_transport_name); // Merges |network_report_| into |partial_report_| and completes the request. // This is a NO-OP if |network_report_| is null. void MergeNetworkReport_s(); // Slots for signals (sigslot) that are wired up to |pc_|. - void OnRtpDataChannelCreated(RtpDataChannel* channel); void OnSctpDataChannelCreated(SctpDataChannel* channel); // Slots for signals (sigslot) that are wired up to |channel|. void OnDataChannelOpened(DataChannelInterface* channel); @@ -256,12 +267,16 @@ class RTCStatsCollector : public virtual rtc::RefCountInterface, // has updated the value of |network_report_|. rtc::Event network_report_event_; - // Set in |GetStatsReport|, read in |ProducePartialResultsOnNetworkThread| and - // |ProducePartialResultsOnSignalingThread|, reset after work is complete. Not - // passed as arguments to avoid copies. This is thread safe - when we - // set/reset we know there are no pending stats requests in progress. + // Cleared and set in `PrepareTransceiverStatsInfosAndCallStats_s_w_n`, + // starting out on the signaling thread, then network. Later read on the + // network and signaling threads as part of collecting stats and finally + // reset when the work is done. Initially this variable was added and not + // passed around as an arguments to avoid copies. This is thread safe due to + // how operations are sequenced and we don't start the stats collection + // sequence if one is in progress. As a future improvement though, we could + // now get rid of the variable and keep the data scoped within a stats + // collection sequence. std::vector transceiver_stats_infos_; - std::set transport_names_; Call::Stats call_stats_; diff --git a/pc/rtc_stats_collector_unittest.cc b/pc/rtc_stats_collector_unittest.cc index 73579ff259..2ac0737715 100644 --- a/pc/rtc_stats_collector_unittest.cc +++ b/pc/rtc_stats_collector_unittest.cc @@ -22,6 +22,7 @@ #include "absl/memory/memory.h" #include "absl/strings/str_replace.h" +#include "api/dtls_transport_interface.h" #include "api/media_stream_track.h" #include "api/rtp_parameters.h" #include "api/stats/rtc_stats_report.h" @@ -47,6 +48,7 @@ #include "rtc_base/synchronization/mutex.h" #include "rtc_base/time_utils.h" +using ::testing::_; using ::testing::AtLeast; using ::testing::Invoke; using ::testing::Return; @@ -118,6 +120,14 @@ namespace { const int64_t kGetStatsReportTimeoutMs = 1000; +// Fake data used by `SetupExampleStatsVoiceGraph()` to fill in remote outbound +// stats. +constexpr int64_t kRemoteOutboundStatsTimestampMs = 123; +constexpr int64_t kRemoteOutboundStatsRemoteTimestampMs = 456; +constexpr uint32_t kRemoteOutboundStatsPacketsSent = 7u; +constexpr uint64_t kRemoteOutboundStatsBytesSent = 8u; +constexpr uint64_t kRemoteOutboundStatsReportsCount = 9u; + struct CertificateInfo { rtc::scoped_refptr certificate; std::vector ders; @@ -190,14 +200,34 @@ std::unique_ptr CreateFakeCandidate( return candidate; } +class FakeAudioProcessor : public AudioProcessorInterface { + public: + FakeAudioProcessor() {} + ~FakeAudioProcessor() {} + + private: + AudioProcessorInterface::AudioProcessorStatistics GetStats( + bool has_recv_streams) override { + AudioProcessorStatistics stats; + stats.apm_statistics.echo_return_loss = 2.0; + stats.apm_statistics.echo_return_loss_enhancement = 3.0; + return stats; + } +}; + class FakeAudioTrackForStats : public MediaStreamTrack { public: static rtc::scoped_refptr Create( const std::string& id, - MediaStreamTrackInterface::TrackState state) { + MediaStreamTrackInterface::TrackState state, + bool create_fake_audio_processor) { rtc::scoped_refptr audio_track_stats( new rtc::RefCountedObject(id)); audio_track_stats->set_state(state); + if (create_fake_audio_processor) { + audio_track_stats->processor_ = + rtc::make_ref_counted(); + } return audio_track_stats; } @@ -212,8 +242,11 @@ class FakeAudioTrackForStats : public MediaStreamTrack { void RemoveSink(webrtc::AudioTrackSinkInterface* sink) override {} bool GetSignalLevel(int* level) override { return false; } rtc::scoped_refptr GetAudioProcessor() override { - return nullptr; + return processor_; } + + private: + rtc::scoped_refptr processor_; }; class FakeVideoTrackSourceForStats : public VideoTrackSourceInterface { @@ -298,9 +331,11 @@ class FakeVideoTrackForStats : public MediaStreamTrack { rtc::scoped_refptr CreateFakeTrack( cricket::MediaType media_type, const std::string& track_id, - MediaStreamTrackInterface::TrackState track_state) { + MediaStreamTrackInterface::TrackState track_state, + bool create_fake_audio_processor = false) { if (media_type == cricket::MEDIA_TYPE_AUDIO) { - return FakeAudioTrackForStats::Create(track_id, track_state); + return FakeAudioTrackForStats::Create(track_id, track_state, + create_fake_audio_processor); } else { RTC_DCHECK_EQ(media_type, cricket::MEDIA_TYPE_VIDEO); return FakeVideoTrackForStats::Create(track_id, track_state, nullptr); @@ -331,6 +366,8 @@ rtc::scoped_refptr CreateMockSender( })); EXPECT_CALL(*sender, AttachmentId()).WillRepeatedly(Return(attachment_id)); EXPECT_CALL(*sender, stream_ids()).WillRepeatedly(Return(local_stream_ids)); + EXPECT_CALL(*sender, SetTransceiverAsStopped()); + EXPECT_CALL(*sender, Stop()); return sender; } @@ -357,6 +394,7 @@ rtc::scoped_refptr CreateMockReceiver( return params; })); EXPECT_CALL(*receiver, AttachmentId()).WillRepeatedly(Return(attachment_id)); + EXPECT_CALL(*receiver, StopAndEndTrack()); return receiver; } @@ -498,6 +536,7 @@ class RTCStatsCollectorWrapper { rtc::scoped_refptr(local_audio_track), voice_sender_info.local_stats[0].ssrc, voice_sender_info.local_stats[0].ssrc + 10, local_stream_ids); + EXPECT_CALL(*rtp_sender, SetMediaChannel(_)); pc_->AddSender(rtp_sender); } @@ -516,6 +555,7 @@ class RTCStatsCollectorWrapper { voice_receiver_info.local_stats[0].ssrc + 10); EXPECT_CALL(*rtp_receiver, streams()) .WillRepeatedly(Return(remote_streams)); + EXPECT_CALL(*rtp_receiver, SetMediaChannel(_)); pc_->AddReceiver(rtp_receiver); } @@ -533,6 +573,7 @@ class RTCStatsCollectorWrapper { rtc::scoped_refptr(local_video_track), video_sender_info.local_stats[0].ssrc, video_sender_info.local_stats[0].ssrc + 10, local_stream_ids); + EXPECT_CALL(*rtp_sender, SetMediaChannel(_)); pc_->AddSender(rtp_sender); } @@ -551,6 +592,7 @@ class RTCStatsCollectorWrapper { video_receiver_info.local_stats[0].ssrc + 10); EXPECT_CALL(*rtp_receiver, streams()) .WillRepeatedly(Return(remote_streams)); + EXPECT_CALL(*rtp_receiver, SetMediaChannel(_)); pc_->AddReceiver(rtp_receiver); } @@ -567,6 +609,11 @@ class RTCStatsCollectorWrapper { EXPECT_TRUE_WAIT(callback->report(), kGetStatsReportTimeoutMs); int64_t after = rtc::TimeUTCMicros(); for (const RTCStats& stats : *callback->report()) { + if (stats.type() == RTCRemoteInboundRtpStreamStats::kType || + stats.type() == RTCRemoteOutboundRtpStreamStats::kType) { + // Ignore remote timestamps. + continue; + } EXPECT_LE(stats.timestamp_us(), after); } return callback->report(); @@ -611,6 +658,7 @@ class RTCStatsCollectorTest : public ::testing::Test { std::string recv_codec_id; std::string outbound_rtp_id; std::string inbound_rtp_id; + std::string remote_outbound_rtp_id; std::string transport_id; std::string sender_track_id; std::string receiver_track_id; @@ -619,9 +667,9 @@ class RTCStatsCollectorTest : public ::testing::Test { std::string media_source_id; }; - // Sets up the example stats graph (see ASCII art below) used for testing the - // stats selection algorithm, - // https://w3c.github.io/webrtc-pc/#dfn-stats-selection-algorithm. + // Sets up the example stats graph (see ASCII art below) for a video only + // call. The graph is used for testing the stats selection algorithm (see + // https://w3c.github.io/webrtc-pc/#dfn-stats-selection-algorithm). // These tests test the integration of the stats traversal algorithm inside of // RTCStatsCollector. See rtcstatstraveral_unittest.cc for more stats // traversal tests. @@ -723,6 +771,125 @@ class RTCStatsCollectorTest : public ::testing::Test { return graph; } + // Sets up an example stats graph (see ASCII art below) for an audio only call + // and checks that the expected stats are generated. + ExampleStatsGraph SetupExampleStatsVoiceGraph( + bool add_remote_outbound_stats) { + constexpr uint32_t kLocalSsrc = 3; + constexpr uint32_t kRemoteSsrc = 4; + ExampleStatsGraph graph; + + // codec (send) + graph.send_codec_id = "RTCCodec_VoiceMid_Outbound_1"; + cricket::VoiceMediaInfo media_info; + RtpCodecParameters send_codec; + send_codec.payload_type = 1; + send_codec.clock_rate = 0; + media_info.send_codecs.insert( + std::make_pair(send_codec.payload_type, send_codec)); + // codec (recv) + graph.recv_codec_id = "RTCCodec_VoiceMid_Inbound_2"; + RtpCodecParameters recv_codec; + recv_codec.payload_type = 2; + recv_codec.clock_rate = 0; + media_info.receive_codecs.insert( + std::make_pair(recv_codec.payload_type, recv_codec)); + // outbound-rtp + graph.outbound_rtp_id = "RTCOutboundRTPAudioStream_3"; + media_info.senders.push_back(cricket::VoiceSenderInfo()); + media_info.senders[0].local_stats.push_back(cricket::SsrcSenderInfo()); + media_info.senders[0].local_stats[0].ssrc = kLocalSsrc; + media_info.senders[0].codec_payload_type = send_codec.payload_type; + // inbound-rtp + graph.inbound_rtp_id = "RTCInboundRTPAudioStream_4"; + media_info.receivers.push_back(cricket::VoiceReceiverInfo()); + media_info.receivers[0].local_stats.push_back(cricket::SsrcReceiverInfo()); + media_info.receivers[0].local_stats[0].ssrc = kRemoteSsrc; + media_info.receivers[0].codec_payload_type = recv_codec.payload_type; + // remote-outbound-rtp + if (add_remote_outbound_stats) { + graph.remote_outbound_rtp_id = "RTCRemoteOutboundRTPAudioStream_4"; + media_info.receivers[0].last_sender_report_timestamp_ms = + kRemoteOutboundStatsTimestampMs; + media_info.receivers[0].last_sender_report_remote_timestamp_ms = + kRemoteOutboundStatsRemoteTimestampMs; + media_info.receivers[0].sender_reports_packets_sent = + kRemoteOutboundStatsPacketsSent; + media_info.receivers[0].sender_reports_bytes_sent = + kRemoteOutboundStatsBytesSent; + media_info.receivers[0].sender_reports_reports_count = + kRemoteOutboundStatsReportsCount; + } + + // transport + graph.transport_id = "RTCTransport_TransportName_1"; + auto* video_media_channel = + pc_->AddVoiceChannel("VoiceMid", "TransportName"); + video_media_channel->SetStats(media_info); + // track (sender) + graph.sender = stats_->SetupLocalTrackAndSender( + cricket::MEDIA_TYPE_AUDIO, "LocalAudioTrackID", kLocalSsrc, false, 50); + graph.sender_track_id = "RTCMediaStreamTrack_sender_" + + rtc::ToString(graph.sender->AttachmentId()); + // track (receiver) and stream (remote stream) + graph.receiver = stats_->SetupRemoteTrackAndReceiver( + cricket::MEDIA_TYPE_AUDIO, "RemoteAudioTrackID", "RemoteStreamId", + kRemoteSsrc); + graph.receiver_track_id = "RTCMediaStreamTrack_receiver_" + + rtc::ToString(graph.receiver->AttachmentId()); + graph.remote_stream_id = "RTCMediaStream_RemoteStreamId"; + // peer-connection + graph.peer_connection_id = "RTCPeerConnection"; + // media-source (kind: video) + graph.media_source_id = + "RTCAudioSource_" + rtc::ToString(graph.sender->AttachmentId()); + + // Expected stats graph: + // + // +--- track (sender) stream (remote stream) ---> track (receiver) + // | ^ ^ + // | | | + // | +--------- outbound-rtp inbound-rtp ---------------+ + // | | | | | | + // | | v v v v + // | | codec (send) transport codec (recv) peer-connection + // v v + // media-source + + // Verify the stats graph is set up correctly. + graph.full_report = stats_->GetStatsReport(); + EXPECT_EQ(graph.full_report->size(), add_remote_outbound_stats ? 11u : 10u); + EXPECT_TRUE(graph.full_report->Get(graph.send_codec_id)); + EXPECT_TRUE(graph.full_report->Get(graph.recv_codec_id)); + EXPECT_TRUE(graph.full_report->Get(graph.outbound_rtp_id)); + EXPECT_TRUE(graph.full_report->Get(graph.inbound_rtp_id)); + EXPECT_TRUE(graph.full_report->Get(graph.transport_id)); + EXPECT_TRUE(graph.full_report->Get(graph.sender_track_id)); + EXPECT_TRUE(graph.full_report->Get(graph.receiver_track_id)); + EXPECT_TRUE(graph.full_report->Get(graph.remote_stream_id)); + EXPECT_TRUE(graph.full_report->Get(graph.peer_connection_id)); + EXPECT_TRUE(graph.full_report->Get(graph.media_source_id)); + // `graph.remote_outbound_rtp_id` is omitted on purpose so that expectations + // can be added by the caller depending on what value it sets for the + // `add_remote_outbound_stats` argument. + const auto& sender_track = graph.full_report->Get(graph.sender_track_id) + ->cast_to(); + EXPECT_EQ(*sender_track.media_source_id, graph.media_source_id); + const auto& outbound_rtp = graph.full_report->Get(graph.outbound_rtp_id) + ->cast_to(); + EXPECT_EQ(*outbound_rtp.media_source_id, graph.media_source_id); + EXPECT_EQ(*outbound_rtp.codec_id, graph.send_codec_id); + EXPECT_EQ(*outbound_rtp.track_id, graph.sender_track_id); + EXPECT_EQ(*outbound_rtp.transport_id, graph.transport_id); + const auto& inbound_rtp = graph.full_report->Get(graph.inbound_rtp_id) + ->cast_to(); + EXPECT_EQ(*inbound_rtp.codec_id, graph.recv_codec_id); + EXPECT_EQ(*inbound_rtp.track_id, graph.receiver_track_id); + EXPECT_EQ(*inbound_rtp.transport_id, graph.transport_id); + + return graph; + } + protected: rtc::ScopedFakeClock fake_clock_; rtc::scoped_refptr pc_; @@ -784,9 +951,14 @@ TEST_F(RTCStatsCollectorTest, ToJsonProducesParseableJson) { ExampleStatsGraph graph = SetupExampleStatsGraphForSelectorTests(); rtc::scoped_refptr report = stats_->GetStatsReport(); std::string json_format = report->ToJson(); - Json::Reader reader; + + Json::CharReaderBuilder builder; Json::Value json_value; - ASSERT_TRUE(reader.parse(json_format, json_value)); + std::unique_ptr reader(builder.newCharReader()); + ASSERT_TRUE(reader->parse(json_format.c_str(), + json_format.c_str() + json_format.size(), + &json_value, nullptr)); + // A very brief sanity check on the result. EXPECT_EQ(report->size(), json_value.size()); } @@ -1075,6 +1247,7 @@ TEST_F(RTCStatsCollectorTest, CollectRTCIceCandidateStats) { expected_a_local_host.transport_id = "RTCTransport_a_0"; expected_a_local_host.network_type = "vpn"; expected_a_local_host.ip = "1.2.3.4"; + expected_a_local_host.address = "1.2.3.4"; expected_a_local_host.port = 5; expected_a_local_host.protocol = "a_local_host's protocol"; expected_a_local_host.candidate_type = "host"; @@ -1088,11 +1261,11 @@ TEST_F(RTCStatsCollectorTest, CollectRTCIceCandidateStats) { "RTCIceCandidate_" + a_remote_srflx->id(), 0); expected_a_remote_srflx.transport_id = "RTCTransport_a_0"; expected_a_remote_srflx.ip = "6.7.8.9"; + expected_a_remote_srflx.address = "6.7.8.9"; expected_a_remote_srflx.port = 10; expected_a_remote_srflx.protocol = "remote_srflx's protocol"; expected_a_remote_srflx.candidate_type = "srflx"; expected_a_remote_srflx.priority = 1; - expected_a_remote_srflx.deleted = false; EXPECT_TRUE(*expected_a_remote_srflx.is_remote); std::unique_ptr a_local_prflx = CreateFakeCandidate( @@ -1103,11 +1276,11 @@ TEST_F(RTCStatsCollectorTest, CollectRTCIceCandidateStats) { expected_a_local_prflx.transport_id = "RTCTransport_a_0"; expected_a_local_prflx.network_type = "cellular"; expected_a_local_prflx.ip = "11.12.13.14"; + expected_a_local_prflx.address = "11.12.13.14"; expected_a_local_prflx.port = 15; expected_a_local_prflx.protocol = "a_local_prflx's protocol"; expected_a_local_prflx.candidate_type = "prflx"; expected_a_local_prflx.priority = 2; - expected_a_local_prflx.deleted = false; EXPECT_FALSE(*expected_a_local_prflx.is_remote); std::unique_ptr a_remote_relay = CreateFakeCandidate( @@ -1117,11 +1290,11 @@ TEST_F(RTCStatsCollectorTest, CollectRTCIceCandidateStats) { "RTCIceCandidate_" + a_remote_relay->id(), 0); expected_a_remote_relay.transport_id = "RTCTransport_a_0"; expected_a_remote_relay.ip = "16.17.18.19"; + expected_a_remote_relay.address = "16.17.18.19"; expected_a_remote_relay.port = 20; expected_a_remote_relay.protocol = "a_remote_relay's protocol"; expected_a_remote_relay.candidate_type = "relay"; expected_a_remote_relay.priority = 3; - expected_a_remote_relay.deleted = false; EXPECT_TRUE(*expected_a_remote_relay.is_remote); std::unique_ptr a_local_relay = CreateFakeCandidate( @@ -1133,12 +1306,12 @@ TEST_F(RTCStatsCollectorTest, CollectRTCIceCandidateStats) { "RTCIceCandidate_" + a_local_relay->id(), 0); expected_a_local_relay.transport_id = "RTCTransport_a_0"; expected_a_local_relay.ip = "16.17.18.19"; + expected_a_local_relay.address = "16.17.18.19"; expected_a_local_relay.port = 21; expected_a_local_relay.protocol = "a_local_relay's protocol"; expected_a_local_relay.relay_protocol = "tcp"; expected_a_local_relay.candidate_type = "relay"; expected_a_local_relay.priority = 1; - expected_a_local_relay.deleted = false; EXPECT_TRUE(*expected_a_local_relay.is_remote); // Candidates in the second transport stats. @@ -1150,11 +1323,11 @@ TEST_F(RTCStatsCollectorTest, CollectRTCIceCandidateStats) { expected_b_local.transport_id = "RTCTransport_b_0"; expected_b_local.network_type = "wifi"; expected_b_local.ip = "42.42.42.42"; + expected_b_local.address = "42.42.42.42"; expected_b_local.port = 42; expected_b_local.protocol = "b_local's protocol"; expected_b_local.candidate_type = "host"; expected_b_local.priority = 42; - expected_b_local.deleted = false; EXPECT_FALSE(*expected_b_local.is_remote); std::unique_ptr b_remote = CreateFakeCandidate( @@ -1164,11 +1337,11 @@ TEST_F(RTCStatsCollectorTest, CollectRTCIceCandidateStats) { "RTCIceCandidate_" + b_remote->id(), 0); expected_b_remote.transport_id = "RTCTransport_b_0"; expected_b_remote.ip = "42.42.42.42"; + expected_b_remote.address = "42.42.42.42"; expected_b_remote.port = 42; expected_b_remote.protocol = "b_remote's protocol"; expected_b_remote.candidate_type = "host"; expected_b_remote.priority = 42; - expected_b_remote.deleted = false; EXPECT_TRUE(*expected_b_remote.is_remote); // Add candidate pairs to connection. @@ -1365,11 +1538,11 @@ TEST_F(RTCStatsCollectorTest, CollectRTCIceCandidatePairStats) { expected_local_candidate.transport_id = *expected_pair.transport_id; expected_local_candidate.network_type = "wifi"; expected_local_candidate.ip = "42.42.42.42"; + expected_local_candidate.address = "42.42.42.42"; expected_local_candidate.port = 42; expected_local_candidate.protocol = "protocol"; expected_local_candidate.candidate_type = "host"; expected_local_candidate.priority = 42; - expected_local_candidate.deleted = false; EXPECT_FALSE(*expected_local_candidate.is_remote); ASSERT_TRUE(report->Get(expected_local_candidate.id())); EXPECT_EQ(expected_local_candidate, @@ -1380,11 +1553,11 @@ TEST_F(RTCStatsCollectorTest, CollectRTCIceCandidatePairStats) { *expected_pair.remote_candidate_id, report->timestamp_us()); expected_remote_candidate.transport_id = *expected_pair.transport_id; expected_remote_candidate.ip = "42.42.42.42"; + expected_remote_candidate.address = "42.42.42.42"; expected_remote_candidate.port = 42; expected_remote_candidate.protocol = "protocol"; expected_remote_candidate.candidate_type = "host"; expected_remote_candidate.priority = 42; - expected_remote_candidate.deleted = false; EXPECT_TRUE(*expected_remote_candidate.is_remote); ASSERT_TRUE(report->Get(expected_remote_candidate.id())); EXPECT_EQ(expected_remote_candidate, @@ -1564,7 +1737,7 @@ TEST_F(RTCStatsCollectorTest, voice_receiver_info.inserted_samples_for_deceleration = 987; voice_receiver_info.removed_samples_for_acceleration = 876; voice_receiver_info.silent_concealed_samples = 765; - voice_receiver_info.jitter_buffer_delay_seconds = 3456; + voice_receiver_info.jitter_buffer_delay_seconds = 3.456; voice_receiver_info.jitter_buffer_emitted_count = 13; voice_receiver_info.jitter_buffer_target_delay_seconds = 7.894; voice_receiver_info.jitter_buffer_flushes = 7; @@ -1609,7 +1782,7 @@ TEST_F(RTCStatsCollectorTest, expected_remote_audio_track.inserted_samples_for_deceleration = 987; expected_remote_audio_track.removed_samples_for_acceleration = 876; expected_remote_audio_track.silent_concealed_samples = 765; - expected_remote_audio_track.jitter_buffer_delay = 3456; + expected_remote_audio_track.jitter_buffer_delay = 3.456; expected_remote_audio_track.jitter_buffer_emitted_count = 13; expected_remote_audio_track.jitter_buffer_target_delay = 7.894; expected_remote_audio_track.jitter_buffer_flushes = 7; @@ -1781,6 +1954,7 @@ TEST_F(RTCStatsCollectorTest, CollectRTCInboundRTPStreamStats_Audio) { voice_media_info.receivers[0].local_stats[0].ssrc = 1; voice_media_info.receivers[0].packets_lost = -1; // Signed per RFC3550 voice_media_info.receivers[0].packets_rcvd = 2; + voice_media_info.receivers[0].nacks_sent = 5; voice_media_info.receivers[0].fec_packets_discarded = 5566; voice_media_info.receivers[0].fec_packets_received = 6677; voice_media_info.receivers[0].payload_bytes_rcvd = 3; @@ -1823,13 +1997,13 @@ TEST_F(RTCStatsCollectorTest, CollectRTCInboundRTPStreamStats_Audio) { RTCInboundRTPStreamStats expected_audio("RTCInboundRTPAudioStream_1", report->timestamp_us()); expected_audio.ssrc = 1; - expected_audio.is_remote = false; expected_audio.media_type = "audio"; expected_audio.kind = "audio"; expected_audio.track_id = stats_of_track_type[0]->id(); expected_audio.transport_id = "RTCTransport_TransportName_1"; expected_audio.codec_id = "RTCCodec_AudioMid_Inbound_42"; expected_audio.packets_received = 2; + expected_audio.nack_count = 5; expected_audio.fec_packets_discarded = 5566; expected_audio.fec_packets_received = 6677; expected_audio.bytes_received = 3; @@ -1856,7 +2030,7 @@ TEST_F(RTCStatsCollectorTest, CollectRTCInboundRTPStreamStats_Audio) { // Set previously undefined values and "GetStats" again. voice_media_info.receivers[0].last_packet_received_timestamp_ms = 3000; - expected_audio.last_packet_received_timestamp = 3.0; + expected_audio.last_packet_received_timestamp = 3000.0; voice_media_info.receivers[0].estimated_playout_ntp_timestamp_ms = 4567; expected_audio.estimated_playout_timestamp = 4567; voice_media_channel->SetStats(voice_media_info); @@ -1895,6 +2069,9 @@ TEST_F(RTCStatsCollectorTest, CollectRTCInboundRTPStreamStats_Video) { video_media_info.receivers[0].total_decode_time_ms = 9000; video_media_info.receivers[0].total_inter_frame_delay = 0.123; video_media_info.receivers[0].total_squared_inter_frame_delay = 0.00456; + video_media_info.receivers[0].jitter_ms = 1199; + video_media_info.receivers[0].jitter_buffer_delay_seconds = 3.456; + video_media_info.receivers[0].jitter_buffer_emitted_count = 13; video_media_info.receivers[0].last_packet_received_timestamp_ms = absl::nullopt; @@ -1921,7 +2098,6 @@ TEST_F(RTCStatsCollectorTest, CollectRTCInboundRTPStreamStats_Video) { RTCInboundRTPStreamStats expected_video("RTCInboundRTPVideoStream_1", report->timestamp_us()); expected_video.ssrc = 1; - expected_video.is_remote = false; expected_video.media_type = "video"; expected_video.kind = "video"; expected_video.track_id = IdForType(report); @@ -1942,6 +2118,9 @@ TEST_F(RTCStatsCollectorTest, CollectRTCInboundRTPStreamStats_Video) { expected_video.total_decode_time = 9.0; expected_video.total_inter_frame_delay = 0.123; expected_video.total_squared_inter_frame_delay = 0.00456; + expected_video.jitter = 1.199; + expected_video.jitter_buffer_delay = 3.456; + expected_video.jitter_buffer_emitted_count = 13; // |expected_video.last_packet_received_timestamp| should be undefined. // |expected_video.content_type| should be undefined. // |expected_video.decoder_implementation| should be undefined. @@ -1955,7 +2134,7 @@ TEST_F(RTCStatsCollectorTest, CollectRTCInboundRTPStreamStats_Video) { video_media_info.receivers[0].qp_sum = 9; expected_video.qp_sum = 9; video_media_info.receivers[0].last_packet_received_timestamp_ms = 1000; - expected_video.last_packet_received_timestamp = 1.0; + expected_video.last_packet_received_timestamp = 1000.0; video_media_info.receivers[0].content_type = VideoContentType::SCREENSHARE; expected_video.content_type = "screenshare"; video_media_info.receivers[0].estimated_playout_ntp_timestamp_ms = 1234; @@ -1986,6 +2165,7 @@ TEST_F(RTCStatsCollectorTest, CollectRTCOutboundRTPStreamStats_Audio) { voice_media_info.senders[0].payload_bytes_sent = 3; voice_media_info.senders[0].header_and_padding_bytes_sent = 12; voice_media_info.senders[0].retransmitted_bytes_sent = 30; + voice_media_info.senders[0].nacks_rcvd = 31; voice_media_info.senders[0].codec_payload_type = 42; RtpCodecParameters codec_parameters; @@ -2009,7 +2189,6 @@ TEST_F(RTCStatsCollectorTest, CollectRTCOutboundRTPStreamStats_Audio) { expected_audio.media_source_id = "RTCAudioSource_50"; // |expected_audio.remote_id| should be undefined. expected_audio.ssrc = 1; - expected_audio.is_remote = false; expected_audio.media_type = "audio"; expected_audio.kind = "audio"; expected_audio.track_id = IdForType(report); @@ -2020,6 +2199,7 @@ TEST_F(RTCStatsCollectorTest, CollectRTCOutboundRTPStreamStats_Audio) { expected_audio.bytes_sent = 3; expected_audio.header_bytes_sent = 12; expected_audio.retransmitted_bytes_sent = 30; + expected_audio.nack_count = 31; ASSERT_TRUE(report->Get(expected_audio.id())); EXPECT_EQ( @@ -2057,6 +2237,8 @@ TEST_F(RTCStatsCollectorTest, CollectRTCOutboundRTPStreamStats_Video) { video_media_info.senders[0].total_packet_send_delay_ms = 10000; video_media_info.senders[0].quality_limitation_reason = QualityLimitationReason::kBandwidth; + video_media_info.senders[0].quality_limitation_durations_ms + [webrtc::QualityLimitationReason::kBandwidth] = 300; video_media_info.senders[0].quality_limitation_resolution_changes = 56u; video_media_info.senders[0].qp_sum = absl::nullopt; video_media_info.senders[0].content_type = VideoContentType::UNSPECIFIED; @@ -2093,7 +2275,6 @@ TEST_F(RTCStatsCollectorTest, CollectRTCOutboundRTPStreamStats_Video) { expected_video.media_source_id = "RTCVideoSource_50"; // |expected_video.remote_id| should be undefined. expected_video.ssrc = 1; - expected_video.is_remote = false; expected_video.media_type = "video"; expected_video.kind = "video"; expected_video.track_id = stats_of_track_type[0]->id(); @@ -2113,6 +2294,9 @@ TEST_F(RTCStatsCollectorTest, CollectRTCOutboundRTPStreamStats_Video) { expected_video.total_encoded_bytes_target = 1234; expected_video.total_packet_send_delay = 10.0; expected_video.quality_limitation_reason = "bandwidth"; + expected_video.quality_limitation_durations = std::map{ + std::pair{"bandwidth", 300.0}, + }; expected_video.quality_limitation_resolution_changes = 56u; expected_video.frame_width = 200u; expected_video.frame_height = 100u; @@ -2182,7 +2366,7 @@ TEST_F(RTCStatsCollectorTest, CollectRTCTransportStats) { rtp_transport_channel_stats.component = cricket::ICE_CANDIDATE_COMPONENT_RTP; rtp_transport_channel_stats.ice_transport_stats.connection_infos.push_back( rtp_connection_info); - rtp_transport_channel_stats.dtls_state = cricket::DTLS_TRANSPORT_NEW; + rtp_transport_channel_stats.dtls_state = DtlsTransportState::kNew; rtp_transport_channel_stats.ice_transport_stats .selected_candidate_pair_changes = 1; pc_->SetTransportStats(kTransportName, {rtp_transport_channel_stats}); @@ -2220,7 +2404,7 @@ TEST_F(RTCStatsCollectorTest, CollectRTCTransportStats) { cricket::ICE_CANDIDATE_COMPONENT_RTCP; rtcp_transport_channel_stats.ice_transport_stats.connection_infos.push_back( rtcp_connection_info); - rtcp_transport_channel_stats.dtls_state = cricket::DTLS_TRANSPORT_CONNECTING; + rtcp_transport_channel_stats.dtls_state = DtlsTransportState::kConnecting; pc_->SetTransportStats(kTransportName, {rtp_transport_channel_stats, rtcp_transport_channel_stats}); @@ -2336,7 +2520,7 @@ TEST_F(RTCStatsCollectorTest, CollectRTCTransportStatsWithCrypto) { rtp_transport_channel_stats.ice_transport_stats.connection_infos.push_back( rtp_connection_info); // The state must be connected in order for crypto parameters to show up. - rtp_transport_channel_stats.dtls_state = cricket::DTLS_TRANSPORT_CONNECTED; + rtp_transport_channel_stats.dtls_state = DtlsTransportState::kConnected; rtp_transport_channel_stats.ice_transport_stats .selected_candidate_pair_changes = 1; rtp_transport_channel_stats.ssl_version_bytes = 0x0203; @@ -2380,6 +2564,7 @@ TEST_F(RTCStatsCollectorTest, CollectNoStreamRTCOutboundRTPStreamStats_Audio) { voice_media_info.senders[0].payload_bytes_sent = 3; voice_media_info.senders[0].header_and_padding_bytes_sent = 4; voice_media_info.senders[0].retransmitted_bytes_sent = 30; + voice_media_info.senders[0].nacks_rcvd = 31; voice_media_info.senders[0].codec_payload_type = 42; RtpCodecParameters codec_parameters; @@ -2403,7 +2588,6 @@ TEST_F(RTCStatsCollectorTest, CollectNoStreamRTCOutboundRTPStreamStats_Audio) { report->timestamp_us()); expected_audio.media_source_id = "RTCAudioSource_50"; expected_audio.ssrc = 1; - expected_audio.is_remote = false; expected_audio.media_type = "audio"; expected_audio.kind = "audio"; expected_audio.track_id = IdForType(report); @@ -2414,6 +2598,7 @@ TEST_F(RTCStatsCollectorTest, CollectNoStreamRTCOutboundRTPStreamStats_Audio) { expected_audio.bytes_sent = 3; expected_audio.header_bytes_sent = 4; expected_audio.retransmitted_bytes_sent = 30; + expected_audio.nack_count = 31; ASSERT_TRUE(report->Get(expected_audio.id())); EXPECT_EQ( @@ -2435,6 +2620,9 @@ TEST_F(RTCStatsCollectorTest, RTCAudioSourceStatsCollectedForSenderWithTrack) { voice_media_info.senders[0].audio_level = 32767; // [0,32767] voice_media_info.senders[0].total_input_energy = 2.0; voice_media_info.senders[0].total_input_duration = 3.0; + voice_media_info.senders[0].apm_statistics.echo_return_loss = 42.0; + voice_media_info.senders[0].apm_statistics.echo_return_loss_enhancement = + 52.0; auto* voice_media_channel = pc_->AddVoiceChannel("AudioMid", "TransportName"); voice_media_channel->SetStats(voice_media_info); stats_->SetupLocalTrackAndSender(cricket::MEDIA_TYPE_AUDIO, @@ -2450,6 +2638,8 @@ TEST_F(RTCStatsCollectorTest, RTCAudioSourceStatsCollectedForSenderWithTrack) { expected_audio.audio_level = 1.0; // [0,1] expected_audio.total_audio_energy = 2.0; expected_audio.total_samples_duration = 3.0; + expected_audio.echo_return_loss = 42.0; + expected_audio.echo_return_loss_enhancement = 52.0; ASSERT_TRUE(report->Get(expected_audio.id())); EXPECT_EQ(report->Get(expected_audio.id())->cast_to(), @@ -2472,6 +2662,7 @@ TEST_F(RTCStatsCollectorTest, RTCVideoSourceStatsCollectedForSenderWithTrack) { cricket::SsrcSenderInfo()); video_media_info.aggregated_senders[0].local_stats[0].ssrc = kSsrc; video_media_info.aggregated_senders[0].framerate_input = 29; + video_media_info.aggregated_senders[0].frames = 10001; auto* video_media_channel = pc_->AddVideoChannel("VideoMid", "TransportName"); video_media_channel->SetStats(video_media_info); @@ -2491,9 +2682,8 @@ TEST_F(RTCStatsCollectorTest, RTCVideoSourceStatsCollectedForSenderWithTrack) { expected_video.kind = "video"; expected_video.width = kVideoSourceWidth; expected_video.height = kVideoSourceHeight; - // |expected_video.frames| is expected to be undefined because it is not set. - // TODO(hbos): When implemented, set its expected value here. expected_video.frames_per_second = 29; + expected_video.frames = 10001; ASSERT_TRUE(report->Get(expected_video.id())); EXPECT_EQ(report->Get(expected_video.id())->cast_to(), @@ -2533,6 +2723,7 @@ TEST_F(RTCStatsCollectorTest, auto video_stats = report->Get("RTCVideoSource_42")->cast_to(); EXPECT_FALSE(video_stats.frames_per_second.is_defined()); + EXPECT_FALSE(video_stats.frames.is_defined()); } // The track not having a source is not expected to be true in practise, but @@ -2674,8 +2865,11 @@ class RTCStatsCollectorTestWithParamKind TEST_P(RTCStatsCollectorTestWithParamKind, RTCRemoteInboundRtpStreamStatsCollectedFromReportBlock) { const int64_t kReportBlockTimestampUtcUs = 123456789; - const int64_t kRoundTripTimeMs = 13000; - const double kRoundTripTimeSeconds = 13.0; + const uint8_t kFractionLost = 12; + const int64_t kRoundTripTimeSample1Ms = 1234; + const double kRoundTripTimeSample1Seconds = 1.234; + const int64_t kRoundTripTimeSample2Ms = 13000; + const double kRoundTripTimeSample2Seconds = 13; // The report block's timestamp cannot be from the future, set the fake clock // to match. @@ -2688,12 +2882,13 @@ TEST_P(RTCStatsCollectorTestWithParamKind, // |source_ssrc|, "SSRC of the RTP packet sender". report_block.source_ssrc = ssrc; report_block.packets_lost = 7; + report_block.fraction_lost = kFractionLost; ReportBlockData report_block_data; report_block_data.SetReportBlock(report_block, kReportBlockTimestampUtcUs); - report_block_data.AddRoundTripTimeSample(1234); + report_block_data.AddRoundTripTimeSample(kRoundTripTimeSample1Ms); // Only the last sample should be exposed as the // |RTCRemoteInboundRtpStreamStats::round_trip_time|. - report_block_data.AddRoundTripTimeSample(kRoundTripTimeMs); + report_block_data.AddRoundTripTimeSample(kRoundTripTimeSample2Ms); report_block_datas.push_back(report_block_data); } AddSenderInfoAndMediaChannel("TransportName", report_block_datas, @@ -2706,6 +2901,8 @@ TEST_P(RTCStatsCollectorTestWithParamKind, "RTCRemoteInboundRtp" + MediaTypeUpperCase() + stream_id, kReportBlockTimestampUtcUs); expected_remote_inbound_rtp.ssrc = ssrc; + expected_remote_inbound_rtp.fraction_lost = + static_cast(kFractionLost) / (1 << 8); expected_remote_inbound_rtp.kind = MediaTypeLowerCase(); expected_remote_inbound_rtp.transport_id = "RTCTransport_TransportName_1"; // 1 for RTP (we have no RTCP @@ -2713,7 +2910,10 @@ TEST_P(RTCStatsCollectorTestWithParamKind, expected_remote_inbound_rtp.packets_lost = 7; expected_remote_inbound_rtp.local_id = "RTCOutboundRTP" + MediaTypeUpperCase() + stream_id; - expected_remote_inbound_rtp.round_trip_time = kRoundTripTimeSeconds; + expected_remote_inbound_rtp.round_trip_time = kRoundTripTimeSample2Seconds; + expected_remote_inbound_rtp.total_round_trip_time = + kRoundTripTimeSample1Seconds + kRoundTripTimeSample2Seconds; + expected_remote_inbound_rtp.round_trip_time_measurements = 2; // This test does not set up RTCCodecStats, so |codec_id| and |jitter| are // expected to be missing. These are tested separately. @@ -2814,11 +3014,11 @@ TEST_P(RTCStatsCollectorTestWithParamKind, cricket::TransportChannelStats rtp_transport_channel_stats; rtp_transport_channel_stats.component = cricket::ICE_CANDIDATE_COMPONENT_RTP; - rtp_transport_channel_stats.dtls_state = cricket::DTLS_TRANSPORT_NEW; + rtp_transport_channel_stats.dtls_state = DtlsTransportState::kNew; cricket::TransportChannelStats rtcp_transport_channel_stats; rtcp_transport_channel_stats.component = cricket::ICE_CANDIDATE_COMPONENT_RTCP; - rtcp_transport_channel_stats.dtls_state = cricket::DTLS_TRANSPORT_NEW; + rtcp_transport_channel_stats.dtls_state = DtlsTransportState::kNew; pc_->SetTransportStats("TransportName", {rtp_transport_channel_stats, rtcp_transport_channel_stats}); AddSenderInfoAndMediaChannel("TransportName", {report_block_data}, @@ -2843,6 +3043,43 @@ INSTANTIATE_TEST_SUITE_P(All, ::testing::Values(cricket::MEDIA_TYPE_AUDIO, // "/0" cricket::MEDIA_TYPE_VIDEO)); // "/1" +// Checks that no remote outbound stats are collected if not available in +// `VoiceMediaInfo`. +TEST_F(RTCStatsCollectorTest, + RTCRemoteOutboundRtpAudioStreamStatsNotCollected) { + ExampleStatsGraph graph = + SetupExampleStatsVoiceGraph(/*add_remote_outbound_stats=*/false); + EXPECT_FALSE(graph.full_report->Get(graph.remote_outbound_rtp_id)); + // Also check that no other remote outbound report is created (in case the + // expected ID is incorrect). + rtc::scoped_refptr report = stats_->GetStatsReport(); + ASSERT_NE(report->begin(), report->end()) + << "No reports have been generated."; + for (const auto& stats : *report) { + SCOPED_TRACE(stats.id()); + EXPECT_NE(stats.type(), RTCRemoteOutboundRtpStreamStats::kType); + } +} + +// Checks that the remote outbound stats are collected when available in +// `VoiceMediaInfo`. +TEST_F(RTCStatsCollectorTest, RTCRemoteOutboundRtpAudioStreamStatsCollected) { + ExampleStatsGraph graph = + SetupExampleStatsVoiceGraph(/*add_remote_outbound_stats=*/true); + ASSERT_TRUE(graph.full_report->Get(graph.remote_outbound_rtp_id)); + const auto& remote_outbound_rtp = + graph.full_report->Get(graph.remote_outbound_rtp_id) + ->cast_to(); + EXPECT_EQ(remote_outbound_rtp.timestamp_us(), + kRemoteOutboundStatsTimestampMs * rtc::kNumMicrosecsPerMillisec); + EXPECT_FLOAT_EQ(*remote_outbound_rtp.remote_timestamp, + static_cast(kRemoteOutboundStatsRemoteTimestampMs)); + EXPECT_EQ(*remote_outbound_rtp.packets_sent, kRemoteOutboundStatsPacketsSent); + EXPECT_EQ(*remote_outbound_rtp.bytes_sent, kRemoteOutboundStatsBytesSent); + EXPECT_EQ(*remote_outbound_rtp.reports_sent, + kRemoteOutboundStatsReportsCount); +} + TEST_F(RTCStatsCollectorTest, RTCVideoSourceStatsNotCollectedForSenderWithoutTrack) { const uint32_t kSsrc = 4; @@ -2864,6 +3101,64 @@ TEST_F(RTCStatsCollectorTest, EXPECT_FALSE(report->Get("RTCVideoSource_42")); } +// Test collecting echo return loss stats from the audio processor attached to +// the track, rather than the voice sender info. +TEST_F(RTCStatsCollectorTest, CollectEchoReturnLossFromTrackAudioProcessor) { + rtc::scoped_refptr local_stream = + MediaStream::Create("LocalStreamId"); + pc_->mutable_local_streams()->AddStream(local_stream); + + // Local audio track + rtc::scoped_refptr local_audio_track = + CreateFakeTrack(cricket::MEDIA_TYPE_AUDIO, "LocalAudioTrackID", + MediaStreamTrackInterface::kEnded, + /*create_fake_audio_processor=*/true); + local_stream->AddTrack( + static_cast(local_audio_track.get())); + + cricket::VoiceSenderInfo voice_sender_info_ssrc1; + voice_sender_info_ssrc1.local_stats.push_back(cricket::SsrcSenderInfo()); + voice_sender_info_ssrc1.local_stats[0].ssrc = 1; + + stats_->CreateMockRtpSendersReceiversAndChannels( + {std::make_pair(local_audio_track.get(), voice_sender_info_ssrc1)}, {}, + {}, {}, {local_stream->id()}, {}); + + rtc::scoped_refptr report = stats_->GetStatsReport(); + + RTCMediaStreamTrackStats expected_local_audio_track_ssrc1( + IdForType(report), report->timestamp_us(), + RTCMediaStreamTrackKind::kAudio); + expected_local_audio_track_ssrc1.track_identifier = local_audio_track->id(); + expected_local_audio_track_ssrc1.media_source_id = + "RTCAudioSource_11"; // Attachment ID = SSRC + 10 + expected_local_audio_track_ssrc1.remote_source = false; + expected_local_audio_track_ssrc1.ended = true; + expected_local_audio_track_ssrc1.detached = false; + expected_local_audio_track_ssrc1.echo_return_loss = 2.0; + expected_local_audio_track_ssrc1.echo_return_loss_enhancement = 3.0; + ASSERT_TRUE(report->Get(expected_local_audio_track_ssrc1.id())) + << "Did not find " << expected_local_audio_track_ssrc1.id() << " in " + << report->ToJson(); + EXPECT_EQ(expected_local_audio_track_ssrc1, + report->Get(expected_local_audio_track_ssrc1.id()) + ->cast_to()); + + RTCAudioSourceStats expected_audio("RTCAudioSource_11", + report->timestamp_us()); + expected_audio.track_identifier = "LocalAudioTrackID"; + expected_audio.kind = "audio"; + expected_audio.audio_level = 0; + expected_audio.total_audio_energy = 0; + expected_audio.total_samples_duration = 0; + expected_audio.echo_return_loss = 2.0; + expected_audio.echo_return_loss_enhancement = 3.0; + + ASSERT_TRUE(report->Get(expected_audio.id())); + EXPECT_EQ(report->Get(expected_audio.id())->cast_to(), + expected_audio); +} + TEST_F(RTCStatsCollectorTest, GetStatsWithSenderSelector) { ExampleStatsGraph graph = SetupExampleStatsGraphForSelectorTests(); // Expected stats graph when filtered by sender: @@ -3029,11 +3324,20 @@ class FakeRTCStatsCollector : public RTCStatsCollector, static rtc::scoped_refptr Create( PeerConnectionInternal* pc, int64_t cache_lifetime_us) { - return rtc::scoped_refptr( - new rtc::RefCountedObject(pc, - cache_lifetime_us)); + return new rtc::RefCountedObject(pc, + cache_lifetime_us); } + // Since FakeRTCStatsCollector inherits twice from RefCountInterface, once via + // RTCStatsCollector and once via RTCStatsCollectorCallback, scoped_refptr + // will get confused about which AddRef()/Release() methods to call. + // So to remove all doubt, we declare them here again in the class that we + // give to scoped_refptr. + // Satisfying the implementation of these methods and associating them with a + // reference counter, will be done by RefCountedObject. + virtual void AddRef() const = 0; + virtual rtc::RefCountReleaseStatus Release() const = 0; + // RTCStatsCollectorCallback implementation. void OnStatsDelivered( const rtc::scoped_refptr& report) override { diff --git a/pc/rtc_stats_integrationtest.cc b/pc/rtc_stats_integrationtest.cc index ee68ec9a0b..2dfe1b5cd5 100644 --- a/pc/rtc_stats_integrationtest.cc +++ b/pc/rtc_stats_integrationtest.cc @@ -114,9 +114,9 @@ class RTCStatsIntegrationTest : public ::testing::Test { RTC_CHECK(network_thread_->Start()); RTC_CHECK(worker_thread_->Start()); - caller_ = new rtc::RefCountedObject( + caller_ = rtc::make_ref_counted( "caller", network_thread_.get(), worker_thread_.get()); - callee_ = new rtc::RefCountedObject( + callee_ = rtc::make_ref_counted( "callee", network_thread_.get(), worker_thread_.get()); } @@ -399,6 +399,9 @@ class RTCStatsReportVerifier { } else if (stats.type() == RTCRemoteInboundRtpStreamStats::kType) { verify_successful &= VerifyRTCRemoteInboundRtpStreamStats( stats.cast_to()); + } else if (stats.type() == RTCRemoteOutboundRtpStreamStats::kType) { + verify_successful &= VerifyRTCRemoteOutboundRTPStreamStats( + stats.cast_to()); } else if (stats.type() == RTCAudioSourceStats::kType) { // RTCAudioSourceStats::kType and RTCVideoSourceStats::kType both have // the value "media-source", but they are distinguishable with pointer @@ -528,12 +531,12 @@ class RTCStatsReportVerifier { verifier.TestMemberIsDefined(candidate.network_type); } verifier.TestMemberIsDefined(candidate.ip); + verifier.TestMemberIsDefined(candidate.address); verifier.TestMemberIsNonNegative(candidate.port); verifier.TestMemberIsDefined(candidate.protocol); verifier.TestMemberIsDefined(candidate.candidate_type); verifier.TestMemberIsNonNegative(candidate.priority); verifier.TestMemberIsUndefined(candidate.url); - verifier.TestMemberIsDefined(candidate.deleted); verifier.TestMemberIsUndefined(candidate.relay_protocol); return verifier.ExpectAllMembersSuccessfullyTested(); } @@ -768,32 +771,38 @@ class RTCStatsReportVerifier { } void VerifyRTCRTPStreamStats(const RTCRTPStreamStats& stream, - RTCStatsVerifier* verifier) { - verifier->TestMemberIsDefined(stream.ssrc); - verifier->TestMemberIsDefined(stream.is_remote); - verifier->TestMemberIsDefined(stream.media_type); - verifier->TestMemberIsDefined(stream.kind); - verifier->TestMemberIsIDReference(stream.track_id, - RTCMediaStreamTrackStats::kType); - verifier->TestMemberIsIDReference(stream.transport_id, - RTCTransportStats::kType); - verifier->TestMemberIsIDReference(stream.codec_id, RTCCodecStats::kType); - if (stream.media_type.is_defined() && *stream.media_type == "video") { - verifier->TestMemberIsNonNegative(stream.fir_count); - verifier->TestMemberIsNonNegative(stream.pli_count); - verifier->TestMemberIsNonNegative(stream.nack_count); + RTCStatsVerifier& verifier) { + verifier.TestMemberIsDefined(stream.ssrc); + verifier.TestMemberIsDefined(stream.kind); + // Some legacy metrics are only defined for some of the RTP types in the + // hierarcy. + if (stream.type() == RTCInboundRTPStreamStats::kType || + stream.type() == RTCOutboundRTPStreamStats::kType) { + verifier.TestMemberIsDefined(stream.media_type); + verifier.TestMemberIsIDReference(stream.track_id, + RTCMediaStreamTrackStats::kType); } else { - verifier->TestMemberIsUndefined(stream.fir_count); - verifier->TestMemberIsUndefined(stream.pli_count); - verifier->TestMemberIsUndefined(stream.nack_count); + verifier.TestMemberIsUndefined(stream.media_type); + verifier.TestMemberIsUndefined(stream.track_id); } - verifier->TestMemberIsUndefined(stream.sli_count); + verifier.TestMemberIsIDReference(stream.transport_id, + RTCTransportStats::kType); + verifier.TestMemberIsIDReference(stream.codec_id, RTCCodecStats::kType); + } + + void VerifyRTCSentRTPStreamStats(const RTCSentRtpStreamStats& sent_stream, + RTCStatsVerifier& verifier) { + VerifyRTCRTPStreamStats(sent_stream, verifier); + verifier.TestMemberIsDefined(sent_stream.packets_sent); + verifier.TestMemberIsDefined(sent_stream.bytes_sent); } bool VerifyRTCInboundRTPStreamStats( const RTCInboundRTPStreamStats& inbound_stream) { RTCStatsVerifier verifier(report_, &inbound_stream); - VerifyRTCRTPStreamStats(inbound_stream, &verifier); + VerifyRTCReceivedRtpStreamStats(inbound_stream, verifier); + verifier.TestMemberIsOptionalIDReference( + inbound_stream.remote_id, RTCRemoteOutboundRtpStreamStats::kType); if (inbound_stream.media_type.is_defined() && *inbound_stream.media_type == "video") { verifier.TestMemberIsNonNegative(inbound_stream.qp_sum); @@ -816,9 +825,6 @@ class RTCStatsReportVerifier { verifier.TestMemberIsNonNegative(inbound_stream.bytes_received); verifier.TestMemberIsNonNegative( inbound_stream.header_bytes_received); - // packets_lost is defined as signed, but this should never happen in - // this test. See RFC 3550. - verifier.TestMemberIsNonNegative(inbound_stream.packets_lost); verifier.TestMemberIsDefined(inbound_stream.last_packet_received_timestamp); if (inbound_stream.frames_received.ValueOrDefault(0) > 0) { verifier.TestMemberIsNonNegative(inbound_stream.frame_width); @@ -834,12 +840,12 @@ class RTCStatsReportVerifier { verifier.TestMemberIsUndefined(inbound_stream.frames_per_second); } verifier.TestMemberIsUndefined(inbound_stream.frame_bit_depth); + verifier.TestMemberIsNonNegative( + inbound_stream.jitter_buffer_delay); + verifier.TestMemberIsNonNegative( + inbound_stream.jitter_buffer_emitted_count); if (inbound_stream.media_type.is_defined() && *inbound_stream.media_type == "video") { - verifier.TestMemberIsUndefined(inbound_stream.jitter); - verifier.TestMemberIsUndefined(inbound_stream.jitter_buffer_delay); - verifier.TestMemberIsUndefined( - inbound_stream.jitter_buffer_emitted_count); verifier.TestMemberIsUndefined(inbound_stream.total_samples_received); verifier.TestMemberIsUndefined(inbound_stream.concealed_samples); verifier.TestMemberIsUndefined(inbound_stream.silent_concealed_samples); @@ -852,12 +858,13 @@ class RTCStatsReportVerifier { verifier.TestMemberIsUndefined(inbound_stream.total_audio_energy); verifier.TestMemberIsUndefined(inbound_stream.total_samples_duration); verifier.TestMemberIsNonNegative(inbound_stream.frames_received); + verifier.TestMemberIsNonNegative(inbound_stream.fir_count); + verifier.TestMemberIsNonNegative(inbound_stream.pli_count); + verifier.TestMemberIsNonNegative(inbound_stream.nack_count); } else { - verifier.TestMemberIsNonNegative(inbound_stream.jitter); - verifier.TestMemberIsNonNegative( - inbound_stream.jitter_buffer_delay); - verifier.TestMemberIsNonNegative( - inbound_stream.jitter_buffer_emitted_count); + verifier.TestMemberIsUndefined(inbound_stream.fir_count); + verifier.TestMemberIsUndefined(inbound_stream.pli_count); + verifier.TestMemberIsUndefined(inbound_stream.nack_count); verifier.TestMemberIsPositive( inbound_stream.total_samples_received); verifier.TestMemberIsNonNegative( @@ -920,21 +927,26 @@ class RTCStatsReportVerifier { bool VerifyRTCOutboundRTPStreamStats( const RTCOutboundRTPStreamStats& outbound_stream) { RTCStatsVerifier verifier(report_, &outbound_stream); - VerifyRTCRTPStreamStats(outbound_stream, &verifier); + VerifyRTCRTPStreamStats(outbound_stream, verifier); if (outbound_stream.media_type.is_defined() && *outbound_stream.media_type == "video") { verifier.TestMemberIsIDReference(outbound_stream.media_source_id, RTCVideoSourceStats::kType); + verifier.TestMemberIsNonNegative(outbound_stream.fir_count); + verifier.TestMemberIsNonNegative(outbound_stream.pli_count); if (*outbound_stream.frames_encoded > 0) { verifier.TestMemberIsNonNegative(outbound_stream.qp_sum); } else { verifier.TestMemberIsUndefined(outbound_stream.qp_sum); } } else { + verifier.TestMemberIsUndefined(outbound_stream.fir_count); + verifier.TestMemberIsUndefined(outbound_stream.pli_count); verifier.TestMemberIsIDReference(outbound_stream.media_source_id, RTCAudioSourceStats::kType); verifier.TestMemberIsUndefined(outbound_stream.qp_sum); } + verifier.TestMemberIsNonNegative(outbound_stream.nack_count); verifier.TestMemberIsOptionalIDReference( outbound_stream.remote_id, RTCRemoteInboundRtpStreamStats::kType); verifier.TestMemberIsNonNegative(outbound_stream.packets_sent); @@ -957,6 +969,8 @@ class RTCStatsReportVerifier { verifier.TestMemberIsNonNegative( outbound_stream.total_packet_send_delay); verifier.TestMemberIsDefined(outbound_stream.quality_limitation_reason); + verifier.TestMemberIsDefined( + outbound_stream.quality_limitation_durations); verifier.TestMemberIsNonNegative( outbound_stream.quality_limitation_resolution_changes); // The integration test is not set up to test screen share; don't require @@ -989,6 +1003,8 @@ class RTCStatsReportVerifier { // TODO(https://crbug.com/webrtc/10635): Implement for audio as well. verifier.TestMemberIsUndefined(outbound_stream.total_packet_send_delay); verifier.TestMemberIsUndefined(outbound_stream.quality_limitation_reason); + verifier.TestMemberIsUndefined( + outbound_stream.quality_limitation_durations); verifier.TestMemberIsUndefined( outbound_stream.quality_limitation_resolution_changes); verifier.TestMemberIsUndefined(outbound_stream.content_type); @@ -1004,23 +1020,40 @@ class RTCStatsReportVerifier { return verifier.ExpectAllMembersSuccessfullyTested(); } + void VerifyRTCReceivedRtpStreamStats( + const RTCReceivedRtpStreamStats& received_rtp, + RTCStatsVerifier& verifier) { + VerifyRTCRTPStreamStats(received_rtp, verifier); + verifier.TestMemberIsNonNegative(received_rtp.jitter); + verifier.TestMemberIsDefined(received_rtp.packets_lost); + } + bool VerifyRTCRemoteInboundRtpStreamStats( const RTCRemoteInboundRtpStreamStats& remote_inbound_stream) { RTCStatsVerifier verifier(report_, &remote_inbound_stream); - verifier.TestMemberIsDefined(remote_inbound_stream.ssrc); - verifier.TestMemberIsDefined(remote_inbound_stream.kind); - verifier.TestMemberIsIDReference(remote_inbound_stream.transport_id, - RTCTransportStats::kType); - verifier.TestMemberIsIDReference(remote_inbound_stream.codec_id, - RTCCodecStats::kType); - verifier.TestMemberIsDefined(remote_inbound_stream.packets_lost); - // Note that the existance of RTCCodecStats is needed for |codec_id| and - // |jitter| to be present. - verifier.TestMemberIsNonNegative(remote_inbound_stream.jitter); + VerifyRTCReceivedRtpStreamStats(remote_inbound_stream, verifier); + verifier.TestMemberIsDefined(remote_inbound_stream.fraction_lost); verifier.TestMemberIsIDReference(remote_inbound_stream.local_id, RTCOutboundRTPStreamStats::kType); verifier.TestMemberIsNonNegative( remote_inbound_stream.round_trip_time); + verifier.TestMemberIsNonNegative( + remote_inbound_stream.total_round_trip_time); + verifier.TestMemberIsNonNegative( + remote_inbound_stream.round_trip_time_measurements); + return verifier.ExpectAllMembersSuccessfullyTested(); + } + + bool VerifyRTCRemoteOutboundRTPStreamStats( + const RTCRemoteOutboundRtpStreamStats& remote_outbound_stream) { + RTCStatsVerifier verifier(report_, &remote_outbound_stream); + VerifyRTCRTPStreamStats(remote_outbound_stream, verifier); + VerifyRTCSentRTPStreamStats(remote_outbound_stream, verifier); + verifier.TestMemberIsIDReference(remote_outbound_stream.local_id, + RTCOutboundRTPStreamStats::kType); + verifier.TestMemberIsNonNegative( + remote_outbound_stream.remote_timestamp); + verifier.TestMemberIsDefined(remote_outbound_stream.reports_sent); return verifier.ExpectAllMembersSuccessfullyTested(); } @@ -1045,6 +1078,12 @@ class RTCStatsReportVerifier { verifier.TestMemberIsNonNegative(audio_source.audio_level); verifier.TestMemberIsPositive(audio_source.total_audio_energy); verifier.TestMemberIsPositive(audio_source.total_samples_duration); + // TODO(hbos): |echo_return_loss| and |echo_return_loss_enhancement| are + // flaky on msan bot (sometimes defined, sometimes undefined). Should the + // test run until available or is there a way to have it always be + // defined? crbug.com/627816 + verifier.MarkMemberTested(audio_source.echo_return_loss, true); + verifier.MarkMemberTested(audio_source.echo_return_loss_enhancement, true); return verifier.ExpectAllMembersSuccessfullyTested(); } @@ -1057,9 +1096,7 @@ class RTCStatsReportVerifier { // reflect real code. verifier.TestMemberIsUndefined(video_source.width); verifier.TestMemberIsUndefined(video_source.height); - // TODO(hbos): When |frames| is implemented test that this member should be - // expected to be non-negative. - verifier.TestMemberIsUndefined(video_source.frames); + verifier.TestMemberIsNonNegative(video_source.frames); verifier.TestMemberIsNonNegative(video_source.frames_per_second); return verifier.ExpectAllMembersSuccessfullyTested(); } @@ -1091,7 +1128,7 @@ class RTCStatsReportVerifier { rtc::scoped_refptr report_; }; -#ifdef HAVE_SCTP +#ifdef WEBRTC_HAVE_SCTP TEST_F(RTCStatsIntegrationTest, GetStatsFromCaller) { StartCall(); @@ -1254,7 +1291,21 @@ TEST_F(RTCStatsIntegrationTest, GetStatsReferencedIds) { } } } -#endif // HAVE_SCTP + +TEST_F(RTCStatsIntegrationTest, GetStatsContainsNoDuplicateMembers) { + StartCall(); + + rtc::scoped_refptr report = GetStatsFromCallee(); + for (const RTCStats& stats : *report) { + std::set member_names; + for (const auto* member : stats.Members()) { + EXPECT_TRUE(member_names.find(member->name()) == member_names.end()) + << member->name() << " is a duplicate!"; + member_names.insert(member->name()); + } + } +} +#endif // WEBRTC_HAVE_SCTP } // namespace diff --git a/pc/rtc_stats_traversal.cc b/pc/rtc_stats_traversal.cc index aa53dde180..e579072ea5 100644 --- a/pc/rtc_stats_traversal.cc +++ b/pc/rtc_stats_traversal.cc @@ -99,24 +99,36 @@ std::vector GetStatsReferencedIds(const RTCStats& stats) { AddIdIfDefined(track.media_source_id, &neighbor_ids); } else if (type == RTCPeerConnectionStats::kType) { // RTCPeerConnectionStats does not have any neighbor references. - } else if (type == RTCInboundRTPStreamStats::kType || - type == RTCOutboundRTPStreamStats::kType) { - const auto& rtp = static_cast(stats); - AddIdIfDefined(rtp.track_id, &neighbor_ids); - AddIdIfDefined(rtp.transport_id, &neighbor_ids); - AddIdIfDefined(rtp.codec_id, &neighbor_ids); - if (type == RTCOutboundRTPStreamStats::kType) { - const auto& outbound_rtp = - static_cast(stats); - AddIdIfDefined(outbound_rtp.media_source_id, &neighbor_ids); - AddIdIfDefined(outbound_rtp.remote_id, &neighbor_ids); - } + } else if (type == RTCInboundRTPStreamStats::kType) { + const auto& inbound_rtp = + static_cast(stats); + AddIdIfDefined(inbound_rtp.remote_id, &neighbor_ids); + AddIdIfDefined(inbound_rtp.track_id, &neighbor_ids); + AddIdIfDefined(inbound_rtp.transport_id, &neighbor_ids); + AddIdIfDefined(inbound_rtp.codec_id, &neighbor_ids); + } else if (type == RTCOutboundRTPStreamStats::kType) { + const auto& outbound_rtp = + static_cast(stats); + AddIdIfDefined(outbound_rtp.remote_id, &neighbor_ids); + AddIdIfDefined(outbound_rtp.track_id, &neighbor_ids); + AddIdIfDefined(outbound_rtp.transport_id, &neighbor_ids); + AddIdIfDefined(outbound_rtp.codec_id, &neighbor_ids); + AddIdIfDefined(outbound_rtp.media_source_id, &neighbor_ids); } else if (type == RTCRemoteInboundRtpStreamStats::kType) { const auto& remote_inbound_rtp = static_cast(stats); AddIdIfDefined(remote_inbound_rtp.transport_id, &neighbor_ids); AddIdIfDefined(remote_inbound_rtp.codec_id, &neighbor_ids); AddIdIfDefined(remote_inbound_rtp.local_id, &neighbor_ids); + } else if (type == RTCRemoteOutboundRtpStreamStats::kType) { + const auto& remote_outbound_rtp = + static_cast(stats); + // Inherited from `RTCRTPStreamStats`. + AddIdIfDefined(remote_outbound_rtp.track_id, &neighbor_ids); + AddIdIfDefined(remote_outbound_rtp.transport_id, &neighbor_ids); + AddIdIfDefined(remote_outbound_rtp.codec_id, &neighbor_ids); + // Direct members of `RTCRemoteOutboundRtpStreamStats`. + AddIdIfDefined(remote_outbound_rtp.local_id, &neighbor_ids); } else if (type == RTCAudioSourceStats::kType || type == RTCVideoSourceStats::kType) { // RTC[Audio/Video]SourceStats does not have any neighbor references. diff --git a/pc/rtp_data_channel.cc b/pc/rtp_data_channel.cc deleted file mode 100644 index b08b2b2ffb..0000000000 --- a/pc/rtp_data_channel.cc +++ /dev/null @@ -1,394 +0,0 @@ -/* - * Copyright 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "pc/rtp_data_channel.h" - -#include -#include -#include - -#include "api/proxy.h" -#include "rtc_base/checks.h" -#include "rtc_base/location.h" -#include "rtc_base/logging.h" -#include "rtc_base/ref_counted_object.h" -#include "rtc_base/thread.h" - -namespace webrtc { - -namespace { - -static size_t kMaxQueuedReceivedDataBytes = 16 * 1024 * 1024; - -static std::atomic g_unique_id{0}; - -int GenerateUniqueId() { - return ++g_unique_id; -} - -// Define proxy for DataChannelInterface. -BEGIN_SIGNALING_PROXY_MAP(DataChannel) -PROXY_SIGNALING_THREAD_DESTRUCTOR() -PROXY_METHOD1(void, RegisterObserver, DataChannelObserver*) -PROXY_METHOD0(void, UnregisterObserver) -BYPASS_PROXY_CONSTMETHOD0(std::string, label) -BYPASS_PROXY_CONSTMETHOD0(bool, reliable) -BYPASS_PROXY_CONSTMETHOD0(bool, ordered) -BYPASS_PROXY_CONSTMETHOD0(uint16_t, maxRetransmitTime) -BYPASS_PROXY_CONSTMETHOD0(uint16_t, maxRetransmits) -BYPASS_PROXY_CONSTMETHOD0(absl::optional, maxRetransmitsOpt) -BYPASS_PROXY_CONSTMETHOD0(absl::optional, maxPacketLifeTime) -BYPASS_PROXY_CONSTMETHOD0(std::string, protocol) -BYPASS_PROXY_CONSTMETHOD0(bool, negotiated) -// Can't bypass the proxy since the id may change. -PROXY_CONSTMETHOD0(int, id) -BYPASS_PROXY_CONSTMETHOD0(Priority, priority) -PROXY_CONSTMETHOD0(DataState, state) -PROXY_CONSTMETHOD0(RTCError, error) -PROXY_CONSTMETHOD0(uint32_t, messages_sent) -PROXY_CONSTMETHOD0(uint64_t, bytes_sent) -PROXY_CONSTMETHOD0(uint32_t, messages_received) -PROXY_CONSTMETHOD0(uint64_t, bytes_received) -PROXY_CONSTMETHOD0(uint64_t, buffered_amount) -PROXY_METHOD0(void, Close) -// TODO(bugs.webrtc.org/11547): Change to run on the network thread. -PROXY_METHOD1(bool, Send, const DataBuffer&) -END_PROXY_MAP() - -} // namespace - -rtc::scoped_refptr RtpDataChannel::Create( - RtpDataChannelProviderInterface* provider, - const std::string& label, - const DataChannelInit& config, - rtc::Thread* signaling_thread) { - rtc::scoped_refptr channel( - new rtc::RefCountedObject(config, provider, label, - signaling_thread)); - if (!channel->Init()) { - return nullptr; - } - return channel; -} - -// static -rtc::scoped_refptr RtpDataChannel::CreateProxy( - rtc::scoped_refptr channel) { - return DataChannelProxy::Create(channel->signaling_thread_, channel.get()); -} - -RtpDataChannel::RtpDataChannel(const DataChannelInit& config, - RtpDataChannelProviderInterface* provider, - const std::string& label, - rtc::Thread* signaling_thread) - : signaling_thread_(signaling_thread), - internal_id_(GenerateUniqueId()), - label_(label), - config_(config), - provider_(provider) { - RTC_DCHECK_RUN_ON(signaling_thread_); -} - -bool RtpDataChannel::Init() { - RTC_DCHECK_RUN_ON(signaling_thread_); - if (config_.reliable || config_.id != -1 || config_.maxRetransmits || - config_.maxRetransmitTime) { - RTC_LOG(LS_ERROR) << "Failed to initialize the RTP data channel due to " - "invalid DataChannelInit."; - return false; - } - - return true; -} - -RtpDataChannel::~RtpDataChannel() { - RTC_DCHECK_RUN_ON(signaling_thread_); -} - -void RtpDataChannel::RegisterObserver(DataChannelObserver* observer) { - RTC_DCHECK_RUN_ON(signaling_thread_); - observer_ = observer; - DeliverQueuedReceivedData(); -} - -void RtpDataChannel::UnregisterObserver() { - RTC_DCHECK_RUN_ON(signaling_thread_); - observer_ = nullptr; -} - -void RtpDataChannel::Close() { - RTC_DCHECK_RUN_ON(signaling_thread_); - if (state_ == kClosed) - return; - send_ssrc_ = 0; - send_ssrc_set_ = false; - SetState(kClosing); - UpdateState(); -} - -RtpDataChannel::DataState RtpDataChannel::state() const { - RTC_DCHECK_RUN_ON(signaling_thread_); - return state_; -} - -RTCError RtpDataChannel::error() const { - RTC_DCHECK_RUN_ON(signaling_thread_); - return error_; -} - -uint32_t RtpDataChannel::messages_sent() const { - RTC_DCHECK_RUN_ON(signaling_thread_); - return messages_sent_; -} - -uint64_t RtpDataChannel::bytes_sent() const { - RTC_DCHECK_RUN_ON(signaling_thread_); - return bytes_sent_; -} - -uint32_t RtpDataChannel::messages_received() const { - RTC_DCHECK_RUN_ON(signaling_thread_); - return messages_received_; -} - -uint64_t RtpDataChannel::bytes_received() const { - RTC_DCHECK_RUN_ON(signaling_thread_); - return bytes_received_; -} - -bool RtpDataChannel::Send(const DataBuffer& buffer) { - RTC_DCHECK_RUN_ON(signaling_thread_); - - if (state_ != kOpen) { - return false; - } - - // TODO(jiayl): the spec is unclear about if the remote side should get the - // onmessage event. We need to figure out the expected behavior and change the - // code accordingly. - if (buffer.size() == 0) { - return true; - } - - return SendDataMessage(buffer); -} - -void RtpDataChannel::SetReceiveSsrc(uint32_t receive_ssrc) { - RTC_DCHECK_RUN_ON(signaling_thread_); - - if (receive_ssrc_set_) { - return; - } - receive_ssrc_ = receive_ssrc; - receive_ssrc_set_ = true; - UpdateState(); -} - -void RtpDataChannel::OnTransportChannelClosed() { - RTCError error = RTCError(RTCErrorType::OPERATION_ERROR_WITH_DATA, - "Transport channel closed"); - CloseAbruptlyWithError(std::move(error)); -} - -DataChannelStats RtpDataChannel::GetStats() const { - RTC_DCHECK_RUN_ON(signaling_thread_); - DataChannelStats stats{internal_id_, id(), label(), - protocol(), state(), messages_sent(), - messages_received(), bytes_sent(), bytes_received()}; - return stats; -} - -// The remote peer request that this channel shall be closed. -void RtpDataChannel::RemotePeerRequestClose() { - // Close with error code explicitly set to OK. - CloseAbruptlyWithError(RTCError()); -} - -void RtpDataChannel::SetSendSsrc(uint32_t send_ssrc) { - RTC_DCHECK_RUN_ON(signaling_thread_); - if (send_ssrc_set_) { - return; - } - send_ssrc_ = send_ssrc; - send_ssrc_set_ = true; - UpdateState(); -} - -void RtpDataChannel::OnDataReceived(const cricket::ReceiveDataParams& params, - const rtc::CopyOnWriteBuffer& payload) { - RTC_DCHECK_RUN_ON(signaling_thread_); - if (params.ssrc != receive_ssrc_) { - return; - } - - RTC_DCHECK(params.type == cricket::DMT_BINARY || - params.type == cricket::DMT_TEXT); - - RTC_LOG(LS_VERBOSE) << "DataChannel received DATA message, sid = " - << params.sid; - - bool binary = (params.type == cricket::DMT_BINARY); - auto buffer = std::make_unique(payload, binary); - if (state_ == kOpen && observer_) { - ++messages_received_; - bytes_received_ += buffer->size(); - observer_->OnMessage(*buffer.get()); - } else { - if (queued_received_data_.byte_count() + payload.size() > - kMaxQueuedReceivedDataBytes) { - RTC_LOG(LS_ERROR) << "Queued received data exceeds the max buffer size."; - - queued_received_data_.Clear(); - CloseAbruptlyWithError( - RTCError(RTCErrorType::RESOURCE_EXHAUSTED, - "Queued received data exceeds the max buffer size.")); - - return; - } - queued_received_data_.PushBack(std::move(buffer)); - } -} - -void RtpDataChannel::OnChannelReady(bool writable) { - RTC_DCHECK_RUN_ON(signaling_thread_); - - writable_ = writable; - if (!writable) { - return; - } - - UpdateState(); -} - -void RtpDataChannel::CloseAbruptlyWithError(RTCError error) { - RTC_DCHECK_RUN_ON(signaling_thread_); - - if (state_ == kClosed) { - return; - } - - if (connected_to_provider_) { - DisconnectFromProvider(); - } - - // Still go to "kClosing" before "kClosed", since observers may be expecting - // that. - SetState(kClosing); - error_ = std::move(error); - SetState(kClosed); -} - -void RtpDataChannel::UpdateState() { - RTC_DCHECK_RUN_ON(signaling_thread_); - // UpdateState determines what to do from a few state variables. Include - // all conditions required for each state transition here for - // clarity. - switch (state_) { - case kConnecting: { - if (send_ssrc_set_ == receive_ssrc_set_) { - if (!connected_to_provider_) { - connected_to_provider_ = provider_->ConnectDataChannel(this); - } - if (connected_to_provider_ && writable_) { - SetState(kOpen); - // If we have received buffers before the channel got writable. - // Deliver them now. - DeliverQueuedReceivedData(); - } - } - break; - } - case kOpen: { - break; - } - case kClosing: { - // For RTP data channels, we can go to "closed" after we finish - // sending data and the send/recv SSRCs are unset. - if (connected_to_provider_) { - DisconnectFromProvider(); - } - if (!send_ssrc_set_ && !receive_ssrc_set_) { - SetState(kClosed); - } - break; - } - case kClosed: - break; - } -} - -void RtpDataChannel::SetState(DataState state) { - RTC_DCHECK_RUN_ON(signaling_thread_); - if (state_ == state) { - return; - } - - state_ = state; - if (observer_) { - observer_->OnStateChange(); - } - if (state_ == kOpen) { - SignalOpened(this); - } else if (state_ == kClosed) { - SignalClosed(this); - } -} - -void RtpDataChannel::DisconnectFromProvider() { - RTC_DCHECK_RUN_ON(signaling_thread_); - if (!connected_to_provider_) - return; - - provider_->DisconnectDataChannel(this); - connected_to_provider_ = false; -} - -void RtpDataChannel::DeliverQueuedReceivedData() { - RTC_DCHECK_RUN_ON(signaling_thread_); - if (!observer_) { - return; - } - - while (!queued_received_data_.Empty()) { - std::unique_ptr buffer = queued_received_data_.PopFront(); - ++messages_received_; - bytes_received_ += buffer->size(); - observer_->OnMessage(*buffer); - } -} - -bool RtpDataChannel::SendDataMessage(const DataBuffer& buffer) { - RTC_DCHECK_RUN_ON(signaling_thread_); - cricket::SendDataParams send_params; - - send_params.ssrc = send_ssrc_; - send_params.type = buffer.binary ? cricket::DMT_BINARY : cricket::DMT_TEXT; - - cricket::SendDataResult send_result = cricket::SDR_SUCCESS; - bool success = provider_->SendData(send_params, buffer.data, &send_result); - - if (success) { - ++messages_sent_; - bytes_sent_ += buffer.size(); - if (observer_ && buffer.size() > 0) { - observer_->OnBufferedAmountChange(buffer.size()); - } - return true; - } - - return false; -} - -// static -void RtpDataChannel::ResetInternalIdAllocatorForTesting(int new_value) { - g_unique_id = new_value; -} - -} // namespace webrtc diff --git a/pc/rtp_data_channel.h b/pc/rtp_data_channel.h deleted file mode 100644 index ea2de49b5a..0000000000 --- a/pc/rtp_data_channel.h +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright 2020 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef PC_RTP_DATA_CHANNEL_H_ -#define PC_RTP_DATA_CHANNEL_H_ - -#include -#include - -#include "api/data_channel_interface.h" -#include "api/priority.h" -#include "api/scoped_refptr.h" -#include "api/transport/data_channel_transport_interface.h" -#include "media/base/media_channel.h" -#include "pc/channel.h" -#include "pc/data_channel_utils.h" -#include "rtc_base/async_invoker.h" -#include "rtc_base/third_party/sigslot/sigslot.h" - -namespace webrtc { - -class RtpDataChannel; - -// TODO(deadbeef): Once RTP data channels go away, get rid of this and have -// DataChannel depend on SctpTransportInternal (pure virtual SctpTransport -// interface) instead. -class RtpDataChannelProviderInterface { - public: - // Sends the data to the transport. - virtual bool SendData(const cricket::SendDataParams& params, - const rtc::CopyOnWriteBuffer& payload, - cricket::SendDataResult* result) = 0; - // Connects to the transport signals. - virtual bool ConnectDataChannel(RtpDataChannel* data_channel) = 0; - // Disconnects from the transport signals. - virtual void DisconnectDataChannel(RtpDataChannel* data_channel) = 0; - // Returns true if the transport channel is ready to send data. - virtual bool ReadyToSendData() const = 0; - - protected: - virtual ~RtpDataChannelProviderInterface() {} -}; - -// RtpDataChannel is an implementation of the DataChannelInterface based on -// libjingle's data engine. It provides an implementation of unreliable data -// channels. - -// DataChannel states: -// kConnecting: The channel has been created the transport might not yet be -// ready. -// kOpen: The channel have a local SSRC set by a call to UpdateSendSsrc -// and a remote SSRC set by call to UpdateReceiveSsrc and the transport -// has been writable once. -// kClosing: DataChannelInterface::Close has been called or UpdateReceiveSsrc -// has been called with SSRC==0 -// kClosed: Both UpdateReceiveSsrc and UpdateSendSsrc has been called with -// SSRC==0. -class RtpDataChannel : public DataChannelInterface, - public sigslot::has_slots<> { - public: - static rtc::scoped_refptr Create( - RtpDataChannelProviderInterface* provider, - const std::string& label, - const DataChannelInit& config, - rtc::Thread* signaling_thread); - - // Instantiates an API proxy for a DataChannel instance that will be handed - // out to external callers. - static rtc::scoped_refptr CreateProxy( - rtc::scoped_refptr channel); - - void RegisterObserver(DataChannelObserver* observer) override; - void UnregisterObserver() override; - - std::string label() const override { return label_; } - bool reliable() const override { return false; } - bool ordered() const override { return config_.ordered; } - // Backwards compatible accessors - uint16_t maxRetransmitTime() const override { - return config_.maxRetransmitTime ? *config_.maxRetransmitTime - : static_cast(-1); - } - uint16_t maxRetransmits() const override { - return config_.maxRetransmits ? *config_.maxRetransmits - : static_cast(-1); - } - absl::optional maxPacketLifeTime() const override { - return config_.maxRetransmitTime; - } - absl::optional maxRetransmitsOpt() const override { - return config_.maxRetransmits; - } - std::string protocol() const override { return config_.protocol; } - bool negotiated() const override { return config_.negotiated; } - int id() const override { return config_.id; } - Priority priority() const override { - return config_.priority ? *config_.priority : Priority::kLow; - } - - virtual int internal_id() const { return internal_id_; } - - uint64_t buffered_amount() const override { return 0; } - void Close() override; - DataState state() const override; - RTCError error() const override; - uint32_t messages_sent() const override; - uint64_t bytes_sent() const override; - uint32_t messages_received() const override; - uint64_t bytes_received() const override; - bool Send(const DataBuffer& buffer) override; - - // Close immediately, ignoring any queued data or closing procedure. - // This is called when SDP indicates a channel should be removed. - void CloseAbruptlyWithError(RTCError error); - - // Called when the channel's ready to use. That can happen when the - // underlying DataMediaChannel becomes ready, or when this channel is a new - // stream on an existing DataMediaChannel, and we've finished negotiation. - void OnChannelReady(bool writable); - - // Slots for provider to connect signals to. - void OnDataReceived(const cricket::ReceiveDataParams& params, - const rtc::CopyOnWriteBuffer& payload); - - // Called when the transport channel is unusable. - // This method makes sure the DataChannel is disconnected and changes state - // to kClosed. - void OnTransportChannelClosed(); - - DataChannelStats GetStats() const; - - // The remote peer requested that this channel should be closed. - void RemotePeerRequestClose(); - // Set the SSRC this channel should use to send data on the - // underlying data engine. |send_ssrc| == 0 means that the channel is no - // longer part of the session negotiation. - void SetSendSsrc(uint32_t send_ssrc); - // Set the SSRC this channel should use to receive data from the - // underlying data engine. - void SetReceiveSsrc(uint32_t receive_ssrc); - - // Emitted when state transitions to kOpen. - sigslot::signal1 SignalOpened; - // Emitted when state transitions to kClosed. - sigslot::signal1 SignalClosed; - - // Reset the allocator for internal ID values for testing, so that - // the internal IDs generated are predictable. Test only. - static void ResetInternalIdAllocatorForTesting(int new_value); - - protected: - RtpDataChannel(const DataChannelInit& config, - RtpDataChannelProviderInterface* client, - const std::string& label, - rtc::Thread* signaling_thread); - ~RtpDataChannel() override; - - private: - bool Init(); - void UpdateState(); - void SetState(DataState state); - void DisconnectFromProvider(); - - void DeliverQueuedReceivedData(); - - bool SendDataMessage(const DataBuffer& buffer); - - rtc::Thread* const signaling_thread_; - const int internal_id_; - const std::string label_; - const DataChannelInit config_; - DataChannelObserver* observer_ RTC_GUARDED_BY(signaling_thread_) = nullptr; - DataState state_ RTC_GUARDED_BY(signaling_thread_) = kConnecting; - RTCError error_ RTC_GUARDED_BY(signaling_thread_); - uint32_t messages_sent_ RTC_GUARDED_BY(signaling_thread_) = 0; - uint64_t bytes_sent_ RTC_GUARDED_BY(signaling_thread_) = 0; - uint32_t messages_received_ RTC_GUARDED_BY(signaling_thread_) = 0; - uint64_t bytes_received_ RTC_GUARDED_BY(signaling_thread_) = 0; - RtpDataChannelProviderInterface* const provider_; - bool connected_to_provider_ RTC_GUARDED_BY(signaling_thread_) = false; - bool send_ssrc_set_ RTC_GUARDED_BY(signaling_thread_) = false; - bool receive_ssrc_set_ RTC_GUARDED_BY(signaling_thread_) = false; - bool writable_ RTC_GUARDED_BY(signaling_thread_) = false; - uint32_t send_ssrc_ RTC_GUARDED_BY(signaling_thread_) = 0; - uint32_t receive_ssrc_ RTC_GUARDED_BY(signaling_thread_) = 0; - PacketQueue queued_received_data_ RTC_GUARDED_BY(signaling_thread_); -}; - -} // namespace webrtc - -#endif // PC_RTP_DATA_CHANNEL_H_ diff --git a/pc/rtp_media_utils.h b/pc/rtp_media_utils.h index e90a76eecb..d45cc744a1 100644 --- a/pc/rtp_media_utils.h +++ b/pc/rtp_media_utils.h @@ -11,6 +11,7 @@ #ifndef PC_RTP_MEDIA_UTILS_H_ #define PC_RTP_MEDIA_UTILS_H_ +#include "api/rtp_transceiver_direction.h" #include "api/rtp_transceiver_interface.h" namespace webrtc { diff --git a/pc/rtp_parameters_conversion.cc b/pc/rtp_parameters_conversion.cc index 68a948ea8e..8d3064ed93 100644 --- a/pc/rtp_parameters_conversion.cc +++ b/pc/rtp_parameters_conversion.cc @@ -10,10 +10,10 @@ #include "pc/rtp_parameters_conversion.h" +#include #include #include #include -#include #include #include "api/array_view.h" diff --git a/pc/rtp_receiver.cc b/pc/rtp_receiver.cc index f65afd7dc4..2444c9b60d 100644 --- a/pc/rtp_receiver.cc +++ b/pc/rtp_receiver.cc @@ -15,13 +15,9 @@ #include #include -#include "api/media_stream_proxy.h" -#include "api/media_stream_track_proxy.h" #include "pc/media_stream.h" -#include "rtc_base/checks.h" +#include "pc/media_stream_proxy.h" #include "rtc_base/location.h" -#include "rtc_base/logging.h" -#include "rtc_base/trace_event.h" namespace webrtc { @@ -43,20 +39,4 @@ RtpReceiverInternal::CreateStreamsFromIds(std::vector stream_ids) { return streams; } -// Attempt to attach the frame decryptor to the current media channel on the -// correct worker thread only if both the media channel exists and a ssrc has -// been allocated to the stream. -void RtpReceiverInternal::MaybeAttachFrameDecryptorToMediaChannel( - const absl::optional& ssrc, - rtc::Thread* worker_thread, - rtc::scoped_refptr frame_decryptor, - cricket::MediaChannel* media_channel, - bool stopped) { - if (media_channel && frame_decryptor && ssrc.has_value() && !stopped) { - worker_thread->Invoke(RTC_FROM_HERE, [&] { - media_channel->SetFrameDecryptor(*ssrc, frame_decryptor); - }); - } -} - } // namespace webrtc diff --git a/pc/rtp_receiver.h b/pc/rtp_receiver.h index 2cfccd4e63..73fc5b9858 100644 --- a/pc/rtp_receiver.h +++ b/pc/rtp_receiver.h @@ -22,6 +22,7 @@ #include "absl/types/optional.h" #include "api/crypto/frame_decryptor_interface.h" +#include "api/dtls_transport_interface.h" #include "api/media_stream_interface.h" #include "api/media_types.h" #include "api/rtp_parameters.h" @@ -91,13 +92,6 @@ class RtpReceiverInternal : public RtpReceiverInterface { static std::vector> CreateStreamsFromIds(std::vector stream_ids); - - static void MaybeAttachFrameDecryptorToMediaChannel( - const absl::optional& ssrc, - rtc::Thread* worker_thread, - rtc::scoped_refptr frame_decryptor, - cricket::MediaChannel* media_channel, - bool stopped); }; } // namespace webrtc diff --git a/pc/rtp_receiver_proxy.h b/pc/rtp_receiver_proxy.h new file mode 100644 index 0000000000..d4114e0f0b --- /dev/null +++ b/pc/rtp_receiver_proxy.h @@ -0,0 +1,54 @@ +/* + * Copyright 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef PC_RTP_RECEIVER_PROXY_H_ +#define PC_RTP_RECEIVER_PROXY_H_ + +#include +#include + +#include "api/rtp_receiver_interface.h" +#include "pc/proxy.h" + +namespace webrtc { + +// Define proxy for RtpReceiverInterface. +// TODO(deadbeef): Move this to .cc file. What threads methods are called on is +// an implementation detail. +BEGIN_PROXY_MAP(RtpReceiver) +PROXY_PRIMARY_THREAD_DESTRUCTOR() +BYPASS_PROXY_CONSTMETHOD0(rtc::scoped_refptr, track) +PROXY_CONSTMETHOD0(rtc::scoped_refptr, dtls_transport) +PROXY_CONSTMETHOD0(std::vector, stream_ids) +PROXY_CONSTMETHOD0(std::vector>, + streams) +BYPASS_PROXY_CONSTMETHOD0(cricket::MediaType, media_type) +BYPASS_PROXY_CONSTMETHOD0(std::string, id) +PROXY_SECONDARY_CONSTMETHOD0(RtpParameters, GetParameters) +PROXY_METHOD1(void, SetObserver, RtpReceiverObserverInterface*) +PROXY_SECONDARY_METHOD1(void, + SetJitterBufferMinimumDelay, + absl::optional) +PROXY_SECONDARY_CONSTMETHOD0(std::vector, GetSources) +// TODO(bugs.webrtc.org/12772): Remove. +PROXY_SECONDARY_METHOD1(void, + SetFrameDecryptor, + rtc::scoped_refptr) +// TODO(bugs.webrtc.org/12772): Remove. +PROXY_SECONDARY_CONSTMETHOD0(rtc::scoped_refptr, + GetFrameDecryptor) +PROXY_SECONDARY_METHOD1(void, + SetDepacketizerToDecoderFrameTransformer, + rtc::scoped_refptr) +END_PROXY_MAP(RtpReceiver) + +} // namespace webrtc + +#endif // PC_RTP_RECEIVER_PROXY_H_ diff --git a/pc/rtp_sender.cc b/pc/rtp_sender.cc index 5a7e237c90..aa268cef45 100644 --- a/pc/rtp_sender.cc +++ b/pc/rtp_sender.cc @@ -10,18 +10,22 @@ #include "pc/rtp_sender.h" +#include #include #include #include +#include "absl/algorithm/container.h" #include "api/audio_options.h" #include "api/media_stream_interface.h" +#include "api/priority.h" #include "media/base/media_engine.h" #include "pc/stats_collector_interface.h" #include "rtc_base/checks.h" #include "rtc_base/helpers.h" #include "rtc_base/location.h" #include "rtc_base/logging.h" +#include "rtc_base/ref_counted_object.h" #include "rtc_base/trace_event.h" namespace webrtc { @@ -420,9 +424,8 @@ rtc::scoped_refptr AudioRtpSender::Create( const std::string& id, StatsCollectorInterface* stats, SetStreamsObserver* set_streams_observer) { - return rtc::scoped_refptr( - new rtc::RefCountedObject(worker_thread, id, stats, - set_streams_observer)); + return rtc::make_ref_counted(worker_thread, id, stats, + set_streams_observer); } AudioRtpSender::AudioRtpSender(rtc::Thread* worker_thread, @@ -567,9 +570,8 @@ rtc::scoped_refptr VideoRtpSender::Create( rtc::Thread* worker_thread, const std::string& id, SetStreamsObserver* set_streams_observer) { - return rtc::scoped_refptr( - new rtc::RefCountedObject(worker_thread, id, - set_streams_observer)); + return rtc::make_ref_counted(worker_thread, id, + set_streams_observer); } VideoRtpSender::VideoRtpSender(rtc::Thread* worker_thread, diff --git a/pc/rtp_sender.h b/pc/rtp_sender.h index 51ae1e978b..0b4c204902 100644 --- a/pc/rtp_sender.h +++ b/pc/rtp_sender.h @@ -15,16 +15,30 @@ #ifndef PC_RTP_SENDER_H_ #define PC_RTP_SENDER_H_ +#include +#include #include #include #include +#include "absl/types/optional.h" +#include "api/crypto/frame_encryptor_interface.h" +#include "api/dtls_transport_interface.h" +#include "api/dtmf_sender_interface.h" +#include "api/frame_transformer_interface.h" #include "api/media_stream_interface.h" +#include "api/media_types.h" +#include "api/rtc_error.h" +#include "api/rtp_parameters.h" #include "api/rtp_sender_interface.h" +#include "api/scoped_refptr.h" #include "media/base/audio_source.h" #include "media/base/media_channel.h" #include "pc/dtmf_sender.h" +#include "pc/stats_collector_interface.h" #include "rtc_base/synchronization/mutex.h" +#include "rtc_base/third_party/sigslot/sigslot.h" +#include "rtc_base/thread.h" namespace webrtc { diff --git a/pc/rtp_sender_proxy.h b/pc/rtp_sender_proxy.h new file mode 100644 index 0000000000..2f8fe2c0bf --- /dev/null +++ b/pc/rtp_sender_proxy.h @@ -0,0 +1,51 @@ +/* + * Copyright 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef PC_RTP_SENDER_PROXY_H_ +#define PC_RTP_SENDER_PROXY_H_ + +#include +#include + +#include "api/rtp_sender_interface.h" +#include "pc/proxy.h" + +namespace webrtc { + +// Define proxy for RtpSenderInterface. +// TODO(deadbeef): Move this to .cc file. What threads methods are called on is +// an implementation detail. +BEGIN_PRIMARY_PROXY_MAP(RtpSender) +PROXY_PRIMARY_THREAD_DESTRUCTOR() +PROXY_METHOD1(bool, SetTrack, MediaStreamTrackInterface*) +PROXY_CONSTMETHOD0(rtc::scoped_refptr, track) +PROXY_CONSTMETHOD0(rtc::scoped_refptr, dtls_transport) +PROXY_CONSTMETHOD0(uint32_t, ssrc) +BYPASS_PROXY_CONSTMETHOD0(cricket::MediaType, media_type) +BYPASS_PROXY_CONSTMETHOD0(std::string, id) +PROXY_CONSTMETHOD0(std::vector, stream_ids) +PROXY_CONSTMETHOD0(std::vector, init_send_encodings) +PROXY_CONSTMETHOD0(RtpParameters, GetParameters) +PROXY_METHOD1(RTCError, SetParameters, const RtpParameters&) +PROXY_CONSTMETHOD0(rtc::scoped_refptr, GetDtmfSender) +PROXY_METHOD1(void, + SetFrameEncryptor, + rtc::scoped_refptr) +PROXY_CONSTMETHOD0(rtc::scoped_refptr, + GetFrameEncryptor) +PROXY_METHOD1(void, SetStreams, const std::vector&) +PROXY_METHOD1(void, + SetEncoderToPacketizerFrameTransformer, + rtc::scoped_refptr) +END_PROXY_MAP(RtpSender) + +} // namespace webrtc + +#endif // PC_RTP_SENDER_PROXY_H_ diff --git a/pc/rtp_sender_receiver_unittest.cc b/pc/rtp_sender_receiver_unittest.cc index 364e87a89f..10dc894518 100644 --- a/pc/rtp_sender_receiver_unittest.cc +++ b/pc/rtp_sender_receiver_unittest.cc @@ -37,7 +37,6 @@ #include "media/base/media_channel.h" #include "media/base/media_config.h" #include "media/base/media_engine.h" -#include "media/base/rtp_data_engine.h" #include "media/base/stream_params.h" #include "media/base/test_utils.h" #include "media/engine/fake_webrtc_call.h" @@ -64,6 +63,7 @@ #include "rtc_base/thread.h" #include "test/gmock.h" #include "test/gtest.h" +#include "test/run_loop.h" using ::testing::_; using ::testing::ContainerEq; @@ -108,24 +108,24 @@ class RtpSenderReceiverTest // Create fake media engine/etc. so we can create channels to use to // test RtpSenders/RtpReceivers. media_engine_(new cricket::FakeMediaEngine()), - channel_manager_(absl::WrapUnique(media_engine_), - std::make_unique(), - worker_thread_, - network_thread_), - fake_call_(), + fake_call_(worker_thread_, network_thread_), local_stream_(MediaStream::Create(kStreamId1)) { - // Create channels to be used by the RtpSenders and RtpReceivers. - channel_manager_.Init(); + worker_thread_->Invoke(RTC_FROM_HERE, [&]() { + channel_manager_ = cricket::ChannelManager::Create( + absl::WrapUnique(media_engine_), false, worker_thread_, + network_thread_); + }); + bool srtp_required = true; rtp_dtls_transport_ = std::make_unique( "fake_dtls_transport", cricket::ICE_CANDIDATE_COMPONENT_RTP); rtp_transport_ = CreateDtlsSrtpTransport(); - voice_channel_ = channel_manager_.CreateVoiceChannel( + voice_channel_ = channel_manager_->CreateVoiceChannel( &fake_call_, cricket::MediaConfig(), rtp_transport_.get(), rtc::Thread::Current(), cricket::CN_AUDIO, srtp_required, webrtc::CryptoOptions(), &ssrc_generator_, cricket::AudioOptions()); - video_channel_ = channel_manager_.CreateVideoChannel( + video_channel_ = channel_manager_->CreateVideoChannel( &fake_call_, cricket::MediaConfig(), rtp_transport_.get(), rtc::Thread::Current(), cricket::CN_VIDEO, srtp_required, webrtc::CryptoOptions(), &ssrc_generator_, cricket::VideoOptions(), @@ -161,6 +161,18 @@ class RtpSenderReceiverTest cricket::StreamParams::CreateLegacy(kVideoSsrc2)); } + ~RtpSenderReceiverTest() { + audio_rtp_sender_ = nullptr; + video_rtp_sender_ = nullptr; + audio_rtp_receiver_ = nullptr; + video_rtp_receiver_ = nullptr; + local_stream_ = nullptr; + video_track_ = nullptr; + audio_track_ = nullptr; + worker_thread_->Invoke(RTC_FROM_HERE, + [&]() { channel_manager_.reset(); }); + } + std::unique_ptr CreateDtlsSrtpTransport() { auto dtls_srtp_transport = std::make_unique( /*rtcp_mux_required=*/true); @@ -288,8 +300,9 @@ class RtpSenderReceiverTest void CreateAudioRtpReceiver( std::vector> streams = {}) { - audio_rtp_receiver_ = - new AudioRtpReceiver(rtc::Thread::Current(), kAudioTrackId, streams); + audio_rtp_receiver_ = rtc::make_ref_counted( + rtc::Thread::Current(), kAudioTrackId, streams, + /*is_unified_plan=*/true); audio_rtp_receiver_->SetMediaChannel(voice_media_channel_); audio_rtp_receiver_->SetupMediaChannel(kAudioSsrc); audio_track_ = audio_rtp_receiver_->audio_track(); @@ -298,8 +311,8 @@ class RtpSenderReceiverTest void CreateVideoRtpReceiver( std::vector> streams = {}) { - video_rtp_receiver_ = - new VideoRtpReceiver(rtc::Thread::Current(), kVideoTrackId, streams); + video_rtp_receiver_ = rtc::make_ref_counted( + rtc::Thread::Current(), kVideoTrackId, streams); video_rtp_receiver_->SetMediaChannel(video_media_channel_); video_rtp_receiver_->SetupMediaChannel(kVideoSsrc); video_track_ = video_rtp_receiver_->video_track(); @@ -318,19 +331,25 @@ class RtpSenderReceiverTest video_media_channel_->AddRecvStream(stream_params); uint32_t primary_ssrc = stream_params.first_ssrc(); - video_rtp_receiver_ = - new VideoRtpReceiver(rtc::Thread::Current(), kVideoTrackId, streams); + video_rtp_receiver_ = rtc::make_ref_counted( + rtc::Thread::Current(), kVideoTrackId, streams); video_rtp_receiver_->SetMediaChannel(video_media_channel_); video_rtp_receiver_->SetupMediaChannel(primary_ssrc); video_track_ = video_rtp_receiver_->video_track(); } void DestroyAudioRtpReceiver() { + if (!audio_rtp_receiver_) + return; + audio_rtp_receiver_->Stop(); audio_rtp_receiver_ = nullptr; VerifyVoiceChannelNoOutput(); } void DestroyVideoRtpReceiver() { + if (!video_rtp_receiver_) + return; + video_rtp_receiver_->Stop(); video_rtp_receiver_ = nullptr; VerifyVideoChannelNoOutput(); } @@ -486,6 +505,7 @@ class RtpSenderReceiverTest } protected: + test::RunLoop run_loop_; rtc::Thread* const network_thread_; rtc::Thread* const worker_thread_; webrtc::RtcEventLogNull event_log_; @@ -497,7 +517,7 @@ class RtpSenderReceiverTest video_bitrate_allocator_factory_; // |media_engine_| is actually owned by |channel_manager_|. cricket::FakeMediaEngine* media_engine_; - cricket::ChannelManager channel_manager_; + std::unique_ptr channel_manager_; cricket::FakeCall fake_call_; cricket::VoiceChannel* voice_channel_; cricket::VideoChannel* video_channel_; @@ -587,11 +607,15 @@ TEST_F(RtpSenderReceiverTest, RemoteAudioTrackDisable) { EXPECT_TRUE(voice_media_channel_->GetOutputVolume(kAudioSsrc, &volume)); EXPECT_EQ(1, volume); + // Handling of enable/disable is applied asynchronously. audio_track_->set_enabled(false); + run_loop_.Flush(); + EXPECT_TRUE(voice_media_channel_->GetOutputVolume(kAudioSsrc, &volume)); EXPECT_EQ(0, volume); audio_track_->set_enabled(true); + run_loop_.Flush(); EXPECT_TRUE(voice_media_channel_->GetOutputVolume(kAudioSsrc, &volume)); EXPECT_EQ(1, volume); @@ -624,6 +648,7 @@ TEST_F(RtpSenderReceiverTest, RemoteVideoTrackState) { EXPECT_EQ(webrtc::MediaStreamTrackInterface::kEnded, video_track_->state()); EXPECT_EQ(webrtc::MediaSourceInterface::kEnded, video_track_->GetSource()->state()); + DestroyVideoRtpReceiver(); } // Currently no action is taken when a remote video track is disabled or @@ -645,22 +670,27 @@ TEST_F(RtpSenderReceiverTest, RemoteAudioTrackSetVolume) { double volume; audio_track_->GetSource()->SetVolume(0.5); + run_loop_.Flush(); EXPECT_TRUE(voice_media_channel_->GetOutputVolume(kAudioSsrc, &volume)); EXPECT_EQ(0.5, volume); // Disable the audio track, this should prevent setting the volume. audio_track_->set_enabled(false); + RTC_DCHECK_EQ(worker_thread_, run_loop_.task_queue()); + run_loop_.Flush(); audio_track_->GetSource()->SetVolume(0.8); EXPECT_TRUE(voice_media_channel_->GetOutputVolume(kAudioSsrc, &volume)); EXPECT_EQ(0, volume); // When the track is enabled, the previously set volume should take effect. audio_track_->set_enabled(true); + run_loop_.Flush(); EXPECT_TRUE(voice_media_channel_->GetOutputVolume(kAudioSsrc, &volume)); EXPECT_EQ(0.8, volume); // Try changing volume one more time. audio_track_->GetSource()->SetVolume(0.9); + run_loop_.Flush(); EXPECT_TRUE(voice_media_channel_->GetOutputVolume(kAudioSsrc, &volume)); EXPECT_EQ(0.9, volume); @@ -671,12 +701,14 @@ TEST_F(RtpSenderReceiverTest, AudioRtpReceiverDelay) { CreateAudioRtpReceiver(); VerifyRtpReceiverDelayBehaviour(voice_media_channel_, audio_rtp_receiver_.get(), kAudioSsrc); + DestroyAudioRtpReceiver(); } TEST_F(RtpSenderReceiverTest, VideoRtpReceiverDelay) { CreateVideoRtpReceiver(); VerifyRtpReceiverDelayBehaviour(video_media_channel_, video_rtp_receiver_.get(), kVideoSsrc); + DestroyVideoRtpReceiver(); } // Test that the media channel isn't enabled for sending if the audio sender @@ -1570,6 +1602,7 @@ TEST_F(RtpSenderReceiverTest, AudioReceiverCanSetFrameDecryptor) { audio_rtp_receiver_->SetFrameDecryptor(fake_frame_decryptor); EXPECT_EQ(fake_frame_decryptor.get(), audio_rtp_receiver_->GetFrameDecryptor().get()); + DestroyAudioRtpReceiver(); } // Validate that the default FrameEncryptor setting is nullptr. @@ -1581,6 +1614,7 @@ TEST_F(RtpSenderReceiverTest, AudioReceiverCannotSetFrameDecryptorAfterStop) { audio_rtp_receiver_->Stop(); audio_rtp_receiver_->SetFrameDecryptor(fake_frame_decryptor); // TODO(webrtc:9926) - Validate media channel not set once fakes updated. + DestroyAudioRtpReceiver(); } // Validate that the default FrameEncryptor setting is nullptr. @@ -1615,6 +1649,7 @@ TEST_F(RtpSenderReceiverTest, VideoReceiverCanSetFrameDecryptor) { video_rtp_receiver_->SetFrameDecryptor(fake_frame_decryptor); EXPECT_EQ(fake_frame_decryptor.get(), video_rtp_receiver_->GetFrameDecryptor().get()); + DestroyVideoRtpReceiver(); } // Validate that the default FrameEncryptor setting is nullptr. @@ -1626,6 +1661,7 @@ TEST_F(RtpSenderReceiverTest, VideoReceiverCannotSetFrameDecryptorAfterStop) { video_rtp_receiver_->Stop(); video_rtp_receiver_->SetFrameDecryptor(fake_frame_decryptor); // TODO(webrtc:9926) - Validate media channel not set once fakes updated. + DestroyVideoRtpReceiver(); } // Checks that calling the internal methods for get/set parameters do not diff --git a/pc/rtp_transceiver.cc b/pc/rtp_transceiver.cc index d11e04b277..a78b9d6be6 100644 --- a/pc/rtp_transceiver.cc +++ b/pc/rtp_transceiver.cc @@ -10,18 +10,23 @@ #include "pc/rtp_transceiver.h" +#include #include #include #include #include "absl/algorithm/container.h" #include "api/rtp_parameters.h" +#include "api/sequence_checker.h" +#include "media/base/codec.h" +#include "media/base/media_constants.h" #include "pc/channel_manager.h" #include "pc/rtp_media_utils.h" -#include "pc/rtp_parameters_conversion.h" #include "pc/session_description.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" +#include "rtc_base/task_utils/to_queued_task.h" +#include "rtc_base/thread.h" namespace webrtc { namespace { @@ -108,12 +113,16 @@ TaskQueueBase* GetCurrentTaskQueueOrThread() { } // namespace -RtpTransceiver::RtpTransceiver(cricket::MediaType media_type) +RtpTransceiver::RtpTransceiver( + cricket::MediaType media_type, + cricket::ChannelManager* channel_manager /* = nullptr*/) : thread_(GetCurrentTaskQueueOrThread()), unified_plan_(false), - media_type_(media_type) { + media_type_(media_type), + channel_manager_(channel_manager) { RTC_DCHECK(media_type == cricket::MEDIA_TYPE_AUDIO || media_type == cricket::MEDIA_TYPE_VIDEO); + RTC_DCHECK(channel_manager_); } RtpTransceiver::RtpTransceiver( @@ -132,52 +141,86 @@ RtpTransceiver::RtpTransceiver( RTC_DCHECK(media_type_ == cricket::MEDIA_TYPE_AUDIO || media_type_ == cricket::MEDIA_TYPE_VIDEO); RTC_DCHECK_EQ(sender->media_type(), receiver->media_type()); + RTC_DCHECK(channel_manager_); senders_.push_back(sender); receivers_.push_back(receiver); } RtpTransceiver::~RtpTransceiver() { - StopInternal(); + // TODO(tommi): On Android, when running PeerConnectionClientTest (e.g. + // PeerConnectionClientTest#testCameraSwitch), the instance doesn't get + // deleted on `thread_`. See if we can fix that. + if (!stopped_) { + RTC_DCHECK_RUN_ON(thread_); + StopInternal(); + } } void RtpTransceiver::SetChannel(cricket::ChannelInterface* channel) { + RTC_DCHECK_RUN_ON(thread_); // Cannot set a non-null channel on a stopped transceiver. if (stopped_ && channel) { return; } + RTC_DCHECK(channel || channel_); + + RTC_LOG_THREAD_BLOCK_COUNT(); + + if (channel_) { + signaling_thread_safety_->SetNotAlive(); + signaling_thread_safety_ = nullptr; + } + if (channel) { RTC_DCHECK_EQ(media_type(), channel->media_type()); + signaling_thread_safety_ = PendingTaskSafetyFlag::Create(); } - if (channel_) { - channel_->SignalFirstPacketReceived().disconnect(this); - } + // An alternative to this, could be to require SetChannel to be called + // on the network thread. The channel object operates for the most part + // on the network thread, as part of its initialization being on the network + // thread is required, so setting a channel object as part of the construction + // (without thread hopping) might be the more efficient thing to do than + // how SetChannel works today. + // Similarly, if the channel() accessor is limited to the network thread, that + // helps with keeping the channel implementation requirements being met and + // avoids synchronization for accessing the pointer or network related state. + channel_manager_->network_thread()->Invoke(RTC_FROM_HERE, [&]() { + if (channel_) { + channel_->SetFirstPacketReceivedCallback(nullptr); + } - channel_ = channel; + channel_ = channel; - if (channel_) { - channel_->SignalFirstPacketReceived().connect( - this, &RtpTransceiver::OnFirstPacketReceived); - } + if (channel_) { + channel_->SetFirstPacketReceivedCallback( + [thread = thread_, flag = signaling_thread_safety_, this]() mutable { + thread->PostTask(ToQueuedTask( + std::move(flag), [this]() { OnFirstPacketReceived(); })); + }); + } + }); for (const auto& sender : senders_) { sender->internal()->SetMediaChannel(channel_ ? channel_->media_channel() : nullptr); } + RTC_DCHECK_BLOCK_COUNT_NO_MORE_THAN(1); + for (const auto& receiver : receivers_) { if (!channel_) { receiver->internal()->Stop(); + } else { + receiver->internal()->SetMediaChannel(channel_->media_channel()); } - - receiver->internal()->SetMediaChannel(channel_ ? channel_->media_channel() - : nullptr); } } void RtpTransceiver::AddSender( rtc::scoped_refptr> sender) { + RTC_DCHECK_RUN_ON(thread_); RTC_DCHECK(!stopped_); RTC_DCHECK(!unified_plan_); RTC_DCHECK(sender); @@ -203,6 +246,7 @@ bool RtpTransceiver::RemoveSender(RtpSenderInterface* sender) { void RtpTransceiver::AddReceiver( rtc::scoped_refptr> receiver) { + RTC_DCHECK_RUN_ON(thread_); RTC_DCHECK(!stopped_); RTC_DCHECK(!unified_plan_); RTC_DCHECK(receiver); @@ -220,12 +264,8 @@ bool RtpTransceiver::RemoveReceiver(RtpReceiverInterface* receiver) { if (it == receivers_.end()) { return false; } + // `Stop()` will clear the internally cached pointer to the media channel. (*it)->internal()->Stop(); - // After the receiver has been removed, there's no guarantee that the - // contained media channel isn't deleted shortly after this. To make sure that - // the receiver doesn't spontaneously try to use it's (potentially stale) - // media channel reference, we clear it out. - (*it)->internal()->SetMediaChannel(nullptr); receivers_.erase(it); return true; } @@ -251,7 +291,7 @@ absl::optional RtpTransceiver::mid() const { return mid_; } -void RtpTransceiver::OnFirstPacketReceived(cricket::ChannelInterface*) { +void RtpTransceiver::OnFirstPacketReceived() { for (const auto& receiver : receivers_) { receiver->internal()->NotifyFirstPacketReceived(); } @@ -288,6 +328,7 @@ void RtpTransceiver::set_fired_direction(RtpTransceiverDirection direction) { } bool RtpTransceiver::stopped() const { + RTC_DCHECK_RUN_ON(thread_); return stopped_; } @@ -388,6 +429,7 @@ RTCError RtpTransceiver::StopStandard() { } void RtpTransceiver::StopInternal() { + RTC_DCHECK_RUN_ON(thread_); StopTransceiverProcedure(); } @@ -459,10 +501,9 @@ RtpTransceiver::HeaderExtensionsToOffer() const { std::vector RtpTransceiver::HeaderExtensionsNegotiated() const { - if (!channel_) - return {}; + RTC_DCHECK_RUN_ON(thread_); std::vector result; - for (const auto& ext : channel_->GetNegotiatedRtpHeaderExtensions()) { + for (const auto& ext : negotiated_header_extensions_) { result.emplace_back(ext.uri, ext.id, RtpTransceiverDirection::kSendRecv); } return result; @@ -512,6 +553,15 @@ RTCError RtpTransceiver::SetOfferedRtpHeaderExtensions( return RTCError::OK(); } +void RtpTransceiver::OnNegotiationUpdate( + SdpType sdp_type, + const cricket::MediaContentDescription* content) { + RTC_DCHECK_RUN_ON(thread_); + RTC_DCHECK(content); + if (sdp_type == SdpType::kAnswer) + negotiated_header_extensions_ = content->rtp_header_extensions(); +} + void RtpTransceiver::SetPeerConnectionClosed() { is_pc_closed_ = true; } diff --git a/pc/rtp_transceiver.h b/pc/rtp_transceiver.h index 57dbaeea85..6b1307b1db 100644 --- a/pc/rtp_transceiver.h +++ b/pc/rtp_transceiver.h @@ -11,14 +11,33 @@ #ifndef PC_RTP_TRANSCEIVER_H_ #define PC_RTP_TRANSCEIVER_H_ +#include + +#include +#include #include #include +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/media_types.h" +#include "api/rtc_error.h" +#include "api/rtp_parameters.h" +#include "api/rtp_transceiver_direction.h" #include "api/rtp_transceiver_interface.h" +#include "api/scoped_refptr.h" +#include "api/task_queue/task_queue_base.h" #include "pc/channel_interface.h" #include "pc/channel_manager.h" +#include "pc/proxy.h" #include "pc/rtp_receiver.h" +#include "pc/rtp_receiver_proxy.h" #include "pc/rtp_sender.h" +#include "pc/rtp_sender_proxy.h" +#include "rtc_base/ref_counted_object.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/third_party/sigslot/sigslot.h" +#include "rtc_base/thread_annotations.h" namespace webrtc { @@ -60,7 +79,8 @@ class RtpTransceiver final // channel set. // |media_type| specifies the type of RtpTransceiver (and, by transitivity, // the type of senders, receivers, and channel). Can either by audio or video. - explicit RtpTransceiver(cricket::MediaType media_type); + RtpTransceiver(cricket::MediaType media_type, + cricket::ChannelManager* channel_manager); // Construct a Unified Plan-style RtpTransceiver with the given sender and // receiver. The media type will be derived from the media types of the sender // and receiver. The sender and receiver should have the same media type. @@ -213,21 +233,32 @@ class RtpTransceiver final rtc::ArrayView header_extensions_to_offer) override; + // Called on the signaling thread when the local or remote content description + // is updated. Used to update the negotiated header extensions. + // TODO(tommi): The implementation of this method is currently very simple and + // only used for updating the negotiated headers. However, we're planning to + // move all the updates done on the channel from the transceiver into this + // method. This will happen with the ownership of the channel object being + // moved into the transceiver. + void OnNegotiationUpdate(SdpType sdp_type, + const cricket::MediaContentDescription* content); + private: - void OnFirstPacketReceived(cricket::ChannelInterface* channel); + void OnFirstPacketReceived(); void StopSendingAndReceiving(); // Enforce that this object is created, used and destroyed on one thread. - const TaskQueueBase* thread_; + TaskQueueBase* const thread_; const bool unified_plan_; const cricket::MediaType media_type_; + rtc::scoped_refptr signaling_thread_safety_; std::vector>> senders_; std::vector< rtc::scoped_refptr>> receivers_; - bool stopped_ = false; + bool stopped_ RTC_GUARDED_BY(thread_) = false; bool stopping_ RTC_GUARDED_BY(thread_) = false; bool is_pc_closed_ = false; RtpTransceiverDirection direction_ = RtpTransceiverDirection::kInactive; @@ -243,11 +274,19 @@ class RtpTransceiver final cricket::ChannelManager* channel_manager_ = nullptr; std::vector codec_preferences_; std::vector header_extensions_to_offer_; + + // |negotiated_header_extensions_| is read and written to on the signaling + // thread from the SdpOfferAnswerHandler class (e.g. + // PushdownMediaDescription(). + cricket::RtpHeaderExtensions negotiated_header_extensions_ + RTC_GUARDED_BY(thread_); + const std::function on_negotiation_needed_; }; -BEGIN_SIGNALING_PROXY_MAP(RtpTransceiver) -PROXY_SIGNALING_THREAD_DESTRUCTOR() +BEGIN_PRIMARY_PROXY_MAP(RtpTransceiver) + +PROXY_PRIMARY_THREAD_DESTRUCTOR() BYPASS_PROXY_CONSTMETHOD0(cricket::MediaType, media_type) PROXY_CONSTMETHOD0(absl::optional, mid) PROXY_CONSTMETHOD0(rtc::scoped_refptr, sender) @@ -271,7 +310,7 @@ PROXY_CONSTMETHOD0(std::vector, PROXY_METHOD1(webrtc::RTCError, SetOfferedRtpHeaderExtensions, rtc::ArrayView) -END_PROXY_MAP() +END_PROXY_MAP(RtpTransceiver) } // namespace webrtc diff --git a/pc/rtp_transceiver_unittest.cc b/pc/rtp_transceiver_unittest.cc index 7b72620305..0128e912e3 100644 --- a/pc/rtp_transceiver_unittest.cc +++ b/pc/rtp_transceiver_unittest.cc @@ -23,6 +23,7 @@ #include "test/gmock.h" #include "test/gtest.h" +using ::testing::_; using ::testing::ElementsAre; using ::testing::Optional; using ::testing::Property; @@ -33,13 +34,13 @@ namespace webrtc { // Checks that a channel cannot be set on a stopped |RtpTransceiver|. TEST(RtpTransceiverTest, CannotSetChannelOnStoppedTransceiver) { - RtpTransceiver transceiver(cricket::MediaType::MEDIA_TYPE_AUDIO); + auto cm = cricket::ChannelManager::Create( + nullptr, true, rtc::Thread::Current(), rtc::Thread::Current()); + RtpTransceiver transceiver(cricket::MediaType::MEDIA_TYPE_AUDIO, cm.get()); cricket::MockChannelInterface channel1; - sigslot::signal1 signal; EXPECT_CALL(channel1, media_type()) .WillRepeatedly(Return(cricket::MediaType::MEDIA_TYPE_AUDIO)); - EXPECT_CALL(channel1, SignalFirstPacketReceived()) - .WillRepeatedly(ReturnRef(signal)); + EXPECT_CALL(channel1, SetFirstPacketReceivedCallback(_)); transceiver.SetChannel(&channel1); EXPECT_EQ(&channel1, transceiver.channel()); @@ -59,13 +60,14 @@ TEST(RtpTransceiverTest, CannotSetChannelOnStoppedTransceiver) { // Checks that a channel can be unset on a stopped |RtpTransceiver| TEST(RtpTransceiverTest, CanUnsetChannelOnStoppedTransceiver) { - RtpTransceiver transceiver(cricket::MediaType::MEDIA_TYPE_VIDEO); + auto cm = cricket::ChannelManager::Create( + nullptr, true, rtc::Thread::Current(), rtc::Thread::Current()); + RtpTransceiver transceiver(cricket::MediaType::MEDIA_TYPE_VIDEO, cm.get()); cricket::MockChannelInterface channel; - sigslot::signal1 signal; EXPECT_CALL(channel, media_type()) .WillRepeatedly(Return(cricket::MediaType::MEDIA_TYPE_VIDEO)); - EXPECT_CALL(channel, SignalFirstPacketReceived()) - .WillRepeatedly(ReturnRef(signal)); + EXPECT_CALL(channel, SetFirstPacketReceivedCallback(_)) + .WillRepeatedly(testing::Return()); transceiver.SetChannel(&channel); EXPECT_EQ(&channel, transceiver.channel()); @@ -82,26 +84,48 @@ TEST(RtpTransceiverTest, CanUnsetChannelOnStoppedTransceiver) { class RtpTransceiverUnifiedPlanTest : public ::testing::Test { public: RtpTransceiverUnifiedPlanTest() - : channel_manager_(std::make_unique(), - std::make_unique(), - rtc::Thread::Current(), - rtc::Thread::Current()), + : channel_manager_(cricket::ChannelManager::Create( + std::make_unique(), + false, + rtc::Thread::Current(), + rtc::Thread::Current())), transceiver_(RtpSenderProxyWithInternal::Create( rtc::Thread::Current(), - new rtc::RefCountedObject()), + sender_), RtpReceiverProxyWithInternal::Create( rtc::Thread::Current(), - new rtc::RefCountedObject()), - &channel_manager_, - channel_manager_.GetSupportedAudioRtpHeaderExtensions(), + rtc::Thread::Current(), + receiver_), + channel_manager_.get(), + channel_manager_->GetSupportedAudioRtpHeaderExtensions(), /* on_negotiation_needed= */ [] {}) {} - cricket::ChannelManager channel_manager_; + static rtc::scoped_refptr MockReceiver() { + auto receiver = rtc::make_ref_counted(); + EXPECT_CALL(*receiver.get(), media_type()) + .WillRepeatedly(Return(cricket::MediaType::MEDIA_TYPE_AUDIO)); + return receiver; + } + + static rtc::scoped_refptr MockSender() { + auto sender = rtc::make_ref_counted(); + EXPECT_CALL(*sender.get(), media_type()) + .WillRepeatedly(Return(cricket::MediaType::MEDIA_TYPE_AUDIO)); + return sender; + } + + rtc::scoped_refptr receiver_ = MockReceiver(); + rtc::scoped_refptr sender_ = MockSender(); + std::unique_ptr channel_manager_; RtpTransceiver transceiver_; }; // Basic tests for Stop() TEST_F(RtpTransceiverUnifiedPlanTest, StopSetsDirection) { + EXPECT_CALL(*receiver_.get(), StopAndEndTrack()); + EXPECT_CALL(*sender_.get(), SetTransceiverAsStopped()); + EXPECT_CALL(*sender_.get(), Stop()); + EXPECT_EQ(RtpTransceiverDirection::kInactive, transceiver_.direction()); EXPECT_FALSE(transceiver_.current_direction()); transceiver_.StopStandard(); @@ -117,10 +141,11 @@ TEST_F(RtpTransceiverUnifiedPlanTest, StopSetsDirection) { class RtpTransceiverTestForHeaderExtensions : public ::testing::Test { public: RtpTransceiverTestForHeaderExtensions() - : channel_manager_(std::make_unique(), - std::make_unique(), - rtc::Thread::Current(), - rtc::Thread::Current()), + : channel_manager_(cricket::ChannelManager::Create( + std::make_unique(), + false, + rtc::Thread::Current(), + rtc::Thread::Current())), extensions_( {RtpHeaderExtensionCapability("uri1", 1, @@ -136,24 +161,50 @@ class RtpTransceiverTestForHeaderExtensions : public ::testing::Test { RtpTransceiverDirection::kSendRecv)}), transceiver_(RtpSenderProxyWithInternal::Create( rtc::Thread::Current(), - new rtc::RefCountedObject()), + sender_), RtpReceiverProxyWithInternal::Create( rtc::Thread::Current(), - new rtc::RefCountedObject()), - &channel_manager_, + rtc::Thread::Current(), + receiver_), + channel_manager_.get(), extensions_, /* on_negotiation_needed= */ [] {}) {} - cricket::ChannelManager channel_manager_; + static rtc::scoped_refptr MockReceiver() { + auto receiver = rtc::make_ref_counted(); + EXPECT_CALL(*receiver.get(), media_type()) + .WillRepeatedly(Return(cricket::MediaType::MEDIA_TYPE_AUDIO)); + return receiver; + } + + static rtc::scoped_refptr MockSender() { + auto sender = rtc::make_ref_counted(); + EXPECT_CALL(*sender.get(), media_type()) + .WillRepeatedly(Return(cricket::MediaType::MEDIA_TYPE_AUDIO)); + return sender; + } + + rtc::scoped_refptr receiver_ = MockReceiver(); + rtc::scoped_refptr sender_ = MockSender(); + + std::unique_ptr channel_manager_; std::vector extensions_; RtpTransceiver transceiver_; }; TEST_F(RtpTransceiverTestForHeaderExtensions, OffersChannelManagerList) { + EXPECT_CALL(*receiver_.get(), StopAndEndTrack()); + EXPECT_CALL(*sender_.get(), SetTransceiverAsStopped()); + EXPECT_CALL(*sender_.get(), Stop()); + EXPECT_EQ(transceiver_.HeaderExtensionsToOffer(), extensions_); } TEST_F(RtpTransceiverTestForHeaderExtensions, ModifiesDirection) { + EXPECT_CALL(*receiver_.get(), StopAndEndTrack()); + EXPECT_CALL(*sender_.get(), SetTransceiverAsStopped()); + EXPECT_CALL(*sender_.get(), Stop()); + auto modified_extensions = extensions_; modified_extensions[0].direction = RtpTransceiverDirection::kSendOnly; EXPECT_TRUE( @@ -174,6 +225,10 @@ TEST_F(RtpTransceiverTestForHeaderExtensions, ModifiesDirection) { } TEST_F(RtpTransceiverTestForHeaderExtensions, AcceptsStoppedExtension) { + EXPECT_CALL(*receiver_.get(), StopAndEndTrack()); + EXPECT_CALL(*sender_.get(), SetTransceiverAsStopped()); + EXPECT_CALL(*sender_.get(), Stop()); + auto modified_extensions = extensions_; modified_extensions[0].direction = RtpTransceiverDirection::kStopped; EXPECT_TRUE( @@ -182,6 +237,10 @@ TEST_F(RtpTransceiverTestForHeaderExtensions, AcceptsStoppedExtension) { } TEST_F(RtpTransceiverTestForHeaderExtensions, RejectsUnsupportedExtension) { + EXPECT_CALL(*receiver_.get(), StopAndEndTrack()); + EXPECT_CALL(*sender_.get(), SetTransceiverAsStopped()); + EXPECT_CALL(*sender_.get(), Stop()); + std::vector modified_extensions( {RtpHeaderExtensionCapability("uri3", 1, RtpTransceiverDirection::kSendRecv)}); @@ -192,6 +251,10 @@ TEST_F(RtpTransceiverTestForHeaderExtensions, RejectsUnsupportedExtension) { TEST_F(RtpTransceiverTestForHeaderExtensions, RejectsStoppedMandatoryExtensions) { + EXPECT_CALL(*receiver_.get(), StopAndEndTrack()); + EXPECT_CALL(*sender_.get(), SetTransceiverAsStopped()); + EXPECT_CALL(*sender_.get(), Stop()); + std::vector modified_extensions = extensions_; // Attempting to stop the mandatory MID extension. modified_extensions[2].direction = RtpTransceiverDirection::kStopped; @@ -208,28 +271,47 @@ TEST_F(RtpTransceiverTestForHeaderExtensions, TEST_F(RtpTransceiverTestForHeaderExtensions, NoNegotiatedHdrExtsWithoutChannel) { + EXPECT_CALL(*receiver_.get(), StopAndEndTrack()); + EXPECT_CALL(*sender_.get(), SetTransceiverAsStopped()); + EXPECT_CALL(*sender_.get(), Stop()); EXPECT_THAT(transceiver_.HeaderExtensionsNegotiated(), ElementsAre()); } TEST_F(RtpTransceiverTestForHeaderExtensions, NoNegotiatedHdrExtsWithChannelWithoutNegotiation) { + EXPECT_CALL(*receiver_.get(), SetMediaChannel(_)); + EXPECT_CALL(*receiver_.get(), StopAndEndTrack()); + EXPECT_CALL(*sender_.get(), SetMediaChannel(_)); + EXPECT_CALL(*sender_.get(), SetTransceiverAsStopped()); + EXPECT_CALL(*sender_.get(), Stop()); cricket::MockChannelInterface mock_channel; - sigslot::signal1 signal; - ON_CALL(mock_channel, SignalFirstPacketReceived) - .WillByDefault(ReturnRef(signal)); + EXPECT_CALL(mock_channel, SetFirstPacketReceivedCallback(_)); + EXPECT_CALL(mock_channel, media_type()) + .WillRepeatedly(Return(cricket::MediaType::MEDIA_TYPE_AUDIO)); + EXPECT_CALL(mock_channel, media_channel()).WillRepeatedly(Return(nullptr)); transceiver_.SetChannel(&mock_channel); EXPECT_THAT(transceiver_.HeaderExtensionsNegotiated(), ElementsAre()); } TEST_F(RtpTransceiverTestForHeaderExtensions, ReturnsNegotiatedHdrExts) { + EXPECT_CALL(*receiver_.get(), SetMediaChannel(_)); + EXPECT_CALL(*receiver_.get(), StopAndEndTrack()); + EXPECT_CALL(*sender_.get(), SetMediaChannel(_)); + EXPECT_CALL(*sender_.get(), SetTransceiverAsStopped()); + EXPECT_CALL(*sender_.get(), Stop()); + cricket::MockChannelInterface mock_channel; - sigslot::signal1 signal; - ON_CALL(mock_channel, SignalFirstPacketReceived) - .WillByDefault(ReturnRef(signal)); + EXPECT_CALL(mock_channel, SetFirstPacketReceivedCallback(_)); + EXPECT_CALL(mock_channel, media_type()) + .WillRepeatedly(Return(cricket::MediaType::MEDIA_TYPE_AUDIO)); + EXPECT_CALL(mock_channel, media_channel()).WillRepeatedly(Return(nullptr)); + cricket::RtpHeaderExtensions extensions = {webrtc::RtpExtension("uri1", 1), webrtc::RtpExtension("uri2", 2)}; - EXPECT_CALL(mock_channel, GetNegotiatedRtpHeaderExtensions) - .WillOnce(Return(extensions)); + cricket::AudioContentDescription description; + description.set_rtp_header_extensions(extensions); + transceiver_.OnNegotiationUpdate(SdpType::kAnswer, &description); + transceiver_.SetChannel(&mock_channel); EXPECT_THAT(transceiver_.HeaderExtensionsNegotiated(), ElementsAre(RtpHeaderExtensionCapability( @@ -240,23 +322,27 @@ TEST_F(RtpTransceiverTestForHeaderExtensions, ReturnsNegotiatedHdrExts) { TEST_F(RtpTransceiverTestForHeaderExtensions, ReturnsNegotiatedHdrExtsSecondTime) { - cricket::MockChannelInterface mock_channel; - sigslot::signal1 signal; - ON_CALL(mock_channel, SignalFirstPacketReceived) - .WillByDefault(ReturnRef(signal)); + EXPECT_CALL(*receiver_.get(), StopAndEndTrack()); + EXPECT_CALL(*sender_.get(), SetTransceiverAsStopped()); + EXPECT_CALL(*sender_.get(), Stop()); + cricket::RtpHeaderExtensions extensions = {webrtc::RtpExtension("uri1", 1), webrtc::RtpExtension("uri2", 2)}; + cricket::AudioContentDescription description; + description.set_rtp_header_extensions(extensions); + transceiver_.OnNegotiationUpdate(SdpType::kAnswer, &description); - EXPECT_CALL(mock_channel, GetNegotiatedRtpHeaderExtensions) - .WillOnce(Return(extensions)); - transceiver_.SetChannel(&mock_channel); - transceiver_.HeaderExtensionsNegotiated(); - testing::Mock::VerifyAndClearExpectations(&mock_channel); + EXPECT_THAT(transceiver_.HeaderExtensionsNegotiated(), + ElementsAre(RtpHeaderExtensionCapability( + "uri1", 1, RtpTransceiverDirection::kSendRecv), + RtpHeaderExtensionCapability( + "uri2", 2, RtpTransceiverDirection::kSendRecv))); extensions = {webrtc::RtpExtension("uri3", 4), webrtc::RtpExtension("uri5", 6)}; - EXPECT_CALL(mock_channel, GetNegotiatedRtpHeaderExtensions) - .WillOnce(Return(extensions)); + description.set_rtp_header_extensions(extensions); + transceiver_.OnNegotiationUpdate(SdpType::kAnswer, &description); + EXPECT_THAT(transceiver_.HeaderExtensionsNegotiated(), ElementsAre(RtpHeaderExtensionCapability( "uri3", 4, RtpTransceiverDirection::kSendRecv), diff --git a/pc/rtp_transmission_manager.cc b/pc/rtp_transmission_manager.cc index e796f9b1b1..9040a69699 100644 --- a/pc/rtp_transmission_manager.cc +++ b/pc/rtp_transmission_manager.cc @@ -11,6 +11,7 @@ #include "pc/rtp_transmission_manager.h" #include +#include #include "absl/types/optional.h" #include "api/peer_connection_interface.h" @@ -240,14 +241,17 @@ RtpTransmissionManager::CreateReceiver(cricket::MediaType media_type, receiver; if (media_type == cricket::MEDIA_TYPE_AUDIO) { receiver = RtpReceiverProxyWithInternal::Create( - signaling_thread(), new AudioRtpReceiver(worker_thread(), receiver_id, - std::vector({}))); + signaling_thread(), worker_thread(), + rtc::make_ref_counted(worker_thread(), receiver_id, + std::vector({}), + IsUnifiedPlan())); NoteUsageEvent(UsageEvent::AUDIO_ADDED); } else { RTC_DCHECK_EQ(media_type, cricket::MEDIA_TYPE_VIDEO); receiver = RtpReceiverProxyWithInternal::Create( - signaling_thread(), new VideoRtpReceiver(worker_thread(), receiver_id, - std::vector({}))); + signaling_thread(), worker_thread(), + rtc::make_ref_counted(worker_thread(), receiver_id, + std::vector({}))); NoteUsageEvent(UsageEvent::VIDEO_ADDED); } return receiver; @@ -452,8 +456,8 @@ void RtpTransmissionManager::CreateAudioReceiver( streams.push_back(rtc::scoped_refptr(stream)); // TODO(https://crbug.com/webrtc/9480): When we remove remote_streams(), use // the constructor taking stream IDs instead. - auto* audio_receiver = new AudioRtpReceiver( - worker_thread(), remote_sender_info.sender_id, streams); + auto audio_receiver = rtc::make_ref_counted( + worker_thread(), remote_sender_info.sender_id, streams, IsUnifiedPlan()); audio_receiver->SetMediaChannel(voice_media_channel()); if (remote_sender_info.sender_id == kDefaultAudioSenderId) { audio_receiver->SetupUnsignaledMediaChannel(); @@ -461,7 +465,7 @@ void RtpTransmissionManager::CreateAudioReceiver( audio_receiver->SetupMediaChannel(remote_sender_info.first_ssrc); } auto receiver = RtpReceiverProxyWithInternal::Create( - signaling_thread(), audio_receiver); + signaling_thread(), worker_thread(), std::move(audio_receiver)); GetAudioTransceiver()->internal()->AddReceiver(receiver); Observer()->OnAddTrack(receiver, streams); NoteUsageEvent(UsageEvent::AUDIO_ADDED); @@ -475,7 +479,7 @@ void RtpTransmissionManager::CreateVideoReceiver( streams.push_back(rtc::scoped_refptr(stream)); // TODO(https://crbug.com/webrtc/9480): When we remove remote_streams(), use // the constructor taking stream IDs instead. - auto* video_receiver = new VideoRtpReceiver( + auto video_receiver = rtc::make_ref_counted( worker_thread(), remote_sender_info.sender_id, streams); video_receiver->SetMediaChannel(video_media_channel()); if (remote_sender_info.sender_id == kDefaultVideoSenderId) { @@ -484,7 +488,7 @@ void RtpTransmissionManager::CreateVideoReceiver( video_receiver->SetupMediaChannel(remote_sender_info.first_ssrc); } auto receiver = RtpReceiverProxyWithInternal::Create( - signaling_thread(), video_receiver); + signaling_thread(), worker_thread(), std::move(video_receiver)); GetVideoTransceiver()->internal()->AddReceiver(receiver); Observer()->OnAddTrack(receiver, streams); NoteUsageEvent(UsageEvent::VIDEO_ADDED); diff --git a/pc/rtp_transmission_manager.h b/pc/rtp_transmission_manager.h index 731c3b74dd..fe0e3abdd3 100644 --- a/pc/rtp_transmission_manager.h +++ b/pc/rtp_transmission_manager.h @@ -12,6 +12,7 @@ #define PC_RTP_TRANSMISSION_MANAGER_H_ #include + #include #include #include @@ -24,6 +25,7 @@ #include "api/rtp_receiver_interface.h" #include "api/rtp_sender_interface.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "media/base/media_channel.h" #include "pc/channel_manager.h" #include "pc/rtp_receiver.h" @@ -32,10 +34,10 @@ #include "pc/stats_collector_interface.h" #include "pc/transceiver_list.h" #include "pc/usage_pattern.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/thread.h" #include "rtc_base/thread_annotations.h" +#include "rtc_base/weak_ptr.h" namespace rtc { class Thread; diff --git a/pc/rtp_transport.cc b/pc/rtp_transport.cc index fe7357fc94..d4edb9501c 100644 --- a/pc/rtp_transport.cc +++ b/pc/rtp_transport.cc @@ -11,12 +11,11 @@ #include "pc/rtp_transport.h" #include - #include #include -#include "api/rtp_headers.h" -#include "api/rtp_parameters.h" +#include "absl/strings/string_view.h" +#include "api/array_view.h" #include "media/base/rtp_utils.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "rtc_base/checks.h" @@ -182,16 +181,16 @@ bool RtpTransport::UnregisterRtpDemuxerSink(RtpPacketSinkInterface* sink) { void RtpTransport::DemuxPacket(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) { - webrtc::RtpPacketReceived parsed_packet(&header_extension_map_); + webrtc::RtpPacketReceived parsed_packet( + &header_extension_map_, packet_time_us == -1 + ? Timestamp::MinusInfinity() + : Timestamp::Micros(packet_time_us)); if (!parsed_packet.Parse(std::move(packet))) { RTC_LOG(LS_ERROR) << "Failed to parse the incoming RTP packet before demuxing. Drop it."; return; } - if (packet_time_us != -1) { - parsed_packet.set_arrival_time_ms((packet_time_us + 500) / 1000); - } if (!rtp_demuxer_.OnRtpPacket(parsed_packet)) { RTC_LOG(LS_WARNING) << "Failed to demux RTP packet: " << RtpDemuxer::DescribePacket(parsed_packet); diff --git a/pc/rtp_transport.h b/pc/rtp_transport.h index 57ad9e5fd0..893d91e734 100644 --- a/pc/rtp_transport.h +++ b/pc/rtp_transport.h @@ -11,11 +11,22 @@ #ifndef PC_RTP_TRANSPORT_H_ #define PC_RTP_TRANSPORT_H_ +#include +#include + #include +#include "absl/types/optional.h" #include "call/rtp_demuxer.h" #include "modules/rtp_rtcp/include/rtp_header_extension_map.h" +#include "p2p/base/packet_transport_internal.h" #include "pc/rtp_transport_internal.h" +#include "pc/session_description.h" +#include "rtc_base/async_packet_socket.h" +#include "rtc_base/copy_on_write_buffer.h" +#include "rtc_base/network/sent_packet.h" +#include "rtc_base/network_route.h" +#include "rtc_base/socket.h" #include "rtc_base/third_party/sigslot/sigslot.h" namespace rtc { diff --git a/pc/scenario_tests/goog_cc_test.cc b/pc/scenario_tests/goog_cc_test.cc index 4a996b8684..d9e27e2edf 100644 --- a/pc/scenario_tests/goog_cc_test.cc +++ b/pc/scenario_tests/goog_cc_test.cc @@ -73,8 +73,8 @@ TEST(GoogCcPeerScenarioTest, MAYBE_NoBweChangeFromVideoUnmute) { ASSERT_EQ(num_video_streams, 1); // Exactly 1 video stream. auto get_bwe = [&] { - rtc::scoped_refptr callback( - new rtc::RefCountedObject()); + auto callback = + rtc::make_ref_counted(); caller->pc()->GetStats(callback); s.net()->time_controller()->Wait([&] { return callback->called(); }); auto stats = diff --git a/pc/sctp_data_channel.cc b/pc/sctp_data_channel.cc index c4357a8da6..0e4ef7de88 100644 --- a/pc/sctp_data_channel.cc +++ b/pc/sctp_data_channel.cc @@ -10,17 +10,19 @@ #include "pc/sctp_data_channel.h" +#include #include #include #include -#include "api/proxy.h" #include "media/sctp/sctp_transport_internal.h" +#include "pc/proxy.h" #include "pc/sctp_utils.h" #include "rtc_base/checks.h" #include "rtc_base/location.h" #include "rtc_base/logging.h" #include "rtc_base/ref_counted_object.h" +#include "rtc_base/system/unused.h" #include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/thread.h" @@ -38,8 +40,8 @@ int GenerateUniqueId() { } // Define proxy for DataChannelInterface. -BEGIN_SIGNALING_PROXY_MAP(DataChannel) -PROXY_SIGNALING_THREAD_DESTRUCTOR() +BEGIN_PRIMARY_PROXY_MAP(DataChannel) +PROXY_PRIMARY_THREAD_DESTRUCTOR() PROXY_METHOD1(void, RegisterObserver, DataChannelObserver*) PROXY_METHOD0(void, UnregisterObserver) BYPASS_PROXY_CONSTMETHOD0(std::string, label) @@ -64,7 +66,7 @@ PROXY_CONSTMETHOD0(uint64_t, buffered_amount) PROXY_METHOD0(void, Close) // TODO(bugs.webrtc.org/11547): Change to run on the network thread. PROXY_METHOD1(bool, Send, const DataBuffer&) -END_PROXY_MAP() +END_PROXY_MAP(DataChannel) } // namespace @@ -78,17 +80,27 @@ InternalDataChannelInit::InternalDataChannelInit(const DataChannelInit& base) // Specified in createDataChannel, WebRTC spec section 6.1 bullet 13. id = -1; } - // Backwards compatibility: If base.maxRetransmits or base.maxRetransmitTime - // have been set to -1, unset them. - if (maxRetransmits && *maxRetransmits == -1) { - RTC_LOG(LS_ERROR) - << "Accepting maxRetransmits = -1 for backwards compatibility"; - maxRetransmits = absl::nullopt; + // Backwards compatibility: If maxRetransmits or maxRetransmitTime + // are negative, the feature is not enabled. + // Values are clamped to a 16bit range. + if (maxRetransmits) { + if (*maxRetransmits < 0) { + RTC_LOG(LS_ERROR) + << "Accepting maxRetransmits < 0 for backwards compatibility"; + maxRetransmits = absl::nullopt; + } else if (*maxRetransmits > std::numeric_limits::max()) { + maxRetransmits = std::numeric_limits::max(); + } } - if (maxRetransmitTime && *maxRetransmitTime == -1) { - RTC_LOG(LS_ERROR) - << "Accepting maxRetransmitTime = -1 for backwards compatibility"; - maxRetransmitTime = absl::nullopt; + + if (maxRetransmitTime) { + if (*maxRetransmitTime < 0) { + RTC_LOG(LS_ERROR) + << "Accepting maxRetransmitTime < 0 for backwards compatibility"; + maxRetransmitTime = absl::nullopt; + } else if (*maxRetransmitTime > std::numeric_limits::max()) { + maxRetransmitTime = std::numeric_limits::max(); + } } } @@ -135,9 +147,8 @@ rtc::scoped_refptr SctpDataChannel::Create( const InternalDataChannelInit& config, rtc::Thread* signaling_thread, rtc::Thread* network_thread) { - rtc::scoped_refptr channel( - new rtc::RefCountedObject( - config, provider, label, signaling_thread, network_thread)); + auto channel = rtc::make_ref_counted( + config, provider, label, signaling_thread, network_thread); if (!channel->Init()) { return nullptr; } @@ -168,6 +179,7 @@ SctpDataChannel::SctpDataChannel(const InternalDataChannelInit& config, observer_(nullptr), provider_(provider) { RTC_DCHECK_RUN_ON(signaling_thread_); + RTC_UNUSED(network_thread_); } bool SctpDataChannel::Init() { @@ -294,13 +306,6 @@ bool SctpDataChannel::Send(const DataBuffer& buffer) { return false; } - // TODO(jiayl): the spec is unclear about if the remote side should get the - // onmessage event. We need to figure out the expected behavior and change the - // code accordingly. - if (buffer.size() == 0) { - return true; - } - buffered_amount_ += buffer.size(); // If the queue is non-empty, we're waiting for SignalReadyToSend, @@ -378,13 +383,11 @@ void SctpDataChannel::OnTransportChannelCreated() { } } -void SctpDataChannel::OnTransportChannelClosed() { - // The SctpTransport is unusable (for example, because the SCTP m= section - // was rejected, or because the DTLS transport closed), so we need to close - // abruptly. - RTCError error = RTCError(RTCErrorType::OPERATION_ERROR_WITH_DATA, - "Transport channel closed"); - error.set_error_detail(RTCErrorDetailType::SCTP_FAILURE); +void SctpDataChannel::OnTransportChannelClosed(RTCError error) { + // The SctpTransport is unusable, which could come from multiplie reasons: + // - the SCTP m= section was rejected + // - the DTLS transport is closed + // - the SCTP transport is closed CloseAbruptlyWithError(std::move(error)); } @@ -403,7 +406,7 @@ void SctpDataChannel::OnDataReceived(const cricket::ReceiveDataParams& params, return; } - if (params.type == cricket::DMT_CONTROL) { + if (params.type == DataMessageType::kControl) { if (handshake_state_ != kHandshakeWaitingForAck) { // Ignore it if we are not expecting an ACK message. RTC_LOG(LS_WARNING) @@ -424,8 +427,8 @@ void SctpDataChannel::OnDataReceived(const cricket::ReceiveDataParams& params, return; } - RTC_DCHECK(params.type == cricket::DMT_BINARY || - params.type == cricket::DMT_TEXT); + RTC_DCHECK(params.type == DataMessageType::kBinary || + params.type == DataMessageType::kText); RTC_LOG(LS_VERBOSE) << "DataChannel received DATA message, sid = " << params.sid; @@ -436,7 +439,7 @@ void SctpDataChannel::OnDataReceived(const cricket::ReceiveDataParams& params, handshake_state_ = kHandshakeReady; } - bool binary = (params.type == cricket::DMT_BINARY); + bool binary = (params.type == webrtc::DataMessageType::kBinary); auto buffer = std::make_unique(payload, binary); if (state_ == kOpen && observer_) { ++messages_received_; @@ -617,7 +620,7 @@ void SctpDataChannel::SendQueuedDataMessages() { bool SctpDataChannel::SendDataMessage(const DataBuffer& buffer, bool queue_if_blocked) { RTC_DCHECK_RUN_ON(signaling_thread_); - cricket::SendDataParams send_params; + SendDataParams send_params; send_params.ordered = config_.ordered; // Send as ordered if it is still going through OPEN/ACK signaling. @@ -628,15 +631,14 @@ bool SctpDataChannel::SendDataMessage(const DataBuffer& buffer, "because the OPEN_ACK message has not been received."; } - send_params.max_rtx_count = - config_.maxRetransmits ? *config_.maxRetransmits : -1; - send_params.max_rtx_ms = - config_.maxRetransmitTime ? *config_.maxRetransmitTime : -1; - send_params.sid = config_.id; - send_params.type = buffer.binary ? cricket::DMT_BINARY : cricket::DMT_TEXT; + send_params.max_rtx_count = config_.maxRetransmits; + send_params.max_rtx_ms = config_.maxRetransmitTime; + send_params.type = + buffer.binary ? DataMessageType::kBinary : DataMessageType::kText; cricket::SendDataResult send_result = cricket::SDR_SUCCESS; - bool success = provider_->SendData(send_params, buffer.data, &send_result); + bool success = + provider_->SendData(config_.id, send_params, buffer.data, &send_result); if (success) { ++messages_sent_; @@ -702,16 +704,16 @@ bool SctpDataChannel::SendControlMessage(const rtc::CopyOnWriteBuffer& buffer) { bool is_open_message = handshake_state_ == kHandshakeShouldSendOpen; RTC_DCHECK(!is_open_message || !config_.negotiated); - cricket::SendDataParams send_params; - send_params.sid = config_.id; + SendDataParams send_params; // Send data as ordered before we receive any message from the remote peer to // make sure the remote peer will not receive any data before it receives the // OPEN message. send_params.ordered = config_.ordered || is_open_message; - send_params.type = cricket::DMT_CONTROL; + send_params.type = DataMessageType::kControl; cricket::SendDataResult send_result = cricket::SDR_SUCCESS; - bool retval = provider_->SendData(send_params, buffer, &send_result); + bool retval = + provider_->SendData(config_.id, send_params, buffer, &send_result); if (retval) { RTC_LOG(LS_VERBOSE) << "Sent CONTROL message on channel " << config_.id; diff --git a/pc/sctp_data_channel.h b/pc/sctp_data_channel.h index 6d121e6f80..b0df48758b 100644 --- a/pc/sctp_data_channel.h +++ b/pc/sctp_data_channel.h @@ -11,18 +11,25 @@ #ifndef PC_SCTP_DATA_CHANNEL_H_ #define PC_SCTP_DATA_CHANNEL_H_ +#include + #include #include #include +#include "absl/types/optional.h" #include "api/data_channel_interface.h" #include "api/priority.h" +#include "api/rtc_error.h" #include "api/scoped_refptr.h" #include "api/transport/data_channel_transport_interface.h" #include "media/base/media_channel.h" #include "pc/data_channel_utils.h" +#include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/ssl_stream_adapter.h" // For SSLRole #include "rtc_base/third_party/sigslot/sigslot.h" +#include "rtc_base/thread.h" +#include "rtc_base/thread_annotations.h" namespace webrtc { @@ -33,7 +40,8 @@ class SctpDataChannel; class SctpDataChannelProviderInterface { public: // Sends the data to the transport. - virtual bool SendData(const cricket::SendDataParams& params, + virtual bool SendData(int sid, + const SendDataParams& params, const rtc::CopyOnWriteBuffer& payload, cricket::SendDataResult* result) = 0; // Connects to the transport signals. @@ -169,8 +177,6 @@ class SctpDataChannel : public DataChannelInterface, void CloseAbruptlyWithError(RTCError error); // Specializations of CloseAbruptlyWithError void CloseAbruptlyWithDataChannelFailure(const std::string& message); - void CloseAbruptlyWithSctpCauseCode(const std::string& message, - uint16_t cause_code); // Slots for provider to connect signals to. // @@ -201,7 +207,7 @@ class SctpDataChannel : public DataChannelInterface, // Called when the transport channel is unusable. // This method makes sure the DataChannel is disconnected and changes state // to kClosed. - void OnTransportChannelClosed(); + void OnTransportChannelClosed(RTCError error); DataChannelStats GetStats() const; diff --git a/pc/sctp_data_channel_transport.cc b/pc/sctp_data_channel_transport.cc index 497e11fcc9..f01f86ebd8 100644 --- a/pc/sctp_data_channel_transport.cc +++ b/pc/sctp_data_channel_transport.cc @@ -9,6 +9,8 @@ */ #include "pc/sctp_data_channel_transport.h" + +#include "absl/types/optional.h" #include "pc/sctp_utils.h" namespace webrtc { @@ -37,18 +39,8 @@ RTCError SctpDataChannelTransport::SendData( int channel_id, const SendDataParams& params, const rtc::CopyOnWriteBuffer& buffer) { - // Map webrtc::SendDataParams to cricket::SendDataParams. - // TODO(mellem): See about unifying these structs. - cricket::SendDataParams sd_params; - sd_params.sid = channel_id; - sd_params.type = ToCricketDataMessageType(params.type); - sd_params.ordered = params.ordered; - sd_params.reliable = !(params.max_rtx_count || params.max_rtx_ms); - sd_params.max_rtx_count = params.max_rtx_count.value_or(-1); - sd_params.max_rtx_ms = params.max_rtx_ms.value_or(-1); - cricket::SendDataResult result; - sctp_transport_->SendData(sd_params, buffer, &result); + sctp_transport_->SendData(channel_id, params, buffer, &result); // TODO(mellem): See about changing the interfaces to not require mapping // SendDataResult to RTCError and back again. @@ -93,8 +85,7 @@ void SctpDataChannelTransport::OnDataReceived( const cricket::ReceiveDataParams& params, const rtc::CopyOnWriteBuffer& buffer) { if (sink_) { - sink_->OnDataReceived(params.sid, ToWebrtcDataMessageType(params.type), - buffer); + sink_->OnDataReceived(params.sid, params.type, buffer); } } @@ -111,9 +102,9 @@ void SctpDataChannelTransport::OnClosingProcedureComplete(int channel_id) { } } -void SctpDataChannelTransport::OnClosedAbruptly() { +void SctpDataChannelTransport::OnClosedAbruptly(RTCError error) { if (sink_) { - sink_->OnTransportClosed(); + sink_->OnTransportClosed(error); } } diff --git a/pc/sctp_data_channel_transport.h b/pc/sctp_data_channel_transport.h index 623a490053..4b89205ea1 100644 --- a/pc/sctp_data_channel_transport.h +++ b/pc/sctp_data_channel_transport.h @@ -11,8 +11,11 @@ #ifndef PC_SCTP_DATA_CHANNEL_TRANSPORT_H_ #define PC_SCTP_DATA_CHANNEL_TRANSPORT_H_ +#include "api/rtc_error.h" #include "api/transport/data_channel_transport_interface.h" +#include "media/base/media_channel.h" #include "media/sctp/sctp_transport_internal.h" +#include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/third_party/sigslot/sigslot.h" namespace webrtc { @@ -38,7 +41,7 @@ class SctpDataChannelTransport : public DataChannelTransportInterface, const rtc::CopyOnWriteBuffer& buffer); void OnClosingProcedureStartedRemotely(int channel_id); void OnClosingProcedureComplete(int channel_id); - void OnClosedAbruptly(); + void OnClosedAbruptly(RTCError error); cricket::SctpTransportInternal* const sctp_transport_; diff --git a/pc/sctp_transport.cc b/pc/sctp_transport.cc index 9450469b8e..7d4e4551f1 100644 --- a/pc/sctp_transport.cc +++ b/pc/sctp_transport.cc @@ -13,7 +13,12 @@ #include #include -#include "rtc_base/bind.h" +#include "absl/types/optional.h" +#include "api/dtls_transport_interface.h" +#include "api/sequence_checker.h" +#include "rtc_base/checks.h" +#include "rtc_base/location.h" +#include "rtc_base/logging.h" namespace webrtc { @@ -41,7 +46,15 @@ SctpTransport::~SctpTransport() { } SctpTransportInformation SctpTransport::Information() const { - MutexLock lock(&lock_); + // TODO(tommi): Update PeerConnection::GetSctpTransport to hand out a proxy + // to the transport so that we can be sure that methods get called on the + // expected thread. Chromium currently calls this method from + // TransceiverStateSurfacer. + if (!owner_thread_->IsCurrent()) { + return owner_thread_->Invoke( + RTC_FROM_HERE, [this] { return Information(); }); + } + RTC_DCHECK_RUN_ON(owner_thread_); return info_; } @@ -67,111 +80,91 @@ rtc::scoped_refptr SctpTransport::dtls_transport() void SctpTransport::Clear() { RTC_DCHECK_RUN_ON(owner_thread_); RTC_DCHECK(internal()); - { - MutexLock lock(&lock_); - // Note that we delete internal_sctp_transport_, but - // only drop the reference to dtls_transport_. - dtls_transport_ = nullptr; - internal_sctp_transport_ = nullptr; - } + // Note that we delete internal_sctp_transport_, but + // only drop the reference to dtls_transport_. + dtls_transport_ = nullptr; + internal_sctp_transport_ = nullptr; UpdateInformation(SctpTransportState::kClosed); } void SctpTransport::SetDtlsTransport( rtc::scoped_refptr transport) { RTC_DCHECK_RUN_ON(owner_thread_); - SctpTransportState next_state; - { - MutexLock lock(&lock_); - next_state = info_.state(); - dtls_transport_ = transport; - if (internal_sctp_transport_) { - if (transport) { - internal_sctp_transport_->SetDtlsTransport(transport->internal()); - transport->internal()->SignalDtlsState.connect( - this, &SctpTransport::OnDtlsStateChange); - if (info_.state() == SctpTransportState::kNew) { - next_state = SctpTransportState::kConnecting; - } - } else { - internal_sctp_transport_->SetDtlsTransport(nullptr); + SctpTransportState next_state = info_.state(); + dtls_transport_ = transport; + if (internal_sctp_transport_) { + if (transport) { + internal_sctp_transport_->SetDtlsTransport(transport->internal()); + + transport->internal()->SubscribeDtlsTransportState( + [this](cricket::DtlsTransportInternal* transport, + DtlsTransportState state) { + OnDtlsStateChange(transport, state); + }); + if (info_.state() == SctpTransportState::kNew) { + next_state = SctpTransportState::kConnecting; } + } else { + internal_sctp_transport_->SetDtlsTransport(nullptr); } } + UpdateInformation(next_state); } void SctpTransport::Start(int local_port, int remote_port, int max_message_size) { - { - MutexLock lock(&lock_); - // Record max message size on calling thread. - info_ = SctpTransportInformation(info_.state(), info_.dtls_transport(), - max_message_size, info_.MaxChannels()); - } - if (owner_thread_->IsCurrent()) { - if (!internal()->Start(local_port, remote_port, max_message_size)) { - RTC_LOG(LS_ERROR) << "Failed to push down SCTP parameters, closing."; - UpdateInformation(SctpTransportState::kClosed); - } - } else { - owner_thread_->Invoke( - RTC_FROM_HERE, rtc::Bind(&SctpTransport::Start, this, local_port, - remote_port, max_message_size)); + RTC_DCHECK_RUN_ON(owner_thread_); + info_ = SctpTransportInformation(info_.state(), info_.dtls_transport(), + max_message_size, info_.MaxChannels()); + + if (!internal()->Start(local_port, remote_port, max_message_size)) { + RTC_LOG(LS_ERROR) << "Failed to push down SCTP parameters, closing."; + UpdateInformation(SctpTransportState::kClosed); } } void SctpTransport::UpdateInformation(SctpTransportState state) { RTC_DCHECK_RUN_ON(owner_thread_); - bool must_send_update; - SctpTransportInformation info_copy(SctpTransportState::kNew); - { - MutexLock lock(&lock_); - must_send_update = (state != info_.state()); - // TODO(https://bugs.webrtc.org/10358): Update max channels from internal - // SCTP transport when available. - if (internal_sctp_transport_) { - info_ = SctpTransportInformation( - state, dtls_transport_, info_.MaxMessageSize(), info_.MaxChannels()); - } else { - info_ = SctpTransportInformation( - state, dtls_transport_, info_.MaxMessageSize(), info_.MaxChannels()); - } - if (observer_ && must_send_update) { - info_copy = info_; - } + bool must_send_update = (state != info_.state()); + // TODO(https://bugs.webrtc.org/10358): Update max channels from internal + // SCTP transport when available. + if (internal_sctp_transport_) { + info_ = SctpTransportInformation( + state, dtls_transport_, info_.MaxMessageSize(), info_.MaxChannels()); + } else { + info_ = SctpTransportInformation( + state, dtls_transport_, info_.MaxMessageSize(), info_.MaxChannels()); } - // We call the observer without holding the lock. + if (observer_ && must_send_update) { - observer_->OnStateChange(info_copy); + observer_->OnStateChange(info_); } } void SctpTransport::OnAssociationChangeCommunicationUp() { RTC_DCHECK_RUN_ON(owner_thread_); - { - MutexLock lock(&lock_); - RTC_DCHECK(internal_sctp_transport_); - if (internal_sctp_transport_->max_outbound_streams() && - internal_sctp_transport_->max_inbound_streams()) { - int max_channels = - std::min(*(internal_sctp_transport_->max_outbound_streams()), - *(internal_sctp_transport_->max_inbound_streams())); - // Record max channels. - info_ = SctpTransportInformation(info_.state(), info_.dtls_transport(), - info_.MaxMessageSize(), max_channels); - } + RTC_DCHECK(internal_sctp_transport_); + if (internal_sctp_transport_->max_outbound_streams() && + internal_sctp_transport_->max_inbound_streams()) { + int max_channels = + std::min(*(internal_sctp_transport_->max_outbound_streams()), + *(internal_sctp_transport_->max_inbound_streams())); + // Record max channels. + info_ = SctpTransportInformation(info_.state(), info_.dtls_transport(), + info_.MaxMessageSize(), max_channels); } + UpdateInformation(SctpTransportState::kConnected); } void SctpTransport::OnDtlsStateChange(cricket::DtlsTransportInternal* transport, - cricket::DtlsTransportState state) { + DtlsTransportState state) { RTC_DCHECK_RUN_ON(owner_thread_); RTC_CHECK(transport == dtls_transport_->internal()); - if (state == cricket::DTLS_TRANSPORT_CLOSED || - state == cricket::DTLS_TRANSPORT_FAILED) { + if (state == DtlsTransportState::kClosed || + state == DtlsTransportState::kFailed) { UpdateInformation(SctpTransportState::kClosed); // TODO(http://bugs.webrtc.org/11090): Close all the data channels } diff --git a/pc/sctp_transport.h b/pc/sctp_transport.h index a902ff02e8..87fde53d97 100644 --- a/pc/sctp_transport.h +++ b/pc/sctp_transport.h @@ -13,11 +13,15 @@ #include +#include "api/dtls_transport_interface.h" #include "api/scoped_refptr.h" #include "api/sctp_transport_interface.h" -#include "media/sctp/sctp_transport.h" +#include "media/sctp/sctp_transport_internal.h" +#include "p2p/base/dtls_transport_internal.h" #include "pc/dtls_transport.h" -#include "rtc_base/synchronization/mutex.h" +#include "rtc_base/third_party/sigslot/sigslot.h" +#include "rtc_base/thread.h" +#include "rtc_base/thread_annotations.h" namespace webrtc { @@ -48,12 +52,12 @@ class SctpTransport : public SctpTransportInterface, // internal() to be functions on the webrtc::SctpTransport interface, // and make the internal() function private. cricket::SctpTransportInternal* internal() { - MutexLock lock(&lock_); + RTC_DCHECK_RUN_ON(owner_thread_); return internal_sctp_transport_.get(); } const cricket::SctpTransportInternal* internal() const { - MutexLock lock(&lock_); + RTC_DCHECK_RUN_ON(owner_thread_); return internal_sctp_transport_.get(); } @@ -67,17 +71,14 @@ class SctpTransport : public SctpTransportInterface, void OnInternalClosingProcedureStartedRemotely(int sid); void OnInternalClosingProcedureComplete(int sid); void OnDtlsStateChange(cricket::DtlsTransportInternal* transport, - cricket::DtlsTransportState state); + DtlsTransportState state); - // Note - owner_thread never changes, but can't be const if we do - // Invoke() on it. - rtc::Thread* owner_thread_; - mutable Mutex lock_; - // Variables accessible off-thread, guarded by lock_ - SctpTransportInformation info_ RTC_GUARDED_BY(lock_); + // NOTE: |owner_thread_| is the thread that the SctpTransport object is + // constructed on. In the context of PeerConnection, it's the network thread. + rtc::Thread* const owner_thread_; + SctpTransportInformation info_ RTC_GUARDED_BY(owner_thread_); std::unique_ptr internal_sctp_transport_ - RTC_GUARDED_BY(lock_); - // Variables only accessed on-thread + RTC_GUARDED_BY(owner_thread_); SctpTransportObserverInterface* observer_ RTC_GUARDED_BY(owner_thread_) = nullptr; rtc::scoped_refptr dtls_transport_ diff --git a/pc/sctp_transport_unittest.cc b/pc/sctp_transport_unittest.cc index f3070cd9a7..679b481f4c 100644 --- a/pc/sctp_transport_unittest.cc +++ b/pc/sctp_transport_unittest.cc @@ -14,6 +14,7 @@ #include #include "absl/memory/memory.h" +#include "api/dtls_transport_interface.h" #include "p2p/base/fake_dtls_transport.h" #include "pc/dtls_transport.h" #include "rtc_base/gunit.h" @@ -38,7 +39,8 @@ class FakeCricketSctpTransport : public cricket::SctpTransportInternal { } bool OpenStream(int sid) override { return true; } bool ResetStream(int sid) override { return true; } - bool SendData(const cricket::SendDataParams& params, + bool SendData(int sid, + const SendDataParams& params, const rtc::CopyOnWriteBuffer& payload, cricket::SendDataResult* result = nullptr) override { return true; @@ -112,8 +114,8 @@ class SctpTransportTest : public ::testing::Test { void CreateTransport() { auto cricket_sctp_transport = absl::WrapUnique(new FakeCricketSctpTransport()); - transport_ = new rtc::RefCountedObject( - std::move(cricket_sctp_transport)); + transport_ = + rtc::make_ref_counted(std::move(cricket_sctp_transport)); } void AddDtlsTransport() { @@ -121,7 +123,7 @@ class SctpTransportTest : public ::testing::Test { std::make_unique( "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP); dtls_transport_ = - new rtc::RefCountedObject(std::move(cricket_transport)); + rtc::make_ref_counted(std::move(cricket_transport)); transport_->SetDtlsTransport(dtls_transport_); } @@ -147,7 +149,7 @@ TEST(SctpTransportSimpleTest, CreateClearDelete) { std::unique_ptr fake_cricket_sctp_transport = absl::WrapUnique(new FakeCricketSctpTransport()); rtc::scoped_refptr sctp_transport = - new rtc::RefCountedObject( + rtc::make_ref_counted( std::move(fake_cricket_sctp_transport)); ASSERT_TRUE(sctp_transport->internal()); ASSERT_EQ(SctpTransportState::kNew, sctp_transport->Information().state()); @@ -203,7 +205,7 @@ TEST_F(SctpTransportTest, CloseWhenTransportCloses) { ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(), kDefaultTimeout); static_cast(dtls_transport_->internal()) - ->SetDtlsState(cricket::DTLS_TRANSPORT_CLOSED); + ->SetDtlsState(DtlsTransportState::kClosed); ASSERT_EQ_WAIT(SctpTransportState::kClosed, observer_.State(), kDefaultTimeout); } diff --git a/pc/sctp_utils.cc b/pc/sctp_utils.cc index 1882a1525f..f7458405ea 100644 --- a/pc/sctp_utils.cc +++ b/pc/sctp_utils.cc @@ -13,8 +13,10 @@ #include #include +#include "absl/types/optional.h" #include "api/priority.h" #include "rtc_base/byte_buffer.h" +#include "rtc_base/checks.h" #include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/logging.h" @@ -228,33 +230,4 @@ void WriteDataChannelOpenAckMessage(rtc::CopyOnWriteBuffer* payload) { payload->SetData(&data, sizeof(data)); } -cricket::DataMessageType ToCricketDataMessageType(DataMessageType type) { - switch (type) { - case DataMessageType::kText: - return cricket::DMT_TEXT; - case DataMessageType::kBinary: - return cricket::DMT_BINARY; - case DataMessageType::kControl: - return cricket::DMT_CONTROL; - default: - return cricket::DMT_NONE; - } - return cricket::DMT_NONE; -} - -DataMessageType ToWebrtcDataMessageType(cricket::DataMessageType type) { - switch (type) { - case cricket::DMT_TEXT: - return DataMessageType::kText; - case cricket::DMT_BINARY: - return DataMessageType::kBinary; - case cricket::DMT_CONTROL: - return DataMessageType::kControl; - case cricket::DMT_NONE: - default: - RTC_NOTREACHED(); - } - return DataMessageType::kControl; -} - } // namespace webrtc diff --git a/pc/sctp_utils.h b/pc/sctp_utils.h index 339ef21163..da854458f4 100644 --- a/pc/sctp_utils.h +++ b/pc/sctp_utils.h @@ -16,6 +16,7 @@ #include "api/data_channel_interface.h" #include "api/transport/data_channel_transport_interface.h" #include "media/base/media_channel.h" +#include "rtc_base/copy_on_write_buffer.h" namespace rtc { class CopyOnWriteBuffer; @@ -39,10 +40,6 @@ bool WriteDataChannelOpenMessage(const std::string& label, void WriteDataChannelOpenAckMessage(rtc::CopyOnWriteBuffer* payload); -cricket::DataMessageType ToCricketDataMessageType(DataMessageType type); - -DataMessageType ToWebrtcDataMessageType(cricket::DataMessageType type); - } // namespace webrtc #endif // PC_SCTP_UTILS_H_ diff --git a/pc/sdp_offer_answer.cc b/pc/sdp_offer_answer.cc index f924c4060d..533bd84dbe 100644 --- a/pc/sdp_offer_answer.cc +++ b/pc/sdp_offer_answer.cc @@ -22,50 +22,45 @@ #include "absl/strings/string_view.h" #include "api/array_view.h" #include "api/crypto/crypto_options.h" -#include "api/data_channel_interface.h" #include "api/dtls_transport_interface.h" -#include "api/media_stream_proxy.h" #include "api/rtp_parameters.h" #include "api/rtp_receiver_interface.h" #include "api/rtp_sender_interface.h" -#include "api/uma_metrics.h" #include "api/video/builtin_video_bitrate_allocator_factory.h" #include "media/base/codec.h" #include "media/base/media_engine.h" #include "media/base/rid_description.h" +#include "p2p/base/ice_transport_internal.h" #include "p2p/base/p2p_constants.h" #include "p2p/base/p2p_transport_channel.h" #include "p2p/base/port.h" #include "p2p/base/transport_description.h" #include "p2p/base/transport_description_factory.h" #include "p2p/base/transport_info.h" -#include "pc/connection_context.h" #include "pc/data_channel_utils.h" -#include "pc/media_protocol_names.h" +#include "pc/dtls_transport.h" #include "pc/media_stream.h" +#include "pc/media_stream_proxy.h" #include "pc/peer_connection.h" #include "pc/peer_connection_message_handler.h" -#include "pc/rtp_data_channel.h" #include "pc/rtp_media_utils.h" #include "pc/rtp_sender.h" #include "pc/rtp_transport_internal.h" -#include "pc/sctp_transport.h" #include "pc/simulcast_description.h" #include "pc/stats_collector.h" #include "pc/usage_pattern.h" #include "pc/webrtc_session_description_factory.h" -#include "rtc_base/bind.h" #include "rtc_base/helpers.h" #include "rtc_base/location.h" #include "rtc_base/logging.h" #include "rtc_base/ref_counted_object.h" #include "rtc_base/rtc_certificate.h" -#include "rtc_base/socket_address.h" #include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/string_encode.h" #include "rtc_base/strings/string_builder.h" #include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/trace_event.h" +#include "system_wrappers/include/field_trial.h" #include "system_wrappers/include/metrics.h" using cricket::ContentInfo; @@ -93,6 +88,9 @@ namespace { typedef webrtc::PeerConnectionInterface::RTCOfferAnswerOptions RTCOfferAnswerOptions; +constexpr const char* kAlwaysAllowPayloadTypeDemuxingFieldTrialName = + "WebRTC-AlwaysAllowPayloadTypeDemuxing"; + // Error messages const char kInvalidSdp[] = "Invalid session description."; const char kInvalidCandidates[] = "Description contains invalid candidates."; @@ -170,6 +168,19 @@ void NoteKeyProtocolAndMedia(KeyExchangeProtocolType protocol_type, } } +std::map GetBundleGroupsByMid( + const SessionDescription* desc) { + std::vector bundle_groups = + desc->GetGroupsByName(cricket::GROUP_TYPE_BUNDLE); + std::map bundle_groups_by_mid; + for (const cricket::ContentGroup* bundle_group : bundle_groups) { + for (const std::string& content_name : bundle_group->content_names()) { + bundle_groups_by_mid[content_name] = bundle_group; + } + } + return bundle_groups_by_mid; +} + // Returns true if |new_desc| requests an ICE restart (i.e., new ufrag/pwd). bool CheckForRemoteIceRestart(const SessionDescriptionInterface* old_desc, const SessionDescriptionInterface* new_desc, @@ -253,7 +264,7 @@ void ReportSimulcastApiVersion(const char* name, } const ContentInfo* FindTransceiverMSection( - RtpTransceiverProxyWithInternal* transceiver, + RtpTransceiver* transceiver, const SessionDescriptionInterface* session_description) { return transceiver->mid() ? session_description->description()->GetContentByName( @@ -340,9 +351,10 @@ bool MediaSectionsHaveSameCount(const SessionDescription& desc1, // needs a ufrag and pwd. Mismatches, such as replying with a DTLS fingerprint // to SDES keys, will be caught in JsepTransport negotiation, and backstopped // by Channel's |srtp_required| check. -RTCError VerifyCrypto(const SessionDescription* desc, bool dtls_enabled) { - const cricket::ContentGroup* bundle = - desc->GetGroupByName(cricket::GROUP_TYPE_BUNDLE); +RTCError VerifyCrypto(const SessionDescription* desc, + bool dtls_enabled, + const std::map& + bundle_groups_by_mid) { for (const cricket::ContentInfo& content_info : desc->contents()) { if (content_info.rejected) { continue; @@ -352,8 +364,10 @@ RTCError VerifyCrypto(const SessionDescription* desc, bool dtls_enabled) { : webrtc::kEnumCounterKeyProtocolSdes, content_info.media_description()->type()); const std::string& mid = content_info.name; - if (bundle && bundle->HasContentName(mid) && - mid != *(bundle->FirstContentName())) { + auto it = bundle_groups_by_mid.find(mid); + const cricket::ContentGroup* bundle = + it != bundle_groups_by_mid.end() ? it->second : nullptr; + if (bundle && mid != *(bundle->FirstContentName())) { // This isn't the first media section in the BUNDLE group, so it's not // required to have crypto attributes, since only the crypto attributes // from the first section actually get used. @@ -390,16 +404,19 @@ RTCError VerifyCrypto(const SessionDescription* desc, bool dtls_enabled) { // Checks that each non-rejected content has ice-ufrag and ice-pwd set, unless // it's in a BUNDLE group, in which case only the BUNDLE-tag section (first // media section/description in the BUNDLE group) needs a ufrag and pwd. -bool VerifyIceUfragPwdPresent(const SessionDescription* desc) { - const cricket::ContentGroup* bundle = - desc->GetGroupByName(cricket::GROUP_TYPE_BUNDLE); +bool VerifyIceUfragPwdPresent( + const SessionDescription* desc, + const std::map& + bundle_groups_by_mid) { for (const cricket::ContentInfo& content_info : desc->contents()) { if (content_info.rejected) { continue; } const std::string& mid = content_info.name; - if (bundle && bundle->HasContentName(mid) && - mid != *(bundle->FirstContentName())) { + auto it = bundle_groups_by_mid.find(mid); + const cricket::ContentGroup* bundle = + it != bundle_groups_by_mid.end() ? it->second : nullptr; + if (bundle && mid != *(bundle->FirstContentName())) { // This isn't the first media section in the BUNDLE group, so it's not // required to have ufrag/password, since only the ufrag/password from // the first section actually get used. @@ -423,7 +440,7 @@ bool VerifyIceUfragPwdPresent(const SessionDescription* desc) { return true; } -static RTCError ValidateMids(const cricket::SessionDescription& description) { +RTCError ValidateMids(const cricket::SessionDescription& description) { std::set mids; for (const cricket::ContentInfo& content : description.contents()) { if (content.name.empty()) { @@ -475,7 +492,7 @@ std::string GetSignalingStateString( // This method will extract any send encodings that were sent by the remote // connection. This is currently only relevant for Simulcast scenario (where // the number of layers may be communicated by the server). -static std::vector GetSendEncodingsFromRemoteDescription( +std::vector GetSendEncodingsFromRemoteDescription( const MediaContentDescription& desc) { if (!desc.HasSimulcast()) { return {}; @@ -499,7 +516,7 @@ static std::vector GetSendEncodingsFromRemoteDescription( return result; } -static RTCError UpdateSimulcastLayerStatusInSender( +RTCError UpdateSimulcastLayerStatusInSender( const std::vector& layers, rtc::scoped_refptr sender) { RTC_DCHECK(sender); @@ -530,19 +547,22 @@ static RTCError UpdateSimulcastLayerStatusInSender( return result; } -static bool SimulcastIsRejected( - const ContentInfo* local_content, - const MediaContentDescription& answer_media_desc) { +bool SimulcastIsRejected(const ContentInfo* local_content, + const MediaContentDescription& answer_media_desc, + bool enable_encrypted_rtp_header_extensions) { bool simulcast_offered = local_content && local_content->media_description() && local_content->media_description()->HasSimulcast(); bool simulcast_answered = answer_media_desc.HasSimulcast(); bool rids_supported = RtpExtension::FindHeaderExtensionByUri( - answer_media_desc.rtp_header_extensions(), RtpExtension::kRidUri); + answer_media_desc.rtp_header_extensions(), RtpExtension::kRidUri, + enable_encrypted_rtp_header_extensions + ? RtpExtension::Filter::kPreferEncryptedExtension + : RtpExtension::Filter::kDiscardEncryptedExtension); return simulcast_offered && (!simulcast_answered || !rids_supported); } -static RTCError DisableSimulcastInSender( +RTCError DisableSimulcastInSender( rtc::scoped_refptr sender) { RTC_DCHECK(sender); RtpParameters parameters = sender->GetParametersInternal(); @@ -560,7 +580,7 @@ static RTCError DisableSimulcastInSender( // The SDP parser used to populate these values by default for the 'content // name' if an a=mid line was absent. -static absl::string_view GetDefaultMidForPlanB(cricket::MediaType media_type) { +absl::string_view GetDefaultMidForPlanB(cricket::MediaType media_type) { switch (media_type) { case cricket::MEDIA_TYPE_AUDIO: return cricket::CN_AUDIO; @@ -599,10 +619,8 @@ void AddPlanBRtpSenderOptions( } } -static cricket::MediaDescriptionOptions -GetMediaDescriptionOptionsForTransceiver( - rtc::scoped_refptr> - transceiver, +cricket::MediaDescriptionOptions GetMediaDescriptionOptionsForTransceiver( + RtpTransceiver* transceiver, const std::string& mid, bool is_create_offer) { // NOTE: a stopping transceiver should be treated as a stopped one in @@ -622,7 +640,7 @@ GetMediaDescriptionOptionsForTransceiver( // 2. If the MSID is included, then it must be included in any subsequent // offer/answer exactly the same until the RtpTransceiver is stopped. if (stopped || (!RtpTransceiverDirectionHasSend(transceiver->direction()) && - !transceiver->internal()->has_ever_been_used_to_send())) { + !transceiver->has_ever_been_used_to_send())) { return media_description_options; } @@ -633,7 +651,7 @@ GetMediaDescriptionOptionsForTransceiver( // The following sets up RIDs and Simulcast. // RIDs are included if Simulcast is requested or if any RID was specified. RtpParameters send_parameters = - transceiver->internal()->sender_internal()->GetParametersInternal(); + transceiver->sender_internal()->GetParametersInternal(); bool has_rids = std::any_of(send_parameters.encodings.begin(), send_parameters.encodings.end(), [](const RtpEncodingParameters& encoding) { @@ -665,9 +683,8 @@ GetMediaDescriptionOptionsForTransceiver( } // Returns the ContentInfo at mline index |i|, or null if none exists. -static const ContentInfo* GetContentByIndex( - const SessionDescriptionInterface* sdesc, - size_t i) { +const ContentInfo* GetContentByIndex(const SessionDescriptionInterface* sdesc, + size_t i) { if (!sdesc) { return nullptr; } @@ -696,27 +713,6 @@ std::string GenerateRtcpCname() { return cname; } -// Add options to |session_options| from |rtp_data_channels|. -void AddRtpDataChannelOptions( - const std::map>& - rtp_data_channels, - cricket::MediaDescriptionOptions* data_media_description_options) { - if (!data_media_description_options) { - return; - } - // Check for data channels. - for (const auto& kv : rtp_data_channels) { - const RtpDataChannel* channel = kv.second; - if (channel->state() == RtpDataChannel::kConnecting || - channel->state() == RtpDataChannel::kOpen) { - // Legacy RTP data channels are signaled with the track/stream ID set to - // the data channel's label. - data_media_description_options->AddRtpDataChannel(channel->label(), - channel->label()); - } - } -} - // Check if we can send |new_stream| on a PeerConnection. bool CanAddLocalMediaStream(webrtc::StreamCollectionInterface* current_streams, webrtc::MediaStreamInterface* new_stream) { @@ -731,6 +727,32 @@ bool CanAddLocalMediaStream(webrtc::StreamCollectionInterface* current_streams, return true; } +rtc::scoped_refptr LookupDtlsTransportByMid( + rtc::Thread* network_thread, + JsepTransportController* controller, + const std::string& mid) { + // TODO(tommi): Can we post this (and associated operations where this + // function is called) to the network thread and avoid this Invoke? + // We might be able to simplify a few things if we set the transport on + // the network thread and then update the implementation to check that + // the set_ and relevant get methods are always called on the network + // thread (we'll need to update proxy maps). + return network_thread->Invoke>( + RTC_FROM_HERE, + [controller, &mid] { return controller->LookupDtlsTransportByMid(mid); }); +} + +bool ContentHasHeaderExtension(const cricket::ContentInfo& content_info, + absl::string_view header_extension_uri) { + for (const RtpExtension& rtp_header_extension : + content_info.media_description()->rtp_header_extensions()) { + if (rtp_header_extension.uri == header_extension_uri) { + return true; + } + } + return false; +} + } // namespace // Used by parameterless SetLocalDescription() to create an offer or answer. @@ -1241,7 +1263,10 @@ void SdpOfferAnswerHandler::SetLocalDescription( } RTCError SdpOfferAnswerHandler::ApplyLocalDescription( - std::unique_ptr desc) { + std::unique_ptr desc, + const std::map& + bundle_groups_by_mid) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::ApplyLocalDescription"); RTC_DCHECK_RUN_ON(signaling_thread()); RTC_DCHECK(desc); @@ -1295,13 +1320,14 @@ RTCError SdpOfferAnswerHandler::ApplyLocalDescription( if (IsUnifiedPlan()) { RTCError error = UpdateTransceiversAndDataChannels( cricket::CS_LOCAL, *local_description(), old_local_description, - remote_description()); + remote_description(), bundle_groups_by_mid); if (!error.ok()) { return error; } std::vector> remove_list; std::vector> removed_streams; - for (const auto& transceiver : transceivers()->List()) { + for (const auto& transceiver_ext : transceivers()->List()) { + auto transceiver = transceiver_ext->internal(); if (transceiver->stopped()) { continue; } @@ -1310,12 +1336,10 @@ RTCError SdpOfferAnswerHandler::ApplyLocalDescription( // Note that code paths that don't set MID won't be able to use // information about DTLS transports. if (transceiver->mid()) { - auto dtls_transport = transport_controller()->LookupDtlsTransportByMid( - *transceiver->mid()); - transceiver->internal()->sender_internal()->set_transport( - dtls_transport); - transceiver->internal()->receiver_internal()->set_transport( - dtls_transport); + auto dtls_transport = LookupDtlsTransportByMid( + pc_->network_thread(), transport_controller(), *transceiver->mid()); + transceiver->sender_internal()->set_transport(dtls_transport); + transceiver->receiver_internal()->set_transport(dtls_transport); } const ContentInfo* content = @@ -1332,16 +1356,15 @@ RTCError SdpOfferAnswerHandler::ApplyLocalDescription( // "recvonly", process the removal of a remote track for the media // description, given transceiver, removeList, and muteTracks. if (!RtpTransceiverDirectionHasRecv(media_desc->direction()) && - (transceiver->internal()->fired_direction() && - RtpTransceiverDirectionHasRecv( - *transceiver->internal()->fired_direction()))) { - ProcessRemovalOfRemoteTrack(transceiver, &remove_list, + (transceiver->fired_direction() && + RtpTransceiverDirectionHasRecv(*transceiver->fired_direction()))) { + ProcessRemovalOfRemoteTrack(transceiver_ext, &remove_list, &removed_streams); } // 2.2.7.1.6.2: Set transceiver's [[CurrentDirection]] and // [[FiredDirection]] slots to direction. - transceiver->internal()->set_current_direction(media_desc->direction()); - transceiver->internal()->set_fired_direction(media_desc->direction()); + transceiver->set_current_direction(media_desc->direction()); + transceiver->set_fired_direction(media_desc->direction()); } } auto observer = pc_->Observer(); @@ -1367,7 +1390,8 @@ RTCError SdpOfferAnswerHandler::ApplyLocalDescription( } error = UpdateSessionState(type, cricket::CS_LOCAL, - local_description()->description()); + local_description()->description(), + bundle_groups_by_mid); if (!error.ok()) { return error; } @@ -1385,12 +1409,15 @@ RTCError SdpOfferAnswerHandler::ApplyLocalDescription( // If setting the description decided our SSL role, allocate any necessary // SCTP sids. rtc::SSLRole role; - if (IsSctpLike(pc_->data_channel_type()) && pc_->GetSctpSslRole(&role)) { + if (pc_->GetSctpSslRole(&role)) { data_channel_controller()->AllocateSctpSids(role); } if (IsUnifiedPlan()) { - for (const auto& transceiver : transceivers()->List()) { + // We must use List and not ListInternal here because + // transceivers()->StableState() is indexed by the non-internal refptr. + for (const auto& transceiver_ext : transceivers()->List()) { + auto transceiver = transceiver_ext->internal(); if (transceiver->stopped()) { continue; } @@ -1399,20 +1426,24 @@ RTCError SdpOfferAnswerHandler::ApplyLocalDescription( if (!content) { continue; } - cricket::ChannelInterface* channel = transceiver->internal()->channel(); + cricket::ChannelInterface* channel = transceiver->channel(); if (content->rejected || !channel || channel->local_streams().empty()) { // 0 is a special value meaning "this sender has no associated send // stream". Need to call this so the sender won't attempt to configure // a no longer existing stream and run into DCHECKs in the lower // layers. - transceiver->internal()->sender_internal()->SetSsrc(0); + transceiver->sender_internal()->SetSsrc(0); } else { // Get the StreamParams from the channel which could generate SSRCs. const std::vector& streams = channel->local_streams(); - transceiver->internal()->sender_internal()->set_stream_ids( - streams[0].stream_ids()); - transceiver->internal()->sender_internal()->SetSsrc( - streams[0].first_ssrc()); + transceiver->sender_internal()->set_stream_ids(streams[0].stream_ids()); + auto encodings = transceiver->sender_internal()->init_send_encodings(); + transceiver->sender_internal()->SetSsrc(streams[0].first_ssrc()); + if (!encodings.empty()) { + transceivers() + ->StableState(transceiver_ext) + ->SetInitSendEncodings(encodings); + } } } } else { @@ -1445,17 +1476,7 @@ RTCError SdpOfferAnswerHandler::ApplyLocalDescription( } } - const cricket::ContentInfo* data_content = - GetFirstDataContent(local_description()->description()); - if (data_content) { - const cricket::RtpDataContentDescription* rtp_data_desc = - data_content->media_description()->as_rtp_data(); - // rtp_data_desc will be null if this is an SCTP description. - if (rtp_data_desc) { - data_channel_controller()->UpdateLocalRtpDataChannels( - rtp_data_desc->streams()); - } - } + // This function does nothing with data content. if (type == SdpType::kAnswer && local_ice_credentials_to_replace_->SatisfiesIceRestart( @@ -1532,7 +1553,10 @@ void SdpOfferAnswerHandler::SetRemoteDescription( } RTCError SdpOfferAnswerHandler::ApplyRemoteDescription( - std::unique_ptr desc) { + std::unique_ptr desc, + const std::map& + bundle_groups_by_mid) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::ApplyRemoteDescription"); RTC_DCHECK_RUN_ON(signaling_thread()); RTC_DCHECK(desc); @@ -1576,7 +1600,7 @@ RTCError SdpOfferAnswerHandler::ApplyRemoteDescription( if (IsUnifiedPlan()) { RTCError error = UpdateTransceiversAndDataChannels( cricket::CS_REMOTE, *remote_description(), local_description(), - old_remote_description); + old_remote_description, bundle_groups_by_mid); if (!error.ok()) { return error; } @@ -1598,7 +1622,8 @@ RTCError SdpOfferAnswerHandler::ApplyRemoteDescription( // NOTE: Candidates allocation will be initiated only when // SetLocalDescription is called. error = UpdateSessionState(type, cricket::CS_REMOTE, - remote_description()->description()); + remote_description()->description(), + bundle_groups_by_mid); if (!error.ok()) { return error; } @@ -1657,7 +1682,7 @@ RTCError SdpOfferAnswerHandler::ApplyRemoteDescription( // If setting the description decided our SSL role, allocate any necessary // SCTP sids. rtc::SSLRole role; - if (IsSctpLike(pc_->data_channel_type()) && pc_->GetSctpSslRole(&role)) { + if (pc_->GetSctpSslRole(&role)) { data_channel_controller()->AllocateSctpSids(role); } @@ -1667,7 +1692,8 @@ RTCError SdpOfferAnswerHandler::ApplyRemoteDescription( std::vector> remove_list; std::vector> added_streams; std::vector> removed_streams; - for (const auto& transceiver : transceivers()->List()) { + for (const auto& transceiver_ext : transceivers()->List()) { + const auto transceiver = transceiver_ext->internal(); const ContentInfo* content = FindMediaSectionForTransceiver(transceiver, remote_description()); if (!content) { @@ -1687,14 +1713,13 @@ RTCError SdpOfferAnswerHandler::ApplyRemoteDescription( stream_ids = media_desc->streams()[0].stream_ids(); } transceivers() - ->StableState(transceiver) + ->StableState(transceiver_ext) ->SetRemoteStreamIdsIfUnset(transceiver->receiver()->stream_ids()); RTC_LOG(LS_INFO) << "Processing the MSIDs for MID=" << content->name << " (" << GetStreamIdsString(stream_ids) << ")."; - SetAssociatedRemoteStreams(transceiver->internal()->receiver_internal(), - stream_ids, &added_streams, - &removed_streams); + SetAssociatedRemoteStreams(transceiver->receiver_internal(), stream_ids, + &added_streams, &removed_streams); // From the WebRTC specification, steps 2.2.8.5/6 of section 4.4.1.6 // "Set the RTCSessionDescription: If direction is sendrecv or recvonly, // and transceiver's current direction is neither sendrecv nor recvonly, @@ -1714,26 +1739,24 @@ RTCError SdpOfferAnswerHandler::ApplyRemoteDescription( if (!RtpTransceiverDirectionHasRecv(local_direction) && (transceiver->fired_direction() && RtpTransceiverDirectionHasRecv(*transceiver->fired_direction()))) { - ProcessRemovalOfRemoteTrack(transceiver, &remove_list, + ProcessRemovalOfRemoteTrack(transceiver_ext, &remove_list, &removed_streams); } // 2.2.8.1.10: Set transceiver's [[FiredDirection]] slot to direction. - transceiver->internal()->set_fired_direction(local_direction); + transceiver->set_fired_direction(local_direction); // 2.2.8.1.11: If description is of type "answer" or "pranswer", then run // the following steps: if (type == SdpType::kPrAnswer || type == SdpType::kAnswer) { // 2.2.8.1.11.1: Set transceiver's [[CurrentDirection]] slot to // direction. - transceiver->internal()->set_current_direction(local_direction); + transceiver->set_current_direction(local_direction); // 2.2.8.1.11.[3-6]: Set the transport internal slots. if (transceiver->mid()) { - auto dtls_transport = - transport_controller()->LookupDtlsTransportByMid( - *transceiver->mid()); - transceiver->internal()->sender_internal()->set_transport( - dtls_transport); - transceiver->internal()->receiver_internal()->set_transport( - dtls_transport); + auto dtls_transport = LookupDtlsTransportByMid(pc_->network_thread(), + transport_controller(), + *transceiver->mid()); + transceiver->sender_internal()->set_transport(dtls_transport); + transceiver->receiver_internal()->set_transport(dtls_transport); } } // 2.2.8.1.12: If the media description is rejected, and transceiver is @@ -1741,18 +1764,16 @@ RTCError SdpOfferAnswerHandler::ApplyRemoteDescription( if (content->rejected && !transceiver->stopped()) { RTC_LOG(LS_INFO) << "Stopping transceiver for MID=" << content->name << " since the media section was rejected."; - transceiver->internal()->StopTransceiverProcedure(); + transceiver->StopTransceiverProcedure(); } if (!content->rejected && RtpTransceiverDirectionHasRecv(local_direction)) { if (!media_desc->streams().empty() && media_desc->streams()[0].has_ssrcs()) { uint32_t ssrc = media_desc->streams()[0].first_ssrc(); - transceiver->internal()->receiver_internal()->SetupMediaChannel(ssrc); + transceiver->receiver_internal()->SetupMediaChannel(ssrc); } else { - transceiver->internal() - ->receiver_internal() - ->SetupUnsignaledMediaChannel(); + transceiver->receiver_internal()->SetupUnsignaledMediaChannel(); } } } @@ -1783,8 +1804,6 @@ RTCError SdpOfferAnswerHandler::ApplyRemoteDescription( GetFirstAudioContentDescription(remote_description()->description()); const cricket::VideoContentDescription* video_desc = GetFirstVideoContentDescription(remote_description()->description()); - const cricket::RtpDataContentDescription* rtp_data_desc = - GetFirstRtpDataContentDescription(remote_description()->description()); // Check if the descriptions include streams, just in case the peer supports // MSID, but doesn't indicate so with "a=msid-semantic". @@ -1837,13 +1856,6 @@ RTCError SdpOfferAnswerHandler::ApplyRemoteDescription( } } - // If this is an RTP data transport, update the DataChannels with the - // information from the remote peer. - if (rtp_data_desc) { - data_channel_controller()->UpdateRemoteRtpDataChannels( - GetActiveStreams(rtp_data_desc)); - } - // Iterate new_streams and notify the observer about new MediaStreams. auto observer = pc_->Observer(); for (size_t i = 0; i < new_streams->count(); ++i) { @@ -1904,7 +1916,10 @@ void SdpOfferAnswerHandler::DoSetLocalDescription( return; } - RTCError error = ValidateSessionDescription(desc.get(), cricket::CS_LOCAL); + std::map bundle_groups_by_mid = + GetBundleGroupsByMid(desc->description()); + RTCError error = ValidateSessionDescription(desc.get(), cricket::CS_LOCAL, + bundle_groups_by_mid); if (!error.ok()) { std::string error_message = GetSetDescriptionErrorMessage( cricket::CS_LOCAL, desc->GetType(), error); @@ -1918,7 +1933,7 @@ void SdpOfferAnswerHandler::DoSetLocalDescription( // which may destroy it before returning. const SdpType type = desc->GetType(); - error = ApplyLocalDescription(std::move(desc)); + error = ApplyLocalDescription(std::move(desc), bundle_groups_by_mid); // |desc| may be destroyed at this point. if (!error.ok()) { @@ -1941,8 +1956,7 @@ void SdpOfferAnswerHandler::DoSetLocalDescription( // TODO(deadbeef): We already had to hop to the network thread for // MaybeStartGathering... pc_->network_thread()->Invoke( - RTC_FROM_HERE, rtc::Bind(&cricket::PortAllocator::DiscardCandidatePool, - port_allocator())); + RTC_FROM_HERE, [this] { port_allocator()->DiscardCandidatePool(); }); // Make UMA notes about what was agreed to. ReportNegotiatedSdpSemantics(*local_description()); } @@ -2028,6 +2042,7 @@ void SdpOfferAnswerHandler::DoCreateOffer( void SdpOfferAnswerHandler::CreateAnswer( CreateSessionDescriptionObserver* observer, const PeerConnectionInterface::RTCOfferAnswerOptions& options) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::CreateAnswer"); RTC_DCHECK_RUN_ON(signaling_thread()); // Chain this operation. If asynchronous operations are pending on the chain, // this operation will be queued to be invoked, otherwise the contents of the @@ -2158,13 +2173,17 @@ void SdpOfferAnswerHandler::DoSetRemoteDescription( desc->GetType() == SdpType::kAnswer) { // Report to UMA the format of the received offer or answer. pc_->ReportSdpFormatReceived(*desc); + pc_->ReportSdpBundleUsage(*desc); } // Handle remote descriptions missing a=mid lines for interop with legacy end // points. FillInMissingRemoteMids(desc->description()); - RTCError error = ValidateSessionDescription(desc.get(), cricket::CS_REMOTE); + std::map bundle_groups_by_mid = + GetBundleGroupsByMid(desc->description()); + RTCError error = ValidateSessionDescription(desc.get(), cricket::CS_REMOTE, + bundle_groups_by_mid); if (!error.ok()) { std::string error_message = GetSetDescriptionErrorMessage( cricket::CS_REMOTE, desc->GetType(), error); @@ -2178,7 +2197,7 @@ void SdpOfferAnswerHandler::DoSetRemoteDescription( // ApplyRemoteDescription, which may destroy it before returning. const SdpType type = desc->GetType(); - error = ApplyRemoteDescription(std::move(desc)); + error = ApplyRemoteDescription(std::move(desc), bundle_groups_by_mid); // |desc| may be destroyed at this point. if (!error.ok()) { @@ -2200,8 +2219,7 @@ void SdpOfferAnswerHandler::DoSetRemoteDescription( // TODO(deadbeef): We already had to hop to the network thread for // MaybeStartGathering... pc_->network_thread()->Invoke( - RTC_FROM_HERE, rtc::Bind(&cricket::PortAllocator::DiscardCandidatePool, - port_allocator())); + RTC_FROM_HERE, [this] { port_allocator()->DiscardCandidatePool(); }); // Make UMA notes about what was agreed to. ReportNegotiatedSdpSemantics(*remote_description()); } @@ -2268,60 +2286,64 @@ void SdpOfferAnswerHandler::SetAssociatedRemoteStreams( bool SdpOfferAnswerHandler::AddIceCandidate( const IceCandidateInterface* ice_candidate) { + const AddIceCandidateResult result = AddIceCandidateInternal(ice_candidate); + NoteAddIceCandidateResult(result); + // If the return value is kAddIceCandidateFailNotReady, the candidate has been + // added, although not 'ready', but that's a success. + return result == kAddIceCandidateSuccess || + result == kAddIceCandidateFailNotReady; +} + +AddIceCandidateResult SdpOfferAnswerHandler::AddIceCandidateInternal( + const IceCandidateInterface* ice_candidate) { RTC_DCHECK_RUN_ON(signaling_thread()); TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::AddIceCandidate"); if (pc_->IsClosed()) { RTC_LOG(LS_ERROR) << "AddIceCandidate: PeerConnection is closed."; - NoteAddIceCandidateResult(kAddIceCandidateFailClosed); - return false; + return kAddIceCandidateFailClosed; } if (!remote_description()) { RTC_LOG(LS_ERROR) << "AddIceCandidate: ICE candidates can't be added " "without any remote session description."; - NoteAddIceCandidateResult(kAddIceCandidateFailNoRemoteDescription); - return false; + return kAddIceCandidateFailNoRemoteDescription; } if (!ice_candidate) { RTC_LOG(LS_ERROR) << "AddIceCandidate: Candidate is null."; - NoteAddIceCandidateResult(kAddIceCandidateFailNullCandidate); - return false; + return kAddIceCandidateFailNullCandidate; } bool valid = false; bool ready = ReadyToUseRemoteCandidate(ice_candidate, nullptr, &valid); if (!valid) { - NoteAddIceCandidateResult(kAddIceCandidateFailNotValid); - return false; + return kAddIceCandidateFailNotValid; } // Add this candidate to the remote session description. if (!mutable_remote_description()->AddCandidate(ice_candidate)) { RTC_LOG(LS_ERROR) << "AddIceCandidate: Candidate cannot be used."; - NoteAddIceCandidateResult(kAddIceCandidateFailInAddition); - return false; + return kAddIceCandidateFailInAddition; } - if (ready) { - bool result = UseCandidate(ice_candidate); - if (result) { - pc_->NoteUsageEvent(UsageEvent::ADD_ICE_CANDIDATE_SUCCEEDED); - NoteAddIceCandidateResult(kAddIceCandidateSuccess); - } else { - NoteAddIceCandidateResult(kAddIceCandidateFailNotUsable); - } - return result; - } else { + if (!ready) { RTC_LOG(LS_INFO) << "AddIceCandidate: Not ready to use candidate."; - NoteAddIceCandidateResult(kAddIceCandidateFailNotReady); - return true; + return kAddIceCandidateFailNotReady; + } + + if (!UseCandidate(ice_candidate)) { + return kAddIceCandidateFailNotUsable; } + + pc_->NoteUsageEvent(UsageEvent::ADD_ICE_CANDIDATE_SUCCEEDED); + + return kAddIceCandidateSuccess; } void SdpOfferAnswerHandler::AddIceCandidate( std::unique_ptr candidate, std::function callback) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::AddIceCandidate"); RTC_DCHECK_RUN_ON(signaling_thread()); // Chain this operation. If asynchronous operations are pending on the chain, // this operation will be queued to be invoked, otherwise the contents of the @@ -2330,23 +2352,25 @@ void SdpOfferAnswerHandler::AddIceCandidate( [this_weak_ptr = weak_ptr_factory_.GetWeakPtr(), candidate = std::move(candidate), callback = std::move(callback)]( std::function operations_chain_callback) { - if (!this_weak_ptr) { - operations_chain_callback(); + auto result = + this_weak_ptr + ? this_weak_ptr->AddIceCandidateInternal(candidate.get()) + : kAddIceCandidateFailClosed; + NoteAddIceCandidateResult(result); + operations_chain_callback(); + if (result == kAddIceCandidateFailClosed) { callback(RTCError( RTCErrorType::INVALID_STATE, "AddIceCandidate failed because the session was shut down")); - return; - } - if (!this_weak_ptr->AddIceCandidate(candidate.get())) { - operations_chain_callback(); + } else if (result != kAddIceCandidateSuccess && + result != kAddIceCandidateFailNotReady) { // Fail with an error type and message consistent with Chromium. // TODO(hbos): Fail with error types according to spec. callback(RTCError(RTCErrorType::UNSUPPORTED_OPERATION, "Error processing ICE candidate")); - return; + } else { + callback(RTCError::OK()); } - operations_chain_callback(); - callback(RTCError::OK()); }); } @@ -2451,6 +2475,7 @@ PeerConnectionInterface::SignalingState SdpOfferAnswerHandler::signaling_state() void SdpOfferAnswerHandler::ChangeSignalingState( PeerConnectionInterface::SignalingState signaling_state) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::ChangeSignalingState"); RTC_DCHECK_RUN_ON(signaling_thread()); if (signaling_state_ == signaling_state) { return; @@ -2466,13 +2491,20 @@ void SdpOfferAnswerHandler::ChangeSignalingState( RTCError SdpOfferAnswerHandler::UpdateSessionState( SdpType type, cricket::ContentSource source, - const cricket::SessionDescription* description) { + const cricket::SessionDescription* description, + const std::map& + bundle_groups_by_mid) { RTC_DCHECK_RUN_ON(signaling_thread()); // If there's already a pending error then no state transition should happen. // But all call-sites should be verifying this before calling us! RTC_DCHECK(session_error() == SessionError::kNone); + // If this is answer-ish we're ready to let media flow. + if (type == SdpType::kPrAnswer || type == SdpType::kAnswer) { + EnableSending(); + } + // Update the signaling state according to the specified state machine (see // https://w3c.github.io/webrtc-pc/#rtcsignalingstate-enum). if (type == SdpType::kOffer) { @@ -2487,17 +2519,11 @@ RTCError SdpOfferAnswerHandler::UpdateSessionState( RTC_DCHECK(type == SdpType::kAnswer); ChangeSignalingState(PeerConnectionInterface::kStable); transceivers()->DiscardStableStates(); - have_pending_rtp_data_channel_ = false; } // Update internal objects according to the session description's media // descriptions. - RTCError error = PushdownMediaDescription(type, source); - if (!error.ok()) { - return error; - } - - return RTCError::OK(); + return PushdownMediaDescription(type, source, bundle_groups_by_mid); } bool SdpOfferAnswerHandler::ShouldFireNegotiationNeededEvent( @@ -2654,6 +2680,7 @@ void SdpOfferAnswerHandler::OnVideoTrackRemoved(VideoTrackInterface* track, } RTCError SdpOfferAnswerHandler::Rollback(SdpType desc_type) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::Rollback"); auto state = signaling_state(); if (state != PeerConnectionInterface::kHaveLocalOffer && state != PeerConnectionInterface::kHaveRemoteOffer) { @@ -2701,16 +2728,16 @@ RTCError SdpOfferAnswerHandler::Rollback(SdpType desc_type) { transceivers()->Remove(transceiver); } } + if (state.init_send_encodings()) { + transceiver->internal()->sender_internal()->set_init_send_encodings( + state.init_send_encodings().value()); + } transceiver->internal()->sender_internal()->set_transport(nullptr); transceiver->internal()->receiver_internal()->set_transport(nullptr); transceiver->internal()->set_mid(state.mid()); transceiver->internal()->set_mline_index(state.mline_index()); } transport_controller()->RollbackTransports(); - if (have_pending_rtp_data_channel_) { - DestroyDataChannelTransport(); - have_pending_rtp_data_channel_ = false; - } transceivers()->DiscardStableStates(); pending_local_description_.reset(); pending_remote_description_.reset(); @@ -2777,7 +2804,7 @@ bool SdpOfferAnswerHandler::IceRestartPending( bool SdpOfferAnswerHandler::NeedsIceRestart( const std::string& content_name) const { - return transport_controller()->NeedsIceRestart(content_name); + return pc_->NeedsIceRestart(content_name); } absl::optional SdpOfferAnswerHandler::GetDtlsRole( @@ -2873,12 +2900,12 @@ bool SdpOfferAnswerHandler::CheckIfNegotiationIsNeeded() { // 5. For each transceiver in connection's set of transceivers, perform the // following checks: - for (const auto& transceiver : transceivers()->List()) { + for (const auto& transceiver : transceivers()->ListInternal()) { const ContentInfo* current_local_msection = - FindTransceiverMSection(transceiver.get(), description); + FindTransceiverMSection(transceiver, description); - const ContentInfo* current_remote_msection = FindTransceiverMSection( - transceiver.get(), current_remote_description()); + const ContentInfo* current_remote_msection = + FindTransceiverMSection(transceiver, current_remote_description()); // 5.4 If transceiver is stopped and is associated with an m= section, // but the associated m= section is not yet rejected in @@ -2966,7 +2993,7 @@ bool SdpOfferAnswerHandler::CheckIfNegotiationIsNeeded() { return true; const ContentInfo* offered_remote_msection = - FindTransceiverMSection(transceiver.get(), remote_description()); + FindTransceiverMSection(transceiver, remote_description()); RtpTransceiverDirection offered_direction = offered_remote_msection @@ -2995,7 +3022,9 @@ void SdpOfferAnswerHandler::GenerateNegotiationNeededEvent() { RTCError SdpOfferAnswerHandler::ValidateSessionDescription( const SessionDescriptionInterface* sdesc, - cricket::ContentSource source) { + cricket::ContentSource source, + const std::map& + bundle_groups_by_mid) { if (session_error() != SessionError::kNone) { LOG_AND_RETURN_ERROR(RTCErrorType::INTERNAL_ERROR, GetSessionErrorMsg()); } @@ -3021,20 +3050,21 @@ RTCError SdpOfferAnswerHandler::ValidateSessionDescription( std::string crypto_error; if (webrtc_session_desc_factory_->SdesPolicy() == cricket::SEC_REQUIRED || pc_->dtls_enabled()) { - RTCError crypto_error = - VerifyCrypto(sdesc->description(), pc_->dtls_enabled()); + RTCError crypto_error = VerifyCrypto( + sdesc->description(), pc_->dtls_enabled(), bundle_groups_by_mid); if (!crypto_error.ok()) { return crypto_error; } } // Verify ice-ufrag and ice-pwd. - if (!VerifyIceUfragPwdPresent(sdesc->description())) { + if (!VerifyIceUfragPwdPresent(sdesc->description(), bundle_groups_by_mid)) { LOG_AND_RETURN_ERROR(RTCErrorType::INVALID_PARAMETER, kSdpWithoutIceUfragPwd); } - if (!pc_->ValidateBundleSettings(sdesc->description())) { + if (!pc_->ValidateBundleSettings(sdesc->description(), + bundle_groups_by_mid)) { LOG_AND_RETURN_ERROR(RTCErrorType::INVALID_PARAMETER, kBundleWithoutRtcpMux); } @@ -3107,18 +3137,25 @@ RTCError SdpOfferAnswerHandler::UpdateTransceiversAndDataChannels( cricket::ContentSource source, const SessionDescriptionInterface& new_session, const SessionDescriptionInterface* old_local_description, - const SessionDescriptionInterface* old_remote_description) { + const SessionDescriptionInterface* old_remote_description, + const std::map& + bundle_groups_by_mid) { + TRACE_EVENT0("webrtc", + "SdpOfferAnswerHandler::UpdateTransceiversAndDataChannels"); RTC_DCHECK_RUN_ON(signaling_thread()); RTC_DCHECK(IsUnifiedPlan()); - const cricket::ContentGroup* bundle_group = nullptr; if (new_session.GetType() == SdpType::kOffer) { - auto bundle_group_or_error = - GetEarlyBundleGroup(*new_session.description()); - if (!bundle_group_or_error.ok()) { - return bundle_group_or_error.MoveError(); + // If the BUNDLE policy is max-bundle, then we know for sure that all + // transports will be bundled from the start. Return an error if max-bundle + // is specified but the session description does not have a BUNDLE group. + if (pc_->configuration()->bundle_policy == + PeerConnectionInterface::kBundlePolicyMaxBundle && + bundle_groups_by_mid.empty()) { + LOG_AND_RETURN_ERROR(RTCErrorType::INVALID_PARAMETER, + "max-bundle configured but session description " + "has no BUNDLE group"); } - bundle_group = bundle_group_or_error.MoveValue(); } const ContentInfos& new_contents = new_session.description()->contents(); @@ -3126,6 +3163,9 @@ RTCError SdpOfferAnswerHandler::UpdateTransceiversAndDataChannels( const cricket::ContentInfo& new_content = new_contents[i]; cricket::MediaType media_type = new_content.media_description()->type(); mid_generator_.AddKnownId(new_content.name); + auto it = bundle_groups_by_mid.find(new_content.name); + const cricket::ContentGroup* bundle_group = + it != bundle_groups_by_mid.end() ? it->second : nullptr; if (media_type == cricket::MEDIA_TYPE_AUDIO || media_type == cricket::MEDIA_TYPE_VIDEO) { const cricket::ContentInfo* old_local_content = nullptr; @@ -3188,6 +3228,7 @@ SdpOfferAnswerHandler::AssociateTransceiver( const ContentInfo& content, const ContentInfo* old_local_content, const ContentInfo* old_remote_content) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::AssociateTransceiver"); RTC_DCHECK(IsUnifiedPlan()); #if RTC_DCHECK_IS_ON // If this is an offer then the m= section might be recycled. If the m= @@ -3264,7 +3305,9 @@ SdpOfferAnswerHandler::AssociateTransceiver( // Check if the offer indicated simulcast but the answer rejected it. // This can happen when simulcast is not supported on the remote party. - if (SimulcastIsRejected(old_local_content, *media_desc)) { + if (SimulcastIsRejected(old_local_content, *media_desc, + pc_->GetCryptoOptions() + .srtp.enable_encrypted_rtp_header_extensions)) { RTC_HISTOGRAM_BOOLEAN(kSimulcastDisabled, true); RTCError error = DisableSimulcastInSender(transceiver->internal()->sender_internal()); @@ -3314,27 +3357,12 @@ SdpOfferAnswerHandler::AssociateTransceiver( return std::move(transceiver); } -RTCErrorOr -SdpOfferAnswerHandler::GetEarlyBundleGroup( - const SessionDescription& desc) const { - const cricket::ContentGroup* bundle_group = nullptr; - if (pc_->configuration()->bundle_policy == - PeerConnectionInterface::kBundlePolicyMaxBundle) { - bundle_group = desc.GetGroupByName(cricket::GROUP_TYPE_BUNDLE); - if (!bundle_group) { - LOG_AND_RETURN_ERROR(RTCErrorType::INVALID_PARAMETER, - "max-bundle configured but session description " - "has no BUNDLE group"); - } - } - return bundle_group; -} - RTCError SdpOfferAnswerHandler::UpdateTransceiverChannel( rtc::scoped_refptr> transceiver, const cricket::ContentInfo& content, const cricket::ContentGroup* bundle_group) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::UpdateTransceiverChannel"); RTC_DCHECK(IsUnifiedPlan()); RTC_DCHECK(transceiver); cricket::ChannelInterface* channel = transceiver->internal()->channel(); @@ -3366,30 +3394,23 @@ RTCError SdpOfferAnswerHandler::UpdateDataChannel( cricket::ContentSource source, const cricket::ContentInfo& content, const cricket::ContentGroup* bundle_group) { - if (pc_->data_channel_type() == cricket::DCT_NONE) { - // If data channels are disabled, ignore this media section. CreateAnswer - // will take care of rejecting it. - return RTCError::OK(); - } if (content.rejected) { - RTC_LOG(LS_INFO) << "Rejected data channel, mid=" << content.mid(); - DestroyDataChannelTransport(); + RTC_LOG(LS_INFO) << "Rejected data channel transport with mid=" + << content.mid(); + + rtc::StringBuilder sb; + sb << "Rejected data channel transport with mid=" << content.mid(); + RTCError error(RTCErrorType::OPERATION_ERROR_WITH_DATA, sb.Release()); + error.set_error_detail(RTCErrorDetailType::DATA_CHANNEL_FAILURE); + DestroyDataChannelTransport(error); } else { - if (!data_channel_controller()->rtp_data_channel() && - !data_channel_controller()->data_channel_transport()) { + if (!data_channel_controller()->data_channel_transport()) { RTC_LOG(LS_INFO) << "Creating data channel, mid=" << content.mid(); if (!CreateDataChannel(content.name)) { LOG_AND_RETURN_ERROR(RTCErrorType::INTERNAL_ERROR, "Failed to create data channel."); } } - if (source == cricket::CS_REMOTE) { - const MediaContentDescription* data_desc = content.media_description(); - if (data_desc && cricket::IsRtpProtocol(data_desc->protocol())) { - data_channel_controller()->UpdateRemoteRtpDataChannels( - GetActiveStreams(data_desc)); - } - } } return RTCError::OK(); } @@ -3483,19 +3504,17 @@ SdpOfferAnswerHandler::FindAvailableTransceiverToReceive( const cricket::ContentInfo* SdpOfferAnswerHandler::FindMediaSectionForTransceiver( - rtc::scoped_refptr> - transceiver, + const RtpTransceiver* transceiver, const SessionDescriptionInterface* sdesc) const { RTC_DCHECK_RUN_ON(signaling_thread()); RTC_DCHECK(transceiver); RTC_DCHECK(sdesc); if (IsUnifiedPlan()) { - if (!transceiver->internal()->mid()) { + if (!transceiver->mid()) { // This transceiver is not associated with a media section yet. return nullptr; } - return sdesc->description()->GetContentByName( - *transceiver->internal()->mid()); + return sdesc->description()->GetContentByName(*transceiver->mid()); } else { // Plan B only allows at most one audio and one video section, so use the // first media section of that type. @@ -3516,16 +3535,6 @@ void SdpOfferAnswerHandler::GetOptionsForOffer( GetOptionsForPlanBOffer(offer_answer_options, session_options); } - // Intentionally unset the data channel type for RTP data channel with the - // second condition. Otherwise the RTP data channels would be successfully - // negotiated by default and the unit tests in WebRtcDataBrowserTest will fail - // when building with chromium. We want to leave RTP data channels broken, so - // people won't try to use them. - if (data_channel_controller()->HasRtpDataChannels() || - pc_->data_channel_type() != cricket::DCT_RTP) { - session_options->data_channel_type = pc_->data_channel_type(); - } - // Apply ICE restart flag and renomination flag. bool ice_restart = offer_answer_options.ice_restart || HasNewIceCredentials(); for (auto& options : session_options->media_description_options) { @@ -3539,8 +3548,7 @@ void SdpOfferAnswerHandler::GetOptionsForOffer( session_options->pooled_ice_credentials = pc_->network_thread()->Invoke>( RTC_FROM_HERE, - rtc::Bind(&cricket::PortAllocator::GetPooledIceCredentials, - port_allocator())); + [this] { return port_allocator()->GetPooledIceCredentials(); }); session_options->offer_extmap_allow_mixed = pc_->configuration()->offer_extmap_allow_mixed; @@ -3703,7 +3711,7 @@ void SdpOfferAnswerHandler::GetOptionsForUnifiedPlanOffer( } else { session_options->media_description_options.push_back( GetMediaDescriptionOptionsForTransceiver( - transceiver, mid, + transceiver->internal(), mid, /*is_create_offer=*/true)); // CreateOffer shouldn't really cause any state changes in // PeerConnection, but we need a way to match new transceivers to new @@ -3741,7 +3749,7 @@ void SdpOfferAnswerHandler::GetOptionsForUnifiedPlanOffer( // and not associated). Reuse media sections marked as recyclable first, // otherwise append to the end of the offer. New media sections should be // added in the order they were added to the PeerConnection. - for (const auto& transceiver : transceivers()->List()) { + for (const auto& transceiver : transceivers()->ListInternal()) { if (transceiver->mid() || transceiver->stopping()) { continue; } @@ -3761,7 +3769,7 @@ void SdpOfferAnswerHandler::GetOptionsForUnifiedPlanOffer( /*is_create_offer=*/true)); } // See comment above for why CreateOffer changes the transceiver's state. - transceiver->internal()->set_mline_index(mline_index); + transceiver->set_mline_index(mline_index); } // Lastly, add a m-section if we have local data channels and an m section // does not already exist. @@ -3784,15 +3792,6 @@ void SdpOfferAnswerHandler::GetOptionsForAnswer( GetOptionsForPlanBAnswer(offer_answer_options, session_options); } - // Intentionally unset the data channel type for RTP data channel. Otherwise - // the RTP data channels would be successfully negotiated by default and the - // unit tests in WebRtcDataBrowserTest will fail when building with chromium. - // We want to leave RTP data channels broken, so people won't try to use them. - if (data_channel_controller()->HasRtpDataChannels() || - pc_->data_channel_type() != cricket::DCT_RTP) { - session_options->data_channel_type = pc_->data_channel_type(); - } - // Apply ICE renomination flag. for (auto& options : session_options->media_description_options) { options.transport_options.enable_ice_renomination = @@ -3804,8 +3803,7 @@ void SdpOfferAnswerHandler::GetOptionsForAnswer( session_options->pooled_ice_credentials = pc_->network_thread()->Invoke>( RTC_FROM_HERE, - rtc::Bind(&cricket::PortAllocator::GetPooledIceCredentials, - port_allocator())); + [this] { return port_allocator()->GetPooledIceCredentials(); }); } void SdpOfferAnswerHandler::GetOptionsForPlanBAnswer( @@ -3874,7 +3872,7 @@ void SdpOfferAnswerHandler::GetOptionsForUnifiedPlanAnswer( if (transceiver) { session_options->media_description_options.push_back( GetMediaDescriptionOptionsForTransceiver( - transceiver, content.name, + transceiver->internal(), content.name, /*is_create_offer=*/false)); } else { // This should only happen with rejected transceivers. @@ -3895,8 +3893,7 @@ void SdpOfferAnswerHandler::GetOptionsForUnifiedPlanAnswer( // Reject all data sections if data channels are disabled. // Reject a data section if it has already been rejected. // Reject all data sections except for the first one. - if (pc_->data_channel_type() == cricket::DCT_NONE || content.rejected || - content.name != *(pc_->GetDataMid())) { + if (content.rejected || content.name != *(pc_->GetDataMid())) { session_options->media_description_options.push_back( GetMediaDescriptionOptionsForRejectedData(content.name)); } else { @@ -4056,6 +4053,7 @@ void SdpOfferAnswerHandler::RemoveSenders(cricket::MediaType media_type) { void SdpOfferAnswerHandler::UpdateLocalSenders( const std::vector& streams, cricket::MediaType media_type) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::UpdateLocalSenders"); RTC_DCHECK_RUN_ON(signaling_thread()); std::vector* current_senders = rtp_manager()->GetLocalSenderInfos(media_type); @@ -4098,6 +4096,7 @@ void SdpOfferAnswerHandler::UpdateRemoteSendersList( bool default_sender_needed, cricket::MediaType media_type, StreamCollection* new_streams) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::UpdateRemoteSendersList"); RTC_DCHECK_RUN_ON(signaling_thread()); RTC_DCHECK(!IsUnifiedPlan()); @@ -4196,69 +4195,94 @@ void SdpOfferAnswerHandler::UpdateRemoteSendersList( } } +void SdpOfferAnswerHandler::EnableSending() { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::EnableSending"); + RTC_DCHECK_RUN_ON(signaling_thread()); + for (const auto& transceiver : transceivers()->ListInternal()) { + cricket::ChannelInterface* channel = transceiver->channel(); + if (channel) { + channel->Enable(true); + } + } +} + RTCError SdpOfferAnswerHandler::PushdownMediaDescription( SdpType type, - cricket::ContentSource source) { + cricket::ContentSource source, + const std::map& + bundle_groups_by_mid) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::PushdownMediaDescription"); const SessionDescriptionInterface* sdesc = (source == cricket::CS_LOCAL ? local_description() : remote_description()); RTC_DCHECK_RUN_ON(signaling_thread()); RTC_DCHECK(sdesc); - // Gather lists of updates to be made on cricket channels on the signaling - // thread, before performing them all at once on the worker thread. Necessary - // due to threading restrictions. - auto payload_type_demuxing_updates = GetPayloadTypeDemuxingUpdates(source); - std::vector content_updates; + if (!UpdatePayloadTypeDemuxingState(source, bundle_groups_by_mid)) { + // Note that this is never expected to fail, since RtpDemuxer doesn't return + // an error when changing payload type demux criteria, which is all this + // does. + LOG_AND_RETURN_ERROR(RTCErrorType::INTERNAL_ERROR, + "Failed to update payload type demuxing state."); + } - // Collect updates for each audio/video transceiver. - for (const auto& transceiver : transceivers()->List()) { + // Push down the new SDP media section for each audio/video transceiver. + auto rtp_transceivers = transceivers()->ListInternal(); + std::vector< + std::pair> + channels; + for (const auto& transceiver : rtp_transceivers) { const ContentInfo* content_info = FindMediaSectionForTransceiver(transceiver, sdesc); - cricket::ChannelInterface* channel = transceiver->internal()->channel(); + cricket::ChannelInterface* channel = transceiver->channel(); if (!channel || !content_info || content_info->rejected) { continue; } const MediaContentDescription* content_desc = content_info->media_description(); - if (content_desc) { - content_updates.emplace_back(channel, content_desc); + if (!content_desc) { + continue; } - } - // If using the RtpDataChannel, add it to the list of updates. - if (data_channel_controller()->rtp_data_channel()) { - const ContentInfo* data_content = - cricket::GetFirstDataContent(sdesc->description()); - if (data_content && !data_content->rejected) { - const MediaContentDescription* data_desc = - data_content->media_description(); - if (data_desc) { - content_updates.push_back( - {data_channel_controller()->rtp_data_channel(), data_desc}); - } - } + transceiver->OnNegotiationUpdate(type, content_desc); + channels.push_back(std::make_pair(channel, content_desc)); } - RTCError error = pc_->worker_thread()->Invoke( - RTC_FROM_HERE, - rtc::Bind(&SdpOfferAnswerHandler::ApplyChannelUpdates, this, type, source, - std::move(payload_type_demuxing_updates), - std::move(content_updates))); - if (!error.ok()) { - return error; + // This for-loop of invokes helps audio impairment during re-negotiations. + // One of the causes is that downstairs decoder creation is synchronous at the + // moment, and that a decoder is created for each codec listed in the SDP. + // + // TODO(bugs.webrtc.org/12840): consider merging the invokes again after + // these projects have shipped: + // - bugs.webrtc.org/12462 + // - crbug.com/1157227 + // - crbug.com/1187289 + for (const auto& entry : channels) { + RTCError error = + pc_->worker_thread()->Invoke(RTC_FROM_HERE, [&]() { + std::string error; + bool success = + (source == cricket::CS_LOCAL) + ? entry.first->SetLocalContent(entry.second, type, &error) + : entry.first->SetRemoteContent(entry.second, type, &error); + if (!success) { + LOG_AND_RETURN_ERROR(RTCErrorType::INVALID_PARAMETER, error); + } + return RTCError::OK(); + }); + if (!error.ok()) { + return error; + } } // Need complete offer/answer with an SCTP m= section before starting SCTP, // according to https://tools.ietf.org/html/draft-ietf-mmusic-sctp-sdp-19 if (pc_->sctp_mid() && local_description() && remote_description()) { - rtc::scoped_refptr sctp_transport = - transport_controller()->GetSctpTransport(*(pc_->sctp_mid())); auto local_sctp_description = cricket::GetFirstSctpDataContentDescription( local_description()->description()); auto remote_sctp_description = cricket::GetFirstSctpDataContentDescription( remote_description()->description()); - if (sctp_transport && local_sctp_description && remote_sctp_description) { + if (local_sctp_description && remote_sctp_description) { int max_message_size; // A remote max message size of zero means "any size supported". // We configure the connection with our own max message size. @@ -4269,60 +4293,19 @@ RTCError SdpOfferAnswerHandler::PushdownMediaDescription( std::min(local_sctp_description->max_message_size(), remote_sctp_description->max_message_size()); } - sctp_transport->Start(local_sctp_description->port(), - remote_sctp_description->port(), max_message_size); + pc_->StartSctpTransport(local_sctp_description->port(), + remote_sctp_description->port(), + max_message_size); } } return RTCError::OK(); } -RTCError SdpOfferAnswerHandler::ApplyChannelUpdates( - SdpType type, - cricket::ContentSource source, - std::vector payload_type_demuxing_updates, - std::vector content_updates) { - RTC_DCHECK_RUN_ON(pc_->worker_thread()); - // If this is answer-ish we're ready to let media flow. - bool enable_sending = type == SdpType::kPrAnswer || type == SdpType::kAnswer; - std::set modified_channels; - for (const auto& update : payload_type_demuxing_updates) { - modified_channels.insert(update.channel); - update.channel->SetPayloadTypeDemuxingEnabled(update.enabled); - } - for (const auto& update : content_updates) { - modified_channels.insert(update.channel); - std::string error; - bool success = (source == cricket::CS_LOCAL) - ? update.channel->SetLocalContent( - update.content_description, type, &error) - : update.channel->SetRemoteContent( - update.content_description, type, &error); - if (!success) { - LOG_AND_RETURN_ERROR(RTCErrorType::INVALID_PARAMETER, error); - } - if (enable_sending && !update.channel->enabled()) { - update.channel->Enable(true); - } - } - // The above calls may have modified properties of the channel (header - // extension mappings, demuxer criteria) which still need to be applied to the - // RtpTransport. - return pc_->network_thread()->Invoke( - RTC_FROM_HERE, [modified_channels] { - for (auto channel : modified_channels) { - std::string error; - if (!channel->UpdateRtpTransport(&error)) { - LOG_AND_RETURN_ERROR(RTCErrorType::INVALID_PARAMETER, error); - } - } - return RTCError::OK(); - }); -} - RTCError SdpOfferAnswerHandler::PushdownTransportDescription( cricket::ContentSource source, SdpType type) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::PushdownTransportDescription"); RTC_DCHECK_RUN_ON(signaling_thread()); if (source == cricket::CS_LOCAL) { @@ -4339,6 +4322,7 @@ RTCError SdpOfferAnswerHandler::PushdownTransportDescription( } void SdpOfferAnswerHandler::RemoveStoppedTransceivers() { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::RemoveStoppedTransceivers"); RTC_DCHECK_RUN_ON(signaling_thread()); // 3.2.10.1: For each transceiver in the connection's set of transceivers // run the following steps: @@ -4355,27 +4339,23 @@ void SdpOfferAnswerHandler::RemoveStoppedTransceivers() { if (!transceiver->stopped()) { continue; } - const ContentInfo* local_content = - FindMediaSectionForTransceiver(transceiver, local_description()); - const ContentInfo* remote_content = - FindMediaSectionForTransceiver(transceiver, remote_description()); + const ContentInfo* local_content = FindMediaSectionForTransceiver( + transceiver->internal(), local_description()); + const ContentInfo* remote_content = FindMediaSectionForTransceiver( + transceiver->internal(), remote_description()); if ((local_content && local_content->rejected) || (remote_content && remote_content->rejected)) { RTC_LOG(LS_INFO) << "Dissociating transceiver" - << " since the media section is being recycled."; + " since the media section is being recycled."; transceiver->internal()->set_mid(absl::nullopt); transceiver->internal()->set_mline_index(absl::nullopt); - transceivers()->Remove(transceiver); - continue; - } - if (!local_content && !remote_content) { + } else if (!local_content && !remote_content) { // TODO(bugs.webrtc.org/11973): Consider if this should be removed already // See https://github.com/w3c/webrtc-pc/issues/2576 RTC_LOG(LS_INFO) << "Dropping stopped transceiver that was never associated"; - transceivers()->Remove(transceiver); - continue; } + transceivers()->Remove(transceiver); } } @@ -4395,8 +4375,18 @@ void SdpOfferAnswerHandler::RemoveUnusedChannels( } const cricket::ContentInfo* data_info = cricket::GetFirstDataContent(desc); - if (!data_info || data_info->rejected) { - DestroyDataChannelTransport(); + if (!data_info) { + RTCError error(RTCErrorType::OPERATION_ERROR_WITH_DATA, + "No data channel section in the description."); + error.set_error_detail(RTCErrorDetailType::DATA_CHANNEL_FAILURE); + DestroyDataChannelTransport(error); + } else if (data_info->rejected) { + rtc::StringBuilder sb; + sb << "Rejected data channel with mid=" << data_info->name << "."; + + RTCError error(RTCErrorType::OPERATION_ERROR_WITH_DATA, sb.Release()); + error.set_error_detail(RTCErrorDetailType::DATA_CHANNEL_FAILURE); + DestroyDataChannelTransport(error); } } @@ -4473,40 +4463,23 @@ bool SdpOfferAnswerHandler::UseCandidatesInSessionDescription( bool SdpOfferAnswerHandler::UseCandidate( const IceCandidateInterface* candidate) { RTC_DCHECK_RUN_ON(signaling_thread()); + + rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; + RTCErrorOr result = FindContentInfo(remote_description(), candidate); - if (!result.ok()) { - RTC_LOG(LS_ERROR) << "UseCandidate: Invalid candidate. " - << result.error().message(); + if (!result.ok()) return false; + + const cricket::Candidate& c = candidate->candidate(); + RTCError error = cricket::VerifyCandidate(c); + if (!error.ok()) { + RTC_LOG(LS_WARNING) << "Invalid candidate: " << c.ToString(); + return true; } - std::vector candidates; - candidates.push_back(candidate->candidate()); - // Invoking BaseSession method to handle remote candidates. - RTCError error = transport_controller()->AddRemoteCandidates( - result.value()->name, candidates); - if (error.ok()) { - ReportRemoteIceCandidateAdded(candidate->candidate()); - // Candidates successfully submitted for checking. - if (pc_->ice_connection_state() == - PeerConnectionInterface::kIceConnectionNew || - pc_->ice_connection_state() == - PeerConnectionInterface::kIceConnectionDisconnected) { - // If state is New, then the session has just gotten its first remote ICE - // candidates, so go to Checking. - // If state is Disconnected, the session is re-using old candidates or - // receiving additional ones, so go to Checking. - // If state is Connected, stay Connected. - // TODO(bemasc): If state is Connected, and the new candidates are for a - // newly added transport, then the state actually _should_ move to - // checking. Add a way to distinguish that case. - pc_->SetIceConnectionState( - PeerConnectionInterface::kIceConnectionChecking); - } - // TODO(bemasc): If state is Completed, go back to Connected. - } else { - RTC_LOG(LS_WARNING) << error.message(); - } + + pc_->AddRemoteCandidate(result.value()->name, c); + return true; } @@ -4539,41 +4512,13 @@ bool SdpOfferAnswerHandler::ReadyToUseRemoteCandidate( return false; } - std::string transport_name = GetTransportName(result.value()->name); - return !transport_name.empty(); -} - -void SdpOfferAnswerHandler::ReportRemoteIceCandidateAdded( - const cricket::Candidate& candidate) { - pc_->NoteUsageEvent(UsageEvent::REMOTE_CANDIDATE_ADDED); - if (candidate.address().IsPrivateIP()) { - pc_->NoteUsageEvent(UsageEvent::REMOTE_PRIVATE_CANDIDATE_ADDED); - } - if (candidate.address().IsUnresolvedIP()) { - pc_->NoteUsageEvent(UsageEvent::REMOTE_MDNS_CANDIDATE_ADDED); - } - if (candidate.address().family() == AF_INET6) { - pc_->NoteUsageEvent(UsageEvent::REMOTE_IPV6_CANDIDATE_ADDED); - } + return true; } RTCErrorOr SdpOfferAnswerHandler::FindContentInfo( const SessionDescriptionInterface* description, const IceCandidateInterface* candidate) { - if (candidate->sdp_mline_index() >= 0) { - size_t mediacontent_index = - static_cast(candidate->sdp_mline_index()); - size_t content_size = description->description()->contents().size(); - if (mediacontent_index < content_size) { - return &description->description()->contents()[mediacontent_index]; - } else { - return RTCError(RTCErrorType::INVALID_RANGE, - "Media line index (" + - rtc::ToString(candidate->sdp_mline_index()) + - ") out of range (number of mlines: " + - rtc::ToString(content_size) + ")."); - } - } else if (!candidate->sdp_mid().empty()) { + if (!candidate->sdp_mid().empty()) { auto& contents = description->description()->contents(); auto it = absl::c_find_if( contents, [candidate](const cricket::ContentInfo& content_info) { @@ -4587,6 +4532,19 @@ RTCErrorOr SdpOfferAnswerHandler::FindContentInfo( } else { return &*it; } + } else if (candidate->sdp_mline_index() >= 0) { + size_t mediacontent_index = + static_cast(candidate->sdp_mline_index()); + size_t content_size = description->description()->contents().size(); + if (mediacontent_index < content_size) { + return &description->description()->contents()[mediacontent_index]; + } else { + return RTCError(RTCErrorType::INVALID_RANGE, + "Media line index (" + + rtc::ToString(candidate->sdp_mline_index()) + + ") out of range (number of mlines: " + + rtc::ToString(content_size) + ")."); + } } return RTCError(RTCErrorType::INVALID_PARAMETER, @@ -4594,6 +4552,7 @@ RTCErrorOr SdpOfferAnswerHandler::FindContentInfo( } RTCError SdpOfferAnswerHandler::CreateChannels(const SessionDescription& desc) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::CreateChannels"); // Creating the media channels. Transports should already have been created // at this point. RTC_DCHECK_RUN_ON(signaling_thread()); @@ -4620,8 +4579,7 @@ RTCError SdpOfferAnswerHandler::CreateChannels(const SessionDescription& desc) { } const cricket::ContentInfo* data = cricket::GetFirstDataContent(&desc); - if (pc_->data_channel_type() != cricket::DCT_NONE && data && - !data->rejected && !data_channel_controller()->rtp_data_channel() && + if (data && !data->rejected && !data_channel_controller()->data_channel_transport()) { if (!CreateDataChannel(data->name)) { LOG_AND_RETURN_ERROR(RTCErrorType::INTERNAL_ERROR, @@ -4635,141 +4593,123 @@ RTCError SdpOfferAnswerHandler::CreateChannels(const SessionDescription& desc) { // TODO(steveanton): Perhaps this should be managed by the RtpTransceiver. cricket::VoiceChannel* SdpOfferAnswerHandler::CreateVoiceChannel( const std::string& mid) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::CreateVoiceChannel"); RTC_DCHECK_RUN_ON(signaling_thread()); + if (!channel_manager()->media_engine()) + return nullptr; + RtpTransportInternal* rtp_transport = pc_->GetRtpTransport(mid); // TODO(bugs.webrtc.org/11992): CreateVoiceChannel internally switches to the // worker thread. We shouldn't be using the |call_ptr_| hack here but simply // be on the worker thread and use |call_| (update upstream code). - cricket::VoiceChannel* voice_channel; - { - RTC_DCHECK_RUN_ON(pc_->signaling_thread()); - voice_channel = channel_manager()->CreateVoiceChannel( - pc_->call_ptr(), pc_->configuration()->media_config, rtp_transport, - signaling_thread(), mid, pc_->SrtpRequired(), pc_->GetCryptoOptions(), - &ssrc_generator_, audio_options()); - } - if (!voice_channel) { - return nullptr; - } - voice_channel->SignalSentPacket().connect(pc_, - &PeerConnection::OnSentPacket_w); - voice_channel->SetRtpTransport(rtp_transport); - - return voice_channel; + return channel_manager()->CreateVoiceChannel( + pc_->call_ptr(), pc_->configuration()->media_config, rtp_transport, + signaling_thread(), mid, pc_->SrtpRequired(), pc_->GetCryptoOptions(), + &ssrc_generator_, audio_options()); } // TODO(steveanton): Perhaps this should be managed by the RtpTransceiver. cricket::VideoChannel* SdpOfferAnswerHandler::CreateVideoChannel( const std::string& mid) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::CreateVideoChannel"); RTC_DCHECK_RUN_ON(signaling_thread()); + if (!channel_manager()->media_engine()) + return nullptr; + + // NOTE: This involves a non-ideal hop (Invoke) over to the network thread. RtpTransportInternal* rtp_transport = pc_->GetRtpTransport(mid); // TODO(bugs.webrtc.org/11992): CreateVideoChannel internally switches to the // worker thread. We shouldn't be using the |call_ptr_| hack here but simply // be on the worker thread and use |call_| (update upstream code). - cricket::VideoChannel* video_channel; - { - RTC_DCHECK_RUN_ON(pc_->signaling_thread()); - video_channel = channel_manager()->CreateVideoChannel( - pc_->call_ptr(), pc_->configuration()->media_config, rtp_transport, - signaling_thread(), mid, pc_->SrtpRequired(), pc_->GetCryptoOptions(), - &ssrc_generator_, video_options(), - video_bitrate_allocator_factory_.get()); - } - if (!video_channel) { - return nullptr; - } - video_channel->SignalSentPacket().connect(pc_, - &PeerConnection::OnSentPacket_w); - video_channel->SetRtpTransport(rtp_transport); - - return video_channel; + return channel_manager()->CreateVideoChannel( + pc_->call_ptr(), pc_->configuration()->media_config, rtp_transport, + signaling_thread(), mid, pc_->SrtpRequired(), pc_->GetCryptoOptions(), + &ssrc_generator_, video_options(), + video_bitrate_allocator_factory_.get()); } bool SdpOfferAnswerHandler::CreateDataChannel(const std::string& mid) { RTC_DCHECK_RUN_ON(signaling_thread()); - switch (pc_->data_channel_type()) { - case cricket::DCT_SCTP: - if (pc_->network_thread()->Invoke( - RTC_FROM_HERE, - rtc::Bind(&PeerConnection::SetupDataChannelTransport_n, pc_, - mid))) { - pc_->SetSctpDataMid(mid); - } else { - return false; - } - return true; - case cricket::DCT_RTP: - default: - RtpTransportInternal* rtp_transport = pc_->GetRtpTransport(mid); - // TODO(bugs.webrtc.org/9987): set_rtp_data_channel() should be called on - // the network thread like set_data_channel_transport is. - { - RTC_DCHECK_RUN_ON(pc_->signaling_thread()); - data_channel_controller()->set_rtp_data_channel( - channel_manager()->CreateRtpDataChannel( - pc_->configuration()->media_config, rtp_transport, - signaling_thread(), mid, pc_->SrtpRequired(), - pc_->GetCryptoOptions(), &ssrc_generator_)); - } - if (!data_channel_controller()->rtp_data_channel()) { - return false; - } - data_channel_controller()->rtp_data_channel()->SignalSentPacket().connect( - pc_, &PeerConnection::OnSentPacket_w); - data_channel_controller()->rtp_data_channel()->SetRtpTransport( - rtp_transport); - SetHavePendingRtpDataChannel(); - return true; + if (!pc_->network_thread()->Invoke(RTC_FROM_HERE, [this, &mid] { + RTC_DCHECK_RUN_ON(pc_->network_thread()); + return pc_->SetupDataChannelTransport_n(mid); + })) { + return false; } - return false; + // TODO(tommi): Is this necessary? SetupDataChannelTransport_n() above + // will have queued up updating the transport name on the signaling thread + // and could update the mid at the same time. This here is synchronous + // though, but it changes the state of PeerConnection and makes it be + // out of sync (transport name not set while the mid is set). + pc_->SetSctpDataMid(mid); + return true; } void SdpOfferAnswerHandler::DestroyTransceiverChannel( rtc::scoped_refptr> transceiver) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::DestroyTransceiverChannel"); RTC_DCHECK(transceiver); + RTC_LOG_THREAD_BLOCK_COUNT(); + + // TODO(tommi): We're currently on the signaling thread. + // There are multiple hops to the worker ahead. + // Consider if we can make the call to SetChannel() on the worker thread + // (and require that to be the context it's always called in) and also + // call DestroyChannelInterface there, since it also needs to hop to the + // worker. cricket::ChannelInterface* channel = transceiver->internal()->channel(); + RTC_DCHECK_BLOCK_COUNT_NO_MORE_THAN(0); if (channel) { + // TODO(tommi): VideoRtpReceiver::SetMediaChannel blocks and jumps to the + // worker thread. When being set to nullptr, there are additional + // blocking calls to e.g. ClearRecordableEncodedFrameCallback which triggers + // another blocking call or Stop() for video channels. + // The channel object also needs to be de-initialized on the network thread + // so if ownership of the channel object lies with the transceiver, we could + // un-set the channel pointer and uninitialize/destruct the channel object + // at the same time, rather than in separate steps. transceiver->internal()->SetChannel(nullptr); + // TODO(tommi): All channel objects end up getting deleted on the + // worker thread (ideally should be on the network thread but the + // MediaChannel objects are tied to the worker. Can the teardown be done + // asynchronously across the threads rather than blocking? DestroyChannelInterface(channel); } } -void SdpOfferAnswerHandler::DestroyDataChannelTransport() { +void SdpOfferAnswerHandler::DestroyDataChannelTransport(RTCError error) { RTC_DCHECK_RUN_ON(signaling_thread()); - if (data_channel_controller()->rtp_data_channel()) { - data_channel_controller()->OnTransportChannelClosed(); - DestroyChannelInterface(data_channel_controller()->rtp_data_channel()); - data_channel_controller()->set_rtp_data_channel(nullptr); - } - - // Note: Cannot use rtc::Bind to create a functor to invoke because it will - // grab a reference to this PeerConnection. If this is called from the - // PeerConnection destructor, the RefCountedObject vtable will have already - // been destroyed (since it is a subclass of PeerConnection) and using - // rtc::Bind will cause "Pure virtual function called" error to appear. - - if (pc_->sctp_mid()) { - RTC_DCHECK_RUN_ON(pc_->signaling_thread()); - data_channel_controller()->OnTransportChannelClosed(); - pc_->network_thread()->Invoke(RTC_FROM_HERE, [this] { - RTC_DCHECK_RUN_ON(pc_->network_thread()); - pc_->TeardownDataChannelTransport_n(); - }); + const bool has_sctp = pc_->sctp_mid().has_value(); + + if (has_sctp) + data_channel_controller()->OnTransportChannelClosed(error); + + pc_->network_thread()->Invoke(RTC_FROM_HERE, [this] { + RTC_DCHECK_RUN_ON(pc_->network_thread()); + pc_->TeardownDataChannelTransport_n(); + }); + + if (has_sctp) pc_->ResetSctpDataMid(); - } } void SdpOfferAnswerHandler::DestroyChannelInterface( cricket::ChannelInterface* channel) { + TRACE_EVENT0("webrtc", "SdpOfferAnswerHandler::DestroyChannelInterface"); + RTC_DCHECK_RUN_ON(signaling_thread()); + RTC_DCHECK(channel_manager()->media_engine()); + RTC_DCHECK(channel); + // TODO(bugs.webrtc.org/11992): All the below methods should be called on the // worker thread. (they switch internally anyway). Change // DestroyChannelInterface to either be called on the worker thread, or do // this asynchronously on the worker. - RTC_DCHECK(channel); + RTC_LOG_THREAD_BLOCK_COUNT(); + switch (channel->media_type()) { case cricket::MEDIA_TYPE_AUDIO: channel_manager()->DestroyVoiceChannel( @@ -4780,13 +4720,19 @@ void SdpOfferAnswerHandler::DestroyChannelInterface( static_cast(channel)); break; case cricket::MEDIA_TYPE_DATA: - channel_manager()->DestroyRtpDataChannel( - static_cast(channel)); + RTC_NOTREACHED() + << "Trying to destroy datachannel through DestroyChannelInterface"; break; default: RTC_NOTREACHED() << "Unknown media type: " << channel->media_type(); break; } + + // TODO(tommi): Figure out why we can get 2 blocking calls when running + // PeerConnectionCryptoTest.CreateAnswerWithDifferentSslRoles. + // and 3 when running + // PeerConnectionCryptoTest.CreateAnswerWithDifferentSslRoles + // RTC_DCHECK_BLOCK_COUNT_NO_MORE_THAN(1); } void SdpOfferAnswerHandler::DestroyAllChannels() { @@ -4794,19 +4740,26 @@ void SdpOfferAnswerHandler::DestroyAllChannels() { if (!transceivers()) { return; } + + RTC_LOG_THREAD_BLOCK_COUNT(); + // Destroy video channels first since they may have a pointer to a voice // channel. - for (const auto& transceiver : transceivers()->List()) { + auto list = transceivers()->List(); + RTC_DCHECK_BLOCK_COUNT_NO_MORE_THAN(0); + + for (const auto& transceiver : list) { if (transceiver->media_type() == cricket::MEDIA_TYPE_VIDEO) { DestroyTransceiverChannel(transceiver); } } - for (const auto& transceiver : transceivers()->List()) { + for (const auto& transceiver : list) { if (transceiver->media_type() == cricket::MEDIA_TYPE_AUDIO) { DestroyTransceiverChannel(transceiver); } } - DestroyDataChannelTransport(); + + DestroyDataChannelTransport({}); } void SdpOfferAnswerHandler::GenerateMediaDescriptionOptions( @@ -4884,8 +4837,6 @@ SdpOfferAnswerHandler::GetMediaDescriptionOptionsForActiveData( cricket::MediaDescriptionOptions options(cricket::MEDIA_TYPE_DATA, mid, RtpTransceiverDirection::kSendRecv, /*stopped=*/false); - AddRtpDataChannelOptions(*(data_channel_controller()->rtp_data_channels()), - &options); return options; } @@ -4896,31 +4847,15 @@ SdpOfferAnswerHandler::GetMediaDescriptionOptionsForRejectedData( cricket::MediaDescriptionOptions options(cricket::MEDIA_TYPE_DATA, mid, RtpTransceiverDirection::kInactive, /*stopped=*/true); - AddRtpDataChannelOptions(*(data_channel_controller()->rtp_data_channels()), - &options); return options; } -const std::string SdpOfferAnswerHandler::GetTransportName( - const std::string& content_name) { - RTC_DCHECK_RUN_ON(signaling_thread()); - cricket::ChannelInterface* channel = pc_->GetChannel(content_name); - if (channel) { - return channel->transport_name(); - } - if (data_channel_controller()->data_channel_transport()) { - RTC_DCHECK(pc_->sctp_mid()); - if (content_name == *(pc_->sctp_mid())) { - return *(pc_->sctp_transport_name()); - } - } - // Return an empty string if failed to retrieve the transport name. - return ""; -} - -std::vector -SdpOfferAnswerHandler::GetPayloadTypeDemuxingUpdates( - cricket::ContentSource source) { +bool SdpOfferAnswerHandler::UpdatePayloadTypeDemuxingState( + cricket::ContentSource source, + const std::map& + bundle_groups_by_mid) { + TRACE_EVENT0("webrtc", + "SdpOfferAnswerHandler::UpdatePayloadTypeDemuxingState"); RTC_DCHECK_RUN_ON(signaling_thread()); // We may need to delete any created default streams and disable creation of // new ones on the basis of payload type. This is needed to avoid SSRC @@ -4933,19 +4868,27 @@ SdpOfferAnswerHandler::GetPayloadTypeDemuxingUpdates( const SessionDescriptionInterface* sdesc = (source == cricket::CS_LOCAL ? local_description() : remote_description()); - const cricket::ContentGroup* bundle_group = - sdesc->description()->GetGroupByName(cricket::GROUP_TYPE_BUNDLE); - std::set audio_payload_types; - std::set video_payload_types; - bool pt_demuxing_enabled_audio = true; - bool pt_demuxing_enabled_video = true; + struct PayloadTypes { + std::set audio_payload_types; + std::set video_payload_types; + bool pt_demuxing_possible_audio = true; + bool pt_demuxing_possible_video = true; + }; + std::map payload_types_by_bundle; + // If the MID is missing from *any* receiving m= section, this is set to true. + bool mid_header_extension_missing_audio = false; + bool mid_header_extension_missing_video = false; for (auto& content_info : sdesc->description()->contents()) { + auto it = bundle_groups_by_mid.find(content_info.name); + const cricket::ContentGroup* bundle_group = + it != bundle_groups_by_mid.end() ? it->second : nullptr; // If this m= section isn't bundled, it's safe to demux by payload type // since other m= sections using the same payload type will also be using // different transports. - if (!bundle_group || !bundle_group->HasContentName(content_info.name)) { + if (!bundle_group) { continue; } + PayloadTypes* payload_types = &payload_types_by_bundle[bundle_group]; if (content_info.rejected || (source == cricket::ContentSource::CS_LOCAL && !RtpTransceiverDirectionHasRecv( @@ -4958,28 +4901,36 @@ SdpOfferAnswerHandler::GetPayloadTypeDemuxingUpdates( } switch (content_info.media_description()->type()) { case cricket::MediaType::MEDIA_TYPE_AUDIO: { + if (!mid_header_extension_missing_audio) { + mid_header_extension_missing_audio = + !ContentHasHeaderExtension(content_info, RtpExtension::kMidUri); + } const cricket::AudioContentDescription* audio_desc = content_info.media_description()->as_audio(); for (const cricket::AudioCodec& audio : audio_desc->codecs()) { - if (audio_payload_types.count(audio.id)) { + if (payload_types->audio_payload_types.count(audio.id)) { // Two m= sections are using the same payload type, thus demuxing // by payload type is not possible. - pt_demuxing_enabled_audio = false; + payload_types->pt_demuxing_possible_audio = false; } - audio_payload_types.insert(audio.id); + payload_types->audio_payload_types.insert(audio.id); } break; } case cricket::MediaType::MEDIA_TYPE_VIDEO: { + if (!mid_header_extension_missing_video) { + mid_header_extension_missing_video = + !ContentHasHeaderExtension(content_info, RtpExtension::kMidUri); + } const cricket::VideoContentDescription* video_desc = content_info.media_description()->as_video(); for (const cricket::VideoCodec& video : video_desc->codecs()) { - if (video_payload_types.count(video.id)) { + if (payload_types->video_payload_types.count(video.id)) { // Two m= sections are using the same payload type, thus demuxing // by payload type is not possible. - pt_demuxing_enabled_video = false; + payload_types->pt_demuxing_possible_video = false; } - video_payload_types.insert(video.id); + payload_types->video_payload_types.insert(video.id); } break; } @@ -4991,9 +4942,10 @@ SdpOfferAnswerHandler::GetPayloadTypeDemuxingUpdates( // Gather all updates ahead of time so that all channels can be updated in a // single Invoke; necessary due to thread guards. - std::vector channel_updates; - for (const auto& transceiver : transceivers()->List()) { - cricket::ChannelInterface* channel = transceiver->internal()->channel(); + std::vector> + channels_to_update; + for (const auto& transceiver : transceivers()->ListInternal()) { + cricket::ChannelInterface* channel = transceiver->channel(); const ContentInfo* content = FindMediaSectionForTransceiver(transceiver, sdesc); if (!channel || !content) { @@ -5004,22 +4956,81 @@ SdpOfferAnswerHandler::GetPayloadTypeDemuxingUpdates( if (source == cricket::CS_REMOTE) { local_direction = RtpTransceiverDirectionReversed(local_direction); } - cricket::MediaType media_type = channel->media_type(); - bool in_bundle_group = - (bundle_group && bundle_group->HasContentName(channel->content_name())); - bool payload_type_demuxing_enabled = false; - if (media_type == cricket::MediaType::MEDIA_TYPE_AUDIO) { - payload_type_demuxing_enabled = - (!in_bundle_group || pt_demuxing_enabled_audio) && - RtpTransceiverDirectionHasRecv(local_direction); - } else if (media_type == cricket::MediaType::MEDIA_TYPE_VIDEO) { - payload_type_demuxing_enabled = - (!in_bundle_group || pt_demuxing_enabled_video) && - RtpTransceiverDirectionHasRecv(local_direction); - } - channel_updates.emplace_back(channel, payload_type_demuxing_enabled); - } - return channel_updates; + channels_to_update.emplace_back(local_direction, transceiver->channel()); + } + + if (channels_to_update.empty()) { + return true; + } + + // In Unified Plan, payload type demuxing is useful for legacy endpoints that + // don't support the MID header extension, but it can also cause incorrrect + // forwarding of packets when going from one m= section to multiple m= + // sections in the same BUNDLE. This only happens if media arrives prior to + // negotiation, but this can cause missing video and unsignalled ssrc bugs + // severe enough to warrant disabling PT demuxing in such cases. Therefore, if + // a MID header extension is present on all m= sections for a given kind + // (audio/video) then we use that as an OK to disable payload type demuxing in + // BUNDLEs of that kind. However if PT demuxing was ever turned on (e.g. MID + // was ever removed on ANY m= section of that kind) then we continue to allow + // PT demuxing in order to prevent disabling it in follow-up O/A exchanges and + // allowing early media by PT. + bool bundled_pt_demux_allowed_audio = !IsUnifiedPlan() || + mid_header_extension_missing_audio || + pt_demuxing_has_been_used_audio_; + bool bundled_pt_demux_allowed_video = !IsUnifiedPlan() || + mid_header_extension_missing_video || + pt_demuxing_has_been_used_video_; + // Kill switch for the above change. + if (field_trial::IsEnabled(kAlwaysAllowPayloadTypeDemuxingFieldTrialName)) { + // TODO(https://crbug.com/webrtc/12814): If disabling PT-based demux does + // not trigger regressions, remove this kill switch. + bundled_pt_demux_allowed_audio = true; + bundled_pt_demux_allowed_video = true; + } + + return pc_->worker_thread()->Invoke( + RTC_FROM_HERE, + [&channels_to_update, &bundle_groups_by_mid, &payload_types_by_bundle, + bundled_pt_demux_allowed_audio, bundled_pt_demux_allowed_video, + pt_demuxing_has_been_used_audio = &pt_demuxing_has_been_used_audio_, + pt_demuxing_has_been_used_video = &pt_demuxing_has_been_used_video_]() { + for (const auto& it : channels_to_update) { + RtpTransceiverDirection local_direction = it.first; + cricket::ChannelInterface* channel = it.second; + cricket::MediaType media_type = channel->media_type(); + auto bundle_it = bundle_groups_by_mid.find(channel->content_name()); + const cricket::ContentGroup* bundle_group = + bundle_it != bundle_groups_by_mid.end() ? bundle_it->second + : nullptr; + if (media_type == cricket::MediaType::MEDIA_TYPE_AUDIO) { + bool pt_demux_enabled = + RtpTransceiverDirectionHasRecv(local_direction) && + (!bundle_group || (bundled_pt_demux_allowed_audio && + payload_types_by_bundle[bundle_group] + .pt_demuxing_possible_audio)); + if (pt_demux_enabled) { + *pt_demuxing_has_been_used_audio = true; + } + if (!channel->SetPayloadTypeDemuxingEnabled(pt_demux_enabled)) { + return false; + } + } else if (media_type == cricket::MediaType::MEDIA_TYPE_VIDEO) { + bool pt_demux_enabled = + RtpTransceiverDirectionHasRecv(local_direction) && + (!bundle_group || (bundled_pt_demux_allowed_video && + payload_types_by_bundle[bundle_group] + .pt_demuxing_possible_video)); + if (pt_demux_enabled) { + *pt_demuxing_has_been_used_video = true; + } + if (!channel->SetPayloadTypeDemuxingEnabled(pt_demux_enabled)) { + return false; + } + } + } + return true; + }); } } // namespace webrtc diff --git a/pc/sdp_offer_answer.h b/pc/sdp_offer_answer.h index 4b14f20708..f86b900b91 100644 --- a/pc/sdp_offer_answer.h +++ b/pc/sdp_offer_answer.h @@ -13,6 +13,7 @@ #include #include + #include #include #include @@ -33,10 +34,12 @@ #include "api/rtp_transceiver_direction.h" #include "api/rtp_transceiver_interface.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "api/set_local_description_observer_interface.h" #include "api/set_remote_description_observer_interface.h" #include "api/transport/data_channel_transport_interface.h" #include "api/turn_customizer.h" +#include "api/uma_metrics.h" #include "api/video/video_bitrate_allocator_factory.h" #include "media/base/media_channel.h" #include "media/base/stream_params.h" @@ -69,7 +72,6 @@ #include "rtc_base/race_checker.h" #include "rtc_base/rtc_certificate.h" #include "rtc_base/ssl_stream_adapter.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/thread.h" #include "rtc_base/thread_annotations.h" @@ -172,19 +174,6 @@ class SdpOfferAnswerHandler : public SdpStateProvider, absl::optional is_caller(); bool HasNewIceCredentials(); void UpdateNegotiationNeeded(); - void SetHavePendingRtpDataChannel() { - RTC_DCHECK_RUN_ON(signaling_thread()); - have_pending_rtp_data_channel_ = true; - } - - // Returns the media section in the given session description that is - // associated with the RtpTransceiver. Returns null if none found or this - // RtpTransceiver is not associated. Logic varies depending on the - // SdpSemantics specified in the configuration. - const cricket::ContentInfo* FindMediaSectionForTransceiver( - rtc::scoped_refptr> - transceiver, - const SessionDescriptionInterface* sdesc) const; // Destroys all BaseChannels and destroys the SCTP data channel, if present. void DestroyAllChannels(); @@ -239,9 +228,13 @@ class SdpOfferAnswerHandler : public SdpStateProvider, // Synchronous implementations of SetLocalDescription/SetRemoteDescription // that return an RTCError instead of invoking a callback. RTCError ApplyLocalDescription( - std::unique_ptr desc); + std::unique_ptr desc, + const std::map& + bundle_groups_by_mid); RTCError ApplyRemoteDescription( - std::unique_ptr desc); + std::unique_ptr desc, + const std::map& + bundle_groups_by_mid); // Implementation of the offer/answer exchange operations. These are chained // onto the |operations_chain_| when the public CreateOffer(), CreateAnswer(), @@ -263,9 +256,12 @@ class SdpOfferAnswerHandler : public SdpStateProvider, void ChangeSignalingState( PeerConnectionInterface::SignalingState signaling_state); - RTCError UpdateSessionState(SdpType type, - cricket::ContentSource source, - const cricket::SessionDescription* description); + RTCError UpdateSessionState( + SdpType type, + cricket::ContentSource source, + const cricket::SessionDescription* description, + const std::map& + bundle_groups_by_mid); bool IsUnifiedPlan() const RTC_RUN_ON(signaling_thread()); @@ -298,9 +294,11 @@ class SdpOfferAnswerHandler : public SdpStateProvider, bool CheckIfNegotiationIsNeeded(); void GenerateNegotiationNeededEvent(); // Helper method which verifies SDP. - RTCError ValidateSessionDescription(const SessionDescriptionInterface* sdesc, - cricket::ContentSource source) - RTC_RUN_ON(signaling_thread()); + RTCError ValidateSessionDescription( + const SessionDescriptionInterface* sdesc, + cricket::ContentSource source, + const std::map& + bundle_groups_by_mid) RTC_RUN_ON(signaling_thread()); // Updates the local RtpTransceivers according to the JSEP rules. Called as // part of setting the local/remote description. @@ -308,7 +306,9 @@ class SdpOfferAnswerHandler : public SdpStateProvider, cricket::ContentSource source, const SessionDescriptionInterface& new_session, const SessionDescriptionInterface* old_local_description, - const SessionDescriptionInterface* old_remote_description); + const SessionDescriptionInterface* old_remote_description, + const std::map& + bundle_groups_by_mid); // Associate the given transceiver according to the JSEP rules. RTCErrorOr< @@ -321,14 +321,13 @@ class SdpOfferAnswerHandler : public SdpStateProvider, const cricket::ContentInfo* old_remote_content) RTC_RUN_ON(signaling_thread()); - // If the BUNDLE policy is max-bundle, then we know for sure that all - // transports will be bundled from the start. This method returns the BUNDLE - // group if that's the case, or null if BUNDLE will be negotiated later. An - // error is returned if max-bundle is specified but the session description - // does not have a BUNDLE group. - RTCErrorOr GetEarlyBundleGroup( - const cricket::SessionDescription& desc) const - RTC_RUN_ON(signaling_thread()); + // Returns the media section in the given session description that is + // associated with the RtpTransceiver. Returns null if none found or this + // RtpTransceiver is not associated. Logic varies depending on the + // SdpSemantics specified in the configuration. + const cricket::ContentInfo* FindMediaSectionForTransceiver( + const RtpTransceiver* transceiver, + const SessionDescriptionInterface* sdesc) const; // Either creates or destroys the transceiver's BaseChannel according to the // given media section. @@ -422,7 +421,7 @@ class SdpOfferAnswerHandler : public SdpStateProvider, // |removed_streams| is the list of streams which no longer have a receiving // track so should be removed. void ProcessRemovalOfRemoteTrack( - rtc::scoped_refptr> + const rtc::scoped_refptr> transceiver, std::vector>* remove_list, std::vector>* removed_streams); @@ -455,31 +454,16 @@ class SdpOfferAnswerHandler : public SdpStateProvider, cricket::MediaType media_type, StreamCollection* new_streams); + // Enables media channels to allow sending of media. + // This enables media to flow on all configured audio/video channels. + void EnableSending(); // Push the media parts of the local or remote session description - // down to all of the channels, and enable sending if applicable. - RTCError PushdownMediaDescription(SdpType type, - cricket::ContentSource source); - - struct PayloadTypeDemuxingUpdate { - PayloadTypeDemuxingUpdate(cricket::ChannelInterface* channel, bool enabled) - : channel(channel), enabled(enabled) {} - cricket::ChannelInterface* channel; - bool enabled; - }; - struct ContentUpdate { - ContentUpdate(cricket::ChannelInterface* channel, - const cricket::MediaContentDescription* content_description) - : channel(channel), content_description(content_description) {} - cricket::ChannelInterface* channel; - const cricket::MediaContentDescription* content_description; - }; - // Helper method used by PushdownMediaDescription to apply a batch of updates - // to BaseChannels on the worker thread. - RTCError ApplyChannelUpdates( + // down to all of the channels. + RTCError PushdownMediaDescription( SdpType type, cricket::ContentSource source, - std::vector payload_type_demuxing_updates, - std::vector content_updates); + const std::map& + bundle_groups_by_mid); RTCError PushdownTransportDescription(cricket::ContentSource source, SdpType type); @@ -510,8 +494,6 @@ class SdpOfferAnswerHandler : public SdpStateProvider, bool ReadyToUseRemoteCandidate(const IceCandidateInterface* candidate, const SessionDescriptionInterface* remote_desc, bool* valid); - void ReportRemoteIceCandidateAdded(const cricket::Candidate& candidate) - RTC_RUN_ON(signaling_thread()); RTCErrorOr FindContentInfo( const SessionDescriptionInterface* description, @@ -539,7 +521,7 @@ class SdpOfferAnswerHandler : public SdpStateProvider, // Destroys the RTP data channel transport and/or the SCTP data channel // transport and clears it. - void DestroyDataChannelTransport(); + void DestroyDataChannelTransport(RTCError error); // Destroys the given ChannelInterface. // The channel cannot be accessed after this method is called. @@ -566,15 +548,12 @@ class SdpOfferAnswerHandler : public SdpStateProvider, cricket::MediaDescriptionOptions GetMediaDescriptionOptionsForRejectedData( const std::string& mid) const; - const std::string GetTransportName(const std::string& content_name); - - // Based on number of transceivers per media type, and their bundle status and - // payload types, determine whether payload type based demuxing should be - // enabled or disabled. Returns a list of channels and the corresponding - // value to be passed into SetPayloadTypeDemuxingEnabled, so that this action - // can be combined with other operations on the worker thread. - std::vector GetPayloadTypeDemuxingUpdates( - cricket::ContentSource source); + // Based on number of transceivers per media type, enabled or disable + // payload type based demuxing in the affected channels. + bool UpdatePayloadTypeDemuxingState( + cricket::ContentSource source, + const std::map& + bundle_groups_by_mid); // ================================================================== // Access to pc_ variables @@ -651,6 +630,11 @@ class SdpOfferAnswerHandler : public SdpStateProvider, uint32_t negotiation_needed_event_id_ = 0; bool update_negotiation_needed_on_empty_chain_ RTC_GUARDED_BY(signaling_thread()) = false; + // If PT demuxing is successfully negotiated one time we will allow PT + // demuxing for the rest of the session so that PT-based apps default to PT + // demuxing in follow-up O/A exchanges. + bool pt_demuxing_has_been_used_audio_ = false; + bool pt_demuxing_has_been_used_video_ = false; // In Unified Plan, if we encounter remote SDP that does not contain an a=msid // line we create and use a stream with a random ID for our receivers. This is @@ -659,13 +643,15 @@ class SdpOfferAnswerHandler : public SdpStateProvider, rtc::scoped_refptr missing_msid_default_stream_ RTC_GUARDED_BY(signaling_thread()); - // Used when rolling back RTP data channels. - bool have_pending_rtp_data_channel_ RTC_GUARDED_BY(signaling_thread()) = - false; - // Updates the error state, signaling if necessary. void SetSessionError(SessionError error, const std::string& error_desc); + // Implements AddIceCandidate without reporting usage, but returns the + // particular success/error value that should be reported (and can be utilized + // for other purposes). + AddIceCandidateResult AddIceCandidateInternal( + const IceCandidateInterface* candidate); + SessionError session_error_ RTC_GUARDED_BY(signaling_thread()) = SessionError::kNone; std::string session_error_desc_ RTC_GUARDED_BY(signaling_thread()); @@ -678,8 +664,9 @@ class SdpOfferAnswerHandler : public SdpStateProvider, // specified by the user (or by the remote party). // The generator is not used directly, instead it is passed on to the // channel manager and the session description factory. - rtc::UniqueRandomIdGenerator ssrc_generator_ - RTC_GUARDED_BY(signaling_thread()); + // TODO(bugs.webrtc.org/12666): This variable is used from both the signaling + // and worker threads. See if we can't restrict usage to a single thread. + rtc::UniqueRandomIdGenerator ssrc_generator_; // A video bitrate allocator factory. // This can be injected using the PeerConnectionDependencies, diff --git a/pc/sdp_serializer.cc b/pc/sdp_serializer.cc index 7ebaffda86..107431627c 100644 --- a/pc/sdp_serializer.cc +++ b/pc/sdp_serializer.cc @@ -10,12 +10,14 @@ #include "pc/sdp_serializer.h" +#include +#include #include #include #include #include "absl/algorithm/container.h" -#include "api/jsep.h" +#include "absl/types/optional.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" #include "rtc_base/checks.h" #include "rtc_base/string_encode.h" diff --git a/pc/sdp_serializer.h b/pc/sdp_serializer.h index 476ebafbdc..1223cd1af7 100644 --- a/pc/sdp_serializer.h +++ b/pc/sdp_serializer.h @@ -17,6 +17,7 @@ #include "api/rtc_error.h" #include "media/base/rid_description.h" #include "pc/session_description.h" +#include "pc/simulcast_description.h" namespace webrtc { diff --git a/pc/sdp_utils.cc b/pc/sdp_utils.cc index f5385a6529..b750b04a46 100644 --- a/pc/sdp_utils.cc +++ b/pc/sdp_utils.cc @@ -11,10 +11,10 @@ #include "pc/sdp_utils.h" #include -#include #include #include "api/jsep_session_description.h" +#include "rtc_base/checks.h" namespace webrtc { diff --git a/pc/sdp_utils.h b/pc/sdp_utils.h index fc4b289f91..effd7cd034 100644 --- a/pc/sdp_utils.h +++ b/pc/sdp_utils.h @@ -16,6 +16,7 @@ #include #include "api/jsep.h" +#include "p2p/base/transport_info.h" #include "pc/session_description.h" #include "rtc_base/system/rtc_export.h" diff --git a/pc/session_description.cc b/pc/session_description.cc index 87d6667270..7b878cbf7b 100644 --- a/pc/session_description.cc +++ b/pc/session_description.cc @@ -10,12 +10,10 @@ #include "pc/session_description.h" -#include #include #include "absl/algorithm/container.h" #include "absl/memory/memory.h" -#include "pc/media_protocol_names.h" #include "rtc_base/checks.h" namespace cricket { @@ -87,6 +85,18 @@ bool ContentGroup::RemoveContentName(const std::string& content_name) { return true; } +std::string ContentGroup::ToString() const { + rtc::StringBuilder acc; + acc << semantics_ << "("; + if (!content_names_.empty()) { + for (const auto& name : content_names_) { + acc << name << " "; + } + } + acc << ")"; + return acc.Release(); +} + SessionDescription::SessionDescription() = default; SessionDescription::SessionDescription(const SessionDescription&) = default; @@ -261,6 +271,17 @@ const ContentGroup* SessionDescription::GetGroupByName( return NULL; } +std::vector SessionDescription::GetGroupsByName( + const std::string& name) const { + std::vector content_groups; + for (const ContentGroup& content_group : content_groups_) { + if (content_group.semantics() == name) { + content_groups.push_back(&content_group); + } + } + return content_groups; +} + ContentInfo::~ContentInfo() { } diff --git a/pc/session_description.h b/pc/session_description.h index 52a3a1fe04..a20caf624a 100644 --- a/pc/session_description.h +++ b/pc/session_description.h @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -24,15 +25,18 @@ #include "api/crypto_params.h" #include "api/media_types.h" #include "api/rtp_parameters.h" +#include "api/rtp_transceiver_direction.h" #include "api/rtp_transceiver_interface.h" +#include "media/base/codec.h" #include "media/base/media_channel.h" #include "media/base/media_constants.h" +#include "media/base/rid_description.h" #include "media/base/stream_params.h" #include "p2p/base/transport_description.h" #include "p2p/base/transport_info.h" #include "pc/media_protocol_names.h" #include "pc/simulcast_description.h" -#include "rtc_base/deprecation.h" +#include "rtc_base/checks.h" #include "rtc_base/socket_address.h" #include "rtc_base/system/rtc_export.h" @@ -40,7 +44,6 @@ namespace cricket { typedef std::vector AudioCodecs; typedef std::vector VideoCodecs; -typedef std::vector RtpDataCodecs; typedef std::vector CryptoParamsVec; typedef std::vector RtpHeaderExtensions; @@ -56,7 +59,6 @@ const int kAutoBandwidth = -1; class AudioContentDescription; class VideoContentDescription; -class RtpDataContentDescription; class SctpDataContentDescription; class UnsupportedContentDescription; @@ -79,11 +81,6 @@ class MediaContentDescription { virtual VideoContentDescription* as_video() { return nullptr; } virtual const VideoContentDescription* as_video() const { return nullptr; } - virtual RtpDataContentDescription* as_rtp_data() { return nullptr; } - virtual const RtpDataContentDescription* as_rtp_data() const { - return nullptr; - } - virtual SctpDataContentDescription* as_sctp() { return nullptr; } virtual const SctpDataContentDescription* as_sctp() const { return nullptr; } @@ -146,6 +143,11 @@ class MediaContentDescription { cryptos_ = cryptos; } + // List of RTP header extensions. URIs are **NOT** guaranteed to be unique + // as they can appear twice when both encrypted and non-encrypted extensions + // are present. + // Use RtpExtension::FindHeaderExtensionByUri for finding and + // RtpExtension::DeduplicateHeaderExtensions for filtering. virtual const RtpHeaderExtensions& rtp_header_extensions() const { return rtp_header_extensions_; } @@ -272,10 +274,7 @@ class MediaContentDescription { webrtc::RtpTransceiverDirection direction_ = webrtc::RtpTransceiverDirection::kSendRecv; rtc::SocketAddress connection_address_; - // Mixed one- and two-byte header not included in offer on media level or - // session level, but we will respond that we support it. The plan is to add - // it to our offer on session level. See todo in SessionDescription. - ExtmapAllowMixed extmap_allow_mixed_enum_ = kNo; + ExtmapAllowMixed extmap_allow_mixed_enum_ = kMedia; SimulcastDescription simulcast_; std::vector receive_rids_; @@ -360,20 +359,6 @@ class VideoContentDescription : public MediaContentDescriptionImpl { } }; -class RtpDataContentDescription - : public MediaContentDescriptionImpl { - public: - RtpDataContentDescription() {} - MediaType type() const override { return MEDIA_TYPE_DATA; } - RtpDataContentDescription* as_rtp_data() override { return this; } - const RtpDataContentDescription* as_rtp_data() const override { return this; } - - private: - RtpDataContentDescription* CloneInternal() const override { - return new RtpDataContentDescription(*this); - } -}; - class SctpDataContentDescription : public MediaContentDescription { public: SctpDataContentDescription() {} @@ -503,6 +488,8 @@ class ContentGroup { bool HasContentName(const std::string& content_name) const; void AddContentName(const std::string& content_name); bool RemoveContentName(const std::string& content_name); + // for debugging + std::string ToString() const; private: std::string semantics_; @@ -587,6 +574,8 @@ class SessionDescription { // Group accessors. const ContentGroups& groups() const { return content_groups_; } const ContentGroup* GetGroupByName(const std::string& name) const; + std::vector GetGroupsByName( + const std::string& name) const; bool HasGroup(const std::string& name) const; // Group mutators. @@ -633,12 +622,7 @@ class SessionDescription { // Default to what Plan B would do. // TODO(bugs.webrtc.org/8530): Change default to kMsidSignalingMediaSection. int msid_signaling_ = kMsidSignalingSsrcAttribute; - // TODO(webrtc:9985): Activate mixed one- and two-byte header extension in - // offer at session level. It's currently not included in offer by default - // because clients prior to https://bugs.webrtc.org/9712 cannot parse this - // correctly. If it's included in offer to us we will respond that we support - // it. - bool extmap_allow_mixed_ = false; + bool extmap_allow_mixed_ = true; }; // Indicates whether a session description was sent by the local client or diff --git a/pc/session_description_unittest.cc b/pc/session_description_unittest.cc index 75e0974ecd..00ce538398 100644 --- a/pc/session_description_unittest.cc +++ b/pc/session_description_unittest.cc @@ -17,7 +17,8 @@ namespace cricket { TEST(MediaContentDescriptionTest, ExtmapAllowMixedDefaultValue) { VideoContentDescription video_desc; - EXPECT_EQ(MediaContentDescription::kNo, video_desc.extmap_allow_mixed_enum()); + EXPECT_EQ(MediaContentDescription::kMedia, + video_desc.extmap_allow_mixed_enum()); } TEST(MediaContentDescriptionTest, SetExtmapAllowMixed) { @@ -129,16 +130,6 @@ TEST(SessionDescriptionTest, AddContentTransfersExtmapAllowMixedSetting) { EXPECT_EQ(MediaContentDescription::kSession, session_desc.GetContentDescriptionByName("video") ->extmap_allow_mixed_enum()); - - // Session level setting overrides media level when new content is added. - std::unique_ptr data_desc = - std::make_unique(); - data_desc->set_extmap_allow_mixed_enum(MediaContentDescription::kMedia); - session_desc.AddContent("data", MediaProtocolType::kRtp, - std::move(data_desc)); - EXPECT_EQ(MediaContentDescription::kSession, - session_desc.GetContentDescriptionByName("data") - ->extmap_allow_mixed_enum()); } } // namespace cricket diff --git a/pc/simulcast_description.cc b/pc/simulcast_description.cc index 8b510febaa..0ae3e2074e 100644 --- a/pc/simulcast_description.cc +++ b/pc/simulcast_description.cc @@ -10,8 +10,6 @@ #include "pc/simulcast_description.h" -#include - #include "rtc_base/checks.h" namespace cricket { diff --git a/pc/simulcast_description.h b/pc/simulcast_description.h index 1337a9ce4d..f7ae28837e 100644 --- a/pc/simulcast_description.h +++ b/pc/simulcast_description.h @@ -11,6 +11,8 @@ #ifndef PC_SIMULCAST_DESCRIPTION_H_ #define PC_SIMULCAST_DESCRIPTION_H_ +#include + #include #include diff --git a/pc/srtp_filter.cc b/pc/srtp_filter.cc index bd48eac83d..2f8d06cbea 100644 --- a/pc/srtp_filter.cc +++ b/pc/srtp_filter.cc @@ -11,8 +11,8 @@ #include "pc/srtp_filter.h" #include - #include +#include #include "absl/strings/match.h" #include "rtc_base/logging.h" @@ -210,9 +210,9 @@ bool SrtpFilter::ApplySendParams(const CryptoParams& send_params) { int send_key_len, send_salt_len; if (!rtc::GetSrtpKeyAndSaltLengths(*send_cipher_suite_, &send_key_len, &send_salt_len)) { - RTC_LOG(LS_WARNING) << "Could not get lengths for crypto suite(s):" - " send cipher_suite " - << send_params.cipher_suite; + RTC_LOG(LS_ERROR) << "Could not get lengths for crypto suite(s):" + " send cipher_suite " + << send_params.cipher_suite; return false; } @@ -241,9 +241,9 @@ bool SrtpFilter::ApplyRecvParams(const CryptoParams& recv_params) { int recv_key_len, recv_salt_len; if (!rtc::GetSrtpKeyAndSaltLengths(*recv_cipher_suite_, &recv_key_len, &recv_salt_len)) { - RTC_LOG(LS_WARNING) << "Could not get lengths for crypto suite(s):" - " recv cipher_suite " - << recv_params.cipher_suite; + RTC_LOG(LS_ERROR) << "Could not get lengths for crypto suite(s):" + " recv cipher_suite " + << recv_params.cipher_suite; return false; } diff --git a/pc/srtp_filter.h b/pc/srtp_filter.h index fc60a356fe..f1e164936c 100644 --- a/pc/srtp_filter.h +++ b/pc/srtp_filter.h @@ -11,6 +11,9 @@ #ifndef PC_SRTP_FILTER_H_ #define PC_SRTP_FILTER_H_ +#include +#include + #include #include #include @@ -21,11 +24,11 @@ #include "api/array_view.h" #include "api/crypto_params.h" #include "api/jsep.h" +#include "api/sequence_checker.h" #include "pc/session_description.h" #include "rtc_base/buffer.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/ssl_stream_adapter.h" -#include "rtc_base/thread_checker.h" // Forward declaration to avoid pulling in libsrtp headers here struct srtp_event_data_t; diff --git a/pc/srtp_session.cc b/pc/srtp_session.cc index 0315c6a63e..45f6b67d12 100644 --- a/pc/srtp_session.cc +++ b/pc/srtp_session.cc @@ -80,6 +80,10 @@ bool SrtpSession::ProtectRtp(void* p, int in_len, int max_len, int* out_len) { return false; } + // Note: the need_len differs from the libsrtp recommendatіon to ensure + // SRTP_MAX_TRAILER_LEN bytes of free space after the data. WebRTC + // never includes a MKI, therefore the amount of bytes added by the + // srtp_protect call is known in advance and depends on the cipher suite. int need_len = in_len + rtp_auth_tag_len_; // NOLINT if (max_len < need_len) { RTC_LOG(LS_WARNING) << "Failed to protect SRTP packet: The buffer length " @@ -122,6 +126,10 @@ bool SrtpSession::ProtectRtcp(void* p, int in_len, int max_len, int* out_len) { return false; } + // Note: the need_len differs from the libsrtp recommendatіon to ensure + // SRTP_MAX_TRAILER_LEN bytes of free space after the data. WebRTC + // never includes a MKI, therefore the amount of bytes added by the + // srtp_protect_rtp call is known in advance and depends on the cipher suite. int need_len = in_len + sizeof(uint32_t) + rtcp_auth_tag_len_; // NOLINT if (max_len < need_len) { RTC_LOG(LS_WARNING) << "Failed to protect SRTCP packet: The buffer length " @@ -261,42 +269,18 @@ bool SrtpSession::DoSetKey(int type, srtp_policy_t policy; memset(&policy, 0, sizeof(policy)); - if (cs == rtc::SRTP_AES128_CM_SHA1_80) { - srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtp); - srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); - } else if (cs == rtc::SRTP_AES128_CM_SHA1_32) { - // RTP HMAC is shortened to 32 bits, but RTCP remains 80 bits. - srtp_crypto_policy_set_aes_cm_128_hmac_sha1_32(&policy.rtp); - srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); - } else if (cs == rtc::SRTP_AEAD_AES_128_GCM) { - srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtp); - srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtcp); - } else if (cs == rtc::SRTP_AEAD_AES_256_GCM) { - srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtp); - srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtcp); - } else { - RTC_LOG(LS_WARNING) << "Failed to " << (session_ ? "update" : "create") - << " SRTP session: unsupported cipher_suite " << cs; - return false; - } - - int expected_key_len; - int expected_salt_len; - if (!rtc::GetSrtpKeyAndSaltLengths(cs, &expected_key_len, - &expected_salt_len)) { - // This should never happen. - RTC_NOTREACHED(); - RTC_LOG(LS_WARNING) - << "Failed to " << (session_ ? "update" : "create") - << " SRTP session: unsupported cipher_suite without length information" - << cs; + if (!(srtp_crypto_policy_set_from_profile_for_rtp( + &policy.rtp, (srtp_profile_t)cs) == srtp_err_status_ok && + srtp_crypto_policy_set_from_profile_for_rtcp( + &policy.rtcp, (srtp_profile_t)cs) == srtp_err_status_ok)) { + RTC_LOG(LS_ERROR) << "Failed to " << (session_ ? "update" : "create") + << " SRTP session: unsupported cipher_suite " << cs; return false; } - if (!key || - len != static_cast(expected_key_len + expected_salt_len)) { - RTC_LOG(LS_WARNING) << "Failed to " << (session_ ? "update" : "create") - << " SRTP session: invalid key"; + if (!key || len != static_cast(policy.rtp.cipher_key_len)) { + RTC_LOG(LS_ERROR) << "Failed to " << (session_ ? "update" : "create") + << " SRTP session: invalid key"; return false; } @@ -477,9 +461,10 @@ void SrtpSession::DumpPacket(const void* buf, int len, bool outbound) { int64_t seconds = (time_of_day / 1000) % 60; int64_t millis = time_of_day % 1000; RTC_LOG(LS_VERBOSE) << "\n" << (outbound ? "O" : "I") << " " - << std::setw(2) << hours << ":" << std::setw(2) << minutes << ":" - << std::setw(2) << seconds << "." << std::setw(3) - << std::setfill('0') << millis << " " + << std::setfill('0') << std::setw(2) << hours << ":" + << std::setfill('0') << std::setw(2) << minutes << ":" + << std::setfill('0') << std::setw(2) << seconds << "." + << std::setfill('0') << std::setw(3) << millis << " " << "000000 " << rtc::hex_encode_with_delimiter((const char *)buf, len, ' ') << " # RTP_DUMP"; } diff --git a/pc/srtp_session.h b/pc/srtp_session.h index 62327a9039..0396412481 100644 --- a/pc/srtp_session.h +++ b/pc/srtp_session.h @@ -14,9 +14,9 @@ #include #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/thread_checker.h" // Forward declaration to avoid pulling in libsrtp headers here struct srtp_event_data_t; @@ -124,10 +124,16 @@ class SrtpSession { void HandleEvent(const srtp_event_data_t* ev); static void HandleEventThunk(srtp_event_data_t* ev); - rtc::ThreadChecker thread_checker_; + webrtc::SequenceChecker thread_checker_; srtp_ctx_t_* session_ = nullptr; + + // Overhead of the SRTP auth tag for RTP and RTCP in bytes. + // Depends on the cipher suite used and is usually the same with the exception + // of the CS_AES_CM_128_HMAC_SHA1_32 cipher suite. The additional four bytes + // required for RTCP protection are not included. int rtp_auth_tag_len_ = 0; int rtcp_auth_tag_len_ = 0; + bool inited_ = false; static webrtc::GlobalMutex lock_; int last_send_seq_num_ = -1; diff --git a/pc/srtp_transport.cc b/pc/srtp_transport.cc index 6acb6b327b..c90b3fa227 100644 --- a/pc/srtp_transport.cc +++ b/pc/srtp_transport.cc @@ -10,7 +10,6 @@ #include "pc/srtp_transport.h" -#include #include #include @@ -202,12 +201,12 @@ bool SrtpTransport::SendRtcpPacket(rtc::CopyOnWriteBuffer* packet, void SrtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) { + TRACE_EVENT0("webrtc", "SrtpTransport::OnRtpPacketReceived"); if (!IsSrtpActive()) { RTC_LOG(LS_WARNING) << "Inactive SRTP transport received an RTP packet. Drop it."; return; } - TRACE_EVENT0("webrtc", "SRTP Decode"); char* data = packet.MutableData(); int len = rtc::checked_cast(packet.size()); if (!UnprotectRtp(data, len, &len)) { @@ -234,12 +233,12 @@ void SrtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer packet, void SrtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) { + TRACE_EVENT0("webrtc", "SrtpTransport::OnRtcpPacketReceived"); if (!IsSrtpActive()) { RTC_LOG(LS_WARNING) << "Inactive SRTP transport received an RTCP packet. Drop it."; return; } - TRACE_EVENT0("webrtc", "SRTP Decode"); char* data = packet.MutableData(); int len = rtc::checked_cast(packet.size()); if (!UnprotectRtcp(data, len, &len)) { diff --git a/pc/stats_collector.cc b/pc/stats_collector.cc index 991cc4eb2b..eb2176ed38 100644 --- a/pc/stats_collector.cc +++ b/pc/stats_collector.cc @@ -10,14 +10,47 @@ #include "pc/stats_collector.h" +#include +#include + #include #include #include #include +#include "absl/types/optional.h" +#include "api/audio_codecs/audio_encoder.h" +#include "api/candidate.h" +#include "api/data_channel_interface.h" +#include "api/media_types.h" +#include "api/rtp_receiver_interface.h" +#include "api/rtp_sender_interface.h" +#include "api/scoped_refptr.h" +#include "api/sequence_checker.h" +#include "api/video/video_content_type.h" +#include "api/video/video_timing.h" +#include "call/call.h" +#include "media/base/media_channel.h" +#include "modules/audio_processing/include/audio_processing_statistics.h" +#include "p2p/base/ice_transport_internal.h" +#include "p2p/base/p2p_constants.h" #include "pc/channel.h" +#include "pc/channel_interface.h" +#include "pc/data_channel_utils.h" +#include "pc/rtp_receiver.h" +#include "pc/rtp_transceiver.h" +#include "pc/transport_stats.h" #include "rtc_base/checks.h" -#include "rtc_base/third_party/base64/base64.h" +#include "rtc_base/ip_address.h" +#include "rtc_base/location.h" +#include "rtc_base/logging.h" +#include "rtc_base/rtc_certificate.h" +#include "rtc_base/socket_address.h" +#include "rtc_base/ssl_stream_adapter.h" +#include "rtc_base/string_encode.h" +#include "rtc_base/thread.h" +#include "rtc_base/time_utils.h" +#include "rtc_base/trace_event.h" #include "system_wrappers/include/field_trial.h" namespace webrtc { @@ -286,6 +319,10 @@ void ExtractStats(const cricket::VideoReceiverInfo& info, if (info.qp_sum) report->AddInt64(StatsReport::kStatsValueNameQpSum, *info.qp_sum); + if (info.nacks_sent) { + report->AddInt(StatsReport::kStatsValueNameNacksSent, *info.nacks_sent); + } + const IntForAdd ints[] = { {StatsReport::kStatsValueNameCurrentDelayMs, info.current_delay_ms}, {StatsReport::kStatsValueNameDecodeMs, info.decode_ms}, @@ -299,7 +336,6 @@ void ExtractStats(const cricket::VideoReceiverInfo& info, {StatsReport::kStatsValueNameMaxDecodeMs, info.max_decode_ms}, {StatsReport::kStatsValueNameMinPlayoutDelayMs, info.min_playout_delay_ms}, - {StatsReport::kStatsValueNameNacksSent, info.nacks_sent}, {StatsReport::kStatsValueNamePacketsLost, info.packets_lost}, {StatsReport::kStatsValueNamePacketsReceived, info.packets_rcvd}, {StatsReport::kStatsValueNamePlisSent, info.plis_sent}, @@ -508,7 +544,7 @@ StatsCollector::StatsCollector(PeerConnectionInternal* pc) } StatsCollector::~StatsCollector() { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); } // Wallclock time in ms. @@ -519,7 +555,7 @@ double StatsCollector::GetTimeNow() { // Adds a MediaStream with tracks that can be used as a |selector| in a call // to GetStats. void StatsCollector::AddStream(MediaStreamInterface* stream) { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); RTC_DCHECK(stream != NULL); CreateTrackReports(stream->GetAudioTracks(), &reports_, @@ -542,7 +578,7 @@ void StatsCollector::AddTrack(MediaStreamTrackInterface* track) { void StatsCollector::AddLocalAudioTrack(AudioTrackInterface* audio_track, uint32_t ssrc) { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); RTC_DCHECK(audio_track != NULL); #if RTC_DCHECK_IS_ON for (const auto& track : local_audio_tracks_) @@ -576,7 +612,7 @@ void StatsCollector::RemoveLocalAudioTrack(AudioTrackInterface* audio_track, void StatsCollector::GetStats(MediaStreamTrackInterface* track, StatsReports* reports) { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); RTC_DCHECK(reports != NULL); RTC_DCHECK(reports->empty()); @@ -616,26 +652,33 @@ void StatsCollector::GetStats(MediaStreamTrackInterface* track, void StatsCollector::UpdateStats( PeerConnectionInterface::StatsOutputLevel level) { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); - double time_now = GetTimeNow(); - // Calls to UpdateStats() that occur less than kMinGatherStatsPeriod number of - // ms apart will be ignored. - const double kMinGatherStatsPeriod = 50; - if (stats_gathering_started_ != 0 && - stats_gathering_started_ + kMinGatherStatsPeriod > time_now) { + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); + // Calls to UpdateStats() that occur less than kMinGatherStatsPeriodMs apart + // will be ignored. Using a monotonic clock specifically for this, while using + // a UTC clock for the reports themselves. + const int64_t kMinGatherStatsPeriodMs = 50; + int64_t cache_now_ms = rtc::TimeMillis(); + if (cache_timestamp_ms_ != 0 && + cache_timestamp_ms_ + kMinGatherStatsPeriodMs > cache_now_ms) { return; } - stats_gathering_started_ = time_now; + cache_timestamp_ms_ = cache_now_ms; + stats_gathering_started_ = GetTimeNow(); + + // TODO(tommi): ExtractSessionInfo now has a single hop to the network thread + // to fetch stats, then applies them on the signaling thread. See if we need + // to do this synchronously or if updating the stats without blocking is safe. + std::map transport_names_by_mid = + ExtractSessionInfo(); // TODO(tommi): All of these hop over to the worker thread to fetch - // information. We could use an AsyncInvoker to run all of these and post + // information. We could post a task to run all of these and post // the information back to the signaling thread where we can create and // update stats reports. That would also clean up the threading story a bit // since we'd be creating/updating the stats report objects consistently on // the same thread (this class has no locks right now). - ExtractSessionInfo(); ExtractBweInfo(); - ExtractMediaInfo(); + ExtractMediaInfo(transport_names_by_mid); ExtractSenderInfo(); ExtractDataInfo(); UpdateTrackReports(); @@ -646,7 +689,7 @@ StatsReport* StatsCollector::PrepareReport(bool local, const std::string& track_id, const StatsReport::Id& transport_id, StatsReport::Direction direction) { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); StatsReport::Id id(StatsReport::NewIdWithDirection( local ? StatsReport::kStatsReportTypeSsrc : StatsReport::kStatsReportTypeRemoteSsrc, @@ -669,7 +712,7 @@ StatsReport* StatsCollector::PrepareReport(bool local, } StatsReport* StatsCollector::PrepareADMReport() { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); StatsReport::Id id(StatsReport::NewTypedId( StatsReport::kStatsReportTypeSession, pc_->session_id())); StatsReport* report = reports_.FindOrAddNew(id); @@ -683,7 +726,7 @@ bool StatsCollector::IsValidTrack(const std::string& track_id) { StatsReport* StatsCollector::AddCertificateReports( std::unique_ptr cert_stats) { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); StatsReport* first_report = nullptr; StatsReport* prev_report = nullptr; @@ -771,7 +814,7 @@ StatsReport* StatsCollector::AddConnectionInfoReport( StatsReport* StatsCollector::AddCandidateReport( const cricket::CandidateStats& candidate_stats, bool local) { - const auto& candidate = candidate_stats.candidate; + const auto& candidate = candidate_stats.candidate(); StatsReport::Id id(StatsReport::NewCandidateId(local, candidate.id())); StatsReport* report = reports_.Find(id); if (!report) { @@ -794,8 +837,8 @@ StatsReport* StatsCollector::AddCandidateReport( } report->set_timestamp(stats_gathering_started_); - if (local && candidate_stats.stun_stats.has_value()) { - const auto& stun_stats = candidate_stats.stun_stats.value(); + if (local && candidate_stats.stun_stats().has_value()) { + const auto& stun_stats = candidate_stats.stun_stats().value(); report->AddInt64(StatsReport::kStatsValueNameSentStunKeepaliveRequests, stun_stats.stun_binding_requests_sent); report->AddInt64(StatsReport::kStatsValueNameRecvStunKeepaliveResponses, @@ -809,35 +852,58 @@ StatsReport* StatsCollector::AddCandidateReport( return report; } -void StatsCollector::ExtractSessionInfo() { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); +std::map StatsCollector::ExtractSessionInfo() { + TRACE_EVENT0("webrtc", "StatsCollector::ExtractSessionInfo"); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); - // Extract information from the base session. - StatsReport::Id id(StatsReport::NewTypedId( - StatsReport::kStatsReportTypeSession, pc_->session_id())); - StatsReport* report = reports_.ReplaceOrAddNew(id); - report->set_timestamp(stats_gathering_started_); - report->AddBoolean(StatsReport::kStatsValueNameInitiator, - pc_->initial_offerer()); + SessionStats stats; + auto transceivers = pc_->GetTransceiversInternal(); + pc_->network_thread()->Invoke( + RTC_FROM_HERE, [&, sctp_transport_name = pc_->sctp_transport_name(), + sctp_mid = pc_->sctp_mid()]() mutable { + stats = ExtractSessionInfo_n( + transceivers, std::move(sctp_transport_name), std::move(sctp_mid)); + }); - cricket::CandidateStatsList pooled_candidate_stats_list = - pc_->GetPooledCandidateStats(); + ExtractSessionInfo_s(stats); - for (const cricket::CandidateStats& stats : pooled_candidate_stats_list) { - AddCandidateReport(stats, true); + return std::move(stats.transport_names_by_mid); +} + +StatsCollector::SessionStats StatsCollector::ExtractSessionInfo_n( + const std::vector>>& transceivers, + absl::optional sctp_transport_name, + absl::optional sctp_mid) { + TRACE_EVENT0("webrtc", "StatsCollector::ExtractSessionInfo_n"); + RTC_DCHECK_RUN_ON(pc_->network_thread()); + rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; + SessionStats stats; + stats.candidate_stats = pc_->GetPooledCandidateStats(); + for (auto& transceiver : transceivers) { + cricket::ChannelInterface* channel = transceiver->internal()->channel(); + if (channel) { + stats.transport_names_by_mid[channel->content_name()] = + channel->transport_name(); + } + } + + if (sctp_transport_name) { + RTC_DCHECK(sctp_mid); + stats.transport_names_by_mid[*sctp_mid] = *sctp_transport_name; } std::set transport_names; - for (const auto& entry : pc_->GetTransportNamesByMid()) { + for (const auto& entry : stats.transport_names_by_mid) { transport_names.insert(entry.second); } std::map transport_stats_by_name = pc_->GetTransportStatsByNames(transport_names); - for (const auto& entry : transport_stats_by_name) { - const std::string& transport_name = entry.first; - const cricket::TransportStats& transport_stats = entry.second; + for (auto& entry : transport_stats_by_name) { + stats.transport_stats.emplace_back(entry.first, std::move(entry.second)); + TransportStats& transport = stats.transport_stats.back(); // Attempt to get a copy of the certificates from the transport and // expose them in stats reports. All channels in a transport share the @@ -845,24 +911,59 @@ void StatsCollector::ExtractSessionInfo() { // StatsReport::Id local_cert_report_id, remote_cert_report_id; rtc::scoped_refptr certificate; - if (pc_->GetLocalCertificate(transport_name, &certificate)) { - StatsReport* r = AddCertificateReports( - certificate->GetSSLCertificateChain().GetStats()); - if (r) - local_cert_report_id = r->id(); + if (pc_->GetLocalCertificate(transport.name, &certificate)) { + transport.local_cert_stats = + certificate->GetSSLCertificateChain().GetStats(); } std::unique_ptr remote_cert_chain = - pc_->GetRemoteSSLCertChain(transport_name); + pc_->GetRemoteSSLCertChain(transport.name); if (remote_cert_chain) { - StatsReport* r = AddCertificateReports(remote_cert_chain->GetStats()); + transport.remote_cert_stats = remote_cert_chain->GetStats(); + } + } + + return stats; +} + +void StatsCollector::ExtractSessionInfo_s(SessionStats& session_stats) { + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); + rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; + + StatsReport::Id id(StatsReport::NewTypedId( + StatsReport::kStatsReportTypeSession, pc_->session_id())); + StatsReport* report = reports_.ReplaceOrAddNew(id); + report->set_timestamp(stats_gathering_started_); + report->AddBoolean(StatsReport::kStatsValueNameInitiator, + pc_->initial_offerer()); + + for (const cricket::CandidateStats& stats : session_stats.candidate_stats) { + AddCandidateReport(stats, true); + } + + for (auto& transport : session_stats.transport_stats) { + // Attempt to get a copy of the certificates from the transport and + // expose them in stats reports. All channels in a transport share the + // same local and remote certificates. + // + StatsReport::Id local_cert_report_id, remote_cert_report_id; + if (transport.local_cert_stats) { + StatsReport* r = + AddCertificateReports(std::move(transport.local_cert_stats)); + if (r) + local_cert_report_id = r->id(); + } + + if (transport.remote_cert_stats) { + StatsReport* r = + AddCertificateReports(std::move(transport.remote_cert_stats)); if (r) remote_cert_report_id = r->id(); } - for (const auto& channel_iter : transport_stats.channel_stats) { + for (const auto& channel_iter : transport.stats.channel_stats) { StatsReport::Id id( - StatsReport::NewComponentId(transport_name, channel_iter.component)); + StatsReport::NewComponentId(transport.name, channel_iter.component)); StatsReport* channel_report = reports_.ReplaceOrAddNew(id); channel_report->set_timestamp(stats_gathering_started_); channel_report->AddInt(StatsReport::kStatsValueNameComponent, @@ -905,7 +1006,7 @@ void StatsCollector::ExtractSessionInfo() { for (const cricket::ConnectionInfo& info : channel_iter.ice_transport_stats.connection_infos) { StatsReport* connection_report = AddConnectionInfoReport( - transport_name, channel_iter.component, connection_id++, + transport.name, channel_iter.component, connection_id++, channel_report->id(), info); if (info.best_connection) { channel_report->AddId( @@ -918,7 +1019,7 @@ void StatsCollector::ExtractSessionInfo() { } void StatsCollector::ExtractBweInfo() { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); if (pc_->signaling_state() == PeerConnectionInterface::kClosed) return; @@ -931,16 +1032,25 @@ void StatsCollector::ExtractBweInfo() { // Fill in target encoder bitrate, actual encoder bitrate, rtx bitrate, etc. // TODO(holmer): Also fill this in for audio. - for (const auto& transceiver : pc_->GetTransceiversInternal()) { + auto transceivers = pc_->GetTransceiversInternal(); + std::vector video_channels; + for (const auto& transceiver : transceivers) { if (transceiver->media_type() != cricket::MEDIA_TYPE_VIDEO) { continue; } auto* video_channel = static_cast(transceiver->internal()->channel()); - if (!video_channel) { - continue; + if (video_channel) { + video_channels.push_back(video_channel); } - video_channel->FillBitrateInfo(&bwe_info); + } + + if (!video_channels.empty()) { + pc_->worker_thread()->Invoke(RTC_FROM_HERE, [&] { + for (const auto& channel : video_channels) { + channel->FillBitrateInfo(&bwe_info); + } + }); } StatsReport::Id report_id(StatsReport::NewBandwidthEstimationId()); @@ -1053,14 +1163,16 @@ std::unique_ptr CreateMediaChannelStatsGatherer( } // namespace -void StatsCollector::ExtractMediaInfo() { +void StatsCollector::ExtractMediaInfo( + const std::map& transport_names_by_mid) { RTC_DCHECK_RUN_ON(pc_->signaling_thread()); std::vector> gatherers; + auto transceivers = pc_->GetTransceiversInternal(); { rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; - for (const auto& transceiver : pc_->GetTransceiversInternal()) { + for (const auto& transceiver : transceivers) { cricket::ChannelInterface* channel = transceiver->internal()->channel(); if (!channel) { continue; @@ -1068,22 +1180,40 @@ void StatsCollector::ExtractMediaInfo() { std::unique_ptr gatherer = CreateMediaChannelStatsGatherer(channel->media_channel()); gatherer->mid = channel->content_name(); - gatherer->transport_name = channel->transport_name(); + gatherer->transport_name = transport_names_by_mid.at(gatherer->mid); + for (const auto& sender : transceiver->internal()->senders()) { - std::string track_id = (sender->track() ? sender->track()->id() : ""); + auto track = sender->track(); + std::string track_id = (track ? track->id() : ""); gatherer->sender_track_id_by_ssrc.insert( std::make_pair(sender->ssrc(), track_id)); } - for (const auto& receiver : transceiver->internal()->receivers()) { - gatherer->receiver_track_id_by_ssrc.insert(std::make_pair( - receiver->internal()->ssrc(), receiver->track()->id())); - } + + // Populating `receiver_track_id_by_ssrc` will be done on the worker + // thread as the `ssrc` property of the receiver needs to be accessed + // there. + gatherers.push_back(std::move(gatherer)); } } pc_->worker_thread()->Invoke(RTC_FROM_HERE, [&] { rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; + // Populate `receiver_track_id_by_ssrc` for the gatherers. + int i = 0; + for (const auto& transceiver : transceivers) { + cricket::ChannelInterface* channel = transceiver->internal()->channel(); + if (!channel) + continue; + MediaChannelStatsGatherer* gatherer = gatherers[i++].get(); + RTC_DCHECK_EQ(gatherer->mid, channel->content_name()); + + for (const auto& receiver : transceiver->internal()->receivers()) { + gatherer->receiver_track_id_by_ssrc.insert(std::make_pair( + receiver->internal()->ssrc(), receiver->track()->id())); + } + } + for (auto it = gatherers.begin(); it != gatherers.end(); /* incremented manually */) { MediaChannelStatsGatherer* gatherer = it->get(); @@ -1109,7 +1239,7 @@ void StatsCollector::ExtractMediaInfo() { } void StatsCollector::ExtractSenderInfo() { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); for (const auto& sender : pc_->GetSenders()) { // TODO(nisse): SSRC == 0 currently means none. Delete check when @@ -1142,7 +1272,7 @@ void StatsCollector::ExtractSenderInfo() { } void StatsCollector::ExtractDataInfo() { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; @@ -1166,7 +1296,7 @@ void StatsCollector::ExtractDataInfo() { StatsReport* StatsCollector::GetReport(const StatsReport::StatsType& type, const std::string& id, StatsReport::Direction direction) { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); RTC_DCHECK(type == StatsReport::kStatsReportTypeSsrc || type == StatsReport::kStatsReportTypeRemoteSsrc); return reports_.Find(StatsReport::NewIdWithDirection(type, id, direction)); @@ -1174,7 +1304,7 @@ StatsReport* StatsCollector::GetReport(const StatsReport::StatsType& type, void StatsCollector::UpdateStatsFromExistingLocalAudioTracks( bool has_remote_tracks) { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); // Loop through the existing local audio tracks. for (const auto& it : local_audio_tracks_) { AudioTrackInterface* track = it.first; @@ -1202,7 +1332,7 @@ void StatsCollector::UpdateStatsFromExistingLocalAudioTracks( void StatsCollector::UpdateReportFromAudioTrack(AudioTrackInterface* track, StatsReport* report, bool has_remote_tracks) { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); RTC_DCHECK(track != NULL); // Don't overwrite report values if they're not available. @@ -1224,7 +1354,7 @@ void StatsCollector::UpdateReportFromAudioTrack(AudioTrackInterface* track, } void StatsCollector::UpdateTrackReports() { - RTC_DCHECK(pc_->signaling_thread()->IsCurrent()); + RTC_DCHECK_RUN_ON(pc_->signaling_thread()); rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; @@ -1235,7 +1365,7 @@ void StatsCollector::UpdateTrackReports() { } void StatsCollector::ClearUpdateStatsCacheForTest() { - stats_gathering_started_ = 0; + cache_timestamp_ms_ = 0; } } // namespace webrtc diff --git a/pc/stats_collector.h b/pc/stats_collector.h index befbcabbf0..2fd5d9d8f8 100644 --- a/pc/stats_collector.h +++ b/pc/stats_collector.h @@ -16,6 +16,8 @@ #include +#include +#include #include #include #include @@ -25,6 +27,7 @@ #include "api/media_stream_interface.h" #include "api/peer_connection_interface.h" #include "api/stats_types.h" +#include "p2p/base/connection_info.h" #include "p2p/base/port.h" #include "pc/peer_connection_internal.h" #include "pc/stats_collector_interface.h" @@ -52,7 +55,7 @@ class StatsCollector : public StatsCollectorInterface { explicit StatsCollector(PeerConnectionInternal* pc); virtual ~StatsCollector(); - // Adds a MediaStream with tracks that can be used as a |selector| in a call + // Adds a MediaStream with tracks that can be used as a `selector` in a call // to GetStats. void AddStream(MediaStreamInterface* stream); void AddTrack(MediaStreamTrackInterface* track); @@ -70,12 +73,12 @@ class StatsCollector : public StatsCollectorInterface { void UpdateStats(PeerConnectionInterface::StatsOutputLevel level); // Gets a StatsReports of the last collected stats. Note that UpdateStats must - // be called before this function to get the most recent stats. |selector| is + // be called before this function to get the most recent stats. `selector` is // a track label or empty string. The most recent reports are stored in - // |reports|. + // `reports`. // TODO(tommi): Change this contract to accept a callback object instead - // of filling in |reports|. As is, there's a requirement that the caller - // uses |reports| immediately without allowing any async activity on + // of filling in `reports`. As is, there's a requirement that the caller + // uses `reports` immediately without allowing any async activity on // the thread (message handling etc) and then discard the results. void GetStats(MediaStreamTrackInterface* track, StatsReports* reports) override; @@ -103,19 +106,48 @@ class StatsCollector : public StatsCollectorInterface { private: friend class StatsCollectorTest; + // Struct that's populated on the network thread and carries the values to + // the signaling thread where the stats are added to the stats reports. + struct TransportStats { + TransportStats() = default; + TransportStats(std::string transport_name, + cricket::TransportStats transport_stats) + : name(std::move(transport_name)), stats(std::move(transport_stats)) {} + TransportStats(TransportStats&&) = default; + TransportStats(const TransportStats&) = delete; + + std::string name; + cricket::TransportStats stats; + std::unique_ptr local_cert_stats; + std::unique_ptr remote_cert_stats; + }; + + struct SessionStats { + SessionStats() = default; + SessionStats(SessionStats&&) = default; + SessionStats(const SessionStats&) = delete; + + SessionStats& operator=(SessionStats&&) = default; + SessionStats& operator=(SessionStats&) = delete; + + cricket::CandidateStatsList candidate_stats; + std::vector transport_stats; + std::map transport_names_by_mid; + }; + // Overridden in unit tests to fake timing. virtual double GetTimeNow(); bool CopySelectedReports(const std::string& selector, StatsReports* reports); - // Helper method for creating IceCandidate report. |is_local| indicates + // Helper method for creating IceCandidate report. `is_local` indicates // whether this candidate is local or remote. StatsReport* AddCandidateReport( const cricket::CandidateStats& candidate_stats, bool local); // Adds a report for this certificate and every certificate in its chain, and - // returns the leaf certificate's report (|cert_stats|'s report). + // returns the leaf certificate's report (`cert_stats`'s report). StatsReport* AddCertificateReports( std::unique_ptr cert_stats); @@ -126,9 +158,14 @@ class StatsCollector : public StatsCollectorInterface { const cricket::ConnectionInfo& info); void ExtractDataInfo(); - void ExtractSessionInfo(); + + // Returns the `transport_names_by_mid` member from the SessionStats as + // gathered and used to populate the stats. + std::map ExtractSessionInfo(); + void ExtractBweInfo(); - void ExtractMediaInfo(); + void ExtractMediaInfo( + const std::map& transport_names_by_mid); void ExtractSenderInfo(); webrtc::StatsReport* GetReport(const StatsReport::StatsType& type, const std::string& id, @@ -143,11 +180,19 @@ class StatsCollector : public StatsCollectorInterface { // Helper method to update the timestamp of track records. void UpdateTrackReports(); + SessionStats ExtractSessionInfo_n( + const std::vector>>& transceivers, + absl::optional sctp_transport_name, + absl::optional sctp_mid); + void ExtractSessionInfo_s(SessionStats& session_stats); + // A collection for all of our stats reports. StatsCollection reports_; TrackIdMap track_ids_; // Raw pointer to the peer connection the statistics are gathered from. PeerConnectionInternal* const pc_; + int64_t cache_timestamp_ms_ = 0; double stats_gathering_started_; const bool use_standard_bytes_stats_; diff --git a/pc/stats_collector_unittest.cc b/pc/stats_collector_unittest.cc index 3767081b56..c630c3af6c 100644 --- a/pc/stats_collector_unittest.cc +++ b/pc/stats_collector_unittest.cc @@ -96,7 +96,7 @@ class FakeAudioTrack : public MediaStreamTrack { public: explicit FakeAudioTrack(const std::string& id) : MediaStreamTrack(id), - processor_(new rtc::RefCountedObject()) {} + processor_(rtc::make_ref_counted()) {} std::string kind() const override { return "audio"; } AudioSourceInterface* GetSource() const override { return NULL; } void AddSink(AudioTrackSinkInterface* sink) override {} @@ -134,8 +134,7 @@ class FakeAudioTrackWithInitValue public: explicit FakeAudioTrackWithInitValue(const std::string& id) : MediaStreamTrack(id), - processor_( - new rtc::RefCountedObject()) {} + processor_(rtc::make_ref_counted()) {} std::string kind() const override { return "audio"; } AudioSourceInterface* GetSource() const override { return NULL; } void AddSink(AudioTrackSinkInterface* sink) override {} @@ -600,7 +599,7 @@ class StatsCollectorForTest : public StatsCollector { class StatsCollectorTest : public ::testing::Test { protected: rtc::scoped_refptr CreatePeerConnection() { - return new rtc::RefCountedObject(); + return rtc::make_ref_counted(); } std::unique_ptr CreateStatsCollector( @@ -738,8 +737,7 @@ class StatsCollectorTest : public ::testing::Test { static rtc::scoped_refptr CreateMockSender( rtc::scoped_refptr track, uint32_t ssrc) { - rtc::scoped_refptr sender( - new rtc::RefCountedObject()); + auto sender = rtc::make_ref_counted(); EXPECT_CALL(*sender, track()).WillRepeatedly(Return(track)); EXPECT_CALL(*sender, ssrc()).WillRepeatedly(Return(ssrc)); EXPECT_CALL(*sender, media_type()) @@ -753,8 +751,7 @@ static rtc::scoped_refptr CreateMockSender( static rtc::scoped_refptr CreateMockReceiver( rtc::scoped_refptr track, uint32_t ssrc) { - rtc::scoped_refptr receiver( - new rtc::RefCountedObject()); + auto receiver = rtc::make_ref_counted(); EXPECT_CALL(*receiver, track()).WillRepeatedly(Return(track)); EXPECT_CALL(*receiver, ssrc()).WillRepeatedly(Return(ssrc)); EXPECT_CALL(*receiver, media_type()) @@ -808,7 +805,7 @@ class StatsCollectorTrackTest : public StatsCollectorTest, rtc::scoped_refptr AddOutgoingAudioTrack( FakePeerConnectionForStats* pc, StatsCollectorForTest* stats) { - audio_track_ = new rtc::RefCountedObject(kLocalTrackId); + audio_track_ = rtc::make_ref_counted(kLocalTrackId); if (GetParam()) { if (!stream_) stream_ = MediaStream::Create("streamid"); @@ -823,7 +820,7 @@ class StatsCollectorTrackTest : public StatsCollectorTest, // Adds a incoming audio track with a given SSRC into the stats. void AddIncomingAudioTrack(FakePeerConnectionForStats* pc, StatsCollectorForTest* stats) { - audio_track_ = new rtc::RefCountedObject(kRemoteTrackId); + audio_track_ = rtc::make_ref_counted(kRemoteTrackId); if (GetParam()) { if (stream_ == NULL) stream_ = MediaStream::Create("streamid"); @@ -1483,8 +1480,8 @@ TEST_P(StatsCollectorTrackTest, FilterOutNegativeInitialValues) { // Create a local stream with a local audio track and adds it to the stats. stream_ = MediaStream::Create("streamid"); - rtc::scoped_refptr local_track( - new rtc::RefCountedObject(kLocalTrackId)); + auto local_track = + rtc::make_ref_counted(kLocalTrackId); stream_->AddTrack(local_track); pc->AddSender(CreateMockSender(local_track, kSsrcOfTrack)); if (GetParam()) { @@ -1495,8 +1492,8 @@ TEST_P(StatsCollectorTrackTest, FilterOutNegativeInitialValues) { // Create a remote stream with a remote audio track and adds it to the stats. rtc::scoped_refptr remote_stream( MediaStream::Create("remotestreamid")); - rtc::scoped_refptr remote_track( - new rtc::RefCountedObject(kRemoteTrackId)); + auto remote_track = + rtc::make_ref_counted(kRemoteTrackId); remote_stream->AddTrack(remote_track); pc->AddReceiver(CreateMockReceiver(remote_track, kSsrcOfTrack)); if (GetParam()) { @@ -1665,8 +1662,7 @@ TEST_P(StatsCollectorTrackTest, LocalAndRemoteTracksWithSameSsrc) { // Create a remote stream with a remote audio track and adds it to the stats. rtc::scoped_refptr remote_stream( MediaStream::Create("remotestreamid")); - rtc::scoped_refptr remote_track( - new rtc::RefCountedObject(kRemoteTrackId)); + auto remote_track = rtc::make_ref_counted(kRemoteTrackId); pc->AddReceiver(CreateMockReceiver(remote_track, kSsrcOfTrack)); remote_stream->AddTrack(remote_track); stats->AddStream(remote_stream); @@ -1755,8 +1751,7 @@ TEST_P(StatsCollectorTrackTest, TwoLocalTracksWithSameSsrc) { // Create a new audio track and adds it to the stream and stats. static const std::string kNewTrackId = "new_track_id"; - rtc::scoped_refptr new_audio_track( - new rtc::RefCountedObject(kNewTrackId)); + auto new_audio_track = rtc::make_ref_counted(kNewTrackId); pc->AddSender(CreateMockSender(new_audio_track, kSsrcOfTrack)); stream_->AddTrack(new_audio_track); @@ -1785,8 +1780,8 @@ TEST_P(StatsCollectorTrackTest, TwoLocalSendersWithSameTrack) { auto pc = CreatePeerConnection(); auto stats = CreateStatsCollector(pc); - rtc::scoped_refptr local_track( - new rtc::RefCountedObject(kLocalTrackId)); + auto local_track = + rtc::make_ref_counted(kLocalTrackId); pc->AddSender(CreateMockSender(local_track, kFirstSsrc)); stats->AddLocalAudioTrack(local_track.get(), kFirstSsrc); pc->AddSender(CreateMockSender(local_track, kSecondSsrc)); diff --git a/pc/stream_collection.h b/pc/stream_collection.h index 28cd46fc5d..9bbf957efd 100644 --- a/pc/stream_collection.h +++ b/pc/stream_collection.h @@ -22,16 +22,12 @@ namespace webrtc { class StreamCollection : public StreamCollectionInterface { public: static rtc::scoped_refptr Create() { - rtc::RefCountedObject* implementation = - new rtc::RefCountedObject(); - return implementation; + return rtc::make_ref_counted(); } static rtc::scoped_refptr Create( StreamCollection* streams) { - rtc::RefCountedObject* implementation = - new rtc::RefCountedObject(streams); - return implementation; + return rtc::make_ref_counted(streams); } virtual size_t count() { return media_streams_.size(); } diff --git a/pc/test/fake_audio_capture_module.cc b/pc/test/fake_audio_capture_module.cc index a395df0409..214ed6b523 100644 --- a/pc/test/fake_audio_capture_module.cc +++ b/pc/test/fake_audio_capture_module.cc @@ -58,8 +58,7 @@ FakeAudioCaptureModule::~FakeAudioCaptureModule() { } rtc::scoped_refptr FakeAudioCaptureModule::Create() { - rtc::scoped_refptr capture_module( - new rtc::RefCountedObject()); + auto capture_module = rtc::make_ref_counted(); if (!capture_module->Initialize()) { return nullptr; } diff --git a/pc/test/fake_audio_capture_module.h b/pc/test/fake_audio_capture_module.h index ee85c9a490..d2db3d666d 100644 --- a/pc/test/fake_audio_capture_module.h +++ b/pc/test/fake_audio_capture_module.h @@ -20,13 +20,20 @@ #ifndef PC_TEST_FAKE_AUDIO_CAPTURE_MODULE_H_ #define PC_TEST_FAKE_AUDIO_CAPTURE_MODULE_H_ +#include +#include + #include #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "modules/audio_device/include/audio_device.h" +#include "modules/audio_device/include/audio_device_defines.h" #include "rtc_base/message_handler.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" +#include "rtc_base/thread.h" +#include "rtc_base/thread_annotations.h" +#include "rtc_base/thread_message.h" namespace rtc { class Thread; diff --git a/pc/test/fake_data_channel_provider.h b/pc/test/fake_data_channel_provider.h index 7145225ca6..f9e9e91d48 100644 --- a/pc/test/fake_data_channel_provider.h +++ b/pc/test/fake_data_channel_provider.h @@ -26,7 +26,8 @@ class FakeDataChannelProvider transport_error_(false) {} virtual ~FakeDataChannelProvider() {} - bool SendData(const cricket::SendDataParams& params, + bool SendData(int sid, + const webrtc::SendDataParams& params, const rtc::CopyOnWriteBuffer& payload, cricket::SendDataResult* result) override { RTC_CHECK(ready_to_send_); @@ -36,11 +37,12 @@ class FakeDataChannelProvider return false; } - if (transport_error_ || payload.size() == 0) { + if (transport_error_) { *result = cricket::SDR_ERROR; return false; } + last_sid_ = sid; last_send_data_params_ = params; return true; } @@ -127,7 +129,8 @@ class FakeDataChannelProvider void set_transport_error() { transport_error_ = true; } - cricket::SendDataParams last_send_data_params() const { + int last_sid() const { return last_sid_; } + const webrtc::SendDataParams& last_send_data_params() const { return last_send_data_params_; } @@ -144,7 +147,8 @@ class FakeDataChannelProvider } private: - cricket::SendDataParams last_send_data_params_; + int last_sid_; + webrtc::SendDataParams last_send_data_params_; bool send_blocked_; bool transport_available_; bool ready_to_send_; diff --git a/pc/test/fake_peer_connection_base.h b/pc/test/fake_peer_connection_base.h index 9531c6de5b..7970dd0f0f 100644 --- a/pc/test/fake_peer_connection_base.h +++ b/pc/test/fake_peer_connection_base.h @@ -120,10 +120,11 @@ class FakePeerConnectionBase : public PeerConnectionInternal { return nullptr; } - rtc::scoped_refptr CreateDataChannel( + RTCErrorOr> CreateDataChannelOrError( const std::string& label, const DataChannelInit* config) override { - return nullptr; + return RTCError(RTCErrorType::UNSUPPORTED_OPERATION, + "Fake function called"); } const SessionDescriptionInterface* local_description() const override { @@ -248,22 +249,16 @@ class FakePeerConnectionBase : public PeerConnectionInternal { return {}; } - sigslot::signal1& SignalRtpDataChannelCreated() override { - return SignalRtpDataChannelCreated_; - } - sigslot::signal1& SignalSctpDataChannelCreated() override { return SignalSctpDataChannelCreated_; } - cricket::RtpDataChannel* rtp_data_channel() const override { return nullptr; } - absl::optional sctp_transport_name() const override { return absl::nullopt; } - std::map GetTransportNamesByMid() const override { - return {}; + absl::optional sctp_mid() const override { + return absl::nullopt; } std::map GetTransportStatsByNames( @@ -298,7 +293,6 @@ class FakePeerConnectionBase : public PeerConnectionInternal { } protected: - sigslot::signal1 SignalRtpDataChannelCreated_; sigslot::signal1 SignalSctpDataChannelCreated_; }; diff --git a/pc/test/fake_peer_connection_for_stats.h b/pc/test/fake_peer_connection_for_stats.h index 70f8dd50a1..4cdbd82162 100644 --- a/pc/test/fake_peer_connection_for_stats.h +++ b/pc/test/fake_peer_connection_for_stats.h @@ -28,8 +28,10 @@ namespace webrtc { // Fake VoiceMediaChannel where the result of GetStats can be configured. class FakeVoiceMediaChannelForStats : public cricket::FakeVoiceMediaChannel { public: - FakeVoiceMediaChannelForStats() - : cricket::FakeVoiceMediaChannel(nullptr, cricket::AudioOptions()) {} + explicit FakeVoiceMediaChannelForStats(TaskQueueBase* network_thread) + : cricket::FakeVoiceMediaChannel(nullptr, + cricket::AudioOptions(), + network_thread) {} void SetStats(const cricket::VoiceMediaInfo& voice_info) { stats_ = voice_info; @@ -52,8 +54,10 @@ class FakeVoiceMediaChannelForStats : public cricket::FakeVoiceMediaChannel { // Fake VideoMediaChannel where the result of GetStats can be configured. class FakeVideoMediaChannelForStats : public cricket::FakeVideoMediaChannel { public: - FakeVideoMediaChannelForStats() - : cricket::FakeVideoMediaChannel(nullptr, cricket::VideoOptions()) {} + explicit FakeVideoMediaChannelForStats(TaskQueueBase* network_thread) + : cricket::FakeVideoMediaChannel(nullptr, + cricket::VideoOptions(), + network_thread) {} void SetStats(const cricket::VideoMediaInfo& video_info) { stats_ = video_info; @@ -75,6 +79,64 @@ class FakeVideoMediaChannelForStats : public cricket::FakeVideoMediaChannel { constexpr bool kDefaultRtcpMuxRequired = true; constexpr bool kDefaultSrtpRequired = true; +class VoiceChannelForTesting : public cricket::VoiceChannel { + public: + VoiceChannelForTesting(rtc::Thread* worker_thread, + rtc::Thread* network_thread, + rtc::Thread* signaling_thread, + std::unique_ptr channel, + const std::string& content_name, + bool srtp_required, + webrtc::CryptoOptions crypto_options, + rtc::UniqueRandomIdGenerator* ssrc_generator, + std::string transport_name) + : VoiceChannel(worker_thread, + network_thread, + signaling_thread, + std::move(channel), + content_name, + srtp_required, + std::move(crypto_options), + ssrc_generator), + test_transport_name_(std::move(transport_name)) {} + + private: + const std::string& transport_name() const override { + return test_transport_name_; + } + + const std::string test_transport_name_; +}; + +class VideoChannelForTesting : public cricket::VideoChannel { + public: + VideoChannelForTesting(rtc::Thread* worker_thread, + rtc::Thread* network_thread, + rtc::Thread* signaling_thread, + std::unique_ptr channel, + const std::string& content_name, + bool srtp_required, + webrtc::CryptoOptions crypto_options, + rtc::UniqueRandomIdGenerator* ssrc_generator, + std::string transport_name) + : VideoChannel(worker_thread, + network_thread, + signaling_thread, + std::move(channel), + content_name, + srtp_required, + std::move(crypto_options), + ssrc_generator), + test_transport_name_(std::move(transport_name)) {} + + private: + const std::string& transport_name() const override { + return test_transport_name_; + } + + const std::string test_transport_name_; +}; + // This class is intended to be fed into the StatsCollector and // RTCStatsCollector so that the stats functionality can be unit tested. // Individual tests can configure this fake as needed to simulate scenarios @@ -120,7 +182,7 @@ class FakePeerConnectionForStats : public FakePeerConnectionBase { // TODO(steveanton): Switch tests to use RtpTransceivers directly. auto receiver_proxy = RtpReceiverProxyWithInternal::Create( - signaling_thread_, receiver); + signaling_thread_, worker_thread_, receiver); GetOrCreateFirstTransceiverOfType(receiver->media_type()) ->internal() ->AddReceiver(receiver_proxy); @@ -138,13 +200,12 @@ class FakePeerConnectionForStats : public FakePeerConnectionBase { const std::string& transport_name) { RTC_DCHECK(!voice_channel_); auto voice_media_channel = - std::make_unique(); + std::make_unique(network_thread_); auto* voice_media_channel_ptr = voice_media_channel.get(); - voice_channel_ = std::make_unique( + voice_channel_ = std::make_unique( worker_thread_, network_thread_, signaling_thread_, std::move(voice_media_channel), mid, kDefaultSrtpRequired, - webrtc::CryptoOptions(), &ssrc_generator_); - voice_channel_->set_transport_name_for_testing(transport_name); + webrtc::CryptoOptions(), &ssrc_generator_, transport_name); GetOrCreateFirstTransceiverOfType(cricket::MEDIA_TYPE_AUDIO) ->internal() ->SetChannel(voice_channel_.get()); @@ -156,13 +217,12 @@ class FakePeerConnectionForStats : public FakePeerConnectionBase { const std::string& transport_name) { RTC_DCHECK(!video_channel_); auto video_media_channel = - std::make_unique(); + std::make_unique(network_thread_); auto video_media_channel_ptr = video_media_channel.get(); - video_channel_ = std::make_unique( + video_channel_ = std::make_unique( worker_thread_, network_thread_, signaling_thread_, std::move(video_media_channel), mid, kDefaultSrtpRequired, - webrtc::CryptoOptions(), &ssrc_generator_); - video_channel_->set_transport_name_for_testing(transport_name); + webrtc::CryptoOptions(), &ssrc_generator_, transport_name); GetOrCreateFirstTransceiverOfType(cricket::MEDIA_TYPE_VIDEO) ->internal() ->SetChannel(video_channel_.get()); @@ -272,21 +332,9 @@ class FakePeerConnectionForStats : public FakePeerConnectionBase { return {}; } - std::map GetTransportNamesByMid() const override { - std::map transport_names_by_mid; - if (voice_channel_) { - transport_names_by_mid[voice_channel_->content_name()] = - voice_channel_->transport_name(); - } - if (video_channel_) { - transport_names_by_mid[video_channel_->content_name()] = - video_channel_->transport_name(); - } - return transport_names_by_mid; - } - std::map GetTransportStatsByNames( const std::set& transport_names) override { + RTC_DCHECK_RUN_ON(network_thread_); std::map transport_stats_by_name; for (const std::string& transport_name : transport_names) { transport_stats_by_name[transport_name] = @@ -344,7 +392,8 @@ class FakePeerConnectionForStats : public FakePeerConnectionBase { } } auto transceiver = RtpTransceiverProxyWithInternal::Create( - signaling_thread_, new RtpTransceiver(media_type)); + signaling_thread_, + new RtpTransceiver(media_type, channel_manager_.get())); transceivers_.push_back(transceiver); return transceiver; } @@ -353,6 +402,12 @@ class FakePeerConnectionForStats : public FakePeerConnectionBase { rtc::Thread* const worker_thread_; rtc::Thread* const signaling_thread_; + std::unique_ptr channel_manager_ = + cricket::ChannelManager::Create(nullptr /* MediaEngineInterface */, + true, + worker_thread_, + network_thread_); + rtc::scoped_refptr local_streams_; rtc::scoped_refptr remote_streams_; diff --git a/pc/test/fake_periodic_video_source.h b/pc/test/fake_periodic_video_source.h index ac6e5a43e7..871c29cbae 100644 --- a/pc/test/fake_periodic_video_source.h +++ b/pc/test/fake_periodic_video_source.h @@ -86,7 +86,7 @@ class FakePeriodicVideoSource final } private: - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; rtc::VideoBroadcaster broadcaster_; cricket::FakeFrameSource frame_source_; diff --git a/pc/test/fake_video_track_source.h b/pc/test/fake_video_track_source.h index d6562313c5..2042c39175 100644 --- a/pc/test/fake_video_track_source.h +++ b/pc/test/fake_video_track_source.h @@ -22,7 +22,7 @@ namespace webrtc { class FakeVideoTrackSource : public VideoTrackSource { public: static rtc::scoped_refptr Create(bool is_screencast) { - return new rtc::RefCountedObject(is_screencast); + return rtc::make_ref_counted(is_screencast); } static rtc::scoped_refptr Create() { diff --git a/pc/test/integration_test_helpers.cc b/pc/test/integration_test_helpers.cc new file mode 100644 index 0000000000..10e4f455ba --- /dev/null +++ b/pc/test/integration_test_helpers.cc @@ -0,0 +1,59 @@ +/* + * Copyright 2012 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "pc/test/integration_test_helpers.h" + +namespace webrtc { + +PeerConnectionInterface::RTCOfferAnswerOptions IceRestartOfferAnswerOptions() { + PeerConnectionInterface::RTCOfferAnswerOptions options; + options.ice_restart = true; + return options; +} + +void RemoveSsrcsAndMsids(cricket::SessionDescription* desc) { + for (ContentInfo& content : desc->contents()) { + content.media_description()->mutable_streams().clear(); + } + desc->set_msid_supported(false); + desc->set_msid_signaling(0); +} + +void RemoveSsrcsAndKeepMsids(cricket::SessionDescription* desc) { + for (ContentInfo& content : desc->contents()) { + std::string track_id; + std::vector stream_ids; + if (!content.media_description()->streams().empty()) { + const StreamParams& first_stream = + content.media_description()->streams()[0]; + track_id = first_stream.id; + stream_ids = first_stream.stream_ids(); + } + content.media_description()->mutable_streams().clear(); + StreamParams new_stream; + new_stream.id = track_id; + new_stream.set_stream_ids(stream_ids); + content.media_description()->AddStream(new_stream); + } +} + +int FindFirstMediaStatsIndexByKind( + const std::string& kind, + const std::vector& + media_stats_vec) { + for (size_t i = 0; i < media_stats_vec.size(); i++) { + if (media_stats_vec[i]->kind.ValueToString() == kind) { + return i; + } + } + return -1; +} + +} // namespace webrtc diff --git a/pc/test/integration_test_helpers.h b/pc/test/integration_test_helpers.h new file mode 100644 index 0000000000..117f1b428b --- /dev/null +++ b/pc/test/integration_test_helpers.h @@ -0,0 +1,1852 @@ +/* + * Copyright 2012 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef PC_TEST_INTEGRATION_TEST_HELPERS_H_ +#define PC_TEST_INTEGRATION_TEST_HELPERS_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "api/audio_options.h" +#include "api/call/call_factory_interface.h" +#include "api/candidate.h" +#include "api/crypto/crypto_options.h" +#include "api/data_channel_interface.h" +#include "api/ice_transport_interface.h" +#include "api/jsep.h" +#include "api/media_stream_interface.h" +#include "api/media_types.h" +#include "api/peer_connection_interface.h" +#include "api/rtc_error.h" +#include "api/rtc_event_log/rtc_event_log_factory.h" +#include "api/rtc_event_log/rtc_event_log_factory_interface.h" +#include "api/rtc_event_log_output.h" +#include "api/rtp_receiver_interface.h" +#include "api/rtp_sender_interface.h" +#include "api/rtp_transceiver_interface.h" +#include "api/scoped_refptr.h" +#include "api/stats/rtc_stats.h" +#include "api/stats/rtc_stats_report.h" +#include "api/stats/rtcstats_objects.h" +#include "api/task_queue/default_task_queue_factory.h" +#include "api/task_queue/task_queue_factory.h" +#include "api/transport/field_trial_based_config.h" +#include "api/transport/webrtc_key_value_config.h" +#include "api/uma_metrics.h" +#include "api/video/video_rotation.h" +#include "api/video_codecs/sdp_video_format.h" +#include "api/video_codecs/video_decoder_factory.h" +#include "api/video_codecs/video_encoder_factory.h" +#include "call/call.h" +#include "logging/rtc_event_log/fake_rtc_event_log_factory.h" +#include "media/base/media_engine.h" +#include "media/base/stream_params.h" +#include "media/engine/fake_webrtc_video_engine.h" +#include "media/engine/webrtc_media_engine.h" +#include "media/engine/webrtc_media_engine_defaults.h" +#include "modules/audio_device/include/audio_device.h" +#include "modules/audio_processing/include/audio_processing.h" +#include "modules/audio_processing/test/audio_processing_builder_for_testing.h" +#include "p2p/base/fake_ice_transport.h" +#include "p2p/base/ice_transport_internal.h" +#include "p2p/base/mock_async_resolver.h" +#include "p2p/base/p2p_constants.h" +#include "p2p/base/port.h" +#include "p2p/base/port_allocator.h" +#include "p2p/base/port_interface.h" +#include "p2p/base/test_stun_server.h" +#include "p2p/base/test_turn_customizer.h" +#include "p2p/base/test_turn_server.h" +#include "p2p/client/basic_port_allocator.h" +#include "pc/dtmf_sender.h" +#include "pc/local_audio_source.h" +#include "pc/media_session.h" +#include "pc/peer_connection.h" +#include "pc/peer_connection_factory.h" +#include "pc/peer_connection_proxy.h" +#include "pc/rtp_media_utils.h" +#include "pc/session_description.h" +#include "pc/test/fake_audio_capture_module.h" +#include "pc/test/fake_periodic_video_source.h" +#include "pc/test/fake_periodic_video_track_source.h" +#include "pc/test/fake_rtc_certificate_generator.h" +#include "pc/test/fake_video_track_renderer.h" +#include "pc/test/mock_peer_connection_observers.h" +#include "pc/video_track_source.h" +#include "rtc_base/checks.h" +#include "rtc_base/fake_clock.h" +#include "rtc_base/fake_mdns_responder.h" +#include "rtc_base/fake_network.h" +#include "rtc_base/firewall_socket_server.h" +#include "rtc_base/gunit.h" +#include "rtc_base/helpers.h" +#include "rtc_base/ip_address.h" +#include "rtc_base/location.h" +#include "rtc_base/logging.h" +#include "rtc_base/mdns_responder_interface.h" +#include "rtc_base/numerics/safe_conversions.h" +#include "rtc_base/ref_counted_object.h" +#include "rtc_base/rtc_certificate_generator.h" +#include "rtc_base/socket_address.h" +#include "rtc_base/ssl_stream_adapter.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/task_utils/to_queued_task.h" +#include "rtc_base/test_certificate_verifier.h" +#include "rtc_base/thread.h" +#include "rtc_base/time_utils.h" +#include "rtc_base/virtual_socket_server.h" +#include "system_wrappers/include/metrics.h" +#include "test/field_trial.h" +#include "test/gmock.h" + +namespace webrtc { + +using ::cricket::ContentInfo; +using ::cricket::StreamParams; +using ::rtc::SocketAddress; +using ::testing::_; +using ::testing::Combine; +using ::testing::Contains; +using ::testing::DoAll; +using ::testing::ElementsAre; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::SetArgPointee; +using ::testing::UnorderedElementsAreArray; +using ::testing::Values; +using RTCConfiguration = PeerConnectionInterface::RTCConfiguration; + +static const int kDefaultTimeout = 10000; +static const int kMaxWaitForStatsMs = 3000; +static const int kMaxWaitForActivationMs = 5000; +static const int kMaxWaitForFramesMs = 10000; +// Default number of audio/video frames to wait for before considering a test +// successful. +static const int kDefaultExpectedAudioFrameCount = 3; +static const int kDefaultExpectedVideoFrameCount = 3; + +static const char kDataChannelLabel[] = "data_channel"; + +// SRTP cipher name negotiated by the tests. This must be updated if the +// default changes. +static const int kDefaultSrtpCryptoSuite = rtc::SRTP_AES128_CM_SHA1_80; +static const int kDefaultSrtpCryptoSuiteGcm = rtc::SRTP_AEAD_AES_256_GCM; + +static const SocketAddress kDefaultLocalAddress("192.168.1.1", 0); + +// Helper function for constructing offer/answer options to initiate an ICE +// restart. +PeerConnectionInterface::RTCOfferAnswerOptions IceRestartOfferAnswerOptions(); + +// Remove all stream information (SSRCs, track IDs, etc.) and "msid-semantic" +// attribute from received SDP, simulating a legacy endpoint. +void RemoveSsrcsAndMsids(cricket::SessionDescription* desc); + +// Removes all stream information besides the stream ids, simulating an +// endpoint that only signals a=msid lines to convey stream_ids. +void RemoveSsrcsAndKeepMsids(cricket::SessionDescription* desc); + +int FindFirstMediaStatsIndexByKind( + const std::string& kind, + const std::vector& + media_stats_vec); + +class SignalingMessageReceiver { + public: + virtual void ReceiveSdpMessage(SdpType type, const std::string& msg) = 0; + virtual void ReceiveIceMessage(const std::string& sdp_mid, + int sdp_mline_index, + const std::string& msg) = 0; + + protected: + SignalingMessageReceiver() {} + virtual ~SignalingMessageReceiver() {} +}; + +class MockRtpReceiverObserver : public webrtc::RtpReceiverObserverInterface { + public: + explicit MockRtpReceiverObserver(cricket::MediaType media_type) + : expected_media_type_(media_type) {} + + void OnFirstPacketReceived(cricket::MediaType media_type) override { + ASSERT_EQ(expected_media_type_, media_type); + first_packet_received_ = true; + } + + bool first_packet_received() const { return first_packet_received_; } + + virtual ~MockRtpReceiverObserver() {} + + private: + bool first_packet_received_ = false; + cricket::MediaType expected_media_type_; +}; + +// Helper class that wraps a peer connection, observes it, and can accept +// signaling messages from another wrapper. +// +// Uses a fake network, fake A/V capture, and optionally fake +// encoders/decoders, though they aren't used by default since they don't +// advertise support of any codecs. +// TODO(steveanton): See how this could become a subclass of +// PeerConnectionWrapper defined in peerconnectionwrapper.h. +class PeerConnectionIntegrationWrapper : public webrtc::PeerConnectionObserver, + public SignalingMessageReceiver { + public: + // Different factory methods for convenience. + // TODO(deadbeef): Could use the pattern of: + // + // PeerConnectionIntegrationWrapper = + // WrapperBuilder.WithConfig(...).WithOptions(...).build(); + // + // To reduce some code duplication. + static PeerConnectionIntegrationWrapper* CreateWithDtlsIdentityStore( + const std::string& debug_name, + std::unique_ptr cert_generator, + rtc::Thread* network_thread, + rtc::Thread* worker_thread) { + PeerConnectionIntegrationWrapper* client( + new PeerConnectionIntegrationWrapper(debug_name)); + webrtc::PeerConnectionDependencies dependencies(nullptr); + dependencies.cert_generator = std::move(cert_generator); + if (!client->Init(nullptr, nullptr, std::move(dependencies), network_thread, + worker_thread, nullptr, + /*reset_encoder_factory=*/false, + /*reset_decoder_factory=*/false)) { + delete client; + return nullptr; + } + return client; + } + + webrtc::PeerConnectionFactoryInterface* pc_factory() const { + return peer_connection_factory_.get(); + } + + webrtc::PeerConnectionInterface* pc() const { return peer_connection_.get(); } + + // If a signaling message receiver is set (via ConnectFakeSignaling), this + // will set the whole offer/answer exchange in motion. Just need to wait for + // the signaling state to reach "stable". + void CreateAndSetAndSignalOffer() { + auto offer = CreateOfferAndWait(); + ASSERT_NE(nullptr, offer); + EXPECT_TRUE(SetLocalDescriptionAndSendSdpMessage(std::move(offer))); + } + + // Sets the options to be used when CreateAndSetAndSignalOffer is called, or + // when a remote offer is received (via fake signaling) and an answer is + // generated. By default, uses default options. + void SetOfferAnswerOptions( + const PeerConnectionInterface::RTCOfferAnswerOptions& options) { + offer_answer_options_ = options; + } + + // Set a callback to be invoked when SDP is received via the fake signaling + // channel, which provides an opportunity to munge (modify) the SDP. This is + // used to test SDP being applied that a PeerConnection would normally not + // generate, but a non-JSEP endpoint might. + void SetReceivedSdpMunger( + std::function munger) { + received_sdp_munger_ = std::move(munger); + } + + // Similar to the above, but this is run on SDP immediately after it's + // generated. + void SetGeneratedSdpMunger( + std::function munger) { + generated_sdp_munger_ = std::move(munger); + } + + // Set a callback to be invoked when a remote offer is received via the fake + // signaling channel. This provides an opportunity to change the + // PeerConnection state before an answer is created and sent to the caller. + void SetRemoteOfferHandler(std::function handler) { + remote_offer_handler_ = std::move(handler); + } + + void SetRemoteAsyncResolver(rtc::MockAsyncResolver* resolver) { + remote_async_resolver_ = resolver; + } + + // Every ICE connection state in order that has been seen by the observer. + std::vector + ice_connection_state_history() const { + return ice_connection_state_history_; + } + void clear_ice_connection_state_history() { + ice_connection_state_history_.clear(); + } + + // Every standardized ICE connection state in order that has been seen by the + // observer. + std::vector + standardized_ice_connection_state_history() const { + return standardized_ice_connection_state_history_; + } + + // Every PeerConnection state in order that has been seen by the observer. + std::vector + peer_connection_state_history() const { + return peer_connection_state_history_; + } + + // Every ICE gathering state in order that has been seen by the observer. + std::vector + ice_gathering_state_history() const { + return ice_gathering_state_history_; + } + std::vector + ice_candidate_pair_change_history() const { + return ice_candidate_pair_change_history_; + } + + // Every PeerConnection signaling state in order that has been seen by the + // observer. + std::vector + peer_connection_signaling_state_history() const { + return peer_connection_signaling_state_history_; + } + + void AddAudioVideoTracks() { + AddAudioTrack(); + AddVideoTrack(); + } + + rtc::scoped_refptr AddAudioTrack() { + return AddTrack(CreateLocalAudioTrack()); + } + + rtc::scoped_refptr AddVideoTrack() { + return AddTrack(CreateLocalVideoTrack()); + } + + rtc::scoped_refptr CreateLocalAudioTrack() { + cricket::AudioOptions options; + // Disable highpass filter so that we can get all the test audio frames. + options.highpass_filter = false; + rtc::scoped_refptr source = + peer_connection_factory_->CreateAudioSource(options); + // TODO(perkj): Test audio source when it is implemented. Currently audio + // always use the default input. + return peer_connection_factory_->CreateAudioTrack(rtc::CreateRandomUuid(), + source); + } + + rtc::scoped_refptr CreateLocalVideoTrack() { + webrtc::FakePeriodicVideoSource::Config config; + config.timestamp_offset_ms = rtc::TimeMillis(); + return CreateLocalVideoTrackInternal(config); + } + + rtc::scoped_refptr + CreateLocalVideoTrackWithConfig( + webrtc::FakePeriodicVideoSource::Config config) { + return CreateLocalVideoTrackInternal(config); + } + + rtc::scoped_refptr + CreateLocalVideoTrackWithRotation(webrtc::VideoRotation rotation) { + webrtc::FakePeriodicVideoSource::Config config; + config.rotation = rotation; + config.timestamp_offset_ms = rtc::TimeMillis(); + return CreateLocalVideoTrackInternal(config); + } + + rtc::scoped_refptr AddTrack( + rtc::scoped_refptr track, + const std::vector& stream_ids = {}) { + auto result = pc()->AddTrack(track, stream_ids); + EXPECT_EQ(RTCErrorType::NONE, result.error().type()); + return result.MoveValue(); + } + + std::vector> GetReceiversOfType( + cricket::MediaType media_type) { + std::vector> receivers; + for (const auto& receiver : pc()->GetReceivers()) { + if (receiver->media_type() == media_type) { + receivers.push_back(receiver); + } + } + return receivers; + } + + rtc::scoped_refptr GetFirstTransceiverOfType( + cricket::MediaType media_type) { + for (auto transceiver : pc()->GetTransceivers()) { + if (transceiver->receiver()->media_type() == media_type) { + return transceiver; + } + } + return nullptr; + } + + bool SignalingStateStable() { + return pc()->signaling_state() == webrtc::PeerConnectionInterface::kStable; + } + + void CreateDataChannel() { CreateDataChannel(nullptr); } + + void CreateDataChannel(const webrtc::DataChannelInit* init) { + CreateDataChannel(kDataChannelLabel, init); + } + + void CreateDataChannel(const std::string& label, + const webrtc::DataChannelInit* init) { + data_channel_ = pc()->CreateDataChannel(label, init); + ASSERT_TRUE(data_channel_.get() != nullptr); + data_observer_.reset(new MockDataChannelObserver(data_channel_)); + } + + DataChannelInterface* data_channel() { return data_channel_; } + const MockDataChannelObserver* data_observer() const { + return data_observer_.get(); + } + + int audio_frames_received() const { + return fake_audio_capture_module_->frames_received(); + } + + // Takes minimum of video frames received for each track. + // + // Can be used like: + // EXPECT_GE(expected_frames, min_video_frames_received_per_track()); + // + // To ensure that all video tracks received at least a certain number of + // frames. + int min_video_frames_received_per_track() const { + int min_frames = INT_MAX; + if (fake_video_renderers_.empty()) { + return 0; + } + + for (const auto& pair : fake_video_renderers_) { + min_frames = std::min(min_frames, pair.second->num_rendered_frames()); + } + return min_frames; + } + + // Returns a MockStatsObserver in a state after stats gathering finished, + // which can be used to access the gathered stats. + rtc::scoped_refptr OldGetStatsForTrack( + webrtc::MediaStreamTrackInterface* track) { + auto observer = rtc::make_ref_counted(); + EXPECT_TRUE(peer_connection_->GetStats( + observer, nullptr, PeerConnectionInterface::kStatsOutputLevelStandard)); + EXPECT_TRUE_WAIT(observer->called(), kDefaultTimeout); + return observer; + } + + // Version that doesn't take a track "filter", and gathers all stats. + rtc::scoped_refptr OldGetStats() { + return OldGetStatsForTrack(nullptr); + } + + // Synchronously gets stats and returns them. If it times out, fails the test + // and returns null. + rtc::scoped_refptr NewGetStats() { + auto callback = + rtc::make_ref_counted(); + peer_connection_->GetStats(callback); + EXPECT_TRUE_WAIT(callback->called(), kDefaultTimeout); + return callback->report(); + } + + int rendered_width() { + EXPECT_FALSE(fake_video_renderers_.empty()); + return fake_video_renderers_.empty() + ? 0 + : fake_video_renderers_.begin()->second->width(); + } + + int rendered_height() { + EXPECT_FALSE(fake_video_renderers_.empty()); + return fake_video_renderers_.empty() + ? 0 + : fake_video_renderers_.begin()->second->height(); + } + + double rendered_aspect_ratio() { + if (rendered_height() == 0) { + return 0.0; + } + return static_cast(rendered_width()) / rendered_height(); + } + + webrtc::VideoRotation rendered_rotation() { + EXPECT_FALSE(fake_video_renderers_.empty()); + return fake_video_renderers_.empty() + ? webrtc::kVideoRotation_0 + : fake_video_renderers_.begin()->second->rotation(); + } + + int local_rendered_width() { + return local_video_renderer_ ? local_video_renderer_->width() : 0; + } + + int local_rendered_height() { + return local_video_renderer_ ? local_video_renderer_->height() : 0; + } + + double local_rendered_aspect_ratio() { + if (local_rendered_height() == 0) { + return 0.0; + } + return static_cast(local_rendered_width()) / + local_rendered_height(); + } + + size_t number_of_remote_streams() { + if (!pc()) { + return 0; + } + return pc()->remote_streams()->count(); + } + + StreamCollectionInterface* remote_streams() const { + if (!pc()) { + ADD_FAILURE(); + return nullptr; + } + return pc()->remote_streams(); + } + + StreamCollectionInterface* local_streams() { + if (!pc()) { + ADD_FAILURE(); + return nullptr; + } + return pc()->local_streams(); + } + + webrtc::PeerConnectionInterface::SignalingState signaling_state() { + return pc()->signaling_state(); + } + + webrtc::PeerConnectionInterface::IceConnectionState ice_connection_state() { + return pc()->ice_connection_state(); + } + + webrtc::PeerConnectionInterface::IceConnectionState + standardized_ice_connection_state() { + return pc()->standardized_ice_connection_state(); + } + + webrtc::PeerConnectionInterface::IceGatheringState ice_gathering_state() { + return pc()->ice_gathering_state(); + } + + // Returns a MockRtpReceiverObserver for each RtpReceiver returned by + // GetReceivers. They're updated automatically when a remote offer/answer + // from the fake signaling channel is applied, or when + // ResetRtpReceiverObservers below is called. + const std::vector>& + rtp_receiver_observers() { + return rtp_receiver_observers_; + } + + void ResetRtpReceiverObservers() { + rtp_receiver_observers_.clear(); + for (const rtc::scoped_refptr& receiver : + pc()->GetReceivers()) { + std::unique_ptr observer( + new MockRtpReceiverObserver(receiver->media_type())); + receiver->SetObserver(observer.get()); + rtp_receiver_observers_.push_back(std::move(observer)); + } + } + + rtc::FakeNetworkManager* network_manager() const { + return fake_network_manager_.get(); + } + cricket::PortAllocator* port_allocator() const { return port_allocator_; } + + webrtc::FakeRtcEventLogFactory* event_log_factory() const { + return event_log_factory_; + } + + const cricket::Candidate& last_candidate_gathered() const { + return last_candidate_gathered_; + } + const cricket::IceCandidateErrorEvent& error_event() const { + return error_event_; + } + + // Sets the mDNS responder for the owned fake network manager and keeps a + // reference to the responder. + void SetMdnsResponder( + std::unique_ptr mdns_responder) { + RTC_DCHECK(mdns_responder != nullptr); + mdns_responder_ = mdns_responder.get(); + network_manager()->set_mdns_responder(std::move(mdns_responder)); + } + + // Returns null on failure. + std::unique_ptr CreateOfferAndWait() { + auto observer = + rtc::make_ref_counted(); + pc()->CreateOffer(observer, offer_answer_options_); + return WaitForDescriptionFromObserver(observer); + } + bool Rollback() { + return SetRemoteDescription( + webrtc::CreateSessionDescription(SdpType::kRollback, "")); + } + + // Functions for querying stats. + void StartWatchingDelayStats() { + // Get the baseline numbers for audio_packets and audio_delay. + auto received_stats = NewGetStats(); + auto track_stats = + received_stats->GetStatsOfType()[0]; + ASSERT_TRUE(track_stats->relative_packet_arrival_delay.is_defined()); + auto rtp_stats = + received_stats->GetStatsOfType()[0]; + ASSERT_TRUE(rtp_stats->packets_received.is_defined()); + ASSERT_TRUE(rtp_stats->track_id.is_defined()); + audio_track_stats_id_ = track_stats->id(); + ASSERT_TRUE(received_stats->Get(audio_track_stats_id_)); + rtp_stats_id_ = rtp_stats->id(); + ASSERT_EQ(audio_track_stats_id_, *rtp_stats->track_id); + audio_packets_stat_ = *rtp_stats->packets_received; + audio_delay_stat_ = *track_stats->relative_packet_arrival_delay; + audio_samples_stat_ = *track_stats->total_samples_received; + audio_concealed_stat_ = *track_stats->concealed_samples; + } + + void UpdateDelayStats(std::string tag, int desc_size) { + auto report = NewGetStats(); + auto track_stats = + report->GetAs(audio_track_stats_id_); + ASSERT_TRUE(track_stats); + auto rtp_stats = + report->GetAs(rtp_stats_id_); + ASSERT_TRUE(rtp_stats); + auto delta_packets = *rtp_stats->packets_received - audio_packets_stat_; + auto delta_rpad = + *track_stats->relative_packet_arrival_delay - audio_delay_stat_; + auto recent_delay = delta_packets > 0 ? delta_rpad / delta_packets : -1; + // The purpose of these checks is to sound the alarm early if we introduce + // serious regressions. The numbers are not acceptable for production, but + // occur on slow bots. + // + // An average relative packet arrival delay over the renegotiation of + // > 100 ms indicates that something is dramatically wrong, and will impact + // quality for sure. + // Worst bots: + // linux_x86_dbg at 0.206 +#if !defined(NDEBUG) + EXPECT_GT(0.25, recent_delay) << tag << " size " << desc_size; +#else + EXPECT_GT(0.1, recent_delay) << tag << " size " << desc_size; +#endif + auto delta_samples = + *track_stats->total_samples_received - audio_samples_stat_; + auto delta_concealed = + *track_stats->concealed_samples - audio_concealed_stat_; + // These limits should be adjusted down as we improve: + // + // Concealing more than 4000 samples during a renegotiation is unacceptable. + // But some bots are slow. + + // Worst bots: + // linux_more_configs bot at conceal count 5184 + // android_arm_rel at conceal count 9241 + // linux_x86_dbg at 15174 +#if !defined(NDEBUG) + EXPECT_GT(18000U, delta_concealed) << "Concealed " << delta_concealed + << " of " << delta_samples << " samples"; +#else + EXPECT_GT(15000U, delta_concealed) << "Concealed " << delta_concealed + << " of " << delta_samples << " samples"; +#endif + // Concealing more than 20% of samples during a renegotiation is + // unacceptable. + // Worst bots: + // linux_more_configs bot at conceal rate 0.516 + // linux_x86_dbg bot at conceal rate 0.854 + if (delta_samples > 0) { +#if !defined(NDEBUG) + EXPECT_GT(0.95, 1.0 * delta_concealed / delta_samples) + << "Concealed " << delta_concealed << " of " << delta_samples + << " samples"; +#else + EXPECT_GT(0.6, 1.0 * delta_concealed / delta_samples) + << "Concealed " << delta_concealed << " of " << delta_samples + << " samples"; +#endif + } + // Increment trailing counters + audio_packets_stat_ = *rtp_stats->packets_received; + audio_delay_stat_ = *track_stats->relative_packet_arrival_delay; + audio_samples_stat_ = *track_stats->total_samples_received; + audio_concealed_stat_ = *track_stats->concealed_samples; + } + + private: + explicit PeerConnectionIntegrationWrapper(const std::string& debug_name) + : debug_name_(debug_name) {} + + bool Init(const PeerConnectionFactory::Options* options, + const PeerConnectionInterface::RTCConfiguration* config, + webrtc::PeerConnectionDependencies dependencies, + rtc::Thread* network_thread, + rtc::Thread* worker_thread, + std::unique_ptr event_log_factory, + bool reset_encoder_factory, + bool reset_decoder_factory) { + // There's an error in this test code if Init ends up being called twice. + RTC_DCHECK(!peer_connection_); + RTC_DCHECK(!peer_connection_factory_); + + fake_network_manager_.reset(new rtc::FakeNetworkManager()); + fake_network_manager_->AddInterface(kDefaultLocalAddress); + + std::unique_ptr port_allocator( + new cricket::BasicPortAllocator(fake_network_manager_.get())); + port_allocator_ = port_allocator.get(); + fake_audio_capture_module_ = FakeAudioCaptureModule::Create(); + if (!fake_audio_capture_module_) { + return false; + } + rtc::Thread* const signaling_thread = rtc::Thread::Current(); + + webrtc::PeerConnectionFactoryDependencies pc_factory_dependencies; + pc_factory_dependencies.network_thread = network_thread; + pc_factory_dependencies.worker_thread = worker_thread; + pc_factory_dependencies.signaling_thread = signaling_thread; + pc_factory_dependencies.task_queue_factory = + webrtc::CreateDefaultTaskQueueFactory(); + pc_factory_dependencies.trials = std::make_unique(); + cricket::MediaEngineDependencies media_deps; + media_deps.task_queue_factory = + pc_factory_dependencies.task_queue_factory.get(); + media_deps.adm = fake_audio_capture_module_; + webrtc::SetMediaEngineDefaults(&media_deps); + + if (reset_encoder_factory) { + media_deps.video_encoder_factory.reset(); + } + if (reset_decoder_factory) { + media_deps.video_decoder_factory.reset(); + } + + if (!media_deps.audio_processing) { + // If the standard Creation method for APM returns a null pointer, instead + // use the builder for testing to create an APM object. + media_deps.audio_processing = AudioProcessingBuilderForTesting().Create(); + } + + media_deps.trials = pc_factory_dependencies.trials.get(); + + pc_factory_dependencies.media_engine = + cricket::CreateMediaEngine(std::move(media_deps)); + pc_factory_dependencies.call_factory = webrtc::CreateCallFactory(); + if (event_log_factory) { + event_log_factory_ = event_log_factory.get(); + pc_factory_dependencies.event_log_factory = std::move(event_log_factory); + } else { + pc_factory_dependencies.event_log_factory = + std::make_unique( + pc_factory_dependencies.task_queue_factory.get()); + } + peer_connection_factory_ = webrtc::CreateModularPeerConnectionFactory( + std::move(pc_factory_dependencies)); + + if (!peer_connection_factory_) { + return false; + } + if (options) { + peer_connection_factory_->SetOptions(*options); + } + if (config) { + sdp_semantics_ = config->sdp_semantics; + } + + dependencies.allocator = std::move(port_allocator); + peer_connection_ = CreatePeerConnection(config, std::move(dependencies)); + return peer_connection_.get() != nullptr; + } + + rtc::scoped_refptr CreatePeerConnection( + const PeerConnectionInterface::RTCConfiguration* config, + webrtc::PeerConnectionDependencies dependencies) { + PeerConnectionInterface::RTCConfiguration modified_config; + // If |config| is null, this will result in a default configuration being + // used. + if (config) { + modified_config = *config; + } + // Disable resolution adaptation; we don't want it interfering with the + // test results. + // TODO(deadbeef): Do something more robust. Since we're testing for aspect + // ratios and not specific resolutions, is this even necessary? + modified_config.set_cpu_adaptation(false); + + dependencies.observer = this; + return peer_connection_factory_->CreatePeerConnection( + modified_config, std::move(dependencies)); + } + + void set_signaling_message_receiver( + SignalingMessageReceiver* signaling_message_receiver) { + signaling_message_receiver_ = signaling_message_receiver; + } + + void set_signaling_delay_ms(int delay_ms) { signaling_delay_ms_ = delay_ms; } + + void set_signal_ice_candidates(bool signal) { + signal_ice_candidates_ = signal; + } + + rtc::scoped_refptr CreateLocalVideoTrackInternal( + webrtc::FakePeriodicVideoSource::Config config) { + // Set max frame rate to 10fps to reduce the risk of test flakiness. + // TODO(deadbeef): Do something more robust. + config.frame_interval_ms = 100; + + video_track_sources_.emplace_back( + rtc::make_ref_counted( + config, false /* remote */)); + rtc::scoped_refptr track( + peer_connection_factory_->CreateVideoTrack( + rtc::CreateRandomUuid(), video_track_sources_.back())); + if (!local_video_renderer_) { + local_video_renderer_.reset(new webrtc::FakeVideoTrackRenderer(track)); + } + return track; + } + + void HandleIncomingOffer(const std::string& msg) { + RTC_LOG(LS_INFO) << debug_name_ << ": HandleIncomingOffer"; + std::unique_ptr desc = + webrtc::CreateSessionDescription(SdpType::kOffer, msg); + if (received_sdp_munger_) { + received_sdp_munger_(desc->description()); + } + + EXPECT_TRUE(SetRemoteDescription(std::move(desc))); + // Setting a remote description may have changed the number of receivers, + // so reset the receiver observers. + ResetRtpReceiverObservers(); + if (remote_offer_handler_) { + remote_offer_handler_(); + } + auto answer = CreateAnswer(); + ASSERT_NE(nullptr, answer); + EXPECT_TRUE(SetLocalDescriptionAndSendSdpMessage(std::move(answer))); + } + + void HandleIncomingAnswer(const std::string& msg) { + RTC_LOG(LS_INFO) << debug_name_ << ": HandleIncomingAnswer"; + std::unique_ptr desc = + webrtc::CreateSessionDescription(SdpType::kAnswer, msg); + if (received_sdp_munger_) { + received_sdp_munger_(desc->description()); + } + + EXPECT_TRUE(SetRemoteDescription(std::move(desc))); + // Set the RtpReceiverObserver after receivers are created. + ResetRtpReceiverObservers(); + } + + // Returns null on failure. + std::unique_ptr CreateAnswer() { + auto observer = + rtc::make_ref_counted(); + pc()->CreateAnswer(observer, offer_answer_options_); + return WaitForDescriptionFromObserver(observer); + } + + std::unique_ptr WaitForDescriptionFromObserver( + MockCreateSessionDescriptionObserver* observer) { + EXPECT_EQ_WAIT(true, observer->called(), kDefaultTimeout); + if (!observer->result()) { + return nullptr; + } + auto description = observer->MoveDescription(); + if (generated_sdp_munger_) { + generated_sdp_munger_(description->description()); + } + return description; + } + + // Setting the local description and sending the SDP message over the fake + // signaling channel are combined into the same method because the SDP + // message needs to be sent as soon as SetLocalDescription finishes, without + // waiting for the observer to be called. This ensures that ICE candidates + // don't outrace the description. + bool SetLocalDescriptionAndSendSdpMessage( + std::unique_ptr desc) { + auto observer = rtc::make_ref_counted(); + RTC_LOG(LS_INFO) << debug_name_ << ": SetLocalDescriptionAndSendSdpMessage"; + SdpType type = desc->GetType(); + std::string sdp; + EXPECT_TRUE(desc->ToString(&sdp)); + RTC_LOG(LS_INFO) << debug_name_ << ": local SDP contents=\n" << sdp; + pc()->SetLocalDescription(observer, desc.release()); + RemoveUnusedVideoRenderers(); + // As mentioned above, we need to send the message immediately after + // SetLocalDescription. + SendSdpMessage(type, sdp); + EXPECT_TRUE_WAIT(observer->called(), kDefaultTimeout); + return true; + } + + bool SetRemoteDescription(std::unique_ptr desc) { + auto observer = rtc::make_ref_counted(); + RTC_LOG(LS_INFO) << debug_name_ << ": SetRemoteDescription"; + pc()->SetRemoteDescription(observer, desc.release()); + RemoveUnusedVideoRenderers(); + EXPECT_TRUE_WAIT(observer->called(), kDefaultTimeout); + return observer->result(); + } + + // This is a work around to remove unused fake_video_renderers from + // transceivers that have either stopped or are no longer receiving. + void RemoveUnusedVideoRenderers() { + if (sdp_semantics_ != SdpSemantics::kUnifiedPlan) { + return; + } + auto transceivers = pc()->GetTransceivers(); + std::set active_renderers; + for (auto& transceiver : transceivers) { + // Note - we don't check for direction here. This function is called + // before direction is set, and in that case, we should not remove + // the renderer. + if (transceiver->receiver()->media_type() == cricket::MEDIA_TYPE_VIDEO) { + active_renderers.insert(transceiver->receiver()->track()->id()); + } + } + for (auto it = fake_video_renderers_.begin(); + it != fake_video_renderers_.end();) { + // Remove fake video renderers belonging to any non-active transceivers. + if (!active_renderers.count(it->first)) { + it = fake_video_renderers_.erase(it); + } else { + it++; + } + } + } + + // Simulate sending a blob of SDP with delay |signaling_delay_ms_| (0 by + // default). + void SendSdpMessage(SdpType type, const std::string& msg) { + if (signaling_delay_ms_ == 0) { + RelaySdpMessageIfReceiverExists(type, msg); + } else { + rtc::Thread::Current()->PostDelayedTask( + ToQueuedTask(task_safety_.flag(), + [this, type, msg] { + RelaySdpMessageIfReceiverExists(type, msg); + }), + signaling_delay_ms_); + } + } + + void RelaySdpMessageIfReceiverExists(SdpType type, const std::string& msg) { + if (signaling_message_receiver_) { + signaling_message_receiver_->ReceiveSdpMessage(type, msg); + } + } + + // Simulate trickling an ICE candidate with delay |signaling_delay_ms_| (0 by + // default). + void SendIceMessage(const std::string& sdp_mid, + int sdp_mline_index, + const std::string& msg) { + if (signaling_delay_ms_ == 0) { + RelayIceMessageIfReceiverExists(sdp_mid, sdp_mline_index, msg); + } else { + rtc::Thread::Current()->PostDelayedTask( + ToQueuedTask(task_safety_.flag(), + [this, sdp_mid, sdp_mline_index, msg] { + RelayIceMessageIfReceiverExists(sdp_mid, + sdp_mline_index, msg); + }), + signaling_delay_ms_); + } + } + + void RelayIceMessageIfReceiverExists(const std::string& sdp_mid, + int sdp_mline_index, + const std::string& msg) { + if (signaling_message_receiver_) { + signaling_message_receiver_->ReceiveIceMessage(sdp_mid, sdp_mline_index, + msg); + } + } + + // SignalingMessageReceiver callbacks. + void ReceiveSdpMessage(SdpType type, const std::string& msg) override { + if (type == SdpType::kOffer) { + HandleIncomingOffer(msg); + } else { + HandleIncomingAnswer(msg); + } + } + + void ReceiveIceMessage(const std::string& sdp_mid, + int sdp_mline_index, + const std::string& msg) override { + RTC_LOG(LS_INFO) << debug_name_ << ": ReceiveIceMessage"; + std::unique_ptr candidate( + webrtc::CreateIceCandidate(sdp_mid, sdp_mline_index, msg, nullptr)); + EXPECT_TRUE(pc()->AddIceCandidate(candidate.get())); + } + + // PeerConnectionObserver callbacks. + void OnSignalingChange( + webrtc::PeerConnectionInterface::SignalingState new_state) override { + EXPECT_EQ(pc()->signaling_state(), new_state); + peer_connection_signaling_state_history_.push_back(new_state); + } + void OnAddTrack(rtc::scoped_refptr receiver, + const std::vector>& + streams) override { + if (receiver->media_type() == cricket::MEDIA_TYPE_VIDEO) { + rtc::scoped_refptr video_track( + static_cast(receiver->track().get())); + ASSERT_TRUE(fake_video_renderers_.find(video_track->id()) == + fake_video_renderers_.end()); + fake_video_renderers_[video_track->id()] = + std::make_unique(video_track); + } + } + void OnRemoveTrack( + rtc::scoped_refptr receiver) override { + if (receiver->media_type() == cricket::MEDIA_TYPE_VIDEO) { + auto it = fake_video_renderers_.find(receiver->track()->id()); + if (it != fake_video_renderers_.end()) { + fake_video_renderers_.erase(it); + } else { + RTC_LOG(LS_ERROR) << "OnRemoveTrack called for non-active renderer"; + } + } + } + void OnRenegotiationNeeded() override {} + void OnIceConnectionChange( + webrtc::PeerConnectionInterface::IceConnectionState new_state) override { + EXPECT_EQ(pc()->ice_connection_state(), new_state); + ice_connection_state_history_.push_back(new_state); + } + void OnStandardizedIceConnectionChange( + webrtc::PeerConnectionInterface::IceConnectionState new_state) override { + standardized_ice_connection_state_history_.push_back(new_state); + } + void OnConnectionChange( + webrtc::PeerConnectionInterface::PeerConnectionState new_state) override { + peer_connection_state_history_.push_back(new_state); + } + + void OnIceGatheringChange( + webrtc::PeerConnectionInterface::IceGatheringState new_state) override { + EXPECT_EQ(pc()->ice_gathering_state(), new_state); + ice_gathering_state_history_.push_back(new_state); + } + + void OnIceSelectedCandidatePairChanged( + const cricket::CandidatePairChangeEvent& event) { + ice_candidate_pair_change_history_.push_back(event); + } + + void OnIceCandidate(const webrtc::IceCandidateInterface* candidate) override { + RTC_LOG(LS_INFO) << debug_name_ << ": OnIceCandidate"; + + if (remote_async_resolver_) { + const auto& local_candidate = candidate->candidate(); + if (local_candidate.address().IsUnresolvedIP()) { + RTC_DCHECK(local_candidate.type() == cricket::LOCAL_PORT_TYPE); + rtc::SocketAddress resolved_addr(local_candidate.address()); + const auto resolved_ip = mdns_responder_->GetMappedAddressForName( + local_candidate.address().hostname()); + RTC_DCHECK(!resolved_ip.IsNil()); + resolved_addr.SetResolvedIP(resolved_ip); + EXPECT_CALL(*remote_async_resolver_, GetResolvedAddress(_, _)) + .WillOnce(DoAll(SetArgPointee<1>(resolved_addr), Return(true))); + EXPECT_CALL(*remote_async_resolver_, Destroy(_)); + } + } + + std::string ice_sdp; + EXPECT_TRUE(candidate->ToString(&ice_sdp)); + if (signaling_message_receiver_ == nullptr || !signal_ice_candidates_) { + // Remote party may be deleted. + return; + } + SendIceMessage(candidate->sdp_mid(), candidate->sdp_mline_index(), ice_sdp); + last_candidate_gathered_ = candidate->candidate(); + } + void OnIceCandidateError(const std::string& address, + int port, + const std::string& url, + int error_code, + const std::string& error_text) override { + error_event_ = cricket::IceCandidateErrorEvent(address, port, url, + error_code, error_text); + } + void OnDataChannel( + rtc::scoped_refptr data_channel) override { + RTC_LOG(LS_INFO) << debug_name_ << ": OnDataChannel"; + data_channel_ = data_channel; + data_observer_.reset(new MockDataChannelObserver(data_channel)); + } + + std::string debug_name_; + + std::unique_ptr fake_network_manager_; + // Reference to the mDNS responder owned by |fake_network_manager_| after set. + webrtc::FakeMdnsResponder* mdns_responder_ = nullptr; + + rtc::scoped_refptr peer_connection_; + rtc::scoped_refptr + peer_connection_factory_; + + cricket::PortAllocator* port_allocator_; + // Needed to keep track of number of frames sent. + rtc::scoped_refptr fake_audio_capture_module_; + // Needed to keep track of number of frames received. + std::map> + fake_video_renderers_; + // Needed to ensure frames aren't received for removed tracks. + std::vector> + removed_fake_video_renderers_; + + // For remote peer communication. + SignalingMessageReceiver* signaling_message_receiver_ = nullptr; + int signaling_delay_ms_ = 0; + bool signal_ice_candidates_ = true; + cricket::Candidate last_candidate_gathered_; + cricket::IceCandidateErrorEvent error_event_; + + // Store references to the video sources we've created, so that we can stop + // them, if required. + std::vector> + video_track_sources_; + // |local_video_renderer_| attached to the first created local video track. + std::unique_ptr local_video_renderer_; + + SdpSemantics sdp_semantics_; + PeerConnectionInterface::RTCOfferAnswerOptions offer_answer_options_; + std::function received_sdp_munger_; + std::function generated_sdp_munger_; + std::function remote_offer_handler_; + rtc::MockAsyncResolver* remote_async_resolver_ = nullptr; + rtc::scoped_refptr data_channel_; + std::unique_ptr data_observer_; + + std::vector> rtp_receiver_observers_; + + std::vector + ice_connection_state_history_; + std::vector + standardized_ice_connection_state_history_; + std::vector + peer_connection_state_history_; + std::vector + ice_gathering_state_history_; + std::vector + ice_candidate_pair_change_history_; + std::vector + peer_connection_signaling_state_history_; + webrtc::FakeRtcEventLogFactory* event_log_factory_; + + // Variables for tracking delay stats on an audio track + int audio_packets_stat_ = 0; + double audio_delay_stat_ = 0.0; + uint64_t audio_samples_stat_ = 0; + uint64_t audio_concealed_stat_ = 0; + std::string rtp_stats_id_; + std::string audio_track_stats_id_; + + ScopedTaskSafety task_safety_; + + friend class PeerConnectionIntegrationBaseTest; +}; + +class MockRtcEventLogOutput : public webrtc::RtcEventLogOutput { + public: + virtual ~MockRtcEventLogOutput() = default; + MOCK_METHOD(bool, IsActive, (), (const, override)); + MOCK_METHOD(bool, Write, (const std::string&), (override)); +}; + +// This helper object is used for both specifying how many audio/video frames +// are expected to be received for a caller/callee. It provides helper functions +// to specify these expectations. The object initially starts in a state of no +// expectations. +class MediaExpectations { + public: + enum ExpectFrames { + kExpectSomeFrames, + kExpectNoFrames, + kNoExpectation, + }; + + void ExpectBidirectionalAudioAndVideo() { + ExpectBidirectionalAudio(); + ExpectBidirectionalVideo(); + } + + void ExpectBidirectionalAudio() { + CallerExpectsSomeAudio(); + CalleeExpectsSomeAudio(); + } + + void ExpectNoAudio() { + CallerExpectsNoAudio(); + CalleeExpectsNoAudio(); + } + + void ExpectBidirectionalVideo() { + CallerExpectsSomeVideo(); + CalleeExpectsSomeVideo(); + } + + void ExpectNoVideo() { + CallerExpectsNoVideo(); + CalleeExpectsNoVideo(); + } + + void CallerExpectsSomeAudioAndVideo() { + CallerExpectsSomeAudio(); + CallerExpectsSomeVideo(); + } + + void CalleeExpectsSomeAudioAndVideo() { + CalleeExpectsSomeAudio(); + CalleeExpectsSomeVideo(); + } + + // Caller's audio functions. + void CallerExpectsSomeAudio( + int expected_audio_frames = kDefaultExpectedAudioFrameCount) { + caller_audio_expectation_ = kExpectSomeFrames; + caller_audio_frames_expected_ = expected_audio_frames; + } + + void CallerExpectsNoAudio() { + caller_audio_expectation_ = kExpectNoFrames; + caller_audio_frames_expected_ = 0; + } + + // Caller's video functions. + void CallerExpectsSomeVideo( + int expected_video_frames = kDefaultExpectedVideoFrameCount) { + caller_video_expectation_ = kExpectSomeFrames; + caller_video_frames_expected_ = expected_video_frames; + } + + void CallerExpectsNoVideo() { + caller_video_expectation_ = kExpectNoFrames; + caller_video_frames_expected_ = 0; + } + + // Callee's audio functions. + void CalleeExpectsSomeAudio( + int expected_audio_frames = kDefaultExpectedAudioFrameCount) { + callee_audio_expectation_ = kExpectSomeFrames; + callee_audio_frames_expected_ = expected_audio_frames; + } + + void CalleeExpectsNoAudio() { + callee_audio_expectation_ = kExpectNoFrames; + callee_audio_frames_expected_ = 0; + } + + // Callee's video functions. + void CalleeExpectsSomeVideo( + int expected_video_frames = kDefaultExpectedVideoFrameCount) { + callee_video_expectation_ = kExpectSomeFrames; + callee_video_frames_expected_ = expected_video_frames; + } + + void CalleeExpectsNoVideo() { + callee_video_expectation_ = kExpectNoFrames; + callee_video_frames_expected_ = 0; + } + + ExpectFrames caller_audio_expectation_ = kNoExpectation; + ExpectFrames caller_video_expectation_ = kNoExpectation; + ExpectFrames callee_audio_expectation_ = kNoExpectation; + ExpectFrames callee_video_expectation_ = kNoExpectation; + int caller_audio_frames_expected_ = 0; + int caller_video_frames_expected_ = 0; + int callee_audio_frames_expected_ = 0; + int callee_video_frames_expected_ = 0; +}; + +class MockIceTransport : public webrtc::IceTransportInterface { + public: + MockIceTransport(const std::string& name, int component) + : internal_(std::make_unique( + name, + component, + nullptr /* network_thread */)) {} + ~MockIceTransport() = default; + cricket::IceTransportInternal* internal() { return internal_.get(); } + + private: + std::unique_ptr internal_; +}; + +class MockIceTransportFactory : public IceTransportFactory { + public: + ~MockIceTransportFactory() override = default; + rtc::scoped_refptr CreateIceTransport( + const std::string& transport_name, + int component, + IceTransportInit init) { + RecordIceTransportCreated(); + return rtc::make_ref_counted(transport_name, component); + } + MOCK_METHOD(void, RecordIceTransportCreated, ()); +}; + +// Tests two PeerConnections connecting to each other end-to-end, using a +// virtual network, fake A/V capture and fake encoder/decoders. The +// PeerConnections share the threads/socket servers, but use separate versions +// of everything else (including "PeerConnectionFactory"s). +class PeerConnectionIntegrationBaseTest : public ::testing::Test { + public: + PeerConnectionIntegrationBaseTest( + SdpSemantics sdp_semantics, + absl::optional field_trials = absl::nullopt) + : sdp_semantics_(sdp_semantics), + ss_(new rtc::VirtualSocketServer()), + fss_(new rtc::FirewallSocketServer(ss_.get())), + network_thread_(new rtc::Thread(fss_.get())), + worker_thread_(rtc::Thread::Create()), + field_trials_(field_trials.has_value() + ? new test::ScopedFieldTrials(*field_trials) + : nullptr) { + network_thread_->SetName("PCNetworkThread", this); + worker_thread_->SetName("PCWorkerThread", this); + RTC_CHECK(network_thread_->Start()); + RTC_CHECK(worker_thread_->Start()); + webrtc::metrics::Reset(); + } + + ~PeerConnectionIntegrationBaseTest() { + // The PeerConnections should be deleted before the TurnCustomizers. + // A TurnPort is created with a raw pointer to a TurnCustomizer. The + // TurnPort has the same lifetime as the PeerConnection, so it's expected + // that the TurnCustomizer outlives the life of the PeerConnection or else + // when Send() is called it will hit a seg fault. + if (caller_) { + caller_->set_signaling_message_receiver(nullptr); + caller_->pc()->Close(); + delete SetCallerPcWrapperAndReturnCurrent(nullptr); + } + if (callee_) { + callee_->set_signaling_message_receiver(nullptr); + callee_->pc()->Close(); + delete SetCalleePcWrapperAndReturnCurrent(nullptr); + } + + // If turn servers were created for the test they need to be destroyed on + // the network thread. + network_thread()->Invoke(RTC_FROM_HERE, [this] { + turn_servers_.clear(); + turn_customizers_.clear(); + }); + } + + bool SignalingStateStable() { + return caller_->SignalingStateStable() && callee_->SignalingStateStable(); + } + + bool DtlsConnected() { + // TODO(deadbeef): kIceConnectionConnected currently means both ICE and DTLS + // are connected. This is an important distinction. Once we have separate + // ICE and DTLS state, this check needs to use the DTLS state. + return (callee()->ice_connection_state() == + webrtc::PeerConnectionInterface::kIceConnectionConnected || + callee()->ice_connection_state() == + webrtc::PeerConnectionInterface::kIceConnectionCompleted) && + (caller()->ice_connection_state() == + webrtc::PeerConnectionInterface::kIceConnectionConnected || + caller()->ice_connection_state() == + webrtc::PeerConnectionInterface::kIceConnectionCompleted); + } + + // When |event_log_factory| is null, the default implementation of the event + // log factory will be used. + std::unique_ptr CreatePeerConnectionWrapper( + const std::string& debug_name, + const PeerConnectionFactory::Options* options, + const RTCConfiguration* config, + webrtc::PeerConnectionDependencies dependencies, + std::unique_ptr event_log_factory, + bool reset_encoder_factory, + bool reset_decoder_factory) { + RTCConfiguration modified_config; + if (config) { + modified_config = *config; + } + modified_config.sdp_semantics = sdp_semantics_; + if (!dependencies.cert_generator) { + dependencies.cert_generator = + std::make_unique(); + } + std::unique_ptr client( + new PeerConnectionIntegrationWrapper(debug_name)); + + if (!client->Init(options, &modified_config, std::move(dependencies), + network_thread_.get(), worker_thread_.get(), + std::move(event_log_factory), reset_encoder_factory, + reset_decoder_factory)) { + return nullptr; + } + return client; + } + + std::unique_ptr + CreatePeerConnectionWrapperWithFakeRtcEventLog( + const std::string& debug_name, + const PeerConnectionFactory::Options* options, + const RTCConfiguration* config, + webrtc::PeerConnectionDependencies dependencies) { + return CreatePeerConnectionWrapper( + debug_name, options, config, std::move(dependencies), + std::make_unique(), + /*reset_encoder_factory=*/false, + /*reset_decoder_factory=*/false); + } + + bool CreatePeerConnectionWrappers() { + return CreatePeerConnectionWrappersWithConfig( + PeerConnectionInterface::RTCConfiguration(), + PeerConnectionInterface::RTCConfiguration()); + } + + bool CreatePeerConnectionWrappersWithSdpSemantics( + SdpSemantics caller_semantics, + SdpSemantics callee_semantics) { + // Can't specify the sdp_semantics in the passed-in configuration since it + // will be overwritten by CreatePeerConnectionWrapper with whatever is + // stored in sdp_semantics_. So get around this by modifying the instance + // variable before calling CreatePeerConnectionWrapper for the caller and + // callee PeerConnections. + SdpSemantics original_semantics = sdp_semantics_; + sdp_semantics_ = caller_semantics; + caller_ = CreatePeerConnectionWrapper( + "Caller", nullptr, nullptr, webrtc::PeerConnectionDependencies(nullptr), + nullptr, + /*reset_encoder_factory=*/false, + /*reset_decoder_factory=*/false); + sdp_semantics_ = callee_semantics; + callee_ = CreatePeerConnectionWrapper( + "Callee", nullptr, nullptr, webrtc::PeerConnectionDependencies(nullptr), + nullptr, + /*reset_encoder_factory=*/false, + /*reset_decoder_factory=*/false); + sdp_semantics_ = original_semantics; + return caller_ && callee_; + } + + bool CreatePeerConnectionWrappersWithConfig( + const PeerConnectionInterface::RTCConfiguration& caller_config, + const PeerConnectionInterface::RTCConfiguration& callee_config) { + caller_ = CreatePeerConnectionWrapper( + "Caller", nullptr, &caller_config, + webrtc::PeerConnectionDependencies(nullptr), nullptr, + /*reset_encoder_factory=*/false, + /*reset_decoder_factory=*/false); + callee_ = CreatePeerConnectionWrapper( + "Callee", nullptr, &callee_config, + webrtc::PeerConnectionDependencies(nullptr), nullptr, + /*reset_encoder_factory=*/false, + /*reset_decoder_factory=*/false); + return caller_ && callee_; + } + + bool CreatePeerConnectionWrappersWithConfigAndDeps( + const PeerConnectionInterface::RTCConfiguration& caller_config, + webrtc::PeerConnectionDependencies caller_dependencies, + const PeerConnectionInterface::RTCConfiguration& callee_config, + webrtc::PeerConnectionDependencies callee_dependencies) { + caller_ = + CreatePeerConnectionWrapper("Caller", nullptr, &caller_config, + std::move(caller_dependencies), nullptr, + /*reset_encoder_factory=*/false, + /*reset_decoder_factory=*/false); + callee_ = + CreatePeerConnectionWrapper("Callee", nullptr, &callee_config, + std::move(callee_dependencies), nullptr, + /*reset_encoder_factory=*/false, + /*reset_decoder_factory=*/false); + return caller_ && callee_; + } + + bool CreatePeerConnectionWrappersWithOptions( + const PeerConnectionFactory::Options& caller_options, + const PeerConnectionFactory::Options& callee_options) { + caller_ = CreatePeerConnectionWrapper( + "Caller", &caller_options, nullptr, + webrtc::PeerConnectionDependencies(nullptr), nullptr, + /*reset_encoder_factory=*/false, + /*reset_decoder_factory=*/false); + callee_ = CreatePeerConnectionWrapper( + "Callee", &callee_options, nullptr, + webrtc::PeerConnectionDependencies(nullptr), nullptr, + /*reset_encoder_factory=*/false, + /*reset_decoder_factory=*/false); + return caller_ && callee_; + } + + bool CreatePeerConnectionWrappersWithFakeRtcEventLog() { + PeerConnectionInterface::RTCConfiguration default_config; + caller_ = CreatePeerConnectionWrapperWithFakeRtcEventLog( + "Caller", nullptr, &default_config, + webrtc::PeerConnectionDependencies(nullptr)); + callee_ = CreatePeerConnectionWrapperWithFakeRtcEventLog( + "Callee", nullptr, &default_config, + webrtc::PeerConnectionDependencies(nullptr)); + return caller_ && callee_; + } + + std::unique_ptr + CreatePeerConnectionWrapperWithAlternateKey() { + std::unique_ptr cert_generator( + new FakeRTCCertificateGenerator()); + cert_generator->use_alternate_key(); + + webrtc::PeerConnectionDependencies dependencies(nullptr); + dependencies.cert_generator = std::move(cert_generator); + return CreatePeerConnectionWrapper("New Peer", nullptr, nullptr, + std::move(dependencies), nullptr, + /*reset_encoder_factory=*/false, + /*reset_decoder_factory=*/false); + } + + bool CreateOneDirectionalPeerConnectionWrappers(bool caller_to_callee) { + caller_ = CreatePeerConnectionWrapper( + "Caller", nullptr, nullptr, webrtc::PeerConnectionDependencies(nullptr), + nullptr, + /*reset_encoder_factory=*/!caller_to_callee, + /*reset_decoder_factory=*/caller_to_callee); + callee_ = CreatePeerConnectionWrapper( + "Callee", nullptr, nullptr, webrtc::PeerConnectionDependencies(nullptr), + nullptr, + /*reset_encoder_factory=*/caller_to_callee, + /*reset_decoder_factory=*/!caller_to_callee); + return caller_ && callee_; + } + + cricket::TestTurnServer* CreateTurnServer( + rtc::SocketAddress internal_address, + rtc::SocketAddress external_address, + cricket::ProtocolType type = cricket::ProtocolType::PROTO_UDP, + const std::string& common_name = "test turn server") { + rtc::Thread* thread = network_thread(); + std::unique_ptr turn_server = + network_thread()->Invoke>( + RTC_FROM_HERE, + [thread, internal_address, external_address, type, common_name] { + return std::make_unique( + thread, internal_address, external_address, type, + /*ignore_bad_certs=*/true, common_name); + }); + turn_servers_.push_back(std::move(turn_server)); + // Interactions with the turn server should be done on the network thread. + return turn_servers_.back().get(); + } + + cricket::TestTurnCustomizer* CreateTurnCustomizer() { + std::unique_ptr turn_customizer = + network_thread()->Invoke>( + RTC_FROM_HERE, + [] { return std::make_unique(); }); + turn_customizers_.push_back(std::move(turn_customizer)); + // Interactions with the turn customizer should be done on the network + // thread. + return turn_customizers_.back().get(); + } + + // Checks that the function counters for a TestTurnCustomizer are greater than + // 0. + void ExpectTurnCustomizerCountersIncremented( + cricket::TestTurnCustomizer* turn_customizer) { + unsigned int allow_channel_data_counter = + network_thread()->Invoke( + RTC_FROM_HERE, [turn_customizer] { + return turn_customizer->allow_channel_data_cnt_; + }); + EXPECT_GT(allow_channel_data_counter, 0u); + unsigned int modify_counter = network_thread()->Invoke( + RTC_FROM_HERE, + [turn_customizer] { return turn_customizer->modify_cnt_; }); + EXPECT_GT(modify_counter, 0u); + } + + // Once called, SDP blobs and ICE candidates will be automatically signaled + // between PeerConnections. + void ConnectFakeSignaling() { + caller_->set_signaling_message_receiver(callee_.get()); + callee_->set_signaling_message_receiver(caller_.get()); + } + + // Once called, SDP blobs will be automatically signaled between + // PeerConnections. Note that ICE candidates will not be signaled unless they + // are in the exchanged SDP blobs. + void ConnectFakeSignalingForSdpOnly() { + ConnectFakeSignaling(); + SetSignalIceCandidates(false); + } + + void SetSignalingDelayMs(int delay_ms) { + caller_->set_signaling_delay_ms(delay_ms); + callee_->set_signaling_delay_ms(delay_ms); + } + + void SetSignalIceCandidates(bool signal) { + caller_->set_signal_ice_candidates(signal); + callee_->set_signal_ice_candidates(signal); + } + + // Messages may get lost on the unreliable DataChannel, so we send multiple + // times to avoid test flakiness. + void SendRtpDataWithRetries(webrtc::DataChannelInterface* dc, + const std::string& data, + int retries) { + for (int i = 0; i < retries; ++i) { + dc->Send(DataBuffer(data)); + } + } + + rtc::Thread* network_thread() { return network_thread_.get(); } + + rtc::VirtualSocketServer* virtual_socket_server() { return ss_.get(); } + + PeerConnectionIntegrationWrapper* caller() { return caller_.get(); } + + // Set the |caller_| to the |wrapper| passed in and return the + // original |caller_|. + PeerConnectionIntegrationWrapper* SetCallerPcWrapperAndReturnCurrent( + PeerConnectionIntegrationWrapper* wrapper) { + PeerConnectionIntegrationWrapper* old = caller_.release(); + caller_.reset(wrapper); + return old; + } + + PeerConnectionIntegrationWrapper* callee() { return callee_.get(); } + + // Set the |callee_| to the |wrapper| passed in and return the + // original |callee_|. + PeerConnectionIntegrationWrapper* SetCalleePcWrapperAndReturnCurrent( + PeerConnectionIntegrationWrapper* wrapper) { + PeerConnectionIntegrationWrapper* old = callee_.release(); + callee_.reset(wrapper); + return old; + } + + void SetPortAllocatorFlags(uint32_t caller_flags, uint32_t callee_flags) { + network_thread()->Invoke(RTC_FROM_HERE, [this, caller_flags] { + caller()->port_allocator()->set_flags(caller_flags); + }); + network_thread()->Invoke(RTC_FROM_HERE, [this, callee_flags] { + callee()->port_allocator()->set_flags(callee_flags); + }); + } + + rtc::FirewallSocketServer* firewall() const { return fss_.get(); } + + // Expects the provided number of new frames to be received within + // kMaxWaitForFramesMs. The new expected frames are specified in + // |media_expectations|. Returns false if any of the expectations were + // not met. + bool ExpectNewFrames(const MediaExpectations& media_expectations) { + // Make sure there are no bogus tracks confusing the issue. + caller()->RemoveUnusedVideoRenderers(); + callee()->RemoveUnusedVideoRenderers(); + // First initialize the expected frame counts based upon the current + // frame count. + int total_caller_audio_frames_expected = caller()->audio_frames_received(); + if (media_expectations.caller_audio_expectation_ == + MediaExpectations::kExpectSomeFrames) { + total_caller_audio_frames_expected += + media_expectations.caller_audio_frames_expected_; + } + int total_caller_video_frames_expected = + caller()->min_video_frames_received_per_track(); + if (media_expectations.caller_video_expectation_ == + MediaExpectations::kExpectSomeFrames) { + total_caller_video_frames_expected += + media_expectations.caller_video_frames_expected_; + } + int total_callee_audio_frames_expected = callee()->audio_frames_received(); + if (media_expectations.callee_audio_expectation_ == + MediaExpectations::kExpectSomeFrames) { + total_callee_audio_frames_expected += + media_expectations.callee_audio_frames_expected_; + } + int total_callee_video_frames_expected = + callee()->min_video_frames_received_per_track(); + if (media_expectations.callee_video_expectation_ == + MediaExpectations::kExpectSomeFrames) { + total_callee_video_frames_expected += + media_expectations.callee_video_frames_expected_; + } + + // Wait for the expected frames. + EXPECT_TRUE_WAIT(caller()->audio_frames_received() >= + total_caller_audio_frames_expected && + caller()->min_video_frames_received_per_track() >= + total_caller_video_frames_expected && + callee()->audio_frames_received() >= + total_callee_audio_frames_expected && + callee()->min_video_frames_received_per_track() >= + total_callee_video_frames_expected, + kMaxWaitForFramesMs); + bool expectations_correct = + caller()->audio_frames_received() >= + total_caller_audio_frames_expected && + caller()->min_video_frames_received_per_track() >= + total_caller_video_frames_expected && + callee()->audio_frames_received() >= + total_callee_audio_frames_expected && + callee()->min_video_frames_received_per_track() >= + total_callee_video_frames_expected; + + // After the combined wait, print out a more detailed message upon + // failure. + EXPECT_GE(caller()->audio_frames_received(), + total_caller_audio_frames_expected); + EXPECT_GE(caller()->min_video_frames_received_per_track(), + total_caller_video_frames_expected); + EXPECT_GE(callee()->audio_frames_received(), + total_callee_audio_frames_expected); + EXPECT_GE(callee()->min_video_frames_received_per_track(), + total_callee_video_frames_expected); + + // We want to make sure nothing unexpected was received. + if (media_expectations.caller_audio_expectation_ == + MediaExpectations::kExpectNoFrames) { + EXPECT_EQ(caller()->audio_frames_received(), + total_caller_audio_frames_expected); + if (caller()->audio_frames_received() != + total_caller_audio_frames_expected) { + expectations_correct = false; + } + } + if (media_expectations.caller_video_expectation_ == + MediaExpectations::kExpectNoFrames) { + EXPECT_EQ(caller()->min_video_frames_received_per_track(), + total_caller_video_frames_expected); + if (caller()->min_video_frames_received_per_track() != + total_caller_video_frames_expected) { + expectations_correct = false; + } + } + if (media_expectations.callee_audio_expectation_ == + MediaExpectations::kExpectNoFrames) { + EXPECT_EQ(callee()->audio_frames_received(), + total_callee_audio_frames_expected); + if (callee()->audio_frames_received() != + total_callee_audio_frames_expected) { + expectations_correct = false; + } + } + if (media_expectations.callee_video_expectation_ == + MediaExpectations::kExpectNoFrames) { + EXPECT_EQ(callee()->min_video_frames_received_per_track(), + total_callee_video_frames_expected); + if (callee()->min_video_frames_received_per_track() != + total_callee_video_frames_expected) { + expectations_correct = false; + } + } + return expectations_correct; + } + + void ClosePeerConnections() { + if (caller()) + caller()->pc()->Close(); + if (callee()) + callee()->pc()->Close(); + } + + void TestNegotiatedCipherSuite( + const PeerConnectionFactory::Options& caller_options, + const PeerConnectionFactory::Options& callee_options, + int expected_cipher_suite) { + ASSERT_TRUE(CreatePeerConnectionWrappersWithOptions(caller_options, + callee_options)); + ConnectFakeSignaling(); + caller()->AddAudioVideoTracks(); + callee()->AddAudioVideoTracks(); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(DtlsConnected(), kDefaultTimeout); + EXPECT_EQ_WAIT(rtc::SrtpCryptoSuiteToName(expected_cipher_suite), + caller()->OldGetStats()->SrtpCipher(), kDefaultTimeout); + // TODO(bugs.webrtc.org/9456): Fix it. + EXPECT_METRIC_EQ(1, webrtc::metrics::NumEvents( + "WebRTC.PeerConnection.SrtpCryptoSuite.Audio", + expected_cipher_suite)); + } + + void TestGcmNegotiationUsesCipherSuite(bool local_gcm_enabled, + bool remote_gcm_enabled, + bool aes_ctr_enabled, + int expected_cipher_suite) { + PeerConnectionFactory::Options caller_options; + caller_options.crypto_options.srtp.enable_gcm_crypto_suites = + local_gcm_enabled; + caller_options.crypto_options.srtp.enable_aes128_sha1_80_crypto_cipher = + aes_ctr_enabled; + PeerConnectionFactory::Options callee_options; + callee_options.crypto_options.srtp.enable_gcm_crypto_suites = + remote_gcm_enabled; + callee_options.crypto_options.srtp.enable_aes128_sha1_80_crypto_cipher = + aes_ctr_enabled; + TestNegotiatedCipherSuite(caller_options, callee_options, + expected_cipher_suite); + } + + protected: + SdpSemantics sdp_semantics_; + + private: + // |ss_| is used by |network_thread_| so it must be destroyed later. + std::unique_ptr ss_; + std::unique_ptr fss_; + // |network_thread_| and |worker_thread_| are used by both + // |caller_| and |callee_| so they must be destroyed + // later. + std::unique_ptr network_thread_; + std::unique_ptr worker_thread_; + // The turn servers and turn customizers should be accessed & deleted on the + // network thread to avoid a race with the socket read/write that occurs + // on the network thread. + std::vector> turn_servers_; + std::vector> turn_customizers_; + std::unique_ptr caller_; + std::unique_ptr callee_; + std::unique_ptr field_trials_; +}; + +} // namespace webrtc + +#endif // PC_TEST_INTEGRATION_TEST_HELPERS_H_ diff --git a/pc/test/mock_channel_interface.h b/pc/test/mock_channel_interface.h index 52404f1dea..6faba5c8fc 100644 --- a/pc/test/mock_channel_interface.h +++ b/pc/test/mock_channel_interface.h @@ -28,11 +28,10 @@ class MockChannelInterface : public cricket::ChannelInterface { MOCK_METHOD(MediaChannel*, media_channel, (), (const, override)); MOCK_METHOD(const std::string&, transport_name, (), (const, override)); MOCK_METHOD(const std::string&, content_name, (), (const, override)); - MOCK_METHOD(bool, enabled, (), (const, override)); - MOCK_METHOD(bool, Enable, (bool), (override)); - MOCK_METHOD(sigslot::signal1&, - SignalFirstPacketReceived, - (), + MOCK_METHOD(void, Enable, (bool), (override)); + MOCK_METHOD(void, + SetFirstPacketReceivedCallback, + (std::function), (override)); MOCK_METHOD(bool, SetLocalContent, @@ -46,8 +45,7 @@ class MockChannelInterface : public cricket::ChannelInterface { webrtc::SdpType, std::string*), (override)); - MOCK_METHOD(void, SetPayloadTypeDemuxingEnabled, (bool), (override)); - MOCK_METHOD(bool, UpdateRtpTransport, (std::string*), (override)); + MOCK_METHOD(bool, SetPayloadTypeDemuxingEnabled, (bool), (override)); MOCK_METHOD(const std::vector&, local_streams, (), @@ -60,10 +58,6 @@ class MockChannelInterface : public cricket::ChannelInterface { SetRtpTransport, (webrtc::RtpTransportInternal*), (override)); - MOCK_METHOD(RtpHeaderExtensions, - GetNegotiatedRtpHeaderExtensions, - (), - (const)); }; } // namespace cricket diff --git a/pc/test/mock_delayable.h b/pc/test/mock_delayable.h deleted file mode 100644 index bef07c1970..0000000000 --- a/pc/test/mock_delayable.h +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef PC_TEST_MOCK_DELAYABLE_H_ -#define PC_TEST_MOCK_DELAYABLE_H_ - -#include - -#include "absl/types/optional.h" -#include "media/base/delayable.h" -#include "test/gmock.h" - -namespace webrtc { - -class MockDelayable : public cricket::Delayable { - public: - MOCK_METHOD(bool, - SetBaseMinimumPlayoutDelayMs, - (uint32_t ssrc, int delay_ms), - (override)); - MOCK_METHOD(absl::optional, - GetBaseMinimumPlayoutDelayMs, - (uint32_t ssrc), - (const, override)); -}; - -} // namespace webrtc - -#endif // PC_TEST_MOCK_DELAYABLE_H_ diff --git a/pc/test/mock_peer_connection_observers.h b/pc/test/mock_peer_connection_observers.h index 7766297843..413339dbf7 100644 --- a/pc/test/mock_peer_connection_observers.h +++ b/pc/test/mock_peer_connection_observers.h @@ -286,7 +286,7 @@ class MockSetSessionDescriptionObserver : public webrtc::SetSessionDescriptionObserver { public: static rtc::scoped_refptr Create() { - return new rtc::RefCountedObject(); + return rtc::make_ref_counted(); } MockSetSessionDescriptionObserver() @@ -351,32 +351,51 @@ class FakeSetRemoteDescriptionObserver class MockDataChannelObserver : public webrtc::DataChannelObserver { public: + struct Message { + std::string data; + bool binary; + }; + explicit MockDataChannelObserver(webrtc::DataChannelInterface* channel) : channel_(channel) { channel_->RegisterObserver(this); - state_ = channel_->state(); + states_.push_back(channel_->state()); } virtual ~MockDataChannelObserver() { channel_->UnregisterObserver(); } void OnBufferedAmountChange(uint64_t previous_amount) override {} - void OnStateChange() override { state_ = channel_->state(); } + void OnStateChange() override { states_.push_back(channel_->state()); } void OnMessage(const DataBuffer& buffer) override { messages_.push_back( - std::string(buffer.data.data(), buffer.data.size())); + {std::string(buffer.data.data(), buffer.data.size()), + buffer.binary}); } - bool IsOpen() const { return state_ == DataChannelInterface::kOpen; } - std::vector messages() const { return messages_; } + bool IsOpen() const { return state() == DataChannelInterface::kOpen; } + std::vector messages() const { return messages_; } std::string last_message() const { - return messages_.empty() ? std::string() : messages_.back(); + if (messages_.empty()) + return {}; + + return messages_.back().data; + } + bool last_message_is_binary() const { + if (messages_.empty()) + return false; + return messages_.back().binary; } size_t received_message_count() const { return messages_.size(); } + DataChannelInterface::DataState state() const { return states_.back(); } + const std::vector& states() const { + return states_; + } + private: rtc::scoped_refptr channel_; - DataChannelInterface::DataState state_; - std::vector messages_; + std::vector states_; + std::vector messages_; }; class MockStatsObserver : public webrtc::StatsObserver { diff --git a/pc/test/peer_connection_test_wrapper.cc b/pc/test/peer_connection_test_wrapper.cc index 946f459f3b..8fdfb1bbb8 100644 --- a/pc/test/peer_connection_test_wrapper.cc +++ b/pc/test/peer_connection_test_wrapper.cc @@ -20,6 +20,7 @@ #include "absl/types/optional.h" #include "api/audio/audio_mixer.h" #include "api/create_peerconnection_factory.h" +#include "api/sequence_checker.h" #include "api/video_codecs/builtin_video_decoder_factory.h" #include "api/video_codecs/builtin_video_encoder_factory.h" #include "api/video_codecs/video_decoder_factory.h" @@ -37,7 +38,6 @@ #include "rtc_base/ref_counted_object.h" #include "rtc_base/rtc_certificate_generator.h" #include "rtc_base/string_encode.h" -#include "rtc_base/thread_checker.h" #include "rtc_base/time_utils.h" #include "test/gtest.h" @@ -123,17 +123,31 @@ bool PeerConnectionTestWrapper::CreatePc( std::unique_ptr cert_generator( new FakeRTCCertificateGenerator()); - peer_connection_ = peer_connection_factory_->CreatePeerConnection( - config, std::move(port_allocator), std::move(cert_generator), this); - - return peer_connection_.get() != NULL; + webrtc::PeerConnectionDependencies deps(this); + deps.allocator = std::move(port_allocator); + deps.cert_generator = std::move(cert_generator); + auto result = peer_connection_factory_->CreatePeerConnectionOrError( + config, std::move(deps)); + if (result.ok()) { + peer_connection_ = result.MoveValue(); + return true; + } else { + return false; + } } rtc::scoped_refptr PeerConnectionTestWrapper::CreateDataChannel( const std::string& label, const webrtc::DataChannelInit& init) { - return peer_connection_->CreateDataChannel(label, &init); + auto result = peer_connection_->CreateDataChannelOrError(label, &init); + if (!result.ok()) { + RTC_LOG(LS_ERROR) << "CreateDataChannel failed: " + << ToString(result.error().type()) << " " + << result.error().message(); + return nullptr; + } + return result.MoveValue(); } void PeerConnectionTestWrapper::WaitForNegotiation() { @@ -221,8 +235,7 @@ void PeerConnectionTestWrapper::SetLocalDescription(SdpType type, << ": SetLocalDescription " << webrtc::SdpTypeToString(type) << " " << sdp; - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); + auto observer = rtc::make_ref_counted(); peer_connection_->SetLocalDescription( observer, webrtc::CreateSessionDescription(type, sdp).release()); } @@ -233,8 +246,7 @@ void PeerConnectionTestWrapper::SetRemoteDescription(SdpType type, << ": SetRemoteDescription " << webrtc::SdpTypeToString(type) << " " << sdp; - rtc::scoped_refptr observer( - new rtc::RefCountedObject()); + auto observer = rtc::make_ref_counted(); peer_connection_->SetRemoteDescription( observer, webrtc::CreateSessionDescription(type, sdp).release()); } @@ -331,9 +343,8 @@ PeerConnectionTestWrapper::GetUserMedia( config.frame_interval_ms = 100; config.timestamp_offset_ms = rtc::TimeMillis(); - rtc::scoped_refptr source = - new rtc::RefCountedObject( - config, /* remote */ false); + auto source = rtc::make_ref_counted( + config, /* remote */ false); std::string videotrack_label = stream_id + kVideoTrackLabelBase; rtc::scoped_refptr video_track( diff --git a/pc/test/peer_connection_test_wrapper.h b/pc/test/peer_connection_test_wrapper.h index 92599b78ab..4abf6c9ea5 100644 --- a/pc/test/peer_connection_test_wrapper.h +++ b/pc/test/peer_connection_test_wrapper.h @@ -25,11 +25,11 @@ #include "api/rtc_error.h" #include "api/rtp_receiver_interface.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "pc/test/fake_audio_capture_module.h" #include "pc/test/fake_video_track_renderer.h" #include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/thread.h" -#include "rtc_base/thread_checker.h" class PeerConnectionTestWrapper : public webrtc::PeerConnectionObserver, @@ -120,7 +120,7 @@ class PeerConnectionTestWrapper std::string name_; rtc::Thread* const network_thread_; rtc::Thread* const worker_thread_; - rtc::ThreadChecker pc_thread_checker_; + webrtc::SequenceChecker pc_thread_checker_; rtc::scoped_refptr peer_connection_; rtc::scoped_refptr peer_connection_factory_; diff --git a/pc/test/rtc_stats_obtainer.h b/pc/test/rtc_stats_obtainer.h index 95201f6649..4da23c6628 100644 --- a/pc/test/rtc_stats_obtainer.h +++ b/pc/test/rtc_stats_obtainer.h @@ -20,8 +20,7 @@ class RTCStatsObtainer : public RTCStatsCollectorCallback { public: static rtc::scoped_refptr Create( rtc::scoped_refptr* report_ptr = nullptr) { - return rtc::scoped_refptr( - new rtc::RefCountedObject(report_ptr)); + return rtc::make_ref_counted(report_ptr); } void OnStatsDelivered( @@ -43,7 +42,7 @@ class RTCStatsObtainer : public RTCStatsCollectorCallback { : report_ptr_(report_ptr) {} private: - rtc::ThreadChecker thread_checker_; + SequenceChecker thread_checker_; rtc::scoped_refptr report_; rtc::scoped_refptr* report_ptr_; }; diff --git a/pc/test/test_sdp_strings.h b/pc/test/test_sdp_strings.h index 849757d300..6394ac5f5e 100644 --- a/pc/test/test_sdp_strings.h +++ b/pc/test/test_sdp_strings.h @@ -60,7 +60,7 @@ static const char kFireFoxSdpOffer[] = "a=candidate:4 2 UDP 2113667326 10.0.254.2 58890 typ host\r\n" "a=candidate:5 2 UDP 1694302206 74.95.2.170 33611 typ srflx raddr" " 10.0.254.2 rport 58890\r\n" -#ifdef HAVE_SCTP +#ifdef WEBRTC_HAVE_SCTP "m=application 45536 DTLS/SCTP 5000\r\n" "c=IN IP4 74.95.2.170\r\n" "a=fmtp:5000 protocol=webrtc-datachannel;streams=16\r\n" diff --git a/pc/track_media_info_map.cc b/pc/track_media_info_map.cc index b3ec68bb27..66f4c461df 100644 --- a/pc/track_media_info_map.cc +++ b/pc/track_media_info_map.cc @@ -10,10 +10,15 @@ #include "pc/track_media_info_map.h" +#include #include #include #include +#include "api/media_types.h" +#include "api/rtp_parameters.h" +#include "media/base/stream_params.h" +#include "rtc_base/checks.h" #include "rtc_base/thread.h" namespace webrtc { diff --git a/pc/track_media_info_map.h b/pc/track_media_info_map.h index 542501eb16..c8c6da2701 100644 --- a/pc/track_media_info_map.h +++ b/pc/track_media_info_map.h @@ -11,12 +11,16 @@ #ifndef PC_TRACK_MEDIA_INFO_MAP_H_ #define PC_TRACK_MEDIA_INFO_MAP_H_ +#include + #include #include #include #include +#include "absl/types/optional.h" #include "api/media_stream_interface.h" +#include "api/scoped_refptr.h" #include "media/base/media_channel.h" #include "pc/rtp_receiver.h" #include "pc/rtp_sender.h" diff --git a/pc/track_media_info_map_unittest.cc b/pc/track_media_info_map_unittest.cc index 0cb1e0e277..1d5caacddb 100644 --- a/pc/track_media_info_map_unittest.cc +++ b/pc/track_media_info_map_unittest.cc @@ -31,6 +31,45 @@ namespace webrtc { namespace { +class MockVideoTrack : public VideoTrackInterface { + public: + // NotifierInterface + MOCK_METHOD(void, + RegisterObserver, + (ObserverInterface * observer), + (override)); + MOCK_METHOD(void, + UnregisterObserver, + (ObserverInterface * observer), + (override)); + + // MediaStreamTrackInterface + MOCK_METHOD(std::string, kind, (), (const, override)); + MOCK_METHOD(std::string, id, (), (const, override)); + MOCK_METHOD(bool, enabled, (), (const, override)); + MOCK_METHOD(bool, set_enabled, (bool enable), (override)); + MOCK_METHOD(TrackState, state, (), (const, override)); + + // VideoSourceInterface + MOCK_METHOD(void, + AddOrUpdateSink, + (rtc::VideoSinkInterface * sink, + const rtc::VideoSinkWants& wants), + (override)); + // RemoveSink must guarantee that at the time the method returns, + // there is no current and no future calls to VideoSinkInterface::OnFrame. + MOCK_METHOD(void, + RemoveSink, + (rtc::VideoSinkInterface * sink), + (override)); + + // VideoTrackInterface + MOCK_METHOD(VideoTrackSourceInterface*, GetSource, (), (const, override)); + + MOCK_METHOD(ContentHint, content_hint, (), (const, override)); + MOCK_METHOD(void, set_content_hint, (ContentHint hint), (override)); +}; + RtpParameters CreateRtpParametersWithSsrcs( std::initializer_list ssrcs) { RtpParameters params; @@ -52,8 +91,7 @@ rtc::scoped_refptr CreateMockRtpSender( } else { first_ssrc = 0; } - rtc::scoped_refptr sender( - new rtc::RefCountedObject()); + auto sender = rtc::make_ref_counted(); EXPECT_CALL(*sender, track()) .WillRepeatedly(::testing::Return(std::move(track))); EXPECT_CALL(*sender, ssrc()).WillRepeatedly(::testing::Return(first_ssrc)); @@ -69,8 +107,7 @@ rtc::scoped_refptr CreateMockRtpReceiver( cricket::MediaType media_type, std::initializer_list ssrcs, rtc::scoped_refptr track) { - rtc::scoped_refptr receiver( - new rtc::RefCountedObject()); + auto receiver = rtc::make_ref_counted(); EXPECT_CALL(*receiver, track()) .WillRepeatedly(::testing::Return(std::move(track))); EXPECT_CALL(*receiver, media_type()) @@ -81,23 +118,35 @@ rtc::scoped_refptr CreateMockRtpReceiver( return receiver; } +rtc::scoped_refptr CreateVideoTrack( + const std::string& id) { + return VideoTrack::Create(id, FakeVideoTrackSource::Create(false), + rtc::Thread::Current()); +} + +rtc::scoped_refptr CreateMockVideoTrack( + const std::string& id) { + auto track = rtc::make_ref_counted(); + EXPECT_CALL(*track, kind()) + .WillRepeatedly(::testing::Return(VideoTrack::kVideoKind)); + return track; +} + class TrackMediaInfoMapTest : public ::testing::Test { public: TrackMediaInfoMapTest() : TrackMediaInfoMapTest(true) {} - explicit TrackMediaInfoMapTest(bool use_current_thread) + explicit TrackMediaInfoMapTest(bool use_real_video_track) : voice_media_info_(new cricket::VoiceMediaInfo()), video_media_info_(new cricket::VideoMediaInfo()), local_audio_track_(AudioTrack::Create("LocalAudioTrack", nullptr)), remote_audio_track_(AudioTrack::Create("RemoteAudioTrack", nullptr)), - local_video_track_(VideoTrack::Create( - "LocalVideoTrack", - FakeVideoTrackSource::Create(false), - use_current_thread ? rtc::Thread::Current() : nullptr)), - remote_video_track_(VideoTrack::Create( - "RemoteVideoTrack", - FakeVideoTrackSource::Create(false), - use_current_thread ? rtc::Thread::Current() : nullptr)) {} + local_video_track_(use_real_video_track + ? CreateVideoTrack("LocalVideoTrack") + : CreateMockVideoTrack("LocalVideoTrack")), + remote_video_track_(use_real_video_track + ? CreateVideoTrack("RemoteVideoTrack") + : CreateMockVideoTrack("LocalVideoTrack")) {} ~TrackMediaInfoMapTest() { // If we have a map the ownership has been passed to the map, only delete if @@ -181,8 +230,8 @@ class TrackMediaInfoMapTest : public ::testing::Test { std::unique_ptr map_; rtc::scoped_refptr local_audio_track_; rtc::scoped_refptr remote_audio_track_; - rtc::scoped_refptr local_video_track_; - rtc::scoped_refptr remote_video_track_; + rtc::scoped_refptr local_video_track_; + rtc::scoped_refptr remote_video_track_; }; } // namespace diff --git a/pc/transceiver_list.cc b/pc/transceiver_list.cc index 5fe148a222..235c9af036 100644 --- a/pc/transceiver_list.cc +++ b/pc/transceiver_list.cc @@ -10,6 +10,8 @@ #include "pc/transceiver_list.h" +#include "rtc_base/checks.h" + namespace webrtc { void TransceiverStableState::set_newly_created() { @@ -34,8 +36,23 @@ void TransceiverStableState::SetRemoteStreamIdsIfUnset( } } +void TransceiverStableState::SetInitSendEncodings( + const std::vector& encodings) { + init_send_encodings_ = encodings; +} + +std::vector TransceiverList::ListInternal() const { + RTC_DCHECK_RUN_ON(&sequence_checker_); + std::vector internals; + for (auto transceiver : transceivers_) { + internals.push_back(transceiver->internal()); + } + return internals; +} + RtpTransceiverProxyRefPtr TransceiverList::FindBySender( rtc::scoped_refptr sender) const { + RTC_DCHECK_RUN_ON(&sequence_checker_); for (auto transceiver : transceivers_) { if (transceiver->sender() == sender) { return transceiver; @@ -46,6 +63,7 @@ RtpTransceiverProxyRefPtr TransceiverList::FindBySender( RtpTransceiverProxyRefPtr TransceiverList::FindByMid( const std::string& mid) const { + RTC_DCHECK_RUN_ON(&sequence_checker_); for (auto transceiver : transceivers_) { if (transceiver->mid() == mid) { return transceiver; @@ -56,6 +74,7 @@ RtpTransceiverProxyRefPtr TransceiverList::FindByMid( RtpTransceiverProxyRefPtr TransceiverList::FindByMLineIndex( size_t mline_index) const { + RTC_DCHECK_RUN_ON(&sequence_checker_); for (auto transceiver : transceivers_) { if (transceiver->internal()->mline_index() == mline_index) { return transceiver; diff --git a/pc/transceiver_list.h b/pc/transceiver_list.h index cd77d67f44..568c9c7e7a 100644 --- a/pc/transceiver_list.h +++ b/pc/transceiver_list.h @@ -11,12 +11,24 @@ #ifndef PC_TRANSCEIVER_LIST_H_ #define PC_TRANSCEIVER_LIST_H_ +#include + #include #include #include #include +#include "absl/types/optional.h" +#include "api/media_types.h" +#include "api/rtc_error.h" +#include "api/rtp_parameters.h" +#include "api/rtp_sender_interface.h" +#include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "pc/rtp_transceiver.h" +#include "rtc_base/checks.h" +#include "rtc_base/system/no_unique_address.h" +#include "rtc_base/thread_annotations.h" namespace webrtc { @@ -32,11 +44,17 @@ class TransceiverStableState { void SetMSectionIfUnset(absl::optional mid, absl::optional mline_index); void SetRemoteStreamIdsIfUnset(const std::vector& ids); + void SetInitSendEncodings( + const std::vector& encodings); absl::optional mid() const { return mid_; } absl::optional mline_index() const { return mline_index_; } absl::optional> remote_stream_ids() const { return remote_stream_ids_; } + absl::optional> init_send_encodings() + const { + return init_send_encodings_; + } bool has_m_section() const { return has_m_section_; } bool newly_created() const { return newly_created_; } @@ -44,6 +62,7 @@ class TransceiverStableState { absl::optional mid_; absl::optional mline_index_; absl::optional> remote_stream_ids_; + absl::optional> init_send_encodings_; // Indicates that mid value from stable state has been captured and // that rollback has to restore the transceiver. Also protects against // subsequent overwrites. @@ -54,14 +73,36 @@ class TransceiverStableState { bool newly_created_ = false; }; +// This class encapsulates the active list of transceivers on a +// PeerConnection, and offers convenient functions on that list. +// It is a single-thread class; all operations must be performed +// on the same thread. class TransceiverList { public: - std::vector List() const { return transceivers_; } + // Returns a copy of the currently active list of transceivers. The + // list consists of rtc::scoped_refptrs, which will keep the transceivers + // from being deallocated, even if they are removed from the TransceiverList. + std::vector List() const { + RTC_DCHECK_RUN_ON(&sequence_checker_); + return transceivers_; + } + // As above, but does not check thread ownership. Unsafe. + // TODO(bugs.webrtc.org/12692): Refactor and remove + std::vector UnsafeList() const { + return transceivers_; + } + + // Returns a list of the internal() pointers of the currently active list + // of transceivers. These raw pointers are not thread-safe, so need to + // be consumed on the same thread. + std::vector ListInternal() const; void Add(RtpTransceiverProxyRefPtr transceiver) { + RTC_DCHECK_RUN_ON(&sequence_checker_); transceivers_.push_back(transceiver); } void Remove(RtpTransceiverProxyRefPtr transceiver) { + RTC_DCHECK_RUN_ON(&sequence_checker_); transceivers_.erase( std::remove(transceivers_.begin(), transceivers_.end(), transceiver), transceivers_.end()); @@ -73,26 +114,33 @@ class TransceiverList { // Find or create the stable state for a transceiver. TransceiverStableState* StableState(RtpTransceiverProxyRefPtr transceiver) { + RTC_DCHECK_RUN_ON(&sequence_checker_); return &(transceiver_stable_states_by_transceivers_[transceiver]); } void DiscardStableStates() { + RTC_DCHECK_RUN_ON(&sequence_checker_); transceiver_stable_states_by_transceivers_.clear(); } std::map& StableStates() { + RTC_DCHECK_RUN_ON(&sequence_checker_); return transceiver_stable_states_by_transceivers_; } private: + RTC_NO_UNIQUE_ADDRESS SequenceChecker sequence_checker_; std::vector transceivers_; + // TODO(bugs.webrtc.org/12692): Add RTC_GUARDED_BY(sequence_checker_); + // Holds changes made to transceivers during applying descriptors for // potential rollback. Gets cleared once signaling state goes to stable. std::map - transceiver_stable_states_by_transceivers_; + transceiver_stable_states_by_transceivers_ + RTC_GUARDED_BY(sequence_checker_); // Holds remote stream ids for transceivers from stable state. std::map> - remote_stream_ids_by_transceivers_; + remote_stream_ids_by_transceivers_ RTC_GUARDED_BY(sequence_checker_); }; } // namespace webrtc diff --git a/pc/transport_stats.h b/pc/transport_stats.h index 7cb95f4ad2..173af91fba 100644 --- a/pc/transport_stats.h +++ b/pc/transport_stats.h @@ -14,6 +14,7 @@ #include #include +#include "api/dtls_transport_interface.h" #include "p2p/base/dtls_transport_internal.h" #include "p2p/base/ice_transport_internal.h" #include "p2p/base/port.h" @@ -30,7 +31,7 @@ struct TransportChannelStats { int ssl_version_bytes = 0; int srtp_crypto_suite = rtc::SRTP_INVALID_CRYPTO_SUITE; int ssl_cipher_suite = rtc::TLS_NULL_WITH_NULL_NULL; - DtlsTransportState dtls_state = DTLS_TRANSPORT_NEW; + webrtc::DtlsTransportState dtls_state = webrtc::DtlsTransportState::kNew; IceTransportStats ice_transport_stats; }; diff --git a/pc/usage_pattern.h b/pc/usage_pattern.h index c4a8918ac2..0182999d6b 100644 --- a/pc/usage_pattern.h +++ b/pc/usage_pattern.h @@ -11,6 +11,8 @@ #ifndef PC_USAGE_PATTERN_H_ #define PC_USAGE_PATTERN_H_ +#include "api/peer_connection_interface.h" + namespace webrtc { class PeerConnectionObserver; diff --git a/pc/used_ids.h b/pc/used_ids.h index 78e64caa41..62b2faa018 100644 --- a/pc/used_ids.h +++ b/pc/used_ids.h @@ -60,7 +60,9 @@ class UsedIds { } protected: - bool IsIdUsed(int new_id) { return id_set_.find(new_id) != id_set_.end(); } + virtual bool IsIdUsed(int new_id) { + return id_set_.find(new_id) != id_set_.end(); + } const int min_allowed_id_; const int max_allowed_id_; @@ -92,11 +94,24 @@ class UsedIds { class UsedPayloadTypes : public UsedIds { public: UsedPayloadTypes() - : UsedIds(kDynamicPayloadTypeMin, kDynamicPayloadTypeMax) {} + : UsedIds(kFirstDynamicPayloadTypeLowerRange, + kLastDynamicPayloadTypeUpperRange) {} + + protected: + bool IsIdUsed(int new_id) override { + // Range marked for RTCP avoidance is "used". + if (new_id > kLastDynamicPayloadTypeLowerRange && + new_id < kFirstDynamicPayloadTypeUpperRange) + return true; + return UsedIds::IsIdUsed(new_id); + } private: - static const int kDynamicPayloadTypeMin = 96; - static const int kDynamicPayloadTypeMax = 127; + static const int kFirstDynamicPayloadTypeLowerRange = 35; + static const int kLastDynamicPayloadTypeLowerRange = 63; + + static const int kFirstDynamicPayloadTypeUpperRange = 96; + static const int kLastDynamicPayloadTypeUpperRange = 127; }; // Helper class used for finding duplicate RTP Header extension ids among diff --git a/pc/video_rtp_receiver.cc b/pc/video_rtp_receiver.cc index dd601259ec..8db4d9f02f 100644 --- a/pc/video_rtp_receiver.cc +++ b/pc/video_rtp_receiver.cc @@ -15,16 +15,12 @@ #include #include -#include "api/media_stream_proxy.h" -#include "api/video_track_source_proxy.h" -#include "pc/jitter_buffer_delay.h" -#include "pc/jitter_buffer_delay_proxy.h" -#include "pc/media_stream.h" +#include "api/video/recordable_encoded_frame.h" +#include "api/video_track_source_proxy_factory.h" #include "pc/video_track.h" #include "rtc_base/checks.h" #include "rtc_base/location.h" #include "rtc_base/logging.h" -#include "rtc_base/trace_event.h" namespace webrtc { @@ -41,121 +37,139 @@ VideoRtpReceiver::VideoRtpReceiver( const std::vector>& streams) : worker_thread_(worker_thread), id_(receiver_id), - source_(new RefCountedObject(this)), + source_(rtc::make_ref_counted(&source_callback_)), track_(VideoTrackProxyWithInternal::Create( rtc::Thread::Current(), worker_thread, - VideoTrack::Create( - receiver_id, - VideoTrackSourceProxy::Create(rtc::Thread::Current(), - worker_thread, - source_), - worker_thread))), - attachment_id_(GenerateUniqueId()), - delay_(JitterBufferDelayProxy::Create( - rtc::Thread::Current(), - worker_thread, - new rtc::RefCountedObject(worker_thread))) { + VideoTrack::Create(receiver_id, + CreateVideoTrackSourceProxy(rtc::Thread::Current(), + worker_thread, + source_), + worker_thread))), + attachment_id_(GenerateUniqueId()) { RTC_DCHECK(worker_thread_); SetStreams(streams); - source_->SetState(MediaSourceInterface::kLive); + RTC_DCHECK_EQ(source_->state(), MediaSourceInterface::kLive); } VideoRtpReceiver::~VideoRtpReceiver() { - // Since cricket::VideoRenderer is not reference counted, - // we need to remove it from the channel before we are deleted. - Stop(); - // Make sure we can't be called by the |source_| anymore. - worker_thread_->Invoke(RTC_FROM_HERE, - [this] { source_->ClearCallback(); }); + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); + RTC_DCHECK(stopped_); + RTC_DCHECK(!media_channel_); } std::vector VideoRtpReceiver::stream_ids() const { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); std::vector stream_ids(streams_.size()); for (size_t i = 0; i < streams_.size(); ++i) stream_ids[i] = streams_[i]->id(); return stream_ids; } +rtc::scoped_refptr VideoRtpReceiver::dtls_transport() + const { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); + return dtls_transport_; +} + +std::vector> +VideoRtpReceiver::streams() const { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); + return streams_; +} + RtpParameters VideoRtpReceiver::GetParameters() const { - if (!media_channel_ || stopped_) { + RTC_DCHECK_RUN_ON(worker_thread_); + if (!media_channel_) return RtpParameters(); - } - return worker_thread_->Invoke(RTC_FROM_HERE, [&] { - return ssrc_ ? media_channel_->GetRtpReceiveParameters(*ssrc_) - : media_channel_->GetDefaultRtpReceiveParameters(); - }); + return ssrc_ ? media_channel_->GetRtpReceiveParameters(*ssrc_) + : media_channel_->GetDefaultRtpReceiveParameters(); } void VideoRtpReceiver::SetFrameDecryptor( rtc::scoped_refptr frame_decryptor) { + RTC_DCHECK_RUN_ON(worker_thread_); frame_decryptor_ = std::move(frame_decryptor); // Special Case: Set the frame decryptor to any value on any existing channel. - if (media_channel_ && ssrc_.has_value() && !stopped_) { - worker_thread_->Invoke(RTC_FROM_HERE, [&] { - media_channel_->SetFrameDecryptor(*ssrc_, frame_decryptor_); - }); + if (media_channel_ && ssrc_) { + media_channel_->SetFrameDecryptor(*ssrc_, frame_decryptor_); } } rtc::scoped_refptr VideoRtpReceiver::GetFrameDecryptor() const { + RTC_DCHECK_RUN_ON(worker_thread_); return frame_decryptor_; } void VideoRtpReceiver::SetDepacketizerToDecoderFrameTransformer( rtc::scoped_refptr frame_transformer) { - worker_thread_->Invoke(RTC_FROM_HERE, [&] { - RTC_DCHECK_RUN_ON(worker_thread_); - frame_transformer_ = std::move(frame_transformer); - if (media_channel_ && !stopped_) { - media_channel_->SetDepacketizerToDecoderFrameTransformer( - ssrc_.value_or(0), frame_transformer_); - } - }); + RTC_DCHECK_RUN_ON(worker_thread_); + frame_transformer_ = std::move(frame_transformer); + if (media_channel_) { + media_channel_->SetDepacketizerToDecoderFrameTransformer( + ssrc_.value_or(0), frame_transformer_); + } } void VideoRtpReceiver::Stop() { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); // TODO(deadbeef): Need to do more here to fully stop receiving packets. - if (stopped_) { - return; + + if (!stopped_) { + source_->SetState(MediaSourceInterface::kEnded); + stopped_ = true; } - source_->SetState(MediaSourceInterface::kEnded); - if (!media_channel_) { - RTC_LOG(LS_WARNING) << "VideoRtpReceiver::Stop: No video channel exists."; - } else { - // Allow that SetSink fails. This is the normal case when the underlying - // media channel has already been deleted. - worker_thread_->Invoke(RTC_FROM_HERE, [&] { - RTC_DCHECK_RUN_ON(worker_thread_); + + worker_thread_->Invoke(RTC_FROM_HERE, [&] { + RTC_DCHECK_RUN_ON(worker_thread_); + if (media_channel_) { SetSink(nullptr); - }); - } - delay_->OnStop(); - stopped_ = true; + SetMediaChannel_w(nullptr); + } + source_->ClearCallback(); + }); } void VideoRtpReceiver::StopAndEndTrack() { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); Stop(); track_->internal()->set_ended(); } void VideoRtpReceiver::RestartMediaChannel(absl::optional ssrc) { - RTC_DCHECK(media_channel_); - if (!stopped_ && ssrc_ == ssrc) { - return; - } - worker_thread_->Invoke(RTC_FROM_HERE, [&] { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); + + // `stopped_` will be `true` on construction. RestartMediaChannel + // can in this case function like "ensure started" and flip `stopped_` + // to false. + + // TODO(tommi): Can we restart the media channel without blocking? + bool ok = worker_thread_->Invoke(RTC_FROM_HERE, [&, was_stopped = + stopped_] { RTC_DCHECK_RUN_ON(worker_thread_); - if (!stopped_) { + if (!media_channel_) { + // Ignore further negotiations if we've already been stopped and don't + // have an associated media channel. + RTC_DCHECK(was_stopped); + return false; // Can't restart. + } + + if (!was_stopped && ssrc_ == ssrc) { + // Already running with that ssrc. + return true; + } + + // Disconnect from the previous ssrc. + if (!was_stopped) { SetSink(nullptr); } + bool encoded_sink_enabled = saved_encoded_sink_enabled_; SetEncodedSinkEnabled(false); - stopped_ = false; - - ssrc_ = ssrc; + // Set up the new ssrc. + ssrc_ = std::move(ssrc); SetSink(source_->sink()); if (encoded_sink_enabled) { SetEncodedSinkEnabled(true); @@ -165,47 +179,62 @@ void VideoRtpReceiver::RestartMediaChannel(absl::optional ssrc) { media_channel_->SetDepacketizerToDecoderFrameTransformer( ssrc_.value_or(0), frame_transformer_); } + + if (media_channel_ && ssrc_) { + if (frame_decryptor_) { + media_channel_->SetFrameDecryptor(*ssrc_, frame_decryptor_); + } + + media_channel_->SetBaseMinimumPlayoutDelayMs(*ssrc_, delay_.GetMs()); + } + + return true; }); - // Attach any existing frame decryptor to the media channel. - MaybeAttachFrameDecryptorToMediaChannel( - ssrc, worker_thread_, frame_decryptor_, media_channel_, stopped_); - // TODO(bugs.webrtc.org/8694): Stop using 0 to mean unsignalled SSRC - // value. - delay_->OnStart(media_channel_, ssrc.value_or(0)); + if (!ok) + return; + + stopped_ = false; } +// RTC_RUN_ON(worker_thread_) void VideoRtpReceiver::SetSink(rtc::VideoSinkInterface* sink) { - RTC_DCHECK(media_channel_); if (ssrc_) { media_channel_->SetSink(*ssrc_, sink); - return; + } else { + media_channel_->SetDefaultSink(sink); } - media_channel_->SetDefaultSink(sink); } void VideoRtpReceiver::SetupMediaChannel(uint32_t ssrc) { - if (!media_channel_) { - RTC_LOG(LS_ERROR) - << "VideoRtpReceiver::SetupMediaChannel: No video channel exists."; - } + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); RestartMediaChannel(ssrc); } void VideoRtpReceiver::SetupUnsignaledMediaChannel() { - if (!media_channel_) { - RTC_LOG(LS_ERROR) << "VideoRtpReceiver::SetupUnsignaledMediaChannel: No " - "video channel exists."; - } + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); RestartMediaChannel(absl::nullopt); } +uint32_t VideoRtpReceiver::ssrc() const { + RTC_DCHECK_RUN_ON(worker_thread_); + return ssrc_.value_or(0); +} + void VideoRtpReceiver::set_stream_ids(std::vector stream_ids) { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); SetStreams(CreateStreamsFromIds(std::move(stream_ids))); } +void VideoRtpReceiver::set_transport( + rtc::scoped_refptr dtls_transport) { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); + dtls_transport_ = std::move(dtls_transport); +} + void VideoRtpReceiver::SetStreams( const std::vector>& streams) { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); // Remove remote track from any streams that are going away. for (const auto& existing_stream : streams_) { bool removed = true; @@ -238,6 +267,7 @@ void VideoRtpReceiver::SetStreams( } void VideoRtpReceiver::SetObserver(RtpReceiverObserverInterface* observer) { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); observer_ = observer; // Deliver any notifications the observer may have missed by being set late. if (received_first_packet_ && observer_) { @@ -247,40 +277,57 @@ void VideoRtpReceiver::SetObserver(RtpReceiverObserverInterface* observer) { void VideoRtpReceiver::SetJitterBufferMinimumDelay( absl::optional delay_seconds) { - delay_->Set(delay_seconds); + RTC_DCHECK_RUN_ON(worker_thread_); + delay_.Set(delay_seconds); + if (media_channel_ && ssrc_) + media_channel_->SetBaseMinimumPlayoutDelayMs(*ssrc_, delay_.GetMs()); } void VideoRtpReceiver::SetMediaChannel(cricket::MediaChannel* media_channel) { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); RTC_DCHECK(media_channel == nullptr || media_channel->media_type() == media_type()); + + if (stopped_ && !media_channel) + return; + worker_thread_->Invoke(RTC_FROM_HERE, [&] { RTC_DCHECK_RUN_ON(worker_thread_); - bool encoded_sink_enabled = saved_encoded_sink_enabled_; - if (encoded_sink_enabled && media_channel_) { - // Turn off the old sink, if any. - SetEncodedSinkEnabled(false); - } + SetMediaChannel_w(media_channel); + }); +} - media_channel_ = static_cast(media_channel); +// RTC_RUN_ON(worker_thread_) +void VideoRtpReceiver::SetMediaChannel_w(cricket::MediaChannel* media_channel) { + if (media_channel == media_channel_) + return; - if (media_channel_) { - if (saved_generate_keyframe_) { - // TODO(bugs.webrtc.org/8694): Stop using 0 to mean unsignalled SSRC - media_channel_->GenerateKeyFrame(ssrc_.value_or(0)); - saved_generate_keyframe_ = false; - } - if (encoded_sink_enabled) { - SetEncodedSinkEnabled(true); - } - if (frame_transformer_) { - media_channel_->SetDepacketizerToDecoderFrameTransformer( - ssrc_.value_or(0), frame_transformer_); - } + bool encoded_sink_enabled = saved_encoded_sink_enabled_; + if (encoded_sink_enabled && media_channel_) { + // Turn off the old sink, if any. + SetEncodedSinkEnabled(false); + } + + media_channel_ = static_cast(media_channel); + + if (media_channel_) { + if (saved_generate_keyframe_) { + // TODO(bugs.webrtc.org/8694): Stop using 0 to mean unsignalled SSRC + media_channel_->GenerateKeyFrame(ssrc_.value_or(0)); + saved_generate_keyframe_ = false; } - }); + if (encoded_sink_enabled) { + SetEncodedSinkEnabled(true); + } + if (frame_transformer_) { + media_channel_->SetDepacketizerToDecoderFrameTransformer( + ssrc_.value_or(0), frame_transformer_); + } + } } void VideoRtpReceiver::NotifyFirstPacketReceived() { + RTC_DCHECK_RUN_ON(&signaling_thread_checker_); if (observer_) { observer_->OnFirstPacketReceived(media_type()); } @@ -288,11 +335,10 @@ void VideoRtpReceiver::NotifyFirstPacketReceived() { } std::vector VideoRtpReceiver::GetSources() const { - if (!media_channel_ || !ssrc_ || stopped_) { - return {}; - } - return worker_thread_->Invoke>( - RTC_FROM_HERE, [&] { return media_channel_->GetSources(*ssrc_); }); + RTC_DCHECK_RUN_ON(worker_thread_); + if (!ssrc_ || !media_channel_) + return std::vector(); + return media_channel_->GetSources(*ssrc_); } void VideoRtpReceiver::OnGenerateKeyFrame() { @@ -318,20 +364,21 @@ void VideoRtpReceiver::OnEncodedSinkEnabled(bool enable) { saved_encoded_sink_enabled_ = enable; } +// RTC_RUN_ON(worker_thread_) void VideoRtpReceiver::SetEncodedSinkEnabled(bool enable) { - if (media_channel_) { - if (enable) { - // TODO(bugs.webrtc.org/8694): Stop using 0 to mean unsignalled SSRC - auto source = source_; - media_channel_->SetRecordableEncodedFrameCallback( - ssrc_.value_or(0), - [source = std::move(source)](const RecordableEncodedFrame& frame) { - source->BroadcastRecordableEncodedFrame(frame); - }); - } else { - // TODO(bugs.webrtc.org/8694): Stop using 0 to mean unsignalled SSRC - media_channel_->ClearRecordableEncodedFrameCallback(ssrc_.value_or(0)); - } + if (!media_channel_) + return; + + // TODO(bugs.webrtc.org/8694): Stop using 0 to mean unsignalled SSRC + const auto ssrc = ssrc_.value_or(0); + + if (enable) { + media_channel_->SetRecordableEncodedFrameCallback( + ssrc, [source = source_](const RecordableEncodedFrame& frame) { + source->BroadcastRecordableEncodedFrame(frame); + }); + } else { + media_channel_->ClearRecordableEncodedFrameCallback(ssrc); } } diff --git a/pc/video_rtp_receiver.h b/pc/video_rtp_receiver.h index 74ae44431e..f59db7a840 100644 --- a/pc/video_rtp_receiver.h +++ b/pc/video_rtp_receiver.h @@ -18,28 +18,32 @@ #include "absl/types/optional.h" #include "api/crypto/frame_decryptor_interface.h" +#include "api/dtls_transport_interface.h" #include "api/frame_transformer_interface.h" #include "api/media_stream_interface.h" -#include "api/media_stream_track_proxy.h" #include "api/media_types.h" #include "api/rtp_parameters.h" #include "api/rtp_receiver_interface.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" +#include "api/transport/rtp/rtp_source.h" #include "api/video/video_frame.h" #include "api/video/video_sink_interface.h" #include "api/video/video_source_interface.h" #include "media/base/media_channel.h" -#include "pc/jitter_buffer_delay_interface.h" +#include "pc/jitter_buffer_delay.h" +#include "pc/media_stream_track_proxy.h" #include "pc/rtp_receiver.h" #include "pc/video_rtp_track_source.h" #include "pc/video_track.h" #include "rtc_base/ref_counted_object.h" +#include "rtc_base/system/no_unique_address.h" #include "rtc_base/thread.h" +#include "rtc_base/thread_annotations.h" namespace webrtc { -class VideoRtpReceiver : public rtc::RefCountedObject, - public VideoRtpTrackSource::Callback { +class VideoRtpReceiver : public RtpReceiverInternal { public: // An SSRC of 0 will create a receiver that will match the first SSRC it // sees. Must be called on signaling thread. @@ -55,23 +59,16 @@ class VideoRtpReceiver : public rtc::RefCountedObject, virtual ~VideoRtpReceiver(); - rtc::scoped_refptr video_track() const { - return track_.get(); - } + rtc::scoped_refptr video_track() const { return track_; } // RtpReceiverInterface implementation rtc::scoped_refptr track() const override { - return track_.get(); - } - rtc::scoped_refptr dtls_transport() const override { - return dtls_transport_; + return track_; } + rtc::scoped_refptr dtls_transport() const override; std::vector stream_ids() const override; std::vector> streams() - const override { - return streams_; - } - + const override; cricket::MediaType media_type() const override { return cricket::MEDIA_TYPE_VIDEO; } @@ -94,13 +91,11 @@ class VideoRtpReceiver : public rtc::RefCountedObject, void StopAndEndTrack() override; void SetupMediaChannel(uint32_t ssrc) override; void SetupUnsignaledMediaChannel() override; - uint32_t ssrc() const override { return ssrc_.value_or(0); } + uint32_t ssrc() const override; void NotifyFirstPacketReceived() override; void set_stream_ids(std::vector stream_ids) override; void set_transport( - rtc::scoped_refptr dtls_transport) override { - dtls_transport_ = dtls_transport; - } + rtc::scoped_refptr dtls_transport) override; void SetStreams(const std::vector>& streams) override; @@ -119,33 +114,68 @@ class VideoRtpReceiver : public rtc::RefCountedObject, void RestartMediaChannel(absl::optional ssrc); void SetSink(rtc::VideoSinkInterface* sink) RTC_RUN_ON(worker_thread_); + void SetMediaChannel_w(cricket::MediaChannel* media_channel) + RTC_RUN_ON(worker_thread_); // VideoRtpTrackSource::Callback - void OnGenerateKeyFrame() override; - void OnEncodedSinkEnabled(bool enable) override; + void OnGenerateKeyFrame(); + void OnEncodedSinkEnabled(bool enable); + void SetEncodedSinkEnabled(bool enable) RTC_RUN_ON(worker_thread_); + class SourceCallback : public VideoRtpTrackSource::Callback { + public: + explicit SourceCallback(VideoRtpReceiver* receiver) : receiver_(receiver) {} + ~SourceCallback() override = default; + + private: + void OnGenerateKeyFrame() override { receiver_->OnGenerateKeyFrame(); } + void OnEncodedSinkEnabled(bool enable) override { + receiver_->OnEncodedSinkEnabled(enable); + } + + VideoRtpReceiver* const receiver_; + } source_callback_{this}; + + RTC_NO_UNIQUE_ADDRESS SequenceChecker signaling_thread_checker_; rtc::Thread* const worker_thread_; const std::string id_; - cricket::VideoMediaChannel* media_channel_ = nullptr; - absl::optional ssrc_; + // See documentation for `stopped_` below for when a valid media channel + // has been assigned and when this pointer will be null. + cricket::VideoMediaChannel* media_channel_ RTC_GUARDED_BY(worker_thread_) = + nullptr; + absl::optional ssrc_ RTC_GUARDED_BY(worker_thread_); // |source_| is held here to be able to change the state of the source when // the VideoRtpReceiver is stopped. - rtc::scoped_refptr source_; - rtc::scoped_refptr> track_; - std::vector> streams_; - bool stopped_ = true; - RtpReceiverObserverInterface* observer_ = nullptr; - bool received_first_packet_ = false; - int attachment_id_ = 0; - rtc::scoped_refptr frame_decryptor_; - rtc::scoped_refptr dtls_transport_; + const rtc::scoped_refptr source_; + const rtc::scoped_refptr> track_; + std::vector> streams_ + RTC_GUARDED_BY(&signaling_thread_checker_); + // `stopped` is state that's used on the signaling thread to indicate whether + // a valid `media_channel_` has been assigned and configured. When an instance + // of VideoRtpReceiver is initially created, `stopped_` is true and will + // remain true until either `SetupMediaChannel` or + // `SetupUnsignaledMediaChannel` is called after assigning a media channel. + // After that, `stopped_` will remain false until `Stop()` is called. + // Note, for checking the state of the class on the worker thread, + // check `media_channel_` instead, as that's the main worker thread state. + bool stopped_ RTC_GUARDED_BY(&signaling_thread_checker_) = true; + RtpReceiverObserverInterface* observer_ + RTC_GUARDED_BY(&signaling_thread_checker_) = nullptr; + bool received_first_packet_ RTC_GUARDED_BY(&signaling_thread_checker_) = + false; + const int attachment_id_; + rtc::scoped_refptr frame_decryptor_ + RTC_GUARDED_BY(worker_thread_); + rtc::scoped_refptr dtls_transport_ + RTC_GUARDED_BY(&signaling_thread_checker_); rtc::scoped_refptr frame_transformer_ RTC_GUARDED_BY(worker_thread_); - // Allows to thread safely change jitter buffer delay. Handles caching cases + // Stores the minimum jitter buffer delay. Handles caching cases // if |SetJitterBufferMinimumDelay| is called before start. - rtc::scoped_refptr delay_; + JitterBufferDelay delay_ RTC_GUARDED_BY(worker_thread_); + // Records if we should generate a keyframe when |media_channel_| gets set up // or switched. bool saved_generate_keyframe_ RTC_GUARDED_BY(worker_thread_) = false; diff --git a/pc/video_rtp_receiver_unittest.cc b/pc/video_rtp_receiver_unittest.cc index b3eb6e6e35..3a8099d30f 100644 --- a/pc/video_rtp_receiver_unittest.cc +++ b/pc/video_rtp_receiver_unittest.cc @@ -17,8 +17,10 @@ #include "test/gmock.h" using ::testing::_; +using ::testing::AnyNumber; using ::testing::InSequence; using ::testing::Mock; +using ::testing::NiceMock; using ::testing::SaveArg; using ::testing::StrictMock; @@ -29,9 +31,11 @@ class VideoRtpReceiverTest : public testing::Test { protected: class MockVideoMediaChannel : public cricket::FakeVideoMediaChannel { public: - MockVideoMediaChannel(cricket::FakeVideoEngine* engine, - const cricket::VideoOptions& options) - : FakeVideoMediaChannel(engine, options) {} + MockVideoMediaChannel( + cricket::FakeVideoEngine* engine, + const cricket::VideoOptions& options, + TaskQueueBase* network_thread = rtc::Thread::Current()) + : FakeVideoMediaChannel(engine, options, network_thread) {} MOCK_METHOD(void, SetRecordableEncodedFrameCallback, (uint32_t, std::function), @@ -51,19 +55,26 @@ class VideoRtpReceiverTest : public testing::Test { VideoRtpReceiverTest() : worker_thread_(rtc::Thread::Create()), channel_(nullptr, cricket::VideoOptions()), - receiver_(new VideoRtpReceiver(worker_thread_.get(), - "receiver", - {"stream"})) { + receiver_(rtc::make_ref_counted( + worker_thread_.get(), + std::string("receiver"), + std::vector({"stream"}))) { worker_thread_->Start(); receiver_->SetMediaChannel(&channel_); } + ~VideoRtpReceiverTest() override { + // Clear expectations that tests may have set up before calling Stop(). + Mock::VerifyAndClearExpectations(&channel_); + receiver_->Stop(); + } + webrtc::VideoTrackSourceInterface* Source() { return receiver_->streams()[0]->FindVideoTrack("receiver")->GetSource(); } std::unique_ptr worker_thread_; - MockVideoMediaChannel channel_; + NiceMock channel_; rtc::scoped_refptr receiver_; }; @@ -96,6 +107,10 @@ TEST_F(VideoRtpReceiverTest, // Switching to a new channel should now not cause calls to GenerateKeyFrame. StrictMock channel4(nullptr, cricket::VideoOptions()); receiver_->SetMediaChannel(&channel4); + + // We must call Stop() here since the mock media channels live on the stack + // and `receiver_` still has a pointer to those objects. + receiver_->Stop(); } TEST_F(VideoRtpReceiverTest, EnablesEncodedOutput) { @@ -129,6 +144,10 @@ TEST_F(VideoRtpReceiverTest, DisablesEnablesEncodedOutputOnChannelSwitch) { Source()->RemoveEncodedSink(&sink); StrictMock channel3(nullptr, cricket::VideoOptions()); receiver_->SetMediaChannel(&channel3); + + // We must call Stop() here since the mock media channels live on the stack + // and `receiver_` still has a pointer to those objects. + receiver_->Stop(); } TEST_F(VideoRtpReceiverTest, BroadcastsEncodedFramesWhenEnabled) { diff --git a/pc/video_rtp_track_source.cc b/pc/video_rtp_track_source.cc index f96db962b1..bcfcdcbdf9 100644 --- a/pc/video_rtp_track_source.cc +++ b/pc/video_rtp_track_source.cc @@ -10,6 +10,12 @@ #include "pc/video_rtp_track_source.h" +#include + +#include + +#include "rtc_base/checks.h" + namespace webrtc { VideoRtpTrackSource::VideoRtpTrackSource(Callback* callback) diff --git a/pc/video_rtp_track_source.h b/pc/video_rtp_track_source.h index 9903aaa232..47b7bc9eef 100644 --- a/pc/video_rtp_track_source.h +++ b/pc/video_rtp_track_source.h @@ -13,11 +13,17 @@ #include +#include "api/sequence_checker.h" +#include "api/video/recordable_encoded_frame.h" +#include "api/video/video_frame.h" +#include "api/video/video_sink_interface.h" +#include "api/video/video_source_interface.h" #include "media/base/video_broadcaster.h" #include "pc/video_track_source.h" -#include "rtc_base/callback.h" +#include "rtc_base/constructor_magic.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/system/no_unique_address.h" +#include "rtc_base/thread_annotations.h" namespace webrtc { diff --git a/pc/video_rtp_track_source_unittest.cc b/pc/video_rtp_track_source_unittest.cc index ea1b4cacf8..5666b77d5f 100644 --- a/pc/video_rtp_track_source_unittest.cc +++ b/pc/video_rtp_track_source_unittest.cc @@ -30,9 +30,7 @@ class MockSink : public rtc::VideoSinkInterface { rtc::scoped_refptr MakeSource( VideoRtpTrackSource::Callback* callback) { - rtc::scoped_refptr source( - new rtc::RefCountedObject(callback)); - return source; + return rtc::make_ref_counted(callback); } TEST(VideoRtpTrackSourceTest, CreatesWithRemoteAtttributeSet) { diff --git a/pc/video_track.cc b/pc/video_track.cc index 55356e7046..d0246faa87 100644 --- a/pc/video_track.cc +++ b/pc/video_track.cc @@ -11,9 +11,11 @@ #include "pc/video_track.h" #include +#include #include #include "api/notifier.h" +#include "api/sequence_checker.h" #include "rtc_base/checks.h" #include "rtc_base/location.h" #include "rtc_base/ref_counted_object.h" @@ -27,10 +29,16 @@ VideoTrack::VideoTrack(const std::string& label, worker_thread_(worker_thread), video_source_(video_source), content_hint_(ContentHint::kNone) { + RTC_DCHECK_RUN_ON(&signaling_thread_); + // Detach the thread checker for VideoSourceBaseGuarded since we'll make calls + // to VideoSourceBaseGuarded on the worker thread, but we're currently on the + // signaling thread. + source_sequence_.Detach(); video_source_->RegisterObserver(this); } VideoTrack::~VideoTrack() { + RTC_DCHECK_RUN_ON(&signaling_thread_); video_source_->UnregisterObserver(this); } @@ -42,26 +50,31 @@ std::string VideoTrack::kind() const { // thread. void VideoTrack::AddOrUpdateSink(rtc::VideoSinkInterface* sink, const rtc::VideoSinkWants& wants) { - RTC_DCHECK(worker_thread_->IsCurrent()); - VideoSourceBase::AddOrUpdateSink(sink, wants); + RTC_DCHECK_RUN_ON(worker_thread_); + VideoSourceBaseGuarded::AddOrUpdateSink(sink, wants); rtc::VideoSinkWants modified_wants = wants; modified_wants.black_frames = !enabled(); video_source_->AddOrUpdateSink(sink, modified_wants); } void VideoTrack::RemoveSink(rtc::VideoSinkInterface* sink) { - RTC_DCHECK(worker_thread_->IsCurrent()); - VideoSourceBase::RemoveSink(sink); + RTC_DCHECK_RUN_ON(worker_thread_); + VideoSourceBaseGuarded::RemoveSink(sink); video_source_->RemoveSink(sink); } +VideoTrackSourceInterface* VideoTrack::GetSource() const { + // Callable from any thread. + return video_source_.get(); +} + VideoTrackInterface::ContentHint VideoTrack::content_hint() const { - RTC_DCHECK_RUN_ON(&signaling_thread_checker_); + RTC_DCHECK_RUN_ON(worker_thread_); return content_hint_; } void VideoTrack::set_content_hint(ContentHint hint) { - RTC_DCHECK_RUN_ON(&signaling_thread_checker_); + RTC_DCHECK_RUN_ON(worker_thread_); if (content_hint_ == hint) return; content_hint_ = hint; @@ -69,34 +82,43 @@ void VideoTrack::set_content_hint(ContentHint hint) { } bool VideoTrack::set_enabled(bool enable) { - RTC_DCHECK(signaling_thread_checker_.IsCurrent()); - worker_thread_->Invoke(RTC_FROM_HERE, [enable, this] { - RTC_DCHECK(worker_thread_->IsCurrent()); - for (auto& sink_pair : sink_pairs()) { - rtc::VideoSinkWants modified_wants = sink_pair.wants; - modified_wants.black_frames = !enable; - video_source_->AddOrUpdateSink(sink_pair.sink, modified_wants); - } - }); + RTC_DCHECK_RUN_ON(worker_thread_); + for (auto& sink_pair : sink_pairs()) { + rtc::VideoSinkWants modified_wants = sink_pair.wants; + modified_wants.black_frames = !enable; + video_source_->AddOrUpdateSink(sink_pair.sink, modified_wants); + } return MediaStreamTrack::set_enabled(enable); } +bool VideoTrack::enabled() const { + RTC_DCHECK_RUN_ON(worker_thread_); + return MediaStreamTrack::enabled(); +} + +MediaStreamTrackInterface::TrackState VideoTrack::state() const { + RTC_DCHECK_RUN_ON(worker_thread_); + return MediaStreamTrack::state(); +} + void VideoTrack::OnChanged() { - RTC_DCHECK(signaling_thread_checker_.IsCurrent()); - if (video_source_->state() == MediaSourceInterface::kEnded) { - set_state(kEnded); - } else { - set_state(kLive); - } + RTC_DCHECK_RUN_ON(&signaling_thread_); + worker_thread_->Invoke( + RTC_FROM_HERE, [this, state = video_source_->state()]() { + // TODO(tommi): Calling set_state() this way isn't ideal since we're + // currently blocking the signaling thread and set_state() may + // internally fire notifications via `FireOnChanged()` which may further + // amplify the blocking effect on the signaling thread. + rtc::Thread::ScopedDisallowBlockingCalls no_blocking_calls; + set_state(state == MediaSourceInterface::kEnded ? kEnded : kLive); + }); } rtc::scoped_refptr VideoTrack::Create( const std::string& id, VideoTrackSourceInterface* source, rtc::Thread* worker_thread) { - rtc::RefCountedObject* track = - new rtc::RefCountedObject(id, source, worker_thread); - return track; + return rtc::make_ref_counted(id, source, worker_thread); } } // namespace webrtc diff --git a/pc/video_track.h b/pc/video_track.h index b7835dee29..e840c8097f 100644 --- a/pc/video_track.h +++ b/pc/video_track.h @@ -16,18 +16,18 @@ #include "api/media_stream_interface.h" #include "api/media_stream_track.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "api/video/video_frame.h" #include "api/video/video_sink_interface.h" #include "api/video/video_source_interface.h" #include "media/base/video_source_base.h" #include "rtc_base/thread.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" namespace webrtc { class VideoTrack : public MediaStreamTrack, - public rtc::VideoSourceBase, + public rtc::VideoSourceBaseGuarded, public ObserverInterface { public: static rtc::scoped_refptr Create( @@ -38,13 +38,13 @@ class VideoTrack : public MediaStreamTrack, void AddOrUpdateSink(rtc::VideoSinkInterface* sink, const rtc::VideoSinkWants& wants) override; void RemoveSink(rtc::VideoSinkInterface* sink) override; + VideoTrackSourceInterface* GetSource() const override; - VideoTrackSourceInterface* GetSource() const override { - return video_source_.get(); - } ContentHint content_hint() const override; void set_content_hint(ContentHint hint) override; bool set_enabled(bool enable) override; + bool enabled() const override; + MediaStreamTrackInterface::TrackState state() const override; std::string kind() const override; protected: @@ -57,10 +57,10 @@ class VideoTrack : public MediaStreamTrack, // Implements ObserverInterface. Observes |video_source_| state. void OnChanged() override; + RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker signaling_thread_; rtc::Thread* const worker_thread_; - rtc::ThreadChecker signaling_thread_checker_; - rtc::scoped_refptr video_source_; - ContentHint content_hint_ RTC_GUARDED_BY(signaling_thread_checker_); + const rtc::scoped_refptr video_source_; + ContentHint content_hint_ RTC_GUARDED_BY(worker_thread_); }; } // namespace webrtc diff --git a/pc/video_track_source.cc b/pc/video_track_source.cc index f45d44aa32..d15eaaf43c 100644 --- a/pc/video_track_source.cc +++ b/pc/video_track_source.cc @@ -15,7 +15,7 @@ namespace webrtc { VideoTrackSource::VideoTrackSource(bool remote) - : state_(kInitializing), remote_(remote) { + : state_(kLive), remote_(remote) { worker_thread_checker_.Detach(); } diff --git a/pc/video_track_source.h b/pc/video_track_source.h index 27331eac4f..4a29381c4c 100644 --- a/pc/video_track_source.h +++ b/pc/video_track_source.h @@ -11,12 +11,16 @@ #ifndef PC_VIDEO_TRACK_SOURCE_H_ #define PC_VIDEO_TRACK_SOURCE_H_ +#include "absl/types/optional.h" #include "api/media_stream_interface.h" #include "api/notifier.h" +#include "api/sequence_checker.h" +#include "api/video/recordable_encoded_frame.h" +#include "api/video/video_frame.h" #include "api/video/video_sink_interface.h" +#include "api/video/video_source_interface.h" #include "media/base/media_channel.h" #include "rtc_base/system/rtc_export.h" -#include "rtc_base/thread_checker.h" namespace webrtc { @@ -52,7 +56,7 @@ class RTC_EXPORT VideoTrackSource : public Notifier { virtual rtc::VideoSourceInterface* source() = 0; private: - rtc::ThreadChecker worker_thread_checker_; + SequenceChecker worker_thread_checker_; SourceState state_; const bool remote_; }; diff --git a/pc/video_track_source_proxy.cc b/pc/video_track_source_proxy.cc new file mode 100644 index 0000000000..309c1f20f8 --- /dev/null +++ b/pc/video_track_source_proxy.cc @@ -0,0 +1,25 @@ +/* + * Copyright 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "pc/video_track_source_proxy.h" + +#include "api/media_stream_interface.h" +#include "api/video_track_source_proxy_factory.h" + +namespace webrtc { + +rtc::scoped_refptr CreateVideoTrackSourceProxy( + rtc::Thread* signaling_thread, + rtc::Thread* worker_thread, + VideoTrackSourceInterface* source) { + return VideoTrackSourceProxy::Create(signaling_thread, worker_thread, source); +} + +} // namespace webrtc diff --git a/api/video_track_source_proxy.h b/pc/video_track_source_proxy.h similarity index 53% rename from api/video_track_source_proxy.h rename to pc/video_track_source_proxy.h index 692ff6493f..8914dd0525 100644 --- a/api/video_track_source_proxy.h +++ b/pc/video_track_source_proxy.h @@ -8,42 +8,42 @@ * be found in the AUTHORS file in the root of the source tree. */ -#ifndef API_VIDEO_TRACK_SOURCE_PROXY_H_ -#define API_VIDEO_TRACK_SOURCE_PROXY_H_ +#ifndef PC_VIDEO_TRACK_SOURCE_PROXY_H_ +#define PC_VIDEO_TRACK_SOURCE_PROXY_H_ #include "api/media_stream_interface.h" -#include "api/proxy.h" +#include "pc/proxy.h" namespace webrtc { // Makes sure the real VideoTrackSourceInterface implementation is destroyed on // the signaling thread and marshals all method calls to the signaling thread. -// TODO(deadbeef): Move this to .cc file and out of api/. What threads methods -// are called on is an implementation detail. +// TODO(deadbeef): Move this to .cc file. What threads methods are called on is +// an implementation detail. BEGIN_PROXY_MAP(VideoTrackSource) -PROXY_SIGNALING_THREAD_DESTRUCTOR() +PROXY_PRIMARY_THREAD_DESTRUCTOR() PROXY_CONSTMETHOD0(SourceState, state) BYPASS_PROXY_CONSTMETHOD0(bool, remote) BYPASS_PROXY_CONSTMETHOD0(bool, is_screencast) PROXY_CONSTMETHOD0(absl::optional, needs_denoising) PROXY_METHOD1(bool, GetStats, Stats*) -PROXY_WORKER_METHOD2(void, - AddOrUpdateSink, - rtc::VideoSinkInterface*, - const rtc::VideoSinkWants&) -PROXY_WORKER_METHOD1(void, RemoveSink, rtc::VideoSinkInterface*) +PROXY_SECONDARY_METHOD2(void, + AddOrUpdateSink, + rtc::VideoSinkInterface*, + const rtc::VideoSinkWants&) +PROXY_SECONDARY_METHOD1(void, RemoveSink, rtc::VideoSinkInterface*) PROXY_METHOD1(void, RegisterObserver, ObserverInterface*) PROXY_METHOD1(void, UnregisterObserver, ObserverInterface*) PROXY_CONSTMETHOD0(bool, SupportsEncodedOutput) -PROXY_WORKER_METHOD0(void, GenerateKeyFrame) -PROXY_WORKER_METHOD1(void, - AddEncodedSink, - rtc::VideoSinkInterface*) -PROXY_WORKER_METHOD1(void, - RemoveEncodedSink, - rtc::VideoSinkInterface*) -END_PROXY_MAP() +PROXY_SECONDARY_METHOD0(void, GenerateKeyFrame) +PROXY_SECONDARY_METHOD1(void, + AddEncodedSink, + rtc::VideoSinkInterface*) +PROXY_SECONDARY_METHOD1(void, + RemoveEncodedSink, + rtc::VideoSinkInterface*) +END_PROXY_MAP(VideoTrackSource) } // namespace webrtc -#endif // API_VIDEO_TRACK_SOURCE_PROXY_H_ +#endif // PC_VIDEO_TRACK_SOURCE_PROXY_H_ diff --git a/pc/video_track_unittest.cc b/pc/video_track_unittest.cc index f86bec8321..ab094ec487 100644 --- a/pc/video_track_unittest.cc +++ b/pc/video_track_unittest.cc @@ -32,7 +32,7 @@ class VideoTrackTest : public ::testing::Test { public: VideoTrackTest() : frame_source_(640, 480, rtc::kNumMicrosecsPerSec / 30) { static const char kVideoTrackId[] = "track_id"; - video_track_source_ = new rtc::RefCountedObject( + video_track_source_ = rtc::make_ref_counted( /*is_screencast=*/false); video_track_ = VideoTrack::Create(kVideoTrackId, video_track_source_, rtc::Thread::Current()); diff --git a/pc/webrtc_sdp.cc b/pc/webrtc_sdp.cc index edd8db6a96..379b2f30c2 100644 --- a/pc/webrtc_sdp.cc +++ b/pc/webrtc_sdp.cc @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -24,29 +25,46 @@ #include #include "absl/algorithm/container.h" -#include "absl/strings/match.h" #include "api/candidate.h" #include "api/crypto_params.h" #include "api/jsep_ice_candidate.h" #include "api/jsep_session_description.h" #include "api/media_types.h" // for RtpExtension +#include "absl/types/optional.h" +#include "api/rtc_error.h" #include "api/rtp_parameters.h" +#include "api/rtp_transceiver_direction.h" #include "media/base/codec.h" #include "media/base/media_constants.h" +#include "media/base/rid_description.h" #include "media/base/rtp_utils.h" +#include "media/base/stream_params.h" #include "media/sctp/sctp_transport_internal.h" +#include "p2p/base/candidate_pair_interface.h" +#include "p2p/base/ice_transport_internal.h" #include "p2p/base/p2p_constants.h" #include "p2p/base/port.h" +#include "p2p/base/port_interface.h" +#include "p2p/base/transport_description.h" +#include "p2p/base/transport_info.h" +#include "pc/media_protocol_names.h" #include "pc/media_session.h" #include "pc/sdp_serializer.h" +#include "pc/session_description.h" +#include "pc/simulcast_description.h" #include "rtc_base/arraysize.h" #include "rtc_base/checks.h" +#include "rtc_base/helpers.h" +#include "rtc_base/ip_address.h" #include "rtc_base/logging.h" -#include "rtc_base/message_digest.h" +#include "rtc_base/net_helper.h" +#include "rtc_base/network_constants.h" +#include "rtc_base/socket_address.h" +#include "rtc_base/ssl_fingerprint.h" +#include "rtc_base/string_encode.h" #include "rtc_base/string_utils.h" #include "rtc_base/strings/string_builder.h" -#include "rtc_base/third_party/base64/base64.h" using cricket::AudioContentDescription; using cricket::Candidate; @@ -64,7 +82,6 @@ using cricket::MediaContentDescription; using cricket::MediaProtocolType; using cricket::MediaType; using cricket::RidDescription; -using cricket::RtpDataContentDescription; using cricket::RtpHeaderExtensions; using cricket::SctpDataContentDescription; using cricket::SimulcastDescription; @@ -79,10 +96,6 @@ using cricket::UnsupportedContentDescription; using cricket::VideoContentDescription; using rtc::SocketAddress; -namespace cricket { -class SessionDescription; -} - // TODO(deadbeef): Switch to using anonymous namespace rather than declaring // everything "static". namespace webrtc { @@ -93,6 +106,15 @@ namespace webrtc { // the form: // = // where MUST be exactly one case-significant character. + +// Legal characters in a value (RFC 4566 section 9): +// token-char = %x21 / %x23-27 / %x2A-2B / %x2D-2E / %x30-39 +// / %x41-5A / %x5E-7E +static const char kLegalTokenCharacters[] = + "!#$%&'*+-." // %x21, %x23-27, %x2A-2B, %x2D-2E + "0123456789" // %x30-39 + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" // %x41-5A + "^_`abcdefghijklmnopqrstuvwxyz{|}~"; // %x5E-7E static const int kLinePrefixLength = 2; // Length of = static const char kLineTypeVersion = 'v'; static const char kLineTypeOrigin = 'o'; @@ -605,6 +627,22 @@ static bool GetValue(const std::string& message, return true; } +// Get a single [token] from : +static bool GetSingleTokenValue(const std::string& message, + const std::string& attribute, + std::string* value, + SdpParseError* error) { + if (!GetValue(message, attribute, value, error)) { + return false; + } + if (strspn(value->c_str(), kLegalTokenCharacters) != value->size()) { + rtc::StringBuilder description; + description << "Illegal character found in the value of " << attribute; + return ParseFailed(message, description.str(), error); + } + return true; +} + static bool CaseInsensitiveFind(std::string str1, std::string str2) { absl::c_transform(str1, str1.begin(), ::tolower); absl::c_transform(str2, str2.begin(), ::tolower); @@ -862,11 +900,11 @@ std::string SdpSerialize(const JsepSessionDescription& jdesc) { // Time Description. AddLine(kTimeDescription, &message); - // Group - if (desc->HasGroup(cricket::GROUP_TYPE_BUNDLE)) { + // BUNDLE Groups + std::vector groups = + desc->GetGroupsByName(cricket::GROUP_TYPE_BUNDLE); + for (const cricket::ContentGroup* group : groups) { std::string group_line = kAttrGroup; - const cricket::ContentGroup* group = - desc->GetGroupByName(cricket::GROUP_TYPE_BUNDLE); RTC_DCHECK(group != NULL); for (const std::string& content_name : group->content_names()) { group_line.append(" "); @@ -1376,12 +1414,7 @@ void BuildMediaDescription(const ContentInfo* content_info, fmt.append(kDefaultSctpmapProtocol); } } else { - const RtpDataContentDescription* rtp_data_desc = - media_desc->as_rtp_data(); - for (const cricket::RtpDataCodec& codec : rtp_data_desc->codecs()) { - fmt.append(" "); - fmt.append(rtc::ToString(codec.id)); - } + RTC_NOTREACHED() << "Data description without SCTP"; } } else if (media_type == cricket::MEDIA_TYPE_UNSUPPORTED) { const UnsupportedContentDescription* unsupported_desc = @@ -1933,19 +1966,6 @@ void BuildRtpMap(const MediaContentDescription* media_desc, ptime = std::max(ptime, max_minptime); AddAttributeLine(kCodecParamPTime, ptime, message); } - } else if (media_type == cricket::MEDIA_TYPE_DATA) { - if (media_desc->as_rtp_data()) { - for (const cricket::RtpDataCodec& codec : - media_desc->as_rtp_data()->codecs()) { - // RFC 4566 - // a=rtpmap: / - // [/] - InitAttrLine(kAttributeRtpmap, &os); - os << kSdpDelimiterColon << codec.id << " " << codec.name << "/" - << codec.clockrate; - AddLine(os.str(), message); - } - } } } @@ -2570,6 +2590,7 @@ static std::unique_ptr ParseContentDescription( std::vector>* candidates, webrtc::SdpParseError* error) { auto media_desc = std::make_unique(); + media_desc->set_extmap_allow_mixed_enum(MediaContentDescription::kNo); if (!ParseContent(message, media_type, mline_index, protocol, payload_types, pos, content_name, bundle_only, msid_signaling, media_desc.get(), transport, candidates, error)) { @@ -2659,6 +2680,10 @@ bool ParseMediaDescription( bool bundle_only = false; int section_msid_signaling = 0; const std::string& media_type = fields[0]; + if ((media_type == kMediaTypeVideo || media_type == kMediaTypeAudio) && + !cricket::IsRtpProtocol(protocol)) { + return ParseFailed(line, "Unsupported protocol for media type", error); + } if (media_type == kMediaTypeVideo) { content = ParseContentDescription( message, cricket::MEDIA_TYPE_VIDEO, mline_index, protocol, @@ -2696,13 +2721,7 @@ bool ParseMediaDescription( data_desc->set_protocol(protocol); content = std::move(data_desc); } else { - // RTP - std::unique_ptr data_desc = - ParseContentDescription( - message, cricket::MEDIA_TYPE_DATA, mline_index, protocol, - payload_types, pos, &content_name, &bundle_only, - §ion_msid_signaling, &transport, candidates, error); - content = std::move(data_desc); + return ParseFailed(line, "Unsupported protocol for media type", error); } } else { RTC_LOG(LS_WARNING) << "Unsupported media type: " << line; @@ -3030,21 +3049,6 @@ bool ParseContent(const std::string& message, return ParseFailed( line, "b=" + bandwidth_type + " value can't be negative.", error); } - // We should never use more than the default bandwidth for RTP-based - // data channels. Don't allow SDP to set the bandwidth, because - // that would give JS the opportunity to "break the Internet". - // See: https://code.google.com/p/chromium/issues/detail?id=280726 - // Disallow TIAS since it shouldn't be generated for RTP data channels in - // the first place and provides another way to get around the limitation. - if (media_type == cricket::MEDIA_TYPE_DATA && - cricket::IsRtpProtocol(protocol) && - (b > cricket::kRtpDataMaxBandwidth / 1000 || - bandwidth_type == kTransportSpecificBandwidth)) { - rtc::StringBuilder description; - description << "RTP-based data channels may not send more than " - << cricket::kRtpDataMaxBandwidth / 1000 << "kbps."; - return ParseFailed(line, description.str(), error); - } // Convert values. Prevent integer overflow. if (bandwidth_type == kApplicationSpecificBandwidth) { b = std::min(b, INT_MAX / 1000) * 1000; @@ -3078,7 +3082,7 @@ bool ParseContent(const std::string& message, // mid-attribute = "a=mid:" identification-tag // identification-tag = token // Use the mid identification-tag as the content name. - if (!GetValue(line, kAttributeMid, &mline_id, error)) { + if (!GetSingleTokenValue(line, kAttributeMid, &mline_id, error)) { return false; } *content_name = mline_id; @@ -3122,16 +3126,12 @@ bool ParseContent(const std::string& message, if (!ParseDtlsSetup(line, &(transport->connection_role), error)) { return false; } - } else if (cricket::IsDtlsSctp(protocol)) { + } else if (cricket::IsDtlsSctp(protocol) && + media_type == cricket::MEDIA_TYPE_DATA) { // // SCTP specific attributes // if (HasAttribute(line, kAttributeSctpPort)) { - if (media_type != cricket::MEDIA_TYPE_DATA) { - return ParseFailed( - line, "sctp-port attribute found in non-data media description.", - error); - } if (media_desc->as_sctp()->use_sctpmap()) { return ParseFailed( line, "sctp-port attribute can't be used with sctpmap.", error); @@ -3142,12 +3142,6 @@ bool ParseContent(const std::string& message, } media_desc->as_sctp()->set_port(sctp_port); } else if (HasAttribute(line, kAttributeMaxMessageSize)) { - if (media_type != cricket::MEDIA_TYPE_DATA) { - return ParseFailed( - line, - "max-message-size attribute found in non-data media description.", - error); - } int max_message_size; if (!ParseSctpMaxMessageSize(line, &max_message_size, error)) { return false; @@ -3635,11 +3629,6 @@ bool ParseRtpmapAttribute(const std::string& line, AudioContentDescription* audio_desc = media_desc->as_audio(); UpdateCodec(payload_type, encoding_name, clock_rate, 0, channels, audio_desc); - } else if (media_type == cricket::MEDIA_TYPE_DATA) { - RtpDataContentDescription* data_desc = media_desc->as_rtp_data(); - if (data_desc) { - data_desc->AddCodec(cricket::RtpDataCodec(payload_type, encoding_name)); - } } return true; } diff --git a/pc/webrtc_sdp.h b/pc/webrtc_sdp.h index 588e02f139..aa3317f341 100644 --- a/pc/webrtc_sdp.h +++ b/pc/webrtc_sdp.h @@ -22,7 +22,12 @@ #include +#include "api/candidate.h" +#include "api/jsep.h" +#include "api/jsep_ice_candidate.h" +#include "api/jsep_session_description.h" #include "media/base/codec.h" +#include "rtc_base/strings/string_builder.h" #include "rtc_base/system/rtc_export.h" namespace cricket { diff --git a/pc/webrtc_sdp_unittest.cc b/pc/webrtc_sdp_unittest.cc index cf5384725b..266fd3dfd6 100644 --- a/pc/webrtc_sdp_unittest.cc +++ b/pc/webrtc_sdp_unittest.cc @@ -56,7 +56,6 @@ using cricket::Candidate; using cricket::ContentGroup; using cricket::ContentInfo; using cricket::CryptoParams; -using cricket::DataCodec; using cricket::ICE_CANDIDATE_COMPONENT_RTCP; using cricket::ICE_CANDIDATE_COMPONENT_RTP; using cricket::kFecSsrcGroupSemantics; @@ -65,7 +64,6 @@ using cricket::MediaProtocolType; using cricket::RELAY_PORT_TYPE; using cricket::RidDescription; using cricket::RidDirection; -using cricket::RtpDataContentDescription; using cricket::SctpDataContentDescription; using cricket::SessionDescription; using cricket::SimulcastDescription; @@ -153,6 +151,7 @@ static const char kSdpFullString[] = "o=- 18446744069414584320 18446462598732840960 IN IP4 127.0.0.1\r\n" "s=-\r\n" "t=0 0\r\n" + "a=extmap-allow-mixed\r\n" "a=msid-semantic: WMS local_stream_1\r\n" "m=audio 2345 RTP/SAVPF 111 103 104\r\n" "c=IN IP4 74.125.127.126\r\n" @@ -223,6 +222,7 @@ static const char kSdpString[] = "o=- 18446744069414584320 18446462598732840960 IN IP4 127.0.0.1\r\n" "s=-\r\n" "t=0 0\r\n" + "a=extmap-allow-mixed\r\n" "a=msid-semantic: WMS local_stream_1\r\n" "m=audio 9 RTP/SAVPF 111 103 104\r\n" "c=IN IP4 0.0.0.0\r\n" @@ -261,22 +261,6 @@ static const char kSdpString[] = "a=ssrc:3 mslabel:local_stream_1\r\n" "a=ssrc:3 label:video_track_id_1\r\n"; -static const char kSdpRtpDataChannelString[] = - "m=application 9 RTP/SAVPF 101\r\n" - "c=IN IP4 0.0.0.0\r\n" - "a=rtcp:9 IN IP4 0.0.0.0\r\n" - "a=ice-ufrag:ufrag_data\r\n" - "a=ice-pwd:pwd_data\r\n" - "a=mid:data_content_name\r\n" - "a=sendrecv\r\n" - "a=crypto:1 AES_CM_128_HMAC_SHA1_80 " - "inline:FvLcvU2P3ZWmQxgPAgcDu7Zl9vftYElFOjEzhWs5\r\n" - "a=rtpmap:101 google-data/90000\r\n" - "a=ssrc:10 cname:data_channel_cname\r\n" - "a=ssrc:10 msid:data_channel data_channeld0\r\n" - "a=ssrc:10 mslabel:data_channel\r\n" - "a=ssrc:10 label:data_channeld0\r\n"; - // draft-ietf-mmusic-sctp-sdp-03 static const char kSdpSctpDataChannelString[] = "m=application 9 UDP/DTLS/SCTP 5000\r\n" @@ -373,6 +357,7 @@ static const char kBundleOnlySdpFullString[] = "o=- 18446744069414584320 18446462598732840960 IN IP4 127.0.0.1\r\n" "s=-\r\n" "t=0 0\r\n" + "a=extmap-allow-mixed\r\n" "a=group:BUNDLE audio_content_name video_content_name\r\n" "a=msid-semantic: WMS local_stream_1\r\n" "m=audio 2345 RTP/SAVPF 111 103 104\r\n" @@ -433,6 +418,7 @@ static const char kPlanBSdpFullString[] = "o=- 18446744069414584320 18446462598732840960 IN IP4 127.0.0.1\r\n" "s=-\r\n" "t=0 0\r\n" + "a=extmap-allow-mixed\r\n" "a=msid-semantic: WMS local_stream_1 local_stream_2\r\n" "m=audio 2345 RTP/SAVPF 111 103 104\r\n" "c=IN IP4 74.125.127.126\r\n" @@ -516,6 +502,7 @@ static const char kUnifiedPlanSdpFullString[] = "o=- 18446744069414584320 18446462598732840960 IN IP4 127.0.0.1\r\n" "s=-\r\n" "t=0 0\r\n" + "a=extmap-allow-mixed\r\n" "a=msid-semantic: WMS local_stream_1\r\n" // Audio track 1, stream 1 (with candidates). "m=audio 2345 RTP/SAVPF 111 103 104\r\n" @@ -628,6 +615,7 @@ static const char kUnifiedPlanSdpFullStringWithSpecialMsid[] = "o=- 18446744069414584320 18446462598732840960 IN IP4 127.0.0.1\r\n" "s=-\r\n" "t=0 0\r\n" + "a=extmap-allow-mixed\r\n" "a=msid-semantic: WMS local_stream_1\r\n" // Audio track 1, with 1 stream id. "m=audio 2345 RTP/SAVPF 111 103 104\r\n" @@ -900,12 +888,6 @@ static const uint32_t kVideoTrack3Ssrc = 6; static const char kAudioTrackId3[] = "audio_track_id_3"; static const uint32_t kAudioTrack3Ssrc = 7; -// DataChannel -static const char kDataChannelLabel[] = "data_channel"; -static const char kDataChannelMsid[] = "data_channeld0"; -static const char kDataChannelCname[] = "data_channel_cname"; -static const uint32_t kDataChannelSsrc = 10; - // Candidate static const char kDummyMid[] = "dummy_mid"; static const int kDummyIndex = 123; @@ -938,15 +920,16 @@ static void Replace(const std::string& line, absl::StrReplaceAll({{line, newlines}}, message); } -// Expect fail to parase |bad_sdp| and expect |bad_part| be part of the error -// message. +// Expect a parse failure on the line containing |bad_part| when attempting to +// parse |bad_sdp|. static void ExpectParseFailure(const std::string& bad_sdp, const std::string& bad_part) { JsepSessionDescription desc(kDummyType); SdpParseError error; bool ret = webrtc::SdpDeserialize(bad_sdp, &desc, &error); - EXPECT_FALSE(ret); - EXPECT_NE(std::string::npos, error.line.find(bad_part.c_str())); + ASSERT_FALSE(ret); + EXPECT_NE(std::string::npos, error.line.find(bad_part.c_str())) + << "Did not find " << bad_part << " in " << error.line; } // Expect fail to parse kSdpFullString if replace |good_part| with |bad_part|. @@ -1459,11 +1442,6 @@ class WebRtcSdpTest : public ::testing::Test { simulcast2.receive_layers().size()); } - void CompareRtpDataContentDescription(const RtpDataContentDescription* dcd1, - const RtpDataContentDescription* dcd2) { - CompareMediaContentDescription(dcd1, dcd2); - } - void CompareSctpDataContentDescription( const SctpDataContentDescription* dcd1, const SctpDataContentDescription* dcd2) { @@ -1514,14 +1492,6 @@ class WebRtcSdpTest : public ::testing::Test { const SctpDataContentDescription* scd2 = c2.media_description()->as_sctp(); CompareSctpDataContentDescription(scd1, scd2); - } else { - if (IsDataContent(&c1)) { - const RtpDataContentDescription* dcd1 = - c1.media_description()->as_rtp_data(); - const RtpDataContentDescription* dcd2 = - c2.media_description()->as_rtp_data(); - CompareRtpDataContentDescription(dcd1, dcd2); - } } CompareSimulcastDescription( @@ -1809,28 +1779,6 @@ class WebRtcSdpTest : public ::testing::Test { kDataContentName, TransportDescription(kUfragData, kPwdData))); } - void AddRtpDataChannel() { - std::unique_ptr data( - new RtpDataContentDescription()); - data_desc_ = data.get(); - - data_desc_->AddCodec(DataCodec(101, "google-data")); - StreamParams data_stream; - data_stream.id = kDataChannelMsid; - data_stream.cname = kDataChannelCname; - data_stream.set_stream_ids({kDataChannelLabel}); - data_stream.ssrcs.push_back(kDataChannelSsrc); - data_desc_->AddStream(data_stream); - data_desc_->AddCrypto( - CryptoParams(1, "AES_CM_128_HMAC_SHA1_80", - "inline:FvLcvU2P3ZWmQxgPAgcDu7Zl9vftYElFOjEzhWs5", "")); - data_desc_->set_protocol(cricket::kMediaProtocolSavpf); - desc_.AddContent(kDataContentName, MediaProtocolType::kRtp, - std::move(data)); - desc_.AddTransportInfo(TransportInfo( - kDataContentName, TransportDescription(kUfragData, kPwdData))); - } - bool TestDeserializeDirection(RtpTransceiverDirection direction) { std::string new_sdp = kSdpFullString; ReplaceDirection(direction, &new_sdp); @@ -1958,7 +1906,8 @@ class WebRtcSdpTest : public ::testing::Test { os.clear(); os.str(""); // Pl type 100 preferred. - os << "m=video 9 RTP/SAVPF 99 95\r\n" + os << "m=video 9 RTP/SAVPF 99 95 96\r\n" + "a=rtpmap:96 VP9/90000\r\n" // out-of-order wrt the m= line. "a=rtpmap:99 VP8/90000\r\n" "a=rtpmap:95 RTX/90000\r\n" "a=fmtp:95 apt=99;\r\n"; @@ -2006,6 +1955,10 @@ class WebRtcSdpTest : public ::testing::Test { EXPECT_EQ("RTX", rtx.name); EXPECT_EQ(95, rtx.id); VerifyCodecParameter(rtx.params, "apt", vp8.id); + // VP9 is listed last in the m= line so should come after VP8 and RTX. + cricket::VideoCodec vp9 = vcd->codecs()[2]; + EXPECT_EQ("VP9", vp9.name); + EXPECT_EQ(96, vp9.id); } void TestDeserializeRtcpFb(JsepSessionDescription* jdesc_output, @@ -2096,7 +2049,6 @@ class WebRtcSdpTest : public ::testing::Test { SessionDescription desc_; AudioContentDescription* audio_desc_; VideoContentDescription* video_desc_; - RtpDataContentDescription* data_desc_; SctpDataContentDescription* sctp_desc_; Candidates candidates_; std::unique_ptr jcandidate_; @@ -2172,17 +2124,21 @@ TEST_F(WebRtcSdpTest, SerializeSessionDescriptionWithoutCandidates) { EXPECT_EQ(std::string(kSdpString), message); } -TEST_F(WebRtcSdpTest, SerializeSessionDescriptionWithBundle) { - ContentGroup group(cricket::GROUP_TYPE_BUNDLE); - group.AddContentName(kAudioContentName); - group.AddContentName(kVideoContentName); - desc_.AddGroup(group); +TEST_F(WebRtcSdpTest, SerializeSessionDescriptionWithBundles) { + ContentGroup group1(cricket::GROUP_TYPE_BUNDLE); + group1.AddContentName(kAudioContentName); + group1.AddContentName(kVideoContentName); + desc_.AddGroup(group1); + ContentGroup group2(cricket::GROUP_TYPE_BUNDLE); + group2.AddContentName(kAudioContentName2); + desc_.AddGroup(group2); ASSERT_TRUE(jdesc_.Initialize(desc_.Clone(), jdesc_.session_id(), jdesc_.session_version())); std::string message = webrtc::SdpSerialize(jdesc_); std::string sdp_with_bundle = kSdpFullString; InjectAfter(kSessionTime, - "a=group:BUNDLE audio_content_name video_content_name\r\n", + "a=group:BUNDLE audio_content_name video_content_name\r\n" + "a=group:BUNDLE audio_content_name_2\r\n", &sdp_with_bundle); EXPECT_EQ(sdp_with_bundle, message); } @@ -2262,18 +2218,6 @@ TEST_F(WebRtcSdpTest, SerializeSessionDescriptionWithAudioVideoRejected) { EXPECT_TRUE(TestSerializeRejected(true, true)); } -TEST_F(WebRtcSdpTest, SerializeSessionDescriptionWithRtpDataChannel) { - AddRtpDataChannel(); - JsepSessionDescription jsep_desc(kDummyType); - - MakeDescriptionWithoutCandidates(&jsep_desc); - std::string message = webrtc::SdpSerialize(jsep_desc); - - std::string expected_sdp = kSdpString; - expected_sdp.append(kSdpRtpDataChannelString); - EXPECT_EQ(expected_sdp, message); -} - TEST_F(WebRtcSdpTest, SerializeSessionDescriptionWithSctpDataChannel) { bool use_sctpmap = true; AddSctpDataChannel(use_sctpmap); @@ -2320,22 +2264,6 @@ TEST_F(WebRtcSdpTest, SerializeWithSctpDataChannelAndNewPort) { EXPECT_EQ(expected_sdp, message); } -TEST_F(WebRtcSdpTest, SerializeSessionDescriptionWithDataChannelAndBandwidth) { - JsepSessionDescription jsep_desc(kDummyType); - AddRtpDataChannel(); - data_desc_->set_bandwidth(100 * 1000); - data_desc_->set_bandwidth_type("AS"); - MakeDescriptionWithoutCandidates(&jsep_desc); - std::string message = webrtc::SdpSerialize(jsep_desc); - - std::string expected_sdp = kSdpString; - expected_sdp.append(kSdpRtpDataChannelString); - // Serializing data content shouldn't ignore bandwidth settings. - InjectAfter("m=application 9 RTP/SAVPF 101\r\nc=IN IP4 0.0.0.0\r\n", - "b=AS:100\r\n", &expected_sdp); - EXPECT_EQ(expected_sdp, message); -} - TEST_F(WebRtcSdpTest, SerializeSessionDescriptionWithExtmapAllowMixed) { jdesc_.description()->set_extmap_allow_mixed(true); TestSerialize(jdesc_); @@ -2752,10 +2680,9 @@ TEST_F(WebRtcSdpTest, DeserializeSessionDescriptionWithoutMsid) { TEST_F(WebRtcSdpTest, DeserializeSessionDescriptionWithExtmapAllowMixed) { jdesc_.description()->set_extmap_allow_mixed(true); std::string sdp_with_extmap_allow_mixed = kSdpFullString; - InjectAfter("t=0 0\r\n", kExtmapAllowMixed, &sdp_with_extmap_allow_mixed); // Deserialize JsepSessionDescription jdesc_deserialized(kDummyType); - EXPECT_TRUE(SdpDeserialize(sdp_with_extmap_allow_mixed, &jdesc_deserialized)); + ASSERT_TRUE(SdpDeserialize(sdp_with_extmap_allow_mixed, &jdesc_deserialized)); // Verify EXPECT_TRUE(CompareSessionDescription(jdesc_, jdesc_deserialized)); } @@ -2763,9 +2690,10 @@ TEST_F(WebRtcSdpTest, DeserializeSessionDescriptionWithExtmapAllowMixed) { TEST_F(WebRtcSdpTest, DeserializeSessionDescriptionWithoutExtmapAllowMixed) { jdesc_.description()->set_extmap_allow_mixed(false); std::string sdp_without_extmap_allow_mixed = kSdpFullString; + Replace(kExtmapAllowMixed, "", &sdp_without_extmap_allow_mixed); // Deserialize JsepSessionDescription jdesc_deserialized(kDummyType); - EXPECT_TRUE( + ASSERT_TRUE( SdpDeserialize(sdp_without_extmap_allow_mixed, &jdesc_deserialized)); // Verify EXPECT_TRUE(CompareSessionDescription(jdesc_, jdesc_deserialized)); @@ -2906,21 +2834,6 @@ TEST_F(WebRtcSdpTest, DeserializeInvalidCandidiate) { EXPECT_FALSE(SdpDeserializeCandidate(kSdpTcpInvalidCandidate, &jcandidate)); } -TEST_F(WebRtcSdpTest, DeserializeSdpWithRtpDataChannels) { - AddRtpDataChannel(); - JsepSessionDescription jdesc(kDummyType); - ASSERT_TRUE(jdesc.Initialize(desc_.Clone(), kSessionId, kSessionVersion)); - - std::string sdp_with_data = kSdpString; - sdp_with_data.append(kSdpRtpDataChannelString); - JsepSessionDescription jdesc_output(kDummyType); - - // Deserialize - EXPECT_TRUE(SdpDeserialize(sdp_with_data, &jdesc_output)); - // Verify - EXPECT_TRUE(CompareSessionDescription(jdesc, jdesc_output)); -} - TEST_F(WebRtcSdpTest, DeserializeSdpWithSctpDataChannels) { bool use_sctpmap = true; AddSctpDataChannel(use_sctpmap); @@ -3081,8 +2994,9 @@ TEST_F(WebRtcSdpTest, DeserializeSdpWithRtpmapAttribute) { } TEST_F(WebRtcSdpTest, DeserializeSdpWithStrangeApplicationProtocolNames) { - static const char* bad_strings[] = {"DTLS/SCTPRTP/", "obviously-bogus", - "UDP/TL/RTSP/SAVPF", "UDP/TL/RTSP/S"}; + static const char* bad_strings[] = { + "DTLS/SCTPRTP/", "obviously-bogus", "UDP/TL/RTSP/SAVPF", + "UDP/TL/RTSP/S", "DTLS/SCTP/RTP/FOO", "obviously-bogus/RTP/"}; for (auto proto : bad_strings) { std::string sdp_with_data = kSdpString; sdp_with_data.append("m=application 9 "); @@ -3092,18 +3006,6 @@ TEST_F(WebRtcSdpTest, DeserializeSdpWithStrangeApplicationProtocolNames) { EXPECT_FALSE(SdpDeserialize(sdp_with_data, &jdesc_output)) << "Parsing should have failed on " << proto; } - // The following strings are strange, but acceptable as RTP. - static const char* weird_strings[] = {"DTLS/SCTP/RTP/FOO", - "obviously-bogus/RTP/"}; - for (auto proto : weird_strings) { - std::string sdp_with_data = kSdpString; - sdp_with_data.append("m=application 9 "); - sdp_with_data.append(proto); - sdp_with_data.append(" 47\r\n"); - JsepSessionDescription jdesc_output(kDummyType); - EXPECT_TRUE(SdpDeserialize(sdp_with_data, &jdesc_output)) - << "Parsing should have succeeded on " << proto; - } } // For crbug/344475. @@ -3161,21 +3063,6 @@ TEST_F(WebRtcSdpTest, EXPECT_TRUE(CompareSessionDescription(jdesc, jdesc_output)); } -TEST_F(WebRtcSdpTest, DeserializeSdpWithRtpDataChannelsAndBandwidth) { - // We want to test that deserializing data content limits bandwidth - // settings (it should never be greater than the default). - // This should prevent someone from using unlimited data bandwidth through - // JS and "breaking the Internet". - // See: https://code.google.com/p/chromium/issues/detail?id=280726 - std::string sdp_with_bandwidth = kSdpString; - sdp_with_bandwidth.append(kSdpRtpDataChannelString); - InjectAfter("a=mid:data_content_name\r\n", "b=AS:100\r\n", - &sdp_with_bandwidth); - JsepSessionDescription jdesc_with_bandwidth(kDummyType); - - EXPECT_FALSE(SdpDeserialize(sdp_with_bandwidth, &jdesc_with_bandwidth)); -} - TEST_F(WebRtcSdpTest, DeserializeSdpWithSctpDataChannelsAndBandwidth) { bool use_sctpmap = true; AddSctpDataChannel(use_sctpmap); @@ -4052,24 +3939,6 @@ TEST_F(WebRtcSdpTest, SerializeBothMediaSectionAndSsrcAttributeMsid) { EXPECT_NE(std::string::npos, sdp.find(kSsrcAttributeMsidLine)); } -// Regression test for heap overflow bug: -// https://bugs.chromium.org/p/chromium/issues/detail?id=647916 -TEST_F(WebRtcSdpTest, DeserializeSctpPortInVideoDescription) { - // The issue occurs when the sctp-port attribute is found in a video - // description. The actual heap overflow occurs when parsing the fmtp line. - static const char kSdpWithSctpPortInVideoDescription[] = - "v=0\r\n" - "o=- 18446744069414584320 18446462598732840960 IN IP4 127.0.0.1\r\n" - "s=-\r\n" - "t=0 0\r\n" - "m=video 9 UDP/DTLS/SCTP 120\r\n" - "a=sctp-port 5000\r\n" - "a=fmtp:108 foo=10\r\n"; - - ExpectParseFailure(std::string(kSdpWithSctpPortInVideoDescription), - "sctp-port"); -} - // Regression test for integer overflow bug: // https://bugs.chromium.org/p/chromium/issues/detail?id=648071 TEST_F(WebRtcSdpTest, DeserializeLargeBandwidthLimit) { @@ -4755,3 +4624,42 @@ TEST_F(WebRtcSdpTest, DeserializeSdpWithUnsupportedMediaType) { EXPECT_EQ(jdesc_output.description()->contents()[0].name, "bogusmid"); EXPECT_EQ(jdesc_output.description()->contents()[1].name, "somethingmid"); } + +TEST_F(WebRtcSdpTest, MediaTypeProtocolMismatch) { + std::string sdp = + "v=0\r\n" + "o=- 18446744069414584320 18446462598732840960 IN IP4 127.0.0.1\r\n" + "s=-\r\n" + "t=0 0\r\n"; + + ExpectParseFailure(std::string(sdp + "m=audio 9 UDP/DTLS/SCTP 120\r\n"), + "m=audio"); + ExpectParseFailure(std::string(sdp + "m=video 9 UDP/DTLS/SCTP 120\r\n"), + "m=video"); + ExpectParseFailure(std::string(sdp + "m=video 9 SOMETHING 120\r\n"), + "m=video"); + ExpectParseFailure(std::string(sdp + "m=application 9 SOMETHING 120\r\n"), + "m=application"); +} + +// Regression test for: +// https://bugs.chromium.org/p/chromium/issues/detail?id=1171965 +TEST_F(WebRtcSdpTest, SctpPortInUnsupportedContent) { + std::string sdp = + "v=0\r\n" + "o=- 18446744069414584320 18446462598732840960 IN IP4 127.0.0.1\r\n" + "s=-\r\n" + "t=0 0\r\n" + "m=o 1 DTLS/SCTP 5000\r\n" + "a=sctp-port\r\n"; + + JsepSessionDescription jdesc_output(kDummyType); + EXPECT_TRUE(SdpDeserialize(sdp, &jdesc_output)); +} + +TEST_F(WebRtcSdpTest, IllegalMidCharacterValue) { + std::string sdp = kSdpString; + // [ is an illegal token value. + Replace("a=mid:", "a=mid:[]", &sdp); + ExpectParseFailure(std::string(sdp), "a=mid:[]"); +} diff --git a/pc/webrtc_session_description_factory.cc b/pc/webrtc_session_description_factory.cc index 2a9dc3fbd8..33826347ff 100644 --- a/pc/webrtc_session_description_factory.cc +++ b/pc/webrtc_session_description_factory.cc @@ -174,8 +174,7 @@ WebRtcSessionDescriptionFactory::WebRtcSessionDescriptionFactory( // Generate certificate. certificate_request_state_ = CERTIFICATE_WAITING; - rtc::scoped_refptr callback( - new rtc::RefCountedObject()); + auto callback = rtc::make_ref_counted(); callback->SignalRequestFailed.connect( this, &WebRtcSessionDescriptionFactory::OnCertificateRequestFailed); callback->SignalCertificateReady.connect( @@ -194,7 +193,7 @@ WebRtcSessionDescriptionFactory::WebRtcSessionDescriptionFactory( } WebRtcSessionDescriptionFactory::~WebRtcSessionDescriptionFactory() { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); // Fail any requests that were asked for before identity generation completed. FailPendingRequests(kFailedDueToSessionShutdown); @@ -222,6 +221,7 @@ void WebRtcSessionDescriptionFactory::CreateOffer( CreateSessionDescriptionObserver* observer, const PeerConnectionInterface::RTCOfferAnswerOptions& options, const cricket::MediaSessionOptions& session_options) { + RTC_DCHECK_RUN_ON(signaling_thread_); std::string error = "CreateOffer"; if (certificate_request_state_ == CERTIFICATE_FAILED) { error += kFailedDueToIdentityFailed; @@ -441,7 +441,7 @@ void WebRtcSessionDescriptionFactory::InternalCreateAnswer( void WebRtcSessionDescriptionFactory::FailPendingRequests( const std::string& reason) { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); while (!create_session_description_requests_.empty()) { const CreateSessionDescriptionRequest& request = create_session_description_requests_.front(); @@ -476,7 +476,7 @@ void WebRtcSessionDescriptionFactory::PostCreateSessionDescriptionSucceeded( } void WebRtcSessionDescriptionFactory::OnCertificateRequestFailed() { - RTC_DCHECK(signaling_thread_->IsCurrent()); + RTC_DCHECK_RUN_ON(signaling_thread_); RTC_LOG(LS_ERROR) << "Asynchronous certificate generation request failed."; certificate_request_state_ = CERTIFICATE_FAILED; diff --git a/pc/webrtc_session_description_factory.h b/pc/webrtc_session_description_factory.h index 9256045d6b..bd2636c0dd 100644 --- a/pc/webrtc_session_description_factory.h +++ b/pc/webrtc_session_description_factory.h @@ -12,6 +12,8 @@ #define PC_WEBRTC_SESSION_DESCRIPTION_FACTORY_H_ #include + +#include #include #include #include diff --git a/resources/audio_processing/output_data_float.pb.sha1 b/resources/audio_processing/output_data_float.pb.sha1 index a19c6c3b60..d3375949ac 100644 --- a/resources/audio_processing/output_data_float.pb.sha1 +++ b/resources/audio_processing/output_data_float.pb.sha1 @@ -1 +1 @@ -1dd2c11da1f1dec49f728881628c1348e07a19cd \ No newline at end of file +749efdfd1e3c3ace434b3673dac9ce4938534449 \ No newline at end of file diff --git a/resources/audio_processing/output_data_float_avx2.pb.sha1 b/resources/audio_processing/output_data_float_avx2.pb.sha1 index 54a5b06963..79a95efc0e 100644 --- a/resources/audio_processing/output_data_float_avx2.pb.sha1 +++ b/resources/audio_processing/output_data_float_avx2.pb.sha1 @@ -1 +1 @@ -16e9d8f3b8b6c23b2b5100a1162acfe67acc37a7 \ No newline at end of file +78c1a84de332173863c997538aa19b8cdcba5020 \ No newline at end of file diff --git a/resources/audio_processing/test/py_quality_assessment/BUILD.gn b/resources/audio_processing/test/py_quality_assessment/BUILD.gn index 5f2d34dd49..594efb05bb 100644 --- a/resources/audio_processing/test/py_quality_assessment/BUILD.gn +++ b/resources/audio_processing/test/py_quality_assessment/BUILD.gn @@ -8,7 +8,7 @@ import("../../../../webrtc.gni") -if (rtc_include_tests) { +if (rtc_include_tests && !build_with_chromium) { copy("noise_tracks") { testonly = true sources = [ "noise_tracks/city.wav" ] diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn index 7181c234f3..8dc89fafba 100644 --- a/rtc_base/BUILD.gn +++ b/rtc_base/BUILD.gn @@ -8,6 +8,7 @@ import("//build/config/crypto.gni") import("//build/config/ui.gni") +import("//third_party/google_benchmark/buildconfig.gni") import("../webrtc.gni") if (is_android) { @@ -15,7 +16,7 @@ if (is_android) { import("//build/config/android/rules.gni") } -config("rtc_base_chromium_config") { +config("threading_chromium_config") { defines = [ "NO_MAIN_THREAD_WRAPPING" ] } @@ -74,18 +75,20 @@ rtc_library("rtc_base_approved") { ":type_traits", "../api:array_view", "../api:scoped_refptr", + "../api:sequence_checker", "synchronization:mutex", "system:arch", "system:no_unique_address", "system:rtc_export", - "system:unused", "third_party/base64", ] - absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/base:core_headers", + "//third_party/abseil-cpp/absl/types:optional", + ] public_deps = [] # no-presubmit-check TODO(webrtc:8603) sources = [ - "bind.h", "bit_buffer.cc", "bit_buffer.h", "buffer.h", @@ -98,9 +101,9 @@ rtc_library("rtc_base_approved") { "copy_on_write_buffer.h", "event_tracer.cc", "event_tracer.h", + "hash.h", "location.cc", "location.h", - "message_buffer_reader.h", "numerics/histogram_percentile_counter.cc", "numerics/histogram_percentile_counter.h", "numerics/mod_ops.h", @@ -126,8 +129,6 @@ rtc_library("rtc_base_approved") { if (is_win) { sources += [ - "win/create_direct3d_device.cc", - "win/create_direct3d_device.h", "win/get_activation_factory.cc", "win/get_activation_factory.h", "win/hstring.cc", @@ -140,6 +141,14 @@ rtc_library("rtc_base_approved") { data_deps = [ "//build/win:runtime_libs" ] } + # These files add a dependency on the Win10 SDK v10.0.10240. + if (rtc_enable_win_wgc) { + sources += [ + "win/create_direct3d_device.cc", + "win/create_direct3d_device.h", + ] + } + if (is_nacl) { public_deps += # no-presubmit-check TODO(webrtc:8603) [ "//native_client_sdk/src/libraries/nacl_io" ] @@ -151,7 +160,6 @@ rtc_library("rtc_base_approved") { public_deps += [ # no-presubmit-check TODO(webrtc:8603) ":atomicops", - ":criticalsection", ":logging", ":macromagic", ":platform_thread", @@ -160,9 +168,8 @@ rtc_library("rtc_base_approved") { ":rtc_event", ":safe_conversions", ":stringutils", - ":thread_checker", ":timeutils", - "synchronization:sequence_checker", + "../api:sequence_checker", ] } @@ -191,7 +198,10 @@ rtc_source_set("refcount") { "ref_counted_object.h", "ref_counter.h", ] - deps = [ ":macromagic" ] + deps = [ + ":macromagic", + "../api:scoped_refptr", + ] } rtc_library("criticalsection") { @@ -215,8 +225,8 @@ rtc_library("platform_thread") { ":rtc_task_queue_libevent", ":rtc_task_queue_stdlib", ":rtc_task_queue_win", + "../api:sequence_checker", "synchronization:mutex", - "synchronization:sequence_checker", ] sources = [ "platform_thread.cc", @@ -228,10 +238,14 @@ rtc_library("platform_thread") { ":macromagic", ":platform_thread_types", ":rtc_event", - ":thread_checker", ":timeutils", + "../api:sequence_checker", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", ] - absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] } rtc_library("rtc_event") { @@ -264,7 +278,6 @@ rtc_library("logging") { libs = [] deps = [ ":checks", - ":criticalsection", ":macromagic", ":platform_thread_types", ":stringutils", @@ -299,24 +312,12 @@ rtc_library("logging") { frameworks = [ "Foundation.framework" ] } - # logging.h needs the deprecation header while downstream projects are - # removing code that depends on logging implementation details. - deps += [ ":deprecation" ] - if (is_android) { libs += [ "log" ] } } } -rtc_source_set("thread_checker") { - sources = [ "thread_checker.h" ] - deps = [ - ":deprecation", - "synchronization:sequence_checker", - ] -} - rtc_source_set("atomicops") { sources = [ "atomic_ops.h" ] } @@ -400,6 +401,8 @@ rtc_source_set("safe_conversions") { rtc_library("timeutils") { visibility = [ "*" ] sources = [ + "system_time.cc", + "system_time.h", "time_utils.cc", "time_utils.h", ] @@ -409,6 +412,10 @@ rtc_library("timeutils") { ":stringutils", "system:rtc_export", ] + if (rtc_exclude_system_time) { + defines = [ "WEBRTC_EXCLUDE_SYSTEM_TIME" ] + } + libs = [] if (is_win) { libs += [ "winmm.lib" ] @@ -455,10 +462,6 @@ rtc_source_set("type_traits") { sources = [ "type_traits.h" ] } -rtc_source_set("deprecation") { - sources = [ "deprecation.h" ] -} - rtc_library("rtc_task_queue") { visibility = [ "*" ] sources = [ @@ -485,7 +488,7 @@ rtc_source_set("rtc_operations_chain") { ":macromagic", ":refcount", "../api:scoped_refptr", - "synchronization:sequence_checker", + "../api:sequence_checker", "system:no_unique_address", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] @@ -500,7 +503,6 @@ if (rtc_enable_libevent) { ] deps = [ ":checks", - ":criticalsection", ":logging", ":macromagic", ":platform_thread", @@ -547,7 +549,6 @@ if (is_win) { ] deps = [ ":checks", - ":criticalsection", ":logging", ":macromagic", ":platform_thread", @@ -557,7 +558,10 @@ if (is_win) { "../api/task_queue", "synchronization:mutex", ] - absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] } } @@ -568,7 +572,6 @@ rtc_library("rtc_task_queue_stdlib") { ] deps = [ ":checks", - ":criticalsection", ":logging", ":macromagic", ":platform_thread", @@ -589,7 +592,7 @@ rtc_library("weak_ptr") { deps = [ ":refcount", "../api:scoped_refptr", - "synchronization:sequence_checker", + "../api:sequence_checker", "system:no_unique_address", ] } @@ -634,10 +637,6 @@ rtc_library("rtc_stats_counters") { config("rtc_json_suppressions") { if (!is_win || is_clang) { cflags_cc = [ - # TODO(bugs.webrtc.org/10770): Update jsoncpp API usage and remove - # -Wno-deprecated-declarations. - "-Wno-deprecated-declarations", - # TODO(bugs.webrtc.org/10814): Remove -Wno-undef as soon as it get # removed upstream. "-Wno-undef", @@ -667,140 +666,214 @@ rtc_library("rtc_json") { } } -rtc_source_set("async_resolver") { - # TODO(bugs.webrtc.org/9987): This build target will soon contain - # async_resolver source files (see - # https://webrtc-review.googlesource.com/c/src/+/196903). - sources = [ "async_resolver.h" ] -} - -rtc_source_set("net_helpers") { - # TODO(bugs.webrtc.org/9987): This build target will soon contain - # the following files: - # sources = [ - # "net_helpers.cc", - # "net_helpers.h", - # ] +rtc_library("net_helpers") { + sources = [ + "net_helpers.cc", + "net_helpers.h", + ] + deps = [] + if (is_android) { + deps += [ ":ifaddrs_android" ] + } + if (is_win) { + deps += [ ":win32" ] + } } -rtc_source_set("async_resolver_interface") { +rtc_library("async_resolver_interface") { visibility = [ "*" ] - # TODO(bugs.webrtc.org/9987): This build target will soon contain - # the following files: - # sources = [ - # "async_resolver_interface.cc", - # "async_resolver_interface.h", - # ] + sources = [ + "async_resolver_interface.cc", + "async_resolver_interface.h", + ] + deps = [ + ":socket_address", + "system:rtc_export", + "third_party/sigslot", + ] } -rtc_source_set("ip_address") { +rtc_library("ip_address") { visibility = [ "*" ] - # TODO(bugs.webrtc.org/9987): This build target will soon contain - # the following files: - # sources = [ - # "ip_address.cc", - # "ip_address.h", - # ] + sources = [ + "ip_address.cc", + "ip_address.h", + ] + deps = [ + ":net_helpers", + ":rtc_base_approved", + ":stringutils", + "system:rtc_export", + ] + if (is_win) { + deps += [ ":win32" ] + } } -rtc_source_set("socket_address") { +rtc_library("socket_address") { visibility = [ "*" ] - # TODO(bugs.webrtc.org/9987): This build target will soon contain - # the following files: - # sources = [ - # "socket_address.cc", - # "socket_address.h", - # ] + sources = [ + "socket_address.cc", + "socket_address.h", + ] + deps = [ + ":checks", + ":ip_address", + ":logging", + ":net_helpers", + ":rtc_base_approved", + ":safe_conversions", + ":stringutils", + "system:rtc_export", + ] + if (is_win) { + deps += [ ":win32" ] + } } -rtc_source_set("null_socket_server") { - # TODO(bugs.webrtc.org/9987): This build target will soon contain - # the following files: - # sources = [ - # "null_socket_server.cc", - # "null_socket_server.h", - # ] +rtc_library("null_socket_server") { + sources = [ + "null_socket_server.cc", + "null_socket_server.h", + ] + deps = [ + ":async_socket", + ":checks", + ":rtc_event", + ":socket", + ":socket_server", + "system:rtc_export", + ] } rtc_source_set("socket_server") { - # TODO(bugs.webrtc.org/9987): This build target will soon contain - # the following files: - # sources = [ - # "socket_server.h", - # ] + sources = [ "socket_server.h" ] + deps = [ ":socket_factory" ] } -rtc_source_set("threading") { +rtc_library("threading") { visibility = [ "*" ] - # TODO(bugs.webrtc.org/9987): This build target will soon contain - # the following files: - # sources = [ - # "asyncresolver.cc", - # "asyncresolver.h", - # "defaultsocketserver.cc", - # "defaultsocketserver.h", - # "message_handler.cc", - # "message_handler.h", - # "network_monitor.cc", - # "network_monitor.h", - # "network_monitor_factory.cc", - # "network_monitor_factory.h", - # "physical_socket_server.cc", - # "physical_socket_server.h", - # "signal_thread.cc", - # "signal_thread.h", - # "thread.cc", - # "thread.h", - # ] + + if (build_with_chromium) { + public_configs = [ ":threading_chromium_config" ] + } + + sources = [ + "async_resolver.cc", + "async_resolver.h", + "internal/default_socket_server.cc", + "internal/default_socket_server.h", + "message_handler.cc", + "message_handler.h", + "network_monitor.cc", + "network_monitor.h", + "network_monitor_factory.cc", + "network_monitor_factory.h", + "physical_socket_server.cc", + "physical_socket_server.h", + "thread.cc", + "thread.h", + "thread_message.h", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container" ] + deps = [ + ":async_resolver_interface", + ":atomicops", + ":checks", + ":criticalsection", + ":ip_address", + ":logging", + ":macromagic", + ":network_constants", + ":null_socket_server", + ":platform_thread_types", + ":rtc_base_approved", + ":rtc_event", + ":rtc_task_queue", + ":socket_address", + ":socket_server", + ":timeutils", + "../api:function_view", + "../api:refcountedbase", + "../api:scoped_refptr", + "../api:sequence_checker", + "../api/task_queue", + "synchronization:mutex", + "system:no_unique_address", + "system:rtc_export", + "task_utils:pending_task_safety_flag", + "task_utils:to_queued_task", + "third_party/sigslot", + ] + if (is_android) { + deps += [ ":ifaddrs_android" ] + } + if (is_win) { + deps += [ ":win32" ] + } + if (is_mac || is_ios) { + deps += [ "system:cocoa_threading" ] + } } rtc_source_set("socket_factory") { - # TODO(bugs.webrtc.org/9987): This build target will soon contain - # the following files: - # sources = [ - # "socket_factory.h", - # ] + sources = [ "socket_factory.h" ] + deps = [ + ":async_socket", + ":socket", + ] } -rtc_source_set("async_socket") { - # TODO(bugs.webrtc.org/9987): This build target will soon contain - # the following files: - # sources = [ - # "async_socket.cc", - # "async_socket.h", - # ] +rtc_library("async_socket") { + sources = [ + "async_socket.cc", + "async_socket.h", + ] + deps = [ + ":checks", + ":socket", + ":socket_address", + "third_party/sigslot", + ] } -rtc_source_set("socket") { - # TODO(bugs.webrtc.org/9987): This build target will soon contain - # the following files: - # sources = [ - # "socket.cc", - # "socket.h", - # ] +rtc_library("socket") { + sources = [ + "socket.cc", + "socket.h", + ] + deps = [ + ":macromagic", + ":socket_address", + ] + if (is_win) { + deps += [ ":win32" ] + } } rtc_source_set("network_constants") { - # TODO(bugs.webrtc.org/9987): This build target will soon contain - # the following files: - # sources = [ - # "network_constants.h", - # ] + sources = [ + "network_constants.cc", + "network_constants.h", + ] + deps = [ ":checks" ] } if (is_android) { - rtc_source_set("ifaddrs_android") { - # TODO(bugs.webrtc.org/9987): This build target will soon contain - # the following files: - # sources = [ - # "ifaddrs_android.cc", - # "ifaddrs_android.h", - # ] + rtc_library("ifaddrs_android") { + sources = [ + "ifaddrs_android.cc", + "ifaddrs_android.h", + ] + libs = [ + "log", + "GLESv2", + ] } } if (is_win) { - rtc_source_set("win32") { + rtc_library("win32") { sources = [ "win32.cc", "win32.h", @@ -831,19 +904,29 @@ rtc_library("rtc_base") { libs = [] defines = [] deps = [ + ":async_resolver_interface", + ":async_socket", ":checks", - ":deprecation", + ":ip_address", + ":network_constants", + ":null_socket_server", ":rtc_task_queue", + ":socket", + ":socket_address", + ":socket_factory", + ":socket_server", ":stringutils", + ":threading", "../api:array_view", "../api:function_view", + "../api:refcountedbase", "../api:scoped_refptr", + "../api:sequence_checker", "../api/numerics", "../api/task_queue", "../system_wrappers:field_trial", "network:sent_packet", "synchronization:mutex", - "synchronization:sequence_checker", "system:file_wrapper", "system:inline", "system:no_unique_address", @@ -870,10 +953,6 @@ rtc_library("rtc_base") { "async_invoker_inl.h", "async_packet_socket.cc", "async_packet_socket.h", - "async_resolver_interface.cc", - "async_resolver_interface.h", - "async_socket.cc", - "async_socket.h", "async_tcp_socket.cc", "async_tcp_socket.h", "async_udp_socket.cc", @@ -884,8 +963,6 @@ rtc_library("rtc_base") { "crypt_string.h", "data_rate_limiter.cc", "data_rate_limiter.h", - "deprecated/signal_thread.cc", - "deprecated/signal_thread.h", "dscp.h", "file_rotating_stream.cc", "file_rotating_stream.h", @@ -893,30 +970,15 @@ rtc_library("rtc_base") { "helpers.h", "http_common.cc", "http_common.h", - "ip_address.cc", - "ip_address.h", - "keep_ref_until_done.h", "mdns_responder_interface.h", "message_digest.cc", "message_digest.h", - "message_handler.cc", - "message_handler.h", "net_helper.cc", "net_helper.h", - "net_helpers.cc", - "net_helpers.h", "network.cc", "network.h", - "network_constants.cc", - "network_constants.h", - "network_monitor.cc", - "network_monitor.h", - "network_monitor_factory.cc", - "network_monitor_factory.h", "network_route.cc", "network_route.h", - "null_socket_server.cc", - "null_socket_server.h", "openssl.h", "openssl_adapter.cc", "openssl_adapter.h", @@ -930,26 +992,17 @@ rtc_library("rtc_base") { "openssl_stream_adapter.h", "openssl_utility.cc", "openssl_utility.h", - "physical_socket_server.cc", - "physical_socket_server.h", "proxy_info.cc", "proxy_info.h", "rtc_certificate.cc", "rtc_certificate.h", "rtc_certificate_generator.cc", "rtc_certificate_generator.h", - "signal_thread.h", "sigslot_repeater.h", - "socket.cc", - "socket.h", "socket_adapters.cc", "socket_adapters.h", - "socket_address.cc", - "socket_address.h", "socket_address_pair.cc", "socket_address_pair.h", - "socket_factory.h", - "socket_server.h", "ssl_adapter.cc", "ssl_adapter.h", "ssl_certificate.cc", @@ -962,9 +1015,6 @@ rtc_library("rtc_base") { "ssl_stream_adapter.h", "stream.cc", "stream.h", - "thread.cc", - "thread.h", - "thread_message.h", "unique_id_generator.cc", "unique_id_generator.h", ] @@ -988,10 +1038,8 @@ rtc_library("rtc_base") { if (build_with_chromium) { include_dirs = [ "../../boringssl/src/include" ] - public_configs += [ ":rtc_base_chromium_config" ] } else { sources += [ - "callback.h", "log_sinks.cc", "log_sinks.h", "rolling_accumulator.h", @@ -1018,20 +1066,11 @@ rtc_library("rtc_base") { } if (is_android) { - sources += [ - "ifaddrs_android.cc", - "ifaddrs_android.h", - ] - - libs += [ - "log", - "GLESv2", - ] + deps += [ ":ifaddrs_android" ] } if (is_ios || is_mac) { sources += [ "mac_ifaddrs_converter.cc" ] - deps += [ "system:cocoa_threading" ] } if (is_linux || is_chromeos) { @@ -1086,6 +1125,7 @@ rtc_library("gunit_helpers") { ":rtc_base", ":rtc_base_tests_utils", ":stringutils", + ":threading", "../test:test_support", ] absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] @@ -1098,10 +1138,10 @@ rtc_library("testclient") { "test_client.h", ] deps = [ - ":criticalsection", ":gunit_helpers", ":rtc_base", ":rtc_base_tests_utils", + ":threading", ":timeutils", "synchronization:mutex", ] @@ -1160,12 +1200,20 @@ rtc_library("rtc_base_tests_utils") { "virtual_socket_server.h", ] deps = [ + ":async_socket", ":checks", + ":ip_address", ":rtc_base", + ":socket", + ":socket_address", + ":socket_factory", + ":socket_server", + ":threading", "../api/units:time_delta", "../api/units:timestamp", "memory:fifo_buffer", "synchronization:mutex", + "task_utils:to_queued_task", "third_party/sigslot", ] absl_deps = [ @@ -1216,138 +1264,6 @@ if (rtc_include_tests) { ] } - rtc_library("rtc_base_nonparallel_tests") { - testonly = true - - sources = [ - "cpu_time_unittest.cc", - "file_rotating_stream_unittest.cc", - "null_socket_server_unittest.cc", - "physical_socket_server_unittest.cc", - "socket_address_unittest.cc", - "socket_unittest.cc", - "socket_unittest.h", - ] - deps = [ - ":checks", - ":gunit_helpers", - ":rtc_base", - ":rtc_base_tests_utils", - ":testclient", - "../system_wrappers", - "../test:fileutils", - "../test:test_main", - "../test:test_support", - "third_party/sigslot", - "//testing/gtest", - ] - absl_deps = [ "//third_party/abseil-cpp/absl/memory" ] - if (is_win) { - sources += [ "win32_socket_server_unittest.cc" ] - } - } - - rtc_library("rtc_base_approved_unittests") { - testonly = true - sources = [ - "atomic_ops_unittest.cc", - "base64_unittest.cc", - "bind_unittest.cc", - "bit_buffer_unittest.cc", - "bounded_inline_vector_unittest.cc", - "buffer_queue_unittest.cc", - "buffer_unittest.cc", - "byte_buffer_unittest.cc", - "byte_order_unittest.cc", - "checks_unittest.cc", - "copy_on_write_buffer_unittest.cc", - "deprecated/recursive_critical_section_unittest.cc", - "event_tracer_unittest.cc", - "event_unittest.cc", - "logging_unittest.cc", - "numerics/divide_round_unittest.cc", - "numerics/histogram_percentile_counter_unittest.cc", - "numerics/mod_ops_unittest.cc", - "numerics/moving_max_counter_unittest.cc", - "numerics/safe_compare_unittest.cc", - "numerics/safe_minmax_unittest.cc", - "numerics/sample_counter_unittest.cc", - "one_time_event_unittest.cc", - "platform_thread_unittest.cc", - "random_unittest.cc", - "rate_limiter_unittest.cc", - "rate_statistics_unittest.cc", - "rate_tracker_unittest.cc", - "ref_counted_object_unittest.cc", - "sanitizer_unittest.cc", - "string_encode_unittest.cc", - "string_to_number_unittest.cc", - "string_utils_unittest.cc", - "strings/string_builder_unittest.cc", - "strings/string_format_unittest.cc", - "swap_queue_unittest.cc", - "thread_annotations_unittest.cc", - "thread_checker_unittest.cc", - "time_utils_unittest.cc", - "timestamp_aligner_unittest.cc", - "virtual_socket_unittest.cc", - "zero_memory_unittest.cc", - ] - if (is_win) { - sources += [ "win/windows_version_unittest.cc" ] - } - deps = [ - ":bounded_inline_vector", - ":checks", - ":divide_round", - ":gunit_helpers", - ":rate_limiter", - ":rtc_base", - ":rtc_base_approved", - ":rtc_base_tests_utils", - ":rtc_numerics", - ":rtc_task_queue", - ":safe_compare", - ":safe_minmax", - ":sanitizer", - ":stringutils", - ":testclient", - "../api:array_view", - "../api:scoped_refptr", - "../api/numerics", - "../api/units:time_delta", - "../system_wrappers", - "../test:fileutils", - "../test:test_main", - "../test:test_support", - "memory:unittests", - "synchronization:mutex", - "task_utils:to_queued_task", - "third_party/base64", - "third_party/sigslot", - ] - absl_deps = [ - "//third_party/abseil-cpp/absl/base:core_headers", - "//third_party/abseil-cpp/absl/memory", - ] - } - - rtc_library("rtc_task_queue_unittests") { - testonly = true - - sources = [ "task_queue_unittest.cc" ] - deps = [ - ":gunit_helpers", - ":rtc_base_approved", - ":rtc_base_tests_utils", - ":rtc_task_queue", - ":task_queue_for_test", - "../test:test_main", - "../test:test_support", - ] - absl_deps = [ "//third_party/abseil-cpp/absl/memory" ] - } - rtc_library("rtc_operations_chain_unittests") { testonly = true @@ -1358,138 +1274,298 @@ if (rtc_include_tests) { ":rtc_base_approved", ":rtc_event", ":rtc_operations_chain", + ":threading", "../test:test_support", ] } - rtc_library("weak_ptr_unittests") { - testonly = true - - sources = [ "weak_ptr_unittest.cc" ] - deps = [ - ":gunit_helpers", - ":rtc_base_approved", - ":rtc_base_tests_utils", - ":rtc_event", - ":task_queue_for_test", - ":weak_ptr", - "../test:test_main", - "../test:test_support", - ] - } + if (!build_with_chromium) { + rtc_library("rtc_base_nonparallel_tests") { + testonly = true + + sources = [ + "cpu_time_unittest.cc", + "file_rotating_stream_unittest.cc", + "null_socket_server_unittest.cc", + "physical_socket_server_unittest.cc", + "socket_address_unittest.cc", + "socket_unittest.cc", + "socket_unittest.h", + ] + deps = [ + ":async_socket", + ":checks", + ":gunit_helpers", + ":ip_address", + ":net_helpers", + ":null_socket_server", + ":rtc_base", + ":rtc_base_tests_utils", + ":socket", + ":socket_address", + ":socket_server", + ":testclient", + ":threading", + "../system_wrappers", + "../test:fileutils", + "../test:test_main", + "../test:test_support", + "third_party/sigslot", + "//testing/gtest", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/memory" ] + if (is_win) { + sources += [ "win32_socket_server_unittest.cc" ] + } + } - rtc_library("rtc_numerics_unittests") { - testonly = true + rtc_library("rtc_base_approved_unittests") { + testonly = true + sources = [ + "atomic_ops_unittest.cc", + "base64_unittest.cc", + "bit_buffer_unittest.cc", + "bounded_inline_vector_unittest.cc", + "buffer_queue_unittest.cc", + "buffer_unittest.cc", + "byte_buffer_unittest.cc", + "byte_order_unittest.cc", + "checks_unittest.cc", + "copy_on_write_buffer_unittest.cc", + "deprecated/recursive_critical_section_unittest.cc", + "event_tracer_unittest.cc", + "event_unittest.cc", + "hash_unittest.cc", + "logging_unittest.cc", + "numerics/divide_round_unittest.cc", + "numerics/histogram_percentile_counter_unittest.cc", + "numerics/mod_ops_unittest.cc", + "numerics/moving_max_counter_unittest.cc", + "numerics/safe_compare_unittest.cc", + "numerics/safe_minmax_unittest.cc", + "numerics/sample_counter_unittest.cc", + "one_time_event_unittest.cc", + "platform_thread_unittest.cc", + "random_unittest.cc", + "rate_limiter_unittest.cc", + "rate_statistics_unittest.cc", + "rate_tracker_unittest.cc", + "ref_counted_object_unittest.cc", + "sanitizer_unittest.cc", + "string_encode_unittest.cc", + "string_to_number_unittest.cc", + "string_utils_unittest.cc", + "strings/string_builder_unittest.cc", + "strings/string_format_unittest.cc", + "swap_queue_unittest.cc", + "thread_annotations_unittest.cc", + "time_utils_unittest.cc", + "timestamp_aligner_unittest.cc", + "virtual_socket_unittest.cc", + "zero_memory_unittest.cc", + ] + if (is_win) { + sources += [ "win/windows_version_unittest.cc" ] + } + deps = [ + ":async_socket", + ":bounded_inline_vector", + ":checks", + ":criticalsection", + ":divide_round", + ":gunit_helpers", + ":ip_address", + ":null_socket_server", + ":rate_limiter", + ":rtc_base", + ":rtc_base_approved", + ":rtc_base_tests_utils", + ":rtc_numerics", + ":rtc_task_queue", + ":safe_compare", + ":safe_minmax", + ":sanitizer", + ":socket", + ":socket_address", + ":socket_server", + ":stringutils", + ":testclient", + ":threading", + "../api:array_view", + "../api:scoped_refptr", + "../api/numerics", + "../api/units:time_delta", + "../system_wrappers", + "../test:fileutils", + "../test:test_main", + "../test:test_support", + "containers:unittests", + "memory:unittests", + "synchronization:mutex", + "task_utils:to_queued_task", + "third_party/base64", + "third_party/sigslot", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/base:core_headers", + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/types:optional", + ] + } - sources = [ - "numerics/event_based_exponential_moving_average_unittest.cc", - "numerics/exp_filter_unittest.cc", - "numerics/moving_average_unittest.cc", - "numerics/moving_median_filter_unittest.cc", - "numerics/percentile_filter_unittest.cc", - "numerics/running_statistics_unittest.cc", - "numerics/sequence_number_util_unittest.cc", - ] - deps = [ - ":rtc_base_approved", - ":rtc_numerics", - "../test:test_main", - "../test:test_support", - ] - absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container" ] - } + rtc_library("rtc_task_queue_unittests") { + testonly = true + + sources = [ "task_queue_unittest.cc" ] + deps = [ + ":gunit_helpers", + ":rtc_base_approved", + ":rtc_base_tests_utils", + ":rtc_task_queue", + ":task_queue_for_test", + "../test:test_main", + "../test:test_support", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/memory" ] + } - rtc_library("rtc_json_unittests") { - testonly = true + rtc_library("weak_ptr_unittests") { + testonly = true + + sources = [ "weak_ptr_unittest.cc" ] + deps = [ + ":gunit_helpers", + ":rtc_base_approved", + ":rtc_base_tests_utils", + ":rtc_event", + ":task_queue_for_test", + ":weak_ptr", + "../test:test_main", + "../test:test_support", + ] + } - sources = [ "strings/json_unittest.cc" ] - deps = [ - ":gunit_helpers", - ":rtc_base_tests_utils", - ":rtc_json", - "../test:test_main", - "../test:test_support", - ] - } + rtc_library("rtc_numerics_unittests") { + testonly = true + + sources = [ + "numerics/event_based_exponential_moving_average_unittest.cc", + "numerics/exp_filter_unittest.cc", + "numerics/moving_average_unittest.cc", + "numerics/moving_median_filter_unittest.cc", + "numerics/percentile_filter_unittest.cc", + "numerics/running_statistics_unittest.cc", + "numerics/sequence_number_util_unittest.cc", + ] + deps = [ + ":rtc_base_approved", + ":rtc_numerics", + "../test:test_main", + "../test:test_support", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container" ] + } - rtc_library("rtc_base_unittests") { - testonly = true - defines = [] + rtc_library("rtc_json_unittests") { + testonly = true - sources = [ - "callback_unittest.cc", - "crc32_unittest.cc", - "data_rate_limiter_unittest.cc", - "deprecated/signal_thread_unittest.cc", - "fake_clock_unittest.cc", - "helpers_unittest.cc", - "ip_address_unittest.cc", - "memory_usage_unittest.cc", - "message_digest_unittest.cc", - "nat_unittest.cc", - "network_route_unittest.cc", - "network_unittest.cc", - "proxy_unittest.cc", - "rolling_accumulator_unittest.cc", - "rtc_certificate_generator_unittest.cc", - "rtc_certificate_unittest.cc", - "sigslot_tester_unittest.cc", - "test_client_unittest.cc", - "thread_unittest.cc", - "unique_id_generator_unittest.cc", - ] - deps = [ - ":checks", - ":gunit_helpers", - ":rtc_base_tests_utils", - ":stringutils", - ":testclient", - "../api:array_view", - "../api/task_queue", - "../api/task_queue:task_queue_test", - "../test:field_trial", - "../test:fileutils", - "../test:rtc_expect_death", - "../test:test_main", - "../test:test_support", - "memory:fifo_buffer", - "synchronization:mutex", - "synchronization:synchronization_unittests", - "task_utils:pending_task_safety_flag", - "task_utils:to_queued_task", - "third_party/sigslot", - ] - if (is_win) { - sources += [ - "win32_unittest.cc", - "win32_window_unittest.cc", + sources = [ "strings/json_unittest.cc" ] + deps = [ + ":gunit_helpers", + ":rtc_base_tests_utils", + ":rtc_json", + "../test:test_main", + "../test:test_support", ] - deps += [ ":win32" ] } - if (is_posix || is_fuchsia) { - sources += [ - "openssl_adapter_unittest.cc", - "openssl_session_cache_unittest.cc", - "openssl_utility_unittest.cc", - "ssl_adapter_unittest.cc", - "ssl_identity_unittest.cc", - "ssl_stream_adapter_unittest.cc", + + rtc_library("rtc_base_unittests") { + testonly = true + defines = [] + + sources = [ + "crc32_unittest.cc", + "data_rate_limiter_unittest.cc", + "fake_clock_unittest.cc", + "helpers_unittest.cc", + "ip_address_unittest.cc", + "memory_usage_unittest.cc", + "message_digest_unittest.cc", + "nat_unittest.cc", + "network_route_unittest.cc", + "network_unittest.cc", + "proxy_unittest.cc", + "rolling_accumulator_unittest.cc", + "rtc_certificate_generator_unittest.cc", + "rtc_certificate_unittest.cc", + "sigslot_tester_unittest.cc", + "test_client_unittest.cc", + "thread_unittest.cc", + "unique_id_generator_unittest.cc", ] - } - absl_deps = [ - "//third_party/abseil-cpp/absl/algorithm:container", - "//third_party/abseil-cpp/absl/memory", - "//third_party/abseil-cpp/absl/strings", - "//third_party/abseil-cpp/absl/types:optional", - ] - public_deps = [ ":rtc_base" ] # no-presubmit-check TODO(webrtc:8603) - if (build_with_chromium) { - include_dirs = [ "../../boringssl/src/include" ] - } - if (rtc_build_ssl) { - deps += [ "//third_party/boringssl" ] - } else { - configs += [ ":external_ssl_library" ] + deps = [ + ":async_socket", + ":checks", + ":gunit_helpers", + ":ip_address", + ":net_helpers", + ":null_socket_server", + ":rtc_base_tests_utils", + ":socket_address", + ":socket_factory", + ":socket_server", + ":stringutils", + ":testclient", + ":threading", + "../api:array_view", + "../api/task_queue", + "../api/task_queue:task_queue_test", + "../test:field_trial", + "../test:fileutils", + "../test:rtc_expect_death", + "../test:test_main", + "../test:test_support", + "memory:fifo_buffer", + "synchronization:mutex", + "task_utils:pending_task_safety_flag", + "task_utils:to_queued_task", + "third_party/sigslot", + ] + if (enable_google_benchmarks) { + deps += [ "synchronization:synchronization_unittests" ] + } + if (is_win) { + sources += [ + "win32_unittest.cc", + "win32_window_unittest.cc", + ] + deps += [ ":win32" ] + } + if (is_posix || is_fuchsia) { + sources += [ + "openssl_adapter_unittest.cc", + "openssl_session_cache_unittest.cc", + "openssl_utility_unittest.cc", + "ssl_adapter_unittest.cc", + "ssl_identity_unittest.cc", + "ssl_stream_adapter_unittest.cc", + ] + } + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + public_deps = [ ":rtc_base" ] # no-presubmit-check TODO(webrtc:8603) + if (build_with_chromium) { + include_dirs = [ "../../boringssl/src/include" ] + } + if (rtc_build_ssl) { + deps += [ "//third_party/boringssl" ] + } else { + configs += [ ":external_ssl_library" ] + } } } } diff --git a/rtc_base/DEPS b/rtc_base/DEPS index c9f7dc5898..3fdc4bc10e 100644 --- a/rtc_base/DEPS +++ b/rtc_base/DEPS @@ -12,4 +12,7 @@ specific_include_rules = { "gunit\.h": [ "+testing/base/public/gunit.h" ], + "logging\.cc": [ + "+absl/synchronization" + ], } diff --git a/rtc_base/async_invoker.cc b/rtc_base/async_invoker.cc index 8b410a4561..87d039373d 100644 --- a/rtc_base/async_invoker.cc +++ b/rtc_base/async_invoker.cc @@ -15,12 +15,12 @@ namespace rtc { -AsyncInvoker::AsyncInvoker() +DEPRECATED_AsyncInvoker::DEPRECATED_AsyncInvoker() : pending_invocations_(0), - invocation_complete_(new RefCountedObject()), + invocation_complete_(make_ref_counted()), destroying_(false) {} -AsyncInvoker::~AsyncInvoker() { +DEPRECATED_AsyncInvoker::~DEPRECATED_AsyncInvoker() { destroying_.store(true, std::memory_order_relaxed); // Messages for this need to be cleared *before* our destructor is complete. ThreadManager::Clear(this); @@ -37,7 +37,7 @@ AsyncInvoker::~AsyncInvoker() { } } -void AsyncInvoker::OnMessage(Message* msg) { +void DEPRECATED_AsyncInvoker::OnMessage(Message* msg) { // Get the AsyncClosure shared ptr from this message's data. ScopedMessageData* data = static_cast*>(msg->pdata); @@ -46,7 +46,8 @@ void AsyncInvoker::OnMessage(Message* msg) { delete data; } -void AsyncInvoker::Flush(Thread* thread, uint32_t id /*= MQID_ANY*/) { +void DEPRECATED_AsyncInvoker::Flush(Thread* thread, + uint32_t id /*= MQID_ANY*/) { // If the destructor is waiting for invocations to finish, don't start // running even more tasks. if (destroying_.load(std::memory_order_relaxed)) @@ -55,7 +56,7 @@ void AsyncInvoker::Flush(Thread* thread, uint32_t id /*= MQID_ANY*/) { // Run this on |thread| to reduce the number of context switches. if (Thread::Current() != thread) { thread->Invoke(RTC_FROM_HERE, - Bind(&AsyncInvoker::Flush, this, thread, id)); + [this, thread, id] { Flush(thread, id); }); return; } @@ -67,14 +68,14 @@ void AsyncInvoker::Flush(Thread* thread, uint32_t id /*= MQID_ANY*/) { } } -void AsyncInvoker::Clear() { +void DEPRECATED_AsyncInvoker::Clear() { ThreadManager::Clear(this); } -void AsyncInvoker::DoInvoke(const Location& posted_from, - Thread* thread, - std::unique_ptr closure, - uint32_t id) { +void DEPRECATED_AsyncInvoker::DoInvoke(const Location& posted_from, + Thread* thread, + std::unique_ptr closure, + uint32_t id) { if (destroying_.load(std::memory_order_relaxed)) { // Note that this may be expected, if the application is AsyncInvoking // tasks that AsyncInvoke other tasks. But otherwise it indicates a race @@ -87,11 +88,12 @@ void AsyncInvoker::DoInvoke(const Location& posted_from, new ScopedMessageData(std::move(closure))); } -void AsyncInvoker::DoInvokeDelayed(const Location& posted_from, - Thread* thread, - std::unique_ptr closure, - uint32_t delay_ms, - uint32_t id) { +void DEPRECATED_AsyncInvoker::DoInvokeDelayed( + const Location& posted_from, + Thread* thread, + std::unique_ptr closure, + uint32_t delay_ms, + uint32_t id) { if (destroying_.load(std::memory_order_relaxed)) { // See above comment. RTC_LOG(LS_WARNING) << "Tried to invoke while destroying the invoker."; @@ -101,7 +103,7 @@ void AsyncInvoker::DoInvokeDelayed(const Location& posted_from, new ScopedMessageData(std::move(closure))); } -AsyncClosure::AsyncClosure(AsyncInvoker* invoker) +AsyncClosure::AsyncClosure(DEPRECATED_AsyncInvoker* invoker) : invoker_(invoker), invocation_complete_(invoker_->invocation_complete_) { invoker_->pending_invocations_.fetch_add(1, std::memory_order_relaxed); } diff --git a/rtc_base/async_invoker.h b/rtc_base/async_invoker.h index 983e710bcd..fd42ca76de 100644 --- a/rtc_base/async_invoker.h +++ b/rtc_base/async_invoker.h @@ -15,9 +15,9 @@ #include #include +#include "absl/base/attributes.h" #include "api/scoped_refptr.h" #include "rtc_base/async_invoker_inl.h" -#include "rtc_base/bind.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/event.h" #include "rtc_base/ref_counted_object.h" @@ -87,10 +87,10 @@ namespace rtc { // destruction. This can be done by starting each chain of invocations on the // same thread on which it will be destroyed, or by using some other // synchronization method. -class AsyncInvoker : public MessageHandlerAutoCleanup { +class DEPRECATED_AsyncInvoker : public MessageHandlerAutoCleanup { public: - AsyncInvoker(); - ~AsyncInvoker() override; + DEPRECATED_AsyncInvoker(); + ~DEPRECATED_AsyncInvoker() override; // Call |functor| asynchronously on |thread|, with no callback upon // completion. Returns immediately. @@ -157,7 +157,7 @@ class AsyncInvoker : public MessageHandlerAutoCleanup { // an AsyncClosure's destructor that's about to call // "invocation_complete_->Set()", it's not dereferenced after being // destroyed. - scoped_refptr> invocation_complete_; + rtc::Ref::Ptr invocation_complete_; // This flag is used to ensure that if an application AsyncInvokes tasks that // recursively AsyncInvoke other tasks ad infinitum, the cycle eventually @@ -166,9 +166,12 @@ class AsyncInvoker : public MessageHandlerAutoCleanup { friend class AsyncClosure; - RTC_DISALLOW_COPY_AND_ASSIGN(AsyncInvoker); + RTC_DISALLOW_COPY_AND_ASSIGN(DEPRECATED_AsyncInvoker); }; +using AsyncInvoker ABSL_DEPRECATED("bugs.webrtc.org/12339") = + DEPRECATED_AsyncInvoker; + } // namespace rtc #endif // RTC_BASE_ASYNC_INVOKER_H_ diff --git a/rtc_base/async_invoker_inl.h b/rtc_base/async_invoker_inl.h index 6307afe220..9fb328782c 100644 --- a/rtc_base/async_invoker_inl.h +++ b/rtc_base/async_invoker_inl.h @@ -12,7 +12,6 @@ #define RTC_BASE_ASYNC_INVOKER_INL_H_ #include "api/scoped_refptr.h" -#include "rtc_base/bind.h" #include "rtc_base/event.h" #include "rtc_base/message_handler.h" #include "rtc_base/ref_counted_object.h" @@ -22,32 +21,33 @@ namespace rtc { -class AsyncInvoker; +class DEPRECATED_AsyncInvoker; // Helper class for AsyncInvoker. Runs a task and triggers a callback // on the calling thread if necessary. class AsyncClosure { public: - explicit AsyncClosure(AsyncInvoker* invoker); + explicit AsyncClosure(DEPRECATED_AsyncInvoker* invoker); virtual ~AsyncClosure(); // Runs the asynchronous task, and triggers a callback to the calling // thread if needed. Should be called from the target thread. virtual void Execute() = 0; protected: - AsyncInvoker* invoker_; + DEPRECATED_AsyncInvoker* invoker_; // Reference counted so that if the AsyncInvoker destructor finishes before // an AsyncClosure's destructor that's about to call // "invocation_complete_->Set()", it's not dereferenced after being // destroyed. - scoped_refptr> invocation_complete_; + rtc::Ref::Ptr invocation_complete_; }; // Simple closure that doesn't trigger a callback for the calling thread. template class FireAndForgetAsyncClosure : public AsyncClosure { public: - explicit FireAndForgetAsyncClosure(AsyncInvoker* invoker, FunctorT&& functor) + explicit FireAndForgetAsyncClosure(DEPRECATED_AsyncInvoker* invoker, + FunctorT&& functor) : AsyncClosure(invoker), functor_(std::forward(functor)) {} virtual void Execute() { functor_(); } diff --git a/rtc_base/async_resolver.cc b/rtc_base/async_resolver.cc new file mode 100644 index 0000000000..d482b4e681 --- /dev/null +++ b/rtc_base/async_resolver.cc @@ -0,0 +1,206 @@ +/* + * Copyright 2008 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "rtc_base/async_resolver.h" + +#include +#include +#include + +#include "api/ref_counted_base.h" +#include "rtc_base/synchronization/mutex.h" +#include "rtc_base/thread_annotations.h" + +#if defined(WEBRTC_WIN) +#include +#include + +#include "rtc_base/win32.h" +#endif +#if defined(WEBRTC_POSIX) && !defined(__native_client__) +#if defined(WEBRTC_ANDROID) +#include "rtc_base/ifaddrs_android.h" +#else +#include +#endif +#endif // defined(WEBRTC_POSIX) && !defined(__native_client__) + +#include "api/task_queue/task_queue_base.h" +#include "rtc_base/ip_address.h" +#include "rtc_base/logging.h" +#include "rtc_base/platform_thread.h" +#include "rtc_base/task_queue.h" +#include "rtc_base/task_utils/to_queued_task.h" +#include "rtc_base/third_party/sigslot/sigslot.h" // for signal_with_thread... + +namespace rtc { + +int ResolveHostname(const std::string& hostname, + int family, + std::vector* addresses) { +#ifdef __native_client__ + RTC_NOTREACHED(); + RTC_LOG(LS_WARNING) << "ResolveHostname() is not implemented for NaCl"; + return -1; +#else // __native_client__ + if (!addresses) { + return -1; + } + addresses->clear(); + struct addrinfo* result = nullptr; + struct addrinfo hints = {0}; + hints.ai_family = family; + // |family| here will almost always be AF_UNSPEC, because |family| comes from + // AsyncResolver::addr_.family(), which comes from a SocketAddress constructed + // with a hostname. When a SocketAddress is constructed with a hostname, its + // family is AF_UNSPEC. However, if someday in the future we construct + // a SocketAddress with both a hostname and a family other than AF_UNSPEC, + // then it would be possible to get a specific family value here. + + // The behavior of AF_UNSPEC is roughly "get both ipv4 and ipv6", as + // documented by the various operating systems: + // Linux: http://man7.org/linux/man-pages/man3/getaddrinfo.3.html + // Windows: https://msdn.microsoft.com/en-us/library/windows/desktop/ + // ms738520(v=vs.85).aspx + // Mac: https://developer.apple.com/legacy/library/documentation/Darwin/ + // Reference/ManPages/man3/getaddrinfo.3.html + // Android (source code, not documentation): + // https://android.googlesource.com/platform/bionic/+/ + // 7e0bfb511e85834d7c6cb9631206b62f82701d60/libc/netbsd/net/getaddrinfo.c#1657 + hints.ai_flags = AI_ADDRCONFIG; + int ret = getaddrinfo(hostname.c_str(), nullptr, &hints, &result); + if (ret != 0) { + return ret; + } + struct addrinfo* cursor = result; + for (; cursor; cursor = cursor->ai_next) { + if (family == AF_UNSPEC || cursor->ai_family == family) { + IPAddress ip; + if (IPFromAddrInfo(cursor, &ip)) { + addresses->push_back(ip); + } + } + } + freeaddrinfo(result); + return 0; +#endif // !__native_client__ +} + +struct AsyncResolver::State : public RefCountedBase { + webrtc::Mutex mutex; + enum class Status { + kLive, + kDead + } status RTC_GUARDED_BY(mutex) = Status::kLive; +}; + +AsyncResolver::AsyncResolver() : error_(-1), state_(new State) {} + +AsyncResolver::~AsyncResolver() { + RTC_DCHECK_RUN_ON(&sequence_checker_); + + // Ensure the thread isn't using a stale reference to the current task queue, + // or calling into ResolveDone post destruction. + webrtc::MutexLock lock(&state_->mutex); + state_->status = State::Status::kDead; +} + +void RunResolution(void* obj) { + std::function* function_ptr = + static_cast*>(obj); + (*function_ptr)(); + delete function_ptr; +} + +void AsyncResolver::Start(const SocketAddress& addr) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + RTC_DCHECK(!destroy_called_); + addr_ = addr; + PlatformThread::SpawnDetached( + [this, addr, caller_task_queue = webrtc::TaskQueueBase::Current(), + state = state_] { + std::vector addresses; + int error = + ResolveHostname(addr.hostname().c_str(), addr.family(), &addresses); + webrtc::MutexLock lock(&state->mutex); + if (state->status == State::Status::kLive) { + caller_task_queue->PostTask(webrtc::ToQueuedTask( + [this, error, addresses = std::move(addresses), state] { + bool live; + { + // ResolveDone can lead to instance destruction, so make sure + // we don't deadlock. + webrtc::MutexLock lock(&state->mutex); + live = state->status == State::Status::kLive; + } + if (live) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + ResolveDone(std::move(addresses), error); + } + })); + } + }, + "AsyncResolver"); +} + +bool AsyncResolver::GetResolvedAddress(int family, SocketAddress* addr) const { + RTC_DCHECK_RUN_ON(&sequence_checker_); + RTC_DCHECK(!destroy_called_); + if (error_ != 0 || addresses_.empty()) + return false; + + *addr = addr_; + for (size_t i = 0; i < addresses_.size(); ++i) { + if (family == addresses_[i].family()) { + addr->SetResolvedIP(addresses_[i]); + return true; + } + } + return false; +} + +int AsyncResolver::GetError() const { + RTC_DCHECK_RUN_ON(&sequence_checker_); + RTC_DCHECK(!destroy_called_); + return error_; +} + +void AsyncResolver::Destroy(bool wait) { + // Some callers have trouble guaranteeing that Destroy is called on the + // sequence guarded by |sequence_checker_|. + // RTC_DCHECK_RUN_ON(&sequence_checker_); + RTC_DCHECK(!destroy_called_); + destroy_called_ = true; + MaybeSelfDestruct(); +} + +const std::vector& AsyncResolver::addresses() const { + RTC_DCHECK_RUN_ON(&sequence_checker_); + RTC_DCHECK(!destroy_called_); + return addresses_; +} + +void AsyncResolver::ResolveDone(std::vector addresses, int error) { + addresses_ = addresses; + error_ = error; + recursion_check_ = true; + SignalDone(this); + MaybeSelfDestruct(); +} + +void AsyncResolver::MaybeSelfDestruct() { + if (!recursion_check_) { + delete this; + } else { + recursion_check_ = false; + } +} + +} // namespace rtc diff --git a/rtc_base/async_resolver.h b/rtc_base/async_resolver.h index 3c3ad82870..0c053eed81 100644 --- a/rtc_base/async_resolver.h +++ b/rtc_base/async_resolver.h @@ -11,7 +11,65 @@ #ifndef RTC_BASE_ASYNC_RESOLVER_H_ #define RTC_BASE_ASYNC_RESOLVER_H_ -// Placeholder header for the refactoring in: -// https://webrtc-review.googlesource.com/c/src/+/196903 +#if defined(WEBRTC_POSIX) +#include +#elif WEBRTC_WIN +#include // NOLINT +#endif + +#include + +#include "api/sequence_checker.h" +#include "rtc_base/async_resolver_interface.h" +#include "rtc_base/event.h" +#include "rtc_base/ip_address.h" +#include "rtc_base/ref_counted_object.h" +#include "rtc_base/socket_address.h" +#include "rtc_base/system/no_unique_address.h" +#include "rtc_base/system/rtc_export.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/thread.h" +#include "rtc_base/thread_annotations.h" + +namespace rtc { + +// AsyncResolver will perform async DNS resolution, signaling the result on +// the SignalDone from AsyncResolverInterface when the operation completes. +// +// This class is thread-compatible, and all methods and destruction needs to +// happen from the same rtc::Thread, except for Destroy which is allowed to +// happen on another context provided it's not happening concurrently to another +// public API call, and is the last access to the object. +class RTC_EXPORT AsyncResolver : public AsyncResolverInterface { + public: + AsyncResolver(); + ~AsyncResolver() override; + + void Start(const SocketAddress& addr) override; + bool GetResolvedAddress(int family, SocketAddress* addr) const override; + int GetError() const override; + void Destroy(bool wait) override; + + const std::vector& addresses() const; + + private: + // Fwd decl. + struct State; + + void ResolveDone(std::vector addresses, int error) + RTC_EXCLUSIVE_LOCKS_REQUIRED(sequence_checker_); + void MaybeSelfDestruct(); + + SocketAddress addr_ RTC_GUARDED_BY(sequence_checker_); + std::vector addresses_ RTC_GUARDED_BY(sequence_checker_); + int error_ RTC_GUARDED_BY(sequence_checker_); + bool recursion_check_ = + false; // Protects against SignalDone calling into Destroy. + bool destroy_called_ = false; + scoped_refptr state_; + RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker sequence_checker_; +}; + +} // namespace rtc #endif // RTC_BASE_ASYNC_RESOLVER_H_ diff --git a/rtc_base/bind.h b/rtc_base/bind.h deleted file mode 100644 index b61d189f7a..0000000000 --- a/rtc_base/bind.h +++ /dev/null @@ -1,282 +0,0 @@ -/* - * Copyright 2012 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -// Bind() is an overloaded function that converts method calls into function -// objects (aka functors). The method object is captured as a scoped_refptr<> if -// possible, and as a raw pointer otherwise. Any arguments to the method are -// captured by value. The return value of Bind is a stateful, nullary function -// object. Care should be taken about the lifetime of objects captured by -// Bind(); the returned functor knows nothing about the lifetime of a non -// ref-counted method object or any arguments passed by pointer, and calling the -// functor with a destroyed object will surely do bad things. -// -// To prevent the method object from being captured as a scoped_refptr<>, you -// can use Unretained. But this should only be done when absolutely necessary, -// and when the caller knows the extra reference isn't needed. -// -// Example usage: -// struct Foo { -// int Test1() { return 42; } -// int Test2() const { return 52; } -// int Test3(int x) { return x*x; } -// float Test4(int x, float y) { return x + y; } -// }; -// -// int main() { -// Foo foo; -// cout << rtc::Bind(&Foo::Test1, &foo)() << endl; -// cout << rtc::Bind(&Foo::Test2, &foo)() << endl; -// cout << rtc::Bind(&Foo::Test3, &foo, 3)() << endl; -// cout << rtc::Bind(&Foo::Test4, &foo, 7, 8.5f)() << endl; -// } -// -// Example usage of ref counted objects: -// struct Bar { -// int AddRef(); -// int Release(); -// -// void Test() {} -// void BindThis() { -// // The functor passed to AsyncInvoke() will keep this object alive. -// invoker.AsyncInvoke(RTC_FROM_HERE,rtc::Bind(&Bar::Test, this)); -// } -// }; -// -// int main() { -// rtc::scoped_refptr bar = new rtc::RefCountedObject(); -// auto functor = rtc::Bind(&Bar::Test, bar); -// bar = nullptr; -// // The functor stores an internal scoped_refptr, so this is safe. -// functor(); -// } -// - -#ifndef RTC_BASE_BIND_H_ -#define RTC_BASE_BIND_H_ - -#include -#include - -#include "api/scoped_refptr.h" - -#define NONAME - -namespace rtc { -namespace detail { -// This is needed because the template parameters in Bind can't be resolved -// if they're used both as parameters of the function pointer type and as -// parameters to Bind itself: the function pointer parameters are exact -// matches to the function prototype, but the parameters to bind have -// references stripped. This trick allows the compiler to dictate the Bind -// parameter types rather than deduce them. -template -struct identity { - typedef T type; -}; - -// IsRefCounted::value will be true for types that can be used in -// rtc::scoped_refptr, i.e. types that implements nullary functions AddRef() -// and Release(), regardless of their return types. AddRef() and Release() can -// be defined in T or any superclass of T. -template -class IsRefCounted { - // This is a complex implementation detail done with SFINAE. - - // Define types such that sizeof(Yes) != sizeof(No). - struct Yes { - char dummy[1]; - }; - struct No { - char dummy[2]; - }; - // Define two overloaded template functions with return types of different - // size. This way, we can use sizeof() on the return type to determine which - // function the compiler would have chosen. One function will be preferred - // over the other if it is possible to create it without compiler errors, - // otherwise the compiler will simply remove it, and default to the less - // preferred function. - template - static Yes test(R* r, decltype(r->AddRef(), r->Release(), 42)); - template - static No test(...); - - public: - // Trick the compiler to tell if it's possible to call AddRef() and Release(). - static const bool value = sizeof(test((T*)nullptr, 42)) == sizeof(Yes); -}; - -// TernaryTypeOperator is a helper class to select a type based on a static bool -// value. -template -struct TernaryTypeOperator {}; - -template -struct TernaryTypeOperator { - typedef IfTrueT type; -}; - -template -struct TernaryTypeOperator { - typedef IfFalseT type; -}; - -// PointerType::type will be scoped_refptr for ref counted types, and T* -// otherwise. -template -struct PointerType { - typedef typename TernaryTypeOperator::value, - scoped_refptr, - T*>::type type; -}; - -template -class UnretainedWrapper { - public: - explicit UnretainedWrapper(T* o) : ptr_(o) {} - T* get() const { return ptr_; } - - private: - T* ptr_; -}; - -} // namespace detail - -template -static inline detail::UnretainedWrapper Unretained(T* o) { - return detail::UnretainedWrapper(o); -} - -template -class MethodFunctor { - public: - MethodFunctor(MethodT method, ObjectT* object, Args... args) - : method_(method), object_(object), args_(args...) {} - R operator()() const { - return CallMethod(std::index_sequence_for()); - } - - private: - template - R CallMethod(std::index_sequence) const { - return (object_->*method_)(std::get(args_)...); - } - - MethodT method_; - typename detail::PointerType::type object_; - typename std::tuple::type...> args_; -}; - -template -class UnretainedMethodFunctor { - public: - UnretainedMethodFunctor(MethodT method, - detail::UnretainedWrapper object, - Args... args) - : method_(method), object_(object.get()), args_(args...) {} - R operator()() const { - return CallMethod(std::index_sequence_for()); - } - - private: - template - R CallMethod(std::index_sequence) const { - return (object_->*method_)(std::get(args_)...); - } - - MethodT method_; - ObjectT* object_; - typename std::tuple::type...> args_; -}; - -template -class Functor { - public: - Functor(const FunctorT& functor, Args... args) - : functor_(functor), args_(args...) {} - R operator()() const { - return CallFunction(std::index_sequence_for()); - } - - private: - template - R CallFunction(std::index_sequence) const { - return functor_(std::get(args_)...); - } - - FunctorT functor_; - typename std::tuple::type...> args_; -}; - -#define FP_T(x) R (ObjectT::*x)(Args...) - -template -MethodFunctor Bind( - FP_T(method), - ObjectT* object, - typename detail::identity::type... args) { - return MethodFunctor(method, object, - args...); -} - -template -MethodFunctor Bind( - FP_T(method), - const scoped_refptr& object, - typename detail::identity::type... args) { - return MethodFunctor(method, object.get(), - args...); -} - -template -UnretainedMethodFunctor Bind( - FP_T(method), - detail::UnretainedWrapper object, - typename detail::identity::type... args) { - return UnretainedMethodFunctor( - method, object, args...); -} - -#undef FP_T -#define FP_T(x) R (ObjectT::*x)(Args...) const - -template -MethodFunctor Bind( - FP_T(method), - const ObjectT* object, - typename detail::identity::type... args) { - return MethodFunctor(method, object, - args...); -} -template -UnretainedMethodFunctor Bind( - FP_T(method), - detail::UnretainedWrapper object, - typename detail::identity::type... args) { - return UnretainedMethodFunctor( - method, object, args...); -} - -#undef FP_T -#define FP_T(x) R (*x)(Args...) - -template -Functor Bind( - FP_T(function), - typename detail::identity::type... args) { - return Functor(function, args...); -} - -#undef FP_T - -} // namespace rtc - -#undef NONAME - -#endif // RTC_BASE_BIND_H_ diff --git a/rtc_base/bind_unittest.cc b/rtc_base/bind_unittest.cc deleted file mode 100644 index 664cb54500..0000000000 --- a/rtc_base/bind_unittest.cc +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Copyright 2004 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "rtc_base/bind.h" - -#include - -#include "rtc_base/ref_count.h" -#include "rtc_base/ref_counted_object.h" -#include "test/gtest.h" - -namespace rtc { - -namespace { - -struct LifeTimeCheck; - -struct MethodBindTester { - void NullaryVoid() { ++call_count; } - int NullaryInt() { - ++call_count; - return 1; - } - int NullaryConst() const { - ++call_count; - return 2; - } - void UnaryVoid(int dummy) { ++call_count; } - template - T Identity(T value) { - ++call_count; - return value; - } - int UnaryByPointer(int* value) const { - ++call_count; - return ++(*value); - } - int UnaryByRef(const int& value) const { - ++call_count; - return ++const_cast(value); - } - int Multiply(int a, int b) const { - ++call_count; - return a * b; - } - void RefArgument(const scoped_refptr& object) { - EXPECT_TRUE(object.get() != nullptr); - } - - mutable int call_count; -}; - -struct A { - int dummy; -}; -struct B : public RefCountInterface { - int dummy; -}; -struct C : public A, B {}; -struct D { - int AddRef(); -}; -struct E : public D { - int Release(); -}; -struct F { - void AddRef(); - void Release(); -}; - -struct LifeTimeCheck { - LifeTimeCheck() : ref_count_(0) {} - void AddRef() { ++ref_count_; } - void Release() { --ref_count_; } - void NullaryVoid() {} - int ref_count_; -}; - -int Return42() { - return 42; -} -int Negate(int a) { - return -a; -} -int Multiply(int a, int b) { - return a * b; -} - -} // namespace - -// Try to catch any problem with scoped_refptr type deduction in rtc::Bind at -// compile time. -#define EXPECT_IS_CAPTURED_AS_PTR(T) \ - static_assert(std::is_same::type, T*>::value, \ - "PointerType") -#define EXPECT_IS_CAPTURED_AS_SCOPED_REFPTR(T) \ - static_assert( \ - std::is_same::type, scoped_refptr>::value, \ - "PointerType") - -EXPECT_IS_CAPTURED_AS_PTR(void); -EXPECT_IS_CAPTURED_AS_PTR(int); -EXPECT_IS_CAPTURED_AS_PTR(double); -EXPECT_IS_CAPTURED_AS_PTR(A); -EXPECT_IS_CAPTURED_AS_PTR(D); -EXPECT_IS_CAPTURED_AS_PTR(RefCountInterface*); -EXPECT_IS_CAPTURED_AS_PTR( - decltype(Unretained>)); - -EXPECT_IS_CAPTURED_AS_SCOPED_REFPTR(RefCountInterface); -EXPECT_IS_CAPTURED_AS_SCOPED_REFPTR(B); -EXPECT_IS_CAPTURED_AS_SCOPED_REFPTR(C); -EXPECT_IS_CAPTURED_AS_SCOPED_REFPTR(E); -EXPECT_IS_CAPTURED_AS_SCOPED_REFPTR(F); -EXPECT_IS_CAPTURED_AS_SCOPED_REFPTR(RefCountedObject); -EXPECT_IS_CAPTURED_AS_SCOPED_REFPTR(RefCountedObject); -EXPECT_IS_CAPTURED_AS_SCOPED_REFPTR(RefCountedObject); -EXPECT_IS_CAPTURED_AS_SCOPED_REFPTR(const RefCountedObject); - -TEST(BindTest, BindToMethod) { - MethodBindTester object = {0}; - EXPECT_EQ(0, object.call_count); - Bind(&MethodBindTester::NullaryVoid, &object)(); - EXPECT_EQ(1, object.call_count); - EXPECT_EQ(1, Bind(&MethodBindTester::NullaryInt, &object)()); - EXPECT_EQ(2, object.call_count); - EXPECT_EQ(2, Bind(&MethodBindTester::NullaryConst, - static_cast(&object))()); - EXPECT_EQ(3, object.call_count); - Bind(&MethodBindTester::UnaryVoid, &object, 5)(); - EXPECT_EQ(4, object.call_count); - EXPECT_EQ(100, Bind(&MethodBindTester::Identity, &object, 100)()); - EXPECT_EQ(5, object.call_count); - const std::string string_value("test string"); - EXPECT_EQ(string_value, Bind(&MethodBindTester::Identity, - &object, string_value)()); - EXPECT_EQ(6, object.call_count); - int value = 11; - // Bind binds by value, even if the method signature is by reference, so - // "reference" binds require pointers. - EXPECT_EQ(12, Bind(&MethodBindTester::UnaryByPointer, &object, &value)()); - EXPECT_EQ(12, value); - EXPECT_EQ(7, object.call_count); - // It's possible to bind to a function that takes a const reference, though - // the capture will be a copy. See UnaryByRef hackery above where it removes - // the const to make sure the underlying storage is, in fact, a copy. - EXPECT_EQ(13, Bind(&MethodBindTester::UnaryByRef, &object, value)()); - // But the original value is unmodified. - EXPECT_EQ(12, value); - EXPECT_EQ(8, object.call_count); - EXPECT_EQ(56, Bind(&MethodBindTester::Multiply, &object, 7, 8)()); - EXPECT_EQ(9, object.call_count); -} - -TEST(BindTest, BindToFunction) { - EXPECT_EQ(42, Bind(&Return42)()); - EXPECT_EQ(3, Bind(&Negate, -3)()); - EXPECT_EQ(56, Bind(&Multiply, 8, 7)()); -} - -// Test Bind where method object implements RefCountInterface and is passed as a -// pointer. -TEST(BindTest, CapturePointerAsScopedRefPtr) { - LifeTimeCheck object; - EXPECT_EQ(object.ref_count_, 0); - scoped_refptr scoped_object(&object); - EXPECT_EQ(object.ref_count_, 1); - { - auto functor = Bind(&LifeTimeCheck::NullaryVoid, &object); - EXPECT_EQ(object.ref_count_, 2); - scoped_object = nullptr; - EXPECT_EQ(object.ref_count_, 1); - } - EXPECT_EQ(object.ref_count_, 0); -} - -// Test Bind where method object implements RefCountInterface and is passed as a -// scoped_refptr<>. -TEST(BindTest, CaptureScopedRefPtrAsScopedRefPtr) { - LifeTimeCheck object; - EXPECT_EQ(object.ref_count_, 0); - scoped_refptr scoped_object(&object); - EXPECT_EQ(object.ref_count_, 1); - { - auto functor = Bind(&LifeTimeCheck::NullaryVoid, scoped_object); - EXPECT_EQ(object.ref_count_, 2); - scoped_object = nullptr; - EXPECT_EQ(object.ref_count_, 1); - } - EXPECT_EQ(object.ref_count_, 0); -} - -// Test Bind where method object is captured as scoped_refptr<> and the functor -// dies while there are references left. -TEST(BindTest, FunctorReleasesObjectOnDestruction) { - LifeTimeCheck object; - EXPECT_EQ(object.ref_count_, 0); - scoped_refptr scoped_object(&object); - EXPECT_EQ(object.ref_count_, 1); - Bind(&LifeTimeCheck::NullaryVoid, &object)(); - EXPECT_EQ(object.ref_count_, 1); - scoped_object = nullptr; - EXPECT_EQ(object.ref_count_, 0); -} - -// Test Bind with scoped_refptr<> argument. -TEST(BindTest, ScopedRefPointerArgument) { - LifeTimeCheck object; - EXPECT_EQ(object.ref_count_, 0); - scoped_refptr scoped_object(&object); - EXPECT_EQ(object.ref_count_, 1); - { - MethodBindTester bind_tester; - auto functor = - Bind(&MethodBindTester::RefArgument, &bind_tester, scoped_object); - EXPECT_EQ(object.ref_count_, 2); - } - EXPECT_EQ(object.ref_count_, 1); - scoped_object = nullptr; - EXPECT_EQ(object.ref_count_, 0); -} - -namespace { - -const int* Ref(const int& a) { - return &a; -} - -} // anonymous namespace - -// Test Bind with non-scoped_refptr<> reference argument, which should be -// modified to a non-reference capture. -TEST(BindTest, RefArgument) { - const int x = 42; - EXPECT_EQ(&x, Ref(x)); - // Bind() should make a copy of |x|, i.e. the pointers should be different. - auto functor = Bind(&Ref, x); - EXPECT_NE(&x, functor()); -} - -} // namespace rtc diff --git a/rtc_base/bit_buffer.cc b/rtc_base/bit_buffer.cc index 540141fe52..d212ef5637 100644 --- a/rtc_base/bit_buffer.cc +++ b/rtc_base/bit_buffer.cc @@ -83,36 +83,36 @@ uint64_t BitBuffer::RemainingBitCount() const { return (static_cast(byte_count_) - byte_offset_) * 8 - bit_offset_; } -bool BitBuffer::ReadUInt8(uint8_t* val) { +bool BitBuffer::ReadUInt8(uint8_t& val) { uint32_t bit_val; - if (!ReadBits(&bit_val, sizeof(uint8_t) * 8)) { + if (!ReadBits(sizeof(uint8_t) * 8, bit_val)) { return false; } RTC_DCHECK(bit_val <= std::numeric_limits::max()); - *val = static_cast(bit_val); + val = static_cast(bit_val); return true; } -bool BitBuffer::ReadUInt16(uint16_t* val) { +bool BitBuffer::ReadUInt16(uint16_t& val) { uint32_t bit_val; - if (!ReadBits(&bit_val, sizeof(uint16_t) * 8)) { + if (!ReadBits(sizeof(uint16_t) * 8, bit_val)) { return false; } RTC_DCHECK(bit_val <= std::numeric_limits::max()); - *val = static_cast(bit_val); + val = static_cast(bit_val); return true; } -bool BitBuffer::ReadUInt32(uint32_t* val) { - return ReadBits(val, sizeof(uint32_t) * 8); +bool BitBuffer::ReadUInt32(uint32_t& val) { + return ReadBits(sizeof(uint32_t) * 8, val); } -bool BitBuffer::PeekBits(uint32_t* val, size_t bit_count) { +bool BitBuffer::PeekBits(size_t bit_count, uint32_t& val) { // TODO(nisse): Could allow bit_count == 0 and always return success. But // current code reads one byte beyond end of buffer in the case that // RemainingBitCount() == 0 and bit_count == 0. RTC_DCHECK(bit_count > 0); - if (!val || bit_count > RemainingBitCount() || bit_count > 32) { + if (bit_count > RemainingBitCount() || bit_count > 32) { return false; } const uint8_t* bytes = bytes_ + byte_offset_; @@ -121,7 +121,7 @@ bool BitBuffer::PeekBits(uint32_t* val, size_t bit_count) { // If we're reading fewer bits than what's left in the current byte, just // return the portion of this byte that we need. if (bit_count < remaining_bits_in_current_byte) { - *val = HighestBits(bits, bit_offset_ + bit_count); + val = HighestBits(bits, bit_offset_ + bit_count); return true; } // Otherwise, subtract what we've read from the bit count and read as many @@ -137,12 +137,50 @@ bool BitBuffer::PeekBits(uint32_t* val, size_t bit_count) { bits <<= bit_count; bits |= HighestBits(*bytes, bit_count); } - *val = bits; + val = bits; return true; } -bool BitBuffer::ReadBits(uint32_t* val, size_t bit_count) { - return PeekBits(val, bit_count) && ConsumeBits(bit_count); +bool BitBuffer::PeekBits(size_t bit_count, uint64_t& val) { + // TODO(nisse): Could allow bit_count == 0 and always return success. But + // current code reads one byte beyond end of buffer in the case that + // RemainingBitCount() == 0 and bit_count == 0. + RTC_DCHECK(bit_count > 0); + if (bit_count > RemainingBitCount() || bit_count > 64) { + return false; + } + const uint8_t* bytes = bytes_ + byte_offset_; + size_t remaining_bits_in_current_byte = 8 - bit_offset_; + uint64_t bits = LowestBits(*bytes++, remaining_bits_in_current_byte); + // If we're reading fewer bits than what's left in the current byte, just + // return the portion of this byte that we need. + if (bit_count < remaining_bits_in_current_byte) { + val = HighestBits(bits, bit_offset_ + bit_count); + return true; + } + // Otherwise, subtract what we've read from the bit count and read as many + // full bytes as we can into bits. + bit_count -= remaining_bits_in_current_byte; + while (bit_count >= 8) { + bits = (bits << 8) | *bytes++; + bit_count -= 8; + } + // Whatever we have left is smaller than a byte, so grab just the bits we need + // and shift them into the lowest bits. + if (bit_count > 0) { + bits <<= bit_count; + bits |= HighestBits(*bytes, bit_count); + } + val = bits; + return true; +} + +bool BitBuffer::ReadBits(size_t bit_count, uint32_t& val) { + return PeekBits(bit_count, val) && ConsumeBits(bit_count); +} + +bool BitBuffer::ReadBits(size_t bit_count, uint64_t& val) { + return PeekBits(bit_count, val) && ConsumeBits(bit_count); } bool BitBuffer::ConsumeBytes(size_t byte_count) { @@ -159,39 +197,36 @@ bool BitBuffer::ConsumeBits(size_t bit_count) { return true; } -bool BitBuffer::ReadNonSymmetric(uint32_t* val, uint32_t num_values) { +bool BitBuffer::ReadNonSymmetric(uint32_t num_values, uint32_t& val) { RTC_DCHECK_GT(num_values, 0); RTC_DCHECK_LE(num_values, uint32_t{1} << 31); if (num_values == 1) { // When there is only one possible value, it requires zero bits to store it. // But ReadBits doesn't support reading zero bits. - *val = 0; + val = 0; return true; } size_t count_bits = CountBits(num_values); uint32_t num_min_bits_values = (uint32_t{1} << count_bits) - num_values; - if (!ReadBits(val, count_bits - 1)) { + if (!ReadBits(count_bits - 1, val)) { return false; } - if (*val < num_min_bits_values) { + if (val < num_min_bits_values) { return true; } uint32_t extra_bit; - if (!ReadBits(&extra_bit, /*bit_count=*/1)) { + if (!ReadBits(/*bit_count=*/1, extra_bit)) { return false; } - *val = (*val << 1) + extra_bit - num_min_bits_values; + val = (val << 1) + extra_bit - num_min_bits_values; return true; } -bool BitBuffer::ReadExponentialGolomb(uint32_t* val) { - if (!val) { - return false; - } +bool BitBuffer::ReadExponentialGolomb(uint32_t& val) { // Store off the current byte/bit offset, in case we want to restore them due // to a failed parse. size_t original_byte_offset = byte_offset_; @@ -200,35 +235,35 @@ bool BitBuffer::ReadExponentialGolomb(uint32_t* val) { // Count the number of leading 0 bits by peeking/consuming them one at a time. size_t zero_bit_count = 0; uint32_t peeked_bit; - while (PeekBits(&peeked_bit, 1) && peeked_bit == 0) { + while (PeekBits(1, peeked_bit) && peeked_bit == 0) { zero_bit_count++; ConsumeBits(1); } // We should either be at the end of the stream, or the next bit should be 1. - RTC_DCHECK(!PeekBits(&peeked_bit, 1) || peeked_bit == 1); + RTC_DCHECK(!PeekBits(1, peeked_bit) || peeked_bit == 1); // The bit count of the value is the number of zeros + 1. Make sure that many // bits fits in a uint32_t and that we have enough bits left for it, and then // read the value. size_t value_bit_count = zero_bit_count + 1; - if (value_bit_count > 32 || !ReadBits(val, value_bit_count)) { + if (value_bit_count > 32 || !ReadBits(value_bit_count, val)) { RTC_CHECK(Seek(original_byte_offset, original_bit_offset)); return false; } - *val -= 1; + val -= 1; return true; } -bool BitBuffer::ReadSignedExponentialGolomb(int32_t* val) { +bool BitBuffer::ReadSignedExponentialGolomb(int32_t& val) { uint32_t unsigned_val; - if (!ReadExponentialGolomb(&unsigned_val)) { + if (!ReadExponentialGolomb(unsigned_val)) { return false; } if ((unsigned_val & 1) == 0) { - *val = -static_cast(unsigned_val / 2); + val = -static_cast(unsigned_val / 2); } else { - *val = (unsigned_val + 1) / 2; + val = (unsigned_val + 1) / 2; } return true; } diff --git a/rtc_base/bit_buffer.h b/rtc_base/bit_buffer.h index de7bf02d56..388218e698 100644 --- a/rtc_base/bit_buffer.h +++ b/rtc_base/bit_buffer.h @@ -14,6 +14,7 @@ #include // For size_t. #include // For integer types. +#include "absl/base/attributes.h" #include "rtc_base/constructor_magic.h" namespace rtc { @@ -38,18 +39,35 @@ class BitBuffer { // Reads byte-sized values from the buffer. Returns false if there isn't // enough data left for the specified type. - bool ReadUInt8(uint8_t* val); - bool ReadUInt16(uint16_t* val); - bool ReadUInt32(uint32_t* val); + bool ReadUInt8(uint8_t& val); + bool ReadUInt16(uint16_t& val); + bool ReadUInt32(uint32_t& val); + ABSL_DEPRECATED("") bool ReadUInt8(uint8_t* val) { + return val ? ReadUInt8(*val) : false; + } + ABSL_DEPRECATED("") bool ReadUInt16(uint16_t* val) { + return val ? ReadUInt16(*val) : false; + } + ABSL_DEPRECATED("") bool ReadUInt32(uint32_t* val) { + return val ? ReadUInt32(*val) : false; + } // Reads bit-sized values from the buffer. Returns false if there isn't enough // data left for the specified bit count. - bool ReadBits(uint32_t* val, size_t bit_count); + bool ReadBits(size_t bit_count, uint32_t& val); + bool ReadBits(size_t bit_count, uint64_t& val); + ABSL_DEPRECATED("") bool ReadBits(uint32_t* val, size_t bit_count) { + return val ? ReadBits(bit_count, *val) : false; + } // Peeks bit-sized values from the buffer. Returns false if there isn't enough // data left for the specified number of bits. Doesn't move the current // offset. - bool PeekBits(uint32_t* val, size_t bit_count); + bool PeekBits(size_t bit_count, uint32_t& val); + bool PeekBits(size_t bit_count, uint64_t& val); + ABSL_DEPRECATED("") bool PeekBits(uint32_t* val, size_t bit_count) { + return val ? PeekBits(bit_count, *val) : false; + } // Reads value in range [0, num_values - 1]. // This encoding is similar to ReadBits(val, Ceil(Log2(num_values)), @@ -61,7 +79,11 @@ class BitBuffer { // Value v in range [k, num_values - 1] is encoded as (v+k) in n bits. // https://aomediacodec.github.io/av1-spec/#nsn // Returns false if there isn't enough data left. - bool ReadNonSymmetric(uint32_t* val, uint32_t num_values); + bool ReadNonSymmetric(uint32_t num_values, uint32_t& val); + ABSL_DEPRECATED("") + bool ReadNonSymmetric(uint32_t* val, uint32_t num_values) { + return val ? ReadNonSymmetric(num_values, *val) : false; + } // Reads the exponential golomb encoded value at the current offset. // Exponential golomb values are encoded as: @@ -71,11 +93,18 @@ class BitBuffer { // and increment the result by 1. // Returns false if there isn't enough data left for the specified type, or if // the value wouldn't fit in a uint32_t. - bool ReadExponentialGolomb(uint32_t* val); + bool ReadExponentialGolomb(uint32_t& val); + ABSL_DEPRECATED("") bool ReadExponentialGolomb(uint32_t* val) { + return val ? ReadExponentialGolomb(*val) : false; + } + // Reads signed exponential golomb values at the current offset. Signed // exponential golomb values are just the unsigned values mapped to the // sequence 0, 1, -1, 2, -2, etc. in order. - bool ReadSignedExponentialGolomb(int32_t* val); + bool ReadSignedExponentialGolomb(int32_t& val); + ABSL_DEPRECATED("") bool ReadSignedExponentialGolomb(int32_t* val) { + return val ? ReadSignedExponentialGolomb(*val) : false; + } // Moves current position |byte_count| bytes forward. Returns false if // there aren't enough bytes left in the buffer. diff --git a/rtc_base/bit_buffer_unittest.cc b/rtc_base/bit_buffer_unittest.cc index 656682c2ef..e6bb4270c7 100644 --- a/rtc_base/bit_buffer_unittest.cc +++ b/rtc_base/bit_buffer_unittest.cc @@ -49,13 +49,13 @@ TEST(BitBufferTest, ReadBytesAligned) { uint16_t val16; uint32_t val32; BitBuffer buffer(bytes, 8); - EXPECT_TRUE(buffer.ReadUInt8(&val8)); + EXPECT_TRUE(buffer.ReadUInt8(val8)); EXPECT_EQ(0x0Au, val8); - EXPECT_TRUE(buffer.ReadUInt8(&val8)); + EXPECT_TRUE(buffer.ReadUInt8(val8)); EXPECT_EQ(0xBCu, val8); - EXPECT_TRUE(buffer.ReadUInt16(&val16)); + EXPECT_TRUE(buffer.ReadUInt16(val16)); EXPECT_EQ(0xDEF1u, val16); - EXPECT_TRUE(buffer.ReadUInt32(&val32)); + EXPECT_TRUE(buffer.ReadUInt32(val32)); EXPECT_EQ(0x23456789u, val32); } @@ -68,13 +68,13 @@ TEST(BitBufferTest, ReadBytesOffset4) { BitBuffer buffer(bytes, 9); EXPECT_TRUE(buffer.ConsumeBits(4)); - EXPECT_TRUE(buffer.ReadUInt8(&val8)); + EXPECT_TRUE(buffer.ReadUInt8(val8)); EXPECT_EQ(0xABu, val8); - EXPECT_TRUE(buffer.ReadUInt8(&val8)); + EXPECT_TRUE(buffer.ReadUInt8(val8)); EXPECT_EQ(0xCDu, val8); - EXPECT_TRUE(buffer.ReadUInt16(&val16)); + EXPECT_TRUE(buffer.ReadUInt16(val16)); EXPECT_EQ(0xEF12u, val16); - EXPECT_TRUE(buffer.ReadUInt32(&val32)); + EXPECT_TRUE(buffer.ReadUInt32(val32)); EXPECT_EQ(0x34567890u, val32); } @@ -102,15 +102,15 @@ TEST(BitBufferTest, ReadBytesOffset3) { uint32_t val32; BitBuffer buffer(bytes, 8); EXPECT_TRUE(buffer.ConsumeBits(3)); - EXPECT_TRUE(buffer.ReadUInt8(&val8)); + EXPECT_TRUE(buffer.ReadUInt8(val8)); EXPECT_EQ(0xFEu, val8); - EXPECT_TRUE(buffer.ReadUInt16(&val16)); + EXPECT_TRUE(buffer.ReadUInt16(val16)); EXPECT_EQ(0xDCBAu, val16); - EXPECT_TRUE(buffer.ReadUInt32(&val32)); + EXPECT_TRUE(buffer.ReadUInt32(val32)); EXPECT_EQ(0x98765432u, val32); // 5 bits left unread. Not enough to read a uint8_t. EXPECT_EQ(5u, buffer.RemainingBitCount()); - EXPECT_FALSE(buffer.ReadUInt8(&val8)); + EXPECT_FALSE(buffer.ReadUInt8(val8)); } TEST(BitBufferTest, ReadBits) { @@ -120,26 +120,58 @@ TEST(BitBufferTest, ReadBits) { const uint8_t bytes[] = {0x4D, 0x32}; uint32_t val; BitBuffer buffer(bytes, 2); - EXPECT_TRUE(buffer.ReadBits(&val, 3)); + EXPECT_TRUE(buffer.ReadBits(3, val)); // 0b010 EXPECT_EQ(0x2u, val); - EXPECT_TRUE(buffer.ReadBits(&val, 2)); + EXPECT_TRUE(buffer.ReadBits(2, val)); // 0b01 EXPECT_EQ(0x1u, val); - EXPECT_TRUE(buffer.ReadBits(&val, 7)); + EXPECT_TRUE(buffer.ReadBits(7, val)); // 0b1010011 EXPECT_EQ(0x53u, val); - EXPECT_TRUE(buffer.ReadBits(&val, 2)); + EXPECT_TRUE(buffer.ReadBits(2, val)); // 0b00 EXPECT_EQ(0x0u, val); - EXPECT_TRUE(buffer.ReadBits(&val, 1)); + EXPECT_TRUE(buffer.ReadBits(1, val)); // 0b1 EXPECT_EQ(0x1u, val); - EXPECT_TRUE(buffer.ReadBits(&val, 1)); + EXPECT_TRUE(buffer.ReadBits(1, val)); // 0b0 EXPECT_EQ(0x0u, val); - EXPECT_FALSE(buffer.ReadBits(&val, 1)); + EXPECT_FALSE(buffer.ReadBits(1, val)); +} + +TEST(BitBufferTest, ReadBits64) { + const uint8_t bytes[] = {0x4D, 0x32, 0xAB, 0x54, 0x00, 0xFF, 0xFE, 0x01, + 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89}; + BitBuffer buffer(bytes, 16); + uint64_t val; + + // Peek and read first 33 bits. + EXPECT_TRUE(buffer.PeekBits(33, val)); + EXPECT_EQ(0x4D32AB5400FFFE01ull >> (64 - 33), val); + val = 0; + EXPECT_TRUE(buffer.ReadBits(33, val)); + EXPECT_EQ(0x4D32AB5400FFFE01ull >> (64 - 33), val); + + // Peek and read next 31 bits. + constexpr uint64_t kMask31Bits = (1ull << 32) - 1; + EXPECT_TRUE(buffer.PeekBits(31, val)); + EXPECT_EQ(0x4D32AB5400FFFE01ull & kMask31Bits, val); + val = 0; + EXPECT_TRUE(buffer.ReadBits(31, val)); + EXPECT_EQ(0x4D32AB5400FFFE01ull & kMask31Bits, val); + + // Peek and read remaining 64 bits. + EXPECT_TRUE(buffer.PeekBits(64, val)); + EXPECT_EQ(0xABCDEF0123456789ull, val); + val = 0; + EXPECT_TRUE(buffer.ReadBits(64, val)); + EXPECT_EQ(0xABCDEF0123456789ull, val); + + // Nothing more to read. + EXPECT_FALSE(buffer.ReadBits(1, val)); } TEST(BitBufferDeathTest, SetOffsetValues) { @@ -187,10 +219,10 @@ TEST(BitBufferTest, ReadNonSymmetricSameNumberOfBitsWhenNumValuesPowerOf2) { uint32_t values[4]; ASSERT_EQ(reader.RemainingBitCount(), 16u); - EXPECT_TRUE(reader.ReadNonSymmetric(&values[0], /*num_values=*/1 << 4)); - EXPECT_TRUE(reader.ReadNonSymmetric(&values[1], /*num_values=*/1 << 4)); - EXPECT_TRUE(reader.ReadNonSymmetric(&values[2], /*num_values=*/1 << 4)); - EXPECT_TRUE(reader.ReadNonSymmetric(&values[3], /*num_values=*/1 << 4)); + EXPECT_TRUE(reader.ReadNonSymmetric(/*num_values=*/1 << 4, values[0])); + EXPECT_TRUE(reader.ReadNonSymmetric(/*num_values=*/1 << 4, values[1])); + EXPECT_TRUE(reader.ReadNonSymmetric(/*num_values=*/1 << 4, values[2])); + EXPECT_TRUE(reader.ReadNonSymmetric(/*num_values=*/1 << 4, values[3])); ASSERT_EQ(reader.RemainingBitCount(), 0u); EXPECT_THAT(values, ElementsAre(0xf, 0x3, 0xa, 0x0)); @@ -244,12 +276,12 @@ TEST(BitBufferWriterTest, NonSymmetricReadsMatchesWrites) { rtc::BitBuffer reader(bytes, 2); uint32_t values[6]; - EXPECT_TRUE(reader.ReadNonSymmetric(&values[0], /*num_values=*/6)); - EXPECT_TRUE(reader.ReadNonSymmetric(&values[1], /*num_values=*/6)); - EXPECT_TRUE(reader.ReadNonSymmetric(&values[2], /*num_values=*/6)); - EXPECT_TRUE(reader.ReadNonSymmetric(&values[3], /*num_values=*/6)); - EXPECT_TRUE(reader.ReadNonSymmetric(&values[4], /*num_values=*/6)); - EXPECT_TRUE(reader.ReadNonSymmetric(&values[5], /*num_values=*/6)); + EXPECT_TRUE(reader.ReadNonSymmetric(/*num_values=*/6, values[0])); + EXPECT_TRUE(reader.ReadNonSymmetric(/*num_values=*/6, values[1])); + EXPECT_TRUE(reader.ReadNonSymmetric(/*num_values=*/6, values[2])); + EXPECT_TRUE(reader.ReadNonSymmetric(/*num_values=*/6, values[3])); + EXPECT_TRUE(reader.ReadNonSymmetric(/*num_values=*/6, values[4])); + EXPECT_TRUE(reader.ReadNonSymmetric(/*num_values=*/6, values[5])); EXPECT_THAT(values, ElementsAre(0, 1, 2, 3, 4, 5)); } @@ -260,7 +292,7 @@ TEST(BitBufferTest, ReadNonSymmetricOnlyValueConsumesNoBits) { uint32_t value = 0xFFFFFFFF; ASSERT_EQ(reader.RemainingBitCount(), 16u); - EXPECT_TRUE(reader.ReadNonSymmetric(&value, /*num_values=*/1)); + EXPECT_TRUE(reader.ReadNonSymmetric(/*num_values=*/1, value)); EXPECT_EQ(value, 0u); EXPECT_EQ(reader.RemainingBitCount(), 16u); @@ -302,7 +334,7 @@ TEST(BitBufferTest, GolombUint32Values) { byteBuffer.WriteUInt64(encoded_val); uint32_t decoded_val; EXPECT_TRUE(buffer.Seek(0, 0)); - EXPECT_TRUE(buffer.ReadExponentialGolomb(&decoded_val)); + EXPECT_TRUE(buffer.ReadExponentialGolomb(decoded_val)); EXPECT_EQ(i, decoded_val); } } @@ -319,7 +351,7 @@ TEST(BitBufferTest, SignedGolombValues) { for (size_t i = 0; i < sizeof(golomb_bits); ++i) { BitBuffer buffer(&golomb_bits[i], 1); int32_t decoded_val; - ASSERT_TRUE(buffer.ReadSignedExponentialGolomb(&decoded_val)); + ASSERT_TRUE(buffer.ReadSignedExponentialGolomb(decoded_val)); EXPECT_EQ(expected[i], decoded_val) << "Mismatch in expected/decoded value for golomb_bits[" << i << "]: " << static_cast(golomb_bits[i]); @@ -332,13 +364,13 @@ TEST(BitBufferTest, NoGolombOverread) { // If it didn't, the above buffer would be valid at 3 bytes. BitBuffer buffer(bytes, 1); uint32_t decoded_val; - EXPECT_FALSE(buffer.ReadExponentialGolomb(&decoded_val)); + EXPECT_FALSE(buffer.ReadExponentialGolomb(decoded_val)); BitBuffer longer_buffer(bytes, 2); - EXPECT_FALSE(longer_buffer.ReadExponentialGolomb(&decoded_val)); + EXPECT_FALSE(longer_buffer.ReadExponentialGolomb(decoded_val)); BitBuffer longest_buffer(bytes, 3); - EXPECT_TRUE(longest_buffer.ReadExponentialGolomb(&decoded_val)); + EXPECT_TRUE(longest_buffer.ReadExponentialGolomb(decoded_val)); // Golomb should have read 9 bits, so 0x01FF, and since it is golomb, the // result is 0x01FF - 1 = 0x01FE. EXPECT_EQ(0x01FEu, decoded_val); @@ -360,20 +392,20 @@ TEST(BitBufferWriterTest, SymmetricReadWrite) { EXPECT_TRUE(buffer.Seek(0, 0)); uint32_t val; - EXPECT_TRUE(buffer.ReadBits(&val, 3)); + EXPECT_TRUE(buffer.ReadBits(3, val)); EXPECT_EQ(0x2u, val); - EXPECT_TRUE(buffer.ReadBits(&val, 2)); + EXPECT_TRUE(buffer.ReadBits(2, val)); EXPECT_EQ(0x1u, val); - EXPECT_TRUE(buffer.ReadBits(&val, 7)); + EXPECT_TRUE(buffer.ReadBits(7, val)); EXPECT_EQ(0x53u, val); - EXPECT_TRUE(buffer.ReadBits(&val, 2)); + EXPECT_TRUE(buffer.ReadBits(2, val)); EXPECT_EQ(0x0u, val); - EXPECT_TRUE(buffer.ReadBits(&val, 1)); + EXPECT_TRUE(buffer.ReadBits(1, val)); EXPECT_EQ(0x1u, val); - EXPECT_TRUE(buffer.ReadBits(&val, 17)); + EXPECT_TRUE(buffer.ReadBits(17, val)); EXPECT_EQ(0x1ABCDu, val); // And there should be nothing left. - EXPECT_FALSE(buffer.ReadBits(&val, 1)); + EXPECT_FALSE(buffer.ReadBits(1, val)); } TEST(BitBufferWriterTest, SymmetricBytesMisaligned) { @@ -390,11 +422,11 @@ TEST(BitBufferWriterTest, SymmetricBytesMisaligned) { uint8_t val8; uint16_t val16; uint32_t val32; - EXPECT_TRUE(buffer.ReadUInt8(&val8)); + EXPECT_TRUE(buffer.ReadUInt8(val8)); EXPECT_EQ(0x12u, val8); - EXPECT_TRUE(buffer.ReadUInt16(&val16)); + EXPECT_TRUE(buffer.ReadUInt16(val16)); EXPECT_EQ(0x3456u, val16); - EXPECT_TRUE(buffer.ReadUInt32(&val32)); + EXPECT_TRUE(buffer.ReadUInt32(val32)); EXPECT_EQ(0x789ABCDEu, val32); } @@ -408,7 +440,7 @@ TEST(BitBufferWriterTest, SymmetricGolomb) { buffer.Seek(0, 0); for (size_t i = 0; i < arraysize(test_string); ++i) { uint32_t val; - EXPECT_TRUE(buffer.ReadExponentialGolomb(&val)); + EXPECT_TRUE(buffer.ReadExponentialGolomb(val)); EXPECT_LE(val, std::numeric_limits::max()); EXPECT_EQ(test_string[i], static_cast(val)); } diff --git a/rtc_base/boringssl_certificate.cc b/rtc_base/boringssl_certificate.cc index 4e55cf398f..bb14036a3e 100644 --- a/rtc_base/boringssl_certificate.cc +++ b/rtc_base/boringssl_certificate.cc @@ -291,7 +291,7 @@ std::unique_ptr BoringSSLCertificate::FromPEMString( #define OID_MATCHES(oid, oid_other) \ (CBS_len(&oid) == sizeof(oid_other) && \ - 0 == memcmp(CBS_data(&oid), oid_other, sizeof(oid_other))) + 0 == memcmp(CBS_data(&oid), oid_other, sizeof(oid_other))) bool BoringSSLCertificate::GetSignatureDigestAlgorithm( std::string* algorithm) const { diff --git a/rtc_base/buffer_queue.h b/rtc_base/buffer_queue.h index 5895530969..09c6c4f734 100644 --- a/rtc_base/buffer_queue.h +++ b/rtc_base/buffer_queue.h @@ -16,9 +16,9 @@ #include #include +#include "api/sequence_checker.h" #include "rtc_base/buffer.h" #include "rtc_base/constructor_magic.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/thread_annotations.h" diff --git a/rtc_base/callback.h b/rtc_base/callback.h deleted file mode 100644 index 47512214e3..0000000000 --- a/rtc_base/callback.h +++ /dev/null @@ -1,250 +0,0 @@ -// This file was GENERATED by command: -// pump.py callback.h.pump -// DO NOT EDIT BY HAND!!! - -/* - * Copyright 2012 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -// To generate callback.h from callback.h.pump, execute: -// ../third_party/googletest/src/googletest/scripts/pump.py callback.h.pump - -// Callbacks are callable object containers. They can hold a function pointer -// or a function object and behave like a value type. Internally, data is -// reference-counted, making copies and pass-by-value inexpensive. -// -// Callbacks are typed using template arguments. The format is: -// CallbackN -// where N is the number of arguments supplied to the callable object. -// Callbacks are invoked using operator(), just like a function or a function -// object. Default-constructed callbacks are "empty," and executing an empty -// callback does nothing. A callback can be made empty by assigning it from -// a default-constructed callback. -// -// Callbacks are similar in purpose to std::function (which isn't available on -// all platforms we support) and a lightweight alternative to sigslots. Since -// they effectively hide the type of the object they call, they're useful in -// breaking dependencies between objects that need to interact with one another. -// Notably, they can hold the results of Bind(), std::bind*, etc, without -// needing -// to know the resulting object type of those calls. -// -// Sigslots, on the other hand, provide a fuller feature set, such as multiple -// subscriptions to a signal, optional thread-safety, and lifetime tracking of -// slots. When these features are needed, choose sigslots. -// -// Example: -// int sqr(int x) { return x * x; } -// struct AddK { -// int k; -// int operator()(int x) const { return x + k; } -// } add_k = {5}; -// -// Callback1 my_callback; -// cout << my_callback.empty() << endl; // true -// -// my_callback = Callback1(&sqr); -// cout << my_callback.empty() << endl; // false -// cout << my_callback(3) << endl; // 9 -// -// my_callback = Callback1(add_k); -// cout << my_callback(10) << endl; // 15 -// -// my_callback = Callback1(); -// cout << my_callback.empty() << endl; // true - -#ifndef RTC_BASE_CALLBACK_H_ -#define RTC_BASE_CALLBACK_H_ - -#include "api/scoped_refptr.h" -#include "rtc_base/ref_count.h" -#include "rtc_base/ref_counted_object.h" - -namespace rtc { - -template -class Callback0 { - public: - // Default copy operations are appropriate for this class. - Callback0() {} - template - Callback0(const T& functor) - : helper_(new RefCountedObject >(functor)) {} - R operator()() { - if (empty()) - return R(); - return helper_->Run(); - } - bool empty() const { return !helper_; } - - private: - struct Helper : RefCountInterface { - virtual ~Helper() {} - virtual R Run() = 0; - }; - template - struct HelperImpl : Helper { - explicit HelperImpl(const T& functor) : functor_(functor) {} - virtual R Run() { return functor_(); } - T functor_; - }; - scoped_refptr helper_; -}; - -template -class Callback1 { - public: - // Default copy operations are appropriate for this class. - Callback1() {} - template - Callback1(const T& functor) - : helper_(new RefCountedObject >(functor)) {} - R operator()(P1 p1) { - if (empty()) - return R(); - return helper_->Run(p1); - } - bool empty() const { return !helper_; } - - private: - struct Helper : RefCountInterface { - virtual ~Helper() {} - virtual R Run(P1 p1) = 0; - }; - template - struct HelperImpl : Helper { - explicit HelperImpl(const T& functor) : functor_(functor) {} - virtual R Run(P1 p1) { return functor_(p1); } - T functor_; - }; - scoped_refptr helper_; -}; - -template -class Callback2 { - public: - // Default copy operations are appropriate for this class. - Callback2() {} - template - Callback2(const T& functor) - : helper_(new RefCountedObject >(functor)) {} - R operator()(P1 p1, P2 p2) { - if (empty()) - return R(); - return helper_->Run(p1, p2); - } - bool empty() const { return !helper_; } - - private: - struct Helper : RefCountInterface { - virtual ~Helper() {} - virtual R Run(P1 p1, P2 p2) = 0; - }; - template - struct HelperImpl : Helper { - explicit HelperImpl(const T& functor) : functor_(functor) {} - virtual R Run(P1 p1, P2 p2) { return functor_(p1, p2); } - T functor_; - }; - scoped_refptr helper_; -}; - -template -class Callback3 { - public: - // Default copy operations are appropriate for this class. - Callback3() {} - template - Callback3(const T& functor) - : helper_(new RefCountedObject >(functor)) {} - R operator()(P1 p1, P2 p2, P3 p3) { - if (empty()) - return R(); - return helper_->Run(p1, p2, p3); - } - bool empty() const { return !helper_; } - - private: - struct Helper : RefCountInterface { - virtual ~Helper() {} - virtual R Run(P1 p1, P2 p2, P3 p3) = 0; - }; - template - struct HelperImpl : Helper { - explicit HelperImpl(const T& functor) : functor_(functor) {} - virtual R Run(P1 p1, P2 p2, P3 p3) { return functor_(p1, p2, p3); } - T functor_; - }; - scoped_refptr helper_; -}; - -template -class Callback4 { - public: - // Default copy operations are appropriate for this class. - Callback4() {} - template - Callback4(const T& functor) - : helper_(new RefCountedObject >(functor)) {} - R operator()(P1 p1, P2 p2, P3 p3, P4 p4) { - if (empty()) - return R(); - return helper_->Run(p1, p2, p3, p4); - } - bool empty() const { return !helper_; } - - private: - struct Helper : RefCountInterface { - virtual ~Helper() {} - virtual R Run(P1 p1, P2 p2, P3 p3, P4 p4) = 0; - }; - template - struct HelperImpl : Helper { - explicit HelperImpl(const T& functor) : functor_(functor) {} - virtual R Run(P1 p1, P2 p2, P3 p3, P4 p4) { - return functor_(p1, p2, p3, p4); - } - T functor_; - }; - scoped_refptr helper_; -}; - -template -class Callback5 { - public: - // Default copy operations are appropriate for this class. - Callback5() {} - template - Callback5(const T& functor) - : helper_(new RefCountedObject >(functor)) {} - R operator()(P1 p1, P2 p2, P3 p3, P4 p4, P5 p5) { - if (empty()) - return R(); - return helper_->Run(p1, p2, p3, p4, p5); - } - bool empty() const { return !helper_; } - - private: - struct Helper : RefCountInterface { - virtual ~Helper() {} - virtual R Run(P1 p1, P2 p2, P3 p3, P4 p4, P5 p5) = 0; - }; - template - struct HelperImpl : Helper { - explicit HelperImpl(const T& functor) : functor_(functor) {} - virtual R Run(P1 p1, P2 p2, P3 p3, P4 p4, P5 p5) { - return functor_(p1, p2, p3, p4, p5); - } - T functor_; - }; - scoped_refptr helper_; -}; -} // namespace rtc - -#endif // RTC_BASE_CALLBACK_H_ diff --git a/rtc_base/callback.h.pump b/rtc_base/callback.h.pump deleted file mode 100644 index dc5fb3ae1d..0000000000 --- a/rtc_base/callback.h.pump +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Copyright 2012 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -// To generate callback.h from callback.h.pump, execute: -// ../third_party/googletest/src/googletest/scripts/pump.py callback.h.pump - -// Callbacks are callable object containers. They can hold a function pointer -// or a function object and behave like a value type. Internally, data is -// reference-counted, making copies and pass-by-value inexpensive. -// -// Callbacks are typed using template arguments. The format is: -// CallbackN -// where N is the number of arguments supplied to the callable object. -// Callbacks are invoked using operator(), just like a function or a function -// object. Default-constructed callbacks are "empty," and executing an empty -// callback does nothing. A callback can be made empty by assigning it from -// a default-constructed callback. -// -// Callbacks are similar in purpose to std::function (which isn't available on -// all platforms we support) and a lightweight alternative to sigslots. Since -// they effectively hide the type of the object they call, they're useful in -// breaking dependencies between objects that need to interact with one another. -// Notably, they can hold the results of Bind(), std::bind*, etc, without needing -// to know the resulting object type of those calls. -// -// Sigslots, on the other hand, provide a fuller feature set, such as multiple -// subscriptions to a signal, optional thread-safety, and lifetime tracking of -// slots. When these features are needed, choose sigslots. -// -// Example: -// int sqr(int x) { return x * x; } -// struct AddK { -// int k; -// int operator()(int x) const { return x + k; } -// } add_k = {5}; -// -// Callback1 my_callback; -// cout << my_callback.empty() << endl; // true -// -// my_callback = Callback1(&sqr); -// cout << my_callback.empty() << endl; // false -// cout << my_callback(3) << endl; // 9 -// -// my_callback = Callback1(add_k); -// cout << my_callback(10) << endl; // 15 -// -// my_callback = Callback1(); -// cout << my_callback.empty() << endl; // true - -#ifndef RTC_BASE_CALLBACK_H_ -#define RTC_BASE_CALLBACK_H_ - -#include "rtc_base/ref_count.h" -#include "rtc_base/ref_counted_object.h" -#include "api/scoped_refptr.h" - -namespace rtc { - -$var n = 5 -$range i 0..n -$for i [[ -$range j 1..i - -template -class Callback$i { - public: - // Default copy operations are appropriate for this class. - Callback$i() {} - template Callback$i(const T& functor) - : helper_(new RefCountedObject< HelperImpl >(functor)) {} - R operator()($for j , [[P$j p$j]]) { - if (empty()) - return R(); - return helper_->Run($for j , [[p$j]]); - } - bool empty() const { return !helper_; } - - private: - struct Helper : RefCountInterface { - virtual ~Helper() {} - virtual R Run($for j , [[P$j p$j]]) = 0; - }; - template struct HelperImpl : Helper { - explicit HelperImpl(const T& functor) : functor_(functor) {} - virtual R Run($for j , [[P$j p$j]]) { - return functor_($for j , [[p$j]]); - } - T functor_; - }; - scoped_refptr helper_; -}; - -]] -} // namespace rtc - -#endif // RTC_BASE_CALLBACK_H_ diff --git a/rtc_base/callback_list_unittest.cc b/rtc_base/callback_list_unittest.cc index 119f88f543..665d779739 100644 --- a/rtc_base/callback_list_unittest.cc +++ b/rtc_base/callback_list_unittest.cc @@ -11,7 +11,6 @@ #include #include "api/function_view.h" -#include "rtc_base/bind.h" #include "rtc_base/callback_list.h" #include "test/gtest.h" @@ -209,8 +208,6 @@ TEST(CallbackList, MemberFunctionTest) { } // todo(glahiru): Add a test case to catch some error for Karl's first fix -// todo(glahiru): Add a test for rtc::Bind -// which used the following code in the Send TEST(CallbackList, RemoveOneReceiver) { int removal_tag[2]; diff --git a/rtc_base/callback_unittest.cc b/rtc_base/callback_unittest.cc deleted file mode 100644 index 876729570c..0000000000 --- a/rtc_base/callback_unittest.cc +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Copyright 2004 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "rtc_base/callback.h" - -#include "rtc_base/bind.h" -#include "rtc_base/keep_ref_until_done.h" -#include "rtc_base/ref_count.h" -#include "test/gtest.h" - -namespace rtc { - -namespace { - -void f() {} -int g() { - return 42; -} -int h(int x) { - return x * x; -} -void i(int& x) { - x *= x; -} // NOLINT: Testing refs - -struct BindTester { - int a() { return 24; } - int b(int x) const { return x * x; } -}; - -class RefCountedBindTester : public RefCountInterface { - public: - RefCountedBindTester() : count_(0) {} - void AddRef() const override { ++count_; } - RefCountReleaseStatus Release() const override { - --count_; - return count_ == 0 ? RefCountReleaseStatus::kDroppedLastRef - : RefCountReleaseStatus::kOtherRefsRemained; - } - int RefCount() const { return count_; } - - private: - mutable int count_; -}; - -} // namespace - -TEST(CallbackTest, VoidReturn) { - Callback0 cb; - EXPECT_TRUE(cb.empty()); - cb(); // Executing an empty callback should not crash. - cb = Callback0(&f); - EXPECT_FALSE(cb.empty()); - cb(); -} - -TEST(CallbackTest, IntReturn) { - Callback0 cb; - EXPECT_TRUE(cb.empty()); - cb = Callback0(&g); - EXPECT_FALSE(cb.empty()); - EXPECT_EQ(42, cb()); - EXPECT_EQ(42, cb()); -} - -TEST(CallbackTest, OneParam) { - Callback1 cb1(&h); - EXPECT_FALSE(cb1.empty()); - EXPECT_EQ(9, cb1(-3)); - EXPECT_EQ(100, cb1(10)); - - // Try clearing a callback. - cb1 = Callback1(); - EXPECT_TRUE(cb1.empty()); - - // Try a callback with a ref parameter. - Callback1 cb2(&i); - int x = 3; - cb2(x); - EXPECT_EQ(9, x); - cb2(x); - EXPECT_EQ(81, x); -} - -TEST(CallbackTest, WithBind) { - BindTester t; - Callback0 cb1 = Bind(&BindTester::a, &t); - EXPECT_EQ(24, cb1()); - EXPECT_EQ(24, cb1()); - cb1 = Bind(&BindTester::b, &t, 10); - EXPECT_EQ(100, cb1()); - EXPECT_EQ(100, cb1()); - cb1 = Bind(&BindTester::b, &t, 5); - EXPECT_EQ(25, cb1()); - EXPECT_EQ(25, cb1()); -} - -TEST(KeepRefUntilDoneTest, simple) { - RefCountedBindTester t; - EXPECT_EQ(0, t.RefCount()); - { - Callback0 cb = KeepRefUntilDone(&t); - EXPECT_EQ(1, t.RefCount()); - cb(); - EXPECT_EQ(1, t.RefCount()); - cb(); - EXPECT_EQ(1, t.RefCount()); - } - EXPECT_EQ(0, t.RefCount()); -} - -TEST(KeepRefUntilDoneTest, copy) { - RefCountedBindTester t; - EXPECT_EQ(0, t.RefCount()); - Callback0 cb2; - { - Callback0 cb = KeepRefUntilDone(&t); - EXPECT_EQ(1, t.RefCount()); - cb2 = cb; - } - EXPECT_EQ(1, t.RefCount()); - cb2 = Callback0(); - EXPECT_EQ(0, t.RefCount()); -} - -TEST(KeepRefUntilDoneTest, scopedref) { - RefCountedBindTester t; - EXPECT_EQ(0, t.RefCount()); - { - scoped_refptr t_scoped_ref(&t); - Callback0 cb = KeepRefUntilDone(t_scoped_ref); - t_scoped_ref = nullptr; - EXPECT_EQ(1, t.RefCount()); - cb(); - EXPECT_EQ(1, t.RefCount()); - } - EXPECT_EQ(0, t.RefCount()); -} - -} // namespace rtc diff --git a/rtc_base/containers/BUILD.gn b/rtc_base/containers/BUILD.gn new file mode 100644 index 0000000000..f303e706e4 --- /dev/null +++ b/rtc_base/containers/BUILD.gn @@ -0,0 +1,59 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../webrtc.gni") + +rtc_library("flat_containers_internal") { + sources = [ + "as_const.h", + "flat_tree.cc", + "flat_tree.h", + "identity.h", + "invoke.h", + "move_only_int.h", + "not_fn.h", + "void_t.h", + ] + deps = [ + "..:checks", + "../system:no_unique_address", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container" ] + visibility = [ ":*" ] +} + +rtc_source_set("flat_set") { + sources = [ "flat_set.h" ] + deps = [ ":flat_containers_internal" ] +} + +rtc_source_set("flat_map") { + sources = [ "flat_map.h" ] + deps = [ + ":flat_containers_internal", + "..:checks", + ] +} + +rtc_library("unittests") { + testonly = true + sources = [ + "flat_map_unittest.cc", + "flat_set_unittest.cc", + "flat_tree_unittest.cc", + ] + deps = [ + ":flat_containers_internal", + ":flat_map", + ":flat_set", + "../../test:test_support", + "//testing/gmock:gmock", + "//testing/gtest:gtest", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container" ] +} diff --git a/rtc_base/containers/as_const.h b/rtc_base/containers/as_const.h new file mode 100644 index 0000000000..a41b3bc378 --- /dev/null +++ b/rtc_base/containers/as_const.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// This implementation is borrowed from Chromium. + +#ifndef RTC_BASE_CONTAINERS_AS_CONST_H_ +#define RTC_BASE_CONTAINERS_AS_CONST_H_ + +#include + +namespace webrtc { + +// C++14 implementation of C++17's std::as_const(): +// https://en.cppreference.com/w/cpp/utility/as_const +template +constexpr std::add_const_t& as_const(T& t) noexcept { + return t; +} + +template +void as_const(const T&& t) = delete; + +} // namespace webrtc + +#endif // RTC_BASE_CONTAINERS_AS_CONST_H_ diff --git a/rtc_base/containers/flat_map.h b/rtc_base/containers/flat_map.h new file mode 100644 index 0000000000..1dfae51655 --- /dev/null +++ b/rtc_base/containers/flat_map.h @@ -0,0 +1,374 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// This implementation is borrowed from Chromium. + +#ifndef RTC_BASE_CONTAINERS_FLAT_MAP_H_ +#define RTC_BASE_CONTAINERS_FLAT_MAP_H_ + +#include +#include +#include +#include + +#include "rtc_base/checks.h" +#include "rtc_base/containers/flat_tree.h" + +namespace webrtc { + +namespace flat_containers_internal { + +// An implementation of the flat_tree GetKeyFromValue template parameter that +// extracts the key as the first element of a pair. +struct GetFirst { + template + constexpr const Key& operator()(const std::pair& p) const { + return p.first; + } +}; + +} // namespace flat_containers_internal + +// flat_map is a container with a std::map-like interface that stores its +// contents in a sorted container, by default a vector. +// +// Its implementation mostly tracks the corresponding standardization proposal +// https://wg21.link/P0429, except that the storage of keys and values is not +// split. +// +// PROS +// +// - Good memory locality. +// - Low overhead, especially for smaller maps. +// - Performance is good for more workloads than you might expect (see +// //base/containers/README.md in Chromium repository) +// - Supports C++14 map interface. +// +// CONS +// +// - Inserts and removals are O(n). +// +// IMPORTANT NOTES +// +// - Iterators are invalidated across mutations. This means that the following +// line of code has undefined behavior since adding a new element could +// resize the container, invalidating all iterators: +// container["new element"] = it.second; +// - If possible, construct a flat_map in one operation by inserting into +// a container and moving that container into the flat_map constructor. +// +// QUICK REFERENCE +// +// Most of the core functionality is inherited from flat_tree. Please see +// flat_tree.h for more details for most of these functions. As a quick +// reference, the functions available are: +// +// Constructors (inputs need not be sorted): +// flat_map(const flat_map&); +// flat_map(flat_map&&); +// flat_map(InputIterator first, InputIterator last, +// const Compare& compare = Compare()); +// flat_map(const container_type& items, +// const Compare& compare = Compare()); +// flat_map(container_type&& items, +// const Compare& compare = Compare()); // Re-use storage. +// flat_map(std::initializer_list ilist, +// const Compare& comp = Compare()); +// +// Constructors (inputs need to be sorted): +// flat_map(sorted_unique_t, +// InputIterator first, InputIterator last, +// const Compare& compare = Compare()); +// flat_map(sorted_unique_t, +// const container_type& items, +// const Compare& compare = Compare()); +// flat_map(sorted_unique_t, +// container_type&& items, +// const Compare& compare = Compare()); // Re-use storage. +// flat_map(sorted_unique_t, +// std::initializer_list ilist, +// const Compare& comp = Compare()); +// +// Assignment functions: +// flat_map& operator=(const flat_map&); +// flat_map& operator=(flat_map&&); +// flat_map& operator=(initializer_list); +// +// Memory management functions: +// void reserve(size_t); +// size_t capacity() const; +// void shrink_to_fit(); +// +// Size management functions: +// void clear(); +// size_t size() const; +// size_t max_size() const; +// bool empty() const; +// +// Iterator functions: +// iterator begin(); +// const_iterator begin() const; +// const_iterator cbegin() const; +// iterator end(); +// const_iterator end() const; +// const_iterator cend() const; +// reverse_iterator rbegin(); +// const reverse_iterator rbegin() const; +// const_reverse_iterator crbegin() const; +// reverse_iterator rend(); +// const_reverse_iterator rend() const; +// const_reverse_iterator crend() const; +// +// Insert and accessor functions: +// mapped_type& operator[](const key_type&); +// mapped_type& operator[](key_type&&); +// mapped_type& at(const K&); +// const mapped_type& at(const K&) const; +// pair insert(const value_type&); +// pair insert(value_type&&); +// iterator insert(const_iterator hint, const value_type&); +// iterator insert(const_iterator hint, value_type&&); +// void insert(InputIterator first, InputIterator last); +// pair insert_or_assign(K&&, M&&); +// iterator insert_or_assign(const_iterator hint, K&&, M&&); +// pair emplace(Args&&...); +// iterator emplace_hint(const_iterator, Args&&...); +// pair try_emplace(K&&, Args&&...); +// iterator try_emplace(const_iterator hint, K&&, Args&&...); + +// Underlying type functions: +// container_type extract() &&; +// void replace(container_type&&); +// +// Erase functions: +// iterator erase(iterator); +// iterator erase(const_iterator); +// iterator erase(const_iterator first, const_iterator& last); +// template size_t erase(const K& key); +// +// Comparators (see std::map documentation). +// key_compare key_comp() const; +// value_compare value_comp() const; +// +// Search functions: +// template size_t count(const K&) const; +// template iterator find(const K&); +// template const_iterator find(const K&) const; +// template bool contains(const K&) const; +// template pair equal_range(const K&); +// template iterator lower_bound(const K&); +// template const_iterator lower_bound(const K&) const; +// template iterator upper_bound(const K&); +// template const_iterator upper_bound(const K&) const; +// +// General functions: +// void swap(flat_map&); +// +// Non-member operators: +// bool operator==(const flat_map&, const flat_map); +// bool operator!=(const flat_map&, const flat_map); +// bool operator<(const flat_map&, const flat_map); +// bool operator>(const flat_map&, const flat_map); +// bool operator>=(const flat_map&, const flat_map); +// bool operator<=(const flat_map&, const flat_map); +// +template , + class Container = std::vector>> +class flat_map : public ::webrtc::flat_containers_internal::flat_tree< + Key, + flat_containers_internal::GetFirst, + Compare, + Container> { + private: + using tree = typename ::webrtc::flat_containers_internal:: + flat_tree; + + public: + using key_type = typename tree::key_type; + using mapped_type = Mapped; + using value_type = typename tree::value_type; + using reference = typename Container::reference; + using const_reference = typename Container::const_reference; + using size_type = typename Container::size_type; + using difference_type = typename Container::difference_type; + using iterator = typename tree::iterator; + using const_iterator = typename tree::const_iterator; + using reverse_iterator = typename tree::reverse_iterator; + using const_reverse_iterator = typename tree::const_reverse_iterator; + using container_type = typename tree::container_type; + + // -------------------------------------------------------------------------- + // Lifetime and assignments. + // + // Note: we explicitly bring operator= in because otherwise + // flat_map<...> x; + // x = {...}; + // Would first create a flat_map and then move assign it. This most likely + // would be optimized away but still affects our debug builds. + + using tree::tree; + using tree::operator=; + + // Out-of-bound calls to at() will CHECK. + template + mapped_type& at(const K& key); + template + const mapped_type& at(const K& key) const; + + // -------------------------------------------------------------------------- + // Map-specific insert operations. + // + // Normal insert() functions are inherited from flat_tree. + // + // Assume that every operation invalidates iterators and references. + // Insertion of one element can take O(size). + + mapped_type& operator[](const key_type& key); + mapped_type& operator[](key_type&& key); + + template + std::pair insert_or_assign(K&& key, M&& obj); + template + iterator insert_or_assign(const_iterator hint, K&& key, M&& obj); + + template + std::enable_if_t::value, + std::pair> + try_emplace(K&& key, Args&&... args); + + template + std::enable_if_t::value, iterator> + try_emplace(const_iterator hint, K&& key, Args&&... args); + + // -------------------------------------------------------------------------- + // General operations. + // + // Assume that swap invalidates iterators and references. + + void swap(flat_map& other) noexcept; + + friend void swap(flat_map& lhs, flat_map& rhs) noexcept { lhs.swap(rhs); } +}; + +// ---------------------------------------------------------------------------- +// Lookups. + +template +template +auto flat_map::at(const K& key) + -> mapped_type& { + iterator found = tree::find(key); + RTC_CHECK(found != tree::end()); + return found->second; +} + +template +template +auto flat_map::at(const K& key) const + -> const mapped_type& { + const_iterator found = tree::find(key); + RTC_CHECK(found != tree::cend()); + return found->second; +} + +// ---------------------------------------------------------------------------- +// Insert operations. + +template +auto flat_map::operator[](const key_type& key) + -> mapped_type& { + iterator found = tree::lower_bound(key); + if (found == tree::end() || tree::key_comp()(key, found->first)) + found = tree::unsafe_emplace(found, key, mapped_type()); + return found->second; +} + +template +auto flat_map::operator[](key_type&& key) + -> mapped_type& { + iterator found = tree::lower_bound(key); + if (found == tree::end() || tree::key_comp()(key, found->first)) + found = tree::unsafe_emplace(found, std::move(key), mapped_type()); + return found->second; +} + +template +template +auto flat_map::insert_or_assign(K&& key, + M&& obj) + -> std::pair { + auto result = + tree::emplace_key_args(key, std::forward(key), std::forward(obj)); + if (!result.second) + result.first->second = std::forward(obj); + return result; +} + +template +template +auto flat_map::insert_or_assign( + const_iterator hint, + K&& key, + M&& obj) -> iterator { + auto result = tree::emplace_hint_key_args(hint, key, std::forward(key), + std::forward(obj)); + if (!result.second) + result.first->second = std::forward(obj); + return result.first; +} + +template +template +auto flat_map::try_emplace(K&& key, + Args&&... args) + -> std::enable_if_t::value, + std::pair> { + return tree::emplace_key_args( + key, std::piecewise_construct, + std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(args)...)); +} + +template +template +auto flat_map::try_emplace(const_iterator hint, + K&& key, + Args&&... args) + -> std::enable_if_t::value, iterator> { + return tree::emplace_hint_key_args( + hint, key, std::piecewise_construct, + std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(args)...)) + .first; +} + +// ---------------------------------------------------------------------------- +// General operations. + +template +void flat_map::swap(flat_map& other) noexcept { + tree::swap(other); +} + +// Erases all elements that match predicate. It has O(size) complexity. +// +// flat_map last_times; +// ... +// EraseIf(last_times, +// [&](const auto& element) { return now - element.second > kLimit; }); + +// NOLINTNEXTLINE(misc-unused-using-decls) +using ::webrtc::flat_containers_internal::EraseIf; + +} // namespace webrtc + +#endif // RTC_BASE_CONTAINERS_FLAT_MAP_H_ diff --git a/rtc_base/containers/flat_map_unittest.cc b/rtc_base/containers/flat_map_unittest.cc new file mode 100644 index 0000000000..8f0b77fc30 --- /dev/null +++ b/rtc_base/containers/flat_map_unittest.cc @@ -0,0 +1,454 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// This implementation is borrowed from Chromium. + +#include "rtc_base/containers/flat_map.h" + +#include +#include +#include + +#include "rtc_base/containers/move_only_int.h" +#include "test/gmock.h" +#include "test/gtest.h" + +// A flat_map is basically a interface to flat_tree. So several basic +// operations are tested to make sure things are set up properly, but the bulk +// of the tests are in flat_tree_unittests.cc. + +using ::testing::ElementsAre; + +namespace webrtc { + +namespace { + +struct Unsortable { + int value; +}; + +bool operator==(const Unsortable& lhs, const Unsortable& rhs) { + return lhs.value == rhs.value; +} + +bool operator<(const Unsortable& lhs, const Unsortable& rhs) = delete; +bool operator<=(const Unsortable& lhs, const Unsortable& rhs) = delete; +bool operator>(const Unsortable& lhs, const Unsortable& rhs) = delete; +bool operator>=(const Unsortable& lhs, const Unsortable& rhs) = delete; + +TEST(FlatMap, IncompleteType) { + struct A { + using Map = flat_map; + int data; + Map set_with_incomplete_type; + Map::iterator it; + Map::const_iterator cit; + + // We do not declare operator< because clang complains that it's unused. + }; + + A a; +} + +TEST(FlatMap, RangeConstructor) { + flat_map::value_type input_vals[] = { + {1, 1}, {1, 2}, {1, 3}, {2, 1}, {2, 2}, {2, 3}, {3, 1}, {3, 2}, {3, 3}}; + + flat_map first(std::begin(input_vals), std::end(input_vals)); + EXPECT_THAT(first, ElementsAre(std::make_pair(1, 1), std::make_pair(2, 1), + std::make_pair(3, 1))); +} + +TEST(FlatMap, MoveConstructor) { + using pair = std::pair; + + flat_map original; + original.insert(pair(MoveOnlyInt(1), MoveOnlyInt(1))); + original.insert(pair(MoveOnlyInt(2), MoveOnlyInt(2))); + original.insert(pair(MoveOnlyInt(3), MoveOnlyInt(3))); + original.insert(pair(MoveOnlyInt(4), MoveOnlyInt(4))); + + flat_map moved(std::move(original)); + + EXPECT_EQ(1U, moved.count(MoveOnlyInt(1))); + EXPECT_EQ(1U, moved.count(MoveOnlyInt(2))); + EXPECT_EQ(1U, moved.count(MoveOnlyInt(3))); + EXPECT_EQ(1U, moved.count(MoveOnlyInt(4))); +} + +TEST(FlatMap, VectorConstructor) { + using IntPair = std::pair; + using IntMap = flat_map; + std::vector vect{{1, 1}, {1, 2}, {2, 1}}; + IntMap map(std::move(vect)); + EXPECT_THAT(map, ElementsAre(IntPair(1, 1), IntPair(2, 1))); +} + +TEST(FlatMap, InitializerListConstructor) { + flat_map cont( + {{1, 1}, {2, 2}, {3, 3}, {4, 4}, {5, 5}, {1, 2}, {10, 10}, {8, 8}}); + EXPECT_THAT(cont, ElementsAre(std::make_pair(1, 1), std::make_pair(2, 2), + std::make_pair(3, 3), std::make_pair(4, 4), + std::make_pair(5, 5), std::make_pair(8, 8), + std::make_pair(10, 10))); +} + +TEST(FlatMap, SortedRangeConstructor) { + using PairType = std::pair; + using MapType = flat_map; + MapType::value_type input_vals[] = {{1, {1}}, {2, {1}}, {3, {1}}}; + MapType map(sorted_unique, std::begin(input_vals), std::end(input_vals)); + EXPECT_THAT( + map, ElementsAre(PairType(1, {1}), PairType(2, {1}), PairType(3, {1}))); +} + +TEST(FlatMap, SortedCopyFromVectorConstructor) { + using PairType = std::pair; + using MapType = flat_map; + std::vector vect{{1, {1}}, {2, {1}}}; + MapType map(sorted_unique, vect); + EXPECT_THAT(map, ElementsAre(PairType(1, {1}), PairType(2, {1}))); +} + +TEST(FlatMap, SortedMoveFromVectorConstructor) { + using PairType = std::pair; + using MapType = flat_map; + std::vector vect{{1, {1}}, {2, {1}}}; + MapType map(sorted_unique, std::move(vect)); + EXPECT_THAT(map, ElementsAre(PairType(1, {1}), PairType(2, {1}))); +} + +TEST(FlatMap, SortedInitializerListConstructor) { + using PairType = std::pair; + flat_map map( + sorted_unique, + {{1, {1}}, {2, {2}}, {3, {3}}, {4, {4}}, {5, {5}}, {8, {8}}, {10, {10}}}); + EXPECT_THAT(map, + ElementsAre(PairType(1, {1}), PairType(2, {2}), PairType(3, {3}), + PairType(4, {4}), PairType(5, {5}), PairType(8, {8}), + PairType(10, {10}))); +} + +TEST(FlatMap, InitializerListAssignment) { + flat_map cont; + cont = {{1, 1}, {2, 2}}; + EXPECT_THAT(cont, ElementsAre(std::make_pair(1, 1), std::make_pair(2, 2))); +} + +TEST(FlatMap, InsertFindSize) { + flat_map s; + s.insert(std::make_pair(1, 1)); + s.insert(std::make_pair(1, 1)); + s.insert(std::make_pair(2, 2)); + + EXPECT_EQ(2u, s.size()); + EXPECT_EQ(std::make_pair(1, 1), *s.find(1)); + EXPECT_EQ(std::make_pair(2, 2), *s.find(2)); + EXPECT_EQ(s.end(), s.find(7)); +} + +TEST(FlatMap, CopySwap) { + flat_map original; + original.insert({1, 1}); + original.insert({2, 2}); + EXPECT_THAT(original, + ElementsAre(std::make_pair(1, 1), std::make_pair(2, 2))); + + flat_map copy(original); + EXPECT_THAT(copy, ElementsAre(std::make_pair(1, 1), std::make_pair(2, 2))); + + copy.erase(copy.begin()); + copy.insert({10, 10}); + EXPECT_THAT(copy, ElementsAre(std::make_pair(2, 2), std::make_pair(10, 10))); + + original.swap(copy); + EXPECT_THAT(original, + ElementsAre(std::make_pair(2, 2), std::make_pair(10, 10))); + EXPECT_THAT(copy, ElementsAre(std::make_pair(1, 1), std::make_pair(2, 2))); +} + +// operator[](const Key&) +TEST(FlatMap, SubscriptConstKey) { + flat_map m; + + // Default construct elements that don't exist yet. + int& s = m["a"]; + EXPECT_EQ(0, s); + EXPECT_EQ(1u, m.size()); + + // The returned mapped reference should refer into the map. + s = 22; + EXPECT_EQ(22, m["a"]); + + // Overwrite existing elements. + m["a"] = 44; + EXPECT_EQ(44, m["a"]); +} + +// operator[](Key&&) +TEST(FlatMap, SubscriptMoveOnlyKey) { + flat_map m; + + // Default construct elements that don't exist yet. + int& s = m[MoveOnlyInt(1)]; + EXPECT_EQ(0, s); + EXPECT_EQ(1u, m.size()); + + // The returned mapped reference should refer into the map. + s = 22; + EXPECT_EQ(22, m[MoveOnlyInt(1)]); + + // Overwrite existing elements. + m[MoveOnlyInt(1)] = 44; + EXPECT_EQ(44, m[MoveOnlyInt(1)]); +} + +// Mapped& at(const Key&) +// const Mapped& at(const Key&) const +TEST(FlatMap, AtFunction) { + flat_map m = {{1, "a"}, {2, "b"}}; + + // Basic Usage. + EXPECT_EQ("a", m.at(1)); + EXPECT_EQ("b", m.at(2)); + + // Const reference works. + const std::string& const_ref = webrtc::as_const(m).at(1); + EXPECT_EQ("a", const_ref); + + // Reference works, can operate on the string. + m.at(1)[0] = 'x'; + EXPECT_EQ("x", m.at(1)); + + // Out-of-bounds will CHECK. + EXPECT_DEATH_IF_SUPPORTED(m.at(-1), ""); + EXPECT_DEATH_IF_SUPPORTED({ m.at(-1)[0] = 'z'; }, ""); + + // Heterogeneous look-up works. + flat_map m2 = {{"a", 1}, {"b", 2}}; + EXPECT_EQ(1, m2.at(absl::string_view("a"))); + EXPECT_EQ(2, webrtc::as_const(m2).at(absl::string_view("b"))); +} + +// insert_or_assign(K&&, M&&) +TEST(FlatMap, InsertOrAssignMoveOnlyKey) { + flat_map m; + + // Initial insertion should return an iterator to the element and set the + // second pair member to |true|. The inserted key and value should be moved + // from. + MoveOnlyInt key(1); + MoveOnlyInt val(22); + auto result = m.insert_or_assign(std::move(key), std::move(val)); + EXPECT_EQ(1, result.first->first.data()); + EXPECT_EQ(22, result.first->second.data()); + EXPECT_TRUE(result.second); + EXPECT_EQ(1u, m.size()); + EXPECT_EQ(0, key.data()); // moved from + EXPECT_EQ(0, val.data()); // moved from + + // Second call with same key should result in an assignment, overwriting the + // old value. Assignment should be indicated by setting the second pair member + // to |false|. Only the inserted value should be moved from, the key should be + // left intact. + key = MoveOnlyInt(1); + val = MoveOnlyInt(44); + result = m.insert_or_assign(std::move(key), std::move(val)); + EXPECT_EQ(1, result.first->first.data()); + EXPECT_EQ(44, result.first->second.data()); + EXPECT_FALSE(result.second); + EXPECT_EQ(1u, m.size()); + EXPECT_EQ(1, key.data()); // not moved from + EXPECT_EQ(0, val.data()); // moved from + + // Check that random insertion results in sorted range. + flat_map map; + for (int i : {3, 1, 5, 6, 8, 7, 0, 9, 4, 2}) { + map.insert_or_assign(MoveOnlyInt(i), i); + EXPECT_TRUE(absl::c_is_sorted(map)); + } +} + +// insert_or_assign(const_iterator hint, K&&, M&&) +TEST(FlatMap, InsertOrAssignMoveOnlyKeyWithHint) { + flat_map m; + + // Initial insertion should return an iterator to the element. The inserted + // key and value should be moved from. + MoveOnlyInt key(1); + MoveOnlyInt val(22); + auto result = m.insert_or_assign(m.end(), std::move(key), std::move(val)); + EXPECT_EQ(1, result->first.data()); + EXPECT_EQ(22, result->second.data()); + EXPECT_EQ(1u, m.size()); + EXPECT_EQ(0, key.data()); // moved from + EXPECT_EQ(0, val.data()); // moved from + + // Second call with same key should result in an assignment, overwriting the + // old value. Only the inserted value should be moved from, the key should be + // left intact. + key = MoveOnlyInt(1); + val = MoveOnlyInt(44); + result = m.insert_or_assign(m.end(), std::move(key), std::move(val)); + EXPECT_EQ(1, result->first.data()); + EXPECT_EQ(44, result->second.data()); + EXPECT_EQ(1u, m.size()); + EXPECT_EQ(1, key.data()); // not moved from + EXPECT_EQ(0, val.data()); // moved from + + // Check that random insertion results in sorted range. + flat_map map; + for (int i : {3, 1, 5, 6, 8, 7, 0, 9, 4, 2}) { + map.insert_or_assign(map.end(), MoveOnlyInt(i), i); + EXPECT_TRUE(absl::c_is_sorted(map)); + } +} + +// try_emplace(K&&, Args&&...) +TEST(FlatMap, TryEmplaceMoveOnlyKey) { + flat_map> m; + + // Trying to emplace into an empty map should succeed. Insertion should return + // an iterator to the element and set the second pair member to |true|. The + // inserted key and value should be moved from. + MoveOnlyInt key(1); + MoveOnlyInt val1(22); + MoveOnlyInt val2(44); + // Test piecewise construction of mapped_type. + auto result = m.try_emplace(std::move(key), std::move(val1), std::move(val2)); + EXPECT_EQ(1, result.first->first.data()); + EXPECT_EQ(22, result.first->second.first.data()); + EXPECT_EQ(44, result.first->second.second.data()); + EXPECT_TRUE(result.second); + EXPECT_EQ(1u, m.size()); + EXPECT_EQ(0, key.data()); // moved from + EXPECT_EQ(0, val1.data()); // moved from + EXPECT_EQ(0, val2.data()); // moved from + + // Second call with same key should result in a no-op, returning an iterator + // to the existing element and returning false as the second pair member. + // Key and values that were attempted to be inserted should be left intact. + key = MoveOnlyInt(1); + auto paired_val = std::make_pair(MoveOnlyInt(33), MoveOnlyInt(55)); + // Test construction of mapped_type from pair. + result = m.try_emplace(std::move(key), std::move(paired_val)); + EXPECT_EQ(1, result.first->first.data()); + EXPECT_EQ(22, result.first->second.first.data()); + EXPECT_EQ(44, result.first->second.second.data()); + EXPECT_FALSE(result.second); + EXPECT_EQ(1u, m.size()); + EXPECT_EQ(1, key.data()); // not moved from + EXPECT_EQ(33, paired_val.first.data()); // not moved from + EXPECT_EQ(55, paired_val.second.data()); // not moved from + + // Check that random insertion results in sorted range. + flat_map map; + for (int i : {3, 1, 5, 6, 8, 7, 0, 9, 4, 2}) { + map.try_emplace(MoveOnlyInt(i), i); + EXPECT_TRUE(absl::c_is_sorted(map)); + } +} + +// try_emplace(const_iterator hint, K&&, Args&&...) +TEST(FlatMap, TryEmplaceMoveOnlyKeyWithHint) { + flat_map> m; + + // Trying to emplace into an empty map should succeed. Insertion should return + // an iterator to the element. The inserted key and value should be moved + // from. + MoveOnlyInt key(1); + MoveOnlyInt val1(22); + MoveOnlyInt val2(44); + // Test piecewise construction of mapped_type. + auto result = + m.try_emplace(m.end(), std::move(key), std::move(val1), std::move(val2)); + EXPECT_EQ(1, result->first.data()); + EXPECT_EQ(22, result->second.first.data()); + EXPECT_EQ(44, result->second.second.data()); + EXPECT_EQ(1u, m.size()); + EXPECT_EQ(0, key.data()); // moved from + EXPECT_EQ(0, val1.data()); // moved from + EXPECT_EQ(0, val2.data()); // moved from + + // Second call with same key should result in a no-op, returning an iterator + // to the existing element. Key and values that were attempted to be inserted + // should be left intact. + key = MoveOnlyInt(1); + val1 = MoveOnlyInt(33); + val2 = MoveOnlyInt(55); + auto paired_val = std::make_pair(MoveOnlyInt(33), MoveOnlyInt(55)); + // Test construction of mapped_type from pair. + result = m.try_emplace(m.end(), std::move(key), std::move(paired_val)); + EXPECT_EQ(1, result->first.data()); + EXPECT_EQ(22, result->second.first.data()); + EXPECT_EQ(44, result->second.second.data()); + EXPECT_EQ(1u, m.size()); + EXPECT_EQ(1, key.data()); // not moved from + EXPECT_EQ(33, paired_val.first.data()); // not moved from + EXPECT_EQ(55, paired_val.second.data()); // not moved from + + // Check that random insertion results in sorted range. + flat_map map; + for (int i : {3, 1, 5, 6, 8, 7, 0, 9, 4, 2}) { + map.try_emplace(map.end(), MoveOnlyInt(i), i); + EXPECT_TRUE(absl::c_is_sorted(map)); + } +} + +TEST(FlatMap, UsingTransparentCompare) { + using ExplicitInt = MoveOnlyInt; + flat_map m; + const auto& m1 = m; + int x = 0; + + // Check if we can use lookup functions without converting to key_type. + // Correctness is checked in flat_tree tests. + m.count(x); + m1.count(x); + m.find(x); + m1.find(x); + m.equal_range(x); + m1.equal_range(x); + m.lower_bound(x); + m1.lower_bound(x); + m.upper_bound(x); + m1.upper_bound(x); + m.erase(x); + + // Check if we broke overload resolution. + m.emplace(ExplicitInt(0), 0); + m.emplace(ExplicitInt(1), 0); + m.erase(m.begin()); + m.erase(m.cbegin()); +} + +TEST(FlatMap, SupportsEraseIf) { + flat_map m; + m.insert(std::make_pair(MoveOnlyInt(1), MoveOnlyInt(1))); + m.insert(std::make_pair(MoveOnlyInt(2), MoveOnlyInt(2))); + m.insert(std::make_pair(MoveOnlyInt(3), MoveOnlyInt(3))); + m.insert(std::make_pair(MoveOnlyInt(4), MoveOnlyInt(4))); + m.insert(std::make_pair(MoveOnlyInt(5), MoveOnlyInt(5))); + + EraseIf(m, [to_be_removed = MoveOnlyInt(2)]( + const std::pair& e) { + return e.first == to_be_removed; + }); + + EXPECT_EQ(m.size(), 4u); + ASSERT_TRUE(m.find(MoveOnlyInt(1)) != m.end()); + ASSERT_FALSE(m.find(MoveOnlyInt(2)) != m.end()); + ASSERT_TRUE(m.find(MoveOnlyInt(3)) != m.end()); + ASSERT_TRUE(m.find(MoveOnlyInt(4)) != m.end()); + ASSERT_TRUE(m.find(MoveOnlyInt(5)) != m.end()); +} + +} // namespace +} // namespace webrtc diff --git a/rtc_base/containers/flat_set.h b/rtc_base/containers/flat_set.h new file mode 100644 index 0000000000..e088cc5314 --- /dev/null +++ b/rtc_base/containers/flat_set.h @@ -0,0 +1,178 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// This implementation is borrowed from Chromium. + +#ifndef RTC_BASE_CONTAINERS_FLAT_SET_H_ +#define RTC_BASE_CONTAINERS_FLAT_SET_H_ + +#include +#include + +#include "rtc_base/containers/flat_tree.h" +#include "rtc_base/containers/identity.h" + +namespace webrtc { + +// flat_set is a container with a std::set-like interface that stores its +// contents in a sorted container, by default a vector. +// +// Its implementation mostly tracks the corresponding standardization proposal +// https://wg21.link/P1222. +// +// +// PROS +// +// - Good memory locality. +// - Low overhead, especially for smaller sets. +// - Performance is good for more workloads than you might expect (see +// //base/containers/README.md in Chromium repository) +// - Supports C++14 set interface. +// +// CONS +// +// - Inserts and removals are O(n). +// +// IMPORTANT NOTES +// +// - Iterators are invalidated across mutations. +// - If possible, construct a flat_set in one operation by inserting into +// a container and moving that container into the flat_set constructor. +// - For multiple removals use base::EraseIf() which is O(n) rather than +// O(n * removed_items). +// +// QUICK REFERENCE +// +// Most of the core functionality is inherited from flat_tree. Please see +// flat_tree.h for more details for most of these functions. As a quick +// reference, the functions available are: +// +// Constructors (inputs need not be sorted): +// flat_set(const flat_set&); +// flat_set(flat_set&&); +// flat_set(InputIterator first, InputIterator last, +// const Compare& compare = Compare()); +// flat_set(const container_type& items, +// const Compare& compare = Compare()); +// flat_set(container_type&& items, +// const Compare& compare = Compare()); // Re-use storage. +// flat_set(std::initializer_list ilist, +// const Compare& comp = Compare()); +// +// Constructors (inputs need to be sorted): +// flat_set(sorted_unique_t, +// InputIterator first, InputIterator last, +// const Compare& compare = Compare()); +// flat_set(sorted_unique_t, +// const container_type& items, +// const Compare& compare = Compare()); +// flat_set(sorted_unique_t, +// container_type&& items, +// const Compare& compare = Compare()); // Re-use storage. +// flat_set(sorted_unique_t, +// std::initializer_list ilist, +// const Compare& comp = Compare()); +// +// Assignment functions: +// flat_set& operator=(const flat_set&); +// flat_set& operator=(flat_set&&); +// flat_set& operator=(initializer_list); +// +// Memory management functions: +// void reserve(size_t); +// size_t capacity() const; +// void shrink_to_fit(); +// +// Size management functions: +// void clear(); +// size_t size() const; +// size_t max_size() const; +// bool empty() const; +// +// Iterator functions: +// iterator begin(); +// const_iterator begin() const; +// const_iterator cbegin() const; +// iterator end(); +// const_iterator end() const; +// const_iterator cend() const; +// reverse_iterator rbegin(); +// const reverse_iterator rbegin() const; +// const_reverse_iterator crbegin() const; +// reverse_iterator rend(); +// const_reverse_iterator rend() const; +// const_reverse_iterator crend() const; +// +// Insert and accessor functions: +// pair insert(const key_type&); +// pair insert(key_type&&); +// void insert(InputIterator first, InputIterator last); +// iterator insert(const_iterator hint, const key_type&); +// iterator insert(const_iterator hint, key_type&&); +// pair emplace(Args&&...); +// iterator emplace_hint(const_iterator, Args&&...); +// +// Underlying type functions: +// container_type extract() &&; +// void replace(container_type&&); +// +// Erase functions: +// iterator erase(iterator); +// iterator erase(const_iterator); +// iterator erase(const_iterator first, const_iterator& last); +// template size_t erase(const K& key); +// +// Comparators (see std::set documentation). +// key_compare key_comp() const; +// value_compare value_comp() const; +// +// Search functions: +// template size_t count(const K&) const; +// template iterator find(const K&); +// template const_iterator find(const K&) const; +// template bool contains(const K&) const; +// template pair equal_range(K&); +// template iterator lower_bound(const K&); +// template const_iterator lower_bound(const K&) const; +// template iterator upper_bound(const K&); +// template const_iterator upper_bound(const K&) const; +// +// General functions: +// void swap(flat_set&); +// +// Non-member operators: +// bool operator==(const flat_set&, const flat_set); +// bool operator!=(const flat_set&, const flat_set); +// bool operator<(const flat_set&, const flat_set); +// bool operator>(const flat_set&, const flat_set); +// bool operator>=(const flat_set&, const flat_set); +// bool operator<=(const flat_set&, const flat_set); +// +template , + class Container = std::vector> +using flat_set = typename ::webrtc::flat_containers_internal:: + flat_tree; + +// ---------------------------------------------------------------------------- +// General operations. + +// Erases all elements that match predicate. It has O(size) complexity. +// +// flat_set numbers; +// ... +// EraseIf(numbers, [](int number) { return number % 2 == 1; }); + +// NOLINTNEXTLINE(misc-unused-using-decls) +using ::webrtc::flat_containers_internal::EraseIf; + +} // namespace webrtc + +#endif // RTC_BASE_CONTAINERS_FLAT_SET_H_ diff --git a/rtc_base/containers/flat_set_unittest.cc b/rtc_base/containers/flat_set_unittest.cc new file mode 100644 index 0000000000..617db92440 --- /dev/null +++ b/rtc_base/containers/flat_set_unittest.cc @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// This implementation is borrowed from Chromium. + +#include "rtc_base/containers/flat_set.h" + +#include +#include +#include +#include + +#include "rtc_base/containers/move_only_int.h" +#include "test/gmock.h" +#include "test/gtest.h" + +// A flat_set is basically a interface to flat_tree. So several basic +// operations are tested to make sure things are set up properly, but the bulk +// of the tests are in flat_tree_unittests.cc. + +using ::testing::ElementsAre; + +namespace webrtc { +namespace { + +TEST(FlatSet, IncompleteType) { + struct A { + using Set = flat_set; + int data; + Set set_with_incomplete_type; + Set::iterator it; + Set::const_iterator cit; + + // We do not declare operator< because clang complains that it's unused. + }; + + A a; +} + +TEST(FlatSet, RangeConstructor) { + flat_set::value_type input_vals[] = {1, 1, 1, 2, 2, 2, 3, 3, 3}; + + flat_set cont(std::begin(input_vals), std::end(input_vals)); + EXPECT_THAT(cont, ElementsAre(1, 2, 3)); +} + +TEST(FlatSet, MoveConstructor) { + int input_range[] = {1, 2, 3, 4}; + + flat_set original(std::begin(input_range), + std::end(input_range)); + flat_set moved(std::move(original)); + + EXPECT_EQ(1U, moved.count(MoveOnlyInt(1))); + EXPECT_EQ(1U, moved.count(MoveOnlyInt(2))); + EXPECT_EQ(1U, moved.count(MoveOnlyInt(3))); + EXPECT_EQ(1U, moved.count(MoveOnlyInt(4))); +} + +TEST(FlatSet, InitializerListConstructor) { + flat_set cont({1, 2, 3, 4, 5, 6, 10, 8}); + EXPECT_THAT(cont, ElementsAre(1, 2, 3, 4, 5, 6, 8, 10)); +} + +TEST(FlatSet, InsertFindSize) { + flat_set s; + s.insert(1); + s.insert(1); + s.insert(2); + + EXPECT_EQ(2u, s.size()); + EXPECT_EQ(1, *s.find(1)); + EXPECT_EQ(2, *s.find(2)); + EXPECT_EQ(s.end(), s.find(7)); +} + +TEST(FlatSet, CopySwap) { + flat_set original; + original.insert(1); + original.insert(2); + EXPECT_THAT(original, ElementsAre(1, 2)); + + flat_set copy(original); + EXPECT_THAT(copy, ElementsAre(1, 2)); + + copy.erase(copy.begin()); + copy.insert(10); + EXPECT_THAT(copy, ElementsAre(2, 10)); + + original.swap(copy); + EXPECT_THAT(original, ElementsAre(2, 10)); + EXPECT_THAT(copy, ElementsAre(1, 2)); +} + +TEST(FlatSet, UsingTransparentCompare) { + using ExplicitInt = webrtc::MoveOnlyInt; + flat_set s; + const auto& s1 = s; + int x = 0; + + // Check if we can use lookup functions without converting to key_type. + // Correctness is checked in flat_tree tests. + s.count(x); + s1.count(x); + s.find(x); + s1.find(x); + s.equal_range(x); + s1.equal_range(x); + s.lower_bound(x); + s1.lower_bound(x); + s.upper_bound(x); + s1.upper_bound(x); + s.erase(x); + + // Check if we broke overload resolution. + s.emplace(0); + s.emplace(1); + s.erase(s.begin()); + s.erase(s.cbegin()); +} + +TEST(FlatSet, SupportsEraseIf) { + flat_set s; + s.emplace(MoveOnlyInt(1)); + s.emplace(MoveOnlyInt(2)); + s.emplace(MoveOnlyInt(3)); + s.emplace(MoveOnlyInt(4)); + s.emplace(MoveOnlyInt(5)); + + EraseIf(s, [to_be_removed = MoveOnlyInt(2)](const MoveOnlyInt& elem) { + return elem == to_be_removed; + }); + + EXPECT_EQ(s.size(), 4u); + ASSERT_TRUE(s.find(MoveOnlyInt(1)) != s.end()); + ASSERT_FALSE(s.find(MoveOnlyInt(2)) != s.end()); + ASSERT_TRUE(s.find(MoveOnlyInt(3)) != s.end()); + ASSERT_TRUE(s.find(MoveOnlyInt(4)) != s.end()); + ASSERT_TRUE(s.find(MoveOnlyInt(5)) != s.end()); +} +} // namespace +} // namespace webrtc diff --git a/media/engine/constants.cc b/rtc_base/containers/flat_tree.cc similarity index 57% rename from media/engine/constants.cc rename to rtc_base/containers/flat_tree.cc index 12d6ddde5a..9e86db191a 100644 --- a/media/engine/constants.cc +++ b/rtc_base/containers/flat_tree.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source @@ -8,12 +8,12 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include "media/engine/constants.h" +// This implementation is borrowed from Chromium. -namespace cricket { +#include "rtc_base/containers/flat_tree.h" -const int kVideoMtu = 1200; -const int kVideoRtpSendBufferSize = 65536; -const int kVideoRtpRecvBufferSize = 262144; +namespace webrtc { -} // namespace cricket +sorted_unique_t sorted_unique; + +} // namespace webrtc diff --git a/rtc_base/containers/flat_tree.h b/rtc_base/containers/flat_tree.h new file mode 100644 index 0000000000..1b02cce1b4 --- /dev/null +++ b/rtc_base/containers/flat_tree.h @@ -0,0 +1,1102 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// This implementation is borrowed from Chromium. + +#ifndef RTC_BASE_CONTAINERS_FLAT_TREE_H_ +#define RTC_BASE_CONTAINERS_FLAT_TREE_H_ + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "rtc_base/checks.h" +#include "rtc_base/containers/as_const.h" +#include "rtc_base/containers/not_fn.h" +#include "rtc_base/containers/void_t.h" +#include "rtc_base/system/no_unique_address.h" + +namespace webrtc { +// Tag type that allows skipping the sort_and_unique step when constructing a +// flat_tree in case the underlying container is already sorted and has no +// duplicate elements. +struct sorted_unique_t { + constexpr sorted_unique_t() = default; +}; +extern sorted_unique_t sorted_unique; + +namespace flat_containers_internal { + +// Helper functions used in RTC_DCHECKs below to make sure that inputs tagged +// with sorted_unique are indeed sorted and unique. +template +constexpr bool is_sorted_and_unique(const Range& range, Comp comp) { + // Being unique implies that there are no adjacent elements that + // compare equal. So this checks that each element is strictly less + // than the element after it. + return absl::c_adjacent_find(range, webrtc::not_fn(comp)) == std::end(range); +} + +// This is a convenience trait inheriting from std::true_type if Iterator is at +// least a ForwardIterator and thus supports multiple passes over a range. +template +using is_multipass = + std::is_base_of::iterator_category>; + +// Uses SFINAE to detect whether type has is_transparent member. +template +struct IsTransparentCompare : std::false_type {}; +template +struct IsTransparentCompare> + : std::true_type {}; + +// Helper inspired by C++20's std::to_array to convert a C-style array to a +// std::array. As opposed to the C++20 version this implementation does not +// provide an overload for rvalues and does not strip cv qualifers from the +// returned std::array::value_type. The returned value_type needs to be +// specified explicitly, allowing the construction of std::arrays with const +// elements. +// +// Reference: https://en.cppreference.com/w/cpp/container/array/to_array +template +constexpr std::array ToArrayImpl(const T (&data)[N], + std::index_sequence) { + return {{data[I]...}}; +} + +template +constexpr std::array ToArray(const T (&data)[N]) { + return ToArrayImpl(data, std::make_index_sequence()); +} + +// std::pair's operator= is not constexpr prior to C++20. Thus we need this +// small helper to invoke operator= on the .first and .second member explicitly. +template +constexpr void Assign(T& lhs, T&& rhs) { + lhs = std::move(rhs); +} + +template +constexpr void Assign(std::pair& lhs, std::pair&& rhs) { + Assign(lhs.first, std::move(rhs.first)); + Assign(lhs.second, std::move(rhs.second)); +} + +// constexpr swap implementation. std::swap is not constexpr prior to C++20. +template +constexpr void Swap(T& lhs, T& rhs) { + T tmp = std::move(lhs); + Assign(lhs, std::move(rhs)); + Assign(rhs, std::move(tmp)); +} + +// constexpr prev implementation. std::prev is not constexpr prior to C++17. +template +constexpr BidirIt Prev(BidirIt it) { + return --it; +} + +// constexpr next implementation. std::next is not constexpr prior to C++17. +template +constexpr InputIt Next(InputIt it) { + return ++it; +} + +// constexpr sort implementation. std::sort is not constexpr prior to C++20. +// While insertion sort has a quadratic worst case complexity, it was chosen +// because it has linear complexity for nearly sorted data, is stable, and +// simple to implement. +template +constexpr void InsertionSort(BidirIt first, BidirIt last, const Compare& comp) { + if (first == last) + return; + + for (auto it = Next(first); it != last; ++it) { + for (auto curr = it; curr != first && comp(*curr, *Prev(curr)); --curr) + Swap(*curr, *Prev(curr)); + } +} + +// Implementation ------------------------------------------------------------- + +// Implementation for the sorted associative flat_set and flat_map using a +// sorted vector as the backing store. Do not use directly. +// +// The use of "value" in this is like std::map uses, meaning it's the thing +// contained (in the case of map it's a pair). The Key is how +// things are looked up. In the case of a set, Key == Value. In the case of +// a map, the Key is a component of a Value. +// +// The helper class GetKeyFromValue provides the means to extract a key from a +// value for comparison purposes. It should implement: +// const Key& operator()(const Value&). +template +class flat_tree { + public: + // -------------------------------------------------------------------------- + // Types. + // + using key_type = Key; + using key_compare = KeyCompare; + using value_type = typename Container::value_type; + + // Wraps the templated key comparison to compare values. + struct value_compare { + constexpr bool operator()(const value_type& left, + const value_type& right) const { + GetKeyFromValue extractor; + return comp(extractor(left), extractor(right)); + } + + RTC_NO_UNIQUE_ADDRESS key_compare comp; + }; + + using pointer = typename Container::pointer; + using const_pointer = typename Container::const_pointer; + using reference = typename Container::reference; + using const_reference = typename Container::const_reference; + using size_type = typename Container::size_type; + using difference_type = typename Container::difference_type; + using iterator = typename Container::iterator; + using const_iterator = typename Container::const_iterator; + using reverse_iterator = typename Container::reverse_iterator; + using const_reverse_iterator = typename Container::const_reverse_iterator; + using container_type = Container; + + // -------------------------------------------------------------------------- + // Lifetime. + // + // Constructors that take range guarantee O(N * log^2(N)) + O(N) complexity + // and take O(N * log(N)) + O(N) if extra memory is available (N is a range + // length). + // + // Assume that move constructors invalidate iterators and references. + // + // The constructors that take ranges, lists, and vectors do not require that + // the input be sorted. + // + // When passing the webrtc::sorted_unique tag as the first argument no sort + // and unique step takes places. This is useful if the underlying container + // already has the required properties. + + flat_tree() = default; + flat_tree(const flat_tree&) = default; + flat_tree(flat_tree&&) = default; + + explicit flat_tree(const key_compare& comp); + + template + flat_tree(InputIterator first, + InputIterator last, + const key_compare& comp = key_compare()); + + flat_tree(const container_type& items, + const key_compare& comp = key_compare()); + + explicit flat_tree(container_type&& items, + const key_compare& comp = key_compare()); + + flat_tree(std::initializer_list ilist, + const key_compare& comp = key_compare()); + + template + flat_tree(sorted_unique_t, + InputIterator first, + InputIterator last, + const key_compare& comp = key_compare()); + + flat_tree(sorted_unique_t, + const container_type& items, + const key_compare& comp = key_compare()); + + constexpr flat_tree(sorted_unique_t, + container_type&& items, + const key_compare& comp = key_compare()); + + flat_tree(sorted_unique_t, + std::initializer_list ilist, + const key_compare& comp = key_compare()); + + ~flat_tree() = default; + + // -------------------------------------------------------------------------- + // Assignments. + // + // Assume that move assignment invalidates iterators and references. + + flat_tree& operator=(const flat_tree&) = default; + flat_tree& operator=(flat_tree&&) = default; + // Takes the first if there are duplicates in the initializer list. + flat_tree& operator=(std::initializer_list ilist); + + // -------------------------------------------------------------------------- + // Memory management. + // + // Beware that shrink_to_fit() simply forwards the request to the + // container_type and its implementation is free to optimize otherwise and + // leave capacity() to be greater that its size. + // + // reserve() and shrink_to_fit() invalidate iterators and references. + + void reserve(size_type new_capacity); + size_type capacity() const; + void shrink_to_fit(); + + // -------------------------------------------------------------------------- + // Size management. + // + // clear() leaves the capacity() of the flat_tree unchanged. + + void clear(); + + constexpr size_type size() const; + constexpr size_type max_size() const; + constexpr bool empty() const; + + // -------------------------------------------------------------------------- + // Iterators. + // + // Iterators follow the ordering defined by the key comparator used in + // construction of the flat_tree. + + iterator begin(); + constexpr const_iterator begin() const; + const_iterator cbegin() const; + + iterator end(); + constexpr const_iterator end() const; + const_iterator cend() const; + + reverse_iterator rbegin(); + const_reverse_iterator rbegin() const; + const_reverse_iterator crbegin() const; + + reverse_iterator rend(); + const_reverse_iterator rend() const; + const_reverse_iterator crend() const; + + // -------------------------------------------------------------------------- + // Insert operations. + // + // Assume that every operation invalidates iterators and references. + // Insertion of one element can take O(size). Capacity of flat_tree grows in + // an implementation-defined manner. + // + // NOTE: Prefer to build a new flat_tree from a std::vector (or similar) + // instead of calling insert() repeatedly. + + std::pair insert(const value_type& val); + std::pair insert(value_type&& val); + + iterator insert(const_iterator position_hint, const value_type& x); + iterator insert(const_iterator position_hint, value_type&& x); + + // This method inserts the values from the range [first, last) into the + // current tree. + template + void insert(InputIterator first, InputIterator last); + + template + std::pair emplace(Args&&... args); + + template + iterator emplace_hint(const_iterator position_hint, Args&&... args); + + // -------------------------------------------------------------------------- + // Underlying type operations. + // + // Assume that either operation invalidates iterators and references. + + // Extracts the container_type and returns it to the caller. Ensures that + // `this` is `empty()` afterwards. + container_type extract() &&; + + // Replaces the container_type with `body`. Expects that `body` is sorted + // and has no repeated elements with regard to value_comp(). + void replace(container_type&& body); + + // -------------------------------------------------------------------------- + // Erase operations. + // + // Assume that every operation invalidates iterators and references. + // + // erase(position), erase(first, last) can take O(size). + // erase(key) may take O(size) + O(log(size)). + // + // Prefer webrtc::EraseIf() or some other variation on erase(remove(), end()) + // idiom when deleting multiple non-consecutive elements. + + iterator erase(iterator position); + // Artificially templatized to break ambiguity if `iterator` and + // `const_iterator` are the same type. + template + iterator erase(const_iterator position); + iterator erase(const_iterator first, const_iterator last); + template + size_type erase(const K& key); + + // -------------------------------------------------------------------------- + // Comparators. + + constexpr key_compare key_comp() const; + constexpr value_compare value_comp() const; + + // -------------------------------------------------------------------------- + // Search operations. + // + // Search operations have O(log(size)) complexity. + + template + size_type count(const K& key) const; + + template + iterator find(const K& key); + + template + const_iterator find(const K& key) const; + + template + bool contains(const K& key) const; + + template + std::pair equal_range(const K& key); + + template + std::pair equal_range(const K& key) const; + + template + iterator lower_bound(const K& key); + + template + const_iterator lower_bound(const K& key) const; + + template + iterator upper_bound(const K& key); + + template + const_iterator upper_bound(const K& key) const; + + // -------------------------------------------------------------------------- + // General operations. + // + // Assume that swap invalidates iterators and references. + // + // Implementation note: currently we use operator==() and operator<() on + // std::vector, because they have the same contract we need, so we use them + // directly for brevity and in case it is more optimal than calling equal() + // and lexicograhpical_compare(). If the underlying container type is changed, + // this code may need to be modified. + + void swap(flat_tree& other) noexcept; + + friend bool operator==(const flat_tree& lhs, const flat_tree& rhs) { + return lhs.body_ == rhs.body_; + } + + friend bool operator!=(const flat_tree& lhs, const flat_tree& rhs) { + return !(lhs == rhs); + } + + friend bool operator<(const flat_tree& lhs, const flat_tree& rhs) { + return lhs.body_ < rhs.body_; + } + + friend bool operator>(const flat_tree& lhs, const flat_tree& rhs) { + return rhs < lhs; + } + + friend bool operator>=(const flat_tree& lhs, const flat_tree& rhs) { + return !(lhs < rhs); + } + + friend bool operator<=(const flat_tree& lhs, const flat_tree& rhs) { + return !(lhs > rhs); + } + + friend void swap(flat_tree& lhs, flat_tree& rhs) noexcept { lhs.swap(rhs); } + + protected: + // Emplaces a new item into the tree that is known not to be in it. This + // is for implementing map operator[]. + template + iterator unsafe_emplace(const_iterator position, Args&&... args); + + // Attempts to emplace a new element with key |key|. Only if |key| is not yet + // present, construct value_type from |args| and insert it. Returns an + // iterator to the element with key |key| and a bool indicating whether an + // insertion happened. + template + std::pair emplace_key_args(const K& key, Args&&... args); + + // Similar to |emplace_key_args|, but checks |hint| first as a possible + // insertion position. + template + std::pair emplace_hint_key_args(const_iterator hint, + const K& key, + Args&&... args); + + private: + // Helper class for e.g. lower_bound that can compare a value on the left + // to a key on the right. + struct KeyValueCompare { + // The key comparison object must outlive this class. + explicit KeyValueCompare(const key_compare& comp) : comp_(comp) {} + + template + bool operator()(const T& lhs, const U& rhs) const { + return comp_(extract_if_value_type(lhs), extract_if_value_type(rhs)); + } + + private: + const key_type& extract_if_value_type(const value_type& v) const { + GetKeyFromValue extractor; + return extractor(v); + } + + template + const K& extract_if_value_type(const K& k) const { + return k; + } + + const key_compare& comp_; + }; + + iterator const_cast_it(const_iterator c_it) { + auto distance = std::distance(cbegin(), c_it); + return std::next(begin(), distance); + } + + // This method is inspired by both std::map::insert(P&&) and + // std::map::insert_or_assign(const K&, V&&). It inserts val if an equivalent + // element is not present yet, otherwise it overwrites. It returns an iterator + // to the modified element and a flag indicating whether insertion or + // assignment happened. + template + std::pair insert_or_assign(V&& val) { + auto position = lower_bound(GetKeyFromValue()(val)); + + if (position == end() || value_comp()(val, *position)) + return {body_.emplace(position, std::forward(val)), true}; + + *position = std::forward(val); + return {position, false}; + } + + // This method is similar to insert_or_assign, with the following differences: + // - Instead of searching [begin(), end()) it only searches [first, last). + // - In case no equivalent element is found, val is appended to the end of the + // underlying body and an iterator to the next bigger element in [first, + // last) is returned. + template + std::pair append_or_assign(iterator first, + iterator last, + V&& val) { + auto position = std::lower_bound(first, last, val, value_comp()); + + if (position == last || value_comp()(val, *position)) { + // emplace_back might invalidate position, which is why distance needs to + // be cached. + const difference_type distance = std::distance(begin(), position); + body_.emplace_back(std::forward(val)); + return {std::next(begin(), distance), true}; + } + + *position = std::forward(val); + return {position, false}; + } + + // This method is similar to insert, with the following differences: + // - Instead of searching [begin(), end()) it only searches [first, last). + // - In case no equivalent element is found, val is appended to the end of the + // underlying body and an iterator to the next bigger element in [first, + // last) is returned. + template + std::pair append_unique(iterator first, + iterator last, + V&& val) { + auto position = std::lower_bound(first, last, val, value_comp()); + + if (position == last || value_comp()(val, *position)) { + // emplace_back might invalidate position, which is why distance needs to + // be cached. + const difference_type distance = std::distance(begin(), position); + body_.emplace_back(std::forward(val)); + return {std::next(begin(), distance), true}; + } + + return {position, false}; + } + + void sort_and_unique(iterator first, iterator last) { + // Preserve stability for the unique code below. + std::stable_sort(first, last, value_comp()); + + // lhs is already <= rhs due to sort, therefore !(lhs < rhs) <=> lhs == rhs. + auto equal_comp = webrtc::not_fn(value_comp()); + erase(std::unique(first, last, equal_comp), last); + } + + void sort_and_unique() { sort_and_unique(begin(), end()); } + + // To support comparators that may not be possible to default-construct, we + // have to store an instance of Compare. Since Compare commonly is stateless, + // we use the RTC_NO_UNIQUE_ADDRESS attribute to save space. + RTC_NO_UNIQUE_ADDRESS key_compare comp_; + // Declare after |key_compare_comp_| to workaround GCC ICE. For details + // see https://crbug.com/1156268 + container_type body_; + + // If the compare is not transparent we want to construct key_type once. + template + using KeyTypeOrK = typename std:: + conditional::value, K, key_type>::type; +}; + +// ---------------------------------------------------------------------------- +// Lifetime. + +template +flat_tree::flat_tree( + const KeyCompare& comp) + : comp_(comp) {} + +template +template +flat_tree::flat_tree( + InputIterator first, + InputIterator last, + const KeyCompare& comp) + : comp_(comp), body_(first, last) { + sort_and_unique(); +} + +template +flat_tree::flat_tree( + const container_type& items, + const KeyCompare& comp) + : comp_(comp), body_(items) { + sort_and_unique(); +} + +template +flat_tree::flat_tree( + container_type&& items, + const KeyCompare& comp) + : comp_(comp), body_(std::move(items)) { + sort_and_unique(); +} + +template +flat_tree::flat_tree( + std::initializer_list ilist, + const KeyCompare& comp) + : flat_tree(std::begin(ilist), std::end(ilist), comp) {} + +template +template +flat_tree::flat_tree( + sorted_unique_t, + InputIterator first, + InputIterator last, + const KeyCompare& comp) + : comp_(comp), body_(first, last) { + RTC_DCHECK(is_sorted_and_unique(*this, value_comp())); +} + +template +flat_tree::flat_tree( + sorted_unique_t, + const container_type& items, + const KeyCompare& comp) + : comp_(comp), body_(items) { + RTC_DCHECK(is_sorted_and_unique(*this, value_comp())); +} + +template +constexpr flat_tree::flat_tree( + sorted_unique_t, + container_type&& items, + const KeyCompare& comp) + : comp_(comp), body_(std::move(items)) { + RTC_DCHECK(is_sorted_and_unique(*this, value_comp())); +} + +template +flat_tree::flat_tree( + sorted_unique_t, + std::initializer_list ilist, + const KeyCompare& comp) + : flat_tree(sorted_unique, std::begin(ilist), std::end(ilist), comp) {} + +// ---------------------------------------------------------------------------- +// Assignments. + +template +auto flat_tree::operator=( + std::initializer_list ilist) -> flat_tree& { + body_ = ilist; + sort_and_unique(); + return *this; +} + +// ---------------------------------------------------------------------------- +// Memory management. + +template +void flat_tree::reserve( + size_type new_capacity) { + body_.reserve(new_capacity); +} + +template +auto flat_tree::capacity() const + -> size_type { + return body_.capacity(); +} + +template +void flat_tree::shrink_to_fit() { + body_.shrink_to_fit(); +} + +// ---------------------------------------------------------------------------- +// Size management. + +template +void flat_tree::clear() { + body_.clear(); +} + +template +constexpr auto flat_tree::size() + const -> size_type { + return body_.size(); +} + +template +constexpr auto +flat_tree::max_size() const + -> size_type { + return body_.max_size(); +} + +template +constexpr bool flat_tree::empty() + const { + return body_.empty(); +} + +// ---------------------------------------------------------------------------- +// Iterators. + +template +auto flat_tree::begin() + -> iterator { + return body_.begin(); +} + +template +constexpr auto flat_tree::begin() + const -> const_iterator { + return std::begin(body_); +} + +template +auto flat_tree::cbegin() const + -> const_iterator { + return body_.cbegin(); +} + +template +auto flat_tree::end() -> iterator { + return body_.end(); +} + +template +constexpr auto flat_tree::end() + const -> const_iterator { + return std::end(body_); +} + +template +auto flat_tree::cend() const + -> const_iterator { + return body_.cend(); +} + +template +auto flat_tree::rbegin() + -> reverse_iterator { + return body_.rbegin(); +} + +template +auto flat_tree::rbegin() const + -> const_reverse_iterator { + return body_.rbegin(); +} + +template +auto flat_tree::crbegin() const + -> const_reverse_iterator { + return body_.crbegin(); +} + +template +auto flat_tree::rend() + -> reverse_iterator { + return body_.rend(); +} + +template +auto flat_tree::rend() const + -> const_reverse_iterator { + return body_.rend(); +} + +template +auto flat_tree::crend() const + -> const_reverse_iterator { + return body_.crend(); +} + +// ---------------------------------------------------------------------------- +// Insert operations. +// +// Currently we use position_hint the same way as eastl or boost: +// https://github.com/electronicarts/EASTL/blob/master/include/EASTL/vector_set.h#L493 + +template +auto flat_tree::insert( + const value_type& val) -> std::pair { + return emplace_key_args(GetKeyFromValue()(val), val); +} + +template +auto flat_tree::insert( + value_type&& val) -> std::pair { + return emplace_key_args(GetKeyFromValue()(val), std::move(val)); +} + +template +auto flat_tree::insert( + const_iterator position_hint, + const value_type& val) -> iterator { + return emplace_hint_key_args(position_hint, GetKeyFromValue()(val), val) + .first; +} + +template +auto flat_tree::insert( + const_iterator position_hint, + value_type&& val) -> iterator { + return emplace_hint_key_args(position_hint, GetKeyFromValue()(val), + std::move(val)) + .first; +} + +template +template +void flat_tree::insert( + InputIterator first, + InputIterator last) { + if (first == last) + return; + + // Dispatch to single element insert if the input range contains a single + // element. + if (is_multipass() && std::next(first) == last) { + insert(end(), *first); + return; + } + + // Provide a convenience lambda to obtain an iterator pointing past the last + // old element. This needs to be dymanic due to possible re-allocations. + auto middle = [this, size = size()] { return std::next(begin(), size); }; + + // For batch updates initialize the first insertion point. + difference_type pos_first_new = size(); + + // Loop over the input range while appending new values and overwriting + // existing ones, if applicable. Keep track of the first insertion point. + for (; first != last; ++first) { + std::pair result = append_unique(begin(), middle(), *first); + if (result.second) { + pos_first_new = + std::min(pos_first_new, std::distance(begin(), result.first)); + } + } + + // The new elements might be unordered and contain duplicates, so post-process + // the just inserted elements and merge them with the rest, inserting them at + // the previously found spot. + sort_and_unique(middle(), end()); + std::inplace_merge(std::next(begin(), pos_first_new), middle(), end(), + value_comp()); +} + +template +template +auto flat_tree::emplace( + Args&&... args) -> std::pair { + return insert(value_type(std::forward(args)...)); +} + +template +template +auto flat_tree::emplace_hint( + const_iterator position_hint, + Args&&... args) -> iterator { + return insert(position_hint, value_type(std::forward(args)...)); +} + +// ---------------------------------------------------------------------------- +// Underlying type operations. + +template +auto flat_tree:: + extract() && -> container_type { + return std::exchange(body_, container_type()); +} + +template +void flat_tree::replace( + container_type&& body) { + // Ensure that `body` is sorted and has no repeated elements according to + // `value_comp()`. + RTC_DCHECK(is_sorted_and_unique(body, value_comp())); + body_ = std::move(body); +} + +// ---------------------------------------------------------------------------- +// Erase operations. + +template +auto flat_tree::erase( + iterator position) -> iterator { + RTC_CHECK(position != body_.end()); + return body_.erase(position); +} + +template +template +auto flat_tree::erase( + const_iterator position) -> iterator { + RTC_CHECK(position != body_.end()); + return body_.erase(position); +} + +template +template +auto flat_tree::erase(const K& val) + -> size_type { + auto eq_range = equal_range(val); + auto res = std::distance(eq_range.first, eq_range.second); + erase(eq_range.first, eq_range.second); + return res; +} + +template +auto flat_tree::erase( + const_iterator first, + const_iterator last) -> iterator { + return body_.erase(first, last); +} + +// ---------------------------------------------------------------------------- +// Comparators. + +template +constexpr auto +flat_tree::key_comp() const + -> key_compare { + return comp_; +} + +template +constexpr auto +flat_tree::value_comp() const + -> value_compare { + return value_compare{comp_}; +} + +// ---------------------------------------------------------------------------- +// Search operations. + +template +template +auto flat_tree::count( + const K& key) const -> size_type { + auto eq_range = equal_range(key); + return std::distance(eq_range.first, eq_range.second); +} + +template +template +auto flat_tree::find(const K& key) + -> iterator { + return const_cast_it(webrtc::as_const(*this).find(key)); +} + +template +template +auto flat_tree::find( + const K& key) const -> const_iterator { + auto eq_range = equal_range(key); + return (eq_range.first == eq_range.second) ? end() : eq_range.first; +} + +template +template +bool flat_tree::contains( + const K& key) const { + auto lower = lower_bound(key); + return lower != end() && !comp_(key, GetKeyFromValue()(*lower)); +} + +template +template +auto flat_tree::equal_range( + const K& key) -> std::pair { + auto res = webrtc::as_const(*this).equal_range(key); + return {const_cast_it(res.first), const_cast_it(res.second)}; +} + +template +template +auto flat_tree::equal_range( + const K& key) const -> std::pair { + auto lower = lower_bound(key); + + KeyValueCompare comp(comp_); + if (lower == end() || comp(key, *lower)) + return {lower, lower}; + + return {lower, std::next(lower)}; +} + +template +template +auto flat_tree::lower_bound( + const K& key) -> iterator { + return const_cast_it(webrtc::as_const(*this).lower_bound(key)); +} + +template +template +auto flat_tree::lower_bound( + const K& key) const -> const_iterator { + static_assert(std::is_convertible&, const K&>::value, + "Requested type cannot be bound to the container's key_type " + "which is required for a non-transparent compare."); + + const KeyTypeOrK& key_ref = key; + + KeyValueCompare comp(comp_); + return absl::c_lower_bound(*this, key_ref, comp); +} + +template +template +auto flat_tree::upper_bound( + const K& key) -> iterator { + return const_cast_it(webrtc::as_const(*this).upper_bound(key)); +} + +template +template +auto flat_tree::upper_bound( + const K& key) const -> const_iterator { + static_assert(std::is_convertible&, const K&>::value, + "Requested type cannot be bound to the container's key_type " + "which is required for a non-transparent compare."); + + const KeyTypeOrK& key_ref = key; + + KeyValueCompare comp(comp_); + return absl::c_upper_bound(*this, key_ref, comp); +} + +// ---------------------------------------------------------------------------- +// General operations. + +template +void flat_tree::swap( + flat_tree& other) noexcept { + std::swap(*this, other); +} + +template +template +auto flat_tree::unsafe_emplace( + const_iterator position, + Args&&... args) -> iterator { + return body_.emplace(position, std::forward(args)...); +} + +template +template +auto flat_tree::emplace_key_args( + const K& key, + Args&&... args) -> std::pair { + auto lower = lower_bound(key); + if (lower == end() || comp_(key, GetKeyFromValue()(*lower))) + return {unsafe_emplace(lower, std::forward(args)...), true}; + return {lower, false}; +} + +template +template +auto flat_tree:: + emplace_hint_key_args(const_iterator hint, const K& key, Args&&... args) + -> std::pair { + KeyValueCompare comp(comp_); + if ((hint == begin() || comp(*std::prev(hint), key))) { + if (hint == end() || comp(key, *hint)) { + // *(hint - 1) < key < *hint => key did not exist and hint is correct. + return {unsafe_emplace(hint, std::forward(args)...), true}; + } + if (!comp(*hint, key)) { + // key == *hint => no-op, return correct hint. + return {const_cast_it(hint), false}; + } + } + // hint was not helpful, dispatch to hintless version. + return emplace_key_args(key, std::forward(args)...); +} + +// ---------------------------------------------------------------------------- +// Free functions. + +// Erases all elements that match predicate. It has O(size) complexity. +template +size_t EraseIf( + webrtc::flat_containers_internal:: + flat_tree& container, + Predicate pred) { + auto it = std::remove_if(container.begin(), container.end(), + std::forward(pred)); + size_t removed = std::distance(it, container.end()); + container.erase(it, container.end()); + return removed; +} + +} // namespace flat_containers_internal +} // namespace webrtc + +#endif // RTC_BASE_CONTAINERS_FLAT_TREE_H_ diff --git a/rtc_base/containers/flat_tree_unittest.cc b/rtc_base/containers/flat_tree_unittest.cc new file mode 100644 index 0000000000..9bb803d16d --- /dev/null +++ b/rtc_base/containers/flat_tree_unittest.cc @@ -0,0 +1,1484 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// This implementation is borrowed from Chromium. + +#include "rtc_base/containers/flat_tree.h" + +// Following tests are ported and extended tests from libcpp for std::set. +// They can be found here: +// https://github.com/llvm/llvm-project/tree/main/libcxx/test/std/containers/associative/set +// +// Not ported tests: +// * No tests with PrivateConstructor and std::less<> changed to std::less +// These tests have to do with C++14 std::less<> +// http://en.cppreference.com/w/cpp/utility/functional/less_void +// and add support for templated versions of lookup functions. +// Because we use same implementation, we figured that it's OK just to check +// compilation and this is what we do in flat_set_unittest/flat_map_unittest. +// * No tests for max_size() +// Has to do with allocator support. +// * No tests with DefaultOnly. +// Standard containers allocate each element in the separate node on the heap +// and then manipulate these nodes. Flat containers store their elements in +// contiguous memory and move them around, type is required to be movable. +// * No tests for N3644. +// This proposal suggests that all default constructed iterators compare +// equal. Currently we use std::vector iterators and they don't implement +// this. +// * No tests with min_allocator and no tests counting allocations. +// Flat sets currently don't support allocators. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "rtc_base/containers/identity.h" +#include "rtc_base/containers/move_only_int.h" +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { +namespace flat_containers_internal { +namespace { + +template +class InputIterator { + public: + using iterator_category = std::input_iterator_tag; + using value_type = typename std::iterator_traits::value_type; + using difference_type = typename std::iterator_traits::difference_type; + using pointer = It; + using reference = typename std::iterator_traits::reference; + + InputIterator() : it_() {} + explicit InputIterator(It it) : it_(it) {} + + reference operator*() const { return *it_; } + pointer operator->() const { return it_; } + + InputIterator& operator++() { + ++it_; + return *this; + } + InputIterator operator++(int) { + InputIterator tmp(*this); + ++(*this); + return tmp; + } + + friend bool operator==(const InputIterator& lhs, const InputIterator& rhs) { + return lhs.it_ == rhs.it_; + } + friend bool operator!=(const InputIterator& lhs, const InputIterator& rhs) { + return !(lhs == rhs); + } + + private: + It it_; +}; + +template +InputIterator MakeInputIterator(It it) { + return InputIterator(it); +} + +class Emplaceable { + public: + Emplaceable() : Emplaceable(0, 0.0) {} + Emplaceable(int i, double d) : int_(i), double_(d) {} + Emplaceable(Emplaceable&& other) : int_(other.int_), double_(other.double_) { + other.int_ = 0; + other.double_ = 0.0; + } + Emplaceable(const Emplaceable&) = delete; + Emplaceable& operator=(const Emplaceable&) = delete; + + Emplaceable& operator=(Emplaceable&& other) { + int_ = other.int_; + other.int_ = 0; + double_ = other.double_; + other.double_ = 0.0; + return *this; + } + + friend bool operator==(const Emplaceable& lhs, const Emplaceable& rhs) { + return std::tie(lhs.int_, lhs.double_) == std::tie(rhs.int_, rhs.double_); + } + + friend bool operator<(const Emplaceable& lhs, const Emplaceable& rhs) { + return std::tie(lhs.int_, lhs.double_) < std::tie(rhs.int_, rhs.double_); + } + + private: + int int_; + double double_; +}; + +struct TemplateConstructor { + template + explicit TemplateConstructor(const T&) {} + + friend bool operator<(const TemplateConstructor&, + const TemplateConstructor&) { + return false; + } +}; + +class NonDefaultConstructibleCompare { + public: + explicit NonDefaultConstructibleCompare(int) {} + + template + bool operator()(const T& lhs, const T& rhs) const { + return std::less()(lhs, rhs); + } +}; + +template +struct LessByFirst { + bool operator()(const PairType& lhs, const PairType& rhs) const { + return lhs.first < rhs.first; + } +}; + +// Common test trees. +template +using TypedTree = flat_tree, + ContainerT>; +using IntTree = TypedTree>; +using IntPair = std::pair; +using IntPairTree = + flat_tree, std::vector>; +using MoveOnlyTree = + flat_tree, std::vector>; +using EmplaceableTree = + flat_tree, std::vector>; +using ReversedTree = + flat_tree, std::vector>; + +using TreeWithStrangeCompare = + flat_tree>; + +using ::testing::ElementsAre; +using ::testing::IsEmpty; + +template +class FlatTreeTest : public testing::Test {}; +TYPED_TEST_SUITE_P(FlatTreeTest); + +TEST(FlatTree, IsMultipass) { + static_assert(!is_multipass>(), + "InputIterator is not multipass"); + static_assert(!is_multipass>(), + "OutputIterator is not multipass"); + + static_assert(is_multipass::iterator>(), + "ForwardIterator is multipass"); + static_assert(is_multipass::iterator>(), + "BidirectionalIterator is multipass"); + static_assert(is_multipass::iterator>(), + "RandomAccessIterator is multipass"); +} + +// Tests that the compiler generated move operators propagrate noexcept +// specifiers. +TEST(FlatTree, NoExcept) { + struct MoveThrows { + MoveThrows(MoveThrows&&) noexcept(false) {} + MoveThrows& operator=(MoveThrows&&) noexcept(false) { return *this; } + }; + + using MoveThrowsTree = + flat_tree, std::array>; + + static_assert(std::is_nothrow_move_constructible::value, + "Error: IntTree is not nothrow move constructible"); + static_assert(std::is_nothrow_move_assignable::value, + "Error: IntTree is not nothrow move assignable"); + + static_assert(!std::is_nothrow_move_constructible::value, + "Error: MoveThrowsTree is nothrow move constructible"); + static_assert(!std::is_nothrow_move_assignable::value, + "Error: MoveThrowsTree is nothrow move assignable"); +} + +// ---------------------------------------------------------------------------- +// Class. + +// Check that flat_tree and its iterators can be instantiated with an +// incomplete type. + +TEST(FlatTree, IncompleteType) { + struct A { + using Tree = flat_tree, std::vector>; + int data; + Tree set_with_incomplete_type; + Tree::iterator it; + Tree::const_iterator cit; + + // We do not declare operator< because clang complains that it's unused. + }; + + A a; +} + +TEST(FlatTree, Stability) { + using Pair = std::pair; + + using Tree = flat_tree, std::vector>; + + // Constructors are stable. + Tree cont({{0, 0}, {1, 0}, {0, 1}, {2, 0}, {0, 2}, {1, 1}}); + + auto AllOfSecondsAreZero = [&cont] { + return absl::c_all_of(cont, + [](const Pair& elem) { return elem.second == 0; }); + }; + + EXPECT_TRUE(AllOfSecondsAreZero()) << "constructor should be stable"; + + // Should not replace existing. + cont.insert(Pair(0, 2)); + cont.insert(Pair(1, 2)); + cont.insert(Pair(2, 2)); + + EXPECT_TRUE(AllOfSecondsAreZero()) << "insert should be stable"; + + cont.insert(Pair(3, 0)); + cont.insert(Pair(3, 2)); + + EXPECT_TRUE(AllOfSecondsAreZero()) << "insert should be stable"; +} + +// ---------------------------------------------------------------------------- +// Types. + +// key_type +// key_compare +// value_type +// value_compare +// pointer +// const_pointer +// reference +// const_reference +// size_type +// difference_type +// iterator +// const_iterator +// reverse_iterator +// const_reverse_iterator + +TEST(FlatTree, Types) { + // These are guaranteed to be portable. + static_assert((std::is_same::value), ""); + static_assert((std::is_same::value), ""); + static_assert((std::is_same, IntTree::key_compare>::value), ""); + static_assert((std::is_same::value), ""); + static_assert((std::is_same::value), + ""); + static_assert((std::is_same::value), ""); + static_assert((std::is_same::value), ""); +} + +// ---------------------------------------------------------------------------- +// Lifetime. + +// flat_tree() +// flat_tree(const Compare& comp) + +TYPED_TEST_P(FlatTreeTest, DefaultConstructor) { + { + TypedTree cont; + EXPECT_THAT(cont, ElementsAre()); + } + + { + TreeWithStrangeCompare cont(NonDefaultConstructibleCompare(0)); + EXPECT_THAT(cont, ElementsAre()); + } +} + +// flat_tree(const flat_tree& x) + +TYPED_TEST_P(FlatTreeTest, CopyConstructor) { + TypedTree original({1, 2, 3, 4}); + TypedTree copied(original); + + EXPECT_THAT(copied, ElementsAre(1, 2, 3, 4)); + + EXPECT_THAT(copied, ElementsAre(1, 2, 3, 4)); + EXPECT_THAT(original, ElementsAre(1, 2, 3, 4)); + EXPECT_EQ(original, copied); +} + +// flat_tree(flat_tree&& x) + +TEST(FlatTree, MoveConstructor) { + int input_range[] = {1, 2, 3, 4}; + + MoveOnlyTree original(std::begin(input_range), std::end(input_range)); + MoveOnlyTree moved(std::move(original)); + + EXPECT_EQ(1U, moved.count(MoveOnlyInt(1))); + EXPECT_EQ(1U, moved.count(MoveOnlyInt(2))); + EXPECT_EQ(1U, moved.count(MoveOnlyInt(3))); + EXPECT_EQ(1U, moved.count(MoveOnlyInt(4))); +} + +// flat_tree(InputIterator first, +// InputIterator last, +// const Compare& comp = Compare()) + +TEST(FlatTree, RangeConstructor) { + { + IntPair input_vals[] = {{1, 1}, {1, 2}, {2, 1}, {2, 2}, {1, 3}, + {2, 3}, {3, 1}, {3, 2}, {3, 3}}; + + IntPairTree first_of(MakeInputIterator(std::begin(input_vals)), + MakeInputIterator(std::end(input_vals))); + EXPECT_THAT(first_of, + ElementsAre(IntPair(1, 1), IntPair(2, 1), IntPair(3, 1))); + } + { + TreeWithStrangeCompare::value_type input_vals[] = {1, 1, 1, 2, 2, + 2, 3, 3, 3}; + + TreeWithStrangeCompare cont(MakeInputIterator(std::begin(input_vals)), + MakeInputIterator(std::end(input_vals)), + NonDefaultConstructibleCompare(0)); + EXPECT_THAT(cont, ElementsAre(1, 2, 3)); + } +} + +// flat_tree(const container_type&) + +TYPED_TEST_P(FlatTreeTest, ContainerCopyConstructor) { + TypeParam items = {1, 2, 3, 4}; + TypedTree tree(items); + + EXPECT_THAT(tree, ElementsAre(1, 2, 3, 4)); + EXPECT_THAT(items, ElementsAre(1, 2, 3, 4)); +} + +// flat_tree(container_type&&) + +TEST(FlatTree, ContainerMoveConstructor) { + using Pair = std::pair; + + // Construct an unsorted vector with a duplicate item in it. Sorted by the + // first item, the second allows us to test for stability. Using a move + // only type to ensure the vector is not copied. + std::vector storage; + storage.push_back(Pair(2, MoveOnlyInt(0))); + storage.push_back(Pair(1, MoveOnlyInt(0))); + storage.push_back(Pair(2, MoveOnlyInt(1))); + + using Tree = flat_tree, std::vector>; + Tree tree(std::move(storage)); + + // The list should be two items long, with only the first "2" saved. + ASSERT_EQ(2u, tree.size()); + const Pair& zeroth = *tree.begin(); + ASSERT_EQ(1, zeroth.first); + ASSERT_EQ(0, zeroth.second.data()); + + const Pair& first = *(tree.begin() + 1); + ASSERT_EQ(2, first.first); + ASSERT_EQ(0, first.second.data()); +} + +// flat_tree(std::initializer_list ilist, +// const Compare& comp = Compare()) + +TYPED_TEST_P(FlatTreeTest, InitializerListConstructor) { + { + TypedTree cont({1, 2, 3, 4, 5, 6, 10, 8}); + EXPECT_THAT(cont, ElementsAre(1, 2, 3, 4, 5, 6, 8, 10)); + } + { + TypedTree cont({1, 2, 3, 4, 5, 6, 10, 8}); + EXPECT_THAT(cont, ElementsAre(1, 2, 3, 4, 5, 6, 8, 10)); + } + { + TreeWithStrangeCompare cont({1, 2, 3, 4, 5, 6, 10, 8}, + NonDefaultConstructibleCompare(0)); + EXPECT_THAT(cont, ElementsAre(1, 2, 3, 4, 5, 6, 8, 10)); + } + { + IntPairTree first_of({{1, 1}, {2, 1}, {1, 2}}); + EXPECT_THAT(first_of, ElementsAre(IntPair(1, 1), IntPair(2, 1))); + } +} + +// flat_tree(sorted_unique_t, +// InputIterator first, +// InputIterator last, +// const Compare& comp = Compare()) + +TEST(FlatTree, SortedUniqueRangeConstructor) { + { + IntPair input_vals[] = {{1, 1}, {2, 1}, {3, 1}}; + + IntPairTree first_of(sorted_unique, + MakeInputIterator(std::begin(input_vals)), + MakeInputIterator(std::end(input_vals))); + EXPECT_THAT(first_of, + ElementsAre(IntPair(1, 1), IntPair(2, 1), IntPair(3, 1))); + } + { + TreeWithStrangeCompare::value_type input_vals[] = {1, 2, 3}; + + TreeWithStrangeCompare cont(sorted_unique, + MakeInputIterator(std::begin(input_vals)), + MakeInputIterator(std::end(input_vals)), + NonDefaultConstructibleCompare(0)); + EXPECT_THAT(cont, ElementsAre(1, 2, 3)); + } +} + +// flat_tree(sorted_unique_t, const container_type&) + +TYPED_TEST_P(FlatTreeTest, SortedUniqueContainerCopyConstructor) { + TypeParam items = {1, 2, 3, 4}; + TypedTree tree(sorted_unique, items); + + EXPECT_THAT(tree, ElementsAre(1, 2, 3, 4)); + EXPECT_THAT(items, ElementsAre(1, 2, 3, 4)); +} + +// flat_tree(sorted_unique_t, std::vector&&) + +TEST(FlatTree, SortedUniqueVectorMoveConstructor) { + using Pair = std::pair; + + std::vector storage; + storage.push_back(Pair(1, MoveOnlyInt(0))); + storage.push_back(Pair(2, MoveOnlyInt(0))); + + using Tree = flat_tree, std::vector>; + Tree tree(sorted_unique, std::move(storage)); + + ASSERT_EQ(2u, tree.size()); + const Pair& zeroth = *tree.begin(); + ASSERT_EQ(1, zeroth.first); + ASSERT_EQ(0, zeroth.second.data()); + + const Pair& first = *(tree.begin() + 1); + ASSERT_EQ(2, first.first); + ASSERT_EQ(0, first.second.data()); +} + +// flat_tree(sorted_unique_t, +// std::initializer_list ilist, +// const Compare& comp = Compare()) + +TYPED_TEST_P(FlatTreeTest, SortedUniqueInitializerListConstructor) { + { + TypedTree cont(sorted_unique, {1, 2, 3, 4, 5, 6, 8, 10}); + EXPECT_THAT(cont, ElementsAre(1, 2, 3, 4, 5, 6, 8, 10)); + } + { + TypedTree cont(sorted_unique, {1, 2, 3, 4, 5, 6, 8, 10}); + EXPECT_THAT(cont, ElementsAre(1, 2, 3, 4, 5, 6, 8, 10)); + } + { + TreeWithStrangeCompare cont(sorted_unique, {1, 2, 3, 4, 5, 6, 8, 10}, + NonDefaultConstructibleCompare(0)); + EXPECT_THAT(cont, ElementsAre(1, 2, 3, 4, 5, 6, 8, 10)); + } + { + IntPairTree first_of(sorted_unique, {{1, 1}, {2, 1}}); + EXPECT_THAT(first_of, ElementsAre(IntPair(1, 1), IntPair(2, 1))); + } +} + +// ---------------------------------------------------------------------------- +// Assignments. + +// flat_tree& operator=(const flat_tree&) + +TYPED_TEST_P(FlatTreeTest, CopyAssignable) { + TypedTree original({1, 2, 3, 4}); + TypedTree copied; + copied = original; + + EXPECT_THAT(copied, ElementsAre(1, 2, 3, 4)); + EXPECT_THAT(original, ElementsAre(1, 2, 3, 4)); + EXPECT_EQ(original, copied); +} + +// flat_tree& operator=(flat_tree&&) + +TEST(FlatTree, MoveAssignable) { + int input_range[] = {1, 2, 3, 4}; + + MoveOnlyTree original(std::begin(input_range), std::end(input_range)); + MoveOnlyTree moved; + moved = std::move(original); + + EXPECT_EQ(1U, moved.count(MoveOnlyInt(1))); + EXPECT_EQ(1U, moved.count(MoveOnlyInt(2))); + EXPECT_EQ(1U, moved.count(MoveOnlyInt(3))); + EXPECT_EQ(1U, moved.count(MoveOnlyInt(4))); +} + +// flat_tree& operator=(std::initializer_list ilist) + +TYPED_TEST_P(FlatTreeTest, InitializerListAssignable) { + TypedTree cont({0}); + cont = {1, 2, 3, 4, 5, 6, 10, 8}; + + EXPECT_EQ(0U, cont.count(0)); + EXPECT_THAT(cont, ElementsAre(1, 2, 3, 4, 5, 6, 8, 10)); +} + +// -------------------------------------------------------------------------- +// Memory management. + +// void reserve(size_type new_capacity) + +TEST(FlatTreeTest, Reserve) { + IntTree cont({1, 2, 3}); + + cont.reserve(5); + EXPECT_LE(5U, cont.capacity()); +} + +// size_type capacity() const + +TEST(FlatTreeTest, Capacity) { + IntTree cont({1, 2, 3}); + + EXPECT_LE(cont.size(), cont.capacity()); + cont.reserve(5); + EXPECT_LE(cont.size(), cont.capacity()); +} + +// void shrink_to_fit() + +TEST(FlatTreeTest, ShrinkToFit) { + IntTree cont({1, 2, 3}); + + IntTree::size_type capacity_before = cont.capacity(); + cont.shrink_to_fit(); + EXPECT_GE(capacity_before, cont.capacity()); +} + +// ---------------------------------------------------------------------------- +// Size management. + +// void clear() + +TYPED_TEST_P(FlatTreeTest, Clear) { + TypedTree cont({1, 2, 3, 4, 5, 6, 7, 8}); + cont.clear(); + EXPECT_THAT(cont, ElementsAre()); +} + +// size_type size() const + +TYPED_TEST_P(FlatTreeTest, Size) { + TypedTree cont; + + EXPECT_EQ(0U, cont.size()); + cont.insert(2); + EXPECT_EQ(1U, cont.size()); + cont.insert(1); + EXPECT_EQ(2U, cont.size()); + cont.insert(3); + EXPECT_EQ(3U, cont.size()); + cont.erase(cont.begin()); + EXPECT_EQ(2U, cont.size()); + cont.erase(cont.begin()); + EXPECT_EQ(1U, cont.size()); + cont.erase(cont.begin()); + EXPECT_EQ(0U, cont.size()); +} + +// bool empty() const + +TYPED_TEST_P(FlatTreeTest, Empty) { + TypedTree cont; + + EXPECT_TRUE(cont.empty()); + cont.insert(1); + EXPECT_FALSE(cont.empty()); + cont.clear(); + EXPECT_TRUE(cont.empty()); +} + +// ---------------------------------------------------------------------------- +// Iterators. + +// iterator begin() +// const_iterator begin() const +// iterator end() +// const_iterator end() const +// +// reverse_iterator rbegin() +// const_reverse_iterator rbegin() const +// reverse_iterator rend() +// const_reverse_iterator rend() const +// +// const_iterator cbegin() const +// const_iterator cend() const +// const_reverse_iterator crbegin() const +// const_reverse_iterator crend() const + +TYPED_TEST_P(FlatTreeTest, Iterators) { + TypedTree cont({1, 2, 3, 4, 5, 6, 7, 8}); + + auto size = + static_cast::difference_type>(cont.size()); + + EXPECT_EQ(size, std::distance(cont.begin(), cont.end())); + EXPECT_EQ(size, std::distance(cont.cbegin(), cont.cend())); + EXPECT_EQ(size, std::distance(cont.rbegin(), cont.rend())); + EXPECT_EQ(size, std::distance(cont.crbegin(), cont.crend())); + + { + auto it = cont.begin(); + auto c_it = cont.cbegin(); + EXPECT_EQ(it, c_it); + for (int j = 1; it != cont.end(); ++it, ++c_it, ++j) { + EXPECT_EQ(j, *it); + EXPECT_EQ(j, *c_it); + } + } + { + auto rit = cont.rbegin(); + auto c_rit = cont.crbegin(); + EXPECT_EQ(rit, c_rit); + for (int j = static_cast(size); rit != cont.rend(); + ++rit, ++c_rit, --j) { + EXPECT_EQ(j, *rit); + EXPECT_EQ(j, *c_rit); + } + } +} + +// ---------------------------------------------------------------------------- +// Insert operations. + +// pair insert(const value_type& val) + +TYPED_TEST_P(FlatTreeTest, InsertLValue) { + TypedTree cont; + + int value = 2; + std::pair::iterator, bool> result = + cont.insert(value); + EXPECT_TRUE(result.second); + EXPECT_EQ(cont.begin(), result.first); + EXPECT_EQ(1U, cont.size()); + EXPECT_EQ(2, *result.first); + + value = 1; + result = cont.insert(value); + EXPECT_TRUE(result.second); + EXPECT_EQ(cont.begin(), result.first); + EXPECT_EQ(2U, cont.size()); + EXPECT_EQ(1, *result.first); + + value = 3; + result = cont.insert(value); + EXPECT_TRUE(result.second); + EXPECT_EQ(std::prev(cont.end()), result.first); + EXPECT_EQ(3U, cont.size()); + EXPECT_EQ(3, *result.first); + + value = 3; + result = cont.insert(value); + EXPECT_FALSE(result.second); + EXPECT_EQ(std::prev(cont.end()), result.first); + EXPECT_EQ(3U, cont.size()); + EXPECT_EQ(3, *result.first); +} + +// pair insert(value_type&& val) + +TEST(FlatTree, InsertRValue) { + MoveOnlyTree cont; + + std::pair result = cont.insert(MoveOnlyInt(2)); + EXPECT_TRUE(result.second); + EXPECT_EQ(cont.begin(), result.first); + EXPECT_EQ(1U, cont.size()); + EXPECT_EQ(2, result.first->data()); + + result = cont.insert(MoveOnlyInt(1)); + EXPECT_TRUE(result.second); + EXPECT_EQ(cont.begin(), result.first); + EXPECT_EQ(2U, cont.size()); + EXPECT_EQ(1, result.first->data()); + + result = cont.insert(MoveOnlyInt(3)); + EXPECT_TRUE(result.second); + EXPECT_EQ(std::prev(cont.end()), result.first); + EXPECT_EQ(3U, cont.size()); + EXPECT_EQ(3, result.first->data()); + + result = cont.insert(MoveOnlyInt(3)); + EXPECT_FALSE(result.second); + EXPECT_EQ(std::prev(cont.end()), result.first); + EXPECT_EQ(3U, cont.size()); + EXPECT_EQ(3, result.first->data()); +} + +// iterator insert(const_iterator position_hint, const value_type& val) + +TYPED_TEST_P(FlatTreeTest, InsertPositionLValue) { + TypedTree cont; + + auto result = cont.insert(cont.cend(), 2); + EXPECT_EQ(cont.begin(), result); + EXPECT_EQ(1U, cont.size()); + EXPECT_EQ(2, *result); + + result = cont.insert(cont.cend(), 1); + EXPECT_EQ(cont.begin(), result); + EXPECT_EQ(2U, cont.size()); + EXPECT_EQ(1, *result); + + result = cont.insert(cont.cend(), 3); + EXPECT_EQ(std::prev(cont.end()), result); + EXPECT_EQ(3U, cont.size()); + EXPECT_EQ(3, *result); + + result = cont.insert(cont.cend(), 3); + EXPECT_EQ(std::prev(cont.end()), result); + EXPECT_EQ(3U, cont.size()); + EXPECT_EQ(3, *result); +} + +// iterator insert(const_iterator position_hint, value_type&& val) + +TEST(FlatTree, InsertPositionRValue) { + MoveOnlyTree cont; + + auto result = cont.insert(cont.cend(), MoveOnlyInt(2)); + EXPECT_EQ(cont.begin(), result); + EXPECT_EQ(1U, cont.size()); + EXPECT_EQ(2, result->data()); + + result = cont.insert(cont.cend(), MoveOnlyInt(1)); + EXPECT_EQ(cont.begin(), result); + EXPECT_EQ(2U, cont.size()); + EXPECT_EQ(1, result->data()); + + result = cont.insert(cont.cend(), MoveOnlyInt(3)); + EXPECT_EQ(std::prev(cont.end()), result); + EXPECT_EQ(3U, cont.size()); + EXPECT_EQ(3, result->data()); + + result = cont.insert(cont.cend(), MoveOnlyInt(3)); + EXPECT_EQ(std::prev(cont.end()), result); + EXPECT_EQ(3U, cont.size()); + EXPECT_EQ(3, result->data()); +} + +// template +// void insert(InputIterator first, InputIterator last); + +TEST(FlatTree, InsertIterIter) { + struct GetKeyFromIntIntPair { + const int& operator()(const std::pair& p) const { + return p.first; + } + }; + + using IntIntMap = flat_tree, + std::vector>; + + { + IntIntMap cont; + IntPair int_pairs[] = {{3, 1}, {1, 1}, {4, 1}, {2, 1}}; + cont.insert(std::begin(int_pairs), std::end(int_pairs)); + EXPECT_THAT(cont, ElementsAre(IntPair(1, 1), IntPair(2, 1), IntPair(3, 1), + IntPair(4, 1))); + } + + { + IntIntMap cont({{1, 1}, {2, 1}, {3, 1}, {4, 1}}); + std::vector int_pairs; + cont.insert(std::begin(int_pairs), std::end(int_pairs)); + EXPECT_THAT(cont, ElementsAre(IntPair(1, 1), IntPair(2, 1), IntPair(3, 1), + IntPair(4, 1))); + } + + { + IntIntMap cont({{1, 1}, {2, 1}, {3, 1}, {4, 1}}); + IntPair int_pairs[] = {{1, 1}}; + cont.insert(std::begin(int_pairs), std::end(int_pairs)); + EXPECT_THAT(cont, ElementsAre(IntPair(1, 1), IntPair(2, 1), IntPair(3, 1), + IntPair(4, 1))); + } + + { + IntIntMap cont({{1, 1}, {2, 1}, {3, 1}, {4, 1}}); + IntPair int_pairs[] = {{5, 1}}; + cont.insert(std::begin(int_pairs), std::end(int_pairs)); + EXPECT_THAT(cont, ElementsAre(IntPair(1, 1), IntPair(2, 1), IntPair(3, 1), + IntPair(4, 1), IntPair(5, 1))); + } + + { + IntIntMap cont({{1, 1}, {2, 1}, {3, 1}, {4, 1}}); + IntPair int_pairs[] = {{3, 2}, {1, 2}, {4, 2}, {2, 2}}; + cont.insert(std::begin(int_pairs), std::end(int_pairs)); + EXPECT_THAT(cont, ElementsAre(IntPair(1, 1), IntPair(2, 1), IntPair(3, 1), + IntPair(4, 1))); + } + + { + IntIntMap cont({{1, 1}, {2, 1}, {3, 1}, {4, 1}}); + IntPair int_pairs[] = {{3, 2}, {1, 2}, {4, 2}, {2, 2}, {7, 2}, {6, 2}, + {8, 2}, {5, 2}, {5, 3}, {6, 3}, {7, 3}, {8, 3}}; + cont.insert(std::begin(int_pairs), std::end(int_pairs)); + EXPECT_THAT(cont, ElementsAre(IntPair(1, 1), IntPair(2, 1), IntPair(3, 1), + IntPair(4, 1), IntPair(5, 2), IntPair(6, 2), + IntPair(7, 2), IntPair(8, 2))); + } +} + +// template +// pair emplace(Args&&... args) + +TYPED_TEST_P(FlatTreeTest, Emplace) { + { + EmplaceableTree cont; + + std::pair result = cont.emplace(); + EXPECT_TRUE(result.second); + EXPECT_EQ(cont.begin(), result.first); + EXPECT_EQ(1U, cont.size()); + EXPECT_EQ(Emplaceable(), *cont.begin()); + + result = cont.emplace(2, 3.5); + EXPECT_TRUE(result.second); + EXPECT_EQ(std::next(cont.begin()), result.first); + EXPECT_EQ(2U, cont.size()); + EXPECT_EQ(Emplaceable(2, 3.5), *result.first); + + result = cont.emplace(2, 3.5); + EXPECT_FALSE(result.second); + EXPECT_EQ(std::next(cont.begin()), result.first); + EXPECT_EQ(2U, cont.size()); + EXPECT_EQ(Emplaceable(2, 3.5), *result.first); + } + { + TypedTree cont; + + std::pair::iterator, bool> result = + cont.emplace(2); + EXPECT_TRUE(result.second); + EXPECT_EQ(cont.begin(), result.first); + EXPECT_EQ(1U, cont.size()); + EXPECT_EQ(2, *result.first); + } +} + +// template +// iterator emplace_hint(const_iterator position_hint, Args&&... args) + +TYPED_TEST_P(FlatTreeTest, EmplacePosition) { + { + EmplaceableTree cont; + + auto result = cont.emplace_hint(cont.cend()); + EXPECT_EQ(cont.begin(), result); + EXPECT_EQ(1U, cont.size()); + EXPECT_EQ(Emplaceable(), *cont.begin()); + + result = cont.emplace_hint(cont.cend(), 2, 3.5); + EXPECT_EQ(std::next(cont.begin()), result); + EXPECT_EQ(2U, cont.size()); + EXPECT_EQ(Emplaceable(2, 3.5), *result); + + result = cont.emplace_hint(cont.cbegin(), 2, 3.5); + EXPECT_EQ(std::next(cont.begin()), result); + EXPECT_EQ(2U, cont.size()); + EXPECT_EQ(Emplaceable(2, 3.5), *result); + } + { + TypedTree cont; + + auto result = cont.emplace_hint(cont.cend(), 2); + EXPECT_EQ(cont.begin(), result); + EXPECT_EQ(1U, cont.size()); + EXPECT_EQ(2, *result); + } +} + +// ---------------------------------------------------------------------------- +// Underlying type operations. + +// underlying_type extract() && +TYPED_TEST_P(FlatTreeTest, Extract) { + TypedTree cont; + cont.emplace(3); + cont.emplace(1); + cont.emplace(2); + cont.emplace(4); + + TypeParam body = std::move(cont).extract(); + EXPECT_THAT(cont, IsEmpty()); + EXPECT_THAT(body, ElementsAre(1, 2, 3, 4)); +} + +// replace(underlying_type&&) +TYPED_TEST_P(FlatTreeTest, Replace) { + TypeParam body = {1, 2, 3, 4}; + TypedTree cont; + cont.replace(std::move(body)); + + EXPECT_THAT(cont, ElementsAre(1, 2, 3, 4)); +} + +// ---------------------------------------------------------------------------- +// Erase operations. + +// iterator erase(const_iterator position_hint) + +TYPED_TEST_P(FlatTreeTest, ErasePosition) { + { + TypedTree cont({1, 2, 3, 4, 5, 6, 7, 8}); + + auto it = cont.erase(std::next(cont.cbegin(), 3)); + EXPECT_EQ(std::next(cont.begin(), 3), it); + EXPECT_THAT(cont, ElementsAre(1, 2, 3, 5, 6, 7, 8)); + + it = cont.erase(std::next(cont.cbegin(), 0)); + EXPECT_EQ(cont.begin(), it); + EXPECT_THAT(cont, ElementsAre(2, 3, 5, 6, 7, 8)); + + it = cont.erase(std::next(cont.cbegin(), 5)); + EXPECT_EQ(cont.end(), it); + EXPECT_THAT(cont, ElementsAre(2, 3, 5, 6, 7)); + + it = cont.erase(std::next(cont.cbegin(), 1)); + EXPECT_EQ(std::next(cont.begin()), it); + EXPECT_THAT(cont, ElementsAre(2, 5, 6, 7)); + + it = cont.erase(std::next(cont.cbegin(), 2)); + EXPECT_EQ(std::next(cont.begin(), 2), it); + EXPECT_THAT(cont, ElementsAre(2, 5, 7)); + + it = cont.erase(std::next(cont.cbegin(), 2)); + EXPECT_EQ(std::next(cont.begin(), 2), it); + EXPECT_THAT(cont, ElementsAre(2, 5)); + + it = cont.erase(std::next(cont.cbegin(), 0)); + EXPECT_EQ(std::next(cont.begin(), 0), it); + EXPECT_THAT(cont, ElementsAre(5)); + + it = cont.erase(cont.cbegin()); + EXPECT_EQ(cont.begin(), it); + EXPECT_EQ(cont.end(), it); + } + // This is LWG #2059. + // There is a potential ambiguity between erase with an iterator and erase + // with a key, if key has a templated constructor. + { + using T = TemplateConstructor; + + flat_tree, std::vector> cont; + T v(0); + + auto it = cont.find(v); + if (it != cont.end()) + cont.erase(it); + } +} + +// iterator erase(const_iterator first, const_iterator last) + +TYPED_TEST_P(FlatTreeTest, EraseRange) { + TypedTree cont({1, 2, 3, 4, 5, 6, 7, 8}); + + auto it = + cont.erase(std::next(cont.cbegin(), 5), std::next(cont.cbegin(), 5)); + EXPECT_EQ(std::next(cont.begin(), 5), it); + EXPECT_THAT(cont, ElementsAre(1, 2, 3, 4, 5, 6, 7, 8)); + + it = cont.erase(std::next(cont.cbegin(), 3), std::next(cont.cbegin(), 4)); + EXPECT_EQ(std::next(cont.begin(), 3), it); + EXPECT_THAT(cont, ElementsAre(1, 2, 3, 5, 6, 7, 8)); + + it = cont.erase(std::next(cont.cbegin(), 2), std::next(cont.cbegin(), 5)); + EXPECT_EQ(std::next(cont.begin(), 2), it); + EXPECT_THAT(cont, ElementsAre(1, 2, 7, 8)); + + it = cont.erase(std::next(cont.cbegin(), 0), std::next(cont.cbegin(), 2)); + EXPECT_EQ(std::next(cont.begin(), 0), it); + EXPECT_THAT(cont, ElementsAre(7, 8)); + + it = cont.erase(cont.cbegin(), cont.cend()); + EXPECT_EQ(cont.begin(), it); + EXPECT_EQ(cont.end(), it); +} + +// size_type erase(const key_type& key) + +TYPED_TEST_P(FlatTreeTest, EraseKey) { + TypedTree cont({1, 2, 3, 4, 5, 6, 7, 8}); + + EXPECT_EQ(0U, cont.erase(9)); + EXPECT_THAT(cont, ElementsAre(1, 2, 3, 4, 5, 6, 7, 8)); + + EXPECT_EQ(1U, cont.erase(4)); + EXPECT_THAT(cont, ElementsAre(1, 2, 3, 5, 6, 7, 8)); + + EXPECT_EQ(1U, cont.erase(1)); + EXPECT_THAT(cont, ElementsAre(2, 3, 5, 6, 7, 8)); + + EXPECT_EQ(1U, cont.erase(8)); + EXPECT_THAT(cont, ElementsAre(2, 3, 5, 6, 7)); + + EXPECT_EQ(1U, cont.erase(3)); + EXPECT_THAT(cont, ElementsAre(2, 5, 6, 7)); + + EXPECT_EQ(1U, cont.erase(6)); + EXPECT_THAT(cont, ElementsAre(2, 5, 7)); + + EXPECT_EQ(1U, cont.erase(7)); + EXPECT_THAT(cont, ElementsAre(2, 5)); + + EXPECT_EQ(1U, cont.erase(2)); + EXPECT_THAT(cont, ElementsAre(5)); + + EXPECT_EQ(1U, cont.erase(5)); + EXPECT_THAT(cont, ElementsAre()); +} + +TYPED_TEST_P(FlatTreeTest, EraseEndDeath) { + { + TypedTree tree; + ASSERT_DEATH_IF_SUPPORTED(tree.erase(tree.cend()), ""); + } + + { + TypedTree tree = {1, 2, 3, 4}; + ASSERT_DEATH_IF_SUPPORTED(tree.erase(tree.find(5)), ""); + } +} + +// ---------------------------------------------------------------------------- +// Comparators. + +// key_compare key_comp() const + +TEST(FlatTree, KeyComp) { + ReversedTree cont({1, 2, 3, 4, 5}); + + EXPECT_TRUE(absl::c_is_sorted(cont, cont.key_comp())); + int new_elements[] = {6, 7, 8, 9, 10}; + std::copy(std::begin(new_elements), std::end(new_elements), + std::inserter(cont, cont.end())); + EXPECT_TRUE(absl::c_is_sorted(cont, cont.key_comp())); +} + +// value_compare value_comp() const + +TEST(FlatTree, ValueComp) { + ReversedTree cont({1, 2, 3, 4, 5}); + + EXPECT_TRUE(absl::c_is_sorted(cont, cont.value_comp())); + int new_elements[] = {6, 7, 8, 9, 10}; + std::copy(std::begin(new_elements), std::end(new_elements), + std::inserter(cont, cont.end())); + EXPECT_TRUE(absl::c_is_sorted(cont, cont.value_comp())); +} + +// ---------------------------------------------------------------------------- +// Search operations. + +// size_type count(const key_type& key) const + +TYPED_TEST_P(FlatTreeTest, Count) { + const TypedTree cont({5, 6, 7, 8, 9, 10, 11, 12}); + + EXPECT_EQ(1U, cont.count(5)); + EXPECT_EQ(1U, cont.count(6)); + EXPECT_EQ(1U, cont.count(7)); + EXPECT_EQ(1U, cont.count(8)); + EXPECT_EQ(1U, cont.count(9)); + EXPECT_EQ(1U, cont.count(10)); + EXPECT_EQ(1U, cont.count(11)); + EXPECT_EQ(1U, cont.count(12)); + EXPECT_EQ(0U, cont.count(4)); +} + +// iterator find(const key_type& key) +// const_iterator find(const key_type& key) const + +TYPED_TEST_P(FlatTreeTest, Find) { + { + TypedTree cont({5, 6, 7, 8, 9, 10, 11, 12}); + + EXPECT_EQ(cont.begin(), cont.find(5)); + EXPECT_EQ(std::next(cont.begin()), cont.find(6)); + EXPECT_EQ(std::next(cont.begin(), 2), cont.find(7)); + EXPECT_EQ(std::next(cont.begin(), 3), cont.find(8)); + EXPECT_EQ(std::next(cont.begin(), 4), cont.find(9)); + EXPECT_EQ(std::next(cont.begin(), 5), cont.find(10)); + EXPECT_EQ(std::next(cont.begin(), 6), cont.find(11)); + EXPECT_EQ(std::next(cont.begin(), 7), cont.find(12)); + EXPECT_EQ(std::next(cont.begin(), 8), cont.find(4)); + } + { + const TypedTree cont({5, 6, 7, 8, 9, 10, 11, 12}); + + EXPECT_EQ(cont.begin(), cont.find(5)); + EXPECT_EQ(std::next(cont.begin()), cont.find(6)); + EXPECT_EQ(std::next(cont.begin(), 2), cont.find(7)); + EXPECT_EQ(std::next(cont.begin(), 3), cont.find(8)); + EXPECT_EQ(std::next(cont.begin(), 4), cont.find(9)); + EXPECT_EQ(std::next(cont.begin(), 5), cont.find(10)); + EXPECT_EQ(std::next(cont.begin(), 6), cont.find(11)); + EXPECT_EQ(std::next(cont.begin(), 7), cont.find(12)); + EXPECT_EQ(std::next(cont.begin(), 8), cont.find(4)); + } +} + +// bool contains(const key_type& key) const + +TYPED_TEST_P(FlatTreeTest, Contains) { + const TypedTree cont({5, 6, 7, 8, 9, 10, 11, 12}); + + EXPECT_TRUE(cont.contains(5)); + EXPECT_TRUE(cont.contains(6)); + EXPECT_TRUE(cont.contains(7)); + EXPECT_TRUE(cont.contains(8)); + EXPECT_TRUE(cont.contains(9)); + EXPECT_TRUE(cont.contains(10)); + EXPECT_TRUE(cont.contains(11)); + EXPECT_TRUE(cont.contains(12)); + EXPECT_FALSE(cont.contains(4)); +} + +// pair equal_range(const key_type& key) +// pair equal_range(const key_type& key) const + +TYPED_TEST_P(FlatTreeTest, EqualRange) { + { + TypedTree cont({5, 7, 9, 11, 13, 15, 17, 19}); + + std::pair::iterator, + typename TypedTree::iterator> + result = cont.equal_range(5); + EXPECT_EQ(std::next(cont.begin(), 0), result.first); + EXPECT_EQ(std::next(cont.begin(), 1), result.second); + result = cont.equal_range(7); + EXPECT_EQ(std::next(cont.begin(), 1), result.first); + EXPECT_EQ(std::next(cont.begin(), 2), result.second); + result = cont.equal_range(9); + EXPECT_EQ(std::next(cont.begin(), 2), result.first); + EXPECT_EQ(std::next(cont.begin(), 3), result.second); + result = cont.equal_range(11); + EXPECT_EQ(std::next(cont.begin(), 3), result.first); + EXPECT_EQ(std::next(cont.begin(), 4), result.second); + result = cont.equal_range(13); + EXPECT_EQ(std::next(cont.begin(), 4), result.first); + EXPECT_EQ(std::next(cont.begin(), 5), result.second); + result = cont.equal_range(15); + EXPECT_EQ(std::next(cont.begin(), 5), result.first); + EXPECT_EQ(std::next(cont.begin(), 6), result.second); + result = cont.equal_range(17); + EXPECT_EQ(std::next(cont.begin(), 6), result.first); + EXPECT_EQ(std::next(cont.begin(), 7), result.second); + result = cont.equal_range(19); + EXPECT_EQ(std::next(cont.begin(), 7), result.first); + EXPECT_EQ(std::next(cont.begin(), 8), result.second); + result = cont.equal_range(4); + EXPECT_EQ(std::next(cont.begin(), 0), result.first); + EXPECT_EQ(std::next(cont.begin(), 0), result.second); + result = cont.equal_range(6); + EXPECT_EQ(std::next(cont.begin(), 1), result.first); + EXPECT_EQ(std::next(cont.begin(), 1), result.second); + result = cont.equal_range(8); + EXPECT_EQ(std::next(cont.begin(), 2), result.first); + EXPECT_EQ(std::next(cont.begin(), 2), result.second); + result = cont.equal_range(10); + EXPECT_EQ(std::next(cont.begin(), 3), result.first); + EXPECT_EQ(std::next(cont.begin(), 3), result.second); + result = cont.equal_range(12); + EXPECT_EQ(std::next(cont.begin(), 4), result.first); + EXPECT_EQ(std::next(cont.begin(), 4), result.second); + result = cont.equal_range(14); + EXPECT_EQ(std::next(cont.begin(), 5), result.first); + EXPECT_EQ(std::next(cont.begin(), 5), result.second); + result = cont.equal_range(16); + EXPECT_EQ(std::next(cont.begin(), 6), result.first); + EXPECT_EQ(std::next(cont.begin(), 6), result.second); + result = cont.equal_range(18); + EXPECT_EQ(std::next(cont.begin(), 7), result.first); + EXPECT_EQ(std::next(cont.begin(), 7), result.second); + result = cont.equal_range(20); + EXPECT_EQ(std::next(cont.begin(), 8), result.first); + EXPECT_EQ(std::next(cont.begin(), 8), result.second); + } + { + const TypedTree cont({5, 7, 9, 11, 13, 15, 17, 19}); + + std::pair::const_iterator, + typename TypedTree::const_iterator> + result = cont.equal_range(5); + EXPECT_EQ(std::next(cont.begin(), 0), result.first); + EXPECT_EQ(std::next(cont.begin(), 1), result.second); + result = cont.equal_range(7); + EXPECT_EQ(std::next(cont.begin(), 1), result.first); + EXPECT_EQ(std::next(cont.begin(), 2), result.second); + result = cont.equal_range(9); + EXPECT_EQ(std::next(cont.begin(), 2), result.first); + EXPECT_EQ(std::next(cont.begin(), 3), result.second); + result = cont.equal_range(11); + EXPECT_EQ(std::next(cont.begin(), 3), result.first); + EXPECT_EQ(std::next(cont.begin(), 4), result.second); + result = cont.equal_range(13); + EXPECT_EQ(std::next(cont.begin(), 4), result.first); + EXPECT_EQ(std::next(cont.begin(), 5), result.second); + result = cont.equal_range(15); + EXPECT_EQ(std::next(cont.begin(), 5), result.first); + EXPECT_EQ(std::next(cont.begin(), 6), result.second); + result = cont.equal_range(17); + EXPECT_EQ(std::next(cont.begin(), 6), result.first); + EXPECT_EQ(std::next(cont.begin(), 7), result.second); + result = cont.equal_range(19); + EXPECT_EQ(std::next(cont.begin(), 7), result.first); + EXPECT_EQ(std::next(cont.begin(), 8), result.second); + result = cont.equal_range(4); + EXPECT_EQ(std::next(cont.begin(), 0), result.first); + EXPECT_EQ(std::next(cont.begin(), 0), result.second); + result = cont.equal_range(6); + EXPECT_EQ(std::next(cont.begin(), 1), result.first); + EXPECT_EQ(std::next(cont.begin(), 1), result.second); + result = cont.equal_range(8); + EXPECT_EQ(std::next(cont.begin(), 2), result.first); + EXPECT_EQ(std::next(cont.begin(), 2), result.second); + result = cont.equal_range(10); + EXPECT_EQ(std::next(cont.begin(), 3), result.first); + EXPECT_EQ(std::next(cont.begin(), 3), result.second); + result = cont.equal_range(12); + EXPECT_EQ(std::next(cont.begin(), 4), result.first); + EXPECT_EQ(std::next(cont.begin(), 4), result.second); + result = cont.equal_range(14); + EXPECT_EQ(std::next(cont.begin(), 5), result.first); + EXPECT_EQ(std::next(cont.begin(), 5), result.second); + result = cont.equal_range(16); + EXPECT_EQ(std::next(cont.begin(), 6), result.first); + EXPECT_EQ(std::next(cont.begin(), 6), result.second); + result = cont.equal_range(18); + EXPECT_EQ(std::next(cont.begin(), 7), result.first); + EXPECT_EQ(std::next(cont.begin(), 7), result.second); + result = cont.equal_range(20); + EXPECT_EQ(std::next(cont.begin(), 8), result.first); + EXPECT_EQ(std::next(cont.begin(), 8), result.second); + } +} + +// iterator lower_bound(const key_type& key); +// const_iterator lower_bound(const key_type& key) const; + +TYPED_TEST_P(FlatTreeTest, LowerBound) { + { + TypedTree cont({5, 7, 9, 11, 13, 15, 17, 19}); + + EXPECT_EQ(cont.begin(), cont.lower_bound(5)); + EXPECT_EQ(std::next(cont.begin()), cont.lower_bound(7)); + EXPECT_EQ(std::next(cont.begin(), 2), cont.lower_bound(9)); + EXPECT_EQ(std::next(cont.begin(), 3), cont.lower_bound(11)); + EXPECT_EQ(std::next(cont.begin(), 4), cont.lower_bound(13)); + EXPECT_EQ(std::next(cont.begin(), 5), cont.lower_bound(15)); + EXPECT_EQ(std::next(cont.begin(), 6), cont.lower_bound(17)); + EXPECT_EQ(std::next(cont.begin(), 7), cont.lower_bound(19)); + EXPECT_EQ(std::next(cont.begin(), 0), cont.lower_bound(4)); + EXPECT_EQ(std::next(cont.begin(), 1), cont.lower_bound(6)); + EXPECT_EQ(std::next(cont.begin(), 2), cont.lower_bound(8)); + EXPECT_EQ(std::next(cont.begin(), 3), cont.lower_bound(10)); + EXPECT_EQ(std::next(cont.begin(), 4), cont.lower_bound(12)); + EXPECT_EQ(std::next(cont.begin(), 5), cont.lower_bound(14)); + EXPECT_EQ(std::next(cont.begin(), 6), cont.lower_bound(16)); + EXPECT_EQ(std::next(cont.begin(), 7), cont.lower_bound(18)); + EXPECT_EQ(std::next(cont.begin(), 8), cont.lower_bound(20)); + } + { + const TypedTree cont({5, 7, 9, 11, 13, 15, 17, 19}); + + EXPECT_EQ(cont.begin(), cont.lower_bound(5)); + EXPECT_EQ(std::next(cont.begin()), cont.lower_bound(7)); + EXPECT_EQ(std::next(cont.begin(), 2), cont.lower_bound(9)); + EXPECT_EQ(std::next(cont.begin(), 3), cont.lower_bound(11)); + EXPECT_EQ(std::next(cont.begin(), 4), cont.lower_bound(13)); + EXPECT_EQ(std::next(cont.begin(), 5), cont.lower_bound(15)); + EXPECT_EQ(std::next(cont.begin(), 6), cont.lower_bound(17)); + EXPECT_EQ(std::next(cont.begin(), 7), cont.lower_bound(19)); + EXPECT_EQ(std::next(cont.begin(), 0), cont.lower_bound(4)); + EXPECT_EQ(std::next(cont.begin(), 1), cont.lower_bound(6)); + EXPECT_EQ(std::next(cont.begin(), 2), cont.lower_bound(8)); + EXPECT_EQ(std::next(cont.begin(), 3), cont.lower_bound(10)); + EXPECT_EQ(std::next(cont.begin(), 4), cont.lower_bound(12)); + EXPECT_EQ(std::next(cont.begin(), 5), cont.lower_bound(14)); + EXPECT_EQ(std::next(cont.begin(), 6), cont.lower_bound(16)); + EXPECT_EQ(std::next(cont.begin(), 7), cont.lower_bound(18)); + EXPECT_EQ(std::next(cont.begin(), 8), cont.lower_bound(20)); + } +} + +// iterator upper_bound(const key_type& key) +// const_iterator upper_bound(const key_type& key) const + +TYPED_TEST_P(FlatTreeTest, UpperBound) { + { + TypedTree cont({5, 7, 9, 11, 13, 15, 17, 19}); + + EXPECT_EQ(std::next(cont.begin(), 1), cont.upper_bound(5)); + EXPECT_EQ(std::next(cont.begin(), 2), cont.upper_bound(7)); + EXPECT_EQ(std::next(cont.begin(), 3), cont.upper_bound(9)); + EXPECT_EQ(std::next(cont.begin(), 4), cont.upper_bound(11)); + EXPECT_EQ(std::next(cont.begin(), 5), cont.upper_bound(13)); + EXPECT_EQ(std::next(cont.begin(), 6), cont.upper_bound(15)); + EXPECT_EQ(std::next(cont.begin(), 7), cont.upper_bound(17)); + EXPECT_EQ(std::next(cont.begin(), 8), cont.upper_bound(19)); + EXPECT_EQ(std::next(cont.begin(), 0), cont.upper_bound(4)); + EXPECT_EQ(std::next(cont.begin(), 1), cont.upper_bound(6)); + EXPECT_EQ(std::next(cont.begin(), 2), cont.upper_bound(8)); + EXPECT_EQ(std::next(cont.begin(), 3), cont.upper_bound(10)); + EXPECT_EQ(std::next(cont.begin(), 4), cont.upper_bound(12)); + EXPECT_EQ(std::next(cont.begin(), 5), cont.upper_bound(14)); + EXPECT_EQ(std::next(cont.begin(), 6), cont.upper_bound(16)); + EXPECT_EQ(std::next(cont.begin(), 7), cont.upper_bound(18)); + EXPECT_EQ(std::next(cont.begin(), 8), cont.upper_bound(20)); + } + { + const TypedTree cont({5, 7, 9, 11, 13, 15, 17, 19}); + + EXPECT_EQ(std::next(cont.begin(), 1), cont.upper_bound(5)); + EXPECT_EQ(std::next(cont.begin(), 2), cont.upper_bound(7)); + EXPECT_EQ(std::next(cont.begin(), 3), cont.upper_bound(9)); + EXPECT_EQ(std::next(cont.begin(), 4), cont.upper_bound(11)); + EXPECT_EQ(std::next(cont.begin(), 5), cont.upper_bound(13)); + EXPECT_EQ(std::next(cont.begin(), 6), cont.upper_bound(15)); + EXPECT_EQ(std::next(cont.begin(), 7), cont.upper_bound(17)); + EXPECT_EQ(std::next(cont.begin(), 8), cont.upper_bound(19)); + EXPECT_EQ(std::next(cont.begin(), 0), cont.upper_bound(4)); + EXPECT_EQ(std::next(cont.begin(), 1), cont.upper_bound(6)); + EXPECT_EQ(std::next(cont.begin(), 2), cont.upper_bound(8)); + EXPECT_EQ(std::next(cont.begin(), 3), cont.upper_bound(10)); + EXPECT_EQ(std::next(cont.begin(), 4), cont.upper_bound(12)); + EXPECT_EQ(std::next(cont.begin(), 5), cont.upper_bound(14)); + EXPECT_EQ(std::next(cont.begin(), 6), cont.upper_bound(16)); + EXPECT_EQ(std::next(cont.begin(), 7), cont.upper_bound(18)); + EXPECT_EQ(std::next(cont.begin(), 8), cont.upper_bound(20)); + } +} + +// ---------------------------------------------------------------------------- +// General operations. + +// void swap(flat_tree& other) +// void swap(flat_tree& lhs, flat_tree& rhs) + +TYPED_TEST_P(FlatTreeTest, Swap) { + TypedTree x({1, 2, 3}); + TypedTree y({4}); + swap(x, y); + EXPECT_THAT(x, ElementsAre(4)); + EXPECT_THAT(y, ElementsAre(1, 2, 3)); + + y.swap(x); + EXPECT_THAT(x, ElementsAre(1, 2, 3)); + EXPECT_THAT(y, ElementsAre(4)); +} + +// bool operator==(const flat_tree& lhs, const flat_tree& rhs) +// bool operator!=(const flat_tree& lhs, const flat_tree& rhs) +// bool operator<(const flat_tree& lhs, const flat_tree& rhs) +// bool operator>(const flat_tree& lhs, const flat_tree& rhs) +// bool operator<=(const flat_tree& lhs, const flat_tree& rhs) +// bool operator>=(const flat_tree& lhs, const flat_tree& rhs) + +TEST(FlatTree, Comparison) { + // Provided comparator does not participate in comparison. + ReversedTree biggest({3}); + ReversedTree smallest({1}); + ReversedTree middle({1, 2}); + + EXPECT_EQ(biggest, biggest); + EXPECT_NE(biggest, smallest); + EXPECT_LT(smallest, middle); + EXPECT_LE(smallest, middle); + EXPECT_LE(middle, middle); + EXPECT_GT(biggest, middle); + EXPECT_GE(biggest, middle); + EXPECT_GE(biggest, biggest); +} + +TYPED_TEST_P(FlatTreeTest, SupportsEraseIf) { + TypedTree x; + EXPECT_EQ(0u, EraseIf(x, [](int) { return false; })); + EXPECT_THAT(x, ElementsAre()); + + x = {1, 2, 3}; + EXPECT_EQ(1u, EraseIf(x, [](int elem) { return !(elem & 1); })); + EXPECT_THAT(x, ElementsAre(1, 3)); + + x = {1, 2, 3, 4}; + EXPECT_EQ(2u, EraseIf(x, [](int elem) { return elem & 1; })); + EXPECT_THAT(x, ElementsAre(2, 4)); +} + +REGISTER_TYPED_TEST_SUITE_P(FlatTreeTest, + DefaultConstructor, + CopyConstructor, + ContainerCopyConstructor, + InitializerListConstructor, + SortedUniqueContainerCopyConstructor, + SortedUniqueInitializerListConstructor, + CopyAssignable, + InitializerListAssignable, + Clear, + Size, + Empty, + Iterators, + InsertLValue, + InsertPositionLValue, + Emplace, + EmplacePosition, + Extract, + Replace, + ErasePosition, + EraseRange, + EraseKey, + EraseEndDeath, + Count, + Find, + Contains, + EqualRange, + LowerBound, + UpperBound, + Swap, + SupportsEraseIf); + +using IntSequenceContainers = + ::testing::Types, std::vector>; +INSTANTIATE_TYPED_TEST_SUITE_P(My, FlatTreeTest, IntSequenceContainers); + +} // namespace +} // namespace flat_containers_internal +} // namespace webrtc diff --git a/rtc_base/containers/identity.h b/rtc_base/containers/identity.h new file mode 100644 index 0000000000..29592931bd --- /dev/null +++ b/rtc_base/containers/identity.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// This implementation is borrowed from Chromium. + +#ifndef RTC_BASE_CONTAINERS_IDENTITY_H_ +#define RTC_BASE_CONTAINERS_IDENTITY_H_ + +#include + +namespace webrtc { + +// Implementation of C++20's std::identity. +// +// Reference: +// - https://en.cppreference.com/w/cpp/utility/functional/identity +// - https://wg21.link/func.identity +struct identity { + template + constexpr T&& operator()(T&& t) const noexcept { + return std::forward(t); + } + + using is_transparent = void; +}; + +} // namespace webrtc + +#endif // RTC_BASE_CONTAINERS_IDENTITY_H_ diff --git a/rtc_base/containers/invoke.h b/rtc_base/containers/invoke.h new file mode 100644 index 0000000000..5d17a70beb --- /dev/null +++ b/rtc_base/containers/invoke.h @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// This implementation is borrowed from Chromium. + +#ifndef RTC_BASE_CONTAINERS_INVOKE_H_ +#define RTC_BASE_CONTAINERS_INVOKE_H_ + +#include +#include + +namespace webrtc { + +namespace invoke_internal { + +// Helper struct and alias to deduce the class type from a member function +// pointer or member object pointer. +template +struct member_pointer_class {}; + +template +struct member_pointer_class { + using type = ClassT; +}; + +template +using member_pointer_class_t = typename member_pointer_class::type; + +// Utility struct to detect specializations of std::reference_wrapper. +template +struct is_reference_wrapper : std::false_type {}; + +template +struct is_reference_wrapper> : std::true_type {}; + +// Small helpers used below in invoke_internal::invoke to make the SFINAE more +// concise. +template +const bool& IsMemFunPtr = + std::is_member_function_pointer>::value; + +template +const bool& IsMemObjPtr = std::is_member_object_pointer>::value; + +template >> +const bool& IsMemPtrToBaseOf = + std::is_base_of>::value; + +template +const bool& IsRefWrapper = is_reference_wrapper>::value; + +template +using EnableIf = std::enable_if_t; + +// Invokes a member function pointer on a reference to an object of a suitable +// type. Covers bullet 1 of the INVOKE definition. +// +// Reference: https://wg21.link/func.require#1.1 +template && IsMemPtrToBaseOf> = true> +constexpr decltype(auto) InvokeImpl(F&& f, T1&& t1, Args&&... args) { + return (std::forward(t1).*f)(std::forward(args)...); +} + +// Invokes a member function pointer on a std::reference_wrapper to an object of +// a suitable type. Covers bullet 2 of the INVOKE definition. +// +// Reference: https://wg21.link/func.require#1.2 +template && IsRefWrapper> = true> +constexpr decltype(auto) InvokeImpl(F&& f, T1&& t1, Args&&... args) { + return (t1.get().*f)(std::forward(args)...); +} + +// Invokes a member function pointer on a pointer-like type to an object of a +// suitable type. Covers bullet 3 of the INVOKE definition. +// +// Reference: https://wg21.link/func.require#1.3 +template && !IsMemPtrToBaseOf && + !IsRefWrapper> = true> +constexpr decltype(auto) InvokeImpl(F&& f, T1&& t1, Args&&... args) { + return ((*std::forward(t1)).*f)(std::forward(args)...); +} + +// Invokes a member object pointer on a reference to an object of a suitable +// type. Covers bullet 4 of the INVOKE definition. +// +// Reference: https://wg21.link/func.require#1.4 +template && IsMemPtrToBaseOf> = true> +constexpr decltype(auto) InvokeImpl(F&& f, T1&& t1) { + return std::forward(t1).*f; +} + +// Invokes a member object pointer on a std::reference_wrapper to an object of +// a suitable type. Covers bullet 5 of the INVOKE definition. +// +// Reference: https://wg21.link/func.require#1.5 +template && IsRefWrapper> = true> +constexpr decltype(auto) InvokeImpl(F&& f, T1&& t1) { + return t1.get().*f; +} + +// Invokes a member object pointer on a pointer-like type to an object of a +// suitable type. Covers bullet 6 of the INVOKE definition. +// +// Reference: https://wg21.link/func.require#1.6 +template && !IsMemPtrToBaseOf && + !IsRefWrapper> = true> +constexpr decltype(auto) InvokeImpl(F&& f, T1&& t1) { + return (*std::forward(t1)).*f; +} + +// Invokes a regular function or function object. Covers bullet 7 of the INVOKE +// definition. +// +// Reference: https://wg21.link/func.require#1.7 +template +constexpr decltype(auto) InvokeImpl(F&& f, Args&&... args) { + return std::forward(f)(std::forward(args)...); +} + +} // namespace invoke_internal + +// Implementation of C++17's std::invoke. This is not based on implementation +// referenced in original std::invoke proposal, but rather a manual +// implementation, so that it can be constexpr. +// +// References: +// - https://wg21.link/n4169#implementability +// - https://en.cppreference.com/w/cpp/utility/functional/invoke +// - https://wg21.link/func.invoke +template +constexpr decltype(auto) invoke(F&& f, Args&&... args) { + return invoke_internal::InvokeImpl(std::forward(f), + std::forward(args)...); +} + +} // namespace webrtc + +#endif // RTC_BASE_CONTAINERS_INVOKE_H_ diff --git a/rtc_base/containers/move_only_int.h b/rtc_base/containers/move_only_int.h new file mode 100644 index 0000000000..8f745aa688 --- /dev/null +++ b/rtc_base/containers/move_only_int.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// This implementation is borrowed from Chromium. + +#ifndef RTC_BASE_CONTAINERS_MOVE_ONLY_INT_H_ +#define RTC_BASE_CONTAINERS_MOVE_ONLY_INT_H_ + +namespace webrtc { + +// A move-only class that holds an integer. This is designed for testing +// containers. See also CopyOnlyInt. +class MoveOnlyInt { + public: + explicit MoveOnlyInt(int data = 1) : data_(data) {} + MoveOnlyInt(const MoveOnlyInt& other) = delete; + MoveOnlyInt& operator=(const MoveOnlyInt& other) = delete; + MoveOnlyInt(MoveOnlyInt&& other) : data_(other.data_) { other.data_ = 0; } + ~MoveOnlyInt() { data_ = 0; } + + MoveOnlyInt& operator=(MoveOnlyInt&& other) { + data_ = other.data_; + other.data_ = 0; + return *this; + } + + friend bool operator==(const MoveOnlyInt& lhs, const MoveOnlyInt& rhs) { + return lhs.data_ == rhs.data_; + } + + friend bool operator!=(const MoveOnlyInt& lhs, const MoveOnlyInt& rhs) { + return !operator==(lhs, rhs); + } + + friend bool operator<(const MoveOnlyInt& lhs, int rhs) { + return lhs.data_ < rhs; + } + + friend bool operator<(int lhs, const MoveOnlyInt& rhs) { + return lhs < rhs.data_; + } + + friend bool operator<(const MoveOnlyInt& lhs, const MoveOnlyInt& rhs) { + return lhs.data_ < rhs.data_; + } + + friend bool operator>(const MoveOnlyInt& lhs, const MoveOnlyInt& rhs) { + return rhs < lhs; + } + + friend bool operator<=(const MoveOnlyInt& lhs, const MoveOnlyInt& rhs) { + return !(rhs < lhs); + } + + friend bool operator>=(const MoveOnlyInt& lhs, const MoveOnlyInt& rhs) { + return !(lhs < rhs); + } + + int data() const { return data_; } + + private: + volatile int data_; +}; + +} // namespace webrtc + +#endif // RTC_BASE_CONTAINERS_MOVE_ONLY_INT_H_ diff --git a/rtc_base/containers/not_fn.h b/rtc_base/containers/not_fn.h new file mode 100644 index 0000000000..39cfd2763c --- /dev/null +++ b/rtc_base/containers/not_fn.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// This implementation is borrowed from Chromium. + +#ifndef RTC_BASE_CONTAINERS_NOT_FN_H_ +#define RTC_BASE_CONTAINERS_NOT_FN_H_ + +#include +#include + +#include "rtc_base/containers/invoke.h" + +namespace webrtc { + +namespace not_fn_internal { + +template +struct NotFnImpl { + F f; + + template + constexpr decltype(auto) operator()(Args&&... args) & noexcept { + return !webrtc::invoke(f, std::forward(args)...); + } + + template + constexpr decltype(auto) operator()(Args&&... args) const& noexcept { + return !webrtc::invoke(f, std::forward(args)...); + } + + template + constexpr decltype(auto) operator()(Args&&... args) && noexcept { + return !webrtc::invoke(std::move(f), std::forward(args)...); + } + + template + constexpr decltype(auto) operator()(Args&&... args) const&& noexcept { + return !webrtc::invoke(std::move(f), std::forward(args)...); + } +}; + +} // namespace not_fn_internal + +// Implementation of C++17's std::not_fn. +// +// Reference: +// - https://en.cppreference.com/w/cpp/utility/functional/not_fn +// - https://wg21.link/func.not.fn +template +constexpr not_fn_internal::NotFnImpl> not_fn(F&& f) { + return {std::forward(f)}; +} + +} // namespace webrtc + +#endif // RTC_BASE_CONTAINERS_NOT_FN_H_ diff --git a/rtc_base/containers/void_t.h b/rtc_base/containers/void_t.h new file mode 100644 index 0000000000..62c57d4bec --- /dev/null +++ b/rtc_base/containers/void_t.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// This implementation is borrowed from Chromium. + +#ifndef RTC_BASE_CONTAINERS_VOID_T_H_ +#define RTC_BASE_CONTAINERS_VOID_T_H_ + +namespace webrtc { +namespace void_t_internal { +// Implementation detail of webrtc::void_t below. +template +struct make_void { + using type = void; +}; + +} // namespace void_t_internal + +// webrtc::void_t is an implementation of std::void_t from C++17. +// +// We use |webrtc::void_t_internal::make_void| as a helper struct to avoid a +// C++14 defect: +// http://en.cppreference.com/w/cpp/types/void_t +// http://open-std.org/JTC1/SC22/WG21/docs/cwg_defects.html#1558 +template +using void_t = typename ::webrtc::void_t_internal::make_void::type; +} // namespace webrtc + +#endif // RTC_BASE_CONTAINERS_VOID_T_H_ diff --git a/rtc_base/copy_on_write_buffer.cc b/rtc_base/copy_on_write_buffer.cc index 73182a12b1..f3cc710f85 100644 --- a/rtc_base/copy_on_write_buffer.cc +++ b/rtc_base/copy_on_write_buffer.cc @@ -32,16 +32,15 @@ CopyOnWriteBuffer::CopyOnWriteBuffer(const std::string& s) : CopyOnWriteBuffer(s.data(), s.length()) {} CopyOnWriteBuffer::CopyOnWriteBuffer(size_t size) - : buffer_(size > 0 ? new RefCountedObject(size) : nullptr), + : buffer_(size > 0 ? new RefCountedBuffer(size) : nullptr), offset_(0), size_(size) { RTC_DCHECK(IsConsistent()); } CopyOnWriteBuffer::CopyOnWriteBuffer(size_t size, size_t capacity) - : buffer_(size > 0 || capacity > 0 - ? new RefCountedObject(size, capacity) - : nullptr), + : buffer_(size > 0 || capacity > 0 ? new RefCountedBuffer(size, capacity) + : nullptr), offset_(0), size_(size) { RTC_DCHECK(IsConsistent()); @@ -61,7 +60,7 @@ void CopyOnWriteBuffer::SetSize(size_t size) { RTC_DCHECK(IsConsistent()); if (!buffer_) { if (size > 0) { - buffer_ = new RefCountedObject(size); + buffer_ = new RefCountedBuffer(size); offset_ = 0; size_ = size; } @@ -84,7 +83,7 @@ void CopyOnWriteBuffer::EnsureCapacity(size_t new_capacity) { RTC_DCHECK(IsConsistent()); if (!buffer_) { if (new_capacity > 0) { - buffer_ = new RefCountedObject(0, new_capacity); + buffer_ = new RefCountedBuffer(0, new_capacity); offset_ = 0; size_ = 0; } @@ -105,7 +104,7 @@ void CopyOnWriteBuffer::Clear() { if (buffer_->HasOneRef()) { buffer_->Clear(); } else { - buffer_ = new RefCountedObject(0, capacity()); + buffer_ = new RefCountedBuffer(0, capacity()); } offset_ = 0; size_ = 0; @@ -117,8 +116,8 @@ void CopyOnWriteBuffer::UnshareAndEnsureCapacity(size_t new_capacity) { return; } - buffer_ = new RefCountedObject(buffer_->data() + offset_, size_, - new_capacity); + buffer_ = + new RefCountedBuffer(buffer_->data() + offset_, size_, new_capacity); offset_ = 0; RTC_DCHECK(IsConsistent()); } diff --git a/rtc_base/copy_on_write_buffer.h b/rtc_base/copy_on_write_buffer.h index 87bf625fea..526cbe5c5c 100644 --- a/rtc_base/copy_on_write_buffer.h +++ b/rtc_base/copy_on_write_buffer.h @@ -95,14 +95,6 @@ class RTC_EXPORT CopyOnWriteBuffer { return buffer_->data() + offset_; } - // TODO(bugs.webrtc.org/12334): Delete when all usage updated to MutableData() - template ::value>::type* = nullptr> - T* data() { - return MutableData(); - } - // Get const pointer to the data. This will not create a copy of the // underlying data if it is shared with other buffers. template 0 ? new RefCountedObject(data, size) : nullptr; + buffer_ = size > 0 ? new RefCountedBuffer(data, size) : nullptr; } else if (!buffer_->HasOneRef()) { - buffer_ = new RefCountedObject(data, size, capacity()); + buffer_ = new RefCountedBuffer(data, size, capacity()); } else { buffer_->SetData(data, size); } @@ -210,7 +196,7 @@ class RTC_EXPORT CopyOnWriteBuffer { void AppendData(const T* data, size_t size) { RTC_DCHECK(IsConsistent()); if (!buffer_) { - buffer_ = new RefCountedObject(data, size); + buffer_ = new RefCountedBuffer(data, size); offset_ = 0; size_ = size; RTC_DCHECK(IsConsistent()); @@ -256,7 +242,7 @@ class RTC_EXPORT CopyOnWriteBuffer { // Swaps two buffers. friend void swap(CopyOnWriteBuffer& a, CopyOnWriteBuffer& b) { - std::swap(a.buffer_, b.buffer_); + a.buffer_.swap(b.buffer_); std::swap(a.offset_, b.offset_); std::swap(a.size_, b.size_); } @@ -271,6 +257,7 @@ class RTC_EXPORT CopyOnWriteBuffer { } private: + using RefCountedBuffer = FinalRefCountedObject; // Create a copy of the underlying data if it is referenced from other Buffer // objects or there is not enough capacity. void UnshareAndEnsureCapacity(size_t new_capacity); @@ -286,7 +273,7 @@ class RTC_EXPORT CopyOnWriteBuffer { } // buffer_ is either null, or points to an rtc::Buffer with capacity > 0. - scoped_refptr> buffer_; + scoped_refptr buffer_; // This buffer may represent a slice of a original data. size_t offset_; // Offset of a current slice in the original data in buffer_. // Should be 0 if the buffer_ is empty. diff --git a/rtc_base/copy_on_write_buffer_unittest.cc b/rtc_base/copy_on_write_buffer_unittest.cc index 5c29c10465..d3978686a8 100644 --- a/rtc_base/copy_on_write_buffer_unittest.cc +++ b/rtc_base/copy_on_write_buffer_unittest.cc @@ -261,46 +261,33 @@ TEST(CopyOnWriteBufferTest, ClearDoesntChangeCapacity) { EXPECT_EQ(10u, buf2.capacity()); } -TEST(CopyOnWriteBufferTest, TestConstDataAccessor) { +TEST(CopyOnWriteBufferTest, DataAccessorDoesntCloneData) { CopyOnWriteBuffer buf1(kTestData, 3, 10); CopyOnWriteBuffer buf2(buf1); - // .cdata() doesn't clone data. - const uint8_t* cdata1 = buf1.cdata(); - const uint8_t* cdata2 = buf2.cdata(); - EXPECT_EQ(cdata1, cdata2); - - // Non-const .data() clones data if shared. - const uint8_t* data1 = buf1.data(); - const uint8_t* data2 = buf2.data(); - EXPECT_NE(data1, data2); - // buf1 was cloned above. - EXPECT_NE(data1, cdata1); - // Therefore buf2 was no longer sharing data and was not cloned. - EXPECT_EQ(data2, cdata1); + EXPECT_EQ(buf1.data(), buf2.data()); } -// TODO(bugs.webrtc.org/12334): Delete when all reads become const -TEST(CopyOnWriteBufferTest, SeveralReads) { +TEST(CopyOnWriteBufferTest, MutableDataClonesDataWhenShared) { CopyOnWriteBuffer buf1(kTestData, 3, 10); CopyOnWriteBuffer buf2(buf1); + const uint8_t* cdata = buf1.data(); - EnsureBuffersShareData(buf1, buf2); - // Non-const reads clone the data if shared. - for (size_t i = 0; i != 3u; ++i) { - EXPECT_EQ(buf1[i], kTestData[i]); - } - EnsureBuffersDontShareData(buf1, buf2); + uint8_t* data1 = buf1.MutableData(); + uint8_t* data2 = buf2.MutableData(); + // buf1 was cloned above. + EXPECT_NE(data1, cdata); + // Therefore buf2 was no longer sharing data and was not cloned. + EXPECT_EQ(data2, cdata); } -TEST(CopyOnWriteBufferTest, SeveralConstReads) { +TEST(CopyOnWriteBufferTest, SeveralReads) { CopyOnWriteBuffer buf1(kTestData, 3, 10); CopyOnWriteBuffer buf2(buf1); EnsureBuffersShareData(buf1, buf2); - const CopyOnWriteBuffer& cbuf1 = buf1; for (size_t i = 0; i != 3u; ++i) { - EXPECT_EQ(cbuf1[i], kTestData[i]); + EXPECT_EQ(buf1[i], kTestData[i]); } EnsureBuffersShareData(buf1, buf2); } diff --git a/rtc_base/cpu_time_unittest.cc b/rtc_base/cpu_time_unittest.cc index 675e86307c..94f82f4306 100644 --- a/rtc_base/cpu_time_unittest.cc +++ b/rtc_base/cpu_time_unittest.cc @@ -30,8 +30,7 @@ const int kProcessingTimeMillisecs = 500; const int kWorkingThreads = 2; // Consumes approximately kProcessingTimeMillisecs of CPU time in single thread. -void WorkingFunction(void* counter_pointer) { - int64_t* counter = reinterpret_cast(counter_pointer); +void WorkingFunction(int64_t* counter) { *counter = 0; int64_t stop_cpu_time = rtc::GetThreadCpuTimeNanos() + @@ -62,14 +61,12 @@ TEST(CpuTimeTest, MAYBE_TEST(TwoThreads)) { int64_t thread_start_time_nanos = GetThreadCpuTimeNanos(); int64_t counter1; int64_t counter2; - PlatformThread thread1(WorkingFunction, reinterpret_cast(&counter1), - "Thread1"); - PlatformThread thread2(WorkingFunction, reinterpret_cast(&counter2), - "Thread2"); - thread1.Start(); - thread2.Start(); - thread1.Stop(); - thread2.Stop(); + auto thread1 = PlatformThread::SpawnJoinable( + [&counter1] { WorkingFunction(&counter1); }, "Thread1"); + auto thread2 = PlatformThread::SpawnJoinable( + [&counter2] { WorkingFunction(&counter2); }, "Thread2"); + thread1.Finalize(); + thread2.Finalize(); EXPECT_GE(counter1, 0); EXPECT_GE(counter2, 0); diff --git a/rtc_base/deprecated/recursive_critical_section_unittest.cc b/rtc_base/deprecated/recursive_critical_section_unittest.cc index 3fb7c519c1..9256a76f58 100644 --- a/rtc_base/deprecated/recursive_critical_section_unittest.cc +++ b/rtc_base/deprecated/recursive_critical_section_unittest.cc @@ -329,33 +329,28 @@ class PerfTestData { class PerfTestThread { public: - PerfTestThread() : thread_(&ThreadFunc, this, "CsPerf") {} - void Start(PerfTestData* data, int repeats, int id) { - RTC_DCHECK(!thread_.IsRunning()); RTC_DCHECK(!data_); data_ = data; repeats_ = repeats; my_id_ = id; - thread_.Start(); + thread_ = PlatformThread::SpawnJoinable( + [this] { + for (int i = 0; i < repeats_; ++i) + data_->AddToCounter(my_id_); + }, + "CsPerf"); } void Stop() { - RTC_DCHECK(thread_.IsRunning()); RTC_DCHECK(data_); - thread_.Stop(); + thread_.Finalize(); repeats_ = 0; data_ = nullptr; my_id_ = 0; } private: - static void ThreadFunc(void* param) { - PerfTestThread* me = static_cast(param); - for (int i = 0; i < me->repeats_; ++i) - me->data_->AddToCounter(me->my_id_); - } - PlatformThread thread_; PerfTestData* data_ = nullptr; int repeats_ = 0; diff --git a/rtc_base/deprecated/signal_thread.cc b/rtc_base/deprecated/signal_thread.cc deleted file mode 100644 index 96bdd65155..0000000000 --- a/rtc_base/deprecated/signal_thread.cc +++ /dev/null @@ -1,172 +0,0 @@ -/* - * Copyright 2004 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "rtc_base/deprecated/signal_thread.h" - -#include - -#include "rtc_base/checks.h" -#include "rtc_base/location.h" -#include "rtc_base/null_socket_server.h" -#include "rtc_base/socket_server.h" - -namespace rtc { - -/////////////////////////////////////////////////////////////////////////////// -// SignalThread -/////////////////////////////////////////////////////////////////////////////// - -DEPRECATED_SignalThread::DEPRECATED_SignalThread() - : main_(Thread::Current()), worker_(this), state_(kInit), refcount_(1) { - main_->SignalQueueDestroyed.connect( - this, &DEPRECATED_SignalThread::OnMainThreadDestroyed); - worker_.SetName("SignalThread", this); -} - -DEPRECATED_SignalThread::~DEPRECATED_SignalThread() { - rtc::CritScope lock(&cs_); - RTC_DCHECK(refcount_ == 0); -} - -bool DEPRECATED_SignalThread::SetName(const std::string& name, - const void* obj) { - EnterExit ee(this); - RTC_DCHECK(!destroy_called_); - RTC_DCHECK(main_->IsCurrent()); - RTC_DCHECK(kInit == state_); - return worker_.SetName(name, obj); -} - -void DEPRECATED_SignalThread::Start() { - EnterExit ee(this); - RTC_DCHECK(!destroy_called_); - RTC_DCHECK(main_->IsCurrent()); - if (kInit == state_ || kComplete == state_) { - state_ = kRunning; - OnWorkStart(); - worker_.Start(); - } else { - RTC_NOTREACHED(); - } -} - -void DEPRECATED_SignalThread::Destroy(bool wait) { - EnterExit ee(this); - // Sometimes the caller can't guarantee which thread will call Destroy, only - // that it will be the last thing it does. - // RTC_DCHECK(main_->IsCurrent()); - RTC_DCHECK(!destroy_called_); - destroy_called_ = true; - if ((kInit == state_) || (kComplete == state_)) { - refcount_--; - } else if (kRunning == state_ || kReleasing == state_) { - state_ = kStopping; - // OnWorkStop() must follow Quit(), so that when the thread wakes up due to - // OWS(), ContinueWork() will return false. - worker_.Quit(); - OnWorkStop(); - if (wait) { - // Release the thread's lock so that it can return from ::Run. - cs_.Leave(); - worker_.Stop(); - cs_.Enter(); - refcount_--; - } - } else { - RTC_NOTREACHED(); - } -} - -void DEPRECATED_SignalThread::Release() { - EnterExit ee(this); - RTC_DCHECK(!destroy_called_); - RTC_DCHECK(main_->IsCurrent()); - if (kComplete == state_) { - refcount_--; - } else if (kRunning == state_) { - state_ = kReleasing; - } else { - // if (kInit == state_) use Destroy() - RTC_NOTREACHED(); - } -} - -bool DEPRECATED_SignalThread::ContinueWork() { - EnterExit ee(this); - RTC_DCHECK(!destroy_called_); - RTC_DCHECK(worker_.IsCurrent()); - return worker_.ProcessMessages(0); -} - -void DEPRECATED_SignalThread::OnMessage(Message* msg) { - EnterExit ee(this); - if (ST_MSG_WORKER_DONE == msg->message_id) { - RTC_DCHECK(main_->IsCurrent()); - OnWorkDone(); - bool do_delete = false; - if (kRunning == state_) { - state_ = kComplete; - } else { - do_delete = true; - } - if (kStopping != state_) { - // Before signaling that the work is done, make sure that the worker - // thread actually is done. We got here because DoWork() finished and - // Run() posted the ST_MSG_WORKER_DONE message. This means the worker - // thread is about to go away anyway, but sometimes it doesn't actually - // finish before SignalWorkDone is processed, and for a reusable - // SignalThread this makes an assert in thread.cc fire. - // - // Calling Stop() on the worker ensures that the OS thread that underlies - // the worker will finish, and will be set to null, enabling us to call - // Start() again. - worker_.Stop(); - SignalWorkDone(this); - } - if (do_delete) { - refcount_--; - } - } -} - -DEPRECATED_SignalThread::Worker::Worker(DEPRECATED_SignalThread* parent) - : Thread(std::make_unique(), /*do_init=*/false), - parent_(parent) { - DoInit(); -} - -DEPRECATED_SignalThread::Worker::~Worker() { - Stop(); -} - -void DEPRECATED_SignalThread::Worker::Run() { - parent_->Run(); -} - -void DEPRECATED_SignalThread::Run() { - DoWork(); - { - EnterExit ee(this); - if (main_) { - main_->Post(RTC_FROM_HERE, this, ST_MSG_WORKER_DONE); - } - } -} - -void DEPRECATED_SignalThread::OnMainThreadDestroyed() { - EnterExit ee(this); - main_ = nullptr; -} - -bool DEPRECATED_SignalThread::Worker::IsProcessingMessagesForTesting() { - return false; -} - -} // namespace rtc diff --git a/rtc_base/deprecated/signal_thread.h b/rtc_base/deprecated/signal_thread.h deleted file mode 100644 index 10805ad456..0000000000 --- a/rtc_base/deprecated/signal_thread.h +++ /dev/null @@ -1,172 +0,0 @@ -/* - * Copyright 2004 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef RTC_BASE_DEPRECATED_SIGNAL_THREAD_H_ -#define RTC_BASE_DEPRECATED_SIGNAL_THREAD_H_ - -#include - -#include "rtc_base/checks.h" -#include "rtc_base/constructor_magic.h" -#include "rtc_base/deprecated/recursive_critical_section.h" -#include "rtc_base/deprecation.h" -#include "rtc_base/message_handler.h" -#include "rtc_base/third_party/sigslot/sigslot.h" -#include "rtc_base/thread.h" -#include "rtc_base/thread_annotations.h" - -namespace rtc { - -/////////////////////////////////////////////////////////////////////////////// -// NOTE: this class has been deprecated. Do not use for new code. New code -// should use factilities exposed by api/task_queue/ instead. -// -// SignalThread - Base class for worker threads. The main thread should call -// Start() to begin work, and then follow one of these models: -// Normal: Wait for SignalWorkDone, and then call Release to destroy. -// Cancellation: Call Release(true), to abort the worker thread. -// Fire-and-forget: Call Release(false), which allows the thread to run to -// completion, and then self-destruct without further notification. -// Periodic tasks: Wait for SignalWorkDone, then eventually call Start() -// again to repeat the task. When the instance isn't needed anymore, -// call Release. DoWork, OnWorkStart and OnWorkStop are called again, -// on a new thread. -// The subclass should override DoWork() to perform the background task. By -// periodically calling ContinueWork(), it can check for cancellation. -// OnWorkStart and OnWorkDone can be overridden to do pre- or post-work -// tasks in the context of the main thread. -/////////////////////////////////////////////////////////////////////////////// - -class DEPRECATED_SignalThread : public sigslot::has_slots<>, - protected MessageHandlerAutoCleanup { - public: - DEPRECATED_SignalThread(); - - // Context: Main Thread. Call before Start to change the worker's name. - bool SetName(const std::string& name, const void* obj); - - // Context: Main Thread. Call to begin the worker thread. - void Start(); - - // Context: Main Thread. If the worker thread is not running, deletes the - // object immediately. Otherwise, asks the worker thread to abort processing, - // and schedules the object to be deleted once the worker exits. - // SignalWorkDone will not be signalled. If wait is true, does not return - // until the thread is deleted. - void Destroy(bool wait); - - // Context: Main Thread. If the worker thread is complete, deletes the - // object immediately. Otherwise, schedules the object to be deleted once - // the worker thread completes. SignalWorkDone will be signalled. - void Release(); - - // Context: Main Thread. Signalled when work is complete. - sigslot::signal1 SignalWorkDone; - - enum { ST_MSG_WORKER_DONE, ST_MSG_FIRST_AVAILABLE }; - - protected: - ~DEPRECATED_SignalThread() override; - - Thread* worker() { return &worker_; } - - // Context: Main Thread. Subclass should override to do pre-work setup. - virtual void OnWorkStart() {} - - // Context: Worker Thread. Subclass should override to do work. - virtual void DoWork() = 0; - - // Context: Worker Thread. Subclass should call periodically to - // dispatch messages and determine if the thread should terminate. - bool ContinueWork(); - - // Context: Worker Thread. Subclass should override when extra work is - // needed to abort the worker thread. - virtual void OnWorkStop() {} - - // Context: Main Thread. Subclass should override to do post-work cleanup. - virtual void OnWorkDone() {} - - // Context: Any Thread. If subclass overrides, be sure to call the base - // implementation. Do not use (message_id < ST_MSG_FIRST_AVAILABLE) - void OnMessage(Message* msg) override; - - private: - enum State { - kInit, // Initialized, but not started - kRunning, // Started and doing work - kReleasing, // Same as running, but to be deleted when work is done - kComplete, // Work is done - kStopping, // Work is being interrupted - }; - - class Worker : public Thread { - public: - explicit Worker(DEPRECATED_SignalThread* parent); - - Worker() = delete; - Worker(const Worker&) = delete; - Worker& operator=(const Worker&) = delete; - - ~Worker() override; - void Run() override; - bool IsProcessingMessagesForTesting() override; - - private: - DEPRECATED_SignalThread* parent_; - }; - - class RTC_SCOPED_LOCKABLE EnterExit { - public: - explicit EnterExit(DEPRECATED_SignalThread* t) - RTC_EXCLUSIVE_LOCK_FUNCTION(t->cs_) - : t_(t) { - t_->cs_.Enter(); - // If refcount_ is zero then the object has already been deleted and we - // will be double-deleting it in ~EnterExit()! (shouldn't happen) - RTC_DCHECK_NE(0, t_->refcount_); - ++t_->refcount_; - } - - EnterExit() = delete; - EnterExit(const EnterExit&) = delete; - EnterExit& operator=(const EnterExit&) = delete; - - ~EnterExit() RTC_UNLOCK_FUNCTION() { - bool d = (0 == --t_->refcount_); - t_->cs_.Leave(); - if (d) - delete t_; - } - - private: - DEPRECATED_SignalThread* t_; - }; - - void Run(); - void OnMainThreadDestroyed(); - - Thread* main_; - Worker worker_; - RecursiveCriticalSection cs_; - State state_ RTC_GUARDED_BY(cs_); - int refcount_ RTC_GUARDED_BY(cs_); - bool destroy_called_ RTC_GUARDED_BY(cs_) = false; - - RTC_DISALLOW_COPY_AND_ASSIGN(DEPRECATED_SignalThread); -}; - -typedef RTC_DEPRECATED DEPRECATED_SignalThread SignalThread; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace rtc - -#endif // RTC_BASE_DEPRECATED_SIGNAL_THREAD_H_ diff --git a/rtc_base/deprecated/signal_thread_unittest.cc b/rtc_base/deprecated/signal_thread_unittest.cc deleted file mode 100644 index f5a49aad63..0000000000 --- a/rtc_base/deprecated/signal_thread_unittest.cc +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Copyright 2004 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "rtc_base/signal_thread.h" - -#include - -#include "rtc_base/constructor_magic.h" -#include "rtc_base/gunit.h" -#include "rtc_base/null_socket_server.h" -#include "rtc_base/synchronization/mutex.h" -#include "rtc_base/thread.h" -#include "rtc_base/thread_annotations.h" -#include "test/gtest.h" - -namespace rtc { -namespace { - -// 10 seconds. -static const int kTimeout = 10000; - -class SignalThreadTest : public ::testing::Test, public sigslot::has_slots<> { - public: - class SlowSignalThread : public DEPRECATED_SignalThread { - public: - explicit SlowSignalThread(SignalThreadTest* harness) : harness_(harness) {} - - ~SlowSignalThread() override { - EXPECT_EQ(harness_->main_thread_, Thread::Current()); - ++harness_->thread_deleted_; - } - - const SignalThreadTest* harness() { return harness_; } - - protected: - void OnWorkStart() override { - ASSERT_TRUE(harness_ != nullptr); - ++harness_->thread_started_; - EXPECT_EQ(harness_->main_thread_, Thread::Current()); - EXPECT_FALSE(worker()->RunningForTest()); // not started yet - } - - void OnWorkStop() override { - ++harness_->thread_stopped_; - EXPECT_EQ(harness_->main_thread_, Thread::Current()); - EXPECT_TRUE(worker()->RunningForTest()); // not stopped yet - } - - void OnWorkDone() override { - ++harness_->thread_done_; - EXPECT_EQ(harness_->main_thread_, Thread::Current()); - EXPECT_TRUE(worker()->RunningForTest()); // not stopped yet - } - - void DoWork() override { - EXPECT_NE(harness_->main_thread_, Thread::Current()); - EXPECT_EQ(worker(), Thread::Current()); - Thread::Current()->socketserver()->Wait(250, false); - } - - private: - SignalThreadTest* harness_; - RTC_DISALLOW_COPY_AND_ASSIGN(SlowSignalThread); - }; - - void OnWorkComplete(rtc::DEPRECATED_SignalThread* thread) { - SlowSignalThread* t = static_cast(thread); - EXPECT_EQ(t->harness(), this); - EXPECT_EQ(main_thread_, Thread::Current()); - - ++thread_completed_; - if (!called_release_) { - thread->Release(); - } - } - - void SetUp() override { - main_thread_ = Thread::Current(); - thread_ = new SlowSignalThread(this); - thread_->SignalWorkDone.connect(this, &SignalThreadTest::OnWorkComplete); - called_release_ = false; - thread_started_ = 0; - thread_done_ = 0; - thread_completed_ = 0; - thread_stopped_ = 0; - thread_deleted_ = 0; - } - - void ExpectState(int started, - int done, - int completed, - int stopped, - int deleted) { - EXPECT_EQ(started, thread_started_); - EXPECT_EQ(done, thread_done_); - EXPECT_EQ(completed, thread_completed_); - EXPECT_EQ(stopped, thread_stopped_); - EXPECT_EQ(deleted, thread_deleted_); - } - - void ExpectStateWait(int started, - int done, - int completed, - int stopped, - int deleted, - int timeout) { - EXPECT_EQ_WAIT(started, thread_started_, timeout); - EXPECT_EQ_WAIT(done, thread_done_, timeout); - EXPECT_EQ_WAIT(completed, thread_completed_, timeout); - EXPECT_EQ_WAIT(stopped, thread_stopped_, timeout); - EXPECT_EQ_WAIT(deleted, thread_deleted_, timeout); - } - - Thread* main_thread_; - SlowSignalThread* thread_; - bool called_release_; - - int thread_started_; - int thread_done_; - int thread_completed_; - int thread_stopped_; - int thread_deleted_; -}; - -class OwnerThread : public Thread, public sigslot::has_slots<> { - public: - explicit OwnerThread(SignalThreadTest* harness) - : Thread(std::make_unique()), - harness_(harness), - has_run_(false) {} - - ~OwnerThread() override { Stop(); } - - void Run() override { - SignalThreadTest::SlowSignalThread* signal_thread = - new SignalThreadTest::SlowSignalThread(harness_); - signal_thread->SignalWorkDone.connect(this, &OwnerThread::OnWorkDone); - signal_thread->Start(); - Thread::Current()->socketserver()->Wait(100, false); - signal_thread->Release(); - // Delete |signal_thread|. - signal_thread->Destroy(true); - { - webrtc::MutexLock lock(&mutex_); - has_run_ = true; - } - } - - bool has_run() { - webrtc::MutexLock lock(&mutex_); - return has_run_; - } - void OnWorkDone(DEPRECATED_SignalThread* /*signal_thread*/) { - FAIL() << " This shouldn't get called."; - } - - private: - webrtc::Mutex mutex_; - SignalThreadTest* harness_; - bool has_run_ RTC_GUARDED_BY(mutex_); - RTC_DISALLOW_COPY_AND_ASSIGN(OwnerThread); -}; - -// Test for when the main thread goes away while the -// signal thread is still working. This may happen -// when shutting down the process. -TEST_F(SignalThreadTest, OwnerThreadGoesAway) { - // We don't use |thread_| for this test, so destroy it. - thread_->Destroy(true); - - { - std::unique_ptr owner(new OwnerThread(this)); - main_thread_ = owner.get(); - owner->Start(); - while (!owner->has_run()) { - Thread::Current()->socketserver()->Wait(10, false); - } - } - // At this point the main thread has gone away. - // Give the SignalThread a little time to do its callback, - // which will crash if the signal thread doesn't handle - // this situation well. - Thread::Current()->socketserver()->Wait(500, false); -} - -TEST_F(SignalThreadTest, ThreadFinishes) { - thread_->Start(); - ExpectState(1, 0, 0, 0, 0); - ExpectStateWait(1, 1, 1, 0, 1, kTimeout); -} - -TEST_F(SignalThreadTest, ReleasedThreadFinishes) { - thread_->Start(); - ExpectState(1, 0, 0, 0, 0); - thread_->Release(); - called_release_ = true; - ExpectState(1, 0, 0, 0, 0); - ExpectStateWait(1, 1, 1, 0, 1, kTimeout); -} - -TEST_F(SignalThreadTest, DestroyedThreadCleansUp) { - thread_->Start(); - ExpectState(1, 0, 0, 0, 0); - thread_->Destroy(true); - ExpectState(1, 0, 0, 1, 1); - Thread::Current()->ProcessMessages(0); - ExpectState(1, 0, 0, 1, 1); -} - -TEST_F(SignalThreadTest, DeferredDestroyedThreadCleansUp) { - thread_->Start(); - ExpectState(1, 0, 0, 0, 0); - thread_->Destroy(false); - ExpectState(1, 0, 0, 1, 0); - ExpectStateWait(1, 1, 0, 1, 1, kTimeout); -} - -} // namespace -} // namespace rtc diff --git a/rtc_base/deprecation.h b/rtc_base/deprecation.h deleted file mode 100644 index f285ab04bb..0000000000 --- a/rtc_base/deprecation.h +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2015 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef RTC_BASE_DEPRECATION_H_ -#define RTC_BASE_DEPRECATION_H_ - -// Annotate the declarations of deprecated functions with this to cause a -// compiler warning when they're used. Like so: -// -// RTC_DEPRECATED std::pony PonyPlz(const std::pony_spec& ps); -// -// NOTE 1: The annotation goes on the declaration in the .h file, not the -// definition in the .cc file! -// -// NOTE 2: In order to keep unit testing the deprecated function without -// getting warnings, do something like this: -// -// std::pony DEPRECATED_PonyPlz(const std::pony_spec& ps); -// RTC_DEPRECATED inline std::pony PonyPlz(const std::pony_spec& ps) { -// return DEPRECATED_PonyPlz(ps); -// } -// -// In other words, rename the existing function, and provide an inline wrapper -// using the original name that calls it. That way, callers who are willing to -// call it using the DEPRECATED_-prefixed name don't get the warning. -// -// TODO(kwiberg): Remove this when we can use [[deprecated]] from C++14. -#if defined(_MSC_VER) -// Note: Deprecation warnings seem to fail to trigger on Windows -// (https://bugs.chromium.org/p/webrtc/issues/detail?id=5368). -#define RTC_DEPRECATED __declspec(deprecated) -#elif defined(__GNUC__) -#define RTC_DEPRECATED __attribute__((__deprecated__)) -#else -#define RTC_DEPRECATED -#endif - -#endif // RTC_BASE_DEPRECATION_H_ diff --git a/rtc_base/event_tracer.cc b/rtc_base/event_tracer.cc index 3af8183b1f..1a2b41ec5c 100644 --- a/rtc_base/event_tracer.cc +++ b/rtc_base/event_tracer.cc @@ -17,6 +17,7 @@ #include #include +#include "api/sequence_checker.h" #include "rtc_base/atomic_ops.h" #include "rtc_base/checks.h" #include "rtc_base/event.h" @@ -25,7 +26,6 @@ #include "rtc_base/platform_thread_types.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" #include "rtc_base/time_utils.h" #include "rtc_base/trace_event.h" @@ -79,19 +79,12 @@ namespace rtc { namespace tracing { namespace { -static void EventTracingThreadFunc(void* params); - // Atomic-int fast path for avoiding logging when disabled. static volatile int g_event_logging_active = 0; // TODO(pbos): Log metadata for all threads, etc. class EventLogger final { public: - EventLogger() - : logging_thread_(EventTracingThreadFunc, - this, - "EventTracingThread", - kLowPriority) {} ~EventLogger() { RTC_DCHECK(thread_checker_.IsCurrent()); } void AddTraceEvent(const char* name, @@ -209,7 +202,8 @@ class EventLogger final { rtc::AtomicOps::CompareAndSwap(&g_event_logging_active, 0, 1)); // Finally start, everything should be set up now. - logging_thread_.Start(); + logging_thread_ = + PlatformThread::SpawnJoinable([this] { Log(); }, "EventTracingThread"); TRACE_EVENT_INSTANT0("webrtc", "EventLogger::Start"); } @@ -223,7 +217,7 @@ class EventLogger final { // Wake up logging thread to finish writing. shutdown_event_.Set(); // Join the logging thread. - logging_thread_.Stop(); + logging_thread_.Finalize(); } private: @@ -321,15 +315,11 @@ class EventLogger final { std::vector trace_events_ RTC_GUARDED_BY(mutex_); rtc::PlatformThread logging_thread_; rtc::Event shutdown_event_; - rtc::ThreadChecker thread_checker_; + webrtc::SequenceChecker thread_checker_; FILE* output_file_ = nullptr; bool output_file_owned_ = false; }; -static void EventTracingThreadFunc(void* params) { - static_cast(params)->Log(); -} - static EventLogger* volatile g_event_logger = nullptr; static const char* const kDisabledTracePrefix = TRACE_DISABLED_BY_DEFAULT(""); const unsigned char* InternalGetCategoryEnabled(const char* name) { diff --git a/rtc_base/event_unittest.cc b/rtc_base/event_unittest.cc index 31118877cf..a634d6e426 100644 --- a/rtc_base/event_unittest.cc +++ b/rtc_base/event_unittest.cc @@ -43,22 +43,21 @@ TEST(EventTest, AutoReset) { class SignalerThread { public: - SignalerThread() : thread_(&ThreadFn, this, "EventPerf") {} void Start(Event* writer, Event* reader) { writer_ = writer; reader_ = reader; - thread_.Start(); + thread_ = PlatformThread::SpawnJoinable( + [this] { + while (!stop_event_.Wait(0)) { + writer_->Set(); + reader_->Wait(Event::kForever); + } + }, + "EventPerf"); } void Stop() { stop_event_.Set(); - thread_.Stop(); - } - static void ThreadFn(void* param) { - auto* me = static_cast(param); - while (!me->stop_event_.Wait(0)) { - me->writer_->Set(); - me->reader_->Wait(Event::kForever); - } + thread_.Finalize(); } Event stop_event_; Event* writer_; diff --git a/rtc_base/experiments/BUILD.gn b/rtc_base/experiments/BUILD.gn index a40c9e0d80..b0a729abfe 100644 --- a/rtc_base/experiments/BUILD.gn +++ b/rtc_base/experiments/BUILD.gn @@ -130,6 +130,20 @@ rtc_library("cpu_speed_experiment") { absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } +rtc_library("encoder_info_settings") { + sources = [ + "encoder_info_settings.cc", + "encoder_info_settings.h", + ] + deps = [ + ":field_trial_parser", + "../:rtc_base_approved", + "../../api/video_codecs:video_codecs_api", + "../../system_wrappers:field_trial", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] +} + rtc_library("rtt_mult_experiment") { sources = [ "rtt_mult_experiment.cc", @@ -217,13 +231,14 @@ rtc_library("min_video_bitrate_experiment") { absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } -if (rtc_include_tests) { +if (rtc_include_tests && !build_with_chromium) { rtc_library("experiments_unittests") { testonly = true sources = [ "balanced_degradation_settings_unittest.cc", "cpu_speed_experiment_unittest.cc", + "encoder_info_settings_unittest.cc", "field_trial_list_unittest.cc", "field_trial_parser_unittest.cc", "field_trial_units_unittest.cc", @@ -241,6 +256,7 @@ if (rtc_include_tests) { deps = [ ":balanced_degradation_settings", ":cpu_speed_experiment", + ":encoder_info_settings", ":field_trial_parser", ":keyframe_interval_settings_experiment", ":min_video_bitrate_experiment", diff --git a/rtc_base/experiments/balanced_degradation_settings.cc b/rtc_base/experiments/balanced_degradation_settings.cc index d061597f70..90d44efb10 100644 --- a/rtc_base/experiments/balanced_degradation_settings.cc +++ b/rtc_base/experiments/balanced_degradation_settings.cc @@ -93,7 +93,8 @@ bool IsValid(const BalancedDegradationSettings::CodecTypeSpecific& config1, bool IsValid(const std::vector& configs) { if (configs.size() <= 1) { - RTC_LOG(LS_WARNING) << "Unsupported size, value ignored."; + if (configs.size() == 1) + RTC_LOG(LS_WARNING) << "Unsupported size, value ignored."; return false; } for (const auto& config : configs) { diff --git a/rtc_base/experiments/cpu_speed_experiment.cc b/rtc_base/experiments/cpu_speed_experiment.cc index 0f53320093..7e61255260 100644 --- a/rtc_base/experiments/cpu_speed_experiment.cc +++ b/rtc_base/experiments/cpu_speed_experiment.cc @@ -25,7 +25,6 @@ constexpr int kMaxSetting = -1; std::vector GetValidOrEmpty( const std::vector& configs) { if (configs.empty()) { - RTC_LOG(LS_WARNING) << "Unsupported size, value ignored."; return {}; } diff --git a/rtc_base/experiments/encoder_info_settings.cc b/rtc_base/experiments/encoder_info_settings.cc new file mode 100644 index 0000000000..9e1a5190a3 --- /dev/null +++ b/rtc_base/experiments/encoder_info_settings.cc @@ -0,0 +1,120 @@ +/* + * Copyright 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "rtc_base/experiments/encoder_info_settings.h" + +#include + +#include "rtc_base/experiments/field_trial_list.h" +#include "rtc_base/logging.h" +#include "system_wrappers/include/field_trial.h" + +namespace webrtc { +namespace { + +std::vector ToResolutionBitrateLimits( + const std::vector& limits) { + std::vector result; + for (const auto& limit : limits) { + result.push_back(VideoEncoder::ResolutionBitrateLimits( + limit.frame_size_pixels, limit.min_start_bitrate_bps, + limit.min_bitrate_bps, limit.max_bitrate_bps)); + } + return result; +} + +} // namespace + +// Default bitrate limits for simulcast with one active stream: +// {frame_size_pixels, min_start_bitrate_bps, min_bitrate_bps, max_bitrate_bps}. +std::vector +EncoderInfoSettings::GetDefaultSinglecastBitrateLimits( + VideoCodecType codec_type) { + // Specific limits for VP9. Other codecs use VP8 limits. + if (codec_type == kVideoCodecVP9) { + return {{320 * 180, 0, 30000, 150000}, + {480 * 270, 120000, 30000, 300000}, + {640 * 360, 190000, 30000, 420000}, + {960 * 540, 350000, 30000, 1000000}, + {1280 * 720, 480000, 30000, 1500000}}; + } + + return {{320 * 180, 0, 30000, 300000}, + {480 * 270, 200000, 30000, 500000}, + {640 * 360, 300000, 30000, 800000}, + {960 * 540, 500000, 30000, 1500000}, + {1280 * 720, 900000, 30000, 2500000}}; +} + +absl::optional +EncoderInfoSettings::GetDefaultSinglecastBitrateLimitsForResolution( + VideoCodecType codec_type, + int frame_size_pixels) { + VideoEncoder::EncoderInfo info; + info.resolution_bitrate_limits = + GetDefaultSinglecastBitrateLimits(codec_type); + return info.GetEncoderBitrateLimitsForResolution(frame_size_pixels); +} + +EncoderInfoSettings::EncoderInfoSettings(std::string name) + : requested_resolution_alignment_("requested_resolution_alignment"), + apply_alignment_to_all_simulcast_layers_( + "apply_alignment_to_all_simulcast_layers") { + FieldTrialStructList bitrate_limits( + {FieldTrialStructMember( + "frame_size_pixels", + [](BitrateLimit* b) { return &b->frame_size_pixels; }), + FieldTrialStructMember( + "min_start_bitrate_bps", + [](BitrateLimit* b) { return &b->min_start_bitrate_bps; }), + FieldTrialStructMember( + "min_bitrate_bps", + [](BitrateLimit* b) { return &b->min_bitrate_bps; }), + FieldTrialStructMember( + "max_bitrate_bps", + [](BitrateLimit* b) { return &b->max_bitrate_bps; })}, + {}); + + if (field_trial::FindFullName(name).empty()) { + // Encoder name not found, use common string applying to all encoders. + name = "WebRTC-GetEncoderInfoOverride"; + } + + ParseFieldTrial({&bitrate_limits, &requested_resolution_alignment_, + &apply_alignment_to_all_simulcast_layers_}, + field_trial::FindFullName(name)); + + resolution_bitrate_limits_ = ToResolutionBitrateLimits(bitrate_limits.Get()); +} + +absl::optional EncoderInfoSettings::requested_resolution_alignment() + const { + if (requested_resolution_alignment_ && + requested_resolution_alignment_.Value() < 1) { + RTC_LOG(LS_WARNING) << "Unsupported alignment value, ignored."; + return absl::nullopt; + } + return requested_resolution_alignment_.GetOptional(); +} + +EncoderInfoSettings::~EncoderInfoSettings() {} + +SimulcastEncoderAdapterEncoderInfoSettings:: + SimulcastEncoderAdapterEncoderInfoSettings() + : EncoderInfoSettings( + "WebRTC-SimulcastEncoderAdapter-GetEncoderInfoOverride") {} + +LibvpxVp8EncoderInfoSettings::LibvpxVp8EncoderInfoSettings() + : EncoderInfoSettings("WebRTC-VP8-GetEncoderInfoOverride") {} + +LibvpxVp9EncoderInfoSettings::LibvpxVp9EncoderInfoSettings() + : EncoderInfoSettings("WebRTC-VP9-GetEncoderInfoOverride") {} + +} // namespace webrtc diff --git a/rtc_base/experiments/encoder_info_settings.h b/rtc_base/experiments/encoder_info_settings.h new file mode 100644 index 0000000000..9cbb5875bb --- /dev/null +++ b/rtc_base/experiments/encoder_info_settings.h @@ -0,0 +1,83 @@ +/* + * Copyright 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef RTC_BASE_EXPERIMENTS_ENCODER_INFO_SETTINGS_H_ +#define RTC_BASE_EXPERIMENTS_ENCODER_INFO_SETTINGS_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "api/video_codecs/video_encoder.h" +#include "rtc_base/experiments/field_trial_parser.h" + +namespace webrtc { + +class EncoderInfoSettings { + public: + virtual ~EncoderInfoSettings(); + + // Bitrate limits per resolution. + struct BitrateLimit { + int frame_size_pixels = 0; // The video frame size. + int min_start_bitrate_bps = 0; // The minimum bitrate to start encoding. + int min_bitrate_bps = 0; // The minimum bitrate. + int max_bitrate_bps = 0; // The maximum bitrate. + }; + + absl::optional requested_resolution_alignment() const; + bool apply_alignment_to_all_simulcast_layers() const { + return apply_alignment_to_all_simulcast_layers_.Get(); + } + std::vector resolution_bitrate_limits() + const { + return resolution_bitrate_limits_; + } + + static std::vector + GetDefaultSinglecastBitrateLimits(VideoCodecType codec_type); + + static absl::optional + GetDefaultSinglecastBitrateLimitsForResolution(VideoCodecType codec_type, + int frame_size_pixels); + + protected: + explicit EncoderInfoSettings(std::string name); + + private: + FieldTrialOptional requested_resolution_alignment_; + FieldTrialFlag apply_alignment_to_all_simulcast_layers_; + std::vector resolution_bitrate_limits_; +}; + +// EncoderInfo settings for SimulcastEncoderAdapter. +class SimulcastEncoderAdapterEncoderInfoSettings : public EncoderInfoSettings { + public: + SimulcastEncoderAdapterEncoderInfoSettings(); + ~SimulcastEncoderAdapterEncoderInfoSettings() override {} +}; + +// EncoderInfo settings for LibvpxVp8Encoder. +class LibvpxVp8EncoderInfoSettings : public EncoderInfoSettings { + public: + LibvpxVp8EncoderInfoSettings(); + ~LibvpxVp8EncoderInfoSettings() override {} +}; + +// EncoderInfo settings for LibvpxVp9Encoder. +class LibvpxVp9EncoderInfoSettings : public EncoderInfoSettings { + public: + LibvpxVp9EncoderInfoSettings(); + ~LibvpxVp9EncoderInfoSettings() override {} +}; + +} // namespace webrtc + +#endif // RTC_BASE_EXPERIMENTS_ENCODER_INFO_SETTINGS_H_ diff --git a/rtc_base/experiments/encoder_info_settings_unittest.cc b/rtc_base/experiments/encoder_info_settings_unittest.cc new file mode 100644 index 0000000000..aabb68718c --- /dev/null +++ b/rtc_base/experiments/encoder_info_settings_unittest.cc @@ -0,0 +1,102 @@ +/* + * Copyright 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "rtc_base/experiments/encoder_info_settings.h" + +#include "rtc_base/gunit.h" +#include "test/field_trial.h" +#include "test/gmock.h" + +namespace webrtc { + +TEST(SimulcastEncoderAdapterSettingsTest, NoValuesWithoutFieldTrial) { + SimulcastEncoderAdapterEncoderInfoSettings settings; + EXPECT_EQ(absl::nullopt, settings.requested_resolution_alignment()); + EXPECT_FALSE(settings.apply_alignment_to_all_simulcast_layers()); + EXPECT_TRUE(settings.resolution_bitrate_limits().empty()); +} + +TEST(SimulcastEncoderAdapterSettingsTest, NoValueForInvalidAlignment) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-SimulcastEncoderAdapter-GetEncoderInfoOverride/" + "requested_resolution_alignment:0/"); + + SimulcastEncoderAdapterEncoderInfoSettings settings; + EXPECT_EQ(absl::nullopt, settings.requested_resolution_alignment()); +} + +TEST(SimulcastEncoderAdapterSettingsTest, GetResolutionAlignment) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-SimulcastEncoderAdapter-GetEncoderInfoOverride/" + "requested_resolution_alignment:2/"); + + SimulcastEncoderAdapterEncoderInfoSettings settings; + EXPECT_EQ(2, settings.requested_resolution_alignment()); + EXPECT_FALSE(settings.apply_alignment_to_all_simulcast_layers()); + EXPECT_TRUE(settings.resolution_bitrate_limits().empty()); +} + +TEST(SimulcastEncoderAdapterSettingsTest, GetApplyAlignment) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-SimulcastEncoderAdapter-GetEncoderInfoOverride/" + "requested_resolution_alignment:3," + "apply_alignment_to_all_simulcast_layers/"); + + SimulcastEncoderAdapterEncoderInfoSettings settings; + EXPECT_EQ(3, settings.requested_resolution_alignment()); + EXPECT_TRUE(settings.apply_alignment_to_all_simulcast_layers()); + EXPECT_TRUE(settings.resolution_bitrate_limits().empty()); +} + +TEST(SimulcastEncoderAdapterSettingsTest, GetResolutionBitrateLimits) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-SimulcastEncoderAdapter-GetEncoderInfoOverride/" + "frame_size_pixels:123," + "min_start_bitrate_bps:11000," + "min_bitrate_bps:44000," + "max_bitrate_bps:77000/"); + + SimulcastEncoderAdapterEncoderInfoSettings settings; + EXPECT_EQ(absl::nullopt, settings.requested_resolution_alignment()); + EXPECT_FALSE(settings.apply_alignment_to_all_simulcast_layers()); + EXPECT_THAT(settings.resolution_bitrate_limits(), + ::testing::ElementsAre(VideoEncoder::ResolutionBitrateLimits{ + 123, 11000, 44000, 77000})); +} + +TEST(SimulcastEncoderAdapterSettingsTest, GetResolutionBitrateLimitsWithList) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-SimulcastEncoderAdapter-GetEncoderInfoOverride/" + "frame_size_pixels:123|456|789," + "min_start_bitrate_bps:11000|22000|33000," + "min_bitrate_bps:44000|55000|66000," + "max_bitrate_bps:77000|88000|99000/"); + + SimulcastEncoderAdapterEncoderInfoSettings settings; + EXPECT_THAT( + settings.resolution_bitrate_limits(), + ::testing::ElementsAre( + VideoEncoder::ResolutionBitrateLimits{123, 11000, 44000, 77000}, + VideoEncoder::ResolutionBitrateLimits{456, 22000, 55000, 88000}, + VideoEncoder::ResolutionBitrateLimits{789, 33000, 66000, 99000})); +} + +TEST(EncoderSettingsTest, CommonSettingsUsedIfEncoderNameUnspecified) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-VP8-GetEncoderInfoOverride/requested_resolution_alignment:2/" + "WebRTC-GetEncoderInfoOverride/requested_resolution_alignment:3/"); + + LibvpxVp8EncoderInfoSettings vp8_settings; + EXPECT_EQ(2, vp8_settings.requested_resolution_alignment()); + LibvpxVp9EncoderInfoSettings vp9_settings; + EXPECT_EQ(3, vp9_settings.requested_resolution_alignment()); +} + +} // namespace webrtc diff --git a/rtc_base/experiments/field_trial_parser.cc b/rtc_base/experiments/field_trial_parser.cc index b88d0f97c4..8fc89cec8f 100644 --- a/rtc_base/experiments/field_trial_parser.cc +++ b/rtc_base/experiments/field_trial_parser.cc @@ -83,7 +83,10 @@ void ParseFieldTrial( RTC_LOG(LS_WARNING) << "Failed to read empty key field with value '" << key << "' in trial: \"" << trial_string << "\""; } - } else { + } else if (key.empty() || key[0] != '_') { + // "_" is be used to prefix keys that are part of the string for + // debugging purposes but not neccessarily used. + // e.g. WebRTC-Experiment/param: value, _DebuggingString RTC_LOG(LS_INFO) << "No field with key: '" << key << "' (found in trial: \"" << trial_string << "\")"; std::string valid_keys; diff --git a/rtc_base/experiments/keyframe_interval_settings.cc b/rtc_base/experiments/keyframe_interval_settings.cc index 2f19a1c53f..76c85cbbad 100644 --- a/rtc_base/experiments/keyframe_interval_settings.cc +++ b/rtc_base/experiments/keyframe_interval_settings.cc @@ -22,11 +22,8 @@ constexpr char kFieldTrialName[] = "WebRTC-KeyframeInterval"; KeyframeIntervalSettings::KeyframeIntervalSettings( const WebRtcKeyValueConfig* const key_value_config) - : min_keyframe_send_interval_ms_("min_keyframe_send_interval_ms"), - max_wait_for_keyframe_ms_("max_wait_for_keyframe_ms"), - max_wait_for_frame_ms_("max_wait_for_frame_ms") { - ParseFieldTrial({&min_keyframe_send_interval_ms_, &max_wait_for_keyframe_ms_, - &max_wait_for_frame_ms_}, + : min_keyframe_send_interval_ms_("min_keyframe_send_interval_ms") { + ParseFieldTrial({&min_keyframe_send_interval_ms_}, key_value_config->Lookup(kFieldTrialName)); } @@ -39,13 +36,4 @@ absl::optional KeyframeIntervalSettings::MinKeyframeSendIntervalMs() const { return min_keyframe_send_interval_ms_.GetOptional(); } - -absl::optional KeyframeIntervalSettings::MaxWaitForKeyframeMs() const { - return max_wait_for_keyframe_ms_.GetOptional(); -} - -absl::optional KeyframeIntervalSettings::MaxWaitForFrameMs() const { - return max_wait_for_frame_ms_.GetOptional(); -} - } // namespace webrtc diff --git a/rtc_base/experiments/keyframe_interval_settings.h b/rtc_base/experiments/keyframe_interval_settings.h index 7c8d6d364a..3f253f0022 100644 --- a/rtc_base/experiments/keyframe_interval_settings.h +++ b/rtc_base/experiments/keyframe_interval_settings.h @@ -17,6 +17,9 @@ namespace webrtc { +// TODO(bugs.webrtc.org/10427): Remove and replace with proper configuration +// parameter, or move to using FIR if intent is to avoid triggering multiple +// times to PLIs corresponding to the same request when RTT is large. class KeyframeIntervalSettings final { public: static KeyframeIntervalSettings ParseFromFieldTrials(); @@ -25,22 +28,11 @@ class KeyframeIntervalSettings final { // The encoded keyframe send rate is <= 1/MinKeyframeSendIntervalMs(). absl::optional MinKeyframeSendIntervalMs() const; - // Receiver side. - // The keyframe request send rate is - // - when we have not received a key frame at all: - // <= 1/MaxWaitForKeyframeMs() - // - when we have not received a frame recently: - // <= 1/MaxWaitForFrameMs() - absl::optional MaxWaitForKeyframeMs() const; - absl::optional MaxWaitForFrameMs() const; - private: explicit KeyframeIntervalSettings( const WebRtcKeyValueConfig* key_value_config); FieldTrialOptional min_keyframe_send_interval_ms_; - FieldTrialOptional max_wait_for_keyframe_ms_; - FieldTrialOptional max_wait_for_frame_ms_; }; } // namespace webrtc diff --git a/rtc_base/experiments/keyframe_interval_settings_unittest.cc b/rtc_base/experiments/keyframe_interval_settings_unittest.cc index 7d89a4c000..25cebbcd70 100644 --- a/rtc_base/experiments/keyframe_interval_settings_unittest.cc +++ b/rtc_base/experiments/keyframe_interval_settings_unittest.cc @@ -27,60 +27,16 @@ TEST(KeyframeIntervalSettingsTest, ParsesMinKeyframeSendIntervalMs) { 100); } -TEST(KeyframeIntervalSettingsTest, ParsesMaxWaitForKeyframeMs) { - EXPECT_FALSE( - KeyframeIntervalSettings::ParseFromFieldTrials().MaxWaitForKeyframeMs()); - - test::ScopedFieldTrials field_trials( - "WebRTC-KeyframeInterval/max_wait_for_keyframe_ms:100/"); - EXPECT_EQ( - KeyframeIntervalSettings::ParseFromFieldTrials().MaxWaitForKeyframeMs(), - 100); -} - -TEST(KeyframeIntervalSettingsTest, ParsesMaxWaitForFrameMs) { - EXPECT_FALSE( - KeyframeIntervalSettings::ParseFromFieldTrials().MaxWaitForFrameMs()); - - test::ScopedFieldTrials field_trials( - "WebRTC-KeyframeInterval/max_wait_for_frame_ms:100/"); - EXPECT_EQ( - KeyframeIntervalSettings::ParseFromFieldTrials().MaxWaitForFrameMs(), - 100); -} - -TEST(KeyframeIntervalSettingsTest, ParsesAllValues) { - test::ScopedFieldTrials field_trials( - "WebRTC-KeyframeInterval/" - "min_keyframe_send_interval_ms:100," - "max_wait_for_keyframe_ms:101," - "max_wait_for_frame_ms:102/"); - EXPECT_EQ(KeyframeIntervalSettings::ParseFromFieldTrials() - .MinKeyframeSendIntervalMs(), - 100); - EXPECT_EQ( - KeyframeIntervalSettings::ParseFromFieldTrials().MaxWaitForKeyframeMs(), - 101); - EXPECT_EQ( - KeyframeIntervalSettings::ParseFromFieldTrials().MaxWaitForFrameMs(), - 102); -} - -TEST(KeyframeIntervalSettingsTest, DoesNotParseAllValuesWhenIncorrectlySet) { - EXPECT_FALSE( - KeyframeIntervalSettings::ParseFromFieldTrials().MaxWaitForFrameMs()); +TEST(KeyframeIntervalSettingsTest, DoesNotParseIncorrectValues) { + EXPECT_FALSE(KeyframeIntervalSettings::ParseFromFieldTrials() + .MinKeyframeSendIntervalMs()); test::ScopedFieldTrials field_trials( - "WebRTC-KeyframeInterval/" - "min_keyframe_send_interval_ms:a," - "max_wait_for_keyframe_ms:b," - "max_wait_for_frame_ms:c/"); + "WebRTC-KeyframeInterval/min_keyframe_send_interval_ms:a/"); + EXPECT_FALSE(KeyframeIntervalSettings::ParseFromFieldTrials() + .MinKeyframeSendIntervalMs()); EXPECT_FALSE(KeyframeIntervalSettings::ParseFromFieldTrials() .MinKeyframeSendIntervalMs()); - EXPECT_FALSE( - KeyframeIntervalSettings::ParseFromFieldTrials().MaxWaitForKeyframeMs()); - EXPECT_FALSE( - KeyframeIntervalSettings::ParseFromFieldTrials().MaxWaitForFrameMs()); } } // namespace diff --git a/rtc_base/experiments/quality_scaling_experiment.cc b/rtc_base/experiments/quality_scaling_experiment.cc index ca58ba858a..7d5722bbe3 100644 --- a/rtc_base/experiments/quality_scaling_experiment.cc +++ b/rtc_base/experiments/quality_scaling_experiment.cc @@ -25,6 +25,11 @@ constexpr int kMaxVp9Qp = 255; constexpr int kMaxH264Qp = 51; constexpr int kMaxGenericQp = 255; +#if !defined(WEBRTC_IOS) +constexpr char kDefaultQualityScalingSetttings[] = + "Enabled-29,95,149,205,24,37,26,36,0.9995,0.9999,1"; +#endif + absl::optional GetThresholds(int low, int high, int max) { @@ -38,15 +43,22 @@ absl::optional GetThresholds(int low, } // namespace bool QualityScalingExperiment::Enabled() { +#if defined(WEBRTC_IOS) return webrtc::field_trial::IsEnabled(kFieldTrial); +#else + return !webrtc::field_trial::IsDisabled(kFieldTrial); +#endif } absl::optional QualityScalingExperiment::ParseSettings() { - const std::string group = webrtc::field_trial::FindFullName(kFieldTrial); + std::string group = webrtc::field_trial::FindFullName(kFieldTrial); + // TODO(http://crbug.com/webrtc/12401): Completely remove the experiment code + // after few releases. +#if !defined(WEBRTC_IOS) if (group.empty()) - return absl::nullopt; - + group = kDefaultQualityScalingSetttings; +#endif Settings s; if (sscanf(group.c_str(), "Enabled-%d,%d,%d,%d,%d,%d,%d,%d,%f,%f,%d", &s.vp8_low, &s.vp8_high, &s.vp9_low, &s.vp9_high, &s.h264_low, diff --git a/rtc_base/experiments/quality_scaling_experiment_unittest.cc b/rtc_base/experiments/quality_scaling_experiment_unittest.cc index 7a345b629f..4507f1514f 100644 --- a/rtc_base/experiments/quality_scaling_experiment_unittest.cc +++ b/rtc_base/experiments/quality_scaling_experiment_unittest.cc @@ -38,10 +38,18 @@ void ExpectEqualConfig(QualityScalingExperiment::Config a, } } // namespace -TEST(QualityScalingExperimentTest, DisabledWithoutFieldTrial) { +#if !defined(WEBRTC_IOS) +// TODO(bugs.webrtc.org/12401): investigate why QualityScaler kicks in on iOS. +TEST(QualityScalingExperimentTest, DefaultEnabledWithoutFieldTrial) { + webrtc::test::ScopedFieldTrials field_trials(""); + EXPECT_TRUE(QualityScalingExperiment::Enabled()); +} +#else +TEST(QualityScalingExperimentTest, DefaultDisabledWithoutFieldTrialIOS) { webrtc::test::ScopedFieldTrials field_trials(""); EXPECT_FALSE(QualityScalingExperiment::Enabled()); } +#endif TEST(QualityScalingExperimentTest, EnabledWithFieldTrial) { webrtc::test::ScopedFieldTrials field_trials( @@ -59,10 +67,19 @@ TEST(QualityScalingExperimentTest, ParseSettings) { ExpectEqualSettings(kExpected, *settings); } +#if !defined(WEBRTC_IOS) +// TODO(bugs.webrtc.org/12401): investigate why QualityScaler kicks in on iOS. +TEST(QualityScalingExperimentTest, ParseSettingsUsesDefaultsWithoutFieldTrial) { + webrtc::test::ScopedFieldTrials field_trials(""); + // Uses some default hard coded values. + EXPECT_TRUE(QualityScalingExperiment::ParseSettings()); +} +#else TEST(QualityScalingExperimentTest, ParseSettingsFailsWithoutFieldTrial) { webrtc::test::ScopedFieldTrials field_trials(""); EXPECT_FALSE(QualityScalingExperiment::ParseSettings()); } +#endif TEST(QualityScalingExperimentTest, ParseSettingsFailsWithInvalidFieldTrial) { webrtc::test::ScopedFieldTrials field_trials( diff --git a/rtc_base/experiments/rate_control_settings.cc b/rtc_base/experiments/rate_control_settings.cc index 6766db62c3..bed194e683 100644 --- a/rtc_base/experiments/rate_control_settings.cc +++ b/rtc_base/experiments/rate_control_settings.cc @@ -24,10 +24,13 @@ namespace webrtc { namespace { -const int kDefaultAcceptedQueueMs = 250; +const int kDefaultAcceptedQueueMs = 350; const int kDefaultMinPushbackTargetBitrateBps = 30000; +const char kCongestionWindowDefaultFieldTrialString[] = + "QueueSize:350,MinBitrate:30000,DropFrame:true"; + const char kUseBaseHeavyVp8Tl3RateAllocationFieldTrialName[] = "WebRTC-UseBaseHeavyVP8TL3RateAllocation"; @@ -91,9 +94,13 @@ std::unique_ptr VideoRateControlConfig::Parser() { } RateControlSettings::RateControlSettings( - const WebRtcKeyValueConfig* const key_value_config) - : congestion_window_config_(CongestionWindowConfig::Parse( - key_value_config->Lookup(CongestionWindowConfig::kKey))) { + const WebRtcKeyValueConfig* const key_value_config) { + std::string congestion_window_config = + key_value_config->Lookup(CongestionWindowConfig::kKey).empty() + ? kCongestionWindowDefaultFieldTrialString + : key_value_config->Lookup(CongestionWindowConfig::kKey); + congestion_window_config_ = + CongestionWindowConfig::Parse(congestion_window_config); video_config_.vp8_base_heavy_tl3_alloc = IsEnabled( key_value_config, kUseBaseHeavyVp8Tl3RateAllocationFieldTrialName); ParseHysteresisFactor(key_value_config, kVideoHysteresisFieldTrialname, diff --git a/rtc_base/experiments/rate_control_settings.h b/rtc_base/experiments/rate_control_settings.h index db7f1cd136..1c38e927dc 100644 --- a/rtc_base/experiments/rate_control_settings.h +++ b/rtc_base/experiments/rate_control_settings.h @@ -96,7 +96,7 @@ class RateControlSettings final { explicit RateControlSettings( const WebRtcKeyValueConfig* const key_value_config); - const CongestionWindowConfig congestion_window_config_; + CongestionWindowConfig congestion_window_config_; VideoRateControlConfig video_config_; }; diff --git a/rtc_base/experiments/rate_control_settings_unittest.cc b/rtc_base/experiments/rate_control_settings_unittest.cc index 8d722722e4..84e5825224 100644 --- a/rtc_base/experiments/rate_control_settings_unittest.cc +++ b/rtc_base/experiments/rate_control_settings_unittest.cc @@ -20,7 +20,7 @@ namespace webrtc { namespace { TEST(RateControlSettingsTest, CongestionWindow) { - EXPECT_FALSE( + EXPECT_TRUE( RateControlSettings::ParseFromFieldTrials().UseCongestionWindow()); test::ScopedFieldTrials field_trials( @@ -32,8 +32,8 @@ TEST(RateControlSettingsTest, CongestionWindow) { } TEST(RateControlSettingsTest, CongestionWindowPushback) { - EXPECT_FALSE(RateControlSettings::ParseFromFieldTrials() - .UseCongestionWindowPushback()); + EXPECT_TRUE(RateControlSettings::ParseFromFieldTrials() + .UseCongestionWindowPushback()); test::ScopedFieldTrials field_trials( "WebRTC-CongestionWindow/QueueSize:100,MinBitrate:100000/"); @@ -44,6 +44,29 @@ TEST(RateControlSettingsTest, CongestionWindowPushback) { 100000u); } +TEST(RateControlSettingsTest, CongestionWindowPushbackDropframe) { + EXPECT_TRUE(RateControlSettings::ParseFromFieldTrials() + .UseCongestionWindowPushback()); + + test::ScopedFieldTrials field_trials( + "WebRTC-CongestionWindow/" + "QueueSize:100,MinBitrate:100000,DropFrame:true/"); + const RateControlSettings settings_after = + RateControlSettings::ParseFromFieldTrials(); + EXPECT_TRUE(settings_after.UseCongestionWindowPushback()); + EXPECT_EQ(settings_after.CongestionWindowMinPushbackTargetBitrateBps(), + 100000u); + EXPECT_TRUE(settings_after.UseCongestionWindowDropFrameOnly()); +} + +TEST(RateControlSettingsTest, CongestionWindowPushbackDefaultConfig) { + const RateControlSettings settings = + RateControlSettings::ParseFromFieldTrials(); + EXPECT_TRUE(settings.UseCongestionWindowPushback()); + EXPECT_EQ(settings.CongestionWindowMinPushbackTargetBitrateBps(), 30000u); + EXPECT_TRUE(settings.UseCongestionWindowDropFrameOnly()); +} + TEST(RateControlSettingsTest, PacingFactor) { EXPECT_FALSE(RateControlSettings::ParseFromFieldTrials().GetPacingFactor()); diff --git a/rtc_base/experiments/struct_parameters_parser.cc b/rtc_base/experiments/struct_parameters_parser.cc index 2605da8fef..d62eb6f1ea 100644 --- a/rtc_base/experiments/struct_parameters_parser.cc +++ b/rtc_base/experiments/struct_parameters_parser.cc @@ -107,7 +107,10 @@ void StructParametersParser::Parse(absl::string_view src) { break; } } - if (!found) { + // "_" is be used to prefix keys that are part of the string for + // debugging purposes but not neccessarily used. + // e.g. WebRTC-Experiment/param: value, _DebuggingString + if (!found && (key.empty() || key[0] != '_')) { RTC_LOG(LS_INFO) << "No field with key: '" << key << "' (found in trial: \"" << src << "\")"; } diff --git a/rtc_base/fake_mdns_responder.h b/rtc_base/fake_mdns_responder.h index 42908764ab..1f87cf4b81 100644 --- a/rtc_base/fake_mdns_responder.h +++ b/rtc_base/fake_mdns_responder.h @@ -15,14 +15,17 @@ #include #include -#include "rtc_base/async_invoker.h" #include "rtc_base/ip_address.h" #include "rtc_base/location.h" #include "rtc_base/mdns_responder_interface.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/thread.h" namespace webrtc { +// This class posts tasks on the given `thread` to invoke callbacks. It's the +// callback's responsibility to be aware of potential destruction of state it +// depends on, e.g., using WeakPtrFactory or PendingTaskSafetyFlag. class FakeMdnsResponder : public MdnsResponderInterface { public: explicit FakeMdnsResponder(rtc::Thread* thread) : thread_(thread) {} @@ -37,9 +40,8 @@ class FakeMdnsResponder : public MdnsResponderInterface { name = std::to_string(next_available_id_++) + ".local"; addr_name_map_[addr] = name; } - invoker_.AsyncInvoke( - RTC_FROM_HERE, thread_, - [callback, addr, name]() { callback(addr, name); }); + thread_->PostTask( + ToQueuedTask([callback, addr, name]() { callback(addr, name); })); } void RemoveNameForAddress(const rtc::IPAddress& addr, NameRemovedCallback callback) override { @@ -48,8 +50,7 @@ class FakeMdnsResponder : public MdnsResponderInterface { addr_name_map_.erase(it); } bool result = it != addr_name_map_.end(); - invoker_.AsyncInvoke(RTC_FROM_HERE, thread_, - [callback, result]() { callback(result); }); + thread_->PostTask(ToQueuedTask([callback, result]() { callback(result); })); } rtc::IPAddress GetMappedAddressForName(const std::string& name) const { @@ -64,8 +65,7 @@ class FakeMdnsResponder : public MdnsResponderInterface { private: uint32_t next_available_id_ = 0; std::map addr_name_map_; - rtc::Thread* thread_; - rtc::AsyncInvoker invoker_; + rtc::Thread* const thread_; }; } // namespace webrtc diff --git a/rtc_base/file_rotating_stream.cc b/rtc_base/file_rotating_stream.cc index 826e6745f3..b7d64ba92d 100644 --- a/rtc_base/file_rotating_stream.cc +++ b/rtc_base/file_rotating_stream.cc @@ -193,49 +193,40 @@ FileRotatingStream::FileRotatingStream(const std::string& dir_path, FileRotatingStream::~FileRotatingStream() {} -StreamState FileRotatingStream::GetState() const { - return (file_.is_open() ? SS_OPEN : SS_CLOSED); +bool FileRotatingStream::IsOpen() const { + return file_.is_open(); } -StreamResult FileRotatingStream::Read(void* buffer, - size_t buffer_len, - size_t* read, - int* error) { - RTC_DCHECK(buffer); - RTC_NOTREACHED(); - return SR_EOS; -} - -StreamResult FileRotatingStream::Write(const void* data, - size_t data_len, - size_t* written, - int* error) { +bool FileRotatingStream::Write(const void* data, size_t data_len) { if (!file_.is_open()) { std::fprintf(stderr, "Open() must be called before Write.\n"); - return SR_ERROR; + return false; } - // Write as much as will fit in to the current file. - RTC_DCHECK_LT(current_bytes_written_, max_file_size_); - size_t remaining_bytes = max_file_size_ - current_bytes_written_; - size_t write_length = std::min(data_len, remaining_bytes); + while (data_len > 0) { + // Write as much as will fit in to the current file. + RTC_DCHECK_LT(current_bytes_written_, max_file_size_); + size_t remaining_bytes = max_file_size_ - current_bytes_written_; + size_t write_length = std::min(data_len, remaining_bytes); + + if (!file_.Write(data, write_length)) { + return false; + } + if (disable_buffering_ && !file_.Flush()) { + return false; + } - if (!file_.Write(data, write_length)) { - return SR_ERROR; - } - if (disable_buffering_ && !file_.Flush()) { - return SR_ERROR; - } + current_bytes_written_ += write_length; - current_bytes_written_ += write_length; - if (written) { - *written = write_length; - } - // If we're done with this file, rotate it out. - if (current_bytes_written_ >= max_file_size_) { - RTC_DCHECK_EQ(current_bytes_written_, max_file_size_); - RotateFiles(); + // If we're done with this file, rotate it out. + if (current_bytes_written_ >= max_file_size_) { + RTC_DCHECK_EQ(current_bytes_written_, max_file_size_); + RotateFiles(); + } + data_len -= write_length; + data = + static_cast(static_cast(data) + write_length); } - return SR_SUCCESS; + return true; } bool FileRotatingStream::Flush() { diff --git a/rtc_base/file_rotating_stream.h b/rtc_base/file_rotating_stream.h index 117cf2019a..88461e344f 100644 --- a/rtc_base/file_rotating_stream.h +++ b/rtc_base/file_rotating_stream.h @@ -18,7 +18,6 @@ #include #include "rtc_base/constructor_magic.h" -#include "rtc_base/stream.h" #include "rtc_base/system/file_wrapper.h" namespace rtc { @@ -27,13 +26,8 @@ namespace rtc { // constructor. It rotates the files once the current file is full. The // individual file size and the number of files used is configurable in the // constructor. Open() must be called before using this stream. -class FileRotatingStream : public StreamInterface { +class FileRotatingStream { public: - // Use this constructor for reading a directory previously written to with - // this stream. - FileRotatingStream(const std::string& dir_path, - const std::string& file_prefix); - // Use this constructor for writing to a directory. Files in the directory // matching the prefix will be deleted on open. FileRotatingStream(const std::string& dir_path, @@ -41,20 +35,13 @@ class FileRotatingStream : public StreamInterface { size_t max_file_size, size_t num_files); - ~FileRotatingStream() override; - - // StreamInterface methods. - StreamState GetState() const override; - StreamResult Read(void* buffer, - size_t buffer_len, - size_t* read, - int* error) override; - StreamResult Write(const void* data, - size_t data_len, - size_t* written, - int* error) override; - bool Flush() override; - void Close() override; + virtual ~FileRotatingStream(); + + bool IsOpen() const; + + bool Write(const void* data, size_t data_len); + bool Flush(); + void Close(); // Opens the appropriate file(s). Call this before using the stream. bool Open(); @@ -63,6 +50,8 @@ class FileRotatingStream : public StreamInterface { // enabled by default for performance. bool DisableBuffering(); + // Below two methods are public for testing only. + // Returns the path used for the i-th newest file, where the 0th file is the // newest file. The file may or may not exist, this is just used for // formatting. Index must be less than GetNumFiles(). @@ -72,8 +61,6 @@ class FileRotatingStream : public StreamInterface { size_t GetNumFiles() const { return file_names_.size(); } protected: - size_t GetMaxFileSize() const { return max_file_size_; } - void SetMaxFileSize(size_t size) { max_file_size_ = size; } size_t GetRotationIndex() const { return rotation_index_; } diff --git a/rtc_base/file_rotating_stream_unittest.cc b/rtc_base/file_rotating_stream_unittest.cc index c2ba06773a..849b111148 100644 --- a/rtc_base/file_rotating_stream_unittest.cc +++ b/rtc_base/file_rotating_stream_unittest.cc @@ -72,7 +72,7 @@ class MAYBE_FileRotatingStreamTest : public ::testing::Test { // Writes the data to the stream and flushes it. void WriteAndFlush(const void* data, const size_t data_len) { - EXPECT_EQ(SR_SUCCESS, stream_->WriteAll(data, data_len, nullptr, nullptr)); + EXPECT_TRUE(stream_->Write(data, data_len)); EXPECT_TRUE(stream_->Flush()); } @@ -114,11 +114,11 @@ const size_t MAYBE_FileRotatingStreamTest::kMaxFileSize = 2; TEST_F(MAYBE_FileRotatingStreamTest, State) { Init("FileRotatingStreamTestState", kFilePrefix, kMaxFileSize, 3); - EXPECT_EQ(SS_CLOSED, stream_->GetState()); + EXPECT_FALSE(stream_->IsOpen()); ASSERT_TRUE(stream_->Open()); - EXPECT_EQ(SS_OPEN, stream_->GetState()); + EXPECT_TRUE(stream_->IsOpen()); stream_->Close(); - EXPECT_EQ(SS_CLOSED, stream_->GetState()); + EXPECT_FALSE(stream_->IsOpen()); } // Tests that nothing is written to file when data of length zero is written. @@ -277,7 +277,7 @@ class MAYBE_CallSessionFileRotatingStreamTest : public ::testing::Test { // Writes the data to the stream and flushes it. void WriteAndFlush(const void* data, const size_t data_len) { - EXPECT_EQ(SR_SUCCESS, stream_->WriteAll(data, data_len, nullptr, nullptr)); + EXPECT_TRUE(stream_->Write(data, data_len)); EXPECT_TRUE(stream_->Flush()); } @@ -334,8 +334,7 @@ TEST_F(MAYBE_CallSessionFileRotatingStreamTest, WriteAndReadLarge) { std::unique_ptr buffer(new uint8_t[buffer_size]); for (int i = 0; i < 8; i++) { memset(buffer.get(), i, buffer_size); - EXPECT_EQ(SR_SUCCESS, - stream_->WriteAll(buffer.get(), buffer_size, nullptr, nullptr)); + EXPECT_TRUE(stream_->Write(buffer.get(), buffer_size)); } const int expected_vals[] = {0, 1, 2, 6, 7}; @@ -369,8 +368,7 @@ TEST_F(MAYBE_CallSessionFileRotatingStreamTest, WriteAndReadFirstHalf) { std::unique_ptr buffer(new uint8_t[buffer_size]); for (int i = 0; i < 2; i++) { memset(buffer.get(), i, buffer_size); - EXPECT_EQ(SR_SUCCESS, - stream_->WriteAll(buffer.get(), buffer_size, nullptr, nullptr)); + EXPECT_TRUE(stream_->Write(buffer.get(), buffer_size)); } const int expected_vals[] = {0, 1}; diff --git a/rtc_base/hash.h b/rtc_base/hash.h new file mode 100644 index 0000000000..56d581cdf1 --- /dev/null +++ b/rtc_base/hash.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef RTC_BASE_HASH_H_ +#define RTC_BASE_HASH_H_ + +#include + +#include +#include + +namespace webrtc { + +// A custom hash function for std::pair, to be able to be used as key in a +// std::unordered_map. If absl::flat_hash_map would ever be used, this is +// unnecessary as it already has a hash function for std::pair. +struct PairHash { + template + size_t operator()(const std::pair& p) const { + return (3 * std::hash{}(p.first)) ^ std::hash{}(p.second); + } +}; + +} // namespace webrtc + +#endif // RTC_BASE_HASH_H_ diff --git a/rtc_base/hash_unittest.cc b/rtc_base/hash_unittest.cc new file mode 100644 index 0000000000..e86c8a8586 --- /dev/null +++ b/rtc_base/hash_unittest.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "rtc_base/hash.h" + +#include +#include +#include + +#include "test/gmock.h" + +namespace webrtc { +namespace { + +TEST(PairHashTest, CanInsertIntoSet) { + using MyPair = std::pair; + + std::unordered_set pairs; + + pairs.insert({1, 2}); + pairs.insert({3, 4}); + + EXPECT_NE(pairs.find({1, 2}), pairs.end()); + EXPECT_NE(pairs.find({3, 4}), pairs.end()); + EXPECT_EQ(pairs.find({1, 3}), pairs.end()); + EXPECT_EQ(pairs.find({3, 3}), pairs.end()); +} + +TEST(PairHashTest, CanInsertIntoMap) { + using MyPair = std::pair; + + std::unordered_map pairs; + + pairs[{"1", 2}] = 99; + pairs[{"3", 4}] = 100; + + EXPECT_EQ((pairs[{"1", 2}]), 99); + EXPECT_EQ((pairs[{"3", 4}]), 100); + EXPECT_EQ(pairs.find({"1", 3}), pairs.end()); + EXPECT_EQ(pairs.find({"3", 3}), pairs.end()); +} +} // namespace +} // namespace webrtc diff --git a/rtc_base/internal/default_socket_server.cc b/rtc_base/internal/default_socket_server.cc new file mode 100644 index 0000000000..5632b989fc --- /dev/null +++ b/rtc_base/internal/default_socket_server.cc @@ -0,0 +1,33 @@ +/* + * Copyright 2020 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "rtc_base/internal/default_socket_server.h" + +#include + +#include "rtc_base/socket_server.h" + +#if defined(__native_client__) +#include "rtc_base/null_socket_server.h" +#else +#include "rtc_base/physical_socket_server.h" +#endif + +namespace rtc { + +std::unique_ptr CreateDefaultSocketServer() { +#if defined(__native_client__) + return std::unique_ptr(new rtc::NullSocketServer); +#else + return std::unique_ptr(new rtc::PhysicalSocketServer); +#endif +} + +} // namespace rtc diff --git a/rtc_base/signal_thread.h b/rtc_base/internal/default_socket_server.h similarity index 56% rename from rtc_base/signal_thread.h rename to rtc_base/internal/default_socket_server.h index b444d54994..5b3489f613 100644 --- a/rtc_base/signal_thread.h +++ b/rtc_base/internal/default_socket_server.h @@ -8,12 +8,17 @@ * be found in the AUTHORS file in the root of the source tree. */ -#ifndef RTC_BASE_SIGNAL_THREAD_H_ -#define RTC_BASE_SIGNAL_THREAD_H_ +#ifndef RTC_BASE_INTERNAL_DEFAULT_SOCKET_SERVER_H_ +#define RTC_BASE_INTERNAL_DEFAULT_SOCKET_SERVER_H_ -// The facilities in this file have been deprecated. Please do not use them -// in new code. New code should use factilities exposed by api/task_queue/ -// instead. -#include "rtc_base/deprecated/signal_thread.h" +#include -#endif // RTC_BASE_SIGNAL_THREAD_H_ +#include "rtc_base/socket_server.h" + +namespace rtc { + +std::unique_ptr CreateDefaultSocketServer(); + +} // namespace rtc + +#endif // RTC_BASE_INTERNAL_DEFAULT_SOCKET_SERVER_H_ diff --git a/rtc_base/ip_address.cc b/rtc_base/ip_address.cc index 9dd534c2b5..86f42e0bf9 100644 --- a/rtc_base/ip_address.cc +++ b/rtc_base/ip_address.cc @@ -20,8 +20,9 @@ #include #endif -#include "rtc_base/byte_order.h" #include "rtc_base/ip_address.h" + +#include "rtc_base/byte_order.h" #include "rtc_base/net_helpers.h" #include "rtc_base/string_utils.h" @@ -148,10 +149,6 @@ std::string IPAddress::ToString() const { } std::string IPAddress::ToSensitiveString() const { -#if !defined(NDEBUG) - // Return non-stripped in debug. - return ToString(); -#else switch (family_) { case AF_INET: { std::string address = ToString(); @@ -175,7 +172,6 @@ std::string IPAddress::ToSensitiveString() const { } } return std::string(); -#endif } IPAddress IPAddress::Normalized() const { diff --git a/rtc_base/ip_address_unittest.cc b/rtc_base/ip_address_unittest.cc index d79a7b4bd6..f94649cfee 100644 --- a/rtc_base/ip_address_unittest.cc +++ b/rtc_base/ip_address_unittest.cc @@ -938,15 +938,9 @@ TEST(IPAddressTest, TestToSensitiveString) { EXPECT_EQ(kIPv4PublicAddrString, addr_v4.ToString()); EXPECT_EQ(kIPv6PublicAddrString, addr_v6.ToString()); EXPECT_EQ(kIPv6PublicAddr2String, addr_v6_2.ToString()); -#if defined(NDEBUG) EXPECT_EQ(kIPv4PublicAddrAnonymizedString, addr_v4.ToSensitiveString()); EXPECT_EQ(kIPv6PublicAddrAnonymizedString, addr_v6.ToSensitiveString()); EXPECT_EQ(kIPv6PublicAddr2AnonymizedString, addr_v6_2.ToSensitiveString()); -#else - EXPECT_EQ(kIPv4PublicAddrString, addr_v4.ToSensitiveString()); - EXPECT_EQ(kIPv6PublicAddrString, addr_v6.ToSensitiveString()); - EXPECT_EQ(kIPv6PublicAddr2String, addr_v6_2.ToSensitiveString()); -#endif // defined(NDEBUG) } TEST(IPAddressTest, TestInterfaceAddress) { diff --git a/rtc_base/java/src/org/webrtc/OWNERS b/rtc_base/java/src/org/webrtc/OWNERS index 299e8b20ec..109bea2725 100644 --- a/rtc_base/java/src/org/webrtc/OWNERS +++ b/rtc_base/java/src/org/webrtc/OWNERS @@ -1,2 +1,2 @@ magjed@webrtc.org -sakal@webrtc.org +xalep@webrtc.org diff --git a/rtc_base/keep_ref_until_done.h b/rtc_base/keep_ref_until_done.h deleted file mode 100644 index 5ae0ed1b21..0000000000 --- a/rtc_base/keep_ref_until_done.h +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright 2015 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef RTC_BASE_KEEP_REF_UNTIL_DONE_H_ -#define RTC_BASE_KEEP_REF_UNTIL_DONE_H_ - -#include "api/scoped_refptr.h" -#include "rtc_base/callback.h" - -namespace rtc { - -// KeepRefUntilDone keeps a reference to |object| until the returned -// callback goes out of scope. If the returned callback is copied, the -// reference will be released when the last callback goes out of scope. -template -static inline Callback0 KeepRefUntilDone(ObjectT* object) { - scoped_refptr p(object); - return [p] {}; -} - -template -static inline Callback0 KeepRefUntilDone( - const scoped_refptr& object) { - return [object] {}; -} - -} // namespace rtc - -#endif // RTC_BASE_KEEP_REF_UNTIL_DONE_H_ diff --git a/rtc_base/log_sinks.cc b/rtc_base/log_sinks.cc index a3019b9786..4365142517 100644 --- a/rtc_base/log_sinks.cc +++ b/rtc_base/log_sinks.cc @@ -16,7 +16,6 @@ #include #include "rtc_base/checks.h" -#include "rtc_base/stream.h" namespace rtc { @@ -37,23 +36,23 @@ FileRotatingLogSink::FileRotatingLogSink(FileRotatingStream* stream) FileRotatingLogSink::~FileRotatingLogSink() {} void FileRotatingLogSink::OnLogMessage(const std::string& message) { - if (stream_->GetState() != SS_OPEN) { + if (!stream_->IsOpen()) { std::fprintf(stderr, "Init() must be called before adding this sink.\n"); return; } - stream_->WriteAll(message.c_str(), message.size(), nullptr, nullptr); + stream_->Write(message.c_str(), message.size()); } void FileRotatingLogSink::OnLogMessage(const std::string& message, LoggingSeverity sev, const char* tag) { - if (stream_->GetState() != SS_OPEN) { + if (!stream_->IsOpen()) { std::fprintf(stderr, "Init() must be called before adding this sink.\n"); return; } - stream_->WriteAll(tag, strlen(tag), nullptr, nullptr); - stream_->WriteAll(": ", 2, nullptr, nullptr); - stream_->WriteAll(message.c_str(), message.size(), nullptr, nullptr); + stream_->Write(tag, strlen(tag)); + stream_->Write(": ", 2); + stream_->Write(message.c_str(), message.size()); } bool FileRotatingLogSink::Init() { diff --git a/rtc_base/logging.cc b/rtc_base/logging.cc index 13a5f02597..a333d83970 100644 --- a/rtc_base/logging.cc +++ b/rtc_base/logging.cc @@ -51,6 +51,17 @@ static const int kMaxLogLineSize = 1024 - 60; #include "rtc_base/thread_annotations.h" #include "rtc_base/time_utils.h" +#if defined(WEBRTC_RACE_CHECK_MUTEX) +#if defined(WEBRTC_ABSL_MUTEX) +#error Please only define one of WEBRTC_RACE_CHECK_MUTEX and WEBRTC_ABSL_MUTEX. +#endif +#include "absl/base/const_init.h" +#include "absl/synchronization/mutex.h" // nogncheck +using LoggingMutexLock = ::absl::MutexLock; +#else +using LoggingMutexLock = ::webrtc::MutexLock; +#endif // if defined(WEBRTC_RACE_CHECK_MUTEX) + namespace rtc { namespace { // By default, release builds don't log, debug builds at info level @@ -75,7 +86,15 @@ const char* FilenameFromPath(const char* file) { // Global lock for log subsystem, only needed to serialize access to streams_. // TODO(bugs.webrtc.org/11665): this is not currently constant initialized and // trivially destructible. +#if defined(WEBRTC_RACE_CHECK_MUTEX) +// When WEBRTC_RACE_CHECK_MUTEX is defined, even though WebRTC objects are +// invoked serially, the logging is static, invoked concurrently and hence needs +// protection. +absl::Mutex g_log_mutex_(absl::kConstInit); +#else webrtc::Mutex g_log_mutex_; +#endif + } // namespace ///////////////////////////////////////////////////////////////////////////// @@ -201,7 +220,7 @@ LogMessage::~LogMessage() { #endif } - webrtc::MutexLock lock(&g_log_mutex_); + LoggingMutexLock lock(&g_log_mutex_); for (LogSink* entry = streams_; entry != nullptr; entry = entry->next_) { if (severity_ >= entry->min_severity_) { #if defined(WEBRTC_ANDROID) @@ -250,7 +269,7 @@ void LogMessage::LogTimestamps(bool on) { void LogMessage::LogToDebug(LoggingSeverity min_sev) { g_dbg_sev = min_sev; - webrtc::MutexLock lock(&g_log_mutex_); + LoggingMutexLock lock(&g_log_mutex_); UpdateMinLogSeverity(); } @@ -259,7 +278,7 @@ void LogMessage::SetLogToStderr(bool log_to_stderr) { } int LogMessage::GetLogToStream(LogSink* stream) { - webrtc::MutexLock lock(&g_log_mutex_); + LoggingMutexLock lock(&g_log_mutex_); LoggingSeverity sev = LS_NONE; for (LogSink* entry = streams_; entry != nullptr; entry = entry->next_) { if (stream == nullptr || stream == entry) { @@ -270,7 +289,7 @@ int LogMessage::GetLogToStream(LogSink* stream) { } void LogMessage::AddLogToStream(LogSink* stream, LoggingSeverity min_sev) { - webrtc::MutexLock lock(&g_log_mutex_); + LoggingMutexLock lock(&g_log_mutex_); stream->min_severity_ = min_sev; stream->next_ = streams_; streams_ = stream; @@ -279,7 +298,7 @@ void LogMessage::AddLogToStream(LogSink* stream, LoggingSeverity min_sev) { } void LogMessage::RemoveLogToStream(LogSink* stream) { - webrtc::MutexLock lock(&g_log_mutex_); + LoggingMutexLock lock(&g_log_mutex_); for (LogSink** entry = &streams_; *entry != nullptr; entry = &(*entry)->next_) { if (*entry == stream) { diff --git a/rtc_base/logging.h b/rtc_base/logging.h index d2607c28b7..e21c30e21a 100644 --- a/rtc_base/logging.h +++ b/rtc_base/logging.h @@ -51,10 +51,10 @@ #include #include +#include "absl/base/attributes.h" #include "absl/meta/type_traits.h" #include "absl/strings/string_view.h" #include "rtc_base/constructor_magic.h" -#include "rtc_base/deprecation.h" #include "rtc_base/strings/string_builder.h" #include "rtc_base/system/inline.h" @@ -434,7 +434,7 @@ class LogMessage { // DEPRECATED - DO NOT USE - PLEASE USE THE MACROS INSTEAD OF THE CLASS. // Android code should use the 'const char*' version since tags are static // and we want to avoid allocating a std::string copy per log line. - RTC_DEPRECATED + ABSL_DEPRECATED("Use RTC_LOG macros instead of accessing this class directly") LogMessage(const char* file, int line, LoggingSeverity sev, @@ -508,7 +508,7 @@ class LogMessage { // DEPRECATED - DO NOT USE - PLEASE USE THE MACROS INSTEAD OF THE CLASS. // Android code should use the 'const char*' version since tags are static // and we want to avoid allocating a std::string copy per log line. - RTC_DEPRECATED + ABSL_DEPRECATED("Use RTC_LOG macros instead of accessing this class directly") LogMessage(const char* file, int line, LoggingSeverity sev, diff --git a/rtc_base/logging_unittest.cc b/rtc_base/logging_unittest.cc index 6bb20abcc1..dc1208f3f6 100644 --- a/rtc_base/logging_unittest.cc +++ b/rtc_base/logging_unittest.cc @@ -20,94 +20,23 @@ #include "rtc_base/checks.h" #include "rtc_base/event.h" #include "rtc_base/platform_thread.h" -#include "rtc_base/stream.h" #include "rtc_base/time_utils.h" #include "test/gtest.h" namespace rtc { -namespace { - -class StringStream : public StreamInterface { +class LogSinkImpl : public LogSink { public: - explicit StringStream(std::string* str); - explicit StringStream(const std::string& str); - - StreamState GetState() const override; - StreamResult Read(void* buffer, - size_t buffer_len, - size_t* read, - int* error) override; - StreamResult Write(const void* data, - size_t data_len, - size_t* written, - int* error) override; - void Close() override; - - private: - std::string& str_; - size_t read_pos_; - bool read_only_; -}; - -StringStream::StringStream(std::string* str) - : str_(*str), read_pos_(0), read_only_(false) {} - -StringStream::StringStream(const std::string& str) - : str_(const_cast(str)), read_pos_(0), read_only_(true) {} - -StreamState StringStream::GetState() const { - return SS_OPEN; -} - -StreamResult StringStream::Read(void* buffer, - size_t buffer_len, - size_t* read, - int* error) { - size_t available = std::min(buffer_len, str_.size() - read_pos_); - if (!available) - return SR_EOS; - memcpy(buffer, str_.data() + read_pos_, available); - read_pos_ += available; - if (read) - *read = available; - return SR_SUCCESS; -} - -StreamResult StringStream::Write(const void* data, - size_t data_len, - size_t* written, - int* error) { - if (read_only_) { - if (error) { - *error = -1; - } - return SR_ERROR; - } - str_.append(static_cast(data), - static_cast(data) + data_len); - if (written) - *written = data_len; - return SR_SUCCESS; -} - -void StringStream::Close() {} - -} // namespace - -template -class LogSinkImpl : public LogSink, public Base { - public: - LogSinkImpl() {} + explicit LogSinkImpl(std::string* log_data) : log_data_(log_data) {} template - explicit LogSinkImpl(P* p) : Base(p) {} + explicit LogSinkImpl(P* p) {} private: void OnLogMessage(const std::string& message) override { - static_cast(this)->WriteAll(message.data(), message.size(), nullptr, - nullptr); + log_data_->append(message); } + std::string* const log_data_; }; class LogMessageForTesting : public LogMessage { @@ -145,7 +74,7 @@ TEST(LogTest, SingleStream) { int sev = LogMessage::GetLogToStream(nullptr); std::string str; - LogSinkImpl stream(&str); + LogSinkImpl stream(&str); LogMessage::AddLogToStream(&stream, LS_INFO); EXPECT_EQ(LS_INFO, LogMessage::GetLogToStream(&stream)); @@ -207,7 +136,7 @@ TEST(LogTest, MultipleStreams) { int sev = LogMessage::GetLogToStream(nullptr); std::string str1, str2; - LogSinkImpl stream1(&str1), stream2(&str2); + LogSinkImpl stream1(&str1), stream2(&str2); LogMessage::AddLogToStream(&stream1, LS_INFO); LogMessage::AddLogToStream(&stream2, LS_VERBOSE); EXPECT_EQ(LS_INFO, LogMessage::GetLogToStream(&stream1)); @@ -231,18 +160,13 @@ TEST(LogTest, MultipleStreams) { class LogThread { public: - LogThread() : thread_(&ThreadEntry, this, "LogThread") {} - ~LogThread() { thread_.Stop(); } - - void Start() { thread_.Start(); } + void Start() { + thread_ = PlatformThread::SpawnJoinable( + [] { RTC_LOG(LS_VERBOSE) << "RTC_LOG"; }, "LogThread"); + } private: - void Run() { RTC_LOG(LS_VERBOSE) << "RTC_LOG"; } - - static void ThreadEntry(void* p) { static_cast(p)->Run(); } - PlatformThread thread_; - Event event_; }; // Ensure we don't crash when adding/removing streams while threads are going. @@ -256,7 +180,7 @@ TEST(LogTest, MultipleThreads) { thread3.Start(); std::string s1, s2, s3; - LogSinkImpl stream1(&s1), stream2(&s2), stream3(&s3); + LogSinkImpl stream1(&s1), stream2(&s2), stream3(&s3); for (int i = 0; i < 1000; ++i) { LogMessage::AddLogToStream(&stream1, LS_WARNING); LogMessage::AddLogToStream(&stream2, LS_INFO); @@ -303,7 +227,7 @@ TEST(LogTest, CheckFilePathParsed) { #if defined(WEBRTC_ANDROID) TEST(LogTest, CheckTagAddedToStringInDefaultOnLogMessageAndroid) { std::string str; - LogSinkImpl stream(&str); + LogSinkImpl stream(&str); LogMessage::AddLogToStream(&stream, LS_INFO); EXPECT_EQ(LS_INFO, LogMessage::GetLogToStream(&stream)); @@ -316,7 +240,7 @@ TEST(LogTest, CheckTagAddedToStringInDefaultOnLogMessageAndroid) { // Test the time required to write 1000 80-character logs to a string. TEST(LogTest, Perf) { std::string str; - LogSinkImpl stream(&str); + LogSinkImpl stream(&str); LogMessage::AddLogToStream(&stream, LS_VERBOSE); const std::string message(80, 'X'); @@ -336,7 +260,6 @@ TEST(LogTest, Perf) { finish = TimeMillis(); LogMessage::RemoveLogToStream(&stream); - stream.Close(); EXPECT_EQ(str.size(), (message.size() + logging_overhead) * kRepetitions); RTC_LOG(LS_INFO) << "Total log time: " << TimeDiff(finish, start) @@ -348,7 +271,7 @@ TEST(LogTest, Perf) { TEST(LogTest, EnumsAreSupported) { enum class TestEnum { kValue0 = 0, kValue1 = 1 }; std::string str; - LogSinkImpl stream(&str); + LogSinkImpl stream(&str); LogMessage::AddLogToStream(&stream, LS_INFO); RTC_LOG(LS_INFO) << "[" << TestEnum::kValue0 << "]"; EXPECT_NE(std::string::npos, str.find("[0]")); @@ -356,7 +279,6 @@ TEST(LogTest, EnumsAreSupported) { RTC_LOG(LS_INFO) << "[" << TestEnum::kValue1 << "]"; EXPECT_NE(std::string::npos, str.find("[1]")); LogMessage::RemoveLogToStream(&stream); - stream.Close(); } TEST(LogTest, NoopSeverityDoesNotRunStringFormatting) { diff --git a/rtc_base/memory/BUILD.gn b/rtc_base/memory/BUILD.gn index 838fbc68d4..ee66ac0df8 100644 --- a/rtc_base/memory/BUILD.gn +++ b/rtc_base/memory/BUILD.gn @@ -21,14 +21,12 @@ rtc_library("aligned_malloc") { } # Test only utility. -# TODO: Tag with `testonly = true` once all depending targets are correctly -# tagged. rtc_library("fifo_buffer") { + testonly = true visibility = [ ":unittests", "..:rtc_base_tests_utils", "..:rtc_base_unittests", - "../../p2p:rtc_p2p", # This needs to be fixed. ] sources = [ "fifo_buffer.cc", @@ -36,6 +34,7 @@ rtc_library("fifo_buffer") { ] deps = [ "..:rtc_base", + "..:threading", "../synchronization:mutex", "../task_utils:pending_task_safety_flag", "../task_utils:to_queued_task", diff --git a/rtc_base/message_buffer_reader.h b/rtc_base/message_buffer_reader.h deleted file mode 100644 index 32b8f336b1..0000000000 --- a/rtc_base/message_buffer_reader.h +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright 2018 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef RTC_BASE_MESSAGE_BUFFER_READER_H_ -#define RTC_BASE_MESSAGE_BUFFER_READER_H_ - -#include "rtc_base/byte_buffer.h" - -namespace webrtc { - -// A simple subclass of the ByteBufferReader that exposes the starting address -// of the message and its length, so that we can recall previously parsed data. -class MessageBufferReader : public rtc::ByteBufferReader { - public: - MessageBufferReader(const char* bytes, size_t len) - : rtc::ByteBufferReader(bytes, len) {} - ~MessageBufferReader() = default; - - // Starting address of the message. - const char* MessageData() const { return bytes_; } - // Total length of the message. Note that this is different from Length(), - // which is the length of the remaining message from the current offset. - size_t MessageLength() const { return size_; } - // Current offset in the message. - size_t CurrentOffset() const { return start_; } -}; - -} // namespace webrtc - -#endif // RTC_BASE_MESSAGE_BUFFER_READER_H_ diff --git a/rtc_base/nat_socket_factory.cc b/rtc_base/nat_socket_factory.cc index 3edf4cecf4..effbb5a6c3 100644 --- a/rtc_base/nat_socket_factory.cc +++ b/rtc_base/nat_socket_factory.cc @@ -428,14 +428,15 @@ NATSocketServer::Translator::Translator(NATSocketServer* server, // Create a new private network, and a NATServer running on the private // network that bridges to the external network. Also tell the private // network to use the same message queue as us. - VirtualSocketServer* internal_server = new VirtualSocketServer(); - internal_server->SetMessageQueue(server_->queue()); - internal_factory_.reset(internal_server); - nat_server_.reset(new NATServer(type, internal_server, int_ip, int_ip, - ext_factory, ext_ip)); + internal_server_ = std::make_unique(); + internal_server_->SetMessageQueue(server_->queue()); + nat_server_ = std::make_unique( + type, internal_server_.get(), int_ip, int_ip, ext_factory, ext_ip); } -NATSocketServer::Translator::~Translator() = default; +NATSocketServer::Translator::~Translator() { + internal_server_->SetMessageQueue(nullptr); +} NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator( const SocketAddress& ext_ip) { diff --git a/rtc_base/nat_socket_factory.h b/rtc_base/nat_socket_factory.h index e649d19a8e..70030d834e 100644 --- a/rtc_base/nat_socket_factory.h +++ b/rtc_base/nat_socket_factory.h @@ -107,7 +107,7 @@ class NATSocketServer : public SocketServer, public NATInternalSocketFactory { const SocketAddress& ext_addr); ~Translator(); - SocketFactory* internal_factory() { return internal_factory_.get(); } + SocketFactory* internal_factory() { return internal_server_.get(); } SocketAddress internal_udp_address() const { return nat_server_->internal_udp_address(); } @@ -129,7 +129,7 @@ class NATSocketServer : public SocketServer, public NATInternalSocketFactory { private: NATSocketServer* server_; - std::unique_ptr internal_factory_; + std::unique_ptr internal_server_; std::unique_ptr nat_server_; TranslatorMap nats_; std::set clients_; diff --git a/rtc_base/net_helpers.cc b/rtc_base/net_helpers.cc index c6685e2a65..bec854af03 100644 --- a/rtc_base/net_helpers.cc +++ b/rtc_base/net_helpers.cc @@ -10,6 +10,8 @@ #include "rtc_base/net_helpers.h" +#include + #if defined(WEBRTC_WIN) #include #include @@ -17,6 +19,7 @@ #include "rtc_base/win32.h" #endif #if defined(WEBRTC_POSIX) && !defined(__native_client__) +#include #if defined(WEBRTC_ANDROID) #include "rtc_base/ifaddrs_android.h" #else @@ -24,145 +27,8 @@ #endif #endif // defined(WEBRTC_POSIX) && !defined(__native_client__) -#include "api/task_queue/task_queue_base.h" -#include "rtc_base/logging.h" -#include "rtc_base/signal_thread.h" -#include "rtc_base/task_queue.h" -#include "rtc_base/task_utils/to_queued_task.h" -#include "rtc_base/third_party/sigslot/sigslot.h" // for signal_with_thread... - namespace rtc { -int ResolveHostname(const std::string& hostname, - int family, - std::vector* addresses) { -#ifdef __native_client__ - RTC_NOTREACHED(); - RTC_LOG(LS_WARNING) << "ResolveHostname() is not implemented for NaCl"; - return -1; -#else // __native_client__ - if (!addresses) { - return -1; - } - addresses->clear(); - struct addrinfo* result = nullptr; - struct addrinfo hints = {0}; - hints.ai_family = family; - // |family| here will almost always be AF_UNSPEC, because |family| comes from - // AsyncResolver::addr_.family(), which comes from a SocketAddress constructed - // with a hostname. When a SocketAddress is constructed with a hostname, its - // family is AF_UNSPEC. However, if someday in the future we construct - // a SocketAddress with both a hostname and a family other than AF_UNSPEC, - // then it would be possible to get a specific family value here. - - // The behavior of AF_UNSPEC is roughly "get both ipv4 and ipv6", as - // documented by the various operating systems: - // Linux: http://man7.org/linux/man-pages/man3/getaddrinfo.3.html - // Windows: https://msdn.microsoft.com/en-us/library/windows/desktop/ - // ms738520(v=vs.85).aspx - // Mac: https://developer.apple.com/legacy/library/documentation/Darwin/ - // Reference/ManPages/man3/getaddrinfo.3.html - // Android (source code, not documentation): - // https://android.googlesource.com/platform/bionic/+/ - // 7e0bfb511e85834d7c6cb9631206b62f82701d60/libc/netbsd/net/getaddrinfo.c#1657 - hints.ai_flags = AI_ADDRCONFIG; - int ret = getaddrinfo(hostname.c_str(), nullptr, &hints, &result); - if (ret != 0) { - return ret; - } - struct addrinfo* cursor = result; - for (; cursor; cursor = cursor->ai_next) { - if (family == AF_UNSPEC || cursor->ai_family == family) { - IPAddress ip; - if (IPFromAddrInfo(cursor, &ip)) { - addresses->push_back(ip); - } - } - } - freeaddrinfo(result); - return 0; -#endif // !__native_client__ -} - -AsyncResolver::AsyncResolver() : error_(-1) {} - -AsyncResolver::~AsyncResolver() { - RTC_DCHECK_RUN_ON(&sequence_checker_); -} - -void AsyncResolver::Start(const SocketAddress& addr) { - RTC_DCHECK_RUN_ON(&sequence_checker_); - RTC_DCHECK(!destroy_called_); - addr_ = addr; - webrtc::TaskQueueBase* current_task_queue = webrtc::TaskQueueBase::Current(); - popup_thread_ = Thread::Create(); - popup_thread_->Start(); - popup_thread_->PostTask(webrtc::ToQueuedTask( - [this, flag = safety_.flag(), addr, current_task_queue] { - std::vector addresses; - int error = - ResolveHostname(addr.hostname().c_str(), addr.family(), &addresses); - current_task_queue->PostTask(webrtc::ToQueuedTask( - std::move(flag), [this, error, addresses = std::move(addresses)] { - RTC_DCHECK_RUN_ON(&sequence_checker_); - ResolveDone(std::move(addresses), error); - })); - })); -} - -bool AsyncResolver::GetResolvedAddress(int family, SocketAddress* addr) const { - RTC_DCHECK_RUN_ON(&sequence_checker_); - RTC_DCHECK(!destroy_called_); - if (error_ != 0 || addresses_.empty()) - return false; - - *addr = addr_; - for (size_t i = 0; i < addresses_.size(); ++i) { - if (family == addresses_[i].family()) { - addr->SetResolvedIP(addresses_[i]); - return true; - } - } - return false; -} - -int AsyncResolver::GetError() const { - RTC_DCHECK_RUN_ON(&sequence_checker_); - RTC_DCHECK(!destroy_called_); - return error_; -} - -void AsyncResolver::Destroy(bool wait) { - // Some callers have trouble guaranteeing that Destroy is called on the - // sequence guarded by |sequence_checker_|. - // RTC_DCHECK_RUN_ON(&sequence_checker_); - RTC_DCHECK(!destroy_called_); - destroy_called_ = true; - MaybeSelfDestruct(); -} - -const std::vector& AsyncResolver::addresses() const { - RTC_DCHECK_RUN_ON(&sequence_checker_); - RTC_DCHECK(!destroy_called_); - return addresses_; -} - -void AsyncResolver::ResolveDone(std::vector addresses, int error) { - addresses_ = addresses; - error_ = error; - recursion_check_ = true; - SignalDone(this); - MaybeSelfDestruct(); -} - -void AsyncResolver::MaybeSelfDestruct() { - if (!recursion_check_) { - delete this; - } else { - recursion_check_ = false; - } -} - const char* inet_ntop(int af, const void* src, char* dst, socklen_t size) { #if defined(WEBRTC_WIN) return win32_inet_ntop(af, src, dst, size); @@ -187,7 +53,7 @@ bool HasIPv4Enabled() { return false; } for (struct ifaddrs* cur = ifa; cur != nullptr; cur = cur->ifa_next) { - if (cur->ifa_addr->sa_family == AF_INET) { + if (cur->ifa_addr != nullptr && cur->ifa_addr->sa_family == AF_INET) { has_ipv4 = true; break; } @@ -246,7 +112,7 @@ bool HasIPv6Enabled() { return false; } for (struct ifaddrs* cur = ifa; cur != nullptr; cur = cur->ifa_next) { - if (cur->ifa_addr->sa_family == AF_INET6) { + if (cur->ifa_addr != nullptr && cur->ifa_addr->sa_family == AF_INET6) { has_ipv6 = true; break; } diff --git a/rtc_base/net_helpers.h b/rtc_base/net_helpers.h index 172a222456..4ed84786b3 100644 --- a/rtc_base/net_helpers.h +++ b/rtc_base/net_helpers.h @@ -15,57 +15,12 @@ #include #elif WEBRTC_WIN #include // NOLINT -#endif - -#include -#include "rtc_base/async_resolver_interface.h" -#include "rtc_base/ip_address.h" -#include "rtc_base/socket_address.h" -#include "rtc_base/synchronization/sequence_checker.h" -#include "rtc_base/system/no_unique_address.h" -#include "rtc_base/system/rtc_export.h" -#include "rtc_base/task_utils/pending_task_safety_flag.h" -#include "rtc_base/thread.h" -#include "rtc_base/thread_annotations.h" +#include "rtc_base/win32.h" +#endif namespace rtc { -// AsyncResolver will perform async DNS resolution, signaling the result on -// the SignalDone from AsyncResolverInterface when the operation completes. -// -// This class is thread-compatible, and all methods and destruction needs to -// happen from the same rtc::Thread, except for Destroy which is allowed to -// happen on another context provided it's not happening concurrently to another -// public API call, and is the last access to the object. -class RTC_EXPORT AsyncResolver : public AsyncResolverInterface { - public: - AsyncResolver(); - ~AsyncResolver() override; - - void Start(const SocketAddress& addr) override; - bool GetResolvedAddress(int family, SocketAddress* addr) const override; - int GetError() const override; - void Destroy(bool wait) override; - - const std::vector& addresses() const; - - private: - void ResolveDone(std::vector addresses, int error) - RTC_EXCLUSIVE_LOCKS_REQUIRED(sequence_checker_); - void MaybeSelfDestruct(); - - SocketAddress addr_ RTC_GUARDED_BY(sequence_checker_); - std::vector addresses_ RTC_GUARDED_BY(sequence_checker_); - int error_ RTC_GUARDED_BY(sequence_checker_); - webrtc::ScopedTaskSafety safety_ RTC_GUARDED_BY(sequence_checker_); - std::unique_ptr popup_thread_ RTC_GUARDED_BY(sequence_checker_); - bool recursion_check_ = - false; // Protects against SignalDone calling into Destroy. - bool destroy_called_ = false; - RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker sequence_checker_; -}; - // rtc namespaced wrappers for inet_ntop and inet_pton so we can avoid // the windows-native versions of these. const char* inet_ntop(int af, const void* src, char* dst, socklen_t size); @@ -73,6 +28,7 @@ int inet_pton(int af, const char* src, void* dst); bool HasIPv4Enabled(); bool HasIPv6Enabled(); + } // namespace rtc #endif // RTC_BASE_NET_HELPERS_H_ diff --git a/rtc_base/network.cc b/rtc_base/network.cc index 07c39ae5c1..f4a349bae0 100644 --- a/rtc_base/network.cc +++ b/rtc_base/network.cc @@ -212,7 +212,8 @@ AdapterType GetAdapterTypeFromName(const char* network_name) { return ADAPTER_TYPE_ETHERNET; } - if (MatchTypeNameWithIndexPattern(network_name, "wlan")) { + if (MatchTypeNameWithIndexPattern(network_name, "wlan") || + MatchTypeNameWithIndexPattern(network_name, "v4-wlan")) { return ADAPTER_TYPE_WIFI; } @@ -478,15 +479,15 @@ Network* NetworkManagerBase::GetNetworkFromAddress( return nullptr; } -BasicNetworkManager::BasicNetworkManager() - : allow_mac_based_ipv6_( - webrtc::field_trial::IsEnabled("WebRTC-AllowMACBasedIPv6")) {} +BasicNetworkManager::BasicNetworkManager() : BasicNetworkManager(nullptr) {} BasicNetworkManager::BasicNetworkManager( NetworkMonitorFactory* network_monitor_factory) : network_monitor_factory_(network_monitor_factory), allow_mac_based_ipv6_( - webrtc::field_trial::IsEnabled("WebRTC-AllowMACBasedIPv6")) {} + webrtc::field_trial::IsEnabled("WebRTC-AllowMACBasedIPv6")), + bind_using_ifname_( + !webrtc::field_trial::IsDisabled("WebRTC-BindUsingInterfaceName")) {} BasicNetworkManager::~BasicNetworkManager() {} @@ -865,6 +866,15 @@ void BasicNetworkManager::StartNetworkMonitor() { network_monitor_->SignalNetworksChanged.connect( this, &BasicNetworkManager::OnNetworksChanged); } + + if (network_monitor_->SupportsBindSocketToNetwork()) { + // Set NetworkBinder on SocketServer so that + // PhysicalSocket::Bind will call + // BasicNetworkManager::BindSocketToNetwork(), (that will lookup interface + // name and then call network_monitor_->BindSocketToNetwork()). + thread_->socketserver()->set_network_binder(this); + } + network_monitor_->Start(); } @@ -873,6 +883,13 @@ void BasicNetworkManager::StopNetworkMonitor() { return; } network_monitor_->Stop(); + + if (network_monitor_->SupportsBindSocketToNetwork()) { + // Reset NetworkBinder on SocketServer. + if (thread_->socketserver()->network_binder() == this) { + thread_->socketserver()->set_network_binder(nullptr); + } + } } void BasicNetworkManager::OnMessage(Message* msg) { @@ -954,6 +971,20 @@ void BasicNetworkManager::DumpNetworks() { } } +NetworkBindingResult BasicNetworkManager::BindSocketToNetwork( + int socket_fd, + const IPAddress& address) { + RTC_DCHECK_RUN_ON(thread_); + std::string if_name; + if (bind_using_ifname_) { + Network* net = GetNetworkFromAddress(address); + if (net != nullptr) { + if_name = net->name(); + } + } + return network_monitor_->BindSocketToNetwork(socket_fd, address, if_name); +} + Network::Network(const std::string& name, const std::string& desc, const IPAddress& prefix, diff --git a/rtc_base/network.h b/rtc_base/network.h index 3107b728d7..8b6b6235fa 100644 --- a/rtc_base/network.h +++ b/rtc_base/network.h @@ -19,12 +19,12 @@ #include #include +#include "api/sequence_checker.h" #include "rtc_base/ip_address.h" #include "rtc_base/mdns_responder_interface.h" #include "rtc_base/message_handler.h" #include "rtc_base/network_monitor.h" #include "rtc_base/network_monitor_factory.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/rtc_export.h" #include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/thread_annotations.h" @@ -194,11 +194,11 @@ class RTC_EXPORT NetworkManagerBase : public NetworkManager { void set_default_local_addresses(const IPAddress& ipv4, const IPAddress& ipv6); + Network* GetNetworkFromAddress(const rtc::IPAddress& ip) const; + private: friend class NetworkTest; - Network* GetNetworkFromAddress(const rtc::IPAddress& ip) const; - EnumerationPermission enumeration_permission_; NetworkList networks_; @@ -225,6 +225,7 @@ class RTC_EXPORT NetworkManagerBase : public NetworkManager { // of networks using OS APIs. class RTC_EXPORT BasicNetworkManager : public NetworkManagerBase, public MessageHandlerAutoCleanup, + public NetworkBinderInterface, public sigslot::has_slots<> { public: BasicNetworkManager(); @@ -248,6 +249,15 @@ class RTC_EXPORT BasicNetworkManager : public NetworkManagerBase, network_ignore_list_ = list; } + // Bind a socket to interface that ip address belong to. + // Implementation look up interface name and calls + // BindSocketToNetwork on NetworkMonitor. + // The interface name is needed as e.g ipv4 over ipv6 addresses + // are not exposed using Android functions, but it is possible + // bind an ipv4 address to the interface. + NetworkBindingResult BindSocketToNetwork(int socket_fd, + const IPAddress& address) override; + protected: #if defined(WEBRTC_POSIX) // Separated from CreateNetworks for tests. @@ -293,7 +303,8 @@ class RTC_EXPORT BasicNetworkManager : public NetworkManagerBase, nullptr; std::unique_ptr network_monitor_ RTC_GUARDED_BY(thread_); - bool allow_mac_based_ipv6_ = false; + bool allow_mac_based_ipv6_ RTC_GUARDED_BY(thread_) = false; + bool bind_using_ifname_ RTC_GUARDED_BY(thread_) = false; }; // Represents a Unix-type network interface, with a name and single address. diff --git a/rtc_base/network_monitor.h b/rtc_base/network_monitor.h index 4a3002f427..dddc2f60f4 100644 --- a/rtc_base/network_monitor.h +++ b/rtc_base/network_monitor.h @@ -36,6 +36,8 @@ enum class NetworkPreference { const char* NetworkPreferenceToString(NetworkPreference preference); +// This interface is set onto a socket server, +// where only the ip address is known at the time of binding. class NetworkBinderInterface { public: // Binds a socket to the network that is attached to |address| so that all @@ -83,6 +85,19 @@ class NetworkMonitorInterface { virtual NetworkPreference GetNetworkPreference( const std::string& interface_name) = 0; + // Does |this| NetworkMonitorInterface implement BindSocketToNetwork? + // Only Android returns true. + virtual bool SupportsBindSocketToNetwork() const { return false; } + + // Bind a socket to an interface specified by ip address and/or interface + // name. Only implemented on Android. + virtual NetworkBindingResult BindSocketToNetwork( + int socket_fd, + const IPAddress& address, + const std::string& interface_name) { + return NetworkBindingResult::NOT_IMPLEMENTED; + } + // Is this interface available to use? WebRTC shouldn't attempt to use it if // this returns false. // diff --git a/rtc_base/network_unittest.cc b/rtc_base/network_unittest.cc index 73ddd81ce8..75856634ab 100644 --- a/rtc_base/network_unittest.cc +++ b/rtc_base/network_unittest.cc @@ -76,9 +76,35 @@ class FakeNetworkMonitor : public NetworkMonitorInterface { unavailable_adapters_ = unavailable_adapters; } + bool SupportsBindSocketToNetwork() const override { return true; } + + NetworkBindingResult BindSocketToNetwork( + int socket_fd, + const IPAddress& address, + const std::string& if_name) override { + if (absl::c_count(addresses_, address) > 0) { + return NetworkBindingResult::SUCCESS; + } + + for (auto const& iter : adapters_) { + if (if_name.find(iter) != std::string::npos) { + return NetworkBindingResult::SUCCESS; + } + } + return NetworkBindingResult::ADDRESS_NOT_FOUND; + } + + void set_ip_addresses(std::vector addresses) { + addresses_ = addresses; + } + + void set_adapters(std::vector adapters) { adapters_ = adapters; } + private: bool started_ = false; + std::vector adapters_; std::vector unavailable_adapters_; + std::vector addresses_; }; class FakeNetworkMonitorFactory : public NetworkMonitorFactory { @@ -1279,4 +1305,45 @@ TEST_F(NetworkTest, WebRTC_AllowMACBasedIPv6Address) { } #endif +#if defined(WEBRTC_POSIX) +TEST_F(NetworkTest, WebRTC_BindUsingInterfaceName) { + char if_name1[20] = "wlan0"; + char if_name2[20] = "v4-wlan0"; + ifaddrs* list = nullptr; + list = AddIpv6Address(list, if_name1, "1000:2000:3000:4000:0:0:0:1", + "FFFF:FFFF:FFFF:FFFF::", 0); + list = AddIpv4Address(list, if_name2, "192.168.0.2", "255.255.255.255"); + NetworkManager::NetworkList result; + + // Sanity check that both interfaces are included by default. + FakeNetworkMonitorFactory factory; + BasicNetworkManager manager(&factory); + manager.StartUpdating(); + CallConvertIfAddrs(manager, list, /*include_ignored=*/false, &result); + EXPECT_EQ(2u, result.size()); + ReleaseIfAddrs(list); + bool changed; + // This ensures we release the objects created in CallConvertIfAddrs. + MergeNetworkList(manager, result, &changed); + result.clear(); + + FakeNetworkMonitor* network_monitor = GetNetworkMonitor(manager); + + IPAddress ipv6; + EXPECT_TRUE(IPFromString("1000:2000:3000:4000:0:0:0:1", &ipv6)); + IPAddress ipv4; + EXPECT_TRUE(IPFromString("192.168.0.2", &ipv4)); + + // The network monitor only knwos about the ipv6 address, interface. + network_monitor->set_adapters({"wlan0"}); + network_monitor->set_ip_addresses({ipv6}); + EXPECT_EQ(manager.BindSocketToNetwork(/* fd */ 77, ipv6), + NetworkBindingResult::SUCCESS); + + // But it will bind anyway using string matching... + EXPECT_EQ(manager.BindSocketToNetwork(/* fd */ 77, ipv4), + NetworkBindingResult::SUCCESS); +} +#endif + } // namespace rtc diff --git a/rtc_base/openssl_adapter.cc b/rtc_base/openssl_adapter.cc index e5c2c42761..c381f04899 100644 --- a/rtc_base/openssl_adapter.cc +++ b/rtc_base/openssl_adapter.cc @@ -289,8 +289,8 @@ int OpenSSLAdapter::BeginSSL() { RTC_LOG(LS_INFO) << "OpenSSLAdapter::BeginSSL: " << ssl_host_name_; RTC_DCHECK(state_ == SSL_CONNECTING); - int err = 0; - BIO* bio = nullptr; + // Cleanup action to deal with on error cleanup a bit cleaner. + EarlyExitCatcher early_exit_catcher(*this); // First set up the context. We should either have a factory, with its own // pre-existing context, or be running standalone, in which case we will @@ -301,26 +301,22 @@ int OpenSSLAdapter::BeginSSL() { } if (!ssl_ctx_) { - err = -1; - goto ssl_error; + return -1; } if (identity_ && !identity_->ConfigureIdentity(ssl_ctx_)) { - SSL_CTX_free(ssl_ctx_); - err = -1; - goto ssl_error; + return -1; } - bio = BIO_new_socket(socket_); + std::unique_ptr bio{BIO_new_socket(socket_), + ::BIO_free}; if (!bio) { - err = -1; - goto ssl_error; + return -1; } ssl_ = SSL_new(ssl_ctx_); if (!ssl_) { - err = -1; - goto ssl_error; + return -1; } SSL_set_app_data(ssl_, this); @@ -346,8 +342,7 @@ int OpenSSLAdapter::BeginSSL() { if (cached) { if (SSL_set_session(ssl_, cached) == 0) { RTC_LOG(LS_WARNING) << "Failed to apply SSL session from cache"; - err = -1; - goto ssl_error; + return -1; } RTC_LOG(LS_INFO) << "Attempting to resume SSL session to " @@ -377,24 +372,16 @@ int OpenSSLAdapter::BeginSSL() { // Now that the initial config is done, transfer ownership of |bio| to the // SSL object. If ContinueSSL() fails, the bio will be freed in Cleanup(). - SSL_set_bio(ssl_, bio, bio); - bio = nullptr; + SSL_set_bio(ssl_, bio.get(), bio.get()); + bio.release(); // Do the connect. - err = ContinueSSL(); + int err = ContinueSSL(); if (err != 0) { - goto ssl_error; - } - - return err; - -ssl_error: - Cleanup(); - if (bio) { - BIO_free(bio); + return err; } - - return err; + early_exit_catcher.disable(); + return 0; } int OpenSSLAdapter::ContinueSSL() { @@ -981,6 +968,9 @@ SSL_CTX* OpenSSLAdapter::CreateContext(SSLMode mode, bool enable_cache) { SSL_CTX_set_custom_verify(ctx, SSL_VERIFY_PEER, SSLVerifyCallback); #else SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, SSLVerifyCallback); + // Verify certificate chains up to a depth of 4. This is not + // needed for DTLS-SRTP which uses self-signed certificates + // (so the depth is 0) but is required to support TURN/TLS. SSL_CTX_set_verify_depth(ctx, 4); #endif // Use defaults, but disable HMAC-SHA256 and HMAC-SHA384 ciphers @@ -1057,4 +1047,17 @@ OpenSSLAdapter* OpenSSLAdapterFactory::CreateAdapter(AsyncSocket* socket) { ssl_cert_verifier_); } +OpenSSLAdapter::EarlyExitCatcher::EarlyExitCatcher(OpenSSLAdapter& adapter_ptr) + : adapter_ptr_(adapter_ptr) {} + +void OpenSSLAdapter::EarlyExitCatcher::disable() { + disabled_ = true; +} + +OpenSSLAdapter::EarlyExitCatcher::~EarlyExitCatcher() { + if (!disabled_) { + adapter_ptr_.Cleanup(); + } +} + } // namespace rtc diff --git a/rtc_base/openssl_adapter.h b/rtc_base/openssl_adapter.h index 76b003a7dd..9b2a36e00f 100644 --- a/rtc_base/openssl_adapter.h +++ b/rtc_base/openssl_adapter.h @@ -89,6 +89,16 @@ class OpenSSLAdapter final : public SSLAdapter, void OnCloseEvent(AsyncSocket* socket, int err) override; private: + class EarlyExitCatcher { + public: + EarlyExitCatcher(OpenSSLAdapter& adapter_ptr); + void disable(); + ~EarlyExitCatcher(); + + private: + bool disabled_ = false; + OpenSSLAdapter& adapter_ptr_; + }; enum SSLState { SSL_NONE, SSL_WAIT, @@ -202,6 +212,10 @@ class OpenSSLAdapterFactory : public SSLAdapterFactory { friend class OpenSSLAdapter; }; +// The EarlyExitCatcher is responsible for calling OpenSSLAdapter::Cleanup on +// destruction. By doing this we have scoped cleanup which can be disabled if +// there were no errors, aka early exits. + std::string TransformAlpnProtocols(const std::vector& protos); } // namespace rtc diff --git a/rtc_base/openssl_certificate.cc b/rtc_base/openssl_certificate.cc index bd9bb04fd4..802787dcfb 100644 --- a/rtc_base/openssl_certificate.cc +++ b/rtc_base/openssl_certificate.cc @@ -59,27 +59,30 @@ static X509* MakeCertificate(EVP_PKEY* pkey, const SSLIdentityParams& params) { RTC_LOG(LS_INFO) << "Making certificate for " << params.common_name; ASN1_INTEGER* asn1_serial_number = nullptr; - BIGNUM* serial_number = nullptr; - X509* x509 = nullptr; - X509_NAME* name = nullptr; + std::unique_ptr serial_number{nullptr, + ::BN_free}; + std::unique_ptr x509{nullptr, ::X509_free}; + std::unique_ptr name{ + nullptr, ::X509_NAME_free}; time_t epoch_off = 0; // Time offset since epoch. - - if ((x509 = X509_new()) == nullptr) { - goto error; + x509.reset(X509_new()); + if (x509 == nullptr) { + return nullptr; } - if (!X509_set_pubkey(x509, pkey)) { - goto error; + if (!X509_set_pubkey(x509.get(), pkey)) { + return nullptr; } // serial number - temporary reference to serial number inside x509 struct - if ((serial_number = BN_new()) == nullptr || - !BN_pseudo_rand(serial_number, SERIAL_RAND_BITS, 0, 0) || - (asn1_serial_number = X509_get_serialNumber(x509)) == nullptr || - !BN_to_ASN1_INTEGER(serial_number, asn1_serial_number)) { - goto error; + serial_number.reset(BN_new()); + if (serial_number == nullptr || + !BN_pseudo_rand(serial_number.get(), SERIAL_RAND_BITS, 0, 0) || + (asn1_serial_number = X509_get_serialNumber(x509.get())) == nullptr || + !BN_to_ASN1_INTEGER(serial_number.get(), asn1_serial_number)) { + return nullptr; } // Set version to X509.V3 - if (!X509_set_version(x509, 2L)) { - goto error; + if (!X509_set_version(x509.get(), 2L)) { + return nullptr; } // There are a lot of possible components for the name entries. In @@ -89,31 +92,27 @@ static X509* MakeCertificate(EVP_PKEY* pkey, const SSLIdentityParams& params) { // arbitrary common_name. Note that this certificate goes out in // clear during SSL negotiation, so there may be a privacy issue in // putting anything recognizable here. - if ((name = X509_NAME_new()) == nullptr || - !X509_NAME_add_entry_by_NID(name, NID_commonName, MBSTRING_UTF8, + name.reset(X509_NAME_new()); + if (name == nullptr || + !X509_NAME_add_entry_by_NID(name.get(), NID_commonName, MBSTRING_UTF8, (unsigned char*)params.common_name.c_str(), -1, -1, 0) || - !X509_set_subject_name(x509, name) || !X509_set_issuer_name(x509, name)) { - goto error; + !X509_set_subject_name(x509.get(), name.get()) || + !X509_set_issuer_name(x509.get(), name.get())) { + return nullptr; } - if (!X509_time_adj(X509_get_notBefore(x509), params.not_before, &epoch_off) || - !X509_time_adj(X509_get_notAfter(x509), params.not_after, &epoch_off)) { - goto error; + if (!X509_time_adj(X509_get_notBefore(x509.get()), params.not_before, + &epoch_off) || + !X509_time_adj(X509_get_notAfter(x509.get()), params.not_after, + &epoch_off)) { + return nullptr; } - if (!X509_sign(x509, pkey, EVP_sha256())) { - goto error; + if (!X509_sign(x509.get(), pkey, EVP_sha256())) { + return nullptr; } - BN_free(serial_number); - X509_NAME_free(name); RTC_LOG(LS_INFO) << "Returning certificate"; - return x509; - -error: - BN_free(serial_number); - X509_NAME_free(name); - X509_free(x509); - return nullptr; + return x509.release(); } } // namespace diff --git a/rtc_base/openssl_stream_adapter.cc b/rtc_base/openssl_stream_adapter.cc index 63b8069e0e..aa0bc3d40c 100644 --- a/rtc_base/openssl_stream_adapter.cc +++ b/rtc_base/openssl_stream_adapter.cc @@ -288,7 +288,7 @@ bool ShouldAllowLegacyTLSProtocols() { OpenSSLStreamAdapter::OpenSSLStreamAdapter( std::unique_ptr stream) - : SSLStreamAdapter(std::move(stream)), + : stream_(std::move(stream)), owner_(rtc::Thread::Current()), state_(SSL_NONE), role_(SSL_CLIENT), @@ -300,7 +300,9 @@ OpenSSLStreamAdapter::OpenSSLStreamAdapter( ssl_max_version_(SSL_PROTOCOL_TLS_12), // Default is to support legacy TLS protocols. // This will be changed to default non-support in M82 or M83. - support_legacy_tls_protocols_flag_(ShouldAllowLegacyTLSProtocols()) {} + support_legacy_tls_protocols_flag_(ShouldAllowLegacyTLSProtocols()) { + stream_->SignalEvent.connect(this, &OpenSSLStreamAdapter::OnEvent); +} OpenSSLStreamAdapter::~OpenSSLStreamAdapter() { timeout_task_.Stop(); @@ -519,7 +521,7 @@ int OpenSSLStreamAdapter::StartSSL() { return -1; } - if (StreamAdapterInterface::GetState() != SS_OPEN) { + if (stream_->GetState() != SS_OPEN) { state_ = SSL_WAIT; return 0; } @@ -561,7 +563,7 @@ StreamResult OpenSSLStreamAdapter::Write(const void* data, switch (state_) { case SSL_NONE: // pass-through in clear text - return StreamAdapterInterface::Write(data, data_len, written, error); + return stream_->Write(data, data_len, written, error); case SSL_WAIT: case SSL_CONNECTING: @@ -629,7 +631,7 @@ StreamResult OpenSSLStreamAdapter::Read(void* data, switch (state_) { case SSL_NONE: // pass-through in clear text - return StreamAdapterInterface::Read(data, data_len, read, error); + return stream_->Read(data, data_len, read, error); case SSL_WAIT: case SSL_CONNECTING: return SR_BLOCK; @@ -733,7 +735,7 @@ void OpenSSLStreamAdapter::Close() { // When we're closed at SSL layer, also close the stream level which // performs necessary clean up. Otherwise, a new incoming packet after // this could overflow the stream buffer. - StreamAdapterInterface::Close(); + stream_->Close(); } StreamState OpenSSLStreamAdapter::GetState() const { @@ -757,7 +759,7 @@ void OpenSSLStreamAdapter::OnEvent(StreamInterface* stream, int err) { int events_to_signal = 0; int signal_error = 0; - RTC_DCHECK(stream == this->stream()); + RTC_DCHECK(stream == stream_.get()); if ((events & SE_OPEN)) { RTC_DLOG(LS_VERBOSE) << "OpenSSLStreamAdapter::OnEvent SE_OPEN"; @@ -809,7 +811,9 @@ void OpenSSLStreamAdapter::OnEvent(StreamInterface* stream, } if (events_to_signal) { - StreamAdapterInterface::OnEvent(stream, events_to_signal, signal_error); + // Note that the adapter presents itself as the origin of the stream events, + // since users of the adapter may not recognize the adapted object. + SignalEvent(this, events_to_signal, signal_error); } } @@ -830,7 +834,12 @@ void OpenSSLStreamAdapter::SetTimeout(int delay_ms) { if (flag->alive()) { RTC_DLOG(LS_INFO) << "DTLS timeout expired"; timeout_task_.Stop(); - DTLSv1_handle_timeout(ssl_); + int res = DTLSv1_handle_timeout(ssl_); + if (res > 0) { + RTC_LOG(LS_INFO) << "DTLS retransmission"; + } else if (res < 0) { + RTC_LOG(LS_INFO) << "DTLSv1_handle_timeout() return -1"; + } ContinueSSL(); } else { RTC_NOTREACHED(); @@ -854,7 +863,7 @@ int OpenSSLStreamAdapter::BeginSSL() { return -1; } - bio = BIO_new_stream(static_cast(stream())); + bio = BIO_new_stream(stream_.get()); if (!bio) { return -1; } @@ -912,8 +921,7 @@ int OpenSSLStreamAdapter::ContinueSSL() { // The caller of ContinueSSL may be the same object listening for these // events and may not be prepared for reentrancy. // PostEvent(SE_OPEN | SE_READ | SE_WRITE, 0); - StreamAdapterInterface::OnEvent(stream(), SE_OPEN | SE_READ | SE_WRITE, - 0); + SignalEvent(this, SE_OPEN | SE_READ | SE_WRITE, 0); } break; @@ -956,7 +964,7 @@ void OpenSSLStreamAdapter::Error(const char* context, ssl_error_code_ = err; Cleanup(alert); if (signal) { - StreamAdapterInterface::OnEvent(stream(), SE_CLOSE, err); + SignalEvent(this, SE_CLOSE, err); } } diff --git a/rtc_base/openssl_stream_adapter.h b/rtc_base/openssl_stream_adapter.h index a09737c024..58e15e3e6f 100644 --- a/rtc_base/openssl_stream_adapter.h +++ b/rtc_base/openssl_stream_adapter.h @@ -136,9 +136,6 @@ class OpenSSLStreamAdapter final : public SSLStreamAdapter { // using a fake clock. static void EnableTimeCallbackForTesting(); - protected: - void OnEvent(StreamInterface* stream, int events, int err) override; - private: enum SSLState { // Before calling one of the StartSSL methods, data flows @@ -151,6 +148,8 @@ class OpenSSLStreamAdapter final : public SSLStreamAdapter { SSL_CLOSED // Clean close }; + void OnEvent(StreamInterface* stream, int events, int err); + void PostEvent(int events, int err); void SetTimeout(int delay_ms); @@ -203,6 +202,8 @@ class OpenSSLStreamAdapter final : public SSLStreamAdapter { !peer_certificate_digest_value_.empty(); } + const std::unique_ptr stream_; + rtc::Thread* const owner_; webrtc::ScopedTaskSafety task_safety_; webrtc::RepeatingTaskHandle timeout_task_; diff --git a/rtc_base/operations_chain.h b/rtc_base/operations_chain.h index a7252d46f0..3dc5995114 100644 --- a/rtc_base/operations_chain.h +++ b/rtc_base/operations_chain.h @@ -20,11 +20,11 @@ #include "absl/types/optional.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "rtc_base/checks.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/ref_count.h" #include "rtc_base/ref_counted_object.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" namespace rtc { diff --git a/rtc_base/operations_chain_unittest.cc b/rtc_base/operations_chain_unittest.cc index 5f183e42cb..792a2c76ff 100644 --- a/rtc_base/operations_chain_unittest.cc +++ b/rtc_base/operations_chain_unittest.cc @@ -16,7 +16,6 @@ #include #include -#include "rtc_base/bind.h" #include "rtc_base/event.h" #include "rtc_base/gunit.h" #include "rtc_base/thread.h" diff --git a/rtc_base/physical_socket_server.cc b/rtc_base/physical_socket_server.cc index 3cb7c2008c..7904548041 100644 --- a/rtc_base/physical_socket_server.cc +++ b/rtc_base/physical_socket_server.cc @@ -48,6 +48,7 @@ #include "rtc_base/logging.h" #include "rtc_base/network_monitor.h" #include "rtc_base/null_socket_server.h" +#include "rtc_base/synchronization/mutex.h" #include "rtc_base/time_utils.h" #if defined(WEBRTC_LINUX) @@ -119,14 +120,6 @@ class ScopedSetTrue { namespace rtc { -std::unique_ptr SocketServer::CreateDefault() { -#if defined(__native_client__) - return std::unique_ptr(new rtc::NullSocketServer); -#else - return std::unique_ptr(new rtc::PhysicalSocketServer); -#endif -} - PhysicalSocket::PhysicalSocket(PhysicalSocketServer* ss, SOCKET s) : ss_(ss), s_(s), @@ -281,12 +274,12 @@ int PhysicalSocket::DoConnect(const SocketAddress& connect_addr) { } int PhysicalSocket::GetError() const { - CritScope cs(&crit_); + webrtc::MutexLock lock(&mutex_); return error_; } void PhysicalSocket::SetError(int error) { - CritScope cs(&crit_); + webrtc::MutexLock lock(&mutex_); error_ = error; } @@ -767,21 +760,14 @@ uint32_t SocketDispatcher::GetRequestedEvents() { return enabled_events(); } -void SocketDispatcher::OnPreEvent(uint32_t ff) { +#if defined(WEBRTC_WIN) + +void SocketDispatcher::OnEvent(uint32_t ff, int err) { if ((ff & DE_CONNECT) != 0) state_ = CS_CONNECTED; -#if defined(WEBRTC_WIN) -// We set CS_CLOSED from CheckSignalClose. -#elif defined(WEBRTC_POSIX) - if ((ff & DE_CLOSE) != 0) - state_ = CS_CLOSED; -#endif -} - -#if defined(WEBRTC_WIN) + // We set CS_CLOSED from CheckSignalClose. -void SocketDispatcher::OnEvent(uint32_t ff, int err) { int cache_id = id_; // Make sure we deliver connect/accept first. Otherwise, consumers may see // something like a READ followed by a CONNECT, which would be odd. @@ -816,6 +802,12 @@ void SocketDispatcher::OnEvent(uint32_t ff, int err) { #elif defined(WEBRTC_POSIX) void SocketDispatcher::OnEvent(uint32_t ff, int err) { + if ((ff & DE_CONNECT) != 0) + state_ = CS_CONNECTED; + + if ((ff & DE_CLOSE) != 0) + state_ = CS_CLOSED; + #if defined(WEBRTC_USE_EPOLL) // Remember currently enabled events so we can combine multiple changes // into one update call later. @@ -927,22 +919,32 @@ int SocketDispatcher::Close() { } #if defined(WEBRTC_POSIX) -class EventDispatcher : public Dispatcher { +// Sets the value of a boolean value to false when signaled. +class Signaler : public Dispatcher { public: - EventDispatcher(PhysicalSocketServer* ss) : ss_(ss), fSignaled_(false) { - if (pipe(afd_) < 0) - RTC_LOG(LERROR) << "pipe failed"; + Signaler(PhysicalSocketServer* ss, bool& flag_to_clear) + : ss_(ss), + afd_([] { + std::array afd = {-1, -1}; + + if (pipe(afd.data()) < 0) { + RTC_LOG(LERROR) << "pipe failed"; + } + return afd; + }()), + fSignaled_(false), + flag_to_clear_(flag_to_clear) { ss_->Add(this); } - ~EventDispatcher() override { + ~Signaler() override { ss_->Remove(this); close(afd_[0]); close(afd_[1]); } virtual void Signal() { - CritScope cs(&crit_); + webrtc::MutexLock lock(&mutex_); if (!fSignaled_) { const uint8_t b[1] = {0}; const ssize_t res = write(afd_[1], b, sizeof(b)); @@ -953,30 +955,30 @@ class EventDispatcher : public Dispatcher { uint32_t GetRequestedEvents() override { return DE_READ; } - void OnPreEvent(uint32_t ff) override { + void OnEvent(uint32_t ff, int err) override { // It is not possible to perfectly emulate an auto-resetting event with // pipes. This simulates it by resetting before the event is handled. - CritScope cs(&crit_); + webrtc::MutexLock lock(&mutex_); if (fSignaled_) { uint8_t b[4]; // Allow for reading more than 1 byte, but expect 1. const ssize_t res = read(afd_[0], b, sizeof(b)); RTC_DCHECK_EQ(1, res); fSignaled_ = false; } + flag_to_clear_ = false; } - void OnEvent(uint32_t ff, int err) override { RTC_NOTREACHED(); } - int GetDescriptor() override { return afd_[0]; } bool IsDescriptorClosed() override { return false; } private: - PhysicalSocketServer* ss_; - int afd_[2]; - bool fSignaled_; - RecursiveCriticalSection crit_; + PhysicalSocketServer* const ss_; + const std::array afd_; + bool fSignaled_ RTC_GUARDED_BY(mutex_); + webrtc::Mutex mutex_; + bool& flag_to_clear_; }; #endif // WEBRTC_POSIX @@ -995,16 +997,18 @@ static uint32_t FlagsToEvents(uint32_t events) { return ffFD; } -class EventDispatcher : public Dispatcher { +// Sets the value of a boolean value to false when signaled. +class Signaler : public Dispatcher { public: - EventDispatcher(PhysicalSocketServer* ss) : ss_(ss) { + Signaler(PhysicalSocketServer* ss, bool& flag_to_clear) + : ss_(ss), flag_to_clear_(flag_to_clear) { hev_ = WSACreateEvent(); if (hev_) { ss_->Add(this); } } - ~EventDispatcher() override { + ~Signaler() override { if (hev_ != nullptr) { ss_->Remove(this); WSACloseEvent(hev_); @@ -1019,9 +1023,10 @@ class EventDispatcher : public Dispatcher { uint32_t GetRequestedEvents() override { return 0; } - void OnPreEvent(uint32_t ff) override { WSAResetEvent(hev_); } - - void OnEvent(uint32_t ff, int err) override {} + void OnEvent(uint32_t ff, int err) override { + WSAResetEvent(hev_); + flag_to_clear_ = false; + } WSAEVENT GetWSAEvent() override { return hev_; } @@ -1032,24 +1037,10 @@ class EventDispatcher : public Dispatcher { private: PhysicalSocketServer* ss_; WSAEVENT hev_; + bool& flag_to_clear_; }; #endif // WEBRTC_WIN -// Sets the value of a boolean value to false when signaled. -class Signaler : public EventDispatcher { - public: - Signaler(PhysicalSocketServer* ss, bool* pf) : EventDispatcher(ss), pf_(pf) {} - ~Signaler() override {} - - void OnEvent(uint32_t ff, int err) override { - if (pf_) - *pf_ = false; - } - - private: - bool* pf_; -}; - PhysicalSocketServer::PhysicalSocketServer() : #if defined(WEBRTC_USE_EPOLL) @@ -1069,7 +1060,8 @@ PhysicalSocketServer::PhysicalSocketServer() // Note that -1 == INVALID_SOCKET, the alias used by later checks. } #endif - signal_wakeup_ = new Signaler(this, &fWait_); + // The `fWait_` flag to be cleared by the Signaler. + signal_wakeup_ = new Signaler(this, fWait_); } PhysicalSocketServer::~PhysicalSocketServer() { @@ -1237,7 +1229,6 @@ static void ProcessEvents(Dispatcher* dispatcher, // Tell the descriptor about the event. if (ff != 0) { - dispatcher->OnPreEvent(ff); dispatcher->OnEvent(ff, errcode); } } @@ -1641,7 +1632,6 @@ bool PhysicalSocketServer::Wait(int cmsWait, bool process_io) { continue; } Dispatcher* disp = dispatcher_by_key_.at(key); - disp->OnPreEvent(0); disp->OnEvent(0, 0); } else if (process_io) { // Iterate only on the dispatchers whose sockets were passed into @@ -1712,7 +1702,6 @@ bool PhysicalSocketServer::Wait(int cmsWait, bool process_io) { errcode = wsaEvents.iErrorCode[FD_CLOSE_BIT]; } if (ff != 0) { - disp->OnPreEvent(ff); disp->OnEvent(ff, errcode); } } diff --git a/rtc_base/physical_socket_server.h b/rtc_base/physical_socket_server.h index cc21a67b1a..4b7957eb20 100644 --- a/rtc_base/physical_socket_server.h +++ b/rtc_base/physical_socket_server.h @@ -21,9 +21,11 @@ #include #include +#include "rtc_base/async_resolver.h" +#include "rtc_base/async_resolver_interface.h" #include "rtc_base/deprecated/recursive_critical_section.h" -#include "rtc_base/net_helpers.h" #include "rtc_base/socket_server.h" +#include "rtc_base/synchronization/mutex.h" #include "rtc_base/system/rtc_export.h" #include "rtc_base/thread_annotations.h" @@ -48,7 +50,6 @@ class Dispatcher { public: virtual ~Dispatcher() {} virtual uint32_t GetRequestedEvents() = 0; - virtual void OnPreEvent(uint32_t ff) = 0; virtual void OnEvent(uint32_t ff, int err) = 0; #if defined(WEBRTC_WIN) virtual WSAEVENT GetWSAEvent() = 0; @@ -202,8 +203,8 @@ class PhysicalSocket : public AsyncSocket, public sigslot::has_slots<> { SOCKET s_; bool udp_; int family_ = 0; - RecursiveCriticalSection crit_; - int error_ RTC_GUARDED_BY(crit_); + mutable webrtc::Mutex mutex_; + int error_ RTC_GUARDED_BY(mutex_); ConnState state_; AsyncResolver* resolver_; @@ -236,7 +237,6 @@ class SocketDispatcher : public Dispatcher, public PhysicalSocket { #endif uint32_t GetRequestedEvents() override; - void OnPreEvent(uint32_t ff) override; void OnEvent(uint32_t ff, int err) override; int Close() override; diff --git a/rtc_base/physical_socket_server_unittest.cc b/rtc_base/physical_socket_server_unittest.cc index 648f39701a..3762762f85 100644 --- a/rtc_base/physical_socket_server_unittest.cc +++ b/rtc_base/physical_socket_server_unittest.cc @@ -18,6 +18,7 @@ #include "rtc_base/gunit.h" #include "rtc_base/ip_address.h" #include "rtc_base/logging.h" +#include "rtc_base/net_helpers.h" #include "rtc_base/network_monitor.h" #include "rtc_base/socket_unittest.h" #include "rtc_base/test_utils.h" diff --git a/rtc_base/platform_thread.cc b/rtc_base/platform_thread.cc index 8a5f2c9d6d..6d369d747e 100644 --- a/rtc_base/platform_thread.cc +++ b/rtc_base/platform_thread.cc @@ -10,131 +10,37 @@ #include "rtc_base/platform_thread.h" +#include +#include + #if !defined(WEBRTC_WIN) #include #endif -#include -#include - -#include #include "rtc_base/checks.h" namespace rtc { namespace { -#if !defined(WEBRTC_WIN) -struct ThreadAttributes { - ThreadAttributes() { pthread_attr_init(&attr); } - ~ThreadAttributes() { pthread_attr_destroy(&attr); } - pthread_attr_t* operator&() { return &attr; } - pthread_attr_t attr; -}; -#endif // defined(WEBRTC_WIN) -} // namespace - -PlatformThread::PlatformThread(ThreadRunFunction func, - void* obj, - absl::string_view thread_name, - ThreadPriority priority /*= kNormalPriority*/) - : run_function_(func), priority_(priority), obj_(obj), name_(thread_name) { - RTC_DCHECK(func); - RTC_DCHECK(!name_.empty()); - // TODO(tommi): Consider lowering the limit to 15 (limit on Linux). - RTC_DCHECK(name_.length() < 64); - spawned_thread_checker_.Detach(); -} - -PlatformThread::~PlatformThread() { - RTC_DCHECK(thread_checker_.IsCurrent()); -#if defined(WEBRTC_WIN) - RTC_DCHECK(!thread_); - RTC_DCHECK(!thread_id_); -#endif // defined(WEBRTC_WIN) -} - -#if defined(WEBRTC_WIN) -DWORD WINAPI PlatformThread::StartThread(void* param) { - // The GetLastError() function only returns valid results when it is called - // after a Win32 API function that returns a "failed" result. A crash dump - // contains the result from GetLastError() and to make sure it does not - // falsely report a Windows error we call SetLastError here. - ::SetLastError(ERROR_SUCCESS); - static_cast(param)->Run(); - return 0; -} -#else -void* PlatformThread::StartThread(void* param) { - static_cast(param)->Run(); - return 0; -} -#endif // defined(WEBRTC_WIN) - -void PlatformThread::Start() { - RTC_DCHECK(thread_checker_.IsCurrent()); - RTC_DCHECK(!thread_) << "Thread already started?"; -#if defined(WEBRTC_WIN) - // See bug 2902 for background on STACK_SIZE_PARAM_IS_A_RESERVATION. - // Set the reserved stack stack size to 1M, which is the default on Windows - // and Linux. - thread_ = ::CreateThread(nullptr, 1024 * 1024, &StartThread, this, - STACK_SIZE_PARAM_IS_A_RESERVATION, &thread_id_); - RTC_CHECK(thread_) << "CreateThread failed"; - RTC_DCHECK(thread_id_); -#else - ThreadAttributes attr; - // Set the stack stack size to 1M. - pthread_attr_setstacksize(&attr, 1024 * 1024); - RTC_CHECK_EQ(0, pthread_create(&thread_, &attr, &StartThread, this)); -#endif // defined(WEBRTC_WIN) -} -bool PlatformThread::IsRunning() const { - RTC_DCHECK(thread_checker_.IsCurrent()); #if defined(WEBRTC_WIN) - return thread_ != nullptr; -#else - return thread_ != 0; -#endif // defined(WEBRTC_WIN) -} - -PlatformThreadRef PlatformThread::GetThreadRef() const { -#if defined(WEBRTC_WIN) - return thread_id_; -#else - return thread_; -#endif // defined(WEBRTC_WIN) -} - -void PlatformThread::Stop() { - RTC_DCHECK(thread_checker_.IsCurrent()); - if (!IsRunning()) - return; - -#if defined(WEBRTC_WIN) - WaitForSingleObject(thread_, INFINITE); - CloseHandle(thread_); - thread_ = nullptr; - thread_id_ = 0; -#else - RTC_CHECK_EQ(0, pthread_join(thread_, nullptr)); - thread_ = 0; -#endif // defined(WEBRTC_WIN) - spawned_thread_checker_.Detach(); -} - -void PlatformThread::Run() { - // Attach the worker thread checker to this thread. - RTC_DCHECK(spawned_thread_checker_.IsCurrent()); - rtc::SetCurrentThreadName(name_.c_str()); - SetPriority(priority_); - run_function_(obj_); +int Win32PriorityFromThreadPriority(ThreadPriority priority) { + switch (priority) { + case ThreadPriority::kLow: + return THREAD_PRIORITY_BELOW_NORMAL; + case ThreadPriority::kNormal: + return THREAD_PRIORITY_NORMAL; + case ThreadPriority::kHigh: + return THREAD_PRIORITY_ABOVE_NORMAL; + case ThreadPriority::kRealtime: + return THREAD_PRIORITY_TIME_CRITICAL; + } } +#endif -bool PlatformThread::SetPriority(ThreadPriority priority) { - RTC_DCHECK(spawned_thread_checker_.IsCurrent()); - +bool SetPriority(ThreadPriority priority) { #if defined(WEBRTC_WIN) - return SetThreadPriority(thread_, priority) != FALSE; + return SetThreadPriority(GetCurrentThread(), + Win32PriorityFromThreadPriority(priority)) != FALSE; #elif defined(__native_client__) || defined(WEBRTC_FUCHSIA) // Setting thread priorities is not supported in NaCl or Fuchsia. return true; @@ -158,35 +64,148 @@ bool PlatformThread::SetPriority(ThreadPriority priority) { const int top_prio = max_prio - 1; const int low_prio = min_prio + 1; switch (priority) { - case kLowPriority: + case ThreadPriority::kLow: param.sched_priority = low_prio; break; - case kNormalPriority: + case ThreadPriority::kNormal: // The -1 ensures that the kHighPriority is always greater or equal to // kNormalPriority. param.sched_priority = (low_prio + top_prio - 1) / 2; break; - case kHighPriority: + case ThreadPriority::kHigh: param.sched_priority = std::max(top_prio - 2, low_prio); break; - case kHighestPriority: - param.sched_priority = std::max(top_prio - 1, low_prio); - break; - case kRealtimePriority: + case ThreadPriority::kRealtime: param.sched_priority = top_prio; break; } - return pthread_setschedparam(thread_, policy, ¶m) == 0; + return pthread_setschedparam(pthread_self(), policy, ¶m) == 0; #endif // defined(WEBRTC_WIN) } #if defined(WEBRTC_WIN) -bool PlatformThread::QueueAPC(PAPCFUNC function, ULONG_PTR data) { - RTC_DCHECK(thread_checker_.IsCurrent()); - RTC_DCHECK(IsRunning()); +DWORD WINAPI RunPlatformThread(void* param) { + // The GetLastError() function only returns valid results when it is called + // after a Win32 API function that returns a "failed" result. A crash dump + // contains the result from GetLastError() and to make sure it does not + // falsely report a Windows error we call SetLastError here. + ::SetLastError(ERROR_SUCCESS); + auto function = static_cast*>(param); + (*function)(); + delete function; + return 0; +} +#else +void* RunPlatformThread(void* param) { + auto function = static_cast*>(param); + (*function)(); + delete function; + return 0; +} +#endif // defined(WEBRTC_WIN) + +} // namespace + +PlatformThread::PlatformThread(Handle handle, bool joinable) + : handle_(handle), joinable_(joinable) {} + +PlatformThread::PlatformThread(PlatformThread&& rhs) + : handle_(rhs.handle_), joinable_(rhs.joinable_) { + rhs.handle_ = absl::nullopt; +} + +PlatformThread& PlatformThread::operator=(PlatformThread&& rhs) { + Finalize(); + handle_ = rhs.handle_; + joinable_ = rhs.joinable_; + rhs.handle_ = absl::nullopt; + return *this; +} + +PlatformThread::~PlatformThread() { + Finalize(); +} + +PlatformThread PlatformThread::SpawnJoinable( + std::function thread_function, + absl::string_view name, + ThreadAttributes attributes) { + return SpawnThread(std::move(thread_function), name, attributes, + /*joinable=*/true); +} + +PlatformThread PlatformThread::SpawnDetached( + std::function thread_function, + absl::string_view name, + ThreadAttributes attributes) { + return SpawnThread(std::move(thread_function), name, attributes, + /*joinable=*/false); +} + +absl::optional PlatformThread::GetHandle() const { + return handle_; +} - return QueueUserAPC(function, thread_, data) != FALSE; +#if defined(WEBRTC_WIN) +bool PlatformThread::QueueAPC(PAPCFUNC function, ULONG_PTR data) { + RTC_DCHECK(handle_.has_value()); + return handle_.has_value() ? QueueUserAPC(function, *handle_, data) != FALSE + : false; } #endif +void PlatformThread::Finalize() { + if (!handle_.has_value()) + return; +#if defined(WEBRTC_WIN) + if (joinable_) + WaitForSingleObject(*handle_, INFINITE); + CloseHandle(*handle_); +#else + if (joinable_) + RTC_CHECK_EQ(0, pthread_join(*handle_, nullptr)); +#endif + handle_ = absl::nullopt; +} + +PlatformThread PlatformThread::SpawnThread( + std::function thread_function, + absl::string_view name, + ThreadAttributes attributes, + bool joinable) { + RTC_DCHECK(thread_function); + RTC_DCHECK(!name.empty()); + // TODO(tommi): Consider lowering the limit to 15 (limit on Linux). + RTC_DCHECK(name.length() < 64); + auto start_thread_function_ptr = + new std::function([thread_function = std::move(thread_function), + name = std::string(name), attributes] { + rtc::SetCurrentThreadName(name.c_str()); + SetPriority(attributes.priority); + thread_function(); + }); +#if defined(WEBRTC_WIN) + // See bug 2902 for background on STACK_SIZE_PARAM_IS_A_RESERVATION. + // Set the reserved stack stack size to 1M, which is the default on Windows + // and Linux. + DWORD thread_id = 0; + PlatformThread::Handle handle = ::CreateThread( + nullptr, 1024 * 1024, &RunPlatformThread, start_thread_function_ptr, + STACK_SIZE_PARAM_IS_A_RESERVATION, &thread_id); + RTC_CHECK(handle) << "CreateThread failed"; +#else + pthread_attr_t attr; + pthread_attr_init(&attr); + // Set the stack stack size to 1M. + pthread_attr_setstacksize(&attr, 1024 * 1024); + pthread_attr_setdetachstate( + &attr, joinable ? PTHREAD_CREATE_JOINABLE : PTHREAD_CREATE_DETACHED); + PlatformThread::Handle handle; + RTC_CHECK_EQ(0, pthread_create(&handle, &attr, &RunPlatformThread, + start_thread_function_ptr)); + pthread_attr_destroy(&attr); +#endif // defined(WEBRTC_WIN) + return PlatformThread(handle, joinable); +} + } // namespace rtc diff --git a/rtc_base/platform_thread.h b/rtc_base/platform_thread.h index 4968de9ee5..11ccfae3d0 100644 --- a/rtc_base/platform_thread.h +++ b/rtc_base/platform_thread.h @@ -11,92 +11,101 @@ #ifndef RTC_BASE_PLATFORM_THREAD_H_ #define RTC_BASE_PLATFORM_THREAD_H_ -#ifndef WEBRTC_WIN -#include -#endif +#include #include #include "absl/strings/string_view.h" -#include "rtc_base/constructor_magic.h" +#include "absl/types/optional.h" #include "rtc_base/platform_thread_types.h" -#include "rtc_base/thread_checker.h" namespace rtc { -// Callback function that the spawned thread will enter once spawned. -typedef void (*ThreadRunFunction)(void*); +enum class ThreadPriority { + kLow = 1, + kNormal, + kHigh, + kRealtime, +}; -enum ThreadPriority { -#ifdef WEBRTC_WIN - kLowPriority = THREAD_PRIORITY_BELOW_NORMAL, - kNormalPriority = THREAD_PRIORITY_NORMAL, - kHighPriority = THREAD_PRIORITY_ABOVE_NORMAL, - kHighestPriority = THREAD_PRIORITY_HIGHEST, - kRealtimePriority = THREAD_PRIORITY_TIME_CRITICAL -#else - kLowPriority = 1, - kNormalPriority = 2, - kHighPriority = 3, - kHighestPriority = 4, - kRealtimePriority = 5 -#endif +struct ThreadAttributes { + ThreadPriority priority = ThreadPriority::kNormal; + ThreadAttributes& SetPriority(ThreadPriority priority_param) { + priority = priority_param; + return *this; + } }; -// Represents a simple worker thread. The implementation must be assumed -// to be single threaded, meaning that all methods of the class, must be -// called from the same thread, including instantiation. -class PlatformThread { +// Represents a simple worker thread. +class PlatformThread final { public: - PlatformThread(ThreadRunFunction func, - void* obj, - absl::string_view thread_name, - ThreadPriority priority = kNormalPriority); + // Handle is the base platform thread handle. +#if defined(WEBRTC_WIN) + using Handle = HANDLE; +#else + using Handle = pthread_t; +#endif // defined(WEBRTC_WIN) + // This ctor creates the PlatformThread with an unset handle (returning true + // in empty()) and is provided for convenience. + // TODO(bugs.webrtc.org/12727) Look into if default and move support can be + // removed. + PlatformThread() = default; + + // Moves |rhs| into this, storing an empty state in |rhs|. + // TODO(bugs.webrtc.org/12727) Look into if default and move support can be + // removed. + PlatformThread(PlatformThread&& rhs); + + // Moves |rhs| into this, storing an empty state in |rhs|. + // TODO(bugs.webrtc.org/12727) Look into if default and move support can be + // removed. + PlatformThread& operator=(PlatformThread&& rhs); + + // For a PlatformThread that's been spawned joinable, the destructor suspends + // the calling thread until the created thread exits unless the thread has + // already exited. virtual ~PlatformThread(); - const std::string& name() const { return name_; } - - // Spawns a thread and tries to set thread priority according to the priority - // from when CreateThread was called. - void Start(); + // Finalizes any allocated resources. + // For a PlatformThread that's been spawned joinable, Finalize() suspends + // the calling thread until the created thread exits unless the thread has + // already exited. + // empty() returns true after completion. + void Finalize(); + + // Returns true if default constructed, moved from, or Finalize()ed. + bool empty() const { return !handle_.has_value(); } + + // Creates a started joinable thread which will be joined when the returned + // PlatformThread destructs or Finalize() is called. + static PlatformThread SpawnJoinable( + std::function thread_function, + absl::string_view name, + ThreadAttributes attributes = ThreadAttributes()); + + // Creates a started detached thread. The caller has to use external + // synchronization as nothing is provided by the PlatformThread construct. + static PlatformThread SpawnDetached( + std::function thread_function, + absl::string_view name, + ThreadAttributes attributes = ThreadAttributes()); + + // Returns the base platform thread handle of this thread. + absl::optional GetHandle() const; - bool IsRunning() const; - - // Returns an identifier for the worker thread that can be used to do - // thread checks. - PlatformThreadRef GetThreadRef() const; - - // Stops (joins) the spawned thread. - void Stop(); - - protected: #if defined(WEBRTC_WIN) - // Exposed to derived classes to allow for special cases specific to Windows. + // Queue a Windows APC function that runs when the thread is alertable. bool QueueAPC(PAPCFUNC apc_function, ULONG_PTR data); #endif private: - void Run(); - bool SetPriority(ThreadPriority priority); - - ThreadRunFunction const run_function_ = nullptr; - const ThreadPriority priority_ = kNormalPriority; - void* const obj_; - // TODO(pbos): Make sure call sites use string literals and update to a const - // char* instead of a std::string. - const std::string name_; - rtc::ThreadChecker thread_checker_; - rtc::ThreadChecker spawned_thread_checker_; -#if defined(WEBRTC_WIN) - static DWORD WINAPI StartThread(void* param); - - HANDLE thread_ = nullptr; - DWORD thread_id_ = 0; -#else - static void* StartThread(void* param); - - pthread_t thread_ = 0; -#endif // defined(WEBRTC_WIN) - RTC_DISALLOW_COPY_AND_ASSIGN(PlatformThread); + PlatformThread(Handle handle, bool joinable); + static PlatformThread SpawnThread(std::function thread_function, + absl::string_view name, + ThreadAttributes attributes, + bool joinable); + + absl::optional handle_; + bool joinable_ = false; }; } // namespace rtc diff --git a/rtc_base/platform_thread_unittest.cc b/rtc_base/platform_thread_unittest.cc index a52e4cd9f5..b60d2131b7 100644 --- a/rtc_base/platform_thread_unittest.cc +++ b/rtc_base/platform_thread_unittest.cc @@ -10,51 +10,99 @@ #include "rtc_base/platform_thread.h" -#include "test/gtest.h" +#include "absl/types/optional.h" +#include "rtc_base/event.h" +#include "system_wrappers/include/sleep.h" +#include "test/gmock.h" namespace rtc { -namespace { -void NullRunFunction(void* obj) {} +TEST(PlatformThreadTest, DefaultConstructedIsEmpty) { + PlatformThread thread; + EXPECT_EQ(thread.GetHandle(), absl::nullopt); + EXPECT_TRUE(thread.empty()); +} -// Function that sets a boolean. -void SetFlagRunFunction(void* obj) { - bool* obj_as_bool = static_cast(obj); - *obj_as_bool = true; +TEST(PlatformThreadTest, StartFinalize) { + PlatformThread thread = PlatformThread::SpawnJoinable([] {}, "1"); + EXPECT_NE(thread.GetHandle(), absl::nullopt); + EXPECT_FALSE(thread.empty()); + thread.Finalize(); + EXPECT_TRUE(thread.empty()); + rtc::Event done; + thread = PlatformThread::SpawnDetached([&] { done.Set(); }, "2"); + EXPECT_FALSE(thread.empty()); + thread.Finalize(); + EXPECT_TRUE(thread.empty()); + done.Wait(30000); } -} // namespace +TEST(PlatformThreadTest, MovesEmpty) { + PlatformThread thread1; + PlatformThread thread2 = std::move(thread1); + EXPECT_TRUE(thread1.empty()); + EXPECT_TRUE(thread2.empty()); +} -TEST(PlatformThreadTest, StartStop) { - PlatformThread thread(&NullRunFunction, nullptr, "PlatformThreadTest"); - EXPECT_TRUE(thread.name() == "PlatformThreadTest"); - EXPECT_TRUE(thread.GetThreadRef() == 0); - thread.Start(); - EXPECT_TRUE(thread.GetThreadRef() != 0); - thread.Stop(); - EXPECT_TRUE(thread.GetThreadRef() == 0); +TEST(PlatformThreadTest, MovesHandles) { + PlatformThread thread1 = PlatformThread::SpawnJoinable([] {}, "1"); + PlatformThread thread2 = std::move(thread1); + EXPECT_TRUE(thread1.empty()); + EXPECT_FALSE(thread2.empty()); + rtc::Event done; + thread1 = PlatformThread::SpawnDetached([&] { done.Set(); }, "2"); + thread2 = std::move(thread1); + EXPECT_TRUE(thread1.empty()); + EXPECT_FALSE(thread2.empty()); + done.Wait(30000); } -TEST(PlatformThreadTest, StartStop2) { - PlatformThread thread1(&NullRunFunction, nullptr, "PlatformThreadTest1"); - PlatformThread thread2(&NullRunFunction, nullptr, "PlatformThreadTest2"); - EXPECT_TRUE(thread1.GetThreadRef() == thread2.GetThreadRef()); - thread1.Start(); - thread2.Start(); - EXPECT_TRUE(thread1.GetThreadRef() != thread2.GetThreadRef()); - thread2.Stop(); - thread1.Stop(); +TEST(PlatformThreadTest, + TwoThreadHandlesAreDifferentWhenStartedAndEqualWhenJoined) { + PlatformThread thread1 = PlatformThread(); + PlatformThread thread2 = PlatformThread(); + EXPECT_EQ(thread1.GetHandle(), thread2.GetHandle()); + thread1 = PlatformThread::SpawnJoinable([] {}, "1"); + thread2 = PlatformThread::SpawnJoinable([] {}, "2"); + EXPECT_NE(thread1.GetHandle(), thread2.GetHandle()); + thread1.Finalize(); + EXPECT_NE(thread1.GetHandle(), thread2.GetHandle()); + thread2.Finalize(); + EXPECT_EQ(thread1.GetHandle(), thread2.GetHandle()); } TEST(PlatformThreadTest, RunFunctionIsCalled) { bool flag = false; - PlatformThread thread(&SetFlagRunFunction, &flag, "RunFunctionIsCalled"); - thread.Start(); + PlatformThread::SpawnJoinable([&] { flag = true; }, "T"); + EXPECT_TRUE(flag); +} - // At this point, the flag may be either true or false. - thread.Stop(); +TEST(PlatformThreadTest, JoinsThread) { + // This test flakes if there are problems with the join implementation. + rtc::Event event; + PlatformThread::SpawnJoinable([&] { event.Set(); }, "T"); + EXPECT_TRUE(event.Wait(/*give_up_after_ms=*/0)); +} - // We expect the thread to have run at least once. +TEST(PlatformThreadTest, StopsBeforeDetachedThreadExits) { + // This test flakes if there are problems with the detached thread + // implementation. + bool flag = false; + rtc::Event thread_started; + rtc::Event thread_continue; + rtc::Event thread_exiting; + PlatformThread::SpawnDetached( + [&] { + thread_started.Set(); + thread_continue.Wait(Event::kForever); + flag = true; + thread_exiting.Set(); + }, + "T"); + thread_started.Wait(Event::kForever); + EXPECT_FALSE(flag); + thread_continue.Set(); + thread_exiting.Wait(Event::kForever); EXPECT_TRUE(flag); } diff --git a/rtc_base/random.cc b/rtc_base/random.cc index 5deb621727..5206b817f3 100644 --- a/rtc_base/random.cc +++ b/rtc_base/random.cc @@ -49,14 +49,14 @@ int32_t Random::Rand(int32_t low, int32_t high) { template <> float Random::Rand() { double result = NextOutput() - 1; - result = result / 0xFFFFFFFFFFFFFFFEull; + result = result / static_cast(0xFFFFFFFFFFFFFFFFull); return static_cast(result); } template <> double Random::Rand() { double result = NextOutput() - 1; - result = result / 0xFFFFFFFFFFFFFFFEull; + result = result / static_cast(0xFFFFFFFFFFFFFFFFull); return result; } @@ -72,8 +72,10 @@ double Random::Gaussian(double mean, double standard_deviation) { // in the range [1, 2^64-1]. Normally this behavior is a bit frustrating, // but here it is exactly what we need. const double kPi = 3.14159265358979323846; - double u1 = static_cast(NextOutput()) / 0xFFFFFFFFFFFFFFFFull; - double u2 = static_cast(NextOutput()) / 0xFFFFFFFFFFFFFFFFull; + double u1 = static_cast(NextOutput()) / + static_cast(0xFFFFFFFFFFFFFFFFull); + double u2 = static_cast(NextOutput()) / + static_cast(0xFFFFFFFFFFFFFFFFull); return mean + standard_deviation * sqrt(-2 * log(u1)) * cos(2 * kPi * u2); } diff --git a/rtc_base/rate_limiter_unittest.cc b/rtc_base/rate_limiter_unittest.cc index 8ebf8aa67b..eda644b4ca 100644 --- a/rtc_base/rate_limiter_unittest.cc +++ b/rtc_base/rate_limiter_unittest.cc @@ -127,10 +127,6 @@ class ThreadTask { rtc::Event end_signal_; }; -void RunTask(void* thread_task) { - reinterpret_cast(thread_task)->Run(); -} - TEST_F(RateLimitTest, MultiThreadedUsage) { // Simple sanity test, with different threads calling the various methods. // Runs a few simple tasks, each on its own thread, but coordinated with @@ -149,8 +145,8 @@ TEST_F(RateLimitTest, MultiThreadedUsage) { EXPECT_TRUE(rate_limiter_->SetWindowSize(kWindowSizeMs / 2)); } } set_window_size_task(rate_limiter.get()); - rtc::PlatformThread thread1(RunTask, &set_window_size_task, "Thread1"); - thread1.Start(); + auto thread1 = rtc::PlatformThread::SpawnJoinable( + [&set_window_size_task] { set_window_size_task.Run(); }, "Thread1"); class SetMaxRateTask : public ThreadTask { public: @@ -160,8 +156,8 @@ TEST_F(RateLimitTest, MultiThreadedUsage) { void DoRun() override { rate_limiter_->SetMaxRate(kMaxRateBps * 2); } } set_max_rate_task(rate_limiter.get()); - rtc::PlatformThread thread2(RunTask, &set_max_rate_task, "Thread2"); - thread2.Start(); + auto thread2 = rtc::PlatformThread::SpawnJoinable( + [&set_max_rate_task] { set_max_rate_task.Run(); }, "Thread2"); class UseRateTask : public ThreadTask { public: @@ -177,8 +173,8 @@ TEST_F(RateLimitTest, MultiThreadedUsage) { SimulatedClock* const clock_; } use_rate_task(rate_limiter.get(), &clock_); - rtc::PlatformThread thread3(RunTask, &use_rate_task, "Thread3"); - thread3.Start(); + auto thread3 = rtc::PlatformThread::SpawnJoinable( + [&use_rate_task] { use_rate_task.Run(); }, "Thread3"); set_window_size_task.start_signal_.Set(); EXPECT_TRUE(set_window_size_task.end_signal_.Wait(kMaxTimeoutMs)); @@ -191,10 +187,6 @@ TEST_F(RateLimitTest, MultiThreadedUsage) { // All rate consumed. EXPECT_FALSE(rate_limiter->TryUseRate(1)); - - thread1.Stop(); - thread2.Stop(); - thread3.Stop(); } } // namespace webrtc diff --git a/rtc_base/ref_counted_object.h b/rtc_base/ref_counted_object.h index ce18379d50..331132c569 100644 --- a/rtc_base/ref_counted_object.h +++ b/rtc_base/ref_counted_object.h @@ -13,6 +13,7 @@ #include #include +#include "api/scoped_refptr.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/ref_count.h" #include "rtc_base/ref_counter.h" @@ -33,9 +34,9 @@ class RefCountedObject : public T { std::forward(p1), std::forward(args)...) {} - virtual void AddRef() const { ref_count_.IncRef(); } + void AddRef() const override { ref_count_.IncRef(); } - virtual RefCountReleaseStatus Release() const { + RefCountReleaseStatus Release() const override { const auto status = ref_count_.DecRef(); if (status == RefCountReleaseStatus::kDroppedLastRef) { delete this; @@ -52,13 +53,146 @@ class RefCountedObject : public T { virtual bool HasOneRef() const { return ref_count_.HasOneRef(); } protected: - virtual ~RefCountedObject() {} + ~RefCountedObject() override {} mutable webrtc::webrtc_impl::RefCounter ref_count_{0}; RTC_DISALLOW_COPY_AND_ASSIGN(RefCountedObject); }; +template +class FinalRefCountedObject final : public T { + public: + using T::T; + // Until c++17 compilers are allowed not to inherit the default constructors. + // Thus the default constructors are forwarded explicitly. + FinalRefCountedObject() = default; + explicit FinalRefCountedObject(const T& other) : T(other) {} + explicit FinalRefCountedObject(T&& other) : T(std::move(other)) {} + FinalRefCountedObject(const FinalRefCountedObject&) = delete; + FinalRefCountedObject& operator=(const FinalRefCountedObject&) = delete; + + void AddRef() const { ref_count_.IncRef(); } + void Release() const { + if (ref_count_.DecRef() == RefCountReleaseStatus::kDroppedLastRef) { + delete this; + } + } + bool HasOneRef() const { return ref_count_.HasOneRef(); } + + private: + ~FinalRefCountedObject() = default; + + mutable webrtc::webrtc_impl::RefCounter ref_count_{0}; +}; + +// General utilities for constructing a reference counted class and the +// appropriate reference count implementation for that class. +// +// These utilities select either the `RefCountedObject` implementation or +// `FinalRefCountedObject` depending on whether the to-be-shared class is +// derived from the RefCountInterface interface or not (respectively). + +// `make_ref_counted`: +// +// Use this when you want to construct a reference counted object of type T and +// get a `scoped_refptr<>` back. Example: +// +// auto p = make_ref_counted("bar", 123); +// +// For a class that inherits from RefCountInterface, this is equivalent to: +// +// auto p = scoped_refptr(new RefCountedObject("bar", 123)); +// +// If the class does not inherit from RefCountInterface, the example is +// equivalent to: +// +// auto p = scoped_refptr>( +// new FinalRefCountedObject("bar", 123)); +// +// In these cases, `make_ref_counted` reduces the amount of boilerplate code but +// also helps with the most commonly intended usage of RefCountedObject whereby +// methods for reference counting, are virtual and designed to satisfy the need +// of an interface. When such a need does not exist, it is more efficient to use +// the `FinalRefCountedObject` template, which does not add the vtable overhead. +// +// Note that in some cases, using RefCountedObject directly may still be what's +// needed. + +// `make_ref_counted` for classes that are convertible to RefCountInterface. +template < + typename T, + typename... Args, + typename std::enable_if::value, + T>::type* = nullptr> +scoped_refptr make_ref_counted(Args&&... args) { + return new RefCountedObject(std::forward(args)...); +} + +// `make_ref_counted` for complete classes that are not convertible to +// RefCountInterface. +template < + typename T, + typename... Args, + typename std::enable_if::value, + T>::type* = nullptr> +scoped_refptr> make_ref_counted(Args&&... args) { + return new FinalRefCountedObject(std::forward(args)...); +} + +// `Ref<>`, `Ref<>::Type` and `Ref<>::Ptr`: +// +// `Ref` is a type declaring utility that is compatible with `make_ref_counted` +// and can be used in classes and methods where it's more convenient (or +// readable) to have the compiler figure out the fully fleshed out type for a +// class rather than spell it out verbatim in all places the type occurs (which +// can mean maintenance work if the class layout changes). +// +// Usage examples: +// +// If you want to declare the parameter type that's always compatible with +// this code: +// +// Bar(make_ref_counted()); +// +// You can use `Ref<>::Ptr` to declare a compatible scoped_refptr type: +// +// void Bar(Ref::Ptr p); +// +// This might be more practically useful in templates though. +// +// In rare cases you might need to be able to declare a parameter that's fully +// compatible with the reference counted T type - and just using T* is not +// enough. To give a code example, we can declare a function, `Foo` that is +// compatible with this code: +// auto p = make_ref_counted(); +// Foo(p.get()); +// +// void Foo(Ref::Type* foo_ptr); +// +// Alternatively this would be: +// void Foo(Foo* foo_ptr); +// or +// void Foo(FinalRefCountedObject* foo_ptr); + +// Declares the approprate reference counted type for T depending on whether +// T is convertible to RefCountInterface or not. +// For classes that are convertible, the type will simply be T. +// For classes that cannot be converted to RefCountInterface, the type will be +// FinalRefCountedObject. +// This is most useful for declaring a scoped_refptr instance for a class +// that may or may not implement a virtual reference counted interface: +// * scoped_refptr::Type> my_ptr; +template +struct Ref { + typedef typename std::conditional< + std::is_convertible::value, + T, + FinalRefCountedObject>::type Type; + + typedef scoped_refptr Ptr; +}; + } // namespace rtc #endif // RTC_BASE_REF_COUNTED_OBJECT_H_ diff --git a/rtc_base/ref_counted_object_unittest.cc b/rtc_base/ref_counted_object_unittest.cc index eacf731782..ab7bb09191 100644 --- a/rtc_base/ref_counted_object_unittest.cc +++ b/rtc_base/ref_counted_object_unittest.cc @@ -12,6 +12,7 @@ #include #include +#include #include #include "api/scoped_refptr.h" @@ -63,6 +64,20 @@ class RefClassWithMixedValues : public RefCountInterface { std::string c_; }; +class Foo { + public: + Foo() {} + Foo(int i, int j) : foo_(i + j) {} + int foo_ = 0; +}; + +class FooItf : public RefCountInterface { + public: + FooItf() {} + FooItf(int i, int j) : foo_(i + j) {} + int foo_ = 0; +}; + } // namespace TEST(RefCountedObject, HasOneRef) { @@ -95,4 +110,73 @@ TEST(RefCountedObject, SupportMixedTypesInCtor) { EXPECT_EQ(c, ref->c_); } +TEST(FinalRefCountedObject, CanWrapIntoScopedRefptr) { + using WrappedTyped = FinalRefCountedObject; + static_assert(!std::is_polymorphic::value, ""); + scoped_refptr ref(new WrappedTyped()); + EXPECT_TRUE(ref.get()); + EXPECT_TRUE(ref->HasOneRef()); + // Test reference counter is updated on some simple operations. + scoped_refptr ref2 = ref; + EXPECT_FALSE(ref->HasOneRef()); + EXPECT_FALSE(ref2->HasOneRef()); + + ref = nullptr; + EXPECT_TRUE(ref2->HasOneRef()); +} + +TEST(FinalRefCountedObject, CanCreateFromMovedType) { + class MoveOnly { + public: + MoveOnly(int a) : a_(a) {} + MoveOnly(MoveOnly&&) = default; + + int a() { return a_; } + + private: + int a_; + }; + MoveOnly foo(5); + auto ref = make_ref_counted(std::move(foo)); + EXPECT_EQ(ref->a(), 5); +} + +// This test is mostly a compile-time test for scoped_refptr compatibility. +TEST(RefCounted, SmartPointers) { + // Sanity compile-time tests. FooItf is virtual, Foo is not, FooItf inherits + // from RefCountInterface, Foo does not. + static_assert(std::is_base_of::value, ""); + static_assert(!std::is_base_of::value, ""); + static_assert(std::is_polymorphic::value, ""); + static_assert(!std::is_polymorphic::value, ""); + + // Check if Ref generates the expected types for Foo and FooItf. + static_assert(std::is_base_of::Type>::value && + !std::is_same::Type>::value, + ""); + static_assert(std::is_same::Type>::value, ""); + + { + // Test with FooItf, a class that inherits from RefCountInterface. + // Check that we get a valid FooItf reference counted object. + auto p = make_ref_counted(2, 3); + EXPECT_NE(p.get(), nullptr); + EXPECT_EQ(p->foo_, 5); // the FooItf ctor just stores 2+3 in foo_. + + // Use a couple of different ways of declaring what should result in the + // same type as `p` is of. + scoped_refptr::Type> p2 = p; + Ref::Ptr p3 = p; + } + + { + // Same for `Foo` + auto p = make_ref_counted(2, 3); + EXPECT_NE(p.get(), nullptr); + EXPECT_EQ(p->foo_, 5); + scoped_refptr::Type> p2 = p; + Ref::Ptr p3 = p; + } +} + } // namespace rtc diff --git a/rtc_base/rtc_certificate.cc b/rtc_base/rtc_certificate.cc index 04ae99685d..496b4ac4b4 100644 --- a/rtc_base/rtc_certificate.cc +++ b/rtc_base/rtc_certificate.cc @@ -13,7 +13,6 @@ #include #include "rtc_base/checks.h" -#include "rtc_base/ref_counted_object.h" #include "rtc_base/ssl_certificate.h" #include "rtc_base/ssl_identity.h" #include "rtc_base/time_utils.h" @@ -22,14 +21,14 @@ namespace rtc { scoped_refptr RTCCertificate::Create( std::unique_ptr identity) { - return new RefCountedObject(identity.release()); + return new RTCCertificate(identity.release()); } RTCCertificate::RTCCertificate(SSLIdentity* identity) : identity_(identity) { RTC_DCHECK(identity_); } -RTCCertificate::~RTCCertificate() {} +RTCCertificate::~RTCCertificate() = default; uint64_t RTCCertificate::Expires() const { int64_t expires = GetSSLCertificate().CertificateExpirationTime(); @@ -47,11 +46,6 @@ const SSLCertificate& RTCCertificate::GetSSLCertificate() const { return identity_->certificate(); } -// Deprecated: TODO(benwright) - Remove once chromium is updated. -const SSLCertificate& RTCCertificate::ssl_certificate() const { - return identity_->certificate(); -} - const SSLCertChain& RTCCertificate::GetSSLCertificateChain() const { return identity_->cert_chain(); } @@ -67,7 +61,7 @@ scoped_refptr RTCCertificate::FromPEM( SSLIdentity::CreateFromPEMStrings(pem.private_key(), pem.certificate())); if (!identity) return nullptr; - return new RefCountedObject(identity.release()); + return new RTCCertificate(identity.release()); } bool RTCCertificate::operator==(const RTCCertificate& certificate) const { diff --git a/rtc_base/rtc_certificate.h b/rtc_base/rtc_certificate.h index 102385e5a2..fa026ec331 100644 --- a/rtc_base/rtc_certificate.h +++ b/rtc_base/rtc_certificate.h @@ -16,8 +16,9 @@ #include #include +#include "absl/base/attributes.h" +#include "api/ref_counted_base.h" #include "api/scoped_refptr.h" -#include "rtc_base/ref_count.h" #include "rtc_base/system/rtc_export.h" namespace rtc { @@ -49,7 +50,8 @@ class RTCCertificatePEM { // A thin abstraction layer between "lower level crypto stuff" like // SSLCertificate and WebRTC usage. Takes ownership of some lower level objects, // reference counting protects these from premature destruction. -class RTC_EXPORT RTCCertificate : public RefCountInterface { +class RTC_EXPORT RTCCertificate final + : public RefCountedNonVirtual { public: // Takes ownership of |identity|. static scoped_refptr Create( @@ -64,9 +66,6 @@ class RTC_EXPORT RTCCertificate : public RefCountInterface { const SSLCertificate& GetSSLCertificate() const; const SSLCertChain& GetSSLCertificateChain() const; - // Deprecated: TODO(benwright) - Remove once chromium is updated. - const SSLCertificate& ssl_certificate() const; - // TODO(hbos): If possible, remove once RTCCertificate and its // GetSSLCertificate() is used in all relevant places. Should not pass around // raw SSLIdentity* for the sake of accessing SSLIdentity::certificate(). @@ -82,12 +81,14 @@ class RTC_EXPORT RTCCertificate : public RefCountInterface { protected: explicit RTCCertificate(SSLIdentity* identity); - ~RTCCertificate() override; + + friend class RefCountedNonVirtual; + ~RTCCertificate(); private: // The SSLIdentity is the owner of the SSLCertificate. To protect our // GetSSLCertificate() we take ownership of |identity_|. - std::unique_ptr identity_; + const std::unique_ptr identity_; }; } // namespace rtc diff --git a/rtc_base/socket.h b/rtc_base/socket.h index c2d1e3d29a..6b3ad5e9f2 100644 --- a/rtc_base/socket.h +++ b/rtc_base/socket.h @@ -59,6 +59,8 @@ #define ECONNREFUSED WSAECONNREFUSED #undef EHOSTUNREACH #define EHOSTUNREACH WSAEHOSTUNREACH +#undef ENETUNREACH +#define ENETUNREACH WSAENETUNREACH #define SOCKET_EACCES WSAEACCES #endif // WEBRTC_WIN diff --git a/rtc_base/socket_address.cc b/rtc_base/socket_address.cc index 639be52c54..2996ede9d2 100644 --- a/rtc_base/socket_address.cc +++ b/rtc_base/socket_address.cc @@ -178,6 +178,16 @@ std::string SocketAddress::ToSensitiveString() const { return sb.str(); } +std::string SocketAddress::ToResolvedSensitiveString() const { + if (IsUnresolvedIP()) { + return ""; + } + char buf[1024]; + rtc::SimpleStringBuilder sb(buf); + sb << ipaddr().ToSensitiveString() << ":" << port(); + return sb.str(); +} + bool SocketAddress::FromString(const std::string& str) { if (str.at(0) == '[') { std::string::size_type closebracket = str.rfind(']'); diff --git a/rtc_base/socket_address.h b/rtc_base/socket_address.h index f459407f54..570a71281e 100644 --- a/rtc_base/socket_address.h +++ b/rtc_base/socket_address.h @@ -124,6 +124,10 @@ class RTC_EXPORT SocketAddress { // Same as ToString but anonymizes it by hiding the last part. std::string ToSensitiveString() const; + // Returns hostname:port string if address is resolved, otherwise returns + // empty string. + std::string ToResolvedSensitiveString() const; + // Parses hostname:port and [hostname]:port. bool FromString(const std::string& str); diff --git a/rtc_base/socket_address_unittest.cc b/rtc_base/socket_address_unittest.cc index 14da8cb519..d1c911abff 100644 --- a/rtc_base/socket_address_unittest.cc +++ b/rtc_base/socket_address_unittest.cc @@ -323,25 +323,15 @@ TEST(SocketAddressTest, TestToSensitiveString) { EXPECT_EQ("1.2.3.4", addr_v4.HostAsURIString()); EXPECT_EQ("1.2.3.4:5678", addr_v4.ToString()); -#if defined(NDEBUG) EXPECT_EQ("1.2.3.x", addr_v4.HostAsSensitiveURIString()); EXPECT_EQ("1.2.3.x:5678", addr_v4.ToSensitiveString()); -#else - EXPECT_EQ("1.2.3.4", addr_v4.HostAsSensitiveURIString()); - EXPECT_EQ("1.2.3.4:5678", addr_v4.ToSensitiveString()); -#endif // defined(NDEBUG) SocketAddress addr_v6(kTestV6AddrString, 5678); EXPECT_EQ("[" + kTestV6AddrString + "]", addr_v6.HostAsURIString()); EXPECT_EQ(kTestV6AddrFullString, addr_v6.ToString()); -#if defined(NDEBUG) EXPECT_EQ("[" + kTestV6AddrAnonymizedString + "]", addr_v6.HostAsSensitiveURIString()); EXPECT_EQ(kTestV6AddrFullAnonymizedString, addr_v6.ToSensitiveString()); -#else - EXPECT_EQ("[" + kTestV6AddrString + "]", addr_v6.HostAsSensitiveURIString()); - EXPECT_EQ(kTestV6AddrFullString, addr_v6.ToSensitiveString()); -#endif // defined(NDEBUG) } } // namespace rtc diff --git a/rtc_base/socket_server.h b/rtc_base/socket_server.h index 98971e4d84..face04dbc2 100644 --- a/rtc_base/socket_server.h +++ b/rtc_base/socket_server.h @@ -33,9 +33,10 @@ class SocketServer : public SocketFactory { static const int kForever = -1; static std::unique_ptr CreateDefault(); - // When the socket server is installed into a Thread, this function is - // called to allow the socket server to use the thread's message queue for - // any messaging that it might need to perform. + // When the socket server is installed into a Thread, this function is called + // to allow the socket server to use the thread's message queue for any + // messaging that it might need to perform. It is also called with a null + // argument before the thread is destroyed. virtual void SetMessageQueue(Thread* queue) {} // Sleeps until: diff --git a/rtc_base/ssl_fingerprint.cc b/rtc_base/ssl_fingerprint.cc index 5b261e0f53..358402eb03 100644 --- a/rtc_base/ssl_fingerprint.cc +++ b/rtc_base/ssl_fingerprint.cc @@ -103,9 +103,6 @@ SSLFingerprint::SSLFingerprint(const std::string& algorithm, size_t digest_len) : SSLFingerprint(algorithm, MakeArrayView(digest_in, digest_len)) {} -SSLFingerprint::SSLFingerprint(const SSLFingerprint& from) - : algorithm(from.algorithm), digest(from.digest) {} - bool SSLFingerprint::operator==(const SSLFingerprint& other) const { return algorithm == other.algorithm && digest == other.digest; } diff --git a/rtc_base/ssl_fingerprint.h b/rtc_base/ssl_fingerprint.h index d65d665d83..add3ab7911 100644 --- a/rtc_base/ssl_fingerprint.h +++ b/rtc_base/ssl_fingerprint.h @@ -57,7 +57,8 @@ struct RTC_EXPORT SSLFingerprint { const uint8_t* digest_in, size_t digest_len); - SSLFingerprint(const SSLFingerprint& from); + SSLFingerprint(const SSLFingerprint& from) = default; + SSLFingerprint& operator=(const SSLFingerprint& from) = default; bool operator==(const SSLFingerprint& other) const; diff --git a/rtc_base/ssl_identity.h b/rtc_base/ssl_identity.h index d078b045a7..a9167ef5eb 100644 --- a/rtc_base/ssl_identity.h +++ b/rtc_base/ssl_identity.h @@ -18,7 +18,6 @@ #include #include -#include "rtc_base/deprecation.h" #include "rtc_base/system/rtc_export.h" namespace rtc { diff --git a/rtc_base/ssl_stream_adapter.cc b/rtc_base/ssl_stream_adapter.cc index 354622e6f0..5730af63d8 100644 --- a/rtc_base/ssl_stream_adapter.cc +++ b/rtc_base/ssl_stream_adapter.cc @@ -95,11 +95,6 @@ std::unique_ptr SSLStreamAdapter::Create( return std::make_unique(std::move(stream)); } -SSLStreamAdapter::SSLStreamAdapter(std::unique_ptr stream) - : StreamAdapterInterface(stream.release()) {} - -SSLStreamAdapter::~SSLStreamAdapter() {} - bool SSLStreamAdapter::GetSslCipherSuite(int* cipher_suite) { return false; } diff --git a/rtc_base/ssl_stream_adapter.h b/rtc_base/ssl_stream_adapter.h index 7bff726510..6b44c76455 100644 --- a/rtc_base/ssl_stream_adapter.h +++ b/rtc_base/ssl_stream_adapter.h @@ -18,7 +18,6 @@ #include #include "absl/memory/memory.h" -#include "rtc_base/deprecation.h" #include "rtc_base/ssl_certificate.h" #include "rtc_base/ssl_identity.h" #include "rtc_base/stream.h" @@ -119,7 +118,7 @@ enum { SSE_MSG_TRUNC = 0xff0001 }; // Used to send back UMA histogram value. Logged when Dtls handshake fails. enum class SSLHandshakeError { UNKNOWN, INCOMPATIBLE_CIPHERSUITE, MAX_VALUE }; -class SSLStreamAdapter : public StreamAdapterInterface { +class SSLStreamAdapter : public StreamInterface, public sigslot::has_slots<> { public: // Instantiate an SSLStreamAdapter wrapping the given stream, // (using the selected implementation for the platform). @@ -127,8 +126,8 @@ class SSLStreamAdapter : public StreamAdapterInterface { static std::unique_ptr Create( std::unique_ptr stream); - explicit SSLStreamAdapter(std::unique_ptr stream); - ~SSLStreamAdapter() override; + SSLStreamAdapter() = default; + ~SSLStreamAdapter() override = default; // Specify our SSL identity: key and certificate. SSLStream takes ownership // of the SSLIdentity object and will free it when appropriate. Should be diff --git a/rtc_base/stream.cc b/rtc_base/stream.cc index ee72f8d2b8..30c767888c 100644 --- a/rtc_base/stream.cc +++ b/rtc_base/stream.cc @@ -49,68 +49,4 @@ bool StreamInterface::Flush() { StreamInterface::StreamInterface() {} -/////////////////////////////////////////////////////////////////////////////// -// StreamAdapterInterface -/////////////////////////////////////////////////////////////////////////////// - -StreamAdapterInterface::StreamAdapterInterface(StreamInterface* stream, - bool owned) - : stream_(stream), owned_(owned) { - if (nullptr != stream_) - stream_->SignalEvent.connect(this, &StreamAdapterInterface::OnEvent); -} - -StreamState StreamAdapterInterface::GetState() const { - return stream_->GetState(); -} -StreamResult StreamAdapterInterface::Read(void* buffer, - size_t buffer_len, - size_t* read, - int* error) { - return stream_->Read(buffer, buffer_len, read, error); -} -StreamResult StreamAdapterInterface::Write(const void* data, - size_t data_len, - size_t* written, - int* error) { - return stream_->Write(data, data_len, written, error); -} -void StreamAdapterInterface::Close() { - stream_->Close(); -} - -bool StreamAdapterInterface::Flush() { - return stream_->Flush(); -} - -void StreamAdapterInterface::Attach(StreamInterface* stream, bool owned) { - if (nullptr != stream_) - stream_->SignalEvent.disconnect(this); - if (owned_) - delete stream_; - stream_ = stream; - owned_ = owned; - if (nullptr != stream_) - stream_->SignalEvent.connect(this, &StreamAdapterInterface::OnEvent); -} - -StreamInterface* StreamAdapterInterface::Detach() { - if (nullptr != stream_) - stream_->SignalEvent.disconnect(this); - StreamInterface* stream = stream_; - stream_ = nullptr; - return stream; -} - -StreamAdapterInterface::~StreamAdapterInterface() { - if (owned_) - delete stream_; -} - -void StreamAdapterInterface::OnEvent(StreamInterface* stream, - int events, - int err) { - SignalEvent(this, events, err); -} - } // namespace rtc diff --git a/rtc_base/stream.h b/rtc_base/stream.h index 9bf11a2405..70de65a75d 100644 --- a/rtc_base/stream.h +++ b/rtc_base/stream.h @@ -115,50 +115,6 @@ class RTC_EXPORT StreamInterface { RTC_DISALLOW_COPY_AND_ASSIGN(StreamInterface); }; -/////////////////////////////////////////////////////////////////////////////// -// StreamAdapterInterface is a convenient base-class for adapting a stream. -// By default, all operations are pass-through. Override the methods that you -// require adaptation. Streams should really be upgraded to reference-counted. -// In the meantime, use the owned flag to indicate whether the adapter should -// own the adapted stream. -/////////////////////////////////////////////////////////////////////////////// - -class StreamAdapterInterface : public StreamInterface, - public sigslot::has_slots<> { - public: - explicit StreamAdapterInterface(StreamInterface* stream, bool owned = true); - - // Core Stream Interface - StreamState GetState() const override; - StreamResult Read(void* buffer, - size_t buffer_len, - size_t* read, - int* error) override; - StreamResult Write(const void* data, - size_t data_len, - size_t* written, - int* error) override; - void Close() override; - - bool Flush() override; - - void Attach(StreamInterface* stream, bool owned = true); - StreamInterface* Detach(); - - protected: - ~StreamAdapterInterface() override; - - // Note that the adapter presents itself as the origin of the stream events, - // since users of the adapter may not recognize the adapted object. - virtual void OnEvent(StreamInterface* stream, int events, int err); - StreamInterface* stream() { return stream_; } - - private: - StreamInterface* stream_; - bool owned_; - RTC_DISALLOW_COPY_AND_ASSIGN(StreamAdapterInterface); -}; - } // namespace rtc #endif // RTC_BASE_STREAM_H_ diff --git a/rtc_base/string_utils.h b/rtc_base/string_utils.h index 23c55cb893..d844e5e125 100644 --- a/rtc_base/string_utils.h +++ b/rtc_base/string_utils.h @@ -88,6 +88,43 @@ std::string string_trim(const std::string& s); // TODO(jonasolsson): replace with absl::Hex when that becomes available. std::string ToHex(const int i); +// CompileTimeString comprises of a string-like object which can be used as a +// regular const char* in compile time and supports concatenation. Useful for +// concatenating constexpr strings in for example macro declarations. +namespace rtc_base_string_utils_internal { +template +struct CompileTimeString { + char string[NPlus1] = {0}; + constexpr CompileTimeString() = default; + template + explicit constexpr CompileTimeString(const char (&chars)[MPlus1]) { + char* chars_pointer = string; + for (auto c : chars) + *chars_pointer++ = c; + } + template + constexpr auto Concat(CompileTimeString b) { + CompileTimeString result; + char* chars_pointer = result.string; + for (auto c : string) + *chars_pointer++ = c; + chars_pointer = result.string + NPlus1 - 1; + for (auto c : b.string) + *chars_pointer++ = c; + result.string[NPlus1 + MPlus1 - 2] = 0; + return result; + } + constexpr operator const char*() { return string; } +}; +} // namespace rtc_base_string_utils_internal + +// Makes a constexpr CompileTimeString without having to specify X +// explicitly. +template +constexpr auto MakeCompileTimeString(const char (&a)[N]) { + return rtc_base_string_utils_internal::CompileTimeString(a); +} + } // namespace rtc #endif // RTC_BASE_STRING_UTILS_H_ diff --git a/rtc_base/string_utils_unittest.cc b/rtc_base/string_utils_unittest.cc index 2fa1f220ac..120f7e60f5 100644 --- a/rtc_base/string_utils_unittest.cc +++ b/rtc_base/string_utils_unittest.cc @@ -39,4 +39,29 @@ TEST(string_toutf, Empty) { #endif // WEBRTC_WIN +TEST(CompileTimeString, MakeActsLikeAString) { + EXPECT_STREQ(MakeCompileTimeString("abc123"), "abc123"); +} + +TEST(CompileTimeString, ConvertibleToStdString) { + EXPECT_EQ(std::string(MakeCompileTimeString("abab")), "abab"); +} + +namespace detail { +constexpr bool StringEquals(const char* a, const char* b) { + while (*a && *a == *b) + a++, b++; + return *a == *b; +} +} // namespace detail + +static_assert(detail::StringEquals(MakeCompileTimeString("handellm"), + "handellm"), + "String should initialize."); + +static_assert(detail::StringEquals(MakeCompileTimeString("abc123").Concat( + MakeCompileTimeString("def456ghi")), + "abc123def456ghi"), + "Strings should concatenate."); + } // namespace rtc diff --git a/rtc_base/strings/json.cc b/rtc_base/strings/json.cc index 8a544a0c0d..99664404cf 100644 --- a/rtc_base/strings/json.cc +++ b/rtc_base/strings/json.cc @@ -286,9 +286,9 @@ bool GetDoubleFromJsonObject(const Json::Value& in, } std::string JsonValueToString(const Json::Value& json) { - Json::FastWriter w; - std::string value = w.write(json); - return value.substr(0, value.size() - 1); // trim trailing newline + Json::StreamWriterBuilder builder; + std::string output = Json::writeString(builder, json); + return output.substr(0, output.size() - 1); // trim trailing newline } } // namespace rtc diff --git a/rtc_base/swap_queue.h b/rtc_base/swap_queue.h index 9eac49a933..3c8149c163 100644 --- a/rtc_base/swap_queue.h +++ b/rtc_base/swap_queue.h @@ -17,8 +17,8 @@ #include #include +#include "absl/base/attributes.h" #include "rtc_base/checks.h" -#include "rtc_base/system/unused.h" namespace webrtc { @@ -127,7 +127,7 @@ class SwapQueue { // When specified, the T given in *input must pass the ItemVerifier() test. // The contents of *input after the call are then also guaranteed to pass the // ItemVerifier() test. - bool Insert(T* input) RTC_WARN_UNUSED_RESULT { + ABSL_MUST_USE_RESULT bool Insert(T* input) { RTC_DCHECK(input); RTC_DCHECK(queue_item_verifier_(*input)); @@ -168,7 +168,7 @@ class SwapQueue { // empty). When specified, The T given in *output must pass the ItemVerifier() // test and the contents of *output after the call are then also guaranteed to // pass the ItemVerifier() test. - bool Remove(T* output) RTC_WARN_UNUSED_RESULT { + ABSL_MUST_USE_RESULT bool Remove(T* output) { RTC_DCHECK(output); RTC_DCHECK(queue_item_verifier_(*output)); diff --git a/rtc_base/synchronization/BUILD.gn b/rtc_base/synchronization/BUILD.gn index 618e224a5d..3cddc55c72 100644 --- a/rtc_base/synchronization/BUILD.gn +++ b/rtc_base/synchronization/BUILD.gn @@ -6,6 +6,7 @@ # in the file PATENTS. All contributing project authors may # be found in the AUTHORS file in the root of the source tree. +import("//third_party/google_benchmark/buildconfig.gni") import("../../webrtc.gni") if (is_android) { import("//build/config/android/config.gni") @@ -26,6 +27,7 @@ rtc_library("mutex") { "mutex.h", "mutex_critical_section.h", "mutex_pthread.h", + "mutex_race_check.h", ] if (rtc_use_absl_mutex) { sources += [ "mutex_abseil.h" ] @@ -44,15 +46,15 @@ rtc_library("mutex") { } } -rtc_library("sequence_checker") { +rtc_library("sequence_checker_internal") { + visibility = [ "../../api:sequence_checker" ] sources = [ - "sequence_checker.cc", - "sequence_checker.h", + "sequence_checker_internal.cc", + "sequence_checker_internal.h", ] deps = [ ":mutex", "..:checks", - "..:criticalsection", "..:macromagic", "..:platform_thread_types", "..:stringutils", @@ -74,47 +76,35 @@ rtc_library("yield_policy") { } if (rtc_include_tests) { - rtc_library("synchronization_unittests") { - testonly = true - sources = [ - "mutex_unittest.cc", - "yield_policy_unittest.cc", - ] - deps = [ - ":mutex", - ":yield", - ":yield_policy", - "..:checks", - "..:macromagic", - "..:rtc_base", - "..:rtc_event", - "../../test:test_support", - "//third_party/google_benchmark", - ] - } - - rtc_library("mutex_benchmark") { - testonly = true - sources = [ "mutex_benchmark.cc" ] - deps = [ - ":mutex", - "../system:unused", - "//third_party/google_benchmark", - ] - } - - rtc_library("sequence_checker_unittests") { - testonly = true + if (enable_google_benchmarks) { + rtc_library("synchronization_unittests") { + testonly = true + sources = [ + "mutex_unittest.cc", + "yield_policy_unittest.cc", + ] + deps = [ + ":mutex", + ":yield", + ":yield_policy", + "..:checks", + "..:macromagic", + "..:rtc_base", + "..:rtc_event", + "..:threading", + "../../test:test_support", + "//third_party/google_benchmark", + ] + } - sources = [ "sequence_checker_unittest.cc" ] - deps = [ - ":sequence_checker", - "..:checks", - "..:rtc_base_approved", - "..:task_queue_for_test", - "../../api:function_view", - "../../test:test_main", - "../../test:test_support", - ] + rtc_library("mutex_benchmark") { + testonly = true + sources = [ "mutex_benchmark.cc" ] + deps = [ + ":mutex", + "../system:unused", + "//third_party/google_benchmark", + ] + } } } diff --git a/rtc_base/synchronization/mutex.h b/rtc_base/synchronization/mutex.h index 620fe74e4a..e1512e96cc 100644 --- a/rtc_base/synchronization/mutex.h +++ b/rtc_base/synchronization/mutex.h @@ -13,12 +13,17 @@ #include +#include "absl/base/attributes.h" #include "absl/base/const_init.h" #include "rtc_base/checks.h" -#include "rtc_base/system/unused.h" #include "rtc_base/thread_annotations.h" -#if defined(WEBRTC_ABSL_MUTEX) +#if defined(WEBRTC_RACE_CHECK_MUTEX) +// To use the race check mutex, define WEBRTC_RACE_CHECK_MUTEX globally. This +// also adds a dependency to absl::Mutex from logging.cc due to concurrent +// invocation of the static logging system. +#include "rtc_base/synchronization/mutex_race_check.h" +#elif defined(WEBRTC_ABSL_MUTEX) #include "rtc_base/synchronization/mutex_abseil.h" // nogncheck #elif defined(WEBRTC_WIN) #include "rtc_base/synchronization/mutex_critical_section.h" @@ -41,7 +46,7 @@ class RTC_LOCKABLE Mutex final { void Lock() RTC_EXCLUSIVE_LOCK_FUNCTION() { impl_.Lock(); } - RTC_WARN_UNUSED_RESULT bool TryLock() RTC_EXCLUSIVE_TRYLOCK_FUNCTION(true) { + ABSL_MUST_USE_RESULT bool TryLock() RTC_EXCLUSIVE_TRYLOCK_FUNCTION(true) { return impl_.TryLock(); } void Unlock() RTC_UNLOCK_FUNCTION() { diff --git a/rtc_base/synchronization/mutex_abseil.h b/rtc_base/synchronization/mutex_abseil.h index 4ad1d07eef..9247065ae6 100644 --- a/rtc_base/synchronization/mutex_abseil.h +++ b/rtc_base/synchronization/mutex_abseil.h @@ -11,6 +11,7 @@ #ifndef RTC_BASE_SYNCHRONIZATION_MUTEX_ABSEIL_H_ #define RTC_BASE_SYNCHRONIZATION_MUTEX_ABSEIL_H_ +#include "absl/base/attributes.h" #include "absl/synchronization/mutex.h" #include "rtc_base/thread_annotations.h" @@ -23,7 +24,7 @@ class RTC_LOCKABLE MutexImpl final { MutexImpl& operator=(const MutexImpl&) = delete; void Lock() RTC_EXCLUSIVE_LOCK_FUNCTION() { mutex_.Lock(); } - RTC_WARN_UNUSED_RESULT bool TryLock() RTC_EXCLUSIVE_TRYLOCK_FUNCTION(true) { + ABSL_MUST_USE_RESULT bool TryLock() RTC_EXCLUSIVE_TRYLOCK_FUNCTION(true) { return mutex_.TryLock(); } void Unlock() RTC_UNLOCK_FUNCTION() { mutex_.Unlock(); } diff --git a/rtc_base/synchronization/mutex_critical_section.h b/rtc_base/synchronization/mutex_critical_section.h index d206794988..cb3d6a095c 100644 --- a/rtc_base/synchronization/mutex_critical_section.h +++ b/rtc_base/synchronization/mutex_critical_section.h @@ -23,6 +23,7 @@ #include // must come after windows headers. // clang-format on +#include "absl/base/attributes.h" #include "rtc_base/thread_annotations.h" namespace webrtc { @@ -37,7 +38,7 @@ class RTC_LOCKABLE MutexImpl final { void Lock() RTC_EXCLUSIVE_LOCK_FUNCTION() { EnterCriticalSection(&critical_section_); } - RTC_WARN_UNUSED_RESULT bool TryLock() RTC_EXCLUSIVE_TRYLOCK_FUNCTION(true) { + ABSL_MUST_USE_RESULT bool TryLock() RTC_EXCLUSIVE_TRYLOCK_FUNCTION(true) { return TryEnterCriticalSection(&critical_section_) != FALSE; } void Unlock() RTC_UNLOCK_FUNCTION() { diff --git a/rtc_base/synchronization/mutex_pthread.h b/rtc_base/synchronization/mutex_pthread.h index c9496e72c9..8898ca5348 100644 --- a/rtc_base/synchronization/mutex_pthread.h +++ b/rtc_base/synchronization/mutex_pthread.h @@ -18,6 +18,7 @@ #include #endif +#include "absl/base/attributes.h" #include "rtc_base/thread_annotations.h" namespace webrtc { @@ -39,7 +40,7 @@ class RTC_LOCKABLE MutexImpl final { ~MutexImpl() { pthread_mutex_destroy(&mutex_); } void Lock() RTC_EXCLUSIVE_LOCK_FUNCTION() { pthread_mutex_lock(&mutex_); } - RTC_WARN_UNUSED_RESULT bool TryLock() RTC_EXCLUSIVE_TRYLOCK_FUNCTION(true) { + ABSL_MUST_USE_RESULT bool TryLock() RTC_EXCLUSIVE_TRYLOCK_FUNCTION(true) { return pthread_mutex_trylock(&mutex_) == 0; } void Unlock() RTC_UNLOCK_FUNCTION() { pthread_mutex_unlock(&mutex_); } diff --git a/rtc_base/synchronization/mutex_race_check.h b/rtc_base/synchronization/mutex_race_check.h new file mode 100644 index 0000000000..cada6292b5 --- /dev/null +++ b/rtc_base/synchronization/mutex_race_check.h @@ -0,0 +1,65 @@ +/* + * Copyright 2020 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef RTC_BASE_SYNCHRONIZATION_MUTEX_RACE_CHECK_H_ +#define RTC_BASE_SYNCHRONIZATION_MUTEX_RACE_CHECK_H_ + +#include + +#include "absl/base/attributes.h" +#include "rtc_base/checks.h" +#include "rtc_base/system/unused.h" +#include "rtc_base/thread_annotations.h" + +namespace webrtc { + +// This implementation class is useful when a consuming project can guarantee +// that all WebRTC invocation is happening serially. Additionally, the consuming +// project cannot use WebRTC code that spawn threads or task queues. +// +// The class internally check fails on Lock() if it finds the consumer actually +// invokes WebRTC concurrently. +// +// To use the race check mutex, define WEBRTC_RACE_CHECK_MUTEX globally. This +// also adds a dependency to absl::Mutex from logging.cc because even though +// objects are invoked serially, the logging is static and invoked concurrently +// and hence needs protection. +class RTC_LOCKABLE MutexImpl final { + public: + MutexImpl() = default; + MutexImpl(const MutexImpl&) = delete; + MutexImpl& operator=(const MutexImpl&) = delete; + + void Lock() RTC_EXCLUSIVE_LOCK_FUNCTION() { + bool was_free = free_.exchange(false, std::memory_order_acquire); + RTC_CHECK(was_free) + << "WEBRTC_RACE_CHECK_MUTEX: mutex locked concurrently."; + } + ABSL_MUST_USE_RESULT bool TryLock() RTC_EXCLUSIVE_TRYLOCK_FUNCTION(true) { + bool was_free = free_.exchange(false, std::memory_order_acquire); + return was_free; + } + void Unlock() RTC_UNLOCK_FUNCTION() { + free_.store(true, std::memory_order_release); + } + + private: + // Release-acquire ordering is used. + // - In the Lock methods we're guaranteeing that reads and writes happening + // after the (Try)Lock don't appear to have happened before the Lock (acquire + // ordering). + // - In the Unlock method we're guaranteeing that reads and writes happening + // before the Unlock don't appear to happen after it (release ordering). + std::atomic free_{true}; +}; + +} // namespace webrtc + +#endif // RTC_BASE_SYNCHRONIZATION_MUTEX_RACE_CHECK_H_ diff --git a/rtc_base/synchronization/sequence_checker.h b/rtc_base/synchronization/sequence_checker.h deleted file mode 100644 index ecf8490cec..0000000000 --- a/rtc_base/synchronization/sequence_checker.h +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Copyright 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ -#ifndef RTC_BASE_SYNCHRONIZATION_SEQUENCE_CHECKER_H_ -#define RTC_BASE_SYNCHRONIZATION_SEQUENCE_CHECKER_H_ - -#include - -#include "api/task_queue/task_queue_base.h" -#include "rtc_base/platform_thread_types.h" -#include "rtc_base/synchronization/mutex.h" -#include "rtc_base/system/rtc_export.h" -#include "rtc_base/thread_annotations.h" - -namespace webrtc { -// Real implementation of SequenceChecker, for use in debug mode, or -// for temporary use in release mode (e.g. to RTC_CHECK on a threading issue -// seen only in the wild). -// -// Note: You should almost always use the SequenceChecker class to get the -// right version for your build configuration. -class RTC_EXPORT SequenceCheckerImpl { - public: - SequenceCheckerImpl(); - ~SequenceCheckerImpl(); - - bool IsCurrent() const; - // Changes the task queue or thread that is checked for in IsCurrent. This can - // be useful when an object may be created on one task queue / thread and then - // used exclusively on another thread. - void Detach(); - - // Returns a string that is formatted to match with the error string printed - // by RTC_CHECK() when a condition is not met. - // This is used in conjunction with the RTC_DCHECK_RUN_ON() macro. - std::string ExpectationToString() const; - - private: - mutable Mutex lock_; - // These are mutable so that IsCurrent can set them. - mutable bool attached_ RTC_GUARDED_BY(lock_); - mutable rtc::PlatformThreadRef valid_thread_ RTC_GUARDED_BY(lock_); - mutable const TaskQueueBase* valid_queue_ RTC_GUARDED_BY(lock_); - mutable const void* valid_system_queue_ RTC_GUARDED_BY(lock_); -}; - -// Do nothing implementation, for use in release mode. -// -// Note: You should almost always use the SequenceChecker class to get the -// right version for your build configuration. -class SequenceCheckerDoNothing { - public: - bool IsCurrent() const { return true; } - void Detach() {} -}; - -// SequenceChecker is a helper class used to help verify that some methods -// of a class are called on the same task queue or thread. A -// SequenceChecker is bound to a a task queue if the object is -// created on a task queue, or a thread otherwise. -// -// -// Example: -// class MyClass { -// public: -// void Foo() { -// RTC_DCHECK_RUN_ON(sequence_checker_); -// ... (do stuff) ... -// } -// -// private: -// SequenceChecker sequence_checker_; -// } -// -// In Release mode, IsCurrent will always return true. -#if RTC_DCHECK_IS_ON -class RTC_LOCKABLE SequenceChecker : public SequenceCheckerImpl {}; -#else -class RTC_LOCKABLE SequenceChecker : public SequenceCheckerDoNothing {}; -#endif // RTC_ENABLE_THREAD_CHECKER - -namespace webrtc_seq_check_impl { -// Helper class used by RTC_DCHECK_RUN_ON (see example usage below). -class RTC_SCOPED_LOCKABLE SequenceCheckerScope { - public: - template - explicit SequenceCheckerScope(const ThreadLikeObject* thread_like_object) - RTC_EXCLUSIVE_LOCK_FUNCTION(thread_like_object) {} - SequenceCheckerScope(const SequenceCheckerScope&) = delete; - SequenceCheckerScope& operator=(const SequenceCheckerScope&) = delete; - ~SequenceCheckerScope() RTC_UNLOCK_FUNCTION() {} - - template - static bool IsCurrent(const ThreadLikeObject* thread_like_object) { - return thread_like_object->IsCurrent(); - } -}; -} // namespace webrtc_seq_check_impl -} // namespace webrtc - -// RTC_RUN_ON/RTC_GUARDED_BY/RTC_DCHECK_RUN_ON macros allows to annotate -// variables are accessed from same thread/task queue. -// Using tools designed to check mutexes, it checks at compile time everywhere -// variable is access, there is a run-time dcheck thread/task queue is correct. -// -// class ThreadExample { -// public: -// void NeedVar1() { -// RTC_DCHECK_RUN_ON(network_thread_); -// transport_->Send(); -// } -// -// private: -// rtc::Thread* network_thread_; -// int transport_ RTC_GUARDED_BY(network_thread_); -// }; -// -// class SequenceCheckerExample { -// public: -// int CalledFromPacer() RTC_RUN_ON(pacer_sequence_checker_) { -// return var2_; -// } -// -// void CallMeFromPacer() { -// RTC_DCHECK_RUN_ON(&pacer_sequence_checker_) -// << "Should be called from pacer"; -// CalledFromPacer(); -// } -// -// private: -// int pacer_var_ RTC_GUARDED_BY(pacer_sequence_checker_); -// SequenceChecker pacer_sequence_checker_; -// }; -// -// class TaskQueueExample { -// public: -// class Encoder { -// public: -// rtc::TaskQueue* Queue() { return encoder_queue_; } -// void Encode() { -// RTC_DCHECK_RUN_ON(encoder_queue_); -// DoSomething(var_); -// } -// -// private: -// rtc::TaskQueue* const encoder_queue_; -// Frame var_ RTC_GUARDED_BY(encoder_queue_); -// }; -// -// void Encode() { -// // Will fail at runtime when DCHECK is enabled: -// // encoder_->Encode(); -// // Will work: -// rtc::scoped_refptr encoder = encoder_; -// encoder_->Queue()->PostTask([encoder] { encoder->Encode(); }); -// } -// -// private: -// rtc::scoped_refptr encoder_; -// } - -// Document if a function expected to be called from same thread/task queue. -#define RTC_RUN_ON(x) \ - RTC_THREAD_ANNOTATION_ATTRIBUTE__(exclusive_locks_required(x)) - -namespace webrtc { -std::string ExpectationToString(const webrtc::SequenceChecker* checker); - -// Catch-all implementation for types other than explicitly supported above. -template -std::string ExpectationToString(const ThreadLikeObject*) { - return std::string(); -} - -} // namespace webrtc - -#define RTC_DCHECK_RUN_ON(x) \ - webrtc::webrtc_seq_check_impl::SequenceCheckerScope seq_check_scope(x); \ - RTC_DCHECK((x)->IsCurrent()) << webrtc::ExpectationToString(x) - -#endif // RTC_BASE_SYNCHRONIZATION_SEQUENCE_CHECKER_H_ diff --git a/rtc_base/synchronization/sequence_checker.cc b/rtc_base/synchronization/sequence_checker_internal.cc similarity index 92% rename from rtc_base/synchronization/sequence_checker.cc rename to rtc_base/synchronization/sequence_checker_internal.cc index 1de26cf0fe..63badd9538 100644 --- a/rtc_base/synchronization/sequence_checker.cc +++ b/rtc_base/synchronization/sequence_checker_internal.cc @@ -7,15 +7,19 @@ * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ -#include "rtc_base/synchronization/sequence_checker.h" +#include "rtc_base/synchronization/sequence_checker_internal.h" + +#include #if defined(WEBRTC_MAC) #include #endif +#include "rtc_base/checks.h" #include "rtc_base/strings/string_builder.h" namespace webrtc { +namespace webrtc_sequence_checker_internal { namespace { // On Mac, returns the label of the current dispatch queue; elsewhere, return // null. @@ -29,21 +33,12 @@ const void* GetSystemQueueRef() { } // namespace -std::string ExpectationToString(const webrtc::SequenceChecker* checker) { -#if RTC_DCHECK_IS_ON - return checker->ExpectationToString(); -#endif - return std::string(); -} - SequenceCheckerImpl::SequenceCheckerImpl() : attached_(true), valid_thread_(rtc::CurrentThreadRef()), valid_queue_(TaskQueueBase::Current()), valid_system_queue_(GetSystemQueueRef()) {} -SequenceCheckerImpl::~SequenceCheckerImpl() = default; - bool SequenceCheckerImpl::IsCurrent() const { const TaskQueueBase* const current_queue = TaskQueueBase::Current(); const rtc::PlatformThreadRef current_thread = rtc::CurrentThreadRef(); @@ -109,4 +104,13 @@ std::string SequenceCheckerImpl::ExpectationToString() const { } #endif // RTC_DCHECK_IS_ON +std::string ExpectationToString(const SequenceCheckerImpl* checker) { +#if RTC_DCHECK_IS_ON + return checker->ExpectationToString(); +#else + return std::string(); +#endif +} + +} // namespace webrtc_sequence_checker_internal } // namespace webrtc diff --git a/rtc_base/synchronization/sequence_checker_internal.h b/rtc_base/synchronization/sequence_checker_internal.h new file mode 100644 index 0000000000..f7ac6de125 --- /dev/null +++ b/rtc_base/synchronization/sequence_checker_internal.h @@ -0,0 +1,93 @@ +/* + * Copyright 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef RTC_BASE_SYNCHRONIZATION_SEQUENCE_CHECKER_INTERNAL_H_ +#define RTC_BASE_SYNCHRONIZATION_SEQUENCE_CHECKER_INTERNAL_H_ + +#include +#include + +#include "api/task_queue/task_queue_base.h" +#include "rtc_base/platform_thread_types.h" +#include "rtc_base/synchronization/mutex.h" +#include "rtc_base/system/rtc_export.h" +#include "rtc_base/thread_annotations.h" + +namespace webrtc { +namespace webrtc_sequence_checker_internal { + +// Real implementation of SequenceChecker, for use in debug mode, or +// for temporary use in release mode (e.g. to RTC_CHECK on a threading issue +// seen only in the wild). +// +// Note: You should almost always use the SequenceChecker class to get the +// right version for your build configuration. +class RTC_EXPORT SequenceCheckerImpl { + public: + SequenceCheckerImpl(); + ~SequenceCheckerImpl() = default; + + bool IsCurrent() const; + // Changes the task queue or thread that is checked for in IsCurrent. This can + // be useful when an object may be created on one task queue / thread and then + // used exclusively on another thread. + void Detach(); + + // Returns a string that is formatted to match with the error string printed + // by RTC_CHECK() when a condition is not met. + // This is used in conjunction with the RTC_DCHECK_RUN_ON() macro. + std::string ExpectationToString() const; + + private: + mutable Mutex lock_; + // These are mutable so that IsCurrent can set them. + mutable bool attached_ RTC_GUARDED_BY(lock_); + mutable rtc::PlatformThreadRef valid_thread_ RTC_GUARDED_BY(lock_); + mutable const TaskQueueBase* valid_queue_ RTC_GUARDED_BY(lock_); + mutable const void* valid_system_queue_ RTC_GUARDED_BY(lock_); +}; + +// Do nothing implementation, for use in release mode. +// +// Note: You should almost always use the SequenceChecker class to get the +// right version for your build configuration. +class SequenceCheckerDoNothing { + public: + bool IsCurrent() const { return true; } + void Detach() {} +}; + +// Helper class used by RTC_DCHECK_RUN_ON (see example usage below). +class RTC_SCOPED_LOCKABLE SequenceCheckerScope { + public: + template + explicit SequenceCheckerScope(const ThreadLikeObject* thread_like_object) + RTC_EXCLUSIVE_LOCK_FUNCTION(thread_like_object) {} + SequenceCheckerScope(const SequenceCheckerScope&) = delete; + SequenceCheckerScope& operator=(const SequenceCheckerScope&) = delete; + ~SequenceCheckerScope() RTC_UNLOCK_FUNCTION() {} + + template + static bool IsCurrent(const ThreadLikeObject* thread_like_object) { + return thread_like_object->IsCurrent(); + } +}; + +std::string ExpectationToString(const SequenceCheckerImpl* checker); + +// Catch-all implementation for types other than explicitly supported above. +template +std::string ExpectationToString(const ThreadLikeObject*) { + return std::string(); +} + +} // namespace webrtc_sequence_checker_internal +} // namespace webrtc + +#endif // RTC_BASE_SYNCHRONIZATION_SEQUENCE_CHECKER_INTERNAL_H_ diff --git a/rtc_base/system/BUILD.gn b/rtc_base/system/BUILD.gn index 385f2e1d84..c604796e60 100644 --- a/rtc_base/system/BUILD.gn +++ b/rtc_base/system/BUILD.gn @@ -32,6 +32,19 @@ rtc_library("file_wrapper") { ] } +if (rtc_include_tests) { + rtc_library("file_wrapper_unittests") { + testonly = true + sources = [ "file_wrapper_unittest.cc" ] + deps = [ + ":file_wrapper", + "//rtc_base:checks", + "//test:fileutils", + "//test:test_support", + ] + } +} + rtc_source_set("ignore_warnings") { sources = [ "ignore_warnings.h" ] } @@ -57,7 +70,6 @@ rtc_source_set("rtc_export") { rtc_source_set("no_unique_address") { sources = [ "no_unique_address.h" ] - deps = [ "..:sanitizer" ] } if (is_mac || is_ios) { diff --git a/rtc_base/system/file_wrapper.cc b/rtc_base/system/file_wrapper.cc index 2828790e09..3e49315793 100644 --- a/rtc_base/system/file_wrapper.cc +++ b/rtc_base/system/file_wrapper.cc @@ -89,6 +89,22 @@ bool FileWrapper::SeekTo(int64_t position) { return fseek(file_, rtc::checked_cast(position), SEEK_SET) == 0; } +long FileWrapper::FileSize() { + if (file_ == nullptr) + return -1; + long original_position = ftell(file_); + if (original_position < 0) + return -1; + int seek_error = fseek(file_, 0, SEEK_END); + if (seek_error) + return -1; + long file_size = ftell(file_); + seek_error = fseek(file_, original_position, SEEK_SET); + if (seek_error) + return -1; + return file_size; +} + bool FileWrapper::Flush() { RTC_DCHECK(file_); return fflush(file_) == 0; diff --git a/rtc_base/system/file_wrapper.h b/rtc_base/system/file_wrapper.h index 42c463cb15..0b293d9a80 100644 --- a/rtc_base/system/file_wrapper.h +++ b/rtc_base/system/file_wrapper.h @@ -38,7 +38,6 @@ class FileWrapper final { static FileWrapper OpenReadOnly(const std::string& file_name_utf8); static FileWrapper OpenWriteOnly(const char* file_name_utf8, int* error = nullptr); - static FileWrapper OpenWriteOnly(const std::string& file_name_utf8, int* error = nullptr); @@ -87,6 +86,11 @@ class FileWrapper final { // Seek to given position. bool SeekTo(int64_t position); + // Returns the file size or -1 if a size could not be determined. + // (A file size might not exists for non-seekable files or file-like + // objects, for example /dev/tty on unix.) + long FileSize(); + // Returns number of bytes read. Short count indicates EOF or error. size_t Read(void* buf, size_t length); diff --git a/rtc_base/system/file_wrapper_unittest.cc b/rtc_base/system/file_wrapper_unittest.cc new file mode 100644 index 0000000000..980b565c73 --- /dev/null +++ b/rtc_base/system/file_wrapper_unittest.cc @@ -0,0 +1,69 @@ +/* + * Copyright 2021 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "rtc_base/system/file_wrapper.h" + +#include "rtc_base/checks.h" +#include "test/gtest.h" +#include "test/testsupport/file_utils.h" + +namespace webrtc { + +TEST(FileWrapper, FileSize) { + auto test_info = ::testing::UnitTest::GetInstance()->current_test_info(); + std::string test_name = + std::string(test_info->test_case_name()) + "_" + test_info->name(); + std::replace(test_name.begin(), test_name.end(), '/', '_'); + const std::string temp_filename = test::OutputPath() + test_name; + + // Write + { + FileWrapper file = FileWrapper::OpenWriteOnly(temp_filename); + ASSERT_TRUE(file.is_open()); + EXPECT_EQ(file.FileSize(), 0); + + EXPECT_TRUE(file.Write("foo", 3)); + EXPECT_EQ(file.FileSize(), 3); + + // FileSize() doesn't change the file size. + EXPECT_EQ(file.FileSize(), 3); + + // FileSize() doesn't move the write position. + EXPECT_TRUE(file.Write("bar", 3)); + EXPECT_EQ(file.FileSize(), 6); + } + + // Read + { + FileWrapper file = FileWrapper::OpenReadOnly(temp_filename); + ASSERT_TRUE(file.is_open()); + EXPECT_EQ(file.FileSize(), 6); + + char buf[10]; + size_t bytes_read = file.Read(buf, 3); + EXPECT_EQ(bytes_read, 3u); + EXPECT_EQ(memcmp(buf, "foo", 3), 0); + + // FileSize() doesn't move the read position. + EXPECT_EQ(file.FileSize(), 6); + + // Attempting to read past the end reads what is available + // and sets the EOF flag. + bytes_read = file.Read(buf, 5); + EXPECT_EQ(bytes_read, 3u); + EXPECT_EQ(memcmp(buf, "bar", 3), 0); + EXPECT_TRUE(file.ReadEof()); + } + + // Clean up temporary file. + remove(temp_filename.c_str()); +} + +} // namespace webrtc diff --git a/rtc_base/system/no_unique_address.h b/rtc_base/system/no_unique_address.h index eca349c0cc..77e7a99526 100644 --- a/rtc_base/system/no_unique_address.h +++ b/rtc_base/system/no_unique_address.h @@ -11,8 +11,6 @@ #ifndef RTC_BASE_SYSTEM_NO_UNIQUE_ADDRESS_H_ #define RTC_BASE_SYSTEM_NO_UNIQUE_ADDRESS_H_ -#include "rtc_base/sanitizer.h" - // RTC_NO_UNIQUE_ADDRESS is a portable annotation to tell the compiler that // a data member need not have an address distinct from all other non-static // data members of its class. @@ -26,10 +24,7 @@ // should add support for it starting from C++20. Among clang compilers, // clang-cl doesn't support it yet and support is unclear also when the target // platform is iOS. -// -// TODO(bugs.webrtc.org/12218): Re-enable on MSan builds. -#if !RTC_HAS_MSAN && \ - ((defined(__clang__) && !defined(_MSC_VER) && !defined(WEBRTC_IOS)) || \ +#if ((defined(__clang__) && !defined(_MSC_VER) && !defined(WEBRTC_IOS)) || \ __cplusplus > 201703L) // NOLINTNEXTLINE(whitespace/braces) #define RTC_NO_UNIQUE_ADDRESS [[no_unique_address]] diff --git a/rtc_base/system/unused.h b/rtc_base/system/unused.h index a0add4ee29..a5732a7e84 100644 --- a/rtc_base/system/unused.h +++ b/rtc_base/system/unused.h @@ -11,24 +11,9 @@ #ifndef RTC_BASE_SYSTEM_UNUSED_H_ #define RTC_BASE_SYSTEM_UNUSED_H_ -// Annotate a function indicating the caller must examine the return value. -// Use like: -// int foo() RTC_WARN_UNUSED_RESULT; -// To explicitly ignore a result, cast to void. -// TODO(kwiberg): Remove when we can use [[nodiscard]] from C++17. -#if defined(__clang__) -#define RTC_WARN_UNUSED_RESULT __attribute__((__warn_unused_result__)) -#elif defined(__GNUC__) -// gcc has a __warn_unused_result__ attribute, but you can't quiet it by -// casting to void, so we don't use it. -#define RTC_WARN_UNUSED_RESULT -#else -#define RTC_WARN_UNUSED_RESULT -#endif - // Prevent the compiler from warning about an unused variable. For example: // int result = DoSomething(); -// assert(result == 17); +// RTC_DCHECK(result == 17); // RTC_UNUSED(result); // Note: In most cases it is better to remove the unused variable rather than // suppressing the compiler warning. diff --git a/rtc_base/system_time.cc b/rtc_base/system_time.cc new file mode 100644 index 0000000000..9efe76e3a6 --- /dev/null +++ b/rtc_base/system_time.cc @@ -0,0 +1,97 @@ +/* + * Copyright 2021 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// If WEBRTC_EXCLUDE_SYSTEM_TIME is set, an implementation of +// rtc::SystemTimeNanos() must be provided externally. +#ifndef WEBRTC_EXCLUDE_SYSTEM_TIME + +#include + +#include + +#if defined(WEBRTC_POSIX) +#include +#if defined(WEBRTC_MAC) +#include +#endif +#endif + +#if defined(WEBRTC_WIN) +// clang-format off +// clang formatting would put last, +// which leads to compilation failure. +#include +#include +#include +// clang-format on +#endif + +#include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_conversions.h" +#include "rtc_base/system_time.h" +#include "rtc_base/time_utils.h" + +namespace rtc { + +int64_t SystemTimeNanos() { + int64_t ticks; +#if defined(WEBRTC_MAC) + static mach_timebase_info_data_t timebase; + if (timebase.denom == 0) { + // Get the timebase if this is the first time we run. + // Recommended by Apple's QA1398. + if (mach_timebase_info(&timebase) != KERN_SUCCESS) { + RTC_NOTREACHED(); + } + } + // Use timebase to convert absolute time tick units into nanoseconds. + const auto mul = [](uint64_t a, uint32_t b) -> int64_t { + RTC_DCHECK_NE(b, 0); + RTC_DCHECK_LE(a, std::numeric_limits::max() / b) + << "The multiplication " << a << " * " << b << " overflows"; + return rtc::dchecked_cast(a * b); + }; + ticks = mul(mach_absolute_time(), timebase.numer) / timebase.denom; +#elif defined(WEBRTC_POSIX) + struct timespec ts; + // TODO(deadbeef): Do we need to handle the case when CLOCK_MONOTONIC is not + // supported? + clock_gettime(CLOCK_MONOTONIC, &ts); + ticks = kNumNanosecsPerSec * static_cast(ts.tv_sec) + + static_cast(ts.tv_nsec); +#elif defined(WINUWP) + ticks = WinUwpSystemTimeNanos(); +#elif defined(WEBRTC_WIN) + static volatile LONG last_timegettime = 0; + static volatile int64_t num_wrap_timegettime = 0; + volatile LONG* last_timegettime_ptr = &last_timegettime; + DWORD now = timeGetTime(); + // Atomically update the last gotten time + DWORD old = InterlockedExchange(last_timegettime_ptr, now); + if (now < old) { + // If now is earlier than old, there may have been a race between threads. + // 0x0fffffff ~3.1 days, the code will not take that long to execute + // so it must have been a wrap around. + if (old > 0xf0000000 && now < 0x0fffffff) { + num_wrap_timegettime++; + } + } + ticks = now + (num_wrap_timegettime << 32); + // TODO(deadbeef): Calculate with nanosecond precision. Otherwise, we're + // just wasting a multiply and divide when doing Time() on Windows. + ticks = ticks * kNumNanosecsPerMillisec; +#else +#error Unsupported platform. +#endif + return ticks; +} + +} // namespace rtc +#endif // WEBRTC_EXCLUDE_SYSTEM_TIME diff --git a/rtc_base/system_time.h b/rtc_base/system_time.h new file mode 100644 index 0000000000..d86e94adf4 --- /dev/null +++ b/rtc_base/system_time.h @@ -0,0 +1,22 @@ +/* + * Copyright 2021 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef RTC_BASE_SYSTEM_TIME_H_ +#define RTC_BASE_SYSTEM_TIME_H_ + +namespace rtc { + +// Returns the actual system time, even if a clock is set for testing. +// Useful for timeouts while using a test clock, or for logging. +int64_t SystemTimeNanos(); + +} // namespace rtc + +#endif // RTC_BASE_SYSTEM_TIME_H_ diff --git a/rtc_base/task_queue_libevent.cc b/rtc_base/task_queue_libevent.cc index 38660cd5a2..909698611e 100644 --- a/rtc_base/task_queue_libevent.cc +++ b/rtc_base/task_queue_libevent.cc @@ -93,16 +93,12 @@ void EventAssign(struct event* ev, rtc::ThreadPriority TaskQueuePriorityToThreadPriority(Priority priority) { switch (priority) { case Priority::HIGH: - return rtc::kRealtimePriority; + return rtc::ThreadPriority::kRealtime; case Priority::LOW: - return rtc::kLowPriority; + return rtc::ThreadPriority::kLow; case Priority::NORMAL: - return rtc::kNormalPriority; - default: - RTC_NOTREACHED(); - break; + return rtc::ThreadPriority::kNormal; } - return rtc::kNormalPriority; } class TaskQueueLibevent final : public TaskQueueBase { @@ -120,7 +116,6 @@ class TaskQueueLibevent final : public TaskQueueBase { ~TaskQueueLibevent() override = default; - static void ThreadMain(void* context); static void OnWakeup(int socket, short flags, void* context); // NOLINT static void RunTimer(int fd, short flags, void* context); // NOLINT @@ -172,8 +167,7 @@ class TaskQueueLibevent::SetTimerTask : public QueuedTask { TaskQueueLibevent::TaskQueueLibevent(absl::string_view queue_name, rtc::ThreadPriority priority) - : event_base_(event_base_new()), - thread_(&TaskQueueLibevent::ThreadMain, this, queue_name, priority) { + : event_base_(event_base_new()) { int fds[2]; RTC_CHECK(pipe(fds) == 0); SetNonBlocking(fds[0]); @@ -184,7 +178,18 @@ TaskQueueLibevent::TaskQueueLibevent(absl::string_view queue_name, EventAssign(&wakeup_event_, event_base_, wakeup_pipe_out_, EV_READ | EV_PERSIST, OnWakeup, this); event_add(&wakeup_event_, 0); - thread_.Start(); + thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { + { + CurrentTaskQueueSetter set_current(this); + while (is_active_) + event_base_loop(event_base_, 0); + } + + for (TimerEvent* timer : pending_timers_) + delete timer; + }, + queue_name, rtc::ThreadAttributes().SetPriority(priority)); } void TaskQueueLibevent::Delete() { @@ -199,7 +204,7 @@ void TaskQueueLibevent::Delete() { nanosleep(&ts, nullptr); } - thread_.Stop(); + thread_.Finalize(); event_del(&wakeup_event_); @@ -252,20 +257,6 @@ void TaskQueueLibevent::PostDelayedTask(std::unique_ptr task, } } -// static -void TaskQueueLibevent::ThreadMain(void* context) { - TaskQueueLibevent* me = static_cast(context); - - { - CurrentTaskQueueSetter set_current(me); - while (me->is_active_) - event_base_loop(me->event_base_, 0); - } - - for (TimerEvent* timer : me->pending_timers_) - delete timer; -} - // static void TaskQueueLibevent::OnWakeup(int socket, short flags, // NOLINT diff --git a/rtc_base/task_queue_stdlib.cc b/rtc_base/task_queue_stdlib.cc index 5de634512e..41da285ee7 100644 --- a/rtc_base/task_queue_stdlib.cc +++ b/rtc_base/task_queue_stdlib.cc @@ -36,14 +36,11 @@ rtc::ThreadPriority TaskQueuePriorityToThreadPriority( TaskQueueFactory::Priority priority) { switch (priority) { case TaskQueueFactory::Priority::HIGH: - return rtc::kRealtimePriority; + return rtc::ThreadPriority::kRealtime; case TaskQueueFactory::Priority::LOW: - return rtc::kLowPriority; + return rtc::ThreadPriority::kLow; case TaskQueueFactory::Priority::NORMAL: - return rtc::kNormalPriority; - default: - RTC_NOTREACHED(); - return rtc::kNormalPriority; + return rtc::ThreadPriority::kNormal; } } @@ -78,8 +75,6 @@ class TaskQueueStdlib final : public TaskQueueBase { NextTask GetNextTask(); - static void ThreadMain(void* context); - void ProcessTasks(); void NotifyWake(); @@ -87,16 +82,9 @@ class TaskQueueStdlib final : public TaskQueueBase { // Indicates if the thread has started. rtc::Event started_; - // Indicates if the thread has stopped. - rtc::Event stopped_; - // Signaled whenever a new task is pending. rtc::Event flag_notify_; - // Contains the active worker thread assigned to processing - // tasks (including delayed tasks). - rtc::PlatformThread thread_; - Mutex pending_lock_; // Indicates if the worker thread needs to shutdown now. @@ -119,15 +107,25 @@ class TaskQueueStdlib final : public TaskQueueBase { // std::unique_ptr out of the queue without the presence of a hack. std::map> delayed_queue_ RTC_GUARDED_BY(pending_lock_); + + // Contains the active worker thread assigned to processing + // tasks (including delayed tasks). + // Placing this last ensures the thread doesn't touch uninitialized attributes + // throughout it's lifetime. + rtc::PlatformThread thread_; }; TaskQueueStdlib::TaskQueueStdlib(absl::string_view queue_name, rtc::ThreadPriority priority) : started_(/*manual_reset=*/false, /*initially_signaled=*/false), - stopped_(/*manual_reset=*/false, /*initially_signaled=*/false), flag_notify_(/*manual_reset=*/false, /*initially_signaled=*/false), - thread_(&TaskQueueStdlib::ThreadMain, this, queue_name, priority) { - thread_.Start(); + thread_(rtc::PlatformThread::SpawnJoinable( + [this] { + CurrentTaskQueueSetter set_current(this); + ProcessTasks(); + }, + queue_name, + rtc::ThreadAttributes().SetPriority(priority))) { started_.Wait(rtc::Event::kForever); } @@ -141,8 +139,6 @@ void TaskQueueStdlib::Delete() { NotifyWake(); - stopped_.Wait(rtc::Event::kForever); - thread_.Stop(); delete this; } @@ -219,13 +215,6 @@ TaskQueueStdlib::NextTask TaskQueueStdlib::GetNextTask() { return result; } -// static -void TaskQueueStdlib::ThreadMain(void* context) { - TaskQueueStdlib* me = static_cast(context); - CurrentTaskQueueSetter set_current(me); - me->ProcessTasks(); -} - void TaskQueueStdlib::ProcessTasks() { started_.Set(); @@ -250,8 +239,6 @@ void TaskQueueStdlib::ProcessTasks() { else flag_notify_.Wait(task.sleep_time_ms_); } - - stopped_.Set(); } void TaskQueueStdlib::NotifyWake() { diff --git a/rtc_base/task_queue_unittest.cc b/rtc_base/task_queue_unittest.cc index a7148dcdd1..0c79858630 100644 --- a/rtc_base/task_queue_unittest.cc +++ b/rtc_base/task_queue_unittest.cc @@ -21,7 +21,6 @@ #include #include "absl/memory/memory.h" -#include "rtc_base/bind.h" #include "rtc_base/event.h" #include "rtc_base/task_queue_for_test.h" #include "rtc_base/time_utils.h" @@ -67,7 +66,7 @@ TEST(TaskQueueTest, DISABLED_PostDelayedHighRes) { webrtc::TaskQueueForTest queue(kQueueName, TaskQueue::Priority::HIGH); uint32_t start = Time(); - queue.PostDelayedTask(Bind(&CheckCurrent, &event, &queue), 3); + queue.PostDelayedTask([&event, &queue] { CheckCurrent(&event, &queue); }, 3); EXPECT_TRUE(event.Wait(1000)); uint32_t end = TimeMillis(); // These tests are a little relaxed due to how "powerful" our test bots can diff --git a/rtc_base/task_queue_win.cc b/rtc_base/task_queue_win.cc index 5eb3776cea..d797d478f4 100644 --- a/rtc_base/task_queue_win.cc +++ b/rtc_base/task_queue_win.cc @@ -29,16 +29,18 @@ #include #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "api/task_queue/queued_task.h" #include "api/task_queue/task_queue_base.h" #include "rtc_base/arraysize.h" #include "rtc_base/checks.h" +#include "rtc_base/constructor_magic.h" #include "rtc_base/event.h" #include "rtc_base/logging.h" #include "rtc_base/numerics/safe_conversions.h" #include "rtc_base/platform_thread.h" -#include "rtc_base/time_utils.h" #include "rtc_base/synchronization/mutex.h" +#include "rtc_base/time_utils.h" namespace webrtc { namespace { @@ -56,16 +58,12 @@ rtc::ThreadPriority TaskQueuePriorityToThreadPriority( TaskQueueFactory::Priority priority) { switch (priority) { case TaskQueueFactory::Priority::HIGH: - return rtc::kRealtimePriority; + return rtc::ThreadPriority::kRealtime; case TaskQueueFactory::Priority::LOW: - return rtc::kLowPriority; + return rtc::ThreadPriority::kLow; case TaskQueueFactory::Priority::NORMAL: - return rtc::kNormalPriority; - default: - RTC_NOTREACHED(); - break; + return rtc::ThreadPriority::kNormal; } - return rtc::kNormalPriority; } int64_t GetTick() { @@ -167,21 +165,6 @@ class TaskQueueWin : public TaskQueueBase { void RunPendingTasks(); private: - static void ThreadMain(void* context); - - class WorkerThread : public rtc::PlatformThread { - public: - WorkerThread(rtc::ThreadRunFunction func, - void* obj, - absl::string_view thread_name, - rtc::ThreadPriority priority) - : PlatformThread(func, obj, thread_name, priority) {} - - bool QueueAPC(PAPCFUNC apc_function, ULONG_PTR data) { - return rtc::PlatformThread::QueueAPC(apc_function, data); - } - }; - void RunThreadMain(); bool ProcessQueuedMessages(); void RunDueTasks(); @@ -204,7 +187,7 @@ class TaskQueueWin : public TaskQueueBase { greater> timer_tasks_; UINT_PTR timer_id_ = 0; - WorkerThread thread_; + rtc::PlatformThread thread_; Mutex pending_lock_; std::queue> pending_ RTC_GUARDED_BY(pending_lock_); @@ -213,10 +196,12 @@ class TaskQueueWin : public TaskQueueBase { TaskQueueWin::TaskQueueWin(absl::string_view queue_name, rtc::ThreadPriority priority) - : thread_(&TaskQueueWin::ThreadMain, this, queue_name, priority), - in_queue_(::CreateEvent(nullptr, true, false, nullptr)) { + : in_queue_(::CreateEvent(nullptr, true, false, nullptr)) { RTC_DCHECK(in_queue_); - thread_.Start(); + thread_ = rtc::PlatformThread::SpawnJoinable( + [this] { RunThreadMain(); }, queue_name, + rtc::ThreadAttributes().SetPriority(priority)); + rtc::Event event(false, false); RTC_CHECK(thread_.QueueAPC(&InitializeQueueThread, reinterpret_cast(&event))); @@ -225,11 +210,13 @@ TaskQueueWin::TaskQueueWin(absl::string_view queue_name, void TaskQueueWin::Delete() { RTC_DCHECK(!IsCurrent()); - while (!::PostThreadMessage(thread_.GetThreadRef(), WM_QUIT, 0, 0)) { + RTC_CHECK(thread_.GetHandle() != absl::nullopt); + while ( + !::PostThreadMessage(GetThreadId(*thread_.GetHandle()), WM_QUIT, 0, 0)) { RTC_CHECK_EQ(ERROR_NOT_ENOUGH_QUOTA, ::GetLastError()); Sleep(1); } - thread_.Stop(); + thread_.Finalize(); ::CloseHandle(in_queue_); delete this; } @@ -252,7 +239,9 @@ void TaskQueueWin::PostDelayedTask(std::unique_ptr task, // and WPARAM is 32bits in 32bit builds. Otherwise, we could pass the // task pointer and timestamp as LPARAM and WPARAM. auto* task_info = new DelayedTaskInfo(milliseconds, std::move(task)); - if (!::PostThreadMessage(thread_.GetThreadRef(), WM_QUEUE_DELAYED_TASK, 0, + RTC_CHECK(thread_.GetHandle() != absl::nullopt); + if (!::PostThreadMessage(GetThreadId(*thread_.GetHandle()), + WM_QUEUE_DELAYED_TASK, 0, reinterpret_cast(task_info))) { delete task_info; } @@ -274,11 +263,6 @@ void TaskQueueWin::RunPendingTasks() { } } -// static -void TaskQueueWin::ThreadMain(void* context) { - static_cast(context)->RunThreadMain(); -} - void TaskQueueWin::RunThreadMain() { CurrentTaskQueueSetter set_current(this); HANDLE handles[2] = {*timer_.event_for_wait(), in_queue_}; diff --git a/rtc_base/task_utils/BUILD.gn b/rtc_base/task_utils/BUILD.gn index 018844fe65..64a041908e 100644 --- a/rtc_base/task_utils/BUILD.gn +++ b/rtc_base/task_utils/BUILD.gn @@ -14,15 +14,15 @@ rtc_library("repeating_task") { "repeating_task.h", ] deps = [ + ":pending_task_safety_flag", ":to_queued_task", "..:logging", - "..:thread_checker", "..:timeutils", + "../../api:sequence_checker", "../../api/task_queue", "../../api/units:time_delta", "../../api/units:timestamp", "../../system_wrappers:system_wrappers", - "../synchronization:sequence_checker", ] absl_deps = [ "//third_party/abseil-cpp/absl/memory" ] } @@ -34,10 +34,9 @@ rtc_library("pending_task_safety_flag") { ] deps = [ "..:checks", - "..:refcount", - "..:thread_checker", + "../../api:refcountedbase", "../../api:scoped_refptr", - "../synchronization:sequence_checker", + "../../api:sequence_checker", "../system:no_unique_address", ] } diff --git a/rtc_base/task_utils/pending_task_safety_flag.cc b/rtc_base/task_utils/pending_task_safety_flag.cc index 4be2131f3f..57b3f6ce88 100644 --- a/rtc_base/task_utils/pending_task_safety_flag.cc +++ b/rtc_base/task_utils/pending_task_safety_flag.cc @@ -10,13 +10,27 @@ #include "rtc_base/task_utils/pending_task_safety_flag.h" -#include "rtc_base/ref_counted_object.h" - namespace webrtc { // static rtc::scoped_refptr PendingTaskSafetyFlag::Create() { - return new rtc::RefCountedObject(); + return new PendingTaskSafetyFlag(true); +} + +rtc::scoped_refptr +PendingTaskSafetyFlag::CreateDetached() { + rtc::scoped_refptr safety_flag( + new PendingTaskSafetyFlag(true)); + safety_flag->main_sequence_.Detach(); + return safety_flag; +} + +rtc::scoped_refptr +PendingTaskSafetyFlag::CreateDetachedInactive() { + rtc::scoped_refptr safety_flag( + new PendingTaskSafetyFlag(false)); + safety_flag->main_sequence_.Detach(); + return safety_flag; } void PendingTaskSafetyFlag::SetNotAlive() { @@ -24,6 +38,11 @@ void PendingTaskSafetyFlag::SetNotAlive() { alive_ = false; } +void PendingTaskSafetyFlag::SetAlive() { + RTC_DCHECK_RUN_ON(&main_sequence_); + alive_ = true; +} + bool PendingTaskSafetyFlag::alive() const { RTC_DCHECK_RUN_ON(&main_sequence_); return alive_; diff --git a/rtc_base/task_utils/pending_task_safety_flag.h b/rtc_base/task_utils/pending_task_safety_flag.h index 182db2cbbc..fc1b5bd878 100644 --- a/rtc_base/task_utils/pending_task_safety_flag.h +++ b/rtc_base/task_utils/pending_task_safety_flag.h @@ -11,24 +11,30 @@ #ifndef RTC_BASE_TASK_UTILS_PENDING_TASK_SAFETY_FLAG_H_ #define RTC_BASE_TASK_UTILS_PENDING_TASK_SAFETY_FLAG_H_ +#include "api/ref_counted_base.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "rtc_base/checks.h" -#include "rtc_base/ref_count.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" namespace webrtc { -// Use this flag to drop pending tasks that have been posted to the "main" -// thread/TQ and end up running after the owning instance has been -// deleted. The owning instance signals deletion by calling SetNotAlive() from -// its destructor. -// +// The PendingTaskSafetyFlag and the ScopedTaskSafety are designed to address +// the issue where you have a task to be executed later that has references, +// but cannot guarantee that the referenced object is alive when the task is +// executed. + +// This mechanism can be used with tasks that are created and destroyed +// on a single thread / task queue, and with tasks posted to the same +// thread/task queue, but tasks can be posted from any thread/TQ. + +// Typical usage: // When posting a task, post a copy (capture by-value in a lambda) of the flag -// instance and before performing the work, check the |alive()| state. Abort if +// reference and before performing the work, check the |alive()| state. Abort if // alive() returns |false|: // -// // Running outside of the main thread. +// class ExampleClass { +// .... // my_task_queue_->PostTask(ToQueuedTask( // [safety = pending_task_safety_flag_, this]() { // // Now running on the main thread. @@ -36,39 +42,79 @@ namespace webrtc { // return; // MyMethod(); // })); +// .... +// ~ExampleClass() { +// pending_task_safety_flag_->SetNotAlive(); +// } +// scoped_refptr pending_task_safety_flag_ +// = PendingTaskSafetyFlag::Create(); +// } // -// Or implicitly by letting ToQueuedTask do the checking: +// ToQueuedTask has an overload that makes this check automatic: // -// // Running outside of the main thread. // my_task_queue_->PostTask(ToQueuedTask(pending_task_safety_flag_, // [this]() { MyMethod(); })); // -// Note that checking the state only works on the construction/destruction -// thread of the ReceiveStatisticsProxy instance. -class PendingTaskSafetyFlag : public rtc::RefCountInterface { +class PendingTaskSafetyFlag final + : public rtc::RefCountedNonVirtual { public: static rtc::scoped_refptr Create(); + // Creates a flag, but with its SequenceChecker initially detached. Hence, it + // may be created on a different thread than the flag will be used on. + static rtc::scoped_refptr CreateDetached(); + + // Same as `CreateDetached()` except the initial state of the returned flag + // will be `!alive()`. + static rtc::scoped_refptr CreateDetachedInactive(); + ~PendingTaskSafetyFlag() = default; void SetNotAlive(); + // The SetAlive method is intended to support Start/Stop/Restart usecases. + // When a class has called SetNotAlive on a flag used for posted tasks, and + // decides it wants to post new tasks and have them run, there are two + // reasonable ways to do that: + // + // (i) Use the below SetAlive method. One subtlety is that any task posted + // prior to SetNotAlive, and still in the queue, is resurrected and will + // run. + // + // (ii) Create a fresh flag, and just drop the reference to the old one. This + // avoids the above problem, and ensures that tasks poster prior to + // SetNotAlive stay cancelled. Instead, there's a potential data race on + // the flag pointer itself. Some synchronization is required between the + // thread overwriting the flag pointer, and the threads that want to post + // tasks and therefore read that same pointer. + void SetAlive(); bool alive() const; protected: - PendingTaskSafetyFlag() = default; + explicit PendingTaskSafetyFlag(bool alive) : alive_(alive) {} private: bool alive_ = true; RTC_NO_UNIQUE_ADDRESS SequenceChecker main_sequence_; }; -// Makes using PendingTaskSafetyFlag very simple. Automatic PTSF creation -// and signalling of destruction when the ScopedTaskSafety instance goes out -// of scope. -// Should be used by the class that wants tasks dropped after destruction. -// Requirements are that the instance be constructed and destructed on +// The ScopedTaskSafety makes using PendingTaskSafetyFlag very simple. +// It does automatic PTSF creation and signalling of destruction when the +// ScopedTaskSafety instance goes out of scope. +// +// ToQueuedTask has an overload that takes a ScopedTaskSafety too, so there +// is no need to explicitly call the "flag" method. +// +// Example usage: +// +// my_task_queue->PostTask(ToQueuedTask(scoped_task_safety, +// [this]() { +// // task goes here +// } +// +// This should be used by the class that wants tasks dropped after destruction. +// The requirement is that the instance has to be constructed and destructed on // the same thread as the potentially dropped tasks would be running on. -class ScopedTaskSafety { +class ScopedTaskSafety final { public: ScopedTaskSafety() = default; ~ScopedTaskSafety() { flag_->SetNotAlive(); } @@ -81,6 +127,21 @@ class ScopedTaskSafety { PendingTaskSafetyFlag::Create(); }; +// Like ScopedTaskSafety, but allows construction on a different thread than +// where the flag will be used. +class ScopedTaskSafetyDetached final { + public: + ScopedTaskSafetyDetached() = default; + ~ScopedTaskSafetyDetached() { flag_->SetNotAlive(); } + + // Returns a new reference to the safety flag. + rtc::scoped_refptr flag() const { return flag_; } + + private: + rtc::scoped_refptr flag_ = + PendingTaskSafetyFlag::CreateDetached(); +}; + } // namespace webrtc #endif // RTC_BASE_TASK_UTILS_PENDING_TASK_SAFETY_FLAG_H_ diff --git a/rtc_base/task_utils/pending_task_safety_flag_unittest.cc b/rtc_base/task_utils/pending_task_safety_flag_unittest.cc index 6df2fe2ffb..07bbea296e 100644 --- a/rtc_base/task_utils/pending_task_safety_flag_unittest.cc +++ b/rtc_base/task_utils/pending_task_safety_flag_unittest.cc @@ -156,8 +156,27 @@ TEST(PendingTaskSafetyFlagTest, PendingTaskDropped) { blocker.Set(); // Run an empty task on tq1 to flush all the queued tasks. - tq1.SendTask([]() {}, RTC_FROM_HERE); + tq1.WaitForPreviouslyPostedTasks(); ASSERT_FALSE(owner); EXPECT_FALSE(stuff_done); } + +TEST(PendingTaskSafetyFlagTest, PendingTaskNotAliveInitialized) { + TaskQueueForTest tq("PendingTaskNotAliveInitialized"); + + // Create a new flag that initially not `alive`. + auto flag = PendingTaskSafetyFlag::CreateDetachedInactive(); + tq.SendTask([&flag]() { EXPECT_FALSE(flag->alive()); }, RTC_FROM_HERE); + + bool task_1_ran = false; + bool task_2_ran = false; + tq.PostTask(ToQueuedTask(flag, [&task_1_ran]() { task_1_ran = true; })); + tq.PostTask([&flag]() { flag->SetAlive(); }); + tq.PostTask(ToQueuedTask(flag, [&task_2_ran]() { task_2_ran = true; })); + + tq.WaitForPreviouslyPostedTasks(); + EXPECT_FALSE(task_1_ran); + EXPECT_TRUE(task_2_ran); +} + } // namespace webrtc diff --git a/rtc_base/task_utils/repeating_task.cc b/rtc_base/task_utils/repeating_task.cc index 574e6331f1..9636680cb4 100644 --- a/rtc_base/task_utils/repeating_task.cc +++ b/rtc_base/task_utils/repeating_task.cc @@ -12,32 +12,36 @@ #include "absl/memory/memory.h" #include "rtc_base/logging.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/time_utils.h" namespace webrtc { namespace webrtc_repeating_task_impl { -RepeatingTaskBase::RepeatingTaskBase(TaskQueueBase* task_queue, - TimeDelta first_delay, - Clock* clock) +RepeatingTaskBase::RepeatingTaskBase( + TaskQueueBase* task_queue, + TimeDelta first_delay, + Clock* clock, + rtc::scoped_refptr alive_flag) : task_queue_(task_queue), clock_(clock), - next_run_time_(clock_->CurrentTime() + first_delay) {} + next_run_time_(clock_->CurrentTime() + first_delay), + alive_flag_(std::move(alive_flag)) {} RepeatingTaskBase::~RepeatingTaskBase() = default; bool RepeatingTaskBase::Run() { RTC_DCHECK_RUN_ON(task_queue_); // Return true to tell the TaskQueue to destruct this object. - if (next_run_time_.IsPlusInfinity()) + if (!alive_flag_->alive()) return true; TimeDelta delay = RunClosure(); // The closure might have stopped this task, in which case we return true to // destruct this object. - if (next_run_time_.IsPlusInfinity()) + if (!alive_flag_->alive()) return true; RTC_DCHECK(delay.IsFinite()); @@ -53,33 +57,11 @@ bool RepeatingTaskBase::Run() { return false; } -void RepeatingTaskBase::Stop() { - RTC_DCHECK_RUN_ON(task_queue_); - RTC_DCHECK(next_run_time_.IsFinite()); - next_run_time_ = Timestamp::PlusInfinity(); -} - } // namespace webrtc_repeating_task_impl -RepeatingTaskHandle::RepeatingTaskHandle(RepeatingTaskHandle&& other) - : repeating_task_(other.repeating_task_) { - other.repeating_task_ = nullptr; -} - -RepeatingTaskHandle& RepeatingTaskHandle::operator=( - RepeatingTaskHandle&& other) { - repeating_task_ = other.repeating_task_; - other.repeating_task_ = nullptr; - return *this; -} - -RepeatingTaskHandle::RepeatingTaskHandle( - webrtc_repeating_task_impl::RepeatingTaskBase* repeating_task) - : repeating_task_(repeating_task) {} - void RepeatingTaskHandle::Stop() { if (repeating_task_) { - repeating_task_->Stop(); + repeating_task_->SetNotAlive(); repeating_task_ = nullptr; } } diff --git a/rtc_base/task_utils/repeating_task.h b/rtc_base/task_utils/repeating_task.h index 487b7d19d4..d5066fdb5c 100644 --- a/rtc_base/task_utils/repeating_task.h +++ b/rtc_base/task_utils/repeating_task.h @@ -19,22 +19,19 @@ #include "api/task_queue/task_queue_base.h" #include "api/units/time_delta.h" #include "api/units/timestamp.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" #include "system_wrappers/include/clock.h" namespace webrtc { - -class RepeatingTaskHandle; - namespace webrtc_repeating_task_impl { class RepeatingTaskBase : public QueuedTask { public: RepeatingTaskBase(TaskQueueBase* task_queue, TimeDelta first_delay, - Clock* clock); + Clock* clock, + rtc::scoped_refptr alive_flag); ~RepeatingTaskBase() override; - void Stop(); - private: virtual TimeDelta RunClosure() = 0; @@ -42,9 +39,10 @@ class RepeatingTaskBase : public QueuedTask { TaskQueueBase* const task_queue_; Clock* const clock_; - // This is always finite, except for the special case where it's PlusInfinity - // to signal that the task should stop. + // This is always finite. Timestamp next_run_time_ RTC_GUARDED_BY(task_queue_); + rtc::scoped_refptr alive_flag_ + RTC_GUARDED_BY(task_queue_); }; // The template closure pattern is based on rtc::ClosureTask. @@ -54,8 +52,12 @@ class RepeatingTaskImpl final : public RepeatingTaskBase { RepeatingTaskImpl(TaskQueueBase* task_queue, TimeDelta first_delay, Closure&& closure, - Clock* clock) - : RepeatingTaskBase(task_queue, first_delay, clock), + Clock* clock, + rtc::scoped_refptr alive_flag) + : RepeatingTaskBase(task_queue, + first_delay, + clock, + std::move(alive_flag)), closure_(std::forward(closure)) { static_assert( std::is_same static RepeatingTaskHandle Start(TaskQueueBase* task_queue, Closure&& closure, Clock* clock = Clock::GetRealTimeClock()) { - auto repeating_task = std::make_unique< - webrtc_repeating_task_impl::RepeatingTaskImpl>( - task_queue, TimeDelta::Zero(), std::forward(closure), clock); - auto* repeating_task_ptr = repeating_task.get(); - task_queue->PostTask(std::move(repeating_task)); - return RepeatingTaskHandle(repeating_task_ptr); + auto alive_flag = PendingTaskSafetyFlag::CreateDetached(); + task_queue->PostTask( + std::make_unique< + webrtc_repeating_task_impl::RepeatingTaskImpl>( + task_queue, TimeDelta::Zero(), std::forward(closure), + clock, alive_flag)); + return RepeatingTaskHandle(std::move(alive_flag)); } // DelayedStart is equivalent to Start except that the first invocation of the @@ -113,12 +114,14 @@ class RepeatingTaskHandle { TimeDelta first_delay, Closure&& closure, Clock* clock = Clock::GetRealTimeClock()) { - auto repeating_task = std::make_unique< - webrtc_repeating_task_impl::RepeatingTaskImpl>( - task_queue, first_delay, std::forward(closure), clock); - auto* repeating_task_ptr = repeating_task.get(); - task_queue->PostDelayedTask(std::move(repeating_task), first_delay.ms()); - return RepeatingTaskHandle(repeating_task_ptr); + auto alive_flag = PendingTaskSafetyFlag::CreateDetached(); + task_queue->PostDelayedTask( + std::make_unique< + webrtc_repeating_task_impl::RepeatingTaskImpl>( + task_queue, first_delay, std::forward(closure), clock, + alive_flag), + first_delay.ms()); + return RepeatingTaskHandle(std::move(alive_flag)); } // Stops future invocations of the repeating task closure. Can only be called @@ -127,15 +130,15 @@ class RepeatingTaskHandle { // closure itself. void Stop(); - // Returns true if Start() or DelayedStart() was called most recently. Returns - // false initially and if Stop() or PostStop() was called most recently. + // Returns true until Stop() was called. + // Can only be called from the TaskQueue where the task is running. bool Running() const; private: explicit RepeatingTaskHandle( - webrtc_repeating_task_impl::RepeatingTaskBase* repeating_task); - // Owned by the task queue. - webrtc_repeating_task_impl::RepeatingTaskBase* repeating_task_ = nullptr; + rtc::scoped_refptr alive_flag) + : repeating_task_(std::move(alive_flag)) {} + rtc::scoped_refptr repeating_task_; }; } // namespace webrtc diff --git a/rtc_base/task_utils/repeating_task_unittest.cc b/rtc_base/task_utils/repeating_task_unittest.cc index 2fb15d1e5a..b23284f988 100644 --- a/rtc_base/task_utils/repeating_task_unittest.cc +++ b/rtc_base/task_utils/repeating_task_unittest.cc @@ -276,4 +276,22 @@ TEST(RepeatingTaskTest, ClockIntegration) { handle.Stop(); } +TEST(RepeatingTaskTest, CanBeStoppedAfterTaskQueueDeletedTheRepeatingTask) { + std::unique_ptr repeating_task; + + MockTaskQueue task_queue; + EXPECT_CALL(task_queue, PostDelayedTask) + .WillOnce([&](std::unique_ptr task, uint32_t milliseconds) { + repeating_task = std::move(task); + }); + + RepeatingTaskHandle handle = + RepeatingTaskHandle::DelayedStart(&task_queue, TimeDelta::Millis(100), + [] { return TimeDelta::Millis(100); }); + + // shutdown task queue: delete all pending tasks and run 'regular' task. + repeating_task = nullptr; + handle.Stop(); +} + } // namespace webrtc diff --git a/rtc_base/task_utils/to_queued_task.h b/rtc_base/task_utils/to_queued_task.h index 07ab0ebe26..b2e3aae7ae 100644 --- a/rtc_base/task_utils/to_queued_task.h +++ b/rtc_base/task_utils/to_queued_task.h @@ -20,7 +20,7 @@ namespace webrtc { namespace webrtc_new_closure_impl { -// Simple implementation of QueuedTask for use with rtc::Bind and lambdas. +// Simple implementation of QueuedTask for use with lambdas. template class ClosureTask : public QueuedTask { public: diff --git a/rtc_base/test_utils.h b/rtc_base/test_utils.h index 4746e962ae..7068e73881 100644 --- a/rtc_base/test_utils.h +++ b/rtc_base/test_utils.h @@ -17,25 +17,23 @@ #include #include "rtc_base/async_socket.h" -#include "rtc_base/stream.h" #include "rtc_base/third_party/sigslot/sigslot.h" namespace webrtc { namespace testing { /////////////////////////////////////////////////////////////////////////////// -// StreamSink - Monitor asynchronously signalled events from StreamInterface -// or AsyncSocket (which should probably be a StreamInterface. +// StreamSink - Monitor asynchronously signalled events from AsyncSocket. /////////////////////////////////////////////////////////////////////////////// -// Note: Any event that is an error is treaded as SSE_ERROR instead of that +// Note: Any event that is an error is treated as SSE_ERROR instead of that // event. enum StreamSinkEvent { - SSE_OPEN = rtc::SE_OPEN, - SSE_READ = rtc::SE_READ, - SSE_WRITE = rtc::SE_WRITE, - SSE_CLOSE = rtc::SE_CLOSE, + SSE_OPEN = 1, + SSE_READ = 2, + SSE_WRITE = 4, + SSE_CLOSE = 8, SSE_ERROR = 16 }; @@ -44,24 +42,6 @@ class StreamSink : public sigslot::has_slots<> { StreamSink(); ~StreamSink() override; - void Monitor(rtc::StreamInterface* stream) { - stream->SignalEvent.connect(this, &StreamSink::OnEvent); - events_.erase(stream); - } - void Unmonitor(rtc::StreamInterface* stream) { - stream->SignalEvent.disconnect(this); - // In case you forgot to unmonitor a previous object with this address - events_.erase(stream); - } - bool Check(rtc::StreamInterface* stream, - StreamSinkEvent event, - bool reset = true) { - return DoCheck(stream, event, reset); - } - int Events(rtc::StreamInterface* stream, bool reset = true) { - return DoEvents(stream, reset); - } - void Monitor(rtc::AsyncSocket* socket) { socket->SignalConnectEvent.connect(this, &StreamSink::OnConnectEvent); socket->SignalReadEvent.connect(this, &StreamSink::OnReadEvent); @@ -82,19 +62,10 @@ class StreamSink : public sigslot::has_slots<> { bool reset = true) { return DoCheck(socket, event, reset); } - int Events(rtc::AsyncSocket* socket, bool reset = true) { - return DoEvents(socket, reset); - } private: - typedef std::map EventMap; + typedef std::map EventMap; - void OnEvent(rtc::StreamInterface* stream, int events, int error) { - if (error) { - events = SSE_ERROR; - } - AddEvents(stream, events); - } void OnConnectEvent(rtc::AsyncSocket* socket) { AddEvents(socket, SSE_OPEN); } void OnReadEvent(rtc::AsyncSocket* socket) { AddEvents(socket, SSE_READ); } void OnWriteEvent(rtc::AsyncSocket* socket) { AddEvents(socket, SSE_WRITE); } @@ -102,7 +73,7 @@ class StreamSink : public sigslot::has_slots<> { AddEvents(socket, (0 == error) ? SSE_CLOSE : SSE_ERROR); } - void AddEvents(void* obj, int events) { + void AddEvents(rtc::AsyncSocket* obj, int events) { EventMap::iterator it = events_.find(obj); if (events_.end() == it) { events_.insert(EventMap::value_type(obj, events)); @@ -110,7 +81,7 @@ class StreamSink : public sigslot::has_slots<> { it->second |= events; } } - bool DoCheck(void* obj, StreamSinkEvent event, bool reset) { + bool DoCheck(rtc::AsyncSocket* obj, StreamSinkEvent event, bool reset) { EventMap::iterator it = events_.find(obj); if ((events_.end() == it) || (0 == (it->second & event))) { return false; @@ -120,16 +91,6 @@ class StreamSink : public sigslot::has_slots<> { } return true; } - int DoEvents(void* obj, bool reset) { - EventMap::iterator it = events_.find(obj); - if (events_.end() == it) - return 0; - int events = it->second; - if (reset) { - it->second = 0; - } - return events; - } EventMap events_; }; diff --git a/rtc_base/third_party/base64/BUILD.gn b/rtc_base/third_party/base64/BUILD.gn index db03e0273d..969c7c0c64 100644 --- a/rtc_base/third_party/base64/BUILD.gn +++ b/rtc_base/third_party/base64/BUILD.gn @@ -14,5 +14,8 @@ rtc_library("base64") { "base64.cc", "base64.h", ] - deps = [ "../../system:rtc_export" ] + deps = [ + "../..:checks", + "../../system:rtc_export", + ] } diff --git a/rtc_base/third_party/base64/base64.cc b/rtc_base/third_party/base64/base64.cc index 53ff6b9d54..b9acf9a4c9 100644 --- a/rtc_base/third_party/base64/base64.cc +++ b/rtc_base/third_party/base64/base64.cc @@ -19,6 +19,8 @@ #include #include +#include "rtc_base/checks.h" + using std::vector; namespace rtc { @@ -95,7 +97,7 @@ bool Base64::IsBase64Encoded(const std::string& str) { void Base64::EncodeFromArray(const void* data, size_t len, std::string* result) { - assert(nullptr != result); + RTC_DCHECK(result); result->clear(); result->resize(((len + 2) / 3) * 4); const unsigned char* byte_data = static_cast(data); @@ -223,15 +225,15 @@ bool Base64::DecodeFromArrayTemplate(const char* data, DecodeFlags flags, T* result, size_t* data_used) { - assert(nullptr != result); - assert(flags <= (DO_PARSE_MASK | DO_PAD_MASK | DO_TERM_MASK)); + RTC_DCHECK(result); + RTC_DCHECK_LE(flags, (DO_PARSE_MASK | DO_PAD_MASK | DO_TERM_MASK)); const DecodeFlags parse_flags = flags & DO_PARSE_MASK; const DecodeFlags pad_flags = flags & DO_PAD_MASK; const DecodeFlags term_flags = flags & DO_TERM_MASK; - assert(0 != parse_flags); - assert(0 != pad_flags); - assert(0 != term_flags); + RTC_DCHECK_NE(0, parse_flags); + RTC_DCHECK_NE(0, pad_flags); + RTC_DCHECK_NE(0, term_flags); result->clear(); result->reserve(len); diff --git a/rtc_base/thread.cc b/rtc_base/thread.cc index 32449020c5..8ca9ce76a8 100644 --- a/rtc_base/thread.cc +++ b/rtc_base/thread.cc @@ -29,13 +29,14 @@ #include #include "absl/algorithm/container.h" +#include "api/sequence_checker.h" #include "rtc_base/atomic_ops.h" #include "rtc_base/checks.h" #include "rtc_base/deprecated/recursive_critical_section.h" #include "rtc_base/event.h" +#include "rtc_base/internal/default_socket_server.h" #include "rtc_base/logging.h" #include "rtc_base/null_socket_server.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/time_utils.h" #include "rtc_base/trace_event.h" @@ -70,8 +71,6 @@ class ScopedAutoReleasePool { namespace rtc { namespace { -const int kSlowDispatchLoggingThreshold = 50; // 50 ms - class MessageHandlerWithTask final : public MessageHandler { public: MessageHandlerWithTask() {} @@ -257,7 +256,7 @@ Thread* Thread::Current() { #ifndef NO_MAIN_THREAD_WRAPPING // Only autowrap the thread which instantiated the ThreadManager. if (!thread && manager->IsMainThread()) { - thread = new Thread(SocketServer::CreateDefault()); + thread = new Thread(CreateDefaultSocketServer()); thread->WrapCurrentWithThreadManager(manager, true); } #endif @@ -326,7 +325,7 @@ void rtc::ThreadManager::ChangeCurrentThreadForTest(rtc::Thread* thread) { Thread* ThreadManager::WrapCurrentThread() { Thread* result = CurrentThread(); if (nullptr == result) { - result = new Thread(SocketServer::CreateDefault()); + result = new Thread(CreateDefaultSocketServer()); result->WrapCurrentWithThreadManager(this, true); } return result; @@ -353,6 +352,35 @@ Thread::ScopedDisallowBlockingCalls::~ScopedDisallowBlockingCalls() { thread_->SetAllowBlockingCalls(previous_state_); } +#if RTC_DCHECK_IS_ON +Thread::ScopedCountBlockingCalls::ScopedCountBlockingCalls( + std::function callback) + : thread_(Thread::Current()), + base_blocking_call_count_(thread_->GetBlockingCallCount()), + base_could_be_blocking_call_count_( + thread_->GetCouldBeBlockingCallCount()), + result_callback_(std::move(callback)) {} + +Thread::ScopedCountBlockingCalls::~ScopedCountBlockingCalls() { + if (GetTotalBlockedCallCount() >= min_blocking_calls_for_callback_) { + result_callback_(GetBlockingCallCount(), GetCouldBeBlockingCallCount()); + } +} + +uint32_t Thread::ScopedCountBlockingCalls::GetBlockingCallCount() const { + return thread_->GetBlockingCallCount() - base_blocking_call_count_; +} + +uint32_t Thread::ScopedCountBlockingCalls::GetCouldBeBlockingCallCount() const { + return thread_->GetCouldBeBlockingCallCount() - + base_could_be_blocking_call_count_; +} + +uint32_t Thread::ScopedCountBlockingCalls::GetTotalBlockedCallCount() const { + return GetBlockingCallCount() + GetCouldBeBlockingCallCount(); +} +#endif + Thread::Thread(SocketServer* ss) : Thread(ss, /*do_init=*/true) {} Thread::Thread(std::unique_ptr ss) @@ -401,13 +429,11 @@ void Thread::DoDestroy() { // The signal is done from here to ensure // that it always gets called when the queue // is going away. - SignalQueueDestroyed(); - ThreadManager::Remove(this); - ClearInternal(nullptr, MQID_ANY, nullptr); - if (ss_) { ss_->SetMessageQueue(nullptr); } + ThreadManager::Remove(this); + ClearInternal(nullptr, MQID_ANY, nullptr); } SocketServer* Thread::socketserver() { @@ -680,14 +706,18 @@ void Thread::Dispatch(Message* pmsg) { TRACE_EVENT2("webrtc", "Thread::Dispatch", "src_file", pmsg->posted_from.file_name(), "src_func", pmsg->posted_from.function_name()); + RTC_DCHECK_RUN_ON(this); int64_t start_time = TimeMillis(); pmsg->phandler->OnMessage(pmsg); int64_t end_time = TimeMillis(); int64_t diff = TimeDiff(end_time, start_time); - if (diff >= kSlowDispatchLoggingThreshold) { - RTC_LOG(LS_INFO) << "Message took " << diff + if (diff >= dispatch_warning_ms_) { + RTC_LOG(LS_INFO) << "Message to " << name() << " took " << diff << "ms to dispatch. Posted from: " << pmsg->posted_from.ToString(); + // To avoid log spew, move the warning limit to only give warning + // for delays that are larger than the one observed. + dispatch_warning_ms_ = diff + 1; } } @@ -696,7 +726,7 @@ bool Thread::IsCurrent() const { } std::unique_ptr Thread::CreateWithSocketServer() { - return std::unique_ptr(new Thread(SocketServer::CreateDefault())); + return std::unique_ptr(new Thread(CreateDefaultSocketServer())); } std::unique_ptr Thread::Create() { @@ -739,6 +769,16 @@ bool Thread::SetName(const std::string& name, const void* obj) { return true; } +void Thread::SetDispatchWarningMs(int deadline) { + if (!IsCurrent()) { + PostTask(webrtc::ToQueuedTask( + [this, deadline]() { SetDispatchWarningMs(deadline); })); + return; + } + RTC_DCHECK_RUN_ON(this); + dispatch_warning_ms_ = deadline; +} + bool Thread::Start() { RTC_DCHECK(!IsRunning()); @@ -888,6 +928,11 @@ void Thread::Send(const Location& posted_from, msg.message_id = id; msg.pdata = pdata; if (IsCurrent()) { +#if RTC_DCHECK_IS_ON + RTC_DCHECK(this->IsInvokeToThreadAllowed(this)); + RTC_DCHECK_RUN_ON(this); + could_be_blocking_call_count_++; +#endif msg.phandler->OnMessage(&msg); return; } @@ -898,6 +943,8 @@ void Thread::Send(const Location& posted_from, #if RTC_DCHECK_IS_ON if (current_thread) { + RTC_DCHECK_RUN_ON(current_thread); + current_thread->blocking_call_count_++; RTC_DCHECK(current_thread->IsInvokeToThreadAllowed(this)); ThreadManager::Instance()->RegisterSendAndCheckForCycles(current_thread, this); @@ -1021,6 +1068,17 @@ void Thread::DisallowAllInvokes() { #endif } +#if RTC_DCHECK_IS_ON +uint32_t Thread::GetBlockingCallCount() const { + RTC_DCHECK_RUN_ON(this); + return blocking_call_count_; +} +uint32_t Thread::GetCouldBeBlockingCallCount() const { + RTC_DCHECK_RUN_ON(this); + return could_be_blocking_call_count_; +} +#endif + // Returns true if no policies added or if there is at least one policy // that permits invocation to |target| thread. bool Thread::IsInvokeToThreadAllowed(rtc::Thread* target) { @@ -1137,7 +1195,7 @@ MessageHandler* Thread::GetPostTaskMessageHandler() { } AutoThread::AutoThread() - : Thread(SocketServer::CreateDefault(), /*do_init=*/false) { + : Thread(CreateDefaultSocketServer(), /*do_init=*/false) { if (!ThreadManager::Instance()->CurrentThread()) { // DoInit registers with ThreadManager. Do that only if we intend to // be rtc::Thread::Current(), otherwise ProcessAllMessageQueuesInternal will diff --git a/rtc_base/thread.h b/rtc_base/thread.h index ed19e98927..6e68f1a679 100644 --- a/rtc_base/thread.h +++ b/rtc_base/thread.h @@ -42,6 +42,35 @@ #include "rtc_base/win32.h" #endif +#if RTC_DCHECK_IS_ON +// Counts how many blocking Thread::Invoke or Thread::Send calls are made from +// within a scope and logs the number of blocking calls at the end of the scope. +#define RTC_LOG_THREAD_BLOCK_COUNT() \ + rtc::Thread::ScopedCountBlockingCalls blocked_call_count_printer( \ + [func = __func__](uint32_t actual_block, uint32_t could_block) { \ + auto total = actual_block + could_block; \ + if (total) { \ + RTC_LOG(LS_WARNING) << "Blocking " << func << ": total=" << total \ + << " (actual=" << actual_block \ + << ", could=" << could_block << ")"; \ + } \ + }) + +// Adds an RTC_DCHECK_LE that checks that the number of blocking calls are +// less than or equal to a specific value. Use to avoid regressing in the +// number of blocking thread calls. +// Note: Use of this macro, requires RTC_LOG_THREAD_BLOCK_COUNT() to be called +// first. +#define RTC_DCHECK_BLOCK_COUNT_NO_MORE_THAN(x) \ + do { \ + blocked_call_count_printer.set_minimum_call_count_for_callback(x + 1); \ + RTC_DCHECK_LE(blocked_call_count_printer.GetTotalBlockedCallCount(), x); \ + } while (0) +#else +#define RTC_LOG_THREAD_BLOCK_COUNT() +#define RTC_DCHECK_BLOCK_COUNT_NO_MORE_THAN(x) +#endif + namespace rtc { class Thread; @@ -212,6 +241,39 @@ class RTC_LOCKABLE RTC_EXPORT Thread : public webrtc::TaskQueueBase { const bool previous_state_; }; +#if RTC_DCHECK_IS_ON + class ScopedCountBlockingCalls { + public: + ScopedCountBlockingCalls(std::function callback); + ScopedCountBlockingCalls(const ScopedDisallowBlockingCalls&) = delete; + ScopedCountBlockingCalls& operator=(const ScopedDisallowBlockingCalls&) = + delete; + ~ScopedCountBlockingCalls(); + + uint32_t GetBlockingCallCount() const; + uint32_t GetCouldBeBlockingCallCount() const; + uint32_t GetTotalBlockedCallCount() const; + + void set_minimum_call_count_for_callback(uint32_t minimum) { + min_blocking_calls_for_callback_ = minimum; + } + + private: + Thread* const thread_; + const uint32_t base_blocking_call_count_; + const uint32_t base_could_be_blocking_call_count_; + // The minimum number of blocking calls required in order to issue the + // result_callback_. This is used by RTC_DCHECK_BLOCK_COUNT_NO_MORE_THAN to + // tame log spam. + // By default we always issue the callback, regardless of callback count. + uint32_t min_blocking_calls_for_callback_ = 0; + std::function result_callback_; + }; + + uint32_t GetBlockingCallCount() const; + uint32_t GetCouldBeBlockingCallCount() const; +#endif + SocketServer* socketserver(); // Note: The behavior of Thread has changed. When a thread is stopped, @@ -274,10 +336,6 @@ class RTC_LOCKABLE RTC_EXPORT Thread : public webrtc::TaskQueueBase { } } - // When this signal is sent out, any references to this queue should - // no longer be used. - sigslot::signal0<> SignalQueueDestroyed; - bool IsCurrent() const; // Sleeps the calling thread for the specified number of milliseconds, during @@ -290,6 +348,11 @@ class RTC_LOCKABLE RTC_EXPORT Thread : public webrtc::TaskQueueBase { const std::string& name() const { return name_; } bool SetName(const std::string& name, const void* obj); + // Sets the expected processing time in ms. The thread will write + // log messages when Invoke() takes more time than this. + // Default is 50 ms. + void SetDispatchWarningMs(int deadline); + // Starts the execution of the thread. bool Start(); @@ -525,6 +588,8 @@ class RTC_LOCKABLE RTC_EXPORT Thread : public webrtc::TaskQueueBase { RecursiveCriticalSection* CritForTest() { return &crit_; } private: + static const int kSlowDispatchLoggingThreshold = 50; // 50 ms + class QueuedTaskHandler final : public MessageHandler { public: QueuedTaskHandler() {} @@ -570,7 +635,9 @@ class RTC_LOCKABLE RTC_EXPORT Thread : public webrtc::TaskQueueBase { MessageList messages_ RTC_GUARDED_BY(crit_); PriorityQueue delayed_messages_ RTC_GUARDED_BY(crit_); uint32_t delayed_next_num_ RTC_GUARDED_BY(crit_); -#if (!defined(NDEBUG) || defined(DCHECK_ALWAYS_ON)) +#if RTC_DCHECK_IS_ON + uint32_t blocking_call_count_ RTC_GUARDED_BY(this) = 0; + uint32_t could_be_blocking_call_count_ RTC_GUARDED_BY(this) = 0; std::vector allowed_threads_ RTC_GUARDED_BY(this); bool invoke_policy_enabled_ RTC_GUARDED_BY(this) = false; #endif @@ -614,6 +681,8 @@ class RTC_LOCKABLE RTC_EXPORT Thread : public webrtc::TaskQueueBase { friend class ThreadManager; + int dispatch_warning_ms_ RTC_GUARDED_BY(this) = kSlowDispatchLoggingThreshold; + RTC_DISALLOW_COPY_AND_ASSIGN(Thread); }; diff --git a/rtc_base/thread_checker.h b/rtc_base/thread_checker.h deleted file mode 100644 index 876a08e38c..0000000000 --- a/rtc_base/thread_checker.h +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -// Borrowed from Chromium's src/base/threading/thread_checker.h. - -#ifndef RTC_BASE_THREAD_CHECKER_H_ -#define RTC_BASE_THREAD_CHECKER_H_ - -#include "rtc_base/deprecation.h" -#include "rtc_base/synchronization/sequence_checker.h" - -namespace rtc { -// TODO(srte): Replace usages of this with SequenceChecker. -class ThreadChecker : public webrtc::SequenceChecker { - public: - RTC_DEPRECATED bool CalledOnValidThread() const { return IsCurrent(); } - RTC_DEPRECATED void DetachFromThread() { Detach(); } -}; -} // namespace rtc -#endif // RTC_BASE_THREAD_CHECKER_H_ diff --git a/rtc_base/thread_checker_unittest.cc b/rtc_base/thread_checker_unittest.cc deleted file mode 100644 index b5927043f0..0000000000 --- a/rtc_base/thread_checker_unittest.cc +++ /dev/null @@ -1,253 +0,0 @@ -/* - * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -// Borrowed from Chromium's src/base/threading/thread_checker_unittest.cc. - -#include "rtc_base/thread_checker.h" - -#include -#include - -#include "rtc_base/checks.h" -#include "rtc_base/constructor_magic.h" -#include "rtc_base/null_socket_server.h" -#include "rtc_base/socket_server.h" -#include "rtc_base/task_queue.h" -#include "rtc_base/thread.h" -#include "test/gtest.h" - -// Duplicated from base/threading/thread_checker.h so that we can be -// good citizens there and undef the macro. -#define ENABLE_THREAD_CHECKER RTC_DCHECK_IS_ON - -namespace rtc { - -namespace { - -// Simple class to exercise the basics of ThreadChecker. -// Both the destructor and DoStuff should verify that they were -// called on the same thread as the constructor. -class ThreadCheckerClass : public ThreadChecker { - public: - ThreadCheckerClass() {} - - // Verifies that it was called on the same thread as the constructor. - void DoStuff() { RTC_DCHECK(IsCurrent()); } - - void Detach() { ThreadChecker::Detach(); } - - static void MethodOnDifferentThreadImpl(); - static void DetachThenCallFromDifferentThreadImpl(); - - private: - RTC_DISALLOW_COPY_AND_ASSIGN(ThreadCheckerClass); -}; - -// Calls ThreadCheckerClass::DoStuff on another thread. -class CallDoStuffOnThread : public Thread { - public: - explicit CallDoStuffOnThread(ThreadCheckerClass* thread_checker_class) - : Thread(std::unique_ptr(new rtc::NullSocketServer())), - thread_checker_class_(thread_checker_class) { - SetName("call_do_stuff_on_thread", nullptr); - } - - void Run() override { thread_checker_class_->DoStuff(); } - - // New method. Needed since Thread::Join is protected, and it is called by - // the TEST. - void Join() { Thread::Join(); } - - private: - ThreadCheckerClass* thread_checker_class_; - - RTC_DISALLOW_COPY_AND_ASSIGN(CallDoStuffOnThread); -}; - -// Deletes ThreadCheckerClass on a different thread. -class DeleteThreadCheckerClassOnThread : public Thread { - public: - explicit DeleteThreadCheckerClassOnThread( - std::unique_ptr thread_checker_class) - : Thread(std::unique_ptr(new rtc::NullSocketServer())), - thread_checker_class_(std::move(thread_checker_class)) { - SetName("delete_thread_checker_class_on_thread", nullptr); - } - - void Run() override { thread_checker_class_.reset(); } - - // New method. Needed since Thread::Join is protected, and it is called by - // the TEST. - void Join() { Thread::Join(); } - - bool has_been_deleted() const { return !thread_checker_class_; } - - private: - std::unique_ptr thread_checker_class_; - - RTC_DISALLOW_COPY_AND_ASSIGN(DeleteThreadCheckerClassOnThread); -}; - -} // namespace - -TEST(ThreadCheckerTest, CallsAllowedOnSameThread) { - std::unique_ptr thread_checker_class( - new ThreadCheckerClass); - - // Verify that DoStuff doesn't assert. - thread_checker_class->DoStuff(); - - // Verify that the destructor doesn't assert. - thread_checker_class.reset(); -} - -TEST(ThreadCheckerTest, DestructorAllowedOnDifferentThread) { - std::unique_ptr thread_checker_class( - new ThreadCheckerClass); - - // Verify that the destructor doesn't assert - // when called on a different thread. - DeleteThreadCheckerClassOnThread delete_on_thread( - std::move(thread_checker_class)); - - EXPECT_FALSE(delete_on_thread.has_been_deleted()); - - delete_on_thread.Start(); - delete_on_thread.Join(); - - EXPECT_TRUE(delete_on_thread.has_been_deleted()); -} - -TEST(ThreadCheckerTest, Detach) { - std::unique_ptr thread_checker_class( - new ThreadCheckerClass); - - // Verify that DoStuff doesn't assert when called on a different thread after - // a call to Detach. - thread_checker_class->Detach(); - CallDoStuffOnThread call_on_thread(thread_checker_class.get()); - - call_on_thread.Start(); - call_on_thread.Join(); -} - -#if GTEST_HAS_DEATH_TEST || !ENABLE_THREAD_CHECKER - -void ThreadCheckerClass::MethodOnDifferentThreadImpl() { - std::unique_ptr thread_checker_class( - new ThreadCheckerClass); - - // DoStuff should assert in debug builds only when called on a - // different thread. - CallDoStuffOnThread call_on_thread(thread_checker_class.get()); - - call_on_thread.Start(); - call_on_thread.Join(); -} - -#if ENABLE_THREAD_CHECKER -TEST(ThreadCheckerDeathTest, MethodNotAllowedOnDifferentThreadInDebug) { - ASSERT_DEATH({ ThreadCheckerClass::MethodOnDifferentThreadImpl(); }, ""); -} -#else -TEST(ThreadCheckerTest, MethodAllowedOnDifferentThreadInRelease) { - ThreadCheckerClass::MethodOnDifferentThreadImpl(); -} -#endif // ENABLE_THREAD_CHECKER - -void ThreadCheckerClass::DetachThenCallFromDifferentThreadImpl() { - std::unique_ptr thread_checker_class( - new ThreadCheckerClass); - - // DoStuff doesn't assert when called on a different thread - // after a call to Detach. - thread_checker_class->Detach(); - CallDoStuffOnThread call_on_thread(thread_checker_class.get()); - - call_on_thread.Start(); - call_on_thread.Join(); - - // DoStuff should assert in debug builds only after moving to - // another thread. - thread_checker_class->DoStuff(); -} - -#if ENABLE_THREAD_CHECKER -TEST(ThreadCheckerDeathTest, DetachFromThreadInDebug) { - ASSERT_DEATH({ ThreadCheckerClass::DetachThenCallFromDifferentThreadImpl(); }, - ""); -} -#else -TEST(ThreadCheckerTest, DetachFromThreadInRelease) { - ThreadCheckerClass::DetachThenCallFromDifferentThreadImpl(); -} -#endif // ENABLE_THREAD_CHECKER - -#endif // GTEST_HAS_DEATH_TEST || !ENABLE_THREAD_CHECKER - -class ThreadAnnotateTest { - public: - // Next two function should create warnings when compile (e.g. if used with - // specific T). - // TODO(danilchap): Find a way to test they do not compile when thread - // annotation checks enabled. - template - void access_var_no_annotate() { - var_thread_ = 42; - } - - template - void access_fun_no_annotate() { - function(); - } - - // Functions below should be able to compile. - void access_var_annotate_thread() { - RTC_DCHECK_RUN_ON(thread_); - var_thread_ = 42; - } - - void access_var_annotate_checker() { - RTC_DCHECK_RUN_ON(&checker_); - var_checker_ = 44; - } - - void access_var_annotate_queue() { - RTC_DCHECK_RUN_ON(queue_); - var_queue_ = 46; - } - - void access_fun_annotate() { - RTC_DCHECK_RUN_ON(thread_); - function(); - } - - void access_fun_and_var() { - RTC_DCHECK_RUN_ON(thread_); - fun_acccess_var(); - } - - private: - void function() RTC_RUN_ON(thread_) {} - void fun_acccess_var() RTC_RUN_ON(thread_) { var_thread_ = 13; } - - rtc::Thread* thread_; - rtc::ThreadChecker checker_; - rtc::TaskQueue* queue_; - - int var_thread_ RTC_GUARDED_BY(thread_); - int var_checker_ RTC_GUARDED_BY(checker_); - int var_queue_ RTC_GUARDED_BY(queue_); -}; - -// Just in case we ever get lumped together with other compilation units. -#undef ENABLE_THREAD_CHECKER - -} // namespace rtc diff --git a/rtc_base/thread_unittest.cc b/rtc_base/thread_unittest.cc index 51321985ed..789bdd943e 100644 --- a/rtc_base/thread_unittest.cc +++ b/rtc_base/thread_unittest.cc @@ -19,6 +19,7 @@ #include "rtc_base/atomic_ops.h" #include "rtc_base/event.h" #include "rtc_base/gunit.h" +#include "rtc_base/internal/default_socket_server.h" #include "rtc_base/null_socket_server.h" #include "rtc_base/physical_socket_server.h" #include "rtc_base/socket_address.h" @@ -255,6 +256,81 @@ TEST(ThreadTest, DISABLED_Main) { EXPECT_EQ(55, sock_client.last); } +TEST(ThreadTest, CountBlockingCalls) { + // When the test runs, this will print out: + // (thread_unittest.cc:262): Blocking TestBody: total=2 (actual=1, could=1) + RTC_LOG_THREAD_BLOCK_COUNT(); +#if RTC_DCHECK_IS_ON + rtc::Thread* current = rtc::Thread::Current(); + ASSERT_TRUE(current); + rtc::Thread::ScopedCountBlockingCalls blocked_calls( + [&](uint32_t actual_block, uint32_t could_block) { + EXPECT_EQ(1u, actual_block); + EXPECT_EQ(1u, could_block); + }); + + EXPECT_EQ(0u, blocked_calls.GetBlockingCallCount()); + EXPECT_EQ(0u, blocked_calls.GetCouldBeBlockingCallCount()); + EXPECT_EQ(0u, blocked_calls.GetTotalBlockedCallCount()); + + // Test invoking on the current thread. This should not count as an 'actual' + // invoke, but should still count as an invoke that could block since we + // that the call to Invoke serves a purpose in some configurations (and should + // not be used a general way to call methods on the same thread). + current->Invoke(RTC_FROM_HERE, []() {}); + EXPECT_EQ(0u, blocked_calls.GetBlockingCallCount()); + EXPECT_EQ(1u, blocked_calls.GetCouldBeBlockingCallCount()); + EXPECT_EQ(1u, blocked_calls.GetTotalBlockedCallCount()); + + // Create a new thread to invoke on. + auto thread = Thread::CreateWithSocketServer(); + thread->Start(); + EXPECT_EQ(42, thread->Invoke(RTC_FROM_HERE, []() { return 42; })); + EXPECT_EQ(1u, blocked_calls.GetBlockingCallCount()); + EXPECT_EQ(1u, blocked_calls.GetCouldBeBlockingCallCount()); + EXPECT_EQ(2u, blocked_calls.GetTotalBlockedCallCount()); + thread->Stop(); + RTC_DCHECK_BLOCK_COUNT_NO_MORE_THAN(2); +#else + RTC_DCHECK_BLOCK_COUNT_NO_MORE_THAN(0); + RTC_LOG(LS_INFO) << "Test not active in this config"; +#endif +} + +#if RTC_DCHECK_IS_ON +TEST(ThreadTest, CountBlockingCallsOneCallback) { + rtc::Thread* current = rtc::Thread::Current(); + ASSERT_TRUE(current); + bool was_called_back = false; + { + rtc::Thread::ScopedCountBlockingCalls blocked_calls( + [&](uint32_t actual_block, uint32_t could_block) { + was_called_back = true; + }); + current->Invoke(RTC_FROM_HERE, []() {}); + } + EXPECT_TRUE(was_called_back); +} + +TEST(ThreadTest, CountBlockingCallsSkipCallback) { + rtc::Thread* current = rtc::Thread::Current(); + ASSERT_TRUE(current); + bool was_called_back = false; + { + rtc::Thread::ScopedCountBlockingCalls blocked_calls( + [&](uint32_t actual_block, uint32_t could_block) { + was_called_back = true; + }); + // Changed `blocked_calls` to not issue the callback if there are 1 or + // fewer blocking calls (i.e. we set the minimum required number to 2). + blocked_calls.set_minimum_call_count_for_callback(2); + current->Invoke(RTC_FROM_HERE, []() {}); + } + // We should not have gotten a call back. + EXPECT_FALSE(was_called_back); +} +#endif + // Test that setting thread names doesn't cause a malfunction. // There's no easy way to verify the name was set properly at this time. TEST(ThreadTest, Names) { @@ -432,7 +508,7 @@ TEST(ThreadTest, ThreeThreadsInvoke) { struct LocalFuncs { static void Set(LockedBool* out) { out->Set(true); } static void InvokeSet(Thread* thread, LockedBool* out) { - thread->Invoke(RTC_FROM_HERE, Bind(&Set, out)); + thread->Invoke(RTC_FROM_HERE, [out] { Set(out); }); } // Set |out| true and call InvokeSet on |thread|. @@ -445,67 +521,41 @@ TEST(ThreadTest, ThreeThreadsInvoke) { // Asynchronously invoke SetAndInvokeSet on |thread1| and wait until // |thread1| starts the call. - static void AsyncInvokeSetAndWait(AsyncInvoker* invoker, + static void AsyncInvokeSetAndWait(DEPRECATED_AsyncInvoker* invoker, Thread* thread1, Thread* thread2, LockedBool* out) { LockedBool async_invoked(false); invoker->AsyncInvoke( - RTC_FROM_HERE, thread1, - Bind(&SetAndInvokeSet, &async_invoked, thread2, out)); + RTC_FROM_HERE, thread1, [&async_invoked, thread2, out] { + SetAndInvokeSet(&async_invoked, thread2, out); + }); EXPECT_TRUE_WAIT(async_invoked.Get(), 2000); } }; - AsyncInvoker invoker; + DEPRECATED_AsyncInvoker invoker; LockedBool thread_a_called(false); // Start the sequence A --(invoke)--> B --(async invoke)--> C --(invoke)--> A. // Thread B returns when C receives the call and C should be blocked until A // starts to process messages. - thread_b->Invoke(RTC_FROM_HERE, - Bind(&LocalFuncs::AsyncInvokeSetAndWait, &invoker, - thread_c.get(), thread_a, &thread_a_called)); + Thread* thread_c_ptr = thread_c.get(); + thread_b->Invoke( + RTC_FROM_HERE, [&invoker, thread_c_ptr, thread_a, &thread_a_called] { + LocalFuncs::AsyncInvokeSetAndWait(&invoker, thread_c_ptr, thread_a, + &thread_a_called); + }); EXPECT_FALSE(thread_a_called.Get()); EXPECT_TRUE_WAIT(thread_a_called.Get(), 2000); } -// Set the name on a thread when the underlying QueueDestroyed signal is -// triggered. This causes an error if the object is already partially -// destroyed. -class SetNameOnSignalQueueDestroyedTester : public sigslot::has_slots<> { - public: - SetNameOnSignalQueueDestroyedTester(Thread* thread) : thread_(thread) { - thread->SignalQueueDestroyed.connect( - this, &SetNameOnSignalQueueDestroyedTester::OnQueueDestroyed); - } - - void OnQueueDestroyed() { - // Makes sure that if we access the Thread while it's being destroyed, that - // it doesn't cause a problem because the vtable has been modified. - thread_->SetName("foo", nullptr); - } - - private: - Thread* thread_; -}; - -TEST(ThreadTest, SetNameOnSignalQueueDestroyed) { - auto thread1 = Thread::CreateWithSocketServer(); - SetNameOnSignalQueueDestroyedTester tester1(thread1.get()); - thread1.reset(); - - Thread* thread2 = new AutoThread(); - SetNameOnSignalQueueDestroyedTester tester2(thread2); - delete thread2; -} - class ThreadQueueTest : public ::testing::Test, public Thread { public: - ThreadQueueTest() : Thread(SocketServer::CreateDefault(), true) {} + ThreadQueueTest() : Thread(CreateDefaultSocketServer(), true) {} bool IsLocked_Worker() { if (!CritForTest()->TryEnter()) { return true; @@ -518,8 +568,8 @@ class ThreadQueueTest : public ::testing::Test, public Thread { // succeed, since our critical sections are reentrant. std::unique_ptr worker(Thread::CreateWithSocketServer()); worker->Start(); - return worker->Invoke( - RTC_FROM_HERE, rtc::Bind(&ThreadQueueTest::IsLocked_Worker, this)); + return worker->Invoke(RTC_FROM_HERE, + [this] { return IsLocked_Worker(); }); } }; @@ -555,7 +605,7 @@ static void DelayedPostsWithIdenticalTimesAreProcessedInFifoOrder(Thread* q) { } TEST_F(ThreadQueueTest, DelayedPostsWithIdenticalTimesAreProcessedInFifoOrder) { - Thread q(SocketServer::CreateDefault(), true); + Thread q(CreateDefaultSocketServer(), true); DelayedPostsWithIdenticalTimesAreProcessedInFifoOrder(&q); NullSocketServer nullss; @@ -711,7 +761,7 @@ class AsyncInvokeTest : public ::testing::Test { }; TEST_F(AsyncInvokeTest, FireAndForget) { - AsyncInvoker invoker; + DEPRECATED_AsyncInvoker invoker; // Create and start the thread. auto thread = Thread::CreateWithSocketServer(); thread->Start(); @@ -723,7 +773,7 @@ TEST_F(AsyncInvokeTest, FireAndForget) { } TEST_F(AsyncInvokeTest, NonCopyableFunctor) { - AsyncInvoker invoker; + DEPRECATED_AsyncInvoker invoker; // Create and start the thread. auto thread = Thread::CreateWithSocketServer(); thread->Start(); @@ -754,7 +804,7 @@ TEST_F(AsyncInvokeTest, KillInvokerDuringExecute) { EXPECT_FALSE(invoker_destroyed); functor_finished.Set(); }; - AsyncInvoker invoker; + DEPRECATED_AsyncInvoker invoker; invoker.AsyncInvoke(RTC_FROM_HERE, thread.get(), functor); functor_started.Wait(Event::kForever); @@ -783,7 +833,7 @@ TEST_F(AsyncInvokeTest, KillInvokerDuringExecuteWithReentrantInvoke) { Thread thread(std::make_unique()); thread.Start(); { - AsyncInvoker invoker; + DEPRECATED_AsyncInvoker invoker; auto reentrant_functor = [&reentrant_functor_run] { reentrant_functor_run = true; }; @@ -802,7 +852,7 @@ TEST_F(AsyncInvokeTest, KillInvokerDuringExecuteWithReentrantInvoke) { } TEST_F(AsyncInvokeTest, Flush) { - AsyncInvoker invoker; + DEPRECATED_AsyncInvoker invoker; AtomicBool flag1; AtomicBool flag2; // Queue two async calls to the current thread. @@ -818,7 +868,7 @@ TEST_F(AsyncInvokeTest, Flush) { } TEST_F(AsyncInvokeTest, FlushWithIds) { - AsyncInvoker invoker; + DEPRECATED_AsyncInvoker invoker; AtomicBool flag1; AtomicBool flag2; // Queue two async calls to the current thread, one with a message id. @@ -839,11 +889,6 @@ TEST_F(AsyncInvokeTest, FlushWithIds) { EXPECT_TRUE(flag2.get()); } -void ThreadIsCurrent(Thread* thread, bool* result, Event* event) { - *result = thread->IsCurrent(); - event->Set(); -} - void WaitAndSetEvent(Event* wait_event, Event* set_event) { wait_event->Wait(Event::kForever); set_event->Set(); @@ -908,15 +953,6 @@ class DestructionFunctor { bool was_invoked_ = false; }; -TEST(ThreadPostTaskTest, InvokesWithBind) { - std::unique_ptr background_thread(rtc::Thread::Create()); - background_thread->Start(); - - Event event; - background_thread->PostTask(RTC_FROM_HERE, Bind(&Event::Set, &event)); - event.Wait(Event::kForever); -} - TEST(ThreadPostTaskTest, InvokesWithLambda) { std::unique_ptr background_thread(rtc::Thread::Create()); background_thread->Start(); @@ -1019,9 +1055,13 @@ TEST(ThreadPostTaskTest, InvokesOnBackgroundThread) { Event event; bool was_invoked_on_background_thread = false; - background_thread->PostTask(RTC_FROM_HERE, - Bind(&ThreadIsCurrent, background_thread.get(), - &was_invoked_on_background_thread, &event)); + Thread* background_thread_ptr = background_thread.get(); + background_thread->PostTask( + RTC_FROM_HERE, + [background_thread_ptr, &was_invoked_on_background_thread, &event] { + was_invoked_on_background_thread = background_thread_ptr->IsCurrent(); + event.Set(); + }); event.Wait(Event::kForever); EXPECT_TRUE(was_invoked_on_background_thread); @@ -1035,9 +1075,10 @@ TEST(ThreadPostTaskTest, InvokesAsynchronously) { // thread. The second event ensures that the message is processed. Event event_set_by_test_thread; Event event_set_by_background_thread; - background_thread->PostTask(RTC_FROM_HERE, - Bind(&WaitAndSetEvent, &event_set_by_test_thread, - &event_set_by_background_thread)); + background_thread->PostTask(RTC_FROM_HERE, [&event_set_by_test_thread, + &event_set_by_background_thread] { + WaitAndSetEvent(&event_set_by_test_thread, &event_set_by_background_thread); + }); event_set_by_test_thread.Set(); event_set_by_background_thread.Wait(Event::kForever); } @@ -1051,12 +1092,12 @@ TEST(ThreadPostTaskTest, InvokesInPostedOrder) { Event third; Event fourth; - background_thread->PostTask(RTC_FROM_HERE, - Bind(&WaitAndSetEvent, &first, &second)); - background_thread->PostTask(RTC_FROM_HERE, - Bind(&WaitAndSetEvent, &second, &third)); - background_thread->PostTask(RTC_FROM_HERE, - Bind(&WaitAndSetEvent, &third, &fourth)); + background_thread->PostTask( + RTC_FROM_HERE, [&first, &second] { WaitAndSetEvent(&first, &second); }); + background_thread->PostTask( + RTC_FROM_HERE, [&second, &third] { WaitAndSetEvent(&second, &third); }); + background_thread->PostTask( + RTC_FROM_HERE, [&third, &fourth] { WaitAndSetEvent(&third, &fourth); }); // All tasks have been posted before the first one is unblocked. first.Set(); @@ -1074,8 +1115,10 @@ TEST(ThreadPostDelayedTaskTest, InvokesAsynchronously) { Event event_set_by_background_thread; background_thread->PostDelayedTask( RTC_FROM_HERE, - Bind(&WaitAndSetEvent, &event_set_by_test_thread, - &event_set_by_background_thread), + [&event_set_by_test_thread, &event_set_by_background_thread] { + WaitAndSetEvent(&event_set_by_test_thread, + &event_set_by_background_thread); + }, /*milliseconds=*/10); event_set_by_test_thread.Set(); event_set_by_background_thread.Wait(Event::kForever); @@ -1091,15 +1134,15 @@ TEST(ThreadPostDelayedTaskTest, InvokesInDelayOrder) { Event third; Event fourth; - background_thread->PostDelayedTask(RTC_FROM_HERE, - Bind(&WaitAndSetEvent, &third, &fourth), - /*milliseconds=*/11); - background_thread->PostDelayedTask(RTC_FROM_HERE, - Bind(&WaitAndSetEvent, &first, &second), - /*milliseconds=*/9); - background_thread->PostDelayedTask(RTC_FROM_HERE, - Bind(&WaitAndSetEvent, &second, &third), - /*milliseconds=*/10); + background_thread->PostDelayedTask( + RTC_FROM_HERE, [&third, &fourth] { WaitAndSetEvent(&third, &fourth); }, + /*milliseconds=*/11); + background_thread->PostDelayedTask( + RTC_FROM_HERE, [&first, &second] { WaitAndSetEvent(&first, &second); }, + /*milliseconds=*/9); + background_thread->PostDelayedTask( + RTC_FROM_HERE, [&second, &third] { WaitAndSetEvent(&second, &third); }, + /*milliseconds=*/10); // All tasks have been posted before the first one is unblocked. first.Set(); diff --git a/rtc_base/time_utils.cc b/rtc_base/time_utils.cc index 11c9d5a47f..fe63d3a8ed 100644 --- a/rtc_base/time_utils.cc +++ b/rtc_base/time_utils.cc @@ -12,23 +12,15 @@ #if defined(WEBRTC_POSIX) #include -#if defined(WEBRTC_MAC) -#include -#endif #endif #if defined(WEBRTC_WIN) -// clang-format off -// clang formatting would put last, -// which leads to compilation failure. -#include -#include #include -// clang-format on #endif #include "rtc_base/checks.h" #include "rtc_base/numerics/safe_conversions.h" +#include "rtc_base/system_time.h" #include "rtc_base/time_utils.h" namespace rtc { @@ -141,61 +133,12 @@ void SyncWithNtp(int64_t time_from_ntp_server_ms) { TimeHelper::SyncWithNtp(time_from_ntp_server_ms); } -#endif // defined(WINUWP) - -int64_t SystemTimeNanos() { - int64_t ticks; -#if defined(WEBRTC_MAC) - static mach_timebase_info_data_t timebase; - if (timebase.denom == 0) { - // Get the timebase if this is the first time we run. - // Recommended by Apple's QA1398. - if (mach_timebase_info(&timebase) != KERN_SUCCESS) { - RTC_NOTREACHED(); - } - } - // Use timebase to convert absolute time tick units into nanoseconds. - const auto mul = [](uint64_t a, uint32_t b) -> int64_t { - RTC_DCHECK_NE(b, 0); - RTC_DCHECK_LE(a, std::numeric_limits::max() / b) - << "The multiplication " << a << " * " << b << " overflows"; - return rtc::dchecked_cast(a * b); - }; - ticks = mul(mach_absolute_time(), timebase.numer) / timebase.denom; -#elif defined(WEBRTC_POSIX) - struct timespec ts; - // TODO(deadbeef): Do we need to handle the case when CLOCK_MONOTONIC is not - // supported? - clock_gettime(CLOCK_MONOTONIC, &ts); - ticks = kNumNanosecsPerSec * static_cast(ts.tv_sec) + - static_cast(ts.tv_nsec); -#elif defined(WINUWP) - ticks = TimeHelper::TicksNs(); -#elif defined(WEBRTC_WIN) - static volatile LONG last_timegettime = 0; - static volatile int64_t num_wrap_timegettime = 0; - volatile LONG* last_timegettime_ptr = &last_timegettime; - DWORD now = timeGetTime(); - // Atomically update the last gotten time - DWORD old = InterlockedExchange(last_timegettime_ptr, now); - if (now < old) { - // If now is earlier than old, there may have been a race between threads. - // 0x0fffffff ~3.1 days, the code will not take that long to execute - // so it must have been a wrap around. - if (old > 0xf0000000 && now < 0x0fffffff) { - num_wrap_timegettime++; - } - } - ticks = now + (num_wrap_timegettime << 32); - // TODO(deadbeef): Calculate with nanosecond precision. Otherwise, we're - // just wasting a multiply and divide when doing Time() on Windows. - ticks = ticks * kNumNanosecsPerMillisec; -#else -#error Unsupported platform. -#endif - return ticks; +int64_t WinUwpSystemTimeNanos() { + return TimeHelper::TicksNs(); } +#endif // defined(WINUWP) + int64_t SystemTimeMillis() { return static_cast(SystemTimeNanos() / kNumNanosecsPerMillisec); } diff --git a/rtc_base/time_utils.h b/rtc_base/time_utils.h index 147ab8daf8..de3c58c815 100644 --- a/rtc_base/time_utils.h +++ b/rtc_base/time_utils.h @@ -16,6 +16,7 @@ #include "rtc_base/checks.h" #include "rtc_base/system/rtc_export.h" +#include "rtc_base/system_time.h" namespace rtc { @@ -61,11 +62,16 @@ RTC_EXPORT ClockInterface* GetClockForTesting(); // Synchronizes the current clock based upon an NTP server's epoch in // milliseconds. void SyncWithNtp(int64_t time_from_ntp_server_ms); + +// Returns the current time in nanoseconds. The clock is synchonized with the +// system wall clock time upon instatiation. It may also be synchronized using +// the SyncWithNtp() function above. Please note that the clock will most likely +// drift away from the system wall clock time as time goes by. +int64_t WinUwpSystemTimeNanos(); #endif // defined(WINUWP) // Returns the actual system time, even if a clock is set for testing. // Useful for timeouts while using a test clock, or for logging. -int64_t SystemTimeNanos(); int64_t SystemTimeMillis(); // Returns the current time in milliseconds in 32 bits. diff --git a/rtc_base/unique_id_generator.cc b/rtc_base/unique_id_generator.cc index d41fa8d186..9fa3021c6f 100644 --- a/rtc_base/unique_id_generator.cc +++ b/rtc_base/unique_id_generator.cc @@ -26,6 +26,8 @@ UniqueRandomIdGenerator::UniqueRandomIdGenerator(ArrayView known_ids) UniqueRandomIdGenerator::~UniqueRandomIdGenerator() = default; uint32_t UniqueRandomIdGenerator::GenerateId() { + webrtc::MutexLock lock(&mutex_); + RTC_CHECK_LT(known_ids_.size(), std::numeric_limits::max() - 1); while (true) { auto pair = known_ids_.insert(CreateRandomNonZeroId()); @@ -36,6 +38,7 @@ uint32_t UniqueRandomIdGenerator::GenerateId() { } bool UniqueRandomIdGenerator::AddKnownId(uint32_t value) { + webrtc::MutexLock lock(&mutex_); return known_ids_.insert(value).second; } diff --git a/rtc_base/unique_id_generator.h b/rtc_base/unique_id_generator.h index 836dc70b61..22398fd3f2 100644 --- a/rtc_base/unique_id_generator.h +++ b/rtc_base/unique_id_generator.h @@ -16,6 +16,9 @@ #include #include "api/array_view.h" +#include "api/sequence_checker.h" +#include "rtc_base/synchronization/mutex.h" +#include "rtc_base/system/no_unique_address.h" namespace rtc { @@ -47,9 +50,10 @@ class UniqueNumberGenerator { bool AddKnownId(TIntegral value); private: + RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker sequence_checker_; static_assert(std::is_integral::value, "Must be integral type."); - TIntegral counter_; - std::set known_ids_; + TIntegral counter_ RTC_GUARDED_BY(sequence_checker_); + std::set known_ids_ RTC_GUARDED_BY(sequence_checker_); }; // This class will generate unique ids. Ids are 32 bit unsigned integers. @@ -76,7 +80,10 @@ class UniqueRandomIdGenerator { bool AddKnownId(uint32_t value); private: - std::set known_ids_; + // TODO(bugs.webrtc.org/12666): This lock is needed due to an instance in + // SdpOfferAnswerHandler being shared between threads. + webrtc::Mutex mutex_; + std::set known_ids_ RTC_GUARDED_BY(&mutex_); }; // This class will generate strings. A common use case is for identifiers. @@ -104,18 +111,23 @@ class UniqueStringGenerator { }; template -UniqueNumberGenerator::UniqueNumberGenerator() : counter_(0) {} +UniqueNumberGenerator::UniqueNumberGenerator() : counter_(0) { + sequence_checker_.Detach(); +} template UniqueNumberGenerator::UniqueNumberGenerator( ArrayView known_ids) - : counter_(0), known_ids_(known_ids.begin(), known_ids.end()) {} + : counter_(0), known_ids_(known_ids.begin(), known_ids.end()) { + sequence_checker_.Detach(); +} template UniqueNumberGenerator::~UniqueNumberGenerator() {} template TIntegral UniqueNumberGenerator::GenerateNumber() { + RTC_DCHECK_RUN_ON(&sequence_checker_); while (true) { RTC_CHECK_LT(counter_, std::numeric_limits::max()); auto pair = known_ids_.insert(counter_++); @@ -127,6 +139,7 @@ TIntegral UniqueNumberGenerator::GenerateNumber() { template bool UniqueNumberGenerator::AddKnownId(TIntegral value) { + RTC_DCHECK_RUN_ON(&sequence_checker_); return known_ids_.insert(value).second; } } // namespace rtc diff --git a/rtc_base/unique_id_generator_unittest.cc b/rtc_base/unique_id_generator_unittest.cc index 868b348b11..835a57e162 100644 --- a/rtc_base/unique_id_generator_unittest.cc +++ b/rtc_base/unique_id_generator_unittest.cc @@ -15,6 +15,7 @@ #include "absl/algorithm/container.h" #include "api/array_view.h" +#include "api/task_queue/task_queue_base.h" #include "rtc_base/gunit.h" #include "rtc_base/helpers.h" #include "test/gmock.h" @@ -23,6 +24,21 @@ using ::testing::IsEmpty; using ::testing::Test; namespace rtc { +namespace { +// Utility class that registers itself as the currently active task queue. +class FakeTaskQueue : public webrtc::TaskQueueBase { + public: + FakeTaskQueue() : task_queue_setter_(this) {} + + void Delete() override {} + void PostTask(std::unique_ptr task) override {} + void PostDelayedTask(std::unique_ptr task, + uint32_t milliseconds) override {} + + private: + CurrentTaskQueueSetter task_queue_setter_; +}; +} // namespace template class UniqueIdGeneratorTest : public Test {}; @@ -148,4 +164,39 @@ TYPED_TEST(UniqueIdGeneratorTest, EXPECT_FALSE(generator2.AddKnownId(id)); } +// Tests that it's OK to construct the generator in one execution environment +// (thread/task queue) but use it in another. +TEST(UniqueNumberGenerator, UsedOnSecondaryThread) { + const auto* current_tq = webrtc::TaskQueueBase::Current(); + // Construct the generator before `fake_task_queue` to ensure that it is + // constructed in a different execution environment than what + // `fake_task_queue` will represent. + UniqueNumberGenerator generator; + + FakeTaskQueue fake_task_queue; + // Sanity check to make sure we're in a different runtime environment. + ASSERT_NE(current_tq, webrtc::TaskQueueBase::Current()); + + // Generating an id should be fine in this context. + generator.GenerateNumber(); +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +TEST(UniqueNumberGeneratorDeathTest, FailsWhenUsedInWrongContext) { + // Instantiate the generator before the `loop`. This ensures that + // thread/sequence checkers will pick up a different thread environment than + // `fake_task_queue` will represent. + UniqueNumberGenerator generator; + // Generate an ID on the current thread. This causes the generator to attach + // to the current thread context. + generator.GenerateNumber(); + + // Instantiate a fake task queue that will register itself as the current tq. + FakeTaskQueue fake_task_queue; + + // Attempting to generate an id should now trigger a dcheck. + EXPECT_DEATH(generator.GenerateNumber(), ""); +} +#endif + } // namespace rtc diff --git a/rtc_base/virtual_socket_server.cc b/rtc_base/virtual_socket_server.cc index 804dc75624..f5e993645e 100644 --- a/rtc_base/virtual_socket_server.cc +++ b/rtc_base/virtual_socket_server.cc @@ -19,7 +19,6 @@ #include "absl/algorithm/container.h" #include "rtc_base/checks.h" -#include "rtc_base/deprecated/recursive_critical_section.h" #include "rtc_base/fake_clock.h" #include "rtc_base/logging.h" #include "rtc_base/physical_socket_server.h" @@ -54,7 +53,6 @@ const int NUM_SAMPLES = 1000; enum { MSG_ID_PACKET, - MSG_ID_ADDRESS_BOUND, MSG_ID_CONNECT, MSG_ID_DISCONNECT, MSG_ID_SIGNALREADEVENT, @@ -149,9 +147,6 @@ int VirtualSocket::Bind(const SocketAddress& addr) { } else { bound_ = true; was_any_ = addr.IsAnyIP(); - // Post a message here such that test case could have chance to - // process the local address. (i.e. SetAlternativeLocalAddress). - server_->msg_queue_->Post(RTC_FROM_HERE, this, MSG_ID_ADDRESS_BOUND); } return result; } @@ -168,65 +163,29 @@ int VirtualSocket::Close() { } if (SOCK_STREAM == type_) { + webrtc::MutexLock lock(&mutex_); + // Cancel pending sockets if (listen_queue_) { while (!listen_queue_->empty()) { SocketAddress addr = listen_queue_->front(); // Disconnect listening socket. - server_->Disconnect(server_->LookupBinding(addr)); + server_->Disconnect(addr); listen_queue_->pop_front(); } - delete listen_queue_; listen_queue_ = nullptr; } // Disconnect stream sockets if (CS_CONNECTED == state_) { - // Disconnect remote socket, check if it is a child of a server socket. - VirtualSocket* socket = - server_->LookupConnection(local_addr_, remote_addr_); - if (!socket) { - // Not a server socket child, then see if it is bound. - // TODO(tbd): If this is indeed a server socket that has no - // children this will cause the server socket to be - // closed. This might lead to unexpected results, how to fix this? - socket = server_->LookupBinding(remote_addr_); - } - server_->Disconnect(socket); - - // Remove mapping for both directions. - server_->RemoveConnection(remote_addr_, local_addr_); - server_->RemoveConnection(local_addr_, remote_addr_); + server_->Disconnect(local_addr_, remote_addr_); } // Cancel potential connects - MessageList msgs; - if (server_->msg_queue_) { - server_->msg_queue_->Clear(this, MSG_ID_CONNECT, &msgs); - } - for (MessageList::iterator it = msgs.begin(); it != msgs.end(); ++it) { - RTC_DCHECK(nullptr != it->pdata); - MessageAddress* data = static_cast(it->pdata); - - // Lookup remote side. - VirtualSocket* socket = - server_->LookupConnection(local_addr_, data->addr); - if (socket) { - // Server socket, remote side is a socket retreived by - // accept. Accepted sockets are not bound so we will not - // find it by looking in the bindings table. - server_->Disconnect(socket); - server_->RemoveConnection(local_addr_, data->addr); - } else { - server_->Disconnect(server_->LookupBinding(data->addr)); - } - delete data; - } + server_->CancelConnects(this); } // Clear incoming packets and disconnect messages - if (server_->msg_queue_) { - server_->msg_queue_->Clear(this); - } + server_->Clear(this); state_ = CS_CLOSED; local_addr_.Clear(); @@ -272,6 +231,8 @@ int VirtualSocket::RecvFrom(void* pv, if (timestamp) { *timestamp = -1; } + + webrtc::MutexLock lock(&mutex_); // If we don't have a packet, then either error or wait for one to arrive. if (recv_buffer_.empty()) { if (async_) { @@ -279,9 +240,7 @@ int VirtualSocket::RecvFrom(void* pv, return -1; } while (recv_buffer_.empty()) { - Message msg; - server_->msg_queue_->Get(&msg); - server_->msg_queue_->Dispatch(&msg); + server_->ProcessOneMessage(); } } @@ -301,18 +260,14 @@ int VirtualSocket::RecvFrom(void* pv, // To behave like a real socket, SignalReadEvent should fire in the next // message loop pass if there's still data buffered. if (!recv_buffer_.empty()) { - // Clear the message so it doesn't end up posted multiple times. - server_->msg_queue_->Clear(this, MSG_ID_SIGNALREADEVENT); - server_->msg_queue_->Post(RTC_FROM_HERE, this, MSG_ID_SIGNALREADEVENT); + server_->PostSignalReadEvent(this); } if (SOCK_STREAM == type_) { - bool was_full = (recv_buffer_size_ == server_->recv_buffer_capacity_); + bool was_full = (recv_buffer_size_ == server_->recv_buffer_capacity()); recv_buffer_size_ -= data_read; if (was_full) { - VirtualSocket* sender = server_->LookupBinding(remote_addr_); - RTC_DCHECK(nullptr != sender); - server_->SendTcp(sender); + server_->SendTcp(remote_addr_); } } @@ -320,6 +275,7 @@ int VirtualSocket::RecvFrom(void* pv, } int VirtualSocket::Listen(int backlog) { + webrtc::MutexLock lock(&mutex_); RTC_DCHECK(SOCK_STREAM == type_); RTC_DCHECK(CS_CLOSED == state_); if (local_addr_.IsNil()) { @@ -327,12 +283,13 @@ int VirtualSocket::Listen(int backlog) { return -1; } RTC_DCHECK(nullptr == listen_queue_); - listen_queue_ = new ListenQueue; + listen_queue_ = std::make_unique(); state_ = CS_CONNECTING; return 0; } VirtualSocket* VirtualSocket::Accept(SocketAddress* paddr) { + webrtc::MutexLock lock(&mutex_); if (nullptr == listen_queue_) { error_ = EINVAL; return nullptr; @@ -351,7 +308,7 @@ VirtualSocket* VirtualSocket::Accept(SocketAddress* paddr) { delete socket; continue; } - socket->CompleteConnect(remote_addr, false); + socket->CompleteConnect(remote_addr); if (paddr) { *paddr = remote_addr; } @@ -388,49 +345,57 @@ int VirtualSocket::SetOption(Option opt, int value) { } void VirtualSocket::OnMessage(Message* pmsg) { - if (pmsg->message_id == MSG_ID_PACKET) { - RTC_DCHECK(nullptr != pmsg->pdata); - Packet* packet = static_cast(pmsg->pdata); - - recv_buffer_.push_back(packet); - - if (async_) { - SignalReadEvent(this); - } - } else if (pmsg->message_id == MSG_ID_CONNECT) { - RTC_DCHECK(nullptr != pmsg->pdata); - MessageAddress* data = static_cast(pmsg->pdata); - if (listen_queue_ != nullptr) { - listen_queue_->push_back(data->addr); - if (async_) { - SignalReadEvent(this); + bool signal_read_event = false; + bool signal_close_event = false; + bool signal_connect_event = false; + int error_to_signal = 0; + { + webrtc::MutexLock lock(&mutex_); + if (pmsg->message_id == MSG_ID_PACKET) { + RTC_DCHECK(nullptr != pmsg->pdata); + Packet* packet = static_cast(pmsg->pdata); + + recv_buffer_.push_back(packet); + signal_read_event = async_; + } else if (pmsg->message_id == MSG_ID_CONNECT) { + RTC_DCHECK(nullptr != pmsg->pdata); + MessageAddress* data = static_cast(pmsg->pdata); + if (listen_queue_ != nullptr) { + listen_queue_->push_back(data->addr); + signal_read_event = async_; + } else if ((SOCK_STREAM == type_) && (CS_CONNECTING == state_)) { + CompleteConnect(data->addr); + signal_connect_event = async_; + } else { + RTC_LOG(LS_VERBOSE) + << "Socket at " << local_addr_.ToString() << " is not listening"; + server_->Disconnect(data->addr); } - } else if ((SOCK_STREAM == type_) && (CS_CONNECTING == state_)) { - CompleteConnect(data->addr, true); - } else { - RTC_LOG(LS_VERBOSE) << "Socket at " << local_addr_.ToString() - << " is not listening"; - server_->Disconnect(server_->LookupBinding(data->addr)); - } - delete data; - } else if (pmsg->message_id == MSG_ID_DISCONNECT) { - RTC_DCHECK(SOCK_STREAM == type_); - if (CS_CLOSED != state_) { - int error = (CS_CONNECTING == state_) ? ECONNREFUSED : 0; - state_ = CS_CLOSED; - remote_addr_.Clear(); - if (async_) { - SignalCloseEvent(this, error); + delete data; + } else if (pmsg->message_id == MSG_ID_DISCONNECT) { + RTC_DCHECK(SOCK_STREAM == type_); + if (CS_CLOSED != state_) { + error_to_signal = (CS_CONNECTING == state_) ? ECONNREFUSED : 0; + state_ = CS_CLOSED; + remote_addr_.Clear(); + signal_close_event = async_; } + } else if (pmsg->message_id == MSG_ID_SIGNALREADEVENT) { + signal_read_event = !recv_buffer_.empty(); + } else { + RTC_NOTREACHED(); } - } else if (pmsg->message_id == MSG_ID_ADDRESS_BOUND) { - SignalAddressReady(this, GetLocalAddress()); - } else if (pmsg->message_id == MSG_ID_SIGNALREADEVENT) { - if (!recv_buffer_.empty()) { - SignalReadEvent(this); - } - } else { - RTC_NOTREACHED(); + } + // Signal events without holding `mutex_`, to avoid recursive locking, as well + // as issues with sigslot and lock order. + if (signal_read_event) { + SignalReadEvent(this); + } + if (signal_close_event) { + SignalCloseEvent(this, error_to_signal); + } + if (signal_connect_event) { + SignalConnectEvent(this); } } @@ -465,14 +430,11 @@ int VirtualSocket::InitiateConnect(const SocketAddress& addr, bool use_delay) { return 0; } -void VirtualSocket::CompleteConnect(const SocketAddress& addr, bool notify) { +void VirtualSocket::CompleteConnect(const SocketAddress& addr) { RTC_DCHECK(CS_CONNECTING == state_); remote_addr_ = addr; state_ = CS_CONNECTED; server_->AddConnection(remote_addr_, local_addr_, this); - if (async_ && notify) { - SignalConnectEvent(this); - } } int VirtualSocket::SendUdp(const void* pv, @@ -494,7 +456,7 @@ int VirtualSocket::SendUdp(const void* pv, } int VirtualSocket::SendTcp(const void* pv, size_t cb) { - size_t capacity = server_->send_buffer_capacity_ - send_buffer_.size(); + size_t capacity = server_->send_buffer_capacity() - send_buffer_.size(); if (0 == capacity) { ready_to_send_ = false; error_ = EWOULDBLOCK; @@ -523,6 +485,67 @@ void VirtualSocket::OnSocketServerReadyToSend() { } } +void VirtualSocket::SetToBlocked() { + webrtc::MutexLock lock(&mutex_); + ready_to_send_ = false; + error_ = EWOULDBLOCK; +} + +void VirtualSocket::UpdateRecv(size_t data_size) { + recv_buffer_size_ += data_size; +} + +void VirtualSocket::UpdateSend(size_t data_size) { + size_t new_buffer_size = send_buffer_.size() - data_size; + // Avoid undefined access beyond the last element of the vector. + // This only happens when new_buffer_size is 0. + if (data_size < send_buffer_.size()) { + // memmove is required for potentially overlapping source/destination. + memmove(&send_buffer_[0], &send_buffer_[data_size], new_buffer_size); + } + send_buffer_.resize(new_buffer_size); +} + +void VirtualSocket::MaybeSignalWriteEvent(size_t capacity) { + if (!ready_to_send_ && (send_buffer_.size() < capacity)) { + ready_to_send_ = true; + SignalWriteEvent(this); + } +} + +uint32_t VirtualSocket::AddPacket(int64_t cur_time, size_t packet_size) { + network_size_ += packet_size; + uint32_t send_delay = + server_->SendDelay(static_cast(network_size_)); + + NetworkEntry entry; + entry.size = packet_size; + entry.done_time = cur_time + send_delay; + network_.push_back(entry); + + return send_delay; +} + +int64_t VirtualSocket::UpdateOrderedDelivery(int64_t ts) { + // Ensure that new packets arrive after previous ones + ts = std::max(ts, last_delivery_time_); + // A socket should not have both ordered and unordered delivery, so its last + // delivery time only needs to be updated when it has ordered delivery. + last_delivery_time_ = ts; + return ts; +} + +size_t VirtualSocket::PurgeNetworkPackets(int64_t cur_time) { + webrtc::MutexLock lock(&mutex_); + + while (!network_.empty() && (network_.front().done_time <= cur_time)) { + RTC_DCHECK(network_size_ >= network_.front().size); + network_size_ -= network_.front().size; + network_.pop_front(); + } + return network_size_; +} + VirtualSocketServer::VirtualSocketServer() : VirtualSocketServer(nullptr) {} VirtualSocketServer::VirtualSocketServer(ThreadProcessingFakeClock* fake_clock) @@ -596,17 +619,11 @@ AsyncSocket* VirtualSocketServer::CreateAsyncSocket(int family, int type) { } VirtualSocket* VirtualSocketServer::CreateSocketInternal(int family, int type) { - VirtualSocket* socket = new VirtualSocket(this, family, type, true); - SignalSocketCreated(socket); - return socket; + return new VirtualSocket(this, family, type, true); } void VirtualSocketServer::SetMessageQueue(Thread* msg_queue) { msg_queue_ = msg_queue; - if (msg_queue_) { - msg_queue_->SignalQueueDestroyed.connect( - this, &VirtualSocketServer::OnMessageQueueDestroyed); - } } bool VirtualSocketServer::Wait(int cmsWait, bool process_io) { @@ -814,19 +831,98 @@ bool VirtualSocketServer::Disconnect(VirtualSocket* socket) { return false; } +bool VirtualSocketServer::Disconnect(const SocketAddress& addr) { + return Disconnect(LookupBinding(addr)); +} + +bool VirtualSocketServer::Disconnect(const SocketAddress& local_addr, + const SocketAddress& remote_addr) { + // Disconnect remote socket, check if it is a child of a server socket. + VirtualSocket* socket = LookupConnection(local_addr, remote_addr); + if (!socket) { + // Not a server socket child, then see if it is bound. + // TODO(tbd): If this is indeed a server socket that has no + // children this will cause the server socket to be + // closed. This might lead to unexpected results, how to fix this? + socket = LookupBinding(remote_addr); + } + Disconnect(socket); + + // Remove mapping for both directions. + RemoveConnection(remote_addr, local_addr); + RemoveConnection(local_addr, remote_addr); + return socket != nullptr; +} + +void VirtualSocketServer::CancelConnects(VirtualSocket* socket) { + MessageList msgs; + if (msg_queue_) { + msg_queue_->Clear(socket, MSG_ID_CONNECT, &msgs); + } + for (MessageList::iterator it = msgs.begin(); it != msgs.end(); ++it) { + RTC_DCHECK(nullptr != it->pdata); + MessageAddress* data = static_cast(it->pdata); + SocketAddress local_addr = socket->GetLocalAddress(); + // Lookup remote side. + VirtualSocket* socket = LookupConnection(local_addr, data->addr); + if (socket) { + // Server socket, remote side is a socket retreived by + // accept. Accepted sockets are not bound so we will not + // find it by looking in the bindings table. + Disconnect(socket); + RemoveConnection(local_addr, data->addr); + } else { + Disconnect(data->addr); + } + delete data; + } +} + +void VirtualSocketServer::Clear(VirtualSocket* socket) { + // Clear incoming packets and disconnect messages + if (msg_queue_) { + msg_queue_->Clear(socket); + } +} + +void VirtualSocketServer::ProcessOneMessage() { + Message msg; + msg_queue_->Get(&msg); + msg_queue_->Dispatch(&msg); +} + +void VirtualSocketServer::PostSignalReadEvent(VirtualSocket* socket) { + // Clear the message so it doesn't end up posted multiple times. + msg_queue_->Clear(socket, MSG_ID_SIGNALREADEVENT); + msg_queue_->Post(RTC_FROM_HERE, socket, MSG_ID_SIGNALREADEVENT); +} + int VirtualSocketServer::SendUdp(VirtualSocket* socket, const char* data, size_t data_size, const SocketAddress& remote_addr) { ++sent_packets_; if (sending_blocked_) { - CritScope cs(&socket->crit_); - socket->ready_to_send_ = false; - socket->error_ = EWOULDBLOCK; + socket->SetToBlocked(); return -1; } + if (data_size > largest_seen_udp_payload_) { + if (data_size > 1000) { + RTC_LOG(LS_VERBOSE) << "Largest UDP seen is " << data_size; + } + largest_seen_udp_payload_ = data_size; + } + // See if we want to drop this packet. + if (data_size > max_udp_payload_) { + RTC_LOG(LS_VERBOSE) << "Dropping too large UDP payload of size " + << data_size << ", UDP payload limit is " + << max_udp_payload_; + // Return as if send was successful; packet disappears. + return data_size; + } + if (Random() < drop_prob_) { RTC_LOG(LS_VERBOSE) << "Dropping packet: bad luck"; return static_cast(data_size); @@ -856,10 +952,8 @@ int VirtualSocketServer::SendUdp(VirtualSocket* socket, } { - CritScope cs(&socket->crit_); - int64_t cur_time = TimeMillis(); - PurgeNetworkPackets(socket, cur_time); + size_t network_size = socket->PurgeNetworkPackets(cur_time); // Determine whether we have enough bandwidth to accept this packet. To do // this, we need to update the send queue. Once we know it's current size, @@ -870,7 +964,7 @@ int VirtualSocketServer::SendUdp(VirtualSocket* socket, // simulation of what a normal network would do. size_t packet_size = data_size + UDP_HEADER_SIZE; - if (socket->network_size_ + packet_size > network_capacity_) { + if (network_size + packet_size > network_capacity_) { RTC_LOG(LS_VERBOSE) << "Dropping packet: network capacity exceeded"; return static_cast(data_size); } @@ -898,45 +992,36 @@ void VirtualSocketServer::SendTcp(VirtualSocket* socket) { // Lookup the local/remote pair in the connections table. VirtualSocket* recipient = - LookupConnection(socket->local_addr_, socket->remote_addr_); + LookupConnection(socket->GetLocalAddress(), socket->GetRemoteAddress()); if (!recipient) { RTC_LOG(LS_VERBOSE) << "Sending data to no one."; return; } - CritScope cs(&socket->crit_); - int64_t cur_time = TimeMillis(); - PurgeNetworkPackets(socket, cur_time); + socket->PurgeNetworkPackets(cur_time); while (true) { - size_t available = recv_buffer_capacity_ - recipient->recv_buffer_size_; + size_t available = recv_buffer_capacity_ - recipient->recv_buffer_size(); size_t max_data_size = std::min(available, TCP_MSS - TCP_HEADER_SIZE); - size_t data_size = std::min(socket->send_buffer_.size(), max_data_size); + size_t data_size = std::min(socket->send_buffer_size(), max_data_size); if (0 == data_size) break; - AddPacketToNetwork(socket, recipient, cur_time, &socket->send_buffer_[0], + AddPacketToNetwork(socket, recipient, cur_time, socket->send_buffer_data(), data_size, TCP_HEADER_SIZE, true); - recipient->recv_buffer_size_ += data_size; - - size_t new_buffer_size = socket->send_buffer_.size() - data_size; - // Avoid undefined access beyond the last element of the vector. - // This only happens when new_buffer_size is 0. - if (data_size < socket->send_buffer_.size()) { - // memmove is required for potentially overlapping source/destination. - memmove(&socket->send_buffer_[0], &socket->send_buffer_[data_size], - new_buffer_size); - } - socket->send_buffer_.resize(new_buffer_size); + recipient->UpdateRecv(data_size); + socket->UpdateSend(data_size); } - if (!socket->ready_to_send_ && - (socket->send_buffer_.size() < send_buffer_capacity_)) { - socket->ready_to_send_ = true; - socket->SignalWriteEvent(socket); - } + socket->MaybeSignalWriteEvent(send_buffer_capacity_); +} + +void VirtualSocketServer::SendTcp(const SocketAddress& addr) { + VirtualSocket* sender = LookupBinding(addr); + RTC_DCHECK(nullptr != sender); + SendTcp(sender); } void VirtualSocketServer::AddPacketToNetwork(VirtualSocket* sender, @@ -946,13 +1031,7 @@ void VirtualSocketServer::AddPacketToNetwork(VirtualSocket* sender, size_t data_size, size_t header_size, bool ordered) { - VirtualSocket::NetworkEntry entry; - entry.size = data_size + header_size; - - sender->network_size_ += entry.size; - uint32_t send_delay = SendDelay(static_cast(sender->network_size_)); - entry.done_time = cur_time + send_delay; - sender->network_.push_back(entry); + uint32_t send_delay = sender->AddPacket(cur_time, data_size + header_size); // Find the delay for crossing the many virtual hops of the network. uint32_t transit_delay = GetTransitDelay(sender); @@ -960,7 +1039,7 @@ void VirtualSocketServer::AddPacketToNetwork(VirtualSocket* sender, // When the incoming packet is from a binding of the any address, translate it // to the default route here such that the recipient will see the default // route. - SocketAddress sender_addr = sender->local_addr_; + SocketAddress sender_addr = sender->GetLocalAddress(); IPAddress default_ip = GetDefaultRoute(sender_addr.ipaddr().family()); if (sender_addr.IsAnyIP() && !IPIsUnspec(default_ip)) { sender_addr.SetIP(default_ip); @@ -971,25 +1050,11 @@ void VirtualSocketServer::AddPacketToNetwork(VirtualSocket* sender, int64_t ts = TimeAfter(send_delay + transit_delay); if (ordered) { - // Ensure that new packets arrive after previous ones - ts = std::max(ts, sender->last_delivery_time_); - // A socket should not have both ordered and unordered delivery, so its last - // delivery time only needs to be updated when it has ordered delivery. - sender->last_delivery_time_ = ts; + ts = sender->UpdateOrderedDelivery(ts); } msg_queue_->PostAt(RTC_FROM_HERE, ts, recipient, MSG_ID_PACKET, p); } -void VirtualSocketServer::PurgeNetworkPackets(VirtualSocket* socket, - int64_t cur_time) { - while (!socket->network_.empty() && - (socket->network_.front().done_time <= cur_time)) { - RTC_DCHECK(socket->network_size_ >= socket->network_.front().size); - socket->network_size_ -= socket->network_.front().size; - socket->network_.pop_front(); - } -} - uint32_t VirtualSocketServer::SendDelay(uint32_t size) { if (bandwidth_ == 0) return 0; @@ -1019,13 +1084,7 @@ void PrintFunction(std::vector >* f) { #endif // void VirtualSocketServer::UpdateDelayDistribution() { - Function* dist = - CreateDistribution(delay_mean_, delay_stddev_, delay_samples_); - // We take a lock just to make sure we don't leak memory. - { - CritScope cs(&delay_crit_); - delay_dist_.reset(dist); - } + delay_dist_ = CreateDistribution(delay_mean_, delay_stddev_, delay_samples_); } static double PI = 4 * atan(1.0); @@ -1044,11 +1103,11 @@ static double Pareto(double x, double min, double k) { } #endif -VirtualSocketServer::Function* VirtualSocketServer::CreateDistribution( - uint32_t mean, - uint32_t stddev, - uint32_t samples) { - Function* f = new Function(); +std::unique_ptr +VirtualSocketServer::CreateDistribution(uint32_t mean, + uint32_t stddev, + uint32_t samples) { + auto f = std::make_unique(); if (0 == stddev) { f->push_back(Point(mean, 1.0)); @@ -1064,7 +1123,7 @@ VirtualSocketServer::Function* VirtualSocketServer::CreateDistribution( f->push_back(Point(x, y)); } } - return Resample(Invert(Accumulate(f)), 0, 1, samples); + return Resample(Invert(Accumulate(std::move(f))), 0, 1, samples); } uint32_t VirtualSocketServer::GetTransitDelay(Socket* socket) { @@ -1093,7 +1152,8 @@ struct FunctionDomainCmp { } }; -VirtualSocketServer::Function* VirtualSocketServer::Accumulate(Function* f) { +std::unique_ptr VirtualSocketServer::Accumulate( + std::unique_ptr f) { RTC_DCHECK(f->size() >= 1); double v = 0; for (Function::size_type i = 0; i < f->size() - 1; ++i) { @@ -1106,7 +1166,8 @@ VirtualSocketServer::Function* VirtualSocketServer::Accumulate(Function* f) { return f; } -VirtualSocketServer::Function* VirtualSocketServer::Invert(Function* f) { +std::unique_ptr VirtualSocketServer::Invert( + std::unique_ptr f) { for (Function::size_type i = 0; i < f->size(); ++i) std::swap((*f)[i].first, (*f)[i].second); @@ -1114,24 +1175,25 @@ VirtualSocketServer::Function* VirtualSocketServer::Invert(Function* f) { return f; } -VirtualSocketServer::Function* VirtualSocketServer::Resample(Function* f, - double x1, - double x2, - uint32_t samples) { - Function* g = new Function(); +std::unique_ptr VirtualSocketServer::Resample( + std::unique_ptr f, + double x1, + double x2, + uint32_t samples) { + auto g = std::make_unique(); for (size_t i = 0; i < samples; i++) { double x = x1 + (x2 - x1) * i / (samples - 1); - double y = Evaluate(f, x); + double y = Evaluate(f.get(), x); g->push_back(Point(x, y)); } - delete f; return g; } -double VirtualSocketServer::Evaluate(Function* f, double x) { - Function::iterator iter = absl::c_lower_bound(*f, x, FunctionDomainCmp()); +double VirtualSocketServer::Evaluate(const Function* f, double x) { + Function::const_iterator iter = + absl::c_lower_bound(*f, x, FunctionDomainCmp()); if (iter == f->begin()) { return (*f)[0].second; } else if (iter == f->end()) { diff --git a/rtc_base/virtual_socket_server.h b/rtc_base/virtual_socket_server.h index 84f8fb1bdc..6c58a4bdfe 100644 --- a/rtc_base/virtual_socket_server.h +++ b/rtc_base/virtual_socket_server.h @@ -17,11 +17,11 @@ #include "rtc_base/checks.h" #include "rtc_base/constructor_magic.h" -#include "rtc_base/deprecated/recursive_critical_section.h" #include "rtc_base/event.h" #include "rtc_base/fake_clock.h" #include "rtc_base/message_handler.h" #include "rtc_base/socket_server.h" +#include "rtc_base/synchronization/mutex.h" namespace rtc { @@ -33,7 +33,7 @@ class SocketAddressPair; // interface can create as many addresses as you want. All of the sockets // created by this network will be able to communicate with one another, unless // they are bound to addresses from incompatible families. -class VirtualSocketServer : public SocketServer, public sigslot::has_slots<> { +class VirtualSocketServer : public SocketServer { public: VirtualSocketServer(); // This constructor needs to be used if the test uses a fake clock and @@ -94,6 +94,16 @@ class VirtualSocketServer : public SocketServer, public sigslot::has_slots<> { drop_prob_ = drop_prob; } + // Controls the maximum UDP payload for the networks simulated + // by this server. Any UDP payload sent that is larger than this will + // be dropped. + size_t max_udp_payload() { return max_udp_payload_; } + void set_max_udp_payload(size_t payload_size) { + max_udp_payload_ = payload_size; + } + + size_t largest_seen_udp_payload() { return largest_seen_udp_payload_; } + // If |blocked| is true, subsequent attempts to send will result in -1 being // returned, with the socket error set to EWOULDBLOCK. // @@ -130,9 +140,9 @@ class VirtualSocketServer : public SocketServer, public sigslot::has_slots<> { typedef std::pair Point; typedef std::vector Function; - static Function* CreateDistribution(uint32_t mean, - uint32_t stddev, - uint32_t samples); + static std::unique_ptr CreateDistribution(uint32_t mean, + uint32_t stddev, + uint32_t samples); // Similar to Thread::ProcessMessages, but it only processes messages until // there are no immediate messages or pending network traffic. Returns false @@ -151,25 +161,12 @@ class VirtualSocketServer : public SocketServer, public sigslot::has_slots<> { // socket server. Intended to be used for test assertions. uint32_t sent_packets() const { return sent_packets_; } - // For testing purpose only. Fired when a client socket is created. - sigslot::signal1 SignalSocketCreated; - - protected: - // Returns a new IP not used before in this network. - IPAddress GetNextIP(int family); - uint16_t GetNextPort(); - - VirtualSocket* CreateSocketInternal(int family, int type); - // Binds the given socket to addr, assigning and IP and Port if necessary int Bind(VirtualSocket* socket, SocketAddress* addr); // Binds the given socket to the given (fully-defined) address. int Bind(VirtualSocket* socket, const SocketAddress& addr); - // Find the socket bound to the given address - VirtualSocket* LookupBinding(const SocketAddress& addr); - int Unbind(const SocketAddress& addr, VirtualSocket* socket); // Adds a mapping between this socket pair and the socket. @@ -177,13 +174,6 @@ class VirtualSocketServer : public SocketServer, public sigslot::has_slots<> { const SocketAddress& server, VirtualSocket* socket); - // Find the socket pair corresponding to this server address. - VirtualSocket* LookupConnection(const SocketAddress& client, - const SocketAddress& server); - - void RemoveConnection(const SocketAddress& client, - const SocketAddress& server); - // Connects the given socket to the socket at the given address int Connect(VirtualSocket* socket, const SocketAddress& remote_addr, @@ -192,6 +182,13 @@ class VirtualSocketServer : public SocketServer, public sigslot::has_slots<> { // Sends a disconnect message to the socket at the given address bool Disconnect(VirtualSocket* socket); + // Lookup address, and disconnect corresponding socket. + bool Disconnect(const SocketAddress& addr); + + // Lookup connection, close corresponding socket. + bool Disconnect(const SocketAddress& local_addr, + const SocketAddress& remote_addr); + // Sends the given packet to the socket at the given address (if one exists). int SendUdp(VirtualSocket* socket, const char* data, @@ -201,6 +198,44 @@ class VirtualSocketServer : public SocketServer, public sigslot::has_slots<> { // Moves as much data as possible from the sender's buffer to the network void SendTcp(VirtualSocket* socket); + // Like above, but lookup sender by address. + void SendTcp(const SocketAddress& addr); + + // Computes the number of milliseconds required to send a packet of this size. + uint32_t SendDelay(uint32_t size); + + // Cancel attempts to connect to a socket that is being closed. + void CancelConnects(VirtualSocket* socket); + + // Clear incoming messages for a socket that is being closed. + void Clear(VirtualSocket* socket); + + void ProcessOneMessage(); + + void PostSignalReadEvent(VirtualSocket* socket); + + // Sending was previously blocked, but now isn't. + sigslot::signal0<> SignalReadyToSend; + + protected: + // Returns a new IP not used before in this network. + IPAddress GetNextIP(int family); + + // Find the socket bound to the given address + VirtualSocket* LookupBinding(const SocketAddress& addr); + + private: + uint16_t GetNextPort(); + + VirtualSocket* CreateSocketInternal(int family, int type); + + // Find the socket pair corresponding to this server address. + VirtualSocket* LookupConnection(const SocketAddress& client, + const SocketAddress& server); + + void RemoveConnection(const SocketAddress& client, + const SocketAddress& server); + // Places a packet on the network. void AddPacketToNetwork(VirtualSocket* socket, VirtualSocket* recipient, @@ -210,31 +245,19 @@ class VirtualSocketServer : public SocketServer, public sigslot::has_slots<> { size_t header_size, bool ordered); - // Removes stale packets from the network - void PurgeNetworkPackets(VirtualSocket* socket, int64_t cur_time); - - // Computes the number of milliseconds required to send a packet of this size. - uint32_t SendDelay(uint32_t size); - // If the delay has been set for the address of the socket, returns the set // delay. Otherwise, returns a random transit delay chosen from the // appropriate distribution. uint32_t GetTransitDelay(Socket* socket); - // Basic operations on functions. Those that return a function also take - // ownership of the function given (and hence, may modify or delete it). - static Function* Accumulate(Function* f); - static Function* Invert(Function* f); - static Function* Resample(Function* f, - double x1, - double x2, - uint32_t samples); - static double Evaluate(Function* f, double x); - - // Null out our message queue if it goes away. Necessary in the case where - // our lifetime is greater than that of the thread we are using, since we - // try to send Close messages for all connected sockets when we shutdown. - void OnMessageQueueDestroyed() { msg_queue_ = nullptr; } + // Basic operations on functions. + static std::unique_ptr Accumulate(std::unique_ptr f); + static std::unique_ptr Invert(std::unique_ptr f); + static std::unique_ptr Resample(std::unique_ptr f, + double x1, + double x2, + uint32_t samples); + static double Evaluate(const Function* f, double x); // Determine if two sockets should be able to communicate. // We don't (currently) specify an address family for sockets; instead, @@ -254,12 +277,6 @@ class VirtualSocketServer : public SocketServer, public sigslot::has_slots<> { // NB: This scheme doesn't permit non-dualstack IPv6 sockets. static bool CanInteractWith(VirtualSocket* local, VirtualSocket* remote); - private: - friend class VirtualSocket; - - // Sending was previously blocked, but now isn't. - sigslot::signal0<> SignalReadyToSend; - typedef std::map AddressMap; typedef std::map ConnectionMap; @@ -295,9 +312,14 @@ class VirtualSocketServer : public SocketServer, public sigslot::has_slots<> { std::map alternative_address_mapping_; std::unique_ptr delay_dist_; - RecursiveCriticalSection delay_crit_; - double drop_prob_; + // The largest UDP payload permitted on this virtual socket server. + // The default is the max size of IPv4 fragmented UDP packet payload: + // 65535 bytes - 8 bytes UDP header - 20 bytes IP header. + size_t max_udp_payload_ = 65507; + // The largest UDP payload seen so far. + size_t largest_seen_udp_payload_ = 0; + bool sending_blocked_ = false; RTC_DISALLOW_COPY_AND_ASSIGN(VirtualSocketServer); }; @@ -334,11 +356,30 @@ class VirtualSocket : public AsyncSocket, int SetOption(Option opt, int value) override; void OnMessage(Message* pmsg) override; + size_t recv_buffer_size() const { return recv_buffer_size_; } + size_t send_buffer_size() const { return send_buffer_.size(); } + const char* send_buffer_data() const { return send_buffer_.data(); } + + // Used by server sockets to set the local address without binding. + void SetLocalAddress(const SocketAddress& addr); + bool was_any() { return was_any_; } void set_was_any(bool was_any) { was_any_ = was_any; } - // For testing purpose only. Fired when client socket is bound to an address. - sigslot::signal2 SignalAddressReady; + void SetToBlocked(); + + void UpdateRecv(size_t data_size); + void UpdateSend(size_t data_size); + + void MaybeSignalWriteEvent(size_t capacity); + + // Adds a packet to be sent. Returns delay, based on network_size_. + uint32_t AddPacket(int64_t cur_time, size_t packet_size); + + int64_t UpdateOrderedDelivery(int64_t ts); + + // Removes stale packets from the network. Returns current size. + size_t PurgeNetworkPackets(int64_t cur_time); private: struct NetworkEntry { @@ -353,25 +394,23 @@ class VirtualSocket : public AsyncSocket, typedef std::map OptionsMap; int InitiateConnect(const SocketAddress& addr, bool use_delay); - void CompleteConnect(const SocketAddress& addr, bool notify); + void CompleteConnect(const SocketAddress& addr); int SendUdp(const void* pv, size_t cb, const SocketAddress& addr); int SendTcp(const void* pv, size_t cb); - // Used by server sockets to set the local address without binding. - void SetLocalAddress(const SocketAddress& addr); - void OnSocketServerReadyToSend(); - VirtualSocketServer* server_; - int type_; - bool async_; + VirtualSocketServer* const server_; + const int type_; + const bool async_; ConnState state_; int error_; SocketAddress local_addr_; SocketAddress remote_addr_; // Pending sockets which can be Accepted - ListenQueue* listen_queue_; + std::unique_ptr listen_queue_ RTC_GUARDED_BY(mutex_) + RTC_PT_GUARDED_BY(mutex_); // Data which tcp has buffered for sending SendBuffer send_buffer_; @@ -379,8 +418,8 @@ class VirtualSocket : public AsyncSocket, // Set back to true when the socket can send again. bool ready_to_send_ = true; - // Critical section to protect the recv_buffer and queue_ - RecursiveCriticalSection crit_; + // Mutex to protect the recv_buffer and listen_queue_ + webrtc::Mutex mutex_; // Network model that enforces bandwidth and capacity constraints NetworkQueue network_; @@ -390,7 +429,7 @@ class VirtualSocket : public AsyncSocket, int64_t last_delivery_time_ = 0; // Data which has been received from the network - RecvBuffer recv_buffer_; + RecvBuffer recv_buffer_ RTC_GUARDED_BY(mutex_); // The amount of data which is in flight or in recv_buffer_ size_t recv_buffer_size_; @@ -405,8 +444,6 @@ class VirtualSocket : public AsyncSocket, // Store the options that are set OptionsMap options_map_; - - friend class VirtualSocketServer; }; } // namespace rtc diff --git a/rtc_base/virtual_socket_unittest.cc b/rtc_base/virtual_socket_unittest.cc index 78003f5cb2..96a359d187 100644 --- a/rtc_base/virtual_socket_unittest.cc +++ b/rtc_base/virtual_socket_unittest.cc @@ -1117,10 +1117,10 @@ TEST_F(VirtualSocketServerTest, CreatesStandardDistribution) { ASSERT_LT(0u, kTestSamples[sidx]); const uint32_t kStdDev = static_cast(kTestDev[didx] * kTestMean[midx]); - VirtualSocketServer::Function* f = + std::unique_ptr f = VirtualSocketServer::CreateDistribution(kTestMean[midx], kStdDev, kTestSamples[sidx]); - ASSERT_TRUE(nullptr != f); + ASSERT_TRUE(nullptr != f.get()); ASSERT_EQ(kTestSamples[sidx], f->size()); double sum = 0; for (uint32_t i = 0; i < f->size(); ++i) { @@ -1139,7 +1139,6 @@ TEST_F(VirtualSocketServerTest, CreatesStandardDistribution) { EXPECT_NEAR(kStdDev, stddev, 0.1 * kStdDev) << "M=" << kTestMean[midx] << " SD=" << kStdDev << " N=" << kTestSamples[sidx]; - delete f; } } } diff --git a/rtc_base/weak_ptr.h b/rtc_base/weak_ptr.h index 68d57fc557..a9e6b3a990 100644 --- a/rtc_base/weak_ptr.h +++ b/rtc_base/weak_ptr.h @@ -15,9 +15,9 @@ #include #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "rtc_base/ref_count.h" #include "rtc_base/ref_counted_object.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" // The implementation is borrowed from chromium except that it does not diff --git a/rtc_base/win/create_direct3d_device.h b/rtc_base/win/create_direct3d_device.h index 102f74148c..7c21f8720a 100644 --- a/rtc_base/win/create_direct3d_device.h +++ b/rtc_base/win/create_direct3d_device.h @@ -11,7 +11,7 @@ #ifndef RTC_BASE_WIN_CREATE_DIRECT3D_DEVICE_H_ #define RTC_BASE_WIN_CREATE_DIRECT3D_DEVICE_H_ -#include +#include #include #include #include diff --git a/rtc_tools/BUILD.gn b/rtc_tools/BUILD.gn index 9ba498c115..b841228a8e 100644 --- a/rtc_tools/BUILD.gn +++ b/rtc_tools/BUILD.gn @@ -24,27 +24,28 @@ group("rtc_tools") { ":rgba_to_i420_converter", ":video_quality_analysis", ] - if (rtc_enable_protobuf) { - deps += [ ":chart_proto" ] - } } - - if (rtc_include_tests) { + if (!build_with_chromium && rtc_enable_protobuf) { + deps += [ ":chart_proto" ] + } + if (!build_with_chromium && rtc_include_tests) { deps += [ ":tools_unittests", ":yuv_to_ivf_converter", ] - if (rtc_enable_protobuf) { - if (!build_with_chromium) { - deps += [ ":event_log_visualizer" ] - } - deps += [ - ":audioproc_f", - ":rtp_analyzer", - ":unpack_aecdump", - "network_tester", - ] - } + } + if (rtc_include_tests && rtc_enable_protobuf) { + deps += [ + ":rtp_analyzer", + "network_tester", + ] + } + if (rtc_include_tests && rtc_enable_protobuf && !build_with_chromium) { + deps += [ + ":audioproc_f", + ":event_log_visualizer", + ":unpack_aecdump", + ] } } @@ -113,6 +114,12 @@ rtc_library("video_quality_analysis") { absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } +# Abseil dependencies are not moved to the absl_deps field deliberately. +# If build_with_chromium is true, the absl_deps replaces the dependencies with +# the "//third_party/abseil-cpp:absl" target. Which doesn't include absl/flags +# (and some others) because they cannot be used in Chromiums. Special exception +# for the "frame_analyzer" target in "third_party/abseil-cpp/absl.gni" allows +# it to be build in chromium. rtc_executable("frame_analyzer") { visibility = [ "*" ] testonly = true @@ -129,6 +136,12 @@ rtc_executable("frame_analyzer") { "//third_party/abseil-cpp/absl/flags:parse", "//third_party/abseil-cpp/absl/strings", ] + + if (build_with_chromium) { + # When building from Chromium, WebRTC's metrics and field trial + # implementations need to be replaced by the Chromium ones. + deps += [ "//third_party/webrtc_overrides:webrtc_component" ] + } } # TODO(bugs.webrtc.org/11474): Enable this on win if needed. For now it @@ -148,6 +161,13 @@ if (!is_component_build) { # This target can be built from Chromium but it doesn't support # is_component_build=true because it depends on WebRTC testonly code # which is not part of //third_party/webrtc_overrides:webrtc_component. + + # Abseil dependencies are not moved to the absl_deps field deliberately. + # If build_with_chromium is true, the absl_deps replaces the dependencies with + # the "//third_party/abseil-cpp:absl" target. Which doesn't include absl/flags + # (and some others) because they cannot be used in Chromiums. Special exception + # for the "frame_analyzer" target in "third_party/abseil-cpp/absl.gni" allows + # it to be build in chromium. rtc_executable("rtp_generator") { visibility = [ "*" ] testonly = true @@ -181,6 +201,7 @@ if (!is_component_build) { "../rtc_base", "../rtc_base:rtc_base_approved", "../rtc_base:rtc_json", + "../rtc_base:threading", "../rtc_base/system:file_wrapper", "../test:fileutils", "../test:rtp_test_utils", @@ -200,6 +221,13 @@ if (!is_component_build) { # This target can be built from Chromium but it doesn't support # is_component_build=true because it depends on WebRTC testonly code # which is not part of //third_party/webrtc_overrides:webrtc_component. + + # Abseil dependencies are not moved to the absl_deps field deliberately. + # If build_with_chromium is true, the absl_deps replaces the dependencies with + # the "//third_party/abseil-cpp:absl" target. Which doesn't include absl/flags + # (and some others) because they cannot be used in Chromiums. Special exception + # for the "frame_analyzer" target in "third_party/abseil-cpp/absl.gni" allows + # it to be build in chromium. rtc_executable("video_replay") { visibility = [ "*" ] testonly = true @@ -209,11 +237,14 @@ if (!is_component_build) { "../api/task_queue:default_task_queue_factory", "../api/test/video:function_video_factory", "../api/transport:field_trial_based_config", + "../api/video:video_frame", "../api/video_codecs:video_codecs_api", "../call", "../call:call_interfaces", "../common_video", "../media:rtc_internal_video_codecs", + "../modules/rtp_rtcp:rtp_rtcp_format", + "../modules/video_coding:video_coding_utility", "../rtc_base:checks", "../rtc_base:rtc_json", "../rtc_base:stringutils", @@ -341,7 +372,6 @@ if (!build_with_chromium) { ":chart_proto", "../api:function_view", "../api:network_state_predictor_api", - "../rtc_base:deprecation", "../rtc_base:ignore_wundef", # TODO(kwiberg): Remove this dependency. @@ -368,10 +398,13 @@ if (!build_with_chromium) { "../rtc_base:rtc_base_approved", "../rtc_base:rtc_numerics", "../rtc_base:stringutils", + "../system_wrappers", "../test:explicit_key_value_config", ] absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/base:core_headers", + "//third_party/abseil-cpp/absl/functional:bind_front", "//third_party/abseil-cpp/absl/strings", "//third_party/abseil-cpp/absl/types:optional", ] @@ -380,152 +413,179 @@ if (!build_with_chromium) { } if (rtc_include_tests) { - rtc_executable("yuv_to_ivf_converter") { - visibility = [ "*" ] - testonly = true - sources = [ "converter/yuv_to_ivf_converter.cc" ] - deps = [ - "../api:create_frame_generator", - "../api:frame_generator_api", - "../api/task_queue:default_task_queue_factory", - "../api/video:encoded_image", - "../api/video:video_frame", - "../api/video_codecs:video_codecs_api", - "../media:rtc_media_base", - "../modules/rtp_rtcp:rtp_rtcp_format", - "../modules/video_coding:video_codec_interface", - "../modules/video_coding:video_coding_utility", - "../modules/video_coding:webrtc_h264", - "../modules/video_coding:webrtc_vp8", - "../modules/video_coding:webrtc_vp9", - "../rtc_base:checks", - "../rtc_base:criticalsection", - "../rtc_base:logging", - "../rtc_base:rtc_event", - "../rtc_base:rtc_task_queue", - "../rtc_base/synchronization:mutex", - "../rtc_base/system:file_wrapper", - "../test:video_test_common", - "../test:video_test_support", - "//third_party/abseil-cpp/absl/debugging:failure_signal_handler", - "//third_party/abseil-cpp/absl/debugging:symbolize", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - "//third_party/abseil-cpp/absl/strings", - ] - } - - if (rtc_enable_protobuf && !build_with_chromium) { - rtc_executable("event_log_visualizer") { + if (!build_with_chromium) { + rtc_executable("yuv_to_ivf_converter") { + visibility = [ "*" ] testonly = true - sources = [ "rtc_event_log_visualizer/main.cc" ] - data = [ - # If --wav_filename is not provided, event_log_visualizer uses - # EN_script2_F_sp2_B1.wav by default. This is a good default to use - # for example with flags --plot=all when there is no need to use a - # specific .wav file. - "../resources/audio_processing/conversational_speech/EN_script2_F_sp2_B1.wav", - ] + sources = [ "converter/yuv_to_ivf_converter.cc" ] deps = [ - ":event_log_visualizer_utils", - "../api/neteq:neteq_api", - "../api/rtc_event_log", - "../logging:rtc_event_log_parser", - "../modules/audio_coding:neteq", + "../api:create_frame_generator", + "../api:frame_generator_api", + "../api/task_queue:default_task_queue_factory", + "../api/video:encoded_image", + "../api/video:video_frame", + "../api/video_codecs:video_codecs_api", + "../media:rtc_media_base", "../modules/rtp_rtcp:rtp_rtcp_format", + "../modules/video_coding:video_codec_interface", + "../modules/video_coding:video_coding_utility", + "../modules/video_coding:webrtc_h264", + "../modules/video_coding:webrtc_vp8", + "../modules/video_coding:webrtc_vp9", "../rtc_base:checks", - "../rtc_base:protobuf_utils", - "../rtc_base:rtc_base_approved", - "../system_wrappers:field_trial", - "../test:field_trial", - "../test:fileutils", - "../test:test_support", - "//third_party/abseil-cpp/absl/algorithm:container", - "//third_party/abseil-cpp/absl/flags:config", + "../rtc_base:criticalsection", + "../rtc_base:logging", + "../rtc_base:rtc_event", + "../rtc_base:rtc_task_queue", + "../rtc_base/synchronization:mutex", + "../rtc_base/system:file_wrapper", + "../test:video_test_common", + "../test:video_test_support", + "//third_party/abseil-cpp/absl/debugging:failure_signal_handler", + "//third_party/abseil-cpp/absl/debugging:symbolize", "//third_party/abseil-cpp/absl/flags:flag", "//third_party/abseil-cpp/absl/flags:parse", - "//third_party/abseil-cpp/absl/flags:usage", "//third_party/abseil-cpp/absl/strings", ] } - } - tools_unittests_resources = [ - "../resources/foreman_128x96.yuv", - "../resources/foreman_cif.yuv", - "../resources/reference_less_video_test_file.y4m", - ] + if (rtc_enable_protobuf) { + rtc_executable("event_log_visualizer") { + testonly = true + sources = [ "rtc_event_log_visualizer/main.cc" ] + data = [ + # If --wav_filename is not provided, event_log_visualizer uses + # EN_script2_F_sp2_B1.wav by default. This is a good default to use + # for example with flags --plot=all when there is no need to use a + # specific .wav file. + "../resources/audio_processing/conversational_speech/EN_script2_F_sp2_B1.wav", + ] + deps = [ + ":event_log_visualizer_utils", + "../api/neteq:neteq_api", + "../api/rtc_event_log", + "../logging:rtc_event_log_parser", + "../modules/audio_coding:neteq", + "../modules/rtp_rtcp:rtp_rtcp_format", + "../rtc_base:checks", + "../rtc_base:protobuf_utils", + "../rtc_base:rtc_base_approved", + "../system_wrappers:field_trial", + "../test:field_trial", + "../test:fileutils", + "../test:test_support", + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/flags:config", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + "//third_party/abseil-cpp/absl/flags:usage", + "//third_party/abseil-cpp/absl/strings", + ] + } + } - if (is_ios) { - bundle_data("tools_unittests_bundle_data") { - testonly = true - sources = tools_unittests_resources - outputs = [ "{{bundle_resources_dir}}/{{source_file_part}}" ] + tools_unittests_resources = [ + "../resources/foreman_128x96.yuv", + "../resources/foreman_cif.yuv", + "../resources/reference_less_video_test_file.y4m", + ] + + if (is_ios) { + bundle_data("tools_unittests_bundle_data") { + testonly = true + sources = tools_unittests_resources + outputs = [ "{{bundle_resources_dir}}/{{source_file_part}}" ] + } } - } - rtc_test("tools_unittests") { - testonly = true + rtc_test("tools_unittests") { + testonly = true - sources = [ - "frame_analyzer/linear_least_squares_unittest.cc", - "frame_analyzer/reference_less_video_analysis_unittest.cc", - "frame_analyzer/video_color_aligner_unittest.cc", - "frame_analyzer/video_geometry_aligner_unittest.cc", - "frame_analyzer/video_quality_analysis_unittest.cc", - "frame_analyzer/video_temporal_aligner_unittest.cc", - "sanitizers_unittest.cc", - "video_file_reader_unittest.cc", - "video_file_writer_unittest.cc", - ] + sources = [ + "frame_analyzer/linear_least_squares_unittest.cc", + "frame_analyzer/reference_less_video_analysis_unittest.cc", + "frame_analyzer/video_color_aligner_unittest.cc", + "frame_analyzer/video_geometry_aligner_unittest.cc", + "frame_analyzer/video_quality_analysis_unittest.cc", + "frame_analyzer/video_temporal_aligner_unittest.cc", + "sanitizers_unittest.cc", + "video_file_reader_unittest.cc", + "video_file_writer_unittest.cc", + ] - deps = [ - ":video_file_reader", - ":video_file_writer", - ":video_quality_analysis", - "../api:scoped_refptr", - "../api/video:video_frame", - "../api/video:video_rtp_headers", - "../common_video", - "../rtc_base", - "../rtc_base:checks", - "../test:fileutils", - "../test:test_main", - "../test:test_support", - "//testing/gtest", - "//third_party/libyuv", - ] + deps = [ + ":video_file_reader", + ":video_file_writer", + ":video_quality_analysis", + "../api:scoped_refptr", + "../api/video:video_frame", + "../api/video:video_rtp_headers", + "../common_video", + "../rtc_base", + "../rtc_base:checks", + "../rtc_base:null_socket_server", + "../rtc_base:threading", + "../test:fileutils", + "../test:test_main", + "../test:test_support", + "//testing/gtest", + "//third_party/libyuv", + ] - if (!build_with_chromium) { - deps += [ ":reference_less_video_analysis_lib" ] + if (!build_with_chromium) { + deps += [ ":reference_less_video_analysis_lib" ] + } + + if (rtc_enable_protobuf) { + deps += [ "network_tester:network_tester_unittests" ] + } + + data = tools_unittests_resources + if (is_android) { + deps += [ "//testing/android/native_test:native_test_support" ] + shard_timeout = 900 + } + if (is_ios) { + deps += [ ":tools_unittests_bundle_data" ] + } } if (rtc_enable_protobuf) { - deps += [ "network_tester:network_tester_unittests" ] - } + rtc_executable("audioproc_f") { + testonly = true + sources = [ "audioproc_f/audioproc_float_main.cc" ] + deps = [ + "../api:audioproc_f_api", + "../modules/audio_processing", + "../modules/audio_processing:api", + "../rtc_base:rtc_base_approved", + ] + } - data = tools_unittests_resources - if (is_android) { - deps += [ "//testing/android/native_test:native_test_support" ] - shard_timeout = 900 - } - if (is_ios) { - deps += [ ":tools_unittests_bundle_data" ] + rtc_executable("unpack_aecdump") { + visibility = [ "*" ] + testonly = true + sources = [ "unpack_aecdump/unpack.cc" ] + + deps = [ + "../api:function_view", + "../common_audio", + "../modules/audio_processing", + "../modules/audio_processing:audioproc_debug_proto", + "../modules/audio_processing:audioproc_debug_proto", + "../modules/audio_processing:audioproc_protobuf_utils", + "../modules/audio_processing:audioproc_test_utils", + "../rtc_base:ignore_wundef", + "../rtc_base:protobuf_utils", + "../rtc_base:rtc_base_approved", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] + } # unpack_aecdump } } if (rtc_enable_protobuf) { - rtc_executable("audioproc_f") { - testonly = true - sources = [ "audioproc_f/audioproc_float_main.cc" ] - deps = [ - "../api:audioproc_f_api", - "../modules/audio_processing", - "../modules/audio_processing:api", - "../rtc_base:rtc_base_approved", - ] - } - copy("rtp_analyzer") { sources = [ "py_event_log_analyzer/misc.py", @@ -536,26 +596,5 @@ if (rtc_include_tests) { outputs = [ "$root_build_dir/{{source_file_part}}" ] deps = [ "../logging:rtc_event_log_proto" ] } # rtp_analyzer - - rtc_executable("unpack_aecdump") { - visibility = [ "*" ] - testonly = true - sources = [ "unpack_aecdump/unpack.cc" ] - - deps = [ - "../api:function_view", - "../common_audio", - "../modules/audio_processing", - "../modules/audio_processing:audioproc_debug_proto", - "../modules/audio_processing:audioproc_debug_proto", - "../modules/audio_processing:audioproc_protobuf_utils", - "../modules/audio_processing:audioproc_test_utils", - "../rtc_base:ignore_wundef", - "../rtc_base:protobuf_utils", - "../rtc_base:rtc_base_approved", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - ] - } # unpack_aecdump } } diff --git a/rtc_tools/DEPS b/rtc_tools/DEPS index 5ccd86b63b..3cf6080c93 100644 --- a/rtc_tools/DEPS +++ b/rtc_tools/DEPS @@ -29,4 +29,7 @@ specific_include_rules = { "+modules/video_coding/utility/ivf_file_writer.h", "+modules/video_coding/codecs/h264/include/h264.h", ], + ".*video_replay\.cc": [ + "+modules/video_coding/utility/ivf_file_writer.h", + ], } diff --git a/rtc_tools/frame_analyzer/video_geometry_aligner.cc b/rtc_tools/frame_analyzer/video_geometry_aligner.cc index db397bc3a5..88da26d4d0 100644 --- a/rtc_tools/frame_analyzer/video_geometry_aligner.cc +++ b/rtc_tools/frame_analyzer/video_geometry_aligner.cc @@ -61,7 +61,7 @@ rtc::scoped_refptr CropAndZoom( adjusted_frame->MutableDataY(), adjusted_frame->StrideY(), adjusted_frame->MutableDataU(), adjusted_frame->StrideU(), adjusted_frame->MutableDataV(), adjusted_frame->StrideV(), - frame->width(), frame->height(), libyuv::kFilterBilinear); + frame->width(), frame->height(), libyuv::kFilterBox); return adjusted_frame; } diff --git a/rtc_tools/network_tester/BUILD.gn b/rtc_tools/network_tester/BUILD.gn index b270262f0d..f7982d3eef 100644 --- a/rtc_tools/network_tester/BUILD.gn +++ b/rtc_tools/network_tester/BUILD.gn @@ -39,17 +39,20 @@ if (rtc_enable_protobuf) { deps = [ ":network_tester_config_proto", ":network_tester_packet_proto", + "../../api:sequence_checker", "../../api/task_queue", "../../api/task_queue:default_task_queue_factory", "../../p2p", "../../rtc_base", "../../rtc_base:checks", "../../rtc_base:ignore_wundef", + "../../rtc_base:ip_address", "../../rtc_base:protobuf_utils", "../../rtc_base:rtc_base_approved", "../../rtc_base:rtc_task_queue", + "../../rtc_base:socket_address", + "../../rtc_base:threading", "../../rtc_base/synchronization:mutex", - "../../rtc_base/synchronization:sequence_checker", "../../rtc_base/system:no_unique_address", "../../rtc_base/third_party/sigslot", ] diff --git a/rtc_tools/network_tester/packet_sender.h b/rtc_tools/network_tester/packet_sender.h index c0ea2c1680..7ccecdd84c 100644 --- a/rtc_tools/network_tester/packet_sender.h +++ b/rtc_tools/network_tester/packet_sender.h @@ -14,10 +14,10 @@ #include #include +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_factory.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/ignore_wundef.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_queue.h" diff --git a/rtc_tools/network_tester/test_controller.h b/rtc_tools/network_tester/test_controller.h index b73ac94329..50055fcf4c 100644 --- a/rtc_tools/network_tester/test_controller.h +++ b/rtc_tools/network_tester/test_controller.h @@ -19,16 +19,15 @@ #include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "p2p/base/basic_packet_socket_factory.h" #include "rtc_base/async_packet_socket.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/ignore_wundef.h" #include "rtc_base/socket_address.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" #include "rtc_tools/network_tester/packet_logger.h" #include "rtc_tools/network_tester/packet_sender.h" @@ -69,7 +68,7 @@ class TestController : public sigslot::has_slots<> { size_t len, const rtc::SocketAddress& remote_addr, const int64_t& packet_time_us); - rtc::ThreadChecker test_controller_thread_checker_; + SequenceChecker test_controller_thread_checker_; SequenceChecker packet_sender_checker_; rtc::BasicPacketSocketFactory socket_factory_; const std::string config_file_path_; diff --git a/rtc_tools/rtc_event_log_visualizer/analyze_audio.cc b/rtc_tools/rtc_event_log_visualizer/analyze_audio.cc index becc0044ab..02184a64ea 100644 --- a/rtc_tools/rtc_event_log_visualizer/analyze_audio.cc +++ b/rtc_tools/rtc_event_log_visualizer/analyze_audio.cc @@ -307,14 +307,10 @@ std::unique_ptr CreateNetEqTestAndRun( input.reset(new test::NetEqReplacementInput(std::move(input), kReplacementPt, cn_types, forbidden_types)); - NetEq::Config config; - config.max_packets_in_buffer = 200; - config.enable_fast_accelerate = true; - std::unique_ptr output(new test::VoidAudioSink()); rtc::scoped_refptr decoder_factory = - new rtc::RefCountedObject( + rtc::make_ref_counted( replacement_file_name, file_sample_rate_hz); test::NetEqTest::DecoderMap codecs = { @@ -330,6 +326,7 @@ std::unique_ptr CreateNetEqTestAndRun( callbacks.post_insert_packet = neteq_stats_getter->delay_analyzer(); callbacks.get_audio_callback = neteq_stats_getter.get(); + NetEq::Config config; test::NetEqTest test(config, decoder_factory, codecs, /*text_log=*/nullptr, /*factory=*/nullptr, std::move(input), std::move(output), callbacks); diff --git a/rtc_tools/rtc_event_log_visualizer/analyzer.cc b/rtc_tools/rtc_event_log_visualizer/analyzer.cc index a7153c6fbd..0f727f2815 100644 --- a/rtc_tools/rtc_event_log_visualizer/analyzer.cc +++ b/rtc_tools/rtc_event_log_visualizer/analyzer.cc @@ -19,6 +19,7 @@ #include #include "absl/algorithm/container.h" +#include "absl/functional/bind_front.h" #include "absl/strings/string_view.h" #include "api/function_view.h" #include "api/network_state_predictor.h" @@ -445,6 +446,8 @@ void EventLogAnalyzer::CreateRtcpTypeGraph(PacketDirection direction, CreateRtcpTypeTimeSeries(parsed_log_.firs(direction), config_, "FIR", 7)); plot->AppendTimeSeries( CreateRtcpTypeTimeSeries(parsed_log_.plis(direction), config_, "PLI", 8)); + plot->AppendTimeSeries( + CreateRtcpTypeTimeSeries(parsed_log_.byes(direction), config_, "BYE", 9)); plot->SetXAxis(config_.CallBeginTimeSec(), config_.CallEndTimeSec(), "Time (s)", kLeftMargin, kRightMargin); plot->SetSuggestedYAxis(0, 1, "RTCP type", kBottomMargin, kTopMargin); @@ -456,7 +459,8 @@ void EventLogAnalyzer::CreateRtcpTypeGraph(PacketDirection direction, {5, "NACK"}, {6, "REMB"}, {7, "FIR"}, - {8, "PLI"}}); + {8, "PLI"}, + {9, "BYE"}}); } template @@ -1263,7 +1267,7 @@ void EventLogAnalyzer::CreateSendSideBweSimulationGraph(Plot* plot) { const RtpPacketType& rtp_packet = *rtp_iterator->second; if (rtp_packet.rtp.header.extension.hasTransportSequenceNumber) { RtpPacketSendInfo packet_info; - packet_info.ssrc = rtp_packet.rtp.header.ssrc; + packet_info.media_ssrc = rtp_packet.rtp.header.ssrc; packet_info.transport_sequence_number = rtp_packet.rtp.header.extension.transportSequenceNumber; packet_info.rtp_sequence_number = rtp_packet.rtp.header.sequenceNumber; @@ -1364,13 +1368,11 @@ void EventLogAnalyzer::CreateSendSideBweSimulationGraph(Plot* plot) { void EventLogAnalyzer::CreateReceiveSideBweSimulationGraph(Plot* plot) { using RtpPacketType = LoggedRtpPacketIncoming; - class RembInterceptingPacketRouter : public PacketRouter { + class RembInterceptor { public: - void OnReceiveBitrateChanged(const std::vector& ssrcs, - uint32_t bitrate_bps) override { + void SendRemb(uint32_t bitrate_bps, std::vector ssrcs) { last_bitrate_bps_ = bitrate_bps; bitrate_updated_ = true; - PacketRouter::OnReceiveBitrateChanged(ssrcs, bitrate_bps); } uint32_t last_bitrate_bps() const { return last_bitrate_bps_; } bool GetAndResetBitrateUpdated() { @@ -1397,10 +1399,10 @@ void EventLogAnalyzer::CreateReceiveSideBweSimulationGraph(Plot* plot) { } SimulatedClock clock(0); - RembInterceptingPacketRouter packet_router; - // TODO(terelius): The PacketRouter is used as the RemoteBitrateObserver. - // Is this intentional? - ReceiveSideCongestionController rscc(&clock, &packet_router); + RembInterceptor remb_interceptor; + ReceiveSideCongestionController rscc( + &clock, [](auto...) {}, + absl::bind_front(&RembInterceptor::SendRemb, &remb_interceptor), nullptr); // TODO(holmer): Log the call config and use that here instead. // static const uint32_t kDefaultStartBitrateBps = 300000; // rscc.SetBweBitrates(0, kDefaultStartBitrateBps, -1); @@ -1425,9 +1427,9 @@ void EventLogAnalyzer::CreateReceiveSideBweSimulationGraph(Plot* plot) { float x = config_.GetCallTimeSec(clock.TimeInMicroseconds()); acked_time_series.points.emplace_back(x, y); } - if (packet_router.GetAndResetBitrateUpdated() || + if (remb_interceptor.GetAndResetBitrateUpdated() || clock.TimeInMicroseconds() - last_update_us >= 1e6) { - uint32_t y = packet_router.last_bitrate_bps() / 1000; + uint32_t y = remb_interceptor.last_bitrate_bps() / 1000; float x = config_.GetCallTimeSec(clock.TimeInMicroseconds()); time_series.points.emplace_back(x, y); last_update_us = clock.TimeInMicroseconds(); diff --git a/rtc_tools/rtc_event_log_visualizer/log_simulation.cc b/rtc_tools/rtc_event_log_visualizer/log_simulation.cc index 0e5b5d04a9..c0b418de4b 100644 --- a/rtc_tools/rtc_event_log_visualizer/log_simulation.cc +++ b/rtc_tools/rtc_event_log_visualizer/log_simulation.cc @@ -14,6 +14,7 @@ #include "logging/rtc_event_log/rtc_event_processor.h" #include "modules/rtp_rtcp/source/time_util.h" +#include "system_wrappers/include/clock.h" namespace webrtc { @@ -83,7 +84,7 @@ void LogBasedNetworkControllerSimulation::OnPacketSent( } RtpPacketSendInfo packet_info; - packet_info.ssrc = packet.ssrc; + packet_info.media_ssrc = packet.ssrc; packet_info.transport_sequence_number = packet.transport_seq_no; packet_info.rtp_sequence_number = packet.stream_seq_no; packet_info.length = packet.size; @@ -142,11 +143,13 @@ void LogBasedNetworkControllerSimulation::OnReceiverReport( HandleStateUpdate(controller_->OnTransportLossReport(msg)); } + Clock* clock = Clock::GetRealTimeClock(); TimeDelta rtt = TimeDelta::PlusInfinity(); for (auto& rb : report.rr.report_blocks()) { if (rb.last_sr()) { + Timestamp report_log_time = Timestamp::Micros(report.log_time_us()); uint32_t receive_time_ntp = - CompactNtp(TimeMicrosToNtp(report.log_time_us())); + CompactNtp(clock->ConvertTimestampToNtpTime(report_log_time)); uint32_t rtt_ntp = receive_time_ntp - rb.delay_since_last_sr() - rb.last_sr(); rtt = std::min(rtt, TimeDelta::Millis(CompactNtpRttToMs(rtt_ntp))); diff --git a/rtc_tools/rtc_event_log_visualizer/plot_base.cc b/rtc_tools/rtc_event_log_visualizer/plot_base.cc index dce601a832..82533e6eb0 100644 --- a/rtc_tools/rtc_event_log_visualizer/plot_base.cc +++ b/rtc_tools/rtc_event_log_visualizer/plot_base.cc @@ -127,9 +127,8 @@ void Plot::PrintPythonCode() const { // There is a plt.bar function that draws bar plots, // but it is *way* too slow to be useful. printf( - "plt.vlines(x%zu, map(lambda t: min(t,0), y%zu), map(lambda t: " - "max(t,0), y%zu), color=colors[%zu], " - "label=\'%s\')\n", + "plt.vlines(x%zu, [min(t,0) for t in y%zu], [max(t,0) for t in " + "y%zu], color=colors[%zu], label=\'%s\')\n", i, i, i, i, series_list_[i].label.c_str()); if (series_list_[i].point_style == PointStyle::kHighlight) { printf( diff --git a/rtc_tools/rtc_event_log_visualizer/plot_base.h b/rtc_tools/rtc_event_log_visualizer/plot_base.h index 06a206f031..a26146b5e5 100644 --- a/rtc_tools/rtc_event_log_visualizer/plot_base.h +++ b/rtc_tools/rtc_event_log_visualizer/plot_base.h @@ -15,7 +15,7 @@ #include #include -#include "rtc_base/deprecation.h" +#include "absl/base/attributes.h" #include "rtc_base/ignore_wundef.h" RTC_PUSH_IGNORING_WUNDEF() @@ -101,8 +101,8 @@ class Plot { public: virtual ~Plot() {} - // Deprecated. Use PrintPythonCode() or ExportProtobuf() instead. - RTC_DEPRECATED virtual void Draw() {} + ABSL_DEPRECATED("Use PrintPythonCode() or ExportProtobuf() instead.") + virtual void Draw() {} // Sets the lower x-axis limit to min_value (if left_margin == 0). // Sets the upper x-axis limit to max_value (if right_margin == 0). @@ -189,8 +189,8 @@ class PlotCollection { public: virtual ~PlotCollection() {} - // Deprecated. Use PrintPythonCode() or ExportProtobuf() instead. - RTC_DEPRECATED virtual void Draw() {} + ABSL_DEPRECATED("Use PrintPythonCode() or ExportProtobuf() instead.") + virtual void Draw() {} virtual Plot* AppendNewPlot(); diff --git a/rtc_tools/rtc_event_log_visualizer/plot_protobuf.h b/rtc_tools/rtc_event_log_visualizer/plot_protobuf.h index 0773b58d20..fbe68853a3 100644 --- a/rtc_tools/rtc_event_log_visualizer/plot_protobuf.h +++ b/rtc_tools/rtc_event_log_visualizer/plot_protobuf.h @@ -10,6 +10,7 @@ #ifndef RTC_TOOLS_RTC_EVENT_LOG_VISUALIZER_PLOT_PROTOBUF_H_ #define RTC_TOOLS_RTC_EVENT_LOG_VISUALIZER_PLOT_PROTOBUF_H_ +#include "absl/base/attributes.h" #include "rtc_base/ignore_wundef.h" RTC_PUSH_IGNORING_WUNDEF() #include "rtc_tools/rtc_event_log_visualizer/proto/chart.pb.h" @@ -25,10 +26,10 @@ class ProtobufPlot final : public Plot { void Draw() override; }; -class ProtobufPlotCollection final : public PlotCollection { +class ABSL_DEPRECATED("Use PlotCollection and ExportProtobuf() instead") + ProtobufPlotCollection final : public PlotCollection { public: - // This class is deprecated. Use PlotCollection and ExportProtobuf() instead. - RTC_DEPRECATED ProtobufPlotCollection(); + ProtobufPlotCollection(); ~ProtobufPlotCollection() override; void Draw() override; Plot* AppendNewPlot() override; diff --git a/rtc_tools/rtc_event_log_visualizer/plot_python.h b/rtc_tools/rtc_event_log_visualizer/plot_python.h index 998ed7b221..6acc436d71 100644 --- a/rtc_tools/rtc_event_log_visualizer/plot_python.h +++ b/rtc_tools/rtc_event_log_visualizer/plot_python.h @@ -10,6 +10,7 @@ #ifndef RTC_TOOLS_RTC_EVENT_LOG_VISUALIZER_PLOT_PYTHON_H_ #define RTC_TOOLS_RTC_EVENT_LOG_VISUALIZER_PLOT_PYTHON_H_ +#include "absl/base/attributes.h" #include "rtc_tools/rtc_event_log_visualizer/plot_base.h" namespace webrtc { @@ -21,10 +22,10 @@ class PythonPlot final : public Plot { void Draw() override; }; -class PythonPlotCollection final : public PlotCollection { +class ABSL_DEPRECATED("Use PlotCollection and PrintPythonCode() instead.") + PythonPlotCollection final : public PlotCollection { public: - // This class is deprecated. Use PlotCollection and PrintPythonCode() instead. - RTC_DEPRECATED explicit PythonPlotCollection(bool shared_xaxis = false); + explicit PythonPlotCollection(bool shared_xaxis = false); ~PythonPlotCollection() override; void Draw() override; Plot* AppendNewPlot() override; diff --git a/rtc_tools/rtp_generator/rtp_generator.cc b/rtc_tools/rtp_generator/rtp_generator.cc index 21826c8dff..c2fc1cff06 100644 --- a/rtc_tools/rtp_generator/rtp_generator.cc +++ b/rtc_tools/rtp_generator/rtp_generator.cc @@ -136,10 +136,15 @@ absl::optional ParseRtpGeneratorOptionsFromFile( } // Parse the file as JSON - Json::Reader json_reader; + Json::CharReaderBuilder builder; Json::Value json; - if (!json_reader.parse(raw_json_buffer.data(), json)) { - RTC_LOG(LS_ERROR) << "Unable to parse the corpus config json file"; + std::string error_message; + std::unique_ptr json_reader(builder.newCharReader()); + if (!json_reader->parse(raw_json_buffer.data(), + raw_json_buffer.data() + raw_json_buffer.size(), + &json, &error_message)) { + RTC_LOG(LS_ERROR) << "Unable to parse the corpus config json file. Error:" + << error_message; return absl::nullopt; } @@ -188,15 +193,17 @@ RtpGenerator::RtpGenerator(const RtpGeneratorOptions& options) PayloadStringToCodecType(video_config.rtp.payload_name); if (video_config.rtp.payload_name == cricket::kVp8CodecName) { VideoCodecVP8 settings = VideoEncoder::GetDefaultVp8Settings(); - encoder_config.encoder_specific_settings = new rtc::RefCountedObject< - VideoEncoderConfig::Vp8EncoderSpecificSettings>(settings); + encoder_config.encoder_specific_settings = + rtc::make_ref_counted( + settings); } else if (video_config.rtp.payload_name == cricket::kVp9CodecName) { VideoCodecVP9 settings = VideoEncoder::GetDefaultVp9Settings(); - encoder_config.encoder_specific_settings = new rtc::RefCountedObject< - VideoEncoderConfig::Vp9EncoderSpecificSettings>(settings); + encoder_config.encoder_specific_settings = + rtc::make_ref_counted( + settings); } else if (video_config.rtp.payload_name == cricket::kH264CodecName) { VideoCodecH264 settings = VideoEncoder::GetDefaultH264Settings(); - encoder_config.encoder_specific_settings = new rtc::RefCountedObject< + encoder_config.encoder_specific_settings = rtc::make_ref_counted< VideoEncoderConfig::H264EncoderSpecificSettings>(settings); } encoder_config.video_format.name = video_config.rtp.payload_name; @@ -217,7 +224,7 @@ RtpGenerator::RtpGenerator(const RtpGeneratorOptions& options) } encoder_config.video_stream_factory = - new rtc::RefCountedObject( + rtc::make_ref_counted( video_config.rtp.payload_name, /*max qp*/ 56, /*screencast*/ false, /*screenshare enabled*/ false); diff --git a/rtc_tools/unpack_aecdump/unpack.cc b/rtc_tools/unpack_aecdump/unpack.cc index ba3af129bf..4a98349820 100644 --- a/rtc_tools/unpack_aecdump/unpack.cc +++ b/rtc_tools/unpack_aecdump/unpack.cc @@ -81,6 +81,10 @@ ABSL_FLAG(bool, text, false, "Write non-audio files as text files instead of binary files."); +ABSL_FLAG(bool, + use_init_suffix, + false, + "Use init index instead of capture frame count as file name suffix."); #define PRINT_CONFIG(field_name) \ if (msg.has_##field_name()) { \ @@ -224,6 +228,16 @@ std::vector RuntimeSettingWriters() { })}; } +std::string GetWavFileIndex(int init_index, int frame_count) { + rtc::StringBuilder suffix; + if (absl::GetFlag(FLAGS_use_init_suffix)) { + suffix << "_" << init_index; + } else { + suffix << frame_count; + } + return suffix.str(); +} + } // namespace int do_main(int argc, char* argv[]) { @@ -243,6 +257,7 @@ int do_main(int argc, char* argv[]) { Event event_msg; int frame_count = 0; + int init_count = 0; size_t reverse_samples_per_channel = 0; size_t input_samples_per_channel = 0; size_t output_samples_per_channel = 0; @@ -452,9 +467,11 @@ int do_main(int argc, char* argv[]) { return 1; } + ++init_count; const Init msg = event_msg.init(); // These should print out zeros if they're missing. - fprintf(settings_file, "Init at frame: %d\n", frame_count); + fprintf(settings_file, "Init #%d at frame: %d\n", init_count, + frame_count); int input_sample_rate = msg.sample_rate(); fprintf(settings_file, " Input sample rate: %d\n", input_sample_rate); int output_sample_rate = msg.output_sample_rate(); @@ -495,24 +512,24 @@ int do_main(int argc, char* argv[]) { if (!absl::GetFlag(FLAGS_raw)) { // The WAV files need to be reset every time, because they cant change // their sample rate or number of channels. + + std::string suffix = GetWavFileIndex(init_count, frame_count); rtc::StringBuilder reverse_name; - reverse_name << absl::GetFlag(FLAGS_reverse_file) << frame_count - << ".wav"; + reverse_name << absl::GetFlag(FLAGS_reverse_file) << suffix << ".wav"; reverse_wav_file.reset(new WavWriter( reverse_name.str(), reverse_sample_rate, num_reverse_channels)); rtc::StringBuilder input_name; - input_name << absl::GetFlag(FLAGS_input_file) << frame_count << ".wav"; + input_name << absl::GetFlag(FLAGS_input_file) << suffix << ".wav"; input_wav_file.reset(new WavWriter(input_name.str(), input_sample_rate, num_input_channels)); rtc::StringBuilder output_name; - output_name << absl::GetFlag(FLAGS_output_file) << frame_count - << ".wav"; + output_name << absl::GetFlag(FLAGS_output_file) << suffix << ".wav"; output_wav_file.reset(new WavWriter( output_name.str(), output_sample_rate, num_output_channels)); if (WritingCallOrderFile()) { rtc::StringBuilder callorder_name; - callorder_name << absl::GetFlag(FLAGS_callorder_file) << frame_count + callorder_name << absl::GetFlag(FLAGS_callorder_file) << suffix << ".char"; callorder_char_file = OpenFile(callorder_name.str(), "wb"); } diff --git a/rtc_tools/video_file_reader.cc b/rtc_tools/video_file_reader.cc index b01fc0fcdd..bfdcba45fa 100644 --- a/rtc_tools/video_file_reader.cc +++ b/rtc_tools/video_file_reader.cc @@ -224,8 +224,8 @@ rtc::scoped_refptr`.h` and `.cc` files come in pairs - -`.h` and `.cc` files should come in pairs, with the same name (except -for the file type suffix), in the same directory, in the same build -target. - -* If a declaration in `path/to/foo.h` has a definition in some `.cc` - file, it should be in `path/to/foo.cc`. -* If a definition in `path/to/foo.cc` file has a declaration in some - `.h` file, it should be in `path/to/foo.h`. -* Omit the `.cc` file if it would have been empty, but still list the - `.h` file in a build target. -* Omit the `.h` file if it would have been empty. (This can happen - with unit test `.cc` files, and with `.cc` files that define - `main`.) - -This makes the source code easier to navigate and organize, and -precludes some questionable build system practices such as having -build targets that don’t pull in definitions for everything they -declare. - -[Examples and exceptions](style-guide/h-cc-pairs.md). - -### TODO comments - -Follow the [Google style][goog-style-todo]. When referencing a WebRTC bug, -prefer the url form, e.g. -``` -// TODO(bugs.webrtc.org/12345): Delete the hack when blocking bugs are resolved. -``` - -[goog-style-todo]: https://google.github.io/styleguide/cppguide.html#TODO_Comments - -### ArrayView - -When passing an array of values to a function, use `rtc::ArrayView` -whenever possible—that is, whenever you’re not passing ownership of -the array, and don’t allow the callee to change the array size. - -For example, - -instead of | use -------------------------------------|--------------------- -`const std::vector&` | `ArrayView` -`const T* ptr, size_t num_elements` | `ArrayView` -`T* ptr, size_t num_elements` | `ArrayView` - -See [the source](api/array_view.h) for more detailed docs. - -### sigslot - -sigslot is a lightweight library that adds a signal/slot language -construct to C++, making it easy to implement the observer pattern -with minimal boilerplate code. - -When adding a signal to a pure interface, **prefer to add a pure -virtual method that returns a reference to a signal**: - -``` -sigslot::signal& SignalFoo() = 0; -``` - -As opposed to making it a public member variable, as a lot of legacy -code does: - -``` -sigslot::signal SignalFoo; -``` - -The virtual method approach has the advantage that it keeps the -interface stateless, and gives the subclass more flexibility in how it -implements the signal. It may: - -* Have its own signal as a member variable. -* Use a `sigslot::repeater`, to repeat a signal of another object: - - ``` - sigslot::repeater foo_; - /* ... */ - foo_.repeat(bar_.SignalFoo()); - ``` -* Just return another object's signal directly, if the other object's - lifetime is the same as its own. - - ``` - sigslot::signal& SignalFoo() { return bar_.SignalFoo(); } - ``` - -### std::bind - -Don’t use `std::bind`—there are pitfalls, and lambdas are almost as -succinct and already familiar to modern C++ programmers. - -### std::function - -`std::function` is allowed, but remember that it’s not the right tool -for every occasion. Prefer to use interfaces when that makes sense, -and consider `rtc::FunctionView` for cases where the callee will not -save the function object. - -### Forward declarations - -WebRTC follows the [Google][goog-forward-declarations] C++ style guide -with respect to forward declarations. In summary: avoid using forward -declarations where possible; just `#include` the headers you need. - -[goog-forward-declarations]: https://google.github.io/styleguide/cppguide.html#Forward_Declarations - -## **C** - -There’s a substantial chunk of legacy C code in WebRTC, and a lot of -it is old enough that it violates the parts of the C++ style guide -that also applies to C (naming etc.) for the simple reason that it -pre-dates the use of the current C++ style guide for this code base. - -* If making small changes to C code, mimic the style of the - surrounding code. -* If making large changes to C code, consider converting the whole - thing to C++ first. - -## **Java** - -WebRTC follows the [Google Java style guide][goog-java-style]. - -[goog-java-style]: https://google.github.io/styleguide/javaguide.html - -## **Objective-C and Objective-C++** - -WebRTC follows the -[Chromium Objective-C and Objective-C++ style guide][chr-objc-style]. - -[chr-objc-style]: https://chromium.googlesource.com/chromium/src/+/HEAD/styleguide/objective-c/objective-c.md - -## **Python** - -WebRTC follows [Chromium’s Python style][chr-py-style]. - -[chr-py-style]: https://chromium.googlesource.com/chromium/src/+/HEAD/styleguide/styleguide.md#python - -## **Build files** - -The WebRTC build files are written in [GN][gn], and we follow -the [Chromium GN style guide][chr-gn-style]. Additionally, there are -some WebRTC-specific rules below; in case of conflict, they trump the -Chromium style guide. - -[gn]: https://chromium.googlesource.com/chromium/src/tools/gn/ -[chr-gn-style]: https://chromium.googlesource.com/chromium/src/tools/gn/+/HEAD/docs/style_guide.md - -### WebRTC-specific GN templates - -Use the following [GN templates][gn-templ] to ensure that all -our [targets][gn-target] are built with the same configuration: - -instead of | use ------------------|--------------------- -`executable` | `rtc_executable` -`shared_library` | `rtc_shared_library` -`source_set` | `rtc_source_set` -`static_library` | `rtc_static_library` -`test` | `rtc_test` - -[gn-templ]: https://chromium.googlesource.com/chromium/src/tools/gn/+/HEAD/docs/language.md#Templates -[gn-target]: https://chromium.googlesource.com/chromium/src/tools/gn/+/HEAD/docs/language.md#Targets - -### Target visibility and the native API - -The [WebRTC-specific GN templates](#webrtc-gn-templates) declare build -targets whose default `visibility` allows all other targets in the -WebRTC tree (and no targets outside the tree) to depend on them. - -Prefer to restrict the visibility if possible: - -* If a target is used by only one or a tiny number of other targets, - prefer to list them explicitly: `visibility = [ ":foo", ":bar" ]` -* If a target is used only by targets in the same `BUILD.gn` file: - `visibility = [ ":*" ]`. - -Setting `visibility = [ "*" ]` means that targets outside the WebRTC -tree can depend on this target; use this only for build targets whose -headers are part of the [native API](native-api.md). - -### Conditional compilation with the C preprocessor - -Avoid using the C preprocessor to conditionally enable or disable -pieces of code. But if you can’t avoid it, introduce a GN variable, -and then set a preprocessor constant to either 0 or 1 in the build -targets that need it: - -``` -if (apm_debug_dump) { - defines = [ "WEBRTC_APM_DEBUG_DUMP=1" ] -} else { - defines = [ "WEBRTC_APM_DEBUG_DUMP=0" ] -} -``` - -In the C, C++, or Objective-C files, use `#if` when testing the flag, -not `#ifdef` or `#if defined()`: - -``` -#if WEBRTC_APM_DEBUG_DUMP -// One way. -#else -// Or another. -#endif -``` - -When combined with the `-Wundef` compiler option, this produces -compile time warnings if preprocessor symbols are misspelled, or used -without corresponding build rules to set them. diff --git a/system_wrappers/BUILD.gn b/system_wrappers/BUILD.gn index f44ff5b8bf..80088e0d01 100644 --- a/system_wrappers/BUILD.gn +++ b/system_wrappers/BUILD.gn @@ -108,7 +108,7 @@ rtc_library("metrics") { ] } -if (rtc_include_tests) { +if (rtc_include_tests && !build_with_chromium) { rtc_test("system_wrappers_unittests") { testonly = true sources = [ @@ -130,9 +130,10 @@ if (rtc_include_tests) { "../test:test_main", "../test:test_support", "//testing/gtest", - "//third_party/abseil-cpp/absl/strings", ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] + if (is_android) { deps += [ "//testing/android/native_test:native_test_support" ] diff --git a/system_wrappers/include/clock.h b/system_wrappers/include/clock.h index 3c60f63da8..271291c214 100644 --- a/system_wrappers/include/clock.h +++ b/system_wrappers/include/clock.h @@ -32,22 +32,24 @@ const double kMagicNtpFractionalUnit = 4.294967296E+9; class RTC_EXPORT Clock { public: virtual ~Clock() {} + // Return a timestamp relative to an unspecified epoch. - virtual Timestamp CurrentTime() { - return Timestamp::Micros(TimeInMicroseconds()); + virtual Timestamp CurrentTime() = 0; + int64_t TimeInMilliseconds() { return CurrentTime().ms(); } + int64_t TimeInMicroseconds() { return CurrentTime().us(); } + + // Retrieve an NTP absolute timestamp (with an epoch of Jan 1, 1900). + // TODO(bugs.webrtc.org/11327): Make this non-virtual once + // "WebRTC-SystemIndependentNtpTimeKillSwitch" is removed. + virtual NtpTime CurrentNtpTime() { + return ConvertTimestampToNtpTime(CurrentTime()); } - virtual int64_t TimeInMilliseconds() { return CurrentTime().ms(); } - virtual int64_t TimeInMicroseconds() { return CurrentTime().us(); } - - // Retrieve an NTP absolute timestamp. - virtual NtpTime CurrentNtpTime() = 0; - - // Retrieve an NTP absolute timestamp in milliseconds. - virtual int64_t CurrentNtpInMilliseconds() = 0; + int64_t CurrentNtpInMilliseconds() { return CurrentNtpTime().ToMs(); } - // Converts an NTP timestamp to a millisecond timestamp. - static int64_t NtpToMs(uint32_t seconds, uint32_t fractions) { - return NtpTime(seconds, fractions).ToMs(); + // Converts between a relative timestamp returned by this clock, to NTP time. + virtual NtpTime ConvertTimestampToNtpTime(Timestamp timestamp) = 0; + int64_t ConvertTimestampToNtpTimeInMilliseconds(int64_t timestamp_ms) { + return ConvertTimestampToNtpTime(Timestamp::Millis(timestamp_ms)).ToMs(); } // Returns an instance of the real-time system clock implementation. @@ -56,20 +58,15 @@ class RTC_EXPORT Clock { class SimulatedClock : public Clock { public: + // The constructors assume an epoch of Jan 1, 1970. explicit SimulatedClock(int64_t initial_time_us); explicit SimulatedClock(Timestamp initial_time); - ~SimulatedClock() override; - // Return a timestamp relative to some arbitrary source; the source is fixed - // for this clock. + // Return a timestamp with an epoch of Jan 1, 1970. Timestamp CurrentTime() override; - // Retrieve an NTP absolute timestamp. - NtpTime CurrentNtpTime() override; - - // Converts an NTP timestamp to a millisecond timestamp. - int64_t CurrentNtpInMilliseconds() override; + NtpTime ConvertTimestampToNtpTime(Timestamp timestamp) override; // Advance the simulated clock with a given number of milliseconds or // microseconds. diff --git a/system_wrappers/source/clock.cc b/system_wrappers/source/clock.cc index 0ae624d849..77c1d36327 100644 --- a/system_wrappers/source/clock.cc +++ b/system_wrappers/source/clock.cc @@ -10,6 +10,8 @@ #include "system_wrappers/include/clock.h" +#include "system_wrappers/include/field_trial.h" + #if defined(WEBRTC_WIN) // Windows needs to be included before mmsystem.h @@ -29,57 +31,88 @@ #include "rtc_base/time_utils.h" namespace webrtc { +namespace { + +int64_t NtpOffsetUsCalledOnce() { + constexpr int64_t kNtpJan1970Sec = 2208988800; + int64_t clock_time = rtc::TimeMicros(); + int64_t utc_time = rtc::TimeUTCMicros(); + return utc_time - clock_time + kNtpJan1970Sec * rtc::kNumMicrosecsPerSec; +} + +NtpTime TimeMicrosToNtp(int64_t time_us) { + static int64_t ntp_offset_us = NtpOffsetUsCalledOnce(); + + int64_t time_ntp_us = time_us + ntp_offset_us; + RTC_DCHECK_GE(time_ntp_us, 0); // Time before year 1900 is unsupported. + + // Convert seconds to uint32 through uint64 for a well-defined cast. + // A wrap around, which will happen in 2036, is expected for NTP time. + uint32_t ntp_seconds = + static_cast(time_ntp_us / rtc::kNumMicrosecsPerSec); + + // Scale fractions of the second to NTP resolution. + constexpr int64_t kNtpFractionsInSecond = 1LL << 32; + int64_t us_fractions = time_ntp_us % rtc::kNumMicrosecsPerSec; + uint32_t ntp_fractions = + us_fractions * kNtpFractionsInSecond / rtc::kNumMicrosecsPerSec; + + return NtpTime(ntp_seconds, ntp_fractions); +} + +void GetSecondsAndFraction(const timeval& time, + uint32_t* seconds, + double* fraction) { + *seconds = time.tv_sec + kNtpJan1970; + *fraction = time.tv_usec / 1e6; + + while (*fraction >= 1) { + --*fraction; + ++*seconds; + } + while (*fraction < 0) { + ++*fraction; + --*seconds; + } +} + +} // namespace class RealTimeClock : public Clock { + public: + RealTimeClock() + : use_system_independent_ntp_time_(!field_trial::IsEnabled( + "WebRTC-SystemIndependentNtpTimeKillSwitch")) {} + Timestamp CurrentTime() override { return Timestamp::Micros(rtc::TimeMicros()); } - // Return a timestamp in milliseconds relative to some arbitrary source; the - // source is fixed for this clock. - int64_t TimeInMilliseconds() override { return rtc::TimeMillis(); } - // Return a timestamp in microseconds relative to some arbitrary source; the - // source is fixed for this clock. - int64_t TimeInMicroseconds() override { return rtc::TimeMicros(); } - - // Retrieve an NTP absolute timestamp. NtpTime CurrentNtpTime() override { - timeval tv = CurrentTimeVal(); - double microseconds_in_seconds; - uint32_t seconds; - Adjust(tv, &seconds, µseconds_in_seconds); - uint32_t fractions = static_cast( - microseconds_in_seconds * kMagicNtpFractionalUnit + 0.5); - return NtpTime(seconds, fractions); + return use_system_independent_ntp_time_ ? TimeMicrosToNtp(rtc::TimeMicros()) + : SystemDependentNtpTime(); } - // Retrieve an NTP absolute timestamp in milliseconds. - int64_t CurrentNtpInMilliseconds() override { - timeval tv = CurrentTimeVal(); - uint32_t seconds; - double microseconds_in_seconds; - Adjust(tv, &seconds, µseconds_in_seconds); - return 1000 * static_cast(seconds) + - static_cast(1000.0 * microseconds_in_seconds + 0.5); + NtpTime ConvertTimestampToNtpTime(Timestamp timestamp) override { + // This method does not check |use_system_independent_ntp_time_| because + // all callers never used the old behavior of |CurrentNtpTime|. + return TimeMicrosToNtp(timestamp.us()); } protected: virtual timeval CurrentTimeVal() = 0; - static void Adjust(const timeval& tv, - uint32_t* adjusted_s, - double* adjusted_us_in_s) { - *adjusted_s = tv.tv_sec + kNtpJan1970; - *adjusted_us_in_s = tv.tv_usec / 1e6; - - if (*adjusted_us_in_s >= 1) { - *adjusted_us_in_s -= 1; - ++*adjusted_s; - } else if (*adjusted_us_in_s < -1) { - *adjusted_us_in_s += 1; - --*adjusted_s; - } + private: + NtpTime SystemDependentNtpTime() { + uint32_t seconds; + double fraction; + GetSecondsAndFraction(CurrentTimeVal(), &seconds, &fraction); + + return NtpTime(seconds, static_cast( + fraction * kMagicNtpFractionalUnit + 0.5)); } + + bool use_system_independent_ntp_time_; }; #if defined(WINUWP) @@ -90,10 +123,10 @@ class WinUwpRealTimeClock final : public RealTimeClock { protected: timeval CurrentTimeVal() override { - // The rtc::SystemTimeNanos() method is already time offset from a base - // epoch value and might as be synchronized against an NTP time server as - // an added bonus. - auto nanos = rtc::SystemTimeNanos(); + // The rtc::WinUwpSystemTimeNanos() method is already time offset from a + // base epoch value and might as be synchronized against an NTP time server + // as an added bonus. + auto nanos = rtc::WinUwpSystemTimeNanos(); struct timeval tv; @@ -249,18 +282,14 @@ Timestamp SimulatedClock::CurrentTime() { return Timestamp::Micros(time_us_.load(std::memory_order_relaxed)); } -NtpTime SimulatedClock::CurrentNtpTime() { - int64_t now_ms = TimeInMilliseconds(); - uint32_t seconds = (now_ms / 1000) + kNtpJan1970; - uint32_t fractions = - static_cast((now_ms % 1000) * kMagicNtpFractionalUnit / 1000); +NtpTime SimulatedClock::ConvertTimestampToNtpTime(Timestamp timestamp) { + int64_t now_us = timestamp.us(); + uint32_t seconds = (now_us / 1'000'000) + kNtpJan1970; + uint32_t fractions = static_cast( + (now_us % 1'000'000) * kMagicNtpFractionalUnit / 1'000'000); return NtpTime(seconds, fractions); } -int64_t SimulatedClock::CurrentNtpInMilliseconds() { - return TimeInMilliseconds() + 1000 * static_cast(kNtpJan1970); -} - void SimulatedClock::AdvanceTimeMilliseconds(int64_t milliseconds) { AdvanceTime(TimeDelta::Millis(milliseconds)); } diff --git a/system_wrappers/source/field_trial.cc b/system_wrappers/source/field_trial.cc index f1dccc987b..d10b5cff3f 100644 --- a/system_wrappers/source/field_trial.cc +++ b/system_wrappers/source/field_trial.cc @@ -85,7 +85,7 @@ void InsertOrReplaceFieldTrialStringsInMap( (*fieldtrial_map)[tokens[idx]] = tokens[idx + 1]; } } else { - RTC_DCHECK(false) << "Invalid field trials string:" << trials_string; + RTC_NOTREACHED() << "Invalid field trials string:" << trials_string; } } diff --git a/system_wrappers/source/ntp_time_unittest.cc b/system_wrappers/source/ntp_time_unittest.cc index cdaca67fbe..0705531e37 100644 --- a/system_wrappers/source/ntp_time_unittest.cc +++ b/system_wrappers/source/ntp_time_unittest.cc @@ -56,7 +56,6 @@ TEST(NtpTimeTest, ToMsMeansToNtpMilliseconds) { SimulatedClock clock(0x123456789abc); NtpTime ntp = clock.CurrentNtpTime(); - EXPECT_EQ(ntp.ToMs(), Clock::NtpToMs(ntp.seconds(), ntp.fractions())); EXPECT_EQ(ntp.ToMs(), clock.CurrentNtpInMilliseconds()); } diff --git a/test/BUILD.gn b/test/BUILD.gn index 0e1209fd20..82d0b9ea28 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -7,29 +7,32 @@ # be found in the AUTHORS file in the root of the source tree. import("//build/config/ui.gni") +import("//third_party/google_benchmark/buildconfig.gni") import("../webrtc.gni") if (is_android) { import("//build/config/android/rules.gni") } -group("test") { - testonly = true - - deps = [ - ":copy_to_file_audio_capturer", - ":rtp_test_utils", - ":test_common", - ":test_renderer", - ":test_support", - ":video_test_common", - ] +if (!build_with_chromium) { + group("test") { + testonly = true - if (rtc_include_tests) { - deps += [ - ":test_main", - ":test_support_unittests", - "pc/e2e", + deps = [ + ":copy_to_file_audio_capturer", + ":rtp_test_utils", + ":test_common", + ":test_renderer", + ":test_support", + ":video_test_common", ] + + if (rtc_include_tests) { + deps += [ + ":test_main", + ":test_support_unittests", + "pc/e2e", + ] + } } } @@ -49,10 +52,10 @@ rtc_library("frame_generator_impl") { ":frame_utils", "../api:frame_generator_api", "../api:scoped_refptr", + "../api:sequence_checker", "../api/video:encoded_image", "../api/video:video_frame", "../api/video:video_frame_i010", - "../api/video:video_frame_nv12", "../api/video:video_rtp_headers", "../api/video_codecs:video_codecs_api", "../common_video", @@ -68,7 +71,6 @@ rtc_library("frame_generator_impl") { "../rtc_base:rtc_base_approved", "../rtc_base:rtc_event", "../rtc_base/synchronization:mutex", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/system:file_wrapper", "../system_wrappers", ] @@ -98,6 +100,8 @@ rtc_library("video_test_common") { "frame_forwarder.h", "frame_generator_capturer.cc", "frame_generator_capturer.h", + "mappable_native_buffer.cc", + "mappable_native_buffer.h", "test_video_capturer.cc", "test_video_capturer.h", "video_codec_settings.h", @@ -106,6 +110,7 @@ rtc_library("video_test_common") { deps = [ ":fileutils", ":frame_utils", + "../api:array_view", "../api:create_frame_generator", "../api:frame_generator_api", "../api:scoped_refptr", @@ -127,7 +132,10 @@ rtc_library("video_test_common") { "../rtc_base/task_utils:repeating_task", "../system_wrappers", ] - absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + ] } if (!build_with_chromium) { @@ -145,6 +153,7 @@ if (!build_with_chromium) { "../api:scoped_refptr", "../modules/video_capture:video_capture_module", "../rtc_base", + "../rtc_base:threading", "../sdk:base_objc", "../sdk:native_api", "../sdk:native_video", @@ -203,6 +212,7 @@ rtc_library("rtp_test_utils") { "../rtc_base/synchronization:mutex", "../rtc_base/system:arch", ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] } rtc_library("field_trial") { @@ -248,10 +258,14 @@ rtc_library("perf_test") { "../rtc_base:criticalsection", "../rtc_base:logging", "../rtc_base:rtc_numerics", + "../rtc_base:stringutils", "../rtc_base/synchronization:mutex", "../test:fileutils", ] - absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] if (rtc_enable_protobuf) { sources += [ "testsupport/perf_test_histogram_writer.cc" ] deps += [ @@ -355,6 +369,7 @@ rtc_library("video_test_support") { ":test_support", ":video_test_common", "../api:scoped_refptr", + "../api:sequence_checker", "../api/video:encoded_image", "../api/video:video_frame", "../api/video_codecs:video_codecs_api", @@ -370,7 +385,6 @@ rtc_library("video_test_support") { "../rtc_base:logging", "../rtc_base:rtc_base_approved", "../rtc_base:rtc_event", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/system:file_wrapper", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] @@ -387,7 +401,16 @@ rtc_library("video_test_support") { } } -if (rtc_include_tests) { +if (rtc_include_tests && enable_google_benchmarks) { + rtc_library("benchmark_main") { + testonly = true + sources = [ "benchmark_main.cc" ] + + deps = [ "//third_party/google_benchmark" ] + } +} + +if (rtc_include_tests && !build_with_chromium) { rtc_library("resources_dir_flag") { testonly = true visibility = [ "*" ] @@ -415,6 +438,7 @@ if (rtc_include_tests) { "../rtc_base:checks", "../rtc_base:logging", "../rtc_base:rtc_base_approved", + "../rtc_base:threading", "../system_wrappers:field_trial", "../system_wrappers:metrics", ] @@ -443,13 +467,6 @@ if (rtc_include_tests) { ] } - rtc_library("benchmark_main") { - testonly = true - sources = [ "benchmark_main.cc" ] - - deps = [ "//third_party/google_benchmark" ] - } - rtc_library("test_support_test_artifacts") { testonly = true sources = [ @@ -535,6 +552,8 @@ if (rtc_include_tests) { "scenario:scenario_unittests", "time_controller:time_controller", "time_controller:time_controller_unittests", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/flags:flag", "//third_party/abseil-cpp/absl/strings", "//third_party/abseil-cpp/absl/types:optional", @@ -573,7 +592,7 @@ if (rtc_include_tests) { deps += [ ":test_support_unittests_bundle_data" ] } - if (!is_android && !build_with_chromium) { + if (!is_android) { # This is needed in order to avoid: # undefined symbol: webrtc::videocapturemodule::VideoCaptureImpl::Create deps += [ "../modules/video_capture:video_capture_internal_impl" ] @@ -717,17 +736,17 @@ rtc_library("direct_transport") { "direct_transport.h", ] deps = [ - ":rtp_test_utils", + "../api:sequence_checker", "../api:simulated_network_api", "../api:transport_api", "../api/task_queue", "../api/units:time_delta", "../call:call_interfaces", "../call:simulated_packet_receiver", + "../modules/rtp_rtcp:rtp_rtcp_format", "../rtc_base:macromagic", "../rtc_base:timeutils", "../rtc_base/synchronization:mutex", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/task_utils:repeating_task", ] absl_deps = [ "//third_party/abseil-cpp/absl/memory" ] @@ -753,6 +772,7 @@ rtc_library("fake_video_codecs") { deps = [ "../api:fec_controller_api", "../api:scoped_refptr", + "../api:sequence_checker", "../api/task_queue", "../api/video:encoded_image", "../api/video:video_bitrate_allocation", @@ -769,7 +789,6 @@ rtc_library("fake_video_codecs") { "../rtc_base:rtc_task_queue", "../rtc_base:timeutils", "../rtc_base/synchronization:mutex", - "../rtc_base/synchronization:sequence_checker", "../system_wrappers", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] @@ -829,9 +848,9 @@ rtc_library("test_common") { ":fake_video_codecs", ":fileutils", ":mock_transport", - ":rtp_test_utils", ":test_support", ":video_test_common", + "../api:array_view", "../api:create_frame_generator", "../api:frame_generator_api", "../api:rtp_headers", @@ -865,6 +884,7 @@ rtc_library("test_common") { "../rtc_base:rtc_base", "../rtc_base:rtc_event", "../rtc_base:task_queue_for_test", + "../rtc_base:threading", "../rtc_base/task_utils:to_queued_task", "../system_wrappers", "../system_wrappers:field_trial", diff --git a/test/android/AndroidManifest.xml b/test/android/AndroidManifest.xml index ee2fec8716..ad3f434b4f 100644 --- a/test/android/AndroidManifest.xml +++ b/test/android/AndroidManifest.xml @@ -39,7 +39,7 @@ be found in the AUTHORS file in the root of the source tree. - diff --git a/test/call_test.cc b/test/call_test.cc index dd7c576ef9..11230dae2f 100644 --- a/test/call_test.cc +++ b/test/call_test.cc @@ -409,7 +409,7 @@ void CallTest::CreateMatchingAudioAndFecConfigs( if (num_flexfec_streams_ == 1) { CreateMatchingFecConfig(rtcp_send_transport, *GetVideoSendConfig()); for (const RtpExtension& extension : GetVideoSendConfig()->rtp.extensions) - GetFlexFecConfig()->rtp_header_extensions.push_back(extension); + GetFlexFecConfig()->rtp.extensions.push_back(extension); } } @@ -444,11 +444,13 @@ void CallTest::CreateMatchingFecConfig( const VideoSendStream::Config& send_config) { FlexfecReceiveStream::Config config(transport); config.payload_type = send_config.rtp.flexfec.payload_type; - config.remote_ssrc = send_config.rtp.flexfec.ssrc; + config.rtp.remote_ssrc = send_config.rtp.flexfec.ssrc; config.protected_media_ssrcs = send_config.rtp.flexfec.protected_media_ssrcs; - config.local_ssrc = kReceiverLocalVideoSsrc; - if (!video_receive_configs_.empty()) + config.rtp.local_ssrc = kReceiverLocalVideoSsrc; + if (!video_receive_configs_.empty()) { video_receive_configs_[0].rtp.protected_by_flexfec = true; + video_receive_configs_[0].rtp.packet_sink_ = this; + } flexfec_receive_configs_.push_back(config); } @@ -510,8 +512,6 @@ void CallTest::CreateVideoStreams() { video_receive_streams_.push_back(receiver_call_->CreateVideoReceiveStream( video_receive_configs_[i].Copy())); } - - AssociateFlexfecStreamsWithVideoStreams(); } void CallTest::CreateVideoSendStreams() { @@ -572,8 +572,6 @@ void CallTest::CreateFlexfecStreams() { receiver_call_->CreateFlexfecReceiveStream( flexfec_receive_configs_[i])); } - - AssociateFlexfecStreamsWithVideoStreams(); } void CallTest::ConnectVideoSourcesToStreams() { @@ -582,23 +580,6 @@ void CallTest::ConnectVideoSourcesToStreams() { degradation_preference_); } -void CallTest::AssociateFlexfecStreamsWithVideoStreams() { - // All FlexFEC streams protect all of the video streams. - for (FlexfecReceiveStream* flexfec_recv_stream : flexfec_receive_streams_) { - for (VideoReceiveStream* video_recv_stream : video_receive_streams_) { - video_recv_stream->AddSecondarySink(flexfec_recv_stream); - } - } -} - -void CallTest::DissociateFlexfecStreamsFromVideoStreams() { - for (FlexfecReceiveStream* flexfec_recv_stream : flexfec_receive_streams_) { - for (VideoReceiveStream* video_recv_stream : video_receive_streams_) { - video_recv_stream->RemoveSecondarySink(flexfec_recv_stream); - } - } -} - void CallTest::Start() { StartVideoStreams(); if (audio_send_stream_) { @@ -632,8 +613,6 @@ void CallTest::StopVideoStreams() { } void CallTest::DestroyStreams() { - DissociateFlexfecStreamsFromVideoStreams(); - if (audio_send_stream_) sender_call_->DestroyAudioSendStream(audio_send_stream_); audio_send_stream_ = nullptr; @@ -691,6 +670,12 @@ FlexfecReceiveStream::Config* CallTest::GetFlexFecConfig() { return &flexfec_receive_configs_[0]; } +void CallTest::OnRtpPacket(const RtpPacketReceived& packet) { + // All FlexFEC streams protect all of the video streams. + for (FlexfecReceiveStream* flexfec_recv_stream : flexfec_receive_streams_) + flexfec_recv_stream->OnRtpPacket(packet); +} + absl::optional CallTest::GetRtpExtensionByUri( const std::string& uri) const { for (const auto& extension : rtp_extensions_) { diff --git a/test/call_test.h b/test/call_test.h index 4b26097b6c..adb21dd7f0 100644 --- a/test/call_test.h +++ b/test/call_test.h @@ -38,7 +38,7 @@ namespace test { class BaseTest; -class CallTest : public ::testing::Test { +class CallTest : public ::testing::Test, public RtpPacketSinkInterface { public: CallTest(); virtual ~CallTest(); @@ -156,9 +156,6 @@ class CallTest : public ::testing::Test { void ConnectVideoSourcesToStreams(); - void AssociateFlexfecStreamsWithVideoStreams(); - void DissociateFlexfecStreamsFromVideoStreams(); - void Start(); void StartVideoStreams(); void Stop(); @@ -177,6 +174,9 @@ class CallTest : public ::testing::Test { FlexfecReceiveStream::Config* GetFlexFecConfig(); TaskQueueBase* task_queue() { return task_queue_.get(); } + // RtpPacketSinkInterface implementation. + void OnRtpPacket(const RtpPacketReceived& packet) override; + test::RunLoop loop_; Clock* const clock_; diff --git a/test/direct_transport.cc b/test/direct_transport.cc index 9c7a8f88d0..7e9c5aefeb 100644 --- a/test/direct_transport.cc +++ b/test/direct_transport.cc @@ -14,9 +14,9 @@ #include "api/units/time_delta.h" #include "call/call.h" #include "call/fake_network_pipe.h" +#include "modules/rtp_rtcp/source/rtp_util.h" #include "rtc_base/task_utils/repeating_task.h" #include "rtc_base/time_utils.h" -#include "test/rtp_header_parser.h" namespace webrtc { namespace test { @@ -26,7 +26,7 @@ Demuxer::Demuxer(const std::map& payload_type_map) MediaType Demuxer::GetMediaType(const uint8_t* packet_data, const size_t packet_length) const { - if (!RtpHeaderParser::IsRtcp(packet_data, packet_length)) { + if (IsRtpPacket(rtc::MakeArrayView(packet_data, packet_length))) { RTC_CHECK_GE(packet_length, 2); const uint8_t payload_type = packet_data[1] & 0x7f; std::map::const_iterator it = diff --git a/test/direct_transport.h b/test/direct_transport.h index 2fc3b7f76b..34b68555d5 100644 --- a/test/direct_transport.h +++ b/test/direct_transport.h @@ -13,12 +13,12 @@ #include #include "api/call/transport.h" +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_base.h" #include "api/test/simulated_network.h" #include "call/call.h" #include "call/simulated_packet_receiver.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/task_utils/repeating_task.h" #include "rtc_base/thread_annotations.h" diff --git a/test/direct_transport_unittest.cc b/test/direct_transport_unittest.cc index 66ab5bcac1..ab00971089 100644 --- a/test/direct_transport_unittest.cc +++ b/test/direct_transport_unittest.cc @@ -18,12 +18,13 @@ namespace test { TEST(DemuxerTest, Demuxing) { constexpr uint8_t kVideoPayloadType = 100; constexpr uint8_t kAudioPayloadType = 101; - constexpr size_t kPacketSize = 10; + constexpr size_t kPacketSize = 12; Demuxer demuxer({{kVideoPayloadType, MediaType::VIDEO}, {kAudioPayloadType, MediaType::AUDIO}}); uint8_t data[kPacketSize]; memset(data, 0, kPacketSize); + data[0] = 0x80; data[1] = kVideoPayloadType; EXPECT_EQ(demuxer.GetMediaType(data, kPacketSize), MediaType::VIDEO); data[1] = kAudioPayloadType; diff --git a/test/drifting_clock.cc b/test/drifting_clock.cc index 1a5154557e..47c8e56916 100644 --- a/test/drifting_clock.cc +++ b/test/drifting_clock.cc @@ -28,22 +28,18 @@ TimeDelta DriftingClock::Drift() const { return (now - start_time_) * drift_; } -Timestamp DriftingClock::CurrentTime() { - return clock_->CurrentTime() + Drift() / 1000.; +Timestamp DriftingClock::Drift(Timestamp timestamp) const { + return timestamp + Drift() / 1000.; } -NtpTime DriftingClock::CurrentNtpTime() { +NtpTime DriftingClock::Drift(NtpTime ntp_time) const { // NTP precision is 1/2^32 seconds, i.e. 2^32 ntp fractions = 1 second. const double kNtpFracPerMicroSecond = 4294.967296; // = 2^32 / 10^6 - NtpTime ntp = clock_->CurrentNtpTime(); - uint64_t total_fractions = static_cast(ntp); + uint64_t total_fractions = static_cast(ntp_time); total_fractions += Drift().us() * kNtpFracPerMicroSecond; return NtpTime(total_fractions); } -int64_t DriftingClock::CurrentNtpInMilliseconds() { - return clock_->CurrentNtpInMilliseconds() + Drift().ms(); -} } // namespace test } // namespace webrtc diff --git a/test/drifting_clock.h b/test/drifting_clock.h index 2539b61786..3471c008a1 100644 --- a/test/drifting_clock.h +++ b/test/drifting_clock.h @@ -30,12 +30,16 @@ class DriftingClock : public Clock { return 1.0f - percent / 100.0f; } - Timestamp CurrentTime() override; - NtpTime CurrentNtpTime() override; - int64_t CurrentNtpInMilliseconds() override; + Timestamp CurrentTime() override { return Drift(clock_->CurrentTime()); } + NtpTime CurrentNtpTime() override { return Drift(clock_->CurrentNtpTime()); } + NtpTime ConvertTimestampToNtpTime(Timestamp timestamp) override { + return Drift(clock_->ConvertTimestampToNtpTime(timestamp)); + } private: TimeDelta Drift() const; + Timestamp Drift(Timestamp timestamp) const; + NtpTime Drift(NtpTime ntp_time) const; Clock* const clock_; const float drift_; diff --git a/test/encoder_settings.cc b/test/encoder_settings.cc index f90931a83c..c8251883fd 100644 --- a/test/encoder_settings.cc +++ b/test/encoder_settings.cc @@ -68,10 +68,9 @@ std::vector CreateVideoStreams( : DefaultVideoStreamFactory::kMaxBitratePerStream[i]; max_bitrate_bps = std::min(bitrate_left_bps, max_bitrate_bps); - int target_bitrate_bps = - stream.target_bitrate_bps > 0 - ? stream.target_bitrate_bps - : DefaultVideoStreamFactory::kMaxBitratePerStream[i]; + int target_bitrate_bps = stream.target_bitrate_bps > 0 + ? stream.target_bitrate_bps + : max_bitrate_bps; target_bitrate_bps = std::min(max_bitrate_bps, target_bitrate_bps); if (stream.min_bitrate_bps > 0) { @@ -91,7 +90,8 @@ std::vector CreateVideoStreams( } stream_settings[i].target_bitrate_bps = target_bitrate_bps; stream_settings[i].max_bitrate_bps = max_bitrate_bps; - stream_settings[i].active = stream.active; + stream_settings[i].active = + encoder_config.number_of_streams == 1 || stream.active; bitrate_left_bps -= stream_settings[i].target_bitrate_bps; } @@ -120,7 +120,7 @@ void FillEncoderConfiguration(VideoCodecType codec_type, configuration->codec_type = codec_type; configuration->number_of_streams = num_streams; configuration->video_stream_factory = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); configuration->max_bitrate_bps = 0; configuration->simulcast_layers = std::vector(num_streams); for (size_t i = 0; i < num_streams; ++i) { diff --git a/test/fake_decoder.cc b/test/fake_decoder.cc index e229bbb2a1..f164bfbe03 100644 --- a/test/fake_decoder.cc +++ b/test/fake_decoder.cc @@ -27,11 +27,6 @@ namespace webrtc { namespace test { -namespace { -const int kDefaultWidth = 320; -const int kDefaultHeight = 180; -} // namespace - FakeDecoder::FakeDecoder() : FakeDecoder(nullptr) {} FakeDecoder::FakeDecoder(TaskQueueFactory* task_queue_factory) diff --git a/test/fake_decoder.h b/test/fake_decoder.h index 2ac2045bc0..6a5d6cb419 100644 --- a/test/fake_decoder.h +++ b/test/fake_decoder.h @@ -25,6 +25,8 @@ namespace test { class FakeDecoder : public VideoDecoder { public: + enum { kDefaultWidth = 320, kDefaultHeight = 180 }; + FakeDecoder(); explicit FakeDecoder(TaskQueueFactory* task_queue_factory); virtual ~FakeDecoder() {} diff --git a/test/fake_encoder.h b/test/fake_encoder.h index abd3134154..9feed1455f 100644 --- a/test/fake_encoder.h +++ b/test/fake_encoder.h @@ -18,6 +18,7 @@ #include #include "api/fec_controller_override.h" +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_factory.h" #include "api/video/encoded_image.h" #include "api/video/video_bitrate_allocation.h" @@ -26,7 +27,6 @@ #include "api/video_codecs/video_encoder.h" #include "modules/video_coding/include/video_codec_interface.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/thread_annotations.h" #include "system_wrappers/include/clock.h" diff --git a/test/fake_texture_frame.cc b/test/fake_texture_frame.cc index 4fa5e9d242..3f155184ab 100644 --- a/test/fake_texture_frame.cc +++ b/test/fake_texture_frame.cc @@ -23,7 +23,7 @@ VideoFrame FakeNativeBuffer::CreateFrame(int width, VideoRotation rotation) { return VideoFrame::Builder() .set_video_frame_buffer( - new rtc::RefCountedObject(width, height)) + rtc::make_ref_counted(width, height)) .set_timestamp_rtp(timestamp) .set_timestamp_ms(render_time_ms) .set_rotation(rotation) diff --git a/test/fake_vp8_encoder.h b/test/fake_vp8_encoder.h index 178a46070d..6aaf547379 100644 --- a/test/fake_vp8_encoder.h +++ b/test/fake_vp8_encoder.h @@ -17,13 +17,13 @@ #include #include "api/fec_controller_override.h" +#include "api/sequence_checker.h" #include "api/video/encoded_image.h" #include "api/video_codecs/video_codec.h" #include "api/video_codecs/video_encoder.h" #include "api/video_codecs/vp8_frame_buffer_controller.h" #include "api/video_codecs/vp8_temporal_layers.h" #include "modules/video_coding/include/video_codec_interface.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/thread_annotations.h" #include "system_wrappers/include/clock.h" #include "test/fake_encoder.h" diff --git a/test/frame_generator.cc b/test/frame_generator.cc index 4594e1de20..913a4fb589 100644 --- a/test/frame_generator.cc +++ b/test/frame_generator.cc @@ -21,7 +21,6 @@ #include "common_video/include/video_frame_buffer.h" #include "common_video/libyuv/include/webrtc_libyuv.h" #include "rtc_base/checks.h" -#include "rtc_base/keep_ref_until_done.h" #include "test/frame_utils.h" namespace webrtc { @@ -368,7 +367,8 @@ void ScrollingImageFrameGenerator::CropSourceToScrolledImage( &i420_buffer->DataY()[offset_y], i420_buffer->StrideY(), &i420_buffer->DataU()[offset_u], i420_buffer->StrideU(), &i420_buffer->DataV()[offset_v], i420_buffer->StrideV(), - KeepRefUntilDone(i420_buffer)), + // To keep reference alive. + [i420_buffer] {}), update_rect); } diff --git a/test/frame_generator_unittest.cc b/test/frame_generator_unittest.cc index 12d5111bff..8e5cde8c5f 100644 --- a/test/frame_generator_unittest.cc +++ b/test/frame_generator_unittest.cc @@ -54,7 +54,7 @@ class FrameGeneratorTest : public ::testing::Test { protected: void WriteYuvFile(FILE* file, uint8_t y, uint8_t u, uint8_t v) { - assert(file); + RTC_DCHECK(file); std::unique_ptr plane_buffer(new uint8_t[y_size]); memset(plane_buffer.get(), y, y_size); fwrite(plane_buffer.get(), 1, y_size, file); diff --git a/test/fuzzers/BUILD.gn b/test/fuzzers/BUILD.gn index 4975f42a98..9824bebb5f 100644 --- a/test/fuzzers/BUILD.gn +++ b/test/fuzzers/BUILD.gn @@ -18,11 +18,12 @@ rtc_library("webrtc_fuzzer_main") { ] # When WebRTC fuzzer tests are built on Chromium bots they need to link - # with Chromium's implementation of metrics and field trial. + # with Chromium's implementation of metrics, field trial, and system time. if (build_with_chromium) { deps += [ "../../../webrtc_overrides:field_trial", "../../../webrtc_overrides:metrics", + "../../../webrtc_overrides:system_time", ] } } @@ -244,6 +245,7 @@ webrtc_fuzzer_test("congestion_controller_feedback_fuzzer") { "../../modules/remote_bitrate_estimator", "../../modules/rtp_rtcp:rtp_rtcp_format", ] + absl_deps = [ "//third_party/abseil-cpp/absl/functional:bind_front" ] } rtc_library("audio_decoder_fuzzer") { @@ -405,6 +407,23 @@ webrtc_fuzzer_test("sdp_parser_fuzzer") { seed_corpus = "corpora/sdp-corpus" } +if (!build_with_chromium) { + # This target depends on test infrastructure that can't be built + # with Chromium at the moment. + # TODO(bugs.chromium.org/12534): Make this fuzzer build in Chromium. + + webrtc_fuzzer_test("sdp_integration_fuzzer") { + sources = [ "sdp_integration_fuzzer.cc" ] + deps = [ + "../../api:libjingle_peerconnection_api", + "../../pc:integration_test_helpers", + "../../pc:libjingle_peerconnection", + "../../test:test_support", + ] + seed_corpus = "corpora/sdp-corpus" + } +} + webrtc_fuzzer_test("stun_parser_fuzzer") { sources = [ "stun_parser_fuzzer.cc" ] deps = [ @@ -425,20 +444,12 @@ webrtc_fuzzer_test("stun_validator_fuzzer") { dict = "corpora/stun.tokens" } -webrtc_fuzzer_test("mdns_parser_fuzzer") { - sources = [ "mdns_parser_fuzzer.cc" ] - deps = [ - "../../p2p:rtc_p2p", - "../../rtc_base:rtc_base_approved", - ] - seed_corpus = "corpora/mdns-corpus" -} - webrtc_fuzzer_test("pseudotcp_parser_fuzzer") { sources = [ "pseudotcp_parser_fuzzer.cc" ] deps = [ "../../p2p:rtc_p2p", "../../rtc_base", + "../../rtc_base:threading", ] } @@ -594,14 +605,29 @@ webrtc_fuzzer_test("sctp_utils_fuzzer") { ] } +webrtc_fuzzer_test("dcsctp_socket_fuzzer") { + sources = [ "dcsctp_socket_fuzzer.cc" ] + deps = [ + "../../net/dcsctp/fuzzers:dcsctp_fuzzers", + "../../net/dcsctp/public:socket", + "../../net/dcsctp/public:types", + "../../net/dcsctp/socket:dcsctp_socket", + "../../rtc_base:rtc_base_approved", + ] +} + webrtc_fuzzer_test("rtp_header_parser_fuzzer") { sources = [ "rtp_header_parser_fuzzer.cc" ] deps = [ "../:rtp_test_utils" ] } webrtc_fuzzer_test("ssl_certificate_fuzzer") { - sources = [ "rtp_header_parser_fuzzer.cc" ] - deps = [ "../:rtp_test_utils" ] + sources = [ "ssl_certificate_fuzzer.cc" ] + deps = [ + "../:rtp_test_utils", + "../../rtc_base", + "../../rtc_base:stringutils", + ] } webrtc_fuzzer_test("vp8_replay_fuzzer") { @@ -613,6 +639,30 @@ webrtc_fuzzer_test("vp8_replay_fuzzer") { seed_corpus = "corpora/rtpdump-corpus/vp8" } +if (rtc_build_libvpx) { + webrtc_fuzzer_test("vp9_encoder_references_fuzzer") { + sources = [ "vp9_encoder_references_fuzzer.cc" ] + deps = [ + "..:test_support", + "../../api:array_view", + "../../api/transport:webrtc_key_value_config", + "../../api/video:video_frame", + "../../api/video_codecs:video_codecs_api", + "../../modules/video_coding:frame_dependencies_calculator", + "../../modules/video_coding:mock_libvpx_interface", + "../../modules/video_coding:webrtc_vp9", + "../../rtc_base:safe_compare", + rtc_libvpx_dir, + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/base:core_headers", + "//third_party/abseil-cpp/absl/container:inlined_vector", + ] + defines = [ "RTC_ENABLE_VP9" ] + } +} + webrtc_fuzzer_test("vp9_replay_fuzzer") { sources = [ "vp9_replay_fuzzer.cc" ] deps = [ diff --git a/test/fuzzers/DEPS b/test/fuzzers/DEPS index 82631c4a1b..50b1c8adce 100644 --- a/test/fuzzers/DEPS +++ b/test/fuzzers/DEPS @@ -1,4 +1,5 @@ include_rules = [ "+audio", "+pc", + "+net/dcsctp", ] diff --git a/test/fuzzers/congestion_controller_feedback_fuzzer.cc b/test/fuzzers/congestion_controller_feedback_fuzzer.cc index 084c8c300a..06a73b0434 100644 --- a/test/fuzzers/congestion_controller_feedback_fuzzer.cc +++ b/test/fuzzers/congestion_controller_feedback_fuzzer.cc @@ -8,6 +8,7 @@ * be found in the AUTHORS file in the root of the source tree. */ +#include "absl/functional/bind_front.h" #include "modules/congestion_controller/include/receive_side_congestion_controller.h" #include "modules/pacing/packet_router.h" #include "modules/remote_bitrate_estimator/include/remote_bitrate_estimator.h" @@ -21,7 +22,10 @@ void FuzzOneInput(const uint8_t* data, size_t size) { return; SimulatedClock clock(data[i++]); PacketRouter packet_router; - ReceiveSideCongestionController cc(&clock, &packet_router); + ReceiveSideCongestionController cc( + &clock, + absl::bind_front(&PacketRouter::SendCombinedRtcpPacket, &packet_router), + absl::bind_front(&PacketRouter::SendRemb, &packet_router), nullptr); RemoteBitrateEstimator* rbe = cc.GetRemoteBitrateEstimator(true); RTPHeader header; header.ssrc = ByteReader::ReadBigEndian(&data[i]); diff --git a/test/fuzzers/corpora/README b/test/fuzzers/corpora/README index d29e169417..cc87025ff6 100644 --- a/test/fuzzers/corpora/README +++ b/test/fuzzers/corpora/README @@ -31,4 +31,7 @@ which header extensions to enable, and the first byte of the fuzz data is used for this. ### PseudoTCP ### -Very small corpus minimised from the unit tests. \ No newline at end of file +Very small corpus minimised from the unit tests. + +### SCTP ### +This corpus was extracted from a few manually recorder wireshark dumps. diff --git a/test/fuzzers/corpora/sctp-packet-corpus/cookie-ack-sack.bin b/test/fuzzers/corpora/sctp-packet-corpus/cookie-ack-sack.bin new file mode 100644 index 0000000000..4374f5aad5 Binary files /dev/null and b/test/fuzzers/corpora/sctp-packet-corpus/cookie-ack-sack.bin differ diff --git a/test/fuzzers/corpora/sctp-packet-corpus/cookie-echo-data-data-data.bin b/test/fuzzers/corpora/sctp-packet-corpus/cookie-echo-data-data-data.bin new file mode 100644 index 0000000000..1f1d0be301 Binary files /dev/null and b/test/fuzzers/corpora/sctp-packet-corpus/cookie-echo-data-data-data.bin differ diff --git a/test/fuzzers/corpora/sctp-packet-corpus/cookie-echo-data-data.bin b/test/fuzzers/corpora/sctp-packet-corpus/cookie-echo-data-data.bin new file mode 100644 index 0000000000..21a0c22837 Binary files /dev/null and b/test/fuzzers/corpora/sctp-packet-corpus/cookie-echo-data-data.bin differ diff --git a/test/fuzzers/corpora/sctp-packet-corpus/cookie-echo-data.bin b/test/fuzzers/corpora/sctp-packet-corpus/cookie-echo-data.bin new file mode 100644 index 0000000000..fc8600106e Binary files /dev/null and b/test/fuzzers/corpora/sctp-packet-corpus/cookie-echo-data.bin differ diff --git a/test/fuzzers/corpora/sctp-packet-corpus/data-fragment1.bin b/test/fuzzers/corpora/sctp-packet-corpus/data-fragment1.bin new file mode 100644 index 0000000000..bec7b289e7 Binary files /dev/null and b/test/fuzzers/corpora/sctp-packet-corpus/data-fragment1.bin differ diff --git a/test/fuzzers/corpora/sctp-packet-corpus/forward-tsn.bin b/test/fuzzers/corpora/sctp-packet-corpus/forward-tsn.bin new file mode 100644 index 0000000000..ab98a0a4a7 Binary files /dev/null and b/test/fuzzers/corpora/sctp-packet-corpus/forward-tsn.bin differ diff --git a/test/fuzzers/corpora/sctp-packet-corpus/heartbeat-ack.bin b/test/fuzzers/corpora/sctp-packet-corpus/heartbeat-ack.bin new file mode 100644 index 0000000000..59200abe5e Binary files /dev/null and b/test/fuzzers/corpora/sctp-packet-corpus/heartbeat-ack.bin differ diff --git a/test/fuzzers/corpora/sctp-packet-corpus/heartbeat.bin b/test/fuzzers/corpora/sctp-packet-corpus/heartbeat.bin new file mode 100644 index 0000000000..cef8cfe929 Binary files /dev/null and b/test/fuzzers/corpora/sctp-packet-corpus/heartbeat.bin differ diff --git a/test/fuzzers/corpora/sctp-packet-corpus/init-ack.bin b/test/fuzzers/corpora/sctp-packet-corpus/init-ack.bin new file mode 100644 index 0000000000..80438434d0 Binary files /dev/null and b/test/fuzzers/corpora/sctp-packet-corpus/init-ack.bin differ diff --git a/test/fuzzers/corpora/sctp-packet-corpus/init.bin b/test/fuzzers/corpora/sctp-packet-corpus/init.bin new file mode 100644 index 0000000000..3fb4977d58 Binary files /dev/null and b/test/fuzzers/corpora/sctp-packet-corpus/init.bin differ diff --git a/test/fuzzers/corpora/sctp-packet-corpus/re-config.bin b/test/fuzzers/corpora/sctp-packet-corpus/re-config.bin new file mode 100644 index 0000000000..74c74f3377 Binary files /dev/null and b/test/fuzzers/corpora/sctp-packet-corpus/re-config.bin differ diff --git a/test/fuzzers/corpora/sctp-packet-corpus/sack-data.bin b/test/fuzzers/corpora/sctp-packet-corpus/sack-data.bin new file mode 100644 index 0000000000..fe4de63863 Binary files /dev/null and b/test/fuzzers/corpora/sctp-packet-corpus/sack-data.bin differ diff --git a/test/fuzzers/corpora/sctp-packet-corpus/sack-gap-ack-1.bin b/test/fuzzers/corpora/sctp-packet-corpus/sack-gap-ack-1.bin new file mode 100644 index 0000000000..08494c1515 Binary files /dev/null and b/test/fuzzers/corpora/sctp-packet-corpus/sack-gap-ack-1.bin differ diff --git a/test/fuzzers/dcsctp_packet_fuzzer.cc b/test/fuzzers/dcsctp_packet_fuzzer.cc new file mode 100644 index 0000000000..2fc3fe10f1 --- /dev/null +++ b/test/fuzzers/dcsctp_packet_fuzzer.cc @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/sctp_packet.h" + +namespace webrtc { +using dcsctp::SctpPacket; + +void FuzzOneInput(const uint8_t* data, size_t size) { + absl::optional c = + SctpPacket::Parse(rtc::ArrayView(data, size), + /*disable_checksum_verification=*/true); + + if (!c.has_value()) { + return; + } + + for (const SctpPacket::ChunkDescriptor& desc : c->descriptors()) { + dcsctp::DebugConvertChunkToString(desc.data); + } +} +} // namespace webrtc diff --git a/test/fuzzers/dcsctp_socket_fuzzer.cc b/test/fuzzers/dcsctp_socket_fuzzer.cc new file mode 100644 index 0000000000..390cbb7f6c --- /dev/null +++ b/test/fuzzers/dcsctp_socket_fuzzer.cc @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/fuzzers/dcsctp_fuzzers.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/socket/dcsctp_socket.h" +#include "rtc_base/logging.h" + +namespace webrtc { + +void FuzzOneInput(const uint8_t* data, size_t size) { + dcsctp::dcsctp_fuzzers::FuzzerCallbacks cb; + dcsctp::DcSctpOptions options; + options.disable_checksum_verification = true; + dcsctp::DcSctpSocket socket("A", cb, nullptr, options); + + dcsctp::dcsctp_fuzzers::FuzzSocket(socket, cb, + rtc::ArrayView(data, size)); +} +} // namespace webrtc diff --git a/test/fuzzers/frame_buffer2_fuzzer.cc b/test/fuzzers/frame_buffer2_fuzzer.cc index 7ec7da5eca..0572675f71 100644 --- a/test/fuzzers/frame_buffer2_fuzzer.cc +++ b/test/fuzzers/frame_buffer2_fuzzer.cc @@ -49,7 +49,7 @@ struct DataReader { size_t offset_ = 0; }; -class FuzzyFrameObject : public video_coding::EncodedFrame { +class FuzzyFrameObject : public EncodedFrame { public: FuzzyFrameObject() {} ~FuzzyFrameObject() {} @@ -77,11 +77,11 @@ void FuzzOneInput(const uint8_t* data, size_t size) { while (reader.MoreToRead()) { if (reader.GetNum() % 2) { std::unique_ptr frame(new FuzzyFrameObject()); - frame->id.picture_id = reader.GetNum(); - frame->id.spatial_layer = reader.GetNum() % 5; + frame->SetId(reader.GetNum()); + frame->SetSpatialIndex(reader.GetNum() % 5); frame->SetTimestamp(reader.GetNum()); - frame->num_references = reader.GetNum() % - video_coding::EncodedFrame::kMaxFrameReferences; + frame->num_references = + reader.GetNum() % EncodedFrame::kMaxFrameReferences; for (size_t r = 0; r < frame->num_references; ++r) frame->references[r] = reader.GetNum(); @@ -98,7 +98,7 @@ void FuzzOneInput(const uint8_t* data, size_t size) { frame_buffer.NextFrame( max_wait_time_ms, keyframe_required, &task_queue, [&next_frame_task_running]( - std::unique_ptr frame, + std::unique_ptr frame, video_coding::FrameBuffer::ReturnReason res) { next_frame_task_running = false; }); diff --git a/test/fuzzers/mdns_parser_fuzzer.cc b/test/fuzzers/mdns_parser_fuzzer.cc deleted file mode 100644 index 451742327f..0000000000 --- a/test/fuzzers/mdns_parser_fuzzer.cc +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright 2018 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include -#include - -#include - -#include "p2p/base/mdns_message.h" -#include "rtc_base/message_buffer_reader.h" - -namespace webrtc { - -void FuzzOneInput(const uint8_t* data, size_t size) { - MessageBufferReader buf(reinterpret_cast(data), size); - auto mdns_msg = std::make_unique(); - mdns_msg->Read(&buf); -} - -} // namespace webrtc diff --git a/test/fuzzers/packet_buffer_fuzzer.cc b/test/fuzzers/packet_buffer_fuzzer.cc index 30f452c9b7..ea9d4896f1 100644 --- a/test/fuzzers/packet_buffer_fuzzer.cc +++ b/test/fuzzers/packet_buffer_fuzzer.cc @@ -13,7 +13,6 @@ #include "modules/video_coding/frame_object.h" #include "modules/video_coding/packet_buffer.h" -#include "system_wrappers/include/clock.h" #include "test/fuzzers/fuzz_data_helper.h" namespace webrtc { @@ -24,8 +23,7 @@ void FuzzOneInput(const uint8_t* data, size_t size) { if (size > 200000) { return; } - SimulatedClock clock(0); - video_coding::PacketBuffer packet_buffer(&clock, 8, 1024); + video_coding::PacketBuffer packet_buffer(8, 1024); test::FuzzDataHelper helper(rtc::ArrayView(data, size)); while (helper.BytesLeft()) { @@ -35,7 +33,6 @@ void FuzzOneInput(const uint8_t* data, size_t size) { helper.CopyTo(&packet->payload_type); helper.CopyTo(&packet->seq_num); helper.CopyTo(&packet->timestamp); - helper.CopyTo(&packet->ntp_time_ms); helper.CopyTo(&packet->times_nacked); // Fuzz non-POD member of the packet. diff --git a/test/fuzzers/rtp_frame_reference_finder_fuzzer.cc b/test/fuzzers/rtp_frame_reference_finder_fuzzer.cc index 8b19a088de..fdb4aa5f3c 100644 --- a/test/fuzzers/rtp_frame_reference_finder_fuzzer.cc +++ b/test/fuzzers/rtp_frame_reference_finder_fuzzer.cc @@ -12,9 +12,7 @@ #include "api/rtp_packet_infos.h" #include "modules/video_coding/frame_object.h" -#include "modules/video_coding/packet_buffer.h" #include "modules/video_coding/rtp_frame_reference_finder.h" -#include "system_wrappers/include/clock.h" namespace webrtc { @@ -58,11 +56,6 @@ class DataReader { size_t offset_ = 0; }; -class NullCallback : public video_coding::OnCompleteFrameCallback { - void OnCompleteFrame( - std::unique_ptr frame) override {} -}; - absl::optional GenerateGenericFrameDependencies(DataReader* reader) { absl::optional result; @@ -92,8 +85,7 @@ GenerateGenericFrameDependencies(DataReader* reader) { void FuzzOneInput(const uint8_t* data, size_t size) { DataReader reader(data, size); - NullCallback cb; - video_coding::RtpFrameReferenceFinder reference_finder(&cb); + RtpFrameReferenceFinder reference_finder; auto codec = static_cast(reader.GetNum() % 5); @@ -135,7 +127,7 @@ void FuzzOneInput(const uint8_t* data, size_t size) { video_header.generic = GenerateGenericFrameDependencies(&reader); // clang-format off - auto frame = std::make_unique( + auto frame = std::make_unique( first_seq_num, last_seq_num, marker_bit, diff --git a/test/fuzzers/rtp_header_parser_fuzzer.cc b/test/fuzzers/rtp_header_parser_fuzzer.cc index d6af5ca3ce..435c64bbb4 100644 --- a/test/fuzzers/rtp_header_parser_fuzzer.cc +++ b/test/fuzzers/rtp_header_parser_fuzzer.cc @@ -20,29 +20,7 @@ namespace webrtc { void FuzzOneInput(const uint8_t* data, size_t size) { - RtpHeaderParser::IsRtcp(data, size); RtpHeaderParser::GetSsrc(data, size); - RTPHeader rtp_header; - - std::unique_ptr rtp_header_parser( - RtpHeaderParser::CreateForTest()); - - rtp_header_parser->Parse(data, size, &rtp_header); - for (int i = 1; i < kRtpExtensionNumberOfExtensions; ++i) { - if (size > 0 && i >= data[size - 1]) { - RTPExtensionType add_extension = static_cast(i); - rtp_header_parser->RegisterRtpHeaderExtension(add_extension, i); - } - } - rtp_header_parser->Parse(data, size, &rtp_header); - - for (int i = 1; i < kRtpExtensionNumberOfExtensions; ++i) { - if (size > 1 && i >= data[size - 2]) { - RTPExtensionType remove_extension = static_cast(i); - rtp_header_parser->DeregisterRtpHeaderExtension(remove_extension); - } - } - rtp_header_parser->Parse(data, size, &rtp_header); } } // namespace webrtc diff --git a/test/fuzzers/rtp_packet_fuzzer.cc b/test/fuzzers/rtp_packet_fuzzer.cc index 3f03114a33..3f2fc5e668 100644 --- a/test/fuzzers/rtp_packet_fuzzer.cc +++ b/test/fuzzers/rtp_packet_fuzzer.cc @@ -9,6 +9,7 @@ */ #include +#include #include "absl/types/optional.h" #include "modules/rtp_rtcp/include/rtp_header_extension_map.h" @@ -76,6 +77,11 @@ void FuzzOneInput(const uint8_t* data, size_t size) { uint8_t audio_level; packet.GetExtension(&voice_activity, &audio_level); break; + case kRtpExtensionCsrcAudioLevel: { + std::vector audio_levels; + packet.GetExtension(&audio_levels); + break; + } case kRtpExtensionAbsoluteSendTime: uint32_t sendtime; packet.GetExtension(&sendtime); @@ -109,10 +115,11 @@ void FuzzOneInput(const uint8_t* data, size_t size) { VideoContentType content_type; packet.GetExtension(&content_type); break; - case kRtpExtensionVideoTiming: + case kRtpExtensionVideoTiming: { VideoSendTiming timing; packet.GetExtension(&timing); break; + } case kRtpExtensionRtpStreamId: { std::string rsid; packet.GetExtension(&rsid); @@ -148,6 +155,11 @@ void FuzzOneInput(const uint8_t* data, size_t size) { packet.GetExtension(&allocation); break; } + case kRtpExtensionVideoFrameTrackingId: { + uint16_t tracking_id; + packet.GetExtension(&tracking_id); + break; + } case kRtpExtensionGenericFrameDescriptor02: // This extension requires state to read and so complicated that // deserves own fuzzer. diff --git a/test/fuzzers/sdp_integration_fuzzer.cc b/test/fuzzers/sdp_integration_fuzzer.cc new file mode 100644 index 0000000000..bc181f0573 --- /dev/null +++ b/test/fuzzers/sdp_integration_fuzzer.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include +#include + +#include "pc/test/integration_test_helpers.h" + +namespace webrtc { + +class FuzzerTest : public PeerConnectionIntegrationBaseTest { + public: + FuzzerTest() + : PeerConnectionIntegrationBaseTest(SdpSemantics::kUnifiedPlan) {} + + void TestBody() override {} +}; + +void FuzzOneInput(const uint8_t* data, size_t size) { + if (size > 16384) { + return; + } + std::string message(reinterpret_cast(data), size); + + FuzzerTest test; + test.CreatePeerConnectionWrappers(); + // Note - we do not do test.ConnectFakeSignaling(); all signals + // generated are discarded. + + auto srd_observer = + rtc::make_ref_counted(); + + webrtc::SdpParseError error; + std::unique_ptr sdp( + CreateSessionDescription("offer", message, &error)); + // Note: This form of SRD takes ownership of the description. + test.caller()->pc()->SetRemoteDescription(srd_observer, sdp.release()); + // Wait a short time for observer to be called. Timeout is short + // because the fuzzer should be trying many branches. + EXPECT_TRUE_WAIT(srd_observer->called(), 100); + + // If set-remote-description was successful, try to answer. + auto sld_observer = + rtc::make_ref_counted(); + if (srd_observer->result()) { + test.caller()->pc()->SetLocalDescription(sld_observer.get()); + EXPECT_TRUE_WAIT(sld_observer->called(), 100); + } +} + +} // namespace webrtc diff --git a/test/fuzzers/ssl_certificate_fuzzer.cc b/test/fuzzers/ssl_certificate_fuzzer.cc index 7ab59b51dd..4bab5c8f02 100644 --- a/test/fuzzers/ssl_certificate_fuzzer.cc +++ b/test/fuzzers/ssl_certificate_fuzzer.cc @@ -13,6 +13,7 @@ #include +#include "rtc_base/message_digest.h" #include "rtc_base/ssl_certificate.h" #include "rtc_base/string_encode.h" @@ -34,7 +35,7 @@ void FuzzOneInput(const uint8_t* data, size_t size) { cert->CertificateExpirationTime(); std::string algorithm; - cert->GetSignatureDigestAlgorithm(algorithm); + cert->GetSignatureDigestAlgorithm(&algorithm); unsigned char digest[rtc::MessageDigest::kMaxSize]; size_t digest_len; diff --git a/test/fuzzers/stun_parser_fuzzer.cc b/test/fuzzers/stun_parser_fuzzer.cc index 720a699662..6ca9eac8b2 100644 --- a/test/fuzzers/stun_parser_fuzzer.cc +++ b/test/fuzzers/stun_parser_fuzzer.cc @@ -24,5 +24,6 @@ void FuzzOneInput(const uint8_t* data, size_t size) { std::unique_ptr stun_msg(new cricket::IceMessage()); rtc::ByteBufferReader buf(message, size); stun_msg->Read(&buf); + stun_msg->ValidateMessageIntegrity(""); } } // namespace webrtc diff --git a/test/fuzzers/stun_validator_fuzzer.cc b/test/fuzzers/stun_validator_fuzzer.cc index 44252fafbc..421638db1b 100644 --- a/test/fuzzers/stun_validator_fuzzer.cc +++ b/test/fuzzers/stun_validator_fuzzer.cc @@ -18,6 +18,6 @@ void FuzzOneInput(const uint8_t* data, size_t size) { const char* message = reinterpret_cast(data); cricket::StunMessage::ValidateFingerprint(message, size); - cricket::StunMessage::ValidateMessageIntegrity(message, size, ""); + cricket::StunMessage::ValidateMessageIntegrityForTesting(message, size, ""); } } // namespace webrtc diff --git a/test/fuzzers/utils/BUILD.gn b/test/fuzzers/utils/BUILD.gn index 6249156058..3e0782f39d 100644 --- a/test/fuzzers/utils/BUILD.gn +++ b/test/fuzzers/utils/BUILD.gn @@ -24,6 +24,7 @@ rtc_library("rtp_replayer") { "../../../call:call_interfaces", "../../../common_video", "../../../media:rtc_internal_video_codecs", + "../../../modules/rtp_rtcp:rtp_rtcp_format", "../../../rtc_base:checks", "../../../rtc_base:rtc_base_approved", "../../../rtc_base:rtc_base_tests_utils", diff --git a/test/fuzzers/utils/rtp_replayer.cc b/test/fuzzers/utils/rtp_replayer.cc index a664adb31d..43b1fc2ea4 100644 --- a/test/fuzzers/utils/rtp_replayer.cc +++ b/test/fuzzers/utils/rtp_replayer.cc @@ -17,13 +17,13 @@ #include "api/task_queue/default_task_queue_factory.h" #include "api/transport/field_trial_based_config.h" +#include "modules/rtp_rtcp/source/rtp_packet.h" #include "rtc_base/strings/json.h" #include "system_wrappers/include/clock.h" #include "test/call_config_utils.h" #include "test/encoder_settings.h" #include "test/fake_decoder.h" #include "test/rtp_file_reader.h" -#include "test/rtp_header_parser.h" #include "test/run_loop.h" namespace webrtc { @@ -164,37 +164,32 @@ void RtpReplayer::ReplayPackets(rtc::FakeClock* clock, std::min(deliver_in_ms, static_cast(100)))); } + rtc::CopyOnWriteBuffer packet_buffer(packet.data, packet.length); ++num_packets; - switch (call->Receiver()->DeliverPacket( - webrtc::MediaType::VIDEO, - rtc::CopyOnWriteBuffer(packet.data, packet.length), - /* packet_time_us */ -1)) { + switch (call->Receiver()->DeliverPacket(webrtc::MediaType::VIDEO, + packet_buffer, + /* packet_time_us */ -1)) { case PacketReceiver::DELIVERY_OK: break; case PacketReceiver::DELIVERY_UNKNOWN_SSRC: { - RTPHeader header; - std::unique_ptr parser( - RtpHeaderParser::CreateForTest()); - - parser->Parse(packet.data, packet.length, &header); - if (unknown_packets[header.ssrc] == 0) { - RTC_LOG(LS_ERROR) << "Unknown SSRC: " << header.ssrc; + webrtc::RtpPacket header; + header.Parse(packet_buffer); + if (unknown_packets[header.Ssrc()] == 0) { + RTC_LOG(LS_ERROR) << "Unknown SSRC: " << header.Ssrc(); } - ++unknown_packets[header.ssrc]; + ++unknown_packets[header.Ssrc()]; break; } case PacketReceiver::DELIVERY_PACKET_ERROR: { RTC_LOG(LS_ERROR) << "Packet error, corrupt packets or incorrect setup?"; - RTPHeader header; - std::unique_ptr parser( - RtpHeaderParser::CreateForTest()); - parser->Parse(packet.data, packet.length, &header); + webrtc::RtpPacket header; + header.Parse(packet_buffer); RTC_LOG(LS_ERROR) << "Packet packet_length=" << packet.length - << " payload_type=" << header.payloadType - << " sequence_number=" << header.sequenceNumber - << " time_stamp=" << header.timestamp - << " ssrc=" << header.ssrc; + << " payload_type=" << header.PayloadType() + << " sequence_number=" << header.SequenceNumber() + << " time_stamp=" << header.Timestamp() + << " ssrc=" << header.Ssrc(); break; } } diff --git a/test/fuzzers/vp9_encoder_references_fuzzer.cc b/test/fuzzers/vp9_encoder_references_fuzzer.cc new file mode 100644 index 0000000000..9c793ae9aa --- /dev/null +++ b/test/fuzzers/vp9_encoder_references_fuzzer.cc @@ -0,0 +1,498 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include + +#include "absl/algorithm/container.h" +#include "absl/base/macros.h" +#include "absl/container/inlined_vector.h" +#include "api/array_view.h" +#include "api/transport/webrtc_key_value_config.h" +#include "api/video/video_frame.h" +#include "api/video_codecs/video_codec.h" +#include "api/video_codecs/video_encoder.h" +#include "modules/video_coding/codecs/interface/mock_libvpx_interface.h" +#include "modules/video_coding/codecs/vp9/libvpx_vp9_encoder.h" +#include "modules/video_coding/frame_dependencies_calculator.h" +#include "rtc_base/numerics/safe_compare.h" +#include "test/fuzzers/fuzz_data_helper.h" +#include "test/gmock.h" + +// Fuzzer simulates various svc configurations and libvpx encoder dropping +// layer frames. +// Validates vp9 encoder wrapper produces consistent frame references. +namespace webrtc { +namespace { + +using test::FuzzDataHelper; +using ::testing::NiceMock; + +class FrameValidator : public EncodedImageCallback { + public: + ~FrameValidator() override = default; + + Result OnEncodedImage(const EncodedImage& encoded_image, + const CodecSpecificInfo* codec_specific_info) override { + RTC_CHECK(codec_specific_info); + RTC_CHECK_EQ(codec_specific_info->codecType, kVideoCodecVP9); + if (codec_specific_info->codecSpecific.VP9.first_frame_in_picture) { + ++picture_id_; + } + int64_t frame_id = frame_id_++; + LayerFrame& layer_frame = frames_[frame_id % kMaxFrameHistorySize]; + layer_frame.picture_id = picture_id_; + layer_frame.spatial_id = encoded_image.SpatialIndex().value_or(0); + layer_frame.frame_id = frame_id; + layer_frame.temporal_id = + codec_specific_info->codecSpecific.VP9.temporal_idx; + if (layer_frame.temporal_id == kNoTemporalIdx) { + layer_frame.temporal_id = 0; + } + layer_frame.vp9_non_ref_for_inter_layer_pred = + codec_specific_info->codecSpecific.VP9.non_ref_for_inter_layer_pred; + CheckVp9References(layer_frame, codec_specific_info->codecSpecific.VP9); + + if (codec_specific_info->generic_frame_info.has_value()) { + absl::InlinedVector frame_dependencies = + dependencies_calculator_.FromBuffersUsage( + frame_id, + codec_specific_info->generic_frame_info->encoder_buffers); + + CheckGenericReferences(frame_dependencies, + *codec_specific_info->generic_frame_info); + CheckGenericAndCodecSpecificReferencesAreConsistent( + frame_dependencies, *codec_specific_info, layer_frame); + } + + return Result(Result::OK); + } + + private: + // With 4 spatial layers and patterns up to 8 pictures, it should be enought + // to keep 32 last frames to validate dependencies. + static constexpr size_t kMaxFrameHistorySize = 32; + struct LayerFrame { + int64_t frame_id; + int64_t picture_id; + int spatial_id; + int temporal_id; + bool vp9_non_ref_for_inter_layer_pred; + }; + + void CheckVp9References(const LayerFrame& layer_frame, + const CodecSpecificInfoVP9& vp9_info) { + if (layer_frame.frame_id == 0) { + RTC_CHECK(!vp9_info.inter_layer_predicted); + } else { + const LayerFrame& previous_frame = Frame(layer_frame.frame_id - 1); + if (vp9_info.inter_layer_predicted) { + RTC_CHECK(!previous_frame.vp9_non_ref_for_inter_layer_pred); + RTC_CHECK_EQ(layer_frame.picture_id, previous_frame.picture_id); + } + if (previous_frame.picture_id == layer_frame.picture_id) { + RTC_CHECK_GT(layer_frame.spatial_id, previous_frame.spatial_id); + // The check below would fail for temporal shift structures. Remove it + // or move it to !flexible_mode section when vp9 encoder starts + // supporting such structures. + RTC_CHECK_EQ(layer_frame.temporal_id, previous_frame.temporal_id); + } + } + if (!vp9_info.flexible_mode) { + if (vp9_info.gof.num_frames_in_gof > 0) { + gof_.CopyGofInfoVP9(vp9_info.gof); + } + RTC_CHECK_EQ(gof_.temporal_idx[vp9_info.gof_idx], + layer_frame.temporal_id); + } + } + + void CheckGenericReferences(rtc::ArrayView frame_dependencies, + const GenericFrameInfo& generic_info) const { + for (int64_t dependency_frame_id : frame_dependencies) { + RTC_CHECK_GE(dependency_frame_id, 0); + const LayerFrame& dependency = Frame(dependency_frame_id); + RTC_CHECK_GE(generic_info.spatial_id, dependency.spatial_id); + RTC_CHECK_GE(generic_info.temporal_id, dependency.temporal_id); + } + } + + void CheckGenericAndCodecSpecificReferencesAreConsistent( + rtc::ArrayView frame_dependencies, + const CodecSpecificInfo& info, + const LayerFrame& layer_frame) const { + const CodecSpecificInfoVP9& vp9_info = info.codecSpecific.VP9; + const GenericFrameInfo& generic_info = *info.generic_frame_info; + + RTC_CHECK_EQ(generic_info.spatial_id, layer_frame.spatial_id); + RTC_CHECK_EQ(generic_info.temporal_id, layer_frame.temporal_id); + auto picture_id_diffs = + rtc::MakeArrayView(vp9_info.p_diff, vp9_info.num_ref_pics); + RTC_CHECK_EQ( + frame_dependencies.size(), + picture_id_diffs.size() + (vp9_info.inter_layer_predicted ? 1 : 0)); + for (int64_t dependency_frame_id : frame_dependencies) { + RTC_CHECK_GE(dependency_frame_id, 0); + const LayerFrame& dependency = Frame(dependency_frame_id); + if (dependency.spatial_id != layer_frame.spatial_id) { + RTC_CHECK(vp9_info.inter_layer_predicted); + RTC_CHECK_EQ(layer_frame.picture_id, dependency.picture_id); + RTC_CHECK_GT(layer_frame.spatial_id, dependency.spatial_id); + } else { + RTC_CHECK(vp9_info.inter_pic_predicted); + RTC_CHECK_EQ(layer_frame.spatial_id, dependency.spatial_id); + RTC_CHECK(absl::c_linear_search( + picture_id_diffs, layer_frame.picture_id - dependency.picture_id)); + } + } + } + + const LayerFrame& Frame(int64_t frame_id) const { + auto& frame = frames_[frame_id % kMaxFrameHistorySize]; + RTC_CHECK_EQ(frame.frame_id, frame_id); + return frame; + } + + GofInfoVP9 gof_; + int64_t frame_id_ = 0; + int64_t picture_id_ = 1; + FrameDependenciesCalculator dependencies_calculator_; + LayerFrame frames_[kMaxFrameHistorySize]; +}; + +class FieldTrials : public WebRtcKeyValueConfig { + public: + explicit FieldTrials(FuzzDataHelper& config) + : flags_(config.ReadOrDefaultValue(0)) {} + + ~FieldTrials() override = default; + std::string Lookup(absl::string_view key) const override { + static constexpr absl::string_view kBinaryFieldTrials[] = { + "WebRTC-Vp9DependencyDescriptor", + "WebRTC-Vp9ExternalRefCtrl", + "WebRTC-Vp9IssueKeyFrameOnLayerDeactivation", + }; + for (size_t i = 0; i < ABSL_ARRAYSIZE(kBinaryFieldTrials); ++i) { + if (key == kBinaryFieldTrials[i]) { + return (flags_ & (1u << i)) ? "Enabled" : "Disabled"; + } + } + + // Ignore following field trials. + if (key == "WebRTC-CongestionWindow" || + key == "WebRTC-UseBaseHeavyVP8TL3RateAllocation" || + key == "WebRTC-SimulcastUpswitchHysteresisPercent" || + key == "WebRTC-SimulcastScreenshareUpswitchHysteresisPercent" || + key == "WebRTC-VideoRateControl" || + key == "WebRTC-VP9-PerformanceFlags" || + key == "WebRTC-VP9VariableFramerateScreenshare" || + key == "WebRTC-VP9QualityScaler") { + return ""; + } + // Crash when using unexpected field trial to decide if it should be fuzzed + // or have a constant value. + RTC_CHECK(false) << "Unfuzzed field trial " << key << "\n"; + } + + private: + const uint8_t flags_; +}; + +VideoCodec CodecSettings(FuzzDataHelper& rng) { + uint16_t config = rng.ReadOrDefaultValue(0); + // Test up to to 4 spatial and 4 temporal layers. + int num_spatial_layers = 1 + (config & 0b11); + int num_temporal_layers = 1 + ((config >> 2) & 0b11); + + VideoCodec codec_settings = {}; + codec_settings.codecType = kVideoCodecVP9; + codec_settings.maxFramerate = 30; + codec_settings.width = 320 << (num_spatial_layers - 1); + codec_settings.height = 180 << (num_spatial_layers - 1); + if (num_spatial_layers > 1) { + for (int sid = 0; sid < num_spatial_layers; ++sid) { + SpatialLayer& spatial_layer = codec_settings.spatialLayers[sid]; + codec_settings.width = 320 << sid; + codec_settings.height = 180 << sid; + spatial_layer.maxFramerate = codec_settings.maxFramerate; + spatial_layer.numberOfTemporalLayers = num_temporal_layers; + } + } + codec_settings.VP9()->numberOfSpatialLayers = num_spatial_layers; + codec_settings.VP9()->numberOfTemporalLayers = num_temporal_layers; + int inter_layer_pred = (config >> 4) & 0b11; + // There are only 3 valid values. + codec_settings.VP9()->interLayerPred = static_cast( + inter_layer_pred < 3 ? inter_layer_pred : 0); + codec_settings.VP9()->flexibleMode = (config & (1u << 6)) != 0; + codec_settings.VP9()->frameDroppingOn = (config & (1u << 7)) != 0; + codec_settings.mode = VideoCodecMode::kRealtimeVideo; + return codec_settings; +} + +VideoEncoder::Settings EncoderSettings() { + return VideoEncoder::Settings(VideoEncoder::Capabilities(false), + /*number_of_cores=*/1, + /*max_payload_size=*/0); +} + +struct LibvpxState { + LibvpxState() { + pkt.kind = VPX_CODEC_CX_FRAME_PKT; + pkt.data.frame.buf = pkt_buffer; + pkt.data.frame.sz = ABSL_ARRAYSIZE(pkt_buffer); + layer_id.spatial_layer_id = -1; + } + + uint8_t pkt_buffer[1000] = {}; + vpx_codec_enc_cfg_t config = {}; + vpx_codec_priv_output_cx_pkt_cb_pair_t callback = {}; + vpx_image_t img = {}; + vpx_svc_ref_frame_config_t ref_config = {}; + vpx_svc_layer_id_t layer_id = {}; + vpx_svc_frame_drop_t frame_drop = {}; + vpx_codec_cx_pkt pkt = {}; +}; + +class StubLibvpx : public NiceMock { + public: + explicit StubLibvpx(LibvpxState* state) : state_(state) { RTC_CHECK(state_); } + + vpx_codec_err_t codec_enc_config_default(vpx_codec_iface_t* iface, + vpx_codec_enc_cfg_t* cfg, + unsigned int usage) const override { + state_->config = *cfg; + return VPX_CODEC_OK; + } + + vpx_codec_err_t codec_enc_init(vpx_codec_ctx_t* ctx, + vpx_codec_iface_t* iface, + const vpx_codec_enc_cfg_t* cfg, + vpx_codec_flags_t flags) const override { + RTC_CHECK(ctx); + ctx->err = VPX_CODEC_OK; + return VPX_CODEC_OK; + } + + vpx_image_t* img_wrap(vpx_image_t* img, + vpx_img_fmt_t fmt, + unsigned int d_w, + unsigned int d_h, + unsigned int stride_align, + unsigned char* img_data) const override { + state_->img.fmt = fmt; + state_->img.d_w = d_w; + state_->img.d_h = d_h; + return &state_->img; + } + + vpx_codec_err_t codec_encode(vpx_codec_ctx_t* ctx, + const vpx_image_t* img, + vpx_codec_pts_t pts, + uint64_t duration, + vpx_enc_frame_flags_t flags, + uint64_t deadline) const override { + if (flags & VPX_EFLAG_FORCE_KF) { + state_->pkt.data.frame.flags = VPX_FRAME_IS_KEY; + } else { + state_->pkt.data.frame.flags = 0; + } + state_->pkt.data.frame.duration = duration; + return VPX_CODEC_OK; + } + + vpx_codec_err_t codec_control(vpx_codec_ctx_t* ctx, + vp8e_enc_control_id ctrl_id, + void* param) const override { + if (ctrl_id == VP9E_REGISTER_CX_CALLBACK) { + state_->callback = + *reinterpret_cast(param); + } + return VPX_CODEC_OK; + } + + vpx_codec_err_t codec_control( + vpx_codec_ctx_t* ctx, + vp8e_enc_control_id ctrl_id, + vpx_svc_ref_frame_config_t* param) const override { + switch (ctrl_id) { + case VP9E_SET_SVC_REF_FRAME_CONFIG: + state_->ref_config = *param; + break; + case VP9E_GET_SVC_REF_FRAME_CONFIG: + *param = state_->ref_config; + break; + default: + break; + } + return VPX_CODEC_OK; + } + + vpx_codec_err_t codec_control(vpx_codec_ctx_t* ctx, + vp8e_enc_control_id ctrl_id, + vpx_svc_layer_id_t* param) const override { + switch (ctrl_id) { + case VP9E_SET_SVC_LAYER_ID: + state_->layer_id = *param; + break; + case VP9E_GET_SVC_LAYER_ID: + *param = state_->layer_id; + break; + default: + break; + } + return VPX_CODEC_OK; + } + + vpx_codec_err_t codec_control(vpx_codec_ctx_t* ctx, + vp8e_enc_control_id ctrl_id, + vpx_svc_frame_drop_t* param) const override { + if (ctrl_id == VP9E_SET_SVC_FRAME_DROP_LAYER) { + state_->frame_drop = *param; + } + return VPX_CODEC_OK; + } + + vpx_codec_err_t codec_enc_config_set( + vpx_codec_ctx_t* ctx, + const vpx_codec_enc_cfg_t* cfg) const override { + state_->config = *cfg; + return VPX_CODEC_OK; + } + + private: + LibvpxState* const state_; +}; + +enum Actions { + kEncode, + kSetRates, +}; + +// When a layer frame is marked for drop, drops all layer frames from that +// pictures with larger spatial ids. +constexpr bool DropAbove(uint8_t layers_mask, int sid) { + uint8_t full_mask = (uint8_t{1} << (sid + 1)) - 1; + return (layers_mask & full_mask) != full_mask; +} +// inline unittests +static_assert(DropAbove(0b1011, /*sid=*/0) == false, ""); +static_assert(DropAbove(0b1011, /*sid=*/1) == false, ""); +static_assert(DropAbove(0b1011, /*sid=*/2) == true, ""); +static_assert(DropAbove(0b1011, /*sid=*/3) == true, ""); + +// When a layer frame is marked for drop, drops all layer frames from that +// pictures with smaller spatial ids. +constexpr bool DropBelow(uint8_t layers_mask, int sid, int num_layers) { + return (layers_mask >> sid) != (1 << (num_layers - sid)) - 1; +} +// inline unittests +static_assert(DropBelow(0b1101, /*sid=*/0, 4) == true, ""); +static_assert(DropBelow(0b1101, /*sid=*/1, 4) == true, ""); +static_assert(DropBelow(0b1101, /*sid=*/2, 4) == false, ""); +static_assert(DropBelow(0b1101, /*sid=*/3, 4) == false, ""); + +} // namespace + +void FuzzOneInput(const uint8_t* data, size_t size) { + FuzzDataHelper helper(rtc::MakeArrayView(data, size)); + + FrameValidator validator; + FieldTrials field_trials(helper); + // Setup call callbacks for the fake + LibvpxState state; + + // Initialize encoder + LibvpxVp9Encoder encoder(cricket::VideoCodec(), + std::make_unique(&state), field_trials); + VideoCodec codec = CodecSettings(helper); + if (encoder.InitEncode(&codec, EncoderSettings()) != WEBRTC_VIDEO_CODEC_OK) { + return; + } + RTC_CHECK_EQ(encoder.RegisterEncodeCompleteCallback(&validator), + WEBRTC_VIDEO_CODEC_OK); + { + // Enable all the layers initially. Encoder doesn't support producing + // frames when no layers are enabled. + LibvpxVp9Encoder::RateControlParameters parameters; + parameters.framerate_fps = 30.0; + for (int sid = 0; sid < codec.VP9()->numberOfSpatialLayers; ++sid) { + for (int tid = 0; tid < codec.VP9()->numberOfTemporalLayers; ++tid) { + parameters.bitrate.SetBitrate(sid, tid, 100'000); + } + } + encoder.SetRates(parameters); + } + + std::vector frame_types(1); + VideoFrame fake_image = VideoFrame::Builder() + .set_video_frame_buffer(I420Buffer::Create( + int{codec.width}, int{codec.height})) + .build(); + + // Start producing frames at random. + while (helper.CanReadBytes(1)) { + uint8_t action = helper.Read(); + switch (action & 0b11) { + case kEncode: { + // bitmask of the action: SSSS-K00, where + // four S bit indicate which spatial layers should be produced, + // K bit indicates if frame should be a key frame. + frame_types[0] = (action & 0b100) ? VideoFrameType::kVideoFrameKey + : VideoFrameType::kVideoFrameDelta; + encoder.Encode(fake_image, &frame_types); + uint8_t encode_spatial_layers = (action >> 4); + for (size_t sid = 0; sid < state.config.ss_number_layers; ++sid) { + bool drop = true; + switch (state.frame_drop.framedrop_mode) { + case FULL_SUPERFRAME_DROP: + drop = encode_spatial_layers == 0; + break; + case LAYER_DROP: + drop = (encode_spatial_layers & (1 << sid)) == 0; + break; + case CONSTRAINED_LAYER_DROP: + drop = DropBelow(encode_spatial_layers, sid, + state.config.ss_number_layers); + break; + case CONSTRAINED_FROM_ABOVE_DROP: + drop = DropAbove(encode_spatial_layers, sid); + break; + } + if (!drop) { + state.layer_id.spatial_layer_id = sid; + state.callback.output_cx_pkt(&state.pkt, state.callback.user_priv); + } + } + } break; + case kSetRates: { + // bitmask of the action: (S3)(S1)(S0)01, + // where Sx is number of temporal layers to enable for spatial layer x + // In pariculat Sx = 0 indicates spatial layer x should be disabled. + LibvpxVp9Encoder::RateControlParameters parameters; + parameters.framerate_fps = 30.0; + for (int sid = 0; sid < codec.VP9()->numberOfSpatialLayers; ++sid) { + int temporal_layers = (action >> ((1 + sid) * 2)) & 0b11; + for (int tid = 0; tid < temporal_layers; ++tid) { + parameters.bitrate.SetBitrate(sid, tid, 100'000); + } + } + // Ignore allocation that turns off all the layers. in such case + // it is up to upper-layer code not to call Encode. + if (parameters.bitrate.get_sum_bps() > 0) { + encoder.SetRates(parameters); + } + } break; + default: + // Unspecificed values are noop. + break; + } + } +} +} // namespace webrtc diff --git a/test/linux/glx_renderer.cc b/test/linux/glx_renderer.cc index 50f2a06a8e..04d482c88b 100644 --- a/test/linux/glx_renderer.cc +++ b/test/linux/glx_renderer.cc @@ -20,8 +20,8 @@ namespace test { GlxRenderer::GlxRenderer(size_t width, size_t height) : width_(width), height_(height), display_(NULL), context_(NULL) { - assert(width > 0); - assert(height > 0); + RTC_DCHECK_GT(width, 0); + RTC_DCHECK_GT(height, 0); } GlxRenderer::~GlxRenderer() { diff --git a/test/logging/memory_log_writer.cc b/test/logging/memory_log_writer.cc index 2eb1cffb48..f57f0317a9 100644 --- a/test/logging/memory_log_writer.cc +++ b/test/logging/memory_log_writer.cc @@ -21,25 +21,18 @@ class MemoryLogWriter final : public RtcEventLogOutput { explicit MemoryLogWriter(std::map* target, std::string filename) : target_(target), filename_(filename) {} - ~MemoryLogWriter() final { - size_t size; - buffer_.GetSize(&size); - target_->insert({filename_, std::string(buffer_.GetBuffer(), size)}); - } + ~MemoryLogWriter() final { target_->insert({filename_, std::move(buffer_)}); } bool IsActive() const override { return true; } bool Write(const std::string& value) override { - size_t written; - int error; - return buffer_.WriteAll(value.data(), value.size(), &written, &error) == - rtc::SR_SUCCESS; - RTC_DCHECK_EQ(value.size(), written); + buffer_.append(value); + return true; } void Flush() override {} private: std::map* const target_; const std::string filename_; - rtc::MemoryStream buffer_; + std::string buffer_; }; class MemoryLogWriterFactory : public LogWriterFactoryInterface { diff --git a/test/logging/memory_log_writer.h b/test/logging/memory_log_writer.h index daef297b88..e795b2fd10 100644 --- a/test/logging/memory_log_writer.h +++ b/test/logging/memory_log_writer.h @@ -15,7 +15,6 @@ #include #include -#include "rtc_base/memory_stream.h" #include "test/logging/log_writer.h" namespace webrtc { diff --git a/test/mappable_native_buffer.cc b/test/mappable_native_buffer.cc new file mode 100644 index 0000000000..bd0b304545 --- /dev/null +++ b/test/mappable_native_buffer.cc @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "test/mappable_native_buffer.h" + +#include "absl/algorithm/container.h" +#include "api/video/i420_buffer.h" +#include "api/video/nv12_buffer.h" +#include "api/video/video_frame.h" +#include "api/video/video_rotation.h" +#include "common_video/include/video_frame_buffer.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace test { + +namespace { + +class NV12BufferWithDidConvertToI420 : public NV12Buffer { + public: + NV12BufferWithDidConvertToI420(int width, int height) + : NV12Buffer(width, height), did_convert_to_i420_(false) {} + + bool did_convert_to_i420() const { return did_convert_to_i420_; } + + rtc::scoped_refptr ToI420() override { + did_convert_to_i420_ = true; + return NV12Buffer::ToI420(); + } + + private: + bool did_convert_to_i420_; +}; + +} // namespace + +VideoFrame CreateMappableNativeFrame(int64_t ntp_time_ms, + VideoFrameBuffer::Type mappable_type, + int width, + int height) { + VideoFrame frame = + VideoFrame::Builder() + .set_video_frame_buffer(rtc::make_ref_counted( + mappable_type, width, height)) + .set_timestamp_rtp(99) + .set_timestamp_ms(99) + .set_rotation(kVideoRotation_0) + .build(); + frame.set_ntp_time_ms(ntp_time_ms); + return frame; +} + +rtc::scoped_refptr GetMappableNativeBufferFromVideoFrame( + const VideoFrame& frame) { + return static_cast(frame.video_frame_buffer().get()); +} + +MappableNativeBuffer::ScaledBuffer::ScaledBuffer( + rtc::scoped_refptr parent, + int width, + int height) + : parent_(std::move(parent)), width_(width), height_(height) {} + +MappableNativeBuffer::ScaledBuffer::~ScaledBuffer() {} + +rtc::scoped_refptr +MappableNativeBuffer::ScaledBuffer::CropAndScale(int offset_x, + int offset_y, + int crop_width, + int crop_height, + int scaled_width, + int scaled_height) { + return rtc::make_ref_counted(parent_, scaled_width, + scaled_height); +} + +rtc::scoped_refptr +MappableNativeBuffer::ScaledBuffer::ToI420() { + return parent_->GetOrCreateMappedBuffer(width_, height_)->ToI420(); +} + +rtc::scoped_refptr +MappableNativeBuffer::ScaledBuffer::GetMappedFrameBuffer( + rtc::ArrayView types) { + if (absl::c_find(types, parent_->mappable_type_) == types.end()) + return nullptr; + return parent_->GetOrCreateMappedBuffer(width_, height_); +} + +MappableNativeBuffer::MappableNativeBuffer(VideoFrameBuffer::Type mappable_type, + int width, + int height) + : mappable_type_(mappable_type), width_(width), height_(height) { + RTC_DCHECK(mappable_type_ == VideoFrameBuffer::Type::kI420 || + mappable_type_ == VideoFrameBuffer::Type::kNV12); +} + +MappableNativeBuffer::~MappableNativeBuffer() {} + +rtc::scoped_refptr MappableNativeBuffer::CropAndScale( + int offset_x, + int offset_y, + int crop_width, + int crop_height, + int scaled_width, + int scaled_height) { + return FullSizeBuffer()->CropAndScale( + offset_x, offset_y, crop_width, crop_height, scaled_width, scaled_height); +} + +rtc::scoped_refptr MappableNativeBuffer::ToI420() { + return FullSizeBuffer()->ToI420(); +} + +rtc::scoped_refptr MappableNativeBuffer::GetMappedFrameBuffer( + rtc::ArrayView types) { + return FullSizeBuffer()->GetMappedFrameBuffer(types); +} + +std::vector> +MappableNativeBuffer::GetMappedFramedBuffers() const { + MutexLock lock(&lock_); + return mapped_buffers_; +} + +bool MappableNativeBuffer::DidConvertToI420() const { + if (mappable_type_ != VideoFrameBuffer::Type::kNV12) + return false; + MutexLock lock(&lock_); + for (auto& mapped_buffer : mapped_buffers_) { + if (static_cast(mapped_buffer.get()) + ->did_convert_to_i420()) { + return true; + } + } + return false; +} + +rtc::scoped_refptr +MappableNativeBuffer::FullSizeBuffer() { + return rtc::make_ref_counted(this, width_, height_); +} + +rtc::scoped_refptr +MappableNativeBuffer::GetOrCreateMappedBuffer(int width, int height) { + MutexLock lock(&lock_); + for (auto& mapped_buffer : mapped_buffers_) { + if (mapped_buffer->width() == width && mapped_buffer->height() == height) { + return mapped_buffer; + } + } + rtc::scoped_refptr mapped_buffer; + switch (mappable_type_) { + case VideoFrameBuffer::Type::kI420: { + rtc::scoped_refptr i420_buffer = + I420Buffer::Create(width, height); + I420Buffer::SetBlack(i420_buffer); + mapped_buffer = i420_buffer; + break; + } + case VideoFrameBuffer::Type::kNV12: { + auto nv12_buffer = + rtc::make_ref_counted(width, height); + nv12_buffer->InitializeData(); + mapped_buffer = std::move(nv12_buffer); + break; + } + default: + RTC_NOTREACHED(); + } + mapped_buffers_.push_back(mapped_buffer); + return mapped_buffer; +} + +} // namespace test +} // namespace webrtc diff --git a/test/mappable_native_buffer.h b/test/mappable_native_buffer.h new file mode 100644 index 0000000000..add22029c7 --- /dev/null +++ b/test/mappable_native_buffer.h @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef TEST_MAPPABLE_NATIVE_BUFFER_H_ +#define TEST_MAPPABLE_NATIVE_BUFFER_H_ + +#include +#include + +#include "api/array_view.h" +#include "api/video/video_frame.h" +#include "common_video/include/video_frame_buffer.h" +#include "rtc_base/ref_counted_object.h" +#include "rtc_base/synchronization/mutex.h" + +namespace webrtc { +namespace test { + +class MappableNativeBuffer; + +VideoFrame CreateMappableNativeFrame(int64_t ntp_time_ms, + VideoFrameBuffer::Type mappable_type, + int width, + int height); + +rtc::scoped_refptr GetMappableNativeBufferFromVideoFrame( + const VideoFrame& frame); + +// A for-testing native buffer that is scalable and mappable. The contents of +// the buffer is black and the pixels are created upon mapping. Mapped buffers +// are stored inside MappableNativeBuffer, allowing tests to verify which +// resolutions were mapped, e.g. when passing them in to an encoder or other +// modules. +class MappableNativeBuffer : public VideoFrameBuffer { + public: + // If |allow_i420_conversion| is false, calling ToI420() on a non-I420 buffer + // will DCHECK-crash. Used to ensure zero-copy in tests. + MappableNativeBuffer(VideoFrameBuffer::Type mappable_type, + int width, + int height); + ~MappableNativeBuffer() override; + + VideoFrameBuffer::Type mappable_type() const { return mappable_type_; } + + VideoFrameBuffer::Type type() const override { return Type::kNative; } + int width() const override { return width_; } + int height() const override { return height_; } + + rtc::scoped_refptr CropAndScale(int offset_x, + int offset_y, + int crop_width, + int crop_height, + int scaled_width, + int scaled_height) override; + + rtc::scoped_refptr ToI420() override; + rtc::scoped_refptr GetMappedFrameBuffer( + rtc::ArrayView types) override; + + // Gets all the buffers that have been mapped so far, including mappings of + // cropped and scaled buffers. + std::vector> GetMappedFramedBuffers() + const; + bool DidConvertToI420() const; + + private: + friend class rtc::RefCountedObject; + + class ScaledBuffer : public VideoFrameBuffer { + public: + ScaledBuffer(rtc::scoped_refptr parent, + int width, + int height); + ~ScaledBuffer() override; + + VideoFrameBuffer::Type type() const override { return Type::kNative; } + int width() const override { return width_; } + int height() const override { return height_; } + + rtc::scoped_refptr CropAndScale( + int offset_x, + int offset_y, + int crop_width, + int crop_height, + int scaled_width, + int scaled_height) override; + + rtc::scoped_refptr ToI420() override; + rtc::scoped_refptr GetMappedFrameBuffer( + rtc::ArrayView types) override; + + private: + friend class rtc::RefCountedObject; + + const rtc::scoped_refptr parent_; + const int width_; + const int height_; + }; + + rtc::scoped_refptr FullSizeBuffer(); + rtc::scoped_refptr GetOrCreateMappedBuffer(int width, + int height); + + const VideoFrameBuffer::Type mappable_type_; + const int width_; + const int height_; + mutable Mutex lock_; + std::vector> mapped_buffers_ + RTC_GUARDED_BY(&lock_); +}; + +} // namespace test +} // namespace webrtc + +#endif // TEST_MAPPABLE_NATIVE_BUFFER_H_ diff --git a/test/mock_audio_decoder_factory.h b/test/mock_audio_decoder_factory.h index cdb03d3f38..4d3eed212c 100644 --- a/test/mock_audio_decoder_factory.h +++ b/test/mock_audio_decoder_factory.h @@ -52,7 +52,7 @@ class MockAudioDecoderFactory : public AudioDecoderFactory { using ::testing::Return; rtc::scoped_refptr factory = - new rtc::RefCountedObject; + rtc::make_ref_counted(); ON_CALL(*factory.get(), GetSupportedDecoders()) .WillByDefault(Return(std::vector())); EXPECT_CALL(*factory.get(), GetSupportedDecoders()).Times(AnyNumber()); @@ -73,7 +73,7 @@ class MockAudioDecoderFactory : public AudioDecoderFactory { using ::testing::SetArgPointee; rtc::scoped_refptr factory = - new rtc::RefCountedObject; + rtc::make_ref_counted(); ON_CALL(*factory.get(), GetSupportedDecoders()) .WillByDefault(Return(std::vector())); EXPECT_CALL(*factory.get(), GetSupportedDecoders()).Times(AnyNumber()); diff --git a/test/network/BUILD.gn b/test/network/BUILD.gn index 383f149699..1e39a3f89b 100644 --- a/test/network/BUILD.gn +++ b/test/network/BUILD.gn @@ -12,6 +12,7 @@ rtc_library("emulated_network") { visibility = [ ":*", "../../api:create_network_emulation_manager", + "../../api/test/network_emulation:create_cross_traffic", ] if (rtc_include_tests) { visibility += [ @@ -40,9 +41,12 @@ rtc_library("emulated_network") { "../../api:array_view", "../../api:network_emulation_manager_api", "../../api:packet_socket_factory", + "../../api:scoped_refptr", + "../../api:sequence_checker", "../../api:simulated_network_api", "../../api:time_controller", "../../api/numerics", + "../../api/test/network_emulation", "../../api/transport:stun_types", "../../api/units:data_rate", "../../api/units:data_size", @@ -51,14 +55,21 @@ rtc_library("emulated_network") { "../../call:simulated_network", "../../p2p:p2p_server_utils", "../../rtc_base", + "../../rtc_base:async_socket", + "../../rtc_base:ip_address", + "../../rtc_base:network_constants", "../../rtc_base:rtc_base_tests_utils", "../../rtc_base:rtc_task_queue", "../../rtc_base:safe_minmax", + "../../rtc_base:socket_address", + "../../rtc_base:socket_server", + "../../rtc_base:stringutils", "../../rtc_base:task_queue_for_test", + "../../rtc_base:threading", "../../rtc_base/synchronization:mutex", - "../../rtc_base/synchronization:sequence_checker", + "../../rtc_base/task_utils:pending_task_safety_flag", "../../rtc_base/task_utils:repeating_task", - "../../rtc_base/third_party/sigslot", + "../../rtc_base/task_utils:to_queued_task", "../../system_wrappers", "../scenario:column_printer", "../time_controller", @@ -86,7 +97,7 @@ rtc_library("network_emulation_unittest") { ] } -if (rtc_include_tests) { +if (rtc_include_tests && !build_with_chromium) { rtc_library("network_emulation_pc_unittest") { testonly = true sources = [ "network_emulation_pc_unittest.cc" ] @@ -126,10 +137,14 @@ rtc_library("cross_traffic_unittest") { "../../call:simulated_network", "../../rtc_base", "../../rtc_base:logging", + "../../rtc_base:network_constants", "../../rtc_base:rtc_event", "../time_controller", ] - absl_deps = [ "//third_party/abseil-cpp/absl/memory" ] + absl_deps = [ + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/types:optional", + ] } if (rtc_include_tests) { @@ -158,13 +173,15 @@ if (rtc_include_tests) { ] } - rtc_library("network_emulation_unittests") { - testonly = true - deps = [ - ":cross_traffic_unittest", - ":feedback_generator_unittest", - ":network_emulation_pc_unittest", - ":network_emulation_unittest", - ] + if (!build_with_chromium) { + rtc_library("network_emulation_unittests") { + testonly = true + deps = [ + ":cross_traffic_unittest", + ":feedback_generator_unittest", + ":network_emulation_pc_unittest", + ":network_emulation_unittest", + ] + } } } diff --git a/test/network/cross_traffic.cc b/test/network/cross_traffic.cc index 56e7635142..ae5b156376 100644 --- a/test/network/cross_traffic.cc +++ b/test/network/cross_traffic.cc @@ -24,7 +24,7 @@ namespace webrtc { namespace test { RandomWalkCrossTraffic::RandomWalkCrossTraffic(RandomWalkConfig config, - TrafficRoute* traffic_route) + CrossTrafficRoute* traffic_route) : config_(config), traffic_route_(traffic_route), random_(config_.random_seed) { @@ -56,6 +56,10 @@ void RandomWalkCrossTraffic::Process(Timestamp at_time) { } } +TimeDelta RandomWalkCrossTraffic::GetProcessInterval() const { + return config_.min_packet_interval; +} + DataRate RandomWalkCrossTraffic::TrafficRate() const { RTC_DCHECK_RUN_ON(&sequence_checker_); return config_.peak_rate * intensity_; @@ -70,8 +74,9 @@ ColumnPrinter RandomWalkCrossTraffic::StatsPrinter() { 32); } -PulsedPeaksCrossTraffic::PulsedPeaksCrossTraffic(PulsedPeaksConfig config, - TrafficRoute* traffic_route) +PulsedPeaksCrossTraffic::PulsedPeaksCrossTraffic( + PulsedPeaksConfig config, + CrossTrafficRoute* traffic_route) : config_(config), traffic_route_(traffic_route) { sequence_checker_.Detach(); } @@ -102,6 +107,10 @@ void PulsedPeaksCrossTraffic::Process(Timestamp at_time) { } } +TimeDelta PulsedPeaksCrossTraffic::GetProcessInterval() const { + return config_.min_packet_interval; +} + DataRate PulsedPeaksCrossTraffic::TrafficRate() const { RTC_DCHECK_RUN_ON(&sequence_checker_); return sending_ ? config_.peak_rate : DataRate::Zero(); @@ -240,21 +249,13 @@ void TcpMessageRouteImpl::HandlePacketTimeout(int seq_num, Timestamp at_time) { } } -FakeTcpCrossTraffic::FakeTcpCrossTraffic(Clock* clock, - FakeTcpConfig config, +FakeTcpCrossTraffic::FakeTcpCrossTraffic(FakeTcpConfig config, EmulatedRoute* send_route, EmulatedRoute* ret_route) - : clock_(clock), conf_(config), route_(this, send_route, ret_route) {} - -void FakeTcpCrossTraffic::Start(TaskQueueBase* task_queue) { - repeating_task_handle_ = RepeatingTaskHandle::Start(task_queue, [this] { - Process(clock_->CurrentTime()); - return conf_.process_interval; - }); -} + : conf_(config), route_(this, send_route, ret_route) {} -void FakeTcpCrossTraffic::Stop() { - repeating_task_handle_.Stop(); +TimeDelta FakeTcpCrossTraffic::GetProcessInterval() const { + return conf_.process_interval; } void FakeTcpCrossTraffic::Process(Timestamp at_time) { diff --git a/test/network/cross_traffic.h b/test/network/cross_traffic.h index 942b863bbf..487622d4d4 100644 --- a/test/network/cross_traffic.h +++ b/test/network/cross_traffic.h @@ -15,41 +15,34 @@ #include #include +#include "api/sequence_checker.h" +#include "api/test/network_emulation_manager.h" #include "api/units/data_rate.h" #include "api/units/data_size.h" #include "api/units/time_delta.h" #include "api/units/timestamp.h" #include "rtc_base/random.h" -#include "rtc_base/synchronization/sequence_checker.h" -#include "test/network/traffic_route.h" +#include "test/network/network_emulation.h" #include "test/scenario/column_printer.h" namespace webrtc { namespace test { -struct RandomWalkConfig { - int random_seed = 1; - DataRate peak_rate = DataRate::KilobitsPerSec(100); - DataSize min_packet_size = DataSize::Bytes(200); - TimeDelta min_packet_interval = TimeDelta::Millis(1); - TimeDelta update_interval = TimeDelta::Millis(200); - double variance = 0.6; - double bias = -0.1; -}; - -class RandomWalkCrossTraffic { +class RandomWalkCrossTraffic final : public CrossTrafficGenerator { public: - RandomWalkCrossTraffic(RandomWalkConfig config, TrafficRoute* traffic_route); + RandomWalkCrossTraffic(RandomWalkConfig config, + CrossTrafficRoute* traffic_route); ~RandomWalkCrossTraffic(); - void Process(Timestamp at_time); + void Process(Timestamp at_time) override; + TimeDelta GetProcessInterval() const override; DataRate TrafficRate() const; ColumnPrinter StatsPrinter(); private: SequenceChecker sequence_checker_; const RandomWalkConfig config_; - TrafficRoute* const traffic_route_ RTC_PT_GUARDED_BY(sequence_checker_); + CrossTrafficRoute* const traffic_route_ RTC_PT_GUARDED_BY(sequence_checker_); webrtc::Random random_ RTC_GUARDED_BY(sequence_checker_); Timestamp last_process_time_ RTC_GUARDED_BY(sequence_checker_) = @@ -62,28 +55,21 @@ class RandomWalkCrossTraffic { DataSize pending_size_ RTC_GUARDED_BY(sequence_checker_) = DataSize::Zero(); }; -struct PulsedPeaksConfig { - DataRate peak_rate = DataRate::KilobitsPerSec(100); - DataSize min_packet_size = DataSize::Bytes(200); - TimeDelta min_packet_interval = TimeDelta::Millis(1); - TimeDelta send_duration = TimeDelta::Millis(100); - TimeDelta hold_duration = TimeDelta::Millis(2000); -}; - -class PulsedPeaksCrossTraffic { +class PulsedPeaksCrossTraffic final : public CrossTrafficGenerator { public: PulsedPeaksCrossTraffic(PulsedPeaksConfig config, - TrafficRoute* traffic_route); + CrossTrafficRoute* traffic_route); ~PulsedPeaksCrossTraffic(); - void Process(Timestamp at_time); + void Process(Timestamp at_time) override; + TimeDelta GetProcessInterval() const override; DataRate TrafficRate() const; ColumnPrinter StatsPrinter(); private: SequenceChecker sequence_checker_; const PulsedPeaksConfig config_; - TrafficRoute* const traffic_route_ RTC_PT_GUARDED_BY(sequence_checker_); + CrossTrafficRoute* const traffic_route_ RTC_PT_GUARDED_BY(sequence_checker_); Timestamp last_update_time_ RTC_GUARDED_BY(sequence_checker_) = Timestamp::MinusInfinity(); @@ -149,23 +135,17 @@ class TcpMessageRouteImpl final : public TcpMessageRoute { TimeDelta last_rtt_ = TimeDelta::Zero(); }; -struct FakeTcpConfig { - DataSize packet_size = DataSize::Bytes(1200); - DataSize send_limit = DataSize::PlusInfinity(); - TimeDelta process_interval = TimeDelta::Millis(200); - TimeDelta packet_timeout = TimeDelta::Seconds(1); -}; - class FakeTcpCrossTraffic - : public TwoWayFakeTrafficRoute::TrafficHandlerInterface { + : public TwoWayFakeTrafficRoute::TrafficHandlerInterface, + public CrossTrafficGenerator { public: - FakeTcpCrossTraffic(Clock* clock, - FakeTcpConfig config, + FakeTcpCrossTraffic(FakeTcpConfig config, EmulatedRoute* send_route, EmulatedRoute* ret_route); - void Start(TaskQueueBase* task_queue); - void Stop(); - void Process(Timestamp at_time); + + TimeDelta GetProcessInterval() const override; + void Process(Timestamp at_time) override; + void OnRequest(int sequence_number, Timestamp at_time) override; void OnResponse(int sequence_number, Timestamp at_time) override; @@ -174,7 +154,6 @@ class FakeTcpCrossTraffic void SendPackets(Timestamp at_time); private: - Clock* const clock_; const FakeTcpConfig conf_; TwoWayFakeTrafficRoute route_; @@ -187,7 +166,6 @@ class FakeTcpCrossTraffic Timestamp last_reduction_time_ = Timestamp::MinusInfinity(); TimeDelta last_rtt_ = TimeDelta::Zero(); DataSize total_sent_ = DataSize::Zero(); - RepeatingTaskHandle repeating_task_handle_; }; } // namespace test diff --git a/test/network/cross_traffic_unittest.cc b/test/network/cross_traffic_unittest.cc index c8d848f154..2744a90ce3 100644 --- a/test/network/cross_traffic_unittest.cc +++ b/test/network/cross_traffic_unittest.cc @@ -16,6 +16,7 @@ #include #include "absl/memory/memory.h" +#include "absl/types/optional.h" #include "api/test/network_emulation_manager.h" #include "api/test/simulated_network.h" #include "call/simulated_network.h" @@ -25,6 +26,7 @@ #include "test/gmock.h" #include "test/gtest.h" #include "test/network/network_emulation_manager.h" +#include "test/network/traffic_route.h" #include "test/time_controller/simulated_time_controller.h" namespace webrtc { @@ -47,21 +49,20 @@ struct TrafficCounterFixture { SimulatedClock clock{0}; CountingReceiver counter; TaskQueueForTest task_queue_; - EmulatedEndpointImpl endpoint{ - /*id=*/1, - rtc::IPAddress(kTestIpAddress), - EmulatedEndpointConfig::StatsGatheringMode::kDefault, - /*is_enabled=*/true, - /*type=*/rtc::AdapterType::ADAPTER_TYPE_UNKNOWN, - &task_queue_, - &clock}; + EmulatedEndpointImpl endpoint{EmulatedEndpointImpl::Options{ + /*id=*/1, + rtc::IPAddress(kTestIpAddress), + EmulatedEndpointConfig(), + }, + /*is_enabled=*/true, &task_queue_, &clock}; }; } // namespace TEST(CrossTrafficTest, TriggerPacketBurst) { TrafficCounterFixture fixture; - TrafficRoute traffic(&fixture.clock, &fixture.counter, &fixture.endpoint); + CrossTrafficRouteImpl traffic(&fixture.clock, &fixture.counter, + &fixture.endpoint); traffic.TriggerPacketBurst(100, 1000); EXPECT_EQ(fixture.counter.packets_count_, 100); @@ -70,7 +71,8 @@ TEST(CrossTrafficTest, TriggerPacketBurst) { TEST(CrossTrafficTest, PulsedPeaksCrossTraffic) { TrafficCounterFixture fixture; - TrafficRoute traffic(&fixture.clock, &fixture.counter, &fixture.endpoint); + CrossTrafficRouteImpl traffic(&fixture.clock, &fixture.counter, + &fixture.endpoint); PulsedPeaksConfig config; config.peak_rate = DataRate::KilobitsPerSec(1000); @@ -95,7 +97,8 @@ TEST(CrossTrafficTest, PulsedPeaksCrossTraffic) { TEST(CrossTrafficTest, RandomWalkCrossTraffic) { TrafficCounterFixture fixture; - TrafficRoute traffic(&fixture.clock, &fixture.counter, &fixture.endpoint); + CrossTrafficRouteImpl traffic(&fixture.clock, &fixture.counter, + &fixture.endpoint); RandomWalkConfig config; config.peak_rate = DataRate::KilobitsPerSec(1000); diff --git a/test/network/emulated_network_manager.h b/test/network/emulated_network_manager.h index 2321af0e04..fd2bb5b665 100644 --- a/test/network/emulated_network_manager.h +++ b/test/network/emulated_network_manager.h @@ -15,13 +15,13 @@ #include #include +#include "api/sequence_checker.h" #include "api/test/network_emulation_manager.h" #include "api/test/time_controller.h" #include "rtc_base/ip_address.h" #include "rtc_base/network.h" #include "rtc_base/socket_server.h" #include "rtc_base/thread.h" -#include "rtc_base/thread_checker.h" #include "test/network/network_emulation.h" namespace webrtc { diff --git a/test/network/emulated_turn_server.cc b/test/network/emulated_turn_server.cc index 06a8bf9a94..d67e4e337a 100644 --- a/test/network/emulated_turn_server.cc +++ b/test/network/emulated_turn_server.cc @@ -14,6 +14,7 @@ #include #include "api/packet_socket_factory.h" +#include "rtc_base/strings/string_builder.h" namespace { diff --git a/test/network/fake_network_socket_server.cc b/test/network/fake_network_socket_server.cc index bee2846be7..bf6ef5f12d 100644 --- a/test/network/fake_network_socket_server.cc +++ b/test/network/fake_network_socket_server.cc @@ -16,8 +16,10 @@ #include #include "absl/algorithm/container.h" -#include "rtc_base/async_invoker.h" +#include "api/scoped_refptr.h" #include "rtc_base/logging.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/thread.h" namespace webrtc { @@ -74,7 +76,7 @@ class FakeNetworkSocket : public rtc::AsyncSocket, std::map options_map_ RTC_GUARDED_BY(&thread_); absl::optional pending_ RTC_GUARDED_BY(thread_); - rtc::AsyncInvoker invoker_; + rtc::scoped_refptr alive_; }; FakeNetworkSocket::FakeNetworkSocket(FakeNetworkSocketServer* socket_server, @@ -82,9 +84,13 @@ FakeNetworkSocket::FakeNetworkSocket(FakeNetworkSocketServer* socket_server, : socket_server_(socket_server), thread_(thread), state_(CS_CLOSED), - error_(0) {} + error_(0), + alive_(PendingTaskSafetyFlag::Create()) {} FakeNetworkSocket::~FakeNetworkSocket() { + // Abandon all pending packets. + alive_->SetNotAlive(); + Close(); socket_server_->Unregister(this); } @@ -103,7 +109,7 @@ void FakeNetworkSocket::OnPacketReceived(EmulatedIpPacket packet) { SignalReadEvent(this); RTC_DCHECK(!pending_); }; - invoker_.AsyncInvoke(RTC_FROM_HERE, thread_, std::move(task)); + thread_->PostTask(ToQueuedTask(alive_, std::move(task))); socket_server_->WakeUp(); } @@ -270,10 +276,6 @@ FakeNetworkSocketServer::FakeNetworkSocketServer( wakeup_(/*manual_reset=*/false, /*initially_signaled=*/false) {} FakeNetworkSocketServer::~FakeNetworkSocketServer() = default; -void FakeNetworkSocketServer::OnMessageQueueDestroyed() { - thread_ = nullptr; -} - EmulatedEndpointImpl* FakeNetworkSocketServer::GetEndpointNode( const rtc::IPAddress& ip) { return endpoints_container_->LookupByLocalAddress(ip); @@ -305,10 +307,6 @@ rtc::AsyncSocket* FakeNetworkSocketServer::CreateAsyncSocket(int family, void FakeNetworkSocketServer::SetMessageQueue(rtc::Thread* thread) { thread_ = thread; - if (thread_) { - thread_->SignalQueueDestroyed.connect( - this, &FakeNetworkSocketServer::OnMessageQueueDestroyed); - } } // Always returns true (if return false, it won't be invoked again...) diff --git a/test/network/fake_network_socket_server.h b/test/network/fake_network_socket_server.h index 2cf4d7c86d..d8be2e24b8 100644 --- a/test/network/fake_network_socket_server.h +++ b/test/network/fake_network_socket_server.h @@ -19,7 +19,6 @@ #include "rtc_base/event.h" #include "rtc_base/socket_server.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/third_party/sigslot/sigslot.h" #include "system_wrappers/include/clock.h" #include "test/network/network_emulation.h" @@ -28,8 +27,7 @@ namespace test { class FakeNetworkSocket; // FakeNetworkSocketServer must outlive any sockets it creates. -class FakeNetworkSocketServer : public rtc::SocketServer, - public sigslot::has_slots<> { +class FakeNetworkSocketServer : public rtc::SocketServer { public: explicit FakeNetworkSocketServer(EndpointsContainer* endpoints_controller); ~FakeNetworkSocketServer() override; @@ -52,8 +50,6 @@ class FakeNetworkSocketServer : public rtc::SocketServer, void Unregister(FakeNetworkSocket* socket); private: - void OnMessageQueueDestroyed(); - const EndpointsContainer* endpoints_container_; rtc::Event wakeup_; rtc::Thread* thread_ = nullptr; diff --git a/test/network/g3doc/g3doc.lua b/test/network/g3doc/g3doc.lua new file mode 100644 index 0000000000..981393c826 --- /dev/null +++ b/test/network/g3doc/g3doc.lua @@ -0,0 +1,5 @@ +config = super() + +config.freshness.owner = 'titovartem' + +return config diff --git a/test/network/g3doc/index.md b/test/network/g3doc/index.md index 908e0e2ca6..5d511916c1 100644 --- a/test/network/g3doc/index.md +++ b/test/network/g3doc/index.md @@ -1,9 +1,6 @@ # Network Emulation Framework - + [TOC] diff --git a/test/network/network_emulation.cc b/test/network/network_emulation.cc index bf6c0683d4..ada9ab542a 100644 --- a/test/network/network_emulation.cc +++ b/test/network/network_emulation.cc @@ -13,10 +13,12 @@ #include #include #include +#include +#include "absl/types/optional.h" #include "api/numerics/samples_stats_counter.h" +#include "api/test/network_emulation/network_emulation_interfaces.h" #include "api/units/data_size.h" -#include "rtc_base/bind.h" #include "rtc_base/logging.h" namespace webrtc { @@ -346,6 +348,9 @@ void NetworkRouterNode::OnPacketReceived(EmulatedIpPacket packet) { } auto receiver_it = routing_.find(packet.to.ipaddr()); if (receiver_it == routing_.end()) { + if (default_receiver_.has_value()) { + (*default_receiver_)->OnPacketReceived(std::move(packet)); + } return; } RTC_CHECK(receiver_it != routing_.end()); @@ -370,6 +375,23 @@ void NetworkRouterNode::RemoveReceiver(const rtc::IPAddress& dest_ip) { routing_.erase(dest_ip); } +void NetworkRouterNode::SetDefaultReceiver( + EmulatedNetworkReceiverInterface* receiver) { + task_queue_->PostTask([=] { + RTC_DCHECK_RUN_ON(task_queue_); + if (default_receiver_.has_value()) { + RTC_CHECK_EQ(*default_receiver_, receiver) + << "Router already default receiver"; + } + default_receiver_ = receiver; + }); +} + +void NetworkRouterNode::RemoveDefaultReceiver() { + RTC_DCHECK_RUN_ON(task_queue_); + default_receiver_ = absl::nullopt; +} + void NetworkRouterNode::SetWatcher( std::function watcher) { task_queue_->PostTask([=] { @@ -415,61 +437,72 @@ void EmulatedNetworkNode::ClearRoute(const rtc::IPAddress& receiver_ip, EmulatedNetworkNode::~EmulatedNetworkNode() = default; -EmulatedEndpointImpl::EmulatedEndpointImpl( - uint64_t id, - const rtc::IPAddress& ip, - EmulatedEndpointConfig::StatsGatheringMode stats_gathering_mode, - bool is_enabled, - rtc::AdapterType type, - rtc::TaskQueue* task_queue, - Clock* clock) - : id_(id), - peer_local_addr_(ip), - stats_gathering_mode_(stats_gathering_mode), +EmulatedEndpointImpl::Options::Options(uint64_t id, + const rtc::IPAddress& ip, + const EmulatedEndpointConfig& config) + : id(id), + ip(ip), + stats_gathering_mode(config.stats_gathering_mode), + type(config.type), + allow_send_packet_with_different_source_ip( + config.allow_send_packet_with_different_source_ip), + allow_receive_packets_with_different_dest_ip( + config.allow_receive_packets_with_different_dest_ip), + log_name(ip.ToString() + " (" + config.name.value_or("") + ")") {} + +EmulatedEndpointImpl::EmulatedEndpointImpl(const Options& options, + bool is_enabled, + rtc::TaskQueue* task_queue, + Clock* clock) + : options_(options), is_enabled_(is_enabled), - type_(type), clock_(clock), task_queue_(task_queue), router_(task_queue_), next_port_(kFirstEphemeralPort), - stats_builder_(peer_local_addr_) { + stats_builder_(options_.ip) { constexpr int kIPv4NetworkPrefixLength = 24; constexpr int kIPv6NetworkPrefixLength = 64; int prefix_length = 0; - if (ip.family() == AF_INET) { + if (options_.ip.family() == AF_INET) { prefix_length = kIPv4NetworkPrefixLength; - } else if (ip.family() == AF_INET6) { + } else if (options_.ip.family() == AF_INET6) { prefix_length = kIPv6NetworkPrefixLength; } - rtc::IPAddress prefix = TruncateIP(ip, prefix_length); + rtc::IPAddress prefix = TruncateIP(options_.ip, prefix_length); network_ = std::make_unique( - ip.ToString(), "Endpoint id=" + std::to_string(id_), prefix, - prefix_length, type_); - network_->AddIP(ip); + options_.ip.ToString(), "Endpoint id=" + std::to_string(options_.id), + prefix, prefix_length, options_.type); + network_->AddIP(options_.ip); enabled_state_checker_.Detach(); + RTC_LOG(INFO) << "Created emulated endpoint " << options_.log_name + << "; id=" << options_.id; } EmulatedEndpointImpl::~EmulatedEndpointImpl() = default; uint64_t EmulatedEndpointImpl::GetId() const { - return id_; + return options_.id; } void EmulatedEndpointImpl::SendPacket(const rtc::SocketAddress& from, const rtc::SocketAddress& to, rtc::CopyOnWriteBuffer packet_data, uint16_t application_overhead) { - RTC_CHECK(from.ipaddr() == peer_local_addr_); + if (!options_.allow_send_packet_with_different_source_ip) { + RTC_CHECK(from.ipaddr() == options_.ip); + } EmulatedIpPacket packet(from, to, std::move(packet_data), clock_->CurrentTime(), application_overhead); task_queue_->PostTask([this, packet = std::move(packet)]() mutable { RTC_DCHECK_RUN_ON(task_queue_); - stats_builder_.OnPacketSent( - packet.arrival_time, clock_->CurrentTime(), packet.to.ipaddr(), - DataSize::Bytes(packet.ip_packet_size()), stats_gathering_mode_); + stats_builder_.OnPacketSent(packet.arrival_time, clock_->CurrentTime(), + packet.to.ipaddr(), + DataSize::Bytes(packet.ip_packet_size()), + options_.stats_gathering_mode); - if (packet.to.ipaddr() == peer_local_addr_) { + if (packet.to.ipaddr() == options_.ip) { OnPacketReceived(std::move(packet)); } else { router_.OnPacketReceived(std::move(packet)); @@ -480,7 +513,20 @@ void EmulatedEndpointImpl::SendPacket(const rtc::SocketAddress& from, absl::optional EmulatedEndpointImpl::BindReceiver( uint16_t desired_port, EmulatedNetworkReceiverInterface* receiver) { - rtc::CritScope crit(&receiver_lock_); + return BindReceiverInternal(desired_port, receiver, /*is_one_shot=*/false); +} + +absl::optional EmulatedEndpointImpl::BindOneShotReceiver( + uint16_t desired_port, + EmulatedNetworkReceiverInterface* receiver) { + return BindReceiverInternal(desired_port, receiver, /*is_one_shot=*/true); +} + +absl::optional EmulatedEndpointImpl::BindReceiverInternal( + uint16_t desired_port, + EmulatedNetworkReceiverInterface* receiver, + bool is_one_shot) { + MutexLock lock(&receiver_lock_); uint16_t port = desired_port; if (port == 0) { // Because client can specify its own port, next_port_ can be already in @@ -496,15 +542,17 @@ absl::optional EmulatedEndpointImpl::BindReceiver( } } RTC_CHECK(port != 0) << "Can't find free port for receiver in endpoint " - << id_; - bool result = port_to_receiver_.insert({port, receiver}).second; + << options_.log_name << "; id=" << options_.id; + bool result = + port_to_receiver_.insert({port, {receiver, is_one_shot}}).second; if (!result) { RTC_LOG(INFO) << "Can't bind receiver to used port " << desired_port - << " in endpoint " << id_; + << " in endpoint " << options_.log_name + << "; id=" << options_.id; return absl::nullopt; } - RTC_LOG(INFO) << "New receiver is binded to endpoint " << id_ << " on port " - << port; + RTC_LOG(INFO) << "New receiver is binded to endpoint " << options_.log_name + << "; id=" << options_.id << " on port " << port; return port; } @@ -519,40 +567,71 @@ uint16_t EmulatedEndpointImpl::NextPort() { } void EmulatedEndpointImpl::UnbindReceiver(uint16_t port) { - rtc::CritScope crit(&receiver_lock_); + MutexLock lock(&receiver_lock_); + RTC_LOG(INFO) << "Receiver is removed on port " << port << " from endpoint " + << options_.log_name << "; id=" << options_.id; port_to_receiver_.erase(port); } +void EmulatedEndpointImpl::BindDefaultReceiver( + EmulatedNetworkReceiverInterface* receiver) { + MutexLock lock(&receiver_lock_); + RTC_CHECK(!default_receiver_.has_value()) + << "Endpoint " << options_.log_name << "; id=" << options_.id + << " already has default receiver"; + RTC_LOG(INFO) << "Default receiver is binded to endpoint " + << options_.log_name << "; id=" << options_.id; + default_receiver_ = receiver; +} + +void EmulatedEndpointImpl::UnbindDefaultReceiver() { + MutexLock lock(&receiver_lock_); + RTC_LOG(INFO) << "Default receiver is removed from endpoint " + << options_.log_name << "; id=" << options_.id; + default_receiver_ = absl::nullopt; +} + rtc::IPAddress EmulatedEndpointImpl::GetPeerLocalAddress() const { - return peer_local_addr_; + return options_.ip; } void EmulatedEndpointImpl::OnPacketReceived(EmulatedIpPacket packet) { RTC_DCHECK_RUN_ON(task_queue_); - RTC_CHECK(packet.to.ipaddr() == peer_local_addr_) - << "Routing error: wrong destination endpoint. Packet.to.ipaddr()=: " - << packet.to.ipaddr().ToString() - << "; Receiver peer_local_addr_=" << peer_local_addr_.ToString(); - rtc::CritScope crit(&receiver_lock_); + if (!options_.allow_receive_packets_with_different_dest_ip) { + RTC_CHECK(packet.to.ipaddr() == options_.ip) + << "Routing error: wrong destination endpoint. Packet.to.ipaddr()=: " + << packet.to.ipaddr().ToString() + << "; Receiver options_.ip=" << options_.ip.ToString(); + } + MutexLock lock(&receiver_lock_); stats_builder_.OnPacketReceived(clock_->CurrentTime(), packet.from.ipaddr(), DataSize::Bytes(packet.ip_packet_size()), - stats_gathering_mode_); + options_.stats_gathering_mode); auto it = port_to_receiver_.find(packet.to.port()); if (it == port_to_receiver_.end()) { + if (default_receiver_.has_value()) { + (*default_receiver_)->OnPacketReceived(std::move(packet)); + return; + } // It can happen, that remote peer closed connection, but there still some // packets, that are going to it. It can happen during peer connection close // process: one peer closed connection, second still sending data. - RTC_LOG(INFO) << "Drop packet: no receiver registered in " << id_ - << " on port " << packet.to.port(); + RTC_LOG(INFO) << "Drop packet: no receiver registered in " + << options_.log_name << "; id=" << options_.id << " on port " + << packet.to.port(); stats_builder_.OnPacketDropped(packet.from.ipaddr(), DataSize::Bytes(packet.ip_packet_size()), - stats_gathering_mode_); + options_.stats_gathering_mode); return; } - // Endpoint assumes frequent calls to bind and unbind methods, so it holds - // lock during packet processing to ensure that receiver won't be deleted - // before call to OnPacketReceived. - it->second->OnPacketReceived(std::move(packet)); + // Endpoint holds lock during packet processing to ensure that a call to + // UnbindReceiver followed by a delete of the receiver cannot race with this + // call to OnPacketReceived. + it->second.receiver->OnPacketReceived(std::move(packet)); + + if (it->second.is_one_shot) { + port_to_receiver_.erase(it); + } } void EmulatedEndpointImpl::Enable() { diff --git a/test/network/network_emulation.h b/test/network/network_emulation.h index c4d79661aa..f700beffcd 100644 --- a/test/network/network_emulation.h +++ b/test/network/network_emulation.h @@ -22,6 +22,7 @@ #include "absl/types/optional.h" #include "api/array_view.h" #include "api/numerics/samples_stats_counter.h" +#include "api/sequence_checker.h" #include "api/test/network_emulation_manager.h" #include "api/test/simulated_network.h" #include "api/units/timestamp.h" @@ -29,11 +30,10 @@ #include "rtc_base/network.h" #include "rtc_base/network_constants.h" #include "rtc_base/socket_address.h" -#include "rtc_base/synchronization/sequence_checker.h" +#include "rtc_base/synchronization/mutex.h" #include "rtc_base/task_queue_for_test.h" #include "rtc_base/task_utils/repeating_task.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" #include "system_wrappers/include/clock.h" namespace webrtc { @@ -419,6 +419,10 @@ class LinkEmulation : public EmulatedNetworkReceiverInterface { uint64_t next_packet_id_ RTC_GUARDED_BY(task_queue_) = 1; }; +// Represents a component responsible for routing packets based on their IP +// address. All possible routes have to be set explicitly before packet for +// desired destination will be seen for the first time. If route is unknown +// the packet will be silently dropped. class NetworkRouterNode : public EmulatedNetworkReceiverInterface { public: explicit NetworkRouterNode(rtc::TaskQueue* task_queue); @@ -427,11 +431,17 @@ class NetworkRouterNode : public EmulatedNetworkReceiverInterface { void SetReceiver(const rtc::IPAddress& dest_ip, EmulatedNetworkReceiverInterface* receiver); void RemoveReceiver(const rtc::IPAddress& dest_ip); + // Sets a default receive that will be used for all incoming packets for which + // there is no specific receiver binded to their destination port. + void SetDefaultReceiver(EmulatedNetworkReceiverInterface* receiver); + void RemoveDefaultReceiver(); void SetWatcher(std::function watcher); void SetFilter(std::function filter); private: rtc::TaskQueue* const task_queue_; + absl::optional default_receiver_ + RTC_GUARDED_BY(task_queue_); std::map routing_ RTC_GUARDED_BY(task_queue_); std::function watcher_ @@ -482,14 +492,33 @@ class EmulatedNetworkNode : public EmulatedNetworkReceiverInterface { // from other EmulatedNetworkNodes. class EmulatedEndpointImpl : public EmulatedEndpoint { public: - EmulatedEndpointImpl( - uint64_t id, - const rtc::IPAddress& ip, - EmulatedEndpointConfig::StatsGatheringMode stats_gathering_mode, - bool is_enabled, - rtc::AdapterType type, - rtc::TaskQueue* task_queue, - Clock* clock); + struct Options { + Options(uint64_t id, + const rtc::IPAddress& ip, + const EmulatedEndpointConfig& config); + + // TODO(titovartem) check if we can remove id. + uint64_t id; + // Endpoint local IP address. + rtc::IPAddress ip; + EmulatedEndpointConfig::StatsGatheringMode stats_gathering_mode; + rtc::AdapterType type; + // Allow endpoint to send packets specifying source IP address different to + // the current endpoint IP address. If false endpoint will crash if attempt + // to send such packet will be done. + bool allow_send_packet_with_different_source_ip; + // Allow endpoint to receive packet with destination IP address different to + // the current endpoint IP address. If false endpoint will crash if such + // packet will arrive. + bool allow_receive_packets_with_different_dest_ip; + // Name of the endpoint used for logging purposes. + std::string log_name; + }; + + EmulatedEndpointImpl(const Options& options, + bool is_enabled, + rtc::TaskQueue* task_queue, + Clock* clock); ~EmulatedEndpointImpl() override; uint64_t GetId() const; @@ -504,7 +533,14 @@ class EmulatedEndpointImpl : public EmulatedEndpoint { absl::optional BindReceiver( uint16_t desired_port, EmulatedNetworkReceiverInterface* receiver) override; + // Binds a receiver, and automatically removes the binding after first call to + // OnPacketReceived. + absl::optional BindOneShotReceiver( + uint16_t desired_port, + EmulatedNetworkReceiverInterface* receiver); void UnbindReceiver(uint16_t port) override; + void BindDefaultReceiver(EmulatedNetworkReceiverInterface* receiver) override; + void UnbindDefaultReceiver() override; rtc::IPAddress GetPeerLocalAddress() const override; @@ -520,25 +556,33 @@ class EmulatedEndpointImpl : public EmulatedEndpoint { std::unique_ptr stats() const; private: + struct ReceiverBinding { + EmulatedNetworkReceiverInterface* receiver; + bool is_one_shot; + }; + + absl::optional BindReceiverInternal( + uint16_t desired_port, + EmulatedNetworkReceiverInterface* receiver, + bool is_one_shot); + static constexpr uint16_t kFirstEphemeralPort = 49152; uint16_t NextPort() RTC_EXCLUSIVE_LOCKS_REQUIRED(receiver_lock_); - rtc::RecursiveCriticalSection receiver_lock_; - rtc::ThreadChecker enabled_state_checker_; + Mutex receiver_lock_; + SequenceChecker enabled_state_checker_; - const uint64_t id_; - // Peer's local IP address for this endpoint network interface. - const rtc::IPAddress peer_local_addr_; - const EmulatedEndpointConfig::StatsGatheringMode stats_gathering_mode_; + const Options options_; bool is_enabled_ RTC_GUARDED_BY(enabled_state_checker_); - const rtc::AdapterType type_; Clock* const clock_; rtc::TaskQueue* const task_queue_; std::unique_ptr network_; NetworkRouterNode router_; uint16_t next_port_ RTC_GUARDED_BY(receiver_lock_); - std::map port_to_receiver_ + absl::optional default_receiver_ + RTC_GUARDED_BY(receiver_lock_); + std::map port_to_receiver_ RTC_GUARDED_BY(receiver_lock_); EmulatedNetworkStatsBuilder stats_builder_ RTC_GUARDED_BY(task_queue_); @@ -548,13 +592,19 @@ class EmulatedRoute { public: EmulatedRoute(EmulatedEndpointImpl* from, std::vector via_nodes, - EmulatedEndpointImpl* to) - : from(from), via_nodes(std::move(via_nodes)), to(to), active(true) {} + EmulatedEndpointImpl* to, + bool is_default) + : from(from), + via_nodes(std::move(via_nodes)), + to(to), + active(true), + is_default(is_default) {} EmulatedEndpointImpl* from; std::vector via_nodes; EmulatedEndpointImpl* to; bool active; + bool is_default; }; // This object is immutable and so thread safe. diff --git a/test/network/network_emulation_manager.cc b/test/network/network_emulation_manager.cc index 57706fc782..2c96191200 100644 --- a/test/network/network_emulation_manager.cc +++ b/test/network/network_emulation_manager.cc @@ -18,6 +18,7 @@ #include "call/simulated_network.h" #include "rtc_base/fake_network.h" #include "test/network/emulated_turn_server.h" +#include "test/network/traffic_route.h" #include "test/time_controller/real_time_controller.h" #include "test/time_controller/simulated_time_controller.h" @@ -45,7 +46,8 @@ std::unique_ptr CreateTimeController(TimeMode mode) { } // namespace NetworkEmulationManagerImpl::NetworkEmulationManagerImpl(TimeMode mode) - : time_controller_(CreateTimeController(mode)), + : time_mode_(mode), + time_controller_(CreateTimeController(mode)), clock_(time_controller_->GetClock()), next_node_id_(1), next_ip4_address_(kMinIPv4Address), @@ -85,7 +87,7 @@ NetworkEmulationManagerImpl::NodeBuilder() { return SimulatedNetworkNode::Builder(this); } -EmulatedEndpoint* NetworkEmulationManagerImpl::CreateEndpoint( +EmulatedEndpointImpl* NetworkEmulationManagerImpl::CreateEndpoint( EmulatedEndpointConfig config) { absl::optional ip = config.ip; if (!ip) { @@ -105,9 +107,9 @@ EmulatedEndpoint* NetworkEmulationManagerImpl::CreateEndpoint( bool res = used_ip_addresses_.insert(*ip).second; RTC_CHECK(res) << "IP=" << ip->ToString() << " already in use"; auto node = std::make_unique( - next_node_id_++, *ip, config.stats_gathering_mode, - config.start_as_enabled, config.type, &task_queue_, clock_); - EmulatedEndpoint* out = node.get(); + EmulatedEndpointImpl::Options(next_node_id_++, *ip, config), + config.start_as_enabled, &task_queue_, clock_); + EmulatedEndpointImpl* out = node.get(); endpoints_.push_back(std::move(node)); return out; } @@ -146,7 +148,7 @@ EmulatedRoute* NetworkEmulationManagerImpl::CreateRoute( std::unique_ptr route = std::make_unique( static_cast(from), std::move(via_nodes), - static_cast(to)); + static_cast(to), /*is_default=*/false); EmulatedRoute* out = route.get(); routes_.push_back(std::move(route)); return out; @@ -159,26 +161,72 @@ EmulatedRoute* NetworkEmulationManagerImpl::CreateRoute( return CreateRoute(from, via_nodes, to); } +EmulatedRoute* NetworkEmulationManagerImpl::CreateDefaultRoute( + EmulatedEndpoint* from, + const std::vector& via_nodes, + EmulatedEndpoint* to) { + // Because endpoint has no send node by default at least one should be + // provided here. + RTC_CHECK(!via_nodes.empty()); + + static_cast(from)->router()->SetDefaultReceiver( + via_nodes[0]); + EmulatedNetworkNode* cur_node = via_nodes[0]; + for (size_t i = 1; i < via_nodes.size(); ++i) { + cur_node->router()->SetDefaultReceiver(via_nodes[i]); + cur_node = via_nodes[i]; + } + cur_node->router()->SetDefaultReceiver(to); + + std::unique_ptr route = std::make_unique( + static_cast(from), std::move(via_nodes), + static_cast(to), /*is_default=*/true); + EmulatedRoute* out = route.get(); + routes_.push_back(std::move(route)); + return out; +} + void NetworkEmulationManagerImpl::ClearRoute(EmulatedRoute* route) { RTC_CHECK(route->active) << "Route already cleared"; task_queue_.SendTask( [route]() { // Remove receiver from intermediate nodes. for (auto* node : route->via_nodes) { - node->router()->RemoveReceiver(route->to->GetPeerLocalAddress()); + if (route->is_default) { + node->router()->RemoveDefaultReceiver(); + } else { + node->router()->RemoveReceiver(route->to->GetPeerLocalAddress()); + } } // Remove destination endpoint from source endpoint's router. - route->from->router()->RemoveReceiver(route->to->GetPeerLocalAddress()); + if (route->is_default) { + route->from->router()->RemoveDefaultReceiver(); + } else { + route->from->router()->RemoveReceiver( + route->to->GetPeerLocalAddress()); + } route->active = false; }, RTC_FROM_HERE); } -TrafficRoute* NetworkEmulationManagerImpl::CreateTrafficRoute( +TcpMessageRoute* NetworkEmulationManagerImpl::CreateTcpRoute( + EmulatedRoute* send_route, + EmulatedRoute* ret_route) { + auto tcp_route = std::make_unique( + clock_, task_queue_.Get(), send_route, ret_route); + auto* route_ptr = tcp_route.get(); + task_queue_.PostTask([this, tcp_route = std::move(tcp_route)]() mutable { + tcp_message_routes_.push_back(std::move(tcp_route)); + }); + return route_ptr; +} + +CrossTrafficRoute* NetworkEmulationManagerImpl::CreateCrossTrafficRoute( const std::vector& via_nodes) { RTC_CHECK(!via_nodes.empty()); - EmulatedEndpoint* endpoint = CreateEndpoint(EmulatedEndpointConfig()); + EmulatedEndpointImpl* endpoint = CreateEndpoint(EmulatedEndpointConfig()); // Setup a route via specified nodes. EmulatedNetworkNode* cur_node = via_nodes[0]; @@ -189,88 +237,40 @@ TrafficRoute* NetworkEmulationManagerImpl::CreateTrafficRoute( } cur_node->router()->SetReceiver(endpoint->GetPeerLocalAddress(), endpoint); - std::unique_ptr traffic_route = - std::make_unique(clock_, via_nodes[0], endpoint); - TrafficRoute* out = traffic_route.get(); + std::unique_ptr traffic_route = + std::make_unique(clock_, via_nodes[0], endpoint); + CrossTrafficRoute* out = traffic_route.get(); traffic_routes_.push_back(std::move(traffic_route)); return out; } -RandomWalkCrossTraffic* -NetworkEmulationManagerImpl::CreateRandomWalkCrossTraffic( - TrafficRoute* traffic_route, - RandomWalkConfig config) { - auto traffic = - std::make_unique(config, traffic_route); - RandomWalkCrossTraffic* out = traffic.get(); - - task_queue_.PostTask( - [this, config, traffic = std::move(traffic)]() mutable { - auto* traffic_ptr = traffic.get(); - random_cross_traffics_.push_back(std::move(traffic)); - RepeatingTaskHandle::Start(task_queue_.Get(), - [this, config, traffic_ptr] { - traffic_ptr->Process(Now()); - return config.min_packet_interval; - }); - }); - return out; -} - -PulsedPeaksCrossTraffic* -NetworkEmulationManagerImpl::CreatePulsedPeaksCrossTraffic( - TrafficRoute* traffic_route, - PulsedPeaksConfig config) { - auto traffic = - std::make_unique(config, traffic_route); - PulsedPeaksCrossTraffic* out = traffic.get(); - task_queue_.PostTask( - [this, config, traffic = std::move(traffic)]() mutable { - auto* traffic_ptr = traffic.get(); - pulsed_cross_traffics_.push_back(std::move(traffic)); - RepeatingTaskHandle::Start(task_queue_.Get(), - [this, config, traffic_ptr] { - traffic_ptr->Process(Now()); - return config.min_packet_interval; - }); - }); - return out; -} +CrossTrafficGenerator* NetworkEmulationManagerImpl::StartCrossTraffic( + std::unique_ptr generator) { + CrossTrafficGenerator* out = generator.get(); + task_queue_.PostTask([this, generator = std::move(generator)]() mutable { + auto* generator_ptr = generator.get(); -FakeTcpCrossTraffic* NetworkEmulationManagerImpl::StartFakeTcpCrossTraffic( - std::vector send_link, - std::vector ret_link, - FakeTcpConfig config) { - auto traffic = std::make_unique( - clock_, config, CreateRoute(send_link), CreateRoute(ret_link)); - auto* traffic_ptr = traffic.get(); - task_queue_.PostTask([this, traffic = std::move(traffic)]() mutable { - traffic->Start(task_queue_.Get()); - tcp_cross_traffics_.push_back(std::move(traffic)); - }); - return traffic_ptr; -} + auto repeating_task_handle = + RepeatingTaskHandle::Start(task_queue_.Get(), [this, generator_ptr] { + generator_ptr->Process(Now()); + return generator_ptr->GetProcessInterval(); + }); -TcpMessageRoute* NetworkEmulationManagerImpl::CreateTcpRoute( - EmulatedRoute* send_route, - EmulatedRoute* ret_route) { - auto tcp_route = std::make_unique( - clock_, task_queue_.Get(), send_route, ret_route); - auto* route_ptr = tcp_route.get(); - task_queue_.PostTask([this, tcp_route = std::move(tcp_route)]() mutable { - tcp_message_routes_.push_back(std::move(tcp_route)); + cross_traffics_.push_back(CrossTrafficSource( + std::move(generator), std::move(repeating_task_handle))); }); - return route_ptr; + return out; } void NetworkEmulationManagerImpl::StopCrossTraffic( - FakeTcpCrossTraffic* traffic) { + CrossTrafficGenerator* generator) { task_queue_.PostTask([=]() { - traffic->Stop(); - tcp_cross_traffics_.remove_if( - [=](const std::unique_ptr& ptr) { - return ptr.get() == traffic; - }); + auto it = std::find_if(cross_traffics_.begin(), cross_traffics_.end(), + [=](const CrossTrafficSource& el) { + return el.first.get() == generator; + }); + it->second.Stop(); + cross_traffics_.erase(it); }); } @@ -278,6 +278,7 @@ EmulatedNetworkManagerInterface* NetworkEmulationManagerImpl::CreateEmulatedNetworkManagerInterface( const std::vector& endpoints) { std::vector endpoint_impls; + endpoint_impls.reserve(endpoints.size()); for (EmulatedEndpoint* endpoint : endpoints) { endpoint_impls.push_back(static_cast(endpoint)); } @@ -303,7 +304,7 @@ NetworkEmulationManagerImpl::CreateEmulatedNetworkManagerInterface( } void NetworkEmulationManagerImpl::GetStats( - rtc::ArrayView endpoints, + rtc::ArrayView endpoints, std::function)> stats_callback) { task_queue_.PostTask([endpoints, stats_callback]() { EmulatedNetworkStatsBuilder stats_builder; diff --git a/test/network/network_emulation_manager.h b/test/network/network_emulation_manager.h index b2b41b34a2..449441a3c1 100644 --- a/test/network/network_emulation_manager.h +++ b/test/network/network_emulation_manager.h @@ -34,7 +34,6 @@ #include "test/network/emulated_turn_server.h" #include "test/network/fake_network_socket_server.h" #include "test/network/network_emulation.h" -#include "test/network/traffic_route.h" namespace webrtc { namespace test { @@ -51,7 +50,7 @@ class NetworkEmulationManagerImpl : public NetworkEmulationManager { SimulatedNetworkNode::Builder NodeBuilder() override; - EmulatedEndpoint* CreateEndpoint(EmulatedEndpointConfig config) override; + EmulatedEndpointImpl* CreateEndpoint(EmulatedEndpointConfig config) override; void EnableEndpoint(EmulatedEndpoint* endpoint) override; void DisableEndpoint(EmulatedEndpoint* endpoint) override; @@ -62,42 +61,46 @@ class NetworkEmulationManagerImpl : public NetworkEmulationManager { EmulatedRoute* CreateRoute( const std::vector& via_nodes) override; - void ClearRoute(EmulatedRoute* route) override; + EmulatedRoute* CreateDefaultRoute( + EmulatedEndpoint* from, + const std::vector& via_nodes, + EmulatedEndpoint* to) override; - TrafficRoute* CreateTrafficRoute( - const std::vector& via_nodes); - RandomWalkCrossTraffic* CreateRandomWalkCrossTraffic( - TrafficRoute* traffic_route, - RandomWalkConfig config); - PulsedPeaksCrossTraffic* CreatePulsedPeaksCrossTraffic( - TrafficRoute* traffic_route, - PulsedPeaksConfig config); - FakeTcpCrossTraffic* StartFakeTcpCrossTraffic( - std::vector send_link, - std::vector ret_link, - FakeTcpConfig config); + void ClearRoute(EmulatedRoute* route) override; TcpMessageRoute* CreateTcpRoute(EmulatedRoute* send_route, EmulatedRoute* ret_route) override; - void StopCrossTraffic(FakeTcpCrossTraffic* traffic); + CrossTrafficRoute* CreateCrossTrafficRoute( + const std::vector& via_nodes) override; + + CrossTrafficGenerator* StartCrossTraffic( + std::unique_ptr generator) override; + void StopCrossTraffic(CrossTrafficGenerator* generator) override; EmulatedNetworkManagerInterface* CreateEmulatedNetworkManagerInterface( const std::vector& endpoints) override; - void GetStats(rtc::ArrayView endpoints, + void GetStats(rtc::ArrayView endpoints, std::function)> stats_callback) override; TimeController* time_controller() override { return time_controller_.get(); } + TimeMode time_mode() const override { return time_mode_; } + Timestamp Now() const; EmulatedTURNServerInterface* CreateTURNServer( EmulatedTURNServerConfig config) override; private: + using CrossTrafficSource = + std::pair, RepeatingTaskHandle>; + absl::optional GetNextIPv4Address(); + + const TimeMode time_mode_; const std::unique_ptr time_controller_; Clock* const clock_; int next_node_id_; @@ -111,10 +114,8 @@ class NetworkEmulationManagerImpl : public NetworkEmulationManager { std::vector> endpoints_; std::vector> network_nodes_; std::vector> routes_; - std::vector> traffic_routes_; - std::vector> random_cross_traffics_; - std::vector> pulsed_cross_traffics_; - std::list> tcp_cross_traffics_; + std::vector> traffic_routes_; + std::vector cross_traffics_; std::list> tcp_message_routes_; std::vector> endpoints_containers_; std::vector> network_managers_; diff --git a/test/network/network_emulation_pc_unittest.cc b/test/network/network_emulation_pc_unittest.cc index 6420e36275..bd15b5ad38 100644 --- a/test/network/network_emulation_pc_unittest.cc +++ b/test/network/network_emulation_pc_unittest.cc @@ -99,7 +99,12 @@ rtc::scoped_refptr CreatePeerConnection( rtc_configuration.servers.push_back(server); } - return pcf->CreatePeerConnection(rtc_configuration, std::move(pc_deps)); + auto result = + pcf->CreatePeerConnectionOrError(rtc_configuration, std::move(pc_deps)); + if (!result.ok()) { + return nullptr; + } + return result.MoveValue(); } } // namespace diff --git a/test/network/network_emulation_unittest.cc b/test/network/network_emulation_unittest.cc index c92b344872..fca10c40b7 100644 --- a/test/network/network_emulation_unittest.cc +++ b/test/network/network_emulation_unittest.cc @@ -207,8 +207,14 @@ TEST(NetworkEmulationManagerTest, Run) { rtc::CopyOnWriteBuffer data("Hello"); for (uint64_t j = 0; j < 2; j++) { - auto* s1 = t1->socketserver()->CreateAsyncSocket(AF_INET, SOCK_DGRAM); - auto* s2 = t2->socketserver()->CreateAsyncSocket(AF_INET, SOCK_DGRAM); + rtc::AsyncSocket* s1 = nullptr; + rtc::AsyncSocket* s2 = nullptr; + t1->Invoke(RTC_FROM_HERE, [&] { + s1 = t1->socketserver()->CreateAsyncSocket(AF_INET, SOCK_DGRAM); + }); + t2->Invoke(RTC_FROM_HERE, [&] { + s2 = t2->socketserver()->CreateAsyncSocket(AF_INET, SOCK_DGRAM); + }); SocketReader r1(s1, t1); SocketReader r2(s2, t2); @@ -357,8 +363,14 @@ TEST(NetworkEmulationManagerTest, DebugStatsCollectedInDebugMode) { rtc::CopyOnWriteBuffer data("Hello"); for (uint64_t j = 0; j < 2; j++) { - auto* s1 = t1->socketserver()->CreateAsyncSocket(AF_INET, SOCK_DGRAM); - auto* s2 = t2->socketserver()->CreateAsyncSocket(AF_INET, SOCK_DGRAM); + rtc::AsyncSocket* s1 = nullptr; + rtc::AsyncSocket* s2 = nullptr; + t1->Invoke(RTC_FROM_HERE, [&] { + s1 = t1->socketserver()->CreateAsyncSocket(AF_INET, SOCK_DGRAM); + }); + t2->Invoke(RTC_FROM_HERE, [&] { + s2 = t2->socketserver()->CreateAsyncSocket(AF_INET, SOCK_DGRAM); + }); SocketReader r1(s1, t1); SocketReader r2(s2, t2); @@ -454,8 +466,15 @@ TEST(NetworkEmulationManagerTest, ThroughputStats) { constexpr int64_t kUdpPayloadSize = 100; constexpr int64_t kSinglePacketSize = kUdpPayloadSize + kOverheadIpv4Udp; rtc::CopyOnWriteBuffer data(kUdpPayloadSize); - auto* s1 = t1->socketserver()->CreateAsyncSocket(AF_INET, SOCK_DGRAM); - auto* s2 = t2->socketserver()->CreateAsyncSocket(AF_INET, SOCK_DGRAM); + + rtc::AsyncSocket* s1 = nullptr; + rtc::AsyncSocket* s2 = nullptr; + t1->Invoke(RTC_FROM_HERE, [&] { + s1 = t1->socketserver()->CreateAsyncSocket(AF_INET, SOCK_DGRAM); + }); + t2->Invoke(RTC_FROM_HERE, [&] { + s2 = t2->socketserver()->CreateAsyncSocket(AF_INET, SOCK_DGRAM); + }); SocketReader r1(s1, t1); SocketReader r2(s2, t2); @@ -568,6 +587,51 @@ TEST(NetworkEmulationManagerTest, EndpointLoopback) { network_manager.time_controller()->AdvanceTime(TimeDelta::Seconds(1)); } +TEST(NetworkEmulationManagerTest, EndpointCanSendWithDifferentSourceIp) { + constexpr uint32_t kEndpointIp = 0xC0A80011; // 192.168.0.17 + constexpr uint32_t kSourceIp = 0xC0A80012; // 192.168.0.18 + NetworkEmulationManagerImpl network_manager(TimeMode::kSimulated); + EmulatedEndpointConfig endpoint_config; + endpoint_config.ip = rtc::IPAddress(kEndpointIp); + endpoint_config.allow_send_packet_with_different_source_ip = true; + auto endpoint = network_manager.CreateEndpoint(endpoint_config); + + MockReceiver receiver; + EXPECT_CALL(receiver, OnPacketReceived(::testing::_)).Times(1); + ASSERT_EQ(endpoint->BindReceiver(80, &receiver), 80); + + endpoint->SendPacket(rtc::SocketAddress(kSourceIp, 80), + rtc::SocketAddress(endpoint->GetPeerLocalAddress(), 80), + "Hello"); + network_manager.time_controller()->AdvanceTime(TimeDelta::Seconds(1)); +} + +TEST(NetworkEmulationManagerTest, + EndpointCanReceiveWithDifferentDestIpThroughDefaultRoute) { + constexpr uint32_t kDestEndpointIp = 0xC0A80011; // 192.168.0.17 + constexpr uint32_t kDestIp = 0xC0A80012; // 192.168.0.18 + NetworkEmulationManagerImpl network_manager(TimeMode::kSimulated); + auto sender_endpoint = + network_manager.CreateEndpoint(EmulatedEndpointConfig()); + EmulatedEndpointConfig endpoint_config; + endpoint_config.ip = rtc::IPAddress(kDestEndpointIp); + endpoint_config.allow_receive_packets_with_different_dest_ip = true; + auto receiver_endpoint = network_manager.CreateEndpoint(endpoint_config); + + MockReceiver receiver; + EXPECT_CALL(receiver, OnPacketReceived(::testing::_)).Times(1); + ASSERT_EQ(receiver_endpoint->BindReceiver(80, &receiver), 80); + + network_manager.CreateDefaultRoute( + sender_endpoint, {network_manager.NodeBuilder().Build().node}, + receiver_endpoint); + + sender_endpoint->SendPacket( + rtc::SocketAddress(sender_endpoint->GetPeerLocalAddress(), 80), + rtc::SocketAddress(kDestIp, 80), "Hello"); + network_manager.time_controller()->AdvanceTime(TimeDelta::Seconds(1)); +} + TEST(NetworkEmulationManagerTURNTest, GetIceServerConfig) { NetworkEmulationManagerImpl network_manager(TimeMode::kRealTime); auto turn = network_manager.CreateTURNServer(EmulatedTURNServerConfig()); diff --git a/test/network/traffic_route.cc b/test/network/traffic_route.cc index 98586337b9..81bb8ca514 100644 --- a/test/network/traffic_route.cc +++ b/test/network/traffic_route.cc @@ -29,33 +29,23 @@ class NullReceiver : public EmulatedNetworkReceiverInterface { class ActionReceiver : public EmulatedNetworkReceiverInterface { public: - ActionReceiver(std::function action, EmulatedEndpoint* endpoint) - : action_(action), endpoint_(endpoint) {} + explicit ActionReceiver(std::function action) : action_(action) {} ~ActionReceiver() override = default; void OnPacketReceived(EmulatedIpPacket packet) override { - RTC_DCHECK(port_); action_(); - endpoint_->UnbindReceiver(port_.value()); } - // We can't set port in constructor, because port will be provided by - // endpoint, when this receiver will be binded to that endpoint. - void SetPort(uint16_t port) { port_ = port; } - private: std::function action_; - // Endpoint and port will be used to free port in the endpoint after action - // will be done. - EmulatedEndpoint* endpoint_; - absl::optional port_ = absl::nullopt; }; } // namespace -TrafficRoute::TrafficRoute(Clock* clock, - EmulatedNetworkReceiverInterface* receiver, - EmulatedEndpoint* endpoint) +CrossTrafficRouteImpl::CrossTrafficRouteImpl( + Clock* clock, + EmulatedNetworkReceiverInterface* receiver, + EmulatedEndpointImpl* endpoint) : clock_(clock), receiver_(receiver), endpoint_(endpoint) { null_receiver_ = std::make_unique(); absl::optional port = @@ -63,30 +53,32 @@ TrafficRoute::TrafficRoute(Clock* clock, RTC_DCHECK(port); null_receiver_port_ = port.value(); } -TrafficRoute::~TrafficRoute() = default; +CrossTrafficRouteImpl::~CrossTrafficRouteImpl() = default; -void TrafficRoute::TriggerPacketBurst(size_t num_packets, size_t packet_size) { +void CrossTrafficRouteImpl::TriggerPacketBurst(size_t num_packets, + size_t packet_size) { for (size_t i = 0; i < num_packets; ++i) { SendPacket(packet_size); } } -void TrafficRoute::NetworkDelayedAction(size_t packet_size, - std::function action) { - auto action_receiver = std::make_unique(action, endpoint_); +void CrossTrafficRouteImpl::NetworkDelayedAction(size_t packet_size, + std::function action) { + auto action_receiver = std::make_unique(action); + // BindOneShotReceiver arranges to free the port in the endpoint after the + // action is done. absl::optional port = - endpoint_->BindReceiver(0, action_receiver.get()); + endpoint_->BindOneShotReceiver(0, action_receiver.get()); RTC_DCHECK(port); - action_receiver->SetPort(port.value()); actions_.push_back(std::move(action_receiver)); SendPacket(packet_size, port.value()); } -void TrafficRoute::SendPacket(size_t packet_size) { +void CrossTrafficRouteImpl::SendPacket(size_t packet_size) { SendPacket(packet_size, null_receiver_port_); } -void TrafficRoute::SendPacket(size_t packet_size, uint16_t dest_port) { +void CrossTrafficRouteImpl::SendPacket(size_t packet_size, uint16_t dest_port) { rtc::CopyOnWriteBuffer data(packet_size); std::fill_n(data.MutableData(), data.size(), 0); receiver_->OnPacketReceived(EmulatedIpPacket( diff --git a/test/network/traffic_route.h b/test/network/traffic_route.h index 1bb34c6b6c..2c2fadc427 100644 --- a/test/network/traffic_route.h +++ b/test/network/traffic_route.h @@ -14,6 +14,7 @@ #include #include +#include "api/test/network_emulation_manager.h" #include "rtc_base/copy_on_write_buffer.h" #include "system_wrappers/include/clock.h" #include "test/network/network_emulation.h" @@ -23,26 +24,27 @@ namespace test { // Represents the endpoint for cross traffic that is going through the network. // It can be used to emulate unexpected network load. -class TrafficRoute { +class CrossTrafficRouteImpl final : public CrossTrafficRoute { public: - TrafficRoute(Clock* clock, - EmulatedNetworkReceiverInterface* receiver, - EmulatedEndpoint* endpoint); - ~TrafficRoute(); + CrossTrafficRouteImpl(Clock* clock, + EmulatedNetworkReceiverInterface* receiver, + EmulatedEndpointImpl* endpoint); + ~CrossTrafficRouteImpl(); // Triggers sending of dummy packets with size |packet_size| bytes. - void TriggerPacketBurst(size_t num_packets, size_t packet_size); + void TriggerPacketBurst(size_t num_packets, size_t packet_size) override; // Sends a packet over the nodes and runs |action| when it has been delivered. - void NetworkDelayedAction(size_t packet_size, std::function action); + void NetworkDelayedAction(size_t packet_size, + std::function action) override; - void SendPacket(size_t packet_size); + void SendPacket(size_t packet_size) override; private: void SendPacket(size_t packet_size, uint16_t dest_port); Clock* const clock_; EmulatedNetworkReceiverInterface* const receiver_; - EmulatedEndpoint* const endpoint_; + EmulatedEndpointImpl* const endpoint_; uint16_t null_receiver_port_; std::unique_ptr null_receiver_; diff --git a/test/pc/e2e/BUILD.gn b/test/pc/e2e/BUILD.gn index 3901297063..9e9d5c2db5 100644 --- a/test/pc/e2e/BUILD.gn +++ b/test/pc/e2e/BUILD.gn @@ -13,13 +13,12 @@ if (!build_with_chromium) { testonly = true deps = [ - ":default_encoded_image_data_injector", ":encoded_image_data_injector_api", ":example_video_quality_analyzer", - ":id_generator", ":quality_analyzing_video_decoder", ":quality_analyzing_video_encoder", ":single_process_encoded_image_data_injector", + ":video_frame_tracking_id_injector", ] if (rtc_include_tests) { deps += [ @@ -35,11 +34,11 @@ if (!build_with_chromium) { testonly = true deps = [ - ":default_encoded_image_data_injector_unittest", ":default_video_quality_analyzer_test", ":multi_head_queue_test", ":peer_connection_e2e_smoke_test", ":single_process_encoded_image_data_injector_unittest", + ":video_frame_tracking_id_injector_unittest", ] } } @@ -61,6 +60,7 @@ if (!build_with_chromium) { "../../../api/transport:webrtc_key_value_config", "../../../api/video_codecs:video_codecs_api", "../../../rtc_base", + "../../../rtc_base:threading", ] } @@ -72,12 +72,12 @@ if (!build_with_chromium) { deps = [ "../../../api/video:encoded_image" ] } - rtc_library("default_encoded_image_data_injector") { + rtc_library("single_process_encoded_image_data_injector") { visibility = [ "*" ] testonly = true sources = [ - "analyzer/video/default_encoded_image_data_injector.cc", - "analyzer/video/default_encoded_image_data_injector.h", + "analyzer/video/single_process_encoded_image_data_injector.cc", + "analyzer/video/single_process_encoded_image_data_injector.h", ] deps = [ @@ -85,38 +85,27 @@ if (!build_with_chromium) { "../../../api/video:encoded_image", "../../../rtc_base:checks", "../../../rtc_base:criticalsection", + "../../../rtc_base/synchronization:mutex", ] absl_deps = [ "//third_party/abseil-cpp/absl/memory" ] } - rtc_library("single_process_encoded_image_data_injector") { + rtc_library("video_frame_tracking_id_injector") { visibility = [ "*" ] testonly = true sources = [ - "analyzer/video/single_process_encoded_image_data_injector.cc", - "analyzer/video/single_process_encoded_image_data_injector.h", + "analyzer/video/video_frame_tracking_id_injector.cc", + "analyzer/video/video_frame_tracking_id_injector.h", ] deps = [ ":encoded_image_data_injector_api", "../../../api/video:encoded_image", "../../../rtc_base:checks", - "../../../rtc_base:criticalsection", - "../../../rtc_base/synchronization:mutex", ] absl_deps = [ "//third_party/abseil-cpp/absl/memory" ] } - rtc_library("id_generator") { - visibility = [ "*" ] - testonly = true - sources = [ - "analyzer/video/id_generator.cc", - "analyzer/video/id_generator.h", - ] - deps = [] - } - rtc_library("simulcast_dummy_buffer_helper") { visibility = [ "*" ] testonly = true @@ -136,7 +125,6 @@ if (!build_with_chromium) { ] deps = [ ":encoded_image_data_injector_api", - ":id_generator", ":simulcast_dummy_buffer_helper", "../../../api:video_quality_analyzer_api", "../../../api/video:encoded_image", @@ -163,7 +151,6 @@ if (!build_with_chromium) { ] deps = [ ":encoded_image_data_injector_api", - ":id_generator", "../../../api:video_quality_analyzer_api", "../../../api/video:encoded_image", "../../../api/video:video_frame", @@ -187,7 +174,6 @@ if (!build_with_chromium) { ] deps = [ ":encoded_image_data_injector_api", - ":id_generator", ":quality_analyzing_video_decoder", ":quality_analyzing_video_encoder", ":simulcast_dummy_buffer_helper", @@ -303,6 +289,7 @@ if (!build_with_chromium) { "../../../api:peer_connection_quality_test_fixture_api", "../../../api/video:video_frame", "../../../pc:peerconnection", + "../../../pc:session_description", "../../../pc:video_track_source", ] absl_deps = [ "//third_party/abseil-cpp/absl/types:variant" ] @@ -328,6 +315,7 @@ if (!build_with_chromium) { "../../../api/transport:network_control", "../../../api/video_codecs:video_codecs_api", "../../../rtc_base", + "../../../rtc_base:threading", ] absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] } @@ -404,6 +392,7 @@ if (!build_with_chromium) { "../../../rtc_base:rtc_base_approved", "../../../rtc_base:safe_conversions", "../../../rtc_base:task_queue_for_test", + "../../../rtc_base:threading", "../../../rtc_base/synchronization:mutex", "../../../system_wrappers", "../../../system_wrappers:field_trial", @@ -424,12 +413,12 @@ if (!build_with_chromium) { ] } - rtc_library("default_encoded_image_data_injector_unittest") { + rtc_library("video_frame_tracking_id_injector_unittest") { testonly = true sources = - [ "analyzer/video/default_encoded_image_data_injector_unittest.cc" ] + [ "analyzer/video/video_frame_tracking_id_injector_unittest.cc" ] deps = [ - ":default_encoded_image_data_injector", + ":video_frame_tracking_id_injector", "../../../api/video:encoded_image", "../../../rtc_base:rtc_base_approved", "../../../test:test_support", @@ -544,9 +533,9 @@ if (!build_with_chromium) { "analyzer_helper.h", ] deps = [ + "../../../api:sequence_checker", "../../../api:track_id_stream_info_map", "../../../rtc_base:macromagic", - "../../../rtc_base/synchronization:sequence_checker", ] absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] } @@ -697,6 +686,7 @@ if (!build_with_chromium) { "../../../api/units:data_size", "../../../api/units:timestamp", "../../../rtc_base", + "../../../rtc_base:ip_address", "../../../rtc_base:rtc_event", "../../../rtc_base:stringutils", "../../../rtc_base/synchronization:mutex", @@ -748,6 +738,8 @@ if (!build_with_chromium) { "../../../p2p:rtc_p2p", "../../../pc:peerconnection", "../../../pc:rtc_pc_base", + "../../../pc:session_description", + "../../../pc:simulcast_description", "../../../rtc_base:stringutils", ] absl_deps = [ diff --git a/test/pc/e2e/analyzer/audio/default_audio_quality_analyzer.cc b/test/pc/e2e/analyzer/audio/default_audio_quality_analyzer.cc index 8830436b09..30c17c1ca9 100644 --- a/test/pc/e2e/analyzer/audio/default_audio_quality_analyzer.cc +++ b/test/pc/e2e/analyzer/audio/default_audio_quality_analyzer.cc @@ -26,7 +26,7 @@ void DefaultAudioQualityAnalyzer::Start(std::string test_case_name, void DefaultAudioQualityAnalyzer::OnStatsReports( absl::string_view pc_label, const rtc::scoped_refptr& report) { - // TODO(https://crbug.com/webrtc/11683): use "inbound-rtp" instead of "track" + // TODO(https://crbug.com/webrtc/11789): use "inbound-rtp" instead of "track" // stats when required audio metrics moved there auto stats = report->GetStatsOfType(); diff --git a/test/pc/e2e/analyzer/video/default_encoded_image_data_injector.cc b/test/pc/e2e/analyzer/video/default_encoded_image_data_injector.cc deleted file mode 100644 index c5eab0a1b0..0000000000 --- a/test/pc/e2e/analyzer/video/default_encoded_image_data_injector.cc +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "test/pc/e2e/analyzer/video/default_encoded_image_data_injector.h" - -#include -#include - -#include "absl/memory/memory.h" -#include "api/video/encoded_image.h" -#include "rtc_base/checks.h" - -namespace webrtc { -namespace webrtc_pc_e2e { -namespace { - -// The amount on which encoded image buffer will be expanded to inject frame id. -// This is 2 bytes for uint16_t frame id itself and 4 bytes for original length -// of the buffer. -constexpr int kEncodedImageBufferExpansion = 6; - -struct ExtractionInfo { - size_t length; - bool discard; -}; - -} // namespace - -DefaultEncodedImageDataInjector::DefaultEncodedImageDataInjector() = default; -DefaultEncodedImageDataInjector::~DefaultEncodedImageDataInjector() = default; - -EncodedImage DefaultEncodedImageDataInjector::InjectData( - uint16_t id, - bool discard, - const EncodedImage& source, - int /*coding_entity_id*/) { - auto buffer = - EncodedImageBuffer::Create(source.size() + kEncodedImageBufferExpansion); - memcpy(buffer->data(), source.data(), source.size()); - - size_t insertion_pos = source.size(); - buffer->data()[insertion_pos] = id & 0x00ff; - buffer->data()[insertion_pos + 1] = (id & 0xff00) >> 8; - buffer->data()[insertion_pos + 2] = source.size() & 0x000000ff; - buffer->data()[insertion_pos + 3] = (source.size() & 0x0000ff00) >> 8; - buffer->data()[insertion_pos + 4] = (source.size() & 0x00ff0000) >> 16; - buffer->data()[insertion_pos + 5] = (source.size() & 0xff000000) >> 24; - - // We will store discard flag in the high bit of high byte of the size. - RTC_CHECK_LT(source.size(), 1U << 31) << "High bit is already in use"; - buffer->data()[insertion_pos + 5] = - buffer->data()[insertion_pos + 5] | ((discard ? 1 : 0) << 7); - - EncodedImage out = source; - out.SetEncodedData(buffer); - return out; -} - -EncodedImageExtractionResult DefaultEncodedImageDataInjector::ExtractData( - const EncodedImage& source, - int /*coding_entity_id*/) { - auto buffer = EncodedImageBuffer::Create(source.size()); - EncodedImage out = source; - out.SetEncodedData(buffer); - - size_t source_pos = source.size() - 1; - absl::optional id = absl::nullopt; - bool discard = true; - std::vector extraction_infos; - // First make a reverse pass through whole buffer to populate frame id, - // discard flags and concatenated encoded images length. - while (true) { - size_t insertion_pos = source_pos - kEncodedImageBufferExpansion + 1; - RTC_CHECK_GE(insertion_pos, 0); - RTC_CHECK_LE(insertion_pos + kEncodedImageBufferExpansion, source.size()); - uint16_t next_id = - source.data()[insertion_pos] + (source.data()[insertion_pos + 1] << 8); - RTC_CHECK(!id || id.value() == next_id) - << "Different frames encoded into single encoded image: " << id.value() - << " vs " << next_id; - id = next_id; - uint32_t length = source.data()[insertion_pos + 2] + - (source.data()[insertion_pos + 3] << 8) + - (source.data()[insertion_pos + 4] << 16) + - ((source.data()[insertion_pos + 5] << 24) & 0b01111111); - bool current_discard = (source.data()[insertion_pos + 5] & 0b10000000) != 0; - extraction_infos.push_back({length, current_discard}); - // Extraction result is discarded only if all encoded partitions are - // discarded. - discard = discard && current_discard; - if (source_pos < length + kEncodedImageBufferExpansion) { - break; - } - source_pos -= length + kEncodedImageBufferExpansion; - } - RTC_CHECK(id); - std::reverse(extraction_infos.begin(), extraction_infos.end()); - if (discard) { - out.set_size(0); - return EncodedImageExtractionResult{*id, out, true}; - } - - // Now basing on populated data make a forward pass to copy required pieces - // of data to the output buffer. - source_pos = 0; - size_t out_pos = 0; - auto extraction_infos_it = extraction_infos.begin(); - while (source_pos < source.size()) { - const ExtractionInfo& info = *extraction_infos_it; - RTC_CHECK_LE(source_pos + kEncodedImageBufferExpansion + info.length, - source.size()); - if (!info.discard) { - // Copy next encoded image payload from concatenated buffer only if it is - // not discarded. - memcpy(&buffer->data()[out_pos], &source.data()[source_pos], info.length); - out_pos += info.length; - } - source_pos += info.length + kEncodedImageBufferExpansion; - ++extraction_infos_it; - } - out.set_size(out_pos); - - return EncodedImageExtractionResult{id.value(), out, discard}; -} - -} // namespace webrtc_pc_e2e -} // namespace webrtc diff --git a/test/pc/e2e/analyzer/video/default_encoded_image_data_injector.h b/test/pc/e2e/analyzer/video/default_encoded_image_data_injector.h deleted file mode 100644 index b60c214703..0000000000 --- a/test/pc/e2e/analyzer/video/default_encoded_image_data_injector.h +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef TEST_PC_E2E_ANALYZER_VIDEO_DEFAULT_ENCODED_IMAGE_DATA_INJECTOR_H_ -#define TEST_PC_E2E_ANALYZER_VIDEO_DEFAULT_ENCODED_IMAGE_DATA_INJECTOR_H_ - -#include -#include -#include -#include -#include -#include - -#include "api/video/encoded_image.h" -#include "test/pc/e2e/analyzer/video/encoded_image_data_injector.h" - -namespace webrtc { -namespace webrtc_pc_e2e { - -// Injects frame id and discard flag into EncodedImage payload buffer. The -// payload buffer will be appended in the injector with 2 bytes frame id and 4 -// bytes original buffer length. Discarded flag will be put into the highest bit -// of the length. It is assumed, that frame's data can't be more then 2^31 -// bytes. In the decoder, frame id and discard flag will be extracted and the -// length will be used to restore original buffer. We can't put this data in the -// beginning of the payload, because first bytes are used in different parts of -// WebRTC pipeline. -// -// The data in the EncodedImage on encoder side after injection will look like -// this: -// 4 bytes frame length + discard flag -// _________________ _ _ _↓_ _ _ -// | original buffer | | | -// ¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯ ¯↑¯ ¯ ¯ ¯ ¯ -// 2 bytes frame id -// -// But on decoder side multiple payloads can be concatenated into single -// EncodedImage in jitter buffer and its payload will look like this: -// _________ _ _ _ _ _ _ _________ _ _ _ _ _ _ _________ _ _ _ _ _ _ -// buf: | payload | | | payload | | | payload | | | -// ¯¯¯¯¯¯¯¯¯ ¯ ¯ ¯ ¯ ¯ ¯ ¯¯¯¯¯¯¯¯¯ ¯ ¯ ¯ ¯ ¯ ¯ ¯¯¯¯¯¯¯¯¯ ¯ ¯ ¯ ¯ ¯ ¯ -// To correctly restore such images we will extract id by this algorithm: -// 1. Make a pass from end to begin of the buffer to restore origin lengths, -// frame ids and discard flags from length high bit. -// 2. If all discard flags are true - discard this encoded image -// 3. Make a pass from begin to end copying data to the output basing on -// previously extracted length -// Also it will check, that all extracted ids are equals. -class DefaultEncodedImageDataInjector : public EncodedImageDataInjector, - public EncodedImageDataExtractor { - public: - DefaultEncodedImageDataInjector(); - ~DefaultEncodedImageDataInjector() override; - - EncodedImage InjectData(uint16_t id, - bool discard, - const EncodedImage& source, - int /*coding_entity_id*/) override; - - void Start(int expected_receivers_count) override {} - EncodedImageExtractionResult ExtractData(const EncodedImage& source, - int coding_entity_id) override; -}; - -} // namespace webrtc_pc_e2e -} // namespace webrtc - -#endif // TEST_PC_E2E_ANALYZER_VIDEO_DEFAULT_ENCODED_IMAGE_DATA_INJECTOR_H_ diff --git a/test/pc/e2e/analyzer/video/default_encoded_image_data_injector_unittest.cc b/test/pc/e2e/analyzer/video/default_encoded_image_data_injector_unittest.cc deleted file mode 100644 index 2ba2298fb5..0000000000 --- a/test/pc/e2e/analyzer/video/default_encoded_image_data_injector_unittest.cc +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "test/pc/e2e/analyzer/video/default_encoded_image_data_injector.h" - -#include - -#include "api/video/encoded_image.h" -#include "rtc_base/buffer.h" -#include "test/gtest.h" - -namespace webrtc { -namespace webrtc_pc_e2e { -namespace { - -rtc::scoped_refptr -CreateEncodedImageBufferOfSizeNFilledWithValuesFromX(size_t n, uint8_t x) { - auto buffer = EncodedImageBuffer::Create(n); - for (size_t i = 0; i < n; ++i) { - buffer->data()[i] = static_cast(x + i); - } - return buffer; -} - -EncodedImage CreateEncodedImageOfSizeNFilledWithValuesFromX(size_t n, - uint8_t x) { - EncodedImage image; - image.SetEncodedData( - CreateEncodedImageBufferOfSizeNFilledWithValuesFromX(n, x)); - return image; -} - -TEST(DefaultEncodedImageDataInjector, InjectExtractDiscardFalse) { - DefaultEncodedImageDataInjector injector; - injector.Start(1); - - EncodedImage source = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 1); - source.SetTimestamp(123456789); - - EncodedImageExtractionResult out = - injector.ExtractData(injector.InjectData(512, false, source, 1), 2); - EXPECT_EQ(out.id, 512); - EXPECT_FALSE(out.discard); - EXPECT_EQ(out.image.size(), 10ul); - for (int i = 0; i < 10; ++i) { - EXPECT_EQ(out.image.data()[i], i + 1); - } -} - -TEST(DefaultEncodedImageDataInjector, InjectExtractDiscardTrue) { - DefaultEncodedImageDataInjector injector; - injector.Start(1); - - EncodedImage source = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 1); - source.SetTimestamp(123456789); - - EncodedImageExtractionResult out = - injector.ExtractData(injector.InjectData(512, true, source, 1), 2); - EXPECT_EQ(out.id, 512); - EXPECT_TRUE(out.discard); - EXPECT_EQ(out.image.size(), 0ul); -} - -TEST(DefaultEncodedImageDataInjector, Inject3Extract3) { - DefaultEncodedImageDataInjector injector; - injector.Start(1); - - // 1st frame - EncodedImage source1 = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 1); - source1.SetTimestamp(123456710); - // 2nd frame 1st spatial layer - EncodedImage source2 = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 11); - source2.SetTimestamp(123456720); - // 2nd frame 2nd spatial layer - EncodedImage source3 = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 21); - source3.SetTimestamp(123456720); - - EncodedImage intermediate1 = injector.InjectData(510, false, source1, 1); - EncodedImage intermediate2 = injector.InjectData(520, true, source2, 1); - EncodedImage intermediate3 = injector.InjectData(520, false, source3, 1); - - // Extract ids in different order. - EncodedImageExtractionResult out3 = injector.ExtractData(intermediate3, 2); - EncodedImageExtractionResult out1 = injector.ExtractData(intermediate1, 2); - EncodedImageExtractionResult out2 = injector.ExtractData(intermediate2, 2); - - EXPECT_EQ(out1.id, 510); - EXPECT_FALSE(out1.discard); - EXPECT_EQ(out1.image.size(), 10ul); - for (int i = 0; i < 10; ++i) { - EXPECT_EQ(out1.image.data()[i], i + 1); - } - EXPECT_EQ(out2.id, 520); - EXPECT_TRUE(out2.discard); - EXPECT_EQ(out2.image.size(), 0ul); - EXPECT_EQ(out3.id, 520); - EXPECT_FALSE(out3.discard); - EXPECT_EQ(out3.image.size(), 10ul); - for (int i = 0; i < 10; ++i) { - EXPECT_EQ(out3.image.data()[i], i + 21); - } -} - -TEST(DefaultEncodedImageDataInjector, InjectExtractFromConcatenated) { - DefaultEncodedImageDataInjector injector; - injector.Start(1); - - EncodedImage source1 = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 1); - source1.SetTimestamp(123456710); - EncodedImage source2 = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 11); - source2.SetTimestamp(123456710); - EncodedImage source3 = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 21); - source3.SetTimestamp(123456710); - - // Inject id into 3 images with same frame id. - EncodedImage intermediate1 = injector.InjectData(512, false, source1, 1); - EncodedImage intermediate2 = injector.InjectData(512, true, source2, 1); - EncodedImage intermediate3 = injector.InjectData(512, false, source3, 1); - - // Concatenate them into single encoded image, like it can be done in jitter - // buffer. - size_t concatenated_length = - intermediate1.size() + intermediate2.size() + intermediate3.size(); - rtc::Buffer concatenated_buffer; - concatenated_buffer.AppendData(intermediate1.data(), intermediate1.size()); - concatenated_buffer.AppendData(intermediate2.data(), intermediate2.size()); - concatenated_buffer.AppendData(intermediate3.data(), intermediate3.size()); - EncodedImage concatenated; - concatenated.SetEncodedData(EncodedImageBuffer::Create( - concatenated_buffer.data(), concatenated_length)); - - // Extract frame id from concatenated image - EncodedImageExtractionResult out = injector.ExtractData(concatenated, 2); - - EXPECT_EQ(out.id, 512); - EXPECT_FALSE(out.discard); - EXPECT_EQ(out.image.size(), 2 * 10ul); - for (int i = 0; i < 10; ++i) { - EXPECT_EQ(out.image.data()[i], i + 1); - EXPECT_EQ(out.image.data()[i + 10], i + 21); - } -} - -TEST(DefaultEncodedImageDataInjector, - InjectExtractFromConcatenatedAllDiscarded) { - DefaultEncodedImageDataInjector injector; - injector.Start(1); - - EncodedImage source1 = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 1); - source1.SetTimestamp(123456710); - EncodedImage source2 = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 11); - source2.SetTimestamp(123456710); - EncodedImage source3 = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 21); - source3.SetTimestamp(123456710); - - // Inject id into 3 images with same frame id. - EncodedImage intermediate1 = injector.InjectData(512, true, source1, 1); - EncodedImage intermediate2 = injector.InjectData(512, true, source2, 1); - EncodedImage intermediate3 = injector.InjectData(512, true, source3, 1); - - // Concatenate them into single encoded image, like it can be done in jitter - // buffer. - size_t concatenated_length = - intermediate1.size() + intermediate2.size() + intermediate3.size(); - rtc::Buffer concatenated_buffer; - concatenated_buffer.AppendData(intermediate1.data(), intermediate1.size()); - concatenated_buffer.AppendData(intermediate2.data(), intermediate2.size()); - concatenated_buffer.AppendData(intermediate3.data(), intermediate3.size()); - EncodedImage concatenated; - concatenated.SetEncodedData(EncodedImageBuffer::Create( - concatenated_buffer.data(), concatenated_length)); - - // Extract frame id from concatenated image - EncodedImageExtractionResult out = injector.ExtractData(concatenated, 2); - - EXPECT_EQ(out.id, 512); - EXPECT_TRUE(out.discard); - EXPECT_EQ(out.image.size(), 0ul); -} - -} // namespace -} // namespace webrtc_pc_e2e -} // namespace webrtc diff --git a/test/pc/e2e/analyzer/video/default_video_quality_analyzer.cc b/test/pc/e2e/analyzer/video/default_video_quality_analyzer.cc index 04999c3b49..53fb14e606 100644 --- a/test/pc/e2e/analyzer/video/default_video_quality_analyzer.cc +++ b/test/pc/e2e/analyzer/video/default_video_quality_analyzer.cc @@ -21,6 +21,7 @@ #include "common_video/libyuv/include/webrtc_libyuv.h" #include "rtc_base/cpu_time.h" #include "rtc_base/logging.h" +#include "rtc_base/platform_thread.h" #include "rtc_base/strings/string_builder.h" #include "rtc_base/time_utils.h" #include "rtc_tools/frame_analyzer/video_geometry_aligner.h" @@ -140,17 +141,14 @@ void DefaultVideoQualityAnalyzer::Start( rtc::ArrayView peer_names, int max_threads_count) { test_label_ = std::move(test_case_name); - peers_ = std::make_unique(peer_names); for (int i = 0; i < max_threads_count; i++) { - auto thread = std::make_unique( - &DefaultVideoQualityAnalyzer::ProcessComparisonsThread, this, - ("DefaultVideoQualityAnalyzerWorker-" + std::to_string(i)).data(), - rtc::ThreadPriority::kNormalPriority); - thread->Start(); - thread_pool_.push_back(std::move(thread)); + thread_pool_.push_back(rtc::PlatformThread::SpawnJoinable( + [this] { ProcessComparisons(); }, + "DefaultVideoQualityAnalyzerWorker-" + std::to_string(i))); } { MutexLock lock(&lock_); + peers_ = std::make_unique(peer_names); RTC_CHECK(start_time_.IsMinusInfinity()); state_ = State::kActive; @@ -166,19 +164,22 @@ uint16_t DefaultVideoQualityAnalyzer::OnFrameCaptured( // |next_frame_id| is atomic, so we needn't lock here. uint16_t frame_id = next_frame_id_++; Timestamp start_time = Timestamp::MinusInfinity(); - size_t peer_index = peers_->index(peer_name); + size_t peer_index = -1; + size_t peers_count = -1; size_t stream_index; { MutexLock lock(&lock_); - // Create a local copy of start_time_ to access it under - // |comparison_lock_| without holding a |lock_| + // Create a local copy of |start_time_|, peer's index and total peers count + // to access it under |comparison_lock_| without holding a |lock_| start_time = start_time_; + peer_index = peers_->index(peer_name); + peers_count = peers_->size(); stream_index = streams_.AddIfAbsent(stream_label); } { // Ensure stats for this stream exists. MutexLock lock(&comparison_lock_); - for (size_t i = 0; i < peers_->size(); ++i) { + for (size_t i = 0; i < peers_count; ++i) { if (i == peer_index) { continue; } @@ -349,17 +350,16 @@ void DefaultVideoQualityAnalyzer::OnFramePreDecode( stream_frame_counters_.at(key).received++; // Determine the time of the last received packet of this video frame. RTC_DCHECK(!input_image.PacketInfos().empty()); - int64_t last_receive_time = + Timestamp last_receive_time = std::max_element(input_image.PacketInfos().cbegin(), input_image.PacketInfos().cend(), [](const RtpPacketInfo& a, const RtpPacketInfo& b) { - return a.receive_time_ms() < b.receive_time_ms(); + return a.receive_time() < b.receive_time(); }) - ->receive_time_ms(); - it->second.OnFramePreDecode( - peer_index, - /*received_time=*/Timestamp::Millis(last_receive_time), - /*decode_start_time=*/Now()); + ->receive_time(); + it->second.OnFramePreDecode(peer_index, + /*received_time=*/last_receive_time, + /*decode_start_time=*/Now()); } void DefaultVideoQualityAnalyzer::OnFrameDecoded( @@ -463,7 +463,7 @@ void DefaultVideoQualityAnalyzer::OnFrameRendered( frame_in_flight->rendered_time(peer_index)); { MutexLock cr(&comparison_lock_); - stream_stats_[stats_key].skipped_between_rendered.AddSample( + stream_stats_.at(stats_key).skipped_between_rendered.AddSample( StatsSample(dropped_count, Now())); } @@ -516,6 +516,7 @@ void DefaultVideoQualityAnalyzer::RegisterParticipantInCall( counters.encoded = frames_count; stream_frame_counters_.insert({key, std::move(counters)}); + stream_stats_.insert({key, StreamStats()}); stream_last_freeze_end_time_.insert({key, start_time_}); } // Ensure, that frames states are handled correctly @@ -524,6 +525,10 @@ void DefaultVideoQualityAnalyzer::RegisterParticipantInCall( key_val.second.AddPeer(); } // Register new peer for every frame in flight. + // It is guaranteed, that no garbadge FrameInFlight objects will stay in + // memory because of adding new peer. Even if the new peer won't receive the + // frame, the frame will be removed by OnFrameRendered after next frame comes + // for the new peer. It is important because FrameInFlight is a large object. for (auto& key_val : captured_frames_in_flight_) { key_val.second.AddPeer(); } @@ -539,10 +544,6 @@ void DefaultVideoQualityAnalyzer::Stop() { } StopMeasuringCpuProcessTime(); comparison_available_event_.Set(); - for (auto& thread : thread_pool_) { - thread->Stop(); - } - // PlatformThread have to be deleted on the same thread, where it was created thread_pool_.clear(); // Perform final Metrics update. On this place analyzer is stopped and no one @@ -669,10 +670,6 @@ void DefaultVideoQualityAnalyzer::AddComparison( StopExcludingCpuThreadTime(); } -void DefaultVideoQualityAnalyzer::ProcessComparisonsThread(void* obj) { - static_cast(obj)->ProcessComparisons(); -} - void DefaultVideoQualityAnalyzer::ProcessComparisons() { while (true) { // Try to pick next comparison to perform from the queue. @@ -918,6 +915,9 @@ void DefaultVideoQualityAnalyzer::ReportResults( frame_counters.dropped, "count", /*important=*/false, ImproveDirection::kSmallerIsBetter); + test::PrintResult("rendered_frames", "", test_case_name, + frame_counters.rendered, "count", /*important=*/false, + ImproveDirection::kBiggerIsBetter); ReportResult("max_skipped", test_case_name, stats.skipped_between_rendered, "count", ImproveDirection::kSmallerIsBetter); ReportResult("target_encode_bitrate", test_case_name, @@ -956,7 +956,7 @@ StatsKey DefaultVideoQualityAnalyzer::ToStatsKey( } std::string DefaultVideoQualityAnalyzer::StatsKeyToMetricName( - const StatsKey& key) { + const StatsKey& key) const { if (peers_->size() <= 2) { return key.stream_label; } diff --git a/test/pc/e2e/analyzer/video/default_video_quality_analyzer.h b/test/pc/e2e/analyzer/video/default_video_quality_analyzer.h index f30e61b9d7..626fa246e5 100644 --- a/test/pc/e2e/analyzer/video/default_video_quality_analyzer.h +++ b/test/pc/e2e/analyzer/video/default_video_quality_analyzer.h @@ -504,7 +504,8 @@ class DefaultVideoQualityAnalyzer : public VideoQualityAnalyzerInterface { RTC_EXCLUSIVE_LOCKS_REQUIRED(lock_); // Returns string representation of stats key for metrics naming. Used for // backward compatibility by metrics naming for 2 peers cases. - std::string StatsKeyToMetricName(const StatsKey& key); + std::string StatsKeyToMetricName(const StatsKey& key) const + RTC_EXCLUSIVE_LOCKS_REQUIRED(lock_); void StartMeasuringCpuProcessTime(); void StopMeasuringCpuProcessTime(); @@ -517,9 +518,9 @@ class DefaultVideoQualityAnalyzer : public VideoQualityAnalyzerInterface { std::atomic next_frame_id_{0}; std::string test_label_; - std::unique_ptr peers_; mutable Mutex lock_; + std::unique_ptr peers_ RTC_GUARDED_BY(lock_); State state_ RTC_GUARDED_BY(lock_) = State::kNew; Timestamp start_time_ RTC_GUARDED_BY(lock_) = Timestamp::MinusInfinity(); // Mapping from stream label to unique size_t value to use in stats and avoid @@ -559,7 +560,7 @@ class DefaultVideoQualityAnalyzer : public VideoQualityAnalyzerInterface { std::deque comparisons_ RTC_GUARDED_BY(comparison_lock_); AnalyzerStats analyzer_stats_ RTC_GUARDED_BY(comparison_lock_); - std::vector> thread_pool_; + std::vector thread_pool_; rtc::Event comparison_available_event_; Mutex cpu_measurement_lock_; diff --git a/test/pc/e2e/analyzer/video/default_video_quality_analyzer_test.cc b/test/pc/e2e/analyzer/video/default_video_quality_analyzer_test.cc index 8b7ce86245..8d8a1af848 100644 --- a/test/pc/e2e/analyzer/video/default_video_quality_analyzer_test.cc +++ b/test/pc/e2e/analyzer/video/default_video_quality_analyzer_test.cc @@ -63,13 +63,13 @@ VideoFrame NextFrame(test::FrameGeneratorInterface* frame_generator, EncodedImage FakeEncode(const VideoFrame& frame) { EncodedImage image; std::vector packet_infos; - packet_infos.push_back( - RtpPacketInfo(/*ssrc=*/1, - /*csrcs=*/{}, - /*rtp_timestamp=*/frame.timestamp(), - /*audio_level=*/absl::nullopt, - /*absolute_capture_time=*/absl::nullopt, - /*receive_time_ms=*/frame.timestamp_us() + 10)); + packet_infos.push_back(RtpPacketInfo( + /*ssrc=*/1, + /*csrcs=*/{}, + /*rtp_timestamp=*/frame.timestamp(), + /*audio_level=*/absl::nullopt, + /*absolute_capture_time=*/absl::nullopt, + /*receive_time=*/Timestamp::Micros(frame.timestamp_us() + 10000))); image.SetPacketInfos(RtpPacketInfos(packet_infos)); return image; } @@ -100,7 +100,7 @@ std::string ToString(const std::vector& values) { } void FakeCPULoad() { - std::vector temp(100000); + std::vector temp(1000000); for (size_t i = 0; i < temp.size(); ++i) { temp[i] = rand(); } @@ -760,6 +760,11 @@ TEST(DefaultVideoQualityAnalyzerTest, RuntimeParticipantsAdding) { constexpr char kAlice[] = "alice"; constexpr char kBob[] = "bob"; constexpr char kCharlie[] = "charlie"; + constexpr char kKatie[] = "katie"; + + constexpr int kFramesCount = 9; + constexpr int kOneThirdFrames = kFramesCount / 3; + constexpr int kTwoThirdFrames = 2 * kOneThirdFrames; DefaultVideoQualityAnalyzer analyzer(Clock::GetRealTimeClock(), AnalyzerOptionsForTest()); @@ -769,7 +774,9 @@ TEST(DefaultVideoQualityAnalyzerTest, RuntimeParticipantsAdding) { std::vector frames_order; analyzer.RegisterParticipantInCall(kAlice); analyzer.RegisterParticipantInCall(kBob); - for (int i = 0; i < kMaxFramesInFlightPerStream; ++i) { + + // Alice is sending frames. + for (int i = 0; i < kFramesCount; ++i) { VideoFrame frame = NextFrame(frame_generator.get(), i); frame.set_id(analyzer.OnFrameCaptured(kAlice, kStreamLabel, frame)); frames_order.push_back(frame.id()); @@ -779,7 +786,8 @@ TEST(DefaultVideoQualityAnalyzerTest, RuntimeParticipantsAdding) { VideoQualityAnalyzerInterface::EncoderStats()); } - for (size_t i = 0; i < frames_order.size() / 2; ++i) { + // Bob receives one third of the sent frames. + for (int i = 0; i < kOneThirdFrames; ++i) { uint16_t frame_id = frames_order.at(i); VideoFrame received_frame = DeepCopy(captured_frames.at(frame_id)); analyzer.OnFramePreDecode(kBob, received_frame.id(), @@ -790,8 +798,11 @@ TEST(DefaultVideoQualityAnalyzerTest, RuntimeParticipantsAdding) { } analyzer.RegisterParticipantInCall(kCharlie); + analyzer.RegisterParticipantInCall(kKatie); - for (size_t i = frames_order.size() / 2; i < frames_order.size(); ++i) { + // New participants were dynamically added. Bob and Charlie receive second + // third of the sent frames. Katie drops the frames. + for (int i = kOneThirdFrames; i < kTwoThirdFrames; ++i) { uint16_t frame_id = frames_order.at(i); VideoFrame bob_received_frame = DeepCopy(captured_frames.at(frame_id)); analyzer.OnFramePreDecode(kBob, bob_received_frame.id(), @@ -808,6 +819,31 @@ TEST(DefaultVideoQualityAnalyzerTest, RuntimeParticipantsAdding) { analyzer.OnFrameRendered(kCharlie, charlie_received_frame); } + // Bob, Charlie and Katie receive the rest of the sent frames. + for (int i = kTwoThirdFrames; i < kFramesCount; ++i) { + uint16_t frame_id = frames_order.at(i); + VideoFrame bob_received_frame = DeepCopy(captured_frames.at(frame_id)); + analyzer.OnFramePreDecode(kBob, bob_received_frame.id(), + FakeEncode(bob_received_frame)); + analyzer.OnFrameDecoded(kBob, bob_received_frame, + VideoQualityAnalyzerInterface::DecoderStats()); + analyzer.OnFrameRendered(kBob, bob_received_frame); + + VideoFrame charlie_received_frame = DeepCopy(captured_frames.at(frame_id)); + analyzer.OnFramePreDecode(kCharlie, charlie_received_frame.id(), + FakeEncode(charlie_received_frame)); + analyzer.OnFrameDecoded(kCharlie, charlie_received_frame, + VideoQualityAnalyzerInterface::DecoderStats()); + analyzer.OnFrameRendered(kCharlie, charlie_received_frame); + + VideoFrame katie_received_frame = DeepCopy(captured_frames.at(frame_id)); + analyzer.OnFramePreDecode(kKatie, katie_received_frame.id(), + FakeEncode(katie_received_frame)); + analyzer.OnFrameDecoded(kKatie, katie_received_frame, + VideoQualityAnalyzerInterface::DecoderStats()); + analyzer.OnFrameRendered(kKatie, katie_received_frame); + } + // Give analyzer some time to process frames on async thread. The computations // have to be fast (heavy metrics are disabled!), so if doesn't fit 100ms it // means we have an issue! @@ -816,8 +852,7 @@ TEST(DefaultVideoQualityAnalyzerTest, RuntimeParticipantsAdding) { AnalyzerStats stats = analyzer.GetAnalyzerStats(); EXPECT_EQ(stats.memory_overloaded_comparisons_done, 0); - EXPECT_EQ(stats.comparisons_done, - kMaxFramesInFlightPerStream + kMaxFramesInFlightPerStream / 2); + EXPECT_EQ(stats.comparisons_done, kFramesCount + 2 * kTwoThirdFrames); std::vector frames_in_flight_sizes = GetSortedSamples(stats.frames_in_flight_left_count); @@ -825,37 +860,45 @@ TEST(DefaultVideoQualityAnalyzerTest, RuntimeParticipantsAdding) { << ToString(frames_in_flight_sizes); FrameCounters frame_counters = analyzer.GetGlobalCounters(); - EXPECT_EQ(frame_counters.captured, kMaxFramesInFlightPerStream); - EXPECT_EQ(frame_counters.received, - kMaxFramesInFlightPerStream + kMaxFramesInFlightPerStream / 2); - EXPECT_EQ(frame_counters.decoded, - kMaxFramesInFlightPerStream + kMaxFramesInFlightPerStream / 2); - EXPECT_EQ(frame_counters.rendered, - kMaxFramesInFlightPerStream + kMaxFramesInFlightPerStream / 2); - EXPECT_EQ(frame_counters.dropped, 0); + EXPECT_EQ(frame_counters.captured, kFramesCount); + EXPECT_EQ(frame_counters.received, 2 * kFramesCount); + EXPECT_EQ(frame_counters.decoded, 2 * kFramesCount); + EXPECT_EQ(frame_counters.rendered, 2 * kFramesCount); + EXPECT_EQ(frame_counters.dropped, kOneThirdFrames); - EXPECT_EQ(analyzer.GetKnownVideoStreams().size(), 2lu); + EXPECT_EQ(analyzer.GetKnownVideoStreams().size(), 3lu); const StatsKey kAliceBobStats(kStreamLabel, kAlice, kBob); const StatsKey kAliceCharlieStats(kStreamLabel, kAlice, kCharlie); + const StatsKey kAliceKatieStats(kStreamLabel, kAlice, kKatie); { FrameCounters stream_conters = analyzer.GetPerStreamCounters().at(kAliceBobStats); - EXPECT_EQ(stream_conters.captured, 10); - EXPECT_EQ(stream_conters.pre_encoded, 10); - EXPECT_EQ(stream_conters.encoded, 10); - EXPECT_EQ(stream_conters.received, 10); - EXPECT_EQ(stream_conters.decoded, 10); - EXPECT_EQ(stream_conters.rendered, 10); + EXPECT_EQ(stream_conters.captured, kFramesCount); + EXPECT_EQ(stream_conters.pre_encoded, kFramesCount); + EXPECT_EQ(stream_conters.encoded, kFramesCount); + EXPECT_EQ(stream_conters.received, kFramesCount); + EXPECT_EQ(stream_conters.decoded, kFramesCount); + EXPECT_EQ(stream_conters.rendered, kFramesCount); } { FrameCounters stream_conters = analyzer.GetPerStreamCounters().at(kAliceCharlieStats); - EXPECT_EQ(stream_conters.captured, 5); - EXPECT_EQ(stream_conters.pre_encoded, 5); - EXPECT_EQ(stream_conters.encoded, 5); - EXPECT_EQ(stream_conters.received, 5); - EXPECT_EQ(stream_conters.decoded, 5); - EXPECT_EQ(stream_conters.rendered, 5); + EXPECT_EQ(stream_conters.captured, kTwoThirdFrames); + EXPECT_EQ(stream_conters.pre_encoded, kTwoThirdFrames); + EXPECT_EQ(stream_conters.encoded, kTwoThirdFrames); + EXPECT_EQ(stream_conters.received, kTwoThirdFrames); + EXPECT_EQ(stream_conters.decoded, kTwoThirdFrames); + EXPECT_EQ(stream_conters.rendered, kTwoThirdFrames); + } + { + FrameCounters stream_conters = + analyzer.GetPerStreamCounters().at(kAliceKatieStats); + EXPECT_EQ(stream_conters.captured, kTwoThirdFrames); + EXPECT_EQ(stream_conters.pre_encoded, kTwoThirdFrames); + EXPECT_EQ(stream_conters.encoded, kTwoThirdFrames); + EXPECT_EQ(stream_conters.received, kOneThirdFrames); + EXPECT_EQ(stream_conters.decoded, kOneThirdFrames); + EXPECT_EQ(stream_conters.rendered, kOneThirdFrames); } } diff --git a/test/pc/e2e/analyzer/video/encoded_image_data_injector.h b/test/pc/e2e/analyzer/video/encoded_image_data_injector.h index ddd6959b91..154e38e43f 100644 --- a/test/pc/e2e/analyzer/video/encoded_image_data_injector.h +++ b/test/pc/e2e/analyzer/video/encoded_image_data_injector.h @@ -27,11 +27,10 @@ class EncodedImageDataInjector { // Return encoded image with specified |id| and |discard| flag injected into // its payload. |discard| flag mean does analyzing decoder should discard this // encoded image because it belongs to unnecessary simulcast stream or spatial - // layer. |coding_entity_id| is unique id of decoder or encoder. + // layer. virtual EncodedImage InjectData(uint16_t id, bool discard, - const EncodedImage& source, - int coding_entity_id) = 0; + const EncodedImage& source) = 0; }; struct EncodedImageExtractionResult { @@ -52,11 +51,15 @@ class EncodedImageDataExtractor { // encoded image. virtual void Start(int expected_receivers_count) = 0; + // Invoked by framework when it is required to add one more receiver for + // frames. Will be invoked before that receiver will start receive data. + virtual void AddParticipantInCall() = 0; + // Returns encoded image id, extracted from payload and also encoded image // with its original payload. For concatenated spatial layers it should be the - // same id. |coding_entity_id| is unique id of decoder or encoder. - virtual EncodedImageExtractionResult ExtractData(const EncodedImage& source, - int coding_entity_id) = 0; + // same id. + virtual EncodedImageExtractionResult ExtractData( + const EncodedImage& source) = 0; }; } // namespace webrtc_pc_e2e diff --git a/test/pc/e2e/analyzer/video/id_generator.h b/test/pc/e2e/analyzer/video/id_generator.h deleted file mode 100644 index 8c988f211a..0000000000 --- a/test/pc/e2e/analyzer/video/id_generator.h +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef TEST_PC_E2E_ANALYZER_VIDEO_ID_GENERATOR_H_ -#define TEST_PC_E2E_ANALYZER_VIDEO_ID_GENERATOR_H_ - -#include - -namespace webrtc { -namespace webrtc_pc_e2e { - -// IdGenerator generates ids. All provided ids have to be unique. There is no -// any order guarantees for provided ids. -template -class IdGenerator { - public: - virtual ~IdGenerator() = default; - - // Returns next unique id. There is no any order guarantees for provided ids. - virtual T GetNextId() = 0; -}; - -// Generates int ids. It is assumed, that no more then max int value ids will be -// requested from this generator. -class IntIdGenerator : public IdGenerator { - public: - explicit IntIdGenerator(int start_value); - ~IntIdGenerator() override; - - int GetNextId() override; - - private: - std::atomic next_id_; -}; - -} // namespace webrtc_pc_e2e -} // namespace webrtc - -#endif // TEST_PC_E2E_ANALYZER_VIDEO_ID_GENERATOR_H_ diff --git a/test/pc/e2e/analyzer/video/quality_analyzing_video_decoder.cc b/test/pc/e2e/analyzer/video/quality_analyzing_video_decoder.cc index 27b9af50bb..68b76cd37d 100644 --- a/test/pc/e2e/analyzer/video/quality_analyzing_video_decoder.cc +++ b/test/pc/e2e/analyzer/video/quality_analyzing_video_decoder.cc @@ -26,13 +26,11 @@ namespace webrtc { namespace webrtc_pc_e2e { QualityAnalyzingVideoDecoder::QualityAnalyzingVideoDecoder( - int id, absl::string_view peer_name, std::unique_ptr delegate, EncodedImageDataExtractor* extractor, VideoQualityAnalyzerInterface* analyzer) - : id_(id), - peer_name_(peer_name), + : peer_name_(peer_name), implementation_name_("AnalyzingDecoder-" + std::string(delegate->ImplementationName())), delegate_(std::move(delegate)), @@ -56,7 +54,7 @@ int32_t QualityAnalyzingVideoDecoder::Decode(const EncodedImage& input_image, // owner of original buffer will be responsible for deleting it, or extractor // can create a new buffer. In such case extractor will be responsible for // deleting it. - EncodedImageExtractionResult out = extractor_->ExtractData(input_image, id_); + EncodedImageExtractionResult out = extractor_->ExtractData(input_image); if (out.discard) { // To partly emulate behavior of Selective Forwarding Unit (SFU) in the @@ -235,12 +233,10 @@ void QualityAnalyzingVideoDecoder::OnFrameDecoded( QualityAnalyzingVideoDecoderFactory::QualityAnalyzingVideoDecoderFactory( absl::string_view peer_name, std::unique_ptr delegate, - IdGenerator* id_generator, EncodedImageDataExtractor* extractor, VideoQualityAnalyzerInterface* analyzer) : peer_name_(peer_name), delegate_(std::move(delegate)), - id_generator_(id_generator), extractor_(extractor), analyzer_(analyzer) {} QualityAnalyzingVideoDecoderFactory::~QualityAnalyzingVideoDecoderFactory() = @@ -256,19 +252,7 @@ QualityAnalyzingVideoDecoderFactory::CreateVideoDecoder( const SdpVideoFormat& format) { std::unique_ptr decoder = delegate_->CreateVideoDecoder(format); return std::make_unique( - id_generator_->GetNextId(), peer_name_, std::move(decoder), extractor_, - analyzer_); -} - -std::unique_ptr -QualityAnalyzingVideoDecoderFactory::LegacyCreateVideoDecoder( - const SdpVideoFormat& format, - const std::string& receive_stream_id) { - std::unique_ptr decoder = - delegate_->LegacyCreateVideoDecoder(format, receive_stream_id); - return std::make_unique( - id_generator_->GetNextId(), peer_name_, std::move(decoder), extractor_, - analyzer_); + peer_name_, std::move(decoder), extractor_, analyzer_); } } // namespace webrtc_pc_e2e diff --git a/test/pc/e2e/analyzer/video/quality_analyzing_video_decoder.h b/test/pc/e2e/analyzer/video/quality_analyzing_video_decoder.h index a26ccbe1ee..e150c91cb4 100644 --- a/test/pc/e2e/analyzer/video/quality_analyzing_video_decoder.h +++ b/test/pc/e2e/analyzer/video/quality_analyzing_video_decoder.h @@ -25,7 +25,6 @@ #include "api/video_codecs/video_decoder_factory.h" #include "rtc_base/synchronization/mutex.h" #include "test/pc/e2e/analyzer/video/encoded_image_data_injector.h" -#include "test/pc/e2e/analyzer/video/id_generator.h" namespace webrtc { namespace webrtc_pc_e2e { @@ -50,11 +49,7 @@ namespace webrtc_pc_e2e { // time the user registers their callback in quality decoder. class QualityAnalyzingVideoDecoder : public VideoDecoder { public: - // Creates analyzing decoder. |id| is unique coding entity id, that will - // be used to distinguish all encoders and decoders inside - // EncodedImageDataInjector and EncodedImageIdExtracor. - QualityAnalyzingVideoDecoder(int id, - absl::string_view peer_name, + QualityAnalyzingVideoDecoder(absl::string_view peer_name, std::unique_ptr delegate, EncodedImageDataExtractor* extractor, VideoQualityAnalyzerInterface* analyzer); @@ -105,7 +100,6 @@ class QualityAnalyzingVideoDecoder : public VideoDecoder { absl::optional decode_time_ms, absl::optional qp); - const int id_; const std::string peer_name_; const std::string implementation_name_; std::unique_ptr delegate_; @@ -134,7 +128,6 @@ class QualityAnalyzingVideoDecoderFactory : public VideoDecoderFactory { QualityAnalyzingVideoDecoderFactory( absl::string_view peer_name, std::unique_ptr delegate, - IdGenerator* id_generator, EncodedImageDataExtractor* extractor, VideoQualityAnalyzerInterface* analyzer); ~QualityAnalyzingVideoDecoderFactory() override; @@ -143,14 +136,10 @@ class QualityAnalyzingVideoDecoderFactory : public VideoDecoderFactory { std::vector GetSupportedFormats() const override; std::unique_ptr CreateVideoDecoder( const SdpVideoFormat& format) override; - std::unique_ptr LegacyCreateVideoDecoder( - const SdpVideoFormat& format, - const std::string& receive_stream_id) override; private: const std::string peer_name_; std::unique_ptr delegate_; - IdGenerator* const id_generator_; EncodedImageDataExtractor* const extractor_; VideoQualityAnalyzerInterface* const analyzer_; }; diff --git a/test/pc/e2e/analyzer/video/quality_analyzing_video_encoder.cc b/test/pc/e2e/analyzer/video/quality_analyzing_video_encoder.cc index 04ec892e12..5b8a571cd0 100644 --- a/test/pc/e2e/analyzer/video/quality_analyzing_video_encoder.cc +++ b/test/pc/e2e/analyzer/video/quality_analyzing_video_encoder.cc @@ -53,15 +53,13 @@ std::pair GetMinMaxBitratesBps(const VideoCodec& codec, } // namespace QualityAnalyzingVideoEncoder::QualityAnalyzingVideoEncoder( - int id, absl::string_view peer_name, std::unique_ptr delegate, double bitrate_multiplier, std::map> stream_required_spatial_index, EncodedImageDataInjector* injector, VideoQualityAnalyzerInterface* analyzer) - : id_(id), - peer_name_(peer_name), + : peer_name_(peer_name), delegate_(std::move(delegate)), bitrate_multiplier_(bitrate_multiplier), stream_required_spatial_index_(std::move(stream_required_spatial_index)), @@ -287,7 +285,7 @@ EncodedImageCallback::Result QualityAnalyzingVideoEncoder::OnEncodedImage( // it) or b) a new buffer (in such case injector will be responsible for // deleting it). const EncodedImage& image = - injector_->InjectData(frame_id, discard, encoded_image, id_); + injector_->InjectData(frame_id, discard, encoded_image); { MutexLock lock(&lock_); RTC_DCHECK(delegate_callback_); @@ -352,14 +350,12 @@ QualityAnalyzingVideoEncoderFactory::QualityAnalyzingVideoEncoderFactory( std::unique_ptr delegate, double bitrate_multiplier, std::map> stream_required_spatial_index, - IdGenerator* id_generator, EncodedImageDataInjector* injector, VideoQualityAnalyzerInterface* analyzer) : peer_name_(peer_name), delegate_(std::move(delegate)), bitrate_multiplier_(bitrate_multiplier), stream_required_spatial_index_(std::move(stream_required_spatial_index)), - id_generator_(id_generator), injector_(injector), analyzer_(analyzer) {} QualityAnalyzingVideoEncoderFactory::~QualityAnalyzingVideoEncoderFactory() = @@ -380,8 +376,7 @@ std::unique_ptr QualityAnalyzingVideoEncoderFactory::CreateVideoEncoder( const SdpVideoFormat& format) { return std::make_unique( - id_generator_->GetNextId(), peer_name_, - delegate_->CreateVideoEncoder(format), bitrate_multiplier_, + peer_name_, delegate_->CreateVideoEncoder(format), bitrate_multiplier_, stream_required_spatial_index_, injector_, analyzer_); } diff --git a/test/pc/e2e/analyzer/video/quality_analyzing_video_encoder.h b/test/pc/e2e/analyzer/video/quality_analyzing_video_encoder.h index 96d9d77e34..2ba8bdcb38 100644 --- a/test/pc/e2e/analyzer/video/quality_analyzing_video_encoder.h +++ b/test/pc/e2e/analyzer/video/quality_analyzing_video_encoder.h @@ -25,7 +25,6 @@ #include "api/video_codecs/video_encoder_factory.h" #include "rtc_base/synchronization/mutex.h" #include "test/pc/e2e/analyzer/video/encoded_image_data_injector.h" -#include "test/pc/e2e/analyzer/video/id_generator.h" namespace webrtc { namespace webrtc_pc_e2e { @@ -55,11 +54,7 @@ constexpr int kAnalyzeAnySpatialStream = -1; class QualityAnalyzingVideoEncoder : public VideoEncoder, public EncodedImageCallback { public: - // Creates analyzing encoder. |id| is unique coding entity id, that will - // be used to distinguish all encoders and decoders inside - // EncodedImageDataInjector and EncodedImageIdExtracor. QualityAnalyzingVideoEncoder( - int id, absl::string_view peer_name, std::unique_ptr delegate, double bitrate_multiplier, @@ -139,7 +134,6 @@ class QualityAnalyzingVideoEncoder : public VideoEncoder, bool ShouldDiscard(uint16_t frame_id, const EncodedImage& encoded_image) RTC_EXCLUSIVE_LOCKS_REQUIRED(lock_); - const int id_; const std::string peer_name_; std::unique_ptr delegate_; const double bitrate_multiplier_; @@ -176,7 +170,6 @@ class QualityAnalyzingVideoEncoderFactory : public VideoEncoderFactory { std::unique_ptr delegate, double bitrate_multiplier, std::map> stream_required_spatial_index, - IdGenerator* id_generator, EncodedImageDataInjector* injector, VideoQualityAnalyzerInterface* analyzer); ~QualityAnalyzingVideoEncoderFactory() override; @@ -193,7 +186,6 @@ class QualityAnalyzingVideoEncoderFactory : public VideoEncoderFactory { std::unique_ptr delegate_; const double bitrate_multiplier_; std::map> stream_required_spatial_index_; - IdGenerator* const id_generator_; EncodedImageDataInjector* const injector_; VideoQualityAnalyzerInterface* const analyzer_; }; diff --git a/test/pc/e2e/analyzer/video/single_process_encoded_image_data_injector.cc b/test/pc/e2e/analyzer/video/single_process_encoded_image_data_injector.cc index 304cb67d37..d7ee0f41b9 100644 --- a/test/pc/e2e/analyzer/video/single_process_encoded_image_data_injector.cc +++ b/test/pc/e2e/analyzer/video/single_process_encoded_image_data_injector.cc @@ -28,8 +28,7 @@ SingleProcessEncodedImageDataInjector:: EncodedImage SingleProcessEncodedImageDataInjector::InjectData( uint16_t id, bool discard, - const EncodedImage& source, - int coding_entity_id) { + const EncodedImage& source) { RTC_CHECK(source.size() >= ExtractionInfo::kUsedBufferSize); ExtractionInfo info; @@ -55,9 +54,13 @@ EncodedImage SingleProcessEncodedImageDataInjector::InjectData( return out; } +void SingleProcessEncodedImageDataInjector::AddParticipantInCall() { + MutexLock crit(&lock_); + expected_receivers_count_++; +} + EncodedImageExtractionResult SingleProcessEncodedImageDataInjector::ExtractData( - const EncodedImage& source, - int coding_entity_id) { + const EncodedImage& source) { size_t size = source.size(); auto buffer = EncodedImageBuffer::Create(source.data(), source.size()); EncodedImage out = source; diff --git a/test/pc/e2e/analyzer/video/single_process_encoded_image_data_injector.h b/test/pc/e2e/analyzer/video/single_process_encoded_image_data_injector.h index 8cf1bc4828..03feb7997c 100644 --- a/test/pc/e2e/analyzer/video/single_process_encoded_image_data_injector.h +++ b/test/pc/e2e/analyzer/video/single_process_encoded_image_data_injector.h @@ -48,15 +48,14 @@ class SingleProcessEncodedImageDataInjector : public EncodedImageDataInjector, // changed. EncodedImage InjectData(uint16_t id, bool discard, - const EncodedImage& source, - int coding_entity_id) override; + const EncodedImage& source) override; void Start(int expected_receivers_count) override { MutexLock crit(&lock_); expected_receivers_count_ = expected_receivers_count; } - EncodedImageExtractionResult ExtractData(const EncodedImage& source, - int coding_entity_id) override; + void AddParticipantInCall() override; + EncodedImageExtractionResult ExtractData(const EncodedImage& source) override; private: // Contains data required to extract frame id from EncodedImage and restore diff --git a/test/pc/e2e/analyzer/video/single_process_encoded_image_data_injector_unittest.cc b/test/pc/e2e/analyzer/video/single_process_encoded_image_data_injector_unittest.cc index da2391467d..cfeab23562 100644 --- a/test/pc/e2e/analyzer/video/single_process_encoded_image_data_injector_unittest.cc +++ b/test/pc/e2e/analyzer/video/single_process_encoded_image_data_injector_unittest.cc @@ -37,7 +37,7 @@ EncodedImage CreateEncodedImageOfSizeNFilledWithValuesFromX(size_t n, return image; } -TEST(SingleProcessEncodedImageDataInjector, InjectExtractDiscardFalse) { +TEST(SingleProcessEncodedImageDataInjectorTest, InjectExtractDiscardFalse) { SingleProcessEncodedImageDataInjector injector; injector.Start(1); @@ -45,7 +45,7 @@ TEST(SingleProcessEncodedImageDataInjector, InjectExtractDiscardFalse) { source.SetTimestamp(123456789); EncodedImageExtractionResult out = - injector.ExtractData(injector.InjectData(512, false, source, 1), 2); + injector.ExtractData(injector.InjectData(512, false, source)); EXPECT_EQ(out.id, 512); EXPECT_FALSE(out.discard); EXPECT_EQ(out.image.size(), 10ul); @@ -55,7 +55,7 @@ TEST(SingleProcessEncodedImageDataInjector, InjectExtractDiscardFalse) { } } -TEST(SingleProcessEncodedImageDataInjector, InjectExtractDiscardTrue) { +TEST(SingleProcessEncodedImageDataInjectorTest, InjectExtractDiscardTrue) { SingleProcessEncodedImageDataInjector injector; injector.Start(1); @@ -63,24 +63,25 @@ TEST(SingleProcessEncodedImageDataInjector, InjectExtractDiscardTrue) { source.SetTimestamp(123456789); EncodedImageExtractionResult out = - injector.ExtractData(injector.InjectData(512, true, source, 1), 2); + injector.ExtractData(injector.InjectData(512, true, source)); EXPECT_EQ(out.id, 512); EXPECT_TRUE(out.discard); EXPECT_EQ(out.image.size(), 0ul); EXPECT_EQ(out.image.SpatialLayerFrameSize(0).value_or(0), 0ul); } -TEST(SingleProcessEncodedImageDataInjector, InjectWithUnsetSpatialLayerSizes) { +TEST(SingleProcessEncodedImageDataInjectorTest, + InjectWithUnsetSpatialLayerSizes) { SingleProcessEncodedImageDataInjector injector; injector.Start(1); EncodedImage source = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 1); source.SetTimestamp(123456789); - EncodedImage intermediate = injector.InjectData(512, false, source, 1); + EncodedImage intermediate = injector.InjectData(512, false, source); intermediate.SetSpatialIndex(2); - EncodedImageExtractionResult out = injector.ExtractData(intermediate, 2); + EncodedImageExtractionResult out = injector.ExtractData(intermediate); EXPECT_EQ(out.id, 512); EXPECT_FALSE(out.discard); EXPECT_EQ(out.image.size(), 10ul); @@ -93,20 +94,21 @@ TEST(SingleProcessEncodedImageDataInjector, InjectWithUnsetSpatialLayerSizes) { } } -TEST(SingleProcessEncodedImageDataInjector, InjectWithZeroSpatialLayerSizes) { +TEST(SingleProcessEncodedImageDataInjectorTest, + InjectWithZeroSpatialLayerSizes) { SingleProcessEncodedImageDataInjector injector; injector.Start(1); EncodedImage source = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 1); source.SetTimestamp(123456789); - EncodedImage intermediate = injector.InjectData(512, false, source, 1); + EncodedImage intermediate = injector.InjectData(512, false, source); intermediate.SetSpatialIndex(2); intermediate.SetSpatialLayerFrameSize(0, 0); intermediate.SetSpatialLayerFrameSize(1, 0); intermediate.SetSpatialLayerFrameSize(2, 0); - EncodedImageExtractionResult out = injector.ExtractData(intermediate, 2); + EncodedImageExtractionResult out = injector.ExtractData(intermediate); EXPECT_EQ(out.id, 512); EXPECT_FALSE(out.discard); EXPECT_EQ(out.image.size(), 10ul); @@ -119,7 +121,7 @@ TEST(SingleProcessEncodedImageDataInjector, InjectWithZeroSpatialLayerSizes) { } } -TEST(SingleProcessEncodedImageDataInjector, Inject3Extract3) { +TEST(SingleProcessEncodedImageDataInjectorTest, Inject3Extract3) { SingleProcessEncodedImageDataInjector injector; injector.Start(1); @@ -133,14 +135,14 @@ TEST(SingleProcessEncodedImageDataInjector, Inject3Extract3) { EncodedImage source3 = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 21); source3.SetTimestamp(123456720); - EncodedImage intermediate1 = injector.InjectData(510, false, source1, 1); - EncodedImage intermediate2 = injector.InjectData(520, true, source2, 1); - EncodedImage intermediate3 = injector.InjectData(520, false, source3, 1); + EncodedImage intermediate1 = injector.InjectData(510, false, source1); + EncodedImage intermediate2 = injector.InjectData(520, true, source2); + EncodedImage intermediate3 = injector.InjectData(520, false, source3); // Extract ids in different order. - EncodedImageExtractionResult out3 = injector.ExtractData(intermediate3, 2); - EncodedImageExtractionResult out1 = injector.ExtractData(intermediate1, 2); - EncodedImageExtractionResult out2 = injector.ExtractData(intermediate2, 2); + EncodedImageExtractionResult out3 = injector.ExtractData(intermediate3); + EncodedImageExtractionResult out1 = injector.ExtractData(intermediate1); + EncodedImageExtractionResult out2 = injector.ExtractData(intermediate2); EXPECT_EQ(out1.id, 510); EXPECT_FALSE(out1.discard); @@ -162,7 +164,7 @@ TEST(SingleProcessEncodedImageDataInjector, Inject3Extract3) { } } -TEST(SingleProcessEncodedImageDataInjector, InjectExtractFromConcatenated) { +TEST(SingleProcessEncodedImageDataInjectorTest, InjectExtractFromConcatenated) { SingleProcessEncodedImageDataInjector injector; injector.Start(1); @@ -174,9 +176,9 @@ TEST(SingleProcessEncodedImageDataInjector, InjectExtractFromConcatenated) { source3.SetTimestamp(123456710); // Inject id into 3 images with same frame id. - EncodedImage intermediate1 = injector.InjectData(512, false, source1, 1); - EncodedImage intermediate2 = injector.InjectData(512, true, source2, 1); - EncodedImage intermediate3 = injector.InjectData(512, false, source3, 1); + EncodedImage intermediate1 = injector.InjectData(512, false, source1); + EncodedImage intermediate2 = injector.InjectData(512, true, source2); + EncodedImage intermediate3 = injector.InjectData(512, false, source3); // Concatenate them into single encoded image, like it can be done in jitter // buffer. @@ -195,7 +197,7 @@ TEST(SingleProcessEncodedImageDataInjector, InjectExtractFromConcatenated) { concatenated.SetSpatialLayerFrameSize(2, intermediate3.size()); // Extract frame id from concatenated image - EncodedImageExtractionResult out = injector.ExtractData(concatenated, 2); + EncodedImageExtractionResult out = injector.ExtractData(concatenated); EXPECT_EQ(out.id, 512); EXPECT_FALSE(out.discard); @@ -223,9 +225,9 @@ TEST(SingleProcessEncodedImageDataInjector, source3.SetTimestamp(123456710); // Inject id into 3 images with same frame id. - EncodedImage intermediate1 = injector.InjectData(512, true, source1, 1); - EncodedImage intermediate2 = injector.InjectData(512, true, source2, 1); - EncodedImage intermediate3 = injector.InjectData(512, true, source3, 1); + EncodedImage intermediate1 = injector.InjectData(512, true, source1); + EncodedImage intermediate2 = injector.InjectData(512, true, source2); + EncodedImage intermediate3 = injector.InjectData(512, true, source3); // Concatenate them into single encoded image, like it can be done in jitter // buffer. @@ -244,7 +246,7 @@ TEST(SingleProcessEncodedImageDataInjector, concatenated.SetSpatialLayerFrameSize(2, intermediate3.size()); // Extract frame id from concatenated image - EncodedImageExtractionResult out = injector.ExtractData(concatenated, 2); + EncodedImageExtractionResult out = injector.ExtractData(concatenated); EXPECT_EQ(out.id, 512); EXPECT_TRUE(out.discard); @@ -255,17 +257,45 @@ TEST(SingleProcessEncodedImageDataInjector, } } -TEST(SingleProcessEncodedImageDataInjector, InjectOnceExtractTwice) { +TEST(SingleProcessEncodedImageDataInjectorTest, InjectOnceExtractTwice) { SingleProcessEncodedImageDataInjector injector; injector.Start(2); EncodedImage source = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 1); source.SetTimestamp(123456789); - EncodedImageExtractionResult out = - injector.ExtractData(injector.InjectData(/*id=*/512, /*discard=*/false, - source, /*coding_entity_id=*/1), - /*coding_entity_id=*/2); + EncodedImageExtractionResult out = injector.ExtractData( + injector.InjectData(/*id=*/512, /*discard=*/false, source)); + EXPECT_EQ(out.id, 512); + EXPECT_FALSE(out.discard); + EXPECT_EQ(out.image.size(), 10ul); + EXPECT_EQ(out.image.SpatialLayerFrameSize(0).value_or(0), 0ul); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(out.image.data()[i], i + 1); + } + out = injector.ExtractData( + injector.InjectData(/*id=*/512, /*discard=*/false, source)); + EXPECT_EQ(out.id, 512); + EXPECT_FALSE(out.discard); + EXPECT_EQ(out.image.size(), 10ul); + EXPECT_EQ(out.image.SpatialLayerFrameSize(0).value_or(0), 0ul); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(out.image.data()[i], i + 1); + } +} + +TEST(SingleProcessEncodedImageDataInjectorTest, Add1stReceiverAfterStart) { + SingleProcessEncodedImageDataInjector injector; + injector.Start(0); + + EncodedImage source = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 1); + source.SetTimestamp(123456789); + EncodedImage modified_image = injector.InjectData( + /*id=*/512, /*discard=*/false, source); + + injector.AddParticipantInCall(); + EncodedImageExtractionResult out = injector.ExtractData(modified_image); + EXPECT_EQ(out.id, 512); EXPECT_FALSE(out.discard); EXPECT_EQ(out.image.size(), 10ul); @@ -273,10 +303,22 @@ TEST(SingleProcessEncodedImageDataInjector, InjectOnceExtractTwice) { for (int i = 0; i < 10; ++i) { EXPECT_EQ(out.image.data()[i], i + 1); } - out = - injector.ExtractData(injector.InjectData(/*id=*/512, /*discard=*/false, - source, /*coding_entity_id=*/1), - 2); +} + +TEST(SingleProcessEncodedImageDataInjectorTest, Add3rdReceiverAfterStart) { + SingleProcessEncodedImageDataInjector injector; + injector.Start(2); + + EncodedImage source = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 1); + source.SetTimestamp(123456789); + EncodedImage modified_image = injector.InjectData( + /*id=*/512, /*discard=*/false, source); + injector.ExtractData(modified_image); + + injector.AddParticipantInCall(); + injector.ExtractData(modified_image); + EncodedImageExtractionResult out = injector.ExtractData(modified_image); + EXPECT_EQ(out.id, 512); EXPECT_FALSE(out.discard); EXPECT_EQ(out.image.size(), 10ul); @@ -296,20 +338,20 @@ EncodedImage DeepCopyEncodedImage(const EncodedImage& source) { return copy; } -TEST(SingleProcessEncodedImageDataInjector, InjectOnceExtractMoreThenExpected) { +TEST(SingleProcessEncodedImageDataInjectorTest, + InjectOnceExtractMoreThenExpected) { SingleProcessEncodedImageDataInjector injector; injector.Start(2); EncodedImage source = CreateEncodedImageOfSizeNFilledWithValuesFromX(10, 1); source.SetTimestamp(123456789); - EncodedImage modified = injector.InjectData(/*id=*/512, /*discard=*/false, - source, /*coding_entity_id=*/1); + EncodedImage modified = + injector.InjectData(/*id=*/512, /*discard=*/false, source); - injector.ExtractData(DeepCopyEncodedImage(modified), /*coding_entity_id=*/2); - injector.ExtractData(DeepCopyEncodedImage(modified), /*coding_entity_id=*/2); - EXPECT_DEATH(injector.ExtractData(DeepCopyEncodedImage(modified), - /*coding_entity_id=*/2), + injector.ExtractData(DeepCopyEncodedImage(modified)); + injector.ExtractData(DeepCopyEncodedImage(modified)); + EXPECT_DEATH(injector.ExtractData(DeepCopyEncodedImage(modified)), "Unknown sub_id=0 for frame_id=512"); } #endif // RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) diff --git a/test/pc/e2e/analyzer/video/video_frame_tracking_id_injector.cc b/test/pc/e2e/analyzer/video/video_frame_tracking_id_injector.cc new file mode 100644 index 0000000000..e149e3f250 --- /dev/null +++ b/test/pc/e2e/analyzer/video/video_frame_tracking_id_injector.cc @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "test/pc/e2e/analyzer/video/video_frame_tracking_id_injector.h" + +#include "absl/memory/memory.h" +#include "api/video/encoded_image.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace webrtc_pc_e2e { + +EncodedImage VideoFrameTrackingIdInjector::InjectData( + uint16_t id, + bool unused_discard, + const EncodedImage& source) { + RTC_CHECK(!unused_discard); + EncodedImage out = source; + out.SetVideoFrameTrackingId(id); + return out; +} + +EncodedImageExtractionResult VideoFrameTrackingIdInjector::ExtractData( + const EncodedImage& source) { + return EncodedImageExtractionResult{source.VideoFrameTrackingId().value_or(0), + source, /*discard=*/false}; +} + +} // namespace webrtc_pc_e2e +} // namespace webrtc diff --git a/test/pc/e2e/analyzer/video/video_frame_tracking_id_injector.h b/test/pc/e2e/analyzer/video/video_frame_tracking_id_injector.h new file mode 100644 index 0000000000..aac7c3726a --- /dev/null +++ b/test/pc/e2e/analyzer/video/video_frame_tracking_id_injector.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef TEST_PC_E2E_ANALYZER_VIDEO_VIDEO_FRAME_TRACKING_ID_INJECTOR_H_ +#define TEST_PC_E2E_ANALYZER_VIDEO_VIDEO_FRAME_TRACKING_ID_INJECTOR_H_ + +#include + +#include "api/video/encoded_image.h" +#include "test/pc/e2e/analyzer/video/encoded_image_data_injector.h" + +namespace webrtc { +namespace webrtc_pc_e2e { + +// This injector sets and retrieves the provided id in the EncodedImage +// video_frame_tracking_id field. This is only possible with the RTP header +// extension VideoFrameTrackingIdExtension that will propagate the input +// tracking id to the received EncodedImage. This RTP header extension is +// enabled with the field trial WebRTC-VideoFrameTrackingIdAdvertised +// (http://www.webrtc.org/experiments/rtp-hdrext/video-frame-tracking-id). +// +// Note that this injector doesn't allow to discard frames. +class VideoFrameTrackingIdInjector : public EncodedImageDataInjector, + public EncodedImageDataExtractor { + public: + EncodedImage InjectData(uint16_t id, + bool unused_discard, + const EncodedImage& source) override; + + EncodedImageExtractionResult ExtractData(const EncodedImage& source) override; + + void Start(int) override {} + void AddParticipantInCall() override {} +}; + +} // namespace webrtc_pc_e2e +} // namespace webrtc + +#endif // TEST_PC_E2E_ANALYZER_VIDEO_VIDEO_FRAME_TRACKING_ID_INJECTOR_H_ diff --git a/test/pc/e2e/analyzer/video/video_frame_tracking_id_injector_unittest.cc b/test/pc/e2e/analyzer/video/video_frame_tracking_id_injector_unittest.cc new file mode 100644 index 0000000000..af85b2283f --- /dev/null +++ b/test/pc/e2e/analyzer/video/video_frame_tracking_id_injector_unittest.cc @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "test/pc/e2e/analyzer/video/video_frame_tracking_id_injector.h" + +#include "api/video/encoded_image.h" +#include "rtc_base/buffer.h" +#include "test/gtest.h" + +namespace webrtc { +namespace webrtc_pc_e2e { +namespace { + +EncodedImage CreateEncodedImageOfSizeN(size_t n) { + EncodedImage image; + rtc::scoped_refptr buffer = EncodedImageBuffer::Create(n); + for (size_t i = 0; i < n; ++i) { + buffer->data()[i] = static_cast(i); + } + image.SetEncodedData(buffer); + return image; +} + +TEST(VideoFrameTrackingIdInjectorTest, InjectExtractDiscardFalse) { + VideoFrameTrackingIdInjector injector; + EncodedImage source = CreateEncodedImageOfSizeN(10); + EncodedImageExtractionResult out = + injector.ExtractData(injector.InjectData(512, false, source)); + + EXPECT_EQ(out.id, 512); + EXPECT_FALSE(out.discard); + EXPECT_EQ(out.image.size(), 10ul); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(source.data()[i], out.image.data()[i]); + } +} + +#if GTEST_HAS_DEATH_TEST +TEST(VideoFrameTrackingIdInjectorTest, InjectExtractDiscardTrue) { + VideoFrameTrackingIdInjector injector; + EncodedImage source = CreateEncodedImageOfSizeN(10); + + EXPECT_DEATH(injector.InjectData(512, true, source), ""); +} +#endif // GTEST_HAS_DEATH_TEST + +} // namespace +} // namespace webrtc_pc_e2e +} // namespace webrtc diff --git a/test/pc/e2e/analyzer/video/video_quality_analyzer_injection_helper.cc b/test/pc/e2e/analyzer/video/video_quality_analyzer_injection_helper.cc index ebfb41697d..b1a22209be 100644 --- a/test/pc/e2e/analyzer/video/video_quality_analyzer_injection_helper.cc +++ b/test/pc/e2e/analyzer/video/video_quality_analyzer_injection_helper.cc @@ -28,17 +28,23 @@ namespace { class VideoWriter final : public rtc::VideoSinkInterface { public: - VideoWriter(test::VideoFrameWriter* video_writer) - : video_writer_(video_writer) {} + VideoWriter(test::VideoFrameWriter* video_writer, int sampling_modulo) + : video_writer_(video_writer), sampling_modulo_(sampling_modulo) {} ~VideoWriter() override = default; void OnFrame(const VideoFrame& frame) override { + if (frames_counter_++ % sampling_modulo_ != 0) { + return; + } bool result = video_writer_->WriteFrame(frame); RTC_CHECK(result) << "Failed to write frame"; } private: - test::VideoFrameWriter* video_writer_; + test::VideoFrameWriter* const video_writer_; + const int sampling_modulo_; + + int64_t frames_counter_ = 0; }; class AnalyzingFramePreprocessor @@ -84,8 +90,7 @@ VideoQualityAnalyzerInjectionHelper::VideoQualityAnalyzerInjectionHelper( EncodedImageDataExtractor* extractor) : analyzer_(std::move(analyzer)), injector_(injector), - extractor_(extractor), - encoding_entities_id_generator_(std::make_unique(1)) { + extractor_(extractor) { RTC_DCHECK(injector_); RTC_DCHECK(extractor_); } @@ -101,8 +106,7 @@ VideoQualityAnalyzerInjectionHelper::WrapVideoEncoderFactory( const { return std::make_unique( peer_name, std::move(delegate), bitrate_multiplier, - std::move(stream_required_spatial_index), - encoding_entities_id_generator_.get(), injector_, analyzer_.get()); + std::move(stream_required_spatial_index), injector_, analyzer_.get()); } std::unique_ptr @@ -110,8 +114,7 @@ VideoQualityAnalyzerInjectionHelper::WrapVideoDecoderFactory( absl::string_view peer_name, std::unique_ptr delegate) const { return std::make_unique( - peer_name, std::move(delegate), encoding_entities_id_generator_.get(), - extractor_, analyzer_.get()); + peer_name, std::move(delegate), extractor_, analyzer_.get()); } std::unique_ptr @@ -122,7 +125,8 @@ VideoQualityAnalyzerInjectionHelper::CreateFramePreprocessor( test::VideoFrameWriter* writer = MaybeCreateVideoWriter(config.input_dump_file_name, config); if (writer) { - sinks.push_back(std::make_unique(writer)); + sinks.push_back(std::make_unique( + writer, config.input_dump_sampling_modulo)); } if (config.show_on_screen) { sinks.push_back(absl::WrapUnique( @@ -225,7 +229,8 @@ VideoQualityAnalyzerInjectionHelper::PopulateSinks( test::VideoFrameWriter* writer = MaybeCreateVideoWriter(config.output_dump_file_name, config); if (writer) { - sinks.push_back(std::make_unique(writer)); + sinks.push_back(std::make_unique( + writer, config.output_dump_sampling_modulo)); } if (config.show_on_screen) { sinks.push_back(absl::WrapUnique( diff --git a/test/pc/e2e/analyzer/video/video_quality_analyzer_injection_helper.h b/test/pc/e2e/analyzer/video/video_quality_analyzer_injection_helper.h index 111aa3484e..85874cb5bc 100644 --- a/test/pc/e2e/analyzer/video/video_quality_analyzer_injection_helper.h +++ b/test/pc/e2e/analyzer/video/video_quality_analyzer_injection_helper.h @@ -27,7 +27,6 @@ #include "api/video_codecs/video_encoder_factory.h" #include "rtc_base/synchronization/mutex.h" #include "test/pc/e2e/analyzer/video/encoded_image_data_injector.h" -#include "test/pc/e2e/analyzer/video/id_generator.h" #include "test/test_video_capturer.h" #include "test/testsupport/video_frame_writer.h" @@ -50,6 +49,7 @@ class VideoQualityAnalyzerInjectionHelper : public StatsObserverInterface { // The method should be called before the participant is actually added. void RegisterParticipantInCall(absl::string_view peer_name) { analyzer_->RegisterParticipantInCall(peer_name); + extractor_->AddParticipantInCall(); } // Wraps video encoder factory to give video quality analyzer access to frames @@ -74,7 +74,8 @@ class VideoQualityAnalyzerInjectionHelper : public StatsObserverInterface { const VideoConfig& config); // Creates sink, that will allow video quality analyzer to get access to // the rendered frames. If corresponding video track has - // |output_dump_file_name| in its VideoConfig, then video also will be written + // |output_dump_file_name| in its VideoConfig, which was used for + // CreateFramePreprocessor(...), then video also will be written // into that file. std::unique_ptr> CreateVideoSink( absl::string_view peer_name); @@ -130,8 +131,6 @@ class VideoQualityAnalyzerInjectionHelper : public StatsObserverInterface { std::map>>> sinks_ RTC_GUARDED_BY(lock_); - - std::unique_ptr> encoding_entities_id_generator_; }; } // namespace webrtc_pc_e2e diff --git a/test/pc/e2e/analyzer_helper.h b/test/pc/e2e/analyzer_helper.h index 4b0e0c3ac4..9cebd7015e 100644 --- a/test/pc/e2e/analyzer_helper.h +++ b/test/pc/e2e/analyzer_helper.h @@ -15,8 +15,8 @@ #include #include "absl/strings/string_view.h" +#include "api/sequence_checker.h" #include "api/test/track_id_stream_info_map.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/thread_annotations.h" namespace webrtc { diff --git a/test/pc/e2e/g3doc/architecture.md b/test/pc/e2e/g3doc/architecture.md new file mode 100644 index 0000000000..5708054c37 --- /dev/null +++ b/test/pc/e2e/g3doc/architecture.md @@ -0,0 +1,208 @@ + + +# PeerConnection level framework fixture architecture + +## Overview + +The main implementation of +[`webrtc::webrtc_pc_e2e::PeerConnectionE2EQualityTestFixture`][1] is +[`webrtc::webrtc_pc_e2e::PeerConnectionE2EQualityTest`][2]. Internally it owns +the next main pieces: + +* [`MediaHelper`][3] - responsible for adding audio and video tracks to the + peers. +* [`VideoQualityAnalyzerInjectionHelper`][4] and + [`SingleProcessEncodedImageDataInjector`][5] - used to inject video quality + analysis and properly match captured and rendered video frames. You can read + more about it in + [DefaultVideoQualityAnalyzer](default_video_quality_analyzer.md) section. +* [`AudioQualityAnalyzerInterface`][6] - used to measure audio quality metrics +* [`TestActivitiesExecutor`][7] - used to support [`ExecuteAt(...)`][8] and + [`ExecuteEvery(...)`][9] API of `PeerConnectionE2EQualityTestFixture` to run + any arbitrary action during test execution timely synchronized with a test + call. +* A vector of [`QualityMetricsReporter`][10] added by the + `PeerConnectionE2EQualityTestFixture` user. +* Two peers: Alice and Bob represented by instances of [`TestPeer`][11] + object. + +Also it keeps a reference to [`webrtc::TimeController`][12], which is used to +create all required threads, task queues, task queue factories and time related +objects. + +## TestPeer + +Call participants are represented by instances of `TestPeer` object. +[`TestPeerFactory`][13] is used to create them. `TestPeer` owns all instances +related to the `webrtc::PeerConnection`, including required listeners and +callbacks. Also it provides an API to do offer/answer exchange and ICE candidate +exchange. For this purposes internally it uses an instance of +[`webrtc::PeerConnectionWrapper`][14]. + +The `TestPeer` also owns the `PeerConnection` worker thread. The signaling +thread for all `PeerConnection`'s is owned by +`PeerConnectionE2EQualityTestFixture` and shared between all participants in the +call. The network thread is owned by the network layer (it maybe either emulated +network provided by [Network Emulation Framework][24] or network thread and +`rtc::NetworkManager` provided by user) and provided when peer is added to the +fixture via [`AddPeer(...)`][15] API. + +## GetStats API based metrics reporters + +`PeerConnectionE2EQualityTestFixture` gives the user ability to provide +different `QualityMetricsReporter`s which will listen for `PeerConnection` +[`GetStats`][16] API. Then such reporters will be able to report various metrics +that user wants to measure. + +`PeerConnectionE2EQualityTestFixture` itself also uses this mechanism to +measure: + +* Audio quality metrics +* Audio/Video sync metrics (with help of [`CrossMediaMetricsReporter`][17]) + +Also framework provides a [`StatsBasedNetworkQualityMetricsReporter`][18] to +measure network related WebRTC metrics and print debug raw emulated network +statistic. This reporter should be added by user via +[`AddQualityMetricsReporter(...)`][19] API if requried. + +Internally stats gathering is done by [`StatsPoller`][20]. Stats are requested +once per second for each `PeerConnection` and then resulted object is provided +into each stats listener. + +## Offer/Answer exchange + +`PeerConnectionE2EQualityTest` provides ability to test Simulcast and SVC for +video. These features aren't supported by P2P call and in general requires a +Selective Forwarding Unit (SFU). So special logic is applied to mimic SFU +behavior in P2P call. This logic is located inside [`SignalingInterceptor`][21], +[`QualityAnalyzingVideoEncoder`][22] and [`QualityAnalyzingVideoDecoder`][23] +and consist of SDP modification during offer/answer exchange and special +handling of video frames from unrelated Simulcast/SVC streams during decoding. + +### Simulcast + +In case of Simulcast we have a video track, which internally contains multiple +video streams, for example low resolution, medium resolution and high +resolution. WebRTC client doesn't support receiving an offer with multiple +streams in it, because usually SFU will keep only single stream for the client. +To bypass it framework will modify offer by converting a single track with three +video streams into three independent video tracks. Then sender will think that +it send simulcast, but receiver will think that it receives 3 independent +tracks. + +To achieve such behavior some extra tweaks are required: + +* MID RTP header extension from original offer have to be removed +* RID RTP header extension from original offer is replaced with MID RTP header + extension, so the ID that sender uses for RID on receiver will be parsed as + MID. +* Answer have to be modified in the opposite way. + +Described modifications are illustrated on the picture below. + +![VP8 Simulcast offer modification](vp8_simulcast_offer_modification.png "VP8 Simulcast offer modification") + +The exchange will look like this: + +1. Alice creates an offer +2. Alice sets offer as local description +3. Do described offer modification +4. Alice sends modified offer to Bob +5. Bob sets modified offer as remote description +6. Bob creates answer +7. Bob sets answer as local description +8. Do reverse modifications on answer +9. Bob sends modified answer to Alice +10. Alice sets modified answer as remote description + +Such mechanism put a constraint that RTX streams are not supported, because they +don't have RID RTP header extension in their packets. + +### SVC + +In case of SVC the framework will update the sender's offer before even setting +it as local description on the sender side. Then no changes to answer will be +required. + +`ssrc` is a 32 bit random value that is generated in RTP to denote a specific +source used to send media in an RTP connection. In original offer video track +section will look like this: + +``` +m=video 9 UDP/TLS/RTP/SAVPF 98 100 99 101 +... +a=ssrc-group:FID +a=ssrc: cname:... +.... +a=ssrc: cname:... +.... +``` + +To enable SVC for such video track framework will add extra `ssrc`s for each SVC +stream that is required like this: + +``` +a=ssrc-group:FID +a=ssrc: cname:... +.... +a=ssrc: cname:.... +... +a=ssrc-group:FID +a=ssrc: cname:... +.... +a=ssrc: cname:.... +... +a=ssrc-group:FID +a=ssrc: cname:... +.... +a=ssrc: cname:.... +... +``` + +The next line will also be added to the video track section of the offer: + +``` +a=ssrc-group:SIM +``` + +It will tell PeerConnection that this track should be configured as SVC. It +utilize WebRTC Plan B offer structure to achieve SVC behavior, also it modifies +offer before setting it as local description which violates WebRTC standard. +Also it adds limitations that on lossy networks only top resolution streams can +be analyzed, because WebRTC won't try to restore low resolution streams in case +of loss, because it still receives higher stream. + +### Handling in encoder/decoder + +In the encoder, the framework for each encoded video frame will propagate +information requried for the fake SFU to know if it belongs to an interesting +simulcast stream/spatial layer of if it should be "discarded". + +On the decoder side frames that should be "discarded" by fake SFU will be auto +decoded into single pixel images and only the interesting simulcast +stream/spatial layer will go into real decoder and then will be analyzed. + +[1]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/test/peerconnection_quality_test_fixture.h;l=55;drc=484acf27231d931dbc99aedce85bc27e06486b96 +[2]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/peer_connection_quality_test.h;l=44;drc=6cc893ad778a0965e2b7a8e614f3c98aa81bee5b +[3]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/media/media_helper.h;l=27;drc=d46db9f1523ae45909b4a6fdc90a140443068bc6 +[4]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/analyzer/video/video_quality_analyzer_injection_helper.h;l=38;drc=79020414fd5c71f9ec1f25445ea5f1c8001e1a49 +[5]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/analyzer/video/single_process_encoded_image_data_injector.h;l=40;drc=79020414fd5c71f9ec1f25445ea5f1c8001e1a49 +[6]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/test/audio_quality_analyzer_interface.h;l=23;drc=20f45823e37fd7272aa841831c029c21f29742c2 +[7]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/test_activities_executor.h;l=28;drc=6cc893ad778a0965e2b7a8e614f3c98aa81bee5b +[8]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/test/peerconnection_quality_test_fixture.h;l=439;drc=484acf27231d931dbc99aedce85bc27e06486b96 +[9]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/test/peerconnection_quality_test_fixture.h;l=445;drc=484acf27231d931dbc99aedce85bc27e06486b96 +[10]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/test/peerconnection_quality_test_fixture.h;l=413;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[11]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/test_activities_executor.h;l=28;drc=6cc893ad778a0965e2b7a8e614f3c98aa81bee5b +[12]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/test_activities_executor.h;l=28;drc=6cc893ad778a0965e2b7a8e614f3c98aa81bee5b +[13]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/test_peer_factory.h;l=46;drc=0ef4a2488a466a24ab97b31fdddde55440d451f9 +[14]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/pc/peer_connection_wrapper.h;l=47;drc=5ab79e62f691875a237ea28ca3975ea1f0ed62ec +[15]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/test/peerconnection_quality_test_fixture.h;l=459;drc=484acf27231d931dbc99aedce85bc27e06486b96 +[16]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/peer_connection_interface.h;l=886;drc=9438fb3fff97c803d1ead34c0e4f223db168526f +[17]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/cross_media_metrics_reporter.h;l=29;drc=9d777620236ec76754cfce19f6e82dd18e52d22c +[18]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/cross_media_metrics_reporter.h;l=29;drc=9d777620236ec76754cfce19f6e82dd18e52d22c +[19]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/test/peerconnection_quality_test_fixture.h;l=450;drc=484acf27231d931dbc99aedce85bc27e06486b96 +[20]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/stats_poller.h;l=52;drc=9b526180c9e9722d3fc7f8689da6ec094fc7fc0a +[21]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/sdp/sdp_changer.h;l=79;drc=ee558dcca89fd8b105114ededf9e74d948da85e8 +[22]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/analyzer/video/quality_analyzing_video_encoder.h;l=54;drc=79020414fd5c71f9ec1f25445ea5f1c8001e1a49 +[23]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/analyzer/video/quality_analyzing_video_decoder.h;l=50;drc=79020414fd5c71f9ec1f25445ea5f1c8001e1a49 +[24]: /test/network/g3doc/index.md diff --git a/test/pc/e2e/g3doc/default_video_quality_analyzer.md b/test/pc/e2e/g3doc/default_video_quality_analyzer.md new file mode 100644 index 0000000000..532226e350 --- /dev/null +++ b/test/pc/e2e/g3doc/default_video_quality_analyzer.md @@ -0,0 +1,196 @@ + + +# DefaultVideoQualityAnalyzer + +## Audience + +This document is for users of +[`webrtc::webrtc_pc_e2e::DefaultVideoQualityAnalyzer`][1]. + +## Overview + +`DefaultVideoQualityAnalyzer` implements +[`webrtc::webrtc_pc_e2e::VideoQualityAnalyzerInterface`][2] and is a main +implementation of video quality analyzer for WebRTC. To operate correctly it +requires to receive video frame on each step: + +1. On frame captured - analyzer will generate a unique ID for the frame, that + caller should attach to the it. +2. Immediately before frame enter the encoder. +3. Immediately after the frame was encoded. +4. After the frame was received and immediately before it entered the decoder. +5. Immediately after the frame was decoded. +6. When the frame was rendered. + +![VideoQualityAnalyzerInterface pipeline](video_quality_analyzer_pipeline.png "VideoQualityAnalyzerInterface pipeline") + +The analyzer updates its internal metrics per frame when it was rendered and +reports all of them after it was stopped through +[WebRTC perf results reporting system][10]. + +To properly inject `DefaultVideoQualityAnalyzer` into pipeline the following helpers can be used: + +### VideoQualityAnalyzerInjectionHelper + +[`webrtc::webrtc_pc_e2e::VideoQualityAnalyzerInjectionHelper`][3] provides +factory methods for components, that will be used to inject +`VideoQualityAnalyzerInterface` into the `PeerConnection` pipeline: + +* Wrappers for [`webrtc::VideoEncoderFactory`][4] and + [`webrtc::VideoDecodeFactory`][5] which will properly pass + [`webrtc::VideoFrame`][6] and [`webrtc::EncodedImage`][7] into analyzer + before and after real video encode and decoder. +* [`webrtc::test::TestVideoCapturer::FramePreprocessor`][8] which is used to + pass generated frames into analyzer on capturing and then set the returned + frame ID. It also configures dumping of captured frames if requried. +* [`rtc::VideoSinkInterface`][9] which is used to pass frames to + the analyzer before they will be rendered to compute per frame metrics. It + also configures dumping of rendered video if requried. + +Besides factories `VideoQualityAnalyzerInjectionHelper` has method to +orchestrate `VideoQualityAnalyzerInterface` workflow: + +* `Start` - to start video analyzer, so it will be able to receive and analyze + video frames. +* `RegisterParticipantInCall` - to add new participants after analyzer was + started. +* `Stop` - to stop analyzer, compute all metrics for frames that were recevied + before and report them. + +Also `VideoQualityAnalyzerInjectionHelper` implements +[`webrtc::webrtc_pc_e2e::StatsObserverInterface`][11] to propagate WebRTC stats +to `VideoQualityAnalyzerInterface`. + +### EncodedImageDataInjector and EncodedImageDataExtractor + +[`webrtc::webrtc_pc_e2e::EncodedImageDataInjector`][14] and +[`webrtc::webrtc_pc_e2e::EncodedImageDataInjector`][15] are used to inject and +extract data into `webrtc::EncodedImage` to propagate frame ID and other +required information through the network. + +By default [`webrtc::webrtc_pc_e2e::SingleProcessEncodedImageDataInjector`][16] +is used. It assumes `webrtc::EncodedImage` payload as black box which is +remaining unchanged from encoder to decoder and stores the information required +for its work in the last 3 bytes of the payload, replacing the original data +during injection and restoring it back during extraction. Also +`SingleProcessEncodedImageDataInjector` requires that sender and receiver were +inside single process. + +![SingleProcessEncodedImageDataInjector](single_process_encoded_image_data_injector.png "SingleProcessEncodedImageDataInjector") + +## Exported metrics + +Exported metrics are reported to WebRTC perf results reporting system. + +### General + +* *`cpu_usage`* - CPU usage excluding video analyzer + +### Video + +* *`psnr`* - peak signal-to-noise ratio: + [wikipedia](https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio) +* *`ssim`* - structural similarity: + [wikipedia](https://en.wikipedia.org/wiki/Structural_similarity). +* *`min_psnr`* - minimum value of psnr across all frames of video stream. +* *`encode_time`* - time to encode a single frame. +* *`decode_time`* - time to decode a single frame. +* *`transport_time`* - time from frame encoded to frame received for decoding. +* *`receive_to_render_time`* - time from frame received for decoding to frame + rendered. +* *`total_delay_incl_transport`* - time from frame was captured on device to + time when frame was displayed on device. +* *`encode_frame_rate`* - frame rate after encoder. +* *`harmonic_framerate`* - video duration divided on squared sum of interframe + delays. Reflects render frame rate penalized by freezes. +* *`time_between_rendered_frames`* - time between frames out to renderer. +* *`dropped_frames`* - amount of frames that were sent, but weren't rendered + and are known not to be “on the way” from sender to receiver. + +Freeze is a pause when no new frames from decoder arrived for 150ms + avg time +between frames or 3 * avg time between frames. + +* *`time_between_freezes`* - mean time from previous freeze end to new freeze + start. +* *`freeze_time_ms`* - total freeze time in ms. +* *`max_skipped`* - frames skipped between two nearest rendered. +* *`pixels_per_frame`* - amount of pixels on frame (width * height). +* *`target_encode_bitrate`* - target encode bitrate provided by BWE to + encoder. +* *`actual_encode_bitrate -`* - actual encode bitrate produced by encoder. +* *`available_send_bandwidth -`* - available send bandwidth estimated by BWE. +* *`transmission_bitrate`* - bitrate of media in the emulated network, not + counting retransmissions FEC, and RTCP messages +* *`retransmission_bitrate`* - bitrate of retransmission streams only. + +### Framework stability + +* *`frames_in_flight`* - amount of frames that were captured but wasn't seen + on receiver. + +## Debug metrics + +Debug metrics are not reported to WebRTC perf results reporting system, but are +available through `DefaultVideoQualityAnalyzer` API. + +### [FrameCounters][12] + +Frame counters consist of next counters: + +* *`captured`* - count of frames, that were passed into WebRTC pipeline by + video stream source +* *`pre_encoded`* - count of frames that reached video encoder. +* *`encoded`* - count of encoded images that were produced by encoder for all + requested spatial layers and simulcast streams. +* *`received`* - count of encoded images received in decoder for all requested + spatial layers and simulcast streams. +* *`decoded`* - count of frames that were produced by decoder. +* *`rendered`* - count of frames that went out from WebRTC pipeline to video + sink. +* *`dropped`* - count of frames that were dropped in any point between + capturing and rendering. + +`DefaultVideoQualityAnalyzer` exports these frame counters: + +* *`GlobalCounters`* - frame counters for frames met on each stage of analysis + for all media streams. +* *`PerStreamCounters`* - frame counters for frames met on each stage of + analysis separated per individual video track (single media section in the + SDP offer). + +### [AnalyzerStats][13] + +Contains metrics about internal state of video analyzer during its work + +* *`comparisons_queue_size`* - size of analyzer internal queue used to perform + captured and rendered frames comparisons measured when new element is added + to the queue. +* *`comparisons_done`* - number of performed comparisons of 2 video frames + from captured and rendered streams. +* *`cpu_overloaded_comparisons_done`* - number of cpu overloaded comparisons. + Comparison is cpu overloaded if it is queued when there are too many not + processed comparisons in the queue. Overloaded comparison doesn't include + metrics like SSIM and PSNR that require heavy computations. +* *`memory_overloaded_comparisons_done`* - number of memory overloaded + comparisons. Comparison is memory overloaded if it is queued when its + captured frame was already removed due to high memory usage for that video + stream. +* *`frames_in_flight_left_count`* - count of frames in flight in analyzer + measured when new comparison is added and after analyzer was stopped. + +[1]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/analyzer/video/default_video_quality_analyzer.h;l=188;drc=08f46909a8735cf181b99ef2f7e1791c5a7531d2 +[2]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/test/video_quality_analyzer_interface.h;l=56;drc=d7808f1c464a07c8f1e2f97ec7ee92fda998d590 +[3]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/analyzer/video/video_quality_analyzer_injection_helper.h;l=39;drc=08f46909a8735cf181b99ef2f7e1791c5a7531d2 +[4]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/video_codecs/video_encoder_factory.h;l=27;drc=08f46909a8735cf181b99ef2f7e1791c5a7531d2 +[5]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/video_codecs/video_decoder_factory.h;l=27;drc=08f46909a8735cf181b99ef2f7e1791c5a7531d2 +[6]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/video/video_frame.h;l=30;drc=08f46909a8735cf181b99ef2f7e1791c5a7531d2 +[7]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/video/encoded_image.h;l=71;drc=08f46909a8735cf181b99ef2f7e1791c5a7531d2 +[8]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/test_video_capturer.h;l=28;drc=08f46909a8735cf181b99ef2f7e1791c5a7531d2 +[9]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/video/video_sink_interface.h;l=19;drc=08f46909a8735cf181b99ef2f7e1791c5a7531d2 +[10]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/testsupport/perf_test.h;drc=0710b401b1e5b500b8e84946fb657656ba1b58b7 +[11]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/test/stats_observer_interface.h;l=21;drc=9b526180c9e9722d3fc7f8689da6ec094fc7fc0a +[12]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/analyzer/video/default_video_quality_analyzer.h;l=57;drc=08f46909a8735cf181b99ef2f7e1791c5a7531d2 +[13]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/analyzer/video/default_video_quality_analyzer.h;l=113;drc=08f46909a8735cf181b99ef2f7e1791c5a7531d2 +[14]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/analyzer/video/encoded_image_data_injector.h;l=23;drc=c57089a97a3df454f4356d882cc8df173e8b3ead +[15]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/analyzer/video/encoded_image_data_injector.h;l=46;drc=c57089a97a3df454f4356d882cc8df173e8b3ead +[16]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/analyzer/video/single_process_encoded_image_data_injector.h;l=40;drc=c57089a97a3df454f4356d882cc8df173e8b3ead diff --git a/test/pc/e2e/g3doc/g3doc.lua b/test/pc/e2e/g3doc/g3doc.lua new file mode 100644 index 0000000000..981393c826 --- /dev/null +++ b/test/pc/e2e/g3doc/g3doc.lua @@ -0,0 +1,5 @@ +config = super() + +config.freshness.owner = 'titovartem' + +return config diff --git a/test/pc/e2e/g3doc/in_test_psnr_plot.png b/test/pc/e2e/g3doc/in_test_psnr_plot.png new file mode 100644 index 0000000000..3f36725727 Binary files /dev/null and b/test/pc/e2e/g3doc/in_test_psnr_plot.png differ diff --git a/test/pc/e2e/g3doc/index.md b/test/pc/e2e/g3doc/index.md new file mode 100644 index 0000000000..d676476ddc --- /dev/null +++ b/test/pc/e2e/g3doc/index.md @@ -0,0 +1,223 @@ + + +# PeerConnection Level Framework + +## API + +* [Fixture][1] +* [Fixture factory function][2] + +## Documentation + +The PeerConnection level framework is designed for end-to-end media quality +testing through the PeerConnection level public API. The framework uses the +*Unified plan* API to generate offers/answers during the signaling phase. The +framework also wraps the video encoder/decoder and inject it into +*`webrtc::PeerConnection`* to measure video quality, performing 1:1 frames +matching between captured and rendered frames without any extra requirements to +input video. For audio quality evaluation the standard `GetStats()` API from +PeerConnection is used. + +The framework API is located in the namespace *`webrtc::webrtc_pc_e2e`*. + +### Supported features + +* Single or bidirectional media in the call +* RTC Event log dump per peer +* AEC dump per peer +* Compatible with *`webrtc::TimeController`* for both real and simulated time +* Media + * AV sync +* Video + * Any amount of video tracks both from caller and callee sides + * Input video from + * Video generator + * Specified file + * Any instance of *`webrtc::test::FrameGeneratorInterface`* + * Dumping of captured/rendered video into file + * Screen sharing + * Vp8 simulcast from caller side + * Vp9 SVC from caller side + * Choosing of video codec (name and parameters), having multiple codecs + negotiated to support codec-switching testing. + * FEC (ULP or Flex) + * Forced codec overshooting (for encoder overshoot emulation on some + mobile devices, when hardware encoder can overshoot target bitrate) +* Audio + * Up to 1 audio track both from caller and callee sides + * Generated audio + * Audio from specified file + * Dumping of captured/rendered audio into file + * Parameterizing of `cricket::AudioOptions` + * Echo emulation +* Injection of various WebRTC components into underlying + *`webrtc::PeerConnection`* or *`webrtc::PeerConnectionFactory`*. You can see + the full list [here][11] +* Scheduling of events, that can happen during the test, for example: + * Changes in network configuration + * User statistics measurements + * Custom defined actions +* User defined statistics reporting via + *`webrtc::webrtc_pc_e2e::PeerConnectionE2EQualityTestFixture::QualityMetricsReporter`* + interface + +## Exported metrics + +### General + +* *`_connected`* - peer successfully established connection to + remote side +* *`cpu_usage`* - CPU usage excluding video analyzer +* *`audio_ahead_ms`* - Used to estimate how much audio and video is out of + sync when the two tracks were from the same source. Stats are polled + periodically during a call. The metric represents how much earlier was audio + played out on average over the call. If, during a stats poll, video is + ahead, then audio_ahead_ms will be equal to 0 for this poll. +* *`video_ahead_ms`* - Used to estimate how much audio and video is out of + sync when the two tracks were from the same source. Stats are polled + periodically during a call. The metric represents how much earlier was video + played out on average over the call. If, during a stats poll, audio is + ahead, then video_ahead_ms will be equal to 0 for this poll. + +### Video + +See documentation for +[*`DefaultVideoQualityAnalyzer`*](default_video_quality_analyzer.md#exported-metrics) + +### Audio + +* *`accelerate_rate`* - when playout is sped up, this counter is increased by + the difference between the number of samples received and the number of + samples played out. If speedup is achieved by removing samples, this will be + the count of samples removed. Rate is calculated as difference between + nearby samples divided on sample interval. +* *`expand_rate`* - the total number of samples that are concealed samples + over time. A concealed sample is a sample that was replaced with synthesized + samples generated locally before being played out. Examples of samples that + have to be concealed are samples from lost packets or samples from packets + that arrive too late to be played out +* *`speech_expand_rate`* - the total number of samples that are concealed + samples minus the total number of concealed samples inserted that are + "silent" over time. Playing out silent samples results in silence or comfort + noise. +* *`preemptive_rate`* - when playout is slowed down, this counter is increased + by the difference between the number of samples received and the number of + samples played out. If playout is slowed down by inserting samples, this + will be the number of inserted samples. Rate is calculated as difference + between nearby samples divided on sample interval. +* *`average_jitter_buffer_delay_ms`* - average size of NetEQ jitter buffer. +* *`preferred_buffer_size_ms`* - preferred size of NetEQ jitter buffer. +* *`visqol_mos`* - proxy for audio quality itself. +* *`asdm_samples`* - measure of how much acceleration/deceleration was in the + signal. +* *`word_error_rate`* - measure of how intelligible the audio was (percent of + words that could not be recognized in output audio). + +### Network + +* *`bytes_sent`* - represents the total number of payload bytes sent on this + PeerConnection, i.e., not including headers or padding +* *`packets_sent`* - represents the total number of packets sent over this + PeerConnection’s transports. +* *`average_send_rate`* - average send rate calculated on bytes_sent divided + by test duration. +* *`payload_bytes_sent`* - total number of bytes sent for all SSRC plus total + number of RTP header and padding bytes sent for all SSRC. This does not + include the size of transport layer headers such as IP or UDP. +* *`sent_packets_loss`* - packets_sent minus corresponding packets_received. +* *`bytes_received`* - represents the total number of bytes received on this + PeerConnection, i.e., not including headers or padding. +* *`packets_received`* - represents the total number of packets received on + this PeerConnection’s transports. +* *`average_receive_rate`* - average receive rate calculated on bytes_received + divided by test duration. +* *`payload_bytes_received`* - total number of bytes received for all SSRC + plus total number of RTP header and padding bytes received for all SSRC. + This does not include the size of transport layer headers such as IP or UDP. + +### Framework stability + +* *`frames_in_flight`* - amount of frames that were captured but wasn't seen + on receiver in the way that also all frames after also weren't seen on + receiver. +* *`bytes_discarded_no_receiver`* - total number of bytes that were received + on network interfaces related to the peer, but destination port was closed. +* *`packets_discarded_no_receiver`* - total number of packets that were + received on network interfaces related to the peer, but destination port was + closed. + +## Examples + +Examples can be found in + +* [peer_connection_e2e_smoke_test.cc][3] +* [pc_full_stack_tests.cc][4] + +## Stats plotting + +### Description + +Stats plotting provides ability to plot statistic collected during the test. +Right now it is used in PeerConnection level framework and give ability to see +how video quality metrics changed during test execution. + +### Usage + +To make any metrics plottable you need: + +1. Collect metric data with [SamplesStatsCounter][5] which internally will + store all intermediate points and timestamps when these points were added. +2. Then you need to report collected data with + [`webrtc::test::PrintResult(...)`][6]. By using these method you will also + specify name of the plottable metric. + +After these steps it will be possible to export your metric for plotting. There +are several options how you can do this: + +1. Use [`webrtc::TestMain::Create()`][7] as `main` function implementation, for + example use [`test/test_main.cc`][8] as `main` function for your test. + + In such case your binary will have flag `--plot`, where you can provide a + list of metrics, that you want to plot or specify `all` to plot all + available metrics. + + If `--plot` is specified, the binary will output metrics data into `stdout`. + Then you need to pipe this `stdout` into python plotter script + [`rtc_tools/metrics_plotter.py`][9], which will plot data. + + Examples: + + ```shell + $ ./out/Default/test_support_unittests \ + --gtest_filter=PeerConnectionE2EQualityTestSmokeTest.Svc \ + --nologs \ + --plot=all \ + | python rtc_tools/metrics_plotter.py + ``` + + ```shell + $ ./out/Default/test_support_unittests \ + --gtest_filter=PeerConnectionE2EQualityTestSmokeTest.Svc \ + --nologs \ + --plot=psnr,ssim \ + | python rtc_tools/metrics_plotter.py + ``` + + Example chart: ![PSNR changes during the test](in_test_psnr_plot.png) + +2. Use API from [`test/testsupport/perf_test.h`][10] directly by invoking + `webrtc::test::PrintPlottableResults(const std::vector& + desired_graphs)` to print plottable metrics to stdout. Then as in previous + option you need to pipe result into plotter script. + +[1]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/test/peerconnection_quality_test_fixture.h;drc=cbe6e8a2589a925d4c91a2ac2c69201f03de9c39 +[2]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/test/create_peerconnection_quality_test_fixture.h;drc=cbe6e8a2589a925d4c91a2ac2c69201f03de9c39 +[3]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/pc/e2e/peer_connection_e2e_smoke_test.cc;drc=cbe6e8a2589a925d4c91a2ac2c69201f03de9c39 +[4]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/video/pc_full_stack_tests.cc;drc=cbe6e8a2589a925d4c91a2ac2c69201f03de9c39 +[5]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/numerics/samples_stats_counter.h;drc=cbe6e8a2589a925d4c91a2ac2c69201f03de9c39 +[6]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/testsupport/perf_test.h;l=86;drc=0710b401b1e5b500b8e84946fb657656ba1b58b7 +[7]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/test_main_lib.h;l=23;drc=bcb42f1e4be136c390986a40d9d5cb3ad0de260b +[8]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/test_main.cc;drc=bcb42f1e4be136c390986a40d9d5cb3ad0de260b +[9]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/rtc_tools/metrics_plotter.py;drc=8cc6695652307929edfc877cd64b75cd9ec2d615 +[10]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/test/testsupport/perf_test.h;l=105;drc=0710b401b1e5b500b8e84946fb657656ba1b58b7 +[11]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/test/peerconnection_quality_test_fixture.h;l=272;drc=484acf27231d931dbc99aedce85bc27e06486b96 diff --git a/test/pc/e2e/g3doc/single_process_encoded_image_data_injector.png b/test/pc/e2e/g3doc/single_process_encoded_image_data_injector.png new file mode 100644 index 0000000000..73480bafbe Binary files /dev/null and b/test/pc/e2e/g3doc/single_process_encoded_image_data_injector.png differ diff --git a/test/pc/e2e/g3doc/video_quality_analyzer_pipeline.png b/test/pc/e2e/g3doc/video_quality_analyzer_pipeline.png new file mode 100644 index 0000000000..6cddb91110 Binary files /dev/null and b/test/pc/e2e/g3doc/video_quality_analyzer_pipeline.png differ diff --git a/test/pc/e2e/g3doc/vp8_simulcast_offer_modification.png b/test/pc/e2e/g3doc/vp8_simulcast_offer_modification.png new file mode 100644 index 0000000000..c7eaa04c0e Binary files /dev/null and b/test/pc/e2e/g3doc/vp8_simulcast_offer_modification.png differ diff --git a/test/pc/e2e/media/media_helper.cc b/test/pc/e2e/media/media_helper.cc index d1c27838a6..6b1996adaa 100644 --- a/test/pc/e2e/media/media_helper.cc +++ b/test/pc/e2e/media/media_helper.cc @@ -64,7 +64,7 @@ MediaHelper::MaybeAddVideo(TestPeer* peer) { video_config.content_hint == VideoTrackInterface::ContentHint::kDetailed; rtc::scoped_refptr source = - new rtc::RefCountedObject( + rtc::make_ref_counted( std::move(capturer), is_screencast); out.push_back(source); RTC_LOG(INFO) << "Adding video with video_config.stream_label=" diff --git a/test/pc/e2e/network_quality_metrics_reporter.cc b/test/pc/e2e/network_quality_metrics_reporter.cc index 2df45291d8..513bdc0a5f 100644 --- a/test/pc/e2e/network_quality_metrics_reporter.cc +++ b/test/pc/e2e/network_quality_metrics_reporter.cc @@ -116,10 +116,10 @@ void NetworkQualityMetricsReporter::ReportStats( "average_send_rate", network_label, stats->PacketsSent() >= 2 ? stats->AverageSendRate().bytes_per_sec() : 0, "bytesPerSecond"); - ReportResult("bytes_dropped", network_label, stats->BytesDropped().bytes(), - "sizeInBytes"); - ReportResult("packets_dropped", network_label, stats->PacketsDropped(), - "unitless"); + ReportResult("bytes_discarded_no_receiver", network_label, + stats->BytesDropped().bytes(), "sizeInBytes"); + ReportResult("packets_discarded_no_receiver", network_label, + stats->PacketsDropped(), "unitless"); ReportResult("bytes_received", network_label, stats->BytesReceived().bytes(), "sizeInBytes"); ReportResult("packets_received", network_label, stats->PacketsReceived(), diff --git a/test/pc/e2e/peer_configurer.cc b/test/pc/e2e/peer_configurer.cc index b5616b5d68..18570c2c6b 100644 --- a/test/pc/e2e/peer_configurer.cc +++ b/test/pc/e2e/peer_configurer.cc @@ -134,6 +134,15 @@ void ValidateParams( RTC_CHECK(inserted) << "Duplicate video_config.stream_label=" << video_config.stream_label.value(); + if (video_config.input_dump_file_name.has_value()) { + RTC_CHECK_GT(video_config.input_dump_sampling_modulo, 0) + << "video_config.input_dump_sampling_modulo must be greater than 0"; + } + if (video_config.output_dump_file_name.has_value()) { + RTC_CHECK_GT(video_config.output_dump_sampling_modulo, 0) + << "video_config.input_dump_sampling_modulo must be greater than 0"; + } + // TODO(bugs.webrtc.org/4762): remove this check after synchronization of // more than two streams is supported. if (video_config.sync_group.has_value()) { diff --git a/test/pc/e2e/peer_connection_quality_test.cc b/test/pc/e2e/peer_connection_quality_test.cc index a234d2b705..38a9ebf801 100644 --- a/test/pc/e2e/peer_connection_quality_test.cc +++ b/test/pc/e2e/peer_connection_quality_test.cc @@ -670,12 +670,12 @@ void PeerConnectionE2EQualityTest::TearDownCall() { video_source->Stop(); } - alice_->pc()->Close(); - bob_->pc()->Close(); - alice_video_sources_.clear(); bob_video_sources_.clear(); + alice_->Close(); + bob_->Close(); + media_helper_ = nullptr; } diff --git a/test/pc/e2e/sdp/sdp_changer.cc b/test/pc/e2e/sdp/sdp_changer.cc index f2aeb1b92d..b46aea1c5f 100644 --- a/test/pc/e2e/sdp/sdp_changer.cc +++ b/test/pc/e2e/sdp/sdp_changer.cc @@ -34,6 +34,23 @@ std::string CodecRequiredParamsToString( return out.str(); } +std::string SupportedCodecsToString( + rtc::ArrayView supported_codecs) { + rtc::StringBuilder out; + for (const auto& codec : supported_codecs) { + out << codec.name; + if (!codec.parameters.empty()) { + out << "("; + for (const auto& param : codec.parameters) { + out << param.first << "=" << param.second << ";"; + } + out << ")"; + } + out << "; "; + } + return out.str(); +} + } // namespace std::vector FilterVideoCodecCapabilities( @@ -42,16 +59,6 @@ std::vector FilterVideoCodecCapabilities( bool use_ulpfec, bool use_flexfec, rtc::ArrayView supported_codecs) { - RTC_LOG(INFO) << "Peer connection support these codecs:"; - for (const auto& codec : supported_codecs) { - RTC_LOG(INFO) << "Codec: " << codec.name; - if (!codec.parameters.empty()) { - RTC_LOG(INFO) << "Params:"; - for (const auto& param : codec.parameters) { - RTC_LOG(INFO) << " " << param.first << "=" << param.second; - } - } - } std::vector output_codecs; // Find requested codecs among supported and add them to output in the order // they were requested. @@ -80,7 +87,8 @@ std::vector FilterVideoCodecCapabilities( RTC_CHECK_GT(output_codecs.size(), size_before) << "Codec with name=" << codec_request.name << " and params {" << CodecRequiredParamsToString(codec_request.required_params) - << "} is unsupported for this peer connection"; + << "} is unsupported for this peer connection. Supported codecs are: " + << SupportedCodecsToString(supported_codecs); } // Add required FEC and RTX codecs to output. @@ -524,9 +532,11 @@ SignalingInterceptor::PatchOffererIceCandidates( context_.simulcast_infos_by_mid.find(candidate->sdp_mid()); if (simulcast_info_it != context_.simulcast_infos_by_mid.end()) { // This is candidate for simulcast section, so it should be transformed - // into candidates for replicated sections - out.push_back(CreateIceCandidate(simulcast_info_it->second->rids[0], 0, - candidate->candidate())); + // into candidates for replicated sections. The sdpMLineIndex is set to + // -1 and ignored if the rid is present. + for (auto rid : simulcast_info_it->second->rids) { + out.push_back(CreateIceCandidate(rid, -1, candidate->candidate())); + } } else { out.push_back(CreateIceCandidate(candidate->sdp_mid(), candidate->sdp_mline_index(), @@ -550,6 +560,9 @@ SignalingInterceptor::PatchAnswererIceCandidates( // section. out.push_back(CreateIceCandidate(simulcast_info_it->second->mid, 0, candidate->candidate())); + } else if (context_.simulcast_infos_by_rid.size()) { + // When using simulcast and bundle, put everything on the first m-line. + out.push_back(CreateIceCandidate("", 0, candidate->candidate())); } else { out.push_back(CreateIceCandidate(candidate->sdp_mid(), candidate->sdp_mline_index(), diff --git a/test/pc/e2e/stats_based_network_quality_metrics_reporter.cc b/test/pc/e2e/stats_based_network_quality_metrics_reporter.cc index e4efe1fd77..eb676a92bd 100644 --- a/test/pc/e2e/stats_based_network_quality_metrics_reporter.cc +++ b/test/pc/e2e/stats_based_network_quality_metrics_reporter.cc @@ -212,10 +212,10 @@ void StatsBasedNetworkQualityMetricsReporter::ReportStats( const NetworkLayerStats& network_layer_stats, int64_t packet_loss, const Timestamp& end_time) { - ReportResult("bytes_dropped", pc_label, + ReportResult("bytes_discarded_no_receiver", pc_label, network_layer_stats.stats->BytesDropped().bytes(), "sizeInBytes"); - ReportResult("packets_dropped", pc_label, + ReportResult("packets_discarded_no_receiver", pc_label, network_layer_stats.stats->PacketsDropped(), "unitless"); ReportResult("payload_bytes_received", pc_label, diff --git a/test/pc/e2e/stats_poller.cc b/test/pc/e2e/stats_poller.cc index e6973e6af1..5f1424cd29 100644 --- a/test/pc/e2e/stats_poller.cc +++ b/test/pc/e2e/stats_poller.cc @@ -31,7 +31,7 @@ void InternalStatsObserver::OnStatsDelivered( StatsPoller::StatsPoller(std::vector observers, std::map peers) { for (auto& peer : peers) { - pollers_.push_back(new rtc::RefCountedObject( + pollers_.push_back(rtc::make_ref_counted( peer.first, peer.second, observers)); } } diff --git a/test/pc/e2e/test_peer.cc b/test/pc/e2e/test_peer.cc index 65d3eb36b8..942bedfba3 100644 --- a/test/pc/e2e/test_peer.cc +++ b/test/pc/e2e/test_peer.cc @@ -21,6 +21,7 @@ namespace webrtc_pc_e2e { bool TestPeer::AddIceCandidates( std::vector> candidates) { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; bool success = true; for (auto& candidate : candidates) { if (!pc()->AddIceCandidate(candidate.get())) { @@ -37,6 +38,15 @@ bool TestPeer::AddIceCandidates( return success; } +void TestPeer::Close() { + wrapper_->pc()->Close(); + remote_ice_candidates_.clear(); + audio_processing_ = nullptr; + video_sources_.clear(); + wrapper_ = nullptr; + worker_thread_ = nullptr; +} + TestPeer::TestPeer( rtc::scoped_refptr pc_factory, rtc::scoped_refptr pc, diff --git a/test/pc/e2e/test_peer.h b/test/pc/e2e/test_peer.h index 4310cbda1c..d8d5b2d1bb 100644 --- a/test/pc/e2e/test_peer.h +++ b/test/pc/e2e/test_peer.h @@ -30,63 +30,87 @@ class TestPeer final { public: Params* params() const { return params_.get(); } PeerConfigurerImpl::VideoSource ReleaseVideoSource(size_t i) { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; return std::move(video_sources_[i]); } PeerConnectionFactoryInterface* pc_factory() { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; return wrapper_->pc_factory(); } - PeerConnectionInterface* pc() { return wrapper_->pc(); } - MockPeerConnectionObserver* observer() { return wrapper_->observer(); } + PeerConnectionInterface* pc() { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; + return wrapper_->pc(); + } + MockPeerConnectionObserver* observer() { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; + return wrapper_->observer(); + } std::unique_ptr CreateOffer() { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; return wrapper_->CreateOffer(); } std::unique_ptr CreateAnswer() { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; return wrapper_->CreateAnswer(); } bool SetLocalDescription(std::unique_ptr desc, std::string* error_out = nullptr) { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; return wrapper_->SetLocalDescription(std::move(desc), error_out); } bool SetRemoteDescription(std::unique_ptr desc, std::string* error_out = nullptr) { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; return wrapper_->SetRemoteDescription(std::move(desc), error_out); } rtc::scoped_refptr AddTransceiver( cricket::MediaType media_type, const RtpTransceiverInit& init) { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; return wrapper_->AddTransceiver(media_type, init); } rtc::scoped_refptr AddTrack( rtc::scoped_refptr track, const std::vector& stream_ids = {}) { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; return wrapper_->AddTrack(track, stream_ids); } rtc::scoped_refptr CreateDataChannel( const std::string& label) { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; return wrapper_->CreateDataChannel(label); } PeerConnectionInterface::SignalingState signaling_state() { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; return wrapper_->signaling_state(); } - bool IsIceGatheringDone() { return wrapper_->IsIceGatheringDone(); } + bool IsIceGatheringDone() { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; + return wrapper_->IsIceGatheringDone(); + } - bool IsIceConnected() { return wrapper_->IsIceConnected(); } + bool IsIceConnected() { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; + return wrapper_->IsIceConnected(); + } rtc::scoped_refptr GetStats() { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; return wrapper_->GetStats(); } void DetachAecDump() { + RTC_CHECK(wrapper_) << "TestPeer is already closed"; if (audio_processing_) { audio_processing_->DetachAecDump(); } @@ -96,6 +120,10 @@ class TestPeer final { bool AddIceCandidates( std::vector> candidates); + // Closes underlying peer connection and destroys all related objects freeing + // up related resources. + void Close(); + protected: friend class TestPeerFactory; TestPeer(rtc::scoped_refptr pc_factory, diff --git a/test/pc/e2e/test_peer_factory.cc b/test/pc/e2e/test_peer_factory.cc index eceec778df..869b40f703 100644 --- a/test/pc/e2e/test_peer_factory.cc +++ b/test/pc/e2e/test_peer_factory.cc @@ -348,8 +348,10 @@ std::unique_ptr TestPeerFactory::CreateTestPeer( PeerConnectionDependencies pc_deps = CreatePCDependencies( observer.get(), std::move(components->pc_dependencies)); rtc::scoped_refptr peer_connection = - peer_connection_factory->CreatePeerConnection(params->rtc_configuration, - std::move(pc_deps)); + peer_connection_factory + ->CreatePeerConnectionOrError(params->rtc_configuration, + std::move(pc_deps)) + .MoveValue(); peer_connection->SetBitrate(params->bitrate_settings); return absl::WrapUnique(new TestPeer( diff --git a/test/pc/sctp/BUILD.gn b/test/pc/sctp/BUILD.gn index 93ae1bf59c..b47cff2c0f 100644 --- a/test/pc/sctp/BUILD.gn +++ b/test/pc/sctp/BUILD.gn @@ -11,5 +11,5 @@ import("../../../webrtc.gni") rtc_source_set("fake_sctp_transport") { visibility = [ "*" ] sources = [ "fake_sctp_transport.h" ] - deps = [ "../../../media:rtc_data" ] + deps = [ "../../../media:rtc_data_sctp_transport_internal" ] } diff --git a/test/pc/sctp/fake_sctp_transport.h b/test/pc/sctp/fake_sctp_transport.h index 5fdb3bbe42..42b978a900 100644 --- a/test/pc/sctp/fake_sctp_transport.h +++ b/test/pc/sctp/fake_sctp_transport.h @@ -29,7 +29,8 @@ class FakeSctpTransport : public cricket::SctpTransportInternal { } bool OpenStream(int sid) override { return true; } bool ResetStream(int sid) override { return true; } - bool SendData(const cricket::SendDataParams& params, + bool SendData(int sid, + const webrtc::SendDataParams& params, const rtc::CopyOnWriteBuffer& payload, cricket::SendDataResult* result = nullptr) override { return true; @@ -40,8 +41,14 @@ class FakeSctpTransport : public cricket::SctpTransportInternal { int max_message_size() const { return max_message_size_; } absl::optional max_outbound_streams() const { return absl::nullopt; } absl::optional max_inbound_streams() const { return absl::nullopt; } - int local_port() const { return *local_port_; } - int remote_port() const { return *remote_port_; } + int local_port() const { + RTC_DCHECK(local_port_); + return *local_port_; + } + int remote_port() const { + RTC_DCHECK(remote_port_); + return *remote_port_; + } private: absl::optional local_port_; diff --git a/test/peer_scenario/BUILD.gn b/test/peer_scenario/BUILD.gn index 70a7471591..033ef4115a 100644 --- a/test/peer_scenario/BUILD.gn +++ b/test/peer_scenario/BUILD.gn @@ -47,7 +47,9 @@ if (rtc_include_tests) { "../../p2p:rtc_p2p", "../../pc:pc_test_utils", "../../pc:rtc_pc_base", + "../../pc:session_description", "../../rtc_base", + "../../rtc_base:null_socket_server", "../../rtc_base:stringutils", "../logging:log_writer", "../network:emulated_network", diff --git a/test/peer_scenario/peer_scenario.cc b/test/peer_scenario/peer_scenario.cc index c3443aa185..ea959c943a 100644 --- a/test/peer_scenario/peer_scenario.cc +++ b/test/peer_scenario/peer_scenario.cc @@ -77,8 +77,8 @@ SignalingRoute PeerScenario::ConnectSignaling( PeerScenarioClient* callee, std::vector send_link, std::vector ret_link) { - return SignalingRoute(caller, callee, net_.CreateTrafficRoute(send_link), - net_.CreateTrafficRoute(ret_link)); + return SignalingRoute(caller, callee, net_.CreateCrossTrafficRoute(send_link), + net_.CreateCrossTrafficRoute(ret_link)); } void PeerScenario::SimpleConnection( diff --git a/test/peer_scenario/peer_scenario_client.cc b/test/peer_scenario/peer_scenario_client.cc index 681a90704f..7f3e126287 100644 --- a/test/peer_scenario/peer_scenario_client.cc +++ b/test/peer_scenario/peer_scenario_client.cc @@ -241,7 +241,9 @@ PeerScenarioClient::PeerScenarioClient( pc_deps.allocator->set_flags(pc_deps.allocator->flags() | cricket::PORTALLOCATOR_DISABLE_TCP); peer_connection_ = - pc_factory_->CreatePeerConnection(config.rtc_config, std::move(pc_deps)); + pc_factory_ + ->CreatePeerConnectionOrError(config.rtc_config, std::move(pc_deps)) + .MoveValue(); if (log_writer_factory_) { peer_connection_->StartRtcEventLog(log_writer_factory_->Create(".rtc.dat"), /*output_period_ms=*/1000); diff --git a/test/peer_scenario/scenario_connection.cc b/test/peer_scenario/scenario_connection.cc index 92082f5097..fefaa00c72 100644 --- a/test/peer_scenario/scenario_connection.cc +++ b/test/peer_scenario/scenario_connection.cc @@ -97,8 +97,7 @@ ScenarioIceConnectionImpl::ScenarioIceConnectionImpl( port_allocator_( new cricket::BasicPortAllocator(manager_->network_manager())), jsep_controller_( - new JsepTransportController(signaling_thread_, - network_thread_, + new JsepTransportController(network_thread_, port_allocator_.get(), /*async_resolver_factory*/ nullptr, CreateJsepConfig())) { @@ -165,8 +164,12 @@ void ScenarioIceConnectionImpl::SetRemoteSdp(SdpType type, const std::string& remote_sdp) { RTC_DCHECK_RUN_ON(signaling_thread_); remote_description_ = webrtc::CreateSessionDescription(type, remote_sdp); - jsep_controller_->SignalIceCandidatesGathered.connect( - this, &ScenarioIceConnectionImpl::OnCandidates); + jsep_controller_->SubscribeIceCandidateGathered( + [this](const std::string& transport, + const std::vector& candidate) { + ScenarioIceConnectionImpl::OnCandidates(transport, candidate); + }); + auto res = jsep_controller_->SetRemoteDescription( remote_description_->GetType(), remote_description_->description()); RTC_CHECK(res.ok()) << res.message(); diff --git a/test/peer_scenario/signaling_route.cc b/test/peer_scenario/signaling_route.cc index 2e0213df16..908d405461 100644 --- a/test/peer_scenario/signaling_route.cc +++ b/test/peer_scenario/signaling_route.cc @@ -41,7 +41,7 @@ struct IceMessage { void StartIceSignalingForRoute(PeerScenarioClient* caller, PeerScenarioClient* callee, - TrafficRoute* send_route) { + CrossTrafficRoute* send_route) { caller->handlers()->on_ice_candidate.push_back( [=](const IceCandidateInterface* candidate) { IceMessage msg(candidate); @@ -56,8 +56,8 @@ void StartIceSignalingForRoute(PeerScenarioClient* caller, void StartSdpNegotiation( PeerScenarioClient* caller, PeerScenarioClient* callee, - TrafficRoute* send_route, - TrafficRoute* ret_route, + CrossTrafficRoute* send_route, + CrossTrafficRoute* ret_route, std::function munge_offer, std::function modify_offer, std::function exchange_finished) { @@ -80,8 +80,8 @@ void StartSdpNegotiation( SignalingRoute::SignalingRoute(PeerScenarioClient* caller, PeerScenarioClient* callee, - TrafficRoute* send_route, - TrafficRoute* ret_route) + CrossTrafficRoute* send_route, + CrossTrafficRoute* ret_route) : caller_(caller), callee_(callee), send_route_(send_route), diff --git a/test/peer_scenario/signaling_route.h b/test/peer_scenario/signaling_route.h index 7434551d3f..021fc4989b 100644 --- a/test/peer_scenario/signaling_route.h +++ b/test/peer_scenario/signaling_route.h @@ -25,8 +25,8 @@ class SignalingRoute { public: SignalingRoute(PeerScenarioClient* caller, PeerScenarioClient* callee, - TrafficRoute* send_route, - TrafficRoute* ret_route); + CrossTrafficRoute* send_route, + CrossTrafficRoute* ret_route); void StartIceSignaling(); @@ -57,8 +57,8 @@ class SignalingRoute { private: PeerScenarioClient* const caller_; PeerScenarioClient* const callee_; - TrafficRoute* const send_route_; - TrafficRoute* const ret_route_; + CrossTrafficRoute* const send_route_; + CrossTrafficRoute* const ret_route_; }; } // namespace test diff --git a/test/peer_scenario/tests/BUILD.gn b/test/peer_scenario/tests/BUILD.gn index 0cf7cf3472..a8b9c2563e 100644 --- a/test/peer_scenario/tests/BUILD.gn +++ b/test/peer_scenario/tests/BUILD.gn @@ -25,6 +25,7 @@ if (rtc_include_tests) { "../../../modules/rtp_rtcp:rtp_rtcp", "../../../modules/rtp_rtcp:rtp_rtcp_format", "../../../pc:rtc_pc_base", + "../../../pc:session_description", ] } } diff --git a/test/peer_scenario/tests/remote_estimate_test.cc b/test/peer_scenario/tests/remote_estimate_test.cc index b882ad9dc2..f1d8345fde 100644 --- a/test/peer_scenario/tests/remote_estimate_test.cc +++ b/test/peer_scenario/tests/remote_estimate_test.cc @@ -8,6 +8,7 @@ * be found in the AUTHORS file in the root of the source tree. */ +#include "modules/rtp_rtcp/source/rtp_util.h" #include "modules/rtp_rtcp/source/rtp_utility.h" #include "pc/media_session.h" #include "pc/session_description.h" @@ -29,7 +30,7 @@ absl::optional GetRtpPacketExtensions( const rtc::ArrayView packet, const RtpHeaderExtensionMap& extension_map) { RtpUtility::RtpHeaderParser rtp_parser(packet.data(), packet.size()); - if (!rtp_parser.RTCP()) { + if (IsRtpPacket(packet)) { RTPHeader header; if (rtp_parser.Parse(&header, &extension_map, true)) { return header.extension; diff --git a/test/peer_scenario/tests/unsignaled_stream_test.cc b/test/peer_scenario/tests/unsignaled_stream_test.cc index 95510a24bd..e0fe02edcf 100644 --- a/test/peer_scenario/tests/unsignaled_stream_test.cc +++ b/test/peer_scenario/tests/unsignaled_stream_test.cc @@ -10,20 +10,44 @@ #include "media/base/stream_params.h" #include "modules/rtp_rtcp/source/byte_io.h" - +#include "modules/rtp_rtcp/source/rtp_util.h" #include "pc/media_session.h" #include "pc/session_description.h" #include "test/field_trial.h" -#include "test/peer_scenario/peer_scenario.h" -#include "test/rtp_header_parser.h" - #include "test/gmock.h" #include "test/gtest.h" +#include "test/peer_scenario/peer_scenario.h" +#include "test/rtp_header_parser.h" namespace webrtc { namespace test { namespace { +enum class MidTestConfiguration { + // Legacy endpoint setup where PT demuxing is used. + kMidNotNegotiated, + // MID is negotiated but missing from packets. PT demuxing is disabled, so + // SSRCs have to be added to the SDP for WebRTC to forward packets correctly. + // Happens when client is spec compliant but the SFU isn't. Popular legacy. + kMidNegotiatedButMissingFromPackets, + // Fully spec-compliant: MID is present so we can safely drop packets with + // unknown MIDs. + kMidNegotiatedAndPresentInPackets, +}; + +// Gives the parameterized test a readable suffix. +std::string TestParametersMidTestConfigurationToString( + testing::TestParamInfo info) { + switch (info.param) { + case MidTestConfiguration::kMidNotNegotiated: + return "MidNotNegotiated"; + case MidTestConfiguration::kMidNegotiatedButMissingFromPackets: + return "MidNegotiatedButMissingFromPackets"; + case MidTestConfiguration::kMidNegotiatedAndPresentInPackets: + return "MidNegotiatedAndPresentInPackets"; + } +} + class FrameObserver : public rtc::VideoSinkInterface { public: FrameObserver() : frame_observed_(false) {} @@ -53,19 +77,24 @@ void set_ssrc(SessionDescriptionInterface* offer, size_t index, uint32_t ssrc) { } // namespace -TEST(UnsignaledStreamTest, ReplacesUnsignaledStreamOnCompletedSignaling) { +class UnsignaledStreamTest + : public ::testing::Test, + public ::testing::WithParamInterface {}; + +TEST_P(UnsignaledStreamTest, ReplacesUnsignaledStreamOnCompletedSignaling) { // This test covers a scenario that might occur if a remote client starts - // sending media packets before negotiation has completed. These packets will - // trigger an unsignalled default stream to be created, and connects that to - // a default video sink. - // In some edge cases using unified plan, the default stream is create in a - // different transceiver to where the media SSRC will actually be used. - // This test verifies that the default stream is removed properly, and that - // packets are demuxed and video frames reach the desired sink. + // sending media packets before negotiation has completed. Depending on setup, + // these packets either get dropped or trigger an unsignalled default stream + // to be created, and connects that to a default video sink. + // In some edge cases using Unified Plan and PT demuxing, the default stream + // is create in a different transceiver to where the media SSRC will actually + // be used. This test verifies that the default stream is removed properly, + // and that packets are demuxed and video frames reach the desired sink. + const MidTestConfiguration kMidTestConfiguration = GetParam(); // Defined before PeerScenario so it gets destructed after, to avoid use after // free. - PeerScenario s(*test_info_); + PeerScenario s(*::testing::UnitTest::GetInstance()->current_test_info()); PeerScenarioClient::Config config = PeerScenarioClient::Config(); // Disable encryption so that we can inject a fake early media packet without @@ -93,34 +122,109 @@ TEST(UnsignaledStreamTest, ReplacesUnsignaledStreamOnCompletedSignaling) { std::atomic got_unsignaled_packet(false); // We will capture the media ssrc of the first added stream, and preemptively - // inject a new media packet using a different ssrc. - // This will create "default stream" for the second ssrc and connected it to - // the default video sink (not set in this test). + // inject a new media packet using a different ssrc. What happens depends on + // the test configuration. + // + // MidTestConfiguration::kMidNotNegotiated: + // - MID is not negotiated which means PT-based demuxing is enabled. Because + // the packets have no MID, the second ssrc packet gets forwarded to the + // first m= section. This will create a "default stream" for the second ssrc + // and connect it to the default video sink (not set in this test). The test + // verifies we can recover from this when we later get packets for the first + // ssrc. + // + // MidTestConfiguration::kMidNegotiatedButMissingFromPackets: + // - MID is negotiated wich means PT-based demuxing is disabled. Because we + // modify the packets not to contain the MID anyway (simulating a legacy SFU + // that does not negotiate properly) unknown SSRCs are dropped but do not + // otherwise cause any issues. + // + // MidTestConfiguration::kMidNegotiatedAndPresentInPackets: + // - MID is negotiated which means PT-based demuxing is enabled. In this case + // the packets have the MID so they either get forwarded or dropped + // depending on if the MID is known. The spec-compliant way is also the most + // straight-forward one. + uint32_t first_ssrc = 0; uint32_t second_ssrc = 0; + absl::optional mid_header_extension_id = absl::nullopt; signaling.NegotiateSdp( - /* munge_sdp = */ {}, + /* munge_sdp = */ + [&](SessionDescriptionInterface* offer) { + // Obtain the MID header extension ID and if we want the + // MidTestConfiguration::kMidNotNegotiated setup then we remove the MID + // header extension through SDP munging (otherwise SDP is not modified). + for (cricket::ContentInfo& content_info : + offer->description()->contents()) { + std::vector header_extensions = + content_info.media_description()->rtp_header_extensions(); + for (auto it = header_extensions.begin(); + it != header_extensions.end(); ++it) { + if (it->uri == RtpExtension::kMidUri) { + // MID header extension found! + mid_header_extension_id = it->id; + if (kMidTestConfiguration == + MidTestConfiguration::kMidNotNegotiated) { + // Munge away the extension. + header_extensions.erase(it); + } + break; + } + } + content_info.media_description()->set_rtp_header_extensions( + std::move(header_extensions)); + } + ASSERT_TRUE(mid_header_extension_id.has_value()); + }, /* modify_sdp = */ [&](SessionDescriptionInterface* offer) { first_ssrc = get_ssrc(offer, 0); second_ssrc = first_ssrc + 1; send_node->router()->SetWatcher([&](const EmulatedIpPacket& packet) { - if (packet.size() > 1 && packet.cdata()[0] >> 6 == 2 && - !RtpHeaderParser::IsRtcp(packet.data.cdata(), - packet.data.size())) { - if (ByteReader::ReadBigEndian(&(packet.cdata()[8])) == - first_ssrc && - !got_unsignaled_packet) { - rtc::CopyOnWriteBuffer updated_buffer = packet.data; - ByteWriter::WriteBigEndian( - updated_buffer.MutableData() + 8, second_ssrc); - EmulatedIpPacket updated_packet( - packet.from, packet.to, updated_buffer, packet.arrival_time); - send_node->OnPacketReceived(std::move(updated_packet)); - got_unsignaled_packet = true; + if (IsRtpPacket(packet.data) && + ByteReader::ReadBigEndian(&(packet.cdata()[8])) == + first_ssrc && + !got_unsignaled_packet) { + // Parse packet and modify the SSRC to simulate a second m= + // section that has not been negotiated yet. + std::vector extensions; + extensions.emplace_back(RtpExtension::kMidUri, + mid_header_extension_id.value()); + RtpHeaderExtensionMap extensions_map(extensions); + RtpPacket parsed_packet; + parsed_packet.IdentifyExtensions(extensions_map); + ASSERT_TRUE(parsed_packet.Parse(packet.data)); + parsed_packet.SetSsrc(second_ssrc); + // The MID extension is present if and only if it was negotiated. + // If present, we either want to remove it or modify it depending + // on setup. + switch (kMidTestConfiguration) { + case MidTestConfiguration::kMidNotNegotiated: + EXPECT_FALSE(parsed_packet.HasExtension()); + break; + case MidTestConfiguration::kMidNegotiatedButMissingFromPackets: + EXPECT_TRUE(parsed_packet.HasExtension()); + ASSERT_TRUE(parsed_packet.RemoveExtension(RtpMid::kId)); + break; + case MidTestConfiguration::kMidNegotiatedAndPresentInPackets: + EXPECT_TRUE(parsed_packet.HasExtension()); + // The simulated second m= section would have a different MID. + // If we don't modify it here then |second_ssrc| would end up + // being mapped to the first m= section which would cause SSRC + // conflicts if we later add the same SSRC to a second m= + // section. Hidden assumption: first m= section does not use + // MID:1. + ASSERT_TRUE(parsed_packet.SetExtension("1")); + break; } + // Inject the modified packet. + rtc::CopyOnWriteBuffer updated_buffer = parsed_packet.Buffer(); + EmulatedIpPacket updated_packet( + packet.from, packet.to, updated_buffer, packet.arrival_time); + send_node->OnPacketReceived(std::move(updated_packet)); + got_unsignaled_packet = true; } }); }, @@ -153,5 +257,13 @@ TEST(UnsignaledStreamTest, ReplacesUnsignaledStreamOnCompletedSignaling) { EXPECT_TRUE(s.WaitAndProcess(&second_sink.frame_observed_)); } +INSTANTIATE_TEST_SUITE_P( + All, + UnsignaledStreamTest, + ::testing::Values(MidTestConfiguration::kMidNotNegotiated, + MidTestConfiguration::kMidNegotiatedButMissingFromPackets, + MidTestConfiguration::kMidNegotiatedAndPresentInPackets), + TestParametersMidTestConfigurationToString); + } // namespace test } // namespace webrtc diff --git a/test/rtp_file_reader.cc b/test/rtp_file_reader.cc index cc5f6f78a2..a09d5a66e4 100644 --- a/test/rtp_file_reader.cc +++ b/test/rtp_file_reader.cc @@ -17,6 +17,7 @@ #include #include +#include "modules/rtp_rtcp/source/rtp_util.h" #include "modules/rtp_rtcp/source/rtp_utility.h" #include "rtc_base/checks.h" #include "rtc_base/constructor_magic.h" @@ -82,7 +83,7 @@ class InterleavedRtpFileReader : public RtpFileReaderImpl { } bool NextPacket(RtpPacket* packet) override { - assert(file_ != nullptr); + RTC_DCHECK(file_); packet->length = RtpPacket::kMaxPacketBufferSize; uint32_t len = 0; TRY(ReadUint32(&len, file_)); @@ -275,7 +276,7 @@ class PcapReader : public RtpFileReaderImpl { if (result == kResultFail) { break; } else if (result == kResultSuccess && packets_.size() == 1) { - assert(stream_start_ms == 0); + RTC_DCHECK_EQ(stream_start_ms, 0); PacketIterator it = packets_.begin(); stream_start_ms = it->time_offset_ms; it->time_offset_ms = 0; @@ -329,9 +330,9 @@ class PcapReader : public RtpFileReaderImpl { } virtual int NextPcap(uint8_t* data, uint32_t* length, uint32_t* time_ms) { - assert(data); - assert(length); - assert(time_ms); + RTC_DCHECK(data); + RTC_DCHECK(length); + RTC_DCHECK(time_ms); if (next_packet_it_ == packets_.end()) { return -1; @@ -408,7 +409,7 @@ class PcapReader : public RtpFileReaderImpl { uint32_t stream_start_ms, uint32_t number, const std::set& ssrc_filter) { - assert(next_packet_pos); + RTC_DCHECK(next_packet_pos); uint32_t ts_sec; // Timestamp seconds. uint32_t ts_usec; // Timestamp microseconds. @@ -434,7 +435,7 @@ class PcapReader : public RtpFileReaderImpl { TRY_PCAP(Read(read_buffer_, marker.payload_length)); RtpUtility::RtpHeaderParser rtp_parser(read_buffer_, marker.payload_length); - if (rtp_parser.RTCP()) { + if (IsRtcpPacket(rtc::MakeArrayView(read_buffer_, marker.payload_length))) { rtp_parser.ParseRtcp(&marker.rtp_header); packets_.push_back(marker); } else { @@ -503,7 +504,7 @@ class PcapReader : public RtpFileReaderImpl { } int ReadXxpIpHeader(RtpPacketMarker* marker) { - assert(marker); + RTC_DCHECK(marker); uint16_t version; uint16_t length; @@ -533,7 +534,7 @@ class PcapReader : public RtpFileReaderImpl { // Skip remaining fields of IP header. uint16_t header_length = (version & 0x0f00) >> (8 - 2); - assert(header_length >= kMinIpHeaderLength); + RTC_DCHECK_GE(header_length, kMinIpHeaderLength); TRY_PCAP(Skip(header_length - kMinIpHeaderLength)); protocol = protocol & 0x00ff; diff --git a/test/rtp_header_parser.cc b/test/rtp_header_parser.cc index 45686acb4c..48e493ddeb 100644 --- a/test/rtp_header_parser.cc +++ b/test/rtp_header_parser.cc @@ -9,46 +9,10 @@ */ #include "test/rtp_header_parser.h" -#include - -#include "modules/rtp_rtcp/include/rtp_header_extension_map.h" #include "modules/rtp_rtcp/source/rtp_utility.h" -#include "rtc_base/synchronization/mutex.h" -#include "rtc_base/thread_annotations.h" namespace webrtc { -class RtpHeaderParserImpl : public RtpHeaderParser { - public: - RtpHeaderParserImpl(); - ~RtpHeaderParserImpl() override = default; - - bool Parse(const uint8_t* packet, - size_t length, - RTPHeader* header) const override; - - bool RegisterRtpHeaderExtension(RTPExtensionType type, uint8_t id) override; - bool RegisterRtpHeaderExtension(RtpExtension extension) override; - - bool DeregisterRtpHeaderExtension(RTPExtensionType type) override; - bool DeregisterRtpHeaderExtension(RtpExtension extension) override; - - private: - mutable Mutex mutex_; - RtpHeaderExtensionMap rtp_header_extension_map_ RTC_GUARDED_BY(mutex_); -}; - -std::unique_ptr RtpHeaderParser::CreateForTest() { - return std::make_unique(); -} - -RtpHeaderParserImpl::RtpHeaderParserImpl() {} - -bool RtpHeaderParser::IsRtcp(const uint8_t* packet, size_t length) { - RtpUtility::RtpHeaderParser rtp_parser(packet, length); - return rtp_parser.RTCP(); -} - absl::optional RtpHeaderParser::GetSsrc(const uint8_t* packet, size_t length) { RtpUtility::RtpHeaderParser rtp_parser(packet, length); @@ -59,43 +23,4 @@ absl::optional RtpHeaderParser::GetSsrc(const uint8_t* packet, return absl::nullopt; } -bool RtpHeaderParserImpl::Parse(const uint8_t* packet, - size_t length, - RTPHeader* header) const { - RtpUtility::RtpHeaderParser rtp_parser(packet, length); - *header = RTPHeader(); - - RtpHeaderExtensionMap map; - { - MutexLock lock(&mutex_); - map = rtp_header_extension_map_; - } - - const bool valid_rtpheader = rtp_parser.Parse(header, &map); - if (!valid_rtpheader) { - return false; - } - return true; -} -bool RtpHeaderParserImpl::RegisterRtpHeaderExtension(RtpExtension extension) { - MutexLock lock(&mutex_); - return rtp_header_extension_map_.RegisterByUri(extension.id, extension.uri); -} - -bool RtpHeaderParserImpl::RegisterRtpHeaderExtension(RTPExtensionType type, - uint8_t id) { - MutexLock lock(&mutex_); - return rtp_header_extension_map_.RegisterByType(id, type); -} - -bool RtpHeaderParserImpl::DeregisterRtpHeaderExtension(RtpExtension extension) { - MutexLock lock(&mutex_); - return rtp_header_extension_map_.Deregister( - rtp_header_extension_map_.GetType(extension.id)); -} - -bool RtpHeaderParserImpl::DeregisterRtpHeaderExtension(RTPExtensionType type) { - MutexLock lock(&mutex_); - return rtp_header_extension_map_.Deregister(type) == 0; -} } // namespace webrtc diff --git a/test/rtp_header_parser.h b/test/rtp_header_parser.h index 851ccf3bc2..f6ed74c043 100644 --- a/test/rtp_header_parser.h +++ b/test/rtp_header_parser.h @@ -10,44 +10,16 @@ #ifndef TEST_RTP_HEADER_PARSER_H_ #define TEST_RTP_HEADER_PARSER_H_ -#include +#include +#include -#include "api/rtp_parameters.h" -#include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" +#include "absl/types/optional.h" namespace webrtc { -struct RTPHeader; - class RtpHeaderParser { public: - static std::unique_ptr CreateForTest(); - virtual ~RtpHeaderParser() {} - - // Returns true if the packet is an RTCP packet, false otherwise. - static bool IsRtcp(const uint8_t* packet, size_t length); static absl::optional GetSsrc(const uint8_t* packet, size_t length); - - // Parses the packet and stores the parsed packet in |header|. Returns true on - // success, false otherwise. - // This method is thread-safe in the sense that it can parse multiple packets - // at once. - virtual bool Parse(const uint8_t* packet, - size_t length, - RTPHeader* header) const = 0; - - // Registers an RTP header extension and binds it to |id|. - virtual bool RegisterRtpHeaderExtension(RTPExtensionType type, - uint8_t id) = 0; - - // Registers an RTP header extension. - virtual bool RegisterRtpHeaderExtension(RtpExtension extension) = 0; - - // De-registers an RTP header extension. - virtual bool DeregisterRtpHeaderExtension(RTPExtensionType type) = 0; - - // De-registers an RTP header extension. - virtual bool DeregisterRtpHeaderExtension(RtpExtension extension) = 0; }; } // namespace webrtc #endif // TEST_RTP_HEADER_PARSER_H_ diff --git a/test/rtp_rtcp_observer.h b/test/rtp_rtcp_observer.h index 036f5cdc20..f17560f021 100644 --- a/test/rtp_rtcp_observer.h +++ b/test/rtp_rtcp_observer.h @@ -15,14 +15,15 @@ #include #include +#include "api/array_view.h" #include "api/test/simulated_network.h" #include "call/simulated_packet_receiver.h" #include "call/video_send_stream.h" +#include "modules/rtp_rtcp/source/rtp_util.h" #include "rtc_base/event.h" #include "system_wrappers/include/field_trial.h" #include "test/direct_transport.h" #include "test/gtest.h" -#include "test/rtp_header_parser.h" namespace { const int kShortTimeoutMs = 500; @@ -98,7 +99,7 @@ class PacketTransport : public test::DirectTransport { bool SendRtp(const uint8_t* packet, size_t length, const PacketOptions& options) override { - EXPECT_FALSE(RtpHeaderParser::IsRtcp(packet, length)); + EXPECT_TRUE(IsRtpPacket(rtc::MakeArrayView(packet, length))); RtpRtcpObserver::Action action; { if (transport_type_ == kSender) { @@ -118,7 +119,7 @@ class PacketTransport : public test::DirectTransport { } bool SendRtcp(const uint8_t* packet, size_t length) override { - EXPECT_TRUE(RtpHeaderParser::IsRtcp(packet, length)); + EXPECT_TRUE(IsRtcpPacket(rtc::MakeArrayView(packet, length))); RtpRtcpObserver::Action action; { if (transport_type_ == kSender) { diff --git a/test/scenario/BUILD.gn b/test/scenario/BUILD.gn index f5c22fcafb..a64f8317a0 100644 --- a/test/scenario/BUILD.gn +++ b/test/scenario/BUILD.gn @@ -21,7 +21,7 @@ rtc_library("column_printer") { ] } -if (is_ios || rtc_include_tests) { +if (rtc_include_tests && !build_with_chromium) { scenario_resources = [ "../../resources/difficult_photo_1850_1110.yuv", "../../resources/photo_1850_1110.yuv", @@ -29,21 +29,20 @@ if (is_ios || rtc_include_tests) { "../../resources/web_screenshot_1850_1110.yuv", ] scenario_unittest_resources = [ "../../resources/foreman_cif.yuv" ] -} -if (is_ios) { - bundle_data("scenario_resources_bundle_data") { - testonly = true - sources = scenario_resources - outputs = [ "{{bundle_resources_dir}}/{{source_file_part}}" ] - } - bundle_data("scenario_unittest_resources_bundle_data") { - testonly = true - sources = scenario_unittest_resources - outputs = [ "{{bundle_resources_dir}}/{{source_file_part}}" ] + if (is_ios) { + bundle_data("scenario_resources_bundle_data") { + testonly = true + sources = scenario_resources + outputs = [ "{{bundle_resources_dir}}/{{source_file_part}}" ] + } + bundle_data("scenario_unittest_resources_bundle_data") { + testonly = true + sources = scenario_unittest_resources + outputs = [ "{{bundle_resources_dir}}/{{source_file_part}}" ] + } } -} -if (rtc_include_tests) { + rtc_library("scenario") { testonly = true sources = [ @@ -82,6 +81,7 @@ if (rtc_include_tests) { "../../api:libjingle_peerconnection_api", "../../api:rtc_event_log_output_file", "../../api:rtp_parameters", + "../../api:sequence_checker", "../../api:time_controller", "../../api:time_controller", "../../api:transport_api", @@ -132,9 +132,10 @@ if (rtc_include_tests) { "../../rtc_base:rtc_stats_counters", "../../rtc_base:rtc_task_queue", "../../rtc_base:safe_minmax", + "../../rtc_base:socket_address", "../../rtc_base:task_queue_for_test", + "../../rtc_base:threading", "../../rtc_base/synchronization:mutex", - "../../rtc_base/synchronization:sequence_checker", "../../rtc_base/task_utils:repeating_task", "../../system_wrappers", "../../system_wrappers:field_trial", @@ -174,6 +175,8 @@ if (rtc_include_tests) { ] deps = [ ":scenario", + "../../api/test/network_emulation", + "../../api/test/network_emulation:create_cross_traffic", "../../logging:mocks", "../../rtc_base:checks", "../../rtc_base:rtc_base_approved", diff --git a/test/scenario/audio_stream.cc b/test/scenario/audio_stream.cc index f3cb8320aa..63f78c8f71 100644 --- a/test/scenario/audio_stream.cc +++ b/test/scenario/audio_stream.cc @@ -185,7 +185,6 @@ ReceiveAudioStream::ReceiveAudioStream( recv_config.rtp.extensions = {{RtpExtension::kTransportSequenceNumberUri, kTransportSequenceNumberExtensionId}}; } - receiver_->AddExtensions(recv_config.rtp.extensions); recv_config.decoder_factory = decoder_factory; recv_config.decoder_map = { {CallTest::kAudioSendPayloadType, {"opus", 48000, 2}}}; diff --git a/test/scenario/call_client.cc b/test/scenario/call_client.cc index f7cd47c36e..be8d39f2a5 100644 --- a/test/scenario/call_client.cc +++ b/test/scenario/call_client.cc @@ -17,6 +17,8 @@ #include "api/rtc_event_log/rtc_event_log_factory.h" #include "api/transport/network_types.h" #include "modules/audio_mixer/audio_mixer_impl.h" +#include "modules/rtp_rtcp/source/rtp_util.h" +#include "test/rtp_header_parser.h" namespace webrtc { namespace test { @@ -213,7 +215,6 @@ CallClient::CallClient( clock_(time_controller->GetClock()), log_writer_factory_(std::move(log_writer_factory)), network_controller_factory_(log_writer_factory_.get(), config.transport), - header_parser_(RtpHeaderParser::CreateForTest()), task_queue_(time_controller->GetTaskQueueFactory()->CreateTaskQueue( "CallClient", TaskQueueFactory::Priority::NORMAL)) { @@ -293,7 +294,7 @@ void CallClient::UpdateBitrateConstraints( void CallClient::OnPacketReceived(EmulatedIpPacket packet) { MediaType media_type = MediaType::ANY; - if (!RtpHeaderParser::IsRtcp(packet.cdata(), packet.data.size())) { + if (IsRtpPacket(packet.data)) { auto ssrc = RtpHeaderParser::GetSsrc(packet.cdata(), packet.data.size()); RTC_CHECK(ssrc.has_value()); media_type = ssrc_media_types_[*ssrc]; @@ -338,11 +339,6 @@ uint32_t CallClient::GetNextRtxSsrc() { return kSendRtxSsrcs[next_rtx_ssrc_index_++]; } -void CallClient::AddExtensions(std::vector extensions) { - for (const auto& extension : extensions) - header_parser_->RegisterRtpHeaderExtension(extension); -} - void CallClient::SendTask(std::function task) { task_queue_.SendTask(std::move(task), RTC_FROM_HERE); } diff --git a/test/scenario/call_client.h b/test/scenario/call_client.h index 27ec9fa39c..08b0131350 100644 --- a/test/scenario/call_client.h +++ b/test/scenario/call_client.h @@ -26,7 +26,6 @@ #include "rtc_base/task_queue_for_test.h" #include "test/logging/log_writer.h" #include "test/network/network_emulation.h" -#include "test/rtp_header_parser.h" #include "test/scenario/column_printer.h" #include "test/scenario/network_node.h" #include "test/scenario/scenario_config.h" @@ -137,7 +136,6 @@ class CallClient : public EmulatedNetworkReceiverInterface { uint32_t GetNextAudioSsrc(); uint32_t GetNextAudioLocalSsrc(); uint32_t GetNextRtxSsrc(); - void AddExtensions(std::vector extensions); int16_t Bind(EmulatedEndpoint* endpoint); void UnBind(); @@ -149,7 +147,6 @@ class CallClient : public EmulatedNetworkReceiverInterface { CallClientFakeAudio fake_audio_setup_; std::unique_ptr call_; std::unique_ptr transport_; - std::unique_ptr const header_parser_; std::vector> endpoints_; int next_video_ssrc_index_ = 0; diff --git a/test/scenario/scenario.cc b/test/scenario/scenario.cc index c1c664a754..239aad9dfe 100644 --- a/test/scenario/scenario.cc +++ b/test/scenario/scenario.cc @@ -198,7 +198,7 @@ SimulationNode* Scenario::CreateMutableSimulationNode( void Scenario::TriggerPacketBurst(std::vector over_nodes, size_t num_packets, size_t packet_size) { - network_manager_.CreateTrafficRoute(over_nodes) + network_manager_.CreateCrossTrafficRoute(over_nodes) ->TriggerPacketBurst(num_packets, packet_size); } @@ -206,7 +206,7 @@ void Scenario::NetworkDelayedAction( std::vector over_nodes, size_t packet_size, std::function action) { - network_manager_.CreateTrafficRoute(over_nodes) + network_manager_.CreateCrossTrafficRoute(over_nodes) ->NetworkDelayedAction(packet_size, action); } diff --git a/test/scenario/scenario_unittest.cc b/test/scenario/scenario_unittest.cc index 177ac27373..6861151a2d 100644 --- a/test/scenario/scenario_unittest.cc +++ b/test/scenario/scenario_unittest.cc @@ -11,6 +11,8 @@ #include +#include "api/test/network_emulation/create_cross_traffic.h" +#include "api/test/network_emulation/cross_traffic.h" #include "test/field_trial.h" #include "test/gtest.h" #include "test/logging/memory_log_writer.h" @@ -44,8 +46,8 @@ TEST(ScenarioTest, StartsAndStopsWithoutErrors) { s.CreateAudioStream(route->reverse(), audio_stream_config); RandomWalkConfig cross_traffic_config; - s.net()->CreateRandomWalkCrossTraffic( - s.net()->CreateTrafficRoute({alice_net}), cross_traffic_config); + s.net()->StartCrossTraffic(CreateRandomWalkCrossTraffic( + s.net()->CreateCrossTrafficRoute({alice_net}), cross_traffic_config)); s.NetworkDelayedAction({alice_net, bob_net}, 100, [&packet_received] { packet_received = true; }); @@ -180,7 +182,11 @@ TEST(ScenarioTest, s.RunFor(TimeDelta::Seconds(10)); // Make sure retransmissions have happened. int retransmit_packets = 0; - for (const auto& substream : video->send()->GetStats().substreams) { + + VideoSendStream::Stats stats; + alice->SendTask([&]() { stats = video->send()->GetStats(); }); + + for (const auto& substream : stats.substreams) { retransmit_packets += substream.second.rtp_stats.retransmitted.packets; } EXPECT_GT(retransmit_packets, 0); diff --git a/test/scenario/stats_collection_unittest.cc b/test/scenario/stats_collection_unittest.cc index 17f0e3a656..96b2830c76 100644 --- a/test/scenario/stats_collection_unittest.cc +++ b/test/scenario/stats_collection_unittest.cc @@ -33,8 +33,14 @@ void CreateAnalyzedStream(Scenario* s, auto* audio = s->CreateAudioStream(route->forward(), AudioStreamConfig()); s->Every(TimeDelta::Seconds(1), [=] { collectors->call.AddStats(caller->GetStats()); - collectors->video_send.AddStats(video->send()->GetStats(), s->Now()); - collectors->audio_receive.AddStats(audio->receive()->GetStats()); + + VideoSendStream::Stats send_stats; + caller->SendTask([&]() { send_stats = video->send()->GetStats(); }); + collectors->video_send.AddStats(send_stats, s->Now()); + + AudioReceiveStream::Stats receive_stats; + caller->SendTask([&]() { receive_stats = audio->receive()->GetStats(); }); + collectors->audio_receive.AddStats(receive_stats); // Querying the video stats from within the expected runtime environment // (i.e. the TQ that belongs to the CallClient, not the Scenario TQ that @@ -87,7 +93,7 @@ TEST(ScenarioAnalyzerTest, PsnrIsLowWhenNetworkIsBad) { EXPECT_NEAR(stats.call.stats().target_rate.Mean().kbps(), 75, 50); EXPECT_NEAR(stats.video_send.stats().media_bitrate.Mean().kbps(), 100, 50); EXPECT_NEAR(stats.video_receive.stats().resolution.Mean(), 180, 10); - EXPECT_NEAR(stats.audio_receive.stats().jitter_buffer.Mean().ms(), 250, 150); + EXPECT_NEAR(stats.audio_receive.stats().jitter_buffer.Mean().ms(), 250, 200); } TEST(ScenarioAnalyzerTest, CountsCapturedButNotRendered) { diff --git a/test/scenario/video_frame_matcher.h b/test/scenario/video_frame_matcher.h index f7f62436ac..a3aa85447d 100644 --- a/test/scenario/video_frame_matcher.h +++ b/test/scenario/video_frame_matcher.h @@ -52,7 +52,7 @@ class VideoFrameMatcher { rtc::scoped_refptr thumb; int repeat_count = 0; }; - using DecodedFrame = rtc::RefCountedObject; + using DecodedFrame = rtc::FinalRefCountedObject; struct CapturedFrame { int id; Timestamp capture_time = Timestamp::PlusInfinity(); diff --git a/test/scenario/video_stream.cc b/test/scenario/video_stream.cc index 5525a9d203..96f6f5bc59 100644 --- a/test/scenario/video_stream.cc +++ b/test/scenario/video_stream.cc @@ -175,8 +175,8 @@ CreateVp9SpecificSettings(VideoStreamConfig video_config) { vp9.automaticResizeOn = conf.single.automatic_scaling; vp9.denoisingOn = conf.single.denoising; } - return new rtc::RefCountedObject< - VideoEncoderConfig::Vp9EncoderSpecificSettings>(vp9); + return rtc::make_ref_counted( + vp9); } rtc::scoped_refptr @@ -192,8 +192,8 @@ CreateVp8SpecificSettings(VideoStreamConfig config) { vp8_settings.automaticResizeOn = config.encoder.single.automatic_scaling; vp8_settings.denoisingOn = config.encoder.single.denoising; } - return new rtc::RefCountedObject< - VideoEncoderConfig::Vp8EncoderSpecificSettings>(vp8_settings); + return rtc::make_ref_counted( + vp8_settings); } rtc::scoped_refptr @@ -205,8 +205,8 @@ CreateH264SpecificSettings(VideoStreamConfig config) { h264_settings.frameDroppingOn = config.encoder.frame_dropping; h264_settings.keyFrameInterval = config.encoder.key_frame_interval.value_or(0); - return new rtc::RefCountedObject< - VideoEncoderConfig::H264EncoderSpecificSettings>(h264_settings); + return rtc::make_ref_counted( + h264_settings); } rtc::scoped_refptr @@ -248,11 +248,11 @@ VideoEncoderConfig CreateVideoEncoderConfig(VideoStreamConfig config) { bool screenshare = config.encoder.content_type == VideoStreamConfig::Encoder::ContentType::kScreen; encoder_config.video_stream_factory = - new rtc::RefCountedObject( + rtc::make_ref_counted( cricket_codec, kDefaultMaxQp, screenshare, screenshare); } else { encoder_config.video_stream_factory = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); } // TODO(srte): Base this on encoder capabilities. @@ -571,10 +571,10 @@ ReceiveVideoStream::ReceiveVideoStream(CallClient* receiver, RTC_DCHECK(num_streams == 1); FlexfecReceiveStream::Config flexfec(feedback_transport); flexfec.payload_type = CallTest::kFlexfecPayloadType; - flexfec.remote_ssrc = CallTest::kFlexfecSendSsrc; + flexfec.rtp.remote_ssrc = CallTest::kFlexfecSendSsrc; flexfec.protected_media_ssrcs = send_stream->rtx_ssrcs_; - flexfec.local_ssrc = recv_config.rtp.local_ssrc; - receiver_->ssrc_media_types_[flexfec.remote_ssrc] = MediaType::VIDEO; + flexfec.rtp.local_ssrc = recv_config.rtp.local_ssrc; + receiver_->ssrc_media_types_[flexfec.rtp.remote_ssrc] = MediaType::VIDEO; receiver_->SendTask([this, &flexfec] { flecfec_stream_ = receiver_->call_->CreateFlexfecReceiveStream(flexfec); diff --git a/test/scenario/video_stream_unittest.cc b/test/scenario/video_stream_unittest.cc index 52be3f82ff..c1649a39b3 100644 --- a/test/scenario/video_stream_unittest.cc +++ b/test/scenario/video_stream_unittest.cc @@ -9,6 +9,8 @@ */ #include +#include "api/test/network_emulation/create_cross_traffic.h" +#include "api/test/network_emulation/cross_traffic.h" #include "test/field_trial.h" #include "test/gtest.h" #include "test/scenario/scenario.h" @@ -128,7 +130,9 @@ TEST(VideoStreamTest, SendsNacksOnLoss) { auto video = s.CreateVideoStream(route->forward(), VideoStreamConfig()); s.RunFor(TimeDelta::Seconds(1)); int retransmit_packets = 0; - for (const auto& substream : video->send()->GetStats().substreams) { + VideoSendStream::Stats stats; + route->first()->SendTask([&]() { stats = video->send()->GetStats(); }); + for (const auto& substream : stats.substreams) { retransmit_packets += substream.second.rtp_stats.retransmitted.packets; } EXPECT_GT(retransmit_packets, 0); @@ -150,7 +154,8 @@ TEST(VideoStreamTest, SendsFecWithUlpFec) { c->stream.use_ulpfec = true; }); s.RunFor(TimeDelta::Seconds(5)); - VideoSendStream::Stats video_stats = video->send()->GetStats(); + VideoSendStream::Stats video_stats; + route->first()->SendTask([&]() { video_stats = video->send()->GetStats(); }); EXPECT_GT(video_stats.substreams.begin()->second.rtp_stats.fec.packets, 0u); } TEST(VideoStreamTest, SendsFecWithFlexFec) { @@ -167,7 +172,8 @@ TEST(VideoStreamTest, SendsFecWithFlexFec) { c->stream.use_flexfec = true; }); s.RunFor(TimeDelta::Seconds(5)); - VideoSendStream::Stats video_stats = video->send()->GetStats(); + VideoSendStream::Stats video_stats; + route->first()->SendTask([&]() { video_stats = video->send()->GetStats(); }); EXPECT_GT(video_stats.substreams.begin()->second.rtp_stats.fec.packets, 0u); } @@ -217,8 +223,9 @@ TEST(VideoStreamTest, ResolutionAdaptsToAvailableBandwidth) { // Trigger cross traffic, run until we have seen 3 consecutive // seconds with no VGA frames due to reduced available bandwidth. - auto cross_traffic = - s.net()->StartFakeTcpCrossTraffic(send_net, ret_net, FakeTcpConfig()); + auto cross_traffic = s.net()->StartCrossTraffic(CreateFakeTcpCrossTraffic( + s.net()->CreateRoute(send_net), s.net()->CreateRoute(ret_net), + FakeTcpConfig())); int num_seconds_without_vga = 0; int num_iterations = 0; diff --git a/test/testsupport/file_utils.cc b/test/testsupport/file_utils.cc index 0b4ffa446c..1f829d320b 100644 --- a/test/testsupport/file_utils.cc +++ b/test/testsupport/file_utils.cc @@ -107,7 +107,7 @@ std::string TempFilename(const std::string& dir, const std::string& prefix) { if (::GetTempFileNameW(rtc::ToUtf16(dir).c_str(), rtc::ToUtf16(prefix).c_str(), 0, filename) != 0) return rtc::ToUtf8(filename); - assert(false); + RTC_NOTREACHED(); return ""; #else int len = dir.size() + prefix.size() + 2 + 6; @@ -116,7 +116,7 @@ std::string TempFilename(const std::string& dir, const std::string& prefix) { snprintf(tempname.get(), len, "%s/%sXXXXXX", dir.c_str(), prefix.c_str()); int fd = ::mkstemp(tempname.get()); if (fd == -1) { - assert(false); + RTC_NOTREACHED(); return ""; } else { ::close(fd); diff --git a/test/testsupport/ivf_video_frame_generator.h b/test/testsupport/ivf_video_frame_generator.h index 32ba21ed26..8ee9c03417 100644 --- a/test/testsupport/ivf_video_frame_generator.h +++ b/test/testsupport/ivf_video_frame_generator.h @@ -15,6 +15,7 @@ #include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "api/test/frame_generator_interface.h" #include "api/video/video_codec_type.h" #include "api/video/video_frame.h" @@ -22,7 +23,6 @@ #include "modules/video_coding/utility/ivf_file_reader.h" #include "rtc_base/event.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" namespace webrtc { namespace test { diff --git a/test/testsupport/ivf_video_frame_generator_unittest.cc b/test/testsupport/ivf_video_frame_generator_unittest.cc index bea9cd2489..126f7203b8 100644 --- a/test/testsupport/ivf_video_frame_generator_unittest.cc +++ b/test/testsupport/ivf_video_frame_generator_unittest.cc @@ -48,7 +48,7 @@ constexpr int kMaxFramerate = 30; constexpr int kMaxFrameEncodeWaitTimeoutMs = 2000; static const VideoEncoder::Capabilities kCapabilities(false); -#if defined(WEBRTC_ANDROID) || defined(WEBRTC_IOS) +#if defined(WEBRTC_ANDROID) || defined(WEBRTC_IOS) || defined(WEBRTC_ARCH_ARM64) constexpr double kExpectedMinPsnr = 35; #else constexpr double kExpectedMinPsnr = 39; diff --git a/test/testsupport/perf_result_reporter.cc b/test/testsupport/perf_result_reporter.cc index e4c98e7446..158f1cd768 100644 --- a/test/testsupport/perf_result_reporter.cc +++ b/test/testsupport/perf_result_reporter.cc @@ -12,6 +12,8 @@ #include +#include "absl/strings/string_view.h" + namespace { // These characters mess with either the stdout parsing or the dashboard itself. @@ -21,7 +23,7 @@ const std::vector& InvalidCharacters() { return kInvalidCharacters; } -void CheckForInvalidCharacters(const std::string& str) { +void CheckForInvalidCharacters(absl::string_view str) { for (const auto& invalid : InvalidCharacters()) { RTC_CHECK(str.find(invalid) == std::string::npos) << "Given invalid character for perf names '" << invalid << "'"; @@ -76,8 +78,8 @@ std::string UnitToString(Unit unit) { } // namespace -PerfResultReporter::PerfResultReporter(const std::string& metric_basename, - const std::string& story_name) +PerfResultReporter::PerfResultReporter(absl::string_view metric_basename, + absl::string_view story_name) : metric_basename_(metric_basename), story_name_(story_name) { CheckForInvalidCharacters(metric_basename_); CheckForInvalidCharacters(story_name_); @@ -85,19 +87,20 @@ PerfResultReporter::PerfResultReporter(const std::string& metric_basename, PerfResultReporter::~PerfResultReporter() = default; -void PerfResultReporter::RegisterMetric(const std::string& metric_suffix, +void PerfResultReporter::RegisterMetric(absl::string_view metric_suffix, Unit unit) { RegisterMetric(metric_suffix, unit, ImproveDirection::kNone); } -void PerfResultReporter::RegisterMetric(const std::string& metric_suffix, +void PerfResultReporter::RegisterMetric(absl::string_view metric_suffix, Unit unit, ImproveDirection improve_direction) { CheckForInvalidCharacters(metric_suffix); - RTC_CHECK(metric_map_.count(metric_suffix) == 0); - metric_map_.insert({metric_suffix, {unit, improve_direction}}); + std::string metric(metric_suffix); + RTC_CHECK(metric_map_.count(metric) == 0); + metric_map_.insert({std::move(metric), {unit, improve_direction}}); } -void PerfResultReporter::AddResult(const std::string& metric_suffix, +void PerfResultReporter::AddResult(absl::string_view metric_suffix, size_t value) const { auto info = GetMetricInfoOrFail(metric_suffix); @@ -105,7 +108,7 @@ void PerfResultReporter::AddResult(const std::string& metric_suffix, UnitToString(info.unit), kNotImportant, info.improve_direction); } -void PerfResultReporter::AddResult(const std::string& metric_suffix, +void PerfResultReporter::AddResult(absl::string_view metric_suffix, double value) const { auto info = GetMetricInfoOrFail(metric_suffix); @@ -114,7 +117,7 @@ void PerfResultReporter::AddResult(const std::string& metric_suffix, } void PerfResultReporter::AddResultList( - const std::string& metric_suffix, + absl::string_view metric_suffix, rtc::ArrayView values) const { auto info = GetMetricInfoOrFail(metric_suffix); @@ -123,7 +126,7 @@ void PerfResultReporter::AddResultList( info.improve_direction); } -void PerfResultReporter::AddResultMeanAndError(const std::string& metric_suffix, +void PerfResultReporter::AddResultMeanAndError(absl::string_view metric_suffix, const double mean, const double error) { auto info = GetMetricInfoOrFail(metric_suffix); @@ -134,8 +137,8 @@ void PerfResultReporter::AddResultMeanAndError(const std::string& metric_suffix, } absl::optional PerfResultReporter::GetMetricInfo( - const std::string& metric_suffix) const { - auto iter = metric_map_.find(metric_suffix); + absl::string_view metric_suffix) const { + auto iter = metric_map_.find(std::string(metric_suffix)); if (iter == metric_map_.end()) { return absl::optional(); } @@ -144,7 +147,7 @@ absl::optional PerfResultReporter::GetMetricInfo( } MetricInfo PerfResultReporter::GetMetricInfoOrFail( - const std::string& metric_suffix) const { + absl::string_view metric_suffix) const { absl::optional info = GetMetricInfo(metric_suffix); RTC_CHECK(info.has_value()) << "Attempted to use unregistered metric " << metric_suffix; diff --git a/test/testsupport/perf_result_reporter.h b/test/testsupport/perf_result_reporter.h index c8028574aa..aeb1786824 100644 --- a/test/testsupport/perf_result_reporter.h +++ b/test/testsupport/perf_result_reporter.h @@ -14,6 +14,7 @@ #include #include +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "api/array_view.h" #include "test/testsupport/perf_test.h" @@ -61,34 +62,34 @@ struct MetricInfo { // as separate subtests (e.g. next to bwe_15s). class PerfResultReporter { public: - PerfResultReporter(const std::string& metric_basename, - const std::string& story_name); + PerfResultReporter(absl::string_view metric_basename, + absl::string_view story_name); ~PerfResultReporter(); - void RegisterMetric(const std::string& metric_suffix, Unit unit); - void RegisterMetric(const std::string& metric_suffix, + void RegisterMetric(absl::string_view metric_suffix, Unit unit); + void RegisterMetric(absl::string_view metric_suffix, Unit unit, ImproveDirection improve_direction); - void AddResult(const std::string& metric_suffix, size_t value) const; - void AddResult(const std::string& metric_suffix, double value) const; + void AddResult(absl::string_view metric_suffix, size_t value) const; + void AddResult(absl::string_view metric_suffix, double value) const; - void AddResultList(const std::string& metric_suffix, + void AddResultList(absl::string_view metric_suffix, rtc::ArrayView values) const; // Users should prefer AddResultList if possible, as otherwise the min/max // values reported on the perf dashboard aren't useful. // |mean_and_error| should be a comma-separated string of mean then // error/stddev, e.g. "2.4,0.5". - void AddResultMeanAndError(const std::string& metric_suffix, + void AddResultMeanAndError(absl::string_view metric_suffix, const double mean, const double error); // Returns the metric info if it has been registered. absl::optional GetMetricInfo( - const std::string& metric_suffix) const; + absl::string_view metric_suffix) const; private: - MetricInfo GetMetricInfoOrFail(const std::string& metric_suffix) const; + MetricInfo GetMetricInfoOrFail(absl::string_view metric_suffix) const; std::string metric_basename_; std::string story_name_; diff --git a/test/testsupport/perf_test.cc b/test/testsupport/perf_test.cc index b68eaa46a1..d282bf23a1 100644 --- a/test/testsupport/perf_test.cc +++ b/test/testsupport/perf_test.cc @@ -17,7 +17,10 @@ #include #include +#include "absl/strings/string_view.h" +#include "api/numerics/samples_stats_counter.h" #include "rtc_base/checks.h" +#include "rtc_base/strings/string_builder.h" #include "rtc_base/synchronization/mutex.h" #include "test/testsupport/file_utils.h" #include "test/testsupport/perf_test_histogram_writer.h" @@ -28,18 +31,31 @@ namespace test { namespace { std::string UnitWithDirection( - const std::string& units, + absl::string_view units, webrtc::test::ImproveDirection improve_direction) { switch (improve_direction) { case webrtc::test::ImproveDirection::kNone: - return units; + return std::string(units); case webrtc::test::ImproveDirection::kSmallerIsBetter: - return units + "_smallerIsBetter"; + return std::string(units) + "_smallerIsBetter"; case webrtc::test::ImproveDirection::kBiggerIsBetter: - return units + "_biggerIsBetter"; + return std::string(units) + "_biggerIsBetter"; } } +std::vector GetSortedSamples( + const SamplesStatsCounter& counter) { + rtc::ArrayView view = + counter.GetTimedSamples(); + std::vector out(view.begin(), view.end()); + std::sort(out.begin(), out.end(), + [](const SamplesStatsCounter::StatsSample& a, + const SamplesStatsCounter::StatsSample& b) { + return a.time < b.time; + }); + return out; +} + template void OutputListToStream(std::ostream* ostream, const Container& values) { const char* sep = ""; @@ -65,12 +81,14 @@ class PlottableCounterPrinter { output_ = output; } - void AddCounter(const std::string& graph_name, - const std::string& trace_name, + void AddCounter(absl::string_view graph_name, + absl::string_view trace_name, const webrtc::SamplesStatsCounter& counter, - const std::string& units) { + absl::string_view units) { MutexLock lock(&mutex_); - plottable_counters_.push_back({graph_name, trace_name, counter, units}); + plottable_counters_.push_back({std::string(graph_name), + std::string(trace_name), counter, + std::string(units)}); } void Print(const std::vector& desired_graphs_raw) const { @@ -128,10 +146,10 @@ class ResultsLinePrinter { output_ = output; } - void PrintResult(const std::string& graph_name, - const std::string& trace_name, + void PrintResult(absl::string_view graph_name, + absl::string_view trace_name, const double value, - const std::string& units, + absl::string_view units, bool important, ImproveDirection improve_direction) { std::ostringstream value_stream; @@ -143,11 +161,11 @@ class ResultsLinePrinter { important); } - void PrintResultMeanAndError(const std::string& graph_name, - const std::string& trace_name, + void PrintResultMeanAndError(absl::string_view graph_name, + absl::string_view trace_name, const double mean, const double error, - const std::string& units, + absl::string_view units, bool important, ImproveDirection improve_direction) { std::ostringstream value_stream; @@ -157,10 +175,10 @@ class ResultsLinePrinter { UnitWithDirection(units, improve_direction), important); } - void PrintResultList(const std::string& graph_name, - const std::string& trace_name, + void PrintResultList(absl::string_view graph_name, + absl::string_view trace_name, const rtc::ArrayView values, - const std::string& units, + absl::string_view units, const bool important, webrtc::test::ImproveDirection improve_direction) { std::ostringstream value_stream; @@ -171,20 +189,21 @@ class ResultsLinePrinter { } private: - void PrintResultImpl(const std::string& graph_name, - const std::string& trace_name, - const std::string& values, - const std::string& prefix, - const std::string& suffix, - const std::string& units, + void PrintResultImpl(absl::string_view graph_name, + absl::string_view trace_name, + absl::string_view values, + absl::string_view prefix, + absl::string_view suffix, + absl::string_view units, bool important) { MutexLock lock(&mutex_); + rtc::StringBuilder message; + message << (important ? "*" : "") << "RESULT " << graph_name << ": " + << trace_name << "= " << prefix << values << suffix << " " << units; // <*>RESULT : = // <*>RESULT : = {, } // <*>RESULT : = [,value,value,...,] - fprintf(output_, "%sRESULT %s: %s= %s%s%s %s\n", important ? "*" : "", - graph_name.c_str(), trace_name.c_str(), prefix.c_str(), - values.c_str(), suffix.c_str(), units.c_str()); + fprintf(output_, "%s\n", message.str().c_str()); } Mutex mutex_; @@ -241,73 +260,94 @@ bool WritePerfResults(const std::string& output_path) { return true; } -void PrintResult(const std::string& measurement, - const std::string& modifier, - const std::string& trace, +void PrintResult(absl::string_view measurement, + absl::string_view modifier, + absl::string_view trace, const double value, - const std::string& units, + absl::string_view units, bool important, ImproveDirection improve_direction) { - std::string graph_name = measurement + modifier; + rtc::StringBuilder graph_name; + graph_name << measurement << modifier; RTC_CHECK(std::isfinite(value)) - << "Expected finite value for graph " << graph_name << ", trace name " - << trace << ", units " << units << ", got " << value; - GetPerfWriter().LogResult(graph_name, trace, value, units, important, + << "Expected finite value for graph " << graph_name.str() + << ", trace name " << trace << ", units " << units << ", got " << value; + GetPerfWriter().LogResult(graph_name.str(), trace, value, units, important, improve_direction); - GetResultsLinePrinter().PrintResult(graph_name, trace, value, units, + GetResultsLinePrinter().PrintResult(graph_name.str(), trace, value, units, important, improve_direction); } -void PrintResult(const std::string& measurement, - const std::string& modifier, - const std::string& trace, +void PrintResult(absl::string_view measurement, + absl::string_view modifier, + absl::string_view trace, const SamplesStatsCounter& counter, - const std::string& units, + absl::string_view units, const bool important, ImproveDirection improve_direction) { - std::string graph_name = measurement + modifier; - GetPlottableCounterPrinter().AddCounter(graph_name, trace, counter, units); + rtc::StringBuilder graph_name; + graph_name << measurement << modifier; + GetPlottableCounterPrinter().AddCounter(graph_name.str(), trace, counter, + units); double mean = counter.IsEmpty() ? 0 : counter.GetAverage(); double error = counter.IsEmpty() ? 0 : counter.GetStandardDeviation(); - PrintResultMeanAndError(measurement, modifier, trace, mean, error, units, - important, improve_direction); + + std::vector timed_samples = + GetSortedSamples(counter); + std::vector samples(timed_samples.size()); + for (size_t i = 0; i < timed_samples.size(); ++i) { + samples[i] = timed_samples[i].value; + } + // If we have an empty counter, default it to 0. + if (samples.empty()) { + samples.push_back(0); + } + + GetPerfWriter().LogResultList(graph_name.str(), trace, samples, units, + important, improve_direction); + GetResultsLinePrinter().PrintResultMeanAndError(graph_name.str(), trace, mean, + error, units, important, + improve_direction); } -void PrintResultMeanAndError(const std::string& measurement, - const std::string& modifier, - const std::string& trace, +void PrintResultMeanAndError(absl::string_view measurement, + absl::string_view modifier, + absl::string_view trace, const double mean, const double error, - const std::string& units, + absl::string_view units, bool important, ImproveDirection improve_direction) { RTC_CHECK(std::isfinite(mean)); RTC_CHECK(std::isfinite(error)); - std::string graph_name = measurement + modifier; - GetPerfWriter().LogResultMeanAndError(graph_name, trace, mean, error, units, - important, improve_direction); - GetResultsLinePrinter().PrintResultMeanAndError( - graph_name, trace, mean, error, units, important, improve_direction); + rtc::StringBuilder graph_name; + graph_name << measurement << modifier; + GetPerfWriter().LogResultMeanAndError(graph_name.str(), trace, mean, error, + units, important, improve_direction); + GetResultsLinePrinter().PrintResultMeanAndError(graph_name.str(), trace, mean, + error, units, important, + improve_direction); } -void PrintResultList(const std::string& measurement, - const std::string& modifier, - const std::string& trace, +void PrintResultList(absl::string_view measurement, + absl::string_view modifier, + absl::string_view trace, const rtc::ArrayView values, - const std::string& units, + absl::string_view units, bool important, ImproveDirection improve_direction) { for (double v : values) { RTC_CHECK(std::isfinite(v)); } - std::string graph_name = measurement + modifier; - GetPerfWriter().LogResultList(graph_name, trace, values, units, important, - improve_direction); - GetResultsLinePrinter().PrintResultList(graph_name, trace, values, units, - important, improve_direction); + rtc::StringBuilder graph_name; + graph_name << measurement << modifier; + GetPerfWriter().LogResultList(graph_name.str(), trace, values, units, + important, improve_direction); + GetResultsLinePrinter().PrintResultList(graph_name.str(), trace, values, + units, important, improve_direction); } } // namespace test diff --git a/test/testsupport/perf_test.h b/test/testsupport/perf_test.h index 25535bce82..41380241c3 100644 --- a/test/testsupport/perf_test.h +++ b/test/testsupport/perf_test.h @@ -15,6 +15,7 @@ #include #include +#include "absl/strings/string_view.h" #include "api/array_view.h" #include "api/numerics/samples_stats_counter.h" @@ -45,11 +46,11 @@ enum class ImproveDirection { // // The binary this runs in must be hooked up as a perf test in the WebRTC // recipes for this to actually be uploaded to chromeperf.appspot.com. -void PrintResult(const std::string& measurement, - const std::string& modifier, - const std::string& user_story, +void PrintResult(absl::string_view measurement, + absl::string_view modifier, + absl::string_view user_story, const double value, - const std::string& units, + absl::string_view units, bool important, ImproveDirection improve_direction = ImproveDirection::kNone); @@ -58,12 +59,12 @@ void PrintResult(const std::string& measurement, // standard deviation (or other error metric) of the measurement. // DEPRECATED: soon unsupported. void PrintResultMeanAndError( - const std::string& measurement, - const std::string& modifier, - const std::string& user_story, + absl::string_view measurement, + absl::string_view modifier, + absl::string_view user_story, const double mean, const double error, - const std::string& units, + absl::string_view units, bool important, ImproveDirection improve_direction = ImproveDirection::kNone); @@ -72,21 +73,21 @@ void PrintResultMeanAndError( // post-processing step might produce plots of their mean and standard // deviation. void PrintResultList( - const std::string& measurement, - const std::string& modifier, - const std::string& user_story, + absl::string_view measurement, + absl::string_view modifier, + absl::string_view user_story, rtc::ArrayView values, - const std::string& units, + absl::string_view units, bool important, ImproveDirection improve_direction = ImproveDirection::kNone); // Like PrintResult(), but prints a (mean, standard deviation) from stats // counter. Also add specified metric to the plotable metrics output. -void PrintResult(const std::string& measurement, - const std::string& modifier, - const std::string& user_story, +void PrintResult(absl::string_view measurement, + absl::string_view modifier, + absl::string_view user_story, const SamplesStatsCounter& counter, - const std::string& units, + absl::string_view units, const bool important, ImproveDirection improve_direction = ImproveDirection::kNone); diff --git a/test/testsupport/perf_test_histogram_writer.cc b/test/testsupport/perf_test_histogram_writer.cc index a4f86dc5f0..096ca44571 100644 --- a/test/testsupport/perf_test_histogram_writer.cc +++ b/test/testsupport/perf_test_histogram_writer.cc @@ -15,7 +15,10 @@ #include #include +#include "absl/strings/string_view.h" +#include "api/numerics/samples_stats_counter.h" #include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" #include "rtc_base/synchronization/mutex.h" #include "third_party/catapult/tracing/tracing/value/diagnostics/reserved_infos.h" #include "third_party/catapult/tracing/tracing/value/histogram.h" @@ -39,20 +42,20 @@ class PerfTestHistogramWriter : public PerfTestResultWriter { histograms_.clear(); } - void LogResult(const std::string& graph_name, - const std::string& trace_name, + void LogResult(absl::string_view graph_name, + absl::string_view trace_name, const double value, - const std::string& units, + absl::string_view units, const bool important, ImproveDirection improve_direction) override { (void)important; AddSample(graph_name, trace_name, value, units, improve_direction); } - void LogResultMeanAndError(const std::string& graph_name, - const std::string& trace_name, + void LogResultMeanAndError(absl::string_view graph_name, + absl::string_view trace_name, const double mean, const double error, - const std::string& units, + absl::string_view units, const bool important, ImproveDirection improve_direction) override { RTC_LOG(LS_WARNING) << "Discarding stddev, not supported by histograms"; @@ -61,10 +64,10 @@ class PerfTestHistogramWriter : public PerfTestResultWriter { AddSample(graph_name, trace_name, mean, units, improve_direction); } - void LogResultList(const std::string& graph_name, - const std::string& trace_name, + void LogResultList(absl::string_view graph_name, + absl::string_view trace_name, const rtc::ArrayView values, - const std::string& units, + absl::string_view units, const bool important, ImproveDirection improve_direction) override { (void)important; @@ -88,14 +91,14 @@ class PerfTestHistogramWriter : public PerfTestResultWriter { } private: - void AddSample(const std::string& original_graph_name, - const std::string& trace_name, + void AddSample(absl::string_view original_graph_name, + absl::string_view trace_name, const double value, - const std::string& units, + absl::string_view units, ImproveDirection improve_direction) { // WebRTC annotates the units into the metric name when they are not // supported by the Histogram API. - std::string graph_name = original_graph_name; + std::string graph_name(original_graph_name); if (units == "dB") { graph_name += "_dB"; } else if (units == "fps") { @@ -107,9 +110,10 @@ class PerfTestHistogramWriter : public PerfTestResultWriter { // Lookup on graph name + trace name (or measurement + story in catapult // parlance). There should be several histograms with the same measurement // if they're for different stories. - std::string measurement_and_story = graph_name + trace_name; + rtc::StringBuilder measurement_and_story; + measurement_and_story << graph_name << trace_name; MutexLock lock(&mutex_); - if (histograms_.count(measurement_and_story) == 0) { + if (histograms_.count(measurement_and_story.str()) == 0) { proto::UnitAndDirection unit = ParseUnit(units, improve_direction); std::unique_ptr builder = std::make_unique(graph_name, unit); @@ -117,24 +121,24 @@ class PerfTestHistogramWriter : public PerfTestResultWriter { // Set all summary options as false - we don't want to generate // metric_std, metric_count, and so on for all metrics. builder->SetSummaryOptions(proto::SummaryOptions()); - histograms_[measurement_and_story] = std::move(builder); + histograms_[measurement_and_story.str()] = std::move(builder); proto::Diagnostic stories; proto::GenericSet* generic_set = stories.mutable_generic_set(); - generic_set->add_values(AsJsonString(trace_name)); - histograms_[measurement_and_story]->AddDiagnostic( + generic_set->add_values(AsJsonString(std::string(trace_name))); + histograms_[measurement_and_story.str()]->AddDiagnostic( catapult::kStoriesDiagnostic, stories); } if (units == "bps") { // Bps has been interpreted as bits per second in WebRTC tests. - histograms_[measurement_and_story]->AddSample(value / 8); + histograms_[measurement_and_story.str()]->AddSample(value / 8); } else { - histograms_[measurement_and_story]->AddSample(value); + histograms_[measurement_and_story.str()]->AddSample(value); } } - proto::UnitAndDirection ParseUnit(const std::string& units, + proto::UnitAndDirection ParseUnit(absl::string_view units, ImproveDirection improve_direction) { RTC_DCHECK(units.find('_') == std::string::npos) << "The unit_bigger|smallerIsBetter syntax isn't supported in WebRTC, " @@ -155,7 +159,7 @@ class PerfTestHistogramWriter : public PerfTestResultWriter { } else if (units == "%") { result.set_unit(proto::UNITLESS); } else { - proto::Unit unit = catapult::UnitFromJsonUnit(units); + proto::Unit unit = catapult::UnitFromJsonUnit(std::string(units)); // UnitFromJsonUnit returns UNITLESS if it doesn't recognize the unit. if (unit == proto::UNITLESS && units != "unitless") { diff --git a/test/testsupport/perf_test_histogram_writer_unittest.cc b/test/testsupport/perf_test_histogram_writer_unittest.cc index 6b083d6543..83025a7447 100644 --- a/test/testsupport/perf_test_histogram_writer_unittest.cc +++ b/test/testsupport/perf_test_histogram_writer_unittest.cc @@ -34,6 +34,25 @@ TEST(PerfHistogramWriterUnittest, TestSimpleHistogram) { ASSERT_EQ(histogram_set.histograms_size(), 1); } +TEST(PerfHistogramWriterUnittest, TestListOfValuesHistogram) { + std::unique_ptr writer = + std::unique_ptr(CreateHistogramWriter()); + + std::vector samples{0, 1, 2}; + writer->LogResultList("-", "-", samples, "ms", false, + ImproveDirection::kNone); + + proto::HistogramSet histogram_set; + EXPECT_TRUE(histogram_set.ParseFromString(writer->Serialize())) + << "Expected valid histogram set"; + + ASSERT_EQ(histogram_set.histograms_size(), 1); + ASSERT_EQ(histogram_set.histograms(0).sample_values_size(), 3); + EXPECT_EQ(histogram_set.histograms(0).sample_values(0), 0); + EXPECT_EQ(histogram_set.histograms(0).sample_values(1), 1); + EXPECT_EQ(histogram_set.histograms(0).sample_values(2), 2); +} + TEST(PerfHistogramWriterUnittest, WritesSamplesAndUserStory) { std::unique_ptr writer = std::unique_ptr(CreateHistogramWriter()); diff --git a/test/testsupport/perf_test_result_writer.h b/test/testsupport/perf_test_result_writer.h index d5d7011749..e7342c137f 100644 --- a/test/testsupport/perf_test_result_writer.h +++ b/test/testsupport/perf_test_result_writer.h @@ -12,8 +12,10 @@ #define TEST_TESTSUPPORT_PERF_TEST_RESULT_WRITER_H_ #include + #include +#include "absl/strings/string_view.h" #include "test/testsupport/perf_test.h" namespace webrtc { @@ -25,25 +27,25 @@ class PerfTestResultWriter { virtual ~PerfTestResultWriter() = default; virtual void ClearResults() = 0; - virtual void LogResult(const std::string& graph_name, - const std::string& trace_name, + virtual void LogResult(absl::string_view graph_name, + absl::string_view trace_name, const double value, - const std::string& units, + absl::string_view units, const bool important, webrtc::test::ImproveDirection improve_direction) = 0; virtual void LogResultMeanAndError( - const std::string& graph_name, - const std::string& trace_name, + absl::string_view graph_name, + absl::string_view trace_name, const double mean, const double error, - const std::string& units, + absl::string_view units, const bool important, webrtc::test::ImproveDirection improve_direction) = 0; virtual void LogResultList( - const std::string& graph_name, - const std::string& trace_name, + absl::string_view graph_name, + absl::string_view trace_name, const rtc::ArrayView values, - const std::string& units, + absl::string_view units, const bool important, webrtc::test::ImproveDirection improve_direction) = 0; diff --git a/test/testsupport/perf_test_unittest.cc b/test/testsupport/perf_test_unittest.cc index 3746e2494a..4cd925d8fb 100644 --- a/test/testsupport/perf_test_unittest.cc +++ b/test/testsupport/perf_test_unittest.cc @@ -103,6 +103,83 @@ TEST_F(PerfTest, TestGetPerfResultsHistograms) { EXPECT_EQ(hist2.unit().unit(), proto::MS_BEST_FIT_FORMAT); } +TEST_F(PerfTest, TestGetPerfResultsHistogramsWithEmptyCounter) { + ClearPerfResults(); + ::testing::internal::CaptureStdout(); + + SamplesStatsCounter empty_counter; + PrintResult("measurement", "_modifier", "story", empty_counter, "ms", false); + + proto::HistogramSet histogram_set; + EXPECT_TRUE(histogram_set.ParseFromString(GetPerfResults())) + << "Expected valid histogram set"; + + ASSERT_EQ(histogram_set.histograms_size(), 1) + << "Should be one histogram: measurement_modifier"; + const proto::Histogram& hist = histogram_set.histograms(0); + + EXPECT_EQ(hist.name(), "measurement_modifier"); + + // Spot check some things in here (there's a more thorough test on the + // histogram writer itself). + EXPECT_EQ(hist.unit().unit(), proto::MS_BEST_FIT_FORMAT); + EXPECT_EQ(hist.sample_values_size(), 1); + EXPECT_EQ(hist.sample_values(0), 0); + + EXPECT_EQ(hist.diagnostics().diagnostic_map().count("stories"), 1u); + const proto::Diagnostic& stories = + hist.diagnostics().diagnostic_map().at("stories"); + ASSERT_EQ(stories.generic_set().values_size(), 1); + EXPECT_EQ(stories.generic_set().values(0), "\"story\""); + + std::string expected = "RESULT measurement_modifier: story= {0,0} ms\n"; + EXPECT_EQ(expected, ::testing::internal::GetCapturedStdout()); +} + +TEST_F(PerfTest, TestGetPerfResultsHistogramsWithStatsCounter) { + ClearPerfResults(); + ::testing::internal::CaptureStdout(); + + SamplesStatsCounter counter; + counter.AddSample(1); + counter.AddSample(2); + counter.AddSample(3); + counter.AddSample(4); + counter.AddSample(5); + PrintResult("measurement", "_modifier", "story", counter, "ms", false); + + proto::HistogramSet histogram_set; + EXPECT_TRUE(histogram_set.ParseFromString(GetPerfResults())) + << "Expected valid histogram set"; + + ASSERT_EQ(histogram_set.histograms_size(), 1) + << "Should be one histogram: measurement_modifier"; + const proto::Histogram& hist = histogram_set.histograms(0); + + EXPECT_EQ(hist.name(), "measurement_modifier"); + + // Spot check some things in here (there's a more thorough test on the + // histogram writer itself). + EXPECT_EQ(hist.unit().unit(), proto::MS_BEST_FIT_FORMAT); + EXPECT_EQ(hist.sample_values_size(), 5); + EXPECT_EQ(hist.sample_values(0), 1); + EXPECT_EQ(hist.sample_values(1), 2); + EXPECT_EQ(hist.sample_values(2), 3); + EXPECT_EQ(hist.sample_values(3), 4); + EXPECT_EQ(hist.sample_values(4), 5); + + EXPECT_EQ(hist.diagnostics().diagnostic_map().count("stories"), 1u); + const proto::Diagnostic& stories = + hist.diagnostics().diagnostic_map().at("stories"); + ASSERT_EQ(stories.generic_set().values_size(), 1); + EXPECT_EQ(stories.generic_set().values(0), "\"story\""); + + // mean = 3; std = sqrt(2) + std::string expected = + "RESULT measurement_modifier: story= {3,1.4142136} ms\n"; + EXPECT_EQ(expected, ::testing::internal::GetCapturedStdout()); +} + #endif // WEBRTC_ENABLE_PROTOBUF #if GTEST_HAS_DEATH_TEST diff --git a/test/time_controller/BUILD.gn b/test/time_controller/BUILD.gn index c9fffe6853..6c13a99648 100644 --- a/test/time_controller/BUILD.gn +++ b/test/time_controller/BUILD.gn @@ -26,6 +26,7 @@ rtc_library("time_controller") { ] deps = [ + "../../api:sequence_checker", "../../api:time_controller", "../../api/task_queue", "../../api/task_queue:default_task_queue_factory", @@ -35,10 +36,10 @@ rtc_library("time_controller") { "../../modules/utility:utility", "../../rtc_base", "../../rtc_base:checks", + "../../rtc_base:null_socket_server", "../../rtc_base:rtc_base_tests_utils", "../../rtc_base:rtc_event", "../../rtc_base/synchronization:mutex", - "../../rtc_base/synchronization:sequence_checker", "../../rtc_base/synchronization:yield_policy", "../../rtc_base/task_utils:to_queued_task", "../../system_wrappers", @@ -62,6 +63,7 @@ if (rtc_include_tests) { "../../rtc_base", "../../rtc_base:rtc_base_approved", "../../rtc_base:rtc_task_queue", + "../../rtc_base:threading", "../../rtc_base/synchronization:mutex", "../../rtc_base/task_utils:repeating_task", "../../rtc_base/task_utils:to_queued_task", diff --git a/test/time_controller/simulated_time_controller.cc b/test/time_controller/simulated_time_controller.cc index aba8c6600e..a34abe8ced 100644 --- a/test/time_controller/simulated_time_controller.cc +++ b/test/time_controller/simulated_time_controller.cc @@ -226,4 +226,14 @@ void GlobalSimulatedTimeController::AdvanceTime(TimeDelta duration) { impl_.RunReadyRunners(); } +void GlobalSimulatedTimeController::Register( + sim_time_impl::SimulatedSequenceRunner* runner) { + impl_.Register(runner); +} + +void GlobalSimulatedTimeController::Unregister( + sim_time_impl::SimulatedSequenceRunner* runner) { + impl_.Unregister(runner); +} + } // namespace webrtc diff --git a/test/time_controller/simulated_time_controller.h b/test/time_controller/simulated_time_controller.h index 6c6dbfab9d..9ded4689de 100644 --- a/test/time_controller/simulated_time_controller.h +++ b/test/time_controller/simulated_time_controller.h @@ -17,6 +17,7 @@ #include #include "absl/strings/string_view.h" +#include "api/sequence_checker.h" #include "api/test/time_controller.h" #include "api/units/timestamp.h" #include "modules/include/module.h" @@ -25,7 +26,6 @@ #include "rtc_base/platform_thread_types.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/synchronization/yield_policy.h" -#include "rtc_base/thread_checker.h" namespace webrtc { namespace sim_time_impl { @@ -140,6 +140,17 @@ class GlobalSimulatedTimeController : public TimeController { void AdvanceTime(TimeDelta duration) override; + // Makes the simulated time controller aware of a custom + // SimulatedSequenceRunner. + // TODO(bugs.webrtc.org/11581): remove method once the ModuleRtpRtcpImpl2 unit + // test stops using it. + void Register(sim_time_impl::SimulatedSequenceRunner* runner); + // Removes a previously installed custom SimulatedSequenceRunner from the + // simulated time controller. + // TODO(bugs.webrtc.org/11581): remove method once the ModuleRtpRtcpImpl2 unit + // test stops using it. + void Unregister(sim_time_impl::SimulatedSequenceRunner* runner); + private: rtc::ScopedBaseFakeClock global_clock_; // Provides simulated CurrentNtpInMilliseconds() diff --git a/test/time_controller/time_controller_conformance_test.cc b/test/time_controller/time_controller_conformance_test.cc index 10f0e1d724..3d582cad8e 100644 --- a/test/time_controller/time_controller_conformance_test.cc +++ b/test/time_controller/time_controller_conformance_test.cc @@ -92,6 +92,9 @@ TEST_P(SimulatedRealTimeControllerConformanceTest, ThreadPostOrderTest) { thread->PostTask(RTC_FROM_HERE, [&]() { execution_order.Executed(2); }); time_controller->AdvanceTime(TimeDelta::Millis(100)); EXPECT_THAT(execution_order.order(), ElementsAreArray({1, 2})); + // Destroy `thread` before `execution_order` to be sure `execution_order` + // is not accessed on the posted task after it is destroyed. + thread = nullptr; } TEST_P(SimulatedRealTimeControllerConformanceTest, ThreadPostDelayedOrderTest) { @@ -105,6 +108,9 @@ TEST_P(SimulatedRealTimeControllerConformanceTest, ThreadPostDelayedOrderTest) { thread->PostTask(ToQueuedTask([&]() { execution_order.Executed(1); })); time_controller->AdvanceTime(TimeDelta::Millis(600)); EXPECT_THAT(execution_order.order(), ElementsAreArray({1, 2})); + // Destroy `thread` before `execution_order` to be sure `execution_order` + // is not accessed on the posted task after it is destroyed. + thread = nullptr; } TEST_P(SimulatedRealTimeControllerConformanceTest, ThreadPostInvokeOrderTest) { @@ -119,6 +125,9 @@ TEST_P(SimulatedRealTimeControllerConformanceTest, ThreadPostInvokeOrderTest) { thread->Invoke(RTC_FROM_HERE, [&]() { execution_order.Executed(2); }); time_controller->AdvanceTime(TimeDelta::Millis(100)); EXPECT_THAT(execution_order.order(), ElementsAreArray({1, 2})); + // Destroy `thread` before `execution_order` to be sure `execution_order` + // is not accessed on the posted task after it is destroyed. + thread = nullptr; } TEST_P(SimulatedRealTimeControllerConformanceTest, @@ -136,6 +145,9 @@ TEST_P(SimulatedRealTimeControllerConformanceTest, }); time_controller->AdvanceTime(TimeDelta::Millis(100)); EXPECT_THAT(execution_order.order(), ElementsAreArray({1, 2})); + // Destroy `thread` before `execution_order` to be sure `execution_order` + // is not accessed on the posted task after it is destroyed. + thread = nullptr; } TEST_P(SimulatedRealTimeControllerConformanceTest, @@ -158,6 +170,9 @@ TEST_P(SimulatedRealTimeControllerConformanceTest, /*warn_after_ms=*/10'000)); time_controller->AdvanceTime(TimeDelta::Millis(100)); EXPECT_THAT(execution_order.order(), ElementsAreArray({1, 2})); + // Destroy `task_queue` before `execution_order` to be sure `execution_order` + // is not accessed on the posted task after it is destroyed. + task_queue = nullptr; } INSTANTIATE_TEST_SUITE_P(ConformanceTest, diff --git a/tools_webrtc/android/OWNERS b/tools_webrtc/android/OWNERS index 3c4e54174e..cf092a316a 100644 --- a/tools_webrtc/android/OWNERS +++ b/tools_webrtc/android/OWNERS @@ -1 +1 @@ -sakal@webrtc.org +xalep@webrtc.org diff --git a/tools_webrtc/android/build_aar.py b/tools_webrtc/android/build_aar.py index 047be7b0a2..9fc4bb0f39 100755 --- a/tools_webrtc/android/build_aar.py +++ b/tools_webrtc/android/build_aar.py @@ -54,9 +54,11 @@ def _ParseArgs(): parser = argparse.ArgumentParser(description='libwebrtc.aar generator.') parser.add_argument( '--build-dir', + type=os.path.abspath, help='Build dir. By default will create and use temporary dir.') parser.add_argument('--output', default='libwebrtc.aar', + type=os.path.abspath, help='Output file of the script.') parser.add_argument( '--arch', diff --git a/tools_webrtc/autoroller/roll_deps.py b/tools_webrtc/autoroller/roll_deps.py index f1a1235f20..286c3c4cda 100755 --- a/tools_webrtc/autoroller/roll_deps.py +++ b/tools_webrtc/autoroller/roll_deps.py @@ -568,16 +568,16 @@ def _IsTreeClean(): return False -def _EnsureUpdatedMasterBranch(dry_run): +def _EnsureUpdatedMainBranch(dry_run): current_branch = _RunCommand(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])[0].splitlines()[0] - if current_branch != 'master': + if current_branch != 'main': logging.error( - 'Please checkout the master branch and re-run this script.') + 'Please checkout the main branch and re-run this script.') if not dry_run: sys.exit(-1) - logging.info('Updating master branch...') + logging.info('Updating main branch...') _RunCommand(['git', 'pull']) @@ -590,7 +590,7 @@ def _CreateRollBranch(dry_run): def _RemovePreviousRollBranch(dry_run): active_branch, branches = _GetBranches() if active_branch == ROLL_BRANCH_NAME: - active_branch = 'master' + active_branch = 'main' if ROLL_BRANCH_NAME in branches: logging.info('Removing previous roll branch (%s)', ROLL_BRANCH_NAME) if not dry_run: @@ -672,7 +672,7 @@ def main(): '--ignore-unclean-workdir', action='store_true', default=False, - help=('Ignore if the current branch is not master or if there ' + help=('Ignore if the current branch is not main or if there ' 'are uncommitted changes (default: %(default)s).')) grp = p.add_mutually_exclusive_group() grp.add_argument( @@ -705,7 +705,7 @@ def main(): _RemovePreviousRollBranch(opts.dry_run) if not opts.ignore_unclean_workdir: - _EnsureUpdatedMasterBranch(opts.dry_run) + _EnsureUpdatedMainBranch(opts.dry_run) deps_filename = os.path.join(CHECKOUT_SRC_DIR, 'DEPS') webrtc_deps = ParseLocalDepsFile(deps_filename) diff --git a/tools_webrtc/get_landmines.py b/tools_webrtc/get_landmines.py index 3b5965fce4..764f053f2a 100755 --- a/tools_webrtc/get_landmines.py +++ b/tools_webrtc/get_landmines.py @@ -55,6 +55,7 @@ def print_landmines(): # pylint: disable=invalid-name print 'Clobber to change neteq_rtpplay type to executable' print 'Clobber to remove .xctest files.' print 'Clobber to remove .xctest files (take 2).' + print 'Switching rtc_executable to rtc_test' def main(): diff --git a/tools_webrtc/ios/build_ios_libs.py b/tools_webrtc/ios/build_ios_libs.py index 8c8a8ac433..a06f829e00 100755 --- a/tools_webrtc/ios/build_ios_libs.py +++ b/tools_webrtc/ios/build_ios_libs.py @@ -7,7 +7,7 @@ # tree. An additional intellectual property rights grant can be found # in the file PATENTS. All contributing project authors may # be found in the AUTHORS file in the root of the source tree. -"""WebRTC iOS FAT libraries build script. +"""WebRTC iOS XCFramework build script. Each architecture is compiled separately before being merged together. By default, the library is created in out_ios_libs/. (Change with -o.) """ @@ -30,7 +30,16 @@ SDK_OUTPUT_DIR = os.path.join(SRC_DIR, 'out_ios_libs') SDK_FRAMEWORK_NAME = 'WebRTC.framework' -DEFAULT_ARCHS = ENABLED_ARCHS = ['arm64', 'arm', 'x64', 'x86'] +SDK_DSYM_NAME = 'WebRTC.dSYM' +SDK_XCFRAMEWORK_NAME = 'WebRTC.xcframework' + +ENABLED_ARCHS = [ + 'device:arm64', 'simulator:arm64', 'simulator:x64', + 'arm64', 'x64' +] +DEFAULT_ARCHS = [ + 'device:arm64', 'simulator:arm64', 'simulator:x64' +] IOS_DEPLOYMENT_TARGET = '9.0' LIBVPX_BUILD_VP9 = True @@ -67,6 +76,7 @@ def _ParseArgs(): parser.add_argument( '-o', '--output-dir', + type=os.path.abspath, default=SDK_OUTPUT_DIR, help='Specifies a directory to output the build artifacts to. ' 'If specified together with -c, deletes the dir.') @@ -113,15 +123,37 @@ def _CleanTemporary(output_dir, architectures): if os.path.isdir(output_dir): logging.info('Removing temporary build files.') for arch in architectures: - arch_lib_path = os.path.join(output_dir, arch + '_libs') + arch_lib_path = os.path.join(output_dir, arch) if os.path.isdir(arch_lib_path): shutil.rmtree(arch_lib_path) -def BuildWebRTC(output_dir, target_arch, flavor, gn_target_name, - ios_deployment_target, libvpx_build_vp9, use_bitcode, use_goma, - extra_gn_args): - output_dir = os.path.join(output_dir, target_arch + '_libs') +def _ParseArchitecture(architectures): + result = dict() + for arch in architectures: + if ":" in arch: + target_environment, target_cpu = arch.split(":") + else: + logging.warning('The environment for build is not specified.') + logging.warning('It is assumed based on cpu type.') + logging.warning('See crbug.com/1138425 for more details.') + if arch == "x64": + target_environment = "simulator" + else: + target_environment = "device" + target_cpu = arch + archs = result.get(target_environment) + if archs is None: + result[target_environment] = {target_cpu} + else: + archs.add(target_cpu) + + return result + + +def BuildWebRTC(output_dir, target_environment, target_arch, flavor, + gn_target_name, ios_deployment_target, libvpx_build_vp9, + use_bitcode, use_goma, extra_gn_args): gn_args = [ 'target_os="ios"', 'ios_enable_code_signing=false', 'use_xcode_clang=true', 'is_component_build=false', @@ -136,6 +168,8 @@ def BuildWebRTC(output_dir, target_arch, flavor, gn_target_name, else: raise ValueError('Unexpected flavor type: %s' % flavor) + gn_args.append('target_environment="%s"' % target_environment) + gn_args.append('target_cpu="%s"' % target_arch) gn_args.append('ios_deployment_target="%s"' % ios_deployment_target) @@ -181,11 +215,14 @@ def main(): _CleanArtifacts(args.output_dir) return 0 - architectures = list(args.arch) + # architectures is typed as Dict[str, Set[str]], + # where key is for the environment (device or simulator) + # and value is for the cpu type. + architectures = _ParseArchitecture(args.arch) gn_args = args.extra_gn_args if args.purify: - _CleanTemporary(args.output_dir, architectures) + _CleanTemporary(args.output_dir, architectures.keys()) return 0 gn_target_name = 'framework_objc' @@ -194,78 +231,101 @@ def main(): gn_args.append('enable_stripping=true') # Build all architectures. - for arch in architectures: - BuildWebRTC(args.output_dir, arch, args.build_config, gn_target_name, - IOS_DEPLOYMENT_TARGET, LIBVPX_BUILD_VP9, args.bitcode, - args.use_goma, gn_args) - - # Create FAT archive. - lib_paths = [ - os.path.join(args.output_dir, arch + '_libs') for arch in architectures - ] - - # Combine the slices. - dylib_path = os.path.join(SDK_FRAMEWORK_NAME, 'WebRTC') - # Dylibs will be combined, all other files are the same across archs. - # Use distutils instead of shutil to support merging folders. - distutils.dir_util.copy_tree( - os.path.join(lib_paths[0], SDK_FRAMEWORK_NAME), - os.path.join(args.output_dir, SDK_FRAMEWORK_NAME)) - logging.info('Merging framework slices.') - dylib_paths = [os.path.join(path, dylib_path) for path in lib_paths] - out_dylib_path = os.path.join(args.output_dir, dylib_path) - try: - os.remove(out_dylib_path) - except OSError: - pass - cmd = ['lipo'] + dylib_paths + ['-create', '-output', out_dylib_path] - _RunCommand(cmd) - - # Merge the dSYM slices. - lib_dsym_dir_path = os.path.join(lib_paths[0], 'WebRTC.dSYM') - if os.path.isdir(lib_dsym_dir_path): + framework_paths = [] + all_lib_paths = [] + for (environment, archs) in architectures.items(): + framework_path = os.path.join(args.output_dir, environment) + framework_paths.append(framework_path) + lib_paths = [] + for arch in archs: + lib_path = os.path.join(framework_path, arch + '_libs') + lib_paths.append(lib_path) + BuildWebRTC(lib_path, environment, arch, args.build_config, + gn_target_name, IOS_DEPLOYMENT_TARGET, + LIBVPX_BUILD_VP9, args.bitcode, args.use_goma, gn_args) + all_lib_paths.extend(lib_paths) + + # Combine the slices. + dylib_path = os.path.join(SDK_FRAMEWORK_NAME, 'WebRTC') + # Dylibs will be combined, all other files are the same across archs. + # Use distutils instead of shutil to support merging folders. distutils.dir_util.copy_tree( - lib_dsym_dir_path, os.path.join(args.output_dir, 'WebRTC.dSYM')) - logging.info('Merging dSYM slices.') - dsym_path = os.path.join('WebRTC.dSYM', 'Contents', 'Resources', - 'DWARF', 'WebRTC') - lib_dsym_paths = [os.path.join(path, dsym_path) for path in lib_paths] - out_dsym_path = os.path.join(args.output_dir, dsym_path) + os.path.join(lib_paths[0], SDK_FRAMEWORK_NAME), + os.path.join(framework_path, SDK_FRAMEWORK_NAME)) + logging.info('Merging framework slices for %s.', environment) + dylib_paths = [os.path.join(path, dylib_path) for path in lib_paths] + out_dylib_path = os.path.join(framework_path, dylib_path) try: - os.remove(out_dsym_path) + os.remove(out_dylib_path) except OSError: pass - cmd = ['lipo'] + lib_dsym_paths + ['-create', '-output', out_dsym_path] + cmd = ['lipo'] + dylib_paths + ['-create', '-output', out_dylib_path] _RunCommand(cmd) - # Generate the license file. - ninja_dirs = [ - os.path.join(args.output_dir, arch + '_libs') - for arch in architectures - ] - gn_target_full_name = '//sdk:' + gn_target_name - builder = LicenseBuilder(ninja_dirs, [gn_target_full_name]) - builder.GenerateLicenseText( - os.path.join(args.output_dir, SDK_FRAMEWORK_NAME)) - - # Modify the version number. - # Format should be ... - # e.g. 55.0.14986 means branch cut 55, no hotfixes, and revision 14986. - infoplist_path = os.path.join(args.output_dir, SDK_FRAMEWORK_NAME, - 'Info.plist') - cmd = [ - 'PlistBuddy', '-c', 'Print :CFBundleShortVersionString', - infoplist_path + # Merge the dSYM slices. + lib_dsym_dir_path = os.path.join(lib_paths[0], SDK_DSYM_NAME) + if os.path.isdir(lib_dsym_dir_path): + distutils.dir_util.copy_tree( + lib_dsym_dir_path, os.path.join(framework_path, SDK_DSYM_NAME)) + logging.info('Merging dSYM slices.') + dsym_path = os.path.join(SDK_DSYM_NAME, 'Contents', 'Resources', + 'DWARF', 'WebRTC') + lib_dsym_paths = [ + os.path.join(path, dsym_path) for path in lib_paths + ] + out_dsym_path = os.path.join(framework_path, dsym_path) + try: + os.remove(out_dsym_path) + except OSError: + pass + cmd = ['lipo' + ] + lib_dsym_paths + ['-create', '-output', out_dsym_path] + _RunCommand(cmd) + + # Modify the version number. + # Format should be ... + # e.g. 55.0.14986 means + # branch cut 55, no hotfixes, and revision 14986. + infoplist_path = os.path.join(framework_path, SDK_FRAMEWORK_NAME, + 'Info.plist') + cmd = [ + 'PlistBuddy', '-c', 'Print :CFBundleShortVersionString', + infoplist_path + ] + major_minor = subprocess.check_output(cmd).strip() + version_number = '%s.%s' % (major_minor, args.revision) + logging.info('Substituting revision number: %s', version_number) + cmd = [ + 'PlistBuddy', '-c', 'Set :CFBundleVersion ' + version_number, + infoplist_path + ] + _RunCommand(cmd) + _RunCommand(['plutil', '-convert', 'binary1', infoplist_path]) + + xcframework_dir = os.path.join(args.output_dir, SDK_XCFRAMEWORK_NAME) + if os.path.isdir(xcframework_dir): + shutil.rmtree(xcframework_dir) + + logging.info('Creating xcframework.') + cmd = ['xcodebuild', '-create-xcframework', '-output', xcframework_dir] + + # Apparently, xcodebuild needs absolute paths for input arguments + for framework_path in framework_paths: + cmd += [ + '-framework', + os.path.abspath(os.path.join(framework_path, SDK_FRAMEWORK_NAME)), + '-debug-symbols', + os.path.abspath(os.path.join(framework_path, SDK_DSYM_NAME)) ] - major_minor = subprocess.check_output(cmd).strip() - version_number = '%s.%s' % (major_minor, args.revision) - logging.info('Substituting revision number: %s', version_number) - cmd = [ - 'PlistBuddy', '-c', 'Set :CFBundleVersion ' + version_number, - infoplist_path - ] - _RunCommand(cmd) - _RunCommand(['plutil', '-convert', 'binary1', infoplist_path]) + + _RunCommand(cmd) + + # Generate the license file. + logging.info('Generate license file.') + gn_target_full_name = '//sdk:' + gn_target_name + builder = LicenseBuilder(all_lib_paths, [gn_target_full_name]) + builder.GenerateLicenseText( + os.path.join(args.output_dir, SDK_XCFRAMEWORK_NAME)) logging.info('Done.') return 0 diff --git a/tools_webrtc/iwyu/apply-iwyu b/tools_webrtc/iwyu/apply-iwyu index 65950d307f..a26f46b933 100755 --- a/tools_webrtc/iwyu/apply-iwyu +++ b/tools_webrtc/iwyu/apply-iwyu @@ -15,28 +15,59 @@ FILE=$1 # the following variable to "yes". This is a style guide violation. REMOVE_CC_INCLUDES=no -if [ ! -f $FILE.h ]; then - echo "$FILE.h not found" - exit 1 +if [ ! -f $FILE ]; then + # See if we have the root name of a .cc/.h pair + if [ ! -f $FILE.h ]; then + echo "$FILE.h not found" + exit 1 + fi + FILE_H=$FILE.h + if [ ! -f $FILE.cc ]; then + echo "$FILE.cc not found" + exit 1 + fi + FILE_CC=$FILE.cc +else + # Exact file, no .h file + FILE_CC=$FILE + FILE_H="" fi -if [ ! -f $FILE.cc ]; then - echo "$FILE.cc not found" - exit 1 -fi +# IWYU has a confusing set of exit codes. Discard it. +iwyu -Xiwyu --no_fwd_decls -D__X86_64__ -DWEBRTC_POSIX -I . \ + -I third_party/abseil-cpp \ + -I third_party/googletest/src/googlemock/include \ + -I third_party/googletest/src/googletest/include \ + $FILE_CC >& /tmp/includefixes$$ || echo "IWYU done, code $?" -iwyu -Xiwyu --no_fwd_decls -D__X86_64__ -DWEBRTC_POSIX -I . -I third_party/abseil-cpp $FILE.cc |& fix_include || echo "Some files modified" +if grep 'fatal error' /tmp/includefixes$$; then + echo "iwyu run failed" + cat /tmp/includefixes$$ + rm /tmp/includefixes$$ + exit 1 +else + fix_include < /tmp/includefixes$$ || echo "Some files modified" + rm /tmp/includefixes$$ +fi if [ $REMOVE_CC_INCLUDES == "yes" ]; then - grep ^#include $FILE.h | grep -v -f - $FILE.cc > $FILE.ccnew - grep -v -f tools_webrtc/iwyu/iwyu-filter-list $FILE.ccnew > $FILE.cc + if [ -n "$FILE_H" ]; then + # Don't include in .cc what's already included in .h + grep ^#include $FILE_H | grep -v -f - $FILE_CC > $FILE_CC.new + else + cp $FILE_CC $FILE_CC.new + fi + # Don't include stuff on the banlist + grep -v -f tools_webrtc/iwyu/iwyu-filter-list $FILE_CC.new > $FILE_CC rm $FILE.ccnew else - grep -v -f tools_webrtc/iwyu/iwyu-filter-list $FILE.cc > $FILE.ccnew - mv $FILE.ccnew $FILE.cc + grep -v -f tools_webrtc/iwyu/iwyu-filter-list $FILE_CC > $FILE_CC.new + mv $FILE_CC.new $FILE_CC +fi +if [ -n "$FILE_H" ]; then + grep -v -f tools_webrtc/iwyu/iwyu-filter-list $FILE_H > $FILE_H.new + mv $FILE_H.new $FILE_H fi -grep -v -f tools_webrtc/iwyu/iwyu-filter-list $FILE.h > $FILE.hnew -mv $FILE.hnew $FILE.h echo "Finished. Check diff, compile and git cl format before uploading." diff --git a/tools_webrtc/libs/generate_licenses.py b/tools_webrtc/libs/generate_licenses.py index f33c050291..cbb1514d3c 100755 --- a/tools_webrtc/libs/generate_licenses.py +++ b/tools_webrtc/libs/generate_licenses.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2016 The WebRTC project authors. All Rights Reserved. # @@ -23,12 +23,16 @@ import sys import argparse -import cgi import json import logging import os import re import subprocess +try: + # python 3.2+ + from html import escape +except ImportError: + from cgi import escape # Third_party library to licences mapping. Keys are names of the libraries # (right after the `third_party/` prefix) @@ -42,6 +46,7 @@ ], 'bazel': ['third_party/bazel/LICENSE'], 'boringssl': ['third_party/boringssl/src/LICENSE'], + 'crc32c': ['third_party/crc32c/src/LICENSE'], 'errorprone': [ 'third_party/android_deps/libs/' 'com_google_errorprone_error_prone_core/LICENSE' @@ -78,6 +83,8 @@ # TODO(bugs.webrtc.org/1110): Remove this hack. This is not a lib. # For some reason it is listed as so in _GetThirdPartyLibraries. 'android_deps': [], + # This is not a library but a collection of libraries. + 'androidx': [], # Compile time dependencies, no license needed: 'yasm': [], @@ -179,7 +186,7 @@ def _RunGN(buildfile_dir, target): target, ] logging.debug('Running: %r', cmd) - output_json = subprocess.check_output(cmd, cwd=WEBRTC_ROOT) + output_json = subprocess.check_output(cmd, cwd=WEBRTC_ROOT).decode('UTF-8') logging.debug('Output: %s', output_json) return output_json @@ -206,7 +213,7 @@ def GenerateLicenseText(self, output_dir): self.common_licenses_dict.keys()) if missing_licenses: error_msg = 'Missing licenses for following third_party targets: %s' % \ - ', '.join(missing_licenses) + ', '.join(sorted(missing_licenses)) logging.error(error_msg) raise Exception(error_msg) @@ -231,7 +238,7 @@ def GenerateLicenseText(self, output_dir): for path in self.common_licenses_dict[license_lib]: license_path = os.path.join(WEBRTC_ROOT, path) with open(license_path, 'r') as license_file: - license_text = cgi.escape(license_file.read(), quote=True) + license_text = escape(license_file.read(), quote=True) output_license_file.write(license_text) output_license_file.write('\n') output_license_file.write('```\n\n') diff --git a/tools_webrtc/libs/generate_licenses_test.py b/tools_webrtc/libs/generate_licenses_test.py index 51acb89881..ebef78e132 100755 --- a/tools_webrtc/libs/generate_licenses_test.py +++ b/tools_webrtc/libs/generate_licenses_test.py @@ -10,7 +10,12 @@ # be found in the AUTHORS file in the root of the source tree. import unittest -import mock +try: + # python 3.3+ + from unittest.mock import patch +except ImportError: + # From site-package + from mock import patch from generate_licenses import LicenseBuilder @@ -32,21 +37,21 @@ def _FakeRunGN(buildfile_dir, target): """ def testParseLibraryName(self): - self.assertEquals( + self.assertEqual( LicenseBuilder._ParseLibraryName('//a/b/third_party/libname1:c'), 'libname1') - self.assertEquals( + self.assertEqual( LicenseBuilder._ParseLibraryName( '//a/b/third_party/libname2:c(d)'), 'libname2') - self.assertEquals( + self.assertEqual( LicenseBuilder._ParseLibraryName( '//a/b/third_party/libname3/c:d(e)'), 'libname3') - self.assertEquals( + self.assertEqual( LicenseBuilder._ParseLibraryName('//a/b/not_third_party/c'), None) def testParseLibrarySimpleMatch(self): builder = LicenseBuilder([], [], {}, {}) - self.assertEquals(builder._ParseLibrary('//a/b/third_party/libname:c'), + self.assertEqual(builder._ParseLibrary('//a/b/third_party/libname:c'), 'libname') def testParseLibraryRegExNoMatchFallbacksToDefaultLibname(self): @@ -54,7 +59,7 @@ def testParseLibraryRegExNoMatchFallbacksToDefaultLibname(self): 'libname:foo.*': ['path/to/LICENSE'], } builder = LicenseBuilder([], [], lib_dict, {}) - self.assertEquals( + self.assertEqual( builder._ParseLibrary('//a/b/third_party/libname:bar_java'), 'libname') @@ -63,7 +68,7 @@ def testParseLibraryRegExMatch(self): 'libname:foo.*': ['path/to/LICENSE'], } builder = LicenseBuilder([], [], {}, lib_regex_dict) - self.assertEquals( + self.assertEqual( builder._ParseLibrary('//a/b/third_party/libname:foo_bar_java'), 'libname:foo.*') @@ -72,7 +77,7 @@ def testParseLibraryRegExMatchWithSubDirectory(self): 'libname/foo:bar.*': ['path/to/LICENSE'], } builder = LicenseBuilder([], [], {}, lib_regex_dict) - self.assertEquals( + self.assertEqual( builder._ParseLibrary('//a/b/third_party/libname/foo:bar_java'), 'libname/foo:bar.*') @@ -81,29 +86,29 @@ def testParseLibraryRegExMatchWithStarInside(self): 'libname/foo.*bar.*': ['path/to/LICENSE'], } builder = LicenseBuilder([], [], {}, lib_regex_dict) - self.assertEquals( + self.assertEqual( builder._ParseLibrary( '//a/b/third_party/libname/fooHAHA:bar_java'), 'libname/foo.*bar.*') - @mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN) + @patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN) def testGetThirdPartyLibrariesWithoutRegex(self): builder = LicenseBuilder([], [], {}, {}) - self.assertEquals( + self.assertEqual( builder._GetThirdPartyLibraries('out/arm', 'target1'), set(['libname1', 'libname2', 'libname3'])) - @mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN) + @patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN) def testGetThirdPartyLibrariesWithRegex(self): lib_regex_dict = { 'libname2:c.*': ['path/to/LICENSE'], } builder = LicenseBuilder([], [], {}, lib_regex_dict) - self.assertEquals( + self.assertEqual( builder._GetThirdPartyLibraries('out/arm', 'target1'), set(['libname1', 'libname2:c.*', 'libname3'])) - @mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN) + @patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN) def testGenerateLicenseTextFailIfUnknownLibrary(self): lib_dict = { 'simple_library': ['path/to/LICENSE'], @@ -113,8 +118,8 @@ def testGenerateLicenseTextFailIfUnknownLibrary(self): with self.assertRaises(Exception) as context: builder.GenerateLicenseText('dummy/dir') - self.assertEquals( - context.exception.message, + self.assertEqual( + context.exception.args[0], 'Missing licenses for following third_party targets: ' 'libname1, libname2, libname3') diff --git a/tools_webrtc/mb/gn_isolate_map.pyl b/tools_webrtc/mb/gn_isolate_map.pyl index dba0d97571..01993a8fcb 100644 --- a/tools_webrtc/mb/gn_isolate_map.pyl +++ b/tools_webrtc/mb/gn_isolate_map.pyl @@ -51,6 +51,10 @@ "label": "//common_video:common_video_unittests", "type": "console_test_launcher", }, + "dcsctp_unittests": { + "label": "//net/dcsctp:dcsctp_unittests", + "type": "console_test_launcher", + }, "isac_fix_test": { "label": "//modules/audio_coding:isac_fix_test", "type": "console_test_launcher", diff --git a/tools_webrtc/mb/mb.py b/tools_webrtc/mb/mb.py index 358a66ebc4..4aff74621f 100755 --- a/tools_webrtc/mb/mb.py +++ b/tools_webrtc/mb/mb.py @@ -81,8 +81,6 @@ def AddCommonOptions(subp): subp.add_argument('-b', '--builder', help='builder name to look up config from') subp.add_argument('-m', '--builder-group', - # TODO(crbug.com/1117773): Remove the 'master' args. - '--master', help='builder group name to look up config from') subp.add_argument('-c', '--config', help='configuration to analyze') @@ -325,11 +323,14 @@ def CmdRun(self): return ret if self.args.swarmed: - return self._RunUnderSwarming(build_dir, target) + cmd, _ = self.GetSwarmingCommand(self.args.target[0], vals) + return self._RunUnderSwarming(build_dir, target, cmd) else: return self._RunLocallyIsolated(build_dir, target) - def _RunUnderSwarming(self, build_dir, target): + def _RunUnderSwarming(self, build_dir, target, isolate_cmd): + cas_instance = 'chromium-swarm' + swarming_server = 'chromium-swarm.appspot.com' # TODO(dpranke): Look up the information for the target in # the //testing/buildbot.json file, if possible, so that we # can determine the isolate target, command line, and additional @@ -338,7 +339,7 @@ def _RunUnderSwarming(self, build_dir, target): # TODO(dpranke): Also, add support for sharding and merging results. dimensions = [] for k, v in self.args.dimensions: - dimensions += ['-d', k, v] + dimensions += ['-d', '%s=%s' % (k, v)] archive_json_path = self.ToSrcRelPath( '%s/%s.archive.json' % (build_dir, target)) @@ -347,13 +348,29 @@ def _RunUnderSwarming(self, build_dir, target): 'archive', '-i', self.ToSrcRelPath('%s/%s.isolate' % (build_dir, target)), - '-s', - self.ToSrcRelPath('%s/%s.isolated' % (build_dir, target)), - '-I', 'isolateserver.appspot.com', - '-dump-json', archive_json_path, - ] - ret, _, _ = self.Run(cmd, force_verbose=False) + '-cas-instance', + cas_instance, + '-dump-json', + archive_json_path, + ] + + # Talking to the isolateserver may fail because we're not logged in. + # We trap the command explicitly and rewrite the error output so that + # the error message is actually correct for a Chromium check out. + self.PrintCmd(cmd, env=None) + ret, out, err = self.Run(cmd, force_verbose=False) if ret: + self.Print(' -> returned %d' % ret) + if out: + self.Print(out, end='') + if err: + # The swarming client will return an exit code of 2 (via + # argparse.ArgumentParser.error()) and print a message to indicate + # that auth failed, so we have to parse the message to check. + if (ret == 2 and 'Please login to' in err): + err = err.replace(' auth.py', ' tools/swarming_client/auth.py') + self.Print(err, end='', file=sys.stderr) + return ret try: @@ -363,7 +380,7 @@ def _RunUnderSwarming(self, build_dir, target): 'Failed to read JSON file "%s"' % archive_json_path, file=sys.stderr) return 1 try: - isolated_hash = archive_hashes[target] + cas_digest = archive_hashes[target] except Exception: self.Print( 'Cannot find hash for "%s" in "%s", file content: %s' % @@ -371,16 +388,44 @@ def _RunUnderSwarming(self, build_dir, target): file=sys.stderr) return 1 + try: + json_dir = self.TempDir() + json_file = self.PathJoin(json_dir, 'task.json') + + cmd = [ + self.PathJoin('tools', 'luci-go', 'swarming'), + 'trigger', + '-digest', + cas_digest, + '-server', + swarming_server, + '-tag=purpose:user-debug-mb', + '-relative-cwd', + self.ToSrcRelPath(build_dir), + '-dump-json', + json_file, + ] + dimensions + ['--'] + list(isolate_cmd) + + if self.args.extra_args: + cmd += ['--'] + self.args.extra_args + self.Print('') + ret, _, _ = self.Run(cmd, force_verbose=True, buffer_output=False) + if ret: + return ret + task_json = self.ReadFile(json_file) + task_id = json.loads(task_json)["tasks"][0]['task_id'] + finally: + if json_dir: + self.RemoveDirectory(json_dir) + cmd = [ - self.executable, - self.PathJoin('tools', 'swarming_client', 'swarming.py'), - 'run', - '-s', isolated_hash, - '-I', 'isolateserver.appspot.com', - '-S', 'chromium-swarm.appspot.com', - ] + dimensions - if self.args.extra_args: - cmd += ['--'] + self.args.extra_args + self.PathJoin('tools', 'luci-go', 'swarming'), + 'collect', + '-server', + swarming_server, + '-task-output-stdout=console', + task_id, + ] ret, _, _ = self.Run(cmd, force_verbose=True, buffer_output=False) return ret @@ -685,7 +730,7 @@ def RunGNGen(self, vals): raise MBErr('did not generate any of %s' % ', '.join(runtime_deps_targets)) - command, extra_files = self.GetIsolateCommand(target, vals) + command, extra_files = self.GetSwarmingCommand(target, vals) runtime_deps = self.ReadFile(runtime_deps_path).splitlines() @@ -703,7 +748,7 @@ def RunGNIsolate(self, vals): label = labels[0] build_dir = self.args.path[0] - command, extra_files = self.GetIsolateCommand(target, vals) + command, extra_files = self.GetSwarmingCommand(target, vals) cmd = self.GNCmd('desc', build_dir, label, 'runtime_deps') ret, out, _ = self.Call(cmd) @@ -826,7 +871,7 @@ def GNArgs(self, vals): gn_args = ('import("%s")\n' % vals['args_file']) + gn_args return gn_args - def GetIsolateCommand(self, target, vals): + def GetSwarmingCommand(self, target, vals): isolate_map = self.ReadIsolateMap() test_type = isolate_map[target]['type'] @@ -1190,6 +1235,10 @@ def RemoveDirectory(self, abs_path): else: shutil.rmtree(abs_path, ignore_errors=True) + def TempDir(self): + # This function largely exists so it can be overriden for testing. + return tempfile.mkdtemp(prefix='mb_') + def TempFile(self, mode='w'): # This function largely exists so it can be overriden for testing. return tempfile.NamedTemporaryFile(mode=mode, delete=False) diff --git a/tools_webrtc/mb/mb_config.pyl b/tools_webrtc/mb/mb_config.pyl index 4bb04aa403..253a57acc5 100644 --- a/tools_webrtc/mb/mb_config.pyl +++ b/tools_webrtc/mb/mb_config.pyl @@ -20,8 +20,6 @@ 'builder_groups': { 'client.webrtc': { # iOS - 'iOS32 Debug': 'ios_debug_bot_arm', - 'iOS32 Release': 'ios_release_bot_arm', 'iOS64 Debug': 'ios_debug_bot_arm64', 'iOS64 Release': 'ios_release_bot_arm64', 'iOS64 Sim Debug (iOS 12)': 'ios_debug_bot_x64', @@ -33,6 +31,7 @@ 'Mac64 Release': 'release_bot_x64', 'Mac64 Builder': 'pure_release_bot_x64', 'Mac Asan': 'mac_asan_clang_release_bot_x64', + 'MacARM64 M1 Release': 'release_bot_arm64', # Linux 'Linux32 Debug': 'no_h264_debug_bot_x86', @@ -44,7 +43,6 @@ 'Linux64 Builder': 'pure_release_bot_x64', 'Linux64 Debug (ARM)': 'debug_bot_arm64', 'Linux64 Release (ARM)': 'release_bot_arm64', - 'Linux64 Release (GCC)': 'gcc_release_bot_x64', 'Linux Asan': 'asan_lsan_clang_release_bot_x64', 'Linux MSan': 'msan_clang_release_bot_x64', 'Linux Tsan v2': 'tsan_clang_release_bot_x64', @@ -92,7 +90,6 @@ 'Win64 Debug (Clang)': 'win_clang_debug_bot_x64', 'Win64 Release (Clang)': 'win_clang_release_bot_x64', 'Win64 ASan': 'win_asan_clang_release_bot_x64', - 'Win64 UWP': 'win_uwp_release_bot_x64', 'Win (more configs)': { 'bwe_test_logging': 'bwe_test_logging_x86', @@ -113,6 +110,7 @@ 'Perf Android64 (M Nexus5X)': 'release_bot_x64', 'Perf Android64 (O Pixel2)': 'release_bot_x64', 'Perf Linux Trusty': 'release_bot_x64', + 'Perf Linux Bionic': 'release_bot_x64', 'Perf Mac 10.11': 'release_bot_x64', 'Perf Win7': 'release_bot_x64', }, @@ -149,8 +147,6 @@ }, 'tryserver.webrtc': { # iOS - 'ios_compile_arm_dbg': 'ios_debug_bot_arm', - 'ios_compile_arm_rel': 'ios_release_bot_arm', 'ios_compile_arm64_dbg': 'ios_debug_bot_arm64', 'ios_compile_arm64_rel': 'ios_release_bot_arm64', 'ios_sim_x64_dbg_ios12': 'ios_debug_bot_x64', @@ -173,7 +169,6 @@ 'linux_compile_arm_rel': 'release_bot_arm', 'linux_compile_arm64_dbg': 'debug_bot_arm64', 'linux_compile_arm64_rel': 'release_bot_arm64', - 'linux_compile_gcc_rel': 'gcc_release_bot_x64', 'linux_dbg': 'debug_bot_x64', 'linux_rel': 'release_bot_x64', 'linux_x86_rel': 'release_bot_x86', @@ -235,7 +230,6 @@ 'win_asan': 'win_asan_clang_release_bot_x64', 'win_x64_clang_dbg_win8': 'win_clang_debug_bot_x64', 'win_x64_clang_dbg_win10': 'win_clang_debug_bot_x64', - 'win_x64_uwp': 'win_uwp_release_bot_x64', 'win_x86_more_configs': { 'bwe_test_logging': 'bwe_test_logging_x86', @@ -253,9 +247,6 @@ # we might have mac, win, and linux bots all using the 'release_bot' config). 'configs': { # Linux, Mac and Windows - 'gcc_release_bot_x64': [ - 'gcc', 'release_bot_no_goma', 'x64', 'no_rtc_tests' - ], # TODO(kjellander): Restore Goma for this when crbug.com/726706 is fixed. 'debug_bot_arm': [ 'openh264', 'debug', 'arm' @@ -310,7 +301,7 @@ ], 'libfuzzer_asan_release_bot_x64': [ 'libfuzzer', 'asan', 'optimize_for_fuzzing', 'openh264', 'release_bot', - 'x64', 'no_rtc_tests' + 'x64' ], # Windows @@ -345,10 +336,6 @@ 'asan', 'clang', 'full_symbols', 'openh264', 'release_bot', 'x64', 'win_fastlink', ], - 'win_uwp_release_bot_x64': [ - # UWP passes compiler flags that are not supported by goma. - 'no_clang', 'openh264', 'x64', 'winuwp', 'release_bot_no_goma' - ], # Mac 'mac_asan_clang_release_bot_x64': [ @@ -391,14 +378,6 @@ ], # iOS - 'ios_debug_bot_arm': [ - 'ios', 'debug_bot', 'arm', 'no_ios_code_signing', 'ios_use_goma_rbe', - 'xctest', - ], - 'ios_release_bot_arm': [ - 'ios', 'release_bot', 'arm', 'no_ios_code_signing', 'ios_use_goma_rbe', - 'xctest', - ], 'ios_debug_bot_arm64': [ 'ios', 'debug_bot', 'arm64', 'no_ios_code_signing', 'ios_use_goma_rbe', 'xctest', @@ -504,11 +483,6 @@ 'gn_args': 'symbol_level=2', }, - 'gcc': { - 'gn_args': ('is_clang=false use_sysroot=false ' - 'treat_warnings_as_errors=false'), - }, - 'goma': { 'gn_args': 'use_goma=true', }, @@ -557,10 +531,6 @@ 'gn_args': 'use_lld=false', }, - 'no_rtc_tests': { - 'gn_args': 'rtc_include_tests=false', - }, - 'openh264': { 'gn_args': 'ffmpeg_branding="Chrome" rtc_use_h264=true', }, @@ -629,10 +599,6 @@ 'gn_args': 'rtc_enable_sctp=false', }, - 'winuwp': { - 'gn_args': 'target_os="winuwp"', - }, - 'win_undef_unicode': { 'gn_args': 'rtc_win_undef_unicode=true', }, diff --git a/tools_webrtc/mb/mb_unittest.py b/tools_webrtc/mb/mb_unittest.py index eb11d092f8..fc359d9995 100755 --- a/tools_webrtc/mb/mb_unittest.py +++ b/tools_webrtc/mb/mb_unittest.py @@ -13,7 +13,9 @@ import json import StringIO import os +import re import sys +import tempfile import unittest import mb @@ -32,6 +34,7 @@ def __init__(self, win32=False): self.platform = 'win32' self.executable = 'c:\\python\\python.exe' self.sep = '\\' + self.cwd = 'c:\\fake_src\\out\\Default' else: self.src_dir = '/fake_src' self.default_config = '/fake_src/tools_webrtc/mb/mb_config.pyl' @@ -39,8 +42,10 @@ def __init__(self, win32=False): self.executable = '/usr/bin/python' self.platform = 'linux2' self.sep = '/' + self.cwd = '/fake_src/out/Default' self.files = {} + self.dirs = set() self.calls = [] self.cmds = [] self.cross_compile = None @@ -52,21 +57,24 @@ def ExpandUser(self, path): return '$HOME/%s' % path def Exists(self, path): - return self.files.get(path) is not None + abs_path = self._AbsPath(path) + return (self.files.get(abs_path) is not None or abs_path in self.dirs) def MaybeMakeDirectory(self, path): - self.files[path] = True + abpath = self._AbsPath(path) + self.dirs.add(abpath) def PathJoin(self, *comps): return self.sep.join(comps) def ReadFile(self, path): - return self.files[path] + return self.files[self._AbsPath(path)] def WriteFile(self, path, contents, force_verbose=False): if self.args.dryrun or self.args.verbose or force_verbose: self.Print('\nWriting """\\\n%s""" to %s.\n' % (contents, path)) - self.files[path] = contents + abpath = self._AbsPath(path) + self.files[abpath] = contents def Call(self, cmd, env=None, buffer_output=True): self.calls.append(cmd) @@ -83,18 +91,34 @@ def Print(self, *args, **kwargs): else: self.out += sep.join(args) + end + def TempDir(self): + tmp_dir = os.path.join(tempfile.gettempdir(), 'mb_test') + self.dirs.add(tmp_dir) + return tmp_dir + def TempFile(self, mode='w'): return FakeFile(self.files) def RemoveFile(self, path): - del self.files[path] + abpath = self._AbsPath(path) + self.files[abpath] = None def RemoveDirectory(self, path): - self.rmdirs.append(path) - files_to_delete = [f for f in self.files if f.startswith(path)] + abpath = self._AbsPath(path) + self.rmdirs.append(abpath) + files_to_delete = [f for f in self.files if f.startswith(abpath)] for f in files_to_delete: self.files[f] = None + def _AbsPath(self, path): + if not ((self.platform == 'win32' and path.startswith('c:')) or + (self.platform != 'win32' and path.startswith('/'))): + path = self.PathJoin(self.cwd, path) + if self.sep == '\\': + return re.sub(r'\\+', r'\\', path) + else: + return re.sub('/+', '/', path) + class FakeFile(object): def __init__(self, files): @@ -176,13 +200,20 @@ def fake_mbw(self, files=None, win32=False): mbw.files[path] = contents return mbw - def check(self, args, mbw=None, files=None, out=None, err=None, ret=None): + def check(self, args, mbw=None, files=None, out=None, err=None, ret=None, + env=None): if not mbw: mbw = self.fake_mbw(files) - actual_ret = mbw.Main(args) - - self.assertEqual(actual_ret, ret) + try: + prev_env = os.environ.copy() + os.environ = env if env else prev_env + actual_ret = mbw.Main(args) + finally: + os.environ = prev_env + self.assertEqual( + actual_ret, ret, + "ret: %s, out: %s, err: %s" % (actual_ret, mbw.out, mbw.err)) if out is not None: self.assertEqual(mbw.out, out) if err is not None: @@ -564,8 +595,8 @@ def test_isolate_windowed_test_launcher_linux(self): def test_gen_windowed_test_launcher_win(self): files = { - '/tmp/swarming_targets': 'unittests\n', - '/fake_src/testing/buildbot/gn_isolate_map.pyl': ( + 'c:\\fake_src\\out\\Default\\tmp\\swarming_targets': 'unittests\n', + 'c:\\fake_src\\testing\\buildbot\\gn_isolate_map.pyl': ( "{'unittests': {" " 'label': '//somewhere:unittests'," " 'type': 'windowed_test_launcher'," @@ -579,9 +610,10 @@ def test_gen_windowed_test_launcher_win(self): mbw = self.fake_mbw(files=files, win32=True) self.check(['gen', '-c', 'debug_goma', - '--swarming-targets-file', '/tmp/swarming_targets', + '--swarming-targets-file', + 'c:\\fake_src\\out\\Default\\tmp\\swarming_targets', '--isolate-map-file', - '/fake_src/testing/buildbot/gn_isolate_map.pyl', + 'c:\\fake_src\\testing\\buildbot\\gn_isolate_map.pyl', '//out/Default'], mbw=mbw, ret=0) isolate_file = mbw.files['c:\\fake_src\\out\\Default\\unittests.isolate'] @@ -750,23 +782,40 @@ def test_run(self): def test_run_swarmed(self): files = { - '/fake_src/testing/buildbot/gn_isolate_map.pyl': ( - "{'base_unittests': {" - " 'label': '//base:base_unittests'," - " 'type': 'raw'," - " 'args': []," - "}}\n" - ), - '/fake_src/out/Default/base_unittests.runtime_deps': ( - "base_unittests\n" - ), - 'out/Default/base_unittests.archive.json': ( - "{\"base_unittests\":\"fake_hash\"}"), + '/fake_src/testing/buildbot/gn_isolate_map.pyl': + ("{'base_unittests': {" + " 'label': '//base:base_unittests'," + " 'type': 'console_test_launcher'," + "}}\n"), + '/fake_src/out/Default/base_unittests.runtime_deps': + ("base_unittests\n"), + '/fake_src/out/Default/base_unittests.archive.json': + ("{\"base_unittests\":\"fake_hash\"}"), + '/fake_src/third_party/depot_tools/cipd_manifest.txt': + ("# vpython\n" + "/some/vpython/pkg git_revision:deadbeef\n"), } + task_json = json.dumps({'tasks': [{'task_id': '00000'}]}) + collect_json = json.dumps({'00000': {'results': {}}}) mbw = self.fake_mbw(files=files) + mbw.files[mbw.PathJoin(mbw.TempDir(), 'task.json')] = task_json + mbw.files[mbw.PathJoin(mbw.TempDir(), 'collect_output.json')] = collect_json + original_impl = mbw.ToSrcRelPath + + def to_src_rel_path_stub(path): + if path.endswith('base_unittests.archive.json'): + return 'base_unittests.archive.json' + return original_impl(path) + + mbw.ToSrcRelPath = to_src_rel_path_stub + self.check(['run', '-s', '-c', 'debug_goma', '//out/Default', 'base_unittests'], mbw=mbw, ret=0) + mbw = self.fake_mbw(files=files) + mbw.files[mbw.PathJoin(mbw.TempDir(), 'task.json')] = task_json + mbw.files[mbw.PathJoin(mbw.TempDir(), 'collect_output.json')] = collect_json + mbw.ToSrcRelPath = to_src_rel_path_stub self.check(['run', '-s', '-c', 'debug_goma', '-d', 'os', 'Win7', '//out/Default', 'base_unittests'], mbw=mbw, ret=0) diff --git a/tools_webrtc/msan/suppressions.txt b/tools_webrtc/msan/suppressions.txt index ce8b14292e..47a0dff16f 100644 --- a/tools_webrtc/msan/suppressions.txt +++ b/tools_webrtc/msan/suppressions.txt @@ -4,8 +4,8 @@ # # Please think twice before you add or remove these rules. -# This is a stripped down copy of Chromium's blacklist.txt, to enable -# adding WebRTC-specific blacklist entries. +# This is a stripped down copy of Chromium's ignorelist.txt, to enable +# adding WebRTC-specific ignorelist entries. # Uninit in zlib. http://crbug.com/116277 fun:*MOZ_Z_deflate* diff --git a/tools_webrtc/perf/catapult_uploader.py b/tools_webrtc/perf/catapult_uploader.py index de7bd81c73..a10dd84cb5 100644 --- a/tools_webrtc/perf/catapult_uploader.py +++ b/tools_webrtc/perf/catapult_uploader.py @@ -145,13 +145,13 @@ def _CheckFullUploadInfo(url, upload_token, '?additional_info=measurements', method='GET', headers=headers) - print 'Full upload info: %r.' % content - if response.status != 200: print 'Failed to reach the dashboard to get full upload info.' return False resp_json = json.loads(content) + print 'Full upload info: %s.' % json.dumps(resp_json, indent=4) + if 'measurements' in resp_json: measurements_cnt = len(resp_json['measurements']) not_completed_state_cnt = len([ @@ -247,10 +247,13 @@ def UploadToDashboard(options): print 'Upload completed.' return 0 - if response.status != 200 or resp_json['state'] == 'FAILED': - print('Upload failed with %d: %s\n\n%s' % (response.status, - response.reason, - str(resp_json))) + if response.status != 200: + print('Upload status poll failed with %d: %s' % (response.status, + response.reason)) + return 1 + + if resp_json['state'] == 'FAILED': + print 'Upload failed.' return 1 print('Upload wasn\'t completed in a given time: %d seconds.' % diff --git a/tools_webrtc/perf/webrtc_dashboard_upload.py b/tools_webrtc/perf/webrtc_dashboard_upload.py index a709af5dcd..19db0250cf 100644 --- a/tools_webrtc/perf/webrtc_dashboard_upload.py +++ b/tools_webrtc/perf/webrtc_dashboard_upload.py @@ -50,7 +50,8 @@ def _CreateParser(): help='Which dashboard to use.') parser.add_argument('--input-results-file', type=argparse.FileType(), required=True, - help='A JSON file with output from WebRTC tests.') + help='A HistogramSet proto file with output from ' + 'WebRTC tests.') parser.add_argument('--output-json-file', type=argparse.FileType('w'), help='Where to write the output (for debugging).') parser.add_argument('--outdir', required=True, diff --git a/tools_webrtc/sanitizers/tsan_suppressions_webrtc.cc b/tools_webrtc/sanitizers/tsan_suppressions_webrtc.cc index 3177fbc74a..3eb85e9fb5 100644 --- a/tools_webrtc/sanitizers/tsan_suppressions_webrtc.cc +++ b/tools_webrtc/sanitizers/tsan_suppressions_webrtc.cc @@ -31,8 +31,6 @@ char kTSanDefaultSuppressions[] = // rtc_unittests // https://code.google.com/p/webrtc/issues/detail?id=2080 "race:rtc_base/logging.cc\n" - "race:rtc_base/shared_exclusive_lock_unittest.cc\n" - "race:rtc_base/signal_thread_unittest.cc\n" // rtc_pc_unittests // https://code.google.com/p/webrtc/issues/detail?id=2079 diff --git a/tools_webrtc/ubsan/suppressions.txt b/tools_webrtc/ubsan/suppressions.txt index 50b66e915a..dc76f38c20 100644 --- a/tools_webrtc/ubsan/suppressions.txt +++ b/tools_webrtc/ubsan/suppressions.txt @@ -1,7 +1,7 @@ ############################################################################# -# UBSan blacklist. +# UBSan ignorelist. # -# This is a WebRTC-specific replacement of Chromium's blacklist.txt. +# This is a WebRTC-specific replacement of Chromium's ignorelist.txt. # Only exceptions for third party libraries go here. WebRTC's code should use # the RTC_NO_SANITIZE macro. Please think twice before adding new exceptions. diff --git a/tools_webrtc/ubsan/vptr_suppressions.txt b/tools_webrtc/ubsan/vptr_suppressions.txt index 739de36659..617ba88f98 100644 --- a/tools_webrtc/ubsan/vptr_suppressions.txt +++ b/tools_webrtc/ubsan/vptr_suppressions.txt @@ -1,5 +1,5 @@ ############################################################################# -# UBSan vptr blacklist. +# UBSan vptr ignorelist. # Function and type based blacklisting use a mangled name, and it is especially # tricky to represent C++ types. For now, any possible changes by name manglings # are simply represented as wildcard expressions of regexp, and thus it might be @@ -8,7 +8,7 @@ # Please think twice before you add or remove these rules. # # This is a stripped down copy of Chromium's vptr_blacklist.txt, to enable -# adding WebRTC-specific blacklist entries. +# adding WebRTC-specific ignorelist entries. ############################################################################# # Using raw pointer values. diff --git a/tools_webrtc/version_updater/update_version.py b/tools_webrtc/version_updater/update_version.py index d5a72b8710..3c2be3fe75 100644 --- a/tools_webrtc/version_updater/update_version.py +++ b/tools_webrtc/version_updater/update_version.py @@ -42,6 +42,14 @@ def _RemovePreviousUpdateBranch(): logging.info('No branch to remove') +def _GetLastAuthor(): + """Returns a string with the author of the last commit.""" + author = subprocess.check_output(['git', 'log', + '-1', + '--pretty=format:"%an"']).splitlines() + return author + + def _GetBranches(): """Returns a tuple (active, branches). @@ -142,11 +150,15 @@ def main(): if opts.clean: _RemovePreviousUpdateBranch() + if _GetLastAuthor() == 'webrtc-version-updater': + logging.info('Last commit is a version change, skipping CL.') + return 0 + version_filename = os.path.join(CHECKOUT_SRC_DIR, 'call', 'version.cc') _CreateUpdateBranch() _UpdateWebRTCVersion(version_filename) if _IsTreeClean(): - logging.info("No WebRTC version change detected, skipping CL.") + logging.info('No WebRTC version change detected, skipping CL.') else: _LocalCommit() logging.info('Uploading CL...') diff --git a/tools_webrtc/whitespace.txt b/tools_webrtc/whitespace.txt index f85a7d2cf8..b1cfabb590 100644 --- a/tools_webrtc/whitespace.txt +++ b/tools_webrtc/whitespace.txt @@ -14,3 +14,4 @@ Foo Bar Baz Bur Alios ego vidi ventos; alias prospexi animo procellas - Cicero +Lahiru modifiying the line numbber 17Lahiru modifiying the line numbber 17 diff --git a/video/BUILD.gn b/video/BUILD.gn index f3e5817a84..7743aba944 100644 --- a/video/BUILD.gn +++ b/video/BUILD.gn @@ -12,8 +12,6 @@ rtc_library("video") { sources = [ "buffered_frame_decryptor.cc", "buffered_frame_decryptor.h", - "call_stats.cc", - "call_stats.h", "call_stats2.cc", "call_stats2.h", "encoder_rtcp_feedback.cc", @@ -22,18 +20,12 @@ rtc_library("video") { "quality_limitation_reason_tracker.h", "quality_threshold.cc", "quality_threshold.h", - "receive_statistics_proxy.cc", - "receive_statistics_proxy.h", "receive_statistics_proxy2.cc", "receive_statistics_proxy2.h", "report_block_stats.cc", "report_block_stats.h", - "rtp_streams_synchronizer.cc", - "rtp_streams_synchronizer.h", "rtp_streams_synchronizer2.cc", "rtp_streams_synchronizer2.h", - "rtp_video_stream_receiver.cc", - "rtp_video_stream_receiver.h", "rtp_video_stream_receiver2.cc", "rtp_video_stream_receiver2.h", "rtp_video_stream_receiver_frame_transformer_delegate.cc", @@ -48,20 +40,14 @@ rtc_library("video") { "stream_synchronization.h", "transport_adapter.cc", "transport_adapter.h", - "video_quality_observer.cc", - "video_quality_observer.h", "video_quality_observer2.cc", "video_quality_observer2.h", - "video_receive_stream.cc", - "video_receive_stream.h", "video_receive_stream2.cc", "video_receive_stream2.h", "video_send_stream.cc", "video_send_stream.h", "video_send_stream_impl.cc", "video_send_stream_impl.h", - "video_stream_decoder.cc", - "video_stream_decoder.h", "video_stream_decoder2.cc", "video_stream_decoder2.h", ] @@ -75,11 +61,13 @@ rtc_library("video") { "../api:libjingle_peerconnection_api", "../api:rtp_parameters", "../api:scoped_refptr", + "../api:sequence_checker", "../api:transport_api", "../api/crypto:frame_decryptor_interface", "../api/crypto:options", "../api/rtc_event_log", "../api/task_queue", + "../api/units:time_delta", "../api/units:timestamp", "../api/video:encoded_image", "../api/video:recordable_encoded_frame", @@ -97,7 +85,6 @@ rtc_library("video") { "../call:rtp_sender", "../call:video_stream_api", "../common_video", - "../media:rtc_h264_profile_id", "../modules:module_api", "../modules:module_api_public", "../modules/pacing", @@ -111,7 +98,6 @@ rtc_library("video") { "../modules/video_coding:nack_module", "../modules/video_coding:video_codec_interface", "../modules/video_coding:video_coding_utility", - "../modules/video_coding/deprecated:nack_module", "../modules/video_processing", "../rtc_base:checks", "../rtc_base:rate_limiter", @@ -120,6 +106,7 @@ rtc_library("video") { "../rtc_base:rtc_numerics", "../rtc_base:rtc_task_queue", "../rtc_base:stringutils", + "../rtc_base:threading", "../rtc_base:weak_ptr", "../rtc_base/experiments:alr_experiment", "../rtc_base/experiments:field_trial_parser", @@ -128,7 +115,6 @@ rtc_library("video") { "../rtc_base/experiments:quality_scaling_experiment", "../rtc_base/experiments:rate_control_settings", "../rtc_base/synchronization:mutex", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/system:no_unique_address", "../rtc_base/system:thread_registry", "../rtc_base/task_utils:pending_task_safety_flag", @@ -154,22 +140,74 @@ rtc_library("video") { } rtc_source_set("video_legacy") { - # TODO(bugs.webrtc.org/11581): These files should be moved to this target: - # - # "call_stats.cc", - # "call_stats.h", - # "receive_statistics_proxy.cc", - # "receive_statistics_proxy.h", - # "rtp_streams_synchronizer.cc", - # "rtp_streams_synchronizer.h", - # "rtp_video_stream_receiver.cc", - # "rtp_video_stream_receiver.h", - # "video_quality_observer.cc", - # "video_quality_observer.h", - # "video_receive_stream.cc", - # "video_receive_stream.h", - # "video_stream_decoder.cc", - # "video_stream_decoder.h", + sources = [ + "call_stats.cc", + "call_stats.h", + "receive_statistics_proxy.cc", + "receive_statistics_proxy.h", + "rtp_streams_synchronizer.cc", + "rtp_streams_synchronizer.h", + "rtp_video_stream_receiver.cc", + "rtp_video_stream_receiver.h", + "video_quality_observer.cc", + "video_quality_observer.h", + "video_receive_stream.cc", + "video_receive_stream.h", + "video_stream_decoder.cc", + "video_stream_decoder.h", + ] + deps = [ + ":frame_dumping_decoder", + ":video", + "../api:array_view", + "../api:scoped_refptr", + "../api:sequence_checker", + "../api/crypto:frame_decryptor_interface", + "../api/task_queue", + "../api/units:timestamp", + "../api/video:encoded_image", + "../api/video:recordable_encoded_frame", + "../api/video:video_frame", + "../api/video:video_rtp_headers", + "../api/video_codecs:video_codecs_api", + "../call:call_interfaces", + "../call:rtp_interfaces", + "../call:rtp_receiver", # For RtxReceiveStream. + "../call:video_stream_api", + "../common_video", + "../modules:module_api", + "../modules/pacing", + "../modules/remote_bitrate_estimator", + "../modules/rtp_rtcp", + "../modules/rtp_rtcp:rtp_rtcp_format", + "../modules/rtp_rtcp:rtp_rtcp_legacy", + "../modules/rtp_rtcp:rtp_video_header", + "../modules/utility", + "../modules/video_coding", + "../modules/video_coding:video_codec_interface", + "../modules/video_coding:video_coding_utility", + "../modules/video_coding/deprecated:nack_module", + "../rtc_base:checks", + "../rtc_base:rtc_base_approved", + "../rtc_base:rtc_numerics", + "../rtc_base:rtc_task_queue", + "../rtc_base/experiments:field_trial_parser", + "../rtc_base/experiments:keyframe_interval_settings_experiment", + "../rtc_base/synchronization:mutex", + "../rtc_base/system:no_unique_address", + "../rtc_base/system:thread_registry", + "../rtc_base/task_utils:to_queued_task", + "../system_wrappers", + "../system_wrappers:field_trial", + "../system_wrappers:metrics", + ] + if (!build_with_mozilla) { + deps += [ "../media:rtc_media_base" ] + } + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/types:optional", + ] } rtc_library("video_stream_decoder_impl") { @@ -181,6 +219,7 @@ rtc_library("video_stream_decoder_impl") { ] deps = [ + "../api:sequence_checker", "../api/task_queue", "../api/video:encoded_frame", "../api/video:video_frame", @@ -237,6 +276,7 @@ rtc_library("video_stream_encoder_impl") { deps = [ "../api:rtp_parameters", + "../api:sequence_checker", "../api/adaptation:resource_adaptation_api", "../api/task_queue:task_queue", "../api/units:data_rate", @@ -268,13 +308,13 @@ rtc_library("video_stream_encoder_impl") { "../rtc_base:timeutils", "../rtc_base/experiments:alr_experiment", "../rtc_base/experiments:balanced_degradation_settings", + "../rtc_base/experiments:encoder_info_settings", "../rtc_base/experiments:field_trial_parser", "../rtc_base/experiments:quality_rampup_experiment", "../rtc_base/experiments:quality_scaler_settings", "../rtc_base/experiments:quality_scaling_experiment", "../rtc_base/experiments:rate_control_settings", "../rtc_base/synchronization:mutex", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/system:no_unique_address", "../rtc_base/task_utils:pending_task_safety_flag", "../rtc_base/task_utils:repeating_task", @@ -298,238 +338,239 @@ if (rtc_include_tests) { "../test:test_support", ] } - rtc_library("video_quality_test") { - testonly = true + if (!build_with_chromium) { + rtc_library("video_quality_test") { + testonly = true - # Only targets in this file and api/ can depend on this. - visibility = [ - ":*", - "../api:create_video_quality_test_fixture_api", - ] - sources = [ - "video_analyzer.cc", - "video_analyzer.h", - "video_quality_test.cc", - "video_quality_test.h", - ] - deps = [ - ":frame_dumping_decoder", - "../api:create_frame_generator", - "../api:fec_controller_api", - "../api:frame_generator_api", - "../api:libjingle_peerconnection_api", - "../api:rtc_event_log_output_file", - "../api:test_dependency_factory", - "../api:video_quality_test_fixture_api", - "../api/numerics", - "../api/rtc_event_log:rtc_event_log_factory", - "../api/task_queue", - "../api/task_queue:default_task_queue_factory", - "../api/video:builtin_video_bitrate_allocator_factory", - "../api/video:video_bitrate_allocator_factory", - "../api/video:video_frame", - "../api/video:video_rtp_headers", - "../api/video_codecs:video_codecs_api", - "../call:fake_network", - "../call:simulated_network", - "../common_video", - "../media:rtc_audio_video", - "../media:rtc_encoder_simulcast_proxy", - "../media:rtc_internal_video_codecs", - "../media:rtc_media_base", - "../modules/audio_device:audio_device_api", - "../modules/audio_device:audio_device_module_from_input_and_output", - "../modules/audio_device:windows_core_audio_utility", - "../modules/audio_mixer:audio_mixer_impl", - "../modules/rtp_rtcp", - "../modules/rtp_rtcp:rtp_rtcp_format", - "../modules/video_coding", - "../modules/video_coding:video_coding_utility", - "../modules/video_coding:webrtc_h264", - "../modules/video_coding:webrtc_multiplex", - "../modules/video_coding:webrtc_vp8", - "../modules/video_coding:webrtc_vp9", - "../rtc_base:rtc_base_approved", - "../rtc_base:rtc_base_tests_utils", - "../rtc_base:rtc_numerics", - "../rtc_base:task_queue_for_test", - "../rtc_base/synchronization:mutex", - "../rtc_base/task_utils:repeating_task", - "../system_wrappers", - "../test:fake_video_codecs", - "../test:fileutils", - "../test:perf_test", - "../test:platform_video_capturer", - "../test:rtp_test_utils", - "../test:test_common", - "../test:test_renderer", - "../test:test_support", - "../test:test_support_test_artifacts", - "../test:video_test_common", - "../test:video_test_support", - ] - absl_deps = [ - "//third_party/abseil-cpp/absl/algorithm:container", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - ] + # Only targets in this file and api/ can depend on this. + visibility = [ + ":*", + "../api:create_video_quality_test_fixture_api", + ] + sources = [ + "video_analyzer.cc", + "video_analyzer.h", + "video_quality_test.cc", + "video_quality_test.h", + ] + deps = [ + ":frame_dumping_decoder", + "../api:create_frame_generator", + "../api:fec_controller_api", + "../api:frame_generator_api", + "../api:libjingle_peerconnection_api", + "../api:rtc_event_log_output_file", + "../api:test_dependency_factory", + "../api:video_quality_test_fixture_api", + "../api/numerics", + "../api/rtc_event_log:rtc_event_log_factory", + "../api/task_queue", + "../api/task_queue:default_task_queue_factory", + "../api/video:builtin_video_bitrate_allocator_factory", + "../api/video:video_bitrate_allocator_factory", + "../api/video:video_frame", + "../api/video:video_rtp_headers", + "../api/video_codecs:video_codecs_api", + "../call:fake_network", + "../call:simulated_network", + "../common_video", + "../media:rtc_audio_video", + "../media:rtc_encoder_simulcast_proxy", + "../media:rtc_internal_video_codecs", + "../media:rtc_media_base", + "../modules/audio_device:audio_device_api", + "../modules/audio_device:audio_device_module_from_input_and_output", + "../modules/audio_device:windows_core_audio_utility", + "../modules/audio_mixer:audio_mixer_impl", + "../modules/rtp_rtcp", + "../modules/rtp_rtcp:rtp_rtcp_format", + "../modules/video_coding", + "../modules/video_coding:video_coding_utility", + "../modules/video_coding:webrtc_h264", + "../modules/video_coding:webrtc_multiplex", + "../modules/video_coding:webrtc_vp8", + "../modules/video_coding:webrtc_vp9", + "../rtc_base:rtc_base_approved", + "../rtc_base:rtc_base_tests_utils", + "../rtc_base:rtc_numerics", + "../rtc_base:task_queue_for_test", + "../rtc_base/synchronization:mutex", + "../rtc_base/task_utils:repeating_task", + "../system_wrappers", + "../test:fake_video_codecs", + "../test:fileutils", + "../test:perf_test", + "../test:platform_video_capturer", + "../test:rtp_test_utils", + "../test:test_common", + "../test:test_renderer", + "../test:test_support", + "../test:test_support_test_artifacts", + "../test:video_test_common", + "../test:video_test_support", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] - if (is_mac || is_ios) { - deps += [ "../test:video_test_mac" ] + if (is_mac || is_ios) { + deps += [ "../test:video_test_mac" ] + } } - } - rtc_library("video_full_stack_tests") { - testonly = true - - sources = [ "full_stack_tests.cc" ] - deps = [ - ":video_quality_test", - "../api:simulated_network_api", - "../api:test_dependency_factory", - "../api:video_quality_test_fixture_api", - "../api/video_codecs:video_codecs_api", - "../media:rtc_vp9_profile", - "../modules/pacing", - "../modules/video_coding:webrtc_vp9", - "../rtc_base/experiments:alr_experiment", - "../system_wrappers:field_trial", - "../test:field_trial", - "../test:fileutils", - "../test:test_common", - "../test:test_support", - "//testing/gtest", - ] - absl_deps = [ - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - "//third_party/abseil-cpp/absl/types:optional", - ] - } + rtc_library("video_full_stack_tests") { + testonly = true - rtc_library("video_pc_full_stack_tests") { - testonly = true + sources = [ "full_stack_tests.cc" ] + deps = [ + ":video_quality_test", + "../api:simulated_network_api", + "../api:test_dependency_factory", + "../api:video_quality_test_fixture_api", + "../api/video_codecs:video_codecs_api", + "../modules/pacing", + "../modules/video_coding:webrtc_vp9", + "../rtc_base/experiments:alr_experiment", + "../system_wrappers:field_trial", + "../test:field_trial", + "../test:fileutils", + "../test:test_common", + "../test:test_support", + "//testing/gtest", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + "//third_party/abseil-cpp/absl/types:optional", + ] + } - sources = [ "pc_full_stack_tests.cc" ] - deps = [ - "../api:create_network_emulation_manager", - "../api:create_peer_connection_quality_test_frame_generator", - "../api:create_peerconnection_quality_test_fixture", - "../api:frame_generator_api", - "../api:media_stream_interface", - "../api:network_emulation_manager_api", - "../api:peer_connection_quality_test_fixture_api", - "../api:simulated_network_api", - "../api:time_controller", - "../call:simulated_network", - "../media:rtc_vp9_profile", - "../modules/video_coding:webrtc_vp9", - "../system_wrappers:field_trial", - "../test:field_trial", - "../test:fileutils", - "../test:test_support", - "../test/pc/e2e:network_quality_metrics_reporter", - ] - } + rtc_library("video_pc_full_stack_tests") { + testonly = true - rtc_library("video_loopback_lib") { - testonly = true - sources = [ - "video_loopback.cc", - "video_loopback.h", - ] - deps = [ - ":video_quality_test", - "../api:libjingle_peerconnection_api", - "../api:simulated_network_api", - "../api:video_quality_test_fixture_api", - "../api/transport:bitrate_settings", - "../api/video_codecs:video_codecs_api", - "../rtc_base:checks", - "../rtc_base:logging", - "../system_wrappers:field_trial", - "../test:field_trial", - "../test:run_test", - "../test:run_test_interface", - "../test:test_common", - "../test:test_renderer", - "../test:test_support", - "//testing/gtest", - ] - absl_deps = [ - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - "//third_party/abseil-cpp/absl/types:optional", - ] - } + sources = [ "pc_full_stack_tests.cc" ] + deps = [ + "../api:create_network_emulation_manager", + "../api:create_peer_connection_quality_test_frame_generator", + "../api:create_peerconnection_quality_test_fixture", + "../api:frame_generator_api", + "../api:media_stream_interface", + "../api:network_emulation_manager_api", + "../api:peer_connection_quality_test_fixture_api", + "../api:simulated_network_api", + "../api:time_controller", + "../api/video_codecs:video_codecs_api", + "../call:simulated_network", + "../modules/video_coding:webrtc_vp9", + "../system_wrappers:field_trial", + "../test:field_trial", + "../test:fileutils", + "../test:test_support", + "../test/pc/e2e:network_quality_metrics_reporter", + ] + } - if (is_mac) { - mac_app_bundle("video_loopback") { + rtc_library("video_loopback_lib") { testonly = true - sources = [ "video_loopback_main.mm" ] - info_plist = "../test/mac/Info.plist" - deps = [ ":video_loopback_lib" ] + sources = [ + "video_loopback.cc", + "video_loopback.h", + ] + deps = [ + ":video_quality_test", + "../api:libjingle_peerconnection_api", + "../api:simulated_network_api", + "../api:video_quality_test_fixture_api", + "../api/transport:bitrate_settings", + "../api/video_codecs:video_codecs_api", + "../rtc_base:checks", + "../rtc_base:logging", + "../system_wrappers:field_trial", + "../test:field_trial", + "../test:run_test", + "../test:run_test_interface", + "../test:test_common", + "../test:test_renderer", + "../test:test_support", + "//testing/gtest", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + "//third_party/abseil-cpp/absl/types:optional", + ] } - } else { - rtc_executable("video_loopback") { - testonly = true - sources = [ "video_loopback_main.cc" ] - deps = [ ":video_loopback_lib" ] + + if (is_mac) { + mac_app_bundle("video_loopback") { + testonly = true + sources = [ "video_loopback_main.mm" ] + info_plist = "../test/mac/Info.plist" + deps = [ ":video_loopback_lib" ] + } + } else { + rtc_executable("video_loopback") { + testonly = true + sources = [ "video_loopback_main.cc" ] + deps = [ ":video_loopback_lib" ] + } } - } - rtc_executable("screenshare_loopback") { - testonly = true - sources = [ "screenshare_loopback.cc" ] + rtc_executable("screenshare_loopback") { + testonly = true + sources = [ "screenshare_loopback.cc" ] - deps = [ - ":video_quality_test", - "../api:libjingle_peerconnection_api", - "../api:simulated_network_api", - "../api:video_quality_test_fixture_api", - "../api/transport:bitrate_settings", - "../api/video_codecs:video_codecs_api", - "../rtc_base:checks", - "../rtc_base:logging", - "../rtc_base:stringutils", - "../system_wrappers:field_trial", - "../test:field_trial", - "../test:run_test", - "../test:run_test_interface", - "../test:test_common", - "../test:test_renderer", - "../test:test_support", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - "//third_party/abseil-cpp/absl/types:optional", - ] - } + deps = [ + ":video_quality_test", + "../api:libjingle_peerconnection_api", + "../api:simulated_network_api", + "../api:video_quality_test_fixture_api", + "../api/transport:bitrate_settings", + "../api/video_codecs:video_codecs_api", + "../rtc_base:checks", + "../rtc_base:logging", + "../rtc_base:stringutils", + "../system_wrappers:field_trial", + "../test:field_trial", + "../test:run_test", + "../test:run_test_interface", + "../test:test_common", + "../test:test_renderer", + "../test:test_support", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + "//third_party/abseil-cpp/absl/types:optional", + ] + } - rtc_executable("sv_loopback") { - testonly = true - sources = [ "sv_loopback.cc" ] - deps = [ - ":video_quality_test", - "../api:libjingle_peerconnection_api", - "../api:simulated_network_api", - "../api:video_quality_test_fixture_api", - "../api/transport:bitrate_settings", - "../api/video_codecs:video_codecs_api", - "../rtc_base:checks", - "../rtc_base:logging", - "../rtc_base:stringutils", - "../system_wrappers:field_trial", - "../test:field_trial", - "../test:run_test", - "../test:run_test_interface", - "../test:test_common", - "../test:test_renderer", - "../test:test_support", - "//testing/gtest", - "//third_party/abseil-cpp/absl/flags:flag", - "//third_party/abseil-cpp/absl/flags:parse", - "//third_party/abseil-cpp/absl/types:optional", - ] + rtc_executable("sv_loopback") { + testonly = true + sources = [ "sv_loopback.cc" ] + deps = [ + ":video_quality_test", + "../api:libjingle_peerconnection_api", + "../api:simulated_network_api", + "../api:video_quality_test_fixture_api", + "../api/transport:bitrate_settings", + "../api/video_codecs:video_codecs_api", + "../rtc_base:checks", + "../rtc_base:logging", + "../rtc_base:stringutils", + "../system_wrappers:field_trial", + "../test:field_trial", + "../test:run_test", + "../test:run_test_interface", + "../test:test_common", + "../test:test_renderer", + "../test:test_support", + "//testing/gtest", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + "//third_party/abseil-cpp/absl/types:optional", + ] + } } # TODO(pbos): Rename test suite. @@ -559,6 +600,7 @@ if (rtc_include_tests) { "end_to_end_tests/multi_stream_tester.h", "end_to_end_tests/multi_stream_tests.cc", "end_to_end_tests/network_state_tests.cc", + "end_to_end_tests/resolution_bitrate_limits_tests.cc", "end_to_end_tests/retransmission_tests.cc", "end_to_end_tests/rtp_rtcp_tests.cc", "end_to_end_tests/ssrc_tests.cc", @@ -600,10 +642,12 @@ if (rtc_include_tests) { "../api:libjingle_peerconnection_api", "../api:mock_fec_controller_override", "../api:mock_frame_decryptor", + "../api:mock_video_codec_factory", "../api:mock_video_encoder", "../api:rtp_headers", "../api:rtp_parameters", "../api:scoped_refptr", + "../api:sequence_checker", "../api:simulated_network_api", "../api:transport_api", "../api/adaptation:resource_adaptation_api", @@ -619,7 +663,6 @@ if (rtc_include_tests) { "../api/video:video_adaptation", "../api/video:video_bitrate_allocation", "../api/video:video_frame", - "../api/video:video_frame_nv12", "../api/video:video_frame_type", "../api/video:video_rtp_headers", "../api/video_codecs:video_codecs_api", @@ -659,6 +702,8 @@ if (rtc_include_tests) { "../modules/video_coding:webrtc_multiplex", "../modules/video_coding:webrtc_vp8", "../modules/video_coding:webrtc_vp9", + "../modules/video_coding:webrtc_vp9_helpers", + "../modules/video_coding/codecs/av1:libaom_av1_encoder", "../rtc_base", "../rtc_base:checks", "../rtc_base:gunit_helpers", @@ -668,9 +713,10 @@ if (rtc_include_tests) { "../rtc_base:rtc_numerics", "../rtc_base:rtc_task_queue", "../rtc_base:task_queue_for_test", + "../rtc_base:threading", "../rtc_base/experiments:alr_experiment", + "../rtc_base/experiments:encoder_info_settings", "../rtc_base/synchronization:mutex", - "../rtc_base/synchronization:sequence_checker", "../rtc_base/task_utils:to_queued_task", "../system_wrappers", "../system_wrappers:field_trial", diff --git a/video/adaptation/BUILD.gn b/video/adaptation/BUILD.gn index c5afb02c83..20a2370b57 100644 --- a/video/adaptation/BUILD.gn +++ b/video/adaptation/BUILD.gn @@ -33,6 +33,7 @@ rtc_library("video_adaptation") { deps = [ "../../api:rtp_parameters", "../../api:scoped_refptr", + "../../api:sequence_checker", "../../api/adaptation:resource_adaptation_api", "../../api/task_queue:task_queue", "../../api/units:data_rate", @@ -55,7 +56,6 @@ rtc_library("video_adaptation") { "../../rtc_base/experiments:quality_rampup_experiment", "../../rtc_base/experiments:quality_scaler_settings", "../../rtc_base/synchronization:mutex", - "../../rtc_base/synchronization:sequence_checker", "../../rtc_base/system:no_unique_address", "../../rtc_base/task_utils:repeating_task", "../../rtc_base/task_utils:to_queued_task", @@ -75,6 +75,7 @@ if (rtc_include_tests) { defines = [] sources = [ + "bitrate_constraint_unittest.cc", "overuse_frame_detector_unittest.cc", "pixel_limit_resource_unittest.cc", "quality_scaler_resource_unittest.cc", diff --git a/video/adaptation/balanced_constraint.cc b/video/adaptation/balanced_constraint.cc index b4926a4a26..ec0b8e41d5 100644 --- a/video/adaptation/balanced_constraint.cc +++ b/video/adaptation/balanced_constraint.cc @@ -8,12 +8,13 @@ * be found in the AUTHORS file in the root of the source tree. */ +#include "video/adaptation/balanced_constraint.h" + #include #include -#include "rtc_base/synchronization/sequence_checker.h" +#include "api/sequence_checker.h" #include "rtc_base/task_utils/to_queued_task.h" -#include "video/adaptation/balanced_constraint.h" namespace webrtc { @@ -40,16 +41,16 @@ bool BalancedConstraint::IsAdaptationUpAllowed( // exceed bitrate constraints. if (degradation_preference_provider_->degradation_preference() == DegradationPreference::BALANCED) { + int frame_size_pixels = input_state.single_active_stream_pixels().value_or( + input_state.frame_size_pixels().value()); if (!balanced_settings_.CanAdaptUp( - input_state.video_codec_type(), - input_state.frame_size_pixels().value(), + input_state.video_codec_type(), frame_size_pixels, encoder_target_bitrate_bps_.value_or(0))) { return false; } if (DidIncreaseResolution(restrictions_before, restrictions_after) && !balanced_settings_.CanAdaptUpResolution( - input_state.video_codec_type(), - input_state.frame_size_pixels().value(), + input_state.video_codec_type(), frame_size_pixels, encoder_target_bitrate_bps_.value_or(0))) { return false; } diff --git a/video/adaptation/balanced_constraint.h b/video/adaptation/balanced_constraint.h index 15219360f5..0bbd670408 100644 --- a/video/adaptation/balanced_constraint.h +++ b/video/adaptation/balanced_constraint.h @@ -14,10 +14,10 @@ #include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "call/adaptation/adaptation_constraint.h" #include "call/adaptation/degradation_preference_provider.h" #include "rtc_base/experiments/balanced_degradation_settings.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" namespace webrtc { diff --git a/video/adaptation/bitrate_constraint.cc b/video/adaptation/bitrate_constraint.cc index 1061c4557f..cd61e555cd 100644 --- a/video/adaptation/bitrate_constraint.cc +++ b/video/adaptation/bitrate_constraint.cc @@ -8,12 +8,14 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include +#include "video/adaptation/bitrate_constraint.h" + #include +#include +#include "api/sequence_checker.h" #include "call/adaptation/video_stream_adapter.h" -#include "rtc_base/synchronization/sequence_checker.h" -#include "video/adaptation/bitrate_constraint.h" +#include "video/adaptation/video_stream_encoder_resource_manager.h" namespace webrtc { @@ -42,19 +44,35 @@ bool BitrateConstraint::IsAdaptationUpAllowed( RTC_DCHECK_RUN_ON(&sequence_checker_); // Make sure bitrate limits are not violated. if (DidIncreaseResolution(restrictions_before, restrictions_after)) { + if (!encoder_settings_.has_value()) { + return true; + } + uint32_t bitrate_bps = encoder_target_bitrate_bps_.value_or(0); + if (bitrate_bps == 0) { + return true; + } + + if (VideoStreamEncoderResourceManager::IsSimulcast( + encoder_settings_->encoder_config())) { + // Resolution bitrate limits usage is restricted to singlecast. + return true; + } + + absl::optional current_frame_size_px = + input_state.single_active_stream_pixels(); + if (!current_frame_size_px.has_value()) { + return true; + } + absl::optional bitrate_limits = - encoder_settings_.has_value() - ? encoder_settings_->encoder_info() - .GetEncoderBitrateLimitsForResolution( - // Need some sort of expected resulting pixels to be used - // instead of unrestricted. - GetHigherResolutionThan( - input_state.frame_size_pixels().value())) - : absl::nullopt; - if (bitrate_limits.has_value() && bitrate_bps != 0) { - RTC_DCHECK_GE(bitrate_limits->frame_size_pixels, - input_state.frame_size_pixels().value()); + encoder_settings_->encoder_info().GetEncoderBitrateLimitsForResolution( + // Need some sort of expected resulting pixels to be used + // instead of unrestricted. + GetHigherResolutionThan(*current_frame_size_px)); + + if (bitrate_limits.has_value()) { + RTC_DCHECK_GE(bitrate_limits->frame_size_pixels, *current_frame_size_px); return bitrate_bps >= static_cast(bitrate_limits->min_start_bitrate_bps); } diff --git a/video/adaptation/bitrate_constraint.h b/video/adaptation/bitrate_constraint.h index 6fefb04c24..a608e5db5d 100644 --- a/video/adaptation/bitrate_constraint.h +++ b/video/adaptation/bitrate_constraint.h @@ -14,11 +14,11 @@ #include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "call/adaptation/adaptation_constraint.h" #include "call/adaptation/encoder_settings.h" #include "call/adaptation/video_source_restrictions.h" #include "call/adaptation/video_stream_input_state.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" namespace webrtc { diff --git a/video/adaptation/bitrate_constraint_unittest.cc b/video/adaptation/bitrate_constraint_unittest.cc new file mode 100644 index 0000000000..d7865a12ed --- /dev/null +++ b/video/adaptation/bitrate_constraint_unittest.cc @@ -0,0 +1,191 @@ +/* + * Copyright 2021 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "video/adaptation/bitrate_constraint.h" + +#include +#include + +#include "api/video_codecs/video_encoder.h" +#include "call/adaptation/encoder_settings.h" +#include "call/adaptation/test/fake_frame_rate_provider.h" +#include "call/adaptation/video_source_restrictions.h" +#include "call/adaptation/video_stream_input_state_provider.h" +#include "test/gtest.h" + +namespace webrtc { + +namespace { +const VideoSourceRestrictions k360p{/*max_pixels_per_frame=*/640 * 360, + /*target_pixels_per_frame=*/640 * 360, + /*max_frame_rate=*/30}; +const VideoSourceRestrictions k720p{/*max_pixels_per_frame=*/1280 * 720, + /*target_pixels_per_frame=*/1280 * 720, + /*max_frame_rate=*/30}; + +void FillCodecConfig(VideoCodec* video_codec, + VideoEncoderConfig* encoder_config, + int width_px, + int height_px, + std::vector active_flags) { + size_t num_layers = active_flags.size(); + video_codec->codecType = kVideoCodecVP8; + video_codec->numberOfSimulcastStreams = num_layers; + + encoder_config->number_of_streams = num_layers; + encoder_config->simulcast_layers.resize(num_layers); + + for (size_t layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + int layer_width_px = width_px >> (num_layers - 1 - layer_idx); + int layer_height_px = height_px >> (num_layers - 1 - layer_idx); + + video_codec->simulcastStream[layer_idx].active = active_flags[layer_idx]; + video_codec->simulcastStream[layer_idx].width = layer_width_px; + video_codec->simulcastStream[layer_idx].height = layer_height_px; + + encoder_config->simulcast_layers[layer_idx].active = + active_flags[layer_idx]; + encoder_config->simulcast_layers[layer_idx].width = layer_width_px; + encoder_config->simulcast_layers[layer_idx].height = layer_height_px; + } +} + +constexpr int kStartBitrateBps720p = 1000000; + +VideoEncoder::EncoderInfo MakeEncoderInfo() { + VideoEncoder::EncoderInfo encoder_info; + encoder_info.resolution_bitrate_limits = { + {640 * 360, 500000, 0, 5000000}, + {1280 * 720, kStartBitrateBps720p, 0, 5000000}, + {1920 * 1080, 2000000, 0, 5000000}}; + return encoder_info; +} + +} // namespace + +class BitrateConstraintTest : public ::testing::Test { + public: + BitrateConstraintTest() + : frame_rate_provider_(), input_state_provider_(&frame_rate_provider_) {} + + protected: + void OnEncoderSettingsUpdated(int width_px, + int height_px, + std::vector active_flags) { + VideoCodec video_codec; + VideoEncoderConfig encoder_config; + FillCodecConfig(&video_codec, &encoder_config, width_px, height_px, + active_flags); + + EncoderSettings encoder_settings(MakeEncoderInfo(), + std::move(encoder_config), video_codec); + bitrate_constraint_.OnEncoderSettingsUpdated(encoder_settings); + input_state_provider_.OnEncoderSettingsChanged(encoder_settings); + } + + FakeFrameRateProvider frame_rate_provider_; + VideoStreamInputStateProvider input_state_provider_; + BitrateConstraint bitrate_constraint_; +}; + +TEST_F(BitrateConstraintTest, AdaptUpAllowedAtSinglecastIfBitrateIsEnough) { + OnEncoderSettingsUpdated(/*width_px=*/640, /*height_px=*/360, + /*active_flags=*/{true}); + + bitrate_constraint_.OnEncoderTargetBitrateUpdated(kStartBitrateBps720p); + + EXPECT_TRUE(bitrate_constraint_.IsAdaptationUpAllowed( + input_state_provider_.InputState(), + /*restrictions_before=*/k360p, + /*restrictions_after=*/k720p)); +} + +TEST_F(BitrateConstraintTest, + AdaptUpDisallowedAtSinglecastIfBitrateIsNotEnough) { + OnEncoderSettingsUpdated(/*width_px=*/640, /*height_px=*/360, + /*active_flags=*/{true}); + + // 1 bps less than needed for 720p. + bitrate_constraint_.OnEncoderTargetBitrateUpdated(kStartBitrateBps720p - 1); + + EXPECT_FALSE(bitrate_constraint_.IsAdaptationUpAllowed( + input_state_provider_.InputState(), + /*restrictions_before=*/k360p, + /*restrictions_after=*/k720p)); +} + +TEST_F(BitrateConstraintTest, + AdaptUpAllowedAtSinglecastUpperLayerActiveIfBitrateIsEnough) { + OnEncoderSettingsUpdated(/*width_px=*/640, /*height_px=*/360, + /*active_flags=*/{false, true}); + + bitrate_constraint_.OnEncoderTargetBitrateUpdated(kStartBitrateBps720p); + + EXPECT_TRUE(bitrate_constraint_.IsAdaptationUpAllowed( + input_state_provider_.InputState(), + /*restrictions_before=*/k360p, + /*restrictions_after=*/k720p)); +} + +TEST_F(BitrateConstraintTest, + AdaptUpDisallowedAtSinglecastUpperLayerActiveIfBitrateIsNotEnough) { + OnEncoderSettingsUpdated(/*width_px=*/640, /*height_px=*/360, + /*active_flags=*/{false, true}); + + // 1 bps less than needed for 720p. + bitrate_constraint_.OnEncoderTargetBitrateUpdated(kStartBitrateBps720p - 1); + + EXPECT_FALSE(bitrate_constraint_.IsAdaptationUpAllowed( + input_state_provider_.InputState(), + /*restrictions_before=*/k360p, + /*restrictions_after=*/k720p)); +} + +TEST_F(BitrateConstraintTest, + AdaptUpAllowedAtSinglecastLowestLayerActiveIfBitrateIsNotEnough) { + OnEncoderSettingsUpdated(/*width_px=*/640, /*height_px=*/360, + /*active_flags=*/{true, false}); + + // 1 bps less than needed for 720p. + bitrate_constraint_.OnEncoderTargetBitrateUpdated(kStartBitrateBps720p - 1); + + EXPECT_TRUE(bitrate_constraint_.IsAdaptationUpAllowed( + input_state_provider_.InputState(), + /*restrictions_before=*/k360p, + /*restrictions_after=*/k720p)); +} + +TEST_F(BitrateConstraintTest, AdaptUpAllowedAtSimulcastIfBitrateIsNotEnough) { + OnEncoderSettingsUpdated(/*width_px=*/640, /*height_px=*/360, + /*active_flags=*/{true, true}); + + // 1 bps less than needed for 720p. + bitrate_constraint_.OnEncoderTargetBitrateUpdated(kStartBitrateBps720p - 1); + + EXPECT_TRUE(bitrate_constraint_.IsAdaptationUpAllowed( + input_state_provider_.InputState(), + /*restrictions_before=*/k360p, + /*restrictions_after=*/k720p)); +} + +TEST_F(BitrateConstraintTest, + AdaptUpInFpsAllowedAtNoResolutionIncreaseIfBitrateIsNotEnough) { + OnEncoderSettingsUpdated(/*width_px=*/640, /*height_px=*/360, + /*active_flags=*/{true}); + + bitrate_constraint_.OnEncoderTargetBitrateUpdated(1); + + EXPECT_TRUE(bitrate_constraint_.IsAdaptationUpAllowed( + input_state_provider_.InputState(), + /*restrictions_before=*/k360p, + /*restrictions_after=*/k360p)); +} + +} // namespace webrtc diff --git a/video/adaptation/encode_usage_resource.cc b/video/adaptation/encode_usage_resource.cc index 8fe7450a0c..c42c63f4b7 100644 --- a/video/adaptation/encode_usage_resource.cc +++ b/video/adaptation/encode_usage_resource.cc @@ -21,7 +21,7 @@ namespace webrtc { // static rtc::scoped_refptr EncodeUsageResource::Create( std::unique_ptr overuse_detector) { - return new rtc::RefCountedObject( + return rtc::make_ref_counted( std::move(overuse_detector)); } diff --git a/video/adaptation/overuse_frame_detector.h b/video/adaptation/overuse_frame_detector.h index c9095d63a5..2b4dd61d21 100644 --- a/video/adaptation/overuse_frame_detector.h +++ b/video/adaptation/overuse_frame_detector.h @@ -15,12 +15,12 @@ #include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_base.h" #include "api/video/video_stream_encoder_observer.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/experiments/field_trial_parser.h" #include "rtc_base/numerics/exp_filter.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_utils/repeating_task.h" #include "rtc_base/thread_annotations.h" diff --git a/video/adaptation/overuse_frame_detector_unittest.cc b/video/adaptation/overuse_frame_detector_unittest.cc index d4bf910faa..37ad974a4c 100644 --- a/video/adaptation/overuse_frame_detector_unittest.cc +++ b/video/adaptation/overuse_frame_detector_unittest.cc @@ -455,6 +455,8 @@ TEST_F(OveruseFrameDetectorTest, RunOnTqNormalUsage) { EXPECT_TRUE(event.Wait(10000)); } +// TODO(crbug.com/webrtc/12846): investigate why the test fails on MAC bots. +#if !defined(WEBRTC_MAC) TEST_F(OveruseFrameDetectorTest, MaxIntervalScalesWithFramerate) { const int kCapturerMaxFrameRate = 30; const int kEncodeMaxFrameRate = 20; // Maximum fps the encoder can sustain. @@ -490,6 +492,7 @@ TEST_F(OveruseFrameDetectorTest, MaxIntervalScalesWithFramerate) { processing_time_us); overuse_detector_->CheckForOveruse(observer_); } +#endif TEST_F(OveruseFrameDetectorTest, RespectsMinFramerate) { const int kMinFrameRate = 7; // Minimum fps allowed by current detector impl. @@ -835,7 +838,7 @@ TEST_F(OveruseFrameDetectorTest2, ConvergesSlowly) { // Should have started to approach correct load of 15%, but not very far. EXPECT_LT(UsagePercent(), InitialUsage()); - EXPECT_GT(UsagePercent(), (InitialUsage() * 3 + 15) / 4); + EXPECT_GT(UsagePercent(), (InitialUsage() * 3 + 8) / 4); // Run for roughly 10s more, should now be closer. InsertAndSendFramesWithInterval(300, kFrameIntervalUs, kWidth, kHeight, diff --git a/video/adaptation/pixel_limit_resource.cc b/video/adaptation/pixel_limit_resource.cc index 96c8cac737..789dac2c0a 100644 --- a/video/adaptation/pixel_limit_resource.cc +++ b/video/adaptation/pixel_limit_resource.cc @@ -10,11 +10,11 @@ #include "video/adaptation/pixel_limit_resource.h" +#include "api/sequence_checker.h" #include "api/units/time_delta.h" #include "call/adaptation/video_stream_adapter.h" #include "rtc_base/checks.h" #include "rtc_base/ref_counted_object.h" -#include "rtc_base/synchronization/sequence_checker.h" namespace webrtc { @@ -28,8 +28,8 @@ constexpr TimeDelta kResourceUsageCheckIntervalMs = TimeDelta::Seconds(5); rtc::scoped_refptr PixelLimitResource::Create( TaskQueueBase* task_queue, VideoStreamInputStateProvider* input_state_provider) { - return new rtc::RefCountedObject(task_queue, - input_state_provider); + return rtc::make_ref_counted(task_queue, + input_state_provider); } PixelLimitResource::PixelLimitResource( diff --git a/video/adaptation/quality_scaler_resource.cc b/video/adaptation/quality_scaler_resource.cc index c438488182..c455252d45 100644 --- a/video/adaptation/quality_scaler_resource.cc +++ b/video/adaptation/quality_scaler_resource.cc @@ -22,7 +22,7 @@ namespace webrtc { // static rtc::scoped_refptr QualityScalerResource::Create() { - return new rtc::RefCountedObject(); + return rtc::make_ref_counted(); } QualityScalerResource::QualityScalerResource() diff --git a/video/adaptation/video_stream_encoder_resource.h b/video/adaptation/video_stream_encoder_resource.h index 477fdf492d..e10f595757 100644 --- a/video/adaptation/video_stream_encoder_resource.h +++ b/video/adaptation/video_stream_encoder_resource.h @@ -16,10 +16,10 @@ #include "absl/types/optional.h" #include "api/adaptation/resource.h" +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_base.h" #include "call/adaptation/adaptation_constraint.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" namespace webrtc { diff --git a/video/adaptation/video_stream_encoder_resource_manager.cc b/video/adaptation/video_stream_encoder_resource_manager.cc index 8d532f3e2c..2705bf9af7 100644 --- a/video/adaptation/video_stream_encoder_resource_manager.cc +++ b/video/adaptation/video_stream_encoder_resource_manager.cc @@ -11,6 +11,7 @@ #include "video/adaptation/video_stream_encoder_resource_manager.h" #include + #include #include #include @@ -20,6 +21,7 @@ #include "absl/algorithm/container.h" #include "absl/base/macros.h" #include "api/adaptation/resource.h" +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_base.h" #include "api/video/video_adaptation_reason.h" #include "api/video/video_source_interface.h" @@ -29,8 +31,8 @@ #include "rtc_base/numerics/safe_conversions.h" #include "rtc_base/ref_counted_object.h" #include "rtc_base/strings/string_builder.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/time_utils.h" +#include "rtc_base/trace_event.h" #include "system_wrappers/include/field_trial.h" #include "video/adaptation/quality_scaler_resource.h" @@ -64,30 +66,6 @@ std::string ToString(VideoAdaptationReason reason) { RTC_CHECK_NOTREACHED(); } -absl::optional GetSingleActiveStreamPixels(const VideoCodec& codec) { - int num_active = 0; - absl::optional pixels; - if (codec.codecType == VideoCodecType::kVideoCodecVP9) { - for (int i = 0; i < codec.VP9().numberOfSpatialLayers; ++i) { - if (codec.spatialLayers[i].active) { - ++num_active; - pixels = codec.spatialLayers[i].width * codec.spatialLayers[i].height; - } - } - } else { - for (int i = 0; i < codec.numberOfSimulcastStreams; ++i) { - if (codec.simulcastStream[i].active) { - ++num_active; - pixels = - codec.simulcastStream[i].width * codec.simulcastStream[i].height; - } - } - } - if (num_active > 1) - return absl::nullopt; - return pixels; -} - std::vector GetActiveLayersFlags(const VideoCodec& codec) { std::vector flags; if (codec.codecType == VideoCodecType::kVideoCodecVP9) { @@ -122,6 +100,8 @@ class VideoStreamEncoderResourceManager::InitialFrameDropper { set_start_bitrate_(DataRate::Zero()), set_start_bitrate_time_ms_(0), initial_framedrop_(0), + use_bandwidth_allocation_(false), + bandwidth_allocation_(DataRate::Zero()), last_input_width_(0), last_input_height_(0) { RTC_DCHECK(quality_scaler_resource_); @@ -136,12 +116,23 @@ class VideoStreamEncoderResourceManager::InitialFrameDropper { return single_active_stream_pixels_; } + absl::optional UseBandwidthAllocationBps() const { + return (use_bandwidth_allocation_ && + bandwidth_allocation_ > DataRate::Zero()) + ? absl::optional(bandwidth_allocation_.bps()) + : absl::nullopt; + } + // Input signals. void SetStartBitrate(DataRate start_bitrate, int64_t now_ms) { set_start_bitrate_ = start_bitrate; set_start_bitrate_time_ms_ = now_ms; } + void SetBandwidthAllocation(DataRate bandwidth_allocation) { + bandwidth_allocation_ = bandwidth_allocation; + } + void SetTargetBitrate(DataRate target_bitrate, int64_t now_ms) { if (set_start_bitrate_ > DataRate::Zero() && !has_seen_first_bwe_drop_ && quality_scaler_resource_->is_started() && @@ -182,18 +173,28 @@ class VideoStreamEncoderResourceManager::InitialFrameDropper { RTC_LOG(LS_INFO) << "Resetting initial_framedrop_ due to changed " "stream parameters"; initial_framedrop_ = 0; + if (single_active_stream_pixels_ && + VideoStreamAdapter::GetSingleActiveLayerPixels(codec) > + *single_active_stream_pixels_) { + // Resolution increased. + use_bandwidth_allocation_ = true; + } } } last_adaptation_counters_ = adaptation_counters; last_active_flags_ = active_flags; last_input_width_ = codec.width; last_input_height_ = codec.height; - single_active_stream_pixels_ = GetSingleActiveStreamPixels(codec); + single_active_stream_pixels_ = + VideoStreamAdapter::GetSingleActiveLayerPixels(codec); } void OnFrameDroppedDueToSize() { ++initial_framedrop_; } - void Disable() { initial_framedrop_ = kMaxInitialFramedrop; } + void Disable() { + initial_framedrop_ = kMaxInitialFramedrop; + use_bandwidth_allocation_ = false; + } void OnQualityScalerSettingsUpdated() { if (quality_scaler_resource_->is_started()) { @@ -201,7 +202,7 @@ class VideoStreamEncoderResourceManager::InitialFrameDropper { initial_framedrop_ = 0; } else { // Quality scaling disabled so we shouldn't drop initial frames. - initial_framedrop_ = kMaxInitialFramedrop; + Disable(); } } @@ -218,6 +219,8 @@ class VideoStreamEncoderResourceManager::InitialFrameDropper { // Counts how many frames we've dropped in the initial framedrop phase. int initial_framedrop_; absl::optional single_active_stream_pixels_; + bool use_bandwidth_allocation_; + DataRate bandwidth_allocation_; std::vector last_active_flags_; VideoAdaptationCounters last_adaptation_counters_; @@ -255,6 +258,9 @@ VideoStreamEncoderResourceManager::VideoStreamEncoderResourceManager( quality_rampup_experiment_( QualityRampUpExperimentHelper::CreateIfEnabled(this, clock_)), encoder_settings_(absl::nullopt) { + TRACE_EVENT0( + "webrtc", + "VideoStreamEncoderResourceManager::VideoStreamEncoderResourceManager"); RTC_CHECK(degradation_preference_provider_); RTC_CHECK(encoder_stats_observer_); } @@ -292,7 +298,7 @@ VideoStreamEncoderResourceManager::degradation_preference() const { return degradation_preference_; } -void VideoStreamEncoderResourceManager::EnsureEncodeUsageResourceStarted() { +void VideoStreamEncoderResourceManager::ConfigureEncodeUsageResource() { RTC_DCHECK_RUN_ON(encoder_queue_); RTC_DCHECK(encoder_settings_.has_value()); if (encode_usage_resource_->is_started()) { @@ -422,6 +428,8 @@ void VideoStreamEncoderResourceManager::SetEncoderRates( const VideoEncoder::RateControlParameters& encoder_rates) { RTC_DCHECK_RUN_ON(encoder_queue_); encoder_rates_ = encoder_rates; + initial_frame_dropper_->SetBandwidthAllocation( + encoder_rates.bandwidth_allocation); } void VideoStreamEncoderResourceManager::OnFrameDroppedDueToSize() { @@ -473,6 +481,12 @@ VideoStreamEncoderResourceManager::SingleActiveStreamPixels() const { return initial_frame_dropper_->single_active_stream_pixels(); } +absl::optional +VideoStreamEncoderResourceManager::UseBandwidthAllocationBps() const { + RTC_DCHECK_RUN_ON(encoder_queue_); + return initial_frame_dropper_->UseBandwidthAllocationBps(); +} + void VideoStreamEncoderResourceManager::OnMaybeEncodeFrame() { RTC_DCHECK_RUN_ON(encoder_queue_); initial_frame_dropper_->Disable(); @@ -484,7 +498,7 @@ void VideoStreamEncoderResourceManager::OnMaybeEncodeFrame() { quality_scaler_resource_, bandwidth, DataRate::BitsPerSec(encoder_target_bitrate_bps_.value_or(0)), DataRate::KilobitsPerSec(encoder_settings_->video_codec().maxBitrate), - LastInputFrameSizeOrDefault()); + LastFrameSizeOrDefault()); } } @@ -511,7 +525,9 @@ void VideoStreamEncoderResourceManager::ConfigureQualityScaler( const auto scaling_settings = encoder_info.scaling_settings; const bool quality_scaling_allowed = IsResolutionScalingEnabled(degradation_preference_) && - scaling_settings.thresholds; + (scaling_settings.thresholds.has_value() || + (encoder_settings_.has_value() && + encoder_settings_->encoder_config().is_quality_scaling_allowed)); // TODO(https://crbug.com/webrtc/11222): Should this move to // QualityScalerResource? @@ -525,9 +541,9 @@ void VideoStreamEncoderResourceManager::ConfigureQualityScaler( experimental_thresholds = QualityScalingExperiment::GetQpThresholds( GetVideoCodecTypeOrGeneric(encoder_settings_)); } - UpdateQualityScalerSettings(experimental_thresholds - ? *experimental_thresholds - : *(scaling_settings.thresholds)); + UpdateQualityScalerSettings(experimental_thresholds.has_value() + ? experimental_thresholds + : scaling_settings.thresholds); } } else { UpdateQualityScalerSettings(absl::nullopt); @@ -539,7 +555,7 @@ void VideoStreamEncoderResourceManager::ConfigureQualityScaler( absl::optional thresholds = balanced_settings_.GetQpThresholds( GetVideoCodecTypeOrGeneric(encoder_settings_), - LastInputFrameSizeOrDefault()); + LastFrameSizeOrDefault()); if (thresholds) { quality_scaler_resource_->SetQpThresholds(*thresholds); } @@ -579,10 +595,13 @@ CpuOveruseOptions VideoStreamEncoderResourceManager::GetCpuOveruseOptions() return options; } -int VideoStreamEncoderResourceManager::LastInputFrameSizeOrDefault() const { +int VideoStreamEncoderResourceManager::LastFrameSizeOrDefault() const { RTC_DCHECK_RUN_ON(encoder_queue_); - return input_state_provider_->InputState().frame_size_pixels().value_or( - kDefaultInputPixelsWidth * kDefaultInputPixelsHeight); + return input_state_provider_->InputState() + .single_active_stream_pixels() + .value_or( + input_state_provider_->InputState().frame_size_pixels().value_or( + kDefaultInputPixelsWidth * kDefaultInputPixelsHeight)); } void VideoStreamEncoderResourceManager::OnVideoSourceRestrictionsUpdated( @@ -703,4 +722,25 @@ void VideoStreamEncoderResourceManager::OnQualityRampUp() { stream_adapter_->ClearRestrictions(); quality_rampup_experiment_.reset(); } + +bool VideoStreamEncoderResourceManager::IsSimulcast( + const VideoEncoderConfig& encoder_config) { + const std::vector& simulcast_layers = + encoder_config.simulcast_layers; + if (simulcast_layers.size() <= 1) { + return false; + } + + if (simulcast_layers[0].active) { + // We can't distinguish between simulcast and singlecast when only the + // lowest spatial layer is active. Treat this case as simulcast. + return true; + } + + int num_active_layers = + std::count_if(simulcast_layers.begin(), simulcast_layers.end(), + [](const VideoStream& layer) { return layer.active; }); + return num_active_layers > 1; +} + } // namespace webrtc diff --git a/video/adaptation/video_stream_encoder_resource_manager.h b/video/adaptation/video_stream_encoder_resource_manager.h index 30bab53cbf..e7174d2344 100644 --- a/video/adaptation/video_stream_encoder_resource_manager.h +++ b/video/adaptation/video_stream_encoder_resource_manager.h @@ -66,7 +66,7 @@ extern const int kDefaultInputPixelsHeight; // resources. // // The manager is also involved with various mitigations not part of the -// ResourceAdaptationProcessor code such as the inital frame dropping. +// ResourceAdaptationProcessor code such as the initial frame dropping. class VideoStreamEncoderResourceManager : public VideoSourceRestrictionsListener, public ResourceLimitationsListener, @@ -92,7 +92,7 @@ class VideoStreamEncoderResourceManager void SetDegradationPreferences(DegradationPreference degradation_preference); DegradationPreference degradation_preference() const; - void EnsureEncodeUsageResourceStarted(); + void ConfigureEncodeUsageResource(); // Initializes the pixel limit resource if the "WebRTC-PixelLimitResource" // field trial is enabled. This can be used for testing. void MaybeInitializePixelLimitResource(); @@ -130,6 +130,7 @@ class VideoStreamEncoderResourceManager // frames based on size and bitrate. bool DropInitialFrames() const; absl::optional SingleActiveStreamPixels() const; + absl::optional UseBandwidthAllocationBps() const; // VideoSourceRestrictionsListener implementation. // Updates |video_source_restrictions_|. @@ -146,6 +147,8 @@ class VideoStreamEncoderResourceManager // QualityRampUpExperimentListener implementation. void OnQualityRampUp() override; + static bool IsSimulcast(const VideoEncoderConfig& encoder_config); + private: class InitialFrameDropper; @@ -153,7 +156,7 @@ class VideoStreamEncoderResourceManager rtc::scoped_refptr resource) const; CpuOveruseOptions GetCpuOveruseOptions() const; - int LastInputFrameSizeOrDefault() const; + int LastFrameSizeOrDefault() const; // Calculates an up-to-date value of the target frame rate and informs the // |encode_usage_resource_| of the new value. diff --git a/video/alignment_adjuster.cc b/video/alignment_adjuster.cc index b08f2f184a..6b1db9238b 100644 --- a/video/alignment_adjuster.cc +++ b/video/alignment_adjuster.cc @@ -66,7 +66,8 @@ double RoundToMultiple(int alignment, int AlignmentAdjuster::GetAlignmentAndMaybeAdjustScaleFactors( const VideoEncoder::EncoderInfo& encoder_info, - VideoEncoderConfig* config) { + VideoEncoderConfig* config, + absl::optional max_layers) { const int requested_alignment = encoder_info.requested_resolution_alignment; if (!encoder_info.apply_alignment_to_all_simulcast_layers) { return requested_alignment; @@ -85,7 +86,11 @@ int AlignmentAdjuster::GetAlignmentAndMaybeAdjustScaleFactors( if (!has_scale_resolution_down_by) { // Default resolution downscaling used (scale factors: 1, 2, 4, ...). - return requested_alignment * (1 << (config->simulcast_layers.size() - 1)); + size_t size = config->simulcast_layers.size(); + if (max_layers && *max_layers > 0 && *max_layers < size) { + size = *max_layers; + } + return requested_alignment * (1 << (size - 1)); } // Get alignment for downscaled layers. diff --git a/video/alignment_adjuster.h b/video/alignment_adjuster.h index 53d7927887..4b72623a19 100644 --- a/video/alignment_adjuster.h +++ b/video/alignment_adjuster.h @@ -28,9 +28,13 @@ class AlignmentAdjuster { // |scale_resolution_down_by| may be adjusted to a common multiple to limit // the alignment value to avoid largely cropped frames and possibly with an // aspect ratio far from the original. + + // Note: |max_layers| currently only taken into account when using default + // scale factors. static int GetAlignmentAndMaybeAdjustScaleFactors( const VideoEncoder::EncoderInfo& info, - VideoEncoderConfig* config); + VideoEncoderConfig* config, + absl::optional max_layers); }; } // namespace webrtc diff --git a/video/alignment_adjuster_unittest.cc b/video/alignment_adjuster_unittest.cc index 07c7de5f16..28e4bc0550 100644 --- a/video/alignment_adjuster_unittest.cc +++ b/video/alignment_adjuster_unittest.cc @@ -86,6 +86,30 @@ INSTANTIATE_TEST_SUITE_P( std::vector{1.5, 2.5}, 15)))); +class AlignmentAdjusterTestTwoLayers : public AlignmentAdjusterTest { + protected: + const int kMaxLayers = 2; +}; + +INSTANTIATE_TEST_SUITE_P( + ScaleFactorsAndAlignmentWithMaxLayers, + AlignmentAdjusterTestTwoLayers, + ::testing::Combine( + ::testing::Values(2), // kRequestedAlignment + ::testing::Values( + std::make_tuple(std::vector{-1.0}, // kScaleFactors + std::vector{-1.0}, // kAdjustedScaleFactors + 2), // default: {1.0} // kAdjustedAlignment + std::make_tuple(std::vector{-1.0, -1.0}, + std::vector{-1.0, -1.0}, + 4), // default: {1.0, 2.0} + std::make_tuple(std::vector{-1.0, -1.0, -1.0}, + std::vector{-1.0, -1.0, -1.0}, + 4), // default: {1.0, 2.0, 4.0} + std::make_tuple(std::vector{1.0, 2.0, 4.0}, + std::vector{1.0, 2.0, 4.0}, + 8)))); + TEST_P(AlignmentAdjusterTest, AlignmentAppliedToAllLayers) { const bool kApplyAlignmentToAllLayers = true; @@ -100,8 +124,8 @@ TEST_P(AlignmentAdjusterTest, AlignmentAppliedToAllLayers) { // Verify requested alignment from sink. VideoEncoder::EncoderInfo info = GetEncoderInfo(kRequestedAlignment, kApplyAlignmentToAllLayers); - int alignment = - AlignmentAdjuster::GetAlignmentAndMaybeAdjustScaleFactors(info, &config); + int alignment = AlignmentAdjuster::GetAlignmentAndMaybeAdjustScaleFactors( + info, &config, absl::nullopt); EXPECT_EQ(alignment, kAdjustedAlignment); // Verify adjusted scale factors. @@ -125,8 +149,8 @@ TEST_P(AlignmentAdjusterTest, AlignmentNotAppliedToAllLayers) { // Verify requested alignment from sink, alignment is not adjusted. VideoEncoder::EncoderInfo info = GetEncoderInfo(kRequestedAlignment, kApplyAlignmentToAllLayers); - int alignment = - AlignmentAdjuster::GetAlignmentAndMaybeAdjustScaleFactors(info, &config); + int alignment = AlignmentAdjuster::GetAlignmentAndMaybeAdjustScaleFactors( + info, &config, absl::nullopt); EXPECT_EQ(alignment, kRequestedAlignment); // Verify that scale factors are not adjusted. @@ -136,5 +160,30 @@ TEST_P(AlignmentAdjusterTest, AlignmentNotAppliedToAllLayers) { } } +TEST_P(AlignmentAdjusterTestTwoLayers, AlignmentAppliedToAllLayers) { + const bool kApplyAlignmentToAllLayers = true; + + // Fill config with the scaling factor by which to reduce encoding size. + const int num_streams = kScaleFactors.size(); + VideoEncoderConfig config; + test::FillEncoderConfiguration(kVideoCodecVP8, num_streams, &config); + for (int i = 0; i < num_streams; ++i) { + config.simulcast_layers[i].scale_resolution_down_by = kScaleFactors[i]; + } + + // Verify requested alignment from sink, alignment is not adjusted. + VideoEncoder::EncoderInfo info = + GetEncoderInfo(kRequestedAlignment, kApplyAlignmentToAllLayers); + int alignment = AlignmentAdjuster::GetAlignmentAndMaybeAdjustScaleFactors( + info, &config, absl::optional(kMaxLayers)); + EXPECT_EQ(alignment, kAdjustedAlignment); + + // Verify adjusted scale factors. + for (int i = 0; i < num_streams; ++i) { + EXPECT_EQ(config.simulcast_layers[i].scale_resolution_down_by, + kAdjustedScaleFactors[i]); + } +} + } // namespace test } // namespace webrtc diff --git a/video/buffered_frame_decryptor.cc b/video/buffered_frame_decryptor.cc index 187bac6ee4..436fff83f8 100644 --- a/video/buffered_frame_decryptor.cc +++ b/video/buffered_frame_decryptor.cc @@ -36,7 +36,7 @@ void BufferedFrameDecryptor::SetFrameDecryptor( } void BufferedFrameDecryptor::ManageEncryptedFrame( - std::unique_ptr encrypted_frame) { + std::unique_ptr encrypted_frame) { switch (DecryptFrame(encrypted_frame.get())) { case FrameDecision::kStash: if (stashed_frames_.size() >= kMaxStashedFrames) { @@ -55,7 +55,7 @@ void BufferedFrameDecryptor::ManageEncryptedFrame( } BufferedFrameDecryptor::FrameDecision BufferedFrameDecryptor::DecryptFrame( - video_coding::RtpFrameObject* frame) { + RtpFrameObject* frame) { // Optionally attempt to decrypt the raw video frame if it was provided. if (frame_decryptor_ == nullptr) { RTC_LOG(LS_INFO) << "Frame decryption required but not attached to this " diff --git a/video/buffered_frame_decryptor.h b/video/buffered_frame_decryptor.h index ff04837bc0..f6dd8d8c2a 100644 --- a/video/buffered_frame_decryptor.h +++ b/video/buffered_frame_decryptor.h @@ -27,8 +27,7 @@ class OnDecryptedFrameCallback { public: virtual ~OnDecryptedFrameCallback() = default; // Called each time a decrypted frame is returned. - virtual void OnDecryptedFrame( - std::unique_ptr frame) = 0; + virtual void OnDecryptedFrame(std::unique_ptr frame) = 0; }; // This callback is called each time there is a status change in the decryption @@ -72,8 +71,7 @@ class BufferedFrameDecryptor final { // Determines whether the frame should be stashed, dropped or handed off to // the OnDecryptedFrameCallback. - void ManageEncryptedFrame( - std::unique_ptr encrypted_frame); + void ManageEncryptedFrame(std::unique_ptr encrypted_frame); private: // Represents what should be done with a given frame. @@ -82,7 +80,7 @@ class BufferedFrameDecryptor final { // Attempts to decrypt the frame, if it fails and no prior frames have been // decrypted it will return kStash. Otherwise fail to decrypts will return // kDrop. Successful decryptions will always return kDecrypted. - FrameDecision DecryptFrame(video_coding::RtpFrameObject* frame); + FrameDecision DecryptFrame(RtpFrameObject* frame); // Retries all the stashed frames this is triggered each time a kDecrypted // event occurs. void RetryStashedFrames(); @@ -96,7 +94,7 @@ class BufferedFrameDecryptor final { rtc::scoped_refptr frame_decryptor_; OnDecryptedFrameCallback* const decrypted_frame_callback_; OnDecryptionStatusChangeCallback* const decryption_status_change_callback_; - std::deque> stashed_frames_; + std::deque> stashed_frames_; }; } // namespace webrtc diff --git a/video/buffered_frame_decryptor_unittest.cc b/video/buffered_frame_decryptor_unittest.cc index bbc08b0da3..2f8a183ba1 100644 --- a/video/buffered_frame_decryptor_unittest.cc +++ b/video/buffered_frame_decryptor_unittest.cc @@ -43,8 +43,7 @@ class BufferedFrameDecryptorTest : public ::testing::Test, public OnDecryptionStatusChangeCallback { public: // Implements the OnDecryptedFrameCallbackInterface - void OnDecryptedFrame( - std::unique_ptr frame) override { + void OnDecryptedFrame(std::unique_ptr frame) override { decrypted_frame_call_count_++; } @@ -54,14 +53,13 @@ class BufferedFrameDecryptorTest : public ::testing::Test, // Returns a new fake RtpFrameObject it abstracts the difficult construction // of the RtpFrameObject to simplify testing. - std::unique_ptr CreateRtpFrameObject( - bool key_frame) { + std::unique_ptr CreateRtpFrameObject(bool key_frame) { seq_num_++; RTPVideoHeader rtp_video_header; rtp_video_header.generic.emplace(); // clang-format off - return std::make_unique( + return std::make_unique( seq_num_, seq_num_, /*markerBit=*/true, @@ -88,7 +86,7 @@ class BufferedFrameDecryptorTest : public ::testing::Test, decrypted_frame_call_count_ = 0; decryption_status_change_count_ = 0; seq_num_ = 0; - mock_frame_decryptor_ = new rtc::RefCountedObject(); + mock_frame_decryptor_ = rtc::make_ref_counted(); buffered_frame_decryptor_ = std::make_unique(this, this); buffered_frame_decryptor_->SetFrameDecryptor(mock_frame_decryptor_.get()); diff --git a/video/call_stats.h b/video/call_stats.h index 3bfb632446..5dc8fa0cbb 100644 --- a/video/call_stats.h +++ b/video/call_stats.h @@ -14,12 +14,12 @@ #include #include +#include "api/sequence_checker.h" #include "modules/include/module.h" #include "modules/include/module_common_types.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" #include "rtc_base/constructor_magic.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/thread_checker.h" #include "system_wrappers/include/clock.h" namespace webrtc { @@ -110,8 +110,8 @@ class CallStats : public Module, public RtcpRttStats { // for the observers_ list, which makes the most common case lock free. std::list observers_; - rtc::ThreadChecker construction_thread_checker_; - rtc::ThreadChecker process_thread_checker_; + SequenceChecker construction_thread_checker_; + SequenceChecker process_thread_checker_; ProcessThread* const process_thread_; bool process_thread_running_ RTC_GUARDED_BY(construction_thread_checker_); diff --git a/video/call_stats2.cc b/video/call_stats2.cc index fbbe2de4f9..2b7c61e0f8 100644 --- a/video/call_stats2.cc +++ b/video/call_stats2.cc @@ -77,11 +77,6 @@ CallStats::CallStats(Clock* clock, TaskQueueBase* task_queue) task_queue_(task_queue) { RTC_DCHECK(task_queue_); RTC_DCHECK_RUN_ON(task_queue_); - repeating_task_ = - RepeatingTaskHandle::DelayedStart(task_queue_, kUpdateInterval, [this]() { - UpdateAndReport(); - return kUpdateInterval; - }); } CallStats::~CallStats() { @@ -93,6 +88,15 @@ CallStats::~CallStats() { UpdateHistograms(); } +void CallStats::EnsureStarted() { + RTC_DCHECK_RUN_ON(task_queue_); + repeating_task_ = + RepeatingTaskHandle::DelayedStart(task_queue_, kUpdateInterval, [this]() { + UpdateAndReport(); + return kUpdateInterval; + }); +} + void CallStats::UpdateAndReport() { RTC_DCHECK_RUN_ON(task_queue_); diff --git a/video/call_stats2.h b/video/call_stats2.h index 5932fad9fb..35a7935581 100644 --- a/video/call_stats2.h +++ b/video/call_stats2.h @@ -35,6 +35,9 @@ class CallStats { CallStats(Clock* clock, TaskQueueBase* task_queue); ~CallStats(); + // Ensure that necessary repeating tasks are started. + void EnsureStarted(); + // Expose an RtcpRttStats implementation without inheriting from RtcpRttStats. // That allows us to separate the threading model of how RtcpRttStats is // used (mostly on a process thread) and how CallStats is used (mostly on diff --git a/video/call_stats2_unittest.cc b/video/call_stats2_unittest.cc index b3d43cb92a..33235faeaa 100644 --- a/video/call_stats2_unittest.cc +++ b/video/call_stats2_unittest.cc @@ -38,7 +38,10 @@ class MockStatsObserver : public CallStatsObserver { class CallStats2Test : public ::testing::Test { public: - CallStats2Test() { process_thread_->Start(); } + CallStats2Test() { + call_stats_.EnsureStarted(); + process_thread_->Start(); + } ~CallStats2Test() override { process_thread_->Stop(); } diff --git a/video/encoder_bitrate_adjuster.cc b/video/encoder_bitrate_adjuster.cc index 45d88875e3..6a2c99ffe3 100644 --- a/video/encoder_bitrate_adjuster.cc +++ b/video/encoder_bitrate_adjuster.cc @@ -314,15 +314,14 @@ void EncoderBitrateAdjuster::OnEncoderInfo( AdjustRateAllocation(current_rate_control_parameters_); } -void EncoderBitrateAdjuster::OnEncodedFrame(const EncodedImage& encoded_image, +void EncoderBitrateAdjuster::OnEncodedFrame(DataSize size, + int spatial_index, int temporal_index) { ++frames_since_layout_change_; // Detectors may not exist, for instance if ScreenshareLayers is used. - auto& detector = - overshoot_detectors_[encoded_image.SpatialIndex().value_or(0)] - [temporal_index]; + auto& detector = overshoot_detectors_[spatial_index][temporal_index]; if (detector) { - detector->OnEncodedFrame(encoded_image.size(), rtc::TimeMillis()); + detector->OnEncodedFrame(size.bytes(), rtc::TimeMillis()); } } diff --git a/video/encoder_bitrate_adjuster.h b/video/encoder_bitrate_adjuster.h index b142519b4e..74d0289ad0 100644 --- a/video/encoder_bitrate_adjuster.h +++ b/video/encoder_bitrate_adjuster.h @@ -47,7 +47,7 @@ class EncoderBitrateAdjuster { void OnEncoderInfo(const VideoEncoder::EncoderInfo& encoder_info); // Updates the overuse detectors according to the encoded image size. - void OnEncodedFrame(const EncodedImage& encoded_image, int temporal_index); + void OnEncodedFrame(DataSize size, int spatial_index, int temporal_index); void Reset(); diff --git a/video/encoder_bitrate_adjuster_unittest.cc b/video/encoder_bitrate_adjuster_unittest.cc index d8fcf382b2..c249a5cb79 100644 --- a/video/encoder_bitrate_adjuster_unittest.cc +++ b/video/encoder_bitrate_adjuster_unittest.cc @@ -160,15 +160,12 @@ class EncoderBitrateAdjusterTest : public ::testing::Test { int sequence_idx = sequence_idx_[si][ti]; sequence_idx_[si][ti] = (sequence_idx_[si][ti] + 1) % kSequenceLength; - const size_t frame_size_bytes = + const DataSize frame_size = DataSize::Bytes( (sequence_idx < kSequenceLength / 2) ? media_frame_size - network_frame_size_diff_bytes - : media_frame_size + network_frame_size_diff_bytes; + : media_frame_size + network_frame_size_diff_bytes); - EncodedImage image; - image.SetEncodedData(EncodedImageBuffer::Create(frame_size_bytes)); - image.SetSpatialIndex(si); - adjuster_->OnEncodedFrame(image, ti); + adjuster_->OnEncodedFrame(frame_size, si, ti); sequence_idx = ++sequence_idx % kSequenceLength; } } diff --git a/video/encoder_rtcp_feedback.cc b/video/encoder_rtcp_feedback.cc index b81ff6120f..17095a0a0c 100644 --- a/video/encoder_rtcp_feedback.cc +++ b/video/encoder_rtcp_feedback.cc @@ -10,6 +10,9 @@ #include "video/encoder_rtcp_feedback.h" +#include +#include + #include "absl/types/optional.h" #include "api/video_codecs/video_encoder.h" #include "rtc_base/checks.h" @@ -21,47 +24,36 @@ namespace { constexpr int kMinKeyframeSendIntervalMs = 300; } // namespace -EncoderRtcpFeedback::EncoderRtcpFeedback(Clock* clock, - const std::vector& ssrcs, - VideoStreamEncoderInterface* encoder) +EncoderRtcpFeedback::EncoderRtcpFeedback( + Clock* clock, + const std::vector& ssrcs, + VideoStreamEncoderInterface* encoder, + std::function( + uint32_t ssrc, + const std::vector& seq_nums)> get_packet_infos) : clock_(clock), ssrcs_(ssrcs), - rtp_video_sender_(nullptr), + get_packet_infos_(std::move(get_packet_infos)), video_stream_encoder_(encoder), - time_last_intra_request_ms_(-1), - min_keyframe_send_interval_ms_( - KeyframeIntervalSettings::ParseFromFieldTrials() - .MinKeyframeSendIntervalMs() - .value_or(kMinKeyframeSendIntervalMs)) { + time_last_packet_delivery_queue_(Timestamp::Millis(0)), + min_keyframe_send_interval_( + TimeDelta::Millis(KeyframeIntervalSettings::ParseFromFieldTrials() + .MinKeyframeSendIntervalMs() + .value_or(kMinKeyframeSendIntervalMs))) { RTC_DCHECK(!ssrcs.empty()); + packet_delivery_queue_.Detach(); } -void EncoderRtcpFeedback::SetRtpVideoSender( - const RtpVideoSenderInterface* rtp_video_sender) { - RTC_DCHECK(rtp_video_sender); - RTC_DCHECK(!rtp_video_sender_); - rtp_video_sender_ = rtp_video_sender; -} +// Called via Call::DeliverRtcp. +void EncoderRtcpFeedback::OnReceivedIntraFrameRequest(uint32_t ssrc) { + RTC_DCHECK_RUN_ON(&packet_delivery_queue_); + RTC_DCHECK(std::find(ssrcs_.begin(), ssrcs_.end(), ssrc) != ssrcs_.end()); -bool EncoderRtcpFeedback::HasSsrc(uint32_t ssrc) { - for (uint32_t registered_ssrc : ssrcs_) { - if (registered_ssrc == ssrc) { - return true; - } - } - return false; -} + const Timestamp now = clock_->CurrentTime(); + if (time_last_packet_delivery_queue_ + min_keyframe_send_interval_ > now) + return; -void EncoderRtcpFeedback::OnReceivedIntraFrameRequest(uint32_t ssrc) { - RTC_DCHECK(HasSsrc(ssrc)); - { - int64_t now_ms = clock_->TimeInMilliseconds(); - MutexLock lock(&mutex_); - if (time_last_intra_request_ms_ + min_keyframe_send_interval_ms_ > now_ms) { - return; - } - time_last_intra_request_ms_ = now_ms; - } + time_last_packet_delivery_queue_ = now; // Always produce key frame for all streams. video_stream_encoder_->SendKeyFrame(); @@ -72,12 +64,12 @@ void EncoderRtcpFeedback::OnReceivedLossNotification( uint16_t seq_num_of_last_decodable, uint16_t seq_num_of_last_received, bool decodability_flag) { - RTC_DCHECK(rtp_video_sender_) << "Object initialization incomplete."; + RTC_DCHECK(get_packet_infos_) << "Object initialization incomplete."; const std::vector seq_nums = {seq_num_of_last_decodable, seq_num_of_last_received}; const std::vector infos = - rtp_video_sender_->GetSentRtpPacketInfos(ssrc, seq_nums); + get_packet_infos_(ssrc, seq_nums); if (infos.empty()) { return; } diff --git a/video/encoder_rtcp_feedback.h b/video/encoder_rtcp_feedback.h index 3bd1cb91f0..2aadcc34e7 100644 --- a/video/encoder_rtcp_feedback.h +++ b/video/encoder_rtcp_feedback.h @@ -10,12 +10,16 @@ #ifndef VIDEO_ENCODER_RTCP_FEEDBACK_H_ #define VIDEO_ENCODER_RTCP_FEEDBACK_H_ +#include #include +#include "api/sequence_checker.h" +#include "api/units/time_delta.h" +#include "api/units/timestamp.h" #include "api/video/video_stream_encoder_interface.h" #include "call/rtp_video_sender_interface.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" -#include "rtc_base/synchronization/mutex.h" +#include "rtc_base/system/no_unique_address.h" #include "system_wrappers/include/clock.h" namespace webrtc { @@ -27,13 +31,15 @@ class VideoStreamEncoderInterface; class EncoderRtcpFeedback : public RtcpIntraFrameObserver, public RtcpLossNotificationObserver { public: - EncoderRtcpFeedback(Clock* clock, - const std::vector& ssrcs, - VideoStreamEncoderInterface* encoder); + EncoderRtcpFeedback( + Clock* clock, + const std::vector& ssrcs, + VideoStreamEncoderInterface* encoder, + std::function( + uint32_t ssrc, + const std::vector& seq_nums)> get_packet_infos); ~EncoderRtcpFeedback() override = default; - void SetRtpVideoSender(const RtpVideoSenderInterface* rtp_video_sender); - void OnReceivedIntraFrameRequest(uint32_t ssrc) override; // Implements RtcpLossNotificationObserver. @@ -43,17 +49,19 @@ class EncoderRtcpFeedback : public RtcpIntraFrameObserver, bool decodability_flag) override; private: - bool HasSsrc(uint32_t ssrc); - Clock* const clock_; const std::vector ssrcs_; - const RtpVideoSenderInterface* rtp_video_sender_; + const std::function( + uint32_t ssrc, + const std::vector& seq_nums)> + get_packet_infos_; VideoStreamEncoderInterface* const video_stream_encoder_; - Mutex mutex_; - int64_t time_last_intra_request_ms_ RTC_GUARDED_BY(mutex_); + RTC_NO_UNIQUE_ADDRESS SequenceChecker packet_delivery_queue_; + Timestamp time_last_packet_delivery_queue_ + RTC_GUARDED_BY(packet_delivery_queue_); - const int min_keyframe_send_interval_ms_; + const TimeDelta min_keyframe_send_interval_; }; } // namespace webrtc diff --git a/video/encoder_rtcp_feedback_unittest.cc b/video/encoder_rtcp_feedback_unittest.cc index 81ac22b6c6..4cbb747e51 100644 --- a/video/encoder_rtcp_feedback_unittest.cc +++ b/video/encoder_rtcp_feedback_unittest.cc @@ -26,7 +26,8 @@ class VieKeyRequestTest : public ::testing::Test { encoder_rtcp_feedback_( &simulated_clock_, std::vector(1, VieKeyRequestTest::kSsrc), - &encoder_) {} + &encoder_, + nullptr) {} protected: const uint32_t kSsrc = 1234; diff --git a/video/end_to_end_tests/config_tests.cc b/video/end_to_end_tests/config_tests.cc index bf63e2a51f..1bd897cb34 100644 --- a/video/end_to_end_tests/config_tests.cc +++ b/video/end_to_end_tests/config_tests.cc @@ -104,7 +104,7 @@ TEST_F(ConfigEndToEndTest, VerifyDefaultFlexfecReceiveConfigParameters) { FlexfecReceiveStream::Config default_receive_config(&rtcp_send_transport); EXPECT_EQ(-1, default_receive_config.payload_type) << "Enabling FlexFEC requires rtpmap: flexfec negotiation."; - EXPECT_EQ(0U, default_receive_config.remote_ssrc) + EXPECT_EQ(0U, default_receive_config.rtp.remote_ssrc) << "Enabling FlexFEC requires ssrc-group: FEC-FR negotiation."; EXPECT_TRUE(default_receive_config.protected_media_ssrcs.empty()) << "Enabling FlexFEC requires ssrc-group: FEC-FR negotiation."; diff --git a/video/end_to_end_tests/fec_tests.cc b/video/end_to_end_tests/fec_tests.cc index 0d4ddac5a4..77ad9eb666 100644 --- a/video/end_to_end_tests/fec_tests.cc +++ b/video/end_to_end_tests/fec_tests.cc @@ -314,7 +314,7 @@ class FlexfecRenderObserver : public test::EndToEndTest, void ModifyFlexfecConfigs( std::vector* receive_configs) override { - (*receive_configs)[0].local_ssrc = kFlexfecLocalSsrc; + (*receive_configs)[0].rtp.local_ssrc = kFlexfecLocalSsrc; } void PerformTest() override { diff --git a/video/end_to_end_tests/network_state_tests.cc b/video/end_to_end_tests/network_state_tests.cc index 9abde3bb32..4e0e86f987 100644 --- a/video/end_to_end_tests/network_state_tests.cc +++ b/video/end_to_end_tests/network_state_tests.cc @@ -10,13 +10,19 @@ #include +#include "api/media_types.h" +#include "api/task_queue/default_task_queue_factory.h" +#include "api/task_queue/task_queue_base.h" +#include "api/task_queue/task_queue_factory.h" #include "api/test/simulated_network.h" #include "api/video_codecs/video_encoder.h" #include "call/fake_network_pipe.h" #include "call/simulated_network.h" #include "modules/rtp_rtcp/source/rtp_packet.h" +#include "rtc_base/location.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/task_queue_for_test.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "system_wrappers/include/sleep.h" #include "test/call_test.h" #include "test/fake_encoder.h" @@ -166,7 +172,10 @@ TEST_F(NetworkStateEndToEndTest, RespectsNetworkState) { explicit NetworkStateTest(TaskQueueBase* task_queue) : EndToEndTest(kDefaultTimeoutMs), FakeEncoder(Clock::GetRealTimeClock()), - task_queue_(task_queue), + e2e_test_task_queue_(task_queue), + task_queue_(CreateDefaultTaskQueueFactory()->CreateTaskQueue( + "NetworkStateTest", + TaskQueueFactory::Priority::NORMAL)), sender_call_(nullptr), receiver_call_(nullptr), encoder_factory_(this), @@ -219,26 +228,36 @@ TEST_F(NetworkStateEndToEndTest, RespectsNetworkState) { send_config->encoder_settings.encoder_factory = &encoder_factory_; } + void SignalChannelNetworkState(Call* call, + MediaType media_type, + NetworkState network_state) { + SendTask(RTC_FROM_HERE, e2e_test_task_queue_, + [call, media_type, network_state] { + call->SignalChannelNetworkState(media_type, network_state); + }); + } + void PerformTest() override { EXPECT_TRUE(encoded_frames_.Wait(kDefaultTimeoutMs)) << "No frames received by the encoder."; - SendTask(RTC_FROM_HERE, task_queue_, [this]() { + SendTask(RTC_FROM_HERE, task_queue_.get(), [this]() { // Wait for packets from both sender/receiver. WaitForPacketsOrSilence(false, false); // Sender-side network down for audio; there should be no effect on // video - sender_call_->SignalChannelNetworkState(MediaType::AUDIO, kNetworkDown); + SignalChannelNetworkState(sender_call_, MediaType::AUDIO, kNetworkDown); + WaitForPacketsOrSilence(false, false); // Receiver-side network down for audio; no change expected - receiver_call_->SignalChannelNetworkState(MediaType::AUDIO, - kNetworkDown); + SignalChannelNetworkState(receiver_call_, MediaType::AUDIO, + kNetworkDown); WaitForPacketsOrSilence(false, false); // Sender-side network down. - sender_call_->SignalChannelNetworkState(MediaType::VIDEO, kNetworkDown); + SignalChannelNetworkState(sender_call_, MediaType::VIDEO, kNetworkDown); { MutexLock lock(&test_mutex_); // After network goes down we shouldn't be encoding more frames. @@ -248,14 +267,14 @@ TEST_F(NetworkStateEndToEndTest, RespectsNetworkState) { WaitForPacketsOrSilence(true, false); // Receiver-side network down. - receiver_call_->SignalChannelNetworkState(MediaType::VIDEO, - kNetworkDown); + SignalChannelNetworkState(receiver_call_, MediaType::VIDEO, + kNetworkDown); WaitForPacketsOrSilence(true, true); // Network up for audio for both sides; video is still not expected to // start - sender_call_->SignalChannelNetworkState(MediaType::AUDIO, kNetworkUp); - receiver_call_->SignalChannelNetworkState(MediaType::AUDIO, kNetworkUp); + SignalChannelNetworkState(sender_call_, MediaType::AUDIO, kNetworkUp); + SignalChannelNetworkState(receiver_call_, MediaType::AUDIO, kNetworkUp); WaitForPacketsOrSilence(true, true); // Network back up again for both. @@ -265,8 +284,8 @@ TEST_F(NetworkStateEndToEndTest, RespectsNetworkState) { // network. sender_state_ = kNetworkUp; } - sender_call_->SignalChannelNetworkState(MediaType::VIDEO, kNetworkUp); - receiver_call_->SignalChannelNetworkState(MediaType::VIDEO, kNetworkUp); + SignalChannelNetworkState(sender_call_, MediaType::VIDEO, kNetworkUp); + SignalChannelNetworkState(receiver_call_, MediaType::VIDEO, kNetworkUp); WaitForPacketsOrSilence(false, false); // TODO(skvlad): add tests to verify that the audio streams are stopped @@ -340,7 +359,8 @@ TEST_F(NetworkStateEndToEndTest, RespectsNetworkState) { } } - TaskQueueBase* const task_queue_; + TaskQueueBase* const e2e_test_task_queue_; + std::unique_ptr task_queue_; Mutex test_mutex_; rtc::Event encoded_frames_; rtc::Event packet_event_; diff --git a/video/end_to_end_tests/resolution_bitrate_limits_tests.cc b/video/end_to_end_tests/resolution_bitrate_limits_tests.cc new file mode 100644 index 0000000000..d46c40cd1e --- /dev/null +++ b/video/end_to_end_tests/resolution_bitrate_limits_tests.cc @@ -0,0 +1,375 @@ +/* + * Copyright 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include + +#include "media/engine/webrtc_video_engine.h" +#include "rtc_base/experiments/encoder_info_settings.h" +#include "test/call_test.h" +#include "test/fake_encoder.h" +#include "test/field_trial.h" +#include "test/gtest.h" +#include "test/video_encoder_proxy_factory.h" + +namespace webrtc { +namespace test { +namespace { +void SetEncoderSpecific(VideoEncoderConfig* encoder_config, + VideoCodecType type, + size_t num_spatial_layers) { + if (type == kVideoCodecVP9) { + VideoCodecVP9 vp9 = VideoEncoder::GetDefaultVp9Settings(); + vp9.numberOfSpatialLayers = num_spatial_layers; + encoder_config->encoder_specific_settings = + rtc::make_ref_counted( + vp9); + } +} + +SpatialLayer GetLayer(int pixels, const VideoCodec& codec) { + if (codec.codecType == VideoCodecType::kVideoCodecVP9) { + for (size_t i = 0; i < codec.VP9().numberOfSpatialLayers; ++i) { + if (codec.spatialLayers[i].width * codec.spatialLayers[i].height == + pixels) { + return codec.spatialLayers[i]; + } + } + } else { + for (int i = 0; i < codec.numberOfSimulcastStreams; ++i) { + if (codec.simulcastStream[i].width * codec.simulcastStream[i].height == + pixels) { + return codec.simulcastStream[i]; + } + } + } + ADD_FAILURE(); + return SpatialLayer(); +} + +} // namespace + +class ResolutionBitrateLimitsTest + : public test::CallTest, + public ::testing::WithParamInterface { + public: + ResolutionBitrateLimitsTest() : payload_name_(GetParam()) {} + + const std::string payload_name_; +}; + +INSTANTIATE_TEST_SUITE_P(PayloadName, + ResolutionBitrateLimitsTest, + ::testing::Values("VP8", "VP9")); + +class InitEncodeTest : public test::EndToEndTest, + public test::FrameGeneratorCapturer::SinkWantsObserver, + public test::FakeEncoder { + public: + struct Bitrate { + const absl::optional min; + const absl::optional max; + }; + struct TestConfig { + const bool active; + const Bitrate bitrate_bps; + }; + struct Expectation { + const uint32_t pixels = 0; + const Bitrate eq_bitrate_bps; + const Bitrate ne_bitrate_bps; + }; + + InitEncodeTest(const std::string& payload_name, + const std::vector& configs, + const std::vector& expectations) + : EndToEndTest(test::CallTest::kDefaultTimeoutMs), + FakeEncoder(Clock::GetRealTimeClock()), + encoder_factory_(this), + payload_name_(payload_name), + configs_(configs), + expectations_(expectations) {} + + void OnFrameGeneratorCapturerCreated( + test::FrameGeneratorCapturer* frame_generator_capturer) override { + frame_generator_capturer->SetSinkWantsObserver(this); + // Set initial resolution. + frame_generator_capturer->ChangeResolution(1280, 720); + } + + void OnSinkWantsChanged(rtc::VideoSinkInterface* sink, + const rtc::VideoSinkWants& wants) override {} + + size_t GetNumVideoStreams() const override { + return (payload_name_ == "VP9") ? 1 : configs_.size(); + } + + void ModifyVideoConfigs( + VideoSendStream::Config* send_config, + std::vector* receive_configs, + VideoEncoderConfig* encoder_config) override { + send_config->encoder_settings.encoder_factory = &encoder_factory_; + send_config->rtp.payload_name = payload_name_; + send_config->rtp.payload_type = test::CallTest::kVideoSendPayloadType; + const VideoCodecType codec_type = PayloadStringToCodecType(payload_name_); + encoder_config->codec_type = codec_type; + encoder_config->video_stream_factory = + rtc::make_ref_counted( + payload_name_, /*max qp*/ 0, /*screencast*/ false, + /*screenshare enabled*/ false); + encoder_config->max_bitrate_bps = -1; + if (configs_.size() == 1 && configs_[0].bitrate_bps.max) + encoder_config->max_bitrate_bps = *configs_[0].bitrate_bps.max; + if (payload_name_ == "VP9") { + // Simulcast layers indicates which spatial layers are active. + encoder_config->simulcast_layers.resize(configs_.size()); + } + double scale_factor = 1.0; + for (int i = configs_.size() - 1; i >= 0; --i) { + VideoStream& stream = encoder_config->simulcast_layers[i]; + stream.active = configs_[i].active; + if (configs_[i].bitrate_bps.min) + stream.min_bitrate_bps = *configs_[i].bitrate_bps.min; + if (configs_[i].bitrate_bps.max) + stream.max_bitrate_bps = *configs_[i].bitrate_bps.max; + stream.scale_resolution_down_by = scale_factor; + scale_factor *= (payload_name_ == "VP9") ? 1.0 : 2.0; + } + SetEncoderSpecific(encoder_config, codec_type, configs_.size()); + } + + int32_t InitEncode(const VideoCodec* codec, + const Settings& settings) override { + for (const auto& expected : expectations_) { + SpatialLayer layer = GetLayer(expected.pixels, *codec); + if (expected.eq_bitrate_bps.min) + EXPECT_EQ(*expected.eq_bitrate_bps.min, layer.minBitrate * 1000); + if (expected.eq_bitrate_bps.max) + EXPECT_EQ(*expected.eq_bitrate_bps.max, layer.maxBitrate * 1000); + EXPECT_NE(expected.ne_bitrate_bps.min, layer.minBitrate * 1000); + EXPECT_NE(expected.ne_bitrate_bps.max, layer.maxBitrate * 1000); + } + observation_complete_.Set(); + return 0; + } + + VideoEncoder::EncoderInfo GetEncoderInfo() const override { + EncoderInfo info = FakeEncoder::GetEncoderInfo(); + if (!encoder_info_override_.resolution_bitrate_limits().empty()) { + info.resolution_bitrate_limits = + encoder_info_override_.resolution_bitrate_limits(); + } + return info; + } + + void PerformTest() override { + ASSERT_TRUE(Wait()) << "Timed out while waiting for InitEncode() call."; + } + + private: + test::VideoEncoderProxyFactory encoder_factory_; + const std::string payload_name_; + const std::vector configs_; + const std::vector expectations_; + const LibvpxVp8EncoderInfoSettings encoder_info_override_; +}; + +TEST_P(ResolutionBitrateLimitsTest, LimitsApplied) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-GetEncoderInfoOverride/" + "frame_size_pixels:921600," + "min_start_bitrate_bps:0," + "min_bitrate_bps:32000," + "max_bitrate_bps:3333000/"); + + InitEncodeTest test( + payload_name_, + {{/*active=*/true, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}}, + // Expectations: + {{1280 * 720, + /*eq_bitrate_bps=*/{32000, 3333000}, + /*ne_bitrate_bps=*/{absl::nullopt, absl::nullopt}}}); + RunBaseTest(&test); +} + +TEST_P(ResolutionBitrateLimitsTest, EncodingsApplied) { + InitEncodeTest test(payload_name_, + {{/*active=*/true, /*bitrate_bps=*/{22000, 3555000}}}, + // Expectations: + {{1280 * 720, + /*eq_bitrate_bps=*/{22000, 3555000}, + /*ne_bitrate_bps=*/{absl::nullopt, absl::nullopt}}}); + RunBaseTest(&test); +} + +TEST_P(ResolutionBitrateLimitsTest, IntersectionApplied) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-GetEncoderInfoOverride/" + "frame_size_pixels:921600," + "min_start_bitrate_bps:0," + "min_bitrate_bps:32000," + "max_bitrate_bps:3333000/"); + + InitEncodeTest test(payload_name_, + {{/*active=*/true, /*bitrate_bps=*/{22000, 1555000}}}, + // Expectations: + {{1280 * 720, + /*eq_bitrate_bps=*/{32000, 1555000}, + /*ne_bitrate_bps=*/{absl::nullopt, absl::nullopt}}}); + RunBaseTest(&test); +} + +TEST_P(ResolutionBitrateLimitsTest, LimitsAppliedMiddleActive) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-GetEncoderInfoOverride/" + "frame_size_pixels:230400|921600," + "min_start_bitrate_bps:0|0," + "min_bitrate_bps:21000|32000," + "max_bitrate_bps:2222000|3333000/"); + + InitEncodeTest test( + payload_name_, + {{/*active=*/false, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}, + {/*active=*/true, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}, + {/*active=*/false, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}}, + // Expectations: + {{640 * 360, + /*eq_bitrate_bps=*/{21000, 2222000}, + /*ne_bitrate_bps=*/{absl::nullopt, absl::nullopt}}}); + + RunBaseTest(&test); +} + +TEST_P(ResolutionBitrateLimitsTest, IntersectionAppliedMiddleActive) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-GetEncoderInfoOverride/" + "frame_size_pixels:230400|921600," + "min_start_bitrate_bps:0|0," + "min_bitrate_bps:31000|32000," + "max_bitrate_bps:2222000|3333000/"); + + InitEncodeTest test( + payload_name_, + {{/*active=*/false, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}, + {/*active=*/true, /*bitrate_bps=*/{30000, 1555000}}, + {/*active=*/false, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}}, + // Expectations: + {{640 * 360, + /*eq_bitrate_bps=*/{31000, 1555000}, + /*ne_bitrate_bps=*/{absl::nullopt, absl::nullopt}}}); + RunBaseTest(&test); +} + +TEST_P(ResolutionBitrateLimitsTest, DefaultLimitsAppliedMiddleActive) { + const absl::optional + kDefaultSinglecastLimits360p = + EncoderInfoSettings::GetDefaultSinglecastBitrateLimitsForResolution( + PayloadStringToCodecType(payload_name_), 640 * 360); + + InitEncodeTest test( + payload_name_, + {{/*active=*/false, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}, + {/*active=*/true, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}, + {/*active=*/false, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}}, + // Expectations: + {{640 * 360, + /*eq_bitrate_bps=*/ + {kDefaultSinglecastLimits360p->min_bitrate_bps, + kDefaultSinglecastLimits360p->max_bitrate_bps}, + /*ne_bitrate_bps=*/{absl::nullopt, absl::nullopt}}}); + + RunBaseTest(&test); +} + +TEST_P(ResolutionBitrateLimitsTest, LimitsAppliedHighestActive) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-GetEncoderInfoOverride/" + "frame_size_pixels:230400|921600," + "min_start_bitrate_bps:0|0," + "min_bitrate_bps:31000|32000," + "max_bitrate_bps:2222000|3333000/"); + + InitEncodeTest test( + payload_name_, + {{/*active=*/false, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}, + {/*active=*/false, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}, + {/*active=*/true, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}}, + // Expectations: + {{1280 * 720, + /*eq_bitrate_bps=*/{32000, 3333000}, + /*ne_bitrate_bps=*/{absl::nullopt, absl::nullopt}}}); + RunBaseTest(&test); +} + +TEST_P(ResolutionBitrateLimitsTest, IntersectionAppliedHighestActive) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-GetEncoderInfoOverride/" + "frame_size_pixels:230400|921600," + "min_start_bitrate_bps:0|0," + "min_bitrate_bps:31000|32000," + "max_bitrate_bps:2222000|3333000/"); + + InitEncodeTest test( + payload_name_, + {{/*active=*/false, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}, + {/*active=*/false, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}, + {/*active=*/true, /*bitrate_bps=*/{30000, 1555000}}}, + // Expectations: + {{1280 * 720, + /*eq_bitrate_bps=*/{32000, 1555000}, + /*ne_bitrate_bps=*/{absl::nullopt, absl::nullopt}}}); + RunBaseTest(&test); +} + +TEST_P(ResolutionBitrateLimitsTest, LimitsNotAppliedLowestActive) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-GetEncoderInfoOverride/" + "frame_size_pixels:230400|921600," + "min_start_bitrate_bps:0|0," + "min_bitrate_bps:31000|32000," + "max_bitrate_bps:2222000|3333000/"); + + InitEncodeTest test( + payload_name_, + {{/*active=*/true, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}, + {/*active=*/false, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}}, + // Expectations: + {{640 * 360, + /*eq_bitrate_bps=*/{absl::nullopt, absl::nullopt}, + /*ne_bitrate_bps=*/{31000, 2222000}}, + {1280 * 720, + /*eq_bitrate_bps=*/{absl::nullopt, absl::nullopt}, + /*ne_bitrate_bps=*/{32000, 3333000}}}); + RunBaseTest(&test); +} + +TEST_P(ResolutionBitrateLimitsTest, LimitsNotAppliedSimulcast) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-GetEncoderInfoOverride/" + "frame_size_pixels:230400|921600," + "min_start_bitrate_bps:0|0," + "min_bitrate_bps:31000|32000," + "max_bitrate_bps:2222000|3333000/"); + + InitEncodeTest test( + payload_name_, + {{/*active=*/true, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}, + {/*active=*/true, /*bitrate_bps=*/{absl::nullopt, absl::nullopt}}}, + // Expectations: + {{640 * 360, + /*eq_bitrate_bps=*/{absl::nullopt, absl::nullopt}, + /*ne_bitrate_bps=*/{31000, 2222000}}, + {1280 * 720, + /*eq_bitrate_bps=*/{absl::nullopt, absl::nullopt}, + /*ne_bitrate_bps=*/{32000, 3333000}}}); + RunBaseTest(&test); +} + +} // namespace test +} // namespace webrtc diff --git a/video/end_to_end_tests/rtp_rtcp_tests.cc b/video/end_to_end_tests/rtp_rtcp_tests.cc index 76018027d6..a698328dad 100644 --- a/video/end_to_end_tests/rtp_rtcp_tests.cc +++ b/video/end_to_end_tests/rtp_rtcp_tests.cc @@ -316,7 +316,7 @@ void RtpRtcpEndToEndTest::TestRtpStatePreservation( } GetVideoEncoderConfig()->video_stream_factory = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); // Use the same total bitrates when sending a single stream to avoid // lowering the bitrate estimate and requiring a subsequent rampup. one_stream = GetVideoEncoderConfig()->Copy(); @@ -537,12 +537,13 @@ TEST_F(RtpRtcpEndToEndTest, DISABLED_TestFlexfecRtpStatePreservation) { receive_transport.get()); flexfec_receive_config.payload_type = GetVideoSendConfig()->rtp.flexfec.payload_type; - flexfec_receive_config.remote_ssrc = GetVideoSendConfig()->rtp.flexfec.ssrc; + flexfec_receive_config.rtp.remote_ssrc = + GetVideoSendConfig()->rtp.flexfec.ssrc; flexfec_receive_config.protected_media_ssrcs = GetVideoSendConfig()->rtp.flexfec.protected_media_ssrcs; - flexfec_receive_config.local_ssrc = kReceiverLocalVideoSsrc; - flexfec_receive_config.transport_cc = true; - flexfec_receive_config.rtp_header_extensions.emplace_back( + flexfec_receive_config.rtp.local_ssrc = kReceiverLocalVideoSsrc; + flexfec_receive_config.rtp.transport_cc = true; + flexfec_receive_config.rtp.extensions.emplace_back( RtpExtension::kTransportSequenceNumberUri, kTransportSequenceNumberExtensionId); flexfec_receive_configs_.push_back(flexfec_receive_config); diff --git a/video/end_to_end_tests/ssrc_tests.cc b/video/end_to_end_tests/ssrc_tests.cc index cedae3934d..bdca05d647 100644 --- a/video/end_to_end_tests/ssrc_tests.cc +++ b/video/end_to_end_tests/ssrc_tests.cc @@ -14,6 +14,7 @@ #include "call/fake_network_pipe.h" #include "call/simulated_network.h" #include "modules/rtp_rtcp/source/rtp_packet.h" +#include "modules/rtp_rtcp/source/rtp_util.h" #include "rtc_base/task_queue_for_test.h" #include "test/call_test.h" #include "test/gtest.h" @@ -60,7 +61,7 @@ TEST_F(SsrcEndToEndTest, UnknownRtpPacketGivesUnknownSsrcReturnCode) { DeliveryStatus DeliverPacket(MediaType media_type, rtc::CopyOnWriteBuffer packet, int64_t packet_time_us) override { - if (RtpHeaderParser::IsRtcp(packet.cdata(), packet.size())) { + if (IsRtcpPacket(packet)) { return receiver_->DeliverPacket(media_type, std::move(packet), packet_time_us); } @@ -132,13 +133,15 @@ void SsrcEndToEndTest::TestSendsSetSsrcs(size_t num_ssrcs, public: SendsSetSsrcs(const uint32_t* ssrcs, size_t num_ssrcs, - bool send_single_ssrc_first) + bool send_single_ssrc_first, + TaskQueueBase* task_queue) : EndToEndTest(kDefaultTimeoutMs), num_ssrcs_(num_ssrcs), send_single_ssrc_first_(send_single_ssrc_first), ssrcs_to_observe_(num_ssrcs), expect_single_ssrc_(send_single_ssrc_first), - send_stream_(nullptr) { + send_stream_(nullptr), + task_queue_(task_queue) { for (size_t i = 0; i < num_ssrcs; ++i) valid_ssrcs_[ssrcs[i]] = true; } @@ -200,8 +203,10 @@ void SsrcEndToEndTest::TestSendsSetSsrcs(size_t num_ssrcs, if (send_single_ssrc_first_) { // Set full simulcast and continue with the rest of the SSRCs. - send_stream_->ReconfigureVideoEncoder( - std::move(video_encoder_config_all_streams_)); + SendTask(RTC_FROM_HERE, task_queue_, [&]() { + send_stream_->ReconfigureVideoEncoder( + std::move(video_encoder_config_all_streams_)); + }); EXPECT_TRUE(Wait()) << "Timed out while waiting on additional SSRCs."; } } @@ -218,7 +223,8 @@ void SsrcEndToEndTest::TestSendsSetSsrcs(size_t num_ssrcs, VideoSendStream* send_stream_; VideoEncoderConfig video_encoder_config_all_streams_; - } test(kVideoSendSsrcs, num_ssrcs, send_single_ssrc_first); + TaskQueueBase* task_queue_; + } test(kVideoSendSsrcs, num_ssrcs, send_single_ssrc_first, task_queue()); RunBaseTest(&test); } diff --git a/video/end_to_end_tests/stats_tests.cc b/video/end_to_end_tests/stats_tests.cc index ae0532b9a3..54e7bcff1c 100644 --- a/video/end_to_end_tests/stats_tests.cc +++ b/video/end_to_end_tests/stats_tests.cc @@ -17,7 +17,7 @@ #include "api/test/video/function_video_encoder_factory.h" #include "call/fake_network_pipe.h" #include "call/simulated_network.h" -#include "modules/rtp_rtcp/source/rtp_utility.h" +#include "modules/rtp_rtcp/source/rtp_packet.h" #include "modules/video_coding/include/video_coding_defines.h" #include "rtc_base/strings/string_builder.h" #include "rtc_base/synchronization/mutex.h" @@ -71,12 +71,11 @@ TEST_F(StatsEndToEndTest, GetStats) { Action OnSendRtp(const uint8_t* packet, size_t length) override { // Drop every 25th packet => 4% loss. static const int kPacketLossFrac = 25; - RTPHeader header; - RtpUtility::RtpHeaderParser parser(packet, length); - if (parser.Parse(&header) && - expected_send_ssrcs_.find(header.ssrc) != + RtpPacket header; + if (header.Parse(packet, length) && + expected_send_ssrcs_.find(header.Ssrc()) != expected_send_ssrcs_.end() && - header.sequenceNumber % kPacketLossFrac == 0) { + header.SequenceNumber() % kPacketLossFrac == 0) { return DROP_PACKET; } check_stats_event_.Set(); @@ -143,8 +142,8 @@ TEST_F(StatsEndToEndTest, GetStats) { stats.rtcp_packet_type_counts.nack_requests != 0 || stats.rtcp_packet_type_counts.unique_nack_requests != 0; - assert(stats.current_payload_type == -1 || - stats.current_payload_type == kFakeVideoSendPayloadType); + RTC_DCHECK(stats.current_payload_type == -1 || + stats.current_payload_type == kFakeVideoSendPayloadType); receive_stats_filled_["IncomingPayloadType"] |= stats.current_payload_type == kFakeVideoSendPayloadType; } @@ -154,7 +153,10 @@ TEST_F(StatsEndToEndTest, GetStats) { bool CheckSendStats() { RTC_DCHECK(send_stream_); - VideoSendStream::Stats stats = send_stream_->GetStats(); + + VideoSendStream::Stats stats; + SendTask(RTC_FROM_HERE, task_queue_, + [&]() { stats = send_stream_->GetStats(); }); size_t expected_num_streams = kNumSimulcastStreams + expected_send_ssrcs_.size(); @@ -179,9 +181,7 @@ TEST_F(StatsEndToEndTest, GetStats) { const VideoSendStream::StreamStats& stream_stats = kv.second; send_stats_filled_[CompoundKey("StatisticsUpdated", kv.first)] |= - stream_stats.rtcp_stats.packets_lost != 0 || - stream_stats.rtcp_stats.extended_highest_sequence_number != 0 || - stream_stats.rtcp_stats.fraction_lost != 0; + stream_stats.report_block_data.has_value(); send_stats_filled_[CompoundKey("DataCountersUpdated", kv.first)] |= stream_stats.rtp_stats.fec.packets != 0 || @@ -612,11 +612,9 @@ TEST_F(StatsEndToEndTest, VerifyNackStats) { Action OnSendRtp(const uint8_t* packet, size_t length) override { MutexLock lock(&mutex_); if (++sent_rtp_packets_ == kPacketNumberToDrop) { - std::unique_ptr parser( - RtpHeaderParser::CreateForTest()); - RTPHeader header; - EXPECT_TRUE(parser->Parse(packet, length, &header)); - dropped_rtp_packet_ = header.sequenceNumber; + RtpPacket header; + EXPECT_TRUE(header.Parse(packet, length)); + dropped_rtp_packet_ = header.SequenceNumber(); return DROP_PACKET; } task_queue_->PostTask(std::unique_ptr(this)); diff --git a/video/frame_encode_metadata_writer.cc b/video/frame_encode_metadata_writer.cc index 0e604cd765..8a0f3b3867 100644 --- a/video/frame_encode_metadata_writer.cc +++ b/video/frame_encode_metadata_writer.cc @@ -217,7 +217,7 @@ void FrameEncodeMetadataWriter::UpdateBitstream( buffer, encoded_image->ColorSpace()); encoded_image->SetEncodedData( - new rtc::RefCountedObject( + rtc::make_ref_counted( std::move(modified_buffer))); } diff --git a/video/full_stack_tests.cc b/video/full_stack_tests.cc index ece756b2dc..3831fdfcef 100644 --- a/video/full_stack_tests.cc +++ b/video/full_stack_tests.cc @@ -21,7 +21,7 @@ #include "api/video_codecs/sdp_video_format.h" #include "api/video_codecs/video_codec.h" #include "api/video_codecs/video_encoder_config.h" -#include "media/base/vp9_profile.h" +#include "api/video_codecs/vp9_profile.h" #include "modules/video_coding/codecs/vp9/include/vp9.h" #include "system_wrappers/include/field_trial.h" #include "test/field_trial.h" diff --git a/video/g3doc/adaptation.md b/video/g3doc/adaptation.md new file mode 100644 index 0000000000..084a0fd3aa --- /dev/null +++ b/video/g3doc/adaptation.md @@ -0,0 +1,114 @@ + + + +# Video Adaptation + +Video adaptation is a mechanism which reduces the bandwidth or CPU consumption +by reducing encoded video quality. + +## Overview + +Adaptation occurs when a _Resource_ signals that it is currently underused or +overused. When overused, the video quality is decreased and when underused, the +video quality is increased. There are currently two dimensions in which the +quality can be adapted: frame-rate and resolution. The dimension that is adapted +is based on the degradation preference for the video track. + +## Resources + +_Resources_ monitor metrics from the system or the video stream. For example, a +resource could monitor system temperature or the bandwidth usage of the video +stream. A resource implements the [Resource][resource.h] interface. When a +resource detects that it is overused, it calls `SetUsageState(kOveruse)`. When +the resource is no longer overused, it can signal this using +`SetUsageState(kUnderuse)`. + +There are two resources that are used by default on all video tracks: Quality +scaler resource and encode overuse resource. + +### QP Scaler Resource + +The quality scaler resource monitors the quantization parameter (QP) of the +encoded video frames for video send stream and ensures that the quality of the +stream is acceptable for the current resolution. After each frame is encoded the +[QualityScaler][quality_scaler.h] is given the QP of the encoded frame. Overuse +or underuse is signalled when the average QP is outside of the +[QP thresholds][VideoEncoder::QpThresholds]. If the average QP is above the +_high_ threshold, the QP scaler signals _overuse_, and when below the _low_ +threshold the QP scaler signals _underuse_. + +The thresholds are set by the video encoder in the `scaling_settings` property +of the [EncoderInfo][EncoderInfo]. + +*Note:* that the QP scaler is only enabled when the degradation preference is +`MAINTAIN_FRAMERATE` or `BALANCED`. + +### Encode Usage Resource + +The [encoder usage resource][encode_usage_resource.h] monitors how long it takes +to encode a video frame. This works as a good proxy measurement for CPU usage as +contention increases when CPU usage is high, increasing the encode times of the +video frames. + +The time is tracked from when frame encoding starts to when it is completed. If +the average encoder usage exceeds the thresholds set, *overuse* is triggered. + +### Injecting other Resources + +A custom resource can be injected into the call using the +[Call::AddAdaptationResource][Call::AddAdaptationResource] method. + +## Adaptation + +When a a *resource* signals the it is over or underused, this signal reaches the +`ResourceAdaptationProcessor` who requests an `Adaptation` proposal from the +[VideoStreamAdapter][VideoStreamAdapter]. This proposal is based on the +degradation preference of the video stream. `ResourceAdaptationProcessor` will +determine if the `Adaptation` should be applied based on the current adaptation +status and the `Adaptation` proposal. + +### Degradation Preference + +There are 3 degradation preferences, described in the +[RtpParameters][RtpParameters] header. These are + +* `MAINTIAIN_FRAMERATE`: Adapt video resolution +* `MAINTIAIN_RESOLUTION`: Adapt video frame-rate. +* `BALANCED`: Adapt video frame-rate or resolution. + +The degradation preference is set for a video track using the +`degradation_preference` property in the [RtpParameters][RtpParameters]. + +## VideoSinkWants and video stream adaptation + +Once an adaptation is applied it notifies the video stream. The video stream +converts this adaptation to a [VideoSinkWants][VideoSinkWants]. These sink wants +indicate to the video stream that some restrictions should be applied to the +stream before it is sent to encoding. It has a few properties, but for +adaptation the properties that might be set are: + +* `target_pixel_count`: The desired number of pixels for each video frame. The + actual pixel count should be close to this but does not have to be exact so + that aspect ratio can be maintained. +* `max_pixel_count`: The maximum number of pixels in each video frame. This + value can not be exceeded if set. +* `max_framerate_fps`: The maximum frame-rate for the video source. The source + is expected to drop frames that cause this threshold to be exceeded. + +The `VideoSinkWants` can be applied by any video source, or one may use the +[AdaptedVideoTraceSource][adapted_video_track_source.h] which is a base class +for sources that need video adaptation. + +[RtpParameters]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/rtp_parameters.h?q=%22RTC_EXPORT%20RtpParameters%22 +[resource.h]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/adaptation/resource.h +[Call::AddAdaptationResource]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/call/call.h?q=Call::AddAdaptationResource +[quality_scaler.h]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/video_coding/utility/quality_scaler.h +[VideoEncoder::QpThresholds]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/video_codecs/video_encoder.h?q=VideoEncoder::QpThresholds +[EncoderInfo]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/video_codecs/video_encoder.h?q=VideoEncoder::EncoderInfo +[encode_usage_resource.h]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/video/adaptation/encode_usage_resource.h +[VideoStreamAdapter]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/call/adaptation/video_stream_adapter.h +[adaptation_constraint.h]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/call/adaptation/adaptation_constraint.h +[bitrate_constraint.h]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/video/adaptation/bitrate_constraint.h +[AddOrUpdateSink]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/video/video_source_interface.h?q=AddOrUpdateSink +[VideoSinkWants]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/api/video/video_source_interface.h?q=%22RTC_EXPORT%20VideoSinkWants%22 +[adapted_video_track_source.h]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/media/base/adapted_video_track_source.h diff --git a/video/g3doc/stats.md b/video/g3doc/stats.md new file mode 100644 index 0000000000..a5d15fe2fa --- /dev/null +++ b/video/g3doc/stats.md @@ -0,0 +1,217 @@ + + + +# Video stats + +Overview of collected statistics for [VideoSendStream] and [VideoReceiveStream]. + +## VideoSendStream + +[VideoSendStream::Stats] for a sending stream can be gathered via `VideoSendStream::GetStats()`. + +Some statistics are collected per RTP stream (see [StreamStats]) and can be of `StreamType`: `kMedia`, `kRtx`, `kFlexfec`. + +Multiple `StreamStats` objects are for example present if simulcast is used (multiple `kMedia` objects) or if RTX or FlexFEC is negotiated. + +### SendStatisticsProxy +`VideoSendStream` owns a [SendStatisticsProxy] which implements +`VideoStreamEncoderObserver`, +`RtcpStatisticsCallback`, +`ReportBlockDataObserver`, +`RtcpPacketTypeCounterObserver`, +`StreamDataCountersCallback`, +`BitrateStatisticsObserver`, +`FrameCountObserver`, +`SendSideDelayObserver` +and holds a `VideoSendStream::Stats` object. + +`SendStatisticsProxy` is called via these interfaces by different components (e.g. `RtpRtcp` module) to update stats. + +#### StreamStats +* `type` - kMedia, kRtx or kFlexfec. +* `referenced_media_ssrc` - only present for type kRtx/kFlexfec. The SSRC for the kMedia stream that retransmissions or FEC is performed for. + +Updated when a frame has been encoded, `VideoStreamEncoder::OnEncodedImage`. +* `frames_encoded `- total number of encoded frames. +* `encode_frame_rate` - number of encoded frames during the last second. +* `width` - width of last encoded frame [[rtcoutboundrtpstreamstats-framewidth]]. +* `height` - height of last encoded frame [[rtcoutboundrtpstreamstats-frameheight]]. +* `total_encode_time_ms` - total encode time for encoded frames. +* `qp_sum` - sum of quantizer values of encoded frames [[rtcoutboundrtpstreamstats-qpsum]]. +* `frame_counts` - total number of encoded key/delta frames [[rtcoutboundrtpstreamstats-keyframesencoded]]. + +Updated when a RTP packet is transmitted to the network, `RtpSenderEgress::SendPacket`. +* `rtp_stats` - total number of sent bytes/packets. +* `total_bitrate_bps` - total bitrate sent in bits per second (over a one second window). +* `retransmit_bitrate_bps` - total retransmit bitrate sent in bits per second (over a one second window). +* `avg_delay_ms` - average capture-to-send delay for sent packets (over a one second window). +* `max_delay_ms` - maximum capture-to-send delay for sent packets (over a one second window). +* `total_packet_send_delay_ms` - total capture-to-send delay for sent packets [[rtcoutboundrtpstreamstats-totalpacketsenddelay]]. + +Updated when an incoming RTCP packet is parsed, `RTCPReceiver::ParseCompoundPacket`. +* `rtcp_packet_type_counts` - total number of received NACK/FIR/PLI packets [rtcoutboundrtpstreamstats-[nackcount], [fircount], [plicount]]. + +Updated when a RTCP report block packet is received, `RTCPReceiver::TriggerCallbacksFromRtcpPacket`. +* `rtcp_stats` - RTCP report block data. +* `report_block_data` - RTCP report block data. + +#### Stats +* `std::map substreams` - StreamStats mapped per SSRC. + +Updated when a frame is received from the source, `VideoStreamEncoder::OnFrame`. +* `frames` - total number of frames fed to VideoStreamEncoder. +* `input_frame_rate` - number of frames fed to VideoStreamEncoder during the last second. +* `frames_dropped_by_congestion_window` - total number of dropped frames due to congestion window pushback. +* `frames_dropped_by_encoder_queue` - total number of dropped frames due to that the encoder is blocked. + +Updated if a frame from the source is dropped, `VideoStreamEncoder::OnDiscardedFrame`. +* `frames_dropped_by_capturer` - total number dropped frames by the source. + +Updated if a frame is dropped by `FrameDropper`, `VideoStreamEncoder::MaybeEncodeVideoFrame`. +* `frames_dropped_by_rate_limiter` - total number of dropped frames to avoid bitrate overuse. + +Updated (if changed) before a frame is passed to the encoder, `VideoStreamEncoder::EncodeVideoFrame`. +* `encoder_implementation_name` - name of encoder implementation [[rtcoutboundrtpstreamstats-encoderimplementation]]. + +Updated after a frame has been encoded, `VideoStreamEncoder::OnEncodedImage`. +* `frames_encoded `- total number of encoded frames [[rtcoutboundrtpstreamstats-framesencoded]]. +* `encode_frame_rate` - number of encoded frames during the last second [[rtcoutboundrtpstreamstats-framespersecond]]. +* `total_encoded_bytes_target` - total target frame size in bytes [[rtcoutboundrtpstreamstats-totalencodedbytestarget]]. +* `huge_frames_sent` - total number of huge frames sent [[rtcoutboundrtpstreamstats-hugeframessent]]. +* `media_bitrate_bps` - the actual bitrate the encoder is producing. +* `avg_encode_time_ms` - average encode time for encoded frames. +* `total_encode_time_ms` - total encode time for encoded frames [[rtcoutboundrtpstreamstats-totalencodetime]]. +* `frames_dropped_by_encoder`- total number of dropped frames by the encoder. + +Adaptation stats. +* `bw_limited_resolution` - shows if resolution is limited due to restricted bandwidth. +* `cpu_limited_resolution` - shows if resolution is limited due to cpu. +* `bw_limited_framerate` - shows if framerate is limited due to restricted bandwidth. +* `cpu_limited_framerate` - shows if framerate is limited due to cpu. +* `quality_limitation_reason` - current reason for limiting resolution and/or framerate [[rtcoutboundrtpstreamstats-qualitylimitationreason]]. +* `quality_limitation_durations_ms` - total time spent in quality limitation state [[rtcoutboundrtpstreamstats-qualitylimitationdurations]]. +* `quality_limitation_resolution_changes` - total number of times that resolution has changed due to quality limitation [[rtcoutboundrtpstreamstats-qualitylimitationresolutionchanges]]. +* `number_of_cpu_adapt_changes` - total number of times resolution/framerate has changed due to cpu limitation. +* `number_of_quality_adapt_changes` - total number of times resolution/framerate has changed due to quality limitation. + +Updated when the encoder is configured, `VideoStreamEncoder::ReconfigureEncoder`. +* `content_type` - configured content type (UNSPECIFIED/SCREENSHARE). + +Updated when the available bitrate changes, `VideoSendStreamImpl::OnBitrateUpdated`. +* `target_media_bitrate_bps` - the bitrate the encoder is configured to use. +* `suspended` - shows if video is suspended due to zero target bitrate. + +## VideoReceiveStream +[VideoReceiveStream::Stats] for a receiving stream can be gathered via `VideoReceiveStream::GetStats()`. + +### ReceiveStatisticsProxy +`VideoReceiveStream` owns a [ReceiveStatisticsProxy] which implements +`VCMReceiveStatisticsCallback`, +`RtcpCnameCallback`, +`RtcpPacketTypeCounterObserver`, +`CallStatsObserver` +and holds a `VideoReceiveStream::Stats` object. + +`ReceiveStatisticsProxy` is called via these interfaces by different components (e.g. `RtpRtcp` module) to update stats. + +#### Stats +* `current_payload_type` - current payload type. +* `ssrc` - configured SSRC for the received stream. + +Updated when a complete frame is received, `FrameBuffer::InsertFrame`. +* `frame_counts` - total number of key/delta frames received [[rtcinboundrtpstreamstats-keyframesdecoded]]. +* `network_frame_rate` - number of frames received during the last second. + +Updated when a frame is ready for decoding, `FrameBuffer::GetNextFrame`. From `VCMTiming`: +* `jitter_buffer_ms` - jitter buffer delay in ms. +* `max_decode_ms` - the 95th percentile observed decode time within a time window (10 sec). +* `render_delay_ms` - render delay in ms. +* `min_playout_delay_ms` - minimum playout delay in ms. +* `target_delay_ms` - target playout delay in ms. Max(`min_playout_delay_ms`, `jitter_delay_ms` + `max_decode_ms` + `render_delay_ms`). +* `current_delay_ms` - actual playout delay in ms. +* `jitter_buffer_delay_seconds` - total jitter buffer delay in seconds [[rtcinboundrtpstreamstats-jitterbufferdelay]]. +* `jitter_buffer_emitted_count` - total number of frames that have come out from the jitter buffer [[rtcinboundrtpstreamstats-jitterbufferemittedcount]]. + +Updated (if changed) after a frame is passed to the decoder, `VCMGenericDecoder::Decode`. +* `decoder_implementation_name` - name of decoder implementation [[rtcinboundrtpstreamstats-decoderimplementation]]. + +Updated when a frame is ready for decoding, `FrameBuffer::GetNextFrame`. +* `timing_frame_info` - timestamps for a full lifetime of a frame. +* `first_frame_received_to_decoded_ms` - initial decoding latency between the first arrived frame and the first decoded frame. +* `frames_dropped` - total number of dropped frames prior to decoding or if the system is too slow [[rtcreceivedrtpstreamstats-framesdropped]]. + +Updated after a frame has been decoded, `VCMDecodedFrameCallback::Decoded`. +* `frames_decoded` - total number of decoded frames [[rtcinboundrtpstreamstats-framesdecoded]]. +* `decode_frame_rate` - number of decoded frames during the last second [[rtcinboundrtpstreamstats-framespersecond]]. +* `decode_ms` - time to decode last frame in ms. +* `total_decode_time_ms` - total decode time for decoded frames [[rtcinboundrtpstreamstats-totaldecodetime]]. +* `qp_sum` - sum of quantizer values of decoded frames [[rtcinboundrtpstreamstats-qpsum]]. +* `content_type` - content type (UNSPECIFIED/SCREENSHARE). +* `interframe_delay_max_ms` - max inter-frame delay within a time window between decoded frames. +* `total_inter_frame_delay` - sum of inter-frame delay in seconds between decoded frames [[rtcinboundrtpstreamstats-totalinterframedelay]]. +* `total_squared_inter_frame_delay` - sum of squared inter-frame delays in seconds between decoded frames [[rtcinboundrtpstreamstats-totalsquaredinterframedelay]]. + +Updated before a frame is sent to the renderer, `VideoReceiveStream2::OnFrame`. +* `frames_rendered` - total number of rendered frames. +* `render_frame_rate` - number of rendered frames during the last second. +* `width` - width of last frame fed to renderer [[rtcinboundrtpstreamstats-framewidth]]. +* `height` - height of last frame fed to renderer [[rtcinboundrtpstreamstats-frameheight]]. +* `estimated_playout_ntp_timestamp_ms` - estimated playout NTP timestamp [[rtcinboundrtpstreamstats-estimatedplayouttimestamp]]. +* `sync_offset_ms` - NTP timestamp difference between the last played out audio and video frame. +* `freeze_count` - total number of detected freezes. +* `pause_count` - total number of detected pauses. +* `total_freezes_duration_ms` - total duration of freezes in ms. +* `total_pauses_duration_ms` - total duration of pauses in ms. +* `total_frames_duration_ms` - time in ms between the last rendered frame and the first rendered frame. +* `sum_squared_frame_durations` - sum of squared inter-frame delays in seconds between rendered frames. + +`ReceiveStatisticsImpl::OnRtpPacket` is updated for received RTP packets. From `ReceiveStatistics`: +* `total_bitrate_bps` - incoming bitrate in bps. +* `rtp_stats` - RTP statistics for the received stream. + +Updated when a RTCP packet is sent, `RTCPSender::ComputeCompoundRTCPPacket`. +* `rtcp_packet_type_counts` - total number of sent NACK/FIR/PLI packets [rtcinboundrtpstreamstats-[nackcount], [fircount], [plicount]]. + + +[VideoSendStream]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/call/video_send_stream.h +[VideoSendStream::Stats]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/call/video_send_stream.h?q=VideoSendStream::Stats +[StreamStats]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/call/video_send_stream.h?q=VideoSendStream::StreamStats +[SendStatisticsProxy]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/video/send_statistics_proxy.h +[rtcoutboundrtpstreamstats-framewidth]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-framewidth +[rtcoutboundrtpstreamstats-frameheight]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-frameheight +[rtcoutboundrtpstreamstats-qpsum]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-qpsum +[rtcoutboundrtpstreamstats-keyframesencoded]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-keyframesencoded +[rtcoutboundrtpstreamstats-totalpacketsenddelay]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-totalpacketsenddelay +[nackcount]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-nackcount +[fircount]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-fircount +[plicount]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-plicount +[rtcoutboundrtpstreamstats-encoderimplementation]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-encoderimplementation +[rtcoutboundrtpstreamstats-framesencoded]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-framesencoded +[rtcoutboundrtpstreamstats-framespersecond]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-framespersecond +[rtcoutboundrtpstreamstats-totalencodedbytestarget]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-totalencodedbytestarget +[rtcoutboundrtpstreamstats-hugeframessent]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-hugeframessent +[rtcoutboundrtpstreamstats-totalencodetime]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-totalencodetime +[rtcoutboundrtpstreamstats-qualitylimitationreason]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-qualitylimitationreason +[rtcoutboundrtpstreamstats-qualitylimitationdurations]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-qualitylimitationdurations +[rtcoutboundrtpstreamstats-qualitylimitationresolutionchanges]: https://w3c.github.io/webrtc-stats/#dom-rtcoutboundrtpstreamstats-qualitylimitationresolutionchanges + +[VideoReceiveStream]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/call/video_receive_stream.h +[VideoReceiveStream::Stats]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/call/video_receive_stream.h?q=VideoReceiveStream::Stats +[ReceiveStatisticsProxy]: https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/video/receive_statistics_proxy2.h +[rtcinboundrtpstreamstats-keyframesdecoded]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-keyframesdecoded +[rtcinboundrtpstreamstats-jitterbufferdelay]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-jitterbufferdelay +[rtcinboundrtpstreamstats-jitterbufferemittedcount]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-jitterbufferemittedcount +[rtcinboundrtpstreamstats-decoderimplementation]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-decoderimplementation +[rtcreceivedrtpstreamstats-framesdropped]: https://www.w3.org/TR/webrtc-stats/#dom-rtcreceivedrtpstreamstats-framesdropped +[rtcinboundrtpstreamstats-framesdecoded]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-framesdecoded +[rtcinboundrtpstreamstats-framespersecond]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-framespersecond +[rtcinboundrtpstreamstats-totaldecodetime]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-totaldecodetime +[rtcinboundrtpstreamstats-qpsum]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-qpsum +[rtcinboundrtpstreamstats-totalinterframedelay]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-totalinterframedelay +[rtcinboundrtpstreamstats-totalsquaredinterframedelay]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-totalsquaredinterframedelay +[rtcinboundrtpstreamstats-estimatedplayouttimestamp]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-estimatedplayouttimestamp +[rtcinboundrtpstreamstats-framewidth]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-framewidth +[rtcinboundrtpstreamstats-frameheight]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-frameheight +[nackcount]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-nackcount +[fircount]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-fircount +[plicount]: https://w3c.github.io/webrtc-stats/#dom-rtcinboundrtpstreamstats-plicount diff --git a/video/pc_full_stack_tests.cc b/video/pc_full_stack_tests.cc index d515a5271b..5cebf41e91 100644 --- a/video/pc_full_stack_tests.cc +++ b/video/pc_full_stack_tests.cc @@ -21,8 +21,8 @@ #include "api/test/peerconnection_quality_test_fixture.h" #include "api/test/simulated_network.h" #include "api/test/time_controller.h" +#include "api/video_codecs/vp9_profile.h" #include "call/simulated_network.h" -#include "media/base/vp9_profile.h" #include "modules/video_coding/codecs/vp9/include/vp9.h" #include "system_wrappers/include/field_trial.h" #include "test/field_trial.h" @@ -1738,9 +1738,9 @@ TEST(PCFullStackTest, MAYBE_LargeRoomVP8_50thumb) { } */ +/* class PCDualStreamsTest : public ::testing::TestWithParam {}; -/* // Disable dual video test on mobile device becuase it's too heavy. // TODO(bugs.webrtc.org/9840): Investigate why is this test flaky on MAC. #if !defined(WEBRTC_ANDROID) && !defined(WEBRTC_IOS) && !defined(WEBRTC_MAC) @@ -1842,10 +1842,10 @@ TEST_P(PCDualStreamsTest, Conference_Restricted) { auto fixture = CreateVideoQualityTestFixture(); fixture->RunWithAnalyzer(dual_streams); } -*/ INSTANTIATE_TEST_SUITE_P(PCFullStackTest, PCDualStreamsTest, ::testing::Values(0, 1)); +*/ } // namespace webrtc diff --git a/video/quality_scaling_tests.cc b/video/quality_scaling_tests.cc index b72b25b86b..9837517b78 100644 --- a/video/quality_scaling_tests.cc +++ b/video/quality_scaling_tests.cc @@ -15,316 +15,513 @@ #include "modules/video_coding/codecs/h264/include/h264.h" #include "modules/video_coding/codecs/vp8/include/vp8.h" #include "modules/video_coding/codecs/vp9/include/vp9.h" +#include "rtc_base/experiments/encoder_info_settings.h" #include "test/call_test.h" #include "test/field_trial.h" #include "test/frame_generator_capturer.h" namespace webrtc { namespace { -constexpr int kWidth = 1280; -constexpr int kHeight = 720; +constexpr int kInitialWidth = 1280; +constexpr int kInitialHeight = 720; constexpr int kLowStartBps = 100000; -constexpr int kHighStartBps = 600000; -constexpr size_t kTimeoutMs = 10000; // Some tests are expected to time out. +constexpr int kHighStartBps = 1000000; +constexpr int kDefaultVgaMinStartBps = 500000; // From video_stream_encoder.cc +constexpr int kTimeoutMs = 10000; // Some tests are expected to time out. void SetEncoderSpecific(VideoEncoderConfig* encoder_config, VideoCodecType type, bool automatic_resize, - bool frame_dropping) { + size_t num_spatial_layers) { if (type == kVideoCodecVP8) { VideoCodecVP8 vp8 = VideoEncoder::GetDefaultVp8Settings(); vp8.automaticResizeOn = automatic_resize; - vp8.frameDroppingOn = frame_dropping; - encoder_config->encoder_specific_settings = new rtc::RefCountedObject< - VideoEncoderConfig::Vp8EncoderSpecificSettings>(vp8); + encoder_config->encoder_specific_settings = + rtc::make_ref_counted( + vp8); } else if (type == kVideoCodecVP9) { VideoCodecVP9 vp9 = VideoEncoder::GetDefaultVp9Settings(); vp9.automaticResizeOn = automatic_resize; - vp9.frameDroppingOn = frame_dropping; - encoder_config->encoder_specific_settings = new rtc::RefCountedObject< - VideoEncoderConfig::Vp9EncoderSpecificSettings>(vp9); - } else if (type == kVideoCodecH264) { - VideoCodecH264 h264 = VideoEncoder::GetDefaultH264Settings(); - h264.frameDroppingOn = frame_dropping; - encoder_config->encoder_specific_settings = new rtc::RefCountedObject< - VideoEncoderConfig::H264EncoderSpecificSettings>(h264); + vp9.numberOfSpatialLayers = num_spatial_layers; + encoder_config->encoder_specific_settings = + rtc::make_ref_counted( + vp9); } } } // namespace class QualityScalingTest : public test::CallTest { protected: - void RunTest(VideoEncoderFactory* encoder_factory, - const std::string& payload_name, - const std::vector& streams_active, - int start_bps, - bool automatic_resize, - bool frame_dropping, - bool expect_adaptation); - const std::string kPrefix = "WebRTC-Video-QualityScaling/Enabled-"; const std::string kEnd = ",0,0,0.9995,0.9999,1/"; + const absl::optional + kSinglecastLimits720pVp8 = + EncoderInfoSettings::GetDefaultSinglecastBitrateLimitsForResolution( + kVideoCodecVP8, + 1280 * 720); + const absl::optional + kSinglecastLimits360pVp9 = + EncoderInfoSettings::GetDefaultSinglecastBitrateLimitsForResolution( + kVideoCodecVP9, + 640 * 360); }; -void QualityScalingTest::RunTest(VideoEncoderFactory* encoder_factory, - const std::string& payload_name, - const std::vector& streams_active, - int start_bps, - bool automatic_resize, - bool frame_dropping, - bool expect_adaptation) { - class ScalingObserver - : public test::SendTest, - public test::FrameGeneratorCapturer::SinkWantsObserver { - public: - ScalingObserver(VideoEncoderFactory* encoder_factory, - const std::string& payload_name, - const std::vector& streams_active, - int start_bps, - bool automatic_resize, - bool frame_dropping, - bool expect_adaptation) - : SendTest(expect_adaptation ? kDefaultTimeoutMs : kTimeoutMs), - encoder_factory_(encoder_factory), - payload_name_(payload_name), - streams_active_(streams_active), - start_bps_(start_bps), - automatic_resize_(automatic_resize), - frame_dropping_(frame_dropping), - expect_adaptation_(expect_adaptation) {} - - private: - void OnFrameGeneratorCapturerCreated( - test::FrameGeneratorCapturer* frame_generator_capturer) override { - frame_generator_capturer->SetSinkWantsObserver(this); - // Set initial resolution. - frame_generator_capturer->ChangeResolution(kWidth, kHeight); - } +class ScalingObserver : public test::SendTest { + protected: + ScalingObserver(const std::string& payload_name, + const std::vector& streams_active, + int start_bps, + bool automatic_resize, + bool expect_scaling) + : SendTest(expect_scaling ? kTimeoutMs * 4 : kTimeoutMs), + encoder_factory_( + [](const SdpVideoFormat& format) -> std::unique_ptr { + if (format.name == "VP8") + return VP8Encoder::Create(); + if (format.name == "VP9") + return VP9Encoder::Create(); + if (format.name == "H264") + return H264Encoder::Create(cricket::VideoCodec("H264")); + RTC_NOTREACHED() << format.name; + return nullptr; + }), + payload_name_(payload_name), + streams_active_(streams_active), + start_bps_(start_bps), + automatic_resize_(automatic_resize), + expect_scaling_(expect_scaling) {} + + DegradationPreference degradation_preference_ = + DegradationPreference::MAINTAIN_FRAMERATE; + + private: + void ModifySenderBitrateConfig(BitrateConstraints* bitrate_config) override { + bitrate_config->start_bitrate_bps = start_bps_; + } - // Called when FrameGeneratorCapturer::AddOrUpdateSink is called. - void OnSinkWantsChanged(rtc::VideoSinkInterface* sink, - const rtc::VideoSinkWants& wants) override { - if (wants.max_pixel_count < kWidth * kHeight) - observation_complete_.Set(); + void ModifyVideoDegradationPreference( + DegradationPreference* degradation_preference) override { + *degradation_preference = degradation_preference_; + } + + size_t GetNumVideoStreams() const override { + return (payload_name_ == "VP9") ? 1 : streams_active_.size(); + } + + void ModifyVideoConfigs( + VideoSendStream::Config* send_config, + std::vector* receive_configs, + VideoEncoderConfig* encoder_config) override { + send_config->encoder_settings.encoder_factory = &encoder_factory_; + send_config->rtp.payload_name = payload_name_; + send_config->rtp.payload_type = test::CallTest::kVideoSendPayloadType; + encoder_config->video_format.name = payload_name_; + const VideoCodecType codec_type = PayloadStringToCodecType(payload_name_); + encoder_config->codec_type = codec_type; + encoder_config->max_bitrate_bps = + std::max(start_bps_, encoder_config->max_bitrate_bps); + if (payload_name_ == "VP9") { + // Simulcast layers indicates which spatial layers are active. + encoder_config->simulcast_layers.resize(streams_active_.size()); + encoder_config->simulcast_layers[0].max_bitrate_bps = + encoder_config->max_bitrate_bps; } - void ModifySenderBitrateConfig( - BitrateConstraints* bitrate_config) override { - bitrate_config->start_bitrate_bps = start_bps_; + double scale_factor = 1.0; + for (int i = streams_active_.size() - 1; i >= 0; --i) { + VideoStream& stream = encoder_config->simulcast_layers[i]; + stream.active = streams_active_[i]; + stream.scale_resolution_down_by = scale_factor; + scale_factor *= (payload_name_ == "VP9") ? 1.0 : 2.0; } + SetEncoderSpecific(encoder_config, codec_type, automatic_resize_, + streams_active_.size()); + } - size_t GetNumVideoStreams() const override { - return streams_active_.size(); - } + void PerformTest() override { EXPECT_EQ(expect_scaling_, Wait()); } - void ModifyVideoConfigs( - VideoSendStream::Config* send_config, - std::vector* receive_configs, - VideoEncoderConfig* encoder_config) override { - send_config->encoder_settings.encoder_factory = encoder_factory_; - send_config->rtp.payload_name = payload_name_; - send_config->rtp.payload_type = kVideoSendPayloadType; - const VideoCodecType codec_type = PayloadStringToCodecType(payload_name_); - encoder_config->codec_type = codec_type; - encoder_config->max_bitrate_bps = - std::max(start_bps_, encoder_config->max_bitrate_bps); - double scale_factor = 1.0; - for (int i = streams_active_.size() - 1; i >= 0; --i) { - VideoStream& stream = encoder_config->simulcast_layers[i]; - stream.active = streams_active_[i]; - stream.scale_resolution_down_by = scale_factor; - scale_factor *= 2.0; - } - SetEncoderSpecific(encoder_config, codec_type, automatic_resize_, - frame_dropping_); - } + test::FunctionVideoEncoderFactory encoder_factory_; + const std::string payload_name_; + const std::vector streams_active_; + const int start_bps_; + const bool automatic_resize_; + const bool expect_scaling_; +}; + +class DownscalingObserver + : public ScalingObserver, + public test::FrameGeneratorCapturer::SinkWantsObserver { + public: + DownscalingObserver(const std::string& payload_name, + const std::vector& streams_active, + int start_bps, + bool automatic_resize, + bool expect_downscale) + : ScalingObserver(payload_name, + streams_active, + start_bps, + automatic_resize, + expect_downscale) {} + + private: + void OnFrameGeneratorCapturerCreated( + test::FrameGeneratorCapturer* frame_generator_capturer) override { + frame_generator_capturer->SetSinkWantsObserver(this); + frame_generator_capturer->ChangeResolution(kInitialWidth, kInitialHeight); + } + + void OnSinkWantsChanged(rtc::VideoSinkInterface* sink, + const rtc::VideoSinkWants& wants) override { + if (wants.max_pixel_count < kInitialWidth * kInitialHeight) + observation_complete_.Set(); + } +}; + +class UpscalingObserver + : public ScalingObserver, + public test::FrameGeneratorCapturer::SinkWantsObserver { + public: + UpscalingObserver(const std::string& payload_name, + const std::vector& streams_active, + int start_bps, + bool automatic_resize, + bool expect_upscale) + : ScalingObserver(payload_name, + streams_active, + start_bps, + automatic_resize, + expect_upscale) {} + + void SetDegradationPreference(DegradationPreference preference) { + degradation_preference_ = preference; + } + + private: + void OnFrameGeneratorCapturerCreated( + test::FrameGeneratorCapturer* frame_generator_capturer) override { + frame_generator_capturer->SetSinkWantsObserver(this); + frame_generator_capturer->ChangeResolution(kInitialWidth, kInitialHeight); + } - void PerformTest() override { - EXPECT_EQ(expect_adaptation_, Wait()) - << "Timed out while waiting for a scale down."; + void OnSinkWantsChanged(rtc::VideoSinkInterface* sink, + const rtc::VideoSinkWants& wants) override { + if (wants.max_pixel_count > last_wants_.max_pixel_count) { + if (wants.max_pixel_count == std::numeric_limits::max()) + observation_complete_.Set(); } + last_wants_ = wants; + } - VideoEncoderFactory* const encoder_factory_; - const std::string payload_name_; - const std::vector streams_active_; - const int start_bps_; - const bool automatic_resize_; - const bool frame_dropping_; - const bool expect_adaptation_; - } test(encoder_factory, payload_name, streams_active, start_bps, - automatic_resize, frame_dropping, expect_adaptation); + rtc::VideoSinkWants last_wants_; +}; +TEST_F(QualityScalingTest, AdaptsDownForHighQp_Vp8) { + // qp_low:1, qp_high:1 -> kHighQp + test::ScopedFieldTrials field_trials(kPrefix + "1,1,0,0,0,0" + kEnd); + + DownscalingObserver test("VP8", /*streams_active=*/{true}, kHighStartBps, + /*automatic_resize=*/true, + /*expect_downscale=*/true); RunBaseTest(&test); } -TEST_F(QualityScalingTest, AdaptsDownForHighQp_Vp8) { - // VP8 QP thresholds, low:1, high:1 -> high QP. +TEST_F(QualityScalingTest, NoAdaptDownForHighQpIfScalingOff_Vp8) { + // qp_low:1, qp_high:1 -> kHighQp test::ScopedFieldTrials field_trials(kPrefix + "1,1,0,0,0,0" + kEnd); - // QualityScaler enabled. - const bool kAutomaticResize = true; - const bool kFrameDropping = true; - const bool kExpectAdapt = true; - - test::FunctionVideoEncoderFactory encoder_factory( - []() { return VP8Encoder::Create(); }); - RunTest(&encoder_factory, "VP8", {true}, kHighStartBps, kAutomaticResize, - kFrameDropping, kExpectAdapt); + DownscalingObserver test("VP8", /*streams_active=*/{true}, kHighStartBps, + /*automatic_resize=*/false, + /*expect_downscale=*/false); + RunBaseTest(&test); } -TEST_F(QualityScalingTest, NoAdaptDownForHighQpWithResizeOff_Vp8) { - // VP8 QP thresholds, low:1, high:1 -> high QP. - test::ScopedFieldTrials field_trials(kPrefix + "1,1,0,0,0,0" + kEnd); +TEST_F(QualityScalingTest, NoAdaptDownForNormalQp_Vp8) { + // qp_low:1, qp_high:127 -> kNormalQp + test::ScopedFieldTrials field_trials(kPrefix + "1,127,0,0,0,0" + kEnd); + + DownscalingObserver test("VP8", /*streams_active=*/{true}, kHighStartBps, + /*automatic_resize=*/true, + /*expect_downscale=*/false); + RunBaseTest(&test); +} - // QualityScaler disabled. - const bool kAutomaticResize = false; - const bool kFrameDropping = true; - const bool kExpectAdapt = false; +TEST_F(QualityScalingTest, AdaptsDownForLowStartBitrate_Vp8) { + // qp_low:1, qp_high:127 -> kNormalQp + test::ScopedFieldTrials field_trials(kPrefix + "1,127,0,0,0,0" + kEnd); - test::FunctionVideoEncoderFactory encoder_factory( - []() { return VP8Encoder::Create(); }); - RunTest(&encoder_factory, "VP8", {true}, kHighStartBps, kAutomaticResize, - kFrameDropping, kExpectAdapt); + DownscalingObserver test("VP8", /*streams_active=*/{true}, kLowStartBps, + /*automatic_resize=*/true, + /*expect_downscale=*/true); + RunBaseTest(&test); } -// TODO(bugs.webrtc.org/10388): Fix and re-enable. -TEST_F(QualityScalingTest, - DISABLED_NoAdaptDownForHighQpWithFrameDroppingOff_Vp8) { - // VP8 QP thresholds, low:1, high:1 -> high QP. - test::ScopedFieldTrials field_trials(kPrefix + "1,1,0,0,0,0" + kEnd); +TEST_F(QualityScalingTest, AdaptsDownForLowStartBitrateAndThenUp) { + // qp_low:127, qp_high:127 -> kLowQp + test::ScopedFieldTrials field_trials( + kPrefix + "127,127,0,0,0,0" + kEnd + + "WebRTC-Video-BalancedDegradationSettings/" + "pixels:230400|921600,fps:20|30,kbps:300|500/"); // should not affect + + UpscalingObserver test("VP8", /*streams_active=*/{true}, + kDefaultVgaMinStartBps - 1, + /*automatic_resize=*/true, /*expect_upscale=*/true); + RunBaseTest(&test); +} - // QualityScaler disabled. - const bool kAutomaticResize = true; - const bool kFrameDropping = false; - const bool kExpectAdapt = false; +TEST_F(QualityScalingTest, AdaptsDownAndThenUpWithBalanced) { + // qp_low:127, qp_high:127 -> kLowQp + test::ScopedFieldTrials field_trials( + kPrefix + "127,127,0,0,0,0" + kEnd + + "WebRTC-Video-BalancedDegradationSettings/" + "pixels:230400|921600,fps:20|30,kbps:300|499/"); + + UpscalingObserver test("VP8", /*streams_active=*/{true}, + kDefaultVgaMinStartBps - 1, + /*automatic_resize=*/true, /*expect_upscale=*/true); + test.SetDegradationPreference(DegradationPreference::BALANCED); + RunBaseTest(&test); +} - test::FunctionVideoEncoderFactory encoder_factory( - []() { return VP8Encoder::Create(); }); - RunTest(&encoder_factory, "VP8", {true}, kHighStartBps, kAutomaticResize, - kFrameDropping, kExpectAdapt); +TEST_F(QualityScalingTest, AdaptsDownButNotUpWithBalancedIfBitrateNotEnough) { + // qp_low:127, qp_high:127 -> kLowQp + test::ScopedFieldTrials field_trials( + kPrefix + "127,127,0,0,0,0" + kEnd + + "WebRTC-Video-BalancedDegradationSettings/" + "pixels:230400|921600,fps:20|30,kbps:300|500/"); + + UpscalingObserver test("VP8", /*streams_active=*/{true}, + kDefaultVgaMinStartBps - 1, + /*automatic_resize=*/true, /*expect_upscale=*/false); + test.SetDegradationPreference(DegradationPreference::BALANCED); + RunBaseTest(&test); } -TEST_F(QualityScalingTest, NoAdaptDownForNormalQp_Vp8) { - // VP8 QP thresholds, low:1, high:127 -> normal QP. +TEST_F(QualityScalingTest, NoAdaptDownForLowStartBitrate_Simulcast) { + // qp_low:1, qp_high:127 -> kNormalQp test::ScopedFieldTrials field_trials(kPrefix + "1,127,0,0,0,0" + kEnd); - // QualityScaler enabled. - const bool kAutomaticResize = true; - const bool kFrameDropping = true; - const bool kExpectAdapt = false; + DownscalingObserver test("VP8", /*streams_active=*/{true, true}, kLowStartBps, + /*automatic_resize=*/false, + /*expect_downscale=*/false); + RunBaseTest(&test); +} + +TEST_F(QualityScalingTest, AdaptsDownForHighQp_HighestStreamActive_Vp8) { + // qp_low:1, qp_high:1 -> kHighQp + test::ScopedFieldTrials field_trials(kPrefix + "1,1,0,0,0,0" + kEnd); - test::FunctionVideoEncoderFactory encoder_factory( - []() { return VP8Encoder::Create(); }); - RunTest(&encoder_factory, "VP8", {true}, kHighStartBps, kAutomaticResize, - kFrameDropping, kExpectAdapt); + DownscalingObserver test("VP8", /*streams_active=*/{false, false, true}, + kHighStartBps, + /*automatic_resize=*/true, + /*expect_downscale=*/true); + RunBaseTest(&test); } -TEST_F(QualityScalingTest, AdaptsDownForLowStartBitrate) { - // VP8 QP thresholds, low:1, high:127 -> normal QP. +TEST_F(QualityScalingTest, + AdaptsDownForLowStartBitrate_HighestStreamActive_Vp8) { + // qp_low:1, qp_high:127 -> kNormalQp test::ScopedFieldTrials field_trials(kPrefix + "1,127,0,0,0,0" + kEnd); - // QualityScaler enabled. - const bool kAutomaticResize = true; - const bool kFrameDropping = true; - const bool kExpectAdapt = true; + DownscalingObserver test("VP8", /*streams_active=*/{false, false, true}, + kSinglecastLimits720pVp8->min_start_bitrate_bps - 1, + /*automatic_resize=*/true, + /*expect_downscale=*/true); + RunBaseTest(&test); +} + +TEST_F(QualityScalingTest, AdaptsDownButNotUpWithMinStartBitrateLimit) { + // qp_low:127, qp_high:127 -> kLowQp + test::ScopedFieldTrials field_trials(kPrefix + "127,127,0,0,0,0" + kEnd); - test::FunctionVideoEncoderFactory encoder_factory( - []() { return VP8Encoder::Create(); }); - RunTest(&encoder_factory, "VP8", {true}, kLowStartBps, kAutomaticResize, - kFrameDropping, kExpectAdapt); + UpscalingObserver test("VP8", /*streams_active=*/{false, true}, + kSinglecastLimits720pVp8->min_start_bitrate_bps - 1, + /*automatic_resize=*/true, /*expect_upscale=*/false); + RunBaseTest(&test); } -TEST_F(QualityScalingTest, NoAdaptDownForLowStartBitrate_Simulcast) { - // VP8 QP thresholds, low:1, high:127 -> normal QP. +TEST_F(QualityScalingTest, NoAdaptDownForLowStartBitrateIfBitrateEnough_Vp8) { + // qp_low:1, qp_high:127 -> kNormalQp test::ScopedFieldTrials field_trials(kPrefix + "1,127,0,0,0,0" + kEnd); - // QualityScaler disabled. - const bool kAutomaticResize = false; - const bool kFrameDropping = true; - const bool kExpectAdapt = false; + DownscalingObserver test("VP8", /*streams_active=*/{false, false, true}, + kSinglecastLimits720pVp8->min_start_bitrate_bps, + /*automatic_resize=*/true, + /*expect_downscale=*/false); + RunBaseTest(&test); +} - test::FunctionVideoEncoderFactory encoder_factory( - []() { return VP8Encoder::Create(); }); - RunTest(&encoder_factory, "VP8", {true, true}, kLowStartBps, kAutomaticResize, - kFrameDropping, kExpectAdapt); +TEST_F(QualityScalingTest, + NoAdaptDownForLowStartBitrateIfDefaultLimitsDisabled_Vp8) { + // qp_low:1, qp_high:127 -> kNormalQp + test::ScopedFieldTrials field_trials( + kPrefix + "1,127,0,0,0,0" + kEnd + + "WebRTC-DefaultBitrateLimitsKillSwitch/Enabled/"); + + DownscalingObserver test("VP8", /*streams_active=*/{false, false, true}, + kSinglecastLimits720pVp8->min_start_bitrate_bps - 1, + /*automatic_resize=*/true, + /*expect_downscale=*/false); + RunBaseTest(&test); } TEST_F(QualityScalingTest, - AdaptsDownForLowStartBitrate_SimulcastOneActiveHighRes) { - // VP8 QP thresholds, low:1, high:127 -> normal QP. + NoAdaptDownForLowStartBitrate_OneStreamSinglecastLimitsNotUsed_Vp8) { + // qp_low:1, qp_high:127 -> kNormalQp test::ScopedFieldTrials field_trials(kPrefix + "1,127,0,0,0,0" + kEnd); - // QualityScaler enabled. - const bool kAutomaticResize = true; - const bool kFrameDropping = true; - const bool kExpectAdapt = true; + DownscalingObserver test("VP8", /*streams_active=*/{true}, + kSinglecastLimits720pVp8->min_start_bitrate_bps - 1, + /*automatic_resize=*/true, + /*expect_downscale=*/false); + RunBaseTest(&test); +} + +TEST_F(QualityScalingTest, NoAdaptDownForHighQp_LowestStreamActive_Vp8) { + // qp_low:1, qp_high:1 -> kHighQp + test::ScopedFieldTrials field_trials(kPrefix + "1,1,0,0,0,0" + kEnd); - test::FunctionVideoEncoderFactory encoder_factory( - []() { return VP8Encoder::Create(); }); - RunTest(&encoder_factory, "VP8", {false, false, true}, kLowStartBps, - kAutomaticResize, kFrameDropping, kExpectAdapt); + DownscalingObserver test("VP8", /*streams_active=*/{true, false, false}, + kHighStartBps, + /*automatic_resize=*/true, + /*expect_downscale=*/false); + RunBaseTest(&test); } TEST_F(QualityScalingTest, - NoAdaptDownForLowStartBitrate_SimulcastOneActiveLowRes) { - // VP8 QP thresholds, low:1, high:127 -> normal QP. + NoAdaptDownForLowStartBitrate_LowestStreamActive_Vp8) { + // qp_low:1, qp_high:127 -> kNormalQp test::ScopedFieldTrials field_trials(kPrefix + "1,127,0,0,0,0" + kEnd); - // QualityScaler enabled. - const bool kAutomaticResize = true; - const bool kFrameDropping = true; - const bool kExpectAdapt = false; - - test::FunctionVideoEncoderFactory encoder_factory( - []() { return VP8Encoder::Create(); }); - RunTest(&encoder_factory, "VP8", {true, false, false}, kLowStartBps, - kAutomaticResize, kFrameDropping, kExpectAdapt); + DownscalingObserver test("VP8", /*streams_active=*/{true, false, false}, + kLowStartBps, + /*automatic_resize=*/true, + /*expect_downscale=*/false); + RunBaseTest(&test); } -TEST_F(QualityScalingTest, NoAdaptDownForLowStartBitrateWithScalingOff) { - // VP8 QP thresholds, low:1, high:127 -> normal QP. +TEST_F(QualityScalingTest, NoAdaptDownForLowStartBitrateIfScalingOff_Vp8) { + // qp_low:1, qp_high:127 -> kNormalQp test::ScopedFieldTrials field_trials(kPrefix + "1,127,0,0,0,0" + kEnd); - // QualityScaler disabled. - const bool kAutomaticResize = false; - const bool kFrameDropping = true; - const bool kExpectAdapt = false; + DownscalingObserver test("VP8", /*streams_active=*/{true}, kLowStartBps, + /*automatic_resize=*/false, + /*expect_downscale=*/false); + RunBaseTest(&test); +} + +TEST_F(QualityScalingTest, AdaptsDownForHighQp_Vp9) { + // qp_low:1, qp_high:1 -> kHighQp + test::ScopedFieldTrials field_trials(kPrefix + "0,0,1,1,0,0" + kEnd + + "WebRTC-VP9QualityScaler/Enabled/"); - test::FunctionVideoEncoderFactory encoder_factory( - []() { return VP8Encoder::Create(); }); - RunTest(&encoder_factory, "VP8", {true}, kLowStartBps, kAutomaticResize, - kFrameDropping, kExpectAdapt); + DownscalingObserver test("VP9", /*streams_active=*/{true}, kHighStartBps, + /*automatic_resize=*/true, + /*expect_downscale=*/true); + RunBaseTest(&test); } -TEST_F(QualityScalingTest, NoAdaptDownForHighQp_Vp9) { - // VP9 QP thresholds, low:1, high:1 -> high QP. +TEST_F(QualityScalingTest, NoAdaptDownForHighQpIfScalingOff_Vp9) { + // qp_low:1, qp_high:1 -> kHighQp test::ScopedFieldTrials field_trials(kPrefix + "0,0,1,1,0,0" + kEnd + "WebRTC-VP9QualityScaler/Disabled/"); - // QualityScaler always disabled. - const bool kAutomaticResize = true; - const bool kFrameDropping = true; - const bool kExpectAdapt = false; + DownscalingObserver test("VP9", /*streams_active=*/{true}, kHighStartBps, + /*automatic_resize=*/true, + /*expect_downscale=*/false); + RunBaseTest(&test); +} + +TEST_F(QualityScalingTest, AdaptsDownForLowStartBitrate_Vp9) { + // qp_low:1, qp_high:255 -> kNormalQp + test::ScopedFieldTrials field_trials(kPrefix + "0,0,1,255,0,0" + kEnd + + "WebRTC-VP9QualityScaler/Enabled/"); + + DownscalingObserver test("VP9", /*streams_active=*/{true}, kLowStartBps, + /*automatic_resize=*/true, + /*expect_downscale=*/true); + RunBaseTest(&test); +} + +TEST_F(QualityScalingTest, NoAdaptDownForHighQp_LowestStreamActive_Vp9) { + // qp_low:1, qp_high:1 -> kHighQp + test::ScopedFieldTrials field_trials(kPrefix + "0,0,1,1,0,0" + kEnd + + "WebRTC-VP9QualityScaler/Enabled/"); + + DownscalingObserver test("VP9", /*streams_active=*/{true, false, false}, + kHighStartBps, + /*automatic_resize=*/true, + /*expect_downscale=*/false); + RunBaseTest(&test); +} + +TEST_F(QualityScalingTest, + NoAdaptDownForLowStartBitrate_LowestStreamActive_Vp9) { + // qp_low:1, qp_high:255 -> kNormalQp + test::ScopedFieldTrials field_trials(kPrefix + "0,0,1,255,0,0" + kEnd + + "WebRTC-VP9QualityScaler/Enabled/"); + + DownscalingObserver test("VP9", /*streams_active=*/{true, false, false}, + kLowStartBps, + /*automatic_resize=*/true, + /*expect_downscale=*/false); + RunBaseTest(&test); +} + +TEST_F(QualityScalingTest, AdaptsDownForHighQp_MiddleStreamActive_Vp9) { + // qp_low:1, qp_high:1 -> kHighQp + test::ScopedFieldTrials field_trials(kPrefix + "0,0,1,1,0,0" + kEnd + + "WebRTC-VP9QualityScaler/Enabled/"); + + DownscalingObserver test("VP9", /*streams_active=*/{false, true, false}, + kHighStartBps, + /*automatic_resize=*/true, + /*expect_downscale=*/true); + RunBaseTest(&test); +} + +TEST_F(QualityScalingTest, + AdaptsDownForLowStartBitrate_MiddleStreamActive_Vp9) { + // qp_low:1, qp_high:255 -> kNormalQp + test::ScopedFieldTrials field_trials(kPrefix + "0,0,1,255,0,0" + kEnd + + "WebRTC-VP9QualityScaler/Enabled/"); + + DownscalingObserver test("VP9", /*streams_active=*/{false, true, false}, + kSinglecastLimits360pVp9->min_start_bitrate_bps - 1, + /*automatic_resize=*/true, + /*expect_downscale=*/true); + RunBaseTest(&test); +} + +TEST_F(QualityScalingTest, NoAdaptDownForLowStartBitrateIfBitrateEnough_Vp9) { + // qp_low:1, qp_high:255 -> kNormalQp + test::ScopedFieldTrials field_trials(kPrefix + "0,0,1,255,0,0" + kEnd + + "WebRTC-VP9QualityScaler/Enabled/"); - test::FunctionVideoEncoderFactory encoder_factory( - []() { return VP9Encoder::Create(); }); - RunTest(&encoder_factory, "VP9", {true}, kHighStartBps, kAutomaticResize, - kFrameDropping, kExpectAdapt); + DownscalingObserver test("VP9", /*streams_active=*/{false, true, false}, + kSinglecastLimits360pVp9->min_start_bitrate_bps, + /*automatic_resize=*/true, + /*expect_downscale=*/false); + RunBaseTest(&test); } #if defined(WEBRTC_USE_H264) TEST_F(QualityScalingTest, AdaptsDownForHighQp_H264) { - // H264 QP thresholds, low:1, high:1 -> high QP. + // qp_low:1, qp_high:1 -> kHighQp test::ScopedFieldTrials field_trials(kPrefix + "0,0,0,0,1,1" + kEnd); - // QualityScaler always enabled. - const bool kAutomaticResize = false; - const bool kFrameDropping = false; - const bool kExpectAdapt = true; + DownscalingObserver test("H264", /*streams_active=*/{true}, kHighStartBps, + /*automatic_resize=*/true, + /*expect_downscale=*/true); + RunBaseTest(&test); +} + +TEST_F(QualityScalingTest, AdaptsDownForLowStartBitrate_H264) { + // qp_low:1, qp_high:51 -> kNormalQp + test::ScopedFieldTrials field_trials(kPrefix + "0,0,0,0,1,51" + kEnd); - test::FunctionVideoEncoderFactory encoder_factory( - []() { return H264Encoder::Create(cricket::VideoCodec("H264")); }); - RunTest(&encoder_factory, "H264", {true}, kHighStartBps, kAutomaticResize, - kFrameDropping, kExpectAdapt); + DownscalingObserver test("H264", /*streams_active=*/{true}, kLowStartBps, + /*automatic_resize=*/true, + /*expect_downscale=*/true); + RunBaseTest(&test); } #endif // defined(WEBRTC_USE_H264) diff --git a/video/receive_statistics_proxy.h b/video/receive_statistics_proxy.h index 8b94c32b69..57738f29cf 100644 --- a/video/receive_statistics_proxy.h +++ b/video/receive_statistics_proxy.h @@ -17,6 +17,7 @@ #include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "call/video_receive_stream.h" #include "modules/include/module_common_types.h" #include "modules/video_coding/include/video_coding_defines.h" @@ -27,7 +28,6 @@ #include "rtc_base/rate_tracker.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" #include "video/quality_threshold.h" #include "video/stats_counter.h" #include "video/video_quality_observer.h" @@ -196,9 +196,9 @@ class ReceiveStatisticsProxy : public VCMReceiveStatisticsCallback, RTC_GUARDED_BY(&mutex_); absl::optional last_estimated_playout_time_ms_ RTC_GUARDED_BY(&mutex_); - rtc::ThreadChecker decode_thread_; - rtc::ThreadChecker network_thread_; - rtc::ThreadChecker main_thread_; + SequenceChecker decode_thread_; + SequenceChecker network_thread_; + SequenceChecker main_thread_; }; } // namespace webrtc diff --git a/video/receive_statistics_proxy2.cc b/video/receive_statistics_proxy2.cc index 3cce3c8ea4..af3cd221e7 100644 --- a/video/receive_statistics_proxy2.cc +++ b/video/receive_statistics_proxy2.cc @@ -946,26 +946,21 @@ void ReceiveStatisticsProxy::OnRenderedFrame( void ReceiveStatisticsProxy::OnSyncOffsetUpdated(int64_t video_playout_ntp_ms, int64_t sync_offset_ms, double estimated_freq_khz) { - RTC_DCHECK_RUN_ON(&incoming_render_queue_); - int64_t now_ms = clock_->TimeInMilliseconds(); - worker_thread_->PostTask( - ToQueuedTask(task_safety_, [video_playout_ntp_ms, sync_offset_ms, - estimated_freq_khz, now_ms, this]() { - RTC_DCHECK_RUN_ON(&main_thread_); - sync_offset_counter_.Add(std::abs(sync_offset_ms)); - stats_.sync_offset_ms = sync_offset_ms; - last_estimated_playout_ntp_timestamp_ms_ = video_playout_ntp_ms; - last_estimated_playout_time_ms_ = now_ms; - - const double kMaxFreqKhz = 10000.0; - int offset_khz = kMaxFreqKhz; - // Should not be zero or negative. If so, report max. - if (estimated_freq_khz < kMaxFreqKhz && estimated_freq_khz > 0.0) - offset_khz = - static_cast(std::fabs(estimated_freq_khz - 90.0) + 0.5); - - freq_offset_counter_.Add(offset_khz); - })); + RTC_DCHECK_RUN_ON(&main_thread_); + + const int64_t now_ms = clock_->TimeInMilliseconds(); + sync_offset_counter_.Add(std::abs(sync_offset_ms)); + stats_.sync_offset_ms = sync_offset_ms; + last_estimated_playout_ntp_timestamp_ms_ = video_playout_ntp_ms; + last_estimated_playout_time_ms_ = now_ms; + + const double kMaxFreqKhz = 10000.0; + int offset_khz = kMaxFreqKhz; + // Should not be zero or negative. If so, report max. + if (estimated_freq_khz < kMaxFreqKhz && estimated_freq_khz > 0.0) + offset_khz = static_cast(std::fabs(estimated_freq_khz - 90.0) + 0.5); + + freq_offset_counter_.Add(offset_khz); } void ReceiveStatisticsProxy::OnCompleteFrame(bool is_keyframe, diff --git a/video/receive_statistics_proxy2.h b/video/receive_statistics_proxy2.h index e9950c5e84..7797d93217 100644 --- a/video/receive_statistics_proxy2.h +++ b/video/receive_statistics_proxy2.h @@ -17,6 +17,7 @@ #include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_base.h" #include "api/units/timestamp.h" #include "call/video_receive_stream.h" @@ -27,11 +28,9 @@ #include "rtc_base/numerics/sample_counter.h" #include "rtc_base/rate_statistics.h" #include "rtc_base/rate_tracker.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" #include "video/quality_threshold.h" #include "video/stats_counter.h" #include "video/video_quality_observer2.h" @@ -215,7 +214,7 @@ class ReceiveStatisticsProxy : public VCMReceiveStatisticsCallback, ScopedTaskSafety task_safety_; RTC_NO_UNIQUE_ADDRESS SequenceChecker decode_queue_; - rtc::ThreadChecker main_thread_; + SequenceChecker main_thread_; RTC_NO_UNIQUE_ADDRESS SequenceChecker incoming_render_queue_; }; diff --git a/video/report_block_stats.cc b/video/report_block_stats.cc index e3e95f9aed..bf60364682 100644 --- a/video/report_block_stats.cc +++ b/video/report_block_stats.cc @@ -31,16 +31,13 @@ ReportBlockStats::ReportBlockStats() ReportBlockStats::~ReportBlockStats() {} -void ReportBlockStats::Store(uint32_t ssrc, const RtcpStatistics& rtcp_stats) { +void ReportBlockStats::Store(uint32_t ssrc, + int packets_lost, + uint32_t extended_highest_sequence_number) { Report report; - report.packets_lost = rtcp_stats.packets_lost; - report.extended_highest_sequence_number = - rtcp_stats.extended_highest_sequence_number; - StoreAndAddPacketIncrement(ssrc, report); -} + report.packets_lost = packets_lost; + report.extended_highest_sequence_number = extended_highest_sequence_number; -void ReportBlockStats::StoreAndAddPacketIncrement(uint32_t ssrc, - const Report& report) { // Get diff with previous report block. const auto prev_report = prev_reports_.find(ssrc); if (prev_report != prev_reports_.end()) { diff --git a/video/report_block_stats.h b/video/report_block_stats.h index de4a079032..1d1140295c 100644 --- a/video/report_block_stats.h +++ b/video/report_block_stats.h @@ -15,8 +15,6 @@ #include -#include "modules/rtp_rtcp/include/rtcp_statistics.h" - namespace webrtc { // TODO(nisse): Usefulness of this class is somewhat unclear. The inputs are @@ -32,7 +30,9 @@ class ReportBlockStats { ~ReportBlockStats(); // Updates stats and stores report block. - void Store(uint32_t ssrc, const RtcpStatistics& rtcp_stats); + void Store(uint32_t ssrc, + int packets_lost, + uint32_t extended_highest_sequence_number); // Returns the total fraction of lost packets (or -1 if less than two report // blocks have been stored). @@ -45,10 +45,6 @@ class ReportBlockStats { int32_t packets_lost; }; - // Updates the total number of packets/lost packets. - // Stores the report. - void StoreAndAddPacketIncrement(uint32_t ssrc, const Report& report); - // The total number of packets/lost packets. uint32_t num_sequence_numbers_; uint32_t num_lost_sequence_numbers_; diff --git a/video/report_block_stats_unittest.cc b/video/report_block_stats_unittest.cc index 0b0230941f..bd66e571a0 100644 --- a/video/report_block_stats_unittest.cc +++ b/video/report_block_stats_unittest.cc @@ -13,65 +13,51 @@ #include "test/gtest.h" namespace webrtc { +namespace { -class ReportBlockStatsTest : public ::testing::Test { - protected: - ReportBlockStatsTest() { - // kSsrc1: report 1-3. - stats1_1_.packets_lost = 10; - stats1_1_.extended_highest_sequence_number = 24000; - stats1_2_.packets_lost = 15; - stats1_2_.extended_highest_sequence_number = 24100; - stats1_3_.packets_lost = 50; - stats1_3_.extended_highest_sequence_number = 24200; - // kSsrc2: report 1,2. - stats2_1_.packets_lost = 111; - stats2_1_.extended_highest_sequence_number = 8500; - stats2_2_.packets_lost = 136; - stats2_2_.extended_highest_sequence_number = 8800; - } +constexpr uint32_t kSsrc1 = 123; +constexpr uint32_t kSsrc2 = 234; - const uint32_t kSsrc1 = 123; - const uint32_t kSsrc2 = 234; - RtcpStatistics stats1_1_; - RtcpStatistics stats1_2_; - RtcpStatistics stats1_3_; - RtcpStatistics stats2_1_; - RtcpStatistics stats2_2_; -}; - -TEST_F(ReportBlockStatsTest, StoreAndGetFractionLost) { +TEST(ReportBlockStatsTest, StoreAndGetFractionLost) { ReportBlockStats stats; EXPECT_EQ(-1, stats.FractionLostInPercent()); // First report. - stats.Store(kSsrc1, stats1_1_); + stats.Store(kSsrc1, /*packets_lost=*/10, + /*extended_highest_sequence_number=*/24'000); EXPECT_EQ(-1, stats.FractionLostInPercent()); // fl: 100 * (15-10) / (24100-24000) = 5% - stats.Store(kSsrc1, stats1_2_); + stats.Store(kSsrc1, /*packets_lost=*/15, + /*extended_highest_sequence_number=*/24'100); EXPECT_EQ(5, stats.FractionLostInPercent()); // fl: 100 * (50-10) / (24200-24000) = 20% - stats.Store(kSsrc1, stats1_3_); + stats.Store(kSsrc1, /*packets_lost=*/50, + /*extended_highest_sequence_number=*/24'200); EXPECT_EQ(20, stats.FractionLostInPercent()); } -TEST_F(ReportBlockStatsTest, StoreAndGetFractionLost_TwoSsrcs) { +TEST(ReportBlockStatsTest, StoreAndGetFractionLost_TwoSsrcs) { ReportBlockStats stats; EXPECT_EQ(-1, stats.FractionLostInPercent()); // First report. - stats.Store(kSsrc1, stats1_1_); + stats.Store(kSsrc1, /*packets_lost=*/10, + /*extended_highest_sequence_number=*/24'000); EXPECT_EQ(-1, stats.FractionLostInPercent()); // fl: 100 * (15-10) / (24100-24000) = 5% - stats.Store(kSsrc1, stats1_2_); + stats.Store(kSsrc1, /*packets_lost=*/15, + /*extended_highest_sequence_number=*/24'100); EXPECT_EQ(5, stats.FractionLostInPercent()); // First report, kSsrc2. - stats.Store(kSsrc2, stats2_1_); + stats.Store(kSsrc2, /*packets_lost=*/111, + /*extended_highest_sequence_number=*/8'500); EXPECT_EQ(5, stats.FractionLostInPercent()); // fl: 100 * ((15-10) + (136-111)) / ((24100-24000) + (8800-8500)) = 7% - stats.Store(kSsrc2, stats2_2_); + stats.Store(kSsrc2, /*packets_lost=*/136, + /*extended_highest_sequence_number=*/8'800); EXPECT_EQ(7, stats.FractionLostInPercent()); } +} // namespace } // namespace webrtc diff --git a/video/rtp_streams_synchronizer.h b/video/rtp_streams_synchronizer.h index 732c9a7d77..574ccba70b 100644 --- a/video/rtp_streams_synchronizer.h +++ b/video/rtp_streams_synchronizer.h @@ -16,9 +16,9 @@ #include +#include "api/sequence_checker.h" #include "modules/include/module.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/thread_checker.h" #include "video/stream_synchronization.h" namespace webrtc { @@ -57,7 +57,7 @@ class RtpStreamsSynchronizer : public Module { StreamSynchronization::Measurements audio_measurement_ RTC_GUARDED_BY(mutex_); StreamSynchronization::Measurements video_measurement_ RTC_GUARDED_BY(mutex_); - rtc::ThreadChecker process_thread_checker_; + SequenceChecker process_thread_checker_; int64_t last_sync_time_ RTC_GUARDED_BY(&process_thread_checker_); int64_t last_stats_log_ms_ RTC_GUARDED_BY(&process_thread_checker_); }; diff --git a/video/rtp_streams_synchronizer2.h b/video/rtp_streams_synchronizer2.h index 3d31738225..192378aba7 100644 --- a/video/rtp_streams_synchronizer2.h +++ b/video/rtp_streams_synchronizer2.h @@ -13,7 +13,7 @@ #include -#include "rtc_base/synchronization/sequence_checker.h" +#include "api/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_queue.h" #include "rtc_base/task_utils/repeating_task.h" diff --git a/video/rtp_video_stream_receiver.cc b/video/rtp_video_stream_receiver.cc index ab60070d82..a0520cd350 100644 --- a/video/rtp_video_stream_receiver.cc +++ b/video/rtp_video_stream_receiver.cc @@ -210,7 +210,7 @@ RtpVideoStreamReceiver::RtpVideoStreamReceiver( ProcessThread* process_thread, NackSender* nack_sender, KeyFrameRequestSender* keyframe_request_sender, - video_coding::OnCompleteFrameCallback* complete_frame_callback, + OnCompleteFrameCallback* complete_frame_callback, rtc::scoped_refptr frame_decryptor, rtc::scoped_refptr frame_transformer) : RtpVideoStreamReceiver(clock, @@ -240,7 +240,7 @@ RtpVideoStreamReceiver::RtpVideoStreamReceiver( ProcessThread* process_thread, NackSender* nack_sender, KeyFrameRequestSender* keyframe_request_sender, - video_coding::OnCompleteFrameCallback* complete_frame_callback, + OnCompleteFrameCallback* complete_frame_callback, rtc::scoped_refptr frame_decryptor, rtc::scoped_refptr frame_transformer) : clock_(clock), @@ -271,10 +271,11 @@ RtpVideoStreamReceiver::RtpVideoStreamReceiver( // TODO(bugs.webrtc.org/10336): Let |rtcp_feedback_buffer_| communicate // directly with |rtp_rtcp_|. rtcp_feedback_buffer_(this, nack_sender, this), - packet_buffer_(clock_, kPacketBufferStartSize, PacketBufferMaxSize()), + packet_buffer_(kPacketBufferStartSize, PacketBufferMaxSize()), + reference_finder_(std::make_unique()), has_received_frame_(false), frames_decryptable_(false), - absolute_capture_time_receiver_(clock) { + absolute_capture_time_interpolator_(clock) { constexpr bool remb_candidate = true; if (packet_router_) packet_router_->AddReceiveRtpModule(rtp_rtcp_.get(), remb_candidate); @@ -321,9 +322,6 @@ RtpVideoStreamReceiver::RtpVideoStreamReceiver( process_thread_->RegisterModule(nack_module_.get(), RTC_FROM_HERE); } - reference_finder_ = - std::make_unique(this); - // Only construct the encrypted receiver if frame encryption is enabled. if (config_.crypto_options.sframe.require_frame_encryption) { buffered_frame_decryptor_ = @@ -334,10 +332,10 @@ RtpVideoStreamReceiver::RtpVideoStreamReceiver( } if (frame_transformer) { - frame_transformer_delegate_ = new rtc::RefCountedObject< - RtpVideoStreamReceiverFrameTransformerDelegate>( - this, std::move(frame_transformer), rtc::Thread::Current(), - config_.rtp.remote_ssrc); + frame_transformer_delegate_ = + rtc::make_ref_counted( + this, std::move(frame_transformer), rtc::Thread::Current(), + config_.rtp.remote_ssrc); frame_transformer_delegate_->Init(); } } @@ -365,6 +363,7 @@ void RtpVideoStreamReceiver::AddReceiveCodec( bool raw_payload) { if (codec_params.count(cricket::kH264FmtpSpsPpsIdrInKeyframe) || field_trial::IsEnabled("WebRTC-SpsPpsIdrIsH264Keyframe")) { + MutexLock lock(&packet_buffer_lock_); packet_buffer_.ForceSpsPpsIdrIsH264Keyframe(); } payload_type_map_.emplace( @@ -377,17 +376,19 @@ void RtpVideoStreamReceiver::AddReceiveCodec( absl::optional RtpVideoStreamReceiver::GetSyncInfo() const { Syncable::Info info; if (rtp_rtcp_->RemoteNTP(&info.capture_time_ntp_secs, - &info.capture_time_ntp_frac, nullptr, nullptr, + &info.capture_time_ntp_frac, + /*rtcp_arrival_time_secs=*/nullptr, + /*rtcp_arrival_time_frac=*/nullptr, &info.capture_time_source_clock) != 0) { return absl::nullopt; } { MutexLock lock(&sync_info_lock_); - if (!last_received_rtp_timestamp_ || !last_received_rtp_system_time_ms_) { + if (!last_received_rtp_timestamp_ || !last_received_rtp_system_time_) { return absl::nullopt; } info.latest_received_capture_timestamp = *last_received_rtp_timestamp_; - info.latest_receive_time_ms = *last_received_rtp_system_time_ms_; + info.latest_receive_time_ms = last_received_rtp_system_time_->ms(); } // Leaves info.current_delay_ms uninitialized. @@ -504,19 +505,9 @@ void RtpVideoStreamReceiver::OnReceivedPayloadData( const RtpPacketReceived& rtp_packet, const RTPVideoHeader& video) { RTC_DCHECK_RUN_ON(&worker_task_checker_); - auto packet = std::make_unique( - rtp_packet, video, ntp_estimator_.Estimate(rtp_packet.Timestamp()), - clock_->TimeInMilliseconds()); - - // Try to extrapolate absolute capture time if it is missing. - packet->packet_info.set_absolute_capture_time( - absolute_capture_time_receiver_.OnReceivePacket( - AbsoluteCaptureTimeReceiver::GetSource(packet->packet_info.ssrc(), - packet->packet_info.csrcs()), - packet->packet_info.rtp_timestamp(), - // Assume frequency is the same one for all video frames. - kVideoPayloadTypeFrequency, - packet->packet_info.absolute_capture_time())); + + auto packet = + std::make_unique(rtp_packet, video); RTPVideoHeader& video_header = packet->video_header; video_header.rotation = kVideoRotation_0; @@ -543,6 +534,12 @@ void RtpVideoStreamReceiver::OnReceivedPayloadData( ParseGenericDependenciesResult generic_descriptor_state = ParseGenericDependenciesExtension(rtp_packet, &video_header); + + if (!rtp_packet.recovered()) { + UpdatePacketReceiveTimestamps( + rtp_packet, video_header.frame_type == VideoFrameType::kVideoFrameKey); + } + if (generic_descriptor_state == kDropPacket) return; @@ -561,6 +558,8 @@ void RtpVideoStreamReceiver::OnReceivedPayloadData( video_header.color_space = last_color_space_; } } + video_header.video_frame_tracking_id = + rtp_packet.GetExtension(); if (loss_notification_controller_) { if (rtp_packet.recovered()) { @@ -636,7 +635,35 @@ void RtpVideoStreamReceiver::OnReceivedPayloadData( rtcp_feedback_buffer_.SendBufferedRtcpFeedback(); frame_counter_.Add(packet->timestamp); - OnInsertedPacket(packet_buffer_.InsertPacket(std::move(packet))); + video_coding::PacketBuffer::InsertResult insert_result; + { + MutexLock lock(&packet_buffer_lock_); + int64_t unwrapped_rtp_seq_num = + rtp_seq_num_unwrapper_.Unwrap(rtp_packet.SequenceNumber()); + auto& packet_info = + packet_infos_ + .emplace( + unwrapped_rtp_seq_num, + RtpPacketInfo( + rtp_packet.Ssrc(), rtp_packet.Csrcs(), + rtp_packet.Timestamp(), + /*audio_level=*/absl::nullopt, + rtp_packet.GetExtension(), + /*receive_time_ms=*/clock_->TimeInMilliseconds())) + .first->second; + + // Try to extrapolate absolute capture time if it is missing. + packet_info.set_absolute_capture_time( + absolute_capture_time_interpolator_.OnReceivePacket( + AbsoluteCaptureTimeInterpolator::GetSource(packet_info.ssrc(), + packet_info.csrcs()), + packet_info.rtp_timestamp(), + // Assume frequency is the same one for all video frames. + kVideoPayloadTypeFrequency, packet_info.absolute_capture_time())); + + insert_result = packet_buffer_.InsertPacket(std::move(packet)); + } + OnInsertedPacket(std::move(insert_result)); } void RtpVideoStreamReceiver::OnRecoveredPacket(const uint8_t* rtp_packet, @@ -669,35 +696,6 @@ void RtpVideoStreamReceiver::OnRtpPacket(const RtpPacketReceived& packet) { return; } - if (!packet.recovered()) { - // TODO(nisse): Exclude out-of-order packets? - int64_t now_ms = clock_->TimeInMilliseconds(); - { - MutexLock lock(&sync_info_lock_); - last_received_rtp_timestamp_ = packet.Timestamp(); - last_received_rtp_system_time_ms_ = now_ms; - } - // Periodically log the RTP header of incoming packets. - if (now_ms - last_packet_log_ms_ > kPacketLogIntervalMs) { - rtc::StringBuilder ss; - ss << "Packet received on SSRC: " << packet.Ssrc() - << " with payload type: " << static_cast(packet.PayloadType()) - << ", timestamp: " << packet.Timestamp() - << ", sequence number: " << packet.SequenceNumber() - << ", arrival time: " << packet.arrival_time_ms(); - int32_t time_offset; - if (packet.GetExtension(&time_offset)) { - ss << ", toffset: " << time_offset; - } - uint32_t send_time; - if (packet.GetExtension(&send_time)) { - ss << ", abs send time: " << send_time; - } - RTC_LOG(LS_INFO) << ss.str(); - last_packet_log_ms_ = now_ms; - } - } - ReceivePacket(packet); // Update receive statistics after ReceivePacket. @@ -752,76 +750,100 @@ bool RtpVideoStreamReceiver::IsDecryptable() const { void RtpVideoStreamReceiver::OnInsertedPacket( video_coding::PacketBuffer::InsertResult result) { - video_coding::PacketBuffer::Packet* first_packet = nullptr; - int max_nack_count; - int64_t min_recv_time; - int64_t max_recv_time; - std::vector> payloads; - RtpPacketInfos::vector_type packet_infos; - - bool frame_boundary = true; - for (auto& packet : result.packets) { - // PacketBuffer promisses frame boundaries are correctly set on each - // packet. Document that assumption with the DCHECKs. - RTC_DCHECK_EQ(frame_boundary, packet->is_first_packet_in_frame()); - if (packet->is_first_packet_in_frame()) { - first_packet = packet.get(); - max_nack_count = packet->times_nacked; - min_recv_time = packet->packet_info.receive_time_ms(); - max_recv_time = packet->packet_info.receive_time_ms(); - payloads.clear(); - packet_infos.clear(); - } else { - max_nack_count = std::max(max_nack_count, packet->times_nacked); - min_recv_time = - std::min(min_recv_time, packet->packet_info.receive_time_ms()); - max_recv_time = - std::max(max_recv_time, packet->packet_info.receive_time_ms()); - } - payloads.emplace_back(packet->video_payload); - packet_infos.push_back(packet->packet_info); - - frame_boundary = packet->is_last_packet_in_frame(); - if (packet->is_last_packet_in_frame()) { - auto depacketizer_it = payload_type_map_.find(first_packet->payload_type); - RTC_CHECK(depacketizer_it != payload_type_map_.end()); - - rtc::scoped_refptr bitstream = - depacketizer_it->second->AssembleFrame(payloads); - if (!bitstream) { - // Failed to assemble a frame. Discard and continue. - continue; + std::vector> assembled_frames; + { + MutexLock lock(&packet_buffer_lock_); + video_coding::PacketBuffer::Packet* first_packet = nullptr; + int max_nack_count; + int64_t min_recv_time; + int64_t max_recv_time; + std::vector> payloads; + RtpPacketInfos::vector_type packet_infos; + + bool frame_boundary = true; + for (auto& packet : result.packets) { + // PacketBuffer promisses frame boundaries are correctly set on each + // packet. Document that assumption with the DCHECKs. + RTC_DCHECK_EQ(frame_boundary, packet->is_first_packet_in_frame()); + int64_t unwrapped_rtp_seq_num = + rtp_seq_num_unwrapper_.Unwrap(packet->seq_num); + RTC_DCHECK(packet_infos_.count(unwrapped_rtp_seq_num) > 0); + RtpPacketInfo& packet_info = packet_infos_[unwrapped_rtp_seq_num]; + if (packet->is_first_packet_in_frame()) { + first_packet = packet.get(); + max_nack_count = packet->times_nacked; + min_recv_time = packet_info.receive_time().ms(); + max_recv_time = packet_info.receive_time().ms(); + payloads.clear(); + packet_infos.clear(); + } else { + max_nack_count = std::max(max_nack_count, packet->times_nacked); + min_recv_time = + std::min(min_recv_time, packet_info.receive_time().ms()); + max_recv_time = + std::max(max_recv_time, packet_info.receive_time().ms()); + } + payloads.emplace_back(packet->video_payload); + packet_infos.push_back(packet_info); + + frame_boundary = packet->is_last_packet_in_frame(); + if (packet->is_last_packet_in_frame()) { + auto depacketizer_it = + payload_type_map_.find(first_packet->payload_type); + RTC_CHECK(depacketizer_it != payload_type_map_.end()); + + rtc::scoped_refptr bitstream = + depacketizer_it->second->AssembleFrame(payloads); + if (!bitstream) { + // Failed to assemble a frame. Discard and continue. + continue; + } + + const video_coding::PacketBuffer::Packet& last_packet = *packet; + assembled_frames.push_back(std::make_unique( + first_packet->seq_num, // + last_packet.seq_num, // + last_packet.marker_bit, // + max_nack_count, // + min_recv_time, // + max_recv_time, // + first_packet->timestamp, // + ntp_estimator_.Estimate(first_packet->timestamp), // + last_packet.video_header.video_timing, // + first_packet->payload_type, // + first_packet->codec(), // + last_packet.video_header.rotation, // + last_packet.video_header.content_type, // + first_packet->video_header, // + last_packet.video_header.color_space, // + RtpPacketInfos(std::move(packet_infos)), // + std::move(bitstream))); } + } + RTC_DCHECK(frame_boundary); - const video_coding::PacketBuffer::Packet& last_packet = *packet; - OnAssembledFrame(std::make_unique( - first_packet->seq_num, // - last_packet.seq_num, // - last_packet.marker_bit, // - max_nack_count, // - min_recv_time, // - max_recv_time, // - first_packet->timestamp, // - first_packet->ntp_time_ms, // - last_packet.video_header.video_timing, // - first_packet->payload_type, // - first_packet->codec(), // - last_packet.video_header.rotation, // - last_packet.video_header.content_type, // - first_packet->video_header, // - last_packet.video_header.color_space, // - RtpPacketInfos(std::move(packet_infos)), // - std::move(bitstream))); + if (result.buffer_cleared) { + packet_infos_.clear(); } - } - RTC_DCHECK(frame_boundary); + } // packet_buffer_lock_ + if (result.buffer_cleared) { + { + MutexLock lock(&sync_info_lock_); + last_received_rtp_system_time_.reset(); + last_received_keyframe_rtp_system_time_.reset(); + last_received_keyframe_rtp_timestamp_.reset(); + } RequestKeyFrame(); } + + for (auto& frame : assembled_frames) { + OnAssembledFrame(std::move(frame)); + } } void RtpVideoStreamReceiver::OnAssembledFrame( - std::unique_ptr frame) { + std::unique_ptr frame) { RTC_DCHECK_RUN_ON(&network_tc_); RTC_DCHECK(frame); @@ -860,12 +882,10 @@ void RtpVideoStreamReceiver::OnAssembledFrame( if (frame_is_newer) { // When we reset the |reference_finder_| we don't want new picture ids // to overlap with old picture ids. To ensure that doesn't happen we - // start from the |last_completed_picture_id_| and add an offset in case - // of reordering. - reference_finder_ = - std::make_unique( - this, last_completed_picture_id_ + - std::numeric_limits::max()); + // start from the |last_completed_picture_id_| and add an offset in + // case of reordering. + reference_finder_ = std::make_unique( + last_completed_picture_id_ + std::numeric_limits::max()); current_codec_ = frame->codec_type(); } else { // Old frame from before the codec switch, discard it. @@ -886,28 +906,30 @@ void RtpVideoStreamReceiver::OnAssembledFrame( } else if (frame_transformer_delegate_) { frame_transformer_delegate_->TransformFrame(std::move(frame)); } else { - reference_finder_->ManageFrame(std::move(frame)); + OnCompleteFrames(reference_finder_->ManageFrame(std::move(frame))); } } -void RtpVideoStreamReceiver::OnCompleteFrame( - std::unique_ptr frame) { +void RtpVideoStreamReceiver::OnCompleteFrames( + RtpFrameReferenceFinder::ReturnVector frames) { { MutexLock lock(&last_seq_num_mutex_); - video_coding::RtpFrameObject* rtp_frame = - static_cast(frame.get()); - last_seq_num_for_pic_id_[rtp_frame->id.picture_id] = - rtp_frame->last_seq_num(); + for (const auto& frame : frames) { + RtpFrameObject* rtp_frame = static_cast(frame.get()); + last_seq_num_for_pic_id_[rtp_frame->Id()] = rtp_frame->last_seq_num(); + } + } + for (auto& frame : frames) { + last_completed_picture_id_ = + std::max(last_completed_picture_id_, frame->Id()); + complete_frame_callback_->OnCompleteFrame(std::move(frame)); } - last_completed_picture_id_ = - std::max(last_completed_picture_id_, frame->id.picture_id); - complete_frame_callback_->OnCompleteFrame(std::move(frame)); } void RtpVideoStreamReceiver::OnDecryptedFrame( - std::unique_ptr frame) { + std::unique_ptr frame) { MutexLock lock(&reference_finder_lock_); - reference_finder_->ManageFrame(std::move(frame)); + OnCompleteFrames(reference_finder_->ManageFrame(std::move(frame))); } void RtpVideoStreamReceiver::OnDecryptionStatusChange( @@ -931,7 +953,7 @@ void RtpVideoStreamReceiver::SetDepacketizerToDecoderFrameTransformer( rtc::scoped_refptr frame_transformer) { RTC_DCHECK_RUN_ON(&network_tc_); frame_transformer_delegate_ = - new rtc::RefCountedObject( + rtc::make_ref_counted( this, std::move(frame_transformer), rtc::Thread::Current(), config_.rtp.remote_ssrc); frame_transformer_delegate_->Init(); @@ -943,12 +965,21 @@ void RtpVideoStreamReceiver::UpdateRtt(int64_t max_rtt_ms) { } absl::optional RtpVideoStreamReceiver::LastReceivedPacketMs() const { - return packet_buffer_.LastReceivedPacketMs(); + MutexLock lock(&sync_info_lock_); + if (last_received_rtp_system_time_) { + return absl::optional(last_received_rtp_system_time_->ms()); + } + return absl::nullopt; } absl::optional RtpVideoStreamReceiver::LastReceivedKeyframePacketMs() const { - return packet_buffer_.LastReceivedKeyframePacketMs(); + MutexLock lock(&sync_info_lock_); + if (last_received_keyframe_rtp_system_time_) { + return absl::optional( + last_received_keyframe_rtp_system_time_->ms()); + } + return absl::nullopt; } void RtpVideoStreamReceiver::AddSecondarySink(RtpPacketSinkInterface* sink) { @@ -972,9 +1003,9 @@ void RtpVideoStreamReceiver::RemoveSecondarySink( } void RtpVideoStreamReceiver::ManageFrame( - std::unique_ptr frame) { + std::unique_ptr frame) { MutexLock lock(&reference_finder_lock_); - reference_finder_->ManageFrame(std::move(frame)); + OnCompleteFrames(reference_finder_->ManageFrame(std::move(frame))); } void RtpVideoStreamReceiver::ReceivePacket(const RtpPacketReceived& packet) { @@ -1029,9 +1060,16 @@ void RtpVideoStreamReceiver::ParseAndHandleEncapsulatingHeader( void RtpVideoStreamReceiver::NotifyReceiverOfEmptyPacket(uint16_t seq_num) { { MutexLock lock(&reference_finder_lock_); - reference_finder_->PaddingReceived(seq_num); + OnCompleteFrames(reference_finder_->PaddingReceived(seq_num)); + } + + video_coding::PacketBuffer::InsertResult insert_result; + { + MutexLock lock(&packet_buffer_lock_); + insert_result = packet_buffer_.InsertPadding(seq_num); } - OnInsertedPacket(packet_buffer_.InsertPadding(seq_num)); + OnInsertedPacket(std::move(insert_result)); + if (nack_module_) { nack_module_->OnReceivedPacket(seq_num, /* is_keyframe = */ false, /* is _recovered = */ false); @@ -1078,7 +1116,7 @@ bool RtpVideoStreamReceiver::DeliverRtcp(const uint8_t* rtcp_packet, absl::optional remote_to_local_clock_offset_ms = ntp_estimator_.EstimateRemoteToLocalClockOffsetMs(); if (remote_to_local_clock_offset_ms.has_value()) { - absolute_capture_time_receiver_.SetRemoteToLocalClockOffset( + capture_clock_offset_updater_.SetRemoteToLocalClockOffset( Int64MsToQ32x32(*remote_to_local_clock_offset_ms)); } } @@ -1113,7 +1151,13 @@ void RtpVideoStreamReceiver::FrameDecoded(int64_t picture_id) { } } if (seq_num != -1) { - packet_buffer_.ClearTo(seq_num); + { + MutexLock lock(&packet_buffer_lock_); + packet_buffer_.ClearTo(seq_num); + int64_t unwrapped_rtp_seq_num = rtp_seq_num_unwrapper_.Unwrap(seq_num); + packet_infos_.erase(packet_infos_.begin(), + packet_infos_.upper_bound(unwrapped_rtp_seq_num)); + } MutexLock lock(&reference_finder_lock_); reference_finder_->ClearTo(seq_num); } @@ -1184,4 +1228,40 @@ void RtpVideoStreamReceiver::InsertSpsPpsIntoTracker(uint8_t payload_type) { sprop_decoder.pps_nalu()); } +void RtpVideoStreamReceiver::UpdatePacketReceiveTimestamps( + const RtpPacketReceived& packet, + bool is_keyframe) { + Timestamp now = clock_->CurrentTime(); + { + MutexLock lock(&sync_info_lock_); + if (is_keyframe || + last_received_keyframe_rtp_timestamp_ == packet.Timestamp()) { + last_received_keyframe_rtp_timestamp_ = packet.Timestamp(); + last_received_keyframe_rtp_system_time_ = now; + } + last_received_rtp_system_time_ = now; + last_received_rtp_timestamp_ = packet.Timestamp(); + } + + // Periodically log the RTP header of incoming packets. + if (now.ms() - last_packet_log_ms_ > kPacketLogIntervalMs) { + rtc::StringBuilder ss; + ss << "Packet received on SSRC: " << packet.Ssrc() + << " with payload type: " << static_cast(packet.PayloadType()) + << ", timestamp: " << packet.Timestamp() + << ", sequence number: " << packet.SequenceNumber() + << ", arrival time: " << ToString(packet.arrival_time()); + int32_t time_offset; + if (packet.GetExtension(&time_offset)) { + ss << ", toffset: " << time_offset; + } + uint32_t send_time; + if (packet.GetExtension(&send_time)) { + ss << ", abs send time: " << send_time; + } + RTC_LOG(LS_INFO) << ss.str(); + last_packet_log_ms_ = now.ms(); + } +} + } // namespace webrtc diff --git a/video/rtp_video_stream_receiver.h b/video/rtp_video_stream_receiver.h index 40958c48ec..b3d62f34a4 100644 --- a/video/rtp_video_stream_receiver.h +++ b/video/rtp_video_stream_receiver.h @@ -21,6 +21,8 @@ #include "absl/types/optional.h" #include "api/array_view.h" #include "api/crypto/frame_decryptor_interface.h" +#include "api/sequence_checker.h" +#include "api/units/timestamp.h" #include "api/video/color_space.h" #include "api/video_codecs/video_codec.h" #include "call/rtp_packet_sink_interface.h" @@ -31,7 +33,8 @@ #include "modules/rtp_rtcp/include/rtp_header_extension_map.h" #include "modules/rtp_rtcp/include/rtp_rtcp.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" -#include "modules/rtp_rtcp/source/absolute_capture_time_receiver.h" +#include "modules/rtp_rtcp/source/absolute_capture_time_interpolator.h" +#include "modules/rtp_rtcp/source/capture_clock_offset_updater.h" #include "modules/rtp_rtcp/source/rtp_dependency_descriptor_extension.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "modules/rtp_rtcp/source/rtp_video_header.h" @@ -45,10 +48,8 @@ #include "rtc_base/experiments/field_trial_parser.h" #include "rtc_base/numerics/sequence_number_util.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" #include "video/buffered_frame_decryptor.h" #include "video/rtp_video_stream_receiver_frame_transformer_delegate.h" @@ -68,11 +69,18 @@ class RtpVideoStreamReceiver : public LossNotificationSender, public RecoveredPacketReceiver, public RtpPacketSinkInterface, public KeyFrameRequestSender, - public video_coding::OnCompleteFrameCallback, public OnDecryptedFrameCallback, public OnDecryptionStatusChangeCallback, public RtpVideoFrameReceiver { public: + // A complete frame is a frame which has received all its packets and all its + // references are known. + class OnCompleteFrameCallback { + public: + virtual ~OnCompleteFrameCallback() {} + virtual void OnCompleteFrame(std::unique_ptr frame) = 0; + }; + // DEPRECATED due to dependency on ReceiveStatisticsProxy. RtpVideoStreamReceiver( Clock* clock, @@ -90,7 +98,7 @@ class RtpVideoStreamReceiver : public LossNotificationSender, // The KeyFrameRequestSender is optional; if not provided, key frame // requests are sent via the internal RtpRtcp module. KeyFrameRequestSender* keyframe_request_sender, - video_coding::OnCompleteFrameCallback* complete_frame_callback, + OnCompleteFrameCallback* complete_frame_callback, rtc::scoped_refptr frame_decryptor, rtc::scoped_refptr frame_transformer); @@ -111,7 +119,7 @@ class RtpVideoStreamReceiver : public LossNotificationSender, // The KeyFrameRequestSender is optional; if not provided, key frame // requests are sent via the internal RtpRtcp module. KeyFrameRequestSender* keyframe_request_sender, - video_coding::OnCompleteFrameCallback* complete_frame_callback, + OnCompleteFrameCallback* complete_frame_callback, rtc::scoped_refptr frame_decryptor, rtc::scoped_refptr frame_transformer); ~RtpVideoStreamReceiver() override; @@ -172,13 +180,10 @@ class RtpVideoStreamReceiver : public LossNotificationSender, // Don't use, still experimental. void RequestPacketRetransmit(const std::vector& sequence_numbers); - // Implements OnCompleteFrameCallback. - void OnCompleteFrame( - std::unique_ptr frame) override; + void OnCompleteFrames(RtpFrameReferenceFinder::ReturnVector frames); // Implements OnDecryptedFrameCallback. - void OnDecryptedFrame( - std::unique_ptr frame) override; + void OnDecryptedFrame(std::unique_ptr frame) override; // Implements OnDecryptionStatusChangeCallback. void OnDecryptionStatusChange( @@ -209,8 +214,7 @@ class RtpVideoStreamReceiver : public LossNotificationSender, private: // Implements RtpVideoFrameReceiver. - void ManageFrame( - std::unique_ptr frame) override; + void ManageFrame(std::unique_ptr frame) override; // Used for buffering RTCP feedback messages and sending them all together. // Note: @@ -306,7 +310,11 @@ class RtpVideoStreamReceiver : public LossNotificationSender, ParseGenericDependenciesResult ParseGenericDependenciesExtension( const RtpPacketReceived& rtp_packet, RTPVideoHeader* video_header) RTC_RUN_ON(worker_task_checker_); - void OnAssembledFrame(std::unique_ptr frame); + void OnAssembledFrame(std::unique_ptr frame) + RTC_LOCKS_EXCLUDED(packet_buffer_lock_); + void UpdatePacketReceiveTimestamps(const RtpPacketReceived& packet, + bool is_keyframe) + RTC_RUN_ON(worker_task_checker_); Clock* const clock_; // Ownership of this object lies with VideoReceiveStream, which owns |this|. @@ -330,14 +338,15 @@ class RtpVideoStreamReceiver : public LossNotificationSender, const std::unique_ptr rtp_rtcp_; - video_coding::OnCompleteFrameCallback* complete_frame_callback_; + OnCompleteFrameCallback* complete_frame_callback_; KeyFrameRequestSender* const keyframe_request_sender_; RtcpFeedbackBuffer rtcp_feedback_buffer_; std::unique_ptr nack_module_; std::unique_ptr loss_notification_controller_; - video_coding::PacketBuffer packet_buffer_; + mutable Mutex packet_buffer_lock_; + video_coding::PacketBuffer packet_buffer_ RTC_GUARDED_BY(packet_buffer_lock_); UniqueTimestampCounter frame_counter_ RTC_GUARDED_BY(worker_task_checker_); SeqNumUnwrapper frame_id_unwrapper_ RTC_GUARDED_BY(worker_task_checker_); @@ -353,7 +362,7 @@ class RtpVideoStreamReceiver : public LossNotificationSender, RTC_GUARDED_BY(worker_task_checker_); Mutex reference_finder_lock_; - std::unique_ptr reference_finder_ + std::unique_ptr reference_finder_ RTC_GUARDED_BY(reference_finder_lock_); absl::optional current_codec_; uint32_t last_assembled_frame_rtp_timestamp_; @@ -382,12 +391,16 @@ class RtpVideoStreamReceiver : public LossNotificationSender, mutable Mutex sync_info_lock_; absl::optional last_received_rtp_timestamp_ RTC_GUARDED_BY(sync_info_lock_); - absl::optional last_received_rtp_system_time_ms_ + absl::optional last_received_keyframe_rtp_timestamp_ + RTC_GUARDED_BY(sync_info_lock_); + absl::optional last_received_rtp_system_time_ + RTC_GUARDED_BY(sync_info_lock_); + absl::optional last_received_keyframe_rtp_system_time_ RTC_GUARDED_BY(sync_info_lock_); // Used to validate the buffered frame decryptor is always run on the correct // thread. - rtc::ThreadChecker network_tc_; + SequenceChecker network_tc_; // Handles incoming encrypted frames and forwards them to the // rtp_reference_finder if they are decryptable. std::unique_ptr buffered_frame_decryptor_ @@ -395,13 +408,21 @@ class RtpVideoStreamReceiver : public LossNotificationSender, std::atomic frames_decryptable_; absl::optional last_color_space_; - AbsoluteCaptureTimeReceiver absolute_capture_time_receiver_ + AbsoluteCaptureTimeInterpolator absolute_capture_time_interpolator_ + RTC_GUARDED_BY(worker_task_checker_); + + CaptureClockOffsetUpdater capture_clock_offset_updater_ RTC_GUARDED_BY(worker_task_checker_); int64_t last_completed_picture_id_ = 0; rtc::scoped_refptr frame_transformer_delegate_; + + SeqNumUnwrapper rtp_seq_num_unwrapper_ + RTC_GUARDED_BY(packet_buffer_lock_); + std::map packet_infos_ + RTC_GUARDED_BY(packet_buffer_lock_); }; } // namespace webrtc diff --git a/video/rtp_video_stream_receiver2.cc b/video/rtp_video_stream_receiver2.cc index 63d8c3835d..4b43247b18 100644 --- a/video/rtp_video_stream_receiver2.cc +++ b/video/rtp_video_stream_receiver2.cc @@ -36,7 +36,6 @@ #include "modules/rtp_rtcp/source/rtp_rtcp_config.h" #include "modules/rtp_rtcp/source/video_rtp_depacketizer.h" #include "modules/rtp_rtcp/source/video_rtp_depacketizer_raw.h" -#include "modules/utility/include/process_thread.h" #include "modules/video_coding/frame_object.h" #include "modules/video_coding/h264_sprop_parameter_sets.h" #include "modules/video_coding/h264_sps_pps_tracker.h" @@ -49,7 +48,6 @@ #include "system_wrappers/include/field_trial.h" #include "system_wrappers/include/metrics.h" #include "system_wrappers/include/ntp_time.h" -#include "video/receive_statistics_proxy2.h" namespace webrtc { @@ -114,6 +112,7 @@ std::unique_ptr MaybeConstructNackModule( if (config.rtp.nack.rtp_history_ms == 0) return nullptr; + // TODO(bugs.webrtc.org/12420): pass rtp_history_ms to the nack module. return std::make_unique(current_queue, clock, nack_sender, keyframe_request_sender); } @@ -133,17 +132,18 @@ RtpVideoStreamReceiver2::RtcpFeedbackBuffer::RtcpFeedbackBuffer( RTC_DCHECK(key_frame_request_sender_); RTC_DCHECK(nack_sender_); RTC_DCHECK(loss_notification_sender_); + packet_sequence_checker_.Detach(); } void RtpVideoStreamReceiver2::RtcpFeedbackBuffer::RequestKeyFrame() { - RTC_DCHECK_RUN_ON(&worker_task_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); request_key_frame_ = true; } void RtpVideoStreamReceiver2::RtcpFeedbackBuffer::SendNack( const std::vector& sequence_numbers, bool buffering_allowed) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); RTC_DCHECK(!sequence_numbers.empty()); nack_sequence_numbers_.insert(nack_sequence_numbers_.end(), sequence_numbers.cbegin(), @@ -160,7 +160,7 @@ void RtpVideoStreamReceiver2::RtcpFeedbackBuffer::SendLossNotification( uint16_t last_received_seq_num, bool decodability_flag, bool buffering_allowed) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); RTC_DCHECK(buffering_allowed); RTC_DCHECK(!lntf_state_) << "SendLossNotification() called twice in a row with no call to " @@ -170,7 +170,7 @@ void RtpVideoStreamReceiver2::RtcpFeedbackBuffer::SendLossNotification( } void RtpVideoStreamReceiver2::RtcpFeedbackBuffer::SendBufferedRtcpFeedback() { - RTC_DCHECK_RUN_ON(&worker_task_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); bool request_key_frame = false; std::vector nack_sequence_numbers; @@ -210,16 +210,14 @@ RtpVideoStreamReceiver2::RtpVideoStreamReceiver2( ReceiveStatistics* rtp_receive_statistics, RtcpPacketTypeCounterObserver* rtcp_packet_type_counter_observer, RtcpCnameCallback* rtcp_cname_callback, - ProcessThread* process_thread, NackSender* nack_sender, KeyFrameRequestSender* keyframe_request_sender, - video_coding::OnCompleteFrameCallback* complete_frame_callback, + OnCompleteFrameCallback* complete_frame_callback, rtc::scoped_refptr frame_decryptor, rtc::scoped_refptr frame_transformer) : clock_(clock), config_(*config), packet_router_(packet_router), - process_thread_(process_thread), ntp_estimator_(clock), rtp_header_extensions_(config_.rtp.extensions), forced_playout_delay_max_ms_("max_ms", absl::nullopt), @@ -249,10 +247,12 @@ RtpVideoStreamReceiver2::RtpVideoStreamReceiver2( clock_, &rtcp_feedback_buffer_, &rtcp_feedback_buffer_)), - packet_buffer_(clock_, kPacketBufferStartSize, PacketBufferMaxSize()), + packet_buffer_(kPacketBufferStartSize, PacketBufferMaxSize()), + reference_finder_(std::make_unique()), has_received_frame_(false), frames_decryptable_(false), - absolute_capture_time_receiver_(clock) { + absolute_capture_time_interpolator_(clock) { + packet_sequence_checker_.Detach(); constexpr bool remb_candidate = true; if (packet_router_) packet_router_->AddReceiveRtpModule(rtp_rtcp_.get(), remb_candidate); @@ -286,17 +286,12 @@ RtpVideoStreamReceiver2::RtpVideoStreamReceiver2( {&forced_playout_delay_max_ms_, &forced_playout_delay_min_ms_}, field_trial::FindFullName("WebRTC-ForcePlayoutDelay")); - process_thread_->RegisterModule(rtp_rtcp_.get(), RTC_FROM_HERE); - if (config_.rtp.lntf.enabled) { loss_notification_controller_ = std::make_unique(&rtcp_feedback_buffer_, &rtcp_feedback_buffer_); } - reference_finder_ = - std::make_unique(this); - // Only construct the encrypted receiver if frame encryption is enabled. if (config_.crypto_options.sframe.require_frame_encryption) { buffered_frame_decryptor_ = @@ -307,19 +302,15 @@ RtpVideoStreamReceiver2::RtpVideoStreamReceiver2( } if (frame_transformer) { - frame_transformer_delegate_ = new rtc::RefCountedObject< - RtpVideoStreamReceiverFrameTransformerDelegate>( - this, std::move(frame_transformer), rtc::Thread::Current(), - config_.rtp.remote_ssrc); + frame_transformer_delegate_ = + rtc::make_ref_counted( + this, std::move(frame_transformer), rtc::Thread::Current(), + config_.rtp.remote_ssrc); frame_transformer_delegate_->Init(); } } RtpVideoStreamReceiver2::~RtpVideoStreamReceiver2() { - RTC_DCHECK(secondary_sinks_.empty()); - - process_thread_->DeRegisterModule(rtp_rtcp_.get()); - if (packet_router_) packet_router_->RemoveReceiveRtpModule(rtp_rtcp_.get()); UpdateHistograms(); @@ -332,7 +323,7 @@ void RtpVideoStreamReceiver2::AddReceiveCodec( const VideoCodec& video_codec, const std::map& codec_params, bool raw_payload) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); if (codec_params.count(cricket::kH264FmtpSpsPpsIdrInKeyframe) || field_trial::IsEnabled("WebRTC-SpsPpsIdrIsH264Keyframe")) { packet_buffer_.ForceSpsPpsIdrIsH264Keyframe(); @@ -345,24 +336,27 @@ void RtpVideoStreamReceiver2::AddReceiveCodec( } absl::optional RtpVideoStreamReceiver2::GetSyncInfo() const { - RTC_DCHECK_RUN_ON(&worker_task_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); Syncable::Info info; if (rtp_rtcp_->RemoteNTP(&info.capture_time_ntp_secs, - &info.capture_time_ntp_frac, nullptr, nullptr, + &info.capture_time_ntp_frac, + /*rtcp_arrival_time_secs=*/nullptr, + /*rtcp_arrival_time_frac=*/nullptr, &info.capture_time_source_clock) != 0) { return absl::nullopt; } - if (!last_received_rtp_timestamp_ || !last_received_rtp_system_time_ms_) { + if (!last_received_rtp_timestamp_ || !last_received_rtp_system_time_) { return absl::nullopt; } info.latest_received_capture_timestamp = *last_received_rtp_timestamp_; - info.latest_receive_time_ms = *last_received_rtp_system_time_ms_; + info.latest_receive_time_ms = last_received_rtp_system_time_->ms(); // Leaves info.current_delay_ms uninitialized. return info; } +// RTC_RUN_ON(packet_sequence_checker_) RtpVideoStreamReceiver2::ParseGenericDependenciesResult RtpVideoStreamReceiver2::ParseGenericDependenciesExtension( const RtpPacketReceived& rtp_packet, @@ -472,20 +466,32 @@ void RtpVideoStreamReceiver2::OnReceivedPayloadData( rtc::CopyOnWriteBuffer codec_payload, const RtpPacketReceived& rtp_packet, const RTPVideoHeader& video) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); - auto packet = std::make_unique( - rtp_packet, video, ntp_estimator_.Estimate(rtp_packet.Timestamp()), - clock_->TimeInMilliseconds()); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + + auto packet = + std::make_unique(rtp_packet, video); + + int64_t unwrapped_rtp_seq_num = + rtp_seq_num_unwrapper_.Unwrap(rtp_packet.SequenceNumber()); + auto& packet_info = + packet_infos_ + .emplace( + unwrapped_rtp_seq_num, + RtpPacketInfo( + rtp_packet.Ssrc(), rtp_packet.Csrcs(), rtp_packet.Timestamp(), + /*audio_level=*/absl::nullopt, + rtp_packet.GetExtension(), + /*receive_time_ms=*/clock_->CurrentTime())) + .first->second; // Try to extrapolate absolute capture time if it is missing. - packet->packet_info.set_absolute_capture_time( - absolute_capture_time_receiver_.OnReceivePacket( - AbsoluteCaptureTimeReceiver::GetSource(packet->packet_info.ssrc(), - packet->packet_info.csrcs()), - packet->packet_info.rtp_timestamp(), + packet_info.set_absolute_capture_time( + absolute_capture_time_interpolator_.OnReceivePacket( + AbsoluteCaptureTimeInterpolator::GetSource(packet_info.ssrc(), + packet_info.csrcs()), + packet_info.rtp_timestamp(), // Assume frequency is the same one for all video frames. - kVideoPayloadTypeFrequency, - packet->packet_info.absolute_capture_time())); + kVideoPayloadTypeFrequency, packet_info.absolute_capture_time())); RTPVideoHeader& video_header = packet->video_header; video_header.rotation = kVideoRotation_0; @@ -512,6 +518,12 @@ void RtpVideoStreamReceiver2::OnReceivedPayloadData( ParseGenericDependenciesResult generic_descriptor_state = ParseGenericDependenciesExtension(rtp_packet, &video_header); + + if (!rtp_packet.recovered()) { + UpdatePacketReceiveTimestamps( + rtp_packet, video_header.frame_type == VideoFrameType::kVideoFrameKey); + } + if (generic_descriptor_state == kDropPacket) return; @@ -530,6 +542,8 @@ void RtpVideoStreamReceiver2::OnReceivedPayloadData( video_header.color_space = last_color_space_; } } + video_header.video_frame_tracking_id = + rtp_packet.GetExtension(); if (loss_notification_controller_) { if (rtp_packet.recovered()) { @@ -610,6 +624,8 @@ void RtpVideoStreamReceiver2::OnReceivedPayloadData( void RtpVideoStreamReceiver2::OnRecoveredPacket(const uint8_t* rtp_packet, size_t rtp_packet_length) { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + RtpPacketReceived packet; if (!packet.Parse(rtp_packet, rtp_packet_length)) return; @@ -632,39 +648,10 @@ void RtpVideoStreamReceiver2::OnRecoveredPacket(const uint8_t* rtp_packet, // This method handles both regular RTP packets and packets recovered // via FlexFEC. void RtpVideoStreamReceiver2::OnRtpPacket(const RtpPacketReceived& packet) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); - if (!receiving_) { + if (!receiving_) return; - } - - if (!packet.recovered()) { - // TODO(nisse): Exclude out-of-order packets? - int64_t now_ms = clock_->TimeInMilliseconds(); - - last_received_rtp_timestamp_ = packet.Timestamp(); - last_received_rtp_system_time_ms_ = now_ms; - - // Periodically log the RTP header of incoming packets. - if (now_ms - last_packet_log_ms_ > kPacketLogIntervalMs) { - rtc::StringBuilder ss; - ss << "Packet received on SSRC: " << packet.Ssrc() - << " with payload type: " << static_cast(packet.PayloadType()) - << ", timestamp: " << packet.Timestamp() - << ", sequence number: " << packet.SequenceNumber() - << ", arrival time: " << packet.arrival_time_ms(); - int32_t time_offset; - if (packet.GetExtension(&time_offset)) { - ss << ", toffset: " << time_offset; - } - uint32_t send_time; - if (packet.GetExtension(&send_time)) { - ss << ", abs send time: " << send_time; - } - RTC_LOG(LS_INFO) << ss.str(); - last_packet_log_ms_ = now_ms; - } - } ReceivePacket(packet); @@ -675,8 +662,8 @@ void RtpVideoStreamReceiver2::OnRtpPacket(const RtpPacketReceived& packet) { rtp_receive_statistics_->OnRtpPacket(packet); } - for (RtpPacketSinkInterface* secondary_sink : secondary_sinks_) { - secondary_sink->OnRtpPacket(packet); + if (config_.rtp.packet_sink_) { + config_.rtp.packet_sink_->OnRtpPacket(packet); } } @@ -721,6 +708,7 @@ bool RtpVideoStreamReceiver2::IsDecryptable() const { return frames_decryptable_; } +// RTC_RUN_ON(packet_sequence_checker_) void RtpVideoStreamReceiver2::OnInsertedPacket( video_coding::PacketBuffer::InsertResult result) { RTC_DCHECK_RUN_ON(&worker_task_checker_); @@ -736,22 +724,24 @@ void RtpVideoStreamReceiver2::OnInsertedPacket( // PacketBuffer promisses frame boundaries are correctly set on each // packet. Document that assumption with the DCHECKs. RTC_DCHECK_EQ(frame_boundary, packet->is_first_packet_in_frame()); + int64_t unwrapped_rtp_seq_num = + rtp_seq_num_unwrapper_.Unwrap(packet->seq_num); + RTC_DCHECK(packet_infos_.count(unwrapped_rtp_seq_num) > 0); + RtpPacketInfo& packet_info = packet_infos_[unwrapped_rtp_seq_num]; if (packet->is_first_packet_in_frame()) { first_packet = packet.get(); max_nack_count = packet->times_nacked; - min_recv_time = packet->packet_info.receive_time_ms(); - max_recv_time = packet->packet_info.receive_time_ms(); + min_recv_time = packet_info.receive_time().ms(); + max_recv_time = packet_info.receive_time().ms(); payloads.clear(); packet_infos.clear(); } else { max_nack_count = std::max(max_nack_count, packet->times_nacked); - min_recv_time = - std::min(min_recv_time, packet->packet_info.receive_time_ms()); - max_recv_time = - std::max(max_recv_time, packet->packet_info.receive_time_ms()); + min_recv_time = std::min(min_recv_time, packet_info.receive_time().ms()); + max_recv_time = std::max(max_recv_time, packet_info.receive_time().ms()); } payloads.emplace_back(packet->video_payload); - packet_infos.push_back(packet->packet_info); + packet_infos.push_back(packet_info); frame_boundary = packet->is_last_packet_in_frame(); if (packet->is_last_packet_in_frame()) { @@ -766,35 +756,39 @@ void RtpVideoStreamReceiver2::OnInsertedPacket( } const video_coding::PacketBuffer::Packet& last_packet = *packet; - OnAssembledFrame(std::make_unique( - first_packet->seq_num, // - last_packet.seq_num, // - last_packet.marker_bit, // - max_nack_count, // - min_recv_time, // - max_recv_time, // - first_packet->timestamp, // - first_packet->ntp_time_ms, // - last_packet.video_header.video_timing, // - first_packet->payload_type, // - first_packet->codec(), // - last_packet.video_header.rotation, // - last_packet.video_header.content_type, // - first_packet->video_header, // - last_packet.video_header.color_space, // - RtpPacketInfos(std::move(packet_infos)), // + OnAssembledFrame(std::make_unique( + first_packet->seq_num, // + last_packet.seq_num, // + last_packet.marker_bit, // + max_nack_count, // + min_recv_time, // + max_recv_time, // + first_packet->timestamp, // + ntp_estimator_.Estimate(first_packet->timestamp), // + last_packet.video_header.video_timing, // + first_packet->payload_type, // + first_packet->codec(), // + last_packet.video_header.rotation, // + last_packet.video_header.content_type, // + first_packet->video_header, // + last_packet.video_header.color_space, // + RtpPacketInfos(std::move(packet_infos)), // std::move(bitstream))); } } RTC_DCHECK(frame_boundary); if (result.buffer_cleared) { + last_received_rtp_system_time_.reset(); + last_received_keyframe_rtp_system_time_.reset(); + last_received_keyframe_rtp_timestamp_.reset(); + packet_infos_.clear(); RequestKeyFrame(); } } +// RTC_RUN_ON(packet_sequence_checker_) void RtpVideoStreamReceiver2::OnAssembledFrame( - std::unique_ptr frame) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); + std::unique_ptr frame) { RTC_DCHECK(frame); const absl::optional& descriptor = @@ -833,10 +827,8 @@ void RtpVideoStreamReceiver2::OnAssembledFrame( // to overlap with old picture ids. To ensure that doesn't happen we // start from the |last_completed_picture_id_| and add an offset in case // of reordering. - reference_finder_ = - std::make_unique( - this, last_completed_picture_id_ + - std::numeric_limits::max()); + reference_finder_ = std::make_unique( + last_completed_picture_id_ + std::numeric_limits::max()); current_codec_ = frame->codec_type(); } else { // Old frame from before the codec switch, discard it. @@ -857,27 +849,27 @@ void RtpVideoStreamReceiver2::OnAssembledFrame( } else if (frame_transformer_delegate_) { frame_transformer_delegate_->TransformFrame(std::move(frame)); } else { - reference_finder_->ManageFrame(std::move(frame)); + OnCompleteFrames(reference_finder_->ManageFrame(std::move(frame))); } } -void RtpVideoStreamReceiver2::OnCompleteFrame( - std::unique_ptr frame) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); - video_coding::RtpFrameObject* rtp_frame = - static_cast(frame.get()); - last_seq_num_for_pic_id_[rtp_frame->id.picture_id] = - rtp_frame->last_seq_num(); - - last_completed_picture_id_ = - std::max(last_completed_picture_id_, frame->id.picture_id); - complete_frame_callback_->OnCompleteFrame(std::move(frame)); +// RTC_RUN_ON(packet_sequence_checker_) +void RtpVideoStreamReceiver2::OnCompleteFrames( + RtpFrameReferenceFinder::ReturnVector frames) { + for (auto& frame : frames) { + RtpFrameObject* rtp_frame = static_cast(frame.get()); + last_seq_num_for_pic_id_[rtp_frame->Id()] = rtp_frame->last_seq_num(); + + last_completed_picture_id_ = + std::max(last_completed_picture_id_, frame->Id()); + complete_frame_callback_->OnCompleteFrame(std::move(frame)); + } } void RtpVideoStreamReceiver2::OnDecryptedFrame( - std::unique_ptr frame) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); - reference_finder_->ManageFrame(std::move(frame)); + std::unique_ptr frame) { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + OnCompleteFrames(reference_finder_->ManageFrame(std::move(frame))); } void RtpVideoStreamReceiver2::OnDecryptionStatusChange( @@ -891,7 +883,9 @@ void RtpVideoStreamReceiver2::OnDecryptionStatusChange( void RtpVideoStreamReceiver2::SetFrameDecryptor( rtc::scoped_refptr frame_decryptor) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); + // TODO(bugs.webrtc.org/11993): Update callers or post the operation over to + // the network thread. + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); if (buffered_frame_decryptor_ == nullptr) { buffered_frame_decryptor_ = std::make_unique(this, this); @@ -903,7 +897,7 @@ void RtpVideoStreamReceiver2::SetDepacketizerToDecoderFrameTransformer( rtc::scoped_refptr frame_transformer) { RTC_DCHECK_RUN_ON(&worker_task_checker_); frame_transformer_delegate_ = - new rtc::RefCountedObject( + rtc::make_ref_counted( this, std::move(frame_transformer), rtc::Thread::Current(), config_.rtp.remote_ssrc); frame_transformer_delegate_->Init(); @@ -916,40 +910,30 @@ void RtpVideoStreamReceiver2::UpdateRtt(int64_t max_rtt_ms) { } absl::optional RtpVideoStreamReceiver2::LastReceivedPacketMs() const { - return packet_buffer_.LastReceivedPacketMs(); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + if (last_received_rtp_system_time_) { + return absl::optional(last_received_rtp_system_time_->ms()); + } + return absl::nullopt; } absl::optional RtpVideoStreamReceiver2::LastReceivedKeyframePacketMs() const { - return packet_buffer_.LastReceivedKeyframePacketMs(); -} - -void RtpVideoStreamReceiver2::AddSecondarySink(RtpPacketSinkInterface* sink) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); - RTC_DCHECK(!absl::c_linear_search(secondary_sinks_, sink)); - secondary_sinks_.push_back(sink); -} - -void RtpVideoStreamReceiver2::RemoveSecondarySink( - const RtpPacketSinkInterface* sink) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); - auto it = absl::c_find(secondary_sinks_, sink); - if (it == secondary_sinks_.end()) { - // We might be rolling-back a call whose setup failed mid-way. In such a - // case, it's simpler to remove "everything" rather than remember what - // has already been added. - RTC_LOG(LS_WARNING) << "Removal of unknown sink."; - return; + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + if (last_received_keyframe_rtp_system_time_) { + return absl::optional( + last_received_keyframe_rtp_system_time_->ms()); } - secondary_sinks_.erase(it); + return absl::nullopt; } void RtpVideoStreamReceiver2::ManageFrame( - std::unique_ptr frame) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); - reference_finder_->ManageFrame(std::move(frame)); + std::unique_ptr frame) { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + OnCompleteFrames(reference_finder_->ManageFrame(std::move(frame))); } +// RTC_RUN_ON(packet_sequence_checker_) void RtpVideoStreamReceiver2::ReceivePacket(const RtpPacketReceived& packet) { RTC_DCHECK_RUN_ON(&worker_task_checker_); if (packet.payload_size() == 0) { @@ -979,9 +963,9 @@ void RtpVideoStreamReceiver2::ReceivePacket(const RtpPacketReceived& packet) { parsed_payload->video_header); } +// RTC_RUN_ON(packet_sequence_checker_) void RtpVideoStreamReceiver2::ParseAndHandleEncapsulatingHeader( const RtpPacketReceived& packet) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); if (packet.PayloadType() == config_.rtp.red_payload_type && packet.payload_size() > 0) { if (packet.payload()[0] == config_.rtp.ulpfec_payload_type) { @@ -1000,10 +984,11 @@ void RtpVideoStreamReceiver2::ParseAndHandleEncapsulatingHeader( // In the case of a video stream without picture ids and no rtx the // RtpFrameReferenceFinder will need to know about padding to // correctly calculate frame references. +// RTC_RUN_ON(packet_sequence_checker_) void RtpVideoStreamReceiver2::NotifyReceiverOfEmptyPacket(uint16_t seq_num) { RTC_DCHECK_RUN_ON(&worker_task_checker_); - reference_finder_->PaddingReceived(seq_num); + OnCompleteFrames(reference_finder_->PaddingReceived(seq_num)); OnInsertedPacket(packet_buffer_.InsertPadding(seq_num)); if (nack_module_) { @@ -1019,7 +1004,7 @@ void RtpVideoStreamReceiver2::NotifyReceiverOfEmptyPacket(uint16_t seq_num) { bool RtpVideoStreamReceiver2::DeliverRtcp(const uint8_t* rtcp_packet, size_t rtcp_packet_length) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); if (!receiving_) { return false; @@ -1052,7 +1037,7 @@ bool RtpVideoStreamReceiver2::DeliverRtcp(const uint8_t* rtcp_packet, absl::optional remote_to_local_clock_offset_ms = ntp_estimator_.EstimateRemoteToLocalClockOffsetMs(); if (remote_to_local_clock_offset_ms.has_value()) { - absolute_capture_time_receiver_.SetRemoteToLocalClockOffset( + capture_clock_offset_updater_.SetRemoteToLocalClockOffset( Int64MsToQ32x32(*remote_to_local_clock_offset_ms)); } } @@ -1061,7 +1046,7 @@ bool RtpVideoStreamReceiver2::DeliverRtcp(const uint8_t* rtcp_packet, } void RtpVideoStreamReceiver2::FrameContinuous(int64_t picture_id) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); if (!nack_module_) return; @@ -1074,8 +1059,7 @@ void RtpVideoStreamReceiver2::FrameContinuous(int64_t picture_id) { } void RtpVideoStreamReceiver2::FrameDecoded(int64_t picture_id) { - RTC_DCHECK_RUN_ON(&worker_task_checker_); - // Running on the decoder thread. + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); int seq_num = -1; auto seq_num_it = last_seq_num_for_pic_id_.find(picture_id); if (seq_num_it != last_seq_num_for_pic_id_.end()) { @@ -1085,6 +1069,9 @@ void RtpVideoStreamReceiver2::FrameDecoded(int64_t picture_id) { } if (seq_num != -1) { + int64_t unwrapped_rtp_seq_num = rtp_seq_num_unwrapper_.Unwrap(seq_num); + packet_infos_.erase(packet_infos_.begin(), + packet_infos_.upper_bound(unwrapped_rtp_seq_num)); packet_buffer_.ClearTo(seq_num); reference_finder_->ClearTo(seq_num); } @@ -1097,12 +1084,12 @@ void RtpVideoStreamReceiver2::SignalNetworkState(NetworkState state) { } void RtpVideoStreamReceiver2::StartReceive() { - RTC_DCHECK_RUN_ON(&worker_task_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); receiving_ = true; } void RtpVideoStreamReceiver2::StopReceive() { - RTC_DCHECK_RUN_ON(&worker_task_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); receiving_ = false; } @@ -1133,6 +1120,7 @@ void RtpVideoStreamReceiver2::UpdateHistograms() { } } +// RTC_RUN_ON(packet_sequence_checker_) void RtpVideoStreamReceiver2::InsertSpsPpsIntoTracker(uint8_t payload_type) { RTC_DCHECK_RUN_ON(&worker_task_checker_); @@ -1158,4 +1146,37 @@ void RtpVideoStreamReceiver2::InsertSpsPpsIntoTracker(uint8_t payload_type) { sprop_decoder.pps_nalu()); } +void RtpVideoStreamReceiver2::UpdatePacketReceiveTimestamps( + const RtpPacketReceived& packet, + bool is_keyframe) { + Timestamp now = clock_->CurrentTime(); + if (is_keyframe || + last_received_keyframe_rtp_timestamp_ == packet.Timestamp()) { + last_received_keyframe_rtp_timestamp_ = packet.Timestamp(); + last_received_keyframe_rtp_system_time_ = now; + } + last_received_rtp_system_time_ = now; + last_received_rtp_timestamp_ = packet.Timestamp(); + + // Periodically log the RTP header of incoming packets. + if (now.ms() - last_packet_log_ms_ > kPacketLogIntervalMs) { + rtc::StringBuilder ss; + ss << "Packet received on SSRC: " << packet.Ssrc() + << " with payload type: " << static_cast(packet.PayloadType()) + << ", timestamp: " << packet.Timestamp() + << ", sequence number: " << packet.SequenceNumber() + << ", arrival time: " << ToString(packet.arrival_time()); + int32_t time_offset; + if (packet.GetExtension(&time_offset)) { + ss << ", toffset: " << time_offset; + } + uint32_t send_time; + if (packet.GetExtension(&send_time)) { + ss << ", abs send time: " << send_time; + } + RTC_LOG(LS_INFO) << ss.str(); + last_packet_log_ms_ = now.ms(); + } +} + } // namespace webrtc diff --git a/video/rtp_video_stream_receiver2.h b/video/rtp_video_stream_receiver2.h index 40e7ef6f1b..ddff26b3bd 100644 --- a/video/rtp_video_stream_receiver2.h +++ b/video/rtp_video_stream_receiver2.h @@ -18,6 +18,8 @@ #include "absl/types/optional.h" #include "api/crypto/frame_decryptor_interface.h" +#include "api/sequence_checker.h" +#include "api/units/timestamp.h" #include "api/video/color_space.h" #include "api/video_codecs/video_codec.h" #include "call/rtp_packet_sink_interface.h" @@ -27,7 +29,8 @@ #include "modules/rtp_rtcp/include/remote_ntp_time_estimator.h" #include "modules/rtp_rtcp/include/rtp_header_extension_map.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" -#include "modules/rtp_rtcp/source/absolute_capture_time_receiver.h" +#include "modules/rtp_rtcp/source/absolute_capture_time_interpolator.h" +#include "modules/rtp_rtcp/source/capture_clock_offset_updater.h" #include "modules/rtp_rtcp/source/rtp_dependency_descriptor_extension.h" #include "modules/rtp_rtcp/source/rtp_packet_received.h" #include "modules/rtp_rtcp/source/rtp_rtcp_impl2.h" @@ -42,7 +45,6 @@ #include "rtc_base/constructor_magic.h" #include "rtc_base/experiments/field_trial_parser.h" #include "rtc_base/numerics/sequence_number_util.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/thread_annotations.h" #include "video/buffered_frame_decryptor.h" @@ -52,7 +54,6 @@ namespace webrtc { class NackModule2; class PacketRouter; -class ProcessThread; class ReceiveStatistics; class RtcpRttStats; class RtpPacketReceived; @@ -63,11 +64,18 @@ class RtpVideoStreamReceiver2 : public LossNotificationSender, public RecoveredPacketReceiver, public RtpPacketSinkInterface, public KeyFrameRequestSender, - public video_coding::OnCompleteFrameCallback, public OnDecryptedFrameCallback, public OnDecryptionStatusChangeCallback, public RtpVideoFrameReceiver { public: + // A complete frame is a frame which has received all its packets and all its + // references are known. + class OnCompleteFrameCallback { + public: + virtual ~OnCompleteFrameCallback() {} + virtual void OnCompleteFrame(std::unique_ptr frame) = 0; + }; + RtpVideoStreamReceiver2( TaskQueueBase* current_queue, Clock* clock, @@ -81,12 +89,11 @@ class RtpVideoStreamReceiver2 : public LossNotificationSender, ReceiveStatistics* rtp_receive_statistics, RtcpPacketTypeCounterObserver* rtcp_packet_type_counter_observer, RtcpCnameCallback* rtcp_cname_callback, - ProcessThread* process_thread, NackSender* nack_sender, // The KeyFrameRequestSender is optional; if not provided, key frame // requests are sent via the internal RtpRtcp module. KeyFrameRequestSender* keyframe_request_sender, - video_coding::OnCompleteFrameCallback* complete_frame_callback, + OnCompleteFrameCallback* complete_frame_callback, rtc::scoped_refptr frame_decryptor, rtc::scoped_refptr frame_transformer); ~RtpVideoStreamReceiver2() override; @@ -112,7 +119,7 @@ class RtpVideoStreamReceiver2 : public LossNotificationSender, // Returns number of different frames seen. int GetUniqueFramesSeen() const { - RTC_DCHECK_RUN_ON(&worker_task_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); return frame_counter_.GetUniqueSeen(); } @@ -144,16 +151,14 @@ class RtpVideoStreamReceiver2 : public LossNotificationSender, // Decryption not SRTP. bool IsDecryptable() const; - // Don't use, still experimental. + // Request packet retransmits via NACK. Called via + // VideoReceiveStream2::SendNack, which gets called when + // RtpVideoStreamReceiver2::RtcpFeedbackBuffer's SendNack and + // SendBufferedRtcpFeedback methods (see `rtcp_feedback_buffer_` below). void RequestPacketRetransmit(const std::vector& sequence_numbers); - // Implements OnCompleteFrameCallback. - void OnCompleteFrame( - std::unique_ptr frame) override; - // Implements OnDecryptedFrameCallback. - void OnDecryptedFrame( - std::unique_ptr frame) override; + void OnDecryptedFrame(std::unique_ptr frame) override; // Implements OnDecryptionStatusChangeCallback. void OnDecryptionStatusChange( @@ -175,17 +180,12 @@ class RtpVideoStreamReceiver2 : public LossNotificationSender, absl::optional LastReceivedPacketMs() const; absl::optional LastReceivedKeyframePacketMs() const; - // RtpDemuxer only forwards a given RTP packet to one sink. However, some - // sinks, such as FlexFEC, might wish to be informed of all of the packets - // a given sink receives (or any set of sinks). They may do so by registering - // themselves as secondary sinks. - void AddSecondarySink(RtpPacketSinkInterface* sink); - void RemoveSecondarySink(const RtpPacketSinkInterface* sink); - private: // Implements RtpVideoFrameReceiver. - void ManageFrame( - std::unique_ptr frame) override; + void ManageFrame(std::unique_ptr frame) override; + + void OnCompleteFrames(RtpFrameReferenceFinder::ReturnVector frame) + RTC_RUN_ON(packet_sequence_checker_); // Used for buffering RTCP feedback messages and sending them all together. // Note: @@ -234,20 +234,20 @@ class RtpVideoStreamReceiver2 : public LossNotificationSender, bool decodability_flag; }; - RTC_NO_UNIQUE_ADDRESS SequenceChecker worker_task_checker_; + RTC_NO_UNIQUE_ADDRESS SequenceChecker packet_sequence_checker_; KeyFrameRequestSender* const key_frame_request_sender_; NackSender* const nack_sender_; LossNotificationSender* const loss_notification_sender_; // Key-frame-request-related state. - bool request_key_frame_ RTC_GUARDED_BY(worker_task_checker_); + bool request_key_frame_ RTC_GUARDED_BY(packet_sequence_checker_); // NACK-related state. std::vector nack_sequence_numbers_ - RTC_GUARDED_BY(worker_task_checker_); + RTC_GUARDED_BY(packet_sequence_checker_); absl::optional lntf_state_ - RTC_GUARDED_BY(worker_task_checker_); + RTC_GUARDED_BY(packet_sequence_checker_); }; enum ParseGenericDependenciesResult { kDropPacket, @@ -257,25 +257,34 @@ class RtpVideoStreamReceiver2 : public LossNotificationSender, // Entry point doing non-stats work for a received packet. Called // for the same packet both before and after RED decapsulation. - void ReceivePacket(const RtpPacketReceived& packet); + void ReceivePacket(const RtpPacketReceived& packet) + RTC_RUN_ON(packet_sequence_checker_); + // Parses and handles RED headers. // This function assumes that it's being called from only one thread. - void ParseAndHandleEncapsulatingHeader(const RtpPacketReceived& packet); - void NotifyReceiverOfEmptyPacket(uint16_t seq_num); + void ParseAndHandleEncapsulatingHeader(const RtpPacketReceived& packet) + RTC_RUN_ON(packet_sequence_checker_); + void NotifyReceiverOfEmptyPacket(uint16_t seq_num) + RTC_RUN_ON(packet_sequence_checker_); void UpdateHistograms(); bool IsRedEnabled() const; - void InsertSpsPpsIntoTracker(uint8_t payload_type); - void OnInsertedPacket(video_coding::PacketBuffer::InsertResult result); + void InsertSpsPpsIntoTracker(uint8_t payload_type) + RTC_RUN_ON(packet_sequence_checker_); + void OnInsertedPacket(video_coding::PacketBuffer::InsertResult result) + RTC_RUN_ON(packet_sequence_checker_); ParseGenericDependenciesResult ParseGenericDependenciesExtension( const RtpPacketReceived& rtp_packet, - RTPVideoHeader* video_header) RTC_RUN_ON(worker_task_checker_); - void OnAssembledFrame(std::unique_ptr frame); + RTPVideoHeader* video_header) RTC_RUN_ON(packet_sequence_checker_); + void OnAssembledFrame(std::unique_ptr frame) + RTC_RUN_ON(packet_sequence_checker_); + void UpdatePacketReceiveTimestamps(const RtpPacketReceived& packet, + bool is_keyframe) + RTC_RUN_ON(packet_sequence_checker_); Clock* const clock_; // Ownership of this object lies with VideoReceiveStream, which owns |this|. const VideoReceiveStream::Config& config_; PacketRouter* const packet_router_; - ProcessThread* const process_thread_; RemoteNtpTimeEstimator ntp_estimator_; @@ -288,79 +297,99 @@ class RtpVideoStreamReceiver2 : public LossNotificationSender, std::unique_ptr ulpfec_receiver_; RTC_NO_UNIQUE_ADDRESS SequenceChecker worker_task_checker_; - bool receiving_ RTC_GUARDED_BY(worker_task_checker_); - int64_t last_packet_log_ms_ RTC_GUARDED_BY(worker_task_checker_); + // TODO(bugs.webrtc.org/11993): This checker conceptually represents + // operations that belong to the network thread. The Call class is currently + // moving towards handling network packets on the network thread and while + // that work is ongoing, this checker may in practice represent the worker + // thread, but still serves as a mechanism of grouping together concepts + // that belong to the network thread. Once the packets are fully delivered + // on the network thread, this comment will be deleted. + RTC_NO_UNIQUE_ADDRESS SequenceChecker packet_sequence_checker_; + bool receiving_ RTC_GUARDED_BY(packet_sequence_checker_); + int64_t last_packet_log_ms_ RTC_GUARDED_BY(packet_sequence_checker_); const std::unique_ptr rtp_rtcp_; - video_coding::OnCompleteFrameCallback* complete_frame_callback_; + OnCompleteFrameCallback* complete_frame_callback_; KeyFrameRequestSender* const keyframe_request_sender_; RtcpFeedbackBuffer rtcp_feedback_buffer_; const std::unique_ptr nack_module_; std::unique_ptr loss_notification_controller_; - video_coding::PacketBuffer packet_buffer_; - UniqueTimestampCounter frame_counter_ RTC_GUARDED_BY(worker_task_checker_); + video_coding::PacketBuffer packet_buffer_ + RTC_GUARDED_BY(packet_sequence_checker_); + UniqueTimestampCounter frame_counter_ + RTC_GUARDED_BY(packet_sequence_checker_); SeqNumUnwrapper frame_id_unwrapper_ - RTC_GUARDED_BY(worker_task_checker_); + RTC_GUARDED_BY(packet_sequence_checker_); // Video structure provided in the dependency descriptor in a first packet // of a key frame. It is required to parse dependency descriptor in the // following delta packets. std::unique_ptr video_structure_ - RTC_GUARDED_BY(worker_task_checker_); + RTC_GUARDED_BY(packet_sequence_checker_); // Frame id of the last frame with the attached video structure. // absl::nullopt when `video_structure_ == nullptr`; absl::optional video_structure_frame_id_ - RTC_GUARDED_BY(worker_task_checker_); + RTC_GUARDED_BY(packet_sequence_checker_); - std::unique_ptr reference_finder_ - RTC_GUARDED_BY(worker_task_checker_); + std::unique_ptr reference_finder_ + RTC_GUARDED_BY(packet_sequence_checker_); absl::optional current_codec_ - RTC_GUARDED_BY(worker_task_checker_); + RTC_GUARDED_BY(packet_sequence_checker_); uint32_t last_assembled_frame_rtp_timestamp_ - RTC_GUARDED_BY(worker_task_checker_); + RTC_GUARDED_BY(packet_sequence_checker_); std::map last_seq_num_for_pic_id_ - RTC_GUARDED_BY(worker_task_checker_); - video_coding::H264SpsPpsTracker tracker_ RTC_GUARDED_BY(worker_task_checker_); + RTC_GUARDED_BY(packet_sequence_checker_); + video_coding::H264SpsPpsTracker tracker_ + RTC_GUARDED_BY(packet_sequence_checker_); // Maps payload id to the depacketizer. std::map> payload_type_map_ - RTC_GUARDED_BY(worker_task_checker_); + RTC_GUARDED_BY(packet_sequence_checker_); // TODO(johan): Remove pt_codec_params_ once // https://bugs.chromium.org/p/webrtc/issues/detail?id=6883 is resolved. // Maps a payload type to a map of out-of-band supplied codec parameters. std::map> pt_codec_params_ - RTC_GUARDED_BY(worker_task_checker_); - int16_t last_payload_type_ RTC_GUARDED_BY(worker_task_checker_) = -1; - - bool has_received_frame_ RTC_GUARDED_BY(worker_task_checker_); + RTC_GUARDED_BY(packet_sequence_checker_); + int16_t last_payload_type_ RTC_GUARDED_BY(packet_sequence_checker_) = -1; - std::vector secondary_sinks_ - RTC_GUARDED_BY(worker_task_checker_); + bool has_received_frame_ RTC_GUARDED_BY(packet_sequence_checker_); absl::optional last_received_rtp_timestamp_ - RTC_GUARDED_BY(worker_task_checker_); - absl::optional last_received_rtp_system_time_ms_ - RTC_GUARDED_BY(worker_task_checker_); + RTC_GUARDED_BY(packet_sequence_checker_); + absl::optional last_received_keyframe_rtp_timestamp_ + RTC_GUARDED_BY(packet_sequence_checker_); + absl::optional last_received_rtp_system_time_ + RTC_GUARDED_BY(packet_sequence_checker_); + absl::optional last_received_keyframe_rtp_system_time_ + RTC_GUARDED_BY(packet_sequence_checker_); // Handles incoming encrypted frames and forwards them to the // rtp_reference_finder if they are decryptable. std::unique_ptr buffered_frame_decryptor_ - RTC_PT_GUARDED_BY(worker_task_checker_); + RTC_PT_GUARDED_BY(packet_sequence_checker_); bool frames_decryptable_ RTC_GUARDED_BY(worker_task_checker_); absl::optional last_color_space_; - AbsoluteCaptureTimeReceiver absolute_capture_time_receiver_ - RTC_GUARDED_BY(worker_task_checker_); + AbsoluteCaptureTimeInterpolator absolute_capture_time_interpolator_ + RTC_GUARDED_BY(packet_sequence_checker_); + + CaptureClockOffsetUpdater capture_clock_offset_updater_ + RTC_GUARDED_BY(packet_sequence_checker_); int64_t last_completed_picture_id_ = 0; rtc::scoped_refptr frame_transformer_delegate_; + + SeqNumUnwrapper rtp_seq_num_unwrapper_ + RTC_GUARDED_BY(packet_sequence_checker_); + std::map packet_infos_ + RTC_GUARDED_BY(packet_sequence_checker_); }; } // namespace webrtc diff --git a/video/rtp_video_stream_receiver2_unittest.cc b/video/rtp_video_stream_receiver2_unittest.cc index dabd9ffae0..7ccf0a5faa 100644 --- a/video/rtp_video_stream_receiver2_unittest.cc +++ b/video/rtp_video_stream_receiver2_unittest.cc @@ -13,6 +13,7 @@ #include #include +#include "api/task_queue/task_queue_base.h" #include "api/video/video_codec_type.h" #include "api/video/video_frame_type.h" #include "common_video/h264/h264_common.h" @@ -38,6 +39,7 @@ #include "test/gtest.h" #include "test/mock_frame_transformer.h" #include "test/time_controller/simulated_task_queue.h" +#include "test/time_controller/simulated_time_controller.h" using ::testing::_; using ::testing::ElementsAre; @@ -51,8 +53,7 @@ namespace { const uint8_t kH264StartCode[] = {0x00, 0x00, 0x00, 0x01}; -std::vector GetAbsoluteCaptureTimestamps( - const video_coding::EncodedFrame* frame) { +std::vector GetAbsoluteCaptureTimestamps(const EncodedFrame* frame) { std::vector result; for (const auto& packet_info : frame->PacketInfos()) { if (packet_info.absolute_capture_time()) { @@ -96,23 +97,13 @@ class MockKeyFrameRequestSender : public KeyFrameRequestSender { }; class MockOnCompleteFrameCallback - : public video_coding::OnCompleteFrameCallback { + : public RtpVideoStreamReceiver2::OnCompleteFrameCallback { public: - MOCK_METHOD(void, DoOnCompleteFrame, (video_coding::EncodedFrame*), ()); - MOCK_METHOD(void, - DoOnCompleteFrameFailNullptr, - (video_coding::EncodedFrame*), - ()); - MOCK_METHOD(void, - DoOnCompleteFrameFailLength, - (video_coding::EncodedFrame*), - ()); - MOCK_METHOD(void, - DoOnCompleteFrameFailBitstream, - (video_coding::EncodedFrame*), - ()); - void OnCompleteFrame( - std::unique_ptr frame) override { + MOCK_METHOD(void, DoOnCompleteFrame, (EncodedFrame*), ()); + MOCK_METHOD(void, DoOnCompleteFrameFailNullptr, (EncodedFrame*), ()); + MOCK_METHOD(void, DoOnCompleteFrameFailLength, (EncodedFrame*), ()); + MOCK_METHOD(void, DoOnCompleteFrameFailBitstream, (EncodedFrame*), ()); + void OnCompleteFrame(std::unique_ptr frame) override { if (!frame) { DoOnCompleteFrameFailNullptr(nullptr); return; @@ -145,11 +136,11 @@ class MockRtpPacketSink : public RtpPacketSinkInterface { }; constexpr uint32_t kSsrc = 111; -constexpr uint16_t kSequenceNumber = 222; constexpr int kPayloadType = 100; constexpr int kRedPayloadType = 125; std::unique_ptr CreateRtpPacketReceived() { + constexpr uint16_t kSequenceNumber = 222; auto packet = std::make_unique(); packet->SetSsrc(kSsrc); packet->SetSequenceNumber(kSequenceNumber); @@ -164,21 +155,25 @@ MATCHER_P(SamePacketAs, other, "") { } // namespace -class RtpVideoStreamReceiver2Test : public ::testing::Test { +class RtpVideoStreamReceiver2Test : public ::testing::Test, + public RtpPacketSinkInterface { public: RtpVideoStreamReceiver2Test() : RtpVideoStreamReceiver2Test("") {} explicit RtpVideoStreamReceiver2Test(std::string field_trials) - : override_field_trials_(field_trials), - config_(CreateConfig()), - process_thread_(ProcessThread::Create("TestThread")) { + : time_controller_(Timestamp::Millis(100)), + task_queue_(time_controller_.GetTaskQueueFactory()->CreateTaskQueue( + "RtpVideoStreamReceiver2Test", + TaskQueueFactory::Priority::NORMAL)), + task_queue_setter_(task_queue_.get()), + override_field_trials_(field_trials), + config_(CreateConfig()) { rtp_receive_statistics_ = ReceiveStatistics::Create(Clock::GetRealTimeClock()); rtp_video_stream_receiver_ = std::make_unique( TaskQueueBase::Current(), Clock::GetRealTimeClock(), &mock_transport_, nullptr, nullptr, &config_, rtp_receive_statistics_.get(), nullptr, - nullptr, process_thread_.get(), &mock_nack_sender_, - &mock_key_frame_request_sender_, &mock_on_complete_frame_callback_, - nullptr, nullptr); + nullptr, &mock_nack_sender_, &mock_key_frame_request_sender_, + &mock_on_complete_frame_callback_, nullptr, nullptr); VideoCodec codec; codec.codecType = kVideoCodecGeneric; rtp_video_stream_receiver_->AddReceiveCodec(kPayloadType, codec, {}, @@ -228,17 +223,24 @@ class RtpVideoStreamReceiver2Test : public ::testing::Test { h264.nalus[h264.nalus_length++] = info; } + void OnRtpPacket(const RtpPacketReceived& packet) override { + if (test_packet_sink_) + test_packet_sink_->OnRtpPacket(packet); + } + protected: - static VideoReceiveStream::Config CreateConfig() { + VideoReceiveStream::Config CreateConfig() { VideoReceiveStream::Config config(nullptr); config.rtp.remote_ssrc = 1111; config.rtp.local_ssrc = 2222; config.rtp.red_payload_type = kRedPayloadType; + config.rtp.packet_sink_ = this; return config; } - TokenTaskQueue task_queue_; - TokenTaskQueue::CurrentTaskQueueSetter task_queue_setter_{&task_queue_}; + GlobalSimulatedTimeController time_controller_; + std::unique_ptr task_queue_; + TokenTaskQueue::CurrentTaskQueueSetter task_queue_setter_; const webrtc::test::ScopedFieldTrials override_field_trials_; VideoReceiveStream::Config config_; @@ -246,9 +248,9 @@ class RtpVideoStreamReceiver2Test : public ::testing::Test { MockKeyFrameRequestSender mock_key_frame_request_sender_; MockTransport mock_transport_; MockOnCompleteFrameCallback mock_on_complete_frame_callback_; - std::unique_ptr process_thread_; std::unique_ptr rtp_receive_statistics_; std::unique_ptr rtp_video_stream_receiver_; + RtpPacketSinkInterface* test_packet_sink_ = nullptr; }; TEST_F(RtpVideoStreamReceiver2Test, CacheColorSpaceFromLastPacketOfKeyframe) { @@ -344,7 +346,7 @@ TEST_F(RtpVideoStreamReceiver2Test, CacheColorSpaceFromLastPacketOfKeyframe) { EXPECT_TRUE(key_frame_packet2.GetExtension()); rtp_video_stream_receiver_->OnRtpPacket(key_frame_packet1); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame(_)) - .WillOnce(Invoke([kColorSpace](video_coding::EncodedFrame* frame) { + .WillOnce(Invoke([kColorSpace](EncodedFrame* frame) { ASSERT_TRUE(frame->EncodedImage().ColorSpace()); EXPECT_EQ(*frame->EncodedImage().ColorSpace(), kColorSpace); })); @@ -360,7 +362,7 @@ TEST_F(RtpVideoStreamReceiver2Test, CacheColorSpaceFromLastPacketOfKeyframe) { // included in the RTP packet. EXPECT_FALSE(delta_frame_packet.GetExtension()); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame(_)) - .WillOnce(Invoke([kColorSpace](video_coding::EncodedFrame* frame) { + .WillOnce(Invoke([kColorSpace](EncodedFrame* frame) { ASSERT_TRUE(frame->EncodedImage().ColorSpace()); EXPECT_EQ(*frame->EncodedImage().ColorSpace(), kColorSpace); })); @@ -402,11 +404,10 @@ TEST_F(RtpVideoStreamReceiver2Test, PacketInfoIsPropagatedIntoVideoFrames) { mock_on_complete_frame_callback_.AppendExpectedBitstream(data.data(), data.size()); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame(_)) - .WillOnce(Invoke( - [kAbsoluteCaptureTimestamp](video_coding::EncodedFrame* frame) { - EXPECT_THAT(GetAbsoluteCaptureTimestamps(frame), - ElementsAre(kAbsoluteCaptureTimestamp)); - })); + .WillOnce(Invoke([kAbsoluteCaptureTimestamp](EncodedFrame* frame) { + EXPECT_THAT(GetAbsoluteCaptureTimestamps(frame), + ElementsAre(kAbsoluteCaptureTimestamp)); + })); rtp_video_stream_receiver_->OnReceivedPayloadData(data, rtp_packet, video_header); } @@ -450,7 +451,7 @@ TEST_F(RtpVideoStreamReceiver2Test, // Expect rtp video stream receiver to extrapolate it for the resulting video // frame using absolute capture time from the previous packet. EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame(_)) - .WillOnce(Invoke([](video_coding::EncodedFrame* frame) { + .WillOnce(Invoke([](EncodedFrame* frame) { EXPECT_THAT(GetAbsoluteCaptureTimestamps(frame), SizeIs(1)); })); rtp_video_stream_receiver_->OnReceivedPayloadData(data, rtp_packet, @@ -663,9 +664,8 @@ TEST_P(RtpVideoStreamReceiver2TestH264, ForceSpsPpsIdrIsKeyframe) { mock_on_complete_frame_callback_.AppendExpectedBitstream(idr_data.data(), idr_data.size()); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { - EXPECT_TRUE(frame->is_keyframe()); - }); + .WillOnce( + [&](EncodedFrame* frame) { EXPECT_TRUE(frame->is_keyframe()); }); rtp_video_stream_receiver_->OnReceivedPayloadData(idr_data, rtp_packet, idr_video_header); mock_on_complete_frame_callback_.ClearExpectedBitstream(); @@ -675,9 +675,8 @@ TEST_P(RtpVideoStreamReceiver2TestH264, ForceSpsPpsIdrIsKeyframe) { idr_data.size()); rtp_packet.SetSequenceNumber(3); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { - EXPECT_FALSE(frame->is_keyframe()); - }); + .WillOnce( + [&](EncodedFrame* frame) { EXPECT_FALSE(frame->is_keyframe()); }); rtp_video_stream_receiver_->OnReceivedPayloadData(idr_data, rtp_packet, idr_video_header); } @@ -755,83 +754,36 @@ TEST_F(RtpVideoStreamReceiver2Test, RequestKeyframeWhenPacketBufferGetsFull) { video_header); } -TEST_F(RtpVideoStreamReceiver2Test, SecondarySinksGetRtpNotifications) { - rtp_video_stream_receiver_->StartReceive(); - - MockRtpPacketSink secondary_sink_1; - MockRtpPacketSink secondary_sink_2; - - rtp_video_stream_receiver_->AddSecondarySink(&secondary_sink_1); - rtp_video_stream_receiver_->AddSecondarySink(&secondary_sink_2); - - auto rtp_packet = CreateRtpPacketReceived(); - EXPECT_CALL(secondary_sink_1, OnRtpPacket(SamePacketAs(*rtp_packet))); - EXPECT_CALL(secondary_sink_2, OnRtpPacket(SamePacketAs(*rtp_packet))); - - rtp_video_stream_receiver_->OnRtpPacket(*rtp_packet); - - // Test tear-down. - rtp_video_stream_receiver_->StopReceive(); - rtp_video_stream_receiver_->RemoveSecondarySink(&secondary_sink_1); - rtp_video_stream_receiver_->RemoveSecondarySink(&secondary_sink_2); -} - -TEST_F(RtpVideoStreamReceiver2Test, - RemovedSecondarySinksGetNoRtpNotifications) { +TEST_F(RtpVideoStreamReceiver2Test, SinkGetsRtpNotifications) { rtp_video_stream_receiver_->StartReceive(); - MockRtpPacketSink secondary_sink; - - rtp_video_stream_receiver_->AddSecondarySink(&secondary_sink); - rtp_video_stream_receiver_->RemoveSecondarySink(&secondary_sink); + MockRtpPacketSink test_sink; + test_packet_sink_ = &test_sink; auto rtp_packet = CreateRtpPacketReceived(); - - EXPECT_CALL(secondary_sink, OnRtpPacket(_)).Times(0); + EXPECT_CALL(test_sink, OnRtpPacket(SamePacketAs(*rtp_packet))); rtp_video_stream_receiver_->OnRtpPacket(*rtp_packet); // Test tear-down. rtp_video_stream_receiver_->StopReceive(); + test_packet_sink_ = nullptr; } -TEST_F(RtpVideoStreamReceiver2Test, - OnlyRemovedSecondarySinksExcludedFromNotifications) { - rtp_video_stream_receiver_->StartReceive(); - - MockRtpPacketSink kept_secondary_sink; - MockRtpPacketSink removed_secondary_sink; - - rtp_video_stream_receiver_->AddSecondarySink(&kept_secondary_sink); - rtp_video_stream_receiver_->AddSecondarySink(&removed_secondary_sink); - rtp_video_stream_receiver_->RemoveSecondarySink(&removed_secondary_sink); - - auto rtp_packet = CreateRtpPacketReceived(); - EXPECT_CALL(kept_secondary_sink, OnRtpPacket(SamePacketAs(*rtp_packet))); - - rtp_video_stream_receiver_->OnRtpPacket(*rtp_packet); - - // Test tear-down. - rtp_video_stream_receiver_->StopReceive(); - rtp_video_stream_receiver_->RemoveSecondarySink(&kept_secondary_sink); -} - -TEST_F(RtpVideoStreamReceiver2Test, - SecondariesOfNonStartedStreamGetNoNotifications) { +TEST_F(RtpVideoStreamReceiver2Test, NonStartedStreamGetsNoRtpCallbacks) { // Explicitly showing that the stream is not in the |started| state, // regardless of whether streams start out |started| or |stopped|. rtp_video_stream_receiver_->StopReceive(); - MockRtpPacketSink secondary_sink; - rtp_video_stream_receiver_->AddSecondarySink(&secondary_sink); + MockRtpPacketSink test_sink; + test_packet_sink_ = &test_sink; auto rtp_packet = CreateRtpPacketReceived(); - EXPECT_CALL(secondary_sink, OnRtpPacket(_)).Times(0); + EXPECT_CALL(test_sink, OnRtpPacket(_)).Times(0); rtp_video_stream_receiver_->OnRtpPacket(*rtp_packet); - // Test tear-down. - rtp_video_stream_receiver_->RemoveSecondarySink(&secondary_sink); + test_packet_sink_ = nullptr; } TEST_F(RtpVideoStreamReceiver2Test, ParseGenericDescriptorOnePacket) { @@ -866,10 +818,10 @@ TEST_F(RtpVideoStreamReceiver2Test, ParseGenericDescriptorOnePacket) { rtp_packet.SetSequenceNumber(1); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce(Invoke([kSpatialIndex](video_coding::EncodedFrame* frame) { + .WillOnce(Invoke([kSpatialIndex](EncodedFrame* frame) { EXPECT_EQ(frame->num_references, 2U); - EXPECT_EQ(frame->references[0], frame->id.picture_id - 90); - EXPECT_EQ(frame->references[1], frame->id.picture_id - 80); + EXPECT_EQ(frame->references[0], frame->Id() - 90); + EXPECT_EQ(frame->references[1], frame->Id() - 80); EXPECT_EQ(frame->SpatialIndex(), kSpatialIndex); EXPECT_THAT(frame->PacketInfos(), SizeIs(1)); })); @@ -924,7 +876,7 @@ TEST_F(RtpVideoStreamReceiver2Test, ParseGenericDescriptorTwoPackets) { data.size() - 1); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce(Invoke([kSpatialIndex](video_coding::EncodedFrame* frame) { + .WillOnce(Invoke([kSpatialIndex](EncodedFrame* frame) { EXPECT_EQ(frame->num_references, 0U); EXPECT_EQ(frame->SpatialIndex(), kSpatialIndex); EXPECT_EQ(frame->EncodedImage()._encodedWidth, 480u); @@ -1003,14 +955,12 @@ TEST_F(RtpVideoStreamReceiver2Test, UnwrapsFrameId) { int64_t first_picture_id; EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { - first_picture_id = frame->id.picture_id; - }); + .WillOnce([&](EncodedFrame* frame) { first_picture_id = frame->Id(); }); inject_packet(/*wrapped_frame_id=*/0xffff); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { - EXPECT_EQ(frame->id.picture_id - first_picture_id, 3); + .WillOnce([&](EncodedFrame* frame) { + EXPECT_EQ(frame->Id() - first_picture_id, 3); }); inject_packet(/*wrapped_frame_id=*/0x0002); } @@ -1074,9 +1024,7 @@ TEST_F(RtpVideoStreamReceiver2DependencyDescriptorTest, UnwrapsFrameId) { // keyframe. Thus feed a key frame first, then test reodered delta frames. int64_t first_picture_id; EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { - first_picture_id = frame->id.picture_id; - }); + .WillOnce([&](EncodedFrame* frame) { first_picture_id = frame->Id(); }); InjectPacketWith(stream_structure, keyframe_descriptor); DependencyDescriptor deltaframe1_descriptor; @@ -1090,13 +1038,13 @@ TEST_F(RtpVideoStreamReceiver2DependencyDescriptorTest, UnwrapsFrameId) { // Parser should unwrap frame ids correctly even if packets were reordered by // the network. EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { + .WillOnce([&](EncodedFrame* frame) { // 0x0002 - 0xfff0 - EXPECT_EQ(frame->id.picture_id - first_picture_id, 18); + EXPECT_EQ(frame->Id() - first_picture_id, 18); }) - .WillOnce([&](video_coding::EncodedFrame* frame) { + .WillOnce([&](EncodedFrame* frame) { // 0xfffe - 0xfff0 - EXPECT_EQ(frame->id.picture_id - first_picture_id, 14); + EXPECT_EQ(frame->Id() - first_picture_id, 14); }); InjectPacketWith(stream_structure, deltaframe2_descriptor); InjectPacketWith(stream_structure, deltaframe1_descriptor); @@ -1160,9 +1108,8 @@ TEST_F(RtpVideoStreamReceiver2DependencyDescriptorTest, keyframe2_descriptor.frame_number = 3; EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { - EXPECT_EQ(frame->id.picture_id & 0xFFFF, 3); - }); + .WillOnce( + [&](EncodedFrame* frame) { EXPECT_EQ(frame->Id() & 0xFFFF, 3); }); InjectPacketWith(stream_structure2, keyframe2_descriptor); InjectPacketWith(stream_structure1, keyframe1_descriptor); @@ -1172,36 +1119,21 @@ TEST_F(RtpVideoStreamReceiver2DependencyDescriptorTest, deltaframe_descriptor.frame_dependencies = stream_structure2.templates[0]; deltaframe_descriptor.frame_number = 4; EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { - EXPECT_EQ(frame->id.picture_id & 0xFFFF, 4); - }); + .WillOnce( + [&](EncodedFrame* frame) { EXPECT_EQ(frame->Id() & 0xFFFF, 4); }); InjectPacketWith(stream_structure2, deltaframe_descriptor); } -#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) -using RtpVideoStreamReceiver2DeathTest = RtpVideoStreamReceiver2Test; -TEST_F(RtpVideoStreamReceiver2DeathTest, RepeatedSecondarySinkDisallowed) { - MockRtpPacketSink secondary_sink; - - rtp_video_stream_receiver_->AddSecondarySink(&secondary_sink); - EXPECT_DEATH(rtp_video_stream_receiver_->AddSecondarySink(&secondary_sink), - ""); - - // Test tear-down. - rtp_video_stream_receiver_->RemoveSecondarySink(&secondary_sink); -} -#endif - TEST_F(RtpVideoStreamReceiver2Test, TransformFrame) { rtc::scoped_refptr mock_frame_transformer = - new rtc::RefCountedObject>(); + rtc::make_ref_counted>(); EXPECT_CALL(*mock_frame_transformer, RegisterTransformedFrameSinkCallback(_, config_.rtp.remote_ssrc)); auto receiver = std::make_unique( TaskQueueBase::Current(), Clock::GetRealTimeClock(), &mock_transport_, nullptr, nullptr, &config_, rtp_receive_statistics_.get(), nullptr, - nullptr, process_thread_.get(), &mock_nack_sender_, nullptr, - &mock_on_complete_frame_callback_, nullptr, mock_frame_transformer); + nullptr, &mock_nack_sender_, nullptr, &mock_on_complete_frame_callback_, + nullptr, mock_frame_transformer); VideoCodec video_codec; video_codec.codecType = kVideoCodecGeneric; receiver->AddReceiveCodec(kPayloadType, video_codec, {}, @@ -1272,8 +1204,8 @@ TEST_P(RtpVideoStreamReceiver2TestPlayoutDelay, PlayoutDelay) { // Expect the playout delay of encoded frame to be the same as the transmitted // playout delay unless it was overridden by a field trial. EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame(_)) - .WillOnce(Invoke([expected_playout_delay = GetParam().expected_delay]( - video_coding::EncodedFrame* frame) { + .WillOnce(Invoke([expected_playout_delay = + GetParam().expected_delay](EncodedFrame* frame) { EXPECT_EQ(frame->EncodedImage().playout_delay_, expected_playout_delay); })); rtp_video_stream_receiver_->OnReceivedPayloadData( diff --git a/video/rtp_video_stream_receiver_frame_transformer_delegate.cc b/video/rtp_video_stream_receiver_frame_transformer_delegate.cc index 31eb344d5b..f2f81df3ee 100644 --- a/video/rtp_video_stream_receiver_frame_transformer_delegate.cc +++ b/video/rtp_video_stream_receiver_frame_transformer_delegate.cc @@ -24,9 +24,8 @@ namespace { class TransformableVideoReceiverFrame : public TransformableVideoFrameInterface { public: - TransformableVideoReceiverFrame( - std::unique_ptr frame, - uint32_t ssrc) + TransformableVideoReceiverFrame(std::unique_ptr frame, + uint32_t ssrc) : frame_(std::move(frame)), metadata_(frame_->GetRtpVideoHeader()), ssrc_(ssrc) {} @@ -55,12 +54,12 @@ class TransformableVideoReceiverFrame const VideoFrameMetadata& GetMetadata() const override { return metadata_; } - std::unique_ptr ExtractFrame() && { + std::unique_ptr ExtractFrame() && { return std::move(frame_); } private: - std::unique_ptr frame_; + std::unique_ptr frame_; const VideoFrameMetadata metadata_; const uint32_t ssrc_; }; @@ -91,7 +90,7 @@ void RtpVideoStreamReceiverFrameTransformerDelegate::Reset() { } void RtpVideoStreamReceiverFrameTransformerDelegate::TransformFrame( - std::unique_ptr frame) { + std::unique_ptr frame) { RTC_DCHECK_RUN_ON(&network_sequence_checker_); frame_transformer_->Transform( std::make_unique(std::move(frame), diff --git a/video/rtp_video_stream_receiver_frame_transformer_delegate.h b/video/rtp_video_stream_receiver_frame_transformer_delegate.h index 2ae8e63bba..ef05d91fd3 100644 --- a/video/rtp_video_stream_receiver_frame_transformer_delegate.h +++ b/video/rtp_video_stream_receiver_frame_transformer_delegate.h @@ -14,8 +14,8 @@ #include #include "api/frame_transformer_interface.h" +#include "api/sequence_checker.h" #include "modules/video_coding/frame_object.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/thread.h" @@ -25,8 +25,7 @@ namespace webrtc { // thread after transformation. class RtpVideoFrameReceiver { public: - virtual void ManageFrame( - std::unique_ptr frame) = 0; + virtual void ManageFrame(std::unique_ptr frame) = 0; protected: virtual ~RtpVideoFrameReceiver() = default; @@ -47,7 +46,7 @@ class RtpVideoStreamReceiverFrameTransformerDelegate void Reset(); // Delegates the call to FrameTransformerInterface::TransformFrame. - void TransformFrame(std::unique_ptr frame); + void TransformFrame(std::unique_ptr frame); // Implements TransformedFrameCallback. Can be called on any thread. Posts // the transformed frame to be managed on the |network_thread_|. diff --git a/video/rtp_video_stream_receiver_frame_transformer_delegate_unittest.cc b/video/rtp_video_stream_receiver_frame_transformer_delegate_unittest.cc index a411ca6e9a..0d85cc08e2 100644 --- a/video/rtp_video_stream_receiver_frame_transformer_delegate_unittest.cc +++ b/video/rtp_video_stream_receiver_frame_transformer_delegate_unittest.cc @@ -35,15 +35,15 @@ using ::testing::ElementsAre; using ::testing::NiceMock; using ::testing::SaveArg; -std::unique_ptr CreateRtpFrameObject( +std::unique_ptr CreateRtpFrameObject( const RTPVideoHeader& video_header) { - return std::make_unique( + return std::make_unique( 0, 0, true, 0, 0, 0, 0, 0, VideoSendTiming(), 0, video_header.codec, kVideoRotation_0, VideoContentType::UNSPECIFIED, video_header, absl::nullopt, RtpPacketInfos(), EncodedImageBuffer::Create(0)); } -std::unique_ptr CreateRtpFrameObject() { +std::unique_ptr CreateRtpFrameObject() { return CreateRtpFrameObject(RTPVideoHeader()); } @@ -54,17 +54,16 @@ class TestRtpVideoFrameReceiver : public RtpVideoFrameReceiver { MOCK_METHOD(void, ManageFrame, - (std::unique_ptr frame), + (std::unique_ptr frame), (override)); }; TEST(RtpVideoStreamReceiverFrameTransformerDelegateTest, RegisterTransformedFrameCallbackSinkOnInit) { TestRtpVideoFrameReceiver receiver; - rtc::scoped_refptr frame_transformer( - new rtc::RefCountedObject()); - rtc::scoped_refptr delegate( - new rtc::RefCountedObject( + auto frame_transformer(rtc::make_ref_counted()); + auto delegate( + rtc::make_ref_counted( &receiver, frame_transformer, rtc::Thread::Current(), /*remote_ssrc*/ 1111)); EXPECT_CALL(*frame_transformer, @@ -75,10 +74,9 @@ TEST(RtpVideoStreamReceiverFrameTransformerDelegateTest, TEST(RtpVideoStreamReceiverFrameTransformerDelegateTest, UnregisterTransformedFrameSinkCallbackOnReset) { TestRtpVideoFrameReceiver receiver; - rtc::scoped_refptr frame_transformer( - new rtc::RefCountedObject()); - rtc::scoped_refptr delegate( - new rtc::RefCountedObject( + auto frame_transformer(rtc::make_ref_counted()); + auto delegate( + rtc::make_ref_counted( &receiver, frame_transformer, rtc::Thread::Current(), /*remote_ssrc*/ 1111)); EXPECT_CALL(*frame_transformer, UnregisterTransformedFrameSinkCallback(1111)); @@ -87,10 +85,10 @@ TEST(RtpVideoStreamReceiverFrameTransformerDelegateTest, TEST(RtpVideoStreamReceiverFrameTransformerDelegateTest, TransformFrame) { TestRtpVideoFrameReceiver receiver; - rtc::scoped_refptr frame_transformer( - new rtc::RefCountedObject>()); - rtc::scoped_refptr delegate( - new rtc::RefCountedObject( + auto frame_transformer( + rtc::make_ref_counted>()); + auto delegate( + rtc::make_ref_counted( &receiver, frame_transformer, rtc::Thread::Current(), /*remote_ssrc*/ 1111)); auto frame = CreateRtpFrameObject(); @@ -101,10 +99,10 @@ TEST(RtpVideoStreamReceiverFrameTransformerDelegateTest, TransformFrame) { TEST(RtpVideoStreamReceiverFrameTransformerDelegateTest, ManageFrameOnTransformedFrame) { TestRtpVideoFrameReceiver receiver; - rtc::scoped_refptr mock_frame_transformer( - new rtc::RefCountedObject>()); - rtc::scoped_refptr delegate = - new rtc::RefCountedObject( + auto mock_frame_transformer( + rtc::make_ref_counted>()); + auto delegate = + rtc::make_ref_counted( &receiver, mock_frame_transformer, rtc::Thread::Current(), /*remote_ssrc*/ 1111); @@ -127,10 +125,10 @@ TEST(RtpVideoStreamReceiverFrameTransformerDelegateTest, TEST(RtpVideoStreamReceiverFrameTransformerDelegateTest, TransformableFrameMetadataHasCorrectValue) { TestRtpVideoFrameReceiver receiver; - rtc::scoped_refptr mock_frame_transformer = - new rtc::RefCountedObject>(); - rtc::scoped_refptr delegate = - new rtc::RefCountedObject( + auto mock_frame_transformer = + rtc::make_ref_counted>(); + auto delegate = + rtc::make_ref_counted( &receiver, mock_frame_transformer, rtc::Thread::Current(), 1111); delegate->Init(); RTPVideoHeader video_header; diff --git a/video/rtp_video_stream_receiver_unittest.cc b/video/rtp_video_stream_receiver_unittest.cc index 2f24dcfcb1..765e1e1716 100644 --- a/video/rtp_video_stream_receiver_unittest.cc +++ b/video/rtp_video_stream_receiver_unittest.cc @@ -50,8 +50,7 @@ namespace { const uint8_t kH264StartCode[] = {0x00, 0x00, 0x00, 0x01}; -std::vector GetAbsoluteCaptureTimestamps( - const video_coding::EncodedFrame* frame) { +std::vector GetAbsoluteCaptureTimestamps(const EncodedFrame* frame) { std::vector result; for (const auto& packet_info : frame->PacketInfos()) { if (packet_info.absolute_capture_time()) { @@ -95,23 +94,13 @@ class MockKeyFrameRequestSender : public KeyFrameRequestSender { }; class MockOnCompleteFrameCallback - : public video_coding::OnCompleteFrameCallback { + : public RtpVideoStreamReceiver::OnCompleteFrameCallback { public: - MOCK_METHOD(void, DoOnCompleteFrame, (video_coding::EncodedFrame*), ()); - MOCK_METHOD(void, - DoOnCompleteFrameFailNullptr, - (video_coding::EncodedFrame*), - ()); - MOCK_METHOD(void, - DoOnCompleteFrameFailLength, - (video_coding::EncodedFrame*), - ()); - MOCK_METHOD(void, - DoOnCompleteFrameFailBitstream, - (video_coding::EncodedFrame*), - ()); - void OnCompleteFrame( - std::unique_ptr frame) override { + MOCK_METHOD(void, DoOnCompleteFrame, (EncodedFrame*), ()); + MOCK_METHOD(void, DoOnCompleteFrameFailNullptr, (EncodedFrame*), ()); + MOCK_METHOD(void, DoOnCompleteFrameFailLength, (EncodedFrame*), ()); + MOCK_METHOD(void, DoOnCompleteFrameFailBitstream, (EncodedFrame*), ()); + void OnCompleteFrame(std::unique_ptr frame) override { if (!frame) { DoOnCompleteFrameFailNullptr(nullptr); return; @@ -339,7 +328,7 @@ TEST_F(RtpVideoStreamReceiverTest, CacheColorSpaceFromLastPacketOfKeyframe) { EXPECT_TRUE(key_frame_packet2.GetExtension()); rtp_video_stream_receiver_->OnRtpPacket(key_frame_packet1); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame(_)) - .WillOnce(Invoke([kColorSpace](video_coding::EncodedFrame* frame) { + .WillOnce(Invoke([kColorSpace](EncodedFrame* frame) { ASSERT_TRUE(frame->EncodedImage().ColorSpace()); EXPECT_EQ(*frame->EncodedImage().ColorSpace(), kColorSpace); })); @@ -355,7 +344,7 @@ TEST_F(RtpVideoStreamReceiverTest, CacheColorSpaceFromLastPacketOfKeyframe) { // included in the RTP packet. EXPECT_FALSE(delta_frame_packet.GetExtension()); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame(_)) - .WillOnce(Invoke([kColorSpace](video_coding::EncodedFrame* frame) { + .WillOnce(Invoke([kColorSpace](EncodedFrame* frame) { ASSERT_TRUE(frame->EncodedImage().ColorSpace()); EXPECT_EQ(*frame->EncodedImage().ColorSpace(), kColorSpace); })); @@ -397,11 +386,10 @@ TEST_F(RtpVideoStreamReceiverTest, PacketInfoIsPropagatedIntoVideoFrames) { mock_on_complete_frame_callback_.AppendExpectedBitstream(data.data(), data.size()); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame(_)) - .WillOnce(Invoke( - [kAbsoluteCaptureTimestamp](video_coding::EncodedFrame* frame) { - EXPECT_THAT(GetAbsoluteCaptureTimestamps(frame), - ElementsAre(kAbsoluteCaptureTimestamp)); - })); + .WillOnce(Invoke([kAbsoluteCaptureTimestamp](EncodedFrame* frame) { + EXPECT_THAT(GetAbsoluteCaptureTimestamps(frame), + ElementsAre(kAbsoluteCaptureTimestamp)); + })); rtp_video_stream_receiver_->OnReceivedPayloadData(data, rtp_packet, video_header); } @@ -445,7 +433,7 @@ TEST_F(RtpVideoStreamReceiverTest, // Expect rtp video stream receiver to extrapolate it for the resulting video // frame using absolute capture time from the previous packet. EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame(_)) - .WillOnce(Invoke([](video_coding::EncodedFrame* frame) { + .WillOnce(Invoke([](EncodedFrame* frame) { EXPECT_THAT(GetAbsoluteCaptureTimestamps(frame), SizeIs(1)); })); rtp_video_stream_receiver_->OnReceivedPayloadData(data, rtp_packet, @@ -657,9 +645,8 @@ TEST_P(RtpVideoStreamReceiverTestH264, ForceSpsPpsIdrIsKeyframe) { mock_on_complete_frame_callback_.AppendExpectedBitstream(idr_data.data(), idr_data.size()); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { - EXPECT_TRUE(frame->is_keyframe()); - }); + .WillOnce( + [&](EncodedFrame* frame) { EXPECT_TRUE(frame->is_keyframe()); }); rtp_video_stream_receiver_->OnReceivedPayloadData(idr_data, rtp_packet, idr_video_header); mock_on_complete_frame_callback_.ClearExpectedBitstream(); @@ -669,9 +656,8 @@ TEST_P(RtpVideoStreamReceiverTestH264, ForceSpsPpsIdrIsKeyframe) { idr_data.size()); rtp_packet.SetSequenceNumber(3); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { - EXPECT_FALSE(frame->is_keyframe()); - }); + .WillOnce( + [&](EncodedFrame* frame) { EXPECT_FALSE(frame->is_keyframe()); }); rtp_video_stream_receiver_->OnReceivedPayloadData(idr_data, rtp_packet, idr_video_header); } @@ -859,10 +845,10 @@ TEST_F(RtpVideoStreamReceiverTest, ParseGenericDescriptorOnePacket) { rtp_packet.SetSequenceNumber(1); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce(Invoke([kSpatialIndex](video_coding::EncodedFrame* frame) { + .WillOnce(Invoke([kSpatialIndex](EncodedFrame* frame) { EXPECT_EQ(frame->num_references, 2U); - EXPECT_EQ(frame->references[0], frame->id.picture_id - 90); - EXPECT_EQ(frame->references[1], frame->id.picture_id - 80); + EXPECT_EQ(frame->references[0], frame->Id() - 90); + EXPECT_EQ(frame->references[1], frame->Id() - 80); EXPECT_EQ(frame->SpatialIndex(), kSpatialIndex); EXPECT_THAT(frame->PacketInfos(), SizeIs(1)); })); @@ -917,7 +903,7 @@ TEST_F(RtpVideoStreamReceiverTest, ParseGenericDescriptorTwoPackets) { data.size() - 1); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce(Invoke([kSpatialIndex](video_coding::EncodedFrame* frame) { + .WillOnce(Invoke([kSpatialIndex](EncodedFrame* frame) { EXPECT_EQ(frame->num_references, 0U); EXPECT_EQ(frame->SpatialIndex(), kSpatialIndex); EXPECT_EQ(frame->EncodedImage()._encodedWidth, 480u); @@ -996,14 +982,12 @@ TEST_F(RtpVideoStreamReceiverTest, UnwrapsFrameId) { int64_t first_picture_id; EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { - first_picture_id = frame->id.picture_id; - }); + .WillOnce([&](EncodedFrame* frame) { first_picture_id = frame->Id(); }); inject_packet(/*wrapped_frame_id=*/0xffff); EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { - EXPECT_EQ(frame->id.picture_id - first_picture_id, 3); + .WillOnce([&](EncodedFrame* frame) { + EXPECT_EQ(frame->Id() - first_picture_id, 3); }); inject_packet(/*wrapped_frame_id=*/0x0002); } @@ -1067,9 +1051,7 @@ TEST_F(RtpVideoStreamReceiverDependencyDescriptorTest, UnwrapsFrameId) { // keyframe. Thus feed a key frame first, then test reodered delta frames. int64_t first_picture_id; EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { - first_picture_id = frame->id.picture_id; - }); + .WillOnce([&](EncodedFrame* frame) { first_picture_id = frame->Id(); }); InjectPacketWith(stream_structure, keyframe_descriptor); DependencyDescriptor deltaframe1_descriptor; @@ -1083,13 +1065,13 @@ TEST_F(RtpVideoStreamReceiverDependencyDescriptorTest, UnwrapsFrameId) { // Parser should unwrap frame ids correctly even if packets were reordered by // the network. EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { + .WillOnce([&](EncodedFrame* frame) { // 0x0002 - 0xfff0 - EXPECT_EQ(frame->id.picture_id - first_picture_id, 18); + EXPECT_EQ(frame->Id() - first_picture_id, 18); }) - .WillOnce([&](video_coding::EncodedFrame* frame) { + .WillOnce([&](EncodedFrame* frame) { // 0xfffe - 0xfff0 - EXPECT_EQ(frame->id.picture_id - first_picture_id, 14); + EXPECT_EQ(frame->Id() - first_picture_id, 14); }); InjectPacketWith(stream_structure, deltaframe2_descriptor); InjectPacketWith(stream_structure, deltaframe1_descriptor); @@ -1153,9 +1135,8 @@ TEST_F(RtpVideoStreamReceiverDependencyDescriptorTest, keyframe2_descriptor.frame_number = 3; EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { - EXPECT_EQ(frame->id.picture_id & 0xFFFF, 3); - }); + .WillOnce( + [&](EncodedFrame* frame) { EXPECT_EQ(frame->Id() & 0xFFFF, 3); }); InjectPacketWith(stream_structure2, keyframe2_descriptor); InjectPacketWith(stream_structure1, keyframe1_descriptor); @@ -1165,9 +1146,8 @@ TEST_F(RtpVideoStreamReceiverDependencyDescriptorTest, deltaframe_descriptor.frame_dependencies = stream_structure2.templates[0]; deltaframe_descriptor.frame_number = 4; EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame) - .WillOnce([&](video_coding::EncodedFrame* frame) { - EXPECT_EQ(frame->id.picture_id & 0xFFFF, 4); - }); + .WillOnce( + [&](EncodedFrame* frame) { EXPECT_EQ(frame->Id() & 0xFFFF, 4); }); InjectPacketWith(stream_structure2, deltaframe_descriptor); } @@ -1186,8 +1166,8 @@ TEST_F(RtpVideoStreamReceiverDeathTest, RepeatedSecondarySinkDisallowed) { #endif TEST_F(RtpVideoStreamReceiverTest, TransformFrame) { - rtc::scoped_refptr mock_frame_transformer = - new rtc::RefCountedObject>(); + auto mock_frame_transformer = + rtc::make_ref_counted>(); EXPECT_CALL(*mock_frame_transformer, RegisterTransformedFrameSinkCallback(_, config_.rtp.remote_ssrc)); auto receiver = std::make_unique( @@ -1265,8 +1245,8 @@ TEST_P(RtpVideoStreamReceiverTestPlayoutDelay, PlayoutDelay) { // Expect the playout delay of encoded frame to be the same as the transmitted // playout delay unless it was overridden by a field trial. EXPECT_CALL(mock_on_complete_frame_callback_, DoOnCompleteFrame(_)) - .WillOnce(Invoke([expected_playout_delay = GetParam().expected_delay]( - video_coding::EncodedFrame* frame) { + .WillOnce(Invoke([expected_playout_delay = + GetParam().expected_delay](EncodedFrame* frame) { EXPECT_EQ(frame->EncodedImage().playout_delay_, expected_playout_delay); })); rtp_video_stream_receiver_->OnReceivedPayloadData( diff --git a/video/send_delay_stats.h b/video/send_delay_stats.h index 20f9804d64..fa76a1e39c 100644 --- a/video/send_delay_stats.h +++ b/video/send_delay_stats.h @@ -27,6 +27,12 @@ namespace webrtc { +// Used to collect delay stats for video streams. The class gets callbacks +// from more than one threads and internally uses a mutex for data access +// synchronization. +// TODO(bugs.webrtc.org/11993): OnSendPacket and OnSentPacket will eventually +// be called consistently on the same thread. Once we're there, we should be +// able to avoid locking (at least for the fast path). class SendDelayStats : public SendPacketObserver { public: explicit SendDelayStats(Clock* clock); diff --git a/video/send_statistics_proxy.cc b/video/send_statistics_proxy.cc index 3b3f69d4e2..1b968ef8f7 100644 --- a/video/send_statistics_proxy.cc +++ b/video/send_statistics_proxy.cc @@ -670,6 +670,7 @@ void SendStatisticsProxy::UmaSamplesContainer::UpdateHistograms( void SendStatisticsProxy::OnEncoderReconfigured( const VideoEncoderConfig& config, const std::vector& streams) { + // Called on VideoStreamEncoder's encoder_queue_. MutexLock lock(&mutex_); if (content_type_ != config.content_type) { @@ -737,6 +738,8 @@ VideoSendStream::Stats SendStatisticsProxy::GetStats() { PurgeOldStats(); stats_.input_frame_rate = round(uma_container_->input_frame_rate_tracker_.ComputeRate()); + stats_.frames = + uma_container_->input_frame_rate_tracker_.TotalSampleCount(); stats_.content_type = content_type_ == VideoEncoderConfig::ContentType::kRealtimeVideo ? VideoContentType::UNSPECIFIED @@ -1282,17 +1285,6 @@ void SendStatisticsProxy::RtcpPacketTypesCounterUpdated( uma_container_->first_rtcp_stats_time_ms_ = clock_->TimeInMilliseconds(); } -void SendStatisticsProxy::StatisticsUpdated(const RtcpStatistics& statistics, - uint32_t ssrc) { - MutexLock lock(&mutex_); - VideoSendStream::StreamStats* stats = GetStatsEntry(ssrc); - if (!stats) - return; - - stats->rtcp_stats = statistics; - uma_container_->report_block_stats_.Store(ssrc, statistics); -} - void SendStatisticsProxy::OnReportBlockDataUpdated( ReportBlockData report_block_data) { MutexLock lock(&mutex_); @@ -1300,6 +1292,13 @@ void SendStatisticsProxy::OnReportBlockDataUpdated( GetStatsEntry(report_block_data.report_block().source_ssrc); if (!stats) return; + const RTCPReportBlock& report_block = report_block_data.report_block(); + uma_container_->report_block_stats_.Store( + /*ssrc=*/report_block.source_ssrc, + /*packets_lost=*/report_block.packets_lost, + /*extended_highest_sequence_number=*/ + report_block.extended_highest_sequence_number); + stats->report_block_data = std::move(report_block_data); } diff --git a/video/send_statistics_proxy.h b/video/send_statistics_proxy.h index 0de7df290e..bfb221f65c 100644 --- a/video/send_statistics_proxy.h +++ b/video/send_statistics_proxy.h @@ -37,7 +37,6 @@ namespace webrtc { class SendStatisticsProxy : public VideoStreamEncoderObserver, - public RtcpStatisticsCallback, public ReportBlockDataObserver, public RtcpPacketTypeCounterObserver, public StreamDataCountersCallback, @@ -106,9 +105,6 @@ class SendStatisticsProxy : public VideoStreamEncoderObserver, int GetSendFrameRate() const; protected: - // From RtcpStatisticsCallback. - void StatisticsUpdated(const RtcpStatistics& statistics, - uint32_t ssrc) override; // From ReportBlockDataObserver. void OnReportBlockDataUpdated(ReportBlockData report_block_data) override; // From RtcpPacketTypeCounterObserver. diff --git a/video/send_statistics_proxy_unittest.cc b/video/send_statistics_proxy_unittest.cc index 33107d4c2f..d4a7a49e39 100644 --- a/video/send_statistics_proxy_unittest.cc +++ b/video/send_statistics_proxy_unittest.cc @@ -64,9 +64,7 @@ class SendStatisticsProxyTest : public ::testing::Test { explicit SendStatisticsProxyTest(const std::string& field_trials) : override_field_trials_(field_trials), fake_clock_(1234), - config_(GetTestConfig()), - avg_delay_ms_(0), - max_delay_ms_(0) {} + config_(GetTestConfig()) {} virtual ~SendStatisticsProxyTest() {} protected: @@ -126,6 +124,7 @@ class SendStatisticsProxyTest : public ::testing::Test { } void ExpectEqual(VideoSendStream::Stats one, VideoSendStream::Stats other) { + EXPECT_EQ(one.frames, other.frames); EXPECT_EQ(one.input_frame_rate, other.input_frame_rate); EXPECT_EQ(one.encode_frame_rate, other.encode_frame_rate); EXPECT_EQ(one.media_bitrate_bps, other.media_bitrate_bps); @@ -160,11 +159,19 @@ class SendStatisticsProxyTest : public ::testing::Test { b.rtp_stats.retransmitted.packets); EXPECT_EQ(a.rtp_stats.fec.packets, b.rtp_stats.fec.packets); - EXPECT_EQ(a.rtcp_stats.fraction_lost, b.rtcp_stats.fraction_lost); - EXPECT_EQ(a.rtcp_stats.packets_lost, b.rtcp_stats.packets_lost); - EXPECT_EQ(a.rtcp_stats.extended_highest_sequence_number, - b.rtcp_stats.extended_highest_sequence_number); - EXPECT_EQ(a.rtcp_stats.jitter, b.rtcp_stats.jitter); + EXPECT_EQ(a.report_block_data.has_value(), + b.report_block_data.has_value()); + if (a.report_block_data.has_value()) { + const RTCPReportBlock& a_rtcp_stats = + a.report_block_data->report_block(); + const RTCPReportBlock& b_rtcp_stats = + b.report_block_data->report_block(); + EXPECT_EQ(a_rtcp_stats.fraction_lost, b_rtcp_stats.fraction_lost); + EXPECT_EQ(a_rtcp_stats.packets_lost, b_rtcp_stats.packets_lost); + EXPECT_EQ(a_rtcp_stats.extended_highest_sequence_number, + b_rtcp_stats.extended_highest_sequence_number); + EXPECT_EQ(a_rtcp_stats.jitter, b_rtcp_stats.jitter); + } } } @@ -172,36 +179,40 @@ class SendStatisticsProxyTest : public ::testing::Test { SimulatedClock fake_clock_; std::unique_ptr statistics_proxy_; VideoSendStream::Config config_; - int avg_delay_ms_; - int max_delay_ms_; VideoSendStream::Stats expected_; - typedef std::map::const_iterator - StreamIterator; }; -TEST_F(SendStatisticsProxyTest, RtcpStatistics) { - RtcpStatisticsCallback* callback = statistics_proxy_.get(); - for (const auto& ssrc : config_.rtp.ssrcs) { - VideoSendStream::StreamStats& ssrc_stats = expected_.substreams[ssrc]; - +TEST_F(SendStatisticsProxyTest, ReportBlockDataObserver) { + ReportBlockDataObserver* callback = statistics_proxy_.get(); + for (uint32_t ssrc : config_.rtp.ssrcs) { // Add statistics with some arbitrary, but unique, numbers. - uint32_t offset = ssrc * sizeof(RtcpStatistics); - ssrc_stats.rtcp_stats.packets_lost = offset; - ssrc_stats.rtcp_stats.extended_highest_sequence_number = offset + 1; - ssrc_stats.rtcp_stats.fraction_lost = offset + 2; - ssrc_stats.rtcp_stats.jitter = offset + 3; - callback->StatisticsUpdated(ssrc_stats.rtcp_stats, ssrc); + uint32_t offset = ssrc * 4; + RTCPReportBlock report_block; + report_block.source_ssrc = ssrc; + report_block.packets_lost = offset; + report_block.extended_highest_sequence_number = offset + 1; + report_block.fraction_lost = offset + 2; + report_block.jitter = offset + 3; + ReportBlockData data; + data.SetReportBlock(report_block, 0); + expected_.substreams[ssrc].report_block_data = data; + + callback->OnReportBlockDataUpdated(data); } - for (const auto& ssrc : config_.rtp.rtx.ssrcs) { - VideoSendStream::StreamStats& ssrc_stats = expected_.substreams[ssrc]; - + for (uint32_t ssrc : config_.rtp.rtx.ssrcs) { // Add statistics with some arbitrary, but unique, numbers. - uint32_t offset = ssrc * sizeof(RtcpStatistics); - ssrc_stats.rtcp_stats.packets_lost = offset; - ssrc_stats.rtcp_stats.extended_highest_sequence_number = offset + 1; - ssrc_stats.rtcp_stats.fraction_lost = offset + 2; - ssrc_stats.rtcp_stats.jitter = offset + 3; - callback->StatisticsUpdated(ssrc_stats.rtcp_stats, ssrc); + uint32_t offset = ssrc * 4; + RTCPReportBlock report_block; + report_block.source_ssrc = ssrc; + report_block.packets_lost = offset; + report_block.extended_highest_sequence_number = offset + 1; + report_block.fraction_lost = offset + 2; + report_block.jitter = offset + 3; + ReportBlockData data; + data.SetReportBlock(report_block, 0); + expected_.substreams[ssrc].report_block_data = data; + + callback->OnReportBlockDataUpdated(data); } VideoSendStream::Stats stats = statistics_proxy_->GetStats(); ExpectEqual(expected_, stats); @@ -283,21 +294,17 @@ TEST_F(SendStatisticsProxyTest, DataCounters) { TEST_F(SendStatisticsProxyTest, Bitrate) { BitrateStatisticsObserver* observer = statistics_proxy_.get(); for (const auto& ssrc : config_.rtp.ssrcs) { - uint32_t total; - uint32_t retransmit; // Use ssrc as bitrate_bps to get a unique value for each stream. - total = ssrc; - retransmit = ssrc + 1; + uint32_t total = ssrc; + uint32_t retransmit = ssrc + 1; observer->Notify(total, retransmit, ssrc); expected_.substreams[ssrc].total_bitrate_bps = total; expected_.substreams[ssrc].retransmit_bitrate_bps = retransmit; } for (const auto& ssrc : config_.rtp.rtx.ssrcs) { - uint32_t total; - uint32_t retransmit; // Use ssrc as bitrate_bps to get a unique value for each stream. - total = ssrc; - retransmit = ssrc + 1; + uint32_t total = ssrc; + uint32_t retransmit = ssrc + 1; observer->Notify(total, retransmit, ssrc); expected_.substreams[ssrc].total_bitrate_bps = total; expected_.substreams[ssrc].retransmit_bitrate_bps = retransmit; @@ -2180,10 +2187,13 @@ TEST_F(SendStatisticsProxyTest, NoSubstreams) { std::max(*absl::c_max_element(config_.rtp.ssrcs), *absl::c_max_element(config_.rtp.rtx.ssrcs)) + 1; - // From RtcpStatisticsCallback. - RtcpStatistics rtcp_stats; - RtcpStatisticsCallback* rtcp_callback = statistics_proxy_.get(); - rtcp_callback->StatisticsUpdated(rtcp_stats, excluded_ssrc); + // From ReportBlockDataObserver. + ReportBlockDataObserver* rtcp_callback = statistics_proxy_.get(); + RTCPReportBlock report_block; + report_block.source_ssrc = excluded_ssrc; + ReportBlockData data; + data.SetReportBlock(report_block, 0); + rtcp_callback->OnReportBlockDataUpdated(data); // From BitrateStatisticsObserver. uint32_t total = 0; @@ -2230,9 +2240,12 @@ TEST_F(SendStatisticsProxyTest, EncodedResolutionTimesOut) { // Update the first SSRC with bogus RTCP stats to make sure that encoded // resolution still times out (no global timeout for all stats). - RtcpStatistics rtcp_statistics; - RtcpStatisticsCallback* rtcp_stats = statistics_proxy_.get(); - rtcp_stats->StatisticsUpdated(rtcp_statistics, config_.rtp.ssrcs[0]); + ReportBlockDataObserver* rtcp_callback = statistics_proxy_.get(); + RTCPReportBlock report_block; + report_block.source_ssrc = config_.rtp.ssrcs[0]; + ReportBlockData data; + data.SetReportBlock(report_block, 0); + rtcp_callback->OnReportBlockDataUpdated(data); // Report stats for second SSRC to make sure it's not outdated along with the // first SSRC. diff --git a/video/video_analyzer.cc b/video/video_analyzer.cc index c16c3b383b..b90ba2973a 100644 --- a/video/video_analyzer.cc +++ b/video/video_analyzer.cc @@ -18,11 +18,13 @@ #include "common_video/libyuv/include/webrtc_libyuv.h" #include "modules/rtp_rtcp/source/create_video_rtp_depacketizer.h" #include "modules/rtp_rtcp/source/rtp_packet.h" +#include "modules/rtp_rtcp/source/rtp_util.h" #include "rtc_base/cpu_time.h" #include "rtc_base/format_macros.h" #include "rtc_base/memory_usage.h" #include "rtc_base/task_queue_for_test.h" #include "rtc_base/task_utils/repeating_task.h" +#include "rtc_base/time_utils.h" #include "system_wrappers/include/cpu_info.h" #include "test/call_test.h" #include "test/testsupport/file_utils.h" @@ -136,10 +138,12 @@ VideoAnalyzer::VideoAnalyzer(test::LayerFilteringTransport* transport, } for (uint32_t i = 0; i < num_cores; ++i) { - rtc::PlatformThread* thread = - new rtc::PlatformThread(&FrameComparisonThread, this, "Analyzer"); - thread->Start(); - comparison_thread_pool_.push_back(thread); + comparison_thread_pool_.push_back(rtc::PlatformThread::SpawnJoinable( + [this] { + while (CompareFrames()) { + } + }, + "Analyzer")); } if (!rtp_dump_name.empty()) { @@ -154,10 +158,8 @@ VideoAnalyzer::~VideoAnalyzer() { MutexLock lock(&comparison_lock_); quit_ = true; } - for (rtc::PlatformThread* thread : comparison_thread_pool_) { - thread->Stop(); - delete thread; - } + // Joins all threads. + comparison_thread_pool_.clear(); } void VideoAnalyzer::SetReceiver(PacketReceiver* receiver) { @@ -211,7 +213,7 @@ PacketReceiver::DeliveryStatus VideoAnalyzer::DeliverPacket( int64_t packet_time_us) { // Ignore timestamps of RTCP packets. They're not synchronized with // RTP packet timestamps and so they would confuse wrap_handler_. - if (RtpHeaderParser::IsRtcp(packet.cdata(), packet.size())) { + if (IsRtcpPacket(packet)) { return receiver_->DeliverPacket(media_type, std::move(packet), packet_time_us); } @@ -532,12 +534,6 @@ void VideoAnalyzer::PollStats() { memory_usage_.AddSample(rtc::GetProcessResidentSizeBytes()); } -void VideoAnalyzer::FrameComparisonThread(void* obj) { - VideoAnalyzer* analyzer = static_cast(obj); - while (analyzer->CompareFrames()) { - } -} - bool VideoAnalyzer::CompareFrames() { if (AllFramesRecorded()) return false; @@ -605,7 +601,7 @@ bool VideoAnalyzer::AllFramesRecordedLocked() { bool VideoAnalyzer::FrameProcessed() { MutexLock lock(&comparison_lock_); ++frames_processed_; - assert(frames_processed_ <= frames_to_process_); + RTC_DCHECK_LE(frames_processed_, frames_to_process_); return frames_processed_ == frames_to_process_ || (clock_->CurrentTime() > test_end_ && comparisons_.empty()); } diff --git a/video/video_analyzer.h b/video/video_analyzer.h index 18bacc16fc..68861d1b5f 100644 --- a/video/video_analyzer.h +++ b/video/video_analyzer.h @@ -302,7 +302,7 @@ class VideoAnalyzer : public PacketReceiver, const double avg_ssim_threshold_; bool is_quick_test_enabled_; - std::vector comparison_thread_pool_; + std::vector comparison_thread_pool_; rtc::Event comparison_available_event_; std::deque comparisons_ RTC_GUARDED_BY(comparison_lock_); bool quit_ RTC_GUARDED_BY(comparison_lock_); diff --git a/video/video_quality_test.cc b/video/video_quality_test.cc index a58aa1f33f..b77a4759a2 100644 --- a/video/video_quality_test.cc +++ b/video/video_quality_test.cc @@ -626,7 +626,7 @@ void VideoQualityTest::FillScalabilitySettings( encoder_config.spatial_layers = params->ss[video_idx].spatial_layers; encoder_config.simulcast_layers = std::vector(num_streams); encoder_config.video_stream_factory = - new rtc::RefCountedObject( + rtc::make_ref_counted( params->video[video_idx].codec, kDefaultMaxQp, params->screenshare[video_idx].enabled, true); params->ss[video_idx].streams = @@ -800,7 +800,7 @@ void VideoQualityTest::SetupVideo(Transport* send_transport, params_.ss[video_idx].streams; } video_encoder_configs_[video_idx].video_stream_factory = - new rtc::RefCountedObject( + rtc::make_ref_counted( params_.video[video_idx].codec, params_.ss[video_idx].streams[0].max_qp, params_.screenshare[video_idx].enabled, true); @@ -829,7 +829,7 @@ void VideoQualityTest::SetupVideo(Transport* send_transport, vp8_settings.numberOfTemporalLayers = static_cast( params_.video[video_idx].num_temporal_layers); video_encoder_configs_[video_idx].encoder_specific_settings = - new rtc::RefCountedObject< + rtc::make_ref_counted< VideoEncoderConfig::Vp8EncoderSpecificSettings>(vp8_settings); } else if (params_.video[video_idx].codec == "VP9") { VideoCodecVP9 vp9_settings = VideoEncoder::GetDefaultVp9Settings(); @@ -846,7 +846,7 @@ void VideoQualityTest::SetupVideo(Transport* send_transport, vp9_settings.flexibleMode = true; } video_encoder_configs_[video_idx].encoder_specific_settings = - new rtc::RefCountedObject< + rtc::make_ref_counted< VideoEncoderConfig::Vp9EncoderSpecificSettings>(vp9_settings); } } else if (params_.ss[video_idx].num_spatial_layers > 1) { @@ -860,8 +860,8 @@ void VideoQualityTest::SetupVideo(Transport* send_transport, vp9_settings.interLayerPred = params_.ss[video_idx].inter_layer_pred; vp9_settings.automaticResizeOn = false; video_encoder_configs_[video_idx].encoder_specific_settings = - new rtc::RefCountedObject< - VideoEncoderConfig::Vp9EncoderSpecificSettings>(vp9_settings); + rtc::make_ref_counted( + vp9_settings); RTC_DCHECK_EQ(video_encoder_configs_[video_idx].simulcast_layers.size(), 1); // Min bitrate will be enforced by spatial layer config instead. @@ -871,7 +871,7 @@ void VideoQualityTest::SetupVideo(Transport* send_transport, VideoCodecVP8 vp8_settings = VideoEncoder::GetDefaultVp8Settings(); vp8_settings.automaticResizeOn = true; video_encoder_configs_[video_idx].encoder_specific_settings = - new rtc::RefCountedObject< + rtc::make_ref_counted< VideoEncoderConfig::Vp8EncoderSpecificSettings>(vp8_settings); } else if (params_.video[video_idx].codec == "VP9") { VideoCodecVP9 vp9_settings = VideoEncoder::GetDefaultVp9Settings(); @@ -879,7 +879,7 @@ void VideoQualityTest::SetupVideo(Transport* send_transport, vp9_settings.automaticResizeOn = params_.ss[video_idx].num_spatial_layers == 1; video_encoder_configs_[video_idx].encoder_specific_settings = - new rtc::RefCountedObject< + rtc::make_ref_counted< VideoEncoderConfig::Vp9EncoderSpecificSettings>(vp9_settings); } else if (params_.video[video_idx].codec == "H264") { // Quality scaling is always on for H.264. @@ -898,18 +898,18 @@ void VideoQualityTest::SetupVideo(Transport* send_transport, VideoCodecVP8 vp8_settings = VideoEncoder::GetDefaultVp8Settings(); vp8_settings.automaticResizeOn = false; video_encoder_configs_[video_idx].encoder_specific_settings = - new rtc::RefCountedObject< + rtc::make_ref_counted< VideoEncoderConfig::Vp8EncoderSpecificSettings>(vp8_settings); } else if (params_.video[video_idx].codec == "VP9") { VideoCodecVP9 vp9_settings = VideoEncoder::GetDefaultVp9Settings(); vp9_settings.automaticResizeOn = false; video_encoder_configs_[video_idx].encoder_specific_settings = - new rtc::RefCountedObject< + rtc::make_ref_counted< VideoEncoderConfig::Vp9EncoderSpecificSettings>(vp9_settings); } else if (params_.video[video_idx].codec == "H264") { VideoCodecH264 h264_settings = VideoEncoder::GetDefaultH264Settings(); video_encoder_configs_[video_idx].encoder_specific_settings = - new rtc::RefCountedObject< + rtc::make_ref_counted< VideoEncoderConfig::H264EncoderSpecificSettings>(h264_settings); } } @@ -925,13 +925,13 @@ void VideoQualityTest::SetupVideo(Transport* send_transport, } CreateMatchingFecConfig(recv_transport, *GetVideoSendConfig()); - GetFlexFecConfig()->transport_cc = params_.call.send_side_bwe; + GetFlexFecConfig()->rtp.transport_cc = params_.call.send_side_bwe; if (params_.call.send_side_bwe) { - GetFlexFecConfig()->rtp_header_extensions.push_back( + GetFlexFecConfig()->rtp.extensions.push_back( RtpExtension(RtpExtension::kTransportSequenceNumberUri, kTransportSequenceNumberExtensionId)); } else { - GetFlexFecConfig()->rtp_header_extensions.push_back( + GetFlexFecConfig()->rtp.extensions.push_back( RtpExtension(RtpExtension::kAbsSendTimeUri, kAbsSendTimeExtensionId)); } } @@ -986,7 +986,7 @@ void VideoQualityTest::SetupThumbnails(Transport* send_transport, thumbnail_encoder_config.max_bitrate_bps = 50000; std::vector streams{params_.ss[0].streams[0]}; thumbnail_encoder_config.video_stream_factory = - new rtc::RefCountedObject(streams); + rtc::make_ref_counted(streams); thumbnail_encoder_config.spatial_layers = params_.ss[0].spatial_layers; thumbnail_encoder_configs_.push_back(thumbnail_encoder_config.Copy()); diff --git a/video/video_receive_stream.cc b/video/video_receive_stream.cc index c46868d749..da8eb7de60 100644 --- a/video/video_receive_stream.cc +++ b/video/video_receive_stream.cc @@ -24,6 +24,7 @@ #include "api/array_view.h" #include "api/crypto/frame_decryptor_interface.h" #include "api/video/encoded_image.h" +#include "api/video_codecs/h264_profile_level_id.h" #include "api/video_codecs/sdp_video_format.h" #include "api/video_codecs/video_codec.h" #include "api/video_codecs/video_decoder_factory.h" @@ -31,7 +32,6 @@ #include "call/rtp_stream_receiver_controller_interface.h" #include "call/rtx_receive_stream.h" #include "common_video/include/incoming_video_stream.h" -#include "media/base/h264_profile_level_id.h" #include "modules/utility/include/process_thread.h" #include "modules/video_coding/include/video_codec_interface.h" #include "modules/video_coding/include/video_coding_defines.h" @@ -39,7 +39,6 @@ #include "modules/video_coding/timing.h" #include "modules/video_coding/utility/vp8_header_parser.h" #include "rtc_base/checks.h" -#include "rtc_base/experiments/keyframe_interval_settings.h" #include "rtc_base/location.h" #include "rtc_base/logging.h" #include "rtc_base/strings/string_builder.h" @@ -60,7 +59,6 @@ constexpr int VideoReceiveStream::kMaxWaitForKeyFrameMs; namespace { -using video_coding::EncodedFrame; using ReturnReason = video_coding::FrameBuffer::ReturnReason; constexpr int kMinBaseMinimumDelayMs = 0; @@ -69,7 +67,7 @@ constexpr int kMaxBaseMinimumDelayMs = 10000; constexpr int kMaxWaitForFrameMs = 3000; // Concrete instance of RecordableEncodedFrame wrapping needed content -// from video_coding::EncodedFrame. +// from EncodedFrame. class WebRtcRecordableEncodedFrame : public RecordableEncodedFrame { public: explicit WebRtcRecordableEncodedFrame(const EncodedFrame& frame) @@ -221,12 +219,8 @@ VideoReceiveStream::VideoReceiveStream( config_.frame_decryptor, config_.frame_transformer), rtp_stream_sync_(this), - max_wait_for_keyframe_ms_(KeyframeIntervalSettings::ParseFromFieldTrials() - .MaxWaitForKeyframeMs() - .value_or(kMaxWaitForKeyFrameMs)), - max_wait_for_frame_ms_(KeyframeIntervalSettings::ParseFromFieldTrials() - .MaxWaitForFrameMs() - .value_or(kMaxWaitForFrameMs)), + max_wait_for_keyframe_ms_(kMaxWaitForKeyFrameMs), + max_wait_for_frame_ms_(kMaxWaitForFrameMs), decode_queue_(task_queue_factory_->CreateTaskQueue( "DecodingQueue", TaskQueueFactory::Priority::HIGH)) { @@ -338,8 +332,7 @@ void VideoReceiveStream::Start() { for (const Decoder& decoder : config_.decoders) { std::unique_ptr video_decoder = - config_.decoder_factory->LegacyCreateVideoDecoder(decoder.video_format, - config_.stream_id); + config_.decoder_factory->CreateVideoDecoder(decoder.video_format); // If we still have no valid decoder, we have to create a "Null" decoder // that ignores all calls. The reason we can get into this state is that the // old decoder factory interface doesn't have a way to query supported @@ -513,6 +506,10 @@ void VideoReceiveStream::OnFrame(const VideoFrame& video_frame) { int64_t video_playout_ntp_ms; int64_t sync_offset_ms; double estimated_freq_khz; + + // TODO(bugs.webrtc.org/10739): we should set local capture clock offset for + // |video_frame.packet_infos|. But VideoFrame is const qualified here. + // TODO(tommi): GetStreamSyncOffsetInMs grabs three locks. One inside the // function itself, another in GetChannel() and a third in // GetPlayoutTimestamp. Seems excessive. Anyhow, I'm assuming the function @@ -554,8 +551,7 @@ void VideoReceiveStream::RequestKeyFrame(int64_t timestamp_ms) { last_keyframe_request_ms_ = timestamp_ms; } -void VideoReceiveStream::OnCompleteFrame( - std::unique_ptr frame) { +void VideoReceiveStream::OnCompleteFrame(std::unique_ptr frame) { RTC_DCHECK_RUN_ON(&network_sequence_checker_); // TODO(https://bugs.webrtc.org/9974): Consider removing this workaround. int64_t time_now_ms = clock_->TimeInMilliseconds(); @@ -670,7 +666,7 @@ void VideoReceiveStream::HandleEncodedFrame( decode_result == WEBRTC_VIDEO_CODEC_OK_REQUEST_KEYFRAME) { keyframe_required_ = false; frame_decoded_ = true; - rtp_video_stream_receiver_.FrameDecoded(frame->id.picture_id); + rtp_video_stream_receiver_.FrameDecoded(frame->Id()); if (decode_result == WEBRTC_VIDEO_CODEC_OK_REQUEST_KEYFRAME) RequestKeyFrame(now_ms); @@ -683,7 +679,6 @@ void VideoReceiveStream::HandleEncodedFrame( } if (encoded_frame_buffer_function_) { - frame->Retain(); encoded_frame_buffer_function_(WebRtcRecordableEncodedFrame(*frame)); } } @@ -767,7 +762,6 @@ VideoReceiveStream::RecordingState VideoReceiveStream::SetAndGetRecordingState( RTC_DCHECK_RUN_ON(&decode_queue_); // Save old state. old_state.callback = std::move(encoded_frame_buffer_function_); - old_state.keyframe_needed = keyframe_generation_requested_; old_state.last_keyframe_request_ms = last_keyframe_request_ms_; // Set new state. @@ -776,7 +770,7 @@ VideoReceiveStream::RecordingState VideoReceiveStream::SetAndGetRecordingState( RequestKeyFrame(clock_->TimeInMilliseconds()); keyframe_generation_requested_ = true; } else { - keyframe_generation_requested_ = state.keyframe_needed; + keyframe_generation_requested_ = false; last_keyframe_request_ms_ = state.last_keyframe_request_ms.value_or(0); } event.Set(); diff --git a/video/video_receive_stream.h b/video/video_receive_stream.h index 5e52063536..c778d74558 100644 --- a/video/video_receive_stream.h +++ b/video/video_receive_stream.h @@ -14,6 +14,7 @@ #include #include +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_factory.h" #include "api/video/recordable_encoded_frame.h" #include "call/rtp_packet_sink_interface.h" @@ -24,7 +25,6 @@ #include "modules/video_coding/frame_buffer2.h" #include "modules/video_coding/video_receiver2.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_queue.h" #include "system_wrappers/include/clock.h" @@ -45,12 +45,13 @@ class VCMTiming; namespace internal { -class VideoReceiveStream : public webrtc::VideoReceiveStream, - public rtc::VideoSinkInterface, - public NackSender, - public video_coding::OnCompleteFrameCallback, - public Syncable, - public CallStatsObserver { +class VideoReceiveStream + : public webrtc::DEPRECATED_VideoReceiveStream, + public rtc::VideoSinkInterface, + public NackSender, + public RtpVideoStreamReceiver::OnCompleteFrameCallback, + public Syncable, + public CallStatsObserver { public: // The default number of milliseconds to pass before re-requesting a key frame // to be sent. @@ -86,6 +87,8 @@ class VideoReceiveStream : public webrtc::VideoReceiveStream, void Start() override; void Stop() override; + const RtpConfig& rtp_config() const override { return config_.rtp; } + webrtc::VideoReceiveStream::Stats GetStats() const override; void AddSecondarySink(RtpPacketSinkInterface* sink) override; @@ -111,9 +114,8 @@ class VideoReceiveStream : public webrtc::VideoReceiveStream, void SendNack(const std::vector& sequence_numbers, bool buffering_allowed) override; - // Implements video_coding::OnCompleteFrameCallback. - void OnCompleteFrame( - std::unique_ptr frame) override; + // Implements RtpVideoStreamReceiver::OnCompleteFrameCallback. + void OnCompleteFrame(std::unique_ptr frame) override; // Implements CallStatsObserver::OnRttUpdate void OnRttUpdate(int64_t avg_rtt_ms, int64_t max_rtt_ms) override; @@ -138,7 +140,7 @@ class VideoReceiveStream : public webrtc::VideoReceiveStream, private: int64_t GetWaitMs() const; void StartNextDecode() RTC_RUN_ON(decode_queue_); - void HandleEncodedFrame(std::unique_ptr frame) + void HandleEncodedFrame(std::unique_ptr frame) RTC_RUN_ON(decode_queue_); void HandleFrameBufferTimeout() RTC_RUN_ON(decode_queue_); void UpdatePlayoutDelays() const diff --git a/video/video_receive_stream2.cc b/video/video_receive_stream2.cc index 5431ae853d..72257f01cc 100644 --- a/video/video_receive_stream2.cc +++ b/video/video_receive_stream2.cc @@ -24,6 +24,7 @@ #include "api/array_view.h" #include "api/crypto/frame_decryptor_interface.h" #include "api/video/encoded_image.h" +#include "api/video_codecs/h264_profile_level_id.h" #include "api/video_codecs/sdp_video_format.h" #include "api/video_codecs/video_codec.h" #include "api/video_codecs/video_decoder_factory.h" @@ -31,17 +32,16 @@ #include "call/rtp_stream_receiver_controller_interface.h" #include "call/rtx_receive_stream.h" #include "common_video/include/incoming_video_stream.h" -#include "media/base/h264_profile_level_id.h" #include "modules/video_coding/include/video_codec_interface.h" #include "modules/video_coding/include/video_coding_defines.h" #include "modules/video_coding/include/video_error_codes.h" #include "modules/video_coding/timing.h" #include "modules/video_coding/utility/vp8_header_parser.h" #include "rtc_base/checks.h" -#include "rtc_base/experiments/keyframe_interval_settings.h" #include "rtc_base/location.h" #include "rtc_base/logging.h" #include "rtc_base/strings/string_builder.h" +#include "rtc_base/synchronization/mutex.h" #include "rtc_base/system/thread_registry.h" #include "rtc_base/time_utils.h" #include "rtc_base/trace_event.h" @@ -58,7 +58,6 @@ constexpr int VideoReceiveStream2::kMaxWaitForKeyFrameMs; namespace { -using video_coding::EncodedFrame; using ReturnReason = video_coding::FrameBuffer::ReturnReason; constexpr int kMinBaseMinimumDelayMs = 0; @@ -66,17 +65,20 @@ constexpr int kMaxBaseMinimumDelayMs = 10000; constexpr int kMaxWaitForFrameMs = 3000; +constexpr int kDefaultMaximumPreStreamDecoders = 100; + // Concrete instance of RecordableEncodedFrame wrapping needed content -// from video_coding::EncodedFrame. +// from EncodedFrame. class WebRtcRecordableEncodedFrame : public RecordableEncodedFrame { public: - explicit WebRtcRecordableEncodedFrame(const EncodedFrame& frame) + explicit WebRtcRecordableEncodedFrame( + const EncodedFrame& frame, + RecordableEncodedFrame::EncodedResolution resolution) : buffer_(frame.GetEncodedData()), render_time_ms_(frame.RenderTime()), codec_(frame.CodecSpecific()->codecType), is_key_frame_(frame.FrameType() == VideoFrameType::kVideoFrameKey), - resolution_{frame.EncodedImage()._encodedWidth, - frame.EncodedImage()._encodedHeight} { + resolution_(resolution) { if (frame.ColorSpace()) { color_space_ = *frame.ColorSpace(); } @@ -179,6 +181,12 @@ class NullVideoDecoder : public webrtc::VideoDecoder { const char* ImplementationName() const override { return "NullVideoDecoder"; } }; +bool IsKeyFrameAndUnspecifiedResolution(const EncodedFrame& frame) { + return frame.FrameType() == VideoFrameType::kVideoFrameKey && + frame.EncodedImage()._encodedWidth == 0 && + frame.EncodedImage()._encodedHeight == 0; +} + // TODO(https://bugs.webrtc.org/9974): Consider removing this workaround. // Maximum time between frames before resetting the FrameBuffer to avoid RTP // timestamps wraparound to affect FrameBuffer. @@ -186,30 +194,44 @@ constexpr int kInactiveStreamThresholdMs = 600000; // 10 minutes. } // namespace -VideoReceiveStream2::VideoReceiveStream2( - TaskQueueFactory* task_queue_factory, - TaskQueueBase* current_queue, - RtpStreamReceiverControllerInterface* receiver_controller, - int num_cpu_cores, - PacketRouter* packet_router, - VideoReceiveStream::Config config, - ProcessThread* process_thread, - CallStats* call_stats, - Clock* clock, - VCMTiming* timing) +int DetermineMaxWaitForFrame(const VideoReceiveStream::Config& config, + bool is_keyframe) { + // A (arbitrary) conversion factor between the remotely signalled NACK buffer + // time (if not present defaults to 1000ms) and the maximum time we wait for a + // remote frame. Chosen to not change existing defaults when using not + // rtx-time. + const int conversion_factor = 3; + + if (config.rtp.nack.rtp_history_ms > 0 && + conversion_factor * config.rtp.nack.rtp_history_ms < kMaxWaitForFrameMs) { + return is_keyframe ? config.rtp.nack.rtp_history_ms + : conversion_factor * config.rtp.nack.rtp_history_ms; + } + return is_keyframe ? VideoReceiveStream2::kMaxWaitForKeyFrameMs + : kMaxWaitForFrameMs; +} + +VideoReceiveStream2::VideoReceiveStream2(TaskQueueFactory* task_queue_factory, + Call* call, + int num_cpu_cores, + PacketRouter* packet_router, + VideoReceiveStream::Config config, + CallStats* call_stats, + Clock* clock, + VCMTiming* timing) : task_queue_factory_(task_queue_factory), transport_adapter_(config.rtcp_send_transport), config_(std::move(config)), num_cpu_cores_(num_cpu_cores), - worker_thread_(current_queue), + call_(call), clock_(clock), call_stats_(call_stats), source_tracker_(clock_), - stats_proxy_(&config_, clock_, worker_thread_), + stats_proxy_(&config_, clock_, call->worker_thread()), rtp_receive_statistics_(ReceiveStatistics::Create(clock_)), timing_(timing), video_receiver_(clock_, timing_.get()), - rtp_video_stream_receiver_(worker_thread_, + rtp_video_stream_receiver_(call->worker_thread(), clock_, &transport_adapter_, call_stats->AsRtcpRttStats(), @@ -218,32 +240,27 @@ VideoReceiveStream2::VideoReceiveStream2( rtp_receive_statistics_.get(), &stats_proxy_, &stats_proxy_, - process_thread, this, // NackSender nullptr, // Use default KeyFrameRequestSender this, // OnCompleteFrameCallback config_.frame_decryptor, config_.frame_transformer), - rtp_stream_sync_(current_queue, this), - max_wait_for_keyframe_ms_(KeyframeIntervalSettings::ParseFromFieldTrials() - .MaxWaitForKeyframeMs() - .value_or(kMaxWaitForKeyFrameMs)), - max_wait_for_frame_ms_(KeyframeIntervalSettings::ParseFromFieldTrials() - .MaxWaitForFrameMs() - .value_or(kMaxWaitForFrameMs)), + rtp_stream_sync_(call->worker_thread(), this), + max_wait_for_keyframe_ms_(DetermineMaxWaitForFrame(config, true)), + max_wait_for_frame_ms_(DetermineMaxWaitForFrame(config, false)), low_latency_renderer_enabled_("enabled", true), low_latency_renderer_include_predecode_buffer_("include_predecode_buffer", true), + maximum_pre_stream_decoders_("max", kDefaultMaximumPreStreamDecoders), decode_queue_(task_queue_factory_->CreateTaskQueue( "DecodingQueue", TaskQueueFactory::Priority::HIGH)) { RTC_LOG(LS_INFO) << "VideoReceiveStream2: " << config_.ToString(); - RTC_DCHECK(worker_thread_); + RTC_DCHECK(call_->worker_thread()); RTC_DCHECK(config_.renderer); RTC_DCHECK(call_stats_); - - module_process_sequence_checker_.Detach(); + packet_sequence_checker_.Detach(); RTC_DCHECK(!config_.decoders.empty()); RTC_CHECK(config_.decoder_factory); @@ -261,15 +278,10 @@ VideoReceiveStream2::VideoReceiveStream2( frame_buffer_.reset( new video_coding::FrameBuffer(clock_, timing_.get(), &stats_proxy_)); - // Register with RtpStreamReceiverController. - media_receiver_ = receiver_controller->CreateReceiver( - config_.rtp.remote_ssrc, &rtp_video_stream_receiver_); if (config_.rtp.rtx_ssrc) { rtx_receive_stream_ = std::make_unique( &rtp_video_stream_receiver_, config.rtp.rtx_associated_payload_types, config_.rtp.remote_ssrc, rtp_receive_statistics_.get()); - rtx_receiver_ = receiver_controller->CreateReceiver( - config_.rtp.rtx_ssrc, rtx_receive_stream_.get()); } else { rtp_receive_statistics_->EnableRetransmitDetection(config.rtp.remote_ssrc, true); @@ -278,25 +290,55 @@ VideoReceiveStream2::VideoReceiveStream2( ParseFieldTrial({&low_latency_renderer_enabled_, &low_latency_renderer_include_predecode_buffer_}, field_trial::FindFullName("WebRTC-LowLatencyRenderer")); + ParseFieldTrial( + { + &maximum_pre_stream_decoders_, + }, + field_trial::FindFullName("WebRTC-PreStreamDecoders")); } VideoReceiveStream2::~VideoReceiveStream2() { RTC_DCHECK_RUN_ON(&worker_sequence_checker_); RTC_LOG(LS_INFO) << "~VideoReceiveStream2: " << config_.ToString(); + RTC_DCHECK(!media_receiver_); + RTC_DCHECK(!rtx_receiver_); Stop(); } +void VideoReceiveStream2::RegisterWithTransport( + RtpStreamReceiverControllerInterface* receiver_controller) { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + RTC_DCHECK(!media_receiver_); + RTC_DCHECK(!rtx_receiver_); + + // Register with RtpStreamReceiverController. + media_receiver_ = receiver_controller->CreateReceiver( + config_.rtp.remote_ssrc, &rtp_video_stream_receiver_); + if (config_.rtp.rtx_ssrc) { + RTC_DCHECK(rtx_receive_stream_); + rtx_receiver_ = receiver_controller->CreateReceiver( + config_.rtp.rtx_ssrc, rtx_receive_stream_.get()); + } +} + +void VideoReceiveStream2::UnregisterFromTransport() { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + media_receiver_.reset(); + rtx_receiver_.reset(); +} + void VideoReceiveStream2::SignalNetworkState(NetworkState state) { RTC_DCHECK_RUN_ON(&worker_sequence_checker_); rtp_video_stream_receiver_.SignalNetworkState(state); } bool VideoReceiveStream2::DeliverRtcp(const uint8_t* packet, size_t length) { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); return rtp_video_stream_receiver_.DeliverRtcp(packet, length); } void VideoReceiveStream2::SetSync(Syncable* audio_syncable) { - RTC_DCHECK_RUN_ON(&worker_sequence_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); rtp_stream_sync_.ConfigureSync(audio_syncable); } @@ -325,48 +367,27 @@ void VideoReceiveStream2::Start() { renderer = this; } + int decoders_count = 0; for (const Decoder& decoder : config_.decoders) { - std::unique_ptr video_decoder = - config_.decoder_factory->LegacyCreateVideoDecoder(decoder.video_format, - config_.stream_id); - // If we still have no valid decoder, we have to create a "Null" decoder - // that ignores all calls. The reason we can get into this state is that the - // old decoder factory interface doesn't have a way to query supported - // codecs. - if (!video_decoder) { - video_decoder = std::make_unique(); - } - - std::string decoded_output_file = - field_trial::FindFullName("WebRTC-DecoderDataDumpDirectory"); - // Because '/' can't be used inside a field trial parameter, we use ';' - // instead. - // This is only relevant to WebRTC-DecoderDataDumpDirectory - // field trial. ';' is chosen arbitrary. Even though it's a legal character - // in some file systems, we can sacrifice ability to use it in the path to - // dumped video, since it's developers-only feature for debugging. - absl::c_replace(decoded_output_file, ';', '/'); - if (!decoded_output_file.empty()) { - char filename_buffer[256]; - rtc::SimpleStringBuilder ssb(filename_buffer); - ssb << decoded_output_file << "/webrtc_receive_stream_" - << this->config_.rtp.remote_ssrc << "-" << rtc::TimeMicros() - << ".ivf"; - video_decoder = CreateFrameDumpingDecoderWrapper( - std::move(video_decoder), FileWrapper::OpenWriteOnly(ssb.str())); + // Create up to maximum_pre_stream_decoders_ up front, wait the the other + // decoders until they are requested (i.e., we receive the corresponding + // payload). + if (decoders_count < maximum_pre_stream_decoders_) { + CreateAndRegisterExternalDecoder(decoder); + ++decoders_count; } - video_decoders_.push_back(std::move(video_decoder)); - - video_receiver_.RegisterExternalDecoder(video_decoders_.back().get(), - decoder.payload_type); VideoCodec codec = CreateDecoderVideoCodec(decoder); const bool raw_payload = config_.rtp.raw_payload_types.count(decoder.payload_type) > 0; - rtp_video_stream_receiver_.AddReceiveCodec(decoder.payload_type, codec, - decoder.video_format.parameters, - raw_payload); + { + // TODO(bugs.webrtc.org/11993): Make this call on the network thread. + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + rtp_video_stream_receiver_.AddReceiveCodec( + decoder.payload_type, codec, decoder.video_format.parameters, + raw_payload); + } RTC_CHECK_EQ(VCM_OK, video_receiver_.RegisterReceiveCodec( decoder.payload_type, &codec, num_cpu_cores_)); } @@ -388,12 +409,23 @@ void VideoReceiveStream2::Start() { StartNextDecode(); }); decoder_running_ = true; - rtp_video_stream_receiver_.StartReceive(); + + { + // TODO(bugs.webrtc.org/11993): Make this call on the network thread. + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + rtp_video_stream_receiver_.StartReceive(); + } } void VideoReceiveStream2::Stop() { RTC_DCHECK_RUN_ON(&worker_sequence_checker_); - rtp_video_stream_receiver_.StopReceive(); + { + // TODO(bugs.webrtc.org/11993): Make this call on the network thread. + // Also call `GetUniqueFramesSeen()` at the same time (since it's a counter + // that's updated on the network thread). + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + rtp_video_stream_receiver_.StopReceive(); + } stats_proxy_.OnUniqueFramesCounted( rtp_video_stream_receiver_.GetUniqueFramesSeen()); @@ -429,6 +461,43 @@ void VideoReceiveStream2::Stop() { transport_adapter_.Disable(); } +void VideoReceiveStream2::CreateAndRegisterExternalDecoder( + const Decoder& decoder) { + TRACE_EVENT0("webrtc", + "VideoReceiveStream2::CreateAndRegisterExternalDecoder"); + std::unique_ptr video_decoder = + config_.decoder_factory->CreateVideoDecoder(decoder.video_format); + // If we still have no valid decoder, we have to create a "Null" decoder + // that ignores all calls. The reason we can get into this state is that the + // old decoder factory interface doesn't have a way to query supported + // codecs. + if (!video_decoder) { + video_decoder = std::make_unique(); + } + + std::string decoded_output_file = + field_trial::FindFullName("WebRTC-DecoderDataDumpDirectory"); + // Because '/' can't be used inside a field trial parameter, we use ';' + // instead. + // This is only relevant to WebRTC-DecoderDataDumpDirectory + // field trial. ';' is chosen arbitrary. Even though it's a legal character + // in some file systems, we can sacrifice ability to use it in the path to + // dumped video, since it's developers-only feature for debugging. + absl::c_replace(decoded_output_file, ';', '/'); + if (!decoded_output_file.empty()) { + char filename_buffer[256]; + rtc::SimpleStringBuilder ssb(filename_buffer); + ssb << decoded_output_file << "/webrtc_receive_stream_" + << this->config_.rtp.remote_ssrc << "-" << rtc::TimeMicros() << ".ivf"; + video_decoder = CreateFrameDumpingDecoderWrapper( + std::move(video_decoder), FileWrapper::OpenWriteOnly(ssb.str())); + } + + video_decoders_.push_back(std::move(video_decoder)); + video_receiver_.RegisterExternalDecoder(video_decoders_.back().get(), + decoder.payload_type); +} + VideoReceiveStream::Stats VideoReceiveStream2::GetStats() const { RTC_DCHECK_RUN_ON(&worker_sequence_checker_); VideoReceiveStream2::Stats stats = stats_proxy_.GetStats(); @@ -471,15 +540,6 @@ void VideoReceiveStream2::UpdateHistograms() { stats_proxy_.UpdateHistograms(fraction_lost, rtp_stats, nullptr); } -void VideoReceiveStream2::AddSecondarySink(RtpPacketSinkInterface* sink) { - rtp_video_stream_receiver_.AddSecondarySink(sink); -} - -void VideoReceiveStream2::RemoveSecondarySink( - const RtpPacketSinkInterface* sink) { - rtp_video_stream_receiver_.RemoveSecondarySink(sink); -} - bool VideoReceiveStream2::SetBaseMinimumPlayoutDelayMs(int delay_ms) { RTC_DCHECK_RUN_ON(&worker_sequence_checker_); if (delay_ms < kMinBaseMinimumDelayMs || delay_ms > kMaxBaseMinimumDelayMs) { @@ -499,7 +559,10 @@ int VideoReceiveStream2::GetBaseMinimumPlayoutDelayMs() const { void VideoReceiveStream2::OnFrame(const VideoFrame& video_frame) { VideoFrameMetaData frame_meta(video_frame, clock_->CurrentTime()); - worker_thread_->PostTask( + // TODO(bugs.webrtc.org/10739): we should set local capture clock offset for + // |video_frame.packet_infos|. But VideoFrame is const qualified here. + + call_->worker_thread()->PostTask( ToQueuedTask(task_safety_, [frame_meta, this]() { RTC_DCHECK_RUN_ON(&worker_sequence_checker_); int64_t video_playout_ntp_ms; @@ -516,6 +579,22 @@ void VideoReceiveStream2::OnFrame(const VideoFrame& video_frame) { source_tracker_.OnFrameDelivered(video_frame.packet_infos()); config_.renderer->OnFrame(video_frame); + webrtc::MutexLock lock(&pending_resolution_mutex_); + if (pending_resolution_.has_value()) { + if (!pending_resolution_->empty() && + (video_frame.width() != static_cast(pending_resolution_->width) || + video_frame.height() != + static_cast(pending_resolution_->height))) { + RTC_LOG(LS_WARNING) + << "Recordable encoded frame stream resolution was reported as " + << pending_resolution_->width << "x" << pending_resolution_->height + << " but the stream is now " << video_frame.width() + << video_frame.height(); + } + pending_resolution_ = RecordableEncodedFrame::EncodedResolution{ + static_cast(video_frame.width()), + static_cast(video_frame.height())}; + } } void VideoReceiveStream2::SetFrameDecryptor( @@ -548,8 +627,7 @@ void VideoReceiveStream2::RequestKeyFrame(int64_t timestamp_ms) { }); } -void VideoReceiveStream2::OnCompleteFrame( - std::unique_ptr frame) { +void VideoReceiveStream2::OnCompleteFrame(std::unique_ptr frame) { RTC_DCHECK_RUN_ON(&worker_sequence_checker_); // TODO(https://bugs.webrtc.org/9974): Consider removing this workaround. @@ -572,8 +650,13 @@ void VideoReceiveStream2::OnCompleteFrame( } int64_t last_continuous_pid = frame_buffer_->InsertFrame(std::move(frame)); - if (last_continuous_pid != -1) - rtp_video_stream_receiver_.FrameContinuous(last_continuous_pid); + if (last_continuous_pid != -1) { + { + // TODO(bugs.webrtc.org/11993): Call on the network thread. + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + rtp_video_stream_receiver_.FrameContinuous(last_continuous_pid); + } + } } void VideoReceiveStream2::OnRttUpdate(int64_t avg_rtt_ms, int64_t max_rtt_ms) { @@ -589,7 +672,7 @@ uint32_t VideoReceiveStream2::id() const { } absl::optional VideoReceiveStream2::GetInfo() const { - RTC_DCHECK_RUN_ON(&worker_sequence_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); absl::optional info = rtp_video_stream_receiver_.GetSyncInfo(); @@ -640,9 +723,10 @@ void VideoReceiveStream2::StartNextDecode() { HandleEncodedFrame(std::move(frame)); } else { int64_t now_ms = clock_->TimeInMilliseconds(); - worker_thread_->PostTask(ToQueuedTask( + // TODO(bugs.webrtc.org/11993): PostTask to the network thread. + call_->worker_thread()->PostTask(ToQueuedTask( task_safety_, [this, now_ms, wait_ms = GetMaxWaitMs()]() { - RTC_DCHECK_RUN_ON(&worker_sequence_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); HandleFrameBufferTimeout(now_ms, wait_ms); })); } @@ -670,13 +754,26 @@ void VideoReceiveStream2::HandleEncodedFrame( const bool keyframe_request_is_due = now_ms >= (last_keyframe_request_ms_ + max_wait_for_keyframe_ms_); - int decode_result = video_receiver_.Decode(frame.get()); + if (!video_receiver_.IsExternalDecoderRegistered(frame->PayloadType())) { + // Look for the decoder with this payload type. + for (const Decoder& decoder : config_.decoders) { + if (decoder.payload_type == frame->PayloadType()) { + CreateAndRegisterExternalDecoder(decoder); + break; + } + } + } + + int64_t frame_id = frame->Id(); + bool received_frame_is_keyframe = + frame->FrameType() == VideoFrameType::kVideoFrameKey; + int decode_result = DecodeAndMaybeDispatchEncodedFrame(std::move(frame)); if (decode_result == WEBRTC_VIDEO_CODEC_OK || decode_result == WEBRTC_VIDEO_CODEC_OK_REQUEST_KEYFRAME) { keyframe_required_ = false; frame_decoded_ = true; - decoded_frame_picture_id = frame->id.picture_id; + decoded_frame_picture_id = frame_id; if (decode_result == WEBRTC_VIDEO_CODEC_OK_REQUEST_KEYFRAME) force_request_key_frame = true; @@ -688,36 +785,90 @@ void VideoReceiveStream2::HandleEncodedFrame( force_request_key_frame = true; } - bool received_frame_is_keyframe = - frame->FrameType() == VideoFrameType::kVideoFrameKey; + { + // TODO(bugs.webrtc.org/11993): Make this PostTask to the network thread. + call_->worker_thread()->PostTask(ToQueuedTask( + task_safety_, + [this, now_ms, received_frame_is_keyframe, force_request_key_frame, + decoded_frame_picture_id, keyframe_request_is_due]() { + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); - worker_thread_->PostTask(ToQueuedTask( - task_safety_, - [this, now_ms, received_frame_is_keyframe, force_request_key_frame, - decoded_frame_picture_id, keyframe_request_is_due]() { - RTC_DCHECK_RUN_ON(&worker_sequence_checker_); + if (decoded_frame_picture_id != -1) + rtp_video_stream_receiver_.FrameDecoded(decoded_frame_picture_id); - if (decoded_frame_picture_id != -1) - rtp_video_stream_receiver_.FrameDecoded(decoded_frame_picture_id); - - HandleKeyFrameGeneration(received_frame_is_keyframe, now_ms, - force_request_key_frame, - keyframe_request_is_due); - })); + HandleKeyFrameGeneration(received_frame_is_keyframe, now_ms, + force_request_key_frame, + keyframe_request_is_due); + })); + } +} - if (encoded_frame_buffer_function_) { - frame->Retain(); - encoded_frame_buffer_function_(WebRtcRecordableEncodedFrame(*frame)); +int VideoReceiveStream2::DecodeAndMaybeDispatchEncodedFrame( + std::unique_ptr frame) { + // Running on decode_queue_. + + // If |buffered_encoded_frames_| grows out of control (=60 queued frames), + // maybe due to a stuck decoder, we just halt the process here and log the + // error. + const bool encoded_frame_output_enabled = + encoded_frame_buffer_function_ != nullptr && + buffered_encoded_frames_.size() < kBufferedEncodedFramesMaxSize; + EncodedFrame* frame_ptr = frame.get(); + if (encoded_frame_output_enabled) { + // If we receive a key frame with unset resolution, hold on dispatching the + // frame and following ones until we know a resolution of the stream. + // NOTE: The code below has a race where it can report the wrong + // resolution for keyframes after an initial keyframe of other resolution. + // However, the only known consumer of this information is the W3C + // MediaRecorder and it will only use the resolution in the first encoded + // keyframe from WebRTC, so misreporting is fine. + buffered_encoded_frames_.push_back(std::move(frame)); + if (buffered_encoded_frames_.size() == kBufferedEncodedFramesMaxSize) + RTC_LOG(LS_ERROR) << "About to halt recordable encoded frame output due " + "to too many buffered frames."; + + webrtc::MutexLock lock(&pending_resolution_mutex_); + if (IsKeyFrameAndUnspecifiedResolution(*frame_ptr) && + !pending_resolution_.has_value()) + pending_resolution_.emplace(); + } + + int decode_result = video_receiver_.Decode(frame_ptr); + if (encoded_frame_output_enabled) { + absl::optional + pending_resolution; + { + // Fish out |pending_resolution_| to avoid taking the mutex on every lap + // or dispatching under the mutex in the flush loop. + webrtc::MutexLock lock(&pending_resolution_mutex_); + if (pending_resolution_.has_value()) + pending_resolution = *pending_resolution_; + } + if (!pending_resolution.has_value() || !pending_resolution->empty()) { + // Flush the buffered frames. + for (const auto& frame : buffered_encoded_frames_) { + RecordableEncodedFrame::EncodedResolution resolution{ + frame->EncodedImage()._encodedWidth, + frame->EncodedImage()._encodedHeight}; + if (IsKeyFrameAndUnspecifiedResolution(*frame)) { + RTC_DCHECK(!pending_resolution->empty()); + resolution = *pending_resolution; + } + encoded_frame_buffer_function_( + WebRtcRecordableEncodedFrame(*frame, resolution)); + } + buffered_encoded_frames_.clear(); + } } + return decode_result; } +// RTC_RUN_ON(packet_sequence_checker_) void VideoReceiveStream2::HandleKeyFrameGeneration( bool received_frame_is_keyframe, int64_t now_ms, bool always_request_key_frame, bool keyframe_request_is_due) { - // Running on worker_sequence_checker_. - bool request_key_frame = always_request_key_frame; // Repeat sending keyframe requests if we've requested a keyframe. @@ -741,9 +892,9 @@ void VideoReceiveStream2::HandleKeyFrameGeneration( } } +// RTC_RUN_ON(packet_sequence_checker_) void VideoReceiveStream2::HandleFrameBufferTimeout(int64_t now_ms, int64_t wait_ms) { - // Running on |worker_sequence_checker_|. absl::optional last_packet_ms = rtp_video_stream_receiver_.LastReceivedPacketMs(); @@ -763,8 +914,8 @@ void VideoReceiveStream2::HandleFrameBufferTimeout(int64_t now_ms, } } +// RTC_RUN_ON(packet_sequence_checker_) bool VideoReceiveStream2::IsReceivingKeyFrame(int64_t timestamp_ms) const { - // Running on worker_sequence_checker_. absl::optional last_keyframe_packet_ms = rtp_video_stream_receiver_.LastReceivedKeyframePacketMs(); @@ -836,13 +987,13 @@ VideoReceiveStream2::SetAndGetRecordingState(RecordingState state, event.Set(); }); - old_state.keyframe_needed = keyframe_generation_requested_; - if (generate_key_frame) { rtp_video_stream_receiver_.RequestKeyFrame(); - keyframe_generation_requested_ = true; - } else { - keyframe_generation_requested_ = state.keyframe_needed; + { + // TODO(bugs.webrtc.org/11993): Post this to the network thread. + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); + keyframe_generation_requested_ = true; + } } event.Wait(rtc::Event::kForever); @@ -850,7 +1001,7 @@ VideoReceiveStream2::SetAndGetRecordingState(RecordingState state, } void VideoReceiveStream2::GenerateKeyFrame() { - RTC_DCHECK_RUN_ON(&worker_sequence_checker_); + RTC_DCHECK_RUN_ON(&packet_sequence_checker_); RequestKeyFrame(clock_->TimeInMilliseconds()); keyframe_generation_requested_ = true; } diff --git a/video/video_receive_stream2.h b/video/video_receive_stream2.h index 658fab510c..9557044277 100644 --- a/video/video_receive_stream2.h +++ b/video/video_receive_stream2.h @@ -14,9 +14,11 @@ #include #include +#include "api/sequence_checker.h" #include "api/task_queue/task_queue_factory.h" #include "api/units/timestamp.h" #include "api/video/recordable_encoded_frame.h" +#include "call/call.h" #include "call/rtp_packet_sink_interface.h" #include "call/syncable.h" #include "call/video_receive_stream.h" @@ -24,10 +26,10 @@ #include "modules/rtp_rtcp/source/source_tracker.h" #include "modules/video_coding/frame_buffer2.h" #include "modules/video_coding/video_receiver2.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_queue.h" #include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "rtc_base/thread_annotations.h" #include "system_wrappers/include/clock.h" #include "video/receive_statistics_proxy2.h" #include "video/rtp_streams_synchronizer2.h" @@ -37,7 +39,6 @@ namespace webrtc { -class ProcessThread; class RtpStreamReceiverInterface; class RtpStreamReceiverControllerInterface; class RtxReceiveStream; @@ -74,29 +75,45 @@ struct VideoFrameMetaData { const Timestamp decode_timestamp; }; -class VideoReceiveStream2 : public webrtc::VideoReceiveStream, - public rtc::VideoSinkInterface, - public NackSender, - public video_coding::OnCompleteFrameCallback, - public Syncable, - public CallStatsObserver { +class VideoReceiveStream2 + : public webrtc::VideoReceiveStream, + public rtc::VideoSinkInterface, + public NackSender, + public RtpVideoStreamReceiver2::OnCompleteFrameCallback, + public Syncable, + public CallStatsObserver { public: // The default number of milliseconds to pass before re-requesting a key frame // to be sent. static constexpr int kMaxWaitForKeyFrameMs = 200; + // The maximum number of buffered encoded frames when encoded output is + // configured. + static constexpr size_t kBufferedEncodedFramesMaxSize = 60; VideoReceiveStream2(TaskQueueFactory* task_queue_factory, - TaskQueueBase* current_queue, - RtpStreamReceiverControllerInterface* receiver_controller, + Call* call, int num_cpu_cores, PacketRouter* packet_router, VideoReceiveStream::Config config, - ProcessThread* process_thread, CallStats* call_stats, Clock* clock, VCMTiming* timing); + // Destruction happens on the worker thread. Prior to destruction the caller + // must ensure that a registration with the transport has been cleared. See + // `RegisterWithTransport` for details. + // TODO(tommi): As a further improvement to this, performing the full + // destruction on the network thread could be made the default. ~VideoReceiveStream2() override; + // Called on `packet_sequence_checker_` to register/unregister with the + // network transport. + void RegisterWithTransport( + RtpStreamReceiverControllerInterface* receiver_controller); + // If registration has previously been done (via `RegisterWithTransport`) then + // `UnregisterFromTransport` must be called prior to destruction, on the + // network thread. + void UnregisterFromTransport(); + const Config& config() const { return config_; } void SignalNetworkState(NetworkState state); @@ -108,10 +125,9 @@ class VideoReceiveStream2 : public webrtc::VideoReceiveStream, void Start() override; void Stop() override; - webrtc::VideoReceiveStream::Stats GetStats() const override; + const RtpConfig& rtp_config() const override { return config_.rtp; } - void AddSecondarySink(RtpPacketSinkInterface* sink) override; - void RemoveSecondarySink(const RtpPacketSinkInterface* sink) override; + webrtc::VideoReceiveStream::Stats GetStats() const override; // SetBaseMinimumPlayoutDelayMs and GetBaseMinimumPlayoutDelayMs are called // from webrtc/api level and requested by user code. For e.g. blink/js layer @@ -133,9 +149,8 @@ class VideoReceiveStream2 : public webrtc::VideoReceiveStream, void SendNack(const std::vector& sequence_numbers, bool buffering_allowed) override; - // Implements video_coding::OnCompleteFrameCallback. - void OnCompleteFrame( - std::unique_ptr frame) override; + // Implements RtpVideoStreamReceiver2::OnCompleteFrameCallback. + void OnCompleteFrame(std::unique_ptr frame) override; // Implements CallStatsObserver::OnRttUpdate void OnRttUpdate(int64_t avg_rtt_ms, int64_t max_rtt_ms) override; @@ -158,35 +173,45 @@ class VideoReceiveStream2 : public webrtc::VideoReceiveStream, void GenerateKeyFrame() override; private: + void CreateAndRegisterExternalDecoder(const Decoder& decoder); int64_t GetMaxWaitMs() const RTC_RUN_ON(decode_queue_); void StartNextDecode() RTC_RUN_ON(decode_queue_); - void HandleEncodedFrame(std::unique_ptr frame) + void HandleEncodedFrame(std::unique_ptr frame) RTC_RUN_ON(decode_queue_); void HandleFrameBufferTimeout(int64_t now_ms, int64_t wait_ms) - RTC_RUN_ON(worker_sequence_checker_); + RTC_RUN_ON(packet_sequence_checker_); void UpdatePlayoutDelays() const RTC_EXCLUSIVE_LOCKS_REQUIRED(worker_sequence_checker_); void RequestKeyFrame(int64_t timestamp_ms) - RTC_RUN_ON(worker_sequence_checker_); + RTC_RUN_ON(packet_sequence_checker_); void HandleKeyFrameGeneration(bool received_frame_is_keyframe, int64_t now_ms, bool always_request_key_frame, bool keyframe_request_is_due) - RTC_RUN_ON(worker_sequence_checker_); + RTC_RUN_ON(packet_sequence_checker_); bool IsReceivingKeyFrame(int64_t timestamp_ms) const - RTC_RUN_ON(worker_sequence_checker_); + RTC_RUN_ON(packet_sequence_checker_); + int DecodeAndMaybeDispatchEncodedFrame(std::unique_ptr frame) + RTC_RUN_ON(decode_queue_); void UpdateHistograms(); RTC_NO_UNIQUE_ADDRESS SequenceChecker worker_sequence_checker_; - RTC_NO_UNIQUE_ADDRESS SequenceChecker module_process_sequence_checker_; + // TODO(bugs.webrtc.org/11993): This checker conceptually represents + // operations that belong to the network thread. The Call class is currently + // moving towards handling network packets on the network thread and while + // that work is ongoing, this checker may in practice represent the worker + // thread, but still serves as a mechanism of grouping together concepts + // that belong to the network thread. Once the packets are fully delivered + // on the network thread, this comment will be deleted. + RTC_NO_UNIQUE_ADDRESS SequenceChecker packet_sequence_checker_; TaskQueueFactory* const task_queue_factory_; TransportAdapter transport_adapter_; const VideoReceiveStream::Config config_; const int num_cpu_cores_; - TaskQueueBase* const worker_thread_; + Call* const call_; Clock* const clock_; CallStats* const call_stats_; @@ -214,9 +239,12 @@ class VideoReceiveStream2 : public webrtc::VideoReceiveStream, // Members for the new jitter buffer experiment. std::unique_ptr frame_buffer_; - std::unique_ptr media_receiver_; - std::unique_ptr rtx_receive_stream_; - std::unique_ptr rtx_receiver_; + std::unique_ptr media_receiver_ + RTC_GUARDED_BY(packet_sequence_checker_); + std::unique_ptr rtx_receive_stream_ + RTC_GUARDED_BY(packet_sequence_checker_); + std::unique_ptr rtx_receiver_ + RTC_GUARDED_BY(packet_sequence_checker_); // Whenever we are in an undecodable state (stream has just started or due to // a decoding error) we require a keyframe to restart the stream. @@ -255,8 +283,18 @@ class VideoReceiveStream2 : public webrtc::VideoReceiveStream, std::function encoded_frame_buffer_function_ RTC_GUARDED_BY(decode_queue_); // Set to true while we're requesting keyframes but not yet received one. - bool keyframe_generation_requested_ RTC_GUARDED_BY(worker_sequence_checker_) = + bool keyframe_generation_requested_ RTC_GUARDED_BY(packet_sequence_checker_) = false; + // Lock to avoid unnecessary per-frame idle wakeups in the code. + webrtc::Mutex pending_resolution_mutex_; + // Signal from decode queue to OnFrame callback to fill pending_resolution_. + // absl::nullopt - no resolution needed. 0x0 - next OnFrame to fill with + // received resolution. Not 0x0 - OnFrame has filled a resolution. + absl::optional pending_resolution_ + RTC_GUARDED_BY(pending_resolution_mutex_); + // Buffered encoded frames held while waiting for decoded resolution. + std::vector> buffered_encoded_frames_ + RTC_GUARDED_BY(decode_queue_); // Set by the field trial WebRTC-LowLatencyRenderer. The parameter |enabled| // determines if the low-latency renderer algorithm should be used for the @@ -268,6 +306,11 @@ class VideoReceiveStream2 : public webrtc::VideoReceiveStream, // queue. FieldTrialParameter low_latency_renderer_include_predecode_buffer_; + // Set by the field trial WebRTC-PreStreamDecoders. The parameter |max| + // determines the maximum number of decoders that are created up front before + // any video frame has been received. + FieldTrialParameter maximum_pre_stream_decoders_; + // Defined last so they are destroyed before all other members. rtc::TaskQueue decode_queue_; diff --git a/video/video_receive_stream2_unittest.cc b/video/video_receive_stream2_unittest.cc index 3f10686db7..850fd0dbb5 100644 --- a/video/video_receive_stream2_unittest.cc +++ b/video/video_receive_stream2_unittest.cc @@ -11,16 +11,19 @@ #include "video/video_receive_stream2.h" #include +#include #include #include #include #include "api/task_queue/default_task_queue_factory.h" #include "api/test/video/function_video_decoder_factory.h" +#include "api/video/video_frame.h" #include "api/video_codecs/video_decoder.h" #include "call/rtp_stream_receiver_controller.h" #include "common_video/test/utilities.h" #include "media/base/fake_video_renderer.h" +#include "media/engine/fake_webrtc_call.h" #include "modules/pacing/packet_router.h" #include "modules/rtp_rtcp/source/rtp_packet_to_send.h" #include "modules/utility/include/process_thread.h" @@ -40,9 +43,13 @@ namespace webrtc { namespace { using ::testing::_; +using ::testing::AllOf; using ::testing::ElementsAreArray; +using ::testing::Field; +using ::testing::InSequence; using ::testing::Invoke; using ::testing::IsEmpty; +using ::testing::Property; using ::testing::SizeIs; constexpr int kDefaultTimeOutMs = 50; @@ -76,7 +83,15 @@ class MockVideoDecoder : public VideoDecoder { const char* ImplementationName() const { return "MockVideoDecoder"; } }; -class FrameObjectFake : public video_coding::EncodedFrame { +class MockVideoDecoderFactory : public VideoDecoderFactory { + public: + MOCK_CONST_METHOD0(GetSupportedFormats, std::vector()); + + MOCK_METHOD1(CreateVideoDecoder, + std::unique_ptr(const SdpVideoFormat& format)); +}; + +class FrameObjectFake : public EncodedFrame { public: void SetPayloadType(uint8_t payload_type) { _payloadType = payload_type; } @@ -94,23 +109,26 @@ class FrameObjectFake : public video_coding::EncodedFrame { class VideoReceiveStream2Test : public ::testing::Test { public: VideoReceiveStream2Test() - : process_thread_(ProcessThread::Create("TestThread")), - task_queue_factory_(CreateDefaultTaskQueueFactory()), - config_(&mock_transport_), - call_stats_(Clock::GetRealTimeClock(), loop_.task_queue()), - h264_decoder_factory_(&mock_h264_video_decoder_) {} + : task_queue_factory_(CreateDefaultTaskQueueFactory()), + h264_decoder_factory_(&mock_h264_video_decoder_), + config_(&mock_transport_, &h264_decoder_factory_), + call_stats_(Clock::GetRealTimeClock(), loop_.task_queue()) {} + ~VideoReceiveStream2Test() override { + if (video_receive_stream_) + video_receive_stream_->UnregisterFromTransport(); + } - void SetUp() { + void SetUp() override { constexpr int kDefaultNumCpuCores = 2; config_.rtp.remote_ssrc = 1111; config_.rtp.local_ssrc = 2222; config_.renderer = &fake_renderer_; - config_.decoder_factory = &h264_decoder_factory_; VideoReceiveStream::Decoder h264_decoder; h264_decoder.payload_type = 99; h264_decoder.video_format = SdpVideoFormat("H264"); h264_decoder.video_format.parameters.insert( {"sprop-parameter-sets", "Z0IACpZTBYmI,aMljiA=="}); + config_.decoders.clear(); config_.decoders.push_back(h264_decoder); clock_ = Clock::GetRealTimeClock(); @@ -118,21 +136,21 @@ class VideoReceiveStream2Test : public ::testing::Test { video_receive_stream_ = std::make_unique( - task_queue_factory_.get(), loop_.task_queue(), - &rtp_stream_receiver_controller_, kDefaultNumCpuCores, - &packet_router_, config_.Copy(), process_thread_.get(), - &call_stats_, clock_, timing_); + task_queue_factory_.get(), &fake_call_, kDefaultNumCpuCores, + &packet_router_, config_.Copy(), &call_stats_, clock_, timing_); + video_receive_stream_->RegisterWithTransport( + &rtp_stream_receiver_controller_); } protected: test::RunLoop loop_; - std::unique_ptr process_thread_; const std::unique_ptr task_queue_factory_; + test::VideoDecoderProxyFactory h264_decoder_factory_; VideoReceiveStream::Config config_; internal::CallStats call_stats_; MockVideoDecoder mock_h264_video_decoder_; - test::VideoDecoderProxyFactory h264_decoder_factory_; cricket::FakeVideoRenderer fake_renderer_; + cricket::FakeCall fake_call_; MockTransport mock_transport_; PacketRouter packet_router_; RtpStreamReceiverController rtp_stream_receiver_controller_; @@ -172,7 +190,7 @@ TEST_F(VideoReceiveStream2Test, CreateFrameFromH264FmtpSpropAndIdr) { TEST_F(VideoReceiveStream2Test, PlayoutDelay) { const VideoPlayoutDelay kPlayoutDelayMs = {123, 321}; std::unique_ptr test_frame(new FrameObjectFake()); - test_frame->id.picture_id = 0; + test_frame->SetId(0); test_frame->SetPlayoutDelay(kPlayoutDelayMs); video_receive_stream_->OnCompleteFrame(std::move(test_frame)); @@ -203,7 +221,7 @@ TEST_F(VideoReceiveStream2Test, PlayoutDelayPreservesDefaultMaxValue) { const VideoPlayoutDelay kPlayoutDelayMs = {123, -1}; std::unique_ptr test_frame(new FrameObjectFake()); - test_frame->id.picture_id = 0; + test_frame->SetId(0); test_frame->SetPlayoutDelay(kPlayoutDelayMs); video_receive_stream_->OnCompleteFrame(std::move(test_frame)); @@ -219,7 +237,7 @@ TEST_F(VideoReceiveStream2Test, PlayoutDelayPreservesDefaultMinValue) { const VideoPlayoutDelay kPlayoutDelayMs = {-1, 321}; std::unique_ptr test_frame(new FrameObjectFake()); - test_frame->id.picture_id = 0; + test_frame->SetId(0); test_frame->SetPlayoutDelay(kPlayoutDelayMs); video_receive_stream_->OnCompleteFrame(std::move(test_frame)); @@ -233,20 +251,20 @@ TEST_F(VideoReceiveStream2Test, PlayoutDelayPreservesDefaultMinValue) { TEST_F(VideoReceiveStream2Test, MaxCompositionDelayNotSetByDefault) { // Default with no playout delay set. std::unique_ptr test_frame0(new FrameObjectFake()); - test_frame0->id.picture_id = 0; + test_frame0->SetId(0); video_receive_stream_->OnCompleteFrame(std::move(test_frame0)); EXPECT_FALSE(timing_->MaxCompositionDelayInFrames()); // Max composition delay not set for playout delay 0,0. std::unique_ptr test_frame1(new FrameObjectFake()); - test_frame1->id.picture_id = 1; + test_frame1->SetId(1); test_frame1->SetPlayoutDelay({0, 0}); video_receive_stream_->OnCompleteFrame(std::move(test_frame1)); EXPECT_FALSE(timing_->MaxCompositionDelayInFrames()); // Max composition delay not set for playout delay X,Y, where X,Y>0. std::unique_ptr test_frame2(new FrameObjectFake()); - test_frame2->id.picture_id = 2; + test_frame2->SetId(2); test_frame2->SetPlayoutDelay({10, 30}); video_receive_stream_->OnCompleteFrame(std::move(test_frame2)); EXPECT_FALSE(timing_->MaxCompositionDelayInFrames()); @@ -257,7 +275,7 @@ TEST_F(VideoReceiveStream2Test, MaxCompositionDelaySetFromMaxPlayoutDelay) { const VideoPlayoutDelay kPlayoutDelayMs = {0, 50}; const int kExpectedMaxCompositionDelayInFrames = 3; // ~50 ms at 60 fps. std::unique_ptr test_frame(new FrameObjectFake()); - test_frame->id.picture_id = 0; + test_frame->SetId(0); test_frame->SetPlayoutDelay(kPlayoutDelayMs); video_receive_stream_->OnCompleteFrame(std::move(test_frame)); EXPECT_EQ(kExpectedMaxCompositionDelayInFrames, @@ -269,16 +287,18 @@ class VideoReceiveStream2TestWithFakeDecoder : public ::testing::Test { VideoReceiveStream2TestWithFakeDecoder() : fake_decoder_factory_( []() { return std::make_unique(); }), - process_thread_(ProcessThread::Create("TestThread")), task_queue_factory_(CreateDefaultTaskQueueFactory()), - config_(&mock_transport_), + config_(&mock_transport_, &fake_decoder_factory_), call_stats_(Clock::GetRealTimeClock(), loop_.task_queue()) {} + ~VideoReceiveStream2TestWithFakeDecoder() override { + if (video_receive_stream_) + video_receive_stream_->UnregisterFromTransport(); + } - void SetUp() { + void SetUp() override { config_.rtp.remote_ssrc = 1111; config_.rtp.local_ssrc = 2222; config_.renderer = &fake_renderer_; - config_.decoder_factory = &fake_decoder_factory_; VideoReceiveStream::Decoder fake_decoder; fake_decoder.payload_type = 99; fake_decoder.video_format = SdpVideoFormat("VP8"); @@ -289,19 +309,22 @@ class VideoReceiveStream2TestWithFakeDecoder : public ::testing::Test { void ReCreateReceiveStream(VideoReceiveStream::RecordingState state) { constexpr int kDefaultNumCpuCores = 2; - video_receive_stream_ = nullptr; + if (video_receive_stream_) { + video_receive_stream_->UnregisterFromTransport(); + video_receive_stream_ = nullptr; + } timing_ = new VCMTiming(clock_); video_receive_stream_.reset(new webrtc::internal::VideoReceiveStream2( - task_queue_factory_.get(), loop_.task_queue(), - &rtp_stream_receiver_controller_, kDefaultNumCpuCores, &packet_router_, - config_.Copy(), process_thread_.get(), &call_stats_, clock_, timing_)); + task_queue_factory_.get(), &fake_call_, kDefaultNumCpuCores, + &packet_router_, config_.Copy(), &call_stats_, clock_, timing_)); + video_receive_stream_->RegisterWithTransport( + &rtp_stream_receiver_controller_); video_receive_stream_->SetAndGetRecordingState(std::move(state), false); } protected: test::RunLoop loop_; test::FunctionVideoDecoderFactory fake_decoder_factory_; - std::unique_ptr process_thread_; const std::unique_ptr task_queue_factory_; VideoReceiveStream::Config config_; internal::CallStats call_stats_; @@ -309,6 +332,7 @@ class VideoReceiveStream2TestWithFakeDecoder : public ::testing::Test { MockTransport mock_transport_; PacketRouter packet_router_; RtpStreamReceiverController rtp_stream_receiver_controller_; + cricket::FakeCall fake_call_; std::unique_ptr video_receive_stream_; Clock* clock_; VCMTiming* timing_; @@ -318,7 +342,7 @@ TEST_F(VideoReceiveStream2TestWithFakeDecoder, PassesNtpTime) { const int64_t kNtpTimestamp = 12345; auto test_frame = std::make_unique(); test_frame->SetPayloadType(99); - test_frame->id.picture_id = 0; + test_frame->SetId(0); test_frame->SetNtpTime(kNtpTimestamp); video_receive_stream_->Start(); @@ -331,7 +355,7 @@ TEST_F(VideoReceiveStream2TestWithFakeDecoder, PassesRotation) { const webrtc::VideoRotation kRotation = webrtc::kVideoRotation_180; auto test_frame = std::make_unique(); test_frame->SetPayloadType(99); - test_frame->id.picture_id = 0; + test_frame->SetId(0); test_frame->SetRotation(kRotation); video_receive_stream_->Start(); @@ -344,7 +368,7 @@ TEST_F(VideoReceiveStream2TestWithFakeDecoder, PassesRotation) { TEST_F(VideoReceiveStream2TestWithFakeDecoder, PassesPacketInfos) { auto test_frame = std::make_unique(); test_frame->SetPayloadType(99); - test_frame->id.picture_id = 0; + test_frame->SetId(0); RtpPacketInfos packet_infos = CreatePacketInfos(3); test_frame->SetPacketInfos(packet_infos); @@ -363,7 +387,7 @@ TEST_F(VideoReceiveStream2TestWithFakeDecoder, RenderedFrameUpdatesGetSources) { // Prepare one video frame with per-packet information. auto test_frame = std::make_unique(); test_frame->SetPayloadType(99); - test_frame->id.picture_id = 0; + test_frame->SetId(0); RtpPacketInfos packet_infos; { RtpPacketInfos::vector_type infos; @@ -373,16 +397,16 @@ TEST_F(VideoReceiveStream2TestWithFakeDecoder, RenderedFrameUpdatesGetSources) { info.set_csrcs({kCsrc}); info.set_rtp_timestamp(kRtpTimestamp); - info.set_receive_time_ms(clock_->TimeInMilliseconds() - 5000); + info.set_receive_time(clock_->CurrentTime() - TimeDelta::Millis(5000)); infos.push_back(info); - info.set_receive_time_ms(clock_->TimeInMilliseconds() - 3000); + info.set_receive_time(clock_->CurrentTime() - TimeDelta::Millis(3000)); infos.push_back(info); - info.set_receive_time_ms(clock_->TimeInMilliseconds() - 2000); + info.set_receive_time(clock_->CurrentTime() - TimeDelta::Millis(2000)); infos.push_back(info); - info.set_receive_time_ms(clock_->TimeInMilliseconds() - 4000); + info.set_receive_time(clock_->CurrentTime() - TimeDelta::Millis(1000)); infos.push_back(info); packet_infos = RtpPacketInfos(std::move(infos)); @@ -433,15 +457,25 @@ TEST_F(VideoReceiveStream2TestWithFakeDecoder, RenderedFrameUpdatesGetSources) { } } -std::unique_ptr MakeFrame(VideoFrameType frame_type, - int picture_id) { +std::unique_ptr MakeFrameWithResolution( + VideoFrameType frame_type, + int picture_id, + int width, + int height) { auto frame = std::make_unique(); frame->SetPayloadType(99); - frame->id.picture_id = picture_id; + frame->SetId(picture_id); frame->SetFrameType(frame_type); + frame->_encodedWidth = width; + frame->_encodedHeight = height; return frame; } +std::unique_ptr MakeFrame(VideoFrameType frame_type, + int picture_id) { + return MakeFrameWithResolution(frame_type, picture_id, 320, 240); +} + TEST_F(VideoReceiveStream2TestWithFakeDecoder, PassesFrameWhenEncodedFramesCallbackSet) { testing::MockFunction callback; @@ -473,8 +507,30 @@ TEST_F(VideoReceiveStream2TestWithFakeDecoder, video_receive_stream_->Stop(); } -class VideoReceiveStream2TestWithSimulatedClock : public ::testing::Test { +class VideoReceiveStream2TestWithSimulatedClock + : public ::testing::TestWithParam { public: + class FakeRenderer : public rtc::VideoSinkInterface { + public: + void SignalDoneAfterFrames(int num_frames_received) { + signal_after_frame_count_ = num_frames_received; + if (frame_count_ == signal_after_frame_count_) + event_.Set(); + } + + void OnFrame(const webrtc::VideoFrame& frame) override { + if (++frame_count_ == signal_after_frame_count_) + event_.Set(); + } + + void WaitUntilDone() { event_.Wait(rtc::Event::kForever); } + + private: + int signal_after_frame_count_ = std::numeric_limits::max(); + int frame_count_ = 0; + rtc::Event event_; + }; + class FakeDecoder2 : public test::FakeDecoder { public: explicit FakeDecoder2(std::function decode_callback) @@ -497,11 +553,11 @@ class VideoReceiveStream2TestWithSimulatedClock : public ::testing::Test { Transport* transport, VideoDecoderFactory* decoder_factory, rtc::VideoSinkInterface* renderer) { - VideoReceiveStream::Config config(transport); + VideoReceiveStream::Config config(transport, decoder_factory); config.rtp.remote_ssrc = 1111; config.rtp.local_ssrc = 2222; + config.rtp.nack.rtp_history_ms = GetParam(); // rtx-time. config.renderer = renderer; - config.decoder_factory = decoder_factory; VideoReceiveStream::Decoder fake_decoder; fake_decoder.payload_type = 99; fake_decoder.video_format = SdpVideoFormat("VP8"); @@ -514,28 +570,30 @@ class VideoReceiveStream2TestWithSimulatedClock : public ::testing::Test { fake_decoder_factory_([this] { return std::make_unique([this] { OnFrameDecoded(); }); }), - process_thread_(time_controller_.CreateProcessThread("ProcessThread")), config_(GetConfig(&mock_transport_, &fake_decoder_factory_, &fake_renderer_)), call_stats_(time_controller_.GetClock(), loop_.task_queue()), video_receive_stream_(time_controller_.GetTaskQueueFactory(), - loop_.task_queue(), - &rtp_stream_receiver_controller_, + &fake_call_, /*num_cores=*/2, &packet_router_, config_.Copy(), - process_thread_.get(), &call_stats_, time_controller_.GetClock(), new VCMTiming(time_controller_.GetClock())) { + video_receive_stream_.RegisterWithTransport( + &rtp_stream_receiver_controller_); video_receive_stream_.Start(); } + ~VideoReceiveStream2TestWithSimulatedClock() override { + video_receive_stream_.UnregisterFromTransport(); + } + void OnFrameDecoded() { event_->Set(); } - void PassEncodedFrameAndWait( - std::unique_ptr frame) { + void PassEncodedFrameAndWait(std::unique_ptr frame) { event_ = std::make_unique(); // This call will eventually end up in the Decoded method where the // event is set. @@ -547,9 +605,9 @@ class VideoReceiveStream2TestWithSimulatedClock : public ::testing::Test { GlobalSimulatedTimeController time_controller_; test::RunLoop loop_; test::FunctionVideoDecoderFactory fake_decoder_factory_; - std::unique_ptr process_thread_; MockTransport mock_transport_; - cricket::FakeVideoRenderer fake_renderer_; + FakeRenderer fake_renderer_; + cricket::FakeCall fake_call_; VideoReceiveStream::Config config_; internal::CallStats call_stats_; PacketRouter packet_router_; @@ -558,10 +616,9 @@ class VideoReceiveStream2TestWithSimulatedClock : public ::testing::Test { std::unique_ptr event_; }; -TEST_F(VideoReceiveStream2TestWithSimulatedClock, +TEST_P(VideoReceiveStream2TestWithSimulatedClock, RequestsKeyFramesUntilKeyFrameReceived) { - auto tick = TimeDelta::Millis( - internal::VideoReceiveStream2::kMaxWaitForKeyFrameMs / 2); + auto tick = TimeDelta::Millis(GetParam() / 2); EXPECT_CALL(mock_transport_, SendRtcp).Times(1).WillOnce(Invoke([this]() { loop_.Quit(); return 0; @@ -573,7 +630,8 @@ TEST_F(VideoReceiveStream2TestWithSimulatedClock, loop_.Run(); testing::Mock::VerifyAndClearExpectations(&mock_transport_); - // T+200ms: still no key frame received, expect key frame request sent again. + // T+keyframetimeout: still no key frame received, expect key frame request + // sent again. EXPECT_CALL(mock_transport_, SendRtcp).Times(1).WillOnce(Invoke([this]() { loop_.Quit(); return 0; @@ -583,8 +641,8 @@ TEST_F(VideoReceiveStream2TestWithSimulatedClock, loop_.Run(); testing::Mock::VerifyAndClearExpectations(&mock_transport_); - // T+200ms: now send a key frame - we should not observe new key frame - // requests after this. + // T+keyframetimeout: now send a key frame - we should not observe new key + // frame requests after this. EXPECT_CALL(mock_transport_, SendRtcp).Times(0); PassEncodedFrameAndWait(MakeFrame(VideoFrameType::kVideoFrameKey, 3)); time_controller_.AdvanceTime(2 * tick); @@ -593,4 +651,171 @@ TEST_F(VideoReceiveStream2TestWithSimulatedClock, loop_.Run(); } +TEST_P(VideoReceiveStream2TestWithSimulatedClock, + DispatchesEncodedFrameSequenceStartingWithKeyframeWithoutResolution) { + video_receive_stream_.Start(); + testing::MockFunction callback; + video_receive_stream_.SetAndGetRecordingState( + VideoReceiveStream::RecordingState(callback.AsStdFunction()), + /*generate_key_frame=*/false); + + InSequence s; + EXPECT_CALL( + callback, + Call(AllOf( + Property(&RecordableEncodedFrame::resolution, + Field(&RecordableEncodedFrame::EncodedResolution::width, + test::FakeDecoder::kDefaultWidth)), + Property(&RecordableEncodedFrame::resolution, + Field(&RecordableEncodedFrame::EncodedResolution::height, + test::FakeDecoder::kDefaultHeight))))); + EXPECT_CALL(callback, Call); + + fake_renderer_.SignalDoneAfterFrames(2); + PassEncodedFrameAndWait( + MakeFrameWithResolution(VideoFrameType::kVideoFrameKey, 0, 0, 0)); + PassEncodedFrameAndWait( + MakeFrameWithResolution(VideoFrameType::kVideoFrameDelta, 1, 0, 0)); + fake_renderer_.WaitUntilDone(); + + video_receive_stream_.Stop(); +} + +TEST_P(VideoReceiveStream2TestWithSimulatedClock, + DispatchesEncodedFrameSequenceStartingWithKeyframeWithResolution) { + video_receive_stream_.Start(); + testing::MockFunction callback; + video_receive_stream_.SetAndGetRecordingState( + VideoReceiveStream::RecordingState(callback.AsStdFunction()), + /*generate_key_frame=*/false); + + InSequence s; + EXPECT_CALL( + callback, + Call(AllOf( + Property( + &RecordableEncodedFrame::resolution, + Field(&RecordableEncodedFrame::EncodedResolution::width, 1080)), + Property(&RecordableEncodedFrame::resolution, + Field(&RecordableEncodedFrame::EncodedResolution::height, + 720))))); + EXPECT_CALL(callback, Call); + + fake_renderer_.SignalDoneAfterFrames(2); + PassEncodedFrameAndWait( + MakeFrameWithResolution(VideoFrameType::kVideoFrameKey, 0, 1080, 720)); + PassEncodedFrameAndWait( + MakeFrameWithResolution(VideoFrameType::kVideoFrameDelta, 1, 0, 0)); + fake_renderer_.WaitUntilDone(); + + video_receive_stream_.Stop(); +} + +INSTANTIATE_TEST_SUITE_P( + RtxTime, + VideoReceiveStream2TestWithSimulatedClock, + ::testing::Values(internal::VideoReceiveStream2::kMaxWaitForKeyFrameMs, + 50 /*ms*/)); + +class VideoReceiveStream2TestWithLazyDecoderCreation : public ::testing::Test { + public: + VideoReceiveStream2TestWithLazyDecoderCreation() + : task_queue_factory_(CreateDefaultTaskQueueFactory()), + config_(&mock_transport_, &mock_h264_decoder_factory_), + call_stats_(Clock::GetRealTimeClock(), loop_.task_queue()) {} + + ~VideoReceiveStream2TestWithLazyDecoderCreation() override { + video_receive_stream_->UnregisterFromTransport(); + } + + void SetUp() override { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-PreStreamDecoders/max:0/"); + constexpr int kDefaultNumCpuCores = 2; + config_.rtp.remote_ssrc = 1111; + config_.rtp.local_ssrc = 2222; + config_.renderer = &fake_renderer_; + VideoReceiveStream::Decoder h264_decoder; + h264_decoder.payload_type = 99; + h264_decoder.video_format = SdpVideoFormat("H264"); + h264_decoder.video_format.parameters.insert( + {"sprop-parameter-sets", "Z0IACpZTBYmI,aMljiA=="}); + config_.decoders.clear(); + config_.decoders.push_back(h264_decoder); + + clock_ = Clock::GetRealTimeClock(); + timing_ = new VCMTiming(clock_); + + video_receive_stream_ = + std::make_unique( + task_queue_factory_.get(), &fake_call_, kDefaultNumCpuCores, + &packet_router_, config_.Copy(), &call_stats_, clock_, timing_); + video_receive_stream_->RegisterWithTransport( + &rtp_stream_receiver_controller_); + } + + protected: + test::RunLoop loop_; + const std::unique_ptr task_queue_factory_; + MockVideoDecoderFactory mock_h264_decoder_factory_; + VideoReceiveStream::Config config_; + internal::CallStats call_stats_; + MockVideoDecoder mock_h264_video_decoder_; + cricket::FakeVideoRenderer fake_renderer_; + cricket::FakeCall fake_call_; + MockTransport mock_transport_; + PacketRouter packet_router_; + RtpStreamReceiverController rtp_stream_receiver_controller_; + std::unique_ptr video_receive_stream_; + Clock* clock_; + VCMTiming* timing_; +}; + +TEST_F(VideoReceiveStream2TestWithLazyDecoderCreation, LazyDecoderCreation) { + constexpr uint8_t idr_nalu[] = {0x05, 0xFF, 0xFF, 0xFF}; + RtpPacketToSend rtppacket(nullptr); + uint8_t* payload = rtppacket.AllocatePayload(sizeof(idr_nalu)); + memcpy(payload, idr_nalu, sizeof(idr_nalu)); + rtppacket.SetMarker(true); + rtppacket.SetSsrc(1111); + rtppacket.SetPayloadType(99); + rtppacket.SetSequenceNumber(1); + rtppacket.SetTimestamp(0); + + // No decoder is created here. + EXPECT_CALL(mock_h264_decoder_factory_, CreateVideoDecoder(_)).Times(0); + video_receive_stream_->Start(); + + EXPECT_CALL(mock_h264_decoder_factory_, CreateVideoDecoder(_)) + .WillOnce(Invoke([this](const SdpVideoFormat& format) { + test::VideoDecoderProxyFactory h264_decoder_factory( + &mock_h264_video_decoder_); + return h264_decoder_factory.CreateVideoDecoder(format); + })); + rtc::Event init_decode_event_; + EXPECT_CALL(mock_h264_video_decoder_, InitDecode(_, _)) + .WillOnce(Invoke([&init_decode_event_](const VideoCodec* config, + int32_t number_of_cores) { + init_decode_event_.Set(); + return 0; + })); + EXPECT_CALL(mock_h264_video_decoder_, RegisterDecodeCompleteCallback(_)); + EXPECT_CALL(mock_h264_video_decoder_, Decode(_, false, _)); + RtpPacketReceived parsed_packet; + ASSERT_TRUE(parsed_packet.Parse(rtppacket.data(), rtppacket.size())); + rtp_stream_receiver_controller_.OnRtpPacket(parsed_packet); + EXPECT_CALL(mock_h264_video_decoder_, Release()); + + // Make sure the decoder thread had a chance to run. + init_decode_event_.Wait(kDefaultTimeOutMs); +} + +TEST_F(VideoReceiveStream2TestWithLazyDecoderCreation, + DeregisterDecoderThatsNotCreated) { + // No decoder is created here. + EXPECT_CALL(mock_h264_decoder_factory_, CreateVideoDecoder(_)).Times(0); + video_receive_stream_->Start(); + video_receive_stream_->Stop(); +} + } // namespace webrtc diff --git a/video/video_receive_stream_unittest.cc b/video/video_receive_stream_unittest.cc index 9ac640ba1b..cb14f7dc06 100644 --- a/video/video_receive_stream_unittest.cc +++ b/video/video_receive_stream_unittest.cc @@ -76,7 +76,7 @@ class MockVideoDecoder : public VideoDecoder { const char* ImplementationName() const { return "MockVideoDecoder"; } }; -class FrameObjectFake : public video_coding::EncodedFrame { +class FrameObjectFake : public EncodedFrame { public: void SetPayloadType(uint8_t payload_type) { _payloadType = payload_type; } @@ -96,16 +96,15 @@ class VideoReceiveStreamTest : public ::testing::Test { VideoReceiveStreamTest() : process_thread_(ProcessThread::Create("TestThread")), task_queue_factory_(CreateDefaultTaskQueueFactory()), - config_(&mock_transport_), - call_stats_(Clock::GetRealTimeClock(), process_thread_.get()), - h264_decoder_factory_(&mock_h264_video_decoder_) {} + h264_decoder_factory_(&mock_h264_video_decoder_), + config_(&mock_transport_, &h264_decoder_factory_), + call_stats_(Clock::GetRealTimeClock(), process_thread_.get()) {} void SetUp() { constexpr int kDefaultNumCpuCores = 2; config_.rtp.remote_ssrc = 1111; config_.rtp.local_ssrc = 2222; config_.renderer = &fake_renderer_; - config_.decoder_factory = &h264_decoder_factory_; VideoReceiveStream::Decoder h264_decoder; h264_decoder.payload_type = 99; h264_decoder.video_format = SdpVideoFormat("H264"); @@ -126,10 +125,10 @@ class VideoReceiveStreamTest : public ::testing::Test { protected: std::unique_ptr process_thread_; const std::unique_ptr task_queue_factory_; + test::VideoDecoderProxyFactory h264_decoder_factory_; VideoReceiveStream::Config config_; CallStats call_stats_; MockVideoDecoder mock_h264_video_decoder_; - test::VideoDecoderProxyFactory h264_decoder_factory_; cricket::FakeVideoRenderer fake_renderer_; MockTransport mock_transport_; PacketRouter packet_router_; @@ -170,7 +169,7 @@ TEST_F(VideoReceiveStreamTest, CreateFrameFromH264FmtpSpropAndIdr) { TEST_F(VideoReceiveStreamTest, PlayoutDelay) { const VideoPlayoutDelay kPlayoutDelayMs = {123, 321}; std::unique_ptr test_frame(new FrameObjectFake()); - test_frame->id.picture_id = 0; + test_frame->SetId(0); test_frame->SetPlayoutDelay(kPlayoutDelayMs); video_receive_stream_->OnCompleteFrame(std::move(test_frame)); @@ -201,7 +200,7 @@ TEST_F(VideoReceiveStreamTest, PlayoutDelayPreservesDefaultMaxValue) { const VideoPlayoutDelay kPlayoutDelayMs = {123, -1}; std::unique_ptr test_frame(new FrameObjectFake()); - test_frame->id.picture_id = 0; + test_frame->SetId(0); test_frame->SetPlayoutDelay(kPlayoutDelayMs); video_receive_stream_->OnCompleteFrame(std::move(test_frame)); @@ -217,7 +216,7 @@ TEST_F(VideoReceiveStreamTest, PlayoutDelayPreservesDefaultMinValue) { const VideoPlayoutDelay kPlayoutDelayMs = {-1, 321}; std::unique_ptr test_frame(new FrameObjectFake()); - test_frame->id.picture_id = 0; + test_frame->SetId(0); test_frame->SetPlayoutDelay(kPlayoutDelayMs); video_receive_stream_->OnCompleteFrame(std::move(test_frame)); @@ -235,14 +234,13 @@ class VideoReceiveStreamTestWithFakeDecoder : public ::testing::Test { []() { return std::make_unique(); }), process_thread_(ProcessThread::Create("TestThread")), task_queue_factory_(CreateDefaultTaskQueueFactory()), - config_(&mock_transport_), + config_(&mock_transport_, &fake_decoder_factory_), call_stats_(Clock::GetRealTimeClock(), process_thread_.get()) {} void SetUp() { config_.rtp.remote_ssrc = 1111; config_.rtp.local_ssrc = 2222; config_.renderer = &fake_renderer_; - config_.decoder_factory = &fake_decoder_factory_; VideoReceiveStream::Decoder fake_decoder; fake_decoder.payload_type = 99; fake_decoder.video_format = SdpVideoFormat("VP8"); @@ -281,7 +279,7 @@ TEST_F(VideoReceiveStreamTestWithFakeDecoder, PassesNtpTime) { const int64_t kNtpTimestamp = 12345; auto test_frame = std::make_unique(); test_frame->SetPayloadType(99); - test_frame->id.picture_id = 0; + test_frame->SetId(0); test_frame->SetNtpTime(kNtpTimestamp); video_receive_stream_->Start(); @@ -294,7 +292,7 @@ TEST_F(VideoReceiveStreamTestWithFakeDecoder, PassesRotation) { const webrtc::VideoRotation kRotation = webrtc::kVideoRotation_180; auto test_frame = std::make_unique(); test_frame->SetPayloadType(99); - test_frame->id.picture_id = 0; + test_frame->SetId(0); test_frame->SetRotation(kRotation); video_receive_stream_->Start(); @@ -307,7 +305,7 @@ TEST_F(VideoReceiveStreamTestWithFakeDecoder, PassesRotation) { TEST_F(VideoReceiveStreamTestWithFakeDecoder, PassesPacketInfos) { auto test_frame = std::make_unique(); test_frame->SetPayloadType(99); - test_frame->id.picture_id = 0; + test_frame->SetId(0); RtpPacketInfos packet_infos = CreatePacketInfos(3); test_frame->SetPacketInfos(packet_infos); @@ -326,7 +324,7 @@ TEST_F(VideoReceiveStreamTestWithFakeDecoder, RenderedFrameUpdatesGetSources) { // Prepare one video frame with per-packet information. auto test_frame = std::make_unique(); test_frame->SetPayloadType(99); - test_frame->id.picture_id = 0; + test_frame->SetId(0); RtpPacketInfos packet_infos; { RtpPacketInfos::vector_type infos; @@ -336,16 +334,16 @@ TEST_F(VideoReceiveStreamTestWithFakeDecoder, RenderedFrameUpdatesGetSources) { info.set_csrcs({kCsrc}); info.set_rtp_timestamp(kRtpTimestamp); - info.set_receive_time_ms(clock_->TimeInMilliseconds() - 5000); + info.set_receive_time(clock_->CurrentTime() - TimeDelta::Millis(5000)); infos.push_back(info); - info.set_receive_time_ms(clock_->TimeInMilliseconds() - 3000); + info.set_receive_time(clock_->CurrentTime() - TimeDelta::Millis(3000)); infos.push_back(info); - info.set_receive_time_ms(clock_->TimeInMilliseconds() - 2000); + info.set_receive_time(clock_->CurrentTime() - TimeDelta::Millis(2000)); infos.push_back(info); - info.set_receive_time_ms(clock_->TimeInMilliseconds() - 4000); + info.set_receive_time(clock_->CurrentTime() - TimeDelta::Millis(4000)); infos.push_back(info); packet_infos = RtpPacketInfos(std::move(infos)); @@ -400,7 +398,7 @@ std::unique_ptr MakeFrame(VideoFrameType frame_type, int picture_id) { auto frame = std::make_unique(); frame->SetPayloadType(99); - frame->id.picture_id = picture_id; + frame->SetId(picture_id); frame->SetFrameType(frame_type); return frame; } @@ -496,13 +494,12 @@ class VideoReceiveStreamTestWithSimulatedClock : public ::testing::Test { void OnFrameDecoded() { event_->Set(); } - void PassEncodedFrameAndWait( - std::unique_ptr frame) { - event_ = std::make_unique(); - // This call will eventually end up in the Decoded method where the - // event is set. - video_receive_stream_.OnCompleteFrame(std::move(frame)); - event_->Wait(rtc::Event::kForever); + void PassEncodedFrameAndWait(std::unique_ptr frame) { + event_ = std::make_unique(); + // This call will eventually end up in the Decoded method where the + // event is set. + video_receive_stream_.OnCompleteFrame(std::move(frame)); + event_->Wait(rtc::Event::kForever); } protected: diff --git a/video/video_send_stream.cc b/video/video_send_stream.cc index 91c246c66e..8c0f8f6f72 100644 --- a/video/video_send_stream.cc +++ b/video/video_send_stream.cc @@ -23,7 +23,6 @@ #include "system_wrappers/include/clock.h" #include "system_wrappers/include/field_trial.h" #include "video/adaptation/overuse_frame_detector.h" -#include "video/video_send_stream_impl.h" #include "video/video_stream_encoder.h" namespace webrtc { @@ -65,7 +64,10 @@ VideoStreamEncoder::BitrateAllocationCallbackType GetBitrateAllocationCallbackType(const VideoSendStream::Config& config) { if (webrtc::RtpExtension::FindHeaderExtensionByUri( config.rtp.extensions, - webrtc::RtpExtension::kVideoLayersAllocationUri)) { + webrtc::RtpExtension::kVideoLayersAllocationUri, + config.crypto_options.srtp.enable_encrypted_rtp_header_extensions + ? RtpExtension::Filter::kPreferEncryptedExtension + : RtpExtension::Filter::kDiscardEncryptedExtension)) { return VideoStreamEncoder::BitrateAllocationCallbackType:: kVideoLayersAllocation; } @@ -77,6 +79,32 @@ GetBitrateAllocationCallbackType(const VideoSendStream::Config& config) { kVideoBitrateAllocationWhenScreenSharing; } +RtpSenderFrameEncryptionConfig CreateFrameEncryptionConfig( + const VideoSendStream::Config* config) { + RtpSenderFrameEncryptionConfig frame_encryption_config; + frame_encryption_config.frame_encryptor = config->frame_encryptor; + frame_encryption_config.crypto_options = config->crypto_options; + return frame_encryption_config; +} + +RtpSenderObservers CreateObservers(RtcpRttStats* call_stats, + EncoderRtcpFeedback* encoder_feedback, + SendStatisticsProxy* stats_proxy, + SendDelayStats* send_delay_stats) { + RtpSenderObservers observers; + observers.rtcp_rtt_stats = call_stats; + observers.intra_frame_callback = encoder_feedback; + observers.rtcp_loss_notification_observer = encoder_feedback; + observers.report_block_data_observer = stats_proxy; + observers.rtp_stats = stats_proxy; + observers.bitrate_observer = stats_proxy; + observers.frame_count_observer = stats_proxy; + observers.rtcp_type_observer = stats_proxy; + observers.send_delay_observer = stats_proxy; + observers.send_packet_observer = send_delay_stats; + return observers; +} + } // namespace namespace internal { @@ -84,7 +112,6 @@ namespace internal { VideoSendStream::VideoSendStream( Clock* clock, int num_cpu_cores, - ProcessThread* module_process_thread, TaskQueueFactory* task_queue_factory, RtcpRttStats* call_stats, RtpTransportControllerSendInterface* transport, @@ -96,56 +123,79 @@ VideoSendStream::VideoSendStream( const std::map& suspended_ssrcs, const std::map& suspended_payload_states, std::unique_ptr fec_controller) - : worker_queue_(transport->GetWorkerQueue()), + : rtp_transport_queue_(transport->GetWorkerQueue()), + transport_(transport), stats_proxy_(clock, config, encoder_config.content_type), config_(std::move(config)), - content_type_(encoder_config.content_type) { + content_type_(encoder_config.content_type), + video_stream_encoder_(std::make_unique( + clock, + num_cpu_cores, + &stats_proxy_, + config_.encoder_settings, + std::make_unique(&stats_proxy_), + task_queue_factory, + GetBitrateAllocationCallbackType(config_))), + encoder_feedback_( + clock, + config_.rtp.ssrcs, + video_stream_encoder_.get(), + [this](uint32_t ssrc, const std::vector& seq_nums) { + return rtp_video_sender_->GetSentRtpPacketInfos(ssrc, seq_nums); + }), + rtp_video_sender_( + transport->CreateRtpVideoSender(suspended_ssrcs, + suspended_payload_states, + config_.rtp, + config_.rtcp_report_interval_ms, + config_.send_transport, + CreateObservers(call_stats, + &encoder_feedback_, + &stats_proxy_, + send_delay_stats), + event_log, + std::move(fec_controller), + CreateFrameEncryptionConfig(&config_), + config_.frame_transformer)), + send_stream_(clock, + &stats_proxy_, + rtp_transport_queue_, + transport, + bitrate_allocator, + video_stream_encoder_.get(), + &config_, + encoder_config.max_bitrate_bps, + encoder_config.bitrate_priority, + encoder_config.content_type, + rtp_video_sender_) { RTC_DCHECK(config_.encoder_settings.encoder_factory); RTC_DCHECK(config_.encoder_settings.bitrate_allocator_factory); - video_stream_encoder_ = std::make_unique( - clock, num_cpu_cores, &stats_proxy_, config_.encoder_settings, - std::make_unique(&stats_proxy_), task_queue_factory, - GetBitrateAllocationCallbackType(config_)); - - // TODO(srte): Initialization should not be done posted on a task queue. - // Note that the posted task must not outlive this scope since the closure - // references local variables. - worker_queue_->PostTask(ToQueuedTask( - [this, clock, call_stats, transport, bitrate_allocator, send_delay_stats, - event_log, &suspended_ssrcs, &encoder_config, &suspended_payload_states, - &fec_controller]() { - send_stream_.reset(new VideoSendStreamImpl( - clock, &stats_proxy_, worker_queue_, call_stats, transport, - bitrate_allocator, send_delay_stats, video_stream_encoder_.get(), - event_log, &config_, encoder_config.max_bitrate_bps, - encoder_config.bitrate_priority, suspended_ssrcs, - suspended_payload_states, encoder_config.content_type, - std::move(fec_controller))); - }, - [this]() { thread_sync_event_.Set(); })); - - // Wait for ConstructionTask to complete so that |send_stream_| can be used. - // |module_process_thread| must be registered and deregistered on the thread - // it was created on. - thread_sync_event_.Wait(rtc::Event::kForever); - send_stream_->RegisterProcessThread(module_process_thread); + video_stream_encoder_->SetFecControllerOverride(rtp_video_sender_); + ReconfigureVideoEncoder(std::move(encoder_config)); } VideoSendStream::~VideoSendStream() { RTC_DCHECK_RUN_ON(&thread_checker_); - RTC_DCHECK(!send_stream_); + RTC_DCHECK(!running_); + transport_->DestroyRtpVideoSender(rtp_video_sender_); } void VideoSendStream::UpdateActiveSimulcastLayers( const std::vector active_layers) { RTC_DCHECK_RUN_ON(&thread_checker_); + // Keep our `running_` flag expected state in sync with active layers since + // the `send_stream_` will be implicitly stopped/started depending on the + // state of the layers. + bool running = false; + rtc::StringBuilder active_layers_string; active_layers_string << "{"; for (size_t i = 0; i < active_layers.size(); ++i) { if (active_layers[i]) { + running = true; active_layers_string << "1"; } else { active_layers_string << "0"; @@ -158,35 +208,53 @@ void VideoSendStream::UpdateActiveSimulcastLayers( RTC_LOG(LS_INFO) << "UpdateActiveSimulcastLayers: " << active_layers_string.str(); - VideoSendStreamImpl* send_stream = send_stream_.get(); - worker_queue_->PostTask([this, send_stream, active_layers] { - send_stream->UpdateActiveSimulcastLayers(active_layers); - thread_sync_event_.Set(); - }); + rtp_transport_queue_->PostTask( + ToQueuedTask(transport_queue_safety_, [this, active_layers] { + send_stream_.UpdateActiveSimulcastLayers(active_layers); + })); - thread_sync_event_.Wait(rtc::Event::kForever); + running_ = running; } void VideoSendStream::Start() { RTC_DCHECK_RUN_ON(&thread_checker_); - RTC_LOG(LS_INFO) << "VideoSendStream::Start"; - VideoSendStreamImpl* send_stream = send_stream_.get(); - worker_queue_->PostTask([this, send_stream] { - send_stream->Start(); + RTC_DLOG(LS_INFO) << "VideoSendStream::Start"; + if (running_) + return; + + running_ = true; + + rtp_transport_queue_->PostTask(ToQueuedTask([this] { + transport_queue_safety_->SetAlive(); + send_stream_.Start(); thread_sync_event_.Set(); - }); + })); // It is expected that after VideoSendStream::Start has been called, incoming // frames are not dropped in VideoStreamEncoder. To ensure this, Start has to // be synchronized. + // TODO(tommi): ^^^ Validate if this still holds. thread_sync_event_.Wait(rtc::Event::kForever); } void VideoSendStream::Stop() { RTC_DCHECK_RUN_ON(&thread_checker_); - RTC_LOG(LS_INFO) << "VideoSendStream::Stop"; - VideoSendStreamImpl* send_stream = send_stream_.get(); - worker_queue_->PostTask([send_stream] { send_stream->Stop(); }); + if (!running_) + return; + RTC_DLOG(LS_INFO) << "VideoSendStream::Stop"; + running_ = false; + rtp_transport_queue_->PostTask(ToQueuedTask(transport_queue_safety_, [this] { + // As the stream can get re-used and implicitly restarted via changing + // the state of the active layers, we do not mark the + // `transport_queue_safety_` flag with `SetNotAlive()` here. That's only + // done when we stop permanently via `StopPermanentlyAndGetRtpStates()`. + send_stream_.Stop(); + })); +} + +bool VideoSendStream::started() { + RTC_DCHECK_RUN_ON(&thread_checker_); + return running_; } void VideoSendStream::AddAdaptationResource( @@ -209,10 +277,8 @@ void VideoSendStream::SetSource( } void VideoSendStream::ReconfigureVideoEncoder(VideoEncoderConfig config) { - // TODO(perkj): Some test cases in VideoSendStreamTest call - // ReconfigureVideoEncoder from the network thread. - // RTC_DCHECK_RUN_ON(&thread_checker_); - RTC_DCHECK(content_type_ == config.content_type); + RTC_DCHECK_RUN_ON(&thread_checker_); + RTC_DCHECK_EQ(content_type_, config.content_type); video_stream_encoder_->ConfigureEncoder( std::move(config), config_.rtp.max_packet_size - CalculateMaxHeaderSize(config_.rtp)); @@ -226,7 +292,7 @@ VideoSendStream::Stats VideoSendStream::GetStats() { } absl::optional VideoSendStream::GetPacingFactorOverride() const { - return send_stream_->configured_pacing_factor_; + return send_stream_.configured_pacing_factor(); } void VideoSendStream::StopPermanentlyAndGetRtpStates( @@ -234,12 +300,16 @@ void VideoSendStream::StopPermanentlyAndGetRtpStates( VideoSendStream::RtpPayloadStateMap* payload_state_map) { RTC_DCHECK_RUN_ON(&thread_checker_); video_stream_encoder_->Stop(); - send_stream_->DeRegisterProcessThread(); - worker_queue_->PostTask([this, rtp_state_map, payload_state_map]() { - send_stream_->Stop(); - *rtp_state_map = send_stream_->GetRtpStates(); - *payload_state_map = send_stream_->GetRtpPayloadStates(); - send_stream_.reset(); + + running_ = false; + // Always run these cleanup steps regardless of whether running_ was set + // or not. This will unregister callbacks before destruction. + // See `VideoSendStreamImpl::StopVideoSendStream` for more. + rtp_transport_queue_->PostTask([this, rtp_state_map, payload_state_map]() { + transport_queue_safety_->SetNotAlive(); + send_stream_.Stop(); + *rtp_state_map = send_stream_.GetRtpStates(); + *payload_state_map = send_stream_.GetRtpPayloadStates(); thread_sync_event_.Set(); }); thread_sync_event_.Wait(rtc::Event::kForever); @@ -247,7 +317,7 @@ void VideoSendStream::StopPermanentlyAndGetRtpStates( void VideoSendStream::DeliverRtcp(const uint8_t* packet, size_t length) { // Called on a network thread. - send_stream_->DeliverRtcp(packet, length); + send_stream_.DeliverRtcp(packet, length); } } // namespace internal diff --git a/video/video_send_stream.h b/video/video_send_stream.h index e10f4ad59b..0d132dd666 100644 --- a/video/video_send_stream.h +++ b/video/video_send_stream.h @@ -16,15 +16,19 @@ #include #include "api/fec_controller.h" +#include "api/sequence_checker.h" #include "api/video/video_stream_encoder_interface.h" #include "call/bitrate_allocator.h" #include "call/video_receive_stream.h" #include "call/video_send_stream.h" #include "rtc_base/event.h" +#include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_queue.h" -#include "rtc_base/thread_checker.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" +#include "video/encoder_rtcp_feedback.h" #include "video/send_delay_stats.h" #include "video/send_statistics_proxy.h" +#include "video/video_send_stream_impl.h" namespace webrtc { namespace test { @@ -33,7 +37,6 @@ class VideoSendStreamPeer; class CallStats; class IvfFileWriter; -class ProcessThread; class RateLimiter; class RtpRtcp; class RtpTransportControllerSendInterface; @@ -45,8 +48,7 @@ class VideoSendStreamImpl; // VideoSendStream implements webrtc::VideoSendStream. // Internally, it delegates all public methods to VideoSendStreamImpl and / or -// VideoStreamEncoder. VideoSendStreamInternal is created and deleted on -// |worker_queue|. +// VideoStreamEncoder. class VideoSendStream : public webrtc::VideoSendStream { public: using RtpStateMap = std::map; @@ -55,7 +57,6 @@ class VideoSendStream : public webrtc::VideoSendStream { VideoSendStream( Clock* clock, int num_cpu_cores, - ProcessThread* module_process_thread, TaskQueueFactory* task_queue_factory, RtcpRttStats* call_stats, RtpTransportControllerSendInterface* transport, @@ -77,6 +78,7 @@ class VideoSendStream : public webrtc::VideoSendStream { const std::vector active_layers) override; void Start() override; void Stop() override; + bool started() override; void AddAdaptationResource(rtc::scoped_refptr resource) override; std::vector> GetAdaptationResources() override; @@ -93,19 +95,23 @@ class VideoSendStream : public webrtc::VideoSendStream { private: friend class test::VideoSendStreamPeer; - class ConstructionTask; - absl::optional GetPacingFactorOverride() const; - rtc::ThreadChecker thread_checker_; - rtc::TaskQueue* const worker_queue_; + RTC_NO_UNIQUE_ADDRESS SequenceChecker thread_checker_; + rtc::TaskQueue* const rtp_transport_queue_; + RtpTransportControllerSendInterface* const transport_; rtc::Event thread_sync_event_; + rtc::scoped_refptr transport_queue_safety_ = + PendingTaskSafetyFlag::CreateDetached(); SendStatisticsProxy stats_proxy_; const VideoSendStream::Config config_; const VideoEncoderConfig::ContentType content_type_; - std::unique_ptr send_stream_; std::unique_ptr video_stream_encoder_; + EncoderRtcpFeedback encoder_feedback_; + RtpVideoSenderInterface* const rtp_video_sender_; + VideoSendStreamImpl send_stream_; + bool running_ RTC_GUARDED_BY(thread_checker_) = false; }; } // namespace internal diff --git a/video/video_send_stream_impl.cc b/video/video_send_stream_impl.cc index aeb197c223..3fc6b676dc 100644 --- a/video/video_send_stream_impl.cc +++ b/video/video_send_stream_impl.cc @@ -20,6 +20,7 @@ #include "api/crypto/crypto_options.h" #include "api/rtp_parameters.h" #include "api/scoped_refptr.h" +#include "api/sequence_checker.h" #include "api/video_codecs/video_codec.h" #include "call/rtp_transport_controller_send_interface.h" #include "call/video_send_stream.h" @@ -32,8 +33,7 @@ #include "rtc_base/experiments/rate_control_settings.h" #include "rtc_base/logging.h" #include "rtc_base/numerics/safe_conversions.h" -#include "rtc_base/synchronization/sequence_checker.h" -#include "rtc_base/thread_checker.h" +#include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/trace_event.h" #include "system_wrappers/include/clock.h" #include "system_wrappers/include/field_trial.h" @@ -131,33 +131,6 @@ int CalculateMaxPadBitrateBps(const std::vector& streams, return pad_up_to_bitrate_bps; } -RtpSenderFrameEncryptionConfig CreateFrameEncryptionConfig( - const VideoSendStream::Config* config) { - RtpSenderFrameEncryptionConfig frame_encryption_config; - frame_encryption_config.frame_encryptor = config->frame_encryptor; - frame_encryption_config.crypto_options = config->crypto_options; - return frame_encryption_config; -} - -RtpSenderObservers CreateObservers(RtcpRttStats* call_stats, - EncoderRtcpFeedback* encoder_feedback, - SendStatisticsProxy* stats_proxy, - SendDelayStats* send_delay_stats) { - RtpSenderObservers observers; - observers.rtcp_rtt_stats = call_stats; - observers.intra_frame_callback = encoder_feedback; - observers.rtcp_loss_notification_observer = encoder_feedback; - observers.rtcp_stats = stats_proxy; - observers.report_block_data_observer = stats_proxy; - observers.rtp_stats = stats_proxy; - observers.bitrate_observer = stats_proxy; - observers.frame_count_observer = stats_proxy; - observers.rtcp_type_observer = stats_proxy; - observers.send_delay_observer = stats_proxy; - observers.send_packet_observer = send_delay_stats; - return observers; -} - absl::optional GetAlrSettings( VideoEncoderConfig::ContentType content_type) { if (content_type == VideoEncoderConfig::ContentType::kScreen) { @@ -179,6 +152,44 @@ bool SameStreamsEnabled(const VideoBitrateAllocation& lhs, } return true; } + +// Returns an optional that has value iff TransportSeqNumExtensionConfigured +// is `true` for the given video send stream config. +absl::optional GetConfiguredPacingFactor( + const VideoSendStream::Config& config, + VideoEncoderConfig::ContentType content_type, + const PacingConfig& default_pacing_config) { + if (!TransportSeqNumExtensionConfigured(config)) + return absl::nullopt; + + absl::optional alr_settings = + GetAlrSettings(content_type); + if (alr_settings) + return alr_settings->pacing_factor; + + RateControlSettings rate_control_settings = + RateControlSettings::ParseFromFieldTrials(); + return rate_control_settings.GetPacingFactor().value_or( + default_pacing_config.pacing_factor); +} + +uint32_t GetInitialEncoderMaxBitrate(int initial_encoder_max_bitrate) { + if (initial_encoder_max_bitrate > 0) + return rtc::dchecked_cast(initial_encoder_max_bitrate); + + // TODO(srte): Make sure max bitrate is not set to negative values. We don't + // have any way to handle unset values in downstream code, such as the + // bitrate allocator. Previously -1 was implicitly casted to UINT32_MAX, a + // behaviour that is not safe. Converting to 10 Mbps should be safe for + // reasonable use cases as it allows adding the max of multiple streams + // without wrappping around. + const int kFallbackMaxBitrateBps = 10000000; + RTC_DLOG(LS_ERROR) << "ERROR: Initial encoder max bitrate = " + << initial_encoder_max_bitrate << " which is <= 0!"; + RTC_DLOG(LS_INFO) << "Using default encoder max bitrate = 10 Mbps"; + return kFallbackMaxBitrateBps; +} + } // namespace PacingConfig::PacingConfig() @@ -194,162 +205,109 @@ PacingConfig::~PacingConfig() = default; VideoSendStreamImpl::VideoSendStreamImpl( Clock* clock, SendStatisticsProxy* stats_proxy, - rtc::TaskQueue* worker_queue, - RtcpRttStats* call_stats, + rtc::TaskQueue* rtp_transport_queue, RtpTransportControllerSendInterface* transport, BitrateAllocatorInterface* bitrate_allocator, - SendDelayStats* send_delay_stats, VideoStreamEncoderInterface* video_stream_encoder, - RtcEventLog* event_log, const VideoSendStream::Config* config, int initial_encoder_max_bitrate, double initial_encoder_bitrate_priority, - std::map suspended_ssrcs, - std::map suspended_payload_states, VideoEncoderConfig::ContentType content_type, - std::unique_ptr fec_controller) + RtpVideoSenderInterface* rtp_video_sender) : clock_(clock), has_alr_probing_(config->periodic_alr_bandwidth_probing || GetAlrSettings(content_type)), pacing_config_(PacingConfig()), stats_proxy_(stats_proxy), config_(config), - worker_queue_(worker_queue), + rtp_transport_queue_(rtp_transport_queue), timed_out_(false), transport_(transport), bitrate_allocator_(bitrate_allocator), disable_padding_(true), max_padding_bitrate_(0), encoder_min_bitrate_bps_(0), + encoder_max_bitrate_bps_( + GetInitialEncoderMaxBitrate(initial_encoder_max_bitrate)), encoder_target_rate_bps_(0), encoder_bitrate_priority_(initial_encoder_bitrate_priority), - has_packet_feedback_(false), video_stream_encoder_(video_stream_encoder), - encoder_feedback_(clock, config_->rtp.ssrcs, video_stream_encoder), bandwidth_observer_(transport->GetBandwidthObserver()), - rtp_video_sender_( - transport_->CreateRtpVideoSender(suspended_ssrcs, - suspended_payload_states, - config_->rtp, - config_->rtcp_report_interval_ms, - config_->send_transport, - CreateObservers(call_stats, - &encoder_feedback_, - stats_proxy_, - send_delay_stats), - event_log, - std::move(fec_controller), - CreateFrameEncryptionConfig(config_), - config->frame_transformer)), - weak_ptr_factory_(this) { - video_stream_encoder->SetFecControllerOverride(rtp_video_sender_); - RTC_DCHECK_RUN_ON(worker_queue_); - RTC_LOG(LS_INFO) << "VideoSendStreamInternal: " << config_->ToString(); - weak_ptr_ = weak_ptr_factory_.GetWeakPtr(); - - encoder_feedback_.SetRtpVideoSender(rtp_video_sender_); - + rtp_video_sender_(rtp_video_sender), + configured_pacing_factor_( + GetConfiguredPacingFactor(*config_, content_type, pacing_config_)) { + RTC_DCHECK_GE(config_->rtp.payload_type, 0); + RTC_DCHECK_LE(config_->rtp.payload_type, 127); RTC_DCHECK(!config_->rtp.ssrcs.empty()); RTC_DCHECK(transport_); RTC_DCHECK_NE(initial_encoder_max_bitrate, 0); - - if (initial_encoder_max_bitrate > 0) { - encoder_max_bitrate_bps_ = - rtc::dchecked_cast(initial_encoder_max_bitrate); - } else { - // TODO(srte): Make sure max bitrate is not set to negative values. We don't - // have any way to handle unset values in downstream code, such as the - // bitrate allocator. Previously -1 was implicitly casted to UINT32_MAX, a - // behaviour that is not safe. Converting to 10 Mbps should be safe for - // reasonable use cases as it allows adding the max of multiple streams - // without wrappping around. - const int kFallbackMaxBitrateBps = 10000000; - RTC_DLOG(LS_ERROR) << "ERROR: Initial encoder max bitrate = " - << initial_encoder_max_bitrate << " which is <= 0!"; - RTC_DLOG(LS_INFO) << "Using default encoder max bitrate = 10 Mbps"; - encoder_max_bitrate_bps_ = kFallbackMaxBitrateBps; - } + RTC_LOG(LS_INFO) << "VideoSendStreamImpl: " << config_->ToString(); RTC_CHECK(AlrExperimentSettings::MaxOneFieldTrialEnabled()); + + // Only request rotation at the source when we positively know that the remote + // side doesn't support the rotation extension. This allows us to prepare the + // encoder in the expectation that rotation is supported - which is the common + // case. + bool rotation_applied = absl::c_none_of( + config_->rtp.extensions, [](const RtpExtension& extension) { + return extension.uri == RtpExtension::kVideoRotationUri; + }); + + video_stream_encoder_->SetSink(this, rotation_applied); + + absl::optional enable_alr_bw_probing; + // If send-side BWE is enabled, check if we should apply updated probing and // pacing settings. - if (TransportSeqNumExtensionConfigured(*config_)) { - has_packet_feedback_ = true; - + if (configured_pacing_factor_) { absl::optional alr_settings = GetAlrSettings(content_type); + int queue_time_limit_ms; if (alr_settings) { - transport->EnablePeriodicAlrProbing(true); - transport->SetPacingFactor(alr_settings->pacing_factor); - configured_pacing_factor_ = alr_settings->pacing_factor; - transport->SetQueueTimeLimit(alr_settings->max_paced_queue_time); + enable_alr_bw_probing = true; + queue_time_limit_ms = alr_settings->max_paced_queue_time; } else { RateControlSettings rate_control_settings = RateControlSettings::ParseFromFieldTrials(); - - transport->EnablePeriodicAlrProbing( - rate_control_settings.UseAlrProbing()); - const double pacing_factor = - rate_control_settings.GetPacingFactor().value_or( - pacing_config_.pacing_factor); - transport->SetPacingFactor(pacing_factor); - configured_pacing_factor_ = pacing_factor; - transport->SetQueueTimeLimit(pacing_config_.max_pacing_delay.Get().ms()); + enable_alr_bw_probing = rate_control_settings.UseAlrProbing(); + queue_time_limit_ms = pacing_config_.max_pacing_delay.Get().ms(); } + + transport->SetQueueTimeLimit(queue_time_limit_ms); } if (config_->periodic_alr_bandwidth_probing) { - transport->EnablePeriodicAlrProbing(true); + enable_alr_bw_probing = config_->periodic_alr_bandwidth_probing; } - RTC_DCHECK_GE(config_->rtp.payload_type, 0); - RTC_DCHECK_LE(config_->rtp.payload_type, 127); - - video_stream_encoder_->SetStartBitrate( - bitrate_allocator_->GetStartBitrate(this)); -} - -VideoSendStreamImpl::~VideoSendStreamImpl() { - RTC_DCHECK_RUN_ON(worker_queue_); - RTC_DCHECK(!rtp_video_sender_->IsActive()) - << "VideoSendStreamImpl::Stop not called"; - RTC_LOG(LS_INFO) << "~VideoSendStreamInternal: " << config_->ToString(); - transport_->DestroyRtpVideoSender(rtp_video_sender_); -} - -void VideoSendStreamImpl::RegisterProcessThread( - ProcessThread* module_process_thread) { - // Called on libjingle's worker thread (not worker_queue_), as part of the - // initialization steps. That's also the correct thread/queue for setting the - // state for |video_stream_encoder_|. - - // Only request rotation at the source when we positively know that the remote - // side doesn't support the rotation extension. This allows us to prepare the - // encoder in the expectation that rotation is supported - which is the common - // case. - bool rotation_applied = absl::c_none_of( - config_->rtp.extensions, [](const RtpExtension& extension) { - return extension.uri == RtpExtension::kVideoRotationUri; - }); + if (enable_alr_bw_probing) { + transport->EnablePeriodicAlrProbing(*enable_alr_bw_probing); + } - video_stream_encoder_->SetSink(this, rotation_applied); + rtp_transport_queue_->PostTask(ToQueuedTask(transport_queue_safety_, [this] { + if (configured_pacing_factor_) + transport_->SetPacingFactor(*configured_pacing_factor_); - rtp_video_sender_->RegisterProcessThread(module_process_thread); + video_stream_encoder_->SetStartBitrate( + bitrate_allocator_->GetStartBitrate(this)); + })); } -void VideoSendStreamImpl::DeRegisterProcessThread() { - rtp_video_sender_->DeRegisterProcessThread(); +VideoSendStreamImpl::~VideoSendStreamImpl() { + RTC_DCHECK_RUN_ON(&thread_checker_); + RTC_LOG(LS_INFO) << "~VideoSendStreamImpl: " << config_->ToString(); } void VideoSendStreamImpl::DeliverRtcp(const uint8_t* packet, size_t length) { // Runs on a network thread. - RTC_DCHECK(!worker_queue_->IsCurrent()); + RTC_DCHECK(!rtp_transport_queue_->IsCurrent()); rtp_video_sender_->DeliverRtcp(packet, length); } void VideoSendStreamImpl::UpdateActiveSimulcastLayers( const std::vector active_layers) { - RTC_DCHECK_RUN_ON(worker_queue_); + RTC_DCHECK_RUN_ON(rtp_transport_queue_); bool previously_active = rtp_video_sender_->IsActive(); rtp_video_sender_->SetActiveModules(active_layers); if (!rtp_video_sender_->IsActive() && previously_active) { @@ -362,17 +320,21 @@ void VideoSendStreamImpl::UpdateActiveSimulcastLayers( } void VideoSendStreamImpl::Start() { - RTC_DCHECK_RUN_ON(worker_queue_); + RTC_DCHECK_RUN_ON(rtp_transport_queue_); RTC_LOG(LS_INFO) << "VideoSendStream::Start"; if (rtp_video_sender_->IsActive()) return; + TRACE_EVENT_INSTANT0("webrtc", "VideoSendStream::Start"); rtp_video_sender_->SetActive(true); StartupVideoSendStream(); } void VideoSendStreamImpl::StartupVideoSendStream() { - RTC_DCHECK_RUN_ON(worker_queue_); + RTC_DCHECK_RUN_ON(rtp_transport_queue_); + + transport_queue_safety_->SetAlive(); + bitrate_allocator_->AddObserver(this, GetAllocationConfig()); // Start monitoring encoder activity. { @@ -381,8 +343,8 @@ void VideoSendStreamImpl::StartupVideoSendStream() { activity_ = false; timed_out_ = false; check_encoder_activity_task_ = RepeatingTaskHandle::DelayedStart( - worker_queue_->Get(), kEncoderTimeOut, [this] { - RTC_DCHECK_RUN_ON(worker_queue_); + rtp_transport_queue_->Get(), kEncoderTimeOut, [this] { + RTC_DCHECK_RUN_ON(rtp_transport_queue_); if (!activity_) { if (!timed_out_) { SignalEncoderTimedOut(); @@ -402,25 +364,29 @@ void VideoSendStreamImpl::StartupVideoSendStream() { } void VideoSendStreamImpl::Stop() { - RTC_DCHECK_RUN_ON(worker_queue_); + RTC_DCHECK_RUN_ON(rtp_transport_queue_); RTC_LOG(LS_INFO) << "VideoSendStreamImpl::Stop"; if (!rtp_video_sender_->IsActive()) return; + + RTC_DCHECK(transport_queue_safety_->alive()); TRACE_EVENT_INSTANT0("webrtc", "VideoSendStream::Stop"); rtp_video_sender_->SetActive(false); StopVideoSendStream(); } +// RTC_RUN_ON(rtp_transport_queue_) void VideoSendStreamImpl::StopVideoSendStream() { bitrate_allocator_->RemoveObserver(this); check_encoder_activity_task_.Stop(); video_stream_encoder_->OnBitrateUpdated(DataRate::Zero(), DataRate::Zero(), DataRate::Zero(), 0, 0, 0); stats_proxy_->OnSetEncoderTargetRate(0); + transport_queue_safety_->SetNotAlive(); } void VideoSendStreamImpl::SignalEncoderTimedOut() { - RTC_DCHECK_RUN_ON(worker_queue_); + RTC_DCHECK_RUN_ON(rtp_transport_queue_); // If the encoder has not produced anything the last kEncoderTimeOut and it // is supposed to, deregister as BitrateAllocatorObserver. This can happen // if a camera stops producing frames. @@ -432,17 +398,14 @@ void VideoSendStreamImpl::SignalEncoderTimedOut() { void VideoSendStreamImpl::OnBitrateAllocationUpdated( const VideoBitrateAllocation& allocation) { - if (!worker_queue_->IsCurrent()) { - auto ptr = weak_ptr_; - worker_queue_->PostTask([=] { - if (!ptr.get()) - return; - ptr->OnBitrateAllocationUpdated(allocation); - }); + if (!rtp_transport_queue_->IsCurrent()) { + rtp_transport_queue_->PostTask(ToQueuedTask(transport_queue_safety_, [=] { + OnBitrateAllocationUpdated(allocation); + })); return; } - RTC_DCHECK_RUN_ON(worker_queue_); + RTC_DCHECK_RUN_ON(rtp_transport_queue_); int64_t now_ms = clock_->TimeInMilliseconds(); if (encoder_target_rate_bps_ != 0) { @@ -487,7 +450,7 @@ void VideoSendStreamImpl::OnVideoLayersAllocationUpdated( } void VideoSendStreamImpl::SignalEncoderActive() { - RTC_DCHECK_RUN_ON(worker_queue_); + RTC_DCHECK_RUN_ON(rtp_transport_queue_); if (rtp_video_sender_->IsActive()) { RTC_LOG(LS_INFO) << "SignalEncoderActive, Encoder is active."; bitrate_allocator_->AddObserver(this, GetAllocationConfig()); @@ -509,21 +472,20 @@ void VideoSendStreamImpl::OnEncoderConfigurationChanged( bool is_svc, VideoEncoderConfig::ContentType content_type, int min_transmit_bitrate_bps) { - if (!worker_queue_->IsCurrent()) { - rtc::WeakPtr send_stream = weak_ptr_; - worker_queue_->PostTask([send_stream, streams, is_svc, content_type, - min_transmit_bitrate_bps]() mutable { - if (send_stream) { - send_stream->OnEncoderConfigurationChanged( - std::move(streams), is_svc, content_type, min_transmit_bitrate_bps); - } - }); + if (!rtp_transport_queue_->IsCurrent()) { + rtp_transport_queue_->PostTask(ToQueuedTask( + transport_queue_safety_, + [this, streams = std::move(streams), is_svc, content_type, + min_transmit_bitrate_bps]() mutable { + OnEncoderConfigurationChanged(std::move(streams), is_svc, + content_type, min_transmit_bitrate_bps); + })); return; } RTC_DCHECK_GE(config_->rtp.ssrcs.size(), streams.size()); TRACE_EVENT0("webrtc", "VideoSendStream::OnEncoderConfigurationChanged"); - RTC_DCHECK_RUN_ON(worker_queue_); + RTC_DCHECK_RUN_ON(rtp_transport_queue_); const VideoCodecType codec_type = PayloadStringToCodecType(config_->rtp.payload_name); @@ -586,14 +548,15 @@ EncodedImageCallback::Result VideoSendStreamImpl::OnEncodedImage( auto enable_padding_task = [this]() { if (disable_padding_) { - RTC_DCHECK_RUN_ON(worker_queue_); + RTC_DCHECK_RUN_ON(rtp_transport_queue_); disable_padding_ = false; // To ensure that padding bitrate is propagated to the bitrate allocator. SignalEncoderActive(); } }; - if (!worker_queue_->IsCurrent()) { - worker_queue_->PostTask(enable_padding_task); + if (!rtp_transport_queue_->IsCurrent()) { + rtp_transport_queue_->PostTask( + ToQueuedTask(transport_queue_safety_, std::move(enable_padding_task))); } else { enable_padding_task(); } @@ -603,18 +566,16 @@ EncodedImageCallback::Result VideoSendStreamImpl::OnEncodedImage( rtp_video_sender_->OnEncodedImage(encoded_image, codec_specific_info); // Check if there's a throttled VideoBitrateAllocation that we should try // sending. - rtc::WeakPtr send_stream = weak_ptr_; - auto update_task = [send_stream]() { - if (send_stream) { - RTC_DCHECK_RUN_ON(send_stream->worker_queue_); - auto& context = send_stream->video_bitrate_allocation_context_; - if (context && context->throttled_allocation) { - send_stream->OnBitrateAllocationUpdated(*context->throttled_allocation); - } + auto update_task = [this]() { + RTC_DCHECK_RUN_ON(rtp_transport_queue_); + auto& context = video_bitrate_allocation_context_; + if (context && context->throttled_allocation) { + OnBitrateAllocationUpdated(*context->throttled_allocation); } }; - if (!worker_queue_->IsCurrent()) { - worker_queue_->PostTask(update_task); + if (!rtp_transport_queue_->IsCurrent()) { + rtp_transport_queue_->PostTask( + ToQueuedTask(transport_queue_safety_, std::move(update_task))); } else { update_task(); } @@ -637,7 +598,7 @@ std::map VideoSendStreamImpl::GetRtpPayloadStates() } uint32_t VideoSendStreamImpl::OnBitrateUpdated(BitrateAllocationUpdate update) { - RTC_DCHECK_RUN_ON(worker_queue_); + RTC_DCHECK_RUN_ON(rtp_transport_queue_); RTC_DCHECK(rtp_video_sender_->IsActive()) << "VideoSendStream::Start has not been called."; diff --git a/video/video_send_stream_impl.h b/video/video_send_stream_impl.h index 41a7859a77..babf1dcfe5 100644 --- a/video/video_send_stream_impl.h +++ b/video/video_send_stream_impl.h @@ -19,8 +19,6 @@ #include #include "absl/types/optional.h" -#include "api/fec_controller.h" -#include "api/rtc_event_log/rtc_event_log.h" #include "api/video/encoded_image.h" #include "api/video/video_bitrate_allocation.h" #include "api/video/video_bitrate_allocator.h" @@ -33,18 +31,14 @@ #include "call/rtp_video_sender_interface.h" #include "modules/include/module_common_types.h" #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h" -#include "modules/utility/include/process_thread.h" #include "modules/video_coding/include/video_codec_interface.h" #include "rtc_base/experiments/field_trial_parser.h" -#include "rtc_base/synchronization/mutex.h" +#include "rtc_base/system/no_unique_address.h" #include "rtc_base/task_queue.h" +#include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/task_utils/repeating_task.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/weak_ptr.h" -#include "video/encoder_rtcp_feedback.h" -#include "video/send_delay_stats.h" #include "video/send_statistics_proxy.h" -#include "video/video_send_stream.h" namespace webrtc { namespace internal { @@ -60,42 +54,28 @@ struct PacingConfig { }; // VideoSendStreamImpl implements internal::VideoSendStream. -// It is created and destroyed on |worker_queue|. The intent is to decrease the -// need for locking and to ensure methods are called in sequence. -// Public methods except |DeliverRtcp| must be called on |worker_queue|. +// It is created and destroyed on `rtp_transport_queue`. The intent is to +// decrease the need for locking and to ensure methods are called in sequence. +// Public methods except `DeliverRtcp` must be called on `rtp_transport_queue`. // DeliverRtcp is called on the libjingle worker thread or a network thread. // An encoder may deliver frames through the EncodedImageCallback on an // arbitrary thread. class VideoSendStreamImpl : public webrtc::BitrateAllocatorObserver, public VideoStreamEncoderInterface::EncoderSink { public: - VideoSendStreamImpl( - Clock* clock, - SendStatisticsProxy* stats_proxy, - rtc::TaskQueue* worker_queue, - RtcpRttStats* call_stats, - RtpTransportControllerSendInterface* transport, - BitrateAllocatorInterface* bitrate_allocator, - SendDelayStats* send_delay_stats, - VideoStreamEncoderInterface* video_stream_encoder, - RtcEventLog* event_log, - const VideoSendStream::Config* config, - int initial_encoder_max_bitrate, - double initial_encoder_bitrate_priority, - std::map suspended_ssrcs, - std::map suspended_payload_states, - VideoEncoderConfig::ContentType content_type, - std::unique_ptr fec_controller); + VideoSendStreamImpl(Clock* clock, + SendStatisticsProxy* stats_proxy, + rtc::TaskQueue* rtp_transport_queue, + RtpTransportControllerSendInterface* transport, + BitrateAllocatorInterface* bitrate_allocator, + VideoStreamEncoderInterface* video_stream_encoder, + const VideoSendStream::Config* config, + int initial_encoder_max_bitrate, + double initial_encoder_bitrate_priority, + VideoEncoderConfig::ContentType content_type, + RtpVideoSenderInterface* rtp_video_sender); ~VideoSendStreamImpl() override; - // RegisterProcessThread register |module_process_thread| with those objects - // that use it. Registration has to happen on the thread were - // |module_process_thread| was created (libjingle's worker thread). - // TODO(perkj): Replace the use of |module_process_thread| with a TaskQueue, - // maybe |worker_queue|. - void RegisterProcessThread(ProcessThread* module_process_thread); - void DeRegisterProcessThread(); - void DeliverRtcp(const uint8_t* packet, size_t length); void UpdateActiveSimulcastLayers(const std::vector active_layers); void Start(); @@ -106,7 +86,9 @@ class VideoSendStreamImpl : public webrtc::BitrateAllocatorObserver, std::map GetRtpPayloadStates() const; - absl::optional configured_pacing_factor_; + const absl::optional& configured_pacing_factor() const { + return configured_pacing_factor_; + } private: // Implements BitrateAllocatorObserver. @@ -138,14 +120,16 @@ class VideoSendStreamImpl : public webrtc::BitrateAllocatorObserver, void StartupVideoSendStream(); // Removes the bitrate observer, stops monitoring and notifies the video // encoder of the bitrate update. - void StopVideoSendStream() RTC_RUN_ON(worker_queue_); + void StopVideoSendStream() RTC_RUN_ON(rtp_transport_queue_); void ConfigureProtection(); void ConfigureSsrcs(); void SignalEncoderTimedOut(); void SignalEncoderActive(); MediaStreamAllocationConfig GetAllocationConfig() const - RTC_RUN_ON(worker_queue_); + RTC_RUN_ON(rtp_transport_queue_); + + RTC_NO_UNIQUE_ADDRESS SequenceChecker thread_checker_; Clock* const clock_; const bool has_alr_probing_; const PacingConfig pacing_config_; @@ -153,40 +137,31 @@ class VideoSendStreamImpl : public webrtc::BitrateAllocatorObserver, SendStatisticsProxy* const stats_proxy_; const VideoSendStream::Config* const config_; - rtc::TaskQueue* const worker_queue_; + rtc::TaskQueue* const rtp_transport_queue_; RepeatingTaskHandle check_encoder_activity_task_ - RTC_GUARDED_BY(worker_queue_); + RTC_GUARDED_BY(rtp_transport_queue_); std::atomic_bool activity_; - bool timed_out_ RTC_GUARDED_BY(worker_queue_); + bool timed_out_ RTC_GUARDED_BY(rtp_transport_queue_); RtpTransportControllerSendInterface* const transport_; BitrateAllocatorInterface* const bitrate_allocator_; - Mutex ivf_writers_mutex_; - bool disable_padding_; int max_padding_bitrate_; int encoder_min_bitrate_bps_; uint32_t encoder_max_bitrate_bps_; uint32_t encoder_target_rate_bps_; double encoder_bitrate_priority_; - bool has_packet_feedback_; VideoStreamEncoderInterface* const video_stream_encoder_; - EncoderRtcpFeedback encoder_feedback_; RtcpBandwidthObserver* const bandwidth_observer_; RtpVideoSenderInterface* const rtp_video_sender_; - // |weak_ptr_| to our self. This is used since we can not call - // |weak_ptr_factory_.GetWeakPtr| from multiple sequences but it is ok to copy - // an existing WeakPtr. - rtc::WeakPtr weak_ptr_; - // |weak_ptr_factory_| must be declared last to make sure all WeakPtr's are - // invalidated before any other members are destroyed. - rtc::WeakPtrFactory weak_ptr_factory_; + rtc::scoped_refptr transport_queue_safety_ = + PendingTaskSafetyFlag::CreateDetached(); // Context for the most recent and last sent video bitrate allocation. Used to // throttle sending of similar bitrate allocations. @@ -196,7 +171,8 @@ class VideoSendStreamImpl : public webrtc::BitrateAllocatorObserver, int64_t last_send_time_ms; }; absl::optional video_bitrate_allocation_context_ - RTC_GUARDED_BY(worker_queue_); + RTC_GUARDED_BY(rtp_transport_queue_); + const absl::optional configured_pacing_factor_; }; } // namespace internal } // namespace webrtc diff --git a/video/video_send_stream_impl_unittest.cc b/video/video_send_stream_impl_unittest.cc index ee303b4eac..30a4aacd92 100644 --- a/video/video_send_stream_impl_unittest.cc +++ b/video/video_send_stream_impl_unittest.cc @@ -31,6 +31,7 @@ #include "test/mock_transport.h" #include "video/call_stats.h" #include "video/test/mock_video_stream_encoder.h" +#include "video/video_send_stream.h" namespace webrtc { @@ -61,8 +62,6 @@ std::string GetAlrProbingExperimentString() { } class MockRtpVideoSender : public RtpVideoSenderInterface { public: - MOCK_METHOD(void, RegisterProcessThread, (ProcessThread*), (override)); - MOCK_METHOD(void, DeRegisterProcessThread, (), (override)); MOCK_METHOD(void, SetActive, (bool), (override)); MOCK_METHOD(void, SetActiveModules, (const std::vector), (override)); MOCK_METHOD(bool, IsActive, (), (override)); @@ -145,17 +144,24 @@ class VideoSendStreamImplTest : public ::testing::Test { int initial_encoder_max_bitrate, double initial_encoder_bitrate_priority, VideoEncoderConfig::ContentType content_type) { + RTC_DCHECK(!test_queue_.IsCurrent()); + EXPECT_CALL(bitrate_allocator_, GetStartBitrate(_)) .WillOnce(Return(123000)); + std::map suspended_ssrcs; std::map suspended_payload_states; - return std::make_unique( - &clock_, &stats_proxy_, &test_queue_, &call_stats_, - &transport_controller_, &bitrate_allocator_, &send_delay_stats_, - &video_stream_encoder_, &event_log_, &config_, + auto ret = std::make_unique( + &clock_, &stats_proxy_, &test_queue_, &transport_controller_, + &bitrate_allocator_, &video_stream_encoder_, &config_, initial_encoder_max_bitrate, initial_encoder_bitrate_priority, - suspended_ssrcs, suspended_payload_states, content_type, - std::make_unique(&clock_)); + content_type, &rtp_video_sender_); + + // The call to GetStartBitrate() executes asynchronously on the tq. + test_queue_.WaitForPreviouslyPostedTasks(); + testing::Mock::VerifyAndClearExpectations(&bitrate_allocator_); + + return ret; } protected: @@ -179,22 +185,22 @@ class VideoSendStreamImplTest : public ::testing::Test { }; TEST_F(VideoSendStreamImplTest, RegistersAsBitrateObserverOnStart) { + auto vss_impl = CreateVideoSendStreamImpl( + kDefaultInitialBitrateBps, kDefaultBitratePriority, + VideoEncoderConfig::ContentType::kRealtimeVideo); + const bool kSuspend = false; + config_.suspend_below_min_bitrate = kSuspend; + EXPECT_CALL(bitrate_allocator_, AddObserver(vss_impl.get(), _)) + .WillOnce(Invoke( + [&](BitrateAllocatorObserver*, MediaStreamAllocationConfig config) { + EXPECT_EQ(config.min_bitrate_bps, 0u); + EXPECT_EQ(config.max_bitrate_bps, kDefaultInitialBitrateBps); + EXPECT_EQ(config.pad_up_bitrate_bps, 0u); + EXPECT_EQ(config.enforce_min_bitrate, !kSuspend); + EXPECT_EQ(config.bitrate_priority, kDefaultBitratePriority); + })); test_queue_.SendTask( - [this] { - const bool kSuspend = false; - config_.suspend_below_min_bitrate = kSuspend; - auto vss_impl = CreateVideoSendStreamImpl( - kDefaultInitialBitrateBps, kDefaultBitratePriority, - VideoEncoderConfig::ContentType::kRealtimeVideo); - EXPECT_CALL(bitrate_allocator_, AddObserver(vss_impl.get(), _)) - .WillOnce(Invoke([&](BitrateAllocatorObserver*, - MediaStreamAllocationConfig config) { - EXPECT_EQ(config.min_bitrate_bps, 0u); - EXPECT_EQ(config.max_bitrate_bps, kDefaultInitialBitrateBps); - EXPECT_EQ(config.pad_up_bitrate_bps, 0u); - EXPECT_EQ(config.enforce_min_bitrate, !kSuspend); - EXPECT_EQ(config.bitrate_priority, kDefaultBitratePriority); - })); + [&] { vss_impl->Start(); EXPECT_CALL(bitrate_allocator_, RemoveObserver(vss_impl.get())) .Times(1); @@ -204,15 +210,16 @@ TEST_F(VideoSendStreamImplTest, RegistersAsBitrateObserverOnStart) { } TEST_F(VideoSendStreamImplTest, UpdatesObserverOnConfigurationChange) { + const bool kSuspend = false; + config_.suspend_below_min_bitrate = kSuspend; + config_.rtp.extensions.emplace_back(RtpExtension::kTransportSequenceNumberUri, + 1); + auto vss_impl = CreateVideoSendStreamImpl( + kDefaultInitialBitrateBps, kDefaultBitratePriority, + VideoEncoderConfig::ContentType::kRealtimeVideo); + test_queue_.SendTask( - [this] { - const bool kSuspend = false; - config_.suspend_below_min_bitrate = kSuspend; - config_.rtp.extensions.emplace_back( - RtpExtension::kTransportSequenceNumberUri, 1); - auto vss_impl = CreateVideoSendStreamImpl( - kDefaultInitialBitrateBps, kDefaultBitratePriority, - VideoEncoderConfig::ContentType::kRealtimeVideo); + [&] { vss_impl->Start(); // QVGA + VGA configuration matching defaults in @@ -269,16 +276,16 @@ TEST_F(VideoSendStreamImplTest, UpdatesObserverOnConfigurationChange) { } TEST_F(VideoSendStreamImplTest, UpdatesObserverOnConfigurationChangeWithAlr) { + const bool kSuspend = false; + config_.suspend_below_min_bitrate = kSuspend; + config_.rtp.extensions.emplace_back(RtpExtension::kTransportSequenceNumberUri, + 1); + config_.periodic_alr_bandwidth_probing = true; + auto vss_impl = CreateVideoSendStreamImpl( + kDefaultInitialBitrateBps, kDefaultBitratePriority, + VideoEncoderConfig::ContentType::kScreen); test_queue_.SendTask( - [this] { - const bool kSuspend = false; - config_.suspend_below_min_bitrate = kSuspend; - config_.rtp.extensions.emplace_back( - RtpExtension::kTransportSequenceNumberUri, 1); - config_.periodic_alr_bandwidth_probing = true; - auto vss_impl = CreateVideoSendStreamImpl( - kDefaultInitialBitrateBps, kDefaultBitratePriority, - VideoEncoderConfig::ContentType::kScreen); + [&] { vss_impl->Start(); // Simulcast screenshare. @@ -341,11 +348,12 @@ TEST_F(VideoSendStreamImplTest, test::ScopedFieldTrials hysteresis_experiment( "WebRTC-VideoRateControl/video_hysteresis:1.25/"); + auto vss_impl = CreateVideoSendStreamImpl( + kDefaultInitialBitrateBps, kDefaultBitratePriority, + VideoEncoderConfig::ContentType::kRealtimeVideo); + test_queue_.SendTask( - [this] { - auto vss_impl = CreateVideoSendStreamImpl( - kDefaultInitialBitrateBps, kDefaultBitratePriority, - VideoEncoderConfig::ContentType::kRealtimeVideo); + [&] { vss_impl->Start(); // 2-layer video simulcast. @@ -401,17 +409,17 @@ TEST_F(VideoSendStreamImplTest, TEST_F(VideoSendStreamImplTest, SetsScreensharePacingFactorWithFeedback) { test::ScopedFieldTrials alr_experiment(GetAlrProbingExperimentString()); + constexpr int kId = 1; + config_.rtp.extensions.emplace_back(RtpExtension::kTransportSequenceNumberUri, + kId); + EXPECT_CALL(transport_controller_, + SetPacingFactor(kAlrProbingExperimentPaceMultiplier)) + .Times(1); + auto vss_impl = CreateVideoSendStreamImpl( + kDefaultInitialBitrateBps, kDefaultBitratePriority, + VideoEncoderConfig::ContentType::kScreen); test_queue_.SendTask( - [this] { - constexpr int kId = 1; - config_.rtp.extensions.emplace_back( - RtpExtension::kTransportSequenceNumberUri, kId); - EXPECT_CALL(transport_controller_, - SetPacingFactor(kAlrProbingExperimentPaceMultiplier)) - .Times(1); - auto vss_impl = CreateVideoSendStreamImpl( - kDefaultInitialBitrateBps, kDefaultBitratePriority, - VideoEncoderConfig::ContentType::kScreen); + [&] { vss_impl->Start(); vss_impl->Stop(); }, @@ -420,12 +428,12 @@ TEST_F(VideoSendStreamImplTest, SetsScreensharePacingFactorWithFeedback) { TEST_F(VideoSendStreamImplTest, DoesNotSetPacingFactorWithoutFeedback) { test::ScopedFieldTrials alr_experiment(GetAlrProbingExperimentString()); + auto vss_impl = CreateVideoSendStreamImpl( + kDefaultInitialBitrateBps, kDefaultBitratePriority, + VideoEncoderConfig::ContentType::kScreen); test_queue_.SendTask( - [this] { + [&] { EXPECT_CALL(transport_controller_, SetPacingFactor(_)).Times(0); - auto vss_impl = CreateVideoSendStreamImpl( - kDefaultInitialBitrateBps, kDefaultBitratePriority, - VideoEncoderConfig::ContentType::kScreen); vss_impl->Start(); vss_impl->Stop(); }, @@ -433,12 +441,12 @@ TEST_F(VideoSendStreamImplTest, DoesNotSetPacingFactorWithoutFeedback) { } TEST_F(VideoSendStreamImplTest, ForwardsVideoBitrateAllocationWhenEnabled) { + auto vss_impl = CreateVideoSendStreamImpl( + kDefaultInitialBitrateBps, kDefaultBitratePriority, + VideoEncoderConfig::ContentType::kScreen); test_queue_.SendTask( - [this] { + [&] { EXPECT_CALL(transport_controller_, SetPacingFactor(_)).Times(0); - auto vss_impl = CreateVideoSendStreamImpl( - kDefaultInitialBitrateBps, kDefaultBitratePriority, - VideoEncoderConfig::ContentType::kScreen); VideoStreamEncoderInterface::EncoderSink* const sink = static_cast( vss_impl.get()); @@ -483,11 +491,11 @@ TEST_F(VideoSendStreamImplTest, ForwardsVideoBitrateAllocationWhenEnabled) { } TEST_F(VideoSendStreamImplTest, ThrottlesVideoBitrateAllocationWhenTooSimilar) { + auto vss_impl = CreateVideoSendStreamImpl( + kDefaultInitialBitrateBps, kDefaultBitratePriority, + VideoEncoderConfig::ContentType::kScreen); test_queue_.SendTask( - [this] { - auto vss_impl = CreateVideoSendStreamImpl( - kDefaultInitialBitrateBps, kDefaultBitratePriority, - VideoEncoderConfig::ContentType::kScreen); + [&] { vss_impl->Start(); // Unpause encoder, to allows allocations to be passed through. const uint32_t kBitrateBps = 100000; @@ -529,8 +537,8 @@ TEST_F(VideoSendStreamImplTest, ThrottlesVideoBitrateAllocationWhenTooSimilar) { .Times(1); sink->OnBitrateAllocationUpdated(updated_alloc); - // This is now a decrease compared to last forward allocation, forward - // immediately. + // This is now a decrease compared to last forward allocation, + // forward immediately. updated_alloc.SetBitrate(0, 0, base_layer_min_update_bitrate_bps - 1); EXPECT_CALL(rtp_video_sender_, OnBitrateAllocationUpdated(updated_alloc)) @@ -543,11 +551,11 @@ TEST_F(VideoSendStreamImplTest, ThrottlesVideoBitrateAllocationWhenTooSimilar) { } TEST_F(VideoSendStreamImplTest, ForwardsVideoBitrateAllocationOnLayerChange) { + auto vss_impl = CreateVideoSendStreamImpl( + kDefaultInitialBitrateBps, kDefaultBitratePriority, + VideoEncoderConfig::ContentType::kScreen); test_queue_.SendTask( - [this] { - auto vss_impl = CreateVideoSendStreamImpl( - kDefaultInitialBitrateBps, kDefaultBitratePriority, - VideoEncoderConfig::ContentType::kScreen); + [&] { vss_impl->Start(); // Unpause encoder, to allows allocations to be passed through. const uint32_t kBitrateBps = 100000; @@ -572,8 +580,8 @@ TEST_F(VideoSendStreamImplTest, ForwardsVideoBitrateAllocationOnLayerChange) { .Times(1); sink->OnBitrateAllocationUpdated(alloc); - // Move some bitrate from one layer to a new one, but keep sum the same. - // Since layout has changed, immediately trigger forward. + // Move some bitrate from one layer to a new one, but keep sum the + // same. Since layout has changed, immediately trigger forward. VideoBitrateAllocation updated_alloc = alloc; updated_alloc.SetBitrate(2, 0, 10000); updated_alloc.SetBitrate(1, 1, alloc.GetBitrate(1, 1) - 10000); @@ -589,11 +597,11 @@ TEST_F(VideoSendStreamImplTest, ForwardsVideoBitrateAllocationOnLayerChange) { } TEST_F(VideoSendStreamImplTest, ForwardsVideoBitrateAllocationAfterTimeout) { + auto vss_impl = CreateVideoSendStreamImpl( + kDefaultInitialBitrateBps, kDefaultBitratePriority, + VideoEncoderConfig::ContentType::kScreen); test_queue_.SendTask( - [this] { - auto vss_impl = CreateVideoSendStreamImpl( - kDefaultInitialBitrateBps, kDefaultBitratePriority, - VideoEncoderConfig::ContentType::kScreen); + [&] { vss_impl->Start(); const uint32_t kBitrateBps = 100000; // Unpause encoder, to allows allocations to be passed through. @@ -639,7 +647,8 @@ TEST_F(VideoSendStreamImplTest, ForwardsVideoBitrateAllocationAfterTimeout) { clock_.AdvanceTimeMicroseconds(kMaxVbaThrottleTimeMs * 1000); { - // Sending similar allocation again after timeout, should forward. + // Sending similar allocation again after timeout, should + // forward. EXPECT_CALL(rtp_video_sender_, OnBitrateAllocationUpdated(alloc)) .Times(1); sink->OnBitrateAllocationUpdated(alloc); @@ -661,8 +670,8 @@ TEST_F(VideoSendStreamImplTest, ForwardsVideoBitrateAllocationAfterTimeout) { } { - // Advance time and send encoded image, this should wake up and send - // cached bitrate allocation. + // Advance time and send encoded image, this should wake up and + // send cached bitrate allocation. clock_.AdvanceTimeMicroseconds(kMaxVbaThrottleTimeMs * 1000); EXPECT_CALL(rtp_video_sender_, OnBitrateAllocationUpdated(alloc)) .Times(1); @@ -671,8 +680,8 @@ TEST_F(VideoSendStreamImplTest, ForwardsVideoBitrateAllocationAfterTimeout) { } { - // Advance time and send encoded image, there should be no cached - // allocation to send. + // Advance time and send encoded image, there should be no + // cached allocation to send. clock_.AdvanceTimeMicroseconds(kMaxVbaThrottleTimeMs * 1000); EXPECT_CALL(rtp_video_sender_, OnBitrateAllocationUpdated(alloc)) .Times(0); @@ -686,15 +695,15 @@ TEST_F(VideoSendStreamImplTest, ForwardsVideoBitrateAllocationAfterTimeout) { } TEST_F(VideoSendStreamImplTest, CallsVideoStreamEncoderOnBitrateUpdate) { + const bool kSuspend = false; + config_.suspend_below_min_bitrate = kSuspend; + config_.rtp.extensions.emplace_back(RtpExtension::kTransportSequenceNumberUri, + 1); + auto vss_impl = CreateVideoSendStreamImpl( + kDefaultInitialBitrateBps, kDefaultBitratePriority, + VideoEncoderConfig::ContentType::kRealtimeVideo); test_queue_.SendTask( - [this] { - const bool kSuspend = false; - config_.suspend_below_min_bitrate = kSuspend; - config_.rtp.extensions.emplace_back( - RtpExtension::kTransportSequenceNumberUri, 1); - auto vss_impl = CreateVideoSendStreamImpl( - kDefaultInitialBitrateBps, kDefaultBitratePriority, - VideoEncoderConfig::ContentType::kRealtimeVideo); + [&] { vss_impl->Start(); VideoStream qvga_stream; @@ -733,8 +742,8 @@ TEST_F(VideoSendStreamImplTest, CallsVideoStreamEncoderOnBitrateUpdate) { static_cast(vss_impl.get()) ->OnBitrateUpdated(update); - // Test allocation where the link allocation is larger than the target, - // meaning we have some headroom on the link. + // Test allocation where the link allocation is larger than the + // target, meaning we have some headroom on the link. const DataRate qvga_max_bitrate = DataRate::BitsPerSec(qvga_stream.max_bitrate_bps); const DataRate headroom = DataRate::BitsPerSec(50000); @@ -750,8 +759,8 @@ TEST_F(VideoSendStreamImplTest, CallsVideoStreamEncoderOnBitrateUpdate) { static_cast(vss_impl.get()) ->OnBitrateUpdated(update); - // Add protection bitrate to the mix, this should be subtracted from the - // headroom. + // Add protection bitrate to the mix, this should be subtracted + // from the headroom. const uint32_t protection_bitrate_bps = 10000; EXPECT_CALL(rtp_video_sender_, GetProtectionBitrateBps()) .WillOnce(Return(protection_bitrate_bps)); @@ -791,14 +800,11 @@ TEST_F(VideoSendStreamImplTest, CallsVideoStreamEncoderOnBitrateUpdate) { TEST_F(VideoSendStreamImplTest, DisablesPaddingOnPausedEncoder) { int padding_bitrate = 0; - std::unique_ptr vss_impl; - + std::unique_ptr vss_impl = CreateVideoSendStreamImpl( + kDefaultInitialBitrateBps, kDefaultBitratePriority, + VideoEncoderConfig::ContentType::kRealtimeVideo); test_queue_.SendTask( [&] { - vss_impl = CreateVideoSendStreamImpl( - kDefaultInitialBitrateBps, kDefaultBitratePriority, - VideoEncoderConfig::ContentType::kRealtimeVideo); - // Capture padding bitrate for testing. EXPECT_CALL(bitrate_allocator_, AddObserver(vss_impl.get(), _)) .WillRepeatedly(Invoke([&](BitrateAllocatorObserver*, @@ -871,7 +877,6 @@ TEST_F(VideoSendStreamImplTest, DisablesPaddingOnPausedEncoder) { EXPECT_EQ(0, padding_bitrate); testing::Mock::VerifyAndClearExpectations(&bitrate_allocator_); vss_impl->Stop(); - vss_impl.reset(); done.Set(); }, 5000); @@ -881,12 +886,11 @@ TEST_F(VideoSendStreamImplTest, DisablesPaddingOnPausedEncoder) { } TEST_F(VideoSendStreamImplTest, KeepAliveOnDroppedFrame) { - std::unique_ptr vss_impl; + std::unique_ptr vss_impl = CreateVideoSendStreamImpl( + kDefaultInitialBitrateBps, kDefaultBitratePriority, + VideoEncoderConfig::ContentType::kRealtimeVideo); test_queue_.SendTask( [&] { - vss_impl = CreateVideoSendStreamImpl( - kDefaultInitialBitrateBps, kDefaultBitratePriority, - VideoEncoderConfig::ContentType::kRealtimeVideo); vss_impl->Start(); const uint32_t kBitrateBps = 100000; EXPECT_CALL(rtp_video_sender_, GetPayloadBitrateBps()) @@ -909,7 +913,6 @@ TEST_F(VideoSendStreamImplTest, KeepAliveOnDroppedFrame) { [&] { testing::Mock::VerifyAndClearExpectations(&bitrate_allocator_); vss_impl->Stop(); - vss_impl.reset(); done.Set(); }, 2000); @@ -933,18 +936,18 @@ TEST_F(VideoSendStreamImplTest, ConfiguresBitratesForSvc) { } for (const TestConfig& test_config : test_variants) { + const bool kSuspend = false; + config_.suspend_below_min_bitrate = kSuspend; + config_.rtp.extensions.emplace_back( + RtpExtension::kTransportSequenceNumberUri, 1); + config_.periodic_alr_bandwidth_probing = test_config.alr; + auto vss_impl = CreateVideoSendStreamImpl( + kDefaultInitialBitrateBps, kDefaultBitratePriority, + test_config.screenshare + ? VideoEncoderConfig::ContentType::kScreen + : VideoEncoderConfig::ContentType::kRealtimeVideo); test_queue_.SendTask( - [this, test_config] { - const bool kSuspend = false; - config_.suspend_below_min_bitrate = kSuspend; - config_.rtp.extensions.emplace_back( - RtpExtension::kTransportSequenceNumberUri, 1); - config_.periodic_alr_bandwidth_probing = test_config.alr; - auto vss_impl = CreateVideoSendStreamImpl( - kDefaultInitialBitrateBps, kDefaultBitratePriority, - test_config.screenshare - ? VideoEncoderConfig::ContentType::kScreen - : VideoEncoderConfig::ContentType::kRealtimeVideo); + [&] { vss_impl->Start(); // Svc diff --git a/video/video_send_stream_tests.cc b/video/video_send_stream_tests.cc index 52e4ddbc42..42963cb8ee 100644 --- a/video/video_send_stream_tests.cc +++ b/video/video_send_stream_tests.cc @@ -12,6 +12,7 @@ #include #include "absl/algorithm/container.h" +#include "api/sequence_checker.h" #include "api/task_queue/default_task_queue_factory.h" #include "api/task_queue/task_queue_base.h" #include "api/test/simulated_network.h" @@ -29,6 +30,7 @@ #include "modules/rtp_rtcp/source/rtp_header_extensions.h" #include "modules/rtp_rtcp/source/rtp_packet.h" #include "modules/rtp_rtcp/source/rtp_rtcp_impl2.h" +#include "modules/rtp_rtcp/source/rtp_util.h" #include "modules/rtp_rtcp/source/video_rtp_depacketizer_vp9.h" #include "modules/video_coding/codecs/vp8/include/vp8.h" #include "modules/video_coding/codecs/vp9/include/vp9.h" @@ -39,7 +41,6 @@ #include "rtc_base/platform_thread.h" #include "rtc_base/rate_limiter.h" #include "rtc_base/synchronization/mutex.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/task_queue_for_test.h" #include "rtc_base/task_utils/to_queued_task.h" #include "rtc_base/time_utils.h" @@ -57,7 +58,6 @@ #include "test/gtest.h" #include "test/null_transport.h" #include "test/rtcp_packet_parser.h" -#include "test/rtp_header_parser.h" #include "test/testsupport/perf_test.h" #include "test/video_encoder_proxy_factory.h" #include "video/send_statistics_proxy.h" @@ -90,6 +90,9 @@ enum : int { // The first valid value is 1. kVideoTimingExtensionId, }; +// Readability convenience enum for `WaitBitrateChanged()`. +enum class WaitUntil : bool { kZero = false, kNonZero = true }; + constexpr int64_t kRtcpIntervalMs = 1000; enum VideoFormat { @@ -948,10 +951,10 @@ void VideoSendStreamTest::TestNackRetransmission( non_padding_sequence_numbers_.end() - kNackedPacketsAtOnceCount, non_padding_sequence_numbers_.end()); - RtpRtcpInterface::Configuration config; + RTCPSender::Configuration config; config.clock = Clock::GetRealTimeClock(); config.outgoing_transport = transport_adapter_.get(); - config.rtcp_report_interval_ms = kRtcpIntervalMs; + config.rtcp_report_interval = TimeDelta::Millis(kRtcpIntervalMs); config.local_media_ssrc = kReceiverLocalVideoSsrc; RTCPSender rtcp_sender(config); @@ -1164,11 +1167,11 @@ void VideoSendStreamTest::TestPacketFragmentationSize(VideoFormat format, kVideoSendSsrcs[0], rtp_packet.SequenceNumber(), packets_lost_, // Cumulative lost. loss_ratio); // Loss percent. - RtpRtcpInterface::Configuration config; + RTCPSender::Configuration config; config.clock = Clock::GetRealTimeClock(); config.receive_statistics = &lossy_receive_stats; config.outgoing_transport = transport_adapter_.get(); - config.rtcp_report_interval_ms = kRtcpIntervalMs; + config.rtcp_report_interval = TimeDelta::Millis(kRtcpIntervalMs); config.local_media_ssrc = kVideoSendSsrcs[0]; RTCPSender rtcp_sender(config); @@ -1467,14 +1470,16 @@ TEST_F(VideoSendStreamTest, MinTransmitBitrateRespectsRemb) { private: Action OnSendRtp(const uint8_t* packet, size_t length) override { - if (RtpHeaderParser::IsRtcp(packet, length)) + if (IsRtcpPacket(rtc::MakeArrayView(packet, length))) return DROP_PACKET; RtpPacket rtp_packet; if (!rtp_packet.Parse(packet, length)) return DROP_PACKET; RTC_DCHECK(stream_); - VideoSendStream::Stats stats = stream_->GetStats(); + VideoSendStream::Stats stats; + SendTask(RTC_FROM_HERE, task_queue_, + [&]() { stats = stream_->GetStats(); }); if (!stats.substreams.empty()) { EXPECT_EQ(1u, stats.substreams.size()); int total_bitrate_bps = @@ -1484,7 +1489,6 @@ TEST_F(VideoSendStreamTest, MinTransmitBitrateRespectsRemb) { "bps", false); if (total_bitrate_bps > kHighBitrateBps) { rtp_rtcp_->SetRemb(kRembBitrateBps, {rtp_packet.Ssrc()}); - rtp_rtcp_->Process(); bitrate_capped_ = true; } else if (bitrate_capped_ && total_bitrate_bps < kRembRespectedBitrateBps) { @@ -1982,7 +1986,6 @@ TEST_F(VideoSendStreamTest, public: EncoderObserver() : FakeEncoder(Clock::GetRealTimeClock()), - number_of_initializations_(0), last_initialized_frame_width_(0), last_initialized_frame_height_(0) {} @@ -2009,7 +2012,6 @@ TEST_F(VideoSendStreamTest, MutexLock lock(&mutex_); last_initialized_frame_width_ = config->width; last_initialized_frame_height_ = config->height; - ++number_of_initializations_; init_encode_called_.Set(); return FakeEncoder::InitEncode(config, settings); } @@ -2023,7 +2025,6 @@ TEST_F(VideoSendStreamTest, Mutex mutex_; rtc::Event init_encode_called_; - size_t number_of_initializations_ RTC_GUARDED_BY(&mutex_); int last_initialized_frame_width_ RTC_GUARDED_BY(&mutex_); int last_initialized_frame_height_ RTC_GUARDED_BY(&mutex_); }; @@ -2155,7 +2156,7 @@ class StartStopBitrateObserver : public test::FakeEncoder { return encoder_init_.Wait(VideoSendStreamTest::kDefaultTimeoutMs); } - bool WaitBitrateChanged(bool non_zero) { + bool WaitBitrateChanged(WaitUntil until) { do { absl::optional bitrate_kbps; { @@ -2165,8 +2166,8 @@ class StartStopBitrateObserver : public test::FakeEncoder { if (!bitrate_kbps) continue; - if ((non_zero && *bitrate_kbps > 0) || - (!non_zero && *bitrate_kbps == 0)) { + if ((until == WaitUntil::kNonZero && *bitrate_kbps > 0) || + (until == WaitUntil::kZero && *bitrate_kbps == 0)) { return true; } } while (bitrate_changed_.Wait(VideoSendStreamTest::kDefaultTimeoutMs)); @@ -2213,15 +2214,15 @@ TEST_F(VideoSendStreamTest, VideoSendStreamStopSetEncoderRateToZero) { SendTask(RTC_FROM_HERE, task_queue(), [this]() { GetVideoSendStream()->Start(); }); - EXPECT_TRUE(encoder.WaitBitrateChanged(true)); + EXPECT_TRUE(encoder.WaitBitrateChanged(WaitUntil::kNonZero)); SendTask(RTC_FROM_HERE, task_queue(), [this]() { GetVideoSendStream()->Stop(); }); - EXPECT_TRUE(encoder.WaitBitrateChanged(false)); + EXPECT_TRUE(encoder.WaitBitrateChanged(WaitUntil::kZero)); SendTask(RTC_FROM_HERE, task_queue(), [this]() { GetVideoSendStream()->Start(); }); - EXPECT_TRUE(encoder.WaitBitrateChanged(true)); + EXPECT_TRUE(encoder.WaitBitrateChanged(WaitUntil::kNonZero)); SendTask(RTC_FROM_HERE, task_queue(), [this]() { DestroyStreams(); @@ -2253,6 +2254,8 @@ TEST_F(VideoSendStreamTest, VideoSendStreamUpdateActiveSimulcastLayers) { CreateVideoStreams(); + EXPECT_FALSE(GetVideoSendStream()->started()); + // Inject a frame, to force encoder creation. GetVideoSendStream()->Start(); GetVideoSendStream()->SetSource(&forwarder, @@ -2266,8 +2269,9 @@ TEST_F(VideoSendStreamTest, VideoSendStreamUpdateActiveSimulcastLayers) { // which in turn updates the VideoEncoder's bitrate. SendTask(RTC_FROM_HERE, task_queue(), [this]() { GetVideoSendStream()->UpdateActiveSimulcastLayers({true, true}); + EXPECT_TRUE(GetVideoSendStream()->started()); }); - EXPECT_TRUE(encoder.WaitBitrateChanged(true)); + EXPECT_TRUE(encoder.WaitBitrateChanged(WaitUntil::kNonZero)); GetVideoEncoderConfig()->simulcast_layers[0].active = true; GetVideoEncoderConfig()->simulcast_layers[1].active = false; @@ -2275,22 +2279,40 @@ TEST_F(VideoSendStreamTest, VideoSendStreamUpdateActiveSimulcastLayers) { GetVideoSendStream()->ReconfigureVideoEncoder( GetVideoEncoderConfig()->Copy()); }); - // TODO(bugs.webrtc.org/8807): Currently we require a hard reconfiguration to - // update the VideoBitrateAllocator and BitrateAllocator of which layers are - // active. Once the change is made for a "soft" reconfiguration we can remove - // the expecation for an encoder init. We can also test that bitrate changes - // when just updating individual active layers, which should change the - // bitrate set to the video encoder. - EXPECT_TRUE(encoder.WaitForEncoderInit()); - EXPECT_TRUE(encoder.WaitBitrateChanged(true)); + EXPECT_TRUE(encoder.WaitBitrateChanged(WaitUntil::kNonZero)); // Turning off both simulcast layers should trigger a bitrate change of 0. GetVideoEncoderConfig()->simulcast_layers[0].active = false; GetVideoEncoderConfig()->simulcast_layers[1].active = false; SendTask(RTC_FROM_HERE, task_queue(), [this]() { GetVideoSendStream()->UpdateActiveSimulcastLayers({false, false}); + EXPECT_FALSE(GetVideoSendStream()->started()); + }); + EXPECT_TRUE(encoder.WaitBitrateChanged(WaitUntil::kZero)); + + // Re-activating a layer should resume sending and trigger a bitrate change. + GetVideoEncoderConfig()->simulcast_layers[0].active = true; + SendTask(RTC_FROM_HERE, task_queue(), [this]() { + GetVideoSendStream()->UpdateActiveSimulcastLayers({true, false}); + EXPECT_TRUE(GetVideoSendStream()->started()); + }); + EXPECT_TRUE(encoder.WaitBitrateChanged(WaitUntil::kNonZero)); + + // Stop the stream and make sure the bit rate goes to zero again. + SendTask(RTC_FROM_HERE, task_queue(), [this]() { + GetVideoSendStream()->Stop(); + EXPECT_FALSE(GetVideoSendStream()->started()); + }); + EXPECT_TRUE(encoder.WaitBitrateChanged(WaitUntil::kZero)); + + // One last test to verify that after `Stop()` we can still implicitly start + // the stream if needed. This is what will happen when a send stream gets + // re-used. See crbug.com/1241213. + SendTask(RTC_FROM_HERE, task_queue(), [this]() { + GetVideoSendStream()->UpdateActiveSimulcastLayers({true, true}); + EXPECT_TRUE(GetVideoSendStream()->started()); }); - EXPECT_TRUE(encoder.WaitBitrateChanged(false)); + EXPECT_TRUE(encoder.WaitBitrateChanged(WaitUntil::kNonZero)); SendTask(RTC_FROM_HERE, task_queue(), [this]() { DestroyStreams(); @@ -2432,14 +2454,16 @@ class VideoCodecConfigObserver : public test::SendTest, public test::FakeEncoder { public: VideoCodecConfigObserver(VideoCodecType video_codec_type, - const char* codec_name) + const char* codec_name, + TaskQueueBase* task_queue) : SendTest(VideoSendStreamTest::kDefaultTimeoutMs), FakeEncoder(Clock::GetRealTimeClock()), video_codec_type_(video_codec_type), codec_name_(codec_name), num_initializations_(0), stream_(nullptr), - encoder_factory_(this) { + encoder_factory_(this), + task_queue_(task_queue) { InitCodecSpecifics(); } @@ -2487,7 +2511,9 @@ class VideoCodecConfigObserver : public test::SendTest, // Change encoder settings to actually trigger reconfiguration. encoder_settings_.frameDroppingOn = !encoder_settings_.frameDroppingOn; encoder_config_.encoder_specific_settings = GetEncoderSpecificSettings(); - stream_->ReconfigureVideoEncoder(std::move(encoder_config_)); + SendTask(RTC_FROM_HERE, task_queue_, [&]() { + stream_->ReconfigureVideoEncoder(std::move(encoder_config_)); + }); ASSERT_TRUE( init_encode_event_.Wait(VideoSendStreamTest::kDefaultTimeoutMs)); EXPECT_EQ(2u, num_initializations_) @@ -2509,6 +2535,7 @@ class VideoCodecConfigObserver : public test::SendTest, VideoSendStream* stream_; test::VideoEncoderProxyFactory encoder_factory_; VideoEncoderConfig encoder_config_; + TaskQueueBase* task_queue_; }; template <> @@ -2541,8 +2568,8 @@ void VideoCodecConfigObserver::VerifyCodecSpecifics( template <> rtc::scoped_refptr VideoCodecConfigObserver::GetEncoderSpecificSettings() const { - return new rtc::RefCountedObject< - VideoEncoderConfig::H264EncoderSpecificSettings>(encoder_settings_); + return rtc::make_ref_counted( + encoder_settings_); } template <> @@ -2575,8 +2602,8 @@ void VideoCodecConfigObserver::VerifyCodecSpecifics( template <> rtc::scoped_refptr VideoCodecConfigObserver::GetEncoderSpecificSettings() const { - return new rtc::RefCountedObject< - VideoEncoderConfig::Vp8EncoderSpecificSettings>(encoder_settings_); + return rtc::make_ref_counted( + encoder_settings_); } template <> @@ -2609,17 +2636,19 @@ void VideoCodecConfigObserver::VerifyCodecSpecifics( template <> rtc::scoped_refptr VideoCodecConfigObserver::GetEncoderSpecificSettings() const { - return new rtc::RefCountedObject< - VideoEncoderConfig::Vp9EncoderSpecificSettings>(encoder_settings_); + return rtc::make_ref_counted( + encoder_settings_); } TEST_F(VideoSendStreamTest, EncoderSetupPropagatesVp8Config) { - VideoCodecConfigObserver test(kVideoCodecVP8, "VP8"); + VideoCodecConfigObserver test(kVideoCodecVP8, "VP8", + task_queue()); RunBaseTest(&test); } TEST_F(VideoSendStreamTest, EncoderSetupPropagatesVp9Config) { - VideoCodecConfigObserver test(kVideoCodecVP9, "VP9"); + VideoCodecConfigObserver test(kVideoCodecVP9, "VP9", + task_queue()); RunBaseTest(&test); } @@ -2631,7 +2660,8 @@ TEST_F(VideoSendStreamTest, EncoderSetupPropagatesVp9Config) { #define MAYBE_EncoderSetupPropagatesH264Config EncoderSetupPropagatesH264Config #endif TEST_F(VideoSendStreamTest, MAYBE_EncoderSetupPropagatesH264Config) { - VideoCodecConfigObserver test(kVideoCodecH264, "H264"); + VideoCodecConfigObserver test(kVideoCodecH264, "H264", + task_queue()); RunBaseTest(&test); } @@ -2736,7 +2766,7 @@ TEST_F(VideoSendStreamTest, TranslatesTwoLayerScreencastToTargetBitrate) { send_config->encoder_settings.encoder_factory = &encoder_factory_; EXPECT_EQ(1u, encoder_config->number_of_streams); encoder_config->video_stream_factory = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); EXPECT_EQ(1u, encoder_config->simulcast_layers.size()); encoder_config->simulcast_layers[0].num_temporal_layers = 2; encoder_config->content_type = VideoEncoderConfig::ContentType::kScreen; @@ -2914,7 +2944,9 @@ TEST_F(VideoSendStreamTest, ReconfigureBitratesSetsEncoderBitratesCorrectly) { // Encoder rate is capped by EncoderConfig max_bitrate_bps. WaitForSetRates(kMaxBitrateKbps); encoder_config_.max_bitrate_bps = kLowerMaxBitrateKbps * 1000; - send_stream_->ReconfigureVideoEncoder(encoder_config_.Copy()); + SendTask(RTC_FROM_HERE, task_queue_, [&]() { + send_stream_->ReconfigureVideoEncoder(encoder_config_.Copy()); + }); ASSERT_TRUE(create_rate_allocator_event_.Wait( VideoSendStreamTest::kDefaultTimeoutMs)); EXPECT_EQ(2, num_rate_allocator_creations_) @@ -2924,7 +2956,9 @@ TEST_F(VideoSendStreamTest, ReconfigureBitratesSetsEncoderBitratesCorrectly) { EXPECT_EQ(1, num_encoder_initializations_); encoder_config_.max_bitrate_bps = kIncreasedMaxBitrateKbps * 1000; - send_stream_->ReconfigureVideoEncoder(encoder_config_.Copy()); + SendTask(RTC_FROM_HERE, task_queue_, [&]() { + send_stream_->ReconfigureVideoEncoder(encoder_config_.Copy()); + }); ASSERT_TRUE(create_rate_allocator_event_.Wait( VideoSendStreamTest::kDefaultTimeoutMs)); EXPECT_EQ(3, num_rate_allocator_creations_) @@ -2965,11 +2999,12 @@ TEST_F(VideoSendStreamTest, ReportsSentResolution) { class ScreencastTargetBitrateTest : public test::SendTest, public test::FakeEncoder { public: - ScreencastTargetBitrateTest() + explicit ScreencastTargetBitrateTest(TaskQueueBase* task_queue) : SendTest(kDefaultTimeoutMs), test::FakeEncoder(Clock::GetRealTimeClock()), send_stream_(nullptr), - encoder_factory_(this) {} + encoder_factory_(this), + task_queue_(task_queue) {} private: int32_t Encode(const VideoFrame& input_image, @@ -3017,7 +3052,9 @@ TEST_F(VideoSendStreamTest, ReportsSentResolution) { void PerformTest() override { EXPECT_TRUE(Wait()) << "Timed out while waiting for the encoder to send one frame."; - VideoSendStream::Stats stats = send_stream_->GetStats(); + VideoSendStream::Stats stats; + SendTask(RTC_FROM_HERE, task_queue_, + [&]() { stats = send_stream_->GetStats(); }); for (size_t i = 0; i < kNumStreams; ++i) { ASSERT_TRUE(stats.substreams.find(kVideoSendSsrcs[i]) != @@ -3039,7 +3076,8 @@ TEST_F(VideoSendStreamTest, ReportsSentResolution) { VideoSendStream* send_stream_; test::VideoEncoderProxyFactory encoder_factory_; - } test; + TaskQueueBase* const task_queue_; + } test(task_queue()); RunBaseTest(&test); } @@ -3074,8 +3112,9 @@ class Vp9HeaderObserver : public test::SendTest { send_config->rtp.payload_name = "VP9"; send_config->rtp.payload_type = kVp9PayloadType; ModifyVideoConfigsHook(send_config, receive_configs, encoder_config); - encoder_config->encoder_specific_settings = new rtc::RefCountedObject< - VideoEncoderConfig::Vp9EncoderSpecificSettings>(vp9_settings_); + encoder_config->encoder_specific_settings = + rtc::make_ref_counted( + vp9_settings_); EXPECT_EQ(1u, encoder_config->number_of_streams); EXPECT_EQ(1u, encoder_config->simulcast_layers.size()); encoder_config->simulcast_layers[0].num_temporal_layers = @@ -3809,14 +3848,15 @@ class ContentSwitchTest : public test::SendTest { }; static const uint32_t kMinPacketsToSend = 50; - explicit ContentSwitchTest(T* stream_reset_fun) + explicit ContentSwitchTest(T* stream_reset_fun, TaskQueueBase* task_queue) : SendTest(test::CallTest::kDefaultTimeoutMs), call_(nullptr), state_(StreamState::kBeforeSwitch), send_stream_(nullptr), send_stream_config_(nullptr), packets_sent_(0), - stream_resetter_(stream_reset_fun) { + stream_resetter_(stream_reset_fun), + task_queue_(task_queue) { RTC_DCHECK(stream_resetter_); } @@ -3850,8 +3890,10 @@ class ContentSwitchTest : public test::SendTest { float pacing_factor = internal_send_peer.GetPacingFactorOverride().value_or(0.0f); float expected_pacing_factor = 1.1; // Strict pacing factor. - if (send_stream_->GetStats().content_type == - webrtc::VideoContentType::SCREENSHARE) { + VideoSendStream::Stats stats; + SendTask(RTC_FROM_HERE, task_queue_, + [&stats, stream = send_stream_]() { stats = stream->GetStats(); }); + if (stats.content_type == webrtc::VideoContentType::SCREENSHARE) { expected_pacing_factor = 1.0f; // Currently used pacing factor in ALR. } @@ -3919,6 +3961,7 @@ class ContentSwitchTest : public test::SendTest { VideoEncoderConfig encoder_config_; uint32_t packets_sent_ RTC_GUARDED_BY(mutex_); T* stream_resetter_; + TaskQueueBase* task_queue_; }; TEST_F(VideoSendStreamTest, SwitchesToScreenshareAndBack) { @@ -3938,7 +3981,7 @@ TEST_F(VideoSendStreamTest, SwitchesToScreenshareAndBack) { Start(); }); }; - ContentSwitchTest test(&reset_fun); + ContentSwitchTest test(&reset_fun, task_queue()); RunBaseTest(&test); } diff --git a/video/video_source_sink_controller.cc b/video/video_source_sink_controller.cc index 376eb85eae..4cd12d8a27 100644 --- a/video/video_source_sink_controller.cc +++ b/video/video_source_sink_controller.cc @@ -29,7 +29,14 @@ std::string WantsToString(const rtc::VideoSinkWants& wants) { << " max_pixel_count=" << wants.max_pixel_count << " target_pixel_count=" << (wants.target_pixel_count.has_value() ? std::to_string(wants.target_pixel_count.value()) - : "null"); + : "null") + << " resolutions={"; + for (size_t i = 0; i < wants.resolutions.size(); ++i) { + if (i != 0) + ss << ","; + ss << wants.resolutions[i].width << "x" << wants.resolutions[i].height; + } + ss << "}"; return ss.Release(); } @@ -104,6 +111,12 @@ int VideoSourceSinkController::resolution_alignment() const { return resolution_alignment_; } +const std::vector& +VideoSourceSinkController::resolutions() const { + RTC_DCHECK_RUN_ON(&sequence_checker_); + return resolutions_; +} + void VideoSourceSinkController::SetRestrictions( VideoSourceRestrictions restrictions) { RTC_DCHECK_RUN_ON(&sequence_checker_); @@ -133,6 +146,12 @@ void VideoSourceSinkController::SetResolutionAlignment( resolution_alignment_ = resolution_alignment; } +void VideoSourceSinkController::SetResolutions( + std::vector resolutions) { + RTC_DCHECK_RUN_ON(&sequence_checker_); + resolutions_ = std::move(resolutions); +} + // RTC_EXCLUSIVE_LOCKS_REQUIRED(sequence_checker_) rtc::VideoSinkWants VideoSourceSinkController::CurrentSettingsToSinkWants() const { @@ -161,6 +180,7 @@ rtc::VideoSinkWants VideoSourceSinkController::CurrentSettingsToSinkWants() frame_rate_upper_limit_.has_value() ? static_cast(frame_rate_upper_limit_.value()) : std::numeric_limits::max()); + wants.resolutions = resolutions_; return wants; } diff --git a/video/video_source_sink_controller.h b/video/video_source_sink_controller.h index 134366cfd0..c61084f99a 100644 --- a/video/video_source_sink_controller.h +++ b/video/video_source_sink_controller.h @@ -12,13 +12,14 @@ #define VIDEO_VIDEO_SOURCE_SINK_CONTROLLER_H_ #include +#include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "api/video/video_frame.h" #include "api/video/video_sink_interface.h" #include "api/video/video_source_interface.h" #include "call/adaptation/video_source_restrictions.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" namespace webrtc { @@ -46,6 +47,7 @@ class VideoSourceSinkController { absl::optional frame_rate_upper_limit() const; bool rotation_applied() const; int resolution_alignment() const; + const std::vector& resolutions() const; // Updates the settings stored internally. In order for these settings to be // applied to the sink, PushSourceSinkSettings() must subsequently be called. @@ -55,6 +57,7 @@ class VideoSourceSinkController { void SetFrameRateUpperLimit(absl::optional frame_rate_upper_limit); void SetRotationApplied(bool rotation_applied); void SetResolutionAlignment(int resolution_alignment); + void SetResolutions(std::vector resolutions); private: rtc::VideoSinkWants CurrentSettingsToSinkWants() const @@ -79,6 +82,8 @@ class VideoSourceSinkController { RTC_GUARDED_BY(&sequence_checker_); bool rotation_applied_ RTC_GUARDED_BY(&sequence_checker_) = false; int resolution_alignment_ RTC_GUARDED_BY(&sequence_checker_) = 1; + std::vector resolutions_ + RTC_GUARDED_BY(&sequence_checker_); }; } // namespace webrtc diff --git a/video/video_stream_decoder_impl.cc b/video/video_stream_decoder_impl.cc index f5b0f5f787..b6d754e8be 100644 --- a/video/video_stream_decoder_impl.cc +++ b/video/video_stream_decoder_impl.cc @@ -50,8 +50,7 @@ VideoStreamDecoderImpl::~VideoStreamDecoderImpl() { shut_down_ = true; } -void VideoStreamDecoderImpl::OnFrame( - std::unique_ptr frame) { +void VideoStreamDecoderImpl::OnFrame(std::unique_ptr frame) { if (!bookkeeping_queue_.IsCurrent()) { bookkeeping_queue_.PostTask([this, frame = std::move(frame)]() mutable { OnFrame(std::move(frame)); @@ -63,11 +62,10 @@ void VideoStreamDecoderImpl::OnFrame( RTC_DCHECK_RUN_ON(&bookkeeping_queue_); - uint64_t continuous_pid = frame_buffer_.InsertFrame(std::move(frame)); - video_coding::VideoLayerFrameId continuous_id(continuous_pid, 0); - if (last_continuous_id_ < continuous_id) { - last_continuous_id_ = continuous_id; - callbacks_->OnContinuousUntil(last_continuous_id_); + int64_t continuous_frame_id = frame_buffer_.InsertFrame(std::move(frame)); + if (last_continuous_frame_id_ < continuous_frame_id) { + last_continuous_frame_id_ = continuous_frame_id; + callbacks_->OnContinuousUntil(last_continuous_frame_id_); } } @@ -124,8 +122,7 @@ VideoDecoder* VideoStreamDecoderImpl::GetDecoder(int payload_type) { return decoder_.get(); } -void VideoStreamDecoderImpl::SaveFrameInfo( - const video_coding::EncodedFrame& frame) { +void VideoStreamDecoderImpl::SaveFrameInfo(const EncodedFrame& frame) { FrameInfo* frame_info = &frame_info_[next_frame_info_index_]; frame_info->timestamp = frame.Timestamp(); frame_info->decode_start_time_ms = rtc::TimeMillis(); @@ -140,7 +137,7 @@ void VideoStreamDecoderImpl::StartNextDecode() { frame_buffer_.NextFrame( max_wait_time, keyframe_required_, &bookkeeping_queue_, - [this](std::unique_ptr frame, + [this](std::unique_ptr frame, video_coding::FrameBuffer::ReturnReason res) mutable { RTC_DCHECK_RUN_ON(&bookkeeping_queue_); OnNextFrameCallback(std::move(frame), res); @@ -148,7 +145,7 @@ void VideoStreamDecoderImpl::StartNextDecode() { } void VideoStreamDecoderImpl::OnNextFrameCallback( - std::unique_ptr frame, + std::unique_ptr frame, video_coding::FrameBuffer::ReturnReason result) { switch (result) { case video_coding::FrameBuffer::kFrameFound: { @@ -205,7 +202,7 @@ void VideoStreamDecoderImpl::OnNextFrameCallback( } VideoStreamDecoderImpl::DecodeResult VideoStreamDecoderImpl::DecodeFrame( - std::unique_ptr frame) { + std::unique_ptr frame) { RTC_DCHECK(frame); VideoDecoder* decoder = GetDecoder(frame->PayloadType()); diff --git a/video/video_stream_decoder_impl.h b/video/video_stream_decoder_impl.h index 69a8195054..106f38340a 100644 --- a/video/video_stream_decoder_impl.h +++ b/video/video_stream_decoder_impl.h @@ -16,13 +16,13 @@ #include #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "api/video/video_stream_decoder.h" #include "modules/video_coding/frame_buffer2.h" #include "modules/video_coding/timing.h" #include "rtc_base/platform_thread.h" #include "rtc_base/synchronization/mutex.h" #include "rtc_base/task_queue.h" -#include "rtc_base/thread_checker.h" #include "system_wrappers/include/clock.h" namespace webrtc { @@ -37,7 +37,7 @@ class VideoStreamDecoderImpl : public VideoStreamDecoderInterface { ~VideoStreamDecoderImpl() override; - void OnFrame(std::unique_ptr frame) override; + void OnFrame(std::unique_ptr frame) override; void SetMinPlayoutDelay(TimeDelta min_delay) override; void SetMaxPlayoutDelay(TimeDelta max_delay) override; @@ -69,11 +69,10 @@ class VideoStreamDecoderImpl : public VideoStreamDecoderInterface { VideoContentType content_type; }; - void SaveFrameInfo(const video_coding::EncodedFrame& frame) - RTC_RUN_ON(bookkeeping_queue_); + void SaveFrameInfo(const EncodedFrame& frame) RTC_RUN_ON(bookkeeping_queue_); FrameInfo* GetFrameInfo(int64_t timestamp) RTC_RUN_ON(bookkeeping_queue_); void StartNextDecode() RTC_RUN_ON(bookkeeping_queue_); - void OnNextFrameCallback(std::unique_ptr frame, + void OnNextFrameCallback(std::unique_ptr frame, video_coding::FrameBuffer::ReturnReason res) RTC_RUN_ON(bookkeeping_queue_); void OnDecodedFrameCallback(VideoFrame& decodedImage, // NOLINT @@ -82,8 +81,7 @@ class VideoStreamDecoderImpl : public VideoStreamDecoderInterface { VideoDecoder* GetDecoder(int payload_type) RTC_RUN_ON(decode_queue_); VideoStreamDecoderImpl::DecodeResult DecodeFrame( - std::unique_ptr frame) - RTC_RUN_ON(decode_queue_); + std::unique_ptr frame) RTC_RUN_ON(decode_queue_); VCMTiming timing_; DecodeCallbacks decode_callbacks_; @@ -96,8 +94,7 @@ class VideoStreamDecoderImpl : public VideoStreamDecoderInterface { int next_frame_info_index_ RTC_GUARDED_BY(bookkeeping_queue_); VideoStreamDecoderInterface::Callbacks* const callbacks_ RTC_PT_GUARDED_BY(bookkeeping_queue_); - video_coding::VideoLayerFrameId last_continuous_id_ - RTC_GUARDED_BY(bookkeeping_queue_); + int64_t last_continuous_frame_id_ RTC_GUARDED_BY(bookkeeping_queue_) = -1; bool keyframe_required_ RTC_GUARDED_BY(bookkeeping_queue_); absl::optional current_payload_type_ RTC_GUARDED_BY(decode_queue_); diff --git a/video/video_stream_decoder_impl_unittest.cc b/video/video_stream_decoder_impl_unittest.cc index a957f01ead..a3e258976a 100644 --- a/video/video_stream_decoder_impl_unittest.cc +++ b/video/video_stream_decoder_impl_unittest.cc @@ -28,10 +28,7 @@ class MockVideoStreamDecoderCallbacks : public VideoStreamDecoderInterface::Callbacks { public: MOCK_METHOD(void, OnNonDecodableState, (), (override)); - MOCK_METHOD(void, - OnContinuousUntil, - (const video_coding::VideoLayerFrameId& key), - (override)); + MOCK_METHOD(void, OnContinuousUntil, (int64_t frame_id), (override)); MOCK_METHOD( void, OnDecodedFrame, @@ -130,7 +127,7 @@ class FakeVideoDecoderFactory : public VideoDecoderFactory { NiceMock av1_decoder_; }; -class FakeEncodedFrame : public video_coding::EncodedFrame { +class FakeEncodedFrame : public EncodedFrame { public: int64_t ReceivedTime() const override { return 0; } int64_t RenderTime() const override { return 0; } @@ -149,7 +146,7 @@ class FrameBuilder { } FrameBuilder& WithPictureId(int picture_id) { - frame_->id.picture_id = picture_id; + frame_->SetId(picture_id); return *this; } diff --git a/video/video_stream_encoder.cc b/video/video_stream_encoder.cc index 1cfb280208..107110987b 100644 --- a/video/video_stream_encoder.cc +++ b/video/video_stream_encoder.cc @@ -19,6 +19,7 @@ #include "absl/algorithm/container.h" #include "absl/types/optional.h" +#include "api/sequence_checker.h" #include "api/task_queue/queued_task.h" #include "api/task_queue/task_queue_base.h" #include "api/video/encoded_image.h" @@ -37,11 +38,11 @@ #include "rtc_base/constructor_magic.h" #include "rtc_base/event.h" #include "rtc_base/experiments/alr_experiment.h" +#include "rtc_base/experiments/encoder_info_settings.h" #include "rtc_base/experiments/rate_control_settings.h" #include "rtc_base/location.h" #include "rtc_base/logging.h" #include "rtc_base/strings/string_builder.h" -#include "rtc_base/synchronization/sequence_checker.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/thread_annotations.h" #include "rtc_base/trace_event.h" @@ -260,8 +261,9 @@ VideoLayersAllocation CreateVideoLayersAllocation( // Encoder may drop frames internally if `maxFramerate` is set. spatial_layer.frame_rate_fps = std::min( encoder_config.simulcastStream[si].maxFramerate, - (current_rate.framerate_fps * frame_rate_fraction) / - VideoEncoder::EncoderInfo::kMaxFramerateFraction); + rtc::saturated_cast( + (current_rate.framerate_fps * frame_rate_fraction) / + VideoEncoder::EncoderInfo::kMaxFramerateFraction)); } } else if (encoder_config.numberOfSimulcastStreams == 1) { // TODO(bugs.webrtc.org/12000): Implement support for AV1 with @@ -329,14 +331,171 @@ VideoLayersAllocation CreateVideoLayersAllocation( // Encoder may drop frames internally if `maxFramerate` is set. spatial_layer.frame_rate_fps = std::min( encoder_config.spatialLayers[si].maxFramerate, - (current_rate.framerate_fps * frame_rate_fraction) / - VideoEncoder::EncoderInfo::kMaxFramerateFraction); + rtc::saturated_cast( + (current_rate.framerate_fps * frame_rate_fraction) / + VideoEncoder::EncoderInfo::kMaxFramerateFraction)); } } return layers_allocation; } +VideoEncoder::EncoderInfo GetEncoderInfoWithBitrateLimitUpdate( + const VideoEncoder::EncoderInfo& info, + const VideoEncoderConfig& encoder_config, + bool default_limits_allowed) { + if (!default_limits_allowed || !info.resolution_bitrate_limits.empty() || + encoder_config.simulcast_layers.size() <= 1) { + return info; + } + // Bitrate limits are not configured and more than one layer is used, use + // the default limits (bitrate limits are not used for simulcast). + VideoEncoder::EncoderInfo new_info = info; + new_info.resolution_bitrate_limits = + EncoderInfoSettings::GetDefaultSinglecastBitrateLimits( + encoder_config.codec_type); + return new_info; +} + +int NumActiveStreams(const std::vector& streams) { + int num_active = 0; + for (const auto& stream : streams) { + if (stream.active) + ++num_active; + } + return num_active; +} + +void ApplyVp9BitrateLimits(const VideoEncoder::EncoderInfo& encoder_info, + const VideoEncoderConfig& encoder_config, + VideoCodec* codec) { + if (codec->codecType != VideoCodecType::kVideoCodecVP9 || + encoder_config.simulcast_layers.size() <= 1 || + VideoStreamEncoderResourceManager::IsSimulcast(encoder_config)) { + // Resolution bitrate limits usage is restricted to singlecast. + return; + } + + // Get bitrate limits for active stream. + absl::optional pixels = + VideoStreamAdapter::GetSingleActiveLayerPixels(*codec); + if (!pixels.has_value()) { + return; + } + absl::optional bitrate_limits = + encoder_info.GetEncoderBitrateLimitsForResolution(*pixels); + if (!bitrate_limits.has_value()) { + return; + } + + // Index for the active stream. + absl::optional index; + for (size_t i = 0; i < encoder_config.simulcast_layers.size(); ++i) { + if (encoder_config.simulcast_layers[i].active) + index = i; + } + if (!index.has_value()) { + return; + } + + int min_bitrate_bps; + if (encoder_config.simulcast_layers[*index].min_bitrate_bps <= 0) { + min_bitrate_bps = bitrate_limits->min_bitrate_bps; + } else { + min_bitrate_bps = + std::max(bitrate_limits->min_bitrate_bps, + encoder_config.simulcast_layers[*index].min_bitrate_bps); + } + int max_bitrate_bps; + if (encoder_config.simulcast_layers[*index].max_bitrate_bps <= 0) { + max_bitrate_bps = bitrate_limits->max_bitrate_bps; + } else { + max_bitrate_bps = + std::min(bitrate_limits->max_bitrate_bps, + encoder_config.simulcast_layers[*index].max_bitrate_bps); + } + if (min_bitrate_bps >= max_bitrate_bps) { + RTC_LOG(LS_WARNING) << "Bitrate limits not used, min_bitrate_bps " + << min_bitrate_bps << " >= max_bitrate_bps " + << max_bitrate_bps; + return; + } + + for (int i = 0; i < codec->VP9()->numberOfSpatialLayers; ++i) { + if (codec->spatialLayers[i].active) { + codec->spatialLayers[i].minBitrate = min_bitrate_bps / 1000; + codec->spatialLayers[i].maxBitrate = max_bitrate_bps / 1000; + codec->spatialLayers[i].targetBitrate = + std::min(codec->spatialLayers[i].targetBitrate, + codec->spatialLayers[i].maxBitrate); + break; + } + } +} + +void ApplyEncoderBitrateLimitsIfSingleActiveStream( + const VideoEncoder::EncoderInfo& encoder_info, + const std::vector& encoder_config_layers, + std::vector* streams) { + // Apply limits if simulcast with one active stream (expect lowest). + bool single_active_stream = + streams->size() > 1 && NumActiveStreams(*streams) == 1 && + !streams->front().active && NumActiveStreams(encoder_config_layers) == 1; + if (!single_active_stream) { + return; + } + + // Index for the active stream. + size_t index = 0; + for (size_t i = 0; i < encoder_config_layers.size(); ++i) { + if (encoder_config_layers[i].active) + index = i; + } + if (streams->size() < (index + 1) || !(*streams)[index].active) { + return; + } + + // Get bitrate limits for active stream. + absl::optional encoder_bitrate_limits = + encoder_info.GetEncoderBitrateLimitsForResolution( + (*streams)[index].width * (*streams)[index].height); + if (!encoder_bitrate_limits) { + return; + } + + // If bitrate limits are set by RtpEncodingParameters, use intersection. + int min_bitrate_bps; + if (encoder_config_layers[index].min_bitrate_bps <= 0) { + min_bitrate_bps = encoder_bitrate_limits->min_bitrate_bps; + } else { + min_bitrate_bps = std::max(encoder_bitrate_limits->min_bitrate_bps, + (*streams)[index].min_bitrate_bps); + } + int max_bitrate_bps; + if (encoder_config_layers[index].max_bitrate_bps <= 0) { + max_bitrate_bps = encoder_bitrate_limits->max_bitrate_bps; + } else { + max_bitrate_bps = std::min(encoder_bitrate_limits->max_bitrate_bps, + (*streams)[index].max_bitrate_bps); + } + if (min_bitrate_bps >= max_bitrate_bps) { + RTC_LOG(LS_WARNING) << "Encoder bitrate limits" + << " (min=" << encoder_bitrate_limits->min_bitrate_bps + << ", max=" << encoder_bitrate_limits->max_bitrate_bps + << ") do not intersect with stream limits" + << " (min=" << (*streams)[index].min_bitrate_bps + << ", max=" << (*streams)[index].max_bitrate_bps + << "). Encoder bitrate limits not used."; + return; + } + + (*streams)[index].min_bitrate_bps = min_bitrate_bps; + (*streams)[index].max_bitrate_bps = max_bitrate_bps; + (*streams)[index].target_bitrate_bps = + std::min((*streams)[index].target_bitrate_bps, + encoder_bitrate_limits->max_bitrate_bps); +} + } // namespace VideoStreamEncoder::EncoderRateSettings::EncoderRateSettings() @@ -433,7 +592,6 @@ VideoStreamEncoder::VideoStreamEncoder( BitrateAllocationCallbackType allocation_cb_type) : main_queue_(TaskQueueBase::Current()), number_of_cores_(number_of_cores), - quality_scaling_experiment_enabled_(QualityScalingExperiment::Enabled()), sink_(nullptr), settings_(settings), allocation_cb_type_(allocation_cb_type), @@ -474,10 +632,8 @@ VideoStreamEncoder::VideoStreamEncoder( next_frame_types_(1, VideoFrameType::kVideoFrameDelta), frame_encode_metadata_writer_(this), experiment_groups_(GetExperimentGroups()), - encoder_switch_experiment_(ParseEncoderSwitchFieldTrial()), automatic_animation_detection_experiment_( ParseAutomatincAnimationDetectionFieldTrial()), - encoder_switch_requested_(false), input_state_provider_(encoder_stats_observer), video_stream_adapter_( std::make_unique(&input_state_provider_, @@ -497,9 +653,14 @@ VideoStreamEncoder::VideoStreamEncoder( degradation_preference_manager_.get()), video_source_sink_controller_(/*sink=*/this, /*source=*/nullptr), + default_limits_allowed_( + !field_trial::IsEnabled("WebRTC-DefaultBitrateLimitsKillSwitch")), + qp_parsing_allowed_( + !field_trial::IsEnabled("WebRTC-QpParsingKillSwitch")), encoder_queue_(task_queue_factory->CreateTaskQueue( "EncoderQueue", TaskQueueFactory::Priority::NORMAL)) { + TRACE_EVENT0("webrtc", "VideoStreamEncoder::VideoStreamEncoder"); RTC_DCHECK(main_queue_); RTC_DCHECK(encoder_stats_observer); RTC_DCHECK_GE(number_of_cores, 1); @@ -582,11 +743,16 @@ void VideoStreamEncoder::SetFecControllerOverride( void VideoStreamEncoder::AddAdaptationResource( rtc::scoped_refptr resource) { RTC_DCHECK_RUN_ON(main_queue_); + TRACE_EVENT0("webrtc", "VideoStreamEncoder::AddAdaptationResource"); // Map any externally added resources as kCpu for the sake of stats reporting. // TODO(hbos): Make the manager map any unknown resources to kCpu and get rid // of this MapResourceToReason() call. + TRACE_EVENT_ASYNC_BEGIN0( + "webrtc", "VideoStreamEncoder::AddAdaptationResource(latency)", this); rtc::Event map_resource_event; encoder_queue_.PostTask([this, resource, &map_resource_event] { + TRACE_EVENT_ASYNC_END0( + "webrtc", "VideoStreamEncoder::AddAdaptationResource(latency)", this); RTC_DCHECK_RUN_ON(&encoder_queue_); additional_resources_.push_back(resource); stream_resource_manager_.AddResource(resource, VideoAdaptationReason::kCpu); @@ -686,19 +852,6 @@ void VideoStreamEncoder::ReconfigureEncoder() { // Running on the encoder queue. RTC_DCHECK(pending_encoder_reconfiguration_); - if (!encoder_selector_ && - encoder_switch_experiment_.IsPixelCountBelowThreshold( - last_frame_info_->width * last_frame_info_->height) && - !encoder_switch_requested_ && settings_.encoder_switch_request_callback) { - EncoderSwitchRequestCallback::Config conf; - conf.codec_name = encoder_switch_experiment_.to_codec; - conf.param = encoder_switch_experiment_.to_param; - conf.value = encoder_switch_experiment_.to_value; - QueueRequestEncoderSwitch(conf); - - encoder_switch_requested_ = true; - } - bool encoder_reset_required = false; if (pending_encoder_creation_) { // Destroy existing encoder instance before creating a new one. Otherwise @@ -726,13 +879,17 @@ void VideoStreamEncoder::ReconfigureEncoder() { // Possibly adjusts scale_resolution_down_by in |encoder_config_| to limit the // alignment value. - int alignment = AlignmentAdjuster::GetAlignmentAndMaybeAdjustScaleFactors( - encoder_->GetEncoderInfo(), &encoder_config_); + AlignmentAdjuster::GetAlignmentAndMaybeAdjustScaleFactors( + encoder_->GetEncoderInfo(), &encoder_config_, absl::nullopt); std::vector streams = encoder_config_.video_stream_factory->CreateEncoderStreams( last_frame_info_->width, last_frame_info_->height, encoder_config_); + // Get alignment when actual number of layers are known. + int alignment = AlignmentAdjuster::GetAlignmentAndMaybeAdjustScaleFactors( + encoder_->GetEncoderInfo(), &encoder_config_, streams.size()); + // Check that the higher layers do not try to set number of temporal layers // to less than 1. // TODO(brandtr): Get rid of the wrapping optional as it serves no purpose @@ -761,53 +918,59 @@ void VideoStreamEncoder::ReconfigureEncoder() { crop_width_ = last_frame_info_->width - highest_stream_width; crop_height_ = last_frame_info_->height - highest_stream_height; - encoder_bitrate_limits_ = + absl::optional encoder_bitrate_limits = encoder_->GetEncoderInfo().GetEncoderBitrateLimitsForResolution( last_frame_info_->width * last_frame_info_->height); - if (streams.size() == 1 && encoder_bitrate_limits_) { - // Bitrate limits can be set by app (in SDP or RtpEncodingParameters) or/and - // can be provided by encoder. In presence of both set of limits, the final - // set is derived as their intersection. - int min_bitrate_bps; - if (encoder_config_.simulcast_layers.empty() || - encoder_config_.simulcast_layers[0].min_bitrate_bps <= 0) { - min_bitrate_bps = encoder_bitrate_limits_->min_bitrate_bps; - } else { - min_bitrate_bps = std::max(encoder_bitrate_limits_->min_bitrate_bps, - streams.back().min_bitrate_bps); - } + if (encoder_bitrate_limits) { + if (streams.size() == 1 && encoder_config_.simulcast_layers.size() == 1) { + // Bitrate limits can be set by app (in SDP or RtpEncodingParameters) + // or/and can be provided by encoder. In presence of both set of limits, + // the final set is derived as their intersection. + int min_bitrate_bps; + if (encoder_config_.simulcast_layers.empty() || + encoder_config_.simulcast_layers[0].min_bitrate_bps <= 0) { + min_bitrate_bps = encoder_bitrate_limits->min_bitrate_bps; + } else { + min_bitrate_bps = std::max(encoder_bitrate_limits->min_bitrate_bps, + streams.back().min_bitrate_bps); + } - int max_bitrate_bps; - // We don't check encoder_config_.simulcast_layers[0].max_bitrate_bps - // here since encoder_config_.max_bitrate_bps is derived from it (as - // well as from other inputs). - if (encoder_config_.max_bitrate_bps <= 0) { - max_bitrate_bps = encoder_bitrate_limits_->max_bitrate_bps; - } else { - max_bitrate_bps = std::min(encoder_bitrate_limits_->max_bitrate_bps, - streams.back().max_bitrate_bps); - } + int max_bitrate_bps; + // We don't check encoder_config_.simulcast_layers[0].max_bitrate_bps + // here since encoder_config_.max_bitrate_bps is derived from it (as + // well as from other inputs). + if (encoder_config_.max_bitrate_bps <= 0) { + max_bitrate_bps = encoder_bitrate_limits->max_bitrate_bps; + } else { + max_bitrate_bps = std::min(encoder_bitrate_limits->max_bitrate_bps, + streams.back().max_bitrate_bps); + } - if (min_bitrate_bps < max_bitrate_bps) { - streams.back().min_bitrate_bps = min_bitrate_bps; - streams.back().max_bitrate_bps = max_bitrate_bps; - streams.back().target_bitrate_bps = - std::min(streams.back().target_bitrate_bps, - encoder_bitrate_limits_->max_bitrate_bps); - } else { - RTC_LOG(LS_WARNING) << "Bitrate limits provided by encoder" - << " (min=" - << encoder_bitrate_limits_->min_bitrate_bps - << ", max=" - << encoder_bitrate_limits_->min_bitrate_bps - << ") do not intersect with limits set by app" - << " (min=" << streams.back().min_bitrate_bps - << ", max=" << encoder_config_.max_bitrate_bps - << "). The app bitrate limits will be used."; + if (min_bitrate_bps < max_bitrate_bps) { + streams.back().min_bitrate_bps = min_bitrate_bps; + streams.back().max_bitrate_bps = max_bitrate_bps; + streams.back().target_bitrate_bps = + std::min(streams.back().target_bitrate_bps, + encoder_bitrate_limits->max_bitrate_bps); + } else { + RTC_LOG(LS_WARNING) + << "Bitrate limits provided by encoder" + << " (min=" << encoder_bitrate_limits->min_bitrate_bps + << ", max=" << encoder_bitrate_limits->max_bitrate_bps + << ") do not intersect with limits set by app" + << " (min=" << streams.back().min_bitrate_bps + << ", max=" << encoder_config_.max_bitrate_bps + << "). The app bitrate limits will be used."; + } } } + ApplyEncoderBitrateLimitsIfSingleActiveStream( + GetEncoderInfoWithBitrateLimitUpdate( + encoder_->GetEncoderInfo(), encoder_config_, default_limits_allowed_), + encoder_config_.simulcast_layers, &streams); + VideoCodec codec; if (!VideoCodecInitializer::SetupCodec(encoder_config_, streams, &codec)) { RTC_LOG(LS_ERROR) << "Failed to create encoder configuration."; @@ -818,6 +981,10 @@ void VideoStreamEncoder::ReconfigureEncoder() { // thus some cropping might be needed. crop_width_ = last_frame_info_->width - codec.width; crop_height_ = last_frame_info_->height - codec.height; + ApplyVp9BitrateLimits(GetEncoderInfoWithBitrateLimitUpdate( + encoder_->GetEncoderInfo(), encoder_config_, + default_limits_allowed_), + encoder_config_, &codec); } char log_stream_buf[4 * 1024]; @@ -869,14 +1036,29 @@ void VideoStreamEncoder::ReconfigureEncoder() { max_framerate = std::max(stream.max_framerate, max_framerate); } - main_queue_->PostTask( - ToQueuedTask(task_safety_, [this, max_framerate, alignment]() { + // The resolutions that we're actually encoding with. + std::vector encoder_resolutions; + // TODO(hbos): For the case of SVC, also make use of |codec.spatialLayers|. + // For now, SVC layers are handled by the VP9 encoder. + for (const auto& simulcastStream : codec.simulcastStream) { + if (!simulcastStream.active) + continue; + encoder_resolutions.emplace_back(simulcastStream.width, + simulcastStream.height); + } + main_queue_->PostTask(ToQueuedTask( + task_safety_, [this, max_framerate, alignment, + encoder_resolutions = std::move(encoder_resolutions)]() { RTC_DCHECK_RUN_ON(main_queue_); if (max_framerate != video_source_sink_controller_.frame_rate_upper_limit() || - alignment != video_source_sink_controller_.resolution_alignment()) { + alignment != video_source_sink_controller_.resolution_alignment() || + encoder_resolutions != + video_source_sink_controller_.resolutions()) { video_source_sink_controller_.SetFrameRateUpperLimit(max_framerate); video_source_sink_controller_.SetResolutionAlignment(alignment); + video_source_sink_controller_.SetResolutions( + std::move(encoder_resolutions)); video_source_sink_controller_.PushSourceSinkSettings(); } })); @@ -911,8 +1093,6 @@ void VideoStreamEncoder::ReconfigureEncoder() { } send_codec_ = codec; - encoder_switch_experiment_.SetCodec(send_codec_.codecType); - // Keep the same encoder, as long as the video_format is unchanged. // Encoder creation block is split in two since EncoderInfo needed to start // CPU adaptation with the correct settings should be polled after @@ -963,7 +1143,7 @@ void VideoStreamEncoder::ReconfigureEncoder() { } if (pending_encoder_creation_) { - stream_resource_manager_.EnsureEncodeUsageResourceStarted(); + stream_resource_manager_.ConfigureEncodeUsageResource(); pending_encoder_creation_ = false; } @@ -1040,8 +1220,10 @@ void VideoStreamEncoder::ReconfigureEncoder() { } void VideoStreamEncoder::OnEncoderSettingsChanged() { - EncoderSettings encoder_settings(encoder_->GetEncoderInfo(), - encoder_config_.Copy(), send_codec_); + EncoderSettings encoder_settings( + GetEncoderInfoWithBitrateLimitUpdate( + encoder_->GetEncoderInfo(), encoder_config_, default_limits_allowed_), + encoder_config_.Copy(), send_codec_); stream_resource_manager_.SetEncoderSettings(encoder_settings); input_state_provider_.OnEncoderSettingsChanged(encoder_settings); bool is_screenshare = encoder_settings.encoder_config().content_type == @@ -1121,7 +1303,7 @@ void VideoStreamEncoder::OnFrame(const VideoFrame& video_frame) { MaybeEncodeVideoFrame(incoming_frame, post_time_us); } else { if (cwnd_frame_drop) { - // Frame drop by congestion window pusback. Do not encode this + // Frame drop by congestion window pushback. Do not encode this // frame. ++dropped_frame_cwnd_pushback_count_; encoder_stats_observer_->OnFrameDropped( @@ -1258,7 +1440,7 @@ void VideoStreamEncoder::SetEncoderRates( // |bitrate_allocation| is 0 it means that the network is down or the send // pacer is full. We currently only report this if the encoder has an internal // source. If the encoder does not have an internal source, higher levels - // are expected to not call AddVideoFrame. We do this since its unclear + // are expected to not call AddVideoFrame. We do this since it is unclear // how current encoder implementations behave when given a zero target // bitrate. // TODO(perkj): Make sure all known encoder implementations handle zero @@ -1319,7 +1501,7 @@ void VideoStreamEncoder::MaybeEncodeVideoFrame(const VideoFrame& video_frame, VideoFrame::UpdateRect{0, 0, video_frame.width(), video_frame.height()}; } - // We have to create then encoder before the frame drop logic, + // We have to create the encoder before the frame drop logic, // because the latter depends on encoder_->GetScalingSettings. // According to the testcase // InitialFrameDropOffWhenEncoderDisabledScaling, the return value @@ -1434,6 +1616,12 @@ void VideoStreamEncoder::EncodeVideoFrame(const VideoFrame& video_frame, if (encoder_failed_) return; + // It's possible that EncodeVideoFrame can be called after we've completed + // a Stop() operation. Check if the encoder_ is set before continuing. + // See: bugs.webrtc.org/12857 + if (!encoder_) + return; + TraceFrameDropEnd(); // Encoder metadata needs to be updated before encode complete callback. @@ -1449,6 +1637,7 @@ void VideoStreamEncoder::EncodeVideoFrame(const VideoFrame& video_frame, if (encoder_info_ != info) { OnEncoderSettingsChanged(); + stream_resource_manager_.ConfigureEncodeUsageResource(); RTC_LOG(LS_INFO) << "Encoder settings changed from " << encoder_info_.ToString() << " to " << info.ToString(); } @@ -1465,45 +1654,12 @@ void VideoStreamEncoder::EncodeVideoFrame(const VideoFrame& video_frame, last_encode_info_ms_ = clock_->TimeInMilliseconds(); VideoFrame out_frame(video_frame); - if (out_frame.video_frame_buffer()->type() == - VideoFrameBuffer::Type::kNative && - !info.supports_native_handle) { - // This module only supports software encoding. - rtc::scoped_refptr buffer = - out_frame.video_frame_buffer()->GetMappedFrameBuffer( - info.preferred_pixel_formats); - bool buffer_was_converted = false; - if (!buffer) { - buffer = out_frame.video_frame_buffer()->ToI420(); - // TODO(https://crbug.com/webrtc/12021): Once GetI420 is pure virtual, - // this just true as an I420 buffer would return from - // GetMappedFrameBuffer. - buffer_was_converted = - (out_frame.video_frame_buffer()->GetI420() == nullptr); - } - if (!buffer) { - RTC_LOG(LS_ERROR) << "Frame conversion failed, dropping frame."; - return; - } - - VideoFrame::UpdateRect update_rect = out_frame.update_rect(); - if (!update_rect.IsEmpty() && - out_frame.video_frame_buffer()->GetI420() == nullptr) { - // UpdatedRect is reset to full update if it's not empty, and buffer was - // converted, therefore we can't guarantee that pixels outside of - // UpdateRect didn't change comparing to the previous frame. - update_rect = - VideoFrame::UpdateRect{0, 0, out_frame.width(), out_frame.height()}; - } - out_frame.set_video_frame_buffer(buffer); - out_frame.set_update_rect(update_rect); - } - - // Crop frame if needed. + // Crop or scale the frame if needed. Dimension may be reduced to fit encoder + // requirements, e.g. some encoders may require them to be divisible by 4. if ((crop_width_ > 0 || crop_height_ > 0) && - out_frame.video_frame_buffer()->type() != - VideoFrameBuffer::Type::kNative) { - // If the frame can't be converted to I420, drop it. + (out_frame.video_frame_buffer()->type() != + VideoFrameBuffer::Type::kNative || + !info.supports_native_handle)) { int cropped_width = video_frame.width() - crop_width_; int cropped_height = video_frame.height() - crop_height_; rtc::scoped_refptr cropped_buffer; @@ -1511,6 +1667,7 @@ void VideoStreamEncoder::EncodeVideoFrame(const VideoFrame& video_frame, // happen after SinkWants signaled correctly from ReconfigureEncoder. VideoFrame::UpdateRect update_rect = video_frame.update_rect(); if (crop_width_ < 4 && crop_height_ < 4) { + // The difference is small, crop without scaling. cropped_buffer = video_frame.video_frame_buffer()->CropAndScale( crop_width_ / 2, crop_height_ / 2, cropped_width, cropped_height, cropped_width, cropped_height); @@ -1520,6 +1677,7 @@ void VideoStreamEncoder::EncodeVideoFrame(const VideoFrame& video_frame, VideoFrame::UpdateRect{0, 0, cropped_width, cropped_height}); } else { + // The difference is large, scale it. cropped_buffer = video_frame.video_frame_buffer()->Scale(cropped_width, cropped_height); if (!update_rect.IsEmpty()) { @@ -1564,14 +1722,12 @@ void VideoStreamEncoder::EncodeVideoFrame(const VideoFrame& video_frame, stream_resource_manager_.OnEncodeStarted(out_frame, time_when_posted_us); - RTC_DCHECK_LE(send_codec_.width, out_frame.width()); - RTC_DCHECK_LE(send_codec_.height, out_frame.height()); - // Native frames should be scaled by the client. - // For internal encoders we scale everything in one place here. - RTC_DCHECK((out_frame.video_frame_buffer()->type() == - VideoFrameBuffer::Type::kNative) || - (send_codec_.width == out_frame.width() && - send_codec_.height == out_frame.height())); + // The encoder should get the size that it expects. + RTC_DCHECK(send_codec_.width <= out_frame.width() && + send_codec_.height <= out_frame.height()) + << "Encoder configured to " << send_codec_.width << "x" + << send_codec_.height << " received a too small frame " + << out_frame.width() << "x" << out_frame.height(); TRACE_EVENT1("webrtc", "VCMGenericEncoder::Encode", "timestamp", out_frame.timestamp()); @@ -1624,6 +1780,9 @@ void VideoStreamEncoder::SendKeyFrame() { TRACE_EVENT0("webrtc", "OnKeyFrameRequest"); RTC_DCHECK(!next_frame_types_.empty()); + if (!encoder_) + return; // Shutting down. + // TODO(webrtc:10615): Map keyframe request to spatial layer. std::fill(next_frame_types_.begin(), next_frame_types_.end(), VideoFrameType::kVideoFrameKey); @@ -1679,6 +1838,18 @@ EncodedImageCallback::Result VideoStreamEncoder::OnEncodedImage( frame_encode_metadata_writer_.UpdateBitstream(codec_specific_info, &image_copy); + VideoCodecType codec_type = codec_specific_info + ? codec_specific_info->codecType + : VideoCodecType::kVideoCodecGeneric; + + if (image_copy.qp_ < 0 && qp_parsing_allowed_) { + // Parse encoded frame QP if that was not provided by encoder. + image_copy.qp_ = qp_parser_ + .Parse(codec_type, spatial_idx, image_copy.data(), + image_copy.size()) + .value_or(-1); + } + // Piggyback ALR experiment group id and simulcast id into the content type. const uint8_t experiment_id = experiment_groups_[videocontenttypehelpers::IsScreenshare( @@ -1701,12 +1872,9 @@ EncodedImageCallback::Result VideoStreamEncoder::OnEncodedImage( // Post a task because |send_codec_| requires |encoder_queue_| lock. unsigned int image_width = image_copy._encodedWidth; unsigned int image_height = image_copy._encodedHeight; - VideoCodecType codec = codec_specific_info - ? codec_specific_info->codecType - : VideoCodecType::kVideoCodecGeneric; - encoder_queue_.PostTask([this, codec, image_width, image_height] { + encoder_queue_.PostTask([this, codec_type, image_width, image_height] { RTC_DCHECK_RUN_ON(&encoder_queue_); - if (codec == VideoCodecType::kVideoCodecVP9 && + if (codec_type == VideoCodecType::kVideoCodecVP9 && send_codec_.VP9()->automaticResizeOn) { unsigned int expected_width = send_codec_.width; unsigned int expected_height = send_codec_.height; @@ -1852,22 +2020,10 @@ void VideoStreamEncoder::OnBitrateUpdated(DataRate target_bitrate, const bool video_is_suspended = target_bitrate == DataRate::Zero(); const bool video_suspension_changed = video_is_suspended != EncoderPaused(); - if (!video_is_suspended && settings_.encoder_switch_request_callback) { - if (encoder_selector_) { - if (auto encoder = - encoder_selector_->OnAvailableBitrate(link_allocation)) { - QueueRequestEncoderSwitch(*encoder); - } - } else if (encoder_switch_experiment_.IsBitrateBelowThreshold( - target_bitrate) && - !encoder_switch_requested_) { - EncoderSwitchRequestCallback::Config conf; - conf.codec_name = encoder_switch_experiment_.to_codec; - conf.param = encoder_switch_experiment_.to_param; - conf.value = encoder_switch_experiment_.to_value; - QueueRequestEncoderSwitch(conf); - - encoder_switch_requested_ = true; + if (!video_is_suspended && settings_.encoder_switch_request_callback && + encoder_selector_) { + if (auto encoder = encoder_selector_->OnAvailableBitrate(link_allocation)) { + QueueRequestEncoderSwitch(*encoder); } } @@ -1932,19 +2088,24 @@ bool VideoStreamEncoder::DropDueToSize(uint32_t pixel_count) const { } } + uint32_t bitrate_bps = + stream_resource_manager_.UseBandwidthAllocationBps().value_or( + encoder_target_bitrate_bps_.value()); + absl::optional encoder_bitrate_limits = - encoder_->GetEncoderInfo().GetEncoderBitrateLimitsForResolution( - pixel_count); + GetEncoderInfoWithBitrateLimitUpdate( + encoder_->GetEncoderInfo(), encoder_config_, default_limits_allowed_) + .GetEncoderBitrateLimitsForResolution(pixel_count); if (encoder_bitrate_limits.has_value()) { // Use bitrate limits provided by encoder. - return encoder_target_bitrate_bps_.value() < + return bitrate_bps < static_cast(encoder_bitrate_limits->min_start_bitrate_bps); } - if (encoder_target_bitrate_bps_.value() < 300000 /* qvga */) { + if (bitrate_bps < 300000 /* qvga */) { return pixel_count > 320 * 240; - } else if (encoder_target_bitrate_bps_.value() < 500000 /* vga */) { + } else if (bitrate_bps < 500000 /* vga */) { return pixel_count > 640 * 480; } return false; @@ -2012,7 +2173,8 @@ void VideoStreamEncoder::RunPostEncode(const EncodedImage& encoded_image, stream_resource_manager_.OnEncodeCompleted(encoded_image, time_sent_us, encode_duration_us); if (bitrate_adjuster_) { - bitrate_adjuster_->OnEncodedFrame(encoded_image, temporal_index); + bitrate_adjuster_->OnEncodedFrame( + frame_size, encoded_image.SpatialIndex().value_or(0), temporal_index); } } @@ -2031,113 +2193,6 @@ void VideoStreamEncoder::ReleaseEncoder() { TRACE_EVENT0("webrtc", "VCMGenericEncoder::Release"); } -bool VideoStreamEncoder::EncoderSwitchExperiment::IsBitrateBelowThreshold( - const DataRate& target_bitrate) { - DataRate rate = DataRate::KilobitsPerSec( - bitrate_filter.Apply(1.0, target_bitrate.kbps())); - return current_thresholds.bitrate && rate < *current_thresholds.bitrate; -} - -bool VideoStreamEncoder::EncoderSwitchExperiment::IsPixelCountBelowThreshold( - int pixel_count) const { - return current_thresholds.pixel_count && - pixel_count < *current_thresholds.pixel_count; -} - -void VideoStreamEncoder::EncoderSwitchExperiment::SetCodec( - VideoCodecType codec) { - auto it = codec_thresholds.find(codec); - if (it == codec_thresholds.end()) { - current_thresholds = {}; - } else { - current_thresholds = it->second; - } -} - -VideoStreamEncoder::EncoderSwitchExperiment -VideoStreamEncoder::ParseEncoderSwitchFieldTrial() const { - EncoderSwitchExperiment result; - - // Each "codec threshold" have the format - // ";;", and are separated by the "|" - // character. - webrtc::FieldTrialOptional codec_thresholds_string{ - "codec_thresholds"}; - webrtc::FieldTrialOptional to_codec{"to_codec"}; - webrtc::FieldTrialOptional to_param{"to_param"}; - webrtc::FieldTrialOptional to_value{"to_value"}; - webrtc::FieldTrialOptional window{"window"}; - - webrtc::ParseFieldTrial( - {&codec_thresholds_string, &to_codec, &to_param, &to_value, &window}, - webrtc::field_trial::FindFullName( - "WebRTC-NetworkCondition-EncoderSwitch")); - - if (!codec_thresholds_string || !to_codec || !window) { - return {}; - } - - result.bitrate_filter.Reset(1.0 - 1.0 / *window); - result.to_codec = *to_codec; - result.to_param = to_param.GetOptional(); - result.to_value = to_value.GetOptional(); - - std::vector codecs_thresholds; - if (rtc::split(*codec_thresholds_string, '|', &codecs_thresholds) == 0) { - return {}; - } - - for (const std::string& codec_threshold : codecs_thresholds) { - std::vector thresholds_split; - if (rtc::split(codec_threshold, ';', &thresholds_split) != 3) { - return {}; - } - - VideoCodecType codec = PayloadStringToCodecType(thresholds_split[0]); - int bitrate_kbps; - rtc::FromString(thresholds_split[1], &bitrate_kbps); - int pixel_count; - rtc::FromString(thresholds_split[2], &pixel_count); - - if (bitrate_kbps > 0) { - result.codec_thresholds[codec].bitrate = - DataRate::KilobitsPerSec(bitrate_kbps); - } - - if (pixel_count > 0) { - result.codec_thresholds[codec].pixel_count = pixel_count; - } - - if (!result.codec_thresholds[codec].bitrate && - !result.codec_thresholds[codec].pixel_count) { - return {}; - } - } - - rtc::StringBuilder ss; - ss << "Successfully parsed WebRTC-NetworkCondition-EncoderSwitch field " - "trial." - " to_codec:" - << result.to_codec << " to_param:" << result.to_param.value_or("") - << " to_value:" << result.to_value.value_or("") - << " codec_thresholds:"; - - for (auto kv : result.codec_thresholds) { - std::string codec_name = CodecTypeToPayloadString(kv.first); - std::string bitrate = kv.second.bitrate - ? std::to_string(kv.second.bitrate->kbps()) - : ""; - std::string pixels = kv.second.pixel_count - ? std::to_string(*kv.second.pixel_count) - : ""; - ss << " (" << codec_name << ":" << bitrate << ":" << pixels << ")"; - } - - RTC_LOG(LS_INFO) << ss.str(); - - return result; -} - VideoStreamEncoder::AutomaticAnimationDetectionExperiment VideoStreamEncoder::ParseAutomatincAnimationDetectionFieldTrial() const { AutomaticAnimationDetectionExperiment result; diff --git a/video/video_stream_encoder.h b/video/video_stream_encoder.h index ff04329fdd..9e70203661 100644 --- a/video/video_stream_encoder.h +++ b/video/video_stream_encoder.h @@ -18,6 +18,7 @@ #include #include "api/adaptation/resource.h" +#include "api/sequence_checker.h" #include "api/units/data_rate.h" #include "api/video/video_bitrate_allocator.h" #include "api/video/video_rotation.h" @@ -33,6 +34,7 @@ #include "call/adaptation/video_source_restrictions.h" #include "call/adaptation/video_stream_input_state_provider.h" #include "modules/video_coding/utility/frame_dropper.h" +#include "modules/video_coding/utility/qp_parser.h" #include "rtc_base/experiments/rate_control_settings.h" #include "rtc_base/numerics/exp_filter.h" #include "rtc_base/race_checker.h" @@ -40,7 +42,6 @@ #include "rtc_base/task_queue.h" #include "rtc_base/task_utils/pending_task_safety_flag.h" #include "rtc_base/thread_annotations.h" -#include "rtc_base/thread_checker.h" #include "system_wrappers/include/clock.h" #include "video/adaptation/video_stream_encoder_resource_manager.h" #include "video/encoder_bitrate_adjuster.h" @@ -181,7 +182,7 @@ class VideoStreamEncoder : public VideoStreamEncoderInterface, void EncodeVideoFrame(const VideoFrame& frame, int64_t time_when_posted_in_ms); - // Indicates wether frame should be dropped because the pixel count is too + // Indicates whether frame should be dropped because the pixel count is too // large for the current bitrate configuration. bool DropDueToSize(uint32_t pixel_count) const RTC_RUN_ON(&encoder_queue_); @@ -229,8 +230,6 @@ class VideoStreamEncoder : public VideoStreamEncoderInterface, const uint32_t number_of_cores_; - const bool quality_scaling_experiment_enabled_; - EncoderSink* sink_; const VideoStreamEncoderSettings settings_; const BitrateAllocationCallbackType allocation_cb_type_; @@ -314,8 +313,6 @@ class VideoStreamEncoder : public VideoStreamEncoderInterface, absl::optional last_encode_info_ms_ RTC_GUARDED_BY(&encoder_queue_); VideoEncoder::EncoderInfo encoder_info_ RTC_GUARDED_BY(&encoder_queue_); - absl::optional encoder_bitrate_limits_ - RTC_GUARDED_BY(&encoder_queue_); VideoEncoderFactory::CodecInfo codec_info_ RTC_GUARDED_BY(&encoder_queue_); VideoCodec send_codec_ RTC_GUARDED_BY(&encoder_queue_); @@ -352,38 +349,6 @@ class VideoStreamEncoder : public VideoStreamEncoderInterface, // experiment group numbers incremented by 1. const std::array experiment_groups_; - struct EncoderSwitchExperiment { - struct Thresholds { - absl::optional bitrate; - absl::optional pixel_count; - }; - - // Codec --> switching thresholds - std::map codec_thresholds; - - // To smooth out the target bitrate so that we don't trigger a switch - // too easily. - rtc::ExpFilter bitrate_filter{1.0}; - - // Codec/implementation to switch to - std::string to_codec; - absl::optional to_param; - absl::optional to_value; - - // Thresholds for the currently used codecs. - Thresholds current_thresholds; - - // Updates the |bitrate_filter|, so not const. - bool IsBitrateBelowThreshold(const DataRate& target_bitrate); - bool IsPixelCountBelowThreshold(int pixel_count) const; - void SetCodec(VideoCodecType codec); - }; - - EncoderSwitchExperiment ParseEncoderSwitchFieldTrial() const; - - EncoderSwitchExperiment encoder_switch_experiment_ - RTC_GUARDED_BY(&encoder_queue_); - struct AutomaticAnimationDetectionExperiment { bool enabled = false; int min_duration_ms = 2000; @@ -404,11 +369,7 @@ class VideoStreamEncoder : public VideoStreamEncoderInterface, AutomaticAnimationDetectionExperiment automatic_animation_detection_experiment_ RTC_GUARDED_BY(&encoder_queue_); - // An encoder switch is only requested once, this variable is used to keep - // track of whether a request has been made or not. - bool encoder_switch_requested_ RTC_GUARDED_BY(&encoder_queue_); - - // Provies video stream input states: current resolution and frame rate. + // Provides video stream input states: current resolution and frame rate. VideoStreamInputStateProvider input_state_provider_; std::unique_ptr video_stream_adapter_ @@ -424,7 +385,7 @@ class VideoStreamEncoder : public VideoStreamEncoderInterface, RTC_GUARDED_BY(&encoder_queue_); // Handles input, output and stats reporting related to VideoStreamEncoder // specific resources, such as "encode usage percent" measurements and "QP - // scaling". Also involved with various mitigations such as inital frame + // scaling". Also involved with various mitigations such as initial frame // dropping. // The manager primarily operates on the |encoder_queue_| but its lifetime is // tied to the VideoStreamEncoder (which is destroyed off the encoder queue) @@ -440,6 +401,14 @@ class VideoStreamEncoder : public VideoStreamEncoderInterface, VideoSourceSinkController video_source_sink_controller_ RTC_GUARDED_BY(main_queue_); + // Default bitrate limits in EncoderInfoSettings allowed. + const bool default_limits_allowed_; + + // QP parser is used to extract QP value from encoded frame when that is not + // provided by encoder. + QpParser qp_parser_; + const bool qp_parsing_allowed_; + // Public methods are proxied to the task queues. The queues must be destroyed // first to make sure no tasks run that use other members. rtc::TaskQueue encoder_queue_; diff --git a/video/video_stream_encoder_unittest.cc b/video/video_stream_encoder_unittest.cc index 85be6951e5..cbfd93e9e2 100644 --- a/video/video_stream_encoder_unittest.cc +++ b/video/video_stream_encoder_unittest.cc @@ -13,17 +13,20 @@ #include #include #include +#include #include #include "absl/memory/memory.h" #include "api/task_queue/default_task_queue_factory.h" #include "api/test/mock_fec_controller_override.h" #include "api/test/mock_video_encoder.h" +#include "api/test/mock_video_encoder_factory.h" #include "api/video/builtin_video_bitrate_allocator_factory.h" #include "api/video/i420_buffer.h" #include "api/video/nv12_buffer.h" #include "api/video/video_adaptation_reason.h" #include "api/video/video_bitrate_allocation.h" +#include "api/video_codecs/sdp_video_format.h" #include "api/video_codecs/video_encoder.h" #include "api/video_codecs/vp8_temporal_layers.h" #include "api/video_codecs/vp8_temporal_layers_factory.h" @@ -33,10 +36,17 @@ #include "common_video/include/video_frame_buffer.h" #include "media/base/video_adapter.h" #include "media/engine/webrtc_video_engine.h" +#include "modules/video_coding/codecs/av1/libaom_av1_encoder.h" +#include "modules/video_coding/codecs/h264/include/h264.h" +#include "modules/video_coding/codecs/multiplex/include/multiplex_encoder_adapter.h" +#include "modules/video_coding/codecs/vp8/include/vp8.h" +#include "modules/video_coding/codecs/vp9/include/vp9.h" #include "modules/video_coding/codecs/vp9/include/vp9_globals.h" +#include "modules/video_coding/codecs/vp9/svc_config.h" #include "modules/video_coding/utility/quality_scaler.h" #include "modules/video_coding/utility/simulcast_rate_allocator.h" #include "rtc_base/event.h" +#include "rtc_base/experiments/encoder_info_settings.h" #include "rtc_base/gunit.h" #include "rtc_base/logging.h" #include "rtc_base/ref_counted_object.h" @@ -49,6 +59,7 @@ #include "test/frame_forwarder.h" #include "test/gmock.h" #include "test/gtest.h" +#include "test/mappable_native_buffer.h" #include "test/time_controller/simulated_time_controller.h" #include "test/video_encoder_proxy_factory.h" #include "video/send_statistics_proxy.h" @@ -96,6 +107,11 @@ uint8_t optimal_sps[] = {0, 0, 0, 1, H264::NaluType::kSps, 0x05, 0x03, 0xC7, 0xE0, 0x1B, 0x41, 0x10, 0x8D, 0x00}; +const uint8_t kCodedFrameVp8Qp25[] = { + 0x10, 0x02, 0x00, 0x9d, 0x01, 0x2a, 0x10, 0x00, 0x10, 0x00, + 0x02, 0x47, 0x08, 0x85, 0x85, 0x88, 0x85, 0x84, 0x88, 0x0c, + 0x82, 0x00, 0x0c, 0x0d, 0x60, 0x00, 0xfe, 0xfc, 0x5c, 0xd0}; + class TestBuffer : public webrtc::I420Buffer { public: TestBuffer(rtc::Event* event, int width, int height) @@ -110,7 +126,8 @@ class TestBuffer : public webrtc::I420Buffer { rtc::Event* const event_; }; -// A fake native buffer that can't be converted to I420. +// A fake native buffer that can't be converted to I420. Upon scaling, it +// produces another FakeNativeBuffer. class FakeNativeBuffer : public webrtc::VideoFrameBuffer { public: FakeNativeBuffer(rtc::Event* event, int width, int height) @@ -121,6 +138,16 @@ class FakeNativeBuffer : public webrtc::VideoFrameBuffer { rtc::scoped_refptr ToI420() override { return nullptr; } + rtc::scoped_refptr CropAndScale( + int offset_x, + int offset_y, + int crop_width, + int crop_height, + int scaled_width, + int scaled_height) override { + return rtc::make_ref_counted(nullptr, scaled_width, + scaled_height); + } private: friend class rtc::RefCountedObject; @@ -461,6 +488,10 @@ class AdaptingFrameForwarder : public test::FrameForwarder { return adaptation_enabled_; } + // The "last wants" is a snapshot of the previous rtc::VideoSinkWants where + // the resolution or frame rate was different than it is currently. If + // something else is modified, such as encoder resolutions, but the resolution + // and frame rate stays the same, last wants is not updated. rtc::VideoSinkWants last_wants() const { MutexLock lock(&mutex_); return last_wants_; @@ -487,13 +518,12 @@ class AdaptingFrameForwarder : public test::FrameForwarder { &cropped_height, &out_width, &out_height)) { VideoFrame adapted_frame = VideoFrame::Builder() - .set_video_frame_buffer(new rtc::RefCountedObject( + .set_video_frame_buffer(rtc::make_ref_counted( nullptr, out_width, out_height)) - .set_timestamp_rtp(99) + .set_ntp_time_ms(video_frame.ntp_time_ms()) .set_timestamp_ms(99) .set_rotation(kVideoRotation_0) .build(); - adapted_frame.set_ntp_time_ms(video_frame.ntp_time_ms()); if (video_frame.has_update_rect()) { adapted_frame.set_update_rect( video_frame.update_rect().ScaleWithFrame( @@ -516,10 +546,26 @@ class AdaptingFrameForwarder : public test::FrameForwarder { } } + void OnOutputFormatRequest(int width, int height) { + absl::optional> target_aspect_ratio = + std::make_pair(width, height); + absl::optional max_pixel_count = width * height; + absl::optional max_fps; + adapter_.OnOutputFormatRequest(target_aspect_ratio, max_pixel_count, + max_fps); + } + void AddOrUpdateSink(rtc::VideoSinkInterface* sink, const rtc::VideoSinkWants& wants) override { MutexLock lock(&mutex_); - last_wants_ = sink_wants_locked(); + rtc::VideoSinkWants prev_wants = sink_wants_locked(); + bool did_adapt = + prev_wants.max_pixel_count != wants.max_pixel_count || + prev_wants.target_pixel_count != wants.target_pixel_count || + prev_wants.max_framerate_fps != wants.max_framerate_fps; + if (did_adapt) { + last_wants_ = prev_wants; + } adapter_.OnSinkWants(wants); test::FrameForwarder::AddOrUpdateSinkLocked(sink, wants); } @@ -679,99 +725,80 @@ class VideoStreamEncoderTest : public ::testing::Test { vp9_settings.numberOfSpatialLayers = num_spatial_layers; vp9_settings.automaticResizeOn = num_spatial_layers <= 1; video_encoder_config.encoder_specific_settings = - new rtc::RefCountedObject< - VideoEncoderConfig::Vp9EncoderSpecificSettings>(vp9_settings); + rtc::make_ref_counted( + vp9_settings); } ConfigureEncoder(std::move(video_encoder_config), allocation_callback_type); } VideoFrame CreateFrame(int64_t ntp_time_ms, rtc::Event* destruction_event) const { - VideoFrame frame = - VideoFrame::Builder() - .set_video_frame_buffer(new rtc::RefCountedObject( - destruction_event, codec_width_, codec_height_)) - .set_timestamp_rtp(99) - .set_timestamp_ms(99) - .set_rotation(kVideoRotation_0) - .build(); - frame.set_ntp_time_ms(ntp_time_ms); - return frame; + return VideoFrame::Builder() + .set_video_frame_buffer(rtc::make_ref_counted( + destruction_event, codec_width_, codec_height_)) + .set_ntp_time_ms(ntp_time_ms) + .set_timestamp_ms(99) + .set_rotation(kVideoRotation_0) + .build(); } VideoFrame CreateFrameWithUpdatedPixel(int64_t ntp_time_ms, rtc::Event* destruction_event, int offset_x) const { - VideoFrame frame = - VideoFrame::Builder() - .set_video_frame_buffer(new rtc::RefCountedObject( - destruction_event, codec_width_, codec_height_)) - .set_timestamp_rtp(99) - .set_timestamp_ms(99) - .set_rotation(kVideoRotation_0) - .set_update_rect(VideoFrame::UpdateRect{offset_x, 0, 1, 1}) - .build(); - frame.set_ntp_time_ms(ntp_time_ms); - return frame; + return VideoFrame::Builder() + .set_video_frame_buffer(rtc::make_ref_counted( + destruction_event, codec_width_, codec_height_)) + .set_ntp_time_ms(ntp_time_ms) + .set_timestamp_ms(99) + .set_rotation(kVideoRotation_0) + .set_update_rect(VideoFrame::UpdateRect{offset_x, 0, 1, 1}) + .build(); } VideoFrame CreateFrame(int64_t ntp_time_ms, int width, int height) const { - VideoFrame frame = - VideoFrame::Builder() - .set_video_frame_buffer( - new rtc::RefCountedObject(nullptr, width, height)) - .set_timestamp_rtp(99) - .set_timestamp_ms(99) - .set_rotation(kVideoRotation_0) - .build(); - frame.set_ntp_time_ms(ntp_time_ms); - frame.set_timestamp_us(ntp_time_ms * 1000); - return frame; + auto buffer = rtc::make_ref_counted(nullptr, width, height); + I420Buffer::SetBlack(buffer.get()); + return VideoFrame::Builder() + .set_video_frame_buffer(std::move(buffer)) + .set_ntp_time_ms(ntp_time_ms) + .set_timestamp_ms(ntp_time_ms) + .set_rotation(kVideoRotation_0) + .build(); } VideoFrame CreateNV12Frame(int64_t ntp_time_ms, int width, int height) const { - VideoFrame frame = - VideoFrame::Builder() - .set_video_frame_buffer(NV12Buffer::Create(width, height)) - .set_timestamp_rtp(99) - .set_timestamp_ms(99) - .set_rotation(kVideoRotation_0) - .build(); - frame.set_ntp_time_ms(ntp_time_ms); - frame.set_timestamp_us(ntp_time_ms * 1000); - return frame; + return VideoFrame::Builder() + .set_video_frame_buffer(NV12Buffer::Create(width, height)) + .set_ntp_time_ms(ntp_time_ms) + .set_timestamp_ms(ntp_time_ms) + .set_rotation(kVideoRotation_0) + .build(); } VideoFrame CreateFakeNativeFrame(int64_t ntp_time_ms, rtc::Event* destruction_event, int width, int height) const { - VideoFrame frame = - VideoFrame::Builder() - .set_video_frame_buffer(new rtc::RefCountedObject( - destruction_event, width, height)) - .set_timestamp_rtp(99) - .set_timestamp_ms(99) - .set_rotation(kVideoRotation_0) - .build(); - frame.set_ntp_time_ms(ntp_time_ms); - return frame; + return VideoFrame::Builder() + .set_video_frame_buffer(rtc::make_ref_counted( + destruction_event, width, height)) + .set_ntp_time_ms(ntp_time_ms) + .set_timestamp_ms(99) + .set_rotation(kVideoRotation_0) + .build(); } VideoFrame CreateFakeNV12NativeFrame(int64_t ntp_time_ms, rtc::Event* destruction_event, int width, int height) const { - VideoFrame frame = VideoFrame::Builder() - .set_video_frame_buffer( - new rtc::RefCountedObject( - destruction_event, width, height)) - .set_timestamp_rtp(99) - .set_timestamp_ms(99) - .set_rotation(kVideoRotation_0) - .build(); - frame.set_ntp_time_ms(ntp_time_ms); - return frame; + return VideoFrame::Builder() + .set_video_frame_buffer(rtc::make_ref_counted( + destruction_event, width, height)) + .set_ntp_time_ms(ntp_time_ms) + .set_timestamp_ms(99) + .set_rotation(kVideoRotation_0) + .build(); } VideoFrame CreateFakeNativeFrame(int64_t ntp_time_ms, @@ -948,9 +975,10 @@ class VideoStreamEncoderTest : public ::testing::Test { FakeEncoder::Encode(input_image, &frame_type); } - void InjectEncodedImage(const EncodedImage& image) { + void InjectEncodedImage(const EncodedImage& image, + const CodecSpecificInfo* codec_specific_info) { MutexLock lock(&local_mutex_); - encoded_image_callback_->OnEncodedImage(image, nullptr); + encoded_image_callback_->OnEncodedImage(image, codec_specific_info); } void SetEncodedImageData( @@ -971,6 +999,16 @@ class VideoStreamEncoderTest : public ::testing::Test { return settings; } + int GetLastInputWidth() const { + MutexLock lock(&local_mutex_); + return last_input_width_; + } + + int GetLastInputHeight() const { + MutexLock lock(&local_mutex_); + return last_input_height_; + } + absl::optional GetLastInputPixelFormat() { MutexLock lock(&local_mutex_); return last_input_pixel_format_; @@ -1236,6 +1274,11 @@ class VideoStreamEncoderTest : public ::testing::Test { return last_capture_time_ms_; } + const EncodedImage& GetLastEncodedImage() { + MutexLock lock(&mutex_); + return last_encoded_image_; + } + std::vector GetLastEncodedImageData() { MutexLock lock(&mutex_); return std::move(last_encoded_image_data_); @@ -1267,18 +1310,21 @@ class VideoStreamEncoderTest : public ::testing::Test { const CodecSpecificInfo* codec_specific_info) override { MutexLock lock(&mutex_); EXPECT_TRUE(expect_frames_); + last_encoded_image_ = EncodedImage(encoded_image); last_encoded_image_data_ = std::vector( encoded_image.data(), encoded_image.data() + encoded_image.size()); uint32_t timestamp = encoded_image.Timestamp(); if (last_timestamp_ != timestamp) { num_received_layers_ = 1; + last_width_ = encoded_image._encodedWidth; + last_height_ = encoded_image._encodedHeight; } else { ++num_received_layers_; + last_width_ = std::max(encoded_image._encodedWidth, last_width_); + last_height_ = std::max(encoded_image._encodedHeight, last_height_); } last_timestamp_ = timestamp; last_capture_time_ms_ = encoded_image.capture_time_ms_; - last_width_ = encoded_image._encodedWidth; - last_height_ = encoded_image._encodedHeight; last_rotation_ = encoded_image.rotation_; if (num_received_layers_ == num_expected_layers_) { encoded_frame_event_.Set(); @@ -1325,6 +1371,7 @@ class VideoStreamEncoderTest : public ::testing::Test { mutable Mutex mutex_; TestEncoder* test_encoder_; rtc::Event encoded_frame_event_; + EncodedImage last_encoded_image_; std::vector last_encoded_image_data_; uint32_t last_timestamp_ = 0; int64_t last_capture_time_ms_ = 0; @@ -1530,7 +1577,7 @@ TEST_F(VideoStreamEncoderBlockedTest, DropsPendingFramesOnSlowEncode) { EXPECT_EQ(1, dropped_count); } -TEST_F(VideoStreamEncoderTest, DropFrameWithFailedI420Conversion) { +TEST_F(VideoStreamEncoderTest, NativeFrameWithoutI420SupportGetsDelivered) { video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( DataRate::BitsPerSec(kTargetBitrateBps), DataRate::BitsPerSec(kTargetBitrateBps), @@ -1539,15 +1586,21 @@ TEST_F(VideoStreamEncoderTest, DropFrameWithFailedI420Conversion) { rtc::Event frame_destroyed_event; video_source_.IncomingCapturedFrame( CreateFakeNativeFrame(1, &frame_destroyed_event)); - ExpectDroppedFrame(); - EXPECT_TRUE(frame_destroyed_event.Wait(kDefaultTimeoutMs)); + WaitForEncodedFrame(1); + EXPECT_EQ(VideoFrameBuffer::Type::kNative, + fake_encoder_.GetLastInputPixelFormat()); + EXPECT_EQ(fake_encoder_.codec_config().width, + fake_encoder_.GetLastInputWidth()); + EXPECT_EQ(fake_encoder_.codec_config().height, + fake_encoder_.GetLastInputHeight()); video_stream_encoder_->Stop(); } -TEST_F(VideoStreamEncoderTest, DropFrameWithFailedI420ConversionWithCrop) { +TEST_F(VideoStreamEncoderTest, + NativeFrameWithoutI420SupportGetsCroppedIfNecessary) { // Use the cropping factory. video_encoder_config_.video_stream_factory = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); video_stream_encoder_->ConfigureEncoder(std::move(video_encoder_config_), kMaxPayloadLength); video_stream_encoder_->WaitUntilTaskQueueIsIdle(); @@ -1569,8 +1622,13 @@ TEST_F(VideoStreamEncoderTest, DropFrameWithFailedI420ConversionWithCrop) { rtc::Event frame_destroyed_event; video_source_.IncomingCapturedFrame(CreateFakeNativeFrame( 2, &frame_destroyed_event, codec_width_ + 1, codec_height_ + 1)); - ExpectDroppedFrame(); - EXPECT_TRUE(frame_destroyed_event.Wait(kDefaultTimeoutMs)); + WaitForEncodedFrame(2); + EXPECT_EQ(VideoFrameBuffer::Type::kNative, + fake_encoder_.GetLastInputPixelFormat()); + EXPECT_EQ(fake_encoder_.codec_config().width, + fake_encoder_.GetLastInputWidth()); + EXPECT_EQ(fake_encoder_.codec_config().height, + fake_encoder_.GetLastInputHeight()); video_stream_encoder_->Stop(); } @@ -1588,8 +1646,7 @@ TEST_F(VideoStreamEncoderTest, NonI420FramesShouldNotBeConvertedToI420) { video_stream_encoder_->Stop(); } -TEST_F(VideoStreamEncoderTest, - NativeFrameIsConvertedToI420IfNoFrameTypePreference) { +TEST_F(VideoStreamEncoderTest, NativeFrameGetsDelivered_NoFrameTypePreference) { video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( DataRate::BitsPerSec(kTargetBitrateBps), DataRate::BitsPerSec(kTargetBitrateBps), @@ -1601,12 +1658,13 @@ TEST_F(VideoStreamEncoderTest, video_source_.IncomingCapturedFrame(CreateFakeNV12NativeFrame( 1, &frame_destroyed_event, codec_width_, codec_height_)); WaitForEncodedFrame(1); - EXPECT_EQ(VideoFrameBuffer::Type::kI420, + EXPECT_EQ(VideoFrameBuffer::Type::kNative, fake_encoder_.GetLastInputPixelFormat()); video_stream_encoder_->Stop(); } -TEST_F(VideoStreamEncoderTest, NativeFrameMappedToPreferredPixelFormat) { +TEST_F(VideoStreamEncoderTest, + NativeFrameGetsDelivered_PixelFormatPreferenceMatches) { video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( DataRate::BitsPerSec(kTargetBitrateBps), DataRate::BitsPerSec(kTargetBitrateBps), @@ -1618,12 +1676,12 @@ TEST_F(VideoStreamEncoderTest, NativeFrameMappedToPreferredPixelFormat) { video_source_.IncomingCapturedFrame(CreateFakeNV12NativeFrame( 1, &frame_destroyed_event, codec_width_, codec_height_)); WaitForEncodedFrame(1); - EXPECT_EQ(VideoFrameBuffer::Type::kNV12, + EXPECT_EQ(VideoFrameBuffer::Type::kNative, fake_encoder_.GetLastInputPixelFormat()); video_stream_encoder_->Stop(); } -TEST_F(VideoStreamEncoderTest, NativeFrameConvertedToI420IfMappingNotFeasible) { +TEST_F(VideoStreamEncoderTest, NativeFrameGetsDelivered_MappingIsNotFeasible) { video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( DataRate::BitsPerSec(kTargetBitrateBps), DataRate::BitsPerSec(kTargetBitrateBps), @@ -1636,12 +1694,12 @@ TEST_F(VideoStreamEncoderTest, NativeFrameConvertedToI420IfMappingNotFeasible) { video_source_.IncomingCapturedFrame(CreateFakeNV12NativeFrame( 1, &frame_destroyed_event, codec_width_, codec_height_)); WaitForEncodedFrame(1); - EXPECT_EQ(VideoFrameBuffer::Type::kI420, + EXPECT_EQ(VideoFrameBuffer::Type::kNative, fake_encoder_.GetLastInputPixelFormat()); video_stream_encoder_->Stop(); } -TEST_F(VideoStreamEncoderTest, NativeFrameBackedByNV12FrameIsEncodedFromI420) { +TEST_F(VideoStreamEncoderTest, NativeFrameGetsDelivered_BackedByNV12) { video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( DataRate::BitsPerSec(kTargetBitrateBps), DataRate::BitsPerSec(kTargetBitrateBps), @@ -1651,7 +1709,7 @@ TEST_F(VideoStreamEncoderTest, NativeFrameBackedByNV12FrameIsEncodedFromI420) { video_source_.IncomingCapturedFrame(CreateFakeNV12NativeFrame( 1, &frame_destroyed_event, codec_width_, codec_height_)); WaitForEncodedFrame(1); - EXPECT_EQ(VideoFrameBuffer::Type::kI420, + EXPECT_EQ(VideoFrameBuffer::Type::kNative, fake_encoder_.GetLastInputPixelFormat()); video_stream_encoder_->Stop(); } @@ -2009,6 +2067,265 @@ TEST_F(VideoStreamEncoderTest, EncoderRecommendedMaxBitrateCapsTargetBitrate) { video_stream_encoder_->Stop(); } +TEST_F(VideoStreamEncoderTest, + EncoderMaxAndMinBitratesUsedForTwoStreamsHighestActive) { + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits270p( + 480 * 270, 34 * 1000, 12 * 1000, 1234 * 1000); + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits360p( + 640 * 360, 43 * 1000, 21 * 1000, 2345 * 1000); + fake_encoder_.SetResolutionBitrateLimits( + {kEncoderLimits270p, kEncoderLimits360p}); + + // Two streams, highest stream active. + VideoEncoderConfig config; + const int kNumStreams = 2; + test::FillEncoderConfiguration(kVideoCodecVP8, kNumStreams, &config); + config.max_bitrate_bps = 0; + config.simulcast_layers[0].active = false; + config.simulcast_layers[1].active = true; + config.video_stream_factory = + rtc::make_ref_counted( + "VP8", /*max qp*/ 56, /*screencast*/ false, + /*screenshare enabled*/ false); + video_stream_encoder_->ConfigureEncoder(config.Copy(), kMaxPayloadLength); + + // The encoder bitrate limits for 270p should be used. + video_source_.IncomingCapturedFrame(CreateFrame(1, 480, 270)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(fake_encoder_.video_codec().numberOfSimulcastStreams, kNumStreams); + EXPECT_EQ(static_cast(kEncoderLimits270p.min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_EQ(static_cast(kEncoderLimits270p.max_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + + // The encoder bitrate limits for 360p should be used. + video_source_.IncomingCapturedFrame(CreateFrame(2, 640, 360)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(static_cast(kEncoderLimits360p.min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_EQ(static_cast(kEncoderLimits360p.max_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + + // Resolution b/w 270p and 360p. The encoder limits for 360p should be used. + video_source_.IncomingCapturedFrame( + CreateFrame(3, (640 + 480) / 2, (360 + 270) / 2)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(static_cast(kEncoderLimits360p.min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_EQ(static_cast(kEncoderLimits360p.max_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + + // Resolution higher than 360p. Encoder limits should be ignored. + video_source_.IncomingCapturedFrame(CreateFrame(4, 960, 540)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_NE(static_cast(kEncoderLimits270p.min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_NE(static_cast(kEncoderLimits270p.max_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + EXPECT_NE(static_cast(kEncoderLimits360p.min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_NE(static_cast(kEncoderLimits360p.max_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + + // Resolution lower than 270p. The encoder limits for 270p should be used. + video_source_.IncomingCapturedFrame(CreateFrame(5, 320, 180)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(static_cast(kEncoderLimits270p.min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_EQ(static_cast(kEncoderLimits270p.max_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, + DefaultEncoderMaxAndMinBitratesUsedForTwoStreamsHighestActive) { + // Two streams, highest stream active. + VideoEncoderConfig config; + const int kNumStreams = 2; + test::FillEncoderConfiguration(kVideoCodecVP8, kNumStreams, &config); + config.max_bitrate_bps = 0; + config.simulcast_layers[0].active = false; + config.simulcast_layers[1].active = true; + config.video_stream_factory = + rtc::make_ref_counted( + "VP8", /*max qp*/ 56, /*screencast*/ false, + /*screenshare enabled*/ false); + video_stream_encoder_->ConfigureEncoder(config.Copy(), kMaxPayloadLength); + + // Default bitrate limits for 270p should be used. + const absl::optional + kDefaultLimits270p = + EncoderInfoSettings::GetDefaultSinglecastBitrateLimitsForResolution( + kVideoCodecVP8, 480 * 270); + video_source_.IncomingCapturedFrame(CreateFrame(1, 480, 270)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(fake_encoder_.video_codec().numberOfSimulcastStreams, kNumStreams); + EXPECT_EQ(static_cast(kDefaultLimits270p->min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_EQ(static_cast(kDefaultLimits270p->max_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + + // Default bitrate limits for 360p should be used. + const absl::optional + kDefaultLimits360p = + EncoderInfoSettings::GetDefaultSinglecastBitrateLimitsForResolution( + kVideoCodecVP8, 640 * 360); + video_source_.IncomingCapturedFrame(CreateFrame(2, 640, 360)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(static_cast(kDefaultLimits360p->min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_EQ(static_cast(kDefaultLimits360p->max_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + + // Resolution b/w 270p and 360p. The default limits for 360p should be used. + video_source_.IncomingCapturedFrame( + CreateFrame(3, (640 + 480) / 2, (360 + 270) / 2)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(static_cast(kDefaultLimits360p->min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_EQ(static_cast(kDefaultLimits360p->max_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + + // Default bitrate limits for 540p should be used. + const absl::optional + kDefaultLimits540p = + EncoderInfoSettings::GetDefaultSinglecastBitrateLimitsForResolution( + kVideoCodecVP8, 960 * 540); + video_source_.IncomingCapturedFrame(CreateFrame(4, 960, 540)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(static_cast(kDefaultLimits540p->min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_EQ(static_cast(kDefaultLimits540p->max_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, + EncoderMaxAndMinBitratesUsedForThreeStreamsMiddleActive) { + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits270p( + 480 * 270, 34 * 1000, 12 * 1000, 1234 * 1000); + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits360p( + 640 * 360, 43 * 1000, 21 * 1000, 2345 * 1000); + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits720p( + 1280 * 720, 54 * 1000, 31 * 1000, 3456 * 1000); + fake_encoder_.SetResolutionBitrateLimits( + {kEncoderLimits270p, kEncoderLimits360p, kEncoderLimits720p}); + + // Three streams, middle stream active. + VideoEncoderConfig config; + const int kNumStreams = 3; + test::FillEncoderConfiguration(kVideoCodecVP8, kNumStreams, &config); + config.simulcast_layers[0].active = false; + config.simulcast_layers[1].active = true; + config.simulcast_layers[2].active = false; + config.video_stream_factory = + rtc::make_ref_counted( + "VP8", /*max qp*/ 56, /*screencast*/ false, + /*screenshare enabled*/ false); + video_stream_encoder_->ConfigureEncoder(config.Copy(), kMaxPayloadLength); + + // The encoder bitrate limits for 360p should be used. + video_source_.IncomingCapturedFrame(CreateFrame(1, 1280, 720)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(fake_encoder_.video_codec().numberOfSimulcastStreams, kNumStreams); + EXPECT_EQ(static_cast(kEncoderLimits360p.min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_EQ(static_cast(kEncoderLimits360p.max_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + + // The encoder bitrate limits for 270p should be used. + video_source_.IncomingCapturedFrame(CreateFrame(2, 960, 540)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(static_cast(kEncoderLimits270p.min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_EQ(static_cast(kEncoderLimits270p.max_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, + EncoderMaxAndMinBitratesNotUsedForThreeStreamsLowestActive) { + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits270p( + 480 * 270, 34 * 1000, 12 * 1000, 1234 * 1000); + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits360p( + 640 * 360, 43 * 1000, 21 * 1000, 2345 * 1000); + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits720p( + 1280 * 720, 54 * 1000, 31 * 1000, 3456 * 1000); + fake_encoder_.SetResolutionBitrateLimits( + {kEncoderLimits270p, kEncoderLimits360p, kEncoderLimits720p}); + + // Three streams, lowest stream active. + VideoEncoderConfig config; + const int kNumStreams = 3; + test::FillEncoderConfiguration(kVideoCodecVP8, kNumStreams, &config); + config.simulcast_layers[0].active = true; + config.simulcast_layers[1].active = false; + config.simulcast_layers[2].active = false; + config.video_stream_factory = + rtc::make_ref_counted( + "VP8", /*max qp*/ 56, /*screencast*/ false, + /*screenshare enabled*/ false); + video_stream_encoder_->ConfigureEncoder(config.Copy(), kMaxPayloadLength); + + // Resolution on lowest stream lower than 270p. The encoder limits not applied + // on lowest stream, limits for 270p should not be used + video_source_.IncomingCapturedFrame(CreateFrame(1, 1280, 720)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(fake_encoder_.video_codec().numberOfSimulcastStreams, kNumStreams); + EXPECT_NE(static_cast(kEncoderLimits270p.min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_NE(static_cast(kEncoderLimits270p.max_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, + EncoderMaxBitrateCappedByConfigForTwoStreamsHighestActive) { + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits270p( + 480 * 270, 34 * 1000, 12 * 1000, 1234 * 1000); + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits360p( + 640 * 360, 43 * 1000, 21 * 1000, 2345 * 1000); + fake_encoder_.SetResolutionBitrateLimits( + {kEncoderLimits270p, kEncoderLimits360p}); + const int kMaxBitrateBps = kEncoderLimits360p.max_bitrate_bps - 100 * 1000; + + // Two streams, highest stream active. + VideoEncoderConfig config; + const int kNumStreams = 2; + test::FillEncoderConfiguration(kVideoCodecVP8, kNumStreams, &config); + config.simulcast_layers[0].active = false; + config.simulcast_layers[1].active = true; + config.simulcast_layers[1].max_bitrate_bps = kMaxBitrateBps; + config.video_stream_factory = + rtc::make_ref_counted( + "VP8", /*max qp*/ 56, /*screencast*/ false, + /*screenshare enabled*/ false); + video_stream_encoder_->ConfigureEncoder(config.Copy(), kMaxPayloadLength); + + // The encoder bitrate limits for 270p should be used. + video_source_.IncomingCapturedFrame(CreateFrame(1, 480, 270)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(fake_encoder_.video_codec().numberOfSimulcastStreams, kNumStreams); + EXPECT_EQ(static_cast(kEncoderLimits270p.min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_EQ(static_cast(kEncoderLimits270p.max_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + + // The max configured bitrate is less than the encoder limit for 360p. + video_source_.IncomingCapturedFrame(CreateFrame(2, 640, 360)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(static_cast(kEncoderLimits360p.min_bitrate_bps), + fake_encoder_.video_codec().simulcastStream[1].minBitrate * 1000); + EXPECT_EQ(static_cast(kMaxBitrateBps), + fake_encoder_.video_codec().simulcastStream[1].maxBitrate * 1000); + + video_stream_encoder_->Stop(); +} + TEST_F(VideoStreamEncoderTest, SwitchSourceDeregisterEncoderAsSink) { EXPECT_TRUE(video_source_.has_sinks()); test::FrameForwarder new_video_source; @@ -2072,7 +2389,7 @@ TEST_P(ResolutionAlignmentTest, SinkWantsAlignmentApplied) { config.simulcast_layers[i].scale_resolution_down_by = scale_factors_[i]; } config.video_stream_factory = - new rtc::RefCountedObject( + rtc::make_ref_counted( "VP8", /*max qp*/ 56, /*screencast*/ false, /*screenshare enabled*/ false); video_stream_encoder_->ConfigureEncoder(std::move(config), kMaxPayloadLength); @@ -3036,78 +3353,329 @@ TEST_F(VideoStreamEncoderTest, SkipsSameOrLargerAdaptDownRequest_BalancedMode) { } TEST_F(VideoStreamEncoderTest, - NoChangeForInitialNormalUsage_MaintainFramerateMode) { - const int kWidth = 1280; - const int kHeight = 720; + FpsCountReturnsToZeroForFewerAdaptationsUpThanDown) { + const int kWidth = 640; + const int kHeight = 360; + const int64_t kFrameIntervalMs = 150; video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( DataRate::BitsPerSec(kTargetBitrateBps), DataRate::BitsPerSec(kTargetBitrateBps), DataRate::BitsPerSec(kTargetBitrateBps), 0, 0, 0); - // Enable MAINTAIN_FRAMERATE preference, no initial limitation. - test::FrameForwarder source; - video_stream_encoder_->SetSource( - &source, webrtc::DegradationPreference::MAINTAIN_FRAMERATE); + // Enable BALANCED preference, no initial limitation. + AdaptingFrameForwarder source(&time_controller_); + source.set_adaptation_enabled(true); + video_stream_encoder_->SetSource(&source, + webrtc::DegradationPreference::BALANCED); - source.IncomingCapturedFrame(CreateFrame(1, kWidth, kHeight)); - WaitForEncodedFrame(kWidth, kHeight); + int64_t timestamp_ms = kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + sink_.WaitForEncodedFrame(kWidth, kHeight); EXPECT_THAT(source.sink_wants(), FpsMaxResolutionMax()); - EXPECT_FALSE(stats_proxy_->GetStats().cpu_limited_resolution); - EXPECT_EQ(0, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_EQ(0, stats_proxy_->GetStats().number_of_quality_adapt_changes); - // Trigger adapt up, expect no change. - video_stream_encoder_->TriggerCpuUnderuse(); - EXPECT_THAT(source.sink_wants(), FpsMaxResolutionMax()); - EXPECT_FALSE(stats_proxy_->GetStats().cpu_limited_resolution); - EXPECT_EQ(0, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + // Trigger adapt down, expect reduced fps (640x360@15fps). + video_stream_encoder_->TriggerQualityLow(); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + sink_.WaitForEncodedFrame(timestamp_ms); + EXPECT_THAT(source.sink_wants(), + FpsMatchesResolutionMax(Lt(kDefaultFramerate))); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_EQ(1, stats_proxy_->GetStats().number_of_quality_adapt_changes); - video_stream_encoder_->Stop(); -} + // Source requests 270p, expect reduced resolution (480x270@15fps). + source.OnOutputFormatRequest(480, 270); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + WaitForEncodedFrame(480, 270); + EXPECT_EQ(1, stats_proxy_->GetStats().number_of_quality_adapt_changes); -TEST_F(VideoStreamEncoderTest, - NoChangeForInitialNormalUsage_MaintainResolutionMode) { - const int kWidth = 1280; - const int kHeight = 720; - video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( - DataRate::BitsPerSec(kTargetBitrateBps), - DataRate::BitsPerSec(kTargetBitrateBps), - DataRate::BitsPerSec(kTargetBitrateBps), 0, 0, 0); + // Trigger adapt down, expect reduced fps (480x270@10fps). + video_stream_encoder_->TriggerQualityLow(); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + sink_.WaitForEncodedFrame(timestamp_ms); + EXPECT_THAT(source.sink_wants(), FpsLtResolutionEq(source.last_wants())); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_EQ(2, stats_proxy_->GetStats().number_of_quality_adapt_changes); - // Enable MAINTAIN_RESOLUTION preference, no initial limitation. - test::FrameForwarder source; - video_stream_encoder_->SetSource( - &source, webrtc::DegradationPreference::MAINTAIN_RESOLUTION); + // Source requests QVGA, expect reduced resolution (320x180@10fps). + source.OnOutputFormatRequest(320, 180); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + WaitForEncodedFrame(320, 180); + EXPECT_EQ(2, stats_proxy_->GetStats().number_of_quality_adapt_changes); - source.IncomingCapturedFrame(CreateFrame(1, kWidth, kHeight)); - WaitForEncodedFrame(kWidth, kHeight); - EXPECT_THAT(source.sink_wants(), FpsMaxResolutionMax()); - EXPECT_FALSE(stats_proxy_->GetStats().cpu_limited_framerate); - EXPECT_EQ(0, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + // Trigger adapt down, expect reduced fps (320x180@7fps). + video_stream_encoder_->TriggerQualityLow(); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + sink_.WaitForEncodedFrame(timestamp_ms); + EXPECT_THAT(source.sink_wants(), FpsLtResolutionEq(source.last_wants())); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_EQ(3, stats_proxy_->GetStats().number_of_quality_adapt_changes); - // Trigger adapt up, expect no change. - video_stream_encoder_->TriggerCpuUnderuse(); - EXPECT_THAT(source.sink_wants(), FpsMaxResolutionMax()); - EXPECT_FALSE(stats_proxy_->GetStats().cpu_limited_framerate); - EXPECT_EQ(0, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + // Source requests VGA, expect increased resolution (640x360@7fps). + source.OnOutputFormatRequest(640, 360); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + WaitForEncodedFrame(timestamp_ms); + EXPECT_EQ(3, stats_proxy_->GetStats().number_of_quality_adapt_changes); + + // Trigger adapt up, expect increased fps (640x360@(max-2)fps). + video_stream_encoder_->TriggerQualityHigh(); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + WaitForEncodedFrame(timestamp_ms); + EXPECT_THAT(source.sink_wants(), FpsGtResolutionEq(source.last_wants())); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_EQ(4, stats_proxy_->GetStats().number_of_quality_adapt_changes); + + // Trigger adapt up, expect increased fps (640x360@(max-1)fps). + video_stream_encoder_->TriggerQualityHigh(); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + WaitForEncodedFrame(timestamp_ms); + EXPECT_THAT(source.sink_wants(), FpsGtResolutionEq(source.last_wants())); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_EQ(5, stats_proxy_->GetStats().number_of_quality_adapt_changes); + + // Trigger adapt up, expect increased fps (640x360@maxfps). + video_stream_encoder_->TriggerQualityHigh(); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + WaitForEncodedFrame(timestamp_ms); + EXPECT_THAT(source.sink_wants(), FpsGtResolutionEq(source.last_wants())); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_EQ(6, stats_proxy_->GetStats().number_of_quality_adapt_changes); video_stream_encoder_->Stop(); } -TEST_F(VideoStreamEncoderTest, NoChangeForInitialNormalUsage_BalancedMode) { +TEST_F(VideoStreamEncoderTest, + FpsCountReturnsToZeroForFewerAdaptationsUpThanDownWithTwoResources) { const int kWidth = 1280; const int kHeight = 720; + const int64_t kFrameIntervalMs = 150; video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( DataRate::BitsPerSec(kTargetBitrateBps), DataRate::BitsPerSec(kTargetBitrateBps), DataRate::BitsPerSec(kTargetBitrateBps), 0, 0, 0); // Enable BALANCED preference, no initial limitation. - test::FrameForwarder source; + AdaptingFrameForwarder source(&time_controller_); + source.set_adaptation_enabled(true); video_stream_encoder_->SetSource(&source, webrtc::DegradationPreference::BALANCED); - source.IncomingCapturedFrame(CreateFrame(1, kWidth, kHeight)); - sink_.WaitForEncodedFrame(kWidth, kHeight); + int64_t timestamp_ms = kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + sink_.WaitForEncodedFrame(kWidth, kHeight); + EXPECT_THAT(source.sink_wants(), FpsMaxResolutionMax()); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_EQ(0, stats_proxy_->GetStats().number_of_quality_adapt_changes); + + // Trigger adapt down, expect scaled down resolution (960x540@maxfps). + video_stream_encoder_->TriggerQualityLow(); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + sink_.WaitForEncodedFrame(timestamp_ms); + EXPECT_THAT(source.sink_wants(), + FpsMaxResolutionMatches(Lt(kWidth * kHeight))); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_EQ(1, stats_proxy_->GetStats().number_of_quality_adapt_changes); + + // Trigger adapt down, expect scaled down resolution (640x360@maxfps). + video_stream_encoder_->TriggerQualityLow(); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + sink_.WaitForEncodedFrame(timestamp_ms); + EXPECT_THAT(source.sink_wants(), FpsMaxResolutionLt(source.last_wants())); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_EQ(2, stats_proxy_->GetStats().number_of_quality_adapt_changes); + + // Trigger adapt down, expect reduced fps (640x360@15fps). + video_stream_encoder_->TriggerQualityLow(); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + WaitForEncodedFrame(timestamp_ms); + EXPECT_THAT(source.sink_wants(), FpsLtResolutionEq(source.last_wants())); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_EQ(3, stats_proxy_->GetStats().number_of_quality_adapt_changes); + + // Source requests QVGA, expect reduced resolution (320x180@15fps). + source.OnOutputFormatRequest(320, 180); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + WaitForEncodedFrame(320, 180); + EXPECT_EQ(3, stats_proxy_->GetStats().number_of_quality_adapt_changes); + EXPECT_EQ(0, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + + // Trigger adapt down, expect reduced fps (320x180@7fps). + video_stream_encoder_->TriggerCpuOveruse(); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + WaitForEncodedFrame(timestamp_ms); + EXPECT_THAT(source.sink_wants(), FpsLtResolutionEq(source.last_wants())); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_TRUE(stats_proxy_->GetStats().cpu_limited_resolution); + EXPECT_TRUE(stats_proxy_->GetStats().cpu_limited_framerate); + EXPECT_EQ(3, stats_proxy_->GetStats().number_of_quality_adapt_changes); + EXPECT_EQ(1, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + + // Source requests HD, expect increased resolution (640x360@7fps). + source.OnOutputFormatRequest(1280, 720); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + WaitForEncodedFrame(timestamp_ms); + EXPECT_EQ(3, stats_proxy_->GetStats().number_of_quality_adapt_changes); + EXPECT_EQ(1, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + + // Trigger adapt up, expect increased fps (640x360@(max-1)fps). + video_stream_encoder_->TriggerCpuUnderuse(); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + WaitForEncodedFrame(timestamp_ms); + EXPECT_THAT(source.sink_wants(), FpsGtResolutionEq(source.last_wants())); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_TRUE(stats_proxy_->GetStats().cpu_limited_resolution); + EXPECT_TRUE(stats_proxy_->GetStats().cpu_limited_framerate); + EXPECT_EQ(3, stats_proxy_->GetStats().number_of_quality_adapt_changes); + EXPECT_EQ(2, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + + // Trigger adapt up, expect increased fps (640x360@maxfps). + video_stream_encoder_->TriggerQualityHigh(); + video_stream_encoder_->TriggerCpuUnderuse(); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + WaitForEncodedFrame(timestamp_ms); + EXPECT_THAT(source.sink_wants(), FpsGtResolutionEq(source.last_wants())); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_TRUE(stats_proxy_->GetStats().cpu_limited_resolution); + EXPECT_FALSE(stats_proxy_->GetStats().cpu_limited_framerate); + EXPECT_EQ(4, stats_proxy_->GetStats().number_of_quality_adapt_changes); + EXPECT_EQ(3, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + + // Trigger adapt up, expect increased resolution (960x570@maxfps). + video_stream_encoder_->TriggerQualityHigh(); + video_stream_encoder_->TriggerCpuUnderuse(); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + WaitForEncodedFrame(timestamp_ms); + EXPECT_THAT(source.sink_wants(), FpsEqResolutionGt(source.last_wants())); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_TRUE(stats_proxy_->GetStats().cpu_limited_resolution); + EXPECT_FALSE(stats_proxy_->GetStats().cpu_limited_framerate); + EXPECT_EQ(5, stats_proxy_->GetStats().number_of_quality_adapt_changes); + EXPECT_EQ(4, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + + // Trigger adapt up, expect increased resolution (1280x720@maxfps). + video_stream_encoder_->TriggerQualityHigh(); + video_stream_encoder_->TriggerCpuUnderuse(); + timestamp_ms += kFrameIntervalMs; + source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); + WaitForEncodedFrame(timestamp_ms); + EXPECT_THAT(source.sink_wants(), FpsEqResolutionGt(source.last_wants())); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_resolution); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_framerate); + EXPECT_FALSE(stats_proxy_->GetStats().cpu_limited_resolution); + EXPECT_FALSE(stats_proxy_->GetStats().cpu_limited_framerate); + EXPECT_EQ(6, stats_proxy_->GetStats().number_of_quality_adapt_changes); + EXPECT_EQ(5, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, + NoChangeForInitialNormalUsage_MaintainFramerateMode) { + const int kWidth = 1280; + const int kHeight = 720; + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), 0, 0, 0); + + // Enable MAINTAIN_FRAMERATE preference, no initial limitation. + test::FrameForwarder source; + video_stream_encoder_->SetSource( + &source, webrtc::DegradationPreference::MAINTAIN_FRAMERATE); + + source.IncomingCapturedFrame(CreateFrame(1, kWidth, kHeight)); + WaitForEncodedFrame(kWidth, kHeight); + EXPECT_THAT(source.sink_wants(), FpsMaxResolutionMax()); + EXPECT_FALSE(stats_proxy_->GetStats().cpu_limited_resolution); + EXPECT_EQ(0, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + + // Trigger adapt up, expect no change. + video_stream_encoder_->TriggerCpuUnderuse(); + EXPECT_THAT(source.sink_wants(), FpsMaxResolutionMax()); + EXPECT_FALSE(stats_proxy_->GetStats().cpu_limited_resolution); + EXPECT_EQ(0, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, + NoChangeForInitialNormalUsage_MaintainResolutionMode) { + const int kWidth = 1280; + const int kHeight = 720; + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), 0, 0, 0); + + // Enable MAINTAIN_RESOLUTION preference, no initial limitation. + test::FrameForwarder source; + video_stream_encoder_->SetSource( + &source, webrtc::DegradationPreference::MAINTAIN_RESOLUTION); + + source.IncomingCapturedFrame(CreateFrame(1, kWidth, kHeight)); + WaitForEncodedFrame(kWidth, kHeight); + EXPECT_THAT(source.sink_wants(), FpsMaxResolutionMax()); + EXPECT_FALSE(stats_proxy_->GetStats().cpu_limited_framerate); + EXPECT_EQ(0, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + + // Trigger adapt up, expect no change. + video_stream_encoder_->TriggerCpuUnderuse(); + EXPECT_THAT(source.sink_wants(), FpsMaxResolutionMax()); + EXPECT_FALSE(stats_proxy_->GetStats().cpu_limited_framerate); + EXPECT_EQ(0, stats_proxy_->GetStats().number_of_cpu_adapt_changes); + + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, NoChangeForInitialNormalUsage_BalancedMode) { + const int kWidth = 1280; + const int kHeight = 720; + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), 0, 0, 0); + + // Enable BALANCED preference, no initial limitation. + test::FrameForwarder source; + video_stream_encoder_->SetSource(&source, + webrtc::DegradationPreference::BALANCED); + + source.IncomingCapturedFrame(CreateFrame(1, kWidth, kHeight)); + sink_.WaitForEncodedFrame(kWidth, kHeight); EXPECT_THAT(source.sink_wants(), FpsMaxResolutionMax()); EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_resolution); EXPECT_FALSE(stats_proxy_->GetStats().cpu_limited_framerate); @@ -4091,7 +4659,7 @@ TEST_F(VideoStreamEncoderTest, ReportsVideoLayersAllocationForVP8Simulcast) { } TEST_F(VideoStreamEncoderTest, - ReportsVideoLayersAllocationForVP8WithMidleLayerDisabled) { + ReportsVideoLayersAllocationForVP8WithMiddleLayerDisabled) { fake_encoder_.SetTemporalLayersSupported(/*spatial_idx=*/0, true); fake_encoder_.SetTemporalLayersSupported(/*spatial_idx*/ 1, true); fake_encoder_.SetTemporalLayersSupported(/*spatial_idx*/ 2, true); @@ -4102,7 +4670,7 @@ TEST_F(VideoStreamEncoderTest, video_encoder_config.content_type = VideoEncoderConfig::ContentType::kRealtimeVideo; video_encoder_config.encoder_specific_settings = - new rtc::RefCountedObject( + rtc::make_ref_counted( VideoEncoder::GetDefaultVp8Settings()); for (auto& layer : video_encoder_config.simulcast_layers) { layer.num_temporal_layers = 2; @@ -4136,7 +4704,7 @@ TEST_F(VideoStreamEncoderTest, } TEST_F(VideoStreamEncoderTest, - ReportsVideoLayersAllocationForVP8WithMidleAndHighestLayerDisabled) { + ReportsVideoLayersAllocationForVP8WithMiddleAndHighestLayerDisabled) { fake_encoder_.SetTemporalLayersSupported(/*spatial_idx=*/0, true); fake_encoder_.SetTemporalLayersSupported(/*spatial_idx*/ 1, true); fake_encoder_.SetTemporalLayersSupported(/*spatial_idx*/ 2, true); @@ -4147,7 +4715,7 @@ TEST_F(VideoStreamEncoderTest, video_encoder_config.content_type = VideoEncoderConfig::ContentType::kRealtimeVideo; video_encoder_config.encoder_specific_settings = - new rtc::RefCountedObject( + rtc::make_ref_counted( VideoEncoder::GetDefaultVp8Settings()); for (auto& layer : video_encoder_config.simulcast_layers) { layer.num_temporal_layers = 2; @@ -4196,7 +4764,7 @@ TEST_F(VideoStreamEncoderTest, vp9_settings.interLayerPred = InterLayerPredMode::kOn; vp9_settings.automaticResizeOn = false; video_encoder_config.encoder_specific_settings = - new rtc::RefCountedObject( + rtc::make_ref_counted( vp9_settings); ConfigureEncoder(std::move(video_encoder_config), VideoStreamEncoder::BitrateAllocationCallbackType:: @@ -4251,7 +4819,7 @@ TEST_F(VideoStreamEncoderTest, vp9_settings.interLayerPred = InterLayerPredMode::kOn; vp9_settings.automaticResizeOn = false; video_encoder_config.encoder_specific_settings = - new rtc::RefCountedObject( + rtc::make_ref_counted( vp9_settings); ConfigureEncoder(std::move(video_encoder_config), VideoStreamEncoder::BitrateAllocationCallbackType:: @@ -4299,7 +4867,7 @@ TEST_F(VideoStreamEncoderTest, vp9_settings.interLayerPred = InterLayerPredMode::kOnKeyPic; vp9_settings.automaticResizeOn = false; video_encoder_config.encoder_specific_settings = - new rtc::RefCountedObject( + rtc::make_ref_counted( vp9_settings); ConfigureEncoder(std::move(video_encoder_config), VideoStreamEncoder::BitrateAllocationCallbackType:: @@ -4347,7 +4915,7 @@ TEST_F(VideoStreamEncoderTest, vp9_settings.interLayerPred = InterLayerPredMode::kOn; vp9_settings.automaticResizeOn = false; video_encoder_config.encoder_specific_settings = - new rtc::RefCountedObject( + rtc::make_ref_counted( vp9_settings); // Simulcast layers are used for enabling/disabling streams. video_encoder_config.simulcast_layers.resize(3); @@ -4406,7 +4974,7 @@ TEST_F(VideoStreamEncoderTest, vp9_settings.interLayerPred = InterLayerPredMode::kOn; vp9_settings.automaticResizeOn = false; video_encoder_config.encoder_specific_settings = - new rtc::RefCountedObject( + rtc::make_ref_counted( vp9_settings); // Simulcast layers are used for enabling/disabling streams. video_encoder_config.simulcast_layers.resize(3); @@ -4458,7 +5026,7 @@ TEST_F(VideoStreamEncoderTest, vp9_settings.interLayerPred = InterLayerPredMode::kOn; vp9_settings.automaticResizeOn = false; video_encoder_config.encoder_specific_settings = - new rtc::RefCountedObject( + rtc::make_ref_counted( vp9_settings); // Simulcast layers are used for enabling/disabling streams. video_encoder_config.simulcast_layers.resize(3); @@ -5057,6 +5625,10 @@ TEST_F(VideoStreamEncoderTest, InitialFrameDropActivatesWhenLayersChange) { VideoEncoderConfig video_encoder_config; test::FillEncoderConfiguration(PayloadStringToCodecType("VP8"), 3, &video_encoder_config); + video_encoder_config.video_stream_factory = + rtc::make_ref_counted( + "VP8", /*max qp*/ 56, /*screencast*/ false, + /*screenshare enabled*/ false); for (auto& layer : video_encoder_config.simulcast_layers) { layer.num_temporal_layers = 1; layer.max_framerate = kDefaultFramerate; @@ -5121,7 +5693,7 @@ TEST_F(VideoStreamEncoderTest, InitialFrameDropActivatesWhenSVCLayersChange) { // Since only one layer is active - automatic resize should be enabled. vp9_settings.automaticResizeOn = true; video_encoder_config.encoder_specific_settings = - new rtc::RefCountedObject( + rtc::make_ref_counted( vp9_settings); video_encoder_config.max_bitrate_bps = kSimulcastTargetBitrateBps; video_encoder_config.content_type = @@ -5162,53 +5734,304 @@ TEST_F(VideoStreamEncoderTest, InitialFrameDropActivatesWhenSVCLayersChange) { } TEST_F(VideoStreamEncoderTest, - InitialFrameDropActivatesWhenResolutionIncreases) { - const int kWidth = 640; - const int kHeight = 360; + EncoderMaxAndMinBitratesUsedIfMiddleStreamActive) { + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits270p( + 480 * 270, 34 * 1000, 12 * 1000, 1234 * 1000); + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits360p( + 640 * 360, 43 * 1000, 21 * 1000, 2345 * 1000); + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits720p( + 1280 * 720, 54 * 1000, 31 * 1000, 2500 * 1000); + fake_encoder_.SetResolutionBitrateLimits( + {kEncoderLimits270p, kEncoderLimits360p, kEncoderLimits720p}); - video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( - DataRate::BitsPerSec(kTargetBitrateBps), - DataRate::BitsPerSec(kTargetBitrateBps), - DataRate::BitsPerSec(kTargetBitrateBps), 0, 0, 0); - video_source_.IncomingCapturedFrame(CreateFrame(1, kWidth / 2, kHeight / 2)); - // Frame should not be dropped. - WaitForEncodedFrame(1); + VideoEncoderConfig video_encoder_config; + test::FillEncoderConfiguration(PayloadStringToCodecType("VP9"), 1, + &video_encoder_config); + VideoCodecVP9 vp9_settings = VideoEncoder::GetDefaultVp9Settings(); + vp9_settings.numberOfSpatialLayers = 3; + // Since only one layer is active - automatic resize should be enabled. + vp9_settings.automaticResizeOn = true; + video_encoder_config.encoder_specific_settings = + rtc::make_ref_counted( + vp9_settings); + video_encoder_config.max_bitrate_bps = kSimulcastTargetBitrateBps; + video_encoder_config.content_type = + VideoEncoderConfig::ContentType::kRealtimeVideo; + // Simulcast layers are used to indicate which spatial layers are active. + video_encoder_config.simulcast_layers.resize(3); + video_encoder_config.simulcast_layers[0].active = false; + video_encoder_config.simulcast_layers[1].active = true; + video_encoder_config.simulcast_layers[2].active = false; - video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( - DataRate::BitsPerSec(kLowTargetBitrateBps), - DataRate::BitsPerSec(kLowTargetBitrateBps), - DataRate::BitsPerSec(kLowTargetBitrateBps), 0, 0, 0); - video_source_.IncomingCapturedFrame(CreateFrame(2, kWidth / 2, kHeight / 2)); - // Frame should not be dropped, bitrate not too low for frame. - WaitForEncodedFrame(2); + video_stream_encoder_->ConfigureEncoder(video_encoder_config.Copy(), + kMaxPayloadLength); + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); - // Incoming resolution increases. - video_source_.IncomingCapturedFrame(CreateFrame(3, kWidth, kHeight)); - // Expect to drop this frame, bitrate too low for frame. - ExpectDroppedFrame(); + // The encoder bitrate limits for 360p should be used. + video_source_.IncomingCapturedFrame(CreateFrame(1, 1280, 720)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(fake_encoder_.video_codec().numberOfSimulcastStreams, 1); + EXPECT_EQ(fake_encoder_.video_codec().codecType, + VideoCodecType::kVideoCodecVP9); + EXPECT_EQ(fake_encoder_.video_codec().VP9()->numberOfSpatialLayers, 2); + EXPECT_TRUE(fake_encoder_.video_codec().spatialLayers[0].active); + EXPECT_EQ(640, fake_encoder_.video_codec().spatialLayers[0].width); + EXPECT_EQ(360, fake_encoder_.video_codec().spatialLayers[0].height); + EXPECT_EQ(static_cast(kEncoderLimits360p.min_bitrate_bps), + fake_encoder_.video_codec().spatialLayers[0].minBitrate * 1000); + EXPECT_EQ(static_cast(kEncoderLimits360p.max_bitrate_bps), + fake_encoder_.video_codec().spatialLayers[0].maxBitrate * 1000); + + // The encoder bitrate limits for 270p should be used. + video_source_.IncomingCapturedFrame(CreateFrame(2, 960, 540)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(fake_encoder_.video_codec().numberOfSimulcastStreams, 1); + EXPECT_EQ(fake_encoder_.video_codec().codecType, + VideoCodecType::kVideoCodecVP9); + EXPECT_EQ(fake_encoder_.video_codec().VP9()->numberOfSpatialLayers, 2); + EXPECT_TRUE(fake_encoder_.video_codec().spatialLayers[0].active); + EXPECT_EQ(480, fake_encoder_.video_codec().spatialLayers[0].width); + EXPECT_EQ(270, fake_encoder_.video_codec().spatialLayers[0].height); + EXPECT_EQ(static_cast(kEncoderLimits270p.min_bitrate_bps), + fake_encoder_.video_codec().spatialLayers[0].minBitrate * 1000); + EXPECT_EQ(static_cast(kEncoderLimits270p.max_bitrate_bps), + fake_encoder_.video_codec().spatialLayers[0].maxBitrate * 1000); - // Expect the sink_wants to specify a scaled frame. - EXPECT_TRUE_WAIT( - video_source_.sink_wants().max_pixel_count < kWidth * kHeight, 5000); video_stream_encoder_->Stop(); } -TEST_F(VideoStreamEncoderTest, InitialFrameDropIsNotReactivatedWhenAdaptingUp) { - const int kWidth = 640; - const int kHeight = 360; - // So that quality scaling doesn't happen by itself. - fake_encoder_.SetQp(kQpHigh); - - AdaptingFrameForwarder source(&time_controller_); - source.set_adaptation_enabled(true); - video_stream_encoder_->SetSource( - &source, webrtc::DegradationPreference::MAINTAIN_FRAMERATE); - - int timestamp = 1; - - video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( - DataRate::BitsPerSec(kTargetBitrateBps), - DataRate::BitsPerSec(kTargetBitrateBps), +TEST_F(VideoStreamEncoderTest, + DefaultMaxAndMinBitratesUsedIfMiddleStreamActive) { + VideoEncoderConfig video_encoder_config; + test::FillEncoderConfiguration(PayloadStringToCodecType("VP9"), 1, + &video_encoder_config); + VideoCodecVP9 vp9_settings = VideoEncoder::GetDefaultVp9Settings(); + vp9_settings.numberOfSpatialLayers = 3; + // Since only one layer is active - automatic resize should be enabled. + vp9_settings.automaticResizeOn = true; + video_encoder_config.encoder_specific_settings = + rtc::make_ref_counted( + vp9_settings); + video_encoder_config.max_bitrate_bps = kSimulcastTargetBitrateBps; + video_encoder_config.content_type = + VideoEncoderConfig::ContentType::kRealtimeVideo; + // Simulcast layers are used to indicate which spatial layers are active. + video_encoder_config.simulcast_layers.resize(3); + video_encoder_config.simulcast_layers[0].active = false; + video_encoder_config.simulcast_layers[1].active = true; + video_encoder_config.simulcast_layers[2].active = false; + + video_stream_encoder_->ConfigureEncoder(video_encoder_config.Copy(), + kMaxPayloadLength); + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + + // The default bitrate limits for 360p should be used. + const absl::optional kLimits360p = + EncoderInfoSettings::GetDefaultSinglecastBitrateLimitsForResolution( + kVideoCodecVP9, 640 * 360); + video_source_.IncomingCapturedFrame(CreateFrame(1, 1280, 720)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(fake_encoder_.video_codec().numberOfSimulcastStreams, 1); + EXPECT_EQ(fake_encoder_.video_codec().codecType, + VideoCodecType::kVideoCodecVP9); + EXPECT_EQ(fake_encoder_.video_codec().VP9()->numberOfSpatialLayers, 2); + EXPECT_TRUE(fake_encoder_.video_codec().spatialLayers[0].active); + EXPECT_EQ(640, fake_encoder_.video_codec().spatialLayers[0].width); + EXPECT_EQ(360, fake_encoder_.video_codec().spatialLayers[0].height); + EXPECT_EQ(static_cast(kLimits360p->min_bitrate_bps), + fake_encoder_.video_codec().spatialLayers[0].minBitrate * 1000); + EXPECT_EQ(static_cast(kLimits360p->max_bitrate_bps), + fake_encoder_.video_codec().spatialLayers[0].maxBitrate * 1000); + + // The default bitrate limits for 270p should be used. + const absl::optional kLimits270p = + EncoderInfoSettings::GetDefaultSinglecastBitrateLimitsForResolution( + kVideoCodecVP9, 480 * 270); + video_source_.IncomingCapturedFrame(CreateFrame(2, 960, 540)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(fake_encoder_.video_codec().numberOfSimulcastStreams, 1); + EXPECT_EQ(fake_encoder_.video_codec().codecType, + VideoCodecType::kVideoCodecVP9); + EXPECT_EQ(fake_encoder_.video_codec().VP9()->numberOfSpatialLayers, 2); + EXPECT_TRUE(fake_encoder_.video_codec().spatialLayers[0].active); + EXPECT_EQ(480, fake_encoder_.video_codec().spatialLayers[0].width); + EXPECT_EQ(270, fake_encoder_.video_codec().spatialLayers[0].height); + EXPECT_EQ(static_cast(kLimits270p->min_bitrate_bps), + fake_encoder_.video_codec().spatialLayers[0].minBitrate * 1000); + EXPECT_EQ(static_cast(kLimits270p->max_bitrate_bps), + fake_encoder_.video_codec().spatialLayers[0].maxBitrate * 1000); + + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, DefaultMaxAndMinBitratesNotUsedIfDisabled) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-DefaultBitrateLimitsKillSwitch/Enabled/"); + VideoEncoderConfig video_encoder_config; + test::FillEncoderConfiguration(PayloadStringToCodecType("VP9"), 1, + &video_encoder_config); + VideoCodecVP9 vp9_settings = VideoEncoder::GetDefaultVp9Settings(); + vp9_settings.numberOfSpatialLayers = 3; + // Since only one layer is active - automatic resize should be enabled. + vp9_settings.automaticResizeOn = true; + video_encoder_config.encoder_specific_settings = + rtc::make_ref_counted( + vp9_settings); + video_encoder_config.max_bitrate_bps = kSimulcastTargetBitrateBps; + video_encoder_config.content_type = + VideoEncoderConfig::ContentType::kRealtimeVideo; + // Simulcast layers are used to indicate which spatial layers are active. + video_encoder_config.simulcast_layers.resize(3); + video_encoder_config.simulcast_layers[0].active = false; + video_encoder_config.simulcast_layers[1].active = true; + video_encoder_config.simulcast_layers[2].active = false; + + // Reset encoder for field trials to take effect. + ConfigureEncoder(video_encoder_config.Copy()); + + video_stream_encoder_->ConfigureEncoder(video_encoder_config.Copy(), + kMaxPayloadLength); + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + + // The default bitrate limits for 360p should not be used. + const absl::optional kLimits360p = + EncoderInfoSettings::GetDefaultSinglecastBitrateLimitsForResolution( + kVideoCodecVP9, 640 * 360); + video_source_.IncomingCapturedFrame(CreateFrame(1, 1280, 720)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(fake_encoder_.video_codec().numberOfSimulcastStreams, 1); + EXPECT_EQ(fake_encoder_.video_codec().codecType, kVideoCodecVP9); + EXPECT_EQ(fake_encoder_.video_codec().VP9()->numberOfSpatialLayers, 2); + EXPECT_TRUE(fake_encoder_.video_codec().spatialLayers[0].active); + EXPECT_EQ(640, fake_encoder_.video_codec().spatialLayers[0].width); + EXPECT_EQ(360, fake_encoder_.video_codec().spatialLayers[0].height); + EXPECT_NE(static_cast(kLimits360p->max_bitrate_bps), + fake_encoder_.video_codec().spatialLayers[0].maxBitrate * 1000); + + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, SinglecastBitrateLimitsNotUsedForOneStream) { + ResetEncoder("VP9", /*num_streams=*/1, /*num_temporal_layers=*/1, + /*num_spatial_layers=*/1, /*screenshare=*/false); + + // The default singlecast bitrate limits for 720p should not be used. + const absl::optional kLimits720p = + EncoderInfoSettings::GetDefaultSinglecastBitrateLimitsForResolution( + kVideoCodecVP9, 1280 * 720); + video_source_.IncomingCapturedFrame(CreateFrame(1, 1280, 720)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(fake_encoder_.video_codec().numberOfSimulcastStreams, 1); + EXPECT_EQ(fake_encoder_.video_codec().codecType, + VideoCodecType::kVideoCodecVP9); + EXPECT_EQ(fake_encoder_.video_codec().VP9()->numberOfSpatialLayers, 1); + EXPECT_TRUE(fake_encoder_.video_codec().spatialLayers[0].active); + EXPECT_EQ(1280, fake_encoder_.video_codec().spatialLayers[0].width); + EXPECT_EQ(720, fake_encoder_.video_codec().spatialLayers[0].height); + EXPECT_NE(static_cast(kLimits720p->max_bitrate_bps), + fake_encoder_.video_codec().spatialLayers[0].maxBitrate * 1000); + + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, + EncoderMaxAndMinBitratesNotUsedIfLowestStreamActive) { + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits180p( + 320 * 180, 34 * 1000, 12 * 1000, 1234 * 1000); + const VideoEncoder::ResolutionBitrateLimits kEncoderLimits720p( + 1280 * 720, 54 * 1000, 31 * 1000, 2500 * 1000); + fake_encoder_.SetResolutionBitrateLimits( + {kEncoderLimits180p, kEncoderLimits720p}); + + VideoEncoderConfig video_encoder_config; + test::FillEncoderConfiguration(PayloadStringToCodecType("VP9"), 1, + &video_encoder_config); + VideoCodecVP9 vp9_settings = VideoEncoder::GetDefaultVp9Settings(); + vp9_settings.numberOfSpatialLayers = 3; + // Since only one layer is active - automatic resize should be enabled. + vp9_settings.automaticResizeOn = true; + video_encoder_config.encoder_specific_settings = + rtc::make_ref_counted( + vp9_settings); + video_encoder_config.max_bitrate_bps = kSimulcastTargetBitrateBps; + video_encoder_config.content_type = + VideoEncoderConfig::ContentType::kRealtimeVideo; + // Simulcast layers are used to indicate which spatial layers are active. + video_encoder_config.simulcast_layers.resize(3); + video_encoder_config.simulcast_layers[0].active = true; + video_encoder_config.simulcast_layers[1].active = false; + video_encoder_config.simulcast_layers[2].active = false; + + video_stream_encoder_->ConfigureEncoder(video_encoder_config.Copy(), + kMaxPayloadLength); + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + + // Limits not applied on lowest stream, limits for 180p should not be used. + video_source_.IncomingCapturedFrame(CreateFrame(1, 1280, 720)); + EXPECT_FALSE(WaitForFrame(1000)); + EXPECT_EQ(fake_encoder_.video_codec().numberOfSimulcastStreams, 1); + EXPECT_EQ(fake_encoder_.video_codec().codecType, + VideoCodecType::kVideoCodecVP9); + EXPECT_EQ(fake_encoder_.video_codec().VP9()->numberOfSpatialLayers, 3); + EXPECT_TRUE(fake_encoder_.video_codec().spatialLayers[0].active); + EXPECT_EQ(320, fake_encoder_.video_codec().spatialLayers[0].width); + EXPECT_EQ(180, fake_encoder_.video_codec().spatialLayers[0].height); + EXPECT_NE(static_cast(kEncoderLimits180p.min_bitrate_bps), + fake_encoder_.video_codec().spatialLayers[0].minBitrate * 1000); + EXPECT_NE(static_cast(kEncoderLimits180p.max_bitrate_bps), + fake_encoder_.video_codec().spatialLayers[0].maxBitrate * 1000); + + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, + InitialFrameDropActivatesWhenResolutionIncreases) { + const int kWidth = 640; + const int kHeight = 360; + + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), 0, 0, 0); + video_source_.IncomingCapturedFrame(CreateFrame(1, kWidth / 2, kHeight / 2)); + // Frame should not be dropped. + WaitForEncodedFrame(1); + + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kLowTargetBitrateBps), + DataRate::BitsPerSec(kLowTargetBitrateBps), + DataRate::BitsPerSec(kLowTargetBitrateBps), 0, 0, 0); + video_source_.IncomingCapturedFrame(CreateFrame(2, kWidth / 2, kHeight / 2)); + // Frame should not be dropped, bitrate not too low for frame. + WaitForEncodedFrame(2); + + // Incoming resolution increases. + video_source_.IncomingCapturedFrame(CreateFrame(3, kWidth, kHeight)); + // Expect to drop this frame, bitrate too low for frame. + ExpectDroppedFrame(); + + // Expect the sink_wants to specify a scaled frame. + EXPECT_TRUE_WAIT( + video_source_.sink_wants().max_pixel_count < kWidth * kHeight, 5000); + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, InitialFrameDropIsNotReactivatedWhenAdaptingUp) { + const int kWidth = 640; + const int kHeight = 360; + // So that quality scaling doesn't happen by itself. + fake_encoder_.SetQp(kQpHigh); + + AdaptingFrameForwarder source(&time_controller_); + source.set_adaptation_enabled(true); + video_stream_encoder_->SetSource( + &source, webrtc::DegradationPreference::MAINTAIN_FRAMERATE); + + int timestamp = 1; + + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), DataRate::BitsPerSec(kTargetBitrateBps), 0, 0, 0); source.IncomingCapturedFrame(CreateFrame(timestamp, kWidth, kHeight)); WaitForEncodedFrame(timestamp); @@ -5253,6 +6076,59 @@ TEST_F(VideoStreamEncoderTest, InitialFrameDropIsNotReactivatedWhenAdaptingUp) { video_stream_encoder_->Stop(); } +TEST_F(VideoStreamEncoderTest, + FrameDroppedWhenResolutionIncreasesAndLinkAllocationIsLow) { + const int kMinStartBps360p = 222000; + fake_encoder_.SetResolutionBitrateLimits( + {VideoEncoder::ResolutionBitrateLimits(320 * 180, 0, 30000, 400000), + VideoEncoder::ResolutionBitrateLimits(640 * 360, kMinStartBps360p, 30000, + 800000)}); + + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kMinStartBps360p - 1), // target_bitrate + DataRate::BitsPerSec(kMinStartBps360p - 1), // stable_target_bitrate + DataRate::BitsPerSec(kMinStartBps360p - 1), // link_allocation + 0, 0, 0); + // Frame should not be dropped, bitrate not too low for frame. + video_source_.IncomingCapturedFrame(CreateFrame(1, 320, 180)); + WaitForEncodedFrame(1); + + // Incoming resolution increases, initial frame drop activates. + // Frame should be dropped, link allocation too low for frame. + video_source_.IncomingCapturedFrame(CreateFrame(2, 640, 360)); + ExpectDroppedFrame(); + + // Expect sink_wants to specify a scaled frame. + EXPECT_TRUE_WAIT(video_source_.sink_wants().max_pixel_count < 640 * 360, + 5000); + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, + FrameNotDroppedWhenResolutionIncreasesAndLinkAllocationIsHigh) { + const int kMinStartBps360p = 222000; + fake_encoder_.SetResolutionBitrateLimits( + {VideoEncoder::ResolutionBitrateLimits(320 * 180, 0, 30000, 400000), + VideoEncoder::ResolutionBitrateLimits(640 * 360, kMinStartBps360p, 30000, + 800000)}); + + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kMinStartBps360p - 1), // target_bitrate + DataRate::BitsPerSec(kMinStartBps360p - 1), // stable_target_bitrate + DataRate::BitsPerSec(kMinStartBps360p), // link_allocation + 0, 0, 0); + // Frame should not be dropped, bitrate not too low for frame. + video_source_.IncomingCapturedFrame(CreateFrame(1, 320, 180)); + WaitForEncodedFrame(1); + + // Incoming resolution increases, initial frame drop activates. + // Frame should be dropped, link allocation not too low for frame. + video_source_.IncomingCapturedFrame(CreateFrame(2, 640, 360)); + WaitForEncodedFrame(2); + + video_stream_encoder_->Stop(); +} + TEST_F(VideoStreamEncoderTest, RampsUpInQualityWhenBwIsHigh) { webrtc::test::ScopedFieldTrials field_trials( "WebRTC-Video-QualityRampupSettings/min_pixels:1,min_duration_ms:2000/"); @@ -5322,6 +6198,8 @@ TEST_F(VideoStreamEncoderTest, RampsUpInQualityWhenBwIsHigh) { TEST_F(VideoStreamEncoderTest, QualityScalerAdaptationsRemovedWhenQualityScalingDisabled) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-Video-QualityScaling/Disabled/"); AdaptingFrameForwarder source(&time_controller_); source.set_adaptation_enabled(true); video_stream_encoder_->SetSource(&source, @@ -5734,7 +6612,7 @@ TEST_F(VideoStreamEncoderTest, EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_framerate); EXPECT_EQ(7, stats_proxy_->GetStats().number_of_quality_adapt_changes); - // Trigger adapt up, expect expect increased fps (320x180@10fps). + // Trigger adapt up, expect increased fps (320x180@10fps). video_stream_encoder_->TriggerQualityHigh(); timestamp_ms += kFrameIntervalMs; source.IncomingCapturedFrame(CreateFrame(timestamp_ms, kWidth, kHeight)); @@ -6071,7 +6949,7 @@ TEST_F(VideoStreamEncoderTest, AcceptsFullHdAdaptedDownSimulcastFrames) { video_encoder_config.simulcast_layers[0].max_framerate = kFramerate; video_encoder_config.max_bitrate_bps = kTargetBitrateBps; video_encoder_config.video_stream_factory = - new rtc::RefCountedObject(); + rtc::make_ref_counted(); video_stream_encoder_->ConfigureEncoder(std::move(video_encoder_config), kMaxPayloadLength); video_stream_encoder_->WaitUntilTaskQueueIsIdle(); @@ -6232,6 +7110,45 @@ TEST_F(VideoStreamEncoderTest, video_stream_encoder_->Stop(); } +TEST_F(VideoStreamEncoderTest, + CpuAdaptationThresholdsUpdatesWhenHardwareAccelerationChange) { + const int kFrameWidth = 1280; + const int kFrameHeight = 720; + + const CpuOveruseOptions default_options; + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), 0, 0, 0); + video_source_.IncomingCapturedFrame( + CreateFrame(1, kFrameWidth, kFrameHeight)); + WaitForEncodedFrame(1); + EXPECT_EQ(video_stream_encoder_->overuse_detector_proxy_->GetOptions() + .low_encode_usage_threshold_percent, + default_options.low_encode_usage_threshold_percent); + EXPECT_EQ(video_stream_encoder_->overuse_detector_proxy_->GetOptions() + .high_encode_usage_threshold_percent, + default_options.high_encode_usage_threshold_percent); + + CpuOveruseOptions hardware_options; + hardware_options.low_encode_usage_threshold_percent = 150; + hardware_options.high_encode_usage_threshold_percent = 200; + fake_encoder_.SetIsHardwareAccelerated(true); + + video_source_.IncomingCapturedFrame( + CreateFrame(2, kFrameWidth, kFrameHeight)); + WaitForEncodedFrame(2); + + EXPECT_EQ(video_stream_encoder_->overuse_detector_proxy_->GetOptions() + .low_encode_usage_threshold_percent, + hardware_options.low_encode_usage_threshold_percent); + EXPECT_EQ(video_stream_encoder_->overuse_detector_proxy_->GetOptions() + .high_encode_usage_threshold_percent, + hardware_options.high_encode_usage_threshold_percent); + + video_stream_encoder_->Stop(); +} + TEST_F(VideoStreamEncoderTest, DropsFramesWhenEncoderOvershoots) { const int kFrameWidth = 320; const int kFrameHeight = 240; @@ -6279,11 +7196,7 @@ TEST_F(VideoStreamEncoderTest, DropsFramesWhenEncoderOvershoots) { // doesn't push back as hard so we don't need quite as much overshoot. // These numbers are unfortunately a bit magical but there's not trivial // way to algebraically infer them. - if (trials.BitrateAdjusterCanUseNetworkHeadroom()) { - overshoot_factor = 2.4; - } else { - overshoot_factor = 4.0; - } + overshoot_factor = 3.0; } fake_encoder_.SimulateOvershoot(overshoot_factor); video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( @@ -6515,14 +7428,12 @@ TEST_F(VideoStreamEncoderTest, AdjustsTimestampInternalSource) { int64_t timestamp = 1; EncodedImage image; - image.SetEncodedData( - EncodedImageBuffer::Create(kTargetBitrateBps / kDefaultFramerate / 8)); image.capture_time_ms_ = ++timestamp; image.SetTimestamp(static_cast(timestamp * 90)); const int64_t kEncodeFinishDelayMs = 10; image.timing_.encode_start_ms = timestamp; image.timing_.encode_finish_ms = timestamp + kEncodeFinishDelayMs; - fake_encoder_.InjectEncodedImage(image); + fake_encoder_.InjectEncodedImage(image, /*codec_specific_info=*/nullptr); // Wait for frame without incrementing clock. EXPECT_TRUE(sink_.WaitForFrame(kDefaultTimeoutMs)); // Frame is captured kEncodeFinishDelayMs before it's encoded, so restored @@ -6720,125 +7631,6 @@ struct MockEncoderSwitchRequestCallback : public EncoderSwitchRequestCallback { (override)); }; -TEST_F(VideoStreamEncoderTest, BitrateEncoderSwitch) { - constexpr int kDontCare = 100; - - StrictMock switch_callback; - video_send_config_.encoder_settings.encoder_switch_request_callback = - &switch_callback; - VideoEncoderConfig encoder_config = video_encoder_config_.Copy(); - encoder_config.codec_type = kVideoCodecVP8; - webrtc::test::ScopedFieldTrials field_trial( - "WebRTC-NetworkCondition-EncoderSwitch/" - "codec_thresholds:VP8;100;-1|H264;-1;30000," - "to_codec:AV1,to_param:ping,to_value:pong,window:2.0/"); - - // Reset encoder for new configuration to take effect. - ConfigureEncoder(std::move(encoder_config)); - - // Send one frame to trigger ReconfigureEncoder. - video_source_.IncomingCapturedFrame( - CreateFrame(kDontCare, kDontCare, kDontCare)); - - using Config = EncoderSwitchRequestCallback::Config; - EXPECT_CALL(switch_callback, RequestEncoderSwitch(Matcher( - AllOf(Field(&Config::codec_name, "AV1"), - Field(&Config::param, "ping"), - Field(&Config::value, "pong"))))); - - video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( - /*target_bitrate=*/DataRate::KilobitsPerSec(50), - /*stable_target_bitrate=*/DataRate::KilobitsPerSec(kDontCare), - /*link_allocation=*/DataRate::KilobitsPerSec(kDontCare), - /*fraction_lost=*/0, - /*rtt_ms=*/0, - /*cwnd_reduce_ratio=*/0); - AdvanceTime(TimeDelta::Millis(0)); - - video_stream_encoder_->Stop(); -} - -TEST_F(VideoStreamEncoderTest, VideoSuspendedNoEncoderSwitch) { - constexpr int kDontCare = 100; - - StrictMock switch_callback; - video_send_config_.encoder_settings.encoder_switch_request_callback = - &switch_callback; - VideoEncoderConfig encoder_config = video_encoder_config_.Copy(); - encoder_config.codec_type = kVideoCodecVP8; - webrtc::test::ScopedFieldTrials field_trial( - "WebRTC-NetworkCondition-EncoderSwitch/" - "codec_thresholds:VP8;100;-1|H264;-1;30000," - "to_codec:AV1,to_param:ping,to_value:pong,window:2.0/"); - - // Reset encoder for new configuration to take effect. - ConfigureEncoder(std::move(encoder_config)); - - // Send one frame to trigger ReconfigureEncoder. - video_source_.IncomingCapturedFrame( - CreateFrame(kDontCare, kDontCare, kDontCare)); - - using Config = EncoderSwitchRequestCallback::Config; - EXPECT_CALL(switch_callback, RequestEncoderSwitch(Matcher(_))) - .Times(0); - - video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( - /*target_bitrate=*/DataRate::KilobitsPerSec(0), - /*stable_target_bitrate=*/DataRate::KilobitsPerSec(0), - /*link_allocation=*/DataRate::KilobitsPerSec(kDontCare), - /*fraction_lost=*/0, - /*rtt_ms=*/0, - /*cwnd_reduce_ratio=*/0); - - video_stream_encoder_->Stop(); -} - -TEST_F(VideoStreamEncoderTest, ResolutionEncoderSwitch) { - constexpr int kSufficientBitrateToNotDrop = 1000; - constexpr int kHighRes = 500; - constexpr int kLowRes = 100; - - StrictMock switch_callback; - video_send_config_.encoder_settings.encoder_switch_request_callback = - &switch_callback; - webrtc::test::ScopedFieldTrials field_trial( - "WebRTC-NetworkCondition-EncoderSwitch/" - "codec_thresholds:VP8;120;-1|H264;-1;30000," - "to_codec:AV1,to_param:ping,to_value:pong,window:2.0/"); - VideoEncoderConfig encoder_config = video_encoder_config_.Copy(); - encoder_config.codec_type = kVideoCodecH264; - - // Reset encoder for new configuration to take effect. - ConfigureEncoder(std::move(encoder_config)); - - // The VideoStreamEncoder needs some bitrate before it can start encoding, - // setting some bitrate so that subsequent calls to WaitForEncodedFrame does - // not fail. - video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( - /*target_bitrate=*/DataRate::KilobitsPerSec(kSufficientBitrateToNotDrop), - /*stable_target_bitrate=*/ - DataRate::KilobitsPerSec(kSufficientBitrateToNotDrop), - /*link_allocation=*/DataRate::KilobitsPerSec(kSufficientBitrateToNotDrop), - /*fraction_lost=*/0, - /*rtt_ms=*/0, - /*cwnd_reduce_ratio=*/0); - - // Send one frame to trigger ReconfigureEncoder. - video_source_.IncomingCapturedFrame(CreateFrame(1, kHighRes, kHighRes)); - WaitForEncodedFrame(1); - - using Config = EncoderSwitchRequestCallback::Config; - EXPECT_CALL(switch_callback, RequestEncoderSwitch(Matcher( - AllOf(Field(&Config::codec_name, "AV1"), - Field(&Config::param, "ping"), - Field(&Config::value, "pong"))))); - - video_source_.IncomingCapturedFrame(CreateFrame(2, kLowRes, kLowRes)); - WaitForEncodedFrame(2); - - video_stream_encoder_->Stop(); -} - TEST_F(VideoStreamEncoderTest, EncoderSelectorCurrentEncoderIsSignaled) { constexpr int kDontCare = 100; StrictMock encoder_selector; @@ -7122,7 +7914,7 @@ TEST_F(VideoStreamEncoderTest, EncoderResetAccordingToParameterChange) { config.simulcast_layers[i].active = true; } config.video_stream_factory = - new rtc::RefCountedObject( + rtc::make_ref_counted( "VP8", /*max qp*/ 56, /*screencast*/ false, /*screenshare enabled*/ false); video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( @@ -7201,4 +7993,543 @@ TEST_F(VideoStreamEncoderTest, EncoderResetAccordingToParameterChange) { video_stream_encoder_->Stop(); } +TEST_F(VideoStreamEncoderTest, EncoderResolutionsExposedInSinglecast) { + const int kFrameWidth = 1280; + const int kFrameHeight = 720; + + SetUp(); + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), 0, 0, 0); + + // Capturing a frame should reconfigure the encoder and expose the encoder + // resolution, which is the same as the input frame. + int64_t timestamp_ms = kFrameIntervalMs; + video_source_.IncomingCapturedFrame( + CreateFrame(timestamp_ms, kFrameWidth, kFrameHeight)); + WaitForEncodedFrame(timestamp_ms); + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + EXPECT_THAT(video_source_.sink_wants().resolutions, + ::testing::ElementsAreArray( + {rtc::VideoSinkWants::FrameSize(kFrameWidth, kFrameHeight)})); + + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, EncoderResolutionsExposedInSimulcast) { + // Pick downscale factors such that we never encode at full resolution - this + // is an interesting use case. The frame resolution influences the encoder + // resolutions, but if no layer has |scale_resolution_down_by| == 1 then the + // encoder should not ask for the frame resolution. This allows video frames + // to have the appearence of one resolution but optimize its internal buffers + // for what is actually encoded. + const size_t kNumSimulcastLayers = 3u; + const float kDownscaleFactors[] = {8.0, 4.0, 2.0}; + const int kFrameWidth = 1280; + const int kFrameHeight = 720; + const rtc::VideoSinkWants::FrameSize kLayer0Size( + kFrameWidth / kDownscaleFactors[0], kFrameHeight / kDownscaleFactors[0]); + const rtc::VideoSinkWants::FrameSize kLayer1Size( + kFrameWidth / kDownscaleFactors[1], kFrameHeight / kDownscaleFactors[1]); + const rtc::VideoSinkWants::FrameSize kLayer2Size( + kFrameWidth / kDownscaleFactors[2], kFrameHeight / kDownscaleFactors[2]); + + VideoEncoderConfig config; + test::FillEncoderConfiguration(kVideoCodecVP8, kNumSimulcastLayers, &config); + for (size_t i = 0; i < kNumSimulcastLayers; ++i) { + config.simulcast_layers[i].scale_resolution_down_by = kDownscaleFactors[i]; + config.simulcast_layers[i].active = true; + } + config.video_stream_factory = + rtc::make_ref_counted( + "VP8", /*max qp*/ 56, /*screencast*/ false, + /*screenshare enabled*/ false); + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kSimulcastTargetBitrateBps), + DataRate::BitsPerSec(kSimulcastTargetBitrateBps), + DataRate::BitsPerSec(kSimulcastTargetBitrateBps), 0, 0, 0); + + // Capture a frame with all layers active. + int64_t timestamp_ms = kFrameIntervalMs; + sink_.SetNumExpectedLayers(kNumSimulcastLayers); + video_stream_encoder_->ConfigureEncoder(config.Copy(), kMaxPayloadLength); + video_source_.IncomingCapturedFrame( + CreateFrame(timestamp_ms, kFrameWidth, kFrameHeight)); + WaitForEncodedFrame(timestamp_ms); + // Expect encoded resolutions to match the expected simulcast layers. + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + EXPECT_THAT( + video_source_.sink_wants().resolutions, + ::testing::ElementsAreArray({kLayer0Size, kLayer1Size, kLayer2Size})); + + // Capture a frame with one of the layers inactive. + timestamp_ms += kFrameIntervalMs; + config.simulcast_layers[2].active = false; + sink_.SetNumExpectedLayers(kNumSimulcastLayers - 1); + video_stream_encoder_->ConfigureEncoder(config.Copy(), kMaxPayloadLength); + video_source_.IncomingCapturedFrame( + CreateFrame(timestamp_ms, kFrameWidth, kFrameHeight)); + WaitForEncodedFrame(timestamp_ms); + + // Expect encoded resolutions to match the expected simulcast layers. + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + EXPECT_THAT(video_source_.sink_wants().resolutions, + ::testing::ElementsAreArray({kLayer0Size, kLayer1Size})); + + // Capture a frame with all but one layer turned off. + timestamp_ms += kFrameIntervalMs; + config.simulcast_layers[1].active = false; + sink_.SetNumExpectedLayers(kNumSimulcastLayers - 2); + video_stream_encoder_->ConfigureEncoder(config.Copy(), kMaxPayloadLength); + video_source_.IncomingCapturedFrame( + CreateFrame(timestamp_ms, kFrameWidth, kFrameHeight)); + WaitForEncodedFrame(timestamp_ms); + + // Expect encoded resolutions to match the expected simulcast layers. + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + EXPECT_THAT(video_source_.sink_wants().resolutions, + ::testing::ElementsAreArray({kLayer0Size})); + + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, QpPresent_QpKept) { + ResetEncoder("VP8", 1, 1, 1, false); + + // Force encoder reconfig. + video_source_.IncomingCapturedFrame( + CreateFrame(1, codec_width_, codec_height_)); + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + + // Set QP on encoded frame and pass the frame to encode complete callback. + // Since QP is present QP parsing won't be triggered and the original value + // should be kept. + EncodedImage encoded_image; + encoded_image.qp_ = 123; + encoded_image.SetEncodedData(EncodedImageBuffer::Create( + kCodedFrameVp8Qp25, sizeof(kCodedFrameVp8Qp25))); + CodecSpecificInfo codec_info; + codec_info.codecType = kVideoCodecVP8; + fake_encoder_.InjectEncodedImage(encoded_image, &codec_info); + EXPECT_TRUE(sink_.WaitForFrame(kDefaultTimeoutMs)); + EXPECT_EQ(sink_.GetLastEncodedImage().qp_, 123); + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, QpAbsent_QpParsed) { + ResetEncoder("VP8", 1, 1, 1, false); + + // Force encoder reconfig. + video_source_.IncomingCapturedFrame( + CreateFrame(1, codec_width_, codec_height_)); + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + + // Pass an encoded frame without QP to encode complete callback. QP should be + // parsed and set. + EncodedImage encoded_image; + encoded_image.qp_ = -1; + encoded_image.SetEncodedData(EncodedImageBuffer::Create( + kCodedFrameVp8Qp25, sizeof(kCodedFrameVp8Qp25))); + CodecSpecificInfo codec_info; + codec_info.codecType = kVideoCodecVP8; + fake_encoder_.InjectEncodedImage(encoded_image, &codec_info); + EXPECT_TRUE(sink_.WaitForFrame(kDefaultTimeoutMs)); + EXPECT_EQ(sink_.GetLastEncodedImage().qp_, 25); + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, QpAbsentParsingDisabled_QpAbsent) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-QpParsingKillSwitch/Enabled/"); + + ResetEncoder("VP8", 1, 1, 1, false); + + // Force encoder reconfig. + video_source_.IncomingCapturedFrame( + CreateFrame(1, codec_width_, codec_height_)); + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + + EncodedImage encoded_image; + encoded_image.qp_ = -1; + encoded_image.SetEncodedData(EncodedImageBuffer::Create( + kCodedFrameVp8Qp25, sizeof(kCodedFrameVp8Qp25))); + CodecSpecificInfo codec_info; + codec_info.codecType = kVideoCodecVP8; + fake_encoder_.InjectEncodedImage(encoded_image, &codec_info); + EXPECT_TRUE(sink_.WaitForFrame(kDefaultTimeoutMs)); + EXPECT_EQ(sink_.GetLastEncodedImage().qp_, -1); + video_stream_encoder_->Stop(); +} + +TEST_F(VideoStreamEncoderTest, + QualityScalingNotAllowed_QualityScalingDisabled) { + VideoEncoderConfig video_encoder_config = video_encoder_config_.Copy(); + + // Disable scaling settings in encoder info. + fake_encoder_.SetQualityScaling(false); + // Disable quality scaling in encoder config. + video_encoder_config.is_quality_scaling_allowed = false; + ConfigureEncoder(std::move(video_encoder_config)); + + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), 0, 0, 0); + + test::FrameForwarder source; + video_stream_encoder_->SetSource( + &source, webrtc::DegradationPreference::MAINTAIN_FRAMERATE); + EXPECT_THAT(source.sink_wants(), UnlimitedSinkWants()); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_resolution); + + source.IncomingCapturedFrame(CreateFrame(1, 1280, 720)); + WaitForEncodedFrame(1); + video_stream_encoder_->TriggerQualityLow(); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_resolution); + + video_stream_encoder_->Stop(); +} + +#if !defined(WEBRTC_IOS) +// TODO(bugs.webrtc.org/12401): Disabled because WebRTC-Video-QualityScaling is +// disabled by default on iOS. +TEST_F(VideoStreamEncoderTest, QualityScalingAllowed_QualityScalingEnabled) { + VideoEncoderConfig video_encoder_config = video_encoder_config_.Copy(); + + // Disable scaling settings in encoder info. + fake_encoder_.SetQualityScaling(false); + // Enable quality scaling in encoder config. + video_encoder_config.is_quality_scaling_allowed = true; + ConfigureEncoder(std::move(video_encoder_config)); + + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), 0, 0, 0); + + test::FrameForwarder source; + video_stream_encoder_->SetSource( + &source, webrtc::DegradationPreference::MAINTAIN_FRAMERATE); + EXPECT_THAT(source.sink_wants(), UnlimitedSinkWants()); + EXPECT_FALSE(stats_proxy_->GetStats().bw_limited_resolution); + + source.IncomingCapturedFrame(CreateFrame(1, 1280, 720)); + WaitForEncodedFrame(1); + video_stream_encoder_->TriggerQualityLow(); + EXPECT_TRUE(stats_proxy_->GetStats().bw_limited_resolution); + + video_stream_encoder_->Stop(); +} +#endif + +// Test parameters: (VideoCodecType codec, bool allow_i420_conversion) +class VideoStreamEncoderWithRealEncoderTest + : public VideoStreamEncoderTest, + public ::testing::WithParamInterface> { + public: + VideoStreamEncoderWithRealEncoderTest() + : VideoStreamEncoderTest(), + codec_type_(std::get<0>(GetParam())), + allow_i420_conversion_(std::get<1>(GetParam())) {} + + void SetUp() override { + VideoStreamEncoderTest::SetUp(); + std::unique_ptr encoder; + switch (codec_type_) { + case kVideoCodecVP8: + encoder = VP8Encoder::Create(); + break; + case kVideoCodecVP9: + encoder = VP9Encoder::Create(); + break; + case kVideoCodecAV1: + encoder = CreateLibaomAv1Encoder(); + break; + case kVideoCodecH264: + encoder = + H264Encoder::Create(cricket::VideoCodec(cricket::kH264CodecName)); + break; + case kVideoCodecMultiplex: + mock_encoder_factory_for_multiplex_ = + std::make_unique(); + EXPECT_CALL(*mock_encoder_factory_for_multiplex_, Die); + EXPECT_CALL(*mock_encoder_factory_for_multiplex_, CreateVideoEncoder) + .WillRepeatedly([] { return VP8Encoder::Create(); }); + encoder = std::make_unique( + mock_encoder_factory_for_multiplex_.get(), SdpVideoFormat("VP8"), + false); + break; + default: + RTC_NOTREACHED(); + } + ConfigureEncoderAndBitrate(codec_type_, std::move(encoder)); + } + + void TearDown() override { + video_stream_encoder_->Stop(); + // Ensure |video_stream_encoder_| is destroyed before + // |encoder_proxy_factory_|. + video_stream_encoder_.reset(); + VideoStreamEncoderTest::TearDown(); + } + + protected: + void ConfigureEncoderAndBitrate(VideoCodecType codec_type, + std::unique_ptr encoder) { + // Configure VSE to use the encoder. + encoder_ = std::move(encoder); + encoder_proxy_factory_ = std::make_unique( + encoder_.get(), &encoder_selector_); + video_send_config_.encoder_settings.encoder_factory = + encoder_proxy_factory_.get(); + VideoEncoderConfig video_encoder_config; + test::FillEncoderConfiguration(codec_type, 1, &video_encoder_config); + video_encoder_config_ = video_encoder_config.Copy(); + ConfigureEncoder(video_encoder_config_.Copy()); + + // Set bitrate to ensure frame is not dropped. + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), + DataRate::BitsPerSec(kTargetBitrateBps), 0, 0, 0); + } + + const VideoCodecType codec_type_; + const bool allow_i420_conversion_; + NiceMock encoder_selector_; + std::unique_ptr encoder_proxy_factory_; + std::unique_ptr encoder_; + std::unique_ptr mock_encoder_factory_for_multiplex_; +}; + +TEST_P(VideoStreamEncoderWithRealEncoderTest, EncoderMapsNativeI420) { + auto native_i420_frame = test::CreateMappableNativeFrame( + 1, VideoFrameBuffer::Type::kI420, codec_width_, codec_height_); + video_source_.IncomingCapturedFrame(native_i420_frame); + WaitForEncodedFrame(codec_width_, codec_height_); + + auto mappable_native_buffer = + test::GetMappableNativeBufferFromVideoFrame(native_i420_frame); + std::vector> mapped_frame_buffers = + mappable_native_buffer->GetMappedFramedBuffers(); + ASSERT_EQ(mapped_frame_buffers.size(), 1u); + EXPECT_EQ(mapped_frame_buffers[0]->width(), codec_width_); + EXPECT_EQ(mapped_frame_buffers[0]->height(), codec_height_); + EXPECT_EQ(mapped_frame_buffers[0]->type(), VideoFrameBuffer::Type::kI420); +} + +TEST_P(VideoStreamEncoderWithRealEncoderTest, EncoderMapsNativeNV12) { + auto native_nv12_frame = test::CreateMappableNativeFrame( + 1, VideoFrameBuffer::Type::kNV12, codec_width_, codec_height_); + video_source_.IncomingCapturedFrame(native_nv12_frame); + WaitForEncodedFrame(codec_width_, codec_height_); + + auto mappable_native_buffer = + test::GetMappableNativeBufferFromVideoFrame(native_nv12_frame); + std::vector> mapped_frame_buffers = + mappable_native_buffer->GetMappedFramedBuffers(); + ASSERT_EQ(mapped_frame_buffers.size(), 1u); + EXPECT_EQ(mapped_frame_buffers[0]->width(), codec_width_); + EXPECT_EQ(mapped_frame_buffers[0]->height(), codec_height_); + EXPECT_EQ(mapped_frame_buffers[0]->type(), VideoFrameBuffer::Type::kNV12); + + if (!allow_i420_conversion_) { + EXPECT_FALSE(mappable_native_buffer->DidConvertToI420()); + } +} + +TEST_P(VideoStreamEncoderWithRealEncoderTest, HandlesLayerToggling) { + if (codec_type_ == kVideoCodecMultiplex) { + // Multiplex codec here uses wrapped mock codecs, ignore for this test. + return; + } + + const size_t kNumSpatialLayers = 3u; + const float kDownscaleFactors[] = {4.0, 2.0, 1.0}; + const int kFrameWidth = 1280; + const int kFrameHeight = 720; + const rtc::VideoSinkWants::FrameSize kLayer0Size( + kFrameWidth / kDownscaleFactors[0], kFrameHeight / kDownscaleFactors[0]); + const rtc::VideoSinkWants::FrameSize kLayer1Size( + kFrameWidth / kDownscaleFactors[1], kFrameHeight / kDownscaleFactors[1]); + const rtc::VideoSinkWants::FrameSize kLayer2Size( + kFrameWidth / kDownscaleFactors[2], kFrameHeight / kDownscaleFactors[2]); + + VideoEncoderConfig config; + if (codec_type_ == VideoCodecType::kVideoCodecVP9) { + test::FillEncoderConfiguration(codec_type_, 1, &config); + config.max_bitrate_bps = kSimulcastTargetBitrateBps; + VideoCodecVP9 vp9_settings = VideoEncoder::GetDefaultVp9Settings(); + vp9_settings.numberOfSpatialLayers = kNumSpatialLayers; + vp9_settings.numberOfTemporalLayers = 3; + vp9_settings.automaticResizeOn = false; + config.encoder_specific_settings = + rtc::make_ref_counted( + vp9_settings); + config.spatial_layers = GetSvcConfig(kFrameWidth, kFrameHeight, + /*fps=*/30.0, + /*first_active_layer=*/0, + /*num_spatial_layers=*/3, + /*num_temporal_layers=*/3, + /*is_screenshare=*/false); + } else if (codec_type_ == VideoCodecType::kVideoCodecAV1) { + test::FillEncoderConfiguration(codec_type_, 1, &config); + config.max_bitrate_bps = kSimulcastTargetBitrateBps; + config.spatial_layers = GetSvcConfig(kFrameWidth, kFrameHeight, + /*fps=*/30.0, + /*first_active_layer=*/0, + /*num_spatial_layers=*/3, + /*num_temporal_layers=*/3, + /*is_screenshare=*/false); + config.simulcast_layers[0].scalability_mode = "L3T3_KEY"; + } else { + // Simulcast for VP8/H264. + test::FillEncoderConfiguration(codec_type_, kNumSpatialLayers, &config); + for (size_t i = 0; i < kNumSpatialLayers; ++i) { + config.simulcast_layers[i].scale_resolution_down_by = + kDownscaleFactors[i]; + config.simulcast_layers[i].active = true; + } + if (codec_type_ == VideoCodecType::kVideoCodecH264) { + // Turn off frame dropping to prevent flakiness. + VideoCodecH264 h264_settings = VideoEncoder::GetDefaultH264Settings(); + h264_settings.frameDroppingOn = false; + config.encoder_specific_settings = rtc::make_ref_counted< + VideoEncoderConfig::H264EncoderSpecificSettings>(h264_settings); + } + } + + auto set_layer_active = [&](int layer_idx, bool active) { + if (codec_type_ == VideoCodecType::kVideoCodecVP9 || + codec_type_ == VideoCodecType::kVideoCodecAV1) { + config.spatial_layers[layer_idx].active = active; + } else { + config.simulcast_layers[layer_idx].active = active; + } + }; + + config.video_stream_factory = + rtc::make_ref_counted( + CodecTypeToPayloadString(codec_type_), /*max qp*/ 56, + /*screencast*/ false, + /*screenshare enabled*/ false); + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::BitsPerSec(kSimulcastTargetBitrateBps), + DataRate::BitsPerSec(kSimulcastTargetBitrateBps), + DataRate::BitsPerSec(kSimulcastTargetBitrateBps), 0, 0, 0); + + // Capture a frame with all layers active. + sink_.SetNumExpectedLayers(kNumSpatialLayers); + video_stream_encoder_->ConfigureEncoder(config.Copy(), kMaxPayloadLength); + int64_t timestamp_ms = kFrameIntervalMs; + video_source_.IncomingCapturedFrame( + CreateFrame(timestamp_ms, kFrameWidth, kFrameHeight)); + + WaitForEncodedFrame(kLayer2Size.width, kLayer2Size.height); + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + + // Capture a frame with one of the layers inactive. + set_layer_active(2, false); + sink_.SetNumExpectedLayers(kNumSpatialLayers - 1); + video_stream_encoder_->ConfigureEncoder(config.Copy(), kMaxPayloadLength); + timestamp_ms += kFrameIntervalMs; + video_source_.IncomingCapturedFrame( + CreateFrame(timestamp_ms, kFrameWidth, kFrameHeight)); + WaitForEncodedFrame(kLayer1Size.width, kLayer1Size.height); + + // New target bitrates signaled based on lower resolution. + DataRate kTwoLayerBitrate = DataRate::KilobitsPerSec(833); + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + kTwoLayerBitrate, kTwoLayerBitrate, kTwoLayerBitrate, 0, 0, 0); + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + + // Re-enable the top layer. + set_layer_active(2, true); + sink_.SetNumExpectedLayers(kNumSpatialLayers); + video_stream_encoder_->ConfigureEncoder(config.Copy(), kMaxPayloadLength); + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + + // Bitrate target adjusted back up to enable HD layer... + video_stream_encoder_->OnBitrateUpdatedAndWaitForManagedResources( + DataRate::KilobitsPerSec(1800), DataRate::KilobitsPerSec(1800), + DataRate::KilobitsPerSec(1800), 0, 0, 0); + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + + // ...then add a new frame. + timestamp_ms += kFrameIntervalMs; + video_source_.IncomingCapturedFrame( + CreateFrame(timestamp_ms, kFrameWidth, kFrameHeight)); + WaitForEncodedFrame(kLayer2Size.width, kLayer2Size.height); + video_stream_encoder_->WaitUntilTaskQueueIsIdle(); + + video_stream_encoder_->Stop(); +} + +std::string TestParametersVideoCodecAndAllowI420ConversionToString( + testing::TestParamInfo> info) { + VideoCodecType codec_type = std::get<0>(info.param); + bool allow_i420_conversion = std::get<1>(info.param); + std::string str; + switch (codec_type) { + case kVideoCodecGeneric: + str = "Generic"; + break; + case kVideoCodecVP8: + str = "VP8"; + break; + case kVideoCodecVP9: + str = "VP9"; + break; + case kVideoCodecAV1: + str = "AV1"; + break; + case kVideoCodecH264: + str = "H264"; + break; + case kVideoCodecMultiplex: + str = "Multiplex"; + break; + default: + RTC_NOTREACHED(); + } + str += allow_i420_conversion ? "_AllowToI420" : "_DisallowToI420"; + return str; +} + +constexpr std::pair kVP8DisallowConversion = + std::make_pair(kVideoCodecVP8, /*allow_i420_conversion=*/false); +constexpr std::pair kVP9DisallowConversion = + std::make_pair(kVideoCodecVP9, /*allow_i420_conversion=*/false); +constexpr std::pair kAV1AllowConversion = + std::make_pair(kVideoCodecAV1, /*allow_i420_conversion=*/true); +constexpr std::pair kMultiplexDisallowConversion = + std::make_pair(kVideoCodecMultiplex, /*allow_i420_conversion=*/false); +#if defined(WEBRTC_USE_H264) +constexpr std::pair kH264AllowConversion = + std::make_pair(kVideoCodecH264, /*allow_i420_conversion=*/true); + +// The windows compiler does not tolerate #if statements inside the +// INSTANTIATE_TEST_SUITE_P() macro, so we have to have two definitions (with +// and without H264). +INSTANTIATE_TEST_SUITE_P( + All, + VideoStreamEncoderWithRealEncoderTest, + ::testing::Values(kVP8DisallowConversion, + kVP9DisallowConversion, + kAV1AllowConversion, + kMultiplexDisallowConversion, + kH264AllowConversion), + TestParametersVideoCodecAndAllowI420ConversionToString); +#else +INSTANTIATE_TEST_SUITE_P( + All, + VideoStreamEncoderWithRealEncoderTest, + ::testing::Values(kVP8DisallowConversion, + kVP9DisallowConversion, + kAV1AllowConversion, + kMultiplexDisallowConversion), + TestParametersVideoCodecAndAllowI420ConversionToString); +#endif + } // namespace webrtc diff --git a/webrtc.gni b/webrtc.gni index 05a230c4f1..c0ff14fe51 100644 --- a/webrtc.gni +++ b/webrtc.gni @@ -59,6 +59,12 @@ declare_args() { # provided. rtc_exclude_metrics_default = build_with_chromium + # Setting this to true will define WEBRTC_EXCLUDE_SYSTEM_TIME which + # will tell the pre-processor to remove the default definition of the + # SystemTimeNanos() which is defined in rtc_base/system_time.cc. In + # that case a new implementation needs to be provided. + rtc_exclude_system_time = build_with_chromium + # Setting this to false will require the API user to pass in their own # SSLCertificateVerifier to verify the certificates presented from a # TLS-TURN server. In return disabling this saves around 100kb in the binary. @@ -118,8 +124,8 @@ declare_args() { rtc_link_pipewire = false # Set this to use certain PipeWire version - # Currently we support PipeWire 0.2 (default) and PipeWire 0.3 - rtc_pipewire_version = "0.2" + # Currently WebRTC supports PipeWire 0.2 and PipeWire 0.3 (default) + rtc_pipewire_version = "0.3" # Enable to use the Mozilla internal settings. build_with_mozilla = false @@ -177,8 +183,9 @@ declare_args() { rtc_apprtcmobile_broadcast_extension = false } - # Determines whether Metal is available on iOS/macOS. - rtc_use_metal_rendering = is_mac || (is_ios && current_cpu == "arm64") + # Determines whether OpenGL is available on iOS/macOS. + rtc_ios_macos_use_opengl_rendering = + !(is_ios && target_environment == "catalyst") # When set to false, builtin audio encoder/decoder factories and all the # audio codecs they depend on will not be included in libwebrtc.{a|lib} @@ -201,9 +208,9 @@ declare_args() { rtc_win_undef_unicode = false # When set to true, a capturer implementation that uses the - # Windows.Graphics.Capture APIs will be available for use. These APIs are - # available in the Win 10 SDK v10.0.19041. - rtc_enable_win_wgc = false + # Windows.Graphics.Capture APIs will be available for use. This introduces a + # dependency on the Win 10 SDK v10.0.17763.0. + rtc_enable_win_wgc = is_win } if (!build_with_mozilla) { @@ -227,7 +234,6 @@ declare_args() { rtc_libvpx_build_vp9 = !build_with_mozilla rtc_build_opus = !build_with_mozilla rtc_build_ssl = !build_with_mozilla - rtc_build_usrsctp = !build_with_mozilla # Enable libevent task queues on platforms that support it. if (is_win || is_mac || is_ios || is_nacl || is_fuchsia || @@ -239,10 +245,6 @@ declare_args() { rtc_build_libevent = !build_with_mozilla } - # Build sources requiring GTK. NOTICE: This is not present in Chrome OS - # build environments, even if available for Chromium builds. - rtc_use_gtk = !build_with_chromium && !build_with_mozilla - # Excluded in Chromium since its prerequisites don't require Pulse Audio. rtc_include_pulse_audio = !build_with_chromium @@ -258,7 +260,8 @@ declare_args() { rtc_enable_avx2 = false } - # Include tests in standalone checkout. + # Set this to true to build the unit tests. + # Disabled when building with Chromium or Mozilla. rtc_include_tests = !build_with_chromium && !build_with_mozilla # Set this to false to skip building code that also requires X11 extensions @@ -283,6 +286,14 @@ declare_args() { rtc_exclude_transient_suppressor = false } +declare_args() { + # Enable the dcsctp backend for DataChannels and related unittests + rtc_build_dcsctp = !build_with_mozilla && rtc_enable_sctp + + # Enable the usrsctp backend for DataChannels and related unittests + rtc_build_usrsctp = !build_with_mozilla && rtc_enable_sctp +} + # Make it possible to provide custom locations for some libraries (move these # up into declare_args should we need to actually use them for the GN build). rtc_libvpx_dir = "//third_party/libvpx" @@ -462,9 +473,13 @@ template("rtc_test") { } if (!build_with_chromium && is_android) { android_manifest = webrtc_root + "test/android/AndroidManifest.xml" + use_raw_android_executable = false min_sdk_version = 21 target_sdk_version = 23 - deps += [ webrtc_root + "test:native_test_java" ] + deps += [ + "//build/android/gtest_apk:native_test_instrumentation_test_runner_java", + webrtc_root + "test:native_test_java", + ] } # When not targeting a simulator, building //base/test:google_test_runner @@ -475,6 +490,30 @@ template("rtc_test") { xctest_module_target = "//base/test:google_test_runner" } } + + # If absl_deps is [], no action is needed. If not [], then it needs to be + # converted to //third_party/abseil-cpp:absl when build_with_chromium=true + # otherwise it just needs to be added to deps. + if (defined(absl_deps) && absl_deps != []) { + if (!defined(deps)) { + deps = [] + } + if (build_with_chromium) { + deps += [ "//third_party/abseil-cpp:absl" ] + } else { + deps += absl_deps + } + } + + if (using_sanitizer) { + if (is_linux) { + if (!defined(invoker.data)) { + data = [] + } + data += + [ "//third_party/llvm-build/Release+Asserts/lib/libstdc++.so.6" ] + } + } } } @@ -957,10 +996,16 @@ if (is_ios) { deps = [ ":create_bracket_include_headers_$this_target_name" ] } + if (target_environment == "catalyst") { + # Catalyst frameworks use the same layout as regular Mac frameworks. + headers_dir = "Versions/A/Headers" + } else { + headers_dir = "Headers" + } copy("copy_umbrella_header_$target_name") { sources = [ umbrella_header_path ] outputs = - [ "$root_out_dir/$output_name.framework/Headers/$output_name.h" ] + [ "$root_out_dir/$output_name.framework/$headers_dir/$output_name.h" ] deps = [ ":umbrella_header_$target_name" ] } diff --git a/webrtc_lib_link_test.cc b/webrtc_lib_link_test.cc index 37e1b14eae..055bd969ff 100644 --- a/webrtc_lib_link_test.cc +++ b/webrtc_lib_link_test.cc @@ -65,9 +65,10 @@ void TestCase1ModularFactory() { auto peer_connection_factory = webrtc::CreateModularPeerConnectionFactory(std::move(pcf_deps)); webrtc::PeerConnectionInterface::RTCConfiguration rtc_config; - auto peer_connection = peer_connection_factory->CreatePeerConnection( - rtc_config, nullptr, nullptr, nullptr); - printf("peer_connection=%s\n", peer_connection == nullptr ? "nullptr" : "ok"); + auto result = peer_connection_factory->CreatePeerConnectionOrError( + rtc_config, PeerConnectionDependencies(nullptr)); + // Creation will fail because of null observer, but that's OK. + printf("peer_connection creation=%s\n", result.ok() ? "succeeded" : "failed"); } void TestCase2RegularFactory() { @@ -81,9 +82,10 @@ void TestCase2RegularFactory() { std::move(media_deps.video_encoder_factory), std::move(media_deps.video_decoder_factory), nullptr, nullptr); webrtc::PeerConnectionInterface::RTCConfiguration rtc_config; - auto peer_connection = peer_connection_factory->CreatePeerConnection( - rtc_config, nullptr, nullptr, nullptr); - printf("peer_connection=%s\n", peer_connection == nullptr ? "nullptr" : "ok"); + auto result = peer_connection_factory->CreatePeerConnectionOrError( + rtc_config, PeerConnectionDependencies(nullptr)); + // Creation will fail because of null observer, but that's OK. + printf("peer_connection creation=%s\n", result.ok() ? "succeeded" : "failed"); } } // namespace webrtc